KD distillation with(out) MixUp On Student
In this report we will see the effect of training the student network with a batch of original CIFAR100 images combined with a MixUp set of images from the same batch.
The teacher network is trained on pure images and the targets for the MixUp images are created on the fly by using the separate images used to create the MixUp.
Created on June 15|Last edited on June 16
Comment
Training set up
Given batch B we create the MixUp version of this batch B'.
First we need to sample the lambda parameter which dictates the "opacity" of each image used in the mix. After that we create a random permutation of the batch index to decide the pairs on which the mixing will occur
index = torch.randperm(batch_size)mixed_x = lam * x + (1 - lam) * x[index, :]
Now all we need to do is set up the criterion for our loss function. By default each batch contains #batch_size images followed by #batch_size mixed images. That being said we need to apply cross entropy on the first #batch_size images and a lambda-weighted cross entropy on the next #batch_size mixed images
# Cross entropy on original imagesoriginal_input_loss = F.cross_entropy(pred[:n_original], y_a)# lambda-weighted cross entropy on mixed imagesmixed_input_loss = lam * F.cross_entropy(pred[n_original:], y_a) + (1 - lam) * F.cross_entropy(pred[n_original:],y_b)# Average of both lossesreturn (original_input_loss + mixed_input_loss)/2 # Average over batch_size and n_mixed
In this experiment we tried to examine the following setup:
Teacher network: Training only on "pure" images
Student network: Training on both "pure" and mixed images.
In order to be able to create a ground truth for the mixed images we did the following.
- For each image in batch create teacher logits
- For each secondary image used to create a MixUp pair -> create the teacher logits
- Combine them in a lambda-weighted manner.
# x is the batch of imageslogits_teacher, _ = self.teacher(x)# x[index,:] are the secondary images used for the MixUp Pairslogits_teacher_mixed, _ = self.teacher(x[index, :])# Concatenate teacher logits with teacher logits for mixed inputsexpanded_teach = torch.cat([logits_teacher, lam * logits_teacher + (1-lam)*logits_teacher_mixed], 0)
Run set
2
In the following section we will experiment with the usage of a weight between the mixed images and the original images. This run uses a 0.3 weight for the mixed images and 0.7 for the original "pure" images.
Run set
4
Add a comment