Let's Talk ConvMixer Architecture. What it is? Why you should use it? And how you can use it?

In this blog you will learn what the novel ConvMixer model is and how to implement it in TensorFlow (with a little vision transformers for good measure).
Burhanuddin Rangwala

Introduction

For the longest time, Convolution based models (CNN's) have dominated the Computer Vision landscape. In classification, segmentation, object detection and other sub-domains models like VGG-net, Resnet, Unet's, YOLO and similar model architectures have the dominion over basically anything else. Even in Kaggle competitions, all you would see is some variation or ensemble of these models.
But things have changed recently with the introduction of Vision Transformers (ViTs).

What are Vision Transformers?

Vision Transformers emerged in 2021, attempting to utilize the transformer architecture which has become the de facto architecture to use in NLP tasks. They were introduced in a paper written by the Google Research Brain team and are ~4 times better in terms of efficiency and accuracy.
CNNs use pixel arrays, whereas ViT splits the images into visual embeddings. The visual transformer then divides an image into fixed-size tokens, and includes positional embedding as an input to the transformer encoder. It basically represents an image as a series of words or word embeddings (patches).
But why does it split the image into patches or tokens instead of using the pixels like a CNN? This is due to the quadratic runtime complexity of the self attention layers in transformers.
This raised a question whether the superiority of ViT was due to the transformer architecture or the use of patches instead of pixel arrays. The ConvMixer paper tries to prove the latter point.

ConvMixer

The Patches Are All You Need? paper tries to prove that the superiority of the ViT is partly due to using image patches, and introduces a novel ConvMixer model which is similar to the ViT as well as the MLP-Mixer model.
It operates directly on patches as input, separates the mixing of spatial and channel dimensions, and maintains equal size and resolution throughout the network. But it utilizes convolutions to achieve the mixing steps. It also uses batch normalizations instead of layer normalizations.
Despite using standard convolutions it still outperforms its counterparts and also the basic CNN models by a fair margin, that too with its code fitting in a tweet.
Image from the paper showing an extremely dense PyTorch code snippet which is less than 280 characters
You can learn more about their code from their Github repository.

ConvMixer Model Architecture

The model architecture of the novel ConvMixer model
The ConvMixer model consists of a patch embedding layer which is nothing but a convolution layer with the kernel size and stride equal to the patch size and with c input channels and h output channels . This is followed by an activation function and a post activation batch normalization. This is the first part of the model.
z0 = BN (σ{Conv_c→_h(X, stride=p, kernelsize=p)})
The second part of the model is the main ConvMixer layer which is repeated depth times. This layer consists of residual block containing a depthwise convolution. A residual block is nothing but a block where the output of a previous layer is added to the output of another later layer. In this case the inputs are concatenated to the output of the DepthWiseConvolution layer. This output is followed by the activation block which is then followed by a pointwise convolution and another activation block.
z'_l = BN (σ{ConvDepthwise(z_{l−1})}) + z_{l−1}
z'_{l+1} = BN (σ{ConvPointwise(z'_l )})
The third part of the model consists of a global pooling layer to get a feature vector of size h which we can then pass to a Softmax classifier or any other head depending on the task.

GELU

The activation function used throughout the model is GELU or Gaussian Error Linear Unit. The GELU activation function weighs inputs on the basis of their magnitude rather than gating them on the basis of their sign like RELU.
GELU(x) = xΦ(x)

DepthWise Convolution

Depthwise convolution is a type of convolution where only a single convolutional filter is applied for each input channel. In contrast normal convolution is performed over multiple input channels, and the filter is as deep as the input which allows us to freely mix different channels to generate the output. Depthwise convolutions keep these channels separated. It also doesn't change the depth of the image by introducing new channels. It is mainly used to mix the spatial dimensions of the image.

PointWise Convolution

Pointwise convolution is a type of convolution where a 1x1 kernel is used to iterate over every single pixel or point in the image. This kernel like the normal convolution has a depth equal to the number of input channels. Pointwise convolution can be used to increase the number of channels (filters) in the image. It is mainly used to mix information across the patches.

Code

