Skip to main content

Image Classification Using Vision Transformer and KerasCV

In this article, we'll learn how to use KerasCV to fine-tune a vision transformer (ViT) on our custom dataset. We also provide code so that you can follow along.
Created on February 9|Last edited on March 7
KerasCV is a new industry-strength computer vision workflows library from the Keras team. The vision of this library is to bridge the gap between academia and industry by bringing the best of academic research in easy-to-use Keras APIs.
KerasCV now offers 10 variants of ViT that one can directly import from KerasCV. There are 5 ViT sizes - Tiny, Small (S), Base (B), Large (L), and Huge (H). The parameter configurations are available for 2 patch sizes - 16 and 32. If the input image size is 224, 224, 3, pretrained ImageNet weights can be used.
You can also build and experiment with custom ViT models using the API and are not limited to the 10 variants mentioned above. David contributed the ViT family of models in this GitHub pull request.
This report will teach you how to fine-tune a Vision Transformer (ViT) using KerasCV. We won't go into the anatomy of vision transformers and will focus solely on how to use them. And, if you'd like to follow in code, we've got you covered:


Table of Contents:



Installation and Imports

Before we start, let's install and import the required libraries for this tutorial. We will install KerasCV to access ViTs and for data augmentations. We'll use Weights & Biases for experiment tracking and model prediction visualization.
pip install keras-cv
pip install wandb
Note that KerasCV requires TensorFlow v2.11.0 or above.
💡

Dataset

We will be using the Stanford Dogs dataset for this tutorial. Luckily for us, the dataset is readily available as TensorFlow Datasets:
# Import TensorFlow Datasets
import tensorflow_datasets as tfds

# Download the dataset into train and test split
ds_train, ds_test = tfds.load('stanford_dogs', split=['train', 'test'])
The dataset contains images of 120 breeds of dogs from around the world. There are 20,580 images, of which 12,000 are used for training and 8580 for testing.
We can use any other dataset as long as it has images and ground truth labels.
💡
You can check out a subset of the dataset logged as W&B Tables below (certainly beats looking at MNIST in our opinion):

Run set
1


Dataloader and Data Augmentation

We will be using tf.data.Dataset API to build a dataloader for our classification pipeline. Since we're using TensorFlow Datasets to download the dataset, it's already exposed as tf.data.Datasets. We will use .map(...) method to parse the dataset and get it in the desired format:
def parse_data(example):
"Apply preprocessing to one data sample at a time."
# Get image
image = example["image"]
# Rescale pixels from [0, 255] to [0, 1]
image = tf.image.convert_image_dtype(image, tf.float32)
# We will resize the images to (224, 224)
image = tf.image.resize(image, (configs.image_size, configs.image_size))
# Get label and one hot encode it
label = example["label"]
label = tf.one_hot(label, configs.num_classes)

return image, label
Since there are only 12K images for training, it's recommended to use data augmentation. We will be using two different ways to build our augmentation pipeline:
  • We will use native Keras augmentation layers like RandomFlip, RandomRotation, etc. These layers will be stacked sequentially, as shown below:
base_augmentations = tf.keras.Sequential(
[
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomRotation(factor=0.02),
tf.keras.layers.RandomZoom(height_factor=0.2, width_factor=0.2),
],
name="base_augmentation",
)
KerasCV has advanced augmentation techniques like MixUp, CutMix, RandAugment, etc. In this tutorial, we will use MixUp augmentation technique as show below:
# Import KerasCV preprocessing module
from keras_cv.layers import preprocessing

# Get MixUp augmentation
mixup = preprocessing.MixUp(alpha=0.8)
Let's tie all of them together using convenient tf.data APIs as shown below:
def get_dataloader(ds, type="train"):
dataloader = (
ds
.map(parse_data, num_parallel_calls=AUTOTUNE)
.batch(configs.batch_size)
)

if type == "train":
dataloader = (
dataloader
.map(apply_base_augmentations, num_parallel_calls=AUTOTUNE)
.map(lambda images, labels: mixup({"images": images, "labels": labels}), num_parallel_calls=AUTOTUNE)
.map(lambda x: (x["images"], x["labels"]), num_parallel_calls=AUTOTUNE)
.shuffle(1024)
)

dataloader = dataloader.prefetch(AUTOTUNE)

return dataloader

Vision Transformer (ViT)

Try out the colab notebook here \rightarrow

As mentioned above, there are 10 variants of the ViT readily available in KerasCV. The API design is intuitive to use and aligns with the Keras design principles. If you have used tf.keras.applications, you will find KerasCV model APIs to be similar.
The code snippet below builds an image classifier with ViT Tiny (5.5M parameters) as the pre-trained backbone:
def get_model():
inputs = tf.keras.layers.Input(shape=(configs.image_size, configs.image_size, 3))

vit = ViTTiny16(
include_rescaling=False,
include_top=False,
name="ViTTiny16",
weights="imagenet",
input_tensor=inputs,
pooling="token_pooling",
activation=tf.keras.activations.gelu,
)
vit.trainable = True

outputs = tf.keras.layers.Dense(configs.num_classes, activation="softmax")(vit.output)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

return model
Here are a few caveats to keep in mind while using the ViT APIs:
  • In the input images are in the range of [0, 255], use include_rescaling = True. It will divide the input image tensors by 255 to rescale the pixel values to [0, 1]. Since we are already rescaling our image pixels in the dataloader, we set this argument to False.
  • Since we have a custom dataset for our classification task, we will not use the pre-trained head of the ViT classifier with 1000 output neurons. Thus include_top is False.
  • The imagenet as pre-trained weights is only available when image size is 224,224,3.
  • When using pre-trained weights, use token_pooling as the pooling strategy. We can also use avg, which will do GlobalAveragePooling however, the model fails to learn in this configuration.
  • If you have a GPU memory constraint and want to use larger ViT models, you can try freezing the backbone or using mixed precision (requires compute capability of 7.0 or more) training.

Model Prediction Visualization (Optional)

The dataset can be visualized using W&B Tables, as we teased above. It can also be used for model predictions visualization and can be powerful tool in your arsenal to easily get insight in your model performance and quickly debug it. Luckily, if you are using Keras, there's an easy to use utility class called WandbEvalCallback that you can subclass to build a model predictions visualization callback.
The code snippet shown below, is subclassing WandbEvalCallback to log the model prediction for image classification task:
class WandbClfEvalCallback(WandbEvalCallback):
def __init__(
self, validloader, data_table_columns, pred_table_columns, num_samples=100
):
super().__init__(data_table_columns, pred_table_columns)
# Prepare the dataloader for visualization
self.val_data = validloader.unbatch().take(num_samples)

def add_ground_truth(self, logs=None):
# Write the logic to add ground truth data to the `data_table`.
for idx, (image, label) in enumerate(self.val_data):
self.data_table.add_data(
idx,
wandb.Image(image),
np.argmax(label, axis=-1)
)

def add_model_predictions(self, epoch, logs=None):
# Write the logic to add model predictions to the `pred_table`.

# Get predictions
preds = self._inference()
table_idxs = self.data_table_ref.get_index()
for idx in table_idxs:
pred = preds[idx]
self.pred_table.add_data(
epoch,
self.data_table_ref.data[idx][0],
self.data_table_ref.data[idx][1],
self.data_table_ref.data[idx][2],
pred
)

def _inference(self):
preds = []
for image, label in self.val_data:
pred = self.model(tf.expand_dims(image, axis=0))
argmax_pred = tf.argmax(pred, axis=-1).numpy()[0]
preds.append(argmax_pred)

return preds
Check out our documentation to learn more about how you can use WandbEvalCallback.

Training ViT

Try out the colab notebook here \rightarrow

Now that we have our classification model ready let's compile it with Adam optimizer with a CosineDecay learning rate scheduler. We will use CategoricalCrossentropy as the loss function since we are one hot encoding the labels. We will monitor the Accuracy metric.
The initial learning rate should be low. 1e-4 is a good default to use while finetuning ViT. A learning rate of 1e-3 (usually a good default) doesn't allow the model to train.
💡
Shown below are the metrics tracked by using the WandbMetricsLogger callback.

Run set
1


GPU and CPU Metrics

Let's look at how well the dataloader and the model utilized our hardware resources. We used a Tesla P100 to train all the models with 8 core CPU and 30 GB RAM.
  • The training job is uses ~90% of allocated GPU memory.
  • There's a cyclic drop in the GPU utilization and an out of phase cyclic rise in the CPU utilization. Clearly our dataloader is CPU bound, and we can improve the augmentation pipeline to mitigate ideal time.

Run set
1


Comparing ViT Variants

Now for the real fun: our comparative study. In the panel below, we are comparing ViT Tiny variant against ViT Small. I have also trained ViT Base, but the backbone is frozen, and only the head is trained in this instance.
  • Clearly, fine-tuning a larger model (ViT Small) improves the eval accuracy by ~11%.
  • Note that even though we only train the head when ViT Base is used as frozen backbone, the eval accuracy is comparable to ViT Small.

Run set
3


Effect of Patch Size

Theoretically, if the patch size is large, the accuracy should reduce at the cost of faster training. This is because larger patch size reduces the information that can be used to learn relevant features. Let's see if it's actually true from the panels shown below. ViT Small with patch size of 16 is compared with ViT Small with patch size of 32.
The experimental results clearly showcase the tradeoff:
  • The ViTS32 is ~12% lower in eval accuracy than ViTS16.
  • However, it only took ~27 minutes to train a ViTS32 model. It's almost half the time it took to train ViTS16 model. Doubling the patch size (16 x 2 = 32) reduces the training time by half.

Run set
2


Model Prediction

If you checked out the optional section above on model prediction visualization, you might have seen how WandbEvalCallback can be subclassed to build custom Keras callbacks for model prediction visualization. Let's see how well our model is doing.
We'll visualize the model predictions of ViTS16 (which is the best model in our set of experiments) at the 20th epoch. You can find examples where the model is failing and so much more.

Run set
1


Conclusion

In this report, we saw how easy it is to fine-tune vision transformer using KerasCV on a custom dataset. We also saw how WandbMetricsLogger could be used to automatically capture the metrics while WandbEvalCallback can be used to visualize model predictions. ViT introduced transformer for computer vision, and since then, many papers have pushed the boundary. KerasCV is a community driven open source repository where you can contribute to push the usefulness of the TensorFlow and Keras.
Have questions? Post them below as a comment, or contact me at @ayushthakur0.
Iterate on AI agents and models faster. Try Weights & Biases today.