Revisiting ResNets: Improved Training and Scaling Strategies

Do training methods matter more than model architectural changes? In this post we take a dive into the ResNet-RS paper and try to answer this question.
Aman Arora

Paper | GitHub | Model checkpoints

Introduction

With over 63,000 citations, ResNets have been at the forefront of research in Computer Vision (CV) even today. Most recent CV papers compare their results to ResNets to showcase improvements either in accuracy or speed or both.
As an example, the EfficientNet-B4 architecture with similar flops as ResNet-50 has been able to improve the top-1 ImageNet accuracy from 76.3% to 83.0%! [Reference]
❓: But, do such improvements on ImageNet top-1 accuracy come from model architectures or improved training and scaling strategies?
This is precisely the question that Bello et al. (2021) try to answer in their recent paper Revisiting ResNets: Improved Training and Scaling Strategies. While doing so, they also introduce new family of re-scaled ResNet architectures called ResNet-RS!

Why should you read this blog?

Why not just download the original paper, perhaps print it, and read that instead? Here are 4 reasons why:
  1. As part of this blog, we introduce the new ResNet-RS architecture to our readers, providing all context and history for concepts that might have been skipped in the original research paper - making this blog post more beginner-friendly.
  2. Links to corresponding papers, explanations, and blogs have been provided for each of the training and regularization strategies mentioned in the paper such as Model EMA, Label Smoothing, Stochastic depth, RandAugment, Dropout, and more!
  3. We showcase how to implement ResNet-RS architecture in PyTorch using TIMM.
👉: ResNet-RS models are also available in TIMM with pre-trained weights from TensorFlow.

Prerequisite

To completely understand the architectural changes proposed in the ResNet-RS architecture, it is recommended that the readers have a good understanding of the ResNet architecture.
Chapter-14 of Deep Learning for Coders with Fastai and PyTorch: AI Applications Without a PhD by Jeremy Howard and Sylvain Gugger is an excellent resource to learn about ResNets!

ResNet-RS

With introductions now out of the way, it is time to get our hands dirty. In this section, we will try to answer some key questions like:
  1. What are the main contributions of the paper?
  2. How is the ResNet-RS architecture from the traditional ResNet architecture?
  3. What are the improved training strategies?
  4. Which scaling strategy did the authors use and recommend? Is compound scaling strategy from the EfficientNet research paper the most effective?
  5. How much do training & scaling strategies improve performance when compared to model architecture improvements
  6. Which data augmentations did the authors use during training of the ResNet-RS architecture? And how do these contribute towards ImageNet top-1 performance?
Figure-1: Improving ResNets to state-of-the-art performance. ResNet-RS architecture outperforms EfficientNets on the speed-accuracy Pareto curve with speed-ups ranging from 1.7x-2.7x on TPUs and 2.1x-3.3x on GPUs. ResNet* represents ResNet-200 architecture trained on ImageNet at 256x256 image resolution.
As can be seen from Figure-1 [reference], ResNet-RS architectures can be up to 3.3x times faster than EfficientNets on GPUs!
Personally, the way I like to analyze this figure is by looking at the ResNet architecture (green dot). With improved training & regularization strategies (discussed later in this post), the authors were able to improve the top-1 accuracy from 79.0% to 82.2%, without any architectural changes! With additional minor architectural changes (also discussed later in this post), the authors were able to further increase the top-1 accuracy to 83.4%!
❓: So what are these training and architectural changes?
Well, let's look at them next!

Improved training & regularization methods

In this section, we discuss the updated training & regularization methods that led to an increase in the ImageNet top-1 from 79.0% to 82.2% without any architectural changes.
Table-1: Additive study of the ResNet-RS training recipe. The purple, green, and yellow colors above refer to training methods, regularization methods, and architectural improvements.
The training and regularization methods have been mentioned in purple and green respectively in Table-1 above along with their contributions to top-1 accuracy.
✅: One key thing to note from Table-1 is that although increasing training epochs actually leads to a -0.5% change in top-1 accuracy, the paper reports that increasing training epochs is only useful once the regularization methods are used. Otherwise, it leads to over-fitting.
Let's now look at each of the training and regularization strategies in detail below.

Regularization Strategies

EMA of weights

What does EMA of weights here mean? EMA stands for Exponential Moving Average.
❓: What is Exponential Moving Average?
Let's say, Adam sells lemonade. On Day-1, he sold 5 lemonades & on Day-2, he sold 3 lemonades.
In that case, the average number of lemonades he sold across the two days is \frac{5+3}{2} = 4.
While, at the same time, given a decay factor of 0.99, the Exponential Moving Average is \frac{0.99*3 + 0.01*5}{2} = 3.02.
✅: Exponential Moving Average essentially gives more importance to the recent data.
❓: But what does EMA mean in the context of improved regularization strategy for ResNet-RS?
Well, basically, it just means that we take the Exponential Moving Average of the model weights and therefore, give high importance to model weights after the last epoch while at the same time, keeping some factor of the model weights from the past epochs.
❓: How do we implement EMA in PyTorch?
Below, I've copied the implementation of EMA from TIMM.
class ModelEMA(nn.Module): def __init__(self, model, decay=0.9999, device=None): super(ModelEmaV2, self).__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) self.module.eval() self.decay = decay self.device = device # perform ema on different device from model if set if self.device is not None: self.module.to(device=device) def _update(self, model, update_fn): with torch.no_grad(): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): if self.device is not None: model_v = model_v.to(device=self.device) ema_v.copy_(update_fn(ema_v, model_v)) def update(self, model): self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) def set(self, model): self._update(model, update_fn=lambda e, m: m)
To apply EMA to model weights in PyTorch, we could just wrap any given model inside the ModelEMA class shared above and during train, call model.update() to get the EMA of model weights like so:
train_dataset, eval_dataset = create_datasets()train_dataloader, eval_dataloader = create_dataloaders()model = create_model()# wrap model in ModelEMA classema_model = ModelEMA(model) for epoch in EPOCHS: batch = next(iter(train_dataloader)) loss = model(**batch) loss.backward() # apply EMA to model weights ema_model.update(model)
👉: I now leave it as an exercise to the reader to map the implementation to the definition of Exponential Moving Average that I shared above. If you have any questions or are not able to understand EMA in the context of regularization strategy in CV, please feel free to reach out to me or comment at the end of this report.

Label Smoothing

❓: What is Label Smoothing?
✅: Label Smoothing Explained using Microsoft Excel is an excellent blog post that explains Label Smoothing using Microsoft Excel! I would refer my dear reader to the blog post for a complete introduction to Label Smoothing.

Stochastic Depth

So as per table-1, stochastic depth leads to around a +0.2% increase in top-1 accuracy. But what is stochastic depth? And how do you implement it? As part of this section, we try to answer these questions.
❓: What is stochastic depth?
It was introduced in Deep Networks with Stochastic Depth paper by Huang et al. (2016).
Figure-2: The linear decay of survival propability p_l on a ResNet with stochastic depth with p_0=1 and p_l = 0.5.
The idea is presented in Figure-2. At every stage, there is a residual path (in yellow) and the identity path. Stochastic depth randomly drops entire residual path outputs based on the "survival" probability of the block. Drop path can be thought of like a gate on the residual branch that is sometimes open/closed. If the gate is open, it lets the outputs pass and the block behaves the same as the original block in ResNet research paper.
I would refer the reader to the paper for more in-depth information.
❓: How to implement Stochastic Depth in PyTorch?
Again, below, I've copied the implementation of Stochastic Depth from TIMM.
def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return outputclass DropPath(nn.Module): def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training)
👉: While I do leave it as an exercise to the reader to map the implementation to the definition of Stochastic Depth, there is one quick question that I'd like to answer below.
❓: Why do we divide the input x in drop_path function above to keep_prob?
This is to scale the input activations similar to dropout but during training time instead. This "inverted dropout" has been explained beautifully here in CS231n course by Stanford.

RandAugment

❓: What is RandAugment?
✅: RandAugment was first introduced in RandAugment: Practical automated data augmentation with a reduced search space and has been explained as part of the timmdocs here.

Dropout

❓: What Is Dropout?
✅: I would refer my dear readers to the Regularization section of the CS231n: Convolutional Neural Networks for Visual Recognition course by Stanford for more information on Dropout.

Decrease Weight Decay

One of the last changes that the authors mentioned was to decrease weight decay.
Table-2: Decrease weight when using more regularization. Top-1 ImageNet accuracy for different regularization combining regularization methods such as dropout (DO), stochastic depth (SD), label smoothing (LS), and RandAugment(RA). Image resolution is 224×224 for ResNet-50 and 256×256 for ResNet-200.
As shown in Table-2 above, the authors found that when using RandAugment and label smoothing, there is no need to change the default weight decay of 1e-4. But, when we further add dropout and/or stochastic depth, the performance can decrease unless weight decay is further decreased. The intuition is that since weight decay acts as a regularizer, its value must be decreased in order to not overly regularize the model when combining many techniques.

Training Strategies

The two main training strategies that the authors found to be useful to help increase the top-1 accuracy are -
  1. Use SGDR: Stochastic Gradient Descent with Warm Restarts learning rate schedule
  2. Train for more number of epochs (350 epochs)

Cosine LR Decay

❓: What is Cosine LR Decay?
✅: I would recommend my dear readers to "Improving the way we work with learning rate" by Vitaly Bushaev for an introduction to SGDR: Stochastic Gradient Descent with Warm Restarts.

ResNet-RS Architecture

So far we have discussed the training and regularization strategies that helped the authors get the top-1 accuracy from 79.0% for a Vanilla ResNet to 82.2% without any architectural changes. But, the authors also introduced small architectural changes that further helped them get the top-1 accuracy up to 83.4%! As part of this section, we are going to be looking at those architectural changes in detail.
Essentially the authors introduced ResNet-D modification and Squeeze-and-Excitation (SE) in all bottleneck blocks.

Squeeze-and-Excitation

The authors introduced the Squeeze & Excitation to the Vanilla ResNet to further boost the accuracy!
❓: So what does it mean by Squeeze-and-Excitation?
✅: For a complete understanding of what Squeeze-and-excitation Networks are along with code implementation in PyTorch, please refer to Squeeze and Excitation Networks Explained with PyTorch Implementation.

ResNet-D

The authors also introduced ResNet-D modifications to the vanilla ResNet network which have been mentioned below (mentioned in section 4.1 from the ResNet-RS paper):
  1. The 7×7 convolution in the stem is replaced by three smaller 3×3 convolutions, as first proposed in InceptionV3.
  2. The stride sizes are switched for the first two convolutions in the residual path of the downsampling blocks.
  3. The stride-2 1×1 convolution in the skip connection path of the downsampling blocks is replaced by stride-2 2×2 average pooling and then a non-strided 1×1 convolution.
  4. The stride2 3×3 max pool layer is removed and the downsampling occurs in the first 3×3 convolution in the next bottleneck block.
Based on these modifications, the updated architecture now looks like this:
Figure-3: ResNet-RS Architecture Diagram. Output Size assumes a 224×224 input image resolution. The × symbol refers to how many times the blocks are repeated in the ResNet-101 architecture.
💭: I hope that these architectural changes mentioned make sense to my dear readers because we still have to implement them in PyTorch using TIMM next. Please feel free to reach out to me or comment towards the end of this report should you have any questions.

ResNet-RS in PyTorch

A complete training notebook with code to implement all variants of ResNet-RS models in PyTorch using TIMM and train on the Imagenette dataset has been provided below:

Colab Notebook with ResNet-RS model implementation

Conclusion

I hope that as part of this blog, I have been able to introduce ResNet-RS architecture to my dear readers. Please feel free to comment below - would love to get your feedback or answer any questions related to this research paper. Thanks for reading!