In the previous report, we built an image classification pipeline using PyTorch Lightning. In this report, we will extend the pipeline to perform transfer learning with PyTorch Lightning.
Transfer Learning is a technique where the knowledge learned while training a model for "task" A and can be used for "task" B. Here A and B can be the same deep learning tasks but on a different dataset.
The first few hidden layers of a deep neural network model learn general abstract features about the dataset. The later layers have task-specific knowledge. These learned features make for a good weight initialization.
Training neural networks from scratch can be expensive. Transfer learning has worked wonderfully so far by reducing the number of training hours and increasing the model's accuracy.
Usually dataset does not have many samples to learn from. Even data augmentation can push the accuracy so far. Transfer learning can come to the rescue.
The most common workflow to use transfer learning in the context of deep learning is:
Take layers from a previously trained model. Usually, these models are trained on a large dataset.
Freeze them to avoid destroying any of the information they contain during future training rounds.
Add some new, trainable layers on top of the frozen layers. These new features will learn task-specific features. Train the new layers on your dataset.
An optional step is fine-tuning, which consists of unfreezing the entire model you obtained above and re-training it on the new data with a very low learning rate. The entire model can be unfrozen partially or in parts(unfreeze a few and train and so on).
This report requires some familiarity with PyTorch Lightning for the image classification task. You can check out my previous post on Image Classification using PyTorch Lightning to get started. Let us train a model with and without transfer learning on the Caltech-101 dataset and compare the results using Weights and Biases.
We will be using the Caltech-101 dataset to train our image classifier. It consists of pictures of objects belonging to 101 classes, plus one
background clutter class. Each image is labeled with a single object and contains roughly 40 to 800 images per class, totaling around 9k images. Images are of variable sizes, with typical edge lengths of 200-300 pixels.
With PyTorch Lighting's DataModule, one can define the download logic, preprocessing steps, augmentation policies, etc., in one class. It organizes the data pipeline into one shareable and reusable class. Learn more about DataModule here.
class Caltech101DataModule(pl.LightningDataModule): def __init__(self, batch_size, data_dir: str = './'): super().__init__() self.data_dir = data_dir self.batch_size = batch_size # Augmentation policy for training set self.augmentation = transforms.Compose([ transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)), transforms.RandomRotation(degrees=15), transforms.RandomHorizontalFlip(), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) ]) # Preprocessing steps applied to validation and test set. self.transform = transforms.Compose([ transforms.Resize(size=256), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) ]) self.num_classes = 102 def prepare_data(self): # source: https://figshare.com/articles/dataset/Caltech101_Image_Dataset/7007090 url = 'https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/12855005/Caltech101ImageDataset.rar' # download download_url(url, self.data_dir) # extract patoolib.extract_archive("Caltech101ImageDataset.rar", outdir=self.data_dir) def setup(self, stage=None): # build dataset caltect_dataset = ImageFolder('Caltech101') # split dataset self.train, self.val, self.test = random_split(caltect_dataset, [6500, 1000, 1645]) self.train.dataset.transform = self.augmentation self.val.dataset.transform = self.transform self.test.dataset.transform = self.transform def train_dataloader(self): return DataLoader(self.train, batch_size=self.batch_size, shuffle=True) def val_dataloader(self): return DataLoader(self.val, batch_size=self.batch_size) def test_dataloader(self): return DataLoader(self.test, batch_size=self.batch_size)
LightningModule is another class that organizes model definition, training, validation, and testing code in one place. Learn more about this here.
** Let us look at the model definition to see how transfer learning can be used with PyTorch Lightning.**
LitModel class, we can use the pre-trained model provided by Torchvision as a feature extractor for our classification model. Here we are using ResNet-18. A list of pre-trained models provided by PyTorch Lightning can be found here.
pretrained=True, we use the pre-trained weights; otherwise, the weights are initialized randomly.
.eval()is used, then the layers are frozen.
Linearlayer is used as the output layer. We can have multiple layers stacked over the
import torchvision.models as models class LitModel(pl.LightningModule): def __init__(self, input_shape, num_classes, learning_rate=2e-4): super().__init__() # log hyperparameters self.save_hyperparameters() self.learning_rate = learning_rate self.dim = input_shape self.num_classes = num_classes # transfer learning if pretrained=True self.feature_extractor = models.resnet18(pretrained=True) # layers are frozen by using eval() self.feature_extractor.eval() # freeze params for param in self.feature_extractor.parameters(): param.requires_grad = False n_sizes = self._get_conv_output(input_shape) self.classifier = nn.Linear(n_sizes, num_classes) # returns the size of the output tensor going into the Linear layer from the conv block. def _get_conv_output(self, shape): batch_size = 1 tmp_input = torch.autograd.Variable(torch.rand(batch_size, *shape)) output_feat = self._forward_features(tmp_input) n_size = output_feat.data.view(batch_size, -1).size(1) return n_size # returns the feature tensor from the conv block def _forward_features(self, x): x = self.feature_extractor(x) return x # will be used during inference def forward(self, x): x = self._forward_features(x) x = x.view(x.size(0), -1) x = F.log_softmax(self.classifier(x), dim=1) return x . .
I have trained the defined model from scratch(ResNet-18 backbone trained from scratch) and with transfer learning(ResNet-18 pre-trained on ImageNet). Every other hyperparameter remains the same. Check out the colab notebook to reproduce the results.
Note that to train from scratch, you will have to pass
pretrained arguments as
False and comment out the line,
Let us compare the metrics to see the magic of transfer learning.
with-tl) performs far better than the model trained from scratch(
with-tlis ~93%, while that of
Let's look at the predictions of the model trained with transfer learning. Since the test accuracy is ~93% we will expect the predictions to be spot on.
I have implemented a custom callback for PyTorch Lightning to log the predictions. Learn more about the callback here.
Finally, let's compare both the models using the precision-recall curve. You can learn more about it here.
We can clearly see that the model trained with transfer learning is more towards the top-right corner of the Average Precision-Recall curve.
I hope you find this report helpful. I will encourage you to play with the code and train an image classifier with a dataset of your choice from scratch and using transfer learning.
To learn more about transfer learning check out these resources: