The Power of Random Features of a CNN

This report presents a number of experiments based on the ideas shown in https://arxiv.org/abs/2003.00152 by Frankle et al. Made by Sayak Paul using W&B
Sayak Paul

Introduction

BatchNorm has been a favorite topic among the ML research community since it was proposed. It is often misunderstood, and even quite poorly understood. So far, the research community has mostly focused on its normalization component. It's also important to note that a BatchNorm has two learnable parameters - a coefficient that is responsible for scaling and a bias that is responsible for shifting. Not much work has been done in order to study the effect of these two parameters systematically.

Earlier this year, Jonathan Frankle and his team published a paper on Training BatchNorm and Only BatchNorm: On the Expressive Power of Random Features in CNNs. They studied how well the scaled and shifted parameters of the BatchNorm layers adjust themselves with the random parameter initializations of CNNs. In this report, I am going to present my experiments based on the ideas presented in this paper.

Check out the code on GitHub →

Configuring the experiment

For these experiments, I used the CIFAR10 dataset and a ResNet20 based architecture specific to the dimensions of the images of the CIFAR10 dataset. Thanks to the Keras Idiomatic Programmer for the implementation of the ResNet architecture. Additionally, we used the Adam optimizer for all the experiments.

I used Colab TPUs to experiment quickly, but for consistency, I ran the same experiments on a GPU instance and the results were identical to a great extent. The central goal of the experiments is to compare the performance of CNN where only the BatchNorm layers are trainable as opposed to all the layers.

Thanks to the ML-GDE program for the GCP credits which were used to spin up notebook instances and to save some model weights to GCS Buckets.

What's BatchNorm anyway?

Normalization is a common technique that is applied to input data to stabilize the training of deep neural networks. In general, the ranges of output values of the neurons of a deep neural network deviate from each other over the course of the training and in doing so, introduce unstable training behavior. BatchNorm helps mitigate this problem by normalizing the neuron outputs i.e. by subtracting the mean and dividing by the standard deviation across a mini-batch.

To introduce a bit of variance in the outputs to allow a deep network to adapt to variations, BatchNorm is parameterized by a scale and a shift parameter. These two parameters tell a BatchNorm layer how much scaling and shifting are required, and also become a part of the model training process.

Source

You can find a more concrete overview of BatchNorm here.

Next up, let's dive into a comparison of performance for different flavors of ResNet20.

ResNet20 with all the layers set to trainable

ResNet20 with all the layers set to trainable

ResNet20 with an LR schedule and all the layers set to trainable

ResNet20 with an LR schedule and all the layers set to trainable

The Promise of BatchNorm

In this section, I'll present the results we've been waiting to see until now – what if we only train the BatchNorm layers and keep all the other trainable parameters at their random initial values?

Note that in this case, the number of trainable parameters in the network is 4000 as can be seen in this notebook.

The Promise of BatchNorm

What if none of the layers are trainable?

What if none of the layers are trainable?

How important is batch size?

This section demonstrates the effect of changing the batch sizes, when we train only the BatchNorm layers.

How important is batch size?

Keeping all layers trainable vs only the BatchNorm layers trainable

At one extreme level, the following three plots compare two models (one with all layers as trainable vs. one with only BatchNorm layers set to trainable) in terms of their training time and training progress -

Keeping all layers trainable vs only the BatchNorm layers trainable

Are the learned convolutional filters consistent?

Let's validate the performance of our model where just the BatchNorm layers were trained by investigating the learned convolution filters. Specifically, we are looking to see - do they learn anything useful?

First we'll take a look at the 10th convolutional filter from the model where all the layers were trained.

Are the learned convolutional filters consistent?

Convolutional filters where only the BatchNorm layers were trained

The results are quite promising in this case as well -

Convolutional filters where only the BatchNorm layers were trained

Next steps

I invite you the reader to take this research forward and explore the following next steps -

Check out the code on GitHub →