Normalization Series: What is Batch Normalization?
An in-depth blogpost covering Batch Normalization, complete with code and interactive visualizations. Part of a bigger series on Normalization.
Created on December 3|Last edited on January 19
Comment
Link to the Colab
🖱 Table of Contents (Click to Expand)
🖱 Table of Contents (Click to Expand)👋 What is Batch Normalization and BatchNorm?🧑🏻🏫 Theory📈 Code + Results✌️ Conclusion
👋 What is Batch Normalization and BatchNorm?
First introduced in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift by Sergey Ioffe and Christian Szegedy, Batch Normalization has become a well-accepted normalization method in the world of deep learning architectures.
In any machine learning pipeline, normalization or standardization of the training dataset is generally considered good practice, as it simplifies the input data distribution to a known standard distribution. If this step is skipped, our data distribution can be highly skewed or irregular and might make the resultant features extremely scarce and hard to learn. This often leads to learning instability, because of exploding (imbalanced) gradients.
It follows that it's often preferred to normalize the dataset. But what about the weights themselves??? We can imagine the same problem happening with the weights–it might happen that, due to some random seed or a particularly hard example, the weight of a particular neuron is drastically different from the others in the same layer and thus, these weights when propagated / updated to the next layer can cause imbalance ultimately leading to instability. Thus, we must NORMALIZE the weights as well!!!
🧑🏻🏫 Theory
The authors propose to solve this problem of the cascading effect of small updates in the layer. The model must be robust to these small changes and learn to continuously adapt to the ever changing distribution obtained after each layer.
When the input distribution changes it's said to experience "Covariate Shift".
💡
The algorithm can be summarized as follows:-
- First, we calculate the mean of the distribution(denoted by ), for a single batch of mini-batch size , using
and the corresponding variance (denoted by ) using
Now that we have the mean and variance, we Normalize the distribution using
2. Then, using this normalized distribution we again transform the data by multiplying it with a certain gamma and then adding a small variable beta to it. We finally fit this transformed data into the next layer.
NOTE:
1. and are learned parameters.
2. This happens before passing the activations to the layer and happens on a per BATCH basis.
💡
📈 Code + Results
It's very easy to implement Batch Normalization using the PyTorch nn API, for example let's look at this simple Neural Network Architecture which uses the nn.BatchNorm2d Class to create a Batch Normalization Layer:
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.bn1 = nn.BatchNorm2d(6). # < -------------- Batch Normalizationself.conv2 = nn.Conv2d(6, 16, 5)self.bn2 = nn.BatchNorm2d(16) # < -------------- Batch Normalizationself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.bn1(x)x = self.pool(F.relu(self.conv2(x)))x = self.bn2(x)x = torch.flatten(x, 1) # flatten all dimensions except batchx = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
The following graphs compare the aforementioned architecture trained on the MNIST dataset for MultiClass Classification using the Stochastic Gradient Descent optimizer and the Cross Entropy Loss for 10 epochs. The Colab using for training these models can be found here.
As we can clearly see, the model performs better when trained using Batch Normalization in terms of both the Cross Entropy Loss and MultiClassAccuracy.
💡
Run set
4
Because of how awesome Weights & Biases is we can also visualize our model's predictions using W&B Tables. For example, let's see this table which can help us spot the places where our model makes mistakes.
Run set
4
✌️ Conclusion
And that wraps up our post on using Batch Normalization and understanding the motivation and its benefits. To see the full suite of W&B features please check out this short 5 minutes guide. If you want more reports covering the math and "from-scratch" code implementations let us know in the comments down below or on our forum ✨!
Check out these other reports on Fully Connected covering other fundamental concepts like Linear Regression and Decision Trees.
An Introduction to Linear Regression For Machine Learning (With Examples)
In this article, we provide an overview of, and a tutorial on, linear regression using scikit-learn, with code and interactive visualizations so you can follow.
Decision Trees: A Guide with Examples
A tutorial covering Decision Trees, complete with code and interactive visualizations
What Is Cross Entropy Loss? A Tutorial With Code
A tutorial covering Cross Entropy Loss, with code samples to implement the cross entropy loss function in PyTorch and Tensorflow with interactive visualizations.
Introduction to Cross Validation Techniques
A tutorial covering Cross Validation techniques, complete with code and interactive visualizations.
Introduction to K-Means Clustering (With Examples)
A tutorial covering K-Means Clustering, complete with code and interactive visualizations.
Add a comment
Tags: Intermediate, Domain Agnostic, PyTorch, Tutorial, Conv2D, Plots, Tables, MNIST, Normalization, Chum here, Exemplary
Iterate on AI agents and models faster. Try Weights & Biases today.