Skip to main content

Can LLMs Learn from a Single Example?

Investigating some curious phenomena around LLM Fine-Tuning.
Created on August 28|Last edited on November 16
This is a companion piece for Jeremy Howard's article over on Fast.ai.
💡
I've been working on the Kaggle "LLM Science Exam" competition a bit recently (playlist of my videos). At some point my team-mate Jeremy noticed a curious pattern in our training loss curve, a kind of stair-stepping descent which had also popped up in some other people's LLM experiments and seemed to warrant further investigation.
We've written a post about it over on the fastai blog - in this report I'll show some more logs from a simple replication of this effect and add a few extra notes.
One of several examples shared by Anton of training loss curves with this step pattern. Cue the jokes about sudden loss drops and AGI!
The drops are happening after each epoch. The first explanation that jumped to mind was memorization / overfitting - but surely that can't be happening within a single epoch!? Conventional wisdom says you need many epochs for that to happen. In most cases the validation metrics continue to go up even as the validation loss flattens out and even starts to worsen - so why worry about it further, right?

Exploring this:

Here's a notebook where you can quickly see this effect at play. I followed the HuggingFace docs to train a Llama-2-7B model on a 'CodeAlpaca' dataset using the SFTTrainer from TRL with LoRA rank set to 16. This is the train loss graph when training on 5000 samples for 3 epochs, testing two different Learning Rates (LRs):


You can see the stair-step pattern we're talking about. In the one with the lower learning rate the validation loss doesn't suffer too much until the third epoch, while the high-LR case shows it clearly getting worse pretty much as soon as data starts repeating. Note that at the end of the first epoch validation loss ~= train loss.
You can explore the effect of learning rate on how well a sample is memorized by messing with the LR schedule. Let's do a cosine schedule with 50% warmup for each epoch:

Run set
1


What are we looking at?

The learning rate starts out super low, so the first few samples don't change the model much (are not 'memorized'). As it increases, the model learns more from each batch.
Towards the end of the first epoch the learning rate decreases again. If the max LR is super high, the loss might get a little worse after the peak. Then the second epoch starts. Unlike the first graph where there is a sudden drop, this one gets gradually better as it goes from the early samples (unmemorized) to the middle ones, then gets worse again as it gets to the end samples which also had a low learning rate!
And in the third epoch this trend is even more pronounced thanks to the middle samples having had two high-learning rate updates while the very early and late samples haven't been memorized thanks to the LR schedule.
Check out our blog post for more examples with our theories as to why exactly this is happening.

Additional Notes

  • The default schedule with HF's Trainer is "linear", which has an optional warmup (set warmup_ratio=0.2 for eg) and then linearly decays to 0 by the end of training. This can mask the effect by smoothing the transitions between epochs.
  • Some people have suggested that this is particularly prevalent on low batch sizes, but I haven't tested this myself. If someone with lots of compute wants to share some experiments that would be grand :) All the experiments here were done with an effective batch size of 8.
  • The effect is stronger for some tasks/datasets than for others. Classification tasks seems to be one where it is particularly strong.
PS: Here's comparing 3 epochs of 5k samples vs 1 epoch of 15k samples, both constant learning rate, to show the difference repeating the data makes. The sudden drops are absent in the single-epoch example.

Run set
2

PPS: All these experiments and a few not shown cost me $2 to run on vast.ai, using a 4090 instance for 5.5 hours because I wanted my own GPU for other stuff. GPU poor FTW.
geronimo
geronimo •  
Thank you for this insightful analysis! However, it's unclear to me how high confidence during training spills over to validation. Refering to the blog post: > Towards the end of that first 10% of the epoch, the training loss plummets, because the LR was high when these batches were seen during the first epoch, and the model has learned what they look like. The model quickly learns that it can very confidentally guess the correct answer. > But during this time, validation loss suffers. That’s because although the model is getting very confident, it’s not actually getting any better at making predictions. I understand how the training loss plummets due to batches which have been seen before and the model making high confidence prediction on those. But how does this influence the predictions of the validation set? It's not like there is some internal memory of how confident the previous 10 predictions where which affects the confidence of the next 10 predictions? Right? What am I missing?
1 reply
Thomas Capelle
Thomas Capelle •  
This is the way, great work as always!
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.