Normalization Series: What is Batch Norm?

An in-depth blogpost covering Batch Normalization, complete with code and interactive visualizations. Part of a bigger series on Batch Normalization.
Saurav Maheshkar

Link to the Colab \longrightarrow

πŸ–± Table of Contents (Click to Expand)

πŸ‘‹ What is BatchNorm?

First introduced in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift by Sergey Ioffe and Christian Szegedy, Batch Normalization has become a well-accepted normalization method in the world of deep learning architectures.
In any machine learning pipeline, normalization or standardization of the training dataset is generally considered good practice, as it simplifies the input data distribution to a known standard distribution. If this step is skipped, our data distribution can be highly skewed or irregular and might make the resultant features extremely scarce and hard to learn. This often leads to learning instability, because of exploding (imbalanced) gradients.
It follows that it's often preferred to normalize the dataset. But what about the weights themselves??? We can imagine the same problem happening with the weights–it might happen that, due to some random seed or a particularly hard example, the weight of a particular neuron is drastically different from the others in the same layer and thus, these weights when propagated / updated to the next layer can cause imbalance ultimately leading to instability. Thus, we must NORMALIZE the weights as well!!!

πŸ§‘πŸ»β€πŸ« Theory

The authors propose to solve this problem of the cascading effect of small updates in the layer. The model must be robust to these small changes and learn to continuously adapt to the ever changing distribution obtained after each layer.
When the input distribution changes it's said to experience "Covariate Shift".
The algorithm can be summarized as follows:-
  1. First, we calculate the mean of the distribution(denoted by \mu_b), for a single batch of mini-batch size m, using
\large \mu_b = \frac{1}{n} \sum_{i=1}^{m} x_i
and the corresponding variance (denoted by \sigma^2_b) using
\large \sigma^2_b = \frac{1}{m} \sum_{i=1}^{m} \, (x_i - \mu_b)^2
Now that we have the mean and variance, we Normalize the distribution using
\huge \hat{x} = \frac{(x_i - \mu_b)}{\sqrt{\sigma^2_b}}
2. Then, using this normalized distribution we again transform the data by multiplying it with a certain gamma \gamma and then adding a small variable beta \beta to it. We finally fit this transformed data into the next layer.
\huge y = \gamma \hat{x} \,\,+ \,\, \beta
NOTE: 1. \gamma and \beta are learned parameters. 2. This happens before passing the activations to the layer and happens on a per BATCH basis.

πŸ“ˆ Code + Results

It's very easy to implement Batch Normalization using the PyTorch nn API, for example let's look at this simple Neural Network Architecture which uses the nn.BatchNorm2d Class to create a Batch Normalization Layer:
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.bn1 = nn.BatchNorm2d(6). # < -------------- Batch Normalization self.conv2 = nn.Conv2d(6, 16, 5) self.bn2 = nn.BatchNorm2d(16) # < -------------- Batch Normalization self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.bn1(x) x = self.pool(F.relu(self.conv2(x))) x = self.bn2(x) x = torch.flatten(x, 1) # flatten all dimensions except batch x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
The following graphs compare the aforementioned architecture trained on the MNIST dataset for MultiClass Classification using the Stochastic Gradient Descent optimizer and the Cross Entropy Loss for 10 epochs. The Colab using for training these models can be found here.
As we can clearly see, the model performs better when trained using Batch Normalization in terms of both the Cross Entropy Loss and MultiClassAccuracy.
Because of how awesome Weights & Biases is we can also visualize our model's predictions using W&B Tables. For example, let's see this table which can help us spot the places where our model makes mistakes.

✌️ Conclusion

And that wraps up our post on using Batch Normalization and understanding the motivation and its benefits. To see the full suite of W&B features please check out this short 5 minutes guide. If you want more reports covering the math and "from-scratch" code implementations let us know in the comments down below or on our forum ✨!
Check out these other reports on Fully Connected covering other fundamental concepts like Linear Regression and Decision Trees.
Report Gallery