Is MLP-Mixer a CNN in Disguise?

As part of this blog post, we look at the MLP Mixer architecture in detail and also understand why it is not considered conv free.
Aman Arora

Introduction

Recently, a new kind of architecture - MLP-Mixer: An all-MLP Architecture for Vision (Tolstikhin et al, 2021) - was proposed which claims to have competitive performance with SOTA models on ImageNet without using convolutions or attention. But is this really true? Are the token-mixing or channel-mixing layers in the MLP Mixer architecture actually "Conv-free"? (Figure-1)
The deep learning community is split on this idea.
While on one side, Yann Le Cun tweeted that the architecture is not exactly Conv-free.
On the other side, Lucas Beyer defended the idea wonderfully by proposing "Mixer is a CNN problem."
As part of this report, both Dr. Habib Bukhari and I got together to look into the MLP Mixer architecture in detail and also try and explain why the community thinks that the architecture is not "Conv-free."
But first, let's understand everything that goes on inside the MLP Mixer architecture.

MLP Mixer

The overall MLP Mixer architecture is really simple to understand and can be implemented in few lines of code using popular frameworks such as PyTorch/JAX.
Figure-1: MLP Architecture
As can be seen, the overall MLP Mixer architecture looks like the above in Figure-1. Thanks to DrHB, this figure can also be represented as Figure-2 below.
Figure-2: MLP Architecture (simplified)
This first part of converting an input image to patch embeddings is similar to Vision Transformer (ViT) and has been explained in detail in a previous blog post here.
The main idea is to convert an image into 16x16 patches and get a vector representation for each of these patches. Assuming that the embedding size is 512 per patch, we get a patch embedding matrix of shape 196x512 where 196 represents the number of patches and 512 represents the number of channels per patch.
Next, these patch embeddings of shape 196x512 are passed through multiple mixer layers before being fed to the MLP Head for classification. The number of mixer layers varies between 8 and 32 depending on the MLP Mixer architecture. We look at the mixer layer in detail next.

Mixer Layer

In this section of the blog post, we are going to be looking at the mixer layer. A single mixer layer essentially consists of token-mixing and channel-mixing MLPs.
Figure-3: A single Mixer Layer in the MLP Mixer architecture.
Figure-3 above is a detailed representation of the Mixer Layer from Figure-1. As can be seen, every Mixer Layer consists of token-mixing and channel-mixing MLPs.
The Mixer Layer accepts a Patch Embedding matrix of shape 196x512 (if input image size is 224x224).

Token-mixing MLP

The token-mixing MLP accepts a transposed version of the Patch Embedding matrix. So as can be seen from Figure-3, the input's shape is 512x196.
👉: This means that each row represents a single channel for each of the 196 tokens (remember we represented each patch with a vector of length 512? That is 512 channels?). And each column represents a single token with 512 channels.
Next, we pass the input matrix of shape 512x196 to an MLP which interacts along the rows. And since each row represents a single channel for each of the tokens, we are essentially mixing the information amongst the tokens. Therefore, this layer is referred to as token-mixing MLP! And the output is also of shape 512x196.
👉: As shown in figure-1, the MLP architecture consists of two fully-connected layers separated by a GeLU non linearity.

Channel-mixing MLP

So far we already know that the output from the token-mixing MLP is of shape 512x196. But we again take the transpose of this matrix before feeding it to the channel-mixing MLP. As shown in figure-3, the input to the channel-mixing MLP is of shape 196x512.
👉: This means that each column represents a single channel for each of the 196 tokens And each row represents a single token with 512 channels.
Next, we pass the input matrix to the channel-mixing MLP, and since the MLPs interact along the rows, the mixing actually occurs per token. That is, the MLP is mixing all the channels for a single token row by row. Therefore, this layer is referred to as the channel-mixing MLP!
And the output of the channel-mixing layer has the same shape as the input 196x512.

Various MLP Mixer Architectures

