Skip to main content

Marin 32B Spike Fixing

Created on June 25|Last edited on June 27
So about a month ago, Percy posted a version of this plot of our Marin 32B pretraining run. We got a lot of feedback, both public and private, that the spikes were bad.
(If your attention span has been irreparably damaged by chronic scrolling, the answer is QK Norm.)
Some told us we were already doomed, some were a bit more hedgy. Some trusted people privately told us that actually things were probably fine.

1T2Tthroughput/total_tokens2
Run set
1


Ideally there wouldn't be spikes of course. But many of the people we talked to (and our own experience) suggested that if the model recovered quickly and it didn't really change the trajectory, it was fine.
And things did recover pretty quickly, in terms of step count.

Run set
1

We also thought we were okay because the model was massively outperforming our other test models on a flop for flop basis. Recall that the 8b run (the gray top lines) was ultimately on par with Llama 3.1. Orange is the 32B.


Run set
8

So, pretty good right? It's running way ahead. To be clear, the training data is different (mostly Nemotron-CC instead of mostly DCLM), and the batch size is much, much larger, at 32Mi tokens instead of 4-12Mi.
Nevertheless, we tried some interventions: tightening the grad norm clip, loss and grad outlier skipping, update clipping, etc. Nothing seemed to make a huge difference.
We did notice that update spikes always preceded loss spikes, similar to what https://arxiv.org/abs/2304.13013 found. So we were pretty hopeful about updating clipping. (Updates can go to 0 b/c of our OLMo2 style update skipping)


Run set
1


The Bad Spike

But then it happened: a Bad Spike. A spike where the loss didn't recover to the same plateau. Everyone has told us this is Bad News. In absolute terms, the loss spike was nothing. But it just didn't settle back. I dunno why. (Same y axis as the previous "fine" spike.)


Run set
1


So we tried some stuff. We tried skipping the problematic step. We tried Muon (which looked great until it didn't... Need to spend more time with it at small scales.) We could have tried some other stuff, but it was time to take drastic action to end the spikes.

Run set
3

Aside: The Muon run was still warming up its Adam params here so the loss was lower. The Muon run did its bad shift a little later? Might be worth investigating.

QK Norm to the rescue

It was time to do what everyone else has learned but we were too proud, too foolish to try. (After all, the 22b and 70b trials were buttery smooth! Eval losses were ahead of schedule!) It was time to add QK Norm.
Now, look, we knew QK Norm was a good idea. We just thought it wasn't a **necessary** idea, not for us. We were different.
Anyway, let's fix it. **BUT** we aren't going to start from scratch. That is not the patented Marin Tootsie Process™️ way. No flop left behind. Add QK Norm, warmstart, keep the optimizer states, just rewarmup the learning rate. (Worst case, it blows up and we eventually throw it out.)

TL; DR: It caught up real quick, about 6.5B tokens, overshooting (due to warmup) before settling in at just a bit better.

Run set
2

And now it's looking great! So, norms good. Also, you can totally just change something major mid-run and it'll be okay. Or, as the meme goes, you can just do things.

Run set
2



Run set
2