Instance Normalization in PyTorch (With Examples)
A quick introduction to Instance Normalization in PyTorch, complete with code and an example to get you started. Part of a bigger series covering the various types of widely used normalization techniques.
Check out the other posts in our Normalization Series.
👆🏻Table of Contents (Click to Expand)
👋 What is Instance Norm?
Before we dive into Instance Norm, let's have a quick primer on why we use Normalization techniques in the first place.
Covariant shift – the change of distribution of data as it passes through the network – is a common problem when your'e training a deep learning network. Even if the data belongs to the same class, because of covariant shift the data distribution can change drastically thereby making any net incapable. We employ Normalization techniques to solve for this, helping reparameterize the data distribution at each step to ensure uniformity. This also increases the training speed and stability of the gradient based training process (in most of the general cases).
But why can't we just use always use Batch Normalization? Doesn't that take care of the covariant shift and boost performance too? Well, yes but it also has some disadvantages.
The performance of the model is highly dependent on the batch size. For smaller batch sizes, the performance degrades but is acceptable for bigger batch sizes
To perform Batch Normalization, you have to wait for the entire batch to be processed/operated. This is because the mean and standard deviation is computed across all points in the batch, thereby making Batch Normalization unemployable in the case of Sample Based Stochastic Gradient Descent (as opposed to the widely used mini-batch SGD) and Recurrent Neural Networks (RNNs).
Let's see how Instance Normalization works and how it is different from the other techniques.
In Batch Normalization
, we compute the mean and standard deviation across the various channels for the entire mini batch.
In Layer Normalization
, we compute the mean and standard deviation across the various channels for a single example.
In Instance Normalization
, we compute the mean and standard deviation across each individual channel for a single example.
Using the above figure as reference, we can see how normalization is achieved across all the channels for a single example. However, as we can see in the case of Instance Normalization we calculate the mean and standard deviation for each channel of each example in our mini batch.
Infact, one can see Instance and Layer Normalization as special cases of Group Normalization. For more details, read our post on Group Normalization here
📈 Code + Results
One can easily use Instance Normalization from the torch.nn
API, using either InstanceNorm.1d
depending on the use case. The following graphs compare the aforementioned architecture trained on the MNIST dataset for MultiClass Classification using the Stochastic Gradient Descent optimizer and the Cross Entropy Loss for 10 epochs. The Colab using for training these models can be found here.
As we can clearly see, the model performs better when trained using Instance Normalization in terms of both the Cross Entropy Loss and MultiClassAccuracy.
Because of how awesome Weights & Biases is we can also visualize our model's predictions using W&B Tables
. For example, let's see this table which can help us spot the places where our model makes mistakes.
The following W&B Parallel Plot
demonstrates how the Accuracy and Cross Entropy Loss varies with the type of Normalization used.
And that wraps up our post on using Instance Normalization and understanding the motivation and its benefits. To see the full suite of W&B features please check out this short 5 minutes guide
. If you want more reports covering the math and "from-scratch" code implementations let us know in the comments down below or on our forum
Check out these other reports on Fully Connected
covering other fundamental concepts like Linear Regression and Decision Trees.