Exploring Adaptive Gradient Clipping and NFNets

A minimal ablation study of the proposed contributions in the latest High-Performance Large-Scale Image Recognition Without Normalization paper.
Ayush Thakur

Introduction

Over the past month or so, many in the computer vision community have been digesting a paper on NFNets that recently achieved state-of-art performance doing classification on ImageNet, "High-Performance Large-Scale Image Recognition Without Normalization". We wanted to take a moment to dig into the research by Brock, De, Smith, and Simonyan (2021) here, talk about the advantages (and disadvantages) of batch normalization, and just mostly roll up our sleeves and talk about this really fascinating research for a few hundred words.
Of course, if you'd like to dig into the source material (and we do recommend that!), you can find those links below. Without further ado:

Paper | GitHub | Official Repo

The Saga of Batch Normalization

Batch normalization (BN) was partially responsible for the growth of deep learning by enabling the training of deeper neural networks. Originally proposed in 2015, adding batch normalization helps normalize the hidden representations learned during training (i.e., the output of hidden layers) in order to address internal covariate shift.
Note however BN has nothing to do with the internal covariate shift. In the paper titled, "How Does Batch Normalization Help Optimization?" by Santurkar et al. (2018) the authors uncovered that BN has a more fundamental impact on training: "it makes the optimization landscape significantly smoother. This smoothness induces a more predictive and stable behavior of the gradients, allowing for faster training."
📌 Note: The experimental results shown below were produced by training a custom convolutional-based neural network on the CIFAR-10 dataset. Check out the linked colab notebooks for implementation details.
No matter the actual reason why BN works, there are many practical benefits of training a model with BN. For starters:

Try out the Colab Notebook \rightarrow

Try out the Colab Notebook \rightarrow

There are a few additional benefits to batch normalization. Namely:
Now, while batch normalization is a key component of most image classification models, it does come with some undesirable properties. The research community does try to find a way around these, but in the long run, it might be more better if we found an alternative to batch normalization instead of dealing with its idiosyncrasies and downsides.
In fact, let's cover a few of the disadvantages of batch normalization:
📌 Note: The experimental results shown below were produced by training a ResNet-20 model on the CIFAR-10 dataset. Check out the linked Kaggle kernel for implementation details.

Try out the Kaggle kernel \rightarrow

📌 Note: "Large-batch training does not achieve higher test accuracies within a fixed epoch budget (Smith et al.,2020), it does achieve a given test accuracy in fewer parameter updates, significantly improving training speed when parallelized across multiple devices."

Towards Normalization-Free Networks

From the benefits of BN, we are aware of the good ingredients that are required for a high-performing neural network. A workable alternative to BN should bring us most (if not all) of these benefits while also mitigating the disadvantages we spelled out above.
Now, previous works have attempted to train deep ResNets to competitive accuracies without normalization by recovering just one or two benefits of BN. The key idea used in those works is to suppress the scale of the activations on the residual branch at initialization by introducing a small constant or learnable scalars.
Normalizer-Free ResNets (NF-ResNets) were first proposed in a paper titled "Characterizing Signal Propagation to close the performance gap in Unnormalized ResNets" by Brock et al. (2021). NF-ResNets are a class of pre-activation ResNets that can be trained to competitive training and test accuracies without normalization layers.
If you want to get up to speed with a ResNet's architecture, here's a nice video summary of the paper by Yannic. But how is NF-ResNet different than good old Resnet?
  1. NF-ResNet employs a residual block of the form h_{i=1} = h_i + αf_i(h_i/β_i), where:
2. NF-ResNet uses Scaled Weight Standardization. Weight Standardization reparameterizes (W_{i,j} \rightarrow \hat W_{i, j}) the convolutional layer such that,
\hat W = (W_{i,j} - μ_i) / σ_i, where:
Standard Weight Standardization is a minor modification to Weight Standardization, where, σ_i^2 = (1/N) \sum_{j=1}^N (W_{i, j} - μ_i)^2.
3. The activation functions are also scaled by a non-linearity specific scalar gain γ.
📌 Fact: "With additional regularization (Dropout and Stochastic Depth), NF-ResNets match the test accuracies achieved by batch normalized pre-activation ResNets on ImageNet at batch size 1024. They also significantly outperform their batch normalized counterparts when the batch size is very small, but they perform worse than batch normalized networks for large batch sizes (4096 or higher)."
📌 Note: "NF-ResNets do not match the performance of state-of-the-art networks like EfficientNets which use Batch Normalization."

ResNet Vs. NF-ResNet on Cifar-10 with Batch Size 32

