MoE vs Dense 1b
Created on May 19|Last edited on May 19
Comment
Training setup
- TPU: v5e-128
- Data: 2T token
- Models:
- OLMoE 1b-7b (8 activated experts out of 64) -> # activated params = ~1b
- OLMo 1b (dense)
- Note that these models are not 100% copying architectures from OLMo (e.g. I didn't use QKNorm), but mainly trying to match the number of parameters etc.
Results
- MoE significantly outperforms dense counterpart with the same amount of tokens.
- Small fine-grained experts leads to even worse MFU :( However, from my computation, we're still a bit faster than OLMoE's reported speed (per device) because they likely implemented in pytorch. It's just our dense model in jax is too much faster than pytorch :)
- When x-axis is time, it seems like there's an intersection point at ~20hr such that moe becomes better than dense even though dense training is much faster.
- The performance difference is also reflected in evaluation. See the bottom 3 plots for 3 classical lm eval tasks where moe significantly outperforms dense model.
Add a comment