Skip to main content

ZLoss vs Not 1.4B

Created on April 4|Last edited on May 18
The goal of this experiment was to decide if we should use z-loss by default for future Marin runs.


Z-Loss

In this context `Z` is the normalizer of the final softmax over the vocabulary for an LLM:
logZ=log(iVexpxi)\log Z = \log\left ( \sum_i^V \exp x_i \right )

where xix_i is the logit for the iith token type (for that position):
zi=[WTransformer(x)]iz_i = \left[\, \mathbf{W} \cdot \mathrm{Transformer}(\mathbf{x}) \,\right]_i

where W\mathbf{W} (also called the lm_head in code) converts the final embeddings into vectors over the vocabulary. Adding zloss amounts to adding this to the loss:
α(logZ)2\alpha \left ( \log Z \right )^2

which shrinks this term to 0. (α\alpha is a constant to control the strength.) Because Z is a softmax, Z-loss basically penalizes the biggest entries of z\mathbf z if they get big.
Z-Loss is often added for stability reasons: to prevent (e.g. [Palm paper, page 10] [Olmo 2, page 6]) training runs from diverging. (Big numbers are often bad for training stability.)

Z-Loss in Marin

We didn't have significant problems with stability in the 8b run, but we had strange slow increases in the training loss at very low LR.
So, we tried running two different configs @ 1.4B parameters for 42B tokens with zloss vs not. We used 1e-4 for the zloss penalty, which is the same as Olmo 2 7B. We used what we call the DCLM mix, which is a mix of DCLM Baseline, Proofpile 2, and StarCoder.

Results

The two configs were:
  • cosine, wd=0.1 (dclm_mix-cosine-1.4b [gray] vs dclm_mix-1.4b-zloss [purple])
  • wsd, wd=0.05 (dclm_mix-wsd-1.4b [green] vs dclm_mix-1.4b-wsd-zloss [red])
There's no discernible impact on either config of adding zloss, in terms of token loss or eval score. (There was a difference between cosine vs wsd, but we weren't concerned with that.)
However, surprisingly, z-loss increases the norm of the lm_head by quite a lot?!? (This is counter intuitive because big lm_head means that the z's are likely to get bigger.)

However, it decreases the final layer norm of the transformer, which makes sense if you think about it: layer norms are typically not weight decayed, but z-loss is sort of like a penalty on the operator norm of the matrix [lm_head * layer_norm.weight].
However, it decreases the final layer norm of the transformer, which makes sense if you think about it: layer norms are typically not weight decayed, but the z-loss acts somewhat like a penalty on the operator norm of the matrix WLMwLN\mathbf{W}_{LM}\mathbf{w}_{LN}​​, where WLM\textbf{W}_{LM}​ is the language modeling head weight matrix and wLN\mathbf{w}_{LN}​ is the final layer norm weight in the transformer and not WLM\mathbf{W}_{LM} alone.

Conclusion

Defaulting z-loss on is probably fine, but its behavior is pretty surprising.

Section 1


Run set
4