Skip to main content

Checking Out the New fastai/timm Integration

In this article, we'll explore how ML practitioners can leverage the full timm backbone catalog in their deep learning pipelines with the new fastai integration.
Created on May 2|Last edited on February 15
A long-time ask from the fastai community has been the integration of timm (PyTorch Image models) into the fastai vision learner. Why? Because users could leverage the full catalog of timm backbones in their precious deep learning pipelines.
Currently, the fastai timm integration only supports creating models through the newly created vision_learner (replacing the old cnn_learner for classification and regression tasks). I hope to see this integration also come to segmentation tasks in the future.
💡
Previously, the integration was possible by using a plain fastai Learner object and manually passing the DataLoaders , loss_func, and the metrics. But with this integration, everything is handled by fastai.
In this article, we'll explore the timm/fastai integration in detail. Here's what we'll cover:

Table of Contents



How Do We Use This New fastai Integration?

So, how do we utilize these amazing pre-trained timm backbones? Thankfully, it is very straightforward!
First, you find which backbone you'd like to use. To do this, you can head to timm documentation or inside Python. There are two handy methods to filter out models:
import timm
timm.list_models()
>>
['adv_inception_v3',
'bat_resnext26ts',
'beit_base_patch16_224',
'beit_base_patch16_224_in22k',
'beit_base_patch16_384',
'beit_large_patch16_224',
'beit_large_patch16_224_in22k',
'beit_large_patch16_384',
'beit_large_patch16_512',
'botnet26t_256',
'botnet50ts_256',
...]
Or filter by expressions:
timm.list_models('gluon_resnet*') # returns all models starting with 'gluon_resnet'
timm.list_models('*resnext*', 'resnet') #returns all models with 'resnext' in 'resnet'
Knowing this, we can grab our preferred model, create a fastai learner object, and train a model 😃 on Imagenette.
from fastai.vision.all import *

timm_arch = 'regnetx_004'

imagenette_path = untar_data(URLs.IMAGENETTE)
dls = ImageDataLoaders.from_folder(imagenette_path, train="train", valid="val", item_tfms=Resize(256))
learn = vision_learner(dls, timm_arch, metrics=accuracy, pretrained=True)
learn.fit_one_cycle(5)
Under the hood, the vision_learner is calling create_vision_model that will stack a simple head on top of timm backbone. The default head looks like this:
Sequential( (0): AdaptiveConcatPool2d( (ap): AdaptiveAvgPool2d(output_size=1) (mp): AdaptiveMaxPool2d(output_size=1) ) (1): Flatten(full=False) (2): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): Dropout(p=0.25, inplace=False) (4): Linear(in_features=768, out_features=512, bias=False) (5): ReLU(inplace=True) (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): Dropout(p=0.5, inplace=False) (8): Linear(in_features=512, out_features=10, bias=False) )
You can use your custom head, pass the custom_head argument to the vision_learner
custom_head = nn.Sequential(
nn.AdaptiveAvgPool2d(),
nn.Flatten(),
nn.Linear(384, 10)). # the 384 value comes from the last layer of the backbone

learn = vision_learner(dls, timm_arch, metrics=accuracy, pretrained=True, custom_head=custom_head)

Trying Out Different timm Backbones

Okay! Let's play with different timm backbones and assess the model performance.
Recently, Kevin Bird found that the pre-trained resnet18 from torchvision gave better results than the pretrained resnet18 from timm when trained for a very short time (5 epochs in Imagenette). Let's try to reproduce this using fastai and Weights & Biases.
We can refactor our training code a little bit and include the WandbCallback for the fastai learner:
import wandb
from fastai.callback.wandb import WandbCallback

PROJECT = "imagenette_timm"

# log into your wandb account
wandb.login()

def train(arch, epochs, group="timm", **learn_kwargs):
"Performs training on `arch` and group runs on the workspace"
with wandb.init(project=PROJECT, group=group):
cbs = [MixedPrecision(), WandbCallback(log_preds=False)]
learn = vision_learner(dls,
arch,
metrics=[accuracy],
cbs=cbs,
pretrained=False,
**learn_kwargs)
learn.fit_one_cycle(epochs)
We can then call the train function with a timm arch or torchvision

Training with timm backbones

Kevin's experiments were done with the resnet18 model available in both timm and torchvision. To use the timm variant of resnet18, we have to pass the name of the arch as a string "resnet18". We will:
  • Train the model 5 times to get an average
  • Train for 5 epochs
  • Pass a group param, so we can easily identify the runs in the W&B workspace.
As below:
N_EXPERIMENTS = 5
for _ in range(N_EXPERIMENTS):
train(arch="resnet18", epochs=5, group="timm")
There is a different init of the parameters in timm compared to torchvision, so we have to re run this with the extra arg: zero_init_last_bn=False to get similar initizalition and results. Kevin goes in detail about this in his blogpost
💡
N_EXPERIMENTS = 5
for _ in range(N_EXPERIMENTS):
train(arch="resnet18", epochs=5, group="timm", zero_init_last_bn=False)

Training a torchvision Backbone

When we call from fastai.vision.all import *, fastai is importing all torchvision.models under the hood. So we already have them all available on the namespace. Conveniently (or not), we can pass the resnet18 constructor to the vision_learner to use the torchvision version of the arch.
Here we pass the resnet18 python function, without the quotes ""
💡
for _ in range(N_EXPERIMENTS):
train(arch=resnet18, epochs=5, group="torchvision")

Results


5 epochs
22


Higher Learning Rates

Also, the effect on initialization is less important when we train with higher learning rates. Choosing lr=0.005 closes the gap. If you follow fastai best practices, the suggested learning rate by learner.lr_find() is close to this value.


5 epochs, lr=0.005
4


More Runs, Better Results

The effect of the initialization becomes less important as we train for more epochs. As you can see when training for 10 epochs, there is almost no change:


10 epochs
3
20 epochs
6



Experiments With ImageWoof

Let's look at a new dataset (ImageWoof). It's a bit more complex as it contains 10 dog breeds.

lr=0.005
5
lr=0.001
6
lr=0.0001
6


Conclusion

When running multiple experiments, keeping track of your parameters and results can become complicated. Fastai plays nicely with Weights & Biases, and personal accounts are free to use!
Also, if you liked this post, try creating one with your own data. It's a powerful and elegant way of sharing results with coworkers and friends!
Iterate on AI agents and models faster. Try Weights & Biases today.