ZLoss vs Not 1.4B
Created on April 4|Last edited on May 18
Comment
The goal of this experiment was to decide if we should use z-loss by default for future Marin runs.
Z-Loss
In this context `Z` is the normalizer of the final softmax over the vocabulary for an LLM:
where is the logit for the th token type (for that position):
where (also called the lm_head in code) converts the final embeddings into vectors over the vocabulary. Adding zloss amounts to adding this to the loss:
which shrinks this term to 0. ( is a constant to control the strength.) Because Z is a softmax, Z-loss basically penalizes the biggest entries of if they get big.
Z-Loss is often added for stability reasons: to prevent (e.g. [Palm paper, page 10] [Olmo 2, page 6]) training runs from diverging. (Big numbers are often bad for training stability.)
Z-Loss in Marin
We didn't have significant problems with stability in the 8b run, but we had strange slow increases in the training loss at very low LR.
So, we tried running two different configs @ 1.4B parameters for 42B tokens with zloss vs not. We used 1e-4 for the zloss penalty, which is the same as Olmo 2 7B. We used what we call the DCLM mix, which is a mix of DCLM Baseline, Proofpile 2, and StarCoder.
Results
The two configs were:
- cosine, wd=0.1 (dclm_mix-cosine-1.4b [gray] vs dclm_mix-1.4b-zloss [purple])
- wsd, wd=0.05 (dclm_mix-wsd-1.4b [green] vs dclm_mix-1.4b-wsd-zloss [red])
There's no discernible impact on either config of adding zloss, in terms of token loss or eval score. (There was a difference between cosine vs wsd, but we weren't concerned with that.)
However, surprisingly, z-loss increases the norm of the lm_head by quite a lot?!? (This is counter intuitive because big lm_head means that the z's are likely to get bigger.)
However, it decreases the final layer norm of the transformer, which makes sense if you think about it: layer norms are typically not weight decayed, but z-loss is sort of like a penalty on the operator norm of the matrix [lm_head * layer_norm.weight].
However, it decreases the final layer norm of the transformer, which makes sense if you think about it: layer norms are typically not weight decayed, but the z-loss acts somewhat like a penalty on the operator norm of the matrix , where is the language modeling head weight matrix and is the final layer norm weight in the transformer and not alone.
Conclusion
Defaulting z-loss on is probably fine, but its behavior is pretty surprising.
Section 1
Run set
4
Add a comment