Skip to main content

Leveraging Pre-Trained Models for Image Classification

In this article, we fine-tune a pre-trained model on a new classification dataset, to understand how well transfer learning helps the model train on new data.
Created on May 10|Last edited on May 11

In this article, we will explore how to fine-tune an image classifier on a new dataset. We will use pre-trained image models from PyTorch Image Models (timm) and fine-tune using fastai recent integration with timm.
The main idea is to understand how well transfer learning helps the model train on a new dataset.
Here's what we'll be covering:

Table of Contents



First, though, let's level-set on a couple of things:
  • Pretrained model: A pre-trained model is a model that is already trained on a large dataset. For computer vision, this is frequently ImageNet. We now have bigger versions like ImageNet 21k. For PyTorch users, the default torchvision pre-trained catalog is very limited, and often users want to try the latest backbones. To the rescue, we have timm. This little library, created and maintained by Ross Wightman contains more than 500 pre-trained models with the latest state-of-the-art techniques! He has improved the baseline for our trusty Resnet 50 and dramatically improved Vision Transformers training methods, saving computing and time for future researchers.
  • Fine-tune a model (aka Transfer Learning): The main idea here is to train a model without starting from scratch. Instead of initializing the model with random weights, you start from a model pre-trained on another task (probably ImageNet). If your dataset is not very different from the original one used to pre-train the model, you will benefit greatly from all the features that the model already knows. Even if your data is completely different (for instance, medical imaging like X-Rays or MRIs), pre-training on ImageNet helps the model to converge faster and to better results, ref.
Check Lesson 2 from the fastai course to get more details on transfer learning and fine tuning
💡

A Simple Example

Let's grab a copy of the Oxford Pets dataset. For one, if we look at images for the rest of this report, they might as well be pets. For another, this is actually a challenging task. The dataset itself covers 37 different breeds. That said:
The visual problem is very challenging as these animals, particularly cats, are very deformable, and there can be pretty subtle differences between the breeds.
Additionally, this dataset is similar to ImageNet because some of the breeds of cats and dogs are contained in ImageNet, so the fine-tuning should work.


Fine-Tuning our Image Model

We can directly grab this dataset from fastai. The code:
from fastai.vision.all import *

def get_pets(batch_size, img_size, seed):
"The dog breeds pets datasets"
dataset_path = untar_data(URLs.PETS)
files = get_image_files(dataset_path/"images")
dls = ImageDataLoaders.from_name_re(dataset_path, files,
pat=r'(^[a-zA-Z]+_*[a-zA-Z]+)',
valid_pct=0.2,
seed=seed,
bs=batch_size,
item_tfms=Resize(img_size))
return dls
Next, let's compare the model performance after training for 5 epochs, one starting from a pre-trained model from ImageNet and then without. The difference in error rate is enormous 😱
config = SimpleNamespace(
batch_size=64,
img_size=224,
seed=42,
pretrained=False,
model_name="regnetx_040",
epochs=5)

def train(config):
"Train the model using the supplied config"
dls = get_pets(config.batch_size, config.img_size, config.seed)
with wandb.init(project=PROJECT, group=GROUP, job_type=JOB_TYPE, config=config):
cbs = [MixedPrecision(), WandbCallback(log_preds=False)]
learn = vision_learner(dls, config.model_name, metrics=error_rate,
cbs=cbs, pretrained=config.pretrained)
learn.fine_tune(config.epochs)

train(config)
The pre-trained model performs the best 🥳 (party). Why is that? Well, one big reason: this dataset is very similar to ImageNet, so the pre-training of the model helps a lot 🚀. To achieve this level of performance, we would need to train for way longer and do a bunch of tricks like augmentation, mix-up, etc.

regnetx
10
convnext
7

You can click on the different run sets and compare both models' performance. Particularly, the ConvNext architecture is hard to train from scratch and needs a lot of data to get good performance. However, they fine-tune extremely well!
💡

A Very Different Dataset (The Planet Competition Dataset)

Now, let's look at what happens when we choose a very different dataset. Instead of cats & dogs, the Planet Competition Dataset consists of satellite images from the Amazonian region. The task here consists of classifying which types of land cover are present in each image. We can have multiple land cover types present in one image.
Here's the relevant bit from the Kaggle competition:
In this competition, Planet and its Brazilian partner SCCON are challenging Kagglers to label satellite image chips with atmospheric conditions and various land cover/land use classes. The resulting algorithms will help the global community better understand where, how, and why deforestation happens worldwide - and ultimately how to respond.
Here the task is a multi-label classification problem, where each image can belong to multiple classes. We will load a sample of this dataset from fastai.
def get_planets(batch_size=64, img_size=224, seed=42):
"A sample of the planets dataset"
dataset_path=untar_data(URLs.PLANET_SAMPLE)
dls = ImageDataLoaders.from_csv(dataset_path,
folder="train",
csv_fname="labels.csv",
label_delim=" ",
suff=".jpg",
bs=batch_size,
seed=seed,
item_tfms=Resize(img_size))
return dls
A sample of the dataset can be seen on the wandb.Table below. For each image, we have a 1 if the corresponding class is present and 0 otherwise.



This is a multi-classification dataset (each image can be of multiple classes simultaneously), so one adapted metric is the F-beta score. As we can see, the pre-trained model fine-tunes slightly better achieving a higher F-beta score (but higher validation loss). The images are so different from ImageNet that the impact of the pre-training is not as noticeable as for the Pets dataset.

regnetx
40
convnext
12



Conclusions and future research

Using pre-trained models is a good practice in general. You should absolutely consider using them when possible. Even if your task is very different from the original dataset used to train the model, there is no harm in trying it 😎!
For image models, the pre-training enables the model to detect patterns and forms present in the image. The paper Visualizing and Understanding CNNs show how the layers of the model learn to identify patterns and forms. Jeremy explains this very well in fastai's "Deep Learning for Coders" lesson 2.
Transfer learning also applies to other domains–in NLP. This is standard practice since the ULMfit Paper. All Deep Learning frameworks enable you to do this with corresponding repositories to get your pre-trained models.

Iterate on AI agents and models faster. Try Weights & Biases today.