Exploring Adaptive Gradient Clipping and NFNets
A minimal ablation study of the proposed contributions in the latest High-Performance Large-Scale Image Recognition Without Normalization paper.
Created on March 5|Last edited on March 18
Comment
Introduction
Over the past month or so, many in the computer vision community have been digesting a paper on NFNets that recently achieved state-of-art performance doing classification on ImageNet, "High-Performance Large-Scale Image Recognition Without Normalization". We wanted to take a moment to dig into the research by Brock, De, Smith, and Simonyan (2021) here, talk about the advantages (and disadvantages) of batch normalization, and just mostly roll up our sleeves and talk about this really fascinating research for a few hundred words.
Of course, if you'd like to dig into the source material (and we do recommend that!), you can find those links below. Without further ado:
Paper | GitHub | Official Repo
The Saga of Batch Normalization
Batch normalization (BN) was partially responsible for the growth of deep learning by enabling the training of deeper neural networks. Originally proposed in 2015, adding batch normalization helps normalize the hidden representations learned during training (i.e., the output of hidden layers) in order to address internal covariate shift.
Note however BN has nothing to do with the internal covariate shift. In the paper titled, "How Does Batch Normalization Help Optimization?" by Santurkar et al. (2018) the authors uncovered that BN has a more fundamental impact on training: "it makes the optimization landscape significantly smoother. This smoothness induces a more predictive and stable behavior of the gradients, allowing for faster training."
📌 Note: The experimental results shown below were produced by training a custom convolutional-based neural network on the CIFAR-10 dataset. Check out the linked colab notebooks for implementation details.
No matter the actual reason why BN works, there are many practical benefits of training a model with BN. For starters:
- Models trained with BN converge quickly with better test accuracy.
Try out the Colab Notebook
Run set
2
- BN also enables model training with a large learning rate.
Try out the Colab Notebook
Run set
2
There are a few additional benefits to batch normalization. Namely:
- BN allows efficient large-batch training.
- BN eliminates mean shift. Activation functions like ReLU and GeLU are non-symmetric thus have non-zero mean activation. This introduces mean-shift. Batch normalization ensures the mean activation on each channel is zero across the current batch, eliminating the mean shift.
- BN has a regularization effect. (Source: Towards Understanding Regularization in Batch Normalization by Luo et al.)
- BN smooths loss landscape. (Source: "How Does Batch Normalization Help Optimization?" by Santurkar et al.)
Now, while batch normalization is a key component of most image classification models, it does come with some undesirable properties. The research community does try to find a way around these, but in the long run, it might be more better if we found an alternative to batch normalization instead of dealing with its idiosyncrasies and downsides.
In fact, let's cover a few of the disadvantages of batch normalization:
- It incurs memory overhead. (Source: In-Place Activated BatchNorm for Memory-Optimized Training of DNNs)
- Batch Normalization increases the time to evaluate gradient in some networks.
- Discrepancies between training and inference score if BN is not used carefully.
- BN can break the independence between training examples in the minibatch. Additionally, because of this particular issue:
- It's hard to reproduce the results on different hardware.
- You can run into subtle implementation errors especially in distributed training. For this reason, Synchronized Batch Normalization was proposed by Zhang et al. in Context Encoding for Semantic Segmentation.
- And since batch statistics are computed while training, which can be seen as an interaction between training examples, networks can "cheat" certain loss functions. This is a major concern for sequence modeling tasks, which has driven language models to adopt alternative normalizers.
- Moreover, networks can also degrade if the batch statistics have a large variance during training.
📌 Note: The experimental results shown below were produced by training a ResNet-20 model on the CIFAR-10 dataset. Check out the linked Kaggle kernel for implementation details.
- Lastly, the performance of batch normalization is sensitive to the batch size. The parallel coordinate plot below shows that the final test accuracy is sensitive to the batch size. There's a negative correlation between the batch size and test accuracy (in other words: high batch size leads generally to lower test accuracy).
Try out the Kaggle kernel
Run set
6
📌 Note: "Large-batch training does not achieve higher test accuracies within a fixed epoch budget (Smith et al.,2020), it does achieve a given test accuracy in fewer parameter updates, significantly improving training speed when parallelized across multiple devices."
Towards Normalization-Free Networks
From the benefits of BN, we are aware of the good ingredients that are required for a high-performing neural network. A workable alternative to BN should bring us most (if not all) of these benefits while also mitigating the disadvantages we spelled out above.
Now, previous works have attempted to train deep ResNets to competitive accuracies without normalization by recovering just one or two benefits of BN. The key idea used in those works is to suppress the scale of the activations on the residual branch at initialization by introducing a small constant or learnable scalars.
Normalizer-Free ResNets (NF-ResNets) were first proposed in a paper titled "Characterizing Signal Propagation to close the performance gap in Unnormalized ResNets" by Brock et al. (2021). NF-ResNets are a class of pre-activation ResNets that can be trained to competitive training and test accuracies without normalization layers.
If you want to get up to speed with a ResNet's architecture, here's a nice video summary of the paper by Yannic. But how is NF-ResNet different than good old Resnet?
- NF-ResNet employs a residual block of the form , where:
- and are inputs to the residual branch and the resulting output respectively. is the input for the next residual block.
- is parameterized to be variance preserving function at initialization such that
- is a scalar that specifies the rate at which the variance of the activation increases after each residual block.
- is the standard deviation of the inputs to the layer residual block.
2. NF-ResNet uses Scaled Weight Standardization. Weight Standardization reparameterizes ( ) the convolutional layer such that,
, where:
- and are the reparameterized and original weights respectively,
- and
-
Standard Weight Standardization is a minor modification to Weight Standardization, where, .
3. The activation functions are also scaled by a non-linearity specific scalar gain .
📌 Fact: "With additional regularization (Dropout and Stochastic Depth), NF-ResNets match the test accuracies achieved by batch normalized pre-activation ResNets on ImageNet at batch size 1024. They also significantly outperform their batch normalized counterparts when the batch size is very small, but they perform worse than batch normalized networks for large batch sizes (4096 or higher)."
📌 Note: "NF-ResNets do not match the performance of state-of-the-art networks like EfficientNets which use Batch Normalization."
ResNet Vs. NF-ResNet on Cifar-10 with Batch Size 32
In the end, NF-ResNet-50 outperforms the good old ResNet-50 by a margin of approximately ~15%. (I have used Ross Wightman's timm package for both the model definitions.)
Train NF-ResNet on Colab Notebook
Run set
2
Adaptive Gradient Clipping
One problem is that NF-ResNet could not scale to large batch sizes (4096 or higher) for training. The authors of the paper we're concerned with today (High-Performance Large-Scale Image Recognition Without Normalization) hypothesized that gradient clipping should help scale NF-ResNets to a larger batch. To this end, they proposed Adaptive Gradient Clipping (AGC).
A standard clipping algorithm clips the gradient before updating the parameter θ such that:
Here, (clipping threshold) is the hyperparameter to be tuned.
Gradient clipping can help train at a higher learning rate but is quite sensitive to the clipping threshold (evident from the media panel below).
Try out the Colab Notebook
Run set
2
AGC tries to overcome this issue by introducing adaptive clipping instead of "hard" clipping.
📌 Note: "The AGC algorithm is motivated by the observation that the ratio of the norm of the gradients to the norm of the weights of layer , , provides a simple measure of how much a single gradient descent step will change the original weights ."
The authors in this paper are using unit-wise ratios of gradient norms to parameter norms instead of layer-wise norm ratios. Here, a Frobenius Norm () is used. The AGC is given such that:
where, is the hyperparameter and . A small value of prevents zero-initialized parameters from always having their gradients clipped to zero.
📌 Fact: "Using AGC, we can train NF-ResNets stably with larger batch sizes (up to 4096), as well as with very strong data augmentations like RandAugment."
In order to understand AGC well and see the effect in model training, I set up a bunch of experiments. The intent of these experiments are not to conclude anything serious but rather to explore. The experiments cover a tiny fraction of the possible experimentation configurations.
Here, AGC stabilizes NF-ResNets for a larger-batch configuration. But larger batch size does not lead to higher test accuracies. To this end, I wanted to investigate the effect of AGC and if it should be considered a viable alternative to good old gradient clipping. So how does AGC do?
Effect of AGC on Normalizer Free Networks
First, a note about the three models I trained:
- A ResNet-20 architecture without the Batch Normalization layer with the same configuration (Colab Notebook).
📌 Note: The results are from training with batch size 1024 with the clipping factor of 0.01.
- AGC could not produce a model comparable to the baseline.
- The test accuracy with AGC is approximately the same as that of one trained without Batch Normalization.
- The result correlated with the choice of the clipping factor. That said, we probably shouldn't conclude a lot from this experiment and I highly suggest investigating a wider configuration space before making any broad claims here.
In my opinion the clipping factor of 0.01 might be very tight. I encourage readers to experiment with larger clipping factors.
Run set
3
Relation Between Batch Size and Clipping Factor
The clipping factor for regular gradient clipping is sensitive to batch size, model depth, learning rate, etc. I wanted to investigate the relationship between batch size and clipping factor and their correlation with the final test accuracy.
Using Weights and Biases Sweep I was able to quickly set up my ablation study.
📌 Note: The experimental results shown below were produced by training a ResNet-20 model on the CIFAR-10 dataset. Check out the linked Kaggle kernel for implementation details.
Try out the Kaggle Kernel
- Batch size has a negative correlation with test accuracy. This also validates the theory that increasing batch sizes might not lead to better test accuracy.
- On the other hand, the clipping factor has a positive correlation. Thus in this configuration space, increasing the clipping factor should lead to higher test accuracy.
- However, it's hard to know what exactly to make out of the relationship between the two parameters. Some that I noticed from this experiment and might not hold true for all cases are:
- Small batch size requires a bigger clipping factor.
- Bigger batch sizes tend to work better with smaller clipping factors.
- Out of the two hyperparameters, the batch size is more important towards higher test accuracy.
Run set
25
The authors performed many ablation studies of their own to showcase the effectiveness of AGC on NF-ResNets. These are the key findings:
- The benefits of using AGC are smaller when the batch size is small.
- Smaller clipping thresholds are necessary for stability at higher batch sizes. This is something we saw from my experiments as well.
- It's better to not clip the final linear layer or the output layer.
- It's possible to train NF-ResNet without clipping the gradients for the initial convolutional layers.
NF-Net: The Current SOTA On ImageNet

🎉 SOTA Warning: Our NFNet-F1 model achieves comparable accuracy to an EffNet-B7 while being 8.7X faster to train. Our NFNet-F5 model has similar training latency to EffNet-B7, but achieves a state-of-the-art 86.0% top-1 accuracy on ImageNet.
📌 Note: NF-ResNets and NF-Nets are two different architectures.
Neural network architecture design depends on the choice of metric to optimize. These metrics can be:
- FLOPs count,
- Inference latency on a target device like edge/mobile devices,
- Training latency on an accelerator.
The authors of this paper decided to optimize training latency on existing accelerators. They explored the model space by manually searching through the trends that improved top-1 accuracy on ImageNet against actual training latency on the device. The aim was to maximize on both fronts.
The verdict? The authors achieved the state-of-the-art result of 86.5% top-1 test accuracy on ImageNet by training NFNet-F6 with the recently proposed Sharpness Aware Minimization (SAM) technique. The official implementation of the architecture is in the JAX framework. You can find the official GitHub repo here.
Additionally, Yannic Kilcher did an amazing explanation of NFNets in his YouTube video. Check it out here:
I have provided a colab notebook to train an NF-Net model using PyTorch Lightning on the Caltech-101 dataset. The NF-Net implementation is based on timm package. Feel free to change model variants and other hyperparameters.
Train NFNet on Colab Notebook using PyTorch Lightning
Run set
1
Conclusion and Acknowledgements
The experiments are conducted on the CIFAR-10 dataset which might not be quite enough for a technique to show its effects. But I limited my experimental setup to Colab Notebooks and Kaggle Kernels for ease of use.
I want to give shoutouts to these works which enabled me to compile this report with various code examples and experimental results:
Finally congrats to the authors on this amazing work.
I would also like to thank Sayak Paul and Morgan McGuire for their feedbacks which helped me to improve this report. Additional thanks to Justin Tenuto for his editorial magic.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.