Figure-4: MLP Mixer architecture with multiple MLP Mixer layers
As you can see from figure-4 above, the MLP mixer architecture essentially consists of multiple MLP Mixer layers that do not change the dimensions of the input.
So it is possible to create multiple MLP Mixer architectures by varying the number of MLP Mixer layers.
So far, in every example that we've seen, we've considered an input image of size 224x224, and embedding size per patch of 512, and a patch size of 16x16 to create the patch embeddings of shape 196x512.
But it's also possible to have a different patch size or a different embedding size.
Therefore, multiple MLP Mixer architectures can be created as shown in table-1 below.
Table-1: Specifications of the Mixer architectures. A brief notation “B/16” means the model of base scale with patches of resolution 16×16. The number of parameters is reported for an input resolution of 224 and does not include the weights of the classifier head.
And that's really it! As I mentioned briefly before, the MLP Mixer architecture is conceptually very easy to understand and implement. And the fact that such a simple architecture can get competitive performance on ImageNet with SOTA architectures is mind-blowing!

Summary

So far we already know what the MLP Mixer architecture looks like. It essentially consists of multiple mixer layers where each mixer layer consists of token-mixing and channel-mixing MLP with efficient utilization of the Transpose operation.
IMHO, the MLP Mixer can be broken down into three parts:
  1. (Stem) The first part converts an input image into a Patch Embedding matrix.
  2. (Body) Multiple Mixer layers perform the token and channel mixing to get information from the input images and process them.
  3. (Head) MLP Head uses the output from the mixer layers to perform the classification.
That's really all there is to the MLP mixer architecture. :)

MLP Mixer in PyTorch

