Supervised Contrastive Larning
Introduction
In this article, we are going to see a new training methodology namely Supervised Contrastive Learning . This methodology is an adaption of contrastive learning in the field of fully supervised problems. This method actually uses the label information to cluster the samples belonging to the same class in the embedding space. On top of it, a linear classifier can be used to separate the images. This method is said to be outperforming the Cross-Entropy.

Here in this article, we are going to discuss the proposed methodology in detail and then we will be running 3 experiments to check how does it perform compared to Cross-Entropy.
Here is the list of content covered in this article.
- Revisiting Cross-Entropy
- Contrastive Learning
- Supervised Contrastive Learning
- Supervised Contrastive Loss
- Experiments
- Conclusion
Revisiting Cross-Entropy
Cross-Entropy. Something we all are well aware of and have used earlier. Cross entropy is used as a loss function for classification problems to find the error. It actually measures the difference between two or more probability distribution. One of which is the predicted value by the machine, i.e. the value that the machine believes should be and the actual value i.e. the value the distribution should actually have. This error is then used to optimize the model and find the optimum values of the parameter so that the model's prediction is close enough to the actual values aka ground truth values.
Having said this, one can ask me, then what is the difference between KL Divergence and Cross-Entropy? If I talk about the intuition behind KL divergence then I would say this also measures the difference between probability distributions. In other words, how similar a probability distribution is to another probability distribution. But the difference lies in their formula. Let's look at them. The formula for cross-entropy loss for multiclass classification is
-∑i=1Nyilog\sum_{i=1}^{N} y_ilog(yi^\hat{y_i})
Where,
N is the number of classes.
yiy_{i} is the actual value.
yi^\hat{y_i} is the predicted value.
Whereas the formula for KL Divergence for the same model will be
∑i=1Nyilog\sum_{i=1}^{N} y_ilog( yiy_{i}/yi^\hat{y_i})
Let's breakdown the formula.
KL Divergence = ∑i=1Nyilog\sum_{i=1}^{N} y_ilog( yiy_{i}/yi^\hat{y_i})
= ∑i=1Nyilog\sum_{i=1}^{N} y_ilog( yiy_{i}) - ∑i=1Nyilog\sum_{i=1}^{N} y_ilog( yi^\hat{y_i})
= Entropy + Cross-Entropy
So you can see, KL Divergence is actually the sum of Entropy and the Cross-Entropy. In an ideal world, the Entropy will be constant. So if you try to optimize KL Divergence, You are actually optimizing Cross-Entropy because Entropy is anyway going to be a constant. But in practice, Cross-Entropy is preferred as a loss function. The simple reason behind this is optimizing Cross-Entropy avoids calculating redundant terms(the constant term).
So coming back to Cross-Entropy, what it actually tries to do is
- don't calculate the loss function for negative classes( for which y =0).
- for positive class, try to minimize negative of the log of the predicted value. And to do this, the function will try to maximize the predicted value. Hence bringing it closer to 1.
Finding loss is an important step because from here the backpropagation starts in the network and pushes the values of the parameters in the right direction. Hence giving the best working network. Having said this, if we choose any loss function which is not suitable or isn't working well, everything can collapse. Now cross-entropy is something which is everyone's goto choice for classification kind of problems. But it also suffers from some major problems such as
-
It is sensitive to noise and adversarial examples. It gives poor results if these are present.
-
Cross entropy leads to poor margin because of which the model gives false results if the inputs differ from the training data even a bit.
To overcome these problems one new approach has been proposed named Supervised Contrastive Learning. Let's see what does it offer. But first, we need to know what is meant by contrastive learning.
Contrastive Learning
The verbal meaning of contrasting is to compare in order to show differences. In contrastive learning, we actually contrast between similar and dissimilar things. Putting in other words, we make our machine able to think in such a way that it can find the difference between similar and dissimilar things. For example, suppose we have a binary classification problem where the machine needs to classify between cats and dogs after looking at an input image. So in contrastive learning, we will make our machine understand that

Taking this forward, let's try to understand Supervised Contrastive Learning.
Supervised Contrastive Learning
The idea proposed in Supervised Contrastive learning is pretty simple. Make the model learn to map the normalized embedding of the samples belonging the same classes closer and the samples belonging the other class/classes farther embeddings of all the cat images should be close to each other and distant to all the dog images and vice-versa.
We know that the neural network first converts the image into a representation and then this representation is used to predict the result. So if the representations are formed keeping the idea given above into consideration, it will be easier for the classifier to give accurate results.
This whole idea is achieved in two stages.
Stage 1: In this stage, the network is trained using contrastive loss. Here the images are encoded in such a way that embeddings of one class are close and that of other classes are far. And to do this, the label of the images are used. This stage has three components namely data augmentation module, encoder network and encoder network. These components are explained below separately.
Stage 2: Here, the encoder network used in Stage 1 is frozen and the projector network is discarded. The representation learnt from the encoder network is then used to learn a classifier which is nothing but a linear layer. At this stage, Cross-Entropy loss is used to predict the labels.
Let's have a look at the components of Stage 1 of the training.
1). Data augmentation module.
This transforms the input image into augmented images. For each image, two augmented images are generated with different augmentation policy.
-
The first augmented image is got by randomly cropping and then resizing into the original size of the input image.
-
To get the second augmented image, three different options were evaluated.
i). AutoAugment
ii). RandAugment
iii). SimAugment ( the augmentation scheme proposed in SimCLR).
The authors used exactly the same data augmentation policies used here in stage 2 to train the linear classifier. They found the best results with this practice.
For a single input image, these two stages result in two different augmented images. That means if there were N sample images, this stage will return 2N images.
2). Encoder Network.
It simply converts the image into a representation vector. The authors used headless ResNet-50 and ResNet-200 as the base model for the encoder network and got some really fantastic results with them. Both the augmented images of the input image which we got from the data augmentation module are sent to the same encoder separately which outputs a pair of representation vectors. These outputs are normalized values. This means 1 input image will have two representations.
3). Projection network.
It converts the representation vectors into a vector suitable for contrastive loss calculation. The authors used a multi-layer perceptron with a single hidden layer of size 2048 and output vector of size DP = 128. The encoded vectors which we get as an output from the encoder network are fed into this network. The output of this network i.e. the projection vectors are normalized and then used in the loss function.
The output vector of this projection network is then sent to the supervised contrastive loss function(explained below), the loss is calculated which is then tried to be minimized.
This whole process is depicted diagrammatically below.
Where,
xx is the input image.
da1 & da2 are two different augmentation policies used in the Data Augmentation Module.
x1′x_1' & x2′x_2' are the output of the Data Augmentation Module.
E(x1′)E(x_1') & E(x2′)E(x_2') are the output vector of the Encoder Network.
P(E(x1′))P(E(x_1')) & P(E(x2′))P(E(x_2')) are the output vector of the Projector Network.
yy is the output.
The projected vector is sent to the Supervised Contrastive Loss which we are going to see in the next section.
Supervised Contrastive Loss
The Supervised Contrastive Loss function is given by the following formula:
Where,
NN is the number of randomly sampled images in a mini-batch. After passing these N images through the model in Stage 1, we will get 2N images.
ii is the index of an arbitrary augmented image within a mini-batch.
jj is the index of the other augmented image originating from the same input image.
kk is the index of other images apart from xix_i and xjx_j
ziz_i is the projected vector of input image. i.e. ziz_i=P(E(xi))P(E(x_i)).
This means ziz_i and zjz_j are the projected vectors of the same image and zkz_k is the projected vector of any other.
ττ is a scalar temperature parameter which is always positive.
1B1_B is 1 iff the condition B is true, 0 otherwise.
NyN_y is the total number of images in the minibatch that have the same label yy.
ziz_i*zjz_j computes an inner (dot) product between the normalized vectors.
Now to actually understand how is this function doing what it is expected to do, I would like to draw your attention to another topic called inner product aka dot product. Look at the image below.
Here a & b are two vectors.
At the left, you can see if we increase the angle(theta) then two vectors separate whereas at the right, if we decrease the angle between them, then they come close to each other. Keeping this into mind, let's come to dot products. The dot product between two vectors say a,b is given by
Now if I talk about cosine of an angle then we all know if the angle is small then the cosine is greater and vice versa. Now putting all these things together, if we want to place two vectors away from each other in space then we will have to increase the angle between them and if we take the cosine into consideration, then we will have to make the cosine of the angle small. And ultimately we will have to make the dot product of two vectors small. Similarly when we want to place two vectors close to each other then we will make the dot product between the two vectors large. Therefore as a conclusion, we can say the greater is the dot product, the closer are the vectors and vice-versa.
But wait!! Is the dot product only about the cosine of the angle between the vectors? No right? It also has two more terms which are the magnitude of the vectors. So the dot product is also depending on the magnitude of the vectors. But we surely don't want this. We want to use the dot product to measure the closeness of two vectors in space and for this, the dot product should be independent of the vector magnitude. So to do this the authors came up with an idea. Instead of using the projected vectors directly coming from the projector network, they normalize the vectors. This is done so that the vectors lie in a unit hypersphere. Simply speaking, making the vectors have a unit distance from the center. i.e. making the magnitude of the vector 1. When these normalised vectors are used to find the dot product, the dot product gives a clear view of their closeness. This may generate a question in your mind. Why not use cosine similarity instead of dot product with normalized vectors? The answer to this will be computation cost. The computation cost in finding the cosine similarity between each vector will be much greater than this proposed method.
Now coming back to the loss function given above, When we try to minimize the loss function we actually try to maximize the log term. Notice that in the numerator of the log term in the loss function, we are finding the exponential of the dot product between the image belonging to the same class whereas, in the denominator, we are finding the dot product between the image belonging to different classes. In order to maximize the log term, the numerator inside the log function will be increased and the denominator will be decreased. i.e. the exponential of the dot product of images belonging to the same class is maximized whereas the exponential of the dot product of images belonging to different classes is minimized. So ultimately when we try to minimize the loss, we actually try to bring the vectors belonging to the same class close and those belonging to different classes far apart.
Putting the pieces together
In this section, I will take you on a tour of the proposed methodology which will help you to understand the complete working of it.
Let's look at all the components step by step.
Step 0: This is the preprocessing step of the dataset. Before starting with the images, we resize them to a fix size 128X128X3(for example). We also normalize the images.
(From here Step i.j means the jth step of the ith stage.)
Step 1.1: The dataset is sent to the Data Augmentation Module which applies different data augmentations(explained above) on this image dataset. Suppose the dataset had 2k images before entering this module then at the end of this step the dataset will have 4K images. This is because, for each image in the dataset, this module will be producing two augmented images. Now the dataset is finally ready to be sent in the encoder network.
Step 1.2: The encoder network has ResNet50 or ResNet 200(without the top) as the base network whose output is then sent to a Dense layer with 2048 neurons. Let's say the size of a minibatch is 64. In this minibatch, there are 32 pairs of images. Each image of a pair of the batch is one augmented image of the same image. So the input matrix is of shape (64,128,128,3) which is sent to the encoder network. The final output that we will get from this will have a shape of (64,2048). This output is nothing but the encoded vectors for the image. So each encoded vector has a size of 2048 and there are 2 encoded vectors for each of the image. But wait, 2 encoded vectors for the same image. Isn't that useless? No. Because we get 2 vectors for each image from two augmented images of that very image. Each of these two augmented images represents a different view of the data and thus contains some subset of the information in the original input image. Ultimately 2 different encoded vectors for the same image are giving us some subset of the original information.
These vectors are finally normalized. In the above section, we have already discussed why we normalize the projected vector before sending to the loss function. But normalization is done at this stage as well. As the authors found after various experiments that this normalization always improved performance. This gave an output of shape (64,2048) which is the final output of the encoded network. This normalized encoded vector is sent to the projector network.
Step 1.3: The projector network is actually a MLP with one hidden layer of size 2048 and one output layer of size 128(as suggested in the paper, for our experiments we used different architecture for the projection network which will be explained in the next section.). This network will give an output of shape (64,128) which is at last normalized. The final normalized projected vector will have a shape as (64, 128). From here onwards the projected vector will be called z. Now the output of the projector network is sent to the Loss Function.
Step 1.4: At this step, the Supervised Contrastive Loss function is used to find the loss. We have already discussed this loss function in the above section. The output that we receive from the projector network is fed into this loss function. Let's look at the loss function again.
For each of the 64 values, we try to find out LisupL^{sup}_{i} and then add all of them together.
Let's see what is happening in LisupL^{sup}_{i}.
Inside LisupL^{sup}_{i} we find the inner product of ziz_{i} with every other vector in the batch but with some restrictions of course. These restrictions are applied with the help of some terms. Let's look at them:
- The term 1i≠j1_{i{\neq}j} restricts finding out the inner product of any vector with the same vector. i.e it never let the loss function calculate the dot product between z100z_{100} and z100z_{100} because it will not take us anywhere.
- The term 1yi=yj1_{{y_i}={y_j}} ensures that ziz_i and zjz_j are the vectors belonging to the same class. These are present in the numerator of the log term.
- The term 1i≠k1_{i{\neq}k} ensures that zkz_kand ziz_i are different vectors. i.e. zkz_k doesn't belong to the class of ziz_i and zjz_j. The vector zkz_k is present at the denominator of the loss term.
Now coming to the log term of LisupL^{sup}_{i}, the numerator and denominator have exp(zi∗zj/zk)exp(z_i*z_j/z_k) term. This exponential term ensures that the log argument goes no higher than 1.
Once the loss is calculated, the optimizer comes to action. It tries to minimize the loss. And the loss is minimized by maximizing the numerator and minimizing the denominator of the log term in the loss function. After the backpropagation, the model learns the parameter in such a way that it can place images belonging to the same class closer and those belonging to different class farther.
Once the training of this model is over we discard the projector network and use the trained encoder network for the second stage of the training.
Step 2.1: From here the second stage of the training begins where we try to train a classifier on top of the encoder network. At this stage, the projector network is discarded and only the encoder network is used which is frozen. One more dense layer is added next to the frozen encoder network with the size equal to the number of classes in the dataset.
The input of this new network will be the same dataset. We will preprocess the dataset. The same image augmentation policies can be used here for data augmentation. The authors reported in their paper that they got best results when they used the same augmentation in both the stages of the training.
Let's say the mini-batch size was 64. So the output of this network which is the final output will have a shape (64,#classes). Once we get the output, this is sent to the loss function.
Step 2.2: At this step, the loss is calculated for the second stage of the training. Here the standard Cross-Entropy loss function is used. The loss is then backpropagated in the network and the parameters are learnt. Notice that at this stage of training, only trainable parameters are the parameters at the final layer.
So after this whole discussion, we are good to go to the experiments we ran. In the next few sections, we will be talking about the 3 experiments done by us to check the efficiency of this methodology.
Experiment 3
Dataset
For this experiment, a subset of ImageNet is used which can be found here. This dataset has 1250 training images and 250 test images belonging to five classes.
##] Dataset Preprocessing
All the images are normalized and resized to (128,128).
Data Augmentation
This experiment was done both with and without augmentation and the results for both the cases are discussed below. We didn't use AutoAugment as proposed in the paper. We used the following few policies.
- Applying random brightness.
- Applying random saturation
- Applying random contrast
- Applying random hue
Model Architecture
For the 1st1^{st} stage of training, the architecture of the encoder network was similar to the encoder network used in the other two experiments. However, the projector network had 256 neurons instead of 128 neurons.
For the 2nd2^{nd} stage of training, the Dense layer had 5 neurons as the dataset had 5 classes. We used the softmax activation function in this Dense layer.
Training Details
We tested this methodology with
- different optimizers such as SGD, RMSprop & Adam
- with a fixed learning rate and learning rate decay function
- with and without augmentations. We ran the model for different number of epochs.
Results
The below table shows the Supervised Contrastive Loss(got in stage 1) and final training and validation accuracy (got in stage 2) with different optimizers, learning rate strategy and augmentation.
| Optimizer + learning rate strategy + with or without augmentation | SCL | Training Accuracy | Validation Accuracy |
|---|---|---|---|
| SGD + lr decayed function + without augmentation | 0.00306 | 0.5832 | 0.4160 |
| SGD + fixeed lr + without augmentation | 0.1572 | 0.1976 | 0.2000 |
| SGD + lr decayed function + with augmentation | 0.159 | 0.172 | 0.184 |
| Adam+ lr decayed function + without augmentation | 0.0104 | 0.984 | 0.6240 |
| Adam + fixeed lr + without augmentation | 0.0094 | 0.9808 | 0.6400 |
| Adam + lr decayed function + with augmentation | 0.00464 | 0.7544 | 0.6560 |
| RMSprop + lr decayed function + without augmentation | 00447 | 0.992 | 0.6920 |
| RMSprop + fixeed lr + without augmentation | 0.0100 | 0.9664 | 0.6360 |
| RMSprop + lr decayed function + with augmentation | 0.02736 | 0.657 | 0.6120 |
Let's look at the graph of all these results in the observation section which we logged using wandb.
Visualizing the learnt embeddings
Having seen the brilliant performance of SCL, let's look at the clusters formed at stage 1.
a). Pets Dataset::
Here we are visualizing the encoded and projected vectors