Transfer Learning Using PyTorch Lightning

A brief introduction to transfer learning using PyTorch Lightning "style guide". Made by Ayush Thakur using Weights & Biases
Ayush Thakur

In the previous report, we built an image classification pipeline using PyTorch Lightning. In this report, we will extend the pipeline to perform transfer learning with PyTorch Lightning.

⚡ Introduction

Transfer Learning is a technique where the knowledge learned while training a model for "task" A and can be used for "task" B. Here A and B can be the same deep learning tasks but on a different dataset.



The most common workflow to use transfer learning in the context of deep learning is:

This report requires some familiarity with PyTorch Lightning for the image classification task. You can check out my previous post on Image Classification using PyTorch Lightning to get started. Let us train a model with and without transfer learning on the Caltech-101 dataset and compare the results using Weights and Biases.

🔧 The Dataset

We will be using the Caltech-101 dataset to train our image classifier. It consists of pictures of objects belonging to 101 classes, plus one background clutter class. Each image is labeled with a single object and contains roughly 40 to 800 images per class, totaling around 9k images. Images are of variable sizes, with typical edge lengths of 200-300 pixels.

With PyTorch Lighting's DataModule, one can define the download logic, preprocessing steps, augmentation policies, etc., in one class. It organizes the data pipeline into one shareable and reusable class. Learn more about DataModule here.

class Caltech101DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './'):
        self.data_dir = data_dir
        self.batch_size = batch_size

        # Augmentation policy for training set
        self.augmentation = transforms.Compose([
              transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
              transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        # Preprocessing steps applied to validation and test set.
        self.transform = transforms.Compose([
              transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        self.num_classes = 102

    def prepare_data(self):
        # source:
        url = ''
        # download
        download_url(url, self.data_dir)
        # extract 
        patoolib.extract_archive("Caltech101ImageDataset.rar", outdir=self.data_dir)

    def setup(self, stage=None):
        # build dataset
        caltect_dataset = ImageFolder('Caltech101')
        # split dataset
        self.train, self.val, self.test = random_split(caltect_dataset, [6500, 1000, 1645])
        self.train.dataset.transform = self.augmentation
        self.val.dataset.transform = self.transform
        self.test.dataset.transform = self.transform
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

🎺 LightingModule - Define the System

LightningModule is another class that organizes model definition, training, validation, and testing code in one place. Learn more about this here.

** Let us look at the model definition to see how transfer learning can be used with PyTorch Lightning.**

In the LitModel class, we can use the pre-trained model provided by Torchvision as a feature extractor for our classification model. Here we are using ResNet-18. A list of pre-trained models provided by PyTorch Lightning can be found here.

import torchvision.models as models

class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        # log hyperparameters
        self.learning_rate = learning_rate
        self.dim = input_shape
        self.num_classes = num_classes
        # transfer learning if pretrained=True
        self.feature_extractor = models.resnet18(pretrained=True)
        # layers are frozen by using eval()
        # freeze params
        for param in self.feature_extractor.parameters():
            param.requires_grad = False
        n_sizes = self._get_conv_output(input_shape)

        self.classifier = nn.Linear(n_sizes, num_classes)
    # returns the size of the output tensor going into the Linear layer from the conv block.
    def _get_conv_output(self, shape):
        batch_size = 1
        tmp_input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(tmp_input) 
        n_size =, -1).size(1)
        return n_size
    # returns the feature tensor from the conv block
    def _forward_features(self, x):
        x = self.feature_extractor(x)
        return x
    # will be used during inference
    def forward(self, x):
       x = self._forward_features(x)
       x = x.view(x.size(0), -1)
       x = F.log_softmax(self.classifier(x), dim=1)
       return x

🎨 Results

I have trained the defined model from scratch(ResNet-18 backbone trained from scratch) and with transfer learning(ResNet-18 pre-trained on ImageNet). Every other hyperparameter remains the same. Check out the colab notebook to reproduce the results.

Note that to train from scratch, you will have to pass pretrained arguments as False and comment out the line, self.feature_extractor.eval().

Let us compare the metrics to see the magic of transfer learning.

Let's look at the predictions of the model trained with transfer learning. Since the test accuracy is ~93% we will expect the predictions to be spot on.

I have implemented a custom callback for PyTorch Lightning to log the predictions. Learn more about the callback here.

Finally, let's compare both the models using the precision-recall curve. You can learn more about it here.

We can clearly see that the model trained with transfer learning is more towards the top-right corner of the Average Precision-Recall curve.

📡 Conclusion and Resources

I hope you find this report helpful. I will encourage you to play with the code and train an image classifier with a dataset of your choice from scratch and using transfer learning.

To learn more about transfer learning check out these resources: