Skip to main content

Group Normalization in Pytorch (With Examples)

A quick introduction to group normalization in Pytorch, complete with code and an example to get you started
Created on December 17|Last edited on January 13

In this report, we will look into yet another widely used normalization technique in deep learning: group normalization. First introduced by Wu et.al.[1], group normalization serves as an alternative to layer normalization and Instance normalization for tackling the same statistical instabilities posed by batch normalization.
Here, we'll use the W&B toolset to visually demonstrate the effectiveness of group norm against batch normalization and no normalization function at all. The notebook used for generating the results is linked below.

Table of Contents (click to expand)

What is Group Normalization?

Much like batch normalization, group normalization is used for statistical stabilization of training processes. This is done by limiting the intermediate activations to follow a simple unit normal distribution. However, as discussed in the layer normalization report, batch normalization is heavily relies on large batch size, a property which is difficulty to obtain for larger models and datasets.
Instance normalization solves this by normalizing the activations along one channel per batch, while layer norm does it by normalizing across all channels per batch. Group normalization serves as a trade-off between these two methods.
Simply put, group norm divides the channels in multiple fixed size groups of size GG, and normalizes across each group per batch. Formally speaking, for a a given activation of shape [N,C,H,W][N, C, H, W] and a group size of GG, group norm, calculates the mean μG\mu_G and standard deviation σG\sigma_G per group by reshaping the activation as[N,CG,G,H,W][N, \frac{C}{G}, G, H, W]. These statistical moments (i.e., μG\mu_G and σG)\sigma_G) are then used for normalize the activations along each group using a similar formula as the one used in batch normalization.
Group normalization is particularly useful, as it allows an intuitive way to interpolate between layer norm (G=C)G = C) and instance norm (G=1G = 1), where GG serves as an extra hyperparameter to optimize for.




Code for Group Norm in Pytorch

Implementing group normalization in any framework is simple. However, Pytorch makes it even simpler by providing a plug-and-play module: torch.nn.GroupNorm.
import torch

num_groups = 4

# MNIST Classifier
net = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
# GroupNorm takes number of groups to divide the
# channels in and the number of channels to expect
# in the input
torch.nn.GroupNorm(num_groups, 32),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2),

torch.nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
torch.nn.GroupNorm(num_groups, 64),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2),

torch.nn.Flatten(),
torch.nn.Linear(7 * 7 * 64, 10),
)


Experiments

In the experiments we use the same architecture as defined above, and use W&B Sweeps to compare between batch norm, group norm, and the absence of any normalization technique. As you can see in the chart below, the group normalization (with num_groups = 8) outperforms batch normalization.

Run set
3

The following W&B Parallel Plot demonstrates the importance of each variables (num_groups, norm_type and batch size) for accuracy. As you can see here, group norm utilized the higher number of iterations created as a byproduct of small batch size to obtain the highest accuracy.

Run set
30


Conclusion

Thanks for joining us for this quick introduction to Group Normalization. Over the coming weeks, we'll be showcasing a few other popular normalization techniques and then ending our series with a meta-report with some tips and tricks on when to use which. In the meantime, if you missed our prior reports, you can find those below:

Iterate on AI agents and models faster. Try Weights & Biases today.