To get and prepare the dataset:
dataset, info = tfds.load('cifar10', with_info=True, as_supervised=True)train_dataset, test_dataset = dataset['train'], dataset['test']#create a function to normalize and resize the imagesdef normalize_and_resize(image, label): image=tf.cast(image, tf.float32) image=tf.divide(image, 255) image=tf.image.resize(image, [28,28]) return image, label#create a function to augment the imagesdef augment(image, label): image=tf.image.random_flip_left_right(image) image=tf.image.random_flip_up_down(image) image=tf.image.random_brightness(image, max_delta=0.5) image=tf.image.random_contrast(image, lower=0.2, upper=1.8) image=tf.image.random_hue(image, max_delta=0.2) image=tf.image.random_saturation(image, lower=0.2, upper=1.8) return image, label#modify the train and test datasets using the functiontrain_dataset=train_dataset.map(normalize_and_resize).cache().map(augment).shuffle(1000).batch(64).repeat()test_dataset=test_dataset.map(normalize_and_resize).cache().batch(64)
To create the model:
#This function performs the activation function and the post activation batch normalizationdef activation_normalization_layer(x): """ x: input tensor """ x=keras.layers.Activation('gelu')(x) x=keras.layers.BatchNormalization()(x) return x#This function creates the patch embeddingsdef patch_conv_layer(x, filters, patch_size): """ x: input tensor filters: number of filters or hidden dimension patch_size: the patch size which in this case determines the kernel size and stride """ x=keras.layers.Conv2D(filters=filters, kernel_size=patch_size, strides=patch_size)(x) x=activation_normalization_layer(x) return x#This is the main ConvMixer layer which is repeated "depth" timesdef conv_mixer_layer(x, filters, kernel_size): """ x: input tensor filters: number of filters or hidden dimension kernel_size: the kernel size """ #residual depthwise convolution initial=x x=keras.layers.DepthwiseConv2D(kernel_size=kernel_size, padding="same")(x) x=activation_normalization_layer(x) x=keras.layers.Concatenate()([x, initial]) #pointwise convolution 1x1 x=keras.layers.Conv2D(filters=filters, kernel_size=1, padding="same")(x) #1x1 because pointwise x=activation_normalization_layer(x) return x def conv_mixer_model(image_size=28,filters=256,depth=8,kernel_size=5,patch_size=2,num_classes=10): """ image_size: the size of the image filters: number of filters or hidden dimension depth: the number of times the conv_mixer_layer is repeated kernel_size: the kernel size patch_size: the patch size num_classes: the number of classes in the output """ inputs=keras.Input(shape=(image_size,image_size,3)) #get the patches x=patch_conv_layer(inputs, filters, patch_size) #conv mixer block repeated 'depth' times for _ in range(depth): x=conv_mixer_layer(x, filters, kernel_size) #pooling and softmax x=keras.layers.GlobalAveragePooling2D()(x) output=keras.layers.Dense(num_classes,activation="softmax")(x) model=keras.Model(inputs=inputs, outputs=output) return model

Results

The above ConvMixer model was trained on the CIFAR-10 dataset for 32 epochs with a batch size of 64. The optimizer used was AdamW as in the paper.
Below are the results of the ConvMixer model along with that of a basic CNN model.
You can see that even with ~4 times less parameters than the basic CNN model the ConvMixer model performs just as good. It has higher validation accuracy and lower validation loss but greater loss and lower accuracy. This also shows that the ConvMixer model reduces overfitting in some way.
The results obtained here can be easily improved by using the techniques specified in the original paper. But for a model with ~1.1 million parameters it performs really well.

Conclusion

ConvMixers show their superiority in terms of accuracy and computer efficiency by giving better results than basic CNN architectures with ~4 times less parameters. It also outperforms the Vision Transformer and MLP-mixer models and are as good as the ResNets, DeiTs and ResMLPs.
It will be interesting to see where this work leads to and what new and interesting models are created which utilize the idea of using image patches in intuitive ways.

Main Points

  1. ConvMixers use image patches along with a convolution based architecture instead of a transformer based architecture to show that the superiority of Vision Transformers is partly due to image patches.
  2. They use depthwise convolution followed by pointwise convolution as their main ConvMixer layer which is repeated depth times.
  3. They use the GELU activation function.
  4. It gives better performance with a large kernel size.
  5. It gives better validation accuracy as compared to a basic CNN model with ~4 times less parameters

References

[1] https://keras.io/examples/vision/convmixer/
[2] https://openreview.net/pdf?id=TVHS5Y4dNvM
[3] https://arxiv.org/abs/2010.11929v2
[4] https://github.com/tmp-iclr/convmixer