Self-Distillation and Knowledge Distillation Experiments
Created on February 2|Last edited on February 10
Comment
Introduction
There are three mysteries in deep learning. This highlighted in the recent blog by Zeyuan Allen-Zhu and Yuanzhi Li [1]. They are ensemble, self-distillation, and knowledge distillation. They discuss the mysteries around them in the blog and analyze why they actually work. Inspired by the fact that these ideas work, and motivated to validate them.
I ran some experiments on the CIFAR-10 dataset, using the ResNet-18 as a Convolutional Neural Network. In this report, I will describe the training process. In the end, achieved 3.12% improvement over the single model with an ensemble of five models (ResNet-18). Self- distillation and knowledge-distillation increased the accuracy over a single model by 1.29% and 1.05%. Github repo can be found here: github.com/mrpositron/distillation.
However, let me briefly describe what those techniques mean. The first technique is ensemble. We train multiple classifiers independently and then average all results. Second technique is knowledge distillation. We take the output of the ensemble and use it as a target to another model. Here the target is a probability distribution. Final technique is self-distillation. We train one classifier and use its probability distribution from the softmax to another classifier.
Method
Due to the hardware constraints the architecture to perform experiments was chosen to be ResNet-18.
Step 1. Train the teacher network.
First, we should train the teacher network. During self- and knowledge- distillations new networks will try to learn the distribution of the teacher network. In order to train teacher models we set they hyperparameters as in the table below.
| Hyperparameters | Values |
|---|---|
| Loss | Cross Entropy |
| Learning Rate | 0.001 |
| Optimizer | Adam |
| Number of epochs | 100 |
| Training size | 45000 |
| Validation size | 5000 |
In figures below you can see the training graphs.
In the following figures you can observe validation graphs.
Models with the minimal validation losses were saved for further experiments. Validation accuracy and test accuracy are shown in the table below.
| Model # | Validation Accuracy | Test Accuracy |
|---|---|---|
| Model 0 | 86.66 | 86.01 |
| Model 1 | 86.06 | 85.46 |
| Model 2 | 86.08 | 85.71 |
| Model 3 | 86.38 | 86.17 |
| Model 4 | 87 | 86.39 |
Mean validation accuracy is 86.44%, and mean test accuracy is 85.95%.
Step 2. The ensemble.
The ensemble of five models achieves 89.07% accuracy on the testing set, which is 3.12% improvement.
Step 3. Self-Distillation and Knowledge Distillation
Graphs below show the training procedure for 5 ResNet models trained by the self-distillation method and a ResNet model trained by the knowledge distillation method. During self-distillation we use the probability distribution from the teacher network as a a target for a student network, and during knowledge distillation the probability distribution from an ensemble of five models is used as a teacher network.
Figures below show the training graphs of the student networks. There are two loss graphs, one of them is cross entropy loss between model output and target distribution, and another one is between model output and the target label. The former is used during training, and the latter acts as a supportive .
Run set
6
In the following figures you can observe validation graphs.
Run set
6
| (TN) Validation accuracy | (TN) Test accuracy | (SN) Validation accuracy | (SN) Test accuracy |
|---|---|---|---|
| 86.66 | 86.01 | 87.7 | 86.88 |
| 86.06 | 85.46 | 87.54 | 86.48 |
| 86.08 | 85.71 | 87.14 | 86.13 |
| 86.38 | 86.17 | 86.46 | 85.32 |
| 87 | 86.39 | 87.38 | 87.08 |
The table shows the results of using the model on the left as a teacher (TN, i.e. Teacher Network), and student (SN, i.e. Student Network)on the right. Models on the left are identical to those used before in Step 1.
Discussion
It is clear to see that ensemble, knowledge and self-distillation works. The accuracy on the test set goes up for about ~1% using self-distillation. The ensemble model gives us ~2.5-3% boost.
Knowledge distillation works as seen from the results. However, I expected more significant boost. One interesting thing to note is that if we will create an ensemble from the models produced in self-distillation, then the accuracy will be 89.38.
Add a comment