Distilling Knowledge in Neural Networks With Weights & Biases
This article discusses the compelling model optimization technique — knowledge distillation — using W&B for tracking, with code walkthroughs in TensorFlow.
Created on September 5|Last edited on July 30
Comment
“Model ensembles are a pretty much-guaranteed way to gain 2% of accuracy on anything.” - Andrej Karpathy.
I absolutely agree! However, deploying an ensemble of heavyweight models may not always be feasible in many cases. Sometimes, your single model could be so large (GPT-3, for example) that deploying it in resource-constrained environments is often not possible. This is why we have been going over some of the model optimization recipes — Quantization and Pruning.
This article is the last one in this series. In this article, we will discuss a compelling model optimization technique: knowledge distillation.
Here's what we'll go over:
Table of Contents
What Is Softmax Telling Us?Using the Softmax Information for Teaching - Knowledge DistillationLoss Functions in Knowledge DistillationA Few Training RecipesExperimental ResultsConclusion and Further ThoughtsReferencesAcknowledgements
What Is Softmax Telling Us?
When working with a classification problem, it is very typical to use softmax as the last activation unit in your neural network. Why is that? Because a softmax function takes a set of logits and spits out a probability distribution over the discrete classes your network is being trained on. Figure 1 presents an example of this:

Figure 1: Predictions of a neural network on an input image.
In Figure 1, our imaginary neural network is highly confident that the given image is . However, it also thinks that there is a slight chance it could be as well. It is thinking quite right, isn’t it? The given image does have subtle seven-ish characteristics. This information would not have been available if we were only dealing with hard one-hot encoded labels like [1, 0] (where 1 and 0 are probabilities of the image being a one and a seven, respectively).
Humans are well equipped to exploit this sort of relativeness. More examples include - a cat-ish dog, a brownish red, a cat-ish tiger, and so on. These are still valid comparisons, as Hinton et al. opine in [1] -
An image of a BMW, for example, may only have a minimal chance of being mistaken for a garbage truck, but that mistake is still many times more probable than mistaking it for a carrot.
This very knowledge helps us to generalize excruciatingly well out there in the wild.
This thought process helps us to dig deeper into what our models might be thinking about the input data. It should be somewhat consistent with the way we would think about the input data. Figure 1 again establishes this - to our eyes, that image looks like a one, but it has some traits of a seven.
So, what now? An immediate question that may strike the mind - what is the best way for us to use this knowledge in neural networks? Let us find out in the next section.
Using the Softmax Information for Teaching - Knowledge Distillation
The softmax information is way more useful than plain hard one-hot encoded labels. So, at this stage, we may have access to:
- The training data
- A trained neural network that performs well on the test data
We are now interested in using the output probabilities produced by our trained network.
Consider teaching someone about the English digits with the MNIST dataset. It is highly likely that you would run into the question from a student - does not that one look like a seven? If that is the case, it’s definitely good news because your student, for sure, knows what a one and a seven look like. As a teacher, you have been able to transfer your knowledge of English digits to your student. It is possible to extend this idea to neural networks as well.
High-Level Mechanics of Knowledge Distillation
So, here is the deal at a high level:
- Train a neural network that performs well on your dataset. This network is going to act like a teacher model.
- Use the teacher model to train a student model on the same dataset. The catch here is that the student model should be significantly smaller than the teacher in terms of capacity.
This workflow briefly formulates the idea of knowledge distillation.
Why smaller?
Isn’t this what we want? To deploy a lightweight model to production that is performant enough?
An Image Classification Case Study
Disclaimer: For the sake of brevity and simplicity, I am going to demonstrate the further sections on a computer vision-based example. Note: These ideas are independent of domains.
- Train a teacher model that performs well on your image dataset. Here the cross-entropy loss would be calculated with respect to the true labels from your dataset.
- Train a smaller student model on the same dataset but use the predictions from the teacher model (the softmax output) as the ground-truth labels. These softmax outputs are referred to as soft labels. More on this in a moment.

