An Evaluation of Transformer Variants
Training of different transformer variants for text-to-image generation with DALL-E-mini
Created on March 17|Last edited on November 15
Comment
Table of contents (click to expand)
Introduction
Our goal today is simple: we're going to train and compare a few transformer variants for text to image generation with DALL-E-mini.
When training a standard BART model, we observe instability where the evaluation loss can suddenly spike while on the first epoch, whether we use "Pre-LayerNorm" or "Post-LayerNorm" architecture.
We can solve it by resuming from latest checkpoint and decreasing learning rate but it eventually ends up limiting our training and can happen more frequently at larger scales. Therefore, we look for alternative transformer architectures that can be more stable during training.
Our model uses the Distributed Shampoo optimizer as it was proved to converge faster and yield better results on our model (see Evaluation of Distributed Shampoo).
Experimental Setup
- 1 model per device: each model is about 400M parameters
- batch size: max until OOM
- gradient accumulation steps: 2 or 3 depending on model so we have at least 800 samples per update
- learning rate: constant with 2000 warmup steps
- hardware: TPU v3-8
Transformer Variants
The different variants we're looking at today implement a combination of:
- Post-LN: standard BART model with post-LN (like original transformer)
- Pre-LN: the LayerNorms are placed as pre-LN
-
- Swin Transformer v2: post-LN in the non-residual branch + cosine attention with tau + relative position embeddings (the ones from Swin v1 as we don't need continuous ones)
- NormFormer: pre-LN + LN after attention + LN after activations (middle of FFN block) + head scale (no residual scaling as the paper mentions it does not work on large models)

-
- GLU variant instead of standard FFN (about 2/3 of params per dense layer to have same total number of params)
- RMSNorm instead of standard LayerNorm

Results
Summary
TL;DR with links to relevant section:
These results are specifically applicable to the DalleBart model from DALL-E-mini and use distributed shampoo optimizer which already improves considerably model stability.
Note: I don't think there's a magic recipe that works all the time. I'm mainly reporting what worked in my specific case.
Please comment in this report, raise an issue or submit a PR if you notice any bugs!
💡
Detailed Comparisons
Final LN
In Pre-LN type of architectures (all except Post-LN & DeepNet), the model will not converge unless there is a final LayerNorm in the decoder.
Using a final LayerNorm in the encoder also helps convergence.
Bias in Dense layers
Using bias in dense layers adds 15% of training time per step and hurts convergence.
DeepNet vs Post-LN
DeepNet just improves slightly stability vs vanilla Post-LN by adding a scaling factor at weights initialization and in residual branches.
If you use Post-LN, you may as well implement it.
NormFormer vs Sandwich-LN
NormFormer is more stable than Sandwich-LN, which has same LayerNorm positions except in the MLP block.
NormFormer variants
NormFormer head scale hurts convergence, probably because it is followed with a LayerNorm. It also slows down training.
The learnt residual connections have not been used as the paper mentions it is not beneficial at larger scale.
We also present a force_scale variant which uses the LayerNorm learnt scale even when directly followed by dense layers. It can improve slightly results, most likely because it indirectly acts as a lower learning rate on parameter updates: (learning_rate * scale_1) * (learning_rate * scale_2). Overall it is not recommended to use it as we should get the same improvements through the learning rate schedule and it adds about 5% of training time per step.
GLU variants
GLU variants are always beneficial even if they require extra memory for the same amount of parameters (using 2/3 per dense layer vs FFN) and therefore a lower batch size.
GLU activation functions
We use different activation functions in combination with GLU variants: GeLU, Swish and SmeLU.
GeGLU (with GeLU) and SwiGLU (with Swish) perform the best. SwiGLU is slightly better up to a certain point because it is faster to compute but eventually GeGLU seems more stable.
RMSNorm
RMSNorm is better than LayerNorm for a very long time.
However after a long time on the best runs, RMSNorm plateau's before LayerNorm.
Swin v2
Swin v2 trains as well as NormFormer. We used a fixed scale tau in this test.
Swin Relative Positions
We encode relative positions as learnt parameters in attention layers per Swin Transformers and compare it with regular absolute position embeddings.
Swin relative positions are slower to train.
SinkFormers
We can only use the SinkFormer in the encoder but not in the decoder due to causal attention however it does not help the model.
All experiments
Select the runs you want to display from bottom table
💡
Run set
67
Acknowledgements
- Rohan Anil for setting up Distributed Shampoo optimizer and continuous feedback
- Phil Wang who has also implemented many variants in PyTorch and shares his insights through x-transformers
- Google TPU Research Cloud (TRC) program for providing computing resources and Pedro Cuenca for running some of these experiments
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.