In the end, NF-ResNet-50 outperforms the good old ResNet-50 by a margin of approximately ~15%. (I have used Ross Wightman's timm package for both the model definitions.)

Train NF-ResNet on Colab Notebook \rightarrow

Adaptive Gradient Clipping

One problem is that NF-ResNet could not scale to large batch sizes (4096 or higher) for training. The authors of the paper we're concerned with today (High-Performance Large-Scale Image Recognition Without Normalization) hypothesized that gradient clipping should help scale NF-ResNets to a larger batch. To this end, they proposed Adaptive Gradient Clipping (AGC).
A standard clipping algorithm clips the gradient before updating the parameter θ such that:
G \rightarrow\left\{\begin{array}{ll} \lambda \frac{G}{\|G\|} & \text { if }\|G\|>\lambda \\ G & \text { otherwise } \end{array}\right.
Here, \lambda (clipping threshold) is the hyperparameter to be tuned.
Gradient clipping can help train at a higher learning rate but is quite sensitive to the clipping threshold (evident from the media panel below).

Try out the Colab Notebook \rightarrow

AGC tries to overcome this issue by introducing adaptive clipping instead of "hard" clipping.
📌 Note: "The AGC algorithm is motivated by the observation that the ratio of the norm of the gradients G^l to the norm of the weights W^l of layer l, \frac{\left\|G^{\ell}\right\|_{F}}{\left\|W^{\ell}\right\|_{F}} , provides a simple measure of how much a single gradient descent step will change the original weights W^l."
The authors in this paper are using unit-wise ratios of gradient norms to parameter norms instead of layer-wise norm ratios. Here, a Frobenius Norm (||\space.\space||_F) is used. The AGC is given such that:
G_{i}^{\ell} \rightarrow\left\{\begin{array}{ll} \lambda \frac{\left\|W_{i}^{\ell}\right\|_{F}^{\star}}{\left\|G_{i}^{\ell}\right\|_{F}} G_{i}^{\ell} & \text { if } \frac{\left\|G_{i}^{\ell}\right\|_{F}}{\left\|W_{i}^{\ell}\right\|_{F}^{\star}}>\lambda \\ G_{i}^{\ell} & \text { otherwise. } \end{array}\right.
where, \lambda is the hyperparameter and ||W_i||_F^* = max(||W_i||_F, ε). A small value of ε = 10^-3 prevents zero-initialized parameters from always having their gradients clipped to zero.
📌 Fact: "Using AGC, we can train NF-ResNets stably with larger batch sizes (up to 4096), as well as with very strong data augmentations like RandAugment."
In order to understand AGC well and see the effect in model training, I set up a bunch of experiments. The intent of these experiments are not to conclude anything serious but rather to explore. The experiments cover a tiny fraction of the possible experimentation configurations.
Here, AGC stabilizes NF-ResNets for a larger-batch configuration. But larger batch size does not lead to higher test accuracies. To this end, I wanted to investigate the effect of AGC and if it should be considered a viable alternative to good old gradient clipping. So how does AGC do?

Effect of AGC on Normalizer Free Networks

First, a note about the three models I trained:
📌 Note: The results are from training with batch size 1024 with the clipping factor of 0.01.
In my opinion the clipping factor of 0.01 might be very tight. I encourage readers to experiment with larger clipping factors.

Relation Between Batch Size and Clipping Factor

The clipping factor for regular gradient clipping is sensitive to batch size, model depth, learning rate, etc. I wanted to investigate the relationship between batch size and clipping factor and their correlation with the final test accuracy.
Using Weights and Biases Sweep I was able to quickly set up my ablation study.
📌 Note: The experimental results shown below were produced by training a ResNet-20 model on the CIFAR-10 dataset. Check out the linked Kaggle kernel for implementation details.

Try out the Kaggle Kernel \rightarrow

The authors performed many ablation studies of their own to showcase the effectiveness of AGC on NF-ResNets. These are the key findings:

NF-Net: The Current SOTA On ImageNet

(Source)
🎉 SOTA Warning: Our NFNet-F1 model achieves comparable accuracy to an EffNet-B7 while being 8.7X faster to train. Our NFNet-F5 model has similar training latency to EffNet-B7, but achieves a state-of-the-art 86.0% top-1 accuracy on ImageNet.
📌 Note: NF-ResNets and NF-Nets are two different architectures.
Neural network architecture design depends on the choice of metric to optimize. These metrics can be:
The authors of this paper decided to optimize training latency on existing accelerators. They explored the model space by manually searching through the trends that improved top-1 accuracy on ImageNet against actual training latency on the device. The aim was to maximize on both fronts.
The verdict? The authors achieved the state-of-the-art result of 86.5% top-1 test accuracy on ImageNet by training NFNet-F6 with the recently proposed Sharpness Aware Minimization (SAM) technique. The official implementation of the architecture is in the JAX framework. You can find the official GitHub repo here.
Additionally, Yannic Kilcher did an amazing explanation of NFNets in his YouTube video. Check it out here:
I have provided a colab notebook to train an NF-Net model using PyTorch Lightning on the Caltech-101 dataset. The NF-Net implementation is based on timm package. Feel free to change model variants and other hyperparameters.

Train NFNet on Colab Notebook using PyTorch Lightning \rightarrow

Conclusion and Acknowledgements

The experiments are conducted on the CIFAR-10 dataset which might not be quite enough for a technique to show its effects. But I limited my experimental setup to Colab Notebooks and Kaggle Kernels for ease of use.
I want to give shoutouts to these works which enabled me to compile this report with various code examples and experimental results:
Finally congrats to the authors on this amazing work.
I would also like to thank Sayak Paul and Morgan McGuire for their feedbacks which helped me to improve this report. Additional thanks to Justin Tenuto for his editorial magic.