Introduction

In the past few years, we have seen a tremendous amount of development in self-supervised learning specifically for computer vision-related problems. While the field of natural language processing has been benefitting from the virtues of self-supervised learning for a long time, but it was not that long computer vision systems started to see the real impact of self-supervised learning paradigms. Works like MoCo, PIRL demonstrated the kind of benefits self-supervised systems can bring to the table for computer vision-based problems.

This year, Chen at al. published their paper A Simple Framework for Contrastive Learning of Visual Representations (SimCLR for short), which presented a simpler yet effective framework for training computer vision-based models in self-supervised ways.

In this report, I will present some findings from a minimal implementation of SimCLR I did with a subset of the ImageNet dataset.

Check out the code on GitHub → | Read The Paper →

Introduction

The Idea of Self-Supervision in Computer Vision

It was when I read this amazing introductory post on self-supervised learning by Jeremy Howard I could really appreciate the kind of benefits one could get out of self-supervised learning systems for computer vision-based problems. If you are already familiar with the idea of word embeddings then extending that understanding for computer vision-based scenarios might be helpful.

For training models in a self-supervised way with completely unlabeled data, one needs to first frame a supervised learning task (also known as pre-text task) with that unlabeled data. For example, for training word embeddings on a large corpus like Wikipedia, you might go with the task of predicting the next word given a sequence of words. So, how you might extend this idea when it comes to dealing with unlabeled images?

It's important to note that the contrastive prediction task should be not too easy nor difficult and it should help the model develop an understanding of the given data (think of how embeddings capture the semantic relationship between related words). For vision-based problems in order for the model to develop that sort of understanding the following tasks have known to work:

You can know more about it from here.

The idea of self-supervision in computer vision

So, the premise here is once effective representations are learned out of the unlabeled images that reflect more on the image understanding part, those representations can be used to downstream tasks such as image classification effectively. Once a model on a given pre-text task has been trained, it can be used without the upper layers that might have been specific for the pre-text task.

So, an obvious idea here would be to take image datasets like ImageNet, OpenImages (without their labels), train a model using the above-mentioned pre-text tasks, unplug the upper layers of the model and use the learned representations for downstream tasks.

Now that we have a brief idea of how self-supervised learning paradigms can be applied to computer vision, let us discuss what loss functions can be used to help a model learn effective representations in self-supervised ways.

Measuring Loss in Self-Supervised Training Methods

While training a model for the pre-text tasks mentioned above, we may often have to mess with the images i.e. colorizing the images, omit certain pixels from the images, change the orientation etc. While doing so to ensure a model is learning consistent representations we need to devise a way to tell the model that the following kinds of pairs of images are similar to each other (comes from the SimCLR paper) -

Contrastive loss is used to tell a model that the above kinds of pairs of images are similar to each other. Why this might be useful?

Our mind can tell all of the above images are similar to each other because we have an outstanding understanding of the original image in the first place. We trust that no matter how small/big mess the original image is imposed upon, our mind will still be able to tell that essentially all the above images are similar to each other.

So, the contrastive loss helps a model to remain consistent when learning the representations even when the supplied data is messed up. Now, of course, it will also need to be able to distinguish when the images are actually different, consider the following figure for example (comes from here).

Note that this contrastive loss can be added as a regularization term to the final loss function or it can be used as a stand-alone loss function as well. It depends on the pre-text task. But studies have shown that adding contrastive loss helps a model to learn the representations more effectively.

So, at this point, two things might really help us in developing a good self-supervised learning system -

SimCLR

In the SimCLR paper, the authors first show that data augmentation has a significant role to play when devising the pre-text task. They first randomly sample a mini-batch of unlabeled images, apply a stochastic data augmentation policy (more on this in a bit), and basically train a model to bring similar images together using the so-called Normalized Temperature-Scaled Cross-Entropy Loss (NT-XEnt loss).

The following figure summarized the SimCLR framework (taken from the paper) -

where,

And at the same time, the following ones are dissimilar (comes from here) -

In the next section, we will discuss one of the major components of the SimCLR framework - data augmentation!

Data Augmentation

Data augmentation is extremely important in order for the SimCLR framework to work well. In the paper, the authors showed how data augmentation can actually help a model to learn about the contrastive representations of the given images. The data augmentation policy proposed in the paper consists of random cropping, random flipping, color distortions, and gaussian blurs applied randomly with different probabilities. Furthermore, the color distortions have a strength hyperparameter which denotes the degree to which the distortions should be applied.

The authors presented a detailed ablation study in the paper about these different augmentation policies and the effect they have on the pre-text task. They state the following points about the role of data augmentation in self-supervised learning (for computer vision) in general -

It's important to note that the augmented images developed from a particular image are considered as positive examples with respect to that image. Any image taken from that set of augmented images when paired with the original image makes a positive pair. The following is an example of a positive pair (comes from here) -

While the following examples denote the negative examples (comes from here) -

Returning to the loss function

The authors used the NT-XEnt loss pretty cleverly. As opposed to using any memory queues or memory banks (as used in MoCo and PIRL respectively), they let the loss function take care of segregating the positive pairs from the negative pairs. The loss function looks like so -