Implementing the MLP Mixer architecture in PyTorch is really easy! Here, we reference the implementation from timm by Ross Wightman.
First, let's implement the MLP which consists of two fully connected layers separated by a GeLU non-linearity.
Figure-5: MLP layer
class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
As can be seen, the MLP implementation above exactly follows figure-5. It consists of two fully connected (or nn.Linear layers) separated by an activation layer.
Next, let's see how the mixer layer can be implemented assuming we already have the input image converted to a 196x512 patch embedding matrix.
Figure-6: Mixer Layer
As you can see, the first step is to perform the normalization using nn.LayerNorm. Next, we transpose the input matrix and pass it through the MLP and transpose it back. This is the token-mixing operation. In the implementation below, the self.mlp_tokens represents the token-mixing MLP. And as can be seen in forward method, x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) represents the token-mixing MLP operation with transposition of the matrix.
Next, we again perform normalization using nn.LayerNorm and feed the outputs to the channel-mixing MLP with skip connections to get the final outputs from the mixer layer.
In the implementation below, self.mlp_channels represents the channel mixing MLP and in the forward method, x = x + self.drop_path(self.mlp_channels(self.norm2(x))) actually performs the channel mixing operation.
class MixerBlock(nn.Module): """ Residual Block w/ token mixing and channel MLPs Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ def __init__( self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): super().__init__() tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)] self.norm1 = norm_layer(dim) self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) x = x + self.drop_path(self.mlp_channels(self.norm2(x))) return x
Finally, for the overall MLP mixer architecture as shown in figure-7 below.
Figure-7: MLP Mixer architecture
Now, we first need to be able to convert an input image into patch embeddings. How to do this in code has been previously explained in a previous blog post here. In timm, the PatchEmbed class takes care of this operation.
class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x
As can be seen, the PatchEmbed class above applies a nn.Conv2d operation with kernel size and stride equal to the patch size to convert the input image into patches of size patch size and each patch represented by embed_dim long embedding vector.
Finally, the overall MLP Mixer architecture can be implemented as below:
class MlpMixer(nn.Module): def __init__( self, num_classes=1000, img_size=224, in_chans=3, patch_size=16, num_blocks=8, embed_dim=512, mlp_ratio=(0.5, 4.0), block_layer=MixerBlock, mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop_rate=0., drop_path_rate=0., nlhb=False, stem_norm=False, ): super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.stem = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None) # FIXME drop_path (stochastic depth scaling rule or all the same?) self.blocks = nn.Sequential(*[ block_layer( embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) for _ in range(num_blocks)]) self.norm = norm_layer(embed_dim) self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() self.init_weights(nlhb=nlhb) def init_weights(self, nlhb=False): head_bias = -math.log(self.num_classes) if nlhb else 0. named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.stem(x) x = self.blocks(x) x = self.norm(x) x = x.mean(dim=1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x
Essentially the main body (or the MLP mixer layers) are part of the self.body which is wrapped in a nn.Sequential class. Also, as can be seen, the self.stem is an instance of PatchEmbed class that is responsible for converting the input image into a Patch Embedding. The implementation above should be fairly intuitive for the reader to understand.

MLP Architecture using Convolutions

Now coming to the main question - "Is the MLP architecture a CNN in disguise?" Let's now try and understand why the deep learning community is split on this idea.
Please note that it has also been mentioned under the Introduction section of the paper that:
In the extreme case, our architecture can be seen as a very special CNN, which uses 1×1 convolutions for channel mixing, and single-channel depth-wise convolutions of a full receptive field and parameter sharing for token mixing. However, the converse is not true as typical CNNs are not special cases of Mixer.
Let's look at this in parts, starting with the stem.

(Stem) Patch Embedding

The stem of the MLP Mixer architecture is responsible for converting an input image into 16x16 patches where each patch is represented by a vector of length 512. Thus for a 224x224 input image, you get 196 patches where each patch is of length 512. So you get a 196x512 patch embedding matrix.
In theory, this operation has been shown to be performed by using a per-patch fully connected layer (figure-1). But as in the code implementation in PyTorch and also the official JAX implementation (appendix of the MLP Mixer paper), this operation has been performed using a Conv2d. Thus, there are multiple ways to do the same thing but this brings us to Yann's first part of the tweet. Thus, it can be observed that the two are equivalent:
1st layer: "Per-patch fully-connected" == "conv layer with 16x16 kernels and 16x16 stride"

(Body) Mixer Layer

Now coming to the Mixer Layer. As I've shown in another blog post - "Are fully connected and convolution layers equivalent? If so, how?" a Conv operation with a kernel size of 1 is equivalent to a fully connected (or nn.Linear) layer!
Thus, it is also possible to replace every linear layer in the MLP implementation with a Conv operation. And so, you could call the MLP Mixer an extreme case of a CNN! In fact, an implementation of such a mixer layer using just convolutions has been shared by Pieter-Jan Hoedt as below (with minor updates):
class MixerConvLayer(nn.Module): """ Single block for the convolutional mixer implementation. """ def __init__(self, num_patches, num_channels, token_mixing=2048, channel_mixing=256): super().__init__() self.c_norm1 = ChannelNormalization2d(num_channels) self.conv1 = nn.Sequential( nn.Conv1d(num_patches, token_mixing, 1), nn.GELU(), nn.Conv1d(token_mixing, num_patches, 1), ) self.c_norm2 = ChannelNormalization2d(num_channels) self.conv2 = nn.Sequential( nn.Conv1d(num_channels, channel_mixing, 1), nn.GELU(), nn.Conv1d(channel_mixing, num_channels, 1), ) def forward(self, x): centred1 = self.c_norm1(x) to_mlp = torch.flatten(centred1, start_dim=2) mix1 = self.conv1(to_mlp.transpose(1, 2)).transpose(1, 2).view(x.shape) skip1 = x + mix1 centred2 = self.c_norm2(skip1) centred2_flat = torch.flatten(centred2, start_dim=2) conv_mix2 = self.conv2(centred2_flat).view(x.shape) skip2 = conv_mix2 + skip1 return skip2
Where the ChannelNormalization2d can be implemented as below:
class ChannelNormalization2d(nn.Module): """ Channel normalisation layer. The only reason for this module is because `nn.LayerNorm((512, 1, 1))` is not implemented pythonically. There is a check somewhere that keeps this from working. """ def __init__(self, num_channels: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(num_channels, 1, 1)) self.bias = nn.Parameter(torch.zeros(num_channels, 1, 1)) def forward(self, x): mean = torch.mean(x, dim=1, keepdims=True) centred = x - mean var = torch.mean(centred ** 2, dim=1, keepdims=True) normalised = centred / torch.sqrt(var + self.eps) return self.weight * normalised + self.bias
This brings us to the second part of Yann's tweet.
"MLP-Mixer" == "conv layer with 1x1 kernels"

Conclusion

So is the MLP Mixer a CNN in disguise?
I leave the answer to the astute reader and only meant to provide a holistic view of the picture as part of this blog post.