Layer Normalization in Pytorch (With Examples)

A quick and dirty introduction to Layer Normalization in Pytorch, complete with code and interactive panels. Made by Adrish Dey using Weights & Biases
Adrish Dey


Training machine learning algorithms can be a challenging task, especially with real-world datasets. Among the numerous pitfalls that one can fall into, statistical stabilization of the intermediate activations is often pretty high on the list.
In this report, we'll have a quick discussion of one of the common methods used for statistical stabilization: Layer Norm. This Report is a continuation of our series on Normalizations in Machine Learning which started with Batch Norm. We hope to have the last couple out before the end of the year.
Check out the Colab Notebook of these experiments here \longrightarrow

What is Layer Normalization anyway?

The Problem

As you may know already, training a machine learning model is a stochastic (random) process. This stems from the fact the initializations and even the most common optimizers (SGD, Adam and so on) are stochastic in nature.
Due this, ML optimizations tend to have a risk of converging into a sharp (non-generalizable) minimas on the solution landscape, resulting in large gradients. Simply put: the activations (a.k.a. the output from a non-linear layer) have the tendency of shooting up to large values. This is not ideal to say the least and the most common method of fixing this is by using Batch Normalization.
However, there's a catch here. Batch Normalization quickly fails as soon as the number of batches are reduced. As modern day ML algorithms increase in data resolution, this becomes a big problem; the batch size needs to be small in order to fit data in memory. Furthermore, performing Batch Normalization requires calculating the running mean / variance of activations at each layer. This method is not applicable for iterative models (like RNNs) where the these statistical estimate of the layers depends on length of the sequence (i.e., number of times the same hidden layer is being called).

The Solution

LayerNorm offers a simple solution to both these problems by calculating the statistics (i.e., mean and variance) for each item in a batch of activations, and normalizing each item with these statistical estimates.
Specifically, given a sample of shape [N, C, H, W] LayerNorm calculates a mean and variance of all the elements of shape [C, H, W] in each batch (see the figure below). This method not only solves both problems mentioned above, but also removes the requirement for storing mean and variances for inference (something which Batch Normalization layers needs to do during training).

Let's see some code

Implementing Layer Normalization in PyTorch is a relatively simple task. To do so, you can use torch.nn.LayerNorm().
For convolutional neural networks however, one also needs to calculate the shape of the output activation map given the parameters used while performing convolution. A simple implementation is provided in calc_activation_shape() function below. (Feel Free to reuse it in your project).
class Network(torch.nn.Module): @staticmethod def calc_activation_shape( dim, ksize, dilation=(1, 1), stride=(1, 1), padding=(0, 0) ): def shape_each_dim(i): odim_i = dim[i] + 2 * padding[i] - dilation[i] * (ksize[i] - 1) - 1 return (odim_i / stride[i]) + 1 return shape_each_dim(0), shape_each_dim(1) def __init__(self, idim, num_classes=10): self.layer1 = torch.nn.Conv2D(3, 5, 3) ln_shape = Network.calc_activation_shape(idim, 3) # <--- Calculate the shape of output of Convolution self.norm1 = torch.nn.LayerNorm([5, *ln_shape]) # <--- Normalize activations over C, H, and W (see fig.above) self.layer2 = torch.nn.Conv2D(5, 10, 3) ln_shape = Network.calc_activation_shape(ln_shape, 3) self.norm2 = torch.nn.LayerNorm([10, *ln_shape]) self.layer3 = torch.nn.Dense(num_classes) def __call__(self, inputs): x = F.relu(self.norm1(self.layer1(input))) x = F.relu(self.norm2(self.layer2(x))) x = F.sigmoid(self.layer3(x)) return x
We benchmark the model provided in our colab notebook with and without using Layer Normalization, as noted in the following chart. Layer Norm does quite well here. (As a note: we take an average of 4 runs, the solid line denotes the mean result for these runs. The lighter color denotes the standard deviation.)


Thanks for joining us for this quick introduction to Layer Normalization. Over the coming weeks, we'll be showcasing a few other popular normalization techniques and then ending our series with a meta-report with some tips and tricks on when to use which (though, of course, context matters a ton here).
If you have any requests for other foundational techniques you'd like us to cover, please leave them in the comments below. Till next time!