Transfer learning using Pytorch
Today, we'll fine-tune a VGG model with PyTorch on the Caltech-101 dataset while showcasing how leveraging a pre-trained model can boost performance on new tasks
Created on February 20|Last edited on May 21
Comment
Introduction
Transfer learning is a technique where a pre-trained model is repurposed for a new task. By way of example, say you already know how to cook Italian food well and want to branch out in Chinese cooking. You don't need to relearn how to dice an onion, so the skills you already have can be transferred to Chinese cuisine. You'll need a few new techniques but you've already got a solid baseline to start with.
That's transition learning in a nutshell.

Source: Author
With transfer learning, you take the pre-trained network (Italian cooking) and fine-tune it to learn Chinese dishes. You essentially leverage the knowledge it gained from its previous tasks to help it learn the new task. This saves a ton of time and computational resources compared to training a network from scratch.
PyTorch is the toolkit we'll use to implement this. Think of it as your kitchen where you have all the tools and ingredients to cook up your machine learning models. PyTorch is super popular because it's beginner-friendly, flexible, and has great community support. Plus, it's got tons of handy features to make building and training neural networks a breeze.
Fundamentals of transfer learning
Transfer learning is a technique where knowledge gained from solving one problem is applied to a different, but related, problem.
Having a pre-trained model is crucial to implement transfer learning. This model leverages learned representations of features from large datasets, and can be fine-tuned for a new task. This approach is particularly useful when the new task has limited labeled data or computational resources.
By adapting the learned features rather than starting from scratch, transfer learning accelerates model development and enhances performance, especially when tasks share similarities in data characteristics or patterns.
How transfer learning works
Here's a simplified explanation of how transfer learning typically operates:

1. Pre-trained model
We start with a pre-trained neural network model that has been trained on a large dataset for a specific task, such as image classification using millions of labeled images.
2. Feature extraction
Next, we extract the learned features from the pre-trained model, which represent high-level patterns in the data relevant to the original task. These features are typically captured in the parameters of the model's layers.
3. Transfer
Next, we repurpose the pre-trained model for a new task by removing the original output layer (which was specific to the source task) and replacing it with a new output layer tailored to the target task. This new output layer is randomly initialized with weights.
4. Fine-tuning
After repurposing the pre-trained model for the new task, we proceed to train the adapted model on the target task's dataset. This involves fine-tuning the parameters of the pre-trained model and training any additional layers added for the target task simultaneously.
Fine-tuning adjusts the model's parameters, either across all layers or only in certain layers, to better adapt the learned features to the characteristics of the new task. During training, the model is exposed to the target task's dataset, and its parameters are adjusted through backpropagation to minimize the loss between predicted outputs and ground truth labels.
5. Evaluation
Once the model has been trained, it is evaluated on a separate validation or test dataset to assess its performance on the target task. This evaluation step helps determine how well the model generalizes to unseen data and provides insights into its effectiveness in solving the specific task at hand.
Performance metrics such as accuracy, precision, recall, or F1 score are typically used to quantify the model's performance and guide further refinement or deployment decisions.
By following these steps, transfer learning enables practitioners to leverage the knowledge encoded in pre-trained models to accelerate model development and improve performance on new tasks.
Types of transfer learning: Inductive, Transductive, and Unsupervised
Transfer learning can be categorized into various types based on the source and target domains, as well as the availability of labeled data. Here are three common types of transfer learning:
Inductive transfer learning
In inductive transfer learning, knowledge from a source domain with a labeled dataset is transferred to a target domain with a similar but different task.
The key characteristic of inductive transfer learning is that the source and target domains have different distributions of data, but there is a common underlying structure or features that make the transfer possible. The goal is to learn representations that are generalizable across domains, allowing the model to perform well on the target task despite differences in data distribution.
Transductive Transfer Learning
Transductive transfer learning involves transferring knowledge from a labeled source domain to a target domain with unlabeled data.
Unlike inductive transfer learning, transductive transfer learning does not assume differences in data distribution between the source and target domains. Instead, it aims to leverage the labeled data from the source domain to improve performance on the target task by exploiting the shared structure or relationships within the data.
Unsupervised Transfer Learning
Unsupervised transfer learning focuses on transferring knowledge from a labeled source domain to an unlabeled target domain.
This type of transfer learning is particularly useful when labeled data for the target task is scarce or unavailable. By leveraging the labeled data from the source domain, the model learns to extract useful features or representations that can be applied to the target domain without the need for labeled data.
When to use transfer learning
Transfer learning is a powerful machine learning approach that can be considered in several scenarios:
Insufficient labeled data
When you have a target task but lack a large enough labeled dataset to train a model effectively, transfer learning allows you to leverage knowledge from a related task with ample data.
Saving time and resources
Training a model from scratch can be computationally expensive and time-consuming. By using pre-trained models and adapting them to your specific needs, you can significantly reduce both training time and computational costs.
Improving performance
In cases where your target dataset is small, a model trained from scratch may not perform well due to overfitting. Transfer learning can improve generalization by introducing knowledge from a related, well-understood task.
Cross-domain applications
When you want to apply insights from one domain to another, such as applying image recognition techniques in medical imaging, transfer learning can bridge the gap between different data distributions and feature spaces.
Benchmarking
When introducing a new problem or dataset, starting with transfer learning can provide a benchmark for performance. This helps in understanding the difficulty of the task and setting realistic expectations for model improvement.
Understanding PyTorch
PyTorch, developed by Meta AI, is a machine learning framework renowned for its flexibility, ease of use, and native support for GPU acceleration, making it an excellent choice for deep learning projects.
Compared to its rival Tensorflow, Pytorch is more flexible and intuitive than TensorFlow's static graphs, making it better for research. Moreover, while tensorflow might offer better scalability, PyTorch is often preferred for its ease of use and dynamic nature, making it ideal for experimental projects and prototyping. A few advantages of PyTorch include:
- User-friendly: PyTorch is intuitive for Python users, simplifying the deep learning model development process.
- Dynamic computation graphs: Allows for flexibility in model design by modifying graphs on the fly.
- Strong GPU acceleration: Efficiently utilizes GPU hardware for faster computation.
- Rich ecosystem: Offers a wide range of pre-built models and training tools.
Overview of Weights & Biases
Weights & Biases (W&B) is a machine learning platform designed to help track experiments, visualize data, and share insights within the ML community. It provides tools for logging hyperparameters, outputs, and results from ML models, making it easier to monitor and compare different experiments and models. W&B is designed to integrate seamlessly into existing ML workflows and supports a wide range of ML frameworks, including PyTorch, TensorFlow, and Keras.
Getting Started with Pytorch and Transfer Learning
1. Import the Libraries
import pandas as pdimport osimport torchimport torchvisionimport torchvision.transforms as transformsfrom torch.utils.data import DataLoader, random_splitfrom torchvision import modelsfrom torchvision import datasets, transformsimport torch.nn as nnimport torch.optim as optimimport zipfile!pip install wandbimport wandbos.environ['CUDA_LAUNCH_BLOCKING'] = "1"
2. Download and extract the Dataset
The Caltech 101 dataset is a widely-used benchmark dataset in computer vision research. It consists of 101 object categories, each containing around 40 to 800 images.
The images vary significantly in terms of object viewpoint, pose, scale, lighting conditions, background clutter, and intra-class variation. The dataset was originally compiled by researchers at the California Institute of Technology (Caltech). It's commonly used for tasks such as object recognition, classification, and detection. The dataset is accessed through Kaggle as a zip file which is then extracted.
from google.colab import drivedrive.mount('/content/drive')!pip install kaggleos.environ['KAGGLE_CONFIG_DIR'] = '/content/drive/MyDrive/kaggle'!kaggle datasets download -d imbikramsaha/caltech-101file_path = '/content/caltech-101.zip'with zipfile.ZipFile(file_path, 'r') as zip_ref:zip_ref.extractall('/content/kaggle/')
3. Preprocess the dataset
Since the Caltech data has images of size 300 x 200 pixels we will resize them to 224 X 224, which is the default requirement for the VGG model.
The data is normalized and split into test (20%) and train (80%) sets and loaded into the data loader provided by Pytorch. This helps in efficient and flexible data loading for input to the model.
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])dataset = datasets.ImageFolder(root='/content/kaggle/caltech-101', transform=transform)total_size = len(dataset)train_size = int(0.8 * total_size)test_size = total_size - train_sizetrain_dataset, test_dataset = random_split(dataset, [train_size, test_size])train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
4. Setup Weights & Biases
wandb.init(project="transfer_learn_pytorch", name="caltech101_classification_2")
5. Load and modify the pre-trained VGG model
We load a pre-trained VGG model by setting pre-trained = True and freeze the feature layers of the model to prevent training of these layers. Next, we add classifier layers with output layers having 102 output neurons accounting for 0 to 101 classes of Caltech.
vgg = models.vgg16(pretrained=True)for param in vgg.features.parameters():param.requires_grad = Falsenum_features = vgg.classifier[6].in_featuresvgg.classifier[6] = nn.Linear(num_features, 102) # Change the last layerdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)vgg.to(device)
6. Transfer learning
To implement transfer learning we set the config variable for wandb that holds and saves hyperparameters and inputs.
config = wandb.configconfig.learning_rate = 0.001config.momentum = 0.9config.epochs = 5config.batch_size = 32
We finally set the loss and optimizer, define the training function and log the performance metrics to wandb which will help us evaluate the model’s performance for the given hyperparameters.
criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(vgg.classifier.parameters(), lr=0.001, momentum=0.9)def train_model(model, criterion, optimizer, num_epochs):for epoch in range(num_epochs):model.train()running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()epoch_loss = running_loss / len(train_loader)print(f"Epoch {epoch+1}, Loss: {epoch_loss}")wandb.log({"epoch": epoch + 1, "loss": epoch_loss})train_model(vgg, criterion, optimizer, num_epochs=5)
7. Evaluating performance metrics
To evaluate the model, we define a function that disables gradient calculation for inference set the model directly to evaluation mode, and report the loss.
def evaluate_model(model, criterion, loader):model.eval()total_loss = 0.0total = 0with torch.no_grad():for inputs, labels in loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)total_loss += loss.item() * inputs.size(0)total += inputs.size(0)average_loss = total_loss / totalreturn average_loss
a. Pre-trained Model
The code below evaluates the pre-trained model by first loading it and then calling the evaluate_model function.
pretrained_vgg = models.vgg16(pretrained=True).to(device)pretrained_test_loss = evaluate_model(pretrained_vgg, criterion, test_loader)wandb.log({"pretrained_test_loss": pretrained_test_loss})
b. Fine-tuned Model
The code below evaluates the fine-tuned model by calling the evaluate_model function.
test_loss = evaluate_model(vgg, criterion, test_loader)wandb.log({"fine_tuned_test_loss": test_loss})
c. Results
The graphs below show the results logged in wandb. The first graph shows that as the number of epochs increases the training loss decreases, helping us evaluate that the model can be fine-tuned for more epochs for improved accuracy. The next two graphs show the fine-tuned and the pre-trained model loss. This helps us evaluate that the fine-tuned model outperforms giving a loss of 0.277 as compared to the pretrained loss of 16.07.

Source: Author
8. Evaluating predicted images
We also evaluate the performance of the model by visualizing the predictions in terms of actual class against predicted class. This helps us better understand how the model’s learning is distributed over the classes. We log one batch of images for simplicity i.e. 32.
def log_predictions(model, loader, prefix=""):model.eval()with torch.no_grad():for inputs, labels in loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)wandb.log({f"{prefix}examples": [wandb.Image(x.cpu(), caption=f"Pred:{pred.item()}, Label:{y.item()}")for x, pred, y in zip(inputs, preds, labels)]})break
a. Pretrained Model
We log predictions from the pre-trained model.
log_predictions(pretrained_vgg, test_loader, prefix="pretrained_")
b. Fine-Tuned Model
We log predictions from the fine-tuned model.
log_predictions(vgg, test_loader, prefix="fine_tuned_")
c. Results
The images below show the classification results on the Caltech dataset for the pretrained and the fine-tuned model. ‘Pred’ represents the class predicted by the model while ‘label’ represents the actual class. If the ‘pred’ and ‘label’ match, the predicted class is correctly identified by the model.
Pre-Trained Model Examples

Source: Author
Fine-Tuned Model Examples

Source: Author
As demonstrated, the fine-tuned model accurately classifies the majority of the classes for which the pretrained model is mostly incorrect.
Conclusion
Transfer learning using PyTorch has proven to be a game-changer in the field of machine learning and deep learning. By leveraging pre-trained models and adapting them to new tasks, developers and researchers can save valuable time and resources, overcome data scarcity issues, and achieve state-of-the-art performance across a variety of domains.
In this exploration of transfer learning, we delved into using the Caltech dataset to finetune VGG model, pretrained on Imagenet using Pytorch. The results logged in wandb helped us efficiently conclude that finetuning outperforms pre-trained models when addressing task-specific challenges. However leveraging the features from the pre-trained models saves us valuable time and resources.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.