DALL·E Mega - Training Journal
Created on April 12|Last edited on June 9
Comment
Table of contents
ResourcesJournal log (click to see)Learning Rate searchDALL·E Mega - TrainingTraining parametersModel configurationDALL·E Mega vs DALL·E MiniSample PredictionsEvolution of Predictions over timeDrawBench predictions
Resources
Journal log (click to see)
Learning Rate search
Run set
4
DALL·E Mega - Training
Several runs are present:
- when the training seems to plateau, we lower learning rate (constant with warmup)
- we used dropout at around 500M train/samples and then stopped at around 750M train/samples (did not seem useful)
- at 500m train/samples, we have a higher eval/loss because we decided to switch validation set to a shard of training set due to some concern of different data distribution
- closer to the end of training, we started using exponential decay
- at 800M train/samples, we switched to full precision training (only 20% slowdown)
- at 1.2B train/samples, we now use 2 validation sets, the original one and the training shard we had switched to
- at 1.8B train/samples, we started exponential decay
Training parameters
- Hardware: 1 pod TPU v3-256 = 32 nodes of TPU VM v3-8 (8 TPU per node) = 256 TPU v3
- Optimizer: Distributed Shampoo with block size of 2048, beta1 0.9, beta2 0.99, preconditioning every 10 steps, no preconditioning of embedding layers
- Model Partition Spec: 8 model parallel x 32 data parallel
- Batch: 44 samples per model x 32 data parallel x 3 gradient accumulation steps = 4224 samples per update
- Learning rate: warmup to 0.0001 for 10,000 steps and then kept constant until plateau
- Gradient checkpointing used on each Encoder/Decoder layer (ie, MHA + FFN)
- No dropout
- No weight decay
- No gradient clipping
- Weights initialized with normal distribution with standard deviation 0.01
Model configuration
- BART architecture following NormFormer variant without learnt residual scale and without head scale
- Text tokens fed to Encoder with max length of 64
- Image tokens fed to Decoder (causal) with length of 256
- 24 layers in encoder and 24 layers in decoder
- Embedding dimension: 2048
- Attention layer: 16 attention heads per layer
- Feed forward layers: GLU variant with a dimension of 4096 for each dense layer
- Activation function: GeLU (changed from Swish)
- Position embeddings: absolute positions fed at start of encoder/decoder
- Using standard LayerNorm
- No bias used
DALL·E Mega vs DALL·E Mini
Sample Predictions
Prompts from OpenAI, Pedro Cuenca, Rohan Anil, Rivers Have Wings, Kianne Luy, multimodal ai art, Annas, genai, Kyle Kastner, and many more members from the community.
I didn't keep track of who the prompts were from so please reach out for me to add your name.
Evolution of Predictions over time
Select and compare different runs to see the evolution of predictions (1 line = 1 checkpoint).
Prediction Runs
5
DrawBench predictions
Subscribe to our Newsletter for more ML News like DALLE.
Add a comment
Nice! Is the list of prompts available somewhere? seems like a great way to evaluate text-> image models in general
Could be great if it was in github somewhere (I tried to download it from wandb unsuccessfuly)
4 replies
If you are on Twitter, you can sample from the latest public dalle-mega model by Tweeting @max_elbo
https://twitter.com/max_elbo/status/1524175998668865537
Reply
Wow these results are getting so good. You’ve blown past Dalle original for sure.
1 reply
You should compare this to Skeb.jp - It's a site where commissioners give artists one prompt and leave the artist to it. It's the most direct human equivalent of what Dall-E is doing
Reply
very good
Reply
Eagerly awaiting the results
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.