Group Normalization in Pytorch (With Examples)

A quick introduction to group normalization in Pytorch, complete with code and an example to get you started. Made by Adrish Dey using Weights & Biases
Adrish Dey
In this report, we will look into yet another widely used normalization technique in deep learning: group normalization. First introduced by Wu et.al.[1], group normalization serves as an alternative to layer normalization and Instance normalization for tackling the same statistical instabilities posed by batch normalization.
Here, we'll use the W&B toolset to visually demonstrate the effectiveness of group norm against batch normalization and no normalization function at all. The notebook used for generating the results is linked below.
Link to Colab Notebook \longrightarrow

Table of Contents (click to expand)

What is Group Normalization?

Much like batch normalization, group normalization is used for statistical stabilization of training processes. This is done by limiting the intermediate activations to follow a simple unit normal distribution. However, as discussed in the layer normalization report, batch normalization is heavily relies on large batch size, a property which is difficulty to obtain for larger models and datasets.
Instance normalization solves this by normalizing the activations along one channel per batch, while layer norm does it by normalizing across all channels per batch. Group normalization serves as a trade-off between these two methods.
Simply put, group norm divides the channels in multiple fixed size groups of size G, and normalizes across each group per batch. Formally speaking, for a a given activation of shape [N, C, H, W] and a group size of G, group norm, calculates the mean \mu_G and standard deviation \sigma_G per group by reshaping the activation as[N, \frac{C}{G}, G, H, W]. These statistical moments (i.e., \mu_G and \sigma_G) are then used for normalize the activations along each group using a similar formula as the one used in batch normalization.
Group normalization is particularly useful, as it allows an intuitive way to interpolate between layer norm (G = C) and instance norm (G = 1), where G serves as an extra hyperparameter to optimize for.

Code for Group Norm in Pytorch

Implementing group normalization in any framework is simple. However, Pytorch makes it even simpler by providing a plug-and-play module: torch.nn.GroupNorm.
import torchnum_groups = 4# MNIST Classifiernet = torch.nn.Sequential( torch.nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2), # GroupNorm takes number of groups to divide the # channels in and the number of channels to expect # in the input torch.nn.GroupNorm(num_groups, 32), torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2), torch.nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2), torch.nn.GroupNorm(num_groups, 64), torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2), torch.nn.Flatten(), torch.nn.Linear(7 * 7 * 64, 10),)

Experiments

In the experiments we use the same architecture as defined above, and use W&B Sweeps to compare between batch norm, group norm, and the absence of any normalization technique. As you can see in the chart below, the group normalization (with num_groups = 8) outperforms batch normalization.
The following W&B Parallel Plot demonstrates the importance of each variables (num_groups, norm_type and batch size) for accuracy. As you can see here, group norm utilized the higher number of iterations created as a byproduct of small batch size to obtain the highest accuracy.

Conclusion

Thanks for joining us for this quick introduction to Group Normalization. Over the coming weeks, we'll be showcasing a few other popular normalization techniques and then ending our series with a meta-report with some tips and tricks on when to use which. In the meantime, if you missed our prior reports, you can find those below:
Report Gallery