A detailed explanation of the loss function is available here. For measuring the similarity between the non-linear projections ($z$) of two images the authors used cosine similarity whereas once can use dot product as well.

The Framework Until now

The following code listing shows the SimCLR framework roughly -

def train_step():
    with tf.GradientTape() as tape:
        # apply data augmentation
        xis = data_aug(a)
        xjs = data_aug(b)
        
        # run forward passes through the encoder network (with non-linear projection)
        zis = model(xis)
        zjs = model(xjs)

        # normalize projection feature vectors
        zis = tf.math.l2_normalize(zis, axis=1)
        zjs = tf.math.l2_normalize(zjs, axis=1)

        # calculate loss
        loss = nt_xent(zis, zjs)

    # Calculate the Gradients and Backprop
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    return loss

It's important to note that the augmentation operations from the whole data augmentation policy are applied at random. I used the following augmentation policy for the dataset used:

While using custom datasets, it's important to experiment with the different augmentation operators to create an augmentation policy that makes the prediction task neither too easy nor too difficult for a model to learn.

Here's how the encoder network looks like with non-linear projections -

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_2 (InputLayer)         [(None, 224, 224, 3)]     0
_________________________________________________________________
resnet50 (Model)             (None, 7, 7, 2048)        23587712
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0
_________________________________________________________________
dense (Dense)                (None, 256)               524544
_________________________________________________________________
activation (Activation)      (None, 256)               0
_________________________________________________________________
dense_1 (Dense)              (None, 128)               32896
_________________________________________________________________
activation_1 (Activation)    (None, 128)               0
_________________________________________________________________
dense_2 (Dense)              (None, 50)                6450
=================================================================
Total params: 24,151,602
Trainable params: 24,098,482
Non-trainable params: 53,120

Enough talking, show me the results!

Talking About Results

Talking about results

Linear Evaluation of the Learned Representations

In the case of linear evaluation, we train a linear classifier with very little labeled data (10% is typical). The linear classifier here looks like so -

linear_model = Sequential([Dense(5, input_shape=(features, ), activation="softmax")])

Here's how the network looks like without any non-linear projections -

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
resnet50 (Model)             (None, 7, 7, 2048)        23587712  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
=================================================================
Total params: 23,587,712
Trainable params: 23,534,592
Non-trainable params: 53,120

Here's the general workflow in code for linear evaluation -

# Encoder model with no projection
projection = Model(resnet_simclr.input, resnet_simclr.layers[-6].output)

# Extract train and test features
train_features = projection.predict(X_train)
test_features = projection.predict(X_test)

# Linear evaluation
linear_model = get_linear_model(2048)
linear_model.compile(loss="sparse_categorical_crossentropy", metrics=["accuracy"],
                     optimizer="adam")
history = linear_model.fit(train_features, y_train_enc,
                 validation_data=(test_features, y_test_enc),
                 batch_size=64,
                 epochs=35)

The following plot represents the linear evaluation performance taken with different levels of projections -

Linear evaluation of the learned representations

From the plots, we see that the representations from the network that did not have any non-linear projections yielded the best results and in our case, it also converged faster. You see the uneven steps across different runs because I set up an EarlyStopping callback to prevent overfitting. In Section 4.2 of the paper, the authors present more commentary on the use of non-linear projections for improving the quality of the representations.

Below you can see the lower-dimensional versions of the learned representations taken at different levels of non-linear projections.

Visualization of the Learned Representations

Visualization of the learned representations

We can already see some sort clustered formations in the above plots which is an indication that our model is indeed learning to group similar images together. Finally, we compare the performance of this framework with a supervised classifier trained on the full training dataset.

Training With the Full Training Dataset in Fully Supervised Manner

In this setup, I did not do any data augmentation and I followed the traditional image classification pipeline. Here's how the image classification looks like in this case -

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
resnet50 (Model)             (None, 7, 7, 2048)        23587712  
_________________________________________________________________
global_average_pooling2d_1 ( (None, 2048)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 256)               524544    
_________________________________________________________________
activation_2 (Activation)    (None, 256)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 5)                 1285      
=================================================================
Total params: 24,113,541
Trainable params: 24,060,421
Non-trainable params: 53,120

Below, I present the performance gained from this model -

Training with the full training dataset in fully supervised manner

We see that compared to this model's performance, the linear model trained using the representations learned using the SimCLR framework performs quite close even with 10% of (labeled) training data. I used Early Stopping in order to prevent overfitting that is why we see different numbers of epochs for the two different runs (supervised-training and linear-eval-no-projections).

Further Notes and Conclusion

As mentioned in the paper, SimCLR benefits from larger data and longer training. This is why it achieves great performance on transfer learning and fine-tuning tasks when trained with the ImageNet dataset.

When plugging in your custom dataset, here are a couple of things to keep in mind -

Thanks to the following resources that I studied to strengthen my understanding of the framework -

Thanks to the ML-GDE program for granting me GCP credits that were used for running a number of different experiments for this report. If you have any feedback to share don't forget to reach out via Twitter (@RisingSayak).