Exploring Diffusion Model Training
Collectively exploring settings for training small diffusion models on a few datasets
Created on November 1|Last edited on November 1
Comment
The intro lesson on diffusion models from our course (The Generative Landscape https://johnowhitaker.github.io/tglcourse/dm1.html) shows a quick demo of training a diffusion model from scratch. The settings chosen are fairly arbitrary - in this report we use the accompanying script to explore the hyperparameter space together and see what we can learn.
Goals
The main goals of this exercise:
- Build an intuition for how long diffusion models take to train
- Explore the performance of different UNet configurations
- Find some 'training recipes' that work well on small datasets
This is research, so let's also keep an eye out for anything else that looks interesting!
What's in the Script?
I've started the script out fairly minimal. Using fastcore.script, the arguments of the main function can be set as command-line options:

TODO re-do screenshot after adding more info
The training loop is fairly similar to the one shown in the lesson notebook, borrowing some ideas from the huggingface example script (todo link) and adding:
- A learning rate schedule that warms up to the specified learning rate over 500 steps and then decays down following a cosine schedule
- Some minimal data augmentation, including horizontal flips
- W&B logging for model config, loss, learning rate etc
- Occasional logging of a grid of sample images (with a fixed seed for easy comparison)
You are encouraged to modify this script. For example, the unet configuration is currently hard-coded, but that is definitely worth tweaking. Use the comments argument to summarize what changes you've made - the code will also be logged for more detailed reference. (TODO check save_code is on by default)
Baselines
I picked two test cases to start with: Butterflies (a very small dataset) at 32px and Flowers at 64px. We can discuss on Discord what other datasets and scales are worth testing.
The images above show the results from a quick run on each dataset. These are the images to beat!
What To Tweak
Besides just messing with the batch size and learning rate, I recommend starting by exploring
- How many layers are used in the UNet
- Which layers have attention
- What tradeoffs there are in terms of model size vs training speed and accuracy
You are of course welcome to train for much longer than the demo examples, to see how good things get.
Some additional things that may be interesting to add:
- Automated evaluation using FID (rather than jut eye-balling images)
- Gradient accumulation or multi-GPU training
- Different objectives or loss weighting
Keep an eye on the GPU utilization (logged automatically) to make sure you're using your GPU(s) to their full potential! I can't wait to see what else you all come up with.
Results
coming soon!
Add a comment