Experiment 950: How does Learning Rate Schedule In Pretraining Impact SFT?
An Exploration of whether models with the same/similar loss can be harder to SFT
Created on May 18|Last edited on July 8
Comment
In our initial cooldowns, we had a version of the 8B model that seemed like a pretty good language model on evaluations like MMLU and generally had decent evaluation loss. However, when Ahmed went to SFT the model it was really bad at instruction following, despite the same SFT recipe working great for OLMo Base and Llama 3 Base!
This freaked us out a bit - was the model somehow fundamentally broken in ways that made it bad for SFT? Fortunately, Ahmed found that sweeping the learning rate much higher fixed many of these problems and we also fixed a few bugs related to chat templates along the way. Ultimately, this made the model pretty decent after SFT but we still wanted to understand why we had the trouble to begin with.
One hypothesis we had was that we used WSD and WSD-S throughout the Marin 8B training run - while Llama and OLMo 2 used cosine learning rate schedules. We also knew that the norms of the Marin model were really big compared with Llama and OLMo 2. Naturally, the thing to test is (1) Does WSD cause these bigger norms? and (2) Are these bigger norms the cause of the SFT difficulty?
Experiment Setup
We trained 3 1.4B parameter models to 1 Trillion tokens to test this. All models were trained on the baseline DCLM mix of data. This was experiment 950:.
Experiment File: https://marin.community/data-browser/experiment/?path=gs%3A%2F%2Fmarin-us-central2%2Fexperiments%2Fexp950_sft_amenability-050465.json
Three pre-training configurations are tested:
1. Linear schedule with high learning rate (1e-3) and z_loss
2. Cosine schedule with high learning rate (1e-3) and z_loss
3. Cosine schedule with lower learning rate (3e-4) and z_loss
Each resulting model is then fine-tuned using supervised fine-tuning (SFT) with the Tulu SFT configuration. Finally, we tested AlpacaEval for each of these to see whether there were any significant differences.
RQ 1: Does higher LR schedule lead to higher weight norms?
Yes! Perhaps unsurprisingly, the AUC of the learning rate schedule and the final parameter weight norms are highly correlated. WSD leads to the highest parameter weight norm. Since the Transformer has lots of scale invariant layers in the middle, this doesn't impact final loss and all of the models achieve very similar final loss. However, this experiment confirms to us that the learning rate schedule we used does explain at least part of our very large magnitude weights.
Run set
6773
RQ 2: Do bigger norms make the model SFT worse?
Run set
6772
It seems like the answer is also yes! The train loss of the WSD run is consistently slightly higher than that of either of the cosine learning rate runs suggesting that the larger norms do in-fact make this recipe worse for the model with larger norms!
More importantly, across several runs the WSD checkpoint ends up getting significantly lower AlpacaEval scores, suggesting the model is indeed less amenable to SFT.
| Run | WSD | COS-High | COS-Low |
|-----|----|-----------|---------|
| Run 1 | 0.31 | 1.21 | 1.28 |
| Run 2 | 0.8 | 0.73 | 0.83 |
| Run 3 | 0.52 | 0.61 | 1.02 |
| Run 4 | 0.64 | 0.98 | 1.26 |
| **Average** | 0.57 | 0.883 | 1.098 |
| **StdDev** | 0.210 | 0.267 | 0.214 |
Conclusions
Based on Ahmed's empirical findings that increasing the learning rate worked to fix a lot of these discrepancies, it seems likely that the following is true:
- Higher LR for Longer -> Higher Weight Norms
- Higher Weight Norms means you need more/larger updates to actually change model behavior.
- Therefore, a model with higher norms can be hard to SFT with a recipe designed for a model with lower norms.
All of these things get hidden during pretraining since the weight norms don't seem to matter due to all the scale invariance! However, they start to matter a lot during SFT. In future work, we should likely explore whether tricks like this which make the LR adapt to the weight norms are useful and reduce the need for hyperparameter tuning of our models during SFT.
Add a comment