Plunging Into Model Pruning in Deep Learning

This report discusses pruning techniques in the context of deep learning. . Made by Sayak Paul using Weights & Biases
Sayak Paul

Introduction

This report is the successor of my part report on Quantization. In this report, we go over the mechanics of model pruning in the context of deep learning. Model pruning is the art of discarding the weights that do not improve a model's performance. Careful pruning enables us to compress and deploy our workhorse neural networks onto mobile phones and other resource-constrained devices.

This report is structured into the following sections:

(The code snippets that we'll be discussing will be based on TensorFlow (2) and the TensorFlow Model Optimization Toolkit).

Run the code in Colab →

Notion of "Non-Significance" in Functions

Neural networks are function approximators. We train them to learn functions that capture underlying representations formulating the input data points. The weights and the biases of a neural network are referred to as its (learnable) parameters. Often, the weights are referred to as coefficients of the function being learned.

Consider the following function -

$f(x) = x + 5x^2$

In the above function, we have two terms on the RHS: $x$ and $x^2$. The coefficients are 1 and 5 respectively. In the following figure, we can see that the behavior of the function does not change much when the first coefficient is nudged.

Here are the coefficients in the different variants of the original function can be referred to as non-significant. Discarding those coefficients won't really change the behavior of the function.

Extension to Neural Networks

The above concept can be applied to neural networks as well. This needs a bit more details to be unfolded. Consider the weights of a trained network. How could we make sense of the weights that are non-significant? What's the premise here?

For this to be answered, consider the optimization process with gradient descent. Not all the weights are updated using the same gradient magnitudes. The gradients of a given loss function are taken with respect to the weights (and biases). During the optimization process, some of the weights are updated with larger gradient magnitudes (both positive and negative) than the others. These weights are considered to be significant by the optimizer to minimize the training objective. The weights that receive relatively smaller gradients can be considered as non-significant.

After the training is complete, we can inspect the weight magnitudes of a network layer by layer and figure out the weights that are significant. This decision be made using several heuristics -

If all of these are becoming hard to comprehend, don't worry. In the next section, things will become clearer.

Pruning a Trained Neural Network

Full code in Colab →

Now that we have a fair bit of understanding of what could be called significant weights, we can discuss magnitude-based pruning. In magnitude-based pruning, we consider weight magnitude to be the criteria for pruning. By pruning what we really mean is zeroing out the non-significant weights. Following code, snippet might be helpful to understand this -

# Copy the kernel weights and get ranked indices of the
# column-wise L2 Norms
kernel_weights = np.copy(k_weights)
ind = np.argsort(np.linalg.norm(kernel_weights, axis=0))
    
# Number of indices to be set to 0
sparsity_percentage = 0.7
cutoff = int(len(ind)*sparsity_percentage)

# The indices in the 2D kernel weight matrix to be set to 0
sparse_cutoff_inds = ind[0:cutoff]
kernel_weights[:,sparse_cutoff_inds] = 0.

(This code snippet comes from here)

Here's a pictorial representation of the transformation that would be happening to the weights after they have been learned -

It can be applied to the biases also. It's important to note that here we consider an entire layer receiving an input of shape (1,2) and containing 3 neurons. It's often advisable to retrain the network after it is pruned to compensate for any drop in its performance. When doing such retraining it's important to note that, the weights that were pruned, won't be updated during the retraining.

Seeing Things in Action

Enough jibber-jabber! Let's see these things in action. To keep things simple we'll be testing these concepts on the MNIST dataset but you should be able to extend them to more complex datasets as well. We'll be using a shallow fully-connected network having the following topology -

The network has got a total of 20,410 trainable parameters. Training this network for 10 epochs gets us a good baseline -

Seeing Things in Action

Recipe 1: Take a trained network, prune it with more training

You are encouraged to follow along with this Colab Notebook mentioned at the top from this point on.

We are going to take the network we trained earlier and prune it from there. We will apply a pruning schedule that will keep the sparsity level constant (to be specified by the developer) throughout the training. The code to express this is as follows:

pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(
          target_sparsity=target_sparsity,
          begin_step=begin_step,
          end_step=end_step,
          frequency=frequency
)

pruned_model = tfmot.sparsity.keras.prune_low_magnitude(
    trained_model, pruning_schedule=pruning_schedule
)

A pruned model needs a re-compile before we can begin training it. We compile it in the same way and we print its summary -

pruned_model.compile(loss='sparse_categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])

pruned_model.summary()
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_conv2d ( (None, 26, 26, 12)        230       
_________________________________________________________________
prune_low_magnitude_max_pool (None, 13, 13, 12)        1         
_________________________________________________________________
prune_low_magnitude_flatten  (None, 2028)              1         
_________________________________________________________________
prune_low_magnitude_dense (P (None, 10)                40572     
=================================================================
Total params: 40,804
Trainable params: 20,410
Non-trainable params: 20,394
_________________________________________________________________

We see that the number of parameters has got changed now. This is because tfmot adds non-trainable masks for each of the weights in the network to denote if a given weight should be pruned. The masks are either 0 or 1.

Let's train it.

Full code in Colab →

Recipe 1: Take a trained network, prune it with more training

We can see that pruning the model does not hurt the performance. The red lines correspond to the pruning experiments.

Note:

We can also verify if tfmot reached to the target sparsity by writing tests like the following:

for layer in model.layers:
	if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
	    for weight in layer.layer.get_prunable_weights():
	        print(np.allclose(
	            target_sparsity, get_sparsity(tf.keras.backend.get_value(weight)), 
	            rtol=1e-6, atol=1e-6)
            )

def get_sparsity(weights):
    return 1.0 - np.count_nonzero(weights) / float(weights.size)

Running it on the pruned model should produce True for all the layers that were pruned.

Recipe 2: Randomly initialize a network, prune it by training from scratch

Everything remains the same in this case except that we are not starting with an already-trained network instead we will be starting with a randomly initialized network.

Recipe 2: Randomly initialize a network, prune it by training from scratch

Performance Evaluation

We will be using the standard zipfile library to compress the models to .zip format. We need to use tfmot.sparsity.keras.strip_pruning when serializing the pruned models, it will remove the pruning wrappers that were added to the models by tfmot. Otherwise, we won't be able to see any compression benefits in the pruned models.

Compressing the regular Keras models remains the same, however.

def get_gzipped_model_size(file):
    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)

    return os.path.getsize(zipped_file)

file should be a path to an already serialized Keras model (both pruned and regular).

In the below figure, we can see the compressed models weigh lesser than the regular Keras model and they still yield pretty good performance.

Full code in Colab →

Performance Evaluation

Apart from the accuracy measurement, compression ratio is another widely used technique to measure the efficacy of a particular pruning algorithm. The compression ratio is the inverse of the fraction of the parameter remaining in a pruned network.

This flavor of quantization is also known as post-training quantization. So, here's a simple recipe for you to follow for optimizing your models for deployment:

In the next section, we will be going through a few modern recipes for pruning. If you want to pursue the field of model optimization more, these ideas will be worth exploring further.

Some Modern Pruning Recipes

Let's start this section with the following motivating questions:

  1. When we are retraining a pruned network, what if the unpruned weights are initialized to their original initial magnitudes? If you obtained a pruned network from an already-trained network (let's say network A) consider these initial magnitudes of network A.
  2. When performing magnitude-based pruning in transfer learning regimes with a pre-trained network, how are we deciding the significance of the weights?

Of Winning Tickets

The first question has been explored tremendously by Frankle et al. in their seminal paper on Lottery Ticket Hypothesis. So, after pruning an already-trained network, the subnetworks that have the initialization just described above are referred to as the winning tickets.

-> Source: Original Paper <-

As a rationale behind this methodology, you could reason that during the initial training of a network a particular initialization of the parameters guided the optimization process. Now, the weights that responded well in the optimization landscape (meaning that they traveled further than the other weights) actually end up in the winning lotteries. So in order for it to (re)train well, if we initialize the weights to their utmost initial magnitudes the optimization process allures itself very nicely with them. Thanks to Yannic Kilcher for this beautiful explanation.

The paper presents a plethora of different experiments to support this hypothesis and it's an absolute recommended read.

Of Systematic Exploration of Lottery Ticket Hypothesis

In the original Lottery Ticket Hypothesis paper, Frankle et al. only explored how did a pruned network perform if the surviving weights were reinitialized to their utmost initial magnitudes before retraining. Just after the Lottery Ticket Hypothesis was presented at ICLR 2019, Zhou et al. published a paper on Deconstructing Lottery Tickets studying different ways to handle both the weights that did survive and did not survive during pruning. The also proposed supermasks which are basically learnable masks.

-> Source: Original Paper <-

Generalization of Lottery Ticket Hypothesis

To be able to scale the Lottery Ticket Hypothesis to datasets like ImageNet, Frankle et al. published a paper on Linear Mode Connectivity that is sort of a generalization of the Lottery Ticket Hypothesis. It proposes weight rewinding as a potential way to initialize the surviving weights of a pruned network. Earlier, we were initializing them with their utmost initial magnitudes. What weight rewinding does is it rewinds the surviving weights to somewhere later in the training of the original network. In other words, the surviving weights get initialized to magnitudes from epoch 5, say, of the training of the original network.

-> Source: Original Paper <-

Extending this idea, Renda et al. published a paper on Learning Rate Rewinding that applies to rewind to learning rate schedules while retraining a pruned network. The authors also propose this as an alternative to fine-tuning.

So, these were some exciting ideas evolving primarily around magnitude-based pruning. In the final section, we will see a pruning method that performs better than magnitude-based pruning, especially for transfer learning regimes.

Pruning based on Weight Movements

In their paper on Movement Pruning, Sanh et al. propose an alternative to magnitude-based pruning that is specifically geared towards handling the pruning of pre-trained models for transfer learning tasks.

Magnitude-based pruning is very positively correlated with the notion of significance that we already discussed earlier. In this case, the significance here simply denotes the absolute magnitudes of the weights. The lower these magnitudes the lesser the significance. Now, this significance can actually change when we try to do transfer learning with a model pre-trained on a different dataset. The weights that were significant while optimizing the source dataset might not be significant for the target dataset.

-> Source: Original Paper <-

So, during transfer learning the pre-trained weights that move toward zero can be actually considered as non-significant with respect to the target task, and the weights that move further away can be considered as significant. This is where the method derives its name from - movement pruning. Thanks to Yannic again for his excellent explanation.

Conclusion, acknowledgments, and final thoughts

If you stick to the end, great! I hope this report gave you a fair idea of what pruning is in the context of deep learning. I would like to acknowledge Raziel and Yunlu (of Google) who provided me with important information about tfmotand some additional thoughts about pruning itself.

Some further ideas I would like to explore in this area are:

At the time of writing this report (June 2020), one of the most recent approaches to pruning was SynFlow. SynFlow does not require any data for pruning a network and it uses Synaptic Saliency Scores to determine the significance of parameters in a network.

I am open to hearing your feedback via Twitter (@RisingSayak).

References

In no particular order: