Image Classification Using Vision Transformer and KerasCV
In this article, we'll learn how to use KerasCV to fine-tune a vision transformer (ViT) on our custom dataset. We also provide code so that you can follow along.
Created on February 9|Last edited on March 7
Comment
KerasCV is a new industry-strength computer vision workflows library from the Keras team. The vision of this library is to bridge the gap between academia and industry by bringing the best of academic research in easy-to-use Keras APIs.
KerasCV now offers 10 variants of ViT that one can directly import from KerasCV. There are 5 ViT sizes - Tiny, Small (S), Base (B), Large (L), and Huge (H). The parameter configurations are available for 2 patch sizes - 16 and 32. If the input image size is 224, 224, 3, pretrained ImageNet weights can be used.
You can also build and experiment with custom ViT models using the API and are not limited to the 10 variants mentioned above. David contributed the ViT family of models in this GitHub pull request.
This report will teach you how to fine-tune a Vision Transformer (ViT) using KerasCV. We won't go into the anatomy of vision transformers and will focus solely on how to use them. And, if you'd like to follow in code, we've got you covered:
Table of Contents:
Table of Contents: Installation and ImportsDatasetDataloader and Data AugmentationVision Transformer (ViT)Model Prediction Visualization (Optional)Training ViTGPU and CPU MetricsComparing ViT VariantsEffect of Patch SizeModel PredictionConclusion
Installation and Imports
Before we start, let's install and import the required libraries for this tutorial. We will install KerasCV to access ViTs and for data augmentations. We'll use Weights & Biases for experiment tracking and model prediction visualization.
pip install keras-cvpip install wandb
Note that KerasCV requires TensorFlow v2.11.0 or above.
💡
Dataset
We will be using the Stanford Dogs dataset for this tutorial. Luckily for us, the dataset is readily available as TensorFlow Datasets:
# Import TensorFlow Datasetsimport tensorflow_datasets as tfds# Download the dataset into train and test splitds_train, ds_test = tfds.load('stanford_dogs', split=['train', 'test'])
The dataset contains images of 120 breeds of dogs from around the world. There are 20,580 images, of which 12,000 are used for training and 8580 for testing.
We can use any other dataset as long as it has images and ground truth labels.
💡
You can check out a subset of the dataset logged as W&B Tables below (certainly beats looking at MNIST in our opinion):
Run set
1
Dataloader and Data Augmentation
We will be using tf.data.Dataset API to build a dataloader for our classification pipeline. Since we're using TensorFlow Datasets to download the dataset, it's already exposed as tf.data.Datasets. We will use .map(...) method to parse the dataset and get it in the desired format:
def parse_data(example):"Apply preprocessing to one data sample at a time."# Get imageimage = example["image"]# Rescale pixels from [0, 255] to [0, 1]image = tf.image.convert_image_dtype(image, tf.float32)# We will resize the images to (224, 224)image = tf.image.resize(image, (configs.image_size, configs.image_size))# Get label and one hot encode itlabel = example["label"]label = tf.one_hot(label, configs.num_classes)return image, label
Since there are only 12K images for training, it's recommended to use data augmentation. We will be using two different ways to build our augmentation pipeline:
- We will use native Keras augmentation layers like RandomFlip, RandomRotation, etc. These layers will be stacked sequentially, as shown below:
base_augmentations = tf.keras.Sequential([tf.keras.layers.RandomFlip("horizontal"),tf.keras.layers.RandomRotation(factor=0.02),tf.keras.layers.RandomZoom(height_factor=0.2, width_factor=0.2),],name="base_augmentation",)
KerasCV has advanced augmentation techniques like MixUp, CutMix, RandAugment, etc. In this tutorial, we will use MixUp augmentation technique as show below:
# Import KerasCV preprocessing modulefrom keras_cv.layers import preprocessing# Get MixUp augmentationmixup = preprocessing.MixUp(alpha=0.8)
Let's tie all of them together using convenient tf.data APIs as shown below:
def get_dataloader(ds, type="train"):dataloader = (ds.map(parse_data, num_parallel_calls=AUTOTUNE).batch(configs.batch_size))if type == "train":dataloader = (dataloader.map(apply_base_augmentations, num_parallel_calls=AUTOTUNE).map(lambda images, labels: mixup({"images": images, "labels": labels}), num_parallel_calls=AUTOTUNE).map(lambda x: (x["images"], x["labels"]), num_parallel_calls=AUTOTUNE).shuffle(1024))dataloader = dataloader.prefetch(AUTOTUNE)return dataloader
Vision Transformer (ViT)
Try out the colab notebook here
As mentioned above, there are 10 variants of the ViT readily available in KerasCV. The API design is intuitive to use and aligns with the Keras design principles. If you have used tf.keras.applications, you will find KerasCV model APIs to be similar.
The code snippet below builds an image classifier with ViT Tiny (5.5M parameters) as the pre-trained backbone:
def get_model():inputs = tf.keras.layers.Input(shape=(configs.image_size, configs.image_size, 3))vit = ViTTiny16(include_rescaling=False,include_top=False,name="ViTTiny16",weights="imagenet",input_tensor=inputs,pooling="token_pooling",activation=tf.keras.activations.gelu,)vit.trainable = Trueoutputs = tf.keras.layers.Dense(configs.num_classes, activation="softmax")(vit.output)model = tf.keras.Model(inputs=inputs, outputs=outputs)return model
Here are a few caveats to keep in mind while using the ViT APIs:
- In the input images are in the range of [0, 255], use include_rescaling = True. It will divide the input image tensors by 255 to rescale the pixel values to [0, 1]. Since we are already rescaling our image pixels in the dataloader, we set this argument to False.
- Since we have a custom dataset for our classification task, we will not use the pre-trained head of the ViT classifier with 1000 output neurons. Thus include_top is False.
- The imagenet as pre-trained weights is only available when image size is 224,224,3.
- When using pre-trained weights, use token_pooling as the pooling strategy. We can also use avg, which will do GlobalAveragePooling however, the model fails to learn in this configuration.
- If you have a GPU memory constraint and want to use larger ViT models, you can try freezing the backbone or using mixed precision (requires compute capability of 7.0 or more) training.
Model Prediction Visualization (Optional)
The dataset can be visualized using W&B Tables, as we teased above. It can also be used for model predictions visualization and can be powerful tool in your arsenal to easily get insight in your model performance and quickly debug it. Luckily, if you are using Keras, there's an easy to use utility class called WandbEvalCallback that you can subclass to build a model predictions visualization callback.
The code snippet shown below, is subclassing WandbEvalCallback to log the model prediction for image classification task:
class WandbClfEvalCallback(WandbEvalCallback):def __init__(self, validloader, data_table_columns, pred_table_columns, num_samples=100):super().__init__(data_table_columns, pred_table_columns)# Prepare the dataloader for visualizationself.val_data = validloader.unbatch().take(num_samples)def add_ground_truth(self, logs=None):# Write the logic to add ground truth data to the `data_table`.for idx, (image, label) in enumerate(self.val_data):self.data_table.add_data(idx,wandb.Image(image),np.argmax(label, axis=-1))def add_model_predictions(self, epoch, logs=None):# Write the logic to add model predictions to the `pred_table`.# Get predictionspreds = self._inference()table_idxs = self.data_table_ref.get_index()for idx in table_idxs:pred = preds[idx]self.pred_table.add_data(epoch,self.data_table_ref.data[idx][0],self.data_table_ref.data[idx][1],self.data_table_ref.data[idx][2],pred)def _inference(self):preds = []for image, label in self.val_data:pred = self.model(tf.expand_dims(image, axis=0))argmax_pred = tf.argmax(pred, axis=-1).numpy()[0]preds.append(argmax_pred)return preds
Training ViT
Try out the colab notebook here
Now that we have our classification model ready let's compile it with Adam optimizer with a CosineDecay learning rate scheduler. We will use CategoricalCrossentropy as the loss function since we are one hot encoding the labels. We will monitor the Accuracy metric.
The initial learning rate should be low. 1e-4 is a good default to use while finetuning ViT. A learning rate of 1e-3 (usually a good default) doesn't allow the model to train.
💡
Run set
1
GPU and CPU Metrics
Let's look at how well the dataloader and the model utilized our hardware resources. We used a Tesla P100 to train all the models with 8 core CPU and 30 GB RAM.
- The training job is uses ~90% of allocated GPU memory.
- There's a cyclic drop in the GPU utilization and an out of phase cyclic rise in the CPU utilization. Clearly our dataloader is CPU bound, and we can improve the augmentation pipeline to mitigate ideal time.
Run set
1
Comparing ViT Variants
Now for the real fun: our comparative study. In the panel below, we are comparing ViT Tiny variant against ViT Small. I have also trained ViT Base, but the backbone is frozen, and only the head is trained in this instance.
- Clearly, fine-tuning a larger model (ViT Small) improves the eval accuracy by ~11%.
- Note that even though we only train the head when ViT Base is used as frozen backbone, the eval accuracy is comparable to ViT Small.
Run set
3
Effect of Patch Size
Theoretically, if the patch size is large, the accuracy should reduce at the cost of faster training. This is because larger patch size reduces the information that can be used to learn relevant features. Let's see if it's actually true from the panels shown below. ViT Small with patch size of 16 is compared with ViT Small with patch size of 32.
The experimental results clearly showcase the tradeoff:
- The ViTS32 is ~12% lower in eval accuracy than ViTS16.
- However, it only took ~27 minutes to train a ViTS32 model. It's almost half the time it took to train ViTS16 model. Doubling the patch size (16 x 2 = 32) reduces the training time by half.
Run set
2
Model Prediction
If you checked out the optional section above on model prediction visualization, you might have seen how WandbEvalCallback can be subclassed to build custom Keras callbacks for model prediction visualization. Let's see how well our model is doing.
We'll visualize the model predictions of ViTS16 (which is the best model in our set of experiments) at the 20th epoch. You can find examples where the model is failing and so much more.
Run set
1
Conclusion
In this report, we saw how easy it is to fine-tune vision transformer using KerasCV on a custom dataset. We also saw how WandbMetricsLogger could be used to automatically capture the metrics while WandbEvalCallback can be used to visualize model predictions. ViT introduced transformer for computer vision, and since then, many papers have pushed the boundary. KerasCV is a community driven open source repository where you can contribute to push the usefulness of the TensorFlow and Keras.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.