As part of this report, we are going to be looking at "Characterizing signal propagation to close the performance gap in unnormalized ResNets" paper in detail.

This report has been divided into two main sections:

- Introduction - High level introduction to the paper's key ideas and contributions
- Code Implementation - In depth understanding of the paper with code implementation in PyTorch (code referenced from timm)

This report aims to break down the research paper into simple, consumable parts and follows a top-down approach. We start out with a high-level understanding before looking into the details of the paper. This report also aims to explain the research paper from a beginner's perspective and therefore is somewhat longer than the research paper itself in some sections.

In most sections of this report, we reference text directly from the research paper and then paraphrase it to explain some ideas in a more simple language, based on my understanding of the paper.

💭: Throughout this report, you will find sidebars like this. These are my own personal comments that might help you along the way. These sidebars could also include some funny bits or failed experimentation stories.

💭: This report is long, longer than an average blog post, but it is complete. I have made a conscious effort to not leave any concept out of this report. IMHO, this report is best read in parts and might require re-reading for users who are being introduced to NF-ResNets for the first time.

💭: Note that this paper is built on past research, and therefore some prior knowledge about various concepts mentioned in the Prerequisites section below will really help. If you feel like this report on its own wasn't sufficient to help explain this research paper, then I am also open to hosting a paper reading session where we could discuss past research to help fill the gap. Let me know and I'll make it happen :)

💭: My personal aim is to create a report that distills the research paper into a simple consumable format. If some parts of the report don't make sense or are confusing, please feel free to provide constructive feedback towards the end of this report.

To get the best out of this report, I would recommend that the reader has some general understanding about ResNets, Batch Normalization, ReLU activation & Weight Standardization. Here are some resources that might quickly get you up to speed with a general introduction to these:

- The Drawbacks of Batch Normalization and introduction to other Normalization techniques sections in Group Normalization blog post by Aman Arora

💭: I am bit biased on the second resource mentioned above. ;)

From the Introduction section in the paper:

BatchNorm has become a core computational primitive in deep learning, and it is used in almost all state-of-the-art image classifiers. A number of different benefits of BatchNorm have been identified. It smoothens the loss landscape, which allows training with larger learning rates, and the noise arising from the mini-batch estimates of the batch statistics introduces implicit regularization. However, BatchNorm also has many disadvantages. Its behavior is strongly dependent on the batchsize, performing poorly when the per device batch size is too small or too large, and it introduces a discrepancy between the behaviour of the model during training and at inference time. A number of alternative normalization layers have been proposed, but typically these alternatives generalize poorly or introduce their own drawbacks, such as added compute costs at inference. Another line of work has sought to eliminate layers which normalize hidden activations entirely.

In this paper, the authors seek to establish a general recipe for training deep ResNets without normalization layers which achieve test accuracies competitive with state of the art! Batch Normalization (BatchNorm) has been key in advancing deep learning research in computer vision, but, in the past few years, a new line of research has emerged that seeks to eliminate layers which normalize activations entirely.

❓: Why do we want to remove BatchNorm? This has been answered in Why do we need Normalizer-Free networks? What's wrong with BatchNorm? section of this report.

This research paper follows this line of research and its key contributions are:

- Signal Propagation Plots: The authors propose a simple set of visualizations which helps practitioners inspect signal propagation at initialization on the forward pass in deep residual networks.
- Scaled Weight Standardization: The authors identify a key failure mode in past unnormalized ResNets with ReLU or Swish activations and Gaussian weights. Because the mean output of these non linearities is positive, the squared mean of the hidden activations on each channel grows rapidly as the network depth increases. To resolve this, the authors propose Scaled Weight Standardization which is an extension of Weight Standardization. Essentially, weight standardization normalizes weights in convolution layers, i.e., making the weights have zero mean and unit variance.
- Comparable performance with BatchNorm Counterparts: The authors apply the normalization-free network structure in conjunction with Scaled Weight Standardization to ResNets on ImageNet, where they, for the first time attain performance which is comparable or better than batch-normalized ResNets on networks as deep as 288 layers.

❓: What does "mean output of these non linearities is positive" in the second bullet point above mean? Remember ReLU is nothing but a max(0, x) operation given we have some input x. So the output of ReLU activation is always positive, thus, there is a mean shift towards mean(x)>0. To counter this, authors introduced Scaled Weight Standardization which is discussed later in Scaled Weight Standardization section of the report.

💭: Networks without BatchNorm until this point weren't able to achieve comparable SOTA performance, so you can imagine why this is pretty exciting! In fact, in their follow up paper, the authors even manage to attain a new SOTA! But, let's not jump ahead of ourselves for now..

If you're asking - "what's wrong with BatchNorm? I have seen it in almost all networks so far..", then this section helps provide the answer.

Basically, BatchNorm has some really nice properties but some drawbacks too.

From the paper, some key advantages that have been mentioned are:

A number of benefits of BatchNorm have been identified. It smoothens the loss landscape (Santurkar et al, 2018), which allows training with larger learning rates (Bjorck et al, 2018) and the noise arising from the minibatch estimates of the batch statistics introduces implicit regularization (Luo et al, 2019). It also mantains good signal propogation at initialization in deep residual networks with identity skip connections (De & Smith, 2020).

But BatchNorm also has many disadvantages. Also, from the paper:

Its behavior is strongly dependent on the batchsize, performing poorly when the per device batch size is too small or too large (Hoffer et al., 2017),and it introduces a discrepancy between the behaviour of the model during training and at inference time. BatchNorm also adds memory overhead (Rota Bulo et al., 2018), and is a common source of implementation errors (Pham et al., 2019). In addition, it is often difficult to replicate batch normalized models trained on different hardware.

Therefore, this line of research follows this argument - "If we can find normalizer-free networks, which keep the good properties of BatchNorm and get rid of the disadvantages, then we can potentially train with smaller batch sizes, have faster training and inference times and also reduce the memory overhead!"

We also want the normalizer-free networks to have good signal propagation throughout the network. But is there a way to measure signal propagation? How do we compare the normalizer free networks with their BatchNormalized counterparts? Enter Signal Propagation Plots..

💭: This is one of my favorite sections of the paper.Basically, signal propagation plots are plots that help measure "signal propagation" inside the network. How? We calculate some statistics at different points inside the network (during a single forward pass) and plot them.

From the paper:

Although papers have recently theoretically analyzed signal propagation in ResNets, practitioners rarely empirically evaluate the scales of the hidden activations at different depths inside a specific deep network when designing new models or proposing modifications to existing architectures. By contrast, we have found that plotting the statistics of the hidden activations at different points inside a network, when conditioned on a batch of either random Gaussian inputs or real training examples, can be extremely beneficial.

The authors found plotting statistics of the hidden activations at different points inside a network to be helpful and termed these plots as Signal Propagation Plots.

💭: We ideally want hidden activations to have zero mean and unit variance throughout the network. This is a good measure of "good" signal propagation.

The authors consider 4-dimensional input and output tensors with dimensions denoted by NHWC, where N denotes the batch dimension, C denotes the channels, and H and W denote the two spatial dimensions height & width.

The authors also assume identity residual block of the form:

x_{L+1} = f_{L}(x_{L}) + x_{L}

where, x_l denotes the input to the l_{th} block, f_l denotes the function computed by the l_{th} residual branch.

😕: Do equations confuse you? This equation represents the Pre-Activation ResNet Block that is shown in Figure - 4(a) of the research paper. Here f(.) represents the function on the Residual Branch which consists on BatchNorm, ReLU, Conv operations as in Figure - 4(a). This is also referred to as Residual Block in the coming sections.

Then, to generate Signal Propogation Plots (SPPs), the authors initialize the network based on the initialization scheme (could be He initialization, or Glorot initialization, or any other), and provide the network with a batch of input examples sampled from a unit Gaussian distribution.

💭: In simple words, create a network of your choice, initialize it with a suitable initialization scheme and make a forward pass on Guassian inputs with mean zero and unit variance.

As this input propogates through the network, the authors plot the following statistics at the end of each residual block:

- Average Channel Squared Mean, computed as the square of the mean across the NHW axes, and then averaged across the C axis.
- Average Channel Variance, computed by taking the channel variance across theNHWaxes, and then averaging across theCaxis.
- Average Channel Variance on the end of the residual branch, before merging with the skip path.

💭: This might be a good time to pause, take a second, maybe even close your eyes and get a feel for SPPs? How do you think they look like? What's your expectation?

Below in Figure 1, we see an example of Signal Propogation Plots that measure signal propagation inside a ResNet-V2 600 network and plot the three statistics mentioned above for two networks - one with BN-ReLU-Conv layer ordering inside the Residual Block and another with ReLU-BN-Conv ordering..

❓: What's the difference BN-ReLU-Conv layer and ReLU-BN-Conv ordering? Just swap the positions of BN and ReLU activation layer in Figure - 4(a) of the research paper.

🤔: Do the plots match your expectation? If not, why not?

Figure 1: Signal Propagation Plot for a ResNetV2-600 at initialization with BatchNorm, ReLU activations and He init, in response to an Gaussian input (with zero mean and unit variance) at 512px resolution. Black dots indicate the end of a stage. Blue plots use the BN-ReLU-Conv ordering while red plots use ReLU-BN-Conv.

If you remember from the ResNet paper, each ResNet model is divided into four stages and each stage has a different number of blocks. For ResNet-V2, there are a total of 200 residual blocks with each stage having 50 blocks.

💭: This model has been defined in the code provided by the authors here.

Since, there are 200 residual blocks in ResNet-V2 600, then we get Figure-1 by calculating the values of the three statistics for every residual block and plotting them.

💭: Figure 1 really confused me for quite some time. I somehow thought that ResNet-V2 600 has 600 residual blocks and therefore was confused by the X-axis of figure 1 as the maximum value on the X-axis is 200. But, now having looked at the code, this makes sense. I hope this is something that's not as confusing to you as it was for me.

⚠️: Pay extra attention here.

Key pattern from Figure 1 that will really help us when we look at normalizer free neural networks:

- From Figure 1(b), Average Channel Variance grows linearly with the depth at a given stage, and resets at each Transition Blocks.
- From Figure 1(a), for BN-ReLU-Conv ordering, the Average Square Channel Means display similar behavior growing linearly with depth.

💭: I direct the curious reader to the SkipInit research paper by the same authors where they replace BatchNorm with a scalar that downscales the signal at the residual branch at initialization.

💭: If we can mimic these patterns in Normalizer-Free ResNets, then we can make sure that the new networks without normalization also train well and are competitive to their normalized counterparts.

Okay, that's it - that's Signal Propagation Plots in theory. Wouldn't it be fun if we could create these plots for ourselves and reproduce them? In fact, we can use one of my favorite libraries - timm - to do just this!

`import torchvisionfrom timm.utils.model import extract_spp_stats, avg_ch_var, avg_ch_var_residual, avg_sq_ch_meanmodel = torchvision.models.resnet50()spp_stats = extract_spp_stats(m, hook_fn_locs=['layer?.?', 'layer?.?', 'layer?.?.bn3'], hook_fns=[avg_sq_ch_mean, avg_ch_var, avg_ch_var_residual])# plot statsfig, ax = plt.subplots(1, 3, figsize=(18,3), sharey=True)ax[0].plot(stats['avg_sq_ch_mean'], label='avg_sq_ch_mean');ax[0].legend(); ax[0].grid();ax[1].plot(stats['avg_ch_var'], label='avg_ch_var');ax[1].legend(); ax[1].grid();ax[2].plot(stats['avg_ch_var_residual'], label='avg_ch_var_residual');ax[2].legend(); ax[2].grid();`

Figure 2: Signal Propagation Plots for ResNet-50 BN-ReLU-Conv ordering using timm

💭: I was lucky enough to have contributed SPPs to timm. For a complete notebook to replicate Figure-1 for a ResNet V2 600, refer here.

💭: You might find this section below a little more complicated than the ones above but it is also the most important as this is where Normalizer-Free ResNets are introduced. Feel free to re-read this section or reach out to me should you have any questions.

With a good solid understanding of BatchNorm normalized networks now, and with SPPs to help our analysis, we are ready to look at variants of ResNet that don't have normalization, but have good signal propagation, are stable during training, and reach test accuracies comparable to their batch-normalized counterparts.

💭: Really, all we needed is a good understanding of BatchNorm normalized networks and SPPs to develop normalizer free networks. Essentially, the authors create normalizer free networks whose SPPs mimick the SPPs normalized ResNets. Why? This has been answered in the Q&A section below.

There are two Key Observations that we must remember and make sure we reproduce these effects when designing normalizer-free networks:

- BatchNorm downscales the input to residual block by a factor proportional to the standard deviation of the input signal.
- Each residual block increases the variance of the signal by an approximately constant factor. (There is linear growth in Average Channel Variance)

The authors proposed to mimic these effects by designing new networks of the form:

x_{l+1} = x_{l} + α f_{l}(\frac{x_l}{β_l})

where, x_l denotes the input to the l_{th} residual branch and f_l(.) denotes the l_{th} residual branch.

💭: Note that this is in contrast to Residual Blocks where x_{L+1} = f_{L}(x_{L}) + x_{L}

The authors designed these new normalizer-free networks such that:

⚠️: Pay extra attention here.

- f_l(.), the function computed by the residual branch is parameterized to be variance preserving at initialization. That is, Var(f(x_l)) = Var(x_l).
- β_l is a scalar function chosen as \sqrt{var(x_l)}. This ensures that f_l(.) has unit variance.
- α is a scalar hyperparameter which controls the rate of variance growth between blocks.

💭: Yes, I know that that's a lot of information and a lot less explanation. Explanation comes next in the form of Q&A.

💭: In this section I try to answer some questions that were very confusing to me when I first read the paper.

❓ What does it mean when we say that f_l(.) is variance-preserving?

Basically this means that f_l(.) doesn't change the variance of the input. That is,

Var(f_l(x)) = Var(x_l)

❓ Why do we want f_l(.) to be variance-preserving?

From the paper:

This constraint enables us to reason about the signal growth in the network, and estimate the variances analytically.

And since we can calculate variances analytically, that means, we can also calculate the values for β_l analytically.

❓ Why should a network without normalization mimic the SPP trend of ResNet?

From OpenReview, this is answered by the authors below:

The choice of how signals should propagate in unnormalized networks is largely one of design. In Appendix G.2, we note that we initially explored designing networks to have constant variance, which without prior knowledge one might assume to be a superior choice. We found that such networks were not as performant, and reasoned that mimicking a signal propagation template which we know to work well was a good design choice, as is supported by our experiments.

❓ Why is β_l chosen to be \sqrt{var(x_l)}?

\sqrt{var(x_l)} is the standard deviation of input signal x_l to the l_{th} residual block. Dividing the signal by its standard deviation, we can make sure that x_l has unit variance, which is desirable for stable training!

❓ What is α?

As we know from the design of the new normalizer-free block:

x_{l+1} = x_{l} + α f_{l}(\frac{x_l}{β_l})

Since, β_l is chosen to be \sqrt{var(x_l)}, therefore, \frac{x_l}{β_l} has unit variance. Since, f_l(.) is variance-preserving, therefore, f_l(\frac{x_l}{β_l}) also has unit variance.

Now, calculating variance, from x_{l+1} = x_{l} + α f_{l}(\frac{x_l}{β_l}) gives us:

Var(x_{l+1}) = Var(x_l) + Var(α.f_l(\frac{x_l}{β_l})

Therefore,

Var(x_{l+1}) = Var(x_l) + α^2

💭: It took me two weeks to get to this point and understand that Var(x_{l+1}) = Var(x_l) + α^2, so it's okay if you don't get it the first time. Let me know if you don't get this part! It's important for a complete understanding of NF-ResNets but not so much from an implementation perspective..

❓ What about the two Key Observations? How do normalizer-free ResNets take care of these?

💭: This is based on my understanding of the paper and hasn't been explicitly explained elsewhere.

Well, two things that BatchNorm does, which we want to reproduce without normalization are:

- Downscale the input to residual branch
- Increases the variance of the signal by an approximately constant factor for each residual block

Since, Var(x_{l+1}) = Var(x_l) + α^2, therefore, α is a scalar hyperparameter which controls the rate of variance growth between blocks. Hence, this ensures that each residual block increases the variance of the signal by an approximately constant factor.

Also, the input to the residual branch is \frac{x_l}{β_l}, which is the downscaled input by a factor of β_l.

💭: I have tried my best to explain all the magic and thinking behind Normalizer Free networks above. If there are parts of this section that didn't make sense, please feel free to let me know towards the end of this report. :)

Ross Wightman has done the good work for us of implementing NF-ResNets already in one of my favorite libraries - timm! Creating NF-ResNets therefore is now as simple as:

`import timm import torch m = timm.create_model('nf_resnet50')x = torch.randn(1, 3, 224, 224)m(x).shape>> torch.Size([1, 1000])`

Above, we simply create a Nf-ResNet 50 model using timm and pass random input to get a classification output. We can use this network on our own custom datasets and finetune this network. One could follow the same training script and finetuning steps as explained here in timm documentation.

💭: I absolutely love timm. It is one of the fastest growing libraries and is kept up to date by Ross. Latest research papers make their way to timm really really quickly! I have also been lucky enough to work on timmdocs project for a more in depth documentation on timm.

💭: If you want to experiment with Nf-ResNets, in PyTorch, using timm is one of the easiest ways to get started with these networks. This is also a great resource by Ayush Thakur that uses PyTorch Lightning and timm for some quick experimentation on batch sizes.

Having looked at a quick and easy way to get started with Nf-ResNets in code, let's now look at the source code of these networks in timm to understand how one could create these networks from scratch.

💭: We could have also looked at the official code implementation in Jax by deepmind, but for simplicity, let's stick to PyTorch. We are going to be using the source code from timm.

Now, to create a Normalizer Free Block, we have to recreate the equation:

x_{l+1} = x_{l} + α f_{l}(\frac{x_l}{β_l})

Let's do that below in PyTorch:

`# copied from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/nfnet.pyimport torchimport torch.nn as nnfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STDfrom timm.models.helpers import build_model_with_cfgfrom timm.models.layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ get_act_layer, get_act_fn, get_attn, make_divisibleclass NormFreeBlock(nn.Module): """Normalization-Free pre-activation block.""" def __init__( self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None, alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False, skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.): super().__init__() out_chs = out_chs or in_chs mid_chs = int(in_chs * bottle_ratio) self.alpha = alpha self.beta = beta if in_chs != out_chs or stride != 1 or dilation != first_dilation: self.downsample = DownsampleAvg( in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer) else: self.downsample = None self.act1 = act_layer() self.conv1 = conv_layer(in_chs, mid_chs, 1) self.act2 = act_layer(inplace=True) self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) self.act3 = act_layer() self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0.) def forward(self, x): out = self.act1(x) * self.beta # shortcut branch shortcut = x if self.downsample is not None: shortcut = self.downsample(out) # residual branch out = self.conv1(out) out = self.conv2(self.act2(out)) out = self.conv3(self.act3(out)) out = out * self.alpha + shortcut return out`

The above code recreates a Normalizer Free Bottleneck block similar to Pre-Activation ResNet Block as in the image below, but, without BatchNorm:

Figure 3: Residual Blocks for pre-activation ResNets (He et al., 2016a).

💭: I leave it as an exercise to the reader here to map the Figure-3 - (a) & (b) to the code implementation from timm for Normalizer Free Blocks shared above. The snippet of code is capable of implementing both the Pre-Activation ResNet Block & Pre-Activation ResNet Transition Block.

💭: Also, then to create Nf-ResNet is a simple case of repeating the blocks as per the various ResNet configurations mentioned here. This is again left as an exercise but happy to dig deeper into this if needed. I leave it as an exercise because it would really be a good exercise/project to understands Nf-ResNets by recreating the networks using the source code from timm as a guide.

💭: All the pieces required for understanding the complete implementation have been shared as part of this report already.

💭: There is still a mystery that's left to be solved, but, it's the last one.

Figure 4: SPPs for three different variants of the ResNetV2-600 network(with ReLU activations). In red, a batch normalized network with ReLU-BN-Conv ordering. In green, a normalizer-free network with He-init and α = 1. In cyan, the same normalizer-free network but with Scaled Weight Standardization.

The authors implemented the NF-ResNet network introduced above and compared the Signal Propagation Plots with a normalized network with ReLU-BN-Conv ordering (one that we've already seen in Figure 1). The NF-ResNet was initialized with He initialization and α = 1. Both SPPs are shown in Figure 4 above.

As can be seen in Figure 4, the two SPPs aren't the same. In fact, there are two unexpected patterns that can be observed:

- For NF-ResNets, the average value of the squared channel mean grows rapidly with depth, achieving large values which exceed the average channel variance. (comparing Figure-5(a) and Figure-5(b))
- For NF-ResNets, The scale of the empirical variances on the residual branch are consistently smaller than one. (Figure - 5(c) in green)

💭: As I've mentioned before, it is considered good propagation if the activations have a zero mean and unit variance throughout the network. Thus, having an average value of the squared channel mean that grows rapidly with depth & values of variances that are consistently smaller than one represents instability in the network.

To prevent the emergence of a mean shift and to ensure that the residual branch f_l(.) is variance-preserving, the authors proposed Scaled Weight Standardization.

💭: Scaled Weight Standardization is an extension of Weight Standardization. Essentially, weight standardization normalizes weights in convolution layers, i.e., making the weights have zero mean and unit variance. Note that this is different from Batch Normalization (BatchNorm) which standardizes the hidden activations instead. For the curious reader, here is a wonderful video by Yannic Kilcher, that explains Weight Standardization.

Scaled Weight Standardization has been formulated as:

\hat{W_{i,j}} = γ \frac{W_{i,j}-μ_{w_i}}{σ_{w_i}.\sqrt{N}}

where, the mean μ and variance σ are computed across the fan-in extent of the convolutional filters. The authors initialize the underlying parameters Wfrom Gaussian weights, while γ is a fixed constant.

💭: Don't worry if this seems confusing, things become clearer with code implementation shared below. Essentially, we standardize the weights of every CNN by subtracting the mean and dividing by standard deviation of the weights.

💭: The only difference between Scaled Weight Standardization and Weight Standardization is the introduction of a fixed constant γ which is non-linearity dependant. I refer the curious reader to section 4.2 of the paper that explains why.

When the authors applied Scaled Weight Standardization to the Normalizer-Free ResNet, as shown in Figure 4, Scaled Weight Standardization eliminates the growth of the average channel squared mean at initialization. Indeed, the SPPs of the network with Scaled Weight Standardization are almost identical to the SPPs of a batch-normalized network employing the ReLU-BN-Conv ordering, shown in red.

💭: I have an implementation of Weight Standardization in PyTorch here.

❓ What is γ?

γ is a fixed constant that is non-linearity dependant and has the following values for various non-linearities. The value of γ, which depends on the non-linearity g, is chosen such that the layer will be variance-preserving.

`# from https://github.com/deepmind/deepmind-research/tree/master/nfnets_nonlin_gamma = dict( identity=1.0, celu=1.270926833152771, elu=1.2716004848480225, gelu=1.7015043497085571, leaky_relu=1.70590341091156, log_sigmoid=1.9193484783172607, log_softmax=1.0002083778381348, relu=1.7139588594436646, relu6=1.7131484746932983, selu=1.0008515119552612, sigmoid=4.803835391998291, silu=1.7881293296813965, softsign=2.338853120803833, softplus=1.9203323125839233, tanh=1.5939117670059204,)`

And that's it! Scaled Weight Standardization was the last piece of the puzzle that had to be put in to place for normalizer-free networks to get competitive performance with their normalized counterparts!

We reference the implementation from timm again for Scaled Weight Standardization as below:

` class ScaledStdConv2d(nn.Conv2d): """Conv2d layer with Scaled Weight Standardization. Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-5, gain_init=1.0): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) self.eps = eps def get_weight(self): std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) weight = self.scale * (self.weight - mean) / (std + self.eps) return self.gain * weight`

💭: Again, I leave it as an exercise to the match the implementation above with the formula. But of course, feel free to reach out to me should there be any confusion.

I hope that through this report, I've been able to explain NF-ResNets to the reader and also showcase how the networks can be implemented in code. I have spent countless hours over the past few weeks reading and understanding about NF-ResNets myself and then tried my best to distill that knowledge into a simple readable report. This report was re-written from scratch multiple times over many iterations until I was happy with a final version. I hope that the effort I put in to this report actually helps my dear readers.

I do realize this report is a bit "math heavy", but, I have tried my best to explain the math by breaking the equations into parts. The math was important to best explain the normalizer-free networks.

As is usual for many of my blog posts, feel free to reach out to me on Twitter or provide constructive feedback as comments below in case I have missed anything.

Thanks for reading and happy experimentation!