“Model ensembles are a pretty much-guaranteed way to gain 2% of accuracy on anything.” - Andrej Karpathy.

I absolutely agree! However, deploying an ensemble of heavyweight models may not always be feasible in many cases. Sometimes, your single model could be so large (GPT-3, for example) that deploying it in resource-constrained environments is often not possible. This is why we have been going over some of model optimization recipes - Quantization and Pruning. This report is the last one in this series. In this report, we will discuss a compelling model optimization technique - knowledge distillation. I have structured the report into the following sections -

Check out the code on GitHub →

Run the experiments using the Google Colab Notebook →

What is Softmax Telling us?

When working with a classification problem, it is very typical to use softmax as the last activation unit in your neural network. Why is that? Because a softmax function takes a set of logits and spits out a probability distribution over the discrete classes, your network is being trained on. Figure 1 presents an example of this.

-> Figure 1: Predictions of a neural network on an input image. <-

In Figure 1, our imaginary neural network is highly confident that the given image is $1$. However, it also thinks that there is a slight chance it could be $7$ as well. It is thinking quite right, isn’t it? The given image does have subtle seven-ish characteristics. This information would not have been available if we were only dealing with hard one-hot encoded labels like [1, 0] (where 1 and 0 are probabilities of the image being a one and a seven respectively).

Humans are well equipped to exploit this sort of relativeness. More examples include - a cat-ish dog, a brownish red, a cat-ish tiger, and so on. These are still valid comparisons as Hinton et al. opines in [1] -

An image of a BMW, for example, may only have a minimal chance of being mistaken for a garbage truck, but that mistake is still many times more probable than mistaking it for a carrot.

This very knowledge helps us to generalize excruciatingly well out there in the wild.

This thought process helps us to dig deeper into what our models might be thinking about the input data. It should be somewhat consistent with the way we would think about the input data. Figure 1 again establishes this - to our eyes, that image looks like a one, but it has some traits of a seven.

So, what now? An immediate question that may strike the mind - what is the best way for us to use this knowledge in neural networks? Let us find out in the next section.

Using the Softmax Information for Teaching - Knowledge Distillation

The softmax information is way more useful than plan hard one-hot encoded labels. So, at this stage, we may have access to -

We are now interested in using the output probabilities produced by our trained network.

Consider teaching someone about the English digits with the MNIST dataset. It is highly likely that you would run into the question from a student - does not that one look like a seven? If that is the case, it’s definitely good news because your student, for sure, knows how a one and a seven look like. As a teacher, you have been able to transfer your knowledge of English digits to your student. It is possible to extend this idea to neural networks as well.

High-Level Mechanics of Knowledge Distillation

So, here is the deal at a high-level -

This workflow briefly formulates the idea of knowledge distillation.

Why smaller? Isn’t this we want? To deploy a lightweight model to production that is performant enough?

An Image Classification Case Study

Disclaimer: For the sake of brevity and simplicity, I am going to demonstrate the further sections on a computer vision-based example. Note: These ideas are independent of domains.

For an image classification example, we can extend the earlier high-level idea -

-> Figure 2: A high-level overview of knowledge distillation. <-

Why are we training the student model on soft-labels?

Remember that our student model is smaller than the teacher model in terms of capacity. So, if your dataset is complex enough, then the smaller student model may not be well suited to capture the hidden representations required for the training objective. We train the student model on soft-labels to compensate for this, which provides more meaningful information than the one-hot encoded labels. In a sense, we are training the student model to imitate the teacher model’s outputs by giving a little bit of exposure to the training dataset.

Hopefully, this provided you with an intuitive understanding of knowledge distillation. In the next section, we will be taking a more detailed look at the student model's training mechanics.

Loss Functions in Knowledge Distillation

In order to train the student model, we can still use our regular cross-entropy loss between the soft-labels from the teacher and predicted labels from the student. It is highly likely that the student model would be confident about many of the input data points, and it would predict probability distributions like the following -

-> Figure 3: Highly confident predictions. <-

Extended Softmax

The problem with these weak probabilities (marked in red in Figure 3) is they do not capture desirable information for the student model to learn effectively. For example, it is almost impossible to transfer the knowledge that the image has seven-ish traits if the probability distribution is like [0.99, 0.01].

Hinton et al. address this problem by scaling the raw logits of the teacher model by some temperature ($\tau$) before they get passed to softmax [1] (known as extended softmax or temperature-scaled softmax). That way, the distribution gets more spread across the available class labels. This same temperature is used in order to train the student model. I have presented this idea in Figure 4.

-> Figure 4: Softened predictions. <-

We can write the student model’s modified loss function in the form of this equation -

-> $\mathcal{L}{C E}^{K D} = -\sum{i} p_{i} \log s_{i}$, <-

where $p_i$ is the softened probability distribution of the teacher model and $s_i$ is expressed as - $\frac{\exp \left(z_{i} / T\right)}{\sum_{j} \exp \left(z_{j} / T\right)}$.

def get_kd_loss(student_logits, teacher_logits, 
                true_labels, temperature,
                alpha, beta):
    
    teacher_probs = tf.nn.softmax(teacher_logits / temperature)
    kd_loss = tf.keras.losses.categorical_crossentropy(
        teacher_probs, student_logits / temperature, 
        from_logits=True)
    
    return kd_loss

Incorporating the Hard-Labels With Extended Softmax

In [1] Hinton et al. also explore the idea of using the conventional cross-entropy loss between the true target labels (typically one-hot encoded) and student model’s predictions. This especially helps when the training dataset is small and there isn’t enough signal in the soft-labels for the student model to pick up.

This approach works significantly better when it is combined with the extended softmax and the overall loss function becomes a weighted average between the two -

-> $\mathcal{L} = \frac{ (\alpha * \mathcal{L}{C E}^{K D} + \beta * \mathcal{L}{CE})} {(\alpha + \beta)}$ <-

def get_kd_loss(student_logits, teacher_logits, 
                true_labels, temperature,
                alpha, beta):
    teacher_probs = tf.nn.softmax(teacher_logits / temperature)
    kd_loss = tf.keras.losses.categorical_crossentropy(
        teacher_probs, student_logits / temperature, 
        from_logits=True)
    
    ce_loss = tf.keras.losses.sparse_categorical_crossentropy(
        true_labels, student_logits, from_logits=True)
    
    total_loss = (alpha * kd_loss) + (beta * ce_loss)
    return total_loss / (alpha + beta)

It’s recommended to weigh $\beta$ considerably smaller than $\alpha$.

Operating on the raw Logits

Caruana et al. operate on the raw logits instead of the softmax values [2]. This workflow is as follows -

-> $\mathcal{L}{M S E}^{K D} = \sum{i}\left|z_{i}^{\theta_{student}}-z_{i(\text { teacher })}^{\text {fixed }}\right|^{2}$ <-

mse = tf.keras.losses.MeanSquaredError()

def mse_kd_loss(teacher_logits, student_logits):
    return mse(teacher_logits, student_logits)

One potential disadvantage of using this loss function could be its unconstrained nature. The raw logits can capture noise which a small model may not be able to properly fit. This is why in order for this loss function to fit well in the distillation regime, the student model needs to a bit bigger.

Tang et al. explore the idea of interpolating between the two losses - the extended softmax and the MSE loss [3]. Mathematically, it would look like the following -

-> $\mathcal{L}=(1-\alpha) \cdot \mathcal{L}{M S E}^{K D}+\alpha \cdot \mathcal{L}{C E}^{K D}$ <-

Empirically, they found out when $\alpha$ is equal to 0, the best performance is achieved (on NLP tasks).

If you’re feeling a bit overwhelmed at this point, don’t sweat it. Hopefully, with the code, things will start to shine.

A few Training Recipes

In this section, I will provide you with a few training recipes that you can consider while working with knowledge distillation.

Using Data Augmentation

This idea is explored in [3] by Tang et al. They demonstrate this idea on NLP datasets but this is applicable to other domains as well. In order to better guide the student model training using data augmentation can help especially when you are dealing with fewer data. As we typically keep the student model much smaller than the teacher model, so the hope is with more diverse data the student model gets to capture the domain better.

Using Labeled and Unlabeled Data to Train the Student Model

In works like Noisy Student Training [4] and SimCLRV2 [5] the authors use additional unlabeled data when training the student model. So, you would use your teacher model to generate the ground-truth distribution on the unlabeled dataset. This helps to increase the generalizability of the model to a great extent. This approach is only feasible when unlabeled data is available in the domain of the dataset you’re dealing with. Sometimes, it may not be the case (healthcare, for example). In [4], Xie et al. explore techniques like data balancing and data filtering in order to mitigate the issues that may arise when incorporating unlabeled data when training the student model.

Don’t use Label-Smoothing When Training the Teacher Model

Label-smoothing is a technique used to relax the high confidence predictions produced by models. It helps to reduce overfitting but it is not recommended to use label-smoothing when training the teacher model since its logits are anyway scaled by some temperature. Hence using label-smoothing in a knowledge distillation situation is not typically recommended. You can check out this article to know more about label-smoothing.

Using Higher Temperature Values

Hinton et al. recommends using higher temperature values to soften the distributions predicted by the teacher model so that the soft-labels can contain even more information for the student model. This is especially useful when dealing with small datasets. For larger datasets, the information becomes available by means of the number of training examples. Refer to the extended softmax section again if this is not clear as to why using a higher temperature might be better in softening the predicted distributions.

We will be exploring these recipes shortly in the next section.

Experimental results

Let’s first review the experimental set up. I used the Flowers dataset for my experiments. Unless otherwise specified, I used the following configurations -

Here’s the Colab Notebook that you can follow along →

Section 8

Baseline Student Model

To make the performance comparisons fair, let's also train the shallow CNN from scratch and observe its performance. Note that in this case, I used Adam as the optimizer with a learning rate of 1e-3.

Section 10

The Training Loop

Before we see the results, I wanted to shed some light on the training loop and how I was able to wrap that inside the classic model.fit() call. This is how the training loop looks like -

def train_step(self, data):
        images, labels = data
        teacher_logits = self.trained_teacher(images)
        
        with tf.GradientTape() as tape:
            student_logits = self.student(images)
            loss = get_kd_loss(teacher_logits, student_logits)
        gradients = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))

        train_loss.update_state(loss)
        train_acc.update_state(labels, tf.nn.softmax(student_logits))
        t_loss, t_acc = train_loss.result(), train_acc.result()
        train_loss.reset_states(), train_acc.reset_states()
        return {"loss": t_loss, "accuracy": t_acc}

The train_step() function should be an easy read if you are already familiar with how to customize a training loop in TensorFlow 2. Notice the get_kd_loss() function. This can be any of the loss functions we have discussed do far. We are using a trained teacher model here, the model we fine-tuned earlier. With this training loop, we can create an entire model that can be trained with a .fit() call.

First, create a class extending tf.keras.Model -

class Student(tf.keras.Model):
    def __init__(self, trained_teacher, student):
        super(Student, self).__init__()
        self.trained_teacher = trained_teacher
        self.student = student

When you extend the tf.keras.Model class, you can put your custom training logic inside the train_step() function (it's provided by the class). So, in its entirety, the Student class would look like this -

class Student(tf.keras.Model):
    def __init__(self, trained_teacher, student):
        super(Student, self).__init__()
        self.trained_teacher = trained_teacher
        self.student = student

    def train_step(self, data):
        images, labels = data
        teacher_logits = self.trained_teacher(images)
        
        with tf.GradientTape() as tape:
            student_logits = self.student(images)
            loss = get_kd_loss(teacher_logits, student_logits)
        gradients = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))

        train_loss.update_state(loss)
        train_acc.update_state(labels, tf.nn.softmax(student_logits))
        t_loss, t_acc = train_loss.result(), train_acc.result()
        train_loss.reset_states(), train_acc.reset_states()
        return {"train_loss": t_loss, "train_accuracy": t_acc}

You can even write a test_step to customize the evaluation behavior of the model. If you are interested to check that out and also the train_step() utility check out this Colab Notebook. Our model can now be trained in the following manner -

student = Student(teacher_model, get_student_model())
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
student.compile(optimizer)

student.fit(train_ds, 
            validation_data=validation_ds,
            epochs=10)

One potential advantage of this method is one can easily incorporate other capabilities like distributed training, custom callbacks, mixed precision, and so on.

Training the Student Model with $\mathcal{L}_{C E}^{K D}$

Upon training our shallow student model with this loss function we get ~74% validation accuracy. We see that the losses start to increase after epoch 8. This suggests stronger regularization might have helped. Also, note that the hyperparameter tuning process has a significant impact here. In my experiments, I did not do rigorous hyperparameter tuning. In order to do faster experimentation, I kept the training schedules short.

Section 12

Training the Student Model with $\frac{ (\alpha * \mathcal{L}{C E}^{K D} + \beta * \mathcal{L}{CE})} {(\alpha + \beta)}$

Let's now see if incorporating the ground-truth labels in the distillation training objective helps. With $\beta$ = 0.1 and $\alpha$ = 0.1, we get around ~71% validation accuracy. The training dynamics again suggests that stronger regularization with a longer training schedule would have helped.

Section 14

Training the Student Model with $\mathcal{L}_{M S E}^{K D}$

With the MSE loss, we see that the validation accuracy gets a whopping decrease to ~56%. The same kind of loss behavior is present in this setting as well suggesting the need for regularization.

Section 16

Note that this loss function is absolutely unconstrained and our shallow student model may not be capable of handling the noise that comes with it. Let's try out with a deeper student model.

Section 17

Using Data Augmentation While Training the Student Model

As mentioned earlier, the student models are of smaller capacity than the teacher model. When dealing with less data, data augmentation can be helpful to train the student model. Let's verify.

Section 20

Effect of Temperature ($\tau$)

In this experiment, let's study the effect of temperature on the student model. In this setting, I used the same shallow CNN.

Section 22

Finally, I wanted to study if the choice of a base model for fine-tuning had significant effect on the student model.

Section 24

Finally, if you are wondering what kind of improvement one could get out of knowledge distillation with respect to production purposes. The table below dictates that for us. Without any hyperparameter tuning, we are able to get a decent model that is significantly more lightweight than the other models shown in the table.

The first row corresponds to the default student model trained with the weighted average loss while the other rows correspond to EfficientNey B0 and MobileNetV2 respectively. Note that I did not include the results I got from including data augmentation during training the student model.

Section 26

Conclusion and Further Thoughts

This concludes the report and also the series I have been developing on model optimization. Knowledge distillation is a very promising technique specifically suited for deployment purposes. A very good about it is it can be combined with quantization and pruning pretty seamlessly in order to further reduce the size of your production models without having to compromise with the accuracy.

We studied the idea of knowledge distillation in an image classification setting. If you are wondering if it’s extensible to other areas like NLP or even GANs I recommend going over the following resources -

You can also check out some of the CVPR 2020 papers on knowledge distillation here.

Another trend that you would see is having a larger or equivalent student model. This has been very systematically studied in [4].

References

  1. Hinton, Geoffrey, et al. “Distilling the Knowledge in a Neural Network.” ArXiv:1503.02531 [Cs, Stat], Mar. 2015. arXiv.org, http://arxiv.org/abs/1503.02531.
  2. Buciluǎ, Cristian, et al. “Model Compression.” Proceedings of the 12th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, Association for Computing Machinery, 2006, pp. 535–541. ACM Digital Library, doi:10.1145/1150402.1150464.
  3. Tang, Raphael, et al. “Distilling Task-Specific Knowledge from BERT into Simple Neural Networks.” ArXiv:1903.12136 [Cs], Mar. 2019. arXiv.org, http://arxiv.org/abs/1903.12136.
  4. Xie, Qizhe, et al. “Self-Training with Noisy Student Improves ImageNet Classification.” ArXiv:1911.04252 [Cs, Stat], June 2020. arXiv.org, http://arxiv.org/abs/1911.04252.
  5. Chen, Ting, et al. “Big Self-Supervised Models Are Strong Semi-Supervised Learners.” ArXiv:2006.10029 [Cs, Stat], June 2020. arXiv.org, http://arxiv.org/abs/2006.10029.

Acknowledgements

I am grateful to Aakash Kumar Nain for providing valuable feedback on the code.