Figure 2: A high-level overview of knowledge distillation.
Why are we training the student model on soft-labels?
Remember that our student model is smaller than the teacher model in terms of capacity. So, if your dataset is complex enough, then the smaller student model may not be well suited to capture the hidden representations required for the training objective. We train the student model on soft-labels to compensate for this, which provides more meaningful information than the one-hot encoded labels. In a sense, we are training the student model to imitate the teacher model’s outputs by giving a little bit of exposure to the training dataset.
Hopefully, this provided you with an intuitive understanding of knowledge distillation. In the next section, we will be taking a more detailed look at the student model's training mechanics.
Loss Functions in Knowledge Distillation
In order to train the student model, we can still use our regular cross-entropy loss between the soft-labels from the teacher and predicted labels from the student. It is highly likely that the student model would be confident about many of the input data points, and it would predict probability distributions like the following -

Figure 3: Highly confident predictions.
Extended Softmax
The problem with these weak probabilities (marked in red in Figure 3) is they do not capture desirable information for the student model to learn effectively. For example, it is almost impossible to transfer the knowledge that the image has seven-ish traits if the probability distribution is like [0.99, 0.01].
Hinton et al. address this problem by scaling the raw logits of the teacher model by some temperature () before they get passed to softmax [1] (known as extended softmax or temperature-scaled softmax). That way, the distribution gets more spread across the available class labels. This same temperature is used in order to train the student model. I have presented this idea in Figure 4.

Figure 4: Softened predictions.
We can write the student model’s modified loss function in the form of this equation -
,
where is the softened probability distribution of the teacher model and is expressed as - .
def get_kd_loss(student_logits, teacher_logits,true_labels, temperature,alpha, beta):teacher_probs = tf.nn.softmax(teacher_logits / temperature)kd_loss = tf.keras.losses.categorical_crossentropy(teacher_probs, student_logits / temperature,from_logits=True)return kd_loss
Incorporating the Hard-Labels With Extended Softmax
In [1] Hinton et al. also explore the idea of using the conventional cross-entropy loss between the true target labels (typically one-hot encoded) and student model’s predictions. This especially helps when the training dataset is small and there isn’t enough signal in the soft-labels for the student model to pick up.
This approach works significantly better when it is combined with the extended softmax and the overall loss function becomes a weighted average between the two -
def get_kd_loss(student_logits, teacher_logits,true_labels, temperature,alpha, beta):teacher_probs = tf.nn.softmax(teacher_logits / temperature)kd_loss = tf.keras.losses.categorical_crossentropy(teacher_probs, student_logits / temperature,from_logits=True)ce_loss = tf.keras.losses.sparse_categorical_crossentropy(true_labels, student_logits, from_logits=True)total_loss = (alpha * kd_loss) + (beta * ce_loss)return total_loss / (alpha + beta)
It’s recommended to weigh considerably smaller than .
Operating on the Raw Logits
Caruana et al. operate on the raw logits instead of the softmax values [2]. This workflow is as follows -
- This part remains the same -
- Train a teacher model that performs well on your image dataset. Here the cross-entropy loss would be calculated with respect to the true labels from your dataset.
- Now, in order to train the student model, the training objective becomes minimizing the mean squared error between the raw logits from the teacher and the student models respectively.
mse = tf.keras.losses.MeanSquaredError()def mse_kd_loss(teacher_logits, student_logits):return mse(teacher_logits, student_logits)
One potential disadvantage of using this loss function could be its unconstrained nature. The raw logits can capture noise that a small model may not be able to fit properly. This is why in order for this loss function to fit well in the distillation regime, the student model needs to a bit bigger.
Tang et al. explore the idea of interpolating between the two losses - the extended softmax and the MSE loss [3]. Mathematically, it would look like the following -
Empirically, they found out when is equal to 0, the best performance is achieved (on NLP tasks).
If you’re feeling a bit overwhelmed at this point, don’t sweat it. Hopefully, with the code, things will start to shine.
A Few Training Recipes
In this section, I will provide you with a few training recipes that you can consider while working with knowledge distillation.
Using Data Augmentation
This idea is explored in [3] by Tang et al. They demonstrate this idea on NLP datasets, but this is applicable to other domains as well. In order to better guide the student model, training using data augmentation can help, especially when you are dealing with less data. As we typically keep the student model much smaller than the teacher model, so the hope is with more diverse data the student model gets to capture the domain better.
Using Labeled and Unlabeled Data to Train the Student Model
In works like Noisy Student Training [4] and SimCLRV2 [5], the authors use additional unlabeled data when training the student model. So, you would use your teacher model to generate the ground-truth distribution on the unlabeled dataset. This helps to increase the generalizability of the model to a great extent.
This approach is only feasible when unlabeled data is available in the domain of the dataset you’re dealing with. Sometimes, it may not be the case (healthcare, for example). In [4], Xie et al. explore techniques like data balancing and data filtering in order to mitigate the issues that may arise when incorporating unlabeled data when training the student model.
Don’t Use Label-Smoothing When Training the Teacher Model
Label-smoothing is a technique used to relax the high confidence predictions produced by models. It helps to reduce overfitting but it is not recommended to use label-smoothing when training the teacher model since its logits are anyway scaled by some temperature. Hence using label-smoothing in a knowledge distillation situation is not typically recommended. You can check out this article to know more about label-smoothing.
Using Higher Temperature Values
Hinton et al. recommend using higher temperature values to soften the distributions predicted by the teacher model so that the soft-labels can contain even more information for the student model. This is especially useful when dealing with small datasets. For larger datasets, the information becomes available by means of the number of training examples. Refer to the extended softmax section again if this is not clear why using a higher temperature might be better for softening the predicted distributions.
We will be exploring these recipes shortly in the next section.
Experimental Results
Let’s first review the experimental set up. I used the Flowers dataset for my experiments. Unless otherwise specified, I used the following configurations -
- I used MobileNetV2 as the base model for fine-tuning with learning rate set to 1e-5 with Adam as the optimizer.
- I set the (temperature) to 5 when using and weighted average of and traditional cross-entropy losses respectively.
- = 0.9 and = 0.1 when using a weighted average of and traditional cross-entropy losses.
- For the student model, I followed this shallow architecture -Layer (type) Output Shape Param #=================================================================conv2d (Conv2D) (None, 222, 222, 64) 1792_________________________________________________________________max_pooling2d (MaxPooling2D) (None, 55, 55, 64) 0_________________________________________________________________conv2d_1 (Conv2D) (None, 53, 53, 128) 73856_________________________________________________________________global_average_pooling2d_3 ( (None, 128) 0_________________________________________________________________dense_3 (Dense) (None, 512) 66048_________________________________________________________________dense_4 (Dense) (None, 5) 2565=================================================================
- During training the student model I used Adam as an optimizer with a learning rate of 1e-2.
- During training the student model with data augmentation, I used the weighted average loss with the same default hyperparameters mentioned above.
Run set
1
Baseline Student Model
To make the performance comparisons fair, let's also train the shallow CNN from scratch and observe its performance. Note that in this case, I used Adam as the optimizer with a learning rate of 1e-3.
Run set
1
The Training Loop
Before we see the results, I wanted to shed some light on the training loop and how I was able to wrap that inside the classic model.fit() call. This is how the training loop looks like -
def train_step(self, data):images, labels = datateacher_logits = self.trained_teacher(images)with tf.GradientTape() as tape:student_logits = self.student(images)loss = get_kd_loss(teacher_logits, student_logits)gradients = tape.gradient(loss, self.student.trainable_variables)self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))train_loss.update_state(loss)train_acc.update_state(labels, tf.nn.softmax(student_logits))t_loss, t_acc = train_loss.result(), train_acc.result()train_loss.reset_states(), train_acc.reset_states()return {"loss": t_loss, "accuracy": t_acc}
The train_step() function should be an easy read if you are already familiar with how to customize a training loop in TensorFlow 2. Notice the get_kd_loss() function. This can be any of the loss functions we have discussed do far.We are using a trained teacher model here, the model we fine-tuned earlier. With this training loop, we can create an entire model that can be trained with a .fit() call.
First, create a class extending tf.keras.Model -
class Student(tf.keras.Model):def __init__(self, trained_teacher, student):super(Student, self).__init__()self.trained_teacher = trained_teacherself.student = student
When you extend the tf.keras.Model class, you can put your custom training logic inside the train_step() function (it's provided by the class). So, in its entirety, the Student class would look like this -
class Student(tf.keras.Model):def __init__(self, trained_teacher, student):super(Student, self).__init__()self.trained_teacher = trained_teacherself.student = studentdef train_step(self, data):images, labels = datateacher_logits = self.trained_teacher(images)with tf.GradientTape() as tape:student_logits = self.student(images)loss = get_kd_loss(teacher_logits, student_logits)gradients = tape.gradient(loss, self.student.trainable_variables)self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))train_loss.update_state(loss)train_acc.update_state(labels, tf.nn.softmax(student_logits))t_loss, t_acc = train_loss.result(), train_acc.result()train_loss.reset_states(), train_acc.reset_states()return {"train_loss": t_loss, "train_accuracy": t_acc}
You can even write a test_step to customize the evaluation behavior of the model. If you are interested to check that out and also the train_step() utility check out this Colab Notebook. Our model can now be trained in the following manner -
student = Student(teacher_model, get_student_model())optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)student.compile(optimizer)student.fit(train_ds,validation_data=validation_ds,epochs=10)
One potential advantage of this method is one can easily incorporate other capabilities like distributed training, custom callbacks, mixed precision, and so on.
Training the Student Model with
Upon training our shallow student model with this loss function we get ~74% validation accuracy. We see that the losses start to increase after epoch 8. This suggests stronger regularization might have helped. Also, note that the hyperparameter tuning process has a significant impact here. In my experiments, I did not do rigorous hyperparameter tuning. In order to do faster experimentation, I kept the training schedules short.
Run set
1
Training the Student Model with
Let's now see if incorporating the ground-truth labels in the distillation training objective helps. With = 0.1 and = 0.1, we get around ~71% validation accuracy. The training dynamics again suggests that stronger regularization with a longer training schedule would have helped.
Run set
1
Training the Student Model with
With the MSE loss, we see that the validation accuracy gets a whopping decrease to ~56%. The same kind of loss behavior is present in this setting as well suggesting the need for regularization.
Run set
1
Note that this loss function is absolutely unconstrained and our shallow student model may not be capable of handling the noise that comes with it. Let's try out with a deeper student model.
Run set
1
Using Data Augmentation While Training the Student Model
As mentioned earlier, the student models are of smaller capacity than the teacher model. When dealing with less data, data augmentation can be helpful to train the student model. Let's verify.
Run set
1
Effect of Temperature ()
In this experiment, let's study the effect of temperature on the student model. In this setting, I used the same shallow CNN.
Run set
11
Finally, I wanted to study if the choice of a base model for fine-tuning had significant effect on the student model.
Run set
2
Finally, if you are wondering what kind of improvement one could get out of knowledge distillation with respect to production purposes. The table below dictates that for us. Without any hyperparameter tuning, we are able to get a decent model that is significantly more lightweight than the other models shown in the table.
The first row corresponds to the default student model trained with the weighted average loss while the other rows correspond to EfficientNey B0 and MobileNetV2 respectively. Note that I did not include the results I got from including data augmentation during training the student model.
Run set
3
Conclusion and Further Thoughts
This concludes the report and also the series I have been developing on model optimization. Knowledge distillation is a very promising technique specifically suited for deployment purposes. A very good about it is it can be combined with quantization and pruning pretty seamlessly in order to further reduce the size of your production models without having to compromise with the accuracy.
We studied the idea of knowledge distillation in an image classification setting. If you are wondering if it’s extensible to other areas like NLP or even GANs I recommend going over the following resources -
- “Distilling Knowledge from Neural Networks to Build Smaller and Faster Models.” FloydHub Blog, 11 Nov. 2019, https://blog.floydhub.com/knowledge-distillation/.
- Sanh, Victor, et al. “DistilBERT, a Distilled Version of BERT: Smaller, Faster, Cheaper and Lighter.” ArXiv:1910.01108 [Cs], Feb. 2020. arXiv.org, http://arxiv.org/abs/1910.01108.
- Aguinaldo, Angeline, et al. “Compressing GANs Using Knowledge Distillation.” ArXiv:1902.00159 [Cs, Stat], Jan. 2019. arXiv.org, http://arxiv.org/abs/1902.00159.
- In [1] Hintol et al. also demonstrates knowledge distillation on a speech recognition task.
Another trend that you would see is having a larger or equivalent student model. This has been very systematically studied in [4].
References
- Hinton, Geoffrey, et al. “Distilling the Knowledge in a Neural Network.” ArXiv:1503.02531 [Cs, Stat], Mar. 2015. arXiv.org, http://arxiv.org/abs/1503.02531.
- Buciluǎ, Cristian, et al. “Model Compression.” Proceedings of the 12th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, Association for Computing Machinery, 2006, pp. 535–541. ACM Digital Library, doi:10.1145/1150402.1150464.
- Tang, Raphael, et al. “Distilling Task-Specific Knowledge from BERT into Simple Neural Networks.” ArXiv:1903.12136 [Cs], Mar. 2019. arXiv.org, http://arxiv.org/abs/1903.12136.
- Xie, Qizhe, et al. “Self-Training with Noisy Student Improves ImageNet Classification.” ArXiv:1911.04252 [Cs, Stat], June 2020. arXiv.org, http://arxiv.org/abs/1911.04252.
- Chen, Ting, et al. “Big Self-Supervised Models Are Strong Semi-Supervised Learners.” ArXiv:2006.10029 [Cs, Stat], June 2020. arXiv.org, http://arxiv.org/abs/2006.10029.
Acknowledgements
Add a comment
In the case of "EfficientNet B0," the efficient network represents the teacher, while the shallow network represents the student?
Reply
Hello,
Thanks for this detailed tutorial. I could not find a better one about knowledge distillation in tensorflow.
I have a question though: Is there a way to save the student model weights after each epoch? I tried adding it into the callbacks in distiller.fit function. But this will not work since it takes distiller as the model and I need to save the weights for just the distiller.student.
Thanks in advance :)
1 reply
Hi, another question: In case the teacher has been trained with categorical crossentropy and one-hot encoded labels, should we use the loss function described in the section "Extended Softmax"?
1 reply
Hi, thank you for this remarkably well-written article!
I've tried to reproduce the training steps involved but encountered an issue when using a custom data generator. In these cases, the subclass of tf.keras.Model throws "NotImplementedError('When subclassing the Model class, you should 'implement a call method.')" upon .fit().
I know it expects a 'call' method, but I haven't been able to figure out the expected return value.
Could you briefly describe what needs to be implemented within 'call'?
Related issues are:
https://stackoverflow.com/questions/64933424/error-using-keras-imagedatagenerator-with-custom-train-step-in-subclassed-model
and
https://github.com/tensorflow/tensorflow/issues/43173
Any help is greatly appreciated.
4 replies
Very clean and concise exposition. A piece that seems missing though is, if certain neural networks have the right biases for the task at hand (say CNNs for image classification) that enables them to see the "seven-ish" and "one-ish" traits in Fig 1, why we need to incorporate these "existing" biases back into labels for training a student network in order to see the performance boost? Perhaps the question is more clear if we limit the number of moving parts, e.g. assuming teacher and student are identical in their architecture and optimizer. Empirical observations suggest that even in such setting the student can outperform its teacher, but why? They have identical architectures/optimizers and hence same inductive biases; so if the teacher has the ability to see "seven-ish" and "one-ish" traits, so should the student. But why does this only work when we put this knowledge back into labels and train again (i.e. KD) ? I think this is part of the KD magic that is not well understood. I conjecture that the distillation feedback loop can amplify these inductive biases. I am very curious to hear about other theories and thoughts on this.
2 replies
Iterate on AI agents and models faster. Try Weights & Biases today.