“Model ensembles are a pretty much-guaranteed way to gain 2% of accuracy on anything.” - Andrej Karpathy.

I absolutely agree! However, deploying an ensemble of heavyweight models may not always be feasible in many cases. Sometimes, your single model could be so large (GPT-3, for example) that deploying it in resource-constrained environments is often not possible. This is why we have been going over some of model optimization recipes - Quantization and Pruning. This report is the last one in this series. In this report, we will discuss a compelling model optimization technique - knowledge distillation. I have structured the report into the following sections -

• What is softmax telling us?
• Using the softmax information for teaching - Knowledge distillation
• Loss functions in knowledge distillation
• A few training recipes
• Experimental results
• Conclusion

# What is Softmax Telling us?

When working with a classification problem, it is very typical to use softmax as the last activation unit in your neural network. Why is that? Because a softmax function takes a set of logits and spits out a probability distribution over the discrete classes, your network is being trained on. Figure 1 presents an example of this.

-> Figure 1: Predictions of a neural network on an input image. <-

In Figure 1, our imaginary neural network is highly confident that the given image is $1$. However, it also thinks that there is a slight chance it could be $7$ as well. It is thinking quite right, isn’t it? The given image does have subtle seven-ish characteristics. This information would not have been available if we were only dealing with hard one-hot encoded labels like [1, 0] (where 1 and 0 are probabilities of the image being a one and a seven respectively).

Humans are well equipped to exploit this sort of relativeness. More examples include - a cat-ish dog, a brownish red, a cat-ish tiger, and so on. These are still valid comparisons as Hinton et al. opines in [1] -

An image of a BMW, for example, may only have a minimal chance of being mistaken for a garbage truck, but that mistake is still many times more probable than mistaking it for a carrot.

This very knowledge helps us to generalize excruciatingly well out there in the wild.

This thought process helps us to dig deeper into what our models might be thinking about the input data. It should be somewhat consistent with the way we would think about the input data. Figure 1 again establishes this - to our eyes, that image looks like a one, but it has some traits of a seven.

So, what now? An immediate question that may strike the mind - what is the best way for us to use this knowledge in neural networks? Let us find out in the next section.

# Using the Softmax Information for Teaching - Knowledge Distillation

The softmax information is way more useful than plan hard one-hot encoded labels. So, at this stage, we may have access to -

• The training data
• A trained neural network that performs well on the test data

We are now interested in using the output probabilities produced by our trained network.

Consider teaching someone about the English digits with the MNIST dataset. It is highly likely that you would run into the question from a student - does not that one look like a seven? If that is the case, it’s definitely good news because your student, for sure, knows how a one and a seven look like. As a teacher, you have been able to transfer your knowledge of English digits to your student. It is possible to extend this idea to neural networks as well.

## High-Level Mechanics of Knowledge Distillation

So, here is the deal at a high-level -

• Train a neural network that performs well on your dataset. This network is going to act like a teacher model.
• Use the teacher model to train a student model on the same dataset. The catch here is that the student model should be significantly smaller than the teacher in terms of capacity.

This workflow briefly formulates the idea of knowledge distillation.

Why smaller? Isn’t this we want? To deploy a lightweight model to production that is performant enough?

## An Image Classification Case Study

Disclaimer: For the sake of brevity and simplicity, I am going to demonstrate the further sections on a computer vision-based example. Note: These ideas are independent of domains.

For an image classification example, we can extend the earlier high-level idea -

• Train a teacher model that performs well on your image dataset. Here the cross-entropy loss would be calculated with respect to the true labels from your dataset.
• Train a smaller student model on the same dataset but use the predictions from the teacher model (the softmax output) as the ground-truth labels. These softmax outputs are referred to as soft-labels. More on this in a moment.

-> Figure 2: A high-level overview of knowledge distillation. <-

Why are we training the student model on soft-labels?

Remember that our student model is smaller than the teacher model in terms of capacity. So, if your dataset is complex enough, then the smaller student model may not be well suited to capture the hidden representations required for the training objective. We train the student model on soft-labels to compensate for this, which provides more meaningful information than the one-hot encoded labels. In a sense, we are training the student model to imitate the teacher model’s outputs by giving a little bit of exposure to the training dataset.

Hopefully, this provided you with an intuitive understanding of knowledge distillation. In the next section, we will be taking a more detailed look at the student model's training mechanics.

# Loss Functions in Knowledge Distillation

In order to train the student model, we can still use our regular cross-entropy loss between the soft-labels from the teacher and predicted labels from the student. It is highly likely that the student model would be confident about many of the input data points, and it would predict probability distributions like the following -

-> Figure 3: Highly confident predictions. <-

## Extended Softmax

The problem with these weak probabilities (marked in red in Figure 3) is they do not capture desirable information for the student model to learn effectively. For example, it is almost impossible to transfer the knowledge that the image has seven-ish traits if the probability distribution is like [0.99, 0.01].

Hinton et al. address this problem by scaling the raw logits of the teacher model by some temperature ($\tau$) before they get passed to softmax [1] (known as extended softmax or temperature-scaled softmax). That way, the distribution gets more spread across the available class labels. This same temperature is used in order to train the student model. I have presented this idea in Figure 4.

-> Figure 4: Softened predictions. <-

We can write the student model’s modified loss function in the form of this equation -

-> $\mathcal{L}{C E}^{K D} = -\sum{i} p_{i} \log s_{i}$, <-

where $p_i$ is the softened probability distribution of the teacher model and $s_i$ is expressed as - $\frac{\exp \left(z_{i} / T\right)}{\sum_{j} \exp \left(z_{j} / T\right)}$.

def get_kd_loss(student_logits, teacher_logits,
true_labels, temperature,
alpha, beta):

teacher_probs = tf.nn.softmax(teacher_logits / temperature)
kd_loss = tf.keras.losses.categorical_crossentropy(
teacher_probs, student_logits / temperature,
from_logits=True)

return kd_loss


## Incorporating the Hard-Labels With Extended Softmax

In [1] Hinton et al. also explore the idea of using the conventional cross-entropy loss between the true target labels (typically one-hot encoded) and student model’s predictions. This especially helps when the training dataset is small and there isn’t enough signal in the soft-labels for the student model to pick up.

This approach works significantly better when it is combined with the extended softmax and the overall loss function becomes a weighted average between the two -

-> $\mathcal{L} = \frac{ (\alpha * \mathcal{L}{C E}^{K D} + \beta * \mathcal{L}{CE})} {(\alpha + \beta)}$ <-

def get_kd_loss(student_logits, teacher_logits,
true_labels, temperature,
alpha, beta):
teacher_probs = tf.nn.softmax(teacher_logits / temperature)
kd_loss = tf.keras.losses.categorical_crossentropy(
teacher_probs, student_logits / temperature,
from_logits=True)

ce_loss = tf.keras.losses.sparse_categorical_crossentropy(
true_labels, student_logits, from_logits=True)

total_loss = (alpha * kd_loss) + (beta * ce_loss)


It’s recommended to weigh $\beta$ considerably smaller than $\alpha$.

## Operating on the raw Logits

Caruana et al. operate on the raw logits instead of the softmax values [2]. This workflow is as follows -

• This part remains the same -

Train a teacher model that performs well on your image dataset. Here the cross-entropy loss would be calculated with respect to the true labels from your dataset.

• Now, in order to train the student model, the training objective becomes minimizing the mean squared error between the raw logits from the teacher and the student models respectively.

-> $\mathcal{L}{M S E}^{K D} = \sum{i}\left|z_{i}^{\theta_{student}}-z_{i(\text { teacher })}^{\text {fixed }}\right|^{2}$ <-

mse = tf.keras.losses.MeanSquaredError()

def mse_kd_loss(teacher_logits, student_logits):
return mse(teacher_logits, student_logits)


One potential disadvantage of using this loss function could be its unconstrained nature. The raw logits can capture noise which a small model may not be able to properly fit. This is why in order for this loss function to fit well in the distillation regime, the student model needs to a bit bigger.

Tang et al. explore the idea of interpolating between the two losses - the extended softmax and the MSE loss [3]. Mathematically, it would look like the following -

-> $\mathcal{L}=(1-\alpha) \cdot \mathcal{L}{M S E}^{K D}+\alpha \cdot \mathcal{L}{C E}^{K D}$ <-

Empirically, they found out when $\alpha$ is equal to 0, the best performance is achieved (on NLP tasks).

If you’re feeling a bit overwhelmed at this point, don’t sweat it. Hopefully, with the code, things will start to shine.

# A few Training Recipes

In this section, I will provide you with a few training recipes that you can consider while working with knowledge distillation.

## Using Data Augmentation

This idea is explored in [3] by Tang et al. They demonstrate this idea on NLP datasets but this is applicable to other domains as well. In order to better guide the student model training using data augmentation can help especially when you are dealing with fewer data. As we typically keep the student model much smaller than the teacher model, so the hope is with more diverse data the student model gets to capture the domain better.

## Using Labeled and Unlabeled Data to Train the Student Model

In works like Noisy Student Training [4] and SimCLRV2 [5] the authors use additional unlabeled data when training the student model. So, you would use your teacher model to generate the ground-truth distribution on the unlabeled dataset. This helps to increase the generalizability of the model to a great extent. This approach is only feasible when unlabeled data is available in the domain of the dataset you’re dealing with. Sometimes, it may not be the case (healthcare, for example). In [4], Xie et al. explore techniques like data balancing and data filtering in order to mitigate the issues that may arise when incorporating unlabeled data when training the student model.

## Don’t use Label-Smoothing When Training the Teacher Model

Label-smoothing is a technique used to relax the high confidence predictions produced by models. It helps to reduce overfitting but it is not recommended to use label-smoothing when training the teacher model since its logits are anyway scaled by some temperature. Hence using label-smoothing in a knowledge distillation situation is not typically recommended. You can check out this article to know more about label-smoothing.

## Using Higher Temperature Values

Hinton et al. recommends using higher temperature values to soften the distributions predicted by the teacher model so that the soft-labels can contain even more information for the student model. This is especially useful when dealing with small datasets. For larger datasets, the information becomes available by means of the number of training examples. Refer to the extended softmax section again if this is not clear as to why using a higher temperature might be better in softening the predicted distributions.

We will be exploring these recipes shortly in the next section.

# Experimental results

Let’s first review the experimental set up. I used the Flowers dataset for my experiments. Unless otherwise specified, I used the following configurations -

• I used MobileNetV2 as the base model for fine-tuning with learning rate set to 1e-5 with Adam as the optimizer.

• I set the (temperature) $\tau$ to 5 when using $\mathcal{L}{C E}^{K D}$ and weighted average of $\mathcal{L}{C E}^{K D}$ and traditional cross-entropy losses respectively.

• $\alpha$ = 0.9 and $\beta$ = 0.1 when using a weighted average of $\mathcal{L}_{C E}^{K D}$ and traditional cross-entropy losses.

• For the student model, I followed this shallow architecture -

Layer (type)                 Output Shape              Param #
=================================================================
conv2d (Conv2D)              (None, 222, 222, 64)      1792
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 55, 55, 64)        0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 53, 53, 128)       73856
_________________________________________________________________
global_average_pooling2d_3 ( (None, 128)               0
_________________________________________________________________
dense_3 (Dense)              (None, 512)               66048
_________________________________________________________________
dense_4 (Dense)              (None, 5)                 2565
=================================================================

• During training the student model I used Adam as an optimizer with a learning rate of 1e-2.

• During training the student model with data augmentation, I used the weighted average loss with the same default hyperparameters mentioned above.

# Section 8

## Baseline Student Model

To make the performance comparisons fair, let's also train the shallow CNN from scratch and observe its performance. Note that in this case, I used Adam as the optimizer with a learning rate of 1e-3.

# Section 10

## The Training Loop

Before we see the results, I wanted to shed some light on the training loop and how I was able to wrap that inside the classic model.fit() call. This is how the training loop looks like -

def train_step(self, data):
images, labels = data
teacher_logits = self.trained_teacher(images)

student_logits = self.student(images)
loss = get_kd_loss(teacher_logits, student_logits)

train_loss.update_state(loss)
train_acc.update_state(labels, tf.nn.softmax(student_logits))
t_loss, t_acc = train_loss.result(), train_acc.result()
train_loss.reset_states(), train_acc.reset_states()
return {"loss": t_loss, "accuracy": t_acc}


The train_step() function should be an easy read if you are already familiar with how to customize a training loop in TensorFlow 2. Notice the get_kd_loss() function. This can be any of the loss functions we have discussed do far. We are using a trained teacher model here, the model we fine-tuned earlier. With this training loop, we can create an entire model that can be trained with a .fit() call.

First, create a class extending tf.keras.Model -

class Student(tf.keras.Model):
def __init__(self, trained_teacher, student):
super(Student, self).__init__()
self.trained_teacher = trained_teacher
self.student = student


When you extend the tf.keras.Model class, you can put your custom training logic inside the train_step() function (it's provided by the class). So, in its entirety, the Student class would look like this -

class Student(tf.keras.Model):
def __init__(self, trained_teacher, student):
super(Student, self).__init__()
self.trained_teacher = trained_teacher
self.student = student

def train_step(self, data):
images, labels = data
teacher_logits = self.trained_teacher(images)

student_logits = self.student(images)
loss = get_kd_loss(teacher_logits, student_logits)

train_loss.update_state(loss)
train_acc.update_state(labels, tf.nn.softmax(student_logits))
t_loss, t_acc = train_loss.result(), train_acc.result()
train_loss.reset_states(), train_acc.reset_states()
return {"train_loss": t_loss, "train_accuracy": t_acc}


You can even write a test_step to customize the evaluation behavior of the model. If you are interested to check that out and also the train_step() utility check out this Colab Notebook. Our model can now be trained in the following manner -

student = Student(teacher_model, get_student_model())
student.compile(optimizer)

student.fit(train_ds,
validation_data=validation_ds,
epochs=10)


One potential advantage of this method is one can easily incorporate other capabilities like distributed training, custom callbacks, mixed precision, and so on.

## Training the Student Model with $\mathcal{L}_{C E}^{K D}$

Upon training our shallow student model with this loss function we get ~74% validation accuracy. We see that the losses start to increase after epoch 8. This suggests stronger regularization might have helped. Also, note that the hyperparameter tuning process has a significant impact here. In my experiments, I did not do rigorous hyperparameter tuning. In order to do faster experimentation, I kept the training schedules short.

# Section 12

## Training the Student Model with $\frac{ (\alpha * \mathcal{L}{C E}^{K D} + \beta * \mathcal{L}{CE})} {(\alpha + \beta)}$

Let's now see if incorporating the ground-truth labels in the distillation training objective helps. With $\beta$ = 0.1 and $\alpha$ = 0.1, we get around ~71% validation accuracy. The training dynamics again suggests that stronger regularization with a longer training schedule would have helped.

# Section 14

## Training the Student Model with $\mathcal{L}_{M S E}^{K D}$

With the MSE loss, we see that the validation accuracy gets a whopping decrease to ~56%. The same kind of loss behavior is present in this setting as well suggesting the need for regularization.

# Section 16

Note that this loss function is absolutely unconstrained and our shallow student model may not be capable of handling the noise that comes with it. Let's try out with a deeper student model.

# Section 17

## Using Data Augmentation While Training the Student Model

As mentioned earlier, the student models are of smaller capacity than the teacher model. When dealing with less data, data augmentation can be helpful to train the student model. Let's verify.

# Section 20

## Effect of Temperature ($\tau$)

In this experiment, let's study the effect of temperature on the student model. In this setting, I used the same shallow CNN.

# Section 22

Finally, I wanted to study if the choice of a base model for fine-tuning had significant effect on the student model.

# Section 24

Finally, if you are wondering what kind of improvement one could get out of knowledge distillation with respect to production purposes. The table below dictates that for us. Without any hyperparameter tuning, we are able to get a decent model that is significantly more lightweight than the other models shown in the table.

The first row corresponds to the default student model trained with the weighted average loss while the other rows correspond to EfficientNey B0 and MobileNetV2 respectively. Note that I did not include the results I got from including data augmentation during training the student model.

# Conclusion and Further Thoughts

This concludes the report and also the series I have been developing on model optimization. Knowledge distillation is a very promising technique specifically suited for deployment purposes. A very good about it is it can be combined with quantization and pruning pretty seamlessly in order to further reduce the size of your production models without having to compromise with the accuracy.

We studied the idea of knowledge distillation in an image classification setting. If you are wondering if it’s extensible to other areas like NLP or even GANs I recommend going over the following resources -

• “Distilling Knowledge from Neural Networks to Build Smaller and Faster Models.” FloydHub Blog, 11 Nov. 2019, https://blog.floydhub.com/knowledge-distillation/.
• Sanh, Victor, et al. “DistilBERT, a Distilled Version of BERT: Smaller, Faster, Cheaper and Lighter.” ArXiv:1910.01108 [Cs], Feb. 2020. arXiv.org, http://arxiv.org/abs/1910.01108.
• Aguinaldo, Angeline, et al. “Compressing GANs Using Knowledge Distillation.” ArXiv:1902.00159 [Cs, Stat], Jan. 2019. arXiv.org, http://arxiv.org/abs/1902.00159.
• In [1] Hintol et al. also demonstrates knowledge distillation on a speech recognition task.

You can also check out some of the CVPR 2020 papers on knowledge distillation here.

Another trend that you would see is having a larger or equivalent student model. This has been very systematically studied in [4].

# References

1. Hinton, Geoffrey, et al. “Distilling the Knowledge in a Neural Network.” ArXiv:1503.02531 [Cs, Stat], Mar. 2015. arXiv.org, http://arxiv.org/abs/1503.02531.
2. Buciluǎ, Cristian, et al. “Model Compression.” Proceedings of the 12th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, Association for Computing Machinery, 2006, pp. 535–541. ACM Digital Library, doi:10.1145/1150402.1150464.
3. Tang, Raphael, et al. “Distilling Task-Specific Knowledge from BERT into Simple Neural Networks.” ArXiv:1903.12136 [Cs], Mar. 2019. arXiv.org, http://arxiv.org/abs/1903.12136.
4. Xie, Qizhe, et al. “Self-Training with Noisy Student Improves ImageNet Classification.” ArXiv:1911.04252 [Cs, Stat], June 2020. arXiv.org, http://arxiv.org/abs/1911.04252.
5. Chen, Ting, et al. “Big Self-Supervised Models Are Strong Semi-Supervised Learners.” ArXiv:2006.10029 [Cs, Stat], June 2020. arXiv.org, http://arxiv.org/abs/2006.10029.

# Acknowledgements

I am grateful to Aakash Kumar Nain for providing valuable feedback on the code.