A brief study on the effect of batch size on test accuracy. Made by Ayush Thakur using Weights & Biases

While realizing that the answer to questions like "what's the optimal batch size?" almost always has the same answer ("it depends"), our goal today is to look at how different batch sizes affect accuracy, training time, and compute resources. Then, we'll look into some hypotheses that explain those differences.

We first need to establish the effect of batch size on the test accuracy and training time.

To do so, let's do an ablation study. We will be using an image classification task and testing how accuracy changes with different batch sizes. A few things we'll focus on for these tests:

- We will use use SEED wherever possible. This eliminates the noise from random initialization which makes our model more robust. Here's an interesting read on Meaning and Noise in Hyperparameter Search.
- We want to use a simple architecture as opposed to Batch Normalization. Check out the effect of batch size on model performance here.
- We won't use an over-parameterized network–this helps avoid overfitting.
- We will be training our model with different batch size for 25 epochs.

We will use Weights and Biases sweep to run our ablation study. Let's dig into the results. 👇

- First off, from the validation metrics, the models trained with small batch sizes generalize well on the validation set.
- The batch size of 32 gave us the best result. The batch size of 2048 gave us the worst result. For our study, we are training our model with the batch size ranging from 8 to 2048 with each batch size twice the size of the previous batch size.
- Our parallel coordinate plot also makes a key tradeoff very evident: larger batch sizes take less time to train but are less accurate.

- Digging in further, we clearly see an exponential decrease in the test error rate as we move from higher batch size to lower batch size. That said, note that for batch size 32 we have the least error rate.
- We see an exponential increase in the time taken to train as we move from higher batch size to lower batch size. And this is expected! Since we are not using early stopping when the model starts to overfit rather allow it to train for 25 epochs we are bound to see this increase in training time.

What might be the reason(s) to explain this strange behavior? This Stat Exchange thread has a few great hypotheses. Some of my favorites:

- This paper claims that large-batch methods tend to converge to sharp minimizers of the training and testing functions–and that sharp minima lead to poorer generalization. In contrast, small-batch methods consistently converge to flat minimizers.
- Gradient descent-based optimization makes linear approximation to the cost function. However if the cost function is highly non-linear (highly curved) then the approximation will not be very good, hence small batch sizes are safe.
- When you put m examples in a minibatch, you need to do O(m) computation and use O(m) memory, but you reduce the amount of uncertainty in the gradient by a factor of only O(sqrt(m)). In other words, there are diminishing marginal returns to putting more examples in the minibatch.
- Even using the entire training set doesn’t really give you the true gradient. Using the entire training set is just using a very large minibatch size.
- Gradient with small batch size oscillates much more compared to larger batch size. This oscillation can be considered noise however for a non-convex loss landscape(which is often the case) this noise helps come out of the local minima. Thus larger batches do fewer and coarser search steps for the optimal solution, and so by construction will be less likely to converge on the optimal solution.

Weights & Biases helps you keep track of your machine learning experiments. Use our tool to log hyperparameters and output metrics from your runs, then visualize and compare results and quickly share findings with your colleagues.

Get started in 5 minutes.