Layer Normalization in Pytorch (With Examples)
A quick and dirty introduction to Layer Normalization in Pytorch, complete with code and interactive panels.
Created on December 3|Last edited on July 9
Comment
Training machine learning algorithms can be a challenging task, especially with real-world datasets. Among the numerous pitfalls that one can fall into, statistical stabilization of the intermediate activations is often pretty high on the list.
In this report, we'll have a quick discussion of one of the common methods used for statistical stabilization: Layer Norm. This Report is a continuation of our series on Normalizations in Machine Learning, which started with Batch Norm. We hope to have the last couple out before the end of the year.
Table of Contents
What s Layer Normalization Anyway?
The Problem
As you may know already, training a machine learning model is a stochastic (random) process. This stems from the fact the initializations and even the most common optimizers (SGD, Adam, and so on) are stochastic in nature.
Due to this, ML optimizations tend to have a risk of converging into sharp (non-generalizable) minimal on the solution landscape, resulting in large gradients. Simply put: the activations (a.k.a. the output from a non-linear layer) have the tendency to shoot up to large values. This is not ideal, to say the least, and the most common method of fixing this is by using Batch Normalization.
However, there's a catch here. Batch Normalization quickly fails as soon as the number of batches is reduced. As modern-day ML algorithms increase in data resolution, this becomes a big problem; the batch size needs to be small in order to fit data in memory. Furthermore, performing Batch Normalization requires calculating the running mean/variance of activations at each layer. This method is not applicable for iterative models (like RNNs) where a statistical estimate of the layers depends on the length of the sequence (i.e., the number of times the same hidden layer is being called).
The Solution
LayerNorm offers a simple solution to both these problems by calculating the statistics (i.e., mean and variance) for each item in a batch of activations and normalizing each item with these statistical estimates.
Specifically, given a sample of shape LayerNorm calculates a mean and variance of all the elements of shape in each batch (see the figure below). This method not only solves both problems mentioned above, but also removes the requirement for storing mean and variances for inference (something which Batch Normalization layers needs to do during training).

Let's see some code
Implementing Layer Normalization in PyTorch is a relatively simple task. To do so, you can use torch.nn.LayerNorm().
For convolutional neural networks, however, one also needs to calculate the shape of the output activation map given the parameters used while performing convolution. A simple implementation is provided in calc_activation_shape() function below. (Feel Free to reuse it in your project).
class Network(torch.nn.Module):@staticmethoddef calc_activation_shape(dim, ksize, dilation=(1, 1), stride=(1, 1), padding=(0, 0)):def shape_each_dim(i):odim_i = dim[i] + 2 * padding[i] - dilation[i] * (ksize[i] - 1) - 1return (odim_i / stride[i]) + 1return shape_each_dim(0), shape_each_dim(1)def __init__(self, idim, num_classes=10):self.layer1 = torch.nn.Conv2D(3, 5, 3)ln_shape = Network.calc_activation_shape(idim, 3) # <--- Calculate the shape of output of Convolutionself.norm1 = torch.nn.LayerNorm([5, *ln_shape]) # <--- Normalize activations over C, H, and W (see fig.above)self.layer2 = torch.nn.Conv2D(5, 10, 3)ln_shape = Network.calc_activation_shape(ln_shape, 3)self.norm2 = torch.nn.LayerNorm([10, *ln_shape])self.layer3 = torch.nn.Dense(num_classes)def __call__(self, inputs):x = F.relu(self.norm1(self.layer1(input)))x = F.relu(self.norm2(self.layer2(x)))x = F.sigmoid(self.layer3(x))return x
We benchmark the model provided in our colab notebook with and without using Layer Normalization, as noted in the following chart. Layer Norm does quite well here. (As a note: we take an average of 4 runs. The solid line denotes the mean result for these runs. The lighter color denotes the standard deviation.)
Run set
8
Conclusion
Thanks for joining us for this quick introduction to Layer Normalization. Over the coming weeks, we'll be showcasing a few other popular normalization techniques and then ending our series with a meta-report with some tips and tricks on when to use, which (though, of course, context matters a ton here).
If you have any requests for other foundational techniques you'd like us to cover, please leave them in the comments below. Till next time!
Add a comment
Thanks for the explanation. Nevertheless, using CNNs with LayerNorm might not be the best practice as suggested by Layer Normalization paper authors. https://arxiv.org/abs/1607.06450
Its performance is analyzed and tested against fully connected and recurrent neural networks, as is shown in the paper.
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.