A Recipe for Training Large Models
Practical advice and tips for training large machine learning models
Created on March 1|Last edited on November 15
Comment
Introduction
When needing an AI model, you should always start by seeing if there is already an existing one that satisfies your needs and aim for the smallest possible model (lower inference cost, more portability, etc). However, in some instances, you will need to train your own model from scratch, but won't be able to run multiple experiments.
The objective of this guide is to:
- help you train large models (>1B parameters)
- avoid instabilities
- save experiments that started to fail without restarting at 0
I only intend to give practical recommendations, using a simple recipe that I have been updating over time and which I follow.
While the goal is to have a systematic way of successfully training models (until it can be completely automated), some of my conclusions may be incorrect, incomplete or not applicable to every type of problem. This guide is a live document that will be updated with future research. Contributions are also welcome.
Additionally, if you'd like to check out W&B's recent whitepaper on training LLMs from scratch, you can find it via the button below:
DOWNLOAD WHITEPAPER
Table of Contents:
IntroductionTable of Contents:When do these training guidelines apply?The Recipe To Train Large ModelsExperimentation phaseAlways start with a small version to ensure your entire pipeline worksScale your model in 3-10x increments before training the final versionStable Training Of Large ModelsYour Model SizeThe Batch SizeModel ShardingThe Learning RateWhen is Training Complete?Large Model InfrastructureConclusionResourcesAcknowledgements
When do these training guidelines apply?
You should try to follow this guide when:
- Your model is over 1 billion parameters
- Your model is going to train for at least multiple days
- You don't think that you are limited by your dataset
- Admittedly, this is much harder to estimate. As an example dalle-mini started overfitting with 400M parameters on a dataset of 17M images after 3 days on a TPU v3-8 (50 epochs)
- It is easier to think in term of number of epochs you expect to iterate versus the dataset
- You are training from scratch, or on a very different task
- if not, you may be able to freeze your model and just fine-tune a few layers, way easier!
- You don't have the budget to perform 50-100 experiments
If any of the above does not apply, you should definitely perform a hyperparameter search and experiment with different model architectures.
The Recipe To Train Large Models
Experimentation phase
A few things you should keep in mind as you get started:
Always start with a small version to ensure your entire pipeline works
You'll scale your model later, so upfront, you want to ensure things like dataset loading, training/validation, model saving/restoring, etc. are working how you expect.
Scale your model in 3-10x increments before training the final version
This represents a fraction of the cost and you can easily port your hyperparameters when scaling up slowly, typically needing only small adjustments by up to one order of magnitude, a cheap way to tune your learning rate for example.
This is the only moment to be creative and experiment with ideas! You should:
- Never try new things on a large model unless it at least seems to work on your own smaller model (and don't trust blindly a paper! Anyone can make mistakes, inflate results, or their tricks may not apply to your own problem). When training your large model you will mainly be busy fighting instabilities so you don't want to add another reason for issues.
- Start with a simple baseline that is known to work so you have a good reference model
- Avoid making conclusions too fast, some runs may initially look promising but end up diverging later or just reach a plateau earlier than others
- Experiment fast, try quick ideas, and make sure to keep notes as you go: you'll revisit your reports and conclusions frequently.
- Add new configuration parameters for your experiments (such as norm_type, use_alibi, and even use_idea_1 is better than nothing) and log them all so you can quickly compare experiments and display the difference between your experiments.
- If you log your runs with Weights & Biases, it's easy to select a graph that compares a few runs, pull it into a report, and add a quick note of what insight it brings.
- See An Evaluation of Transformer Variants as an example: I was just adding tiny sections over time as I ran experiments to record my conclusions, and I keep coming back to it.
Stable Training Of Large Models
The most important aspect of training large models is to ensure stable training (as opposed to over-optimizing every single aspect):
- Optimizer
- I recommend Distributed Shampoo as a very stable optimizer. The memory overhead (vs. Adam) disappears when using multiple nodes (due to sharding)
- Precision
- Some instabilities are just solved by increasing precision: float32 > bfloat16 > float16
- You typically will want to keep your model parameters in full precision and only decrease the precision of computations. Be mindful of certain operations that may require higher precision such as attention layers and losses (softmax, contrastive loss especially with large batch size, etc).
- If you don't get significant advantages with lower precision (speed or memory constraints), you should just use a higher precision.
- Weight decay
- You don't need weight decay with Shampoo
- Weight decay may help at the start of training but you should decrease it to 0 during training (if your optimizer allows it)

Weight decay helps the start of training but leads to a plateau
- Dropout
- I never use dropout for large models with enough data.
- Model architecture
- I have been having success with NormFormer + GLU, another good one if you want fewer LayerNorm's is the cosine similarity from Swin v2.
- I use no bias.
- Weight initialization
- Larger layers often require smaller weights.
- Initializing residual connection weights to 0 can improve stability if needed.
- Other stabilization methods: if you don't use Shampoo you may benefit from gradient clipping (in that case I'd suggest clipping by global norm).
- EMA parameters
- If you have enough memory, you can compute EMA of parameters during training.
- Otherwise, you can keep a copy of model checkpoints and average over them.
- It is mainly beneficial for diffusion models and has a minimal effect on other models.
Your Model Size
Your target model size should be defined early on:
- Typically the bigger you can handle the better.
- Limitations may be based on your compute budget. Additionally:
- See what model size your hardware can handle.
- Find the most similar model known to have been trained and note their model size, dataset, and training budget.
- Refer to the closest applicable scaling laws available.
- Finally confirm your training budget based on batch size, number of steps/samples to be seen, and your actual training time per step (measured on a short experiment).
- Size your model based on the intended application:
- What are the hardware limitations of the target host device?
- What is the inference speed requirement?
- Test actual inference performance on a dummy untrained model.
The Batch Size
Make sure to use a large enough batch size:
- Always use the biggest batch size you can.
- Find a minimum requirement for batch size based on similar models that have been trained successfully.
To increase your batch size:
- Use gradient accumulation to increase your effective batch size.
- Warning: not applicable on certain models such as CLIP where the loss is based on calculations over the entire batch at a time.
- Use gradient checkpointing (evaluate the impact on training speed).
- Reduce your model size if you have to, a batch size too low can cause low performance of your model.
You may modify your base batch size in some instances:
- At the start of training, a lower batch size can help you train faster. However, it adds noise to the gradient so can potentially lead to instability, and it can be hard to detect at which point you should switch back to full batch size.
- Later during training, an increase in batch size has similar effects as a decrease in the learning rate, with an immediate decrease in loss. Based on the impact it has on training speed, it can be preferable to increase the batch size (through gradient accumulation) vs. lowering the learning rate.
Model Sharding
Ideally, you can just train your large model using data parallelism, meaning your model is just replicated on each device and your batch is split over different devices.
When training a large model, you may have to shard its parameters or activations over multiple devices.
There are multiple strategies regarding how to shard each layer & activations, still out of my skill set. When using JAX I typically just refer to the partitioning strategies from T5X, and try them all with different combinations of batch size and the number of devices over which to shard the model.
My typical workflow includes:
- Finding the smallest number of devices over which to shard the model (the device mesh can be defined as an array sharded over model and data), experimenting with various configurations of partitioning strategies.
- Making the batch size as large as possible.
- Experimenting with model size adjustments to evaluate the impact on training speed.
- Testing FSDP strategy and fully shard data and model across all devices
You want to try to allow a large dimension over data → more data parallelism → potentially bigger total batch size (not always true as it depends both on data dimension and batch size per individual data slice) → less gradient accumulation steps needed for your target effective batch size → faster time per training step.
The Learning Rate
- Start on a smaller model, learning rate ports pretty well if your model size increases in x5-10 increments, with minor adjustments.
- Perform a quick search:
- Find the best learning rate that goes in about x3 increments (1e-3, 3e-3, 1e-2…).
- Start with x10 increments initially to go faster.
- Run the experiments long enough to get some confidence.
- It is essential to optimize at least the learning rate before an expensive training run as it has a significant impact on training speed (and even training success).
- At the start of training, use a warmup
- It can only be beneficial
- set the warmup duration based on your training budget (should be less than 5-10%)
- During training, keep the learning rate constant:
- It let you decouple its effect on metrics (lowering learning rate typically induces a sharp decrease of loss)
- Try from time to time to lower/increase the learning rate (I use x3 increments) and observe the effect it has on the training speed (slope), you can always readjust it without having to go back in time
- If you are really lazy, you can use a cosine or linear decay but keep in mind that it's only because you're really lazy 😬
- At the end of training, use a final cosine or linear decay to 0 for a little boost
When is Training Complete?
Nobody knows for certain but if you're happy with the current results you should just finish your training with a final decay to 0.
If results are not improving anymore it's probably time to try something else (or just scale up).
Large Model Infrastructure
You should implement the following features:
- Logging of training/validation loss + relevant metrics (accuracy…).
- Logging of parameters and gradients:
- The norm of parameters and gradients should regularly be logged so you can refer to it during instabilities
- Histograms can be logged optionally (it will impact training speed but is nice to enable when debugging instabilities)
- At start of training, manually verify that all gradients are flowing properly in your model and that parameters are being updated

You should also implement:
- Saving/resuming from the checkpoint.
- You may want to handle separate model and optimizer states.
- Logging of prediction samples.
- There is nothing better than visualizing actual sample predictions.
- If it requires too much compute, you should create a script that computes them against your latest model checkpoint and start it on a different instance.
- A small demo so you can quickly experiment with your model:
- It is essential to be able to test your model during training.
- A simple notebook or a Gradio demo (better to share with other users) are very fast to implement and will let you detect potential issues early on.
- Keep a training log of your model with:
- Your training configuration, hardware, and hyperparameters.
- Your loss, metrics, etc. over time with all relevant runs.
- A small note on every single change you made between runs so you can understand your graphs later.
- Sample predictions over time if possible.
Conclusion
Large models and large model training are, of course, evolving as we speak. But the tips above have proved really valuable for me while training large models of my own. If you have anything else you've discovered or think is worth adding to our piece above, please leave a comment below. We'd love to hear what's worked for you.
Resources
- NeurIPS 2022, Workshop "Has it Trained Yet", "The Road to Craiyon" related to the scaling of dalle-mini
- Another cool guide: "Deep Learning Tuning Playbook" from Varun Godbole, George E. Dahl, Justin Gilmer, Christopher J. Shallue, Zachary Nado
Acknowledgements
Add a comment
Awesome guide. Thanks for putting this together!
Reply
In the sentence “dalle-mini started overfitting with 400M parameters on a dataset of 17M images after 3 days on a TPU v3-8 (50 epochs)”, what method do you use to determine whether a model starts to overfit?
Reply
In the sentence "in that case I'd suggest clipping by global norm", what does "global norm" mean?
Reply
This is extremely useful, thank you for the putting in the time!
1 reply
Iterate on AI agents and models faster. Try Weights & Biases today.