Skip to main content

DALL·E Mega - Training Journal

Created on April 12|Last edited on June 9
Table of contents


Resources

Journal log (click to see)


Run set
4


DALL·E Mega - Training

Several runs are present:
  • when the training seems to plateau, we lower learning rate (constant with warmup)
  • we used dropout at around 500M train/samples and then stopped at around 750M train/samples (did not seem useful)
  • at 500m train/samples, we have a higher eval/loss because we decided to switch validation set to a shard of training set due to some concern of different data distribution
  • closer to the end of training, we started using exponential decay
  • at 800M train/samples, we switched to full precision training (only 20% slowdown)
  • at 1.2B train/samples, we now use 2 validation sets, the original one and the training shard we had switched to
  • at 1.8B train/samples, we started exponential decay


See the training runs to visualize more details (gradients, parameter norm, etc).

Training parameters

  • Hardware: 1 pod TPU v3-256 = 32 nodes of TPU VM v3-8 (8 TPU per node) = 256 TPU v3
  • Optimizer: Distributed Shampoo with block size of 2048, beta1 0.9, beta2 0.99, preconditioning every 10 steps, no preconditioning of embedding layers
  • Model Partition Spec: 8 model parallel x 32 data parallel
  • Batch: 44 samples per model x 32 data parallel x 3 gradient accumulation steps = 4224 samples per update
  • Learning rate: warmup to 0.0001 for 10,000 steps and then kept constant until plateau
  • Gradient checkpointing used on each Encoder/Decoder layer (ie, MHA + FFN)
  • No dropout
  • No weight decay
  • No gradient clipping
  • Weights initialized with normal distribution with standard deviation 0.01

Model configuration

  • BART architecture following NormFormer variant without learnt residual scale and without head scale
  • Text tokens fed to Encoder with max length of 64
  • Image tokens fed to Decoder (causal) with length of 256
  • 24 layers in encoder and 24 layers in decoder
  • Embedding dimension: 2048
  • Attention layer: 16 attention heads per layer
  • Feed forward layers: GLU variant with a dimension of 4096 for each dense layer
  • Activation function: GeLU (changed from Swish)
  • Position embeddings: absolute positions fed at start of encoder/decoder
  • Using standard LayerNorm
  • No bias used

DALL·E Mega vs DALL·E Mini




Sample Predictions

Prompts from OpenAI, Pedro Cuenca, Rohan Anil, Rivers Have Wings, Kianne Luy, multimodal ai art, Annas, genai, Kyle Kastner, and many more members from the community.
I didn't keep track of who the prompts were from so please reach out for me to add your name.





Evolution of Predictions over time

Select and compare different runs to see the evolution of predictions (1 line = 1 checkpoint).

Prediction Runs
5


DrawBench predictions





Subscribe to our Newsletter for more ML News like DALLE.

Romain Beaumont
Romain Beaumont •  
Nice! Is the list of prompts available somewhere? seems like a great way to evaluate text-> image models in general Could be great if it was in github somewhere (I tried to download it from wandb unsuccessfuly)
4 replies
Saravana Rathinam
Saravana Rathinam •  *
If you are on Twitter, you can sample from the latest public dalle-mega model by Tweeting @max_elbo https://twitter.com/max_elbo/status/1524175998668865537
Reply
Nicholas Bardy
Nicholas Bardy •  
Wow these results are getting so good. You’ve blown past Dalle original for sure.
1 reply
Kieron George
Kieron George •  
You should compare this to Skeb.jp - It's a site where commissioners give artists one prompt and leave the artist to it. It's the most direct human equivalent of what Dall-E is doing
Reply
Annas
Annas •  
very good
Reply
Nicholas Bardy
Nicholas Bardy •  
Eagerly awaiting the results
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.