What's the Optimal Batch Size to Train a Neural Network?
We look at the effect of batch size on test accuracy when training a neural network. We'll pit large batch sizes vs small batch sizes and provide a Colab you can use.
Created on August 19|Last edited on July 9
Comment
While the answer to questions like, "What's the optimal batch size to train a neural network?" almost always has the same answer ("It depends"), our goal today is to look at how different batch sizes affect accuracy, training time, and compute resources.
In this article, we'll then look into some hypotheses that explain the results of our testing of large batch sizes vs. small batch sizes and everything in between.
Here's what we'll be covering:
Small Batch Sizes vs. Large Batch Sizes (and everything in between)Results Of Small vs. Large Batch Sizes On Neural Network TrainingWhy Do Large Batch Sizes Lead To Poorer Generalization?Weights & Biases
Let's get going!
Small Batch Sizes vs. Large Batch Sizes (and everything in between)
We first need to establish the effect of batch size on the test accuracy and training time. To do this, we'll do an ablation study.
We will be using an image classification task and testing how accuracy changes with different batch sizes. Here are a few things we'll focus on for these tests:
- We will use SEED wherever possible. This eliminates the noise from random initialization, which makes our model more robust. Here's an interesting read on Meaning and Noise in Hyperparameter Search.
- We want to use a simple architecture as opposed to Batch Normalization. Check out the effect of batch size on model performance here.
- We won't use an over-parameterized network–this helps avoid overfitting.
- We will be training our model with different batch sizes for 25 epochs.
Let's dig into the results. 👇
Run set
9
Results Of Small vs. Large Batch Sizes On Neural Network Training
- From the validation metrics, the models trained with small batch sizes generalize well on the validation set.
- The batch size of 32 gave us the best result. The batch size of 2048 gave us the worst result. For our study, we are training our model with the batch size ranging from 8 to 2048 with each batch size twice the size of the previous batch size.
- Our parallel coordinate plot also makes a key tradeoff very evident: larger batch sizes take less time to train but are less accurate.
Run set
9
- Digging in further, we clearly see an exponential decrease in the test error rate as we move from larger batch sizes to smaller ones. That said, note that for batch size 32, we have the least error rate.
- We see an exponential increase in the time taken to train as we move from a higher batch size to a lower batch size. And this is expected! Since we are not using early stopping when the model starts to overfit but rather allowing it to train for 25 epochs, we are bound to see this increase in training time.
Why Do Large Batch Sizes Lead To Poorer Generalization?
What might be the reason(s) to explain this strange behavior? This Stack Exchange thread has a few great hypotheses.
Some of my favorites:
- This paper claims that large-batch methods tend to converge to sharp minimizers of the training and testing functions–and that sharp minima lead to poorer generalization. In contrast, small-batch methods consistently converge to flat minimizers.
- Gradient descent-based optimization makes a linear approximation to the cost function. However, if the cost function is highly non-linear (highly curved) then the approximation will not be very good, hence small batch sizes are safe.
- When you put m examples in a minibatch, you need to do O(m) computation and use O(m) memory, but you reduce the amount of uncertainty in the gradient by a factor of only O(sqrt(m)). In other words, there are diminishing marginal returns to putting more examples in the minibatch.
- Even using the entire training set doesn’t really give you the true gradient. Using the entire training set is just using a very large minibatch size.
- Gradient with small batch size oscillates much more compared to larger batch size. This oscillation can be considered noise. However, for a non-convex loss landscape(which is often the case), this noise helps come out of the local minima. Thus larger batches do fewer and coarser search steps for the optimal solution, and so by construction, will be less likely to converge on the optimal solution.
Weights & Biases
Weights & Biases helps you keep track of your machine learning experiments. Use our tool to log hyperparameters and output metrics from your runs, then visualize and compare results and quickly share findings with your colleagues.
Normalization Series: What is Batch Normalization?
An in-depth blogpost covering Batch Normalization, complete with code and interactive visualizations. Part of a bigger series on Normalization.
Meaning and Noise in Hyperparameter Search with Weights & Biases
How do we distinguish signal from pareidolia (imaginary patterns)? This article is showcases what is possible with W&B and aims to inspire further exploration.
Add a comment
Nice points you bring up in the end of the article. However, I'm not sure I understand the second one. Could you explain why a highly non-linear cost function makes small batch sizes "safe"?
Also, one more thing to consider is that batch normalization inherently adds noise during training, since its behavior is based on the batch statistics, which vary between batches. This noise acts as regularization and makes the network less prone to overfitting the training data, and the smaller the batch size is, the more the statistics are going to vary, hence adding more noise during training, which increases the strength of the regularization. So if the training is poorly regularized and you use batch normalization, using a smaller batch size can be helpful for this reason.
Reply
Interesting article!
Given that you train for the same number of epochs on all networks, the models with larger batch size will actually run way fewer gradient descent updates.
Also, this doesn't look like low generalization because there are no plots for training loss, it could just mean that models that saw fewer training iterations are actually slightly underfit when compared to those that saw more training updates.
The point of training with large batch sizes is that you're able to see more data on the *same* time budget (assuming you're training only one epoch on a very large dataset).
PS: the google colab link is dead on my end.
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.