MLPs are All You Need: Back to Square One?

In the past few months there have been various papers proposing MLP based architectures without Attention or Convolutions. This report analyses the paper 'MLP-Mixer: An all-MLP Architecture for Vision' by Ilya Tolstikhin, Neil Houlsby, Alexander Kolesnikov, Lucas Beyer and others.
Saurav Maheshkar

Link to the paper \longrightarrow

Github Repository / Python Package with Flax Implementation

Recently a lot of research has been published around MLPs without attention or convolutions! With interesting new methods like RepMLP, ResMLP (CVPR 2021), Can Attention Enable MLPs To Catch Up With CNNs? (CVM 2021) and Are Pre-trained Convolutions Better than Pre-trained Transformers? (ACL 2021). This new awoken interest seems to be a widely accepted one with giants like Google, Facebook and universities like Tsinghua and Monash submitting multiple papers on this topic to top-tier conferences. (Collection of recent MLP based work)
Figure 1: A comparison of the recent MLP based Architectures
The strong performance of recent vision architectures is often attributed to Attention or Convolutions. But Multi Layer Perceptrons have always been better at capturing long-range dependencies and positional patterns, but admittedly fall behind when it comes to learning local features, which is where CNNs shine. An interesting new perspective of viewing convolutions as a "sparse FC with shared parameters" was proposed in Ding, et al. This perspective opens up a new way of looking at architectures. In this report we'll look at one such paper which explores the idea of using convolutions with an extremely small kernel size of (1,1) essentially turning convolutions into standard matrix multiplications applied independently to each spatial location. This modification alone doesn't allow for the aggregation of spatial information. To compensate for this the authors proposed dense matrix multiplications that are applied to every feature across all spatial locations.

Model Architecture

Figure 2: An overview of the MLP-Mixer architecture

1️⃣ MLP Block

Each MLP Block consists of two fully-connected layers with a non-linearity (GELU in this case) applied independently to each row of the input data tensor. It's important to note that the hidden widths for the MLP Blocks are chosen independently of the number of input patches. Thus the computational complexity of the network is linear in the number of input patches, unlike Vision based transformers whose complexity is quadratic. Thus the overall complexity is linear in terms of the number of pixels, similar to CNNs. Below is an implementation of the MLP Block in Flax. (For the entire code see the associated github repository)
class MlpBlock(nn.Module): """ A MLPBlock Wrapper Flax Module consisting of two Fully connected layers with a GELU layer in between Attributes: mlp_dim: No of output dimensions for the first FC approximate: If True, uses the approximate formulation of GELU dtype: the dtype of the computation (default: float32) """ mlp_dim: int approximate: bool = True dtype: Dtype = jnp.float32 @nn.compact def __call__(self, x) -> Array: y = nn.Dense(features=self.mlp_dim, dtype=self.dtype)(x) y = nn.gelu(y, approximate=self.approximate) out = nn.Dense(features=x.shape[-1], dtype=self.dtype)(y) return out

2️⃣ Mixer Block

Figure 3: A Pictographic View of the Mixer Block Architecture with some of its salient features.
A Mixer Block takes as input a sequence of S non-overlapping image patches, each projected to a hidden dimension C. The authors refer to this as a real-valued input table X \in \mathbb{R}^{S \times C}. If the input image has a resolution of (H, W) and if we use square patches of resolution (P,P) then, one can easily figure out the number of patches by conserving the area as
S = HW / P^2
A Mixer Block consists of two MLP Blocks :-
  1. Token Mixing Block: This block acts along the columns of the input table X (after performing X^T). It maps \mathbb{R}^S \mapsto \mathbb{R}^S, where S is the input sequence length and is shared across all columns.
  2. Channel Mixing Block: This block acts along the rows of the input table X. It maps \mathbb{R}^C \mapsto \mathbb{R}^Cand is shared across all rows.
Having the same MLP Block (sharing the same kernel / parameters) for the rows and columns is a key design choice of this architecture. This instills the model with positional invariance (A key feature of convolutions). This "parameter-tying" prevents the model from growing too fast while increasing the value of C or S and as reported leads to memory savings.
Another thing to note about that the Mixer Block architecture is that it has a "isotropic" design, meaning all layers of the block take an input of the same size (width). This is a common choice for Transformers and RNNs but CNNs on the other hand have a "pyramidal" structure where the deeper layers have lower resolution but higher number of channels.
Unlike Vision Transformers, MLPMixer doesn't use positional embeddings, because the Token Mixing Block is sensitive to the order of the tokens thus allowing it to learn "locations". MLPMixer also uses standard layers like Layer Normalization and Skip Connections along with a Classifier head.
Below is an implementation of the Mixer Block in Flax. (For the entire code see the associated github repository)
class MixerBlock(nn.Module): """ A Flax Module to act as the mixer block layer for the MLP-Mixer Architecture. Attributes: tokens_mlp_dim: No of dimensions for the MLP Block 1 channels_mlp_dim: No of dimensions for the MLP Block 2 approximate: If True, uses the approximate formulation of GELU in each MLP Block dtype: the dtype of the computation (default: float32) """ tokens_mlp_dim: int channels_mlp_dim: int approximate: bool = True dtype: Dtype = jnp.float32 @nn.compact def __call__(self, x) -> Array: # Layer Normalization y = nn.LayerNorm(dtype=self.dtype)(x) # Transpose y = jnp.swapaxes(y, 1, 2) # MLP 1 y = MlpBlock( mlp_dim=self.tokens_mlp_dim, approximate=self.approximate, dtype=self.dtype, name="token_mixing", )(y) # Transpose y = jnp.swapaxes(y, 1, 2) # Skip Connection x = x + y # Layer Normalization y = nn.LayerNorm(dtype=self.dtype)(x) # MLP 2 with Skip Connection out = x + MlpBlock( mlp_dim=self.channels_mlp_dim, approximate=self.approximate, dtype=self.dtype, name="channel_mixing", )(y) return out

🏠 Complete Model

The Model also consists of a "Per-Patch Fully Connected layer" which converts the input image patches to fixed length vectors. This forms our input table (X \in \mathbb{R}^{S \times C}) which is then passed through a number of Mixer Blocks. Finally a Classification head is added to get the desired output.
Below is an implementation of the MLPMixer in Flax. (For the entire code see the associated github repository)
class MlpMixer(nn.Module): """ Flax Module for the MLP-Mixer Architecture. Attributes: patches: Patch configuration num_classes: No of classes for the output head num_blocks: No of Blocks of Mixers to use hidden_dim: No of Hidden Dimension for the Patch-Wise Convolution Layer tokens_mlp_dim: No of dimensions for the MLP Block 1 channels_mlp_dim: No of dimensions for the MLP Block 2 approximate: If True, uses the approximate formulation of GELU in each MLP Block dtype: the dtype of the computation (default: float32) """ patches: Any num_classes: int num_blocks: int hidden_dim: int tokens_mlp_dim: int channels_mlp_dim: int approximate: bool = True dtype: Dtype = jnp.float32 @nn.compact def __call__(self, inputs, *, train) -> Array: del train # Per-Patch Fully Connected Layer x = nn.Conv( features=self.hidden_dim, kernel_size=self.patches.size, strides=self.patches.size, dtype=self.dtype, name="stem", )(inputs) x = rearrange(x, "n h w c -> n (h w) c") # Num Blocks x Mixer Blocks for _ in range(self.num_blocks): x = MixerBlock( tokens_mlp_dim=self.tokens_mlp_dim, channels_mlp_dim=self.channels_mlp_dim, approximate=self.approximate, dtype=self.dtype, )(x) # Output Head x = nn.LayerNorm(dtype=self.dtype, name="pre_head_layer_norm")(x) x = jnp.mean(x, axis=1, dtype=self.dtype) return nn.Dense( self.num_classes, kernel_init=nn.initializers.zeros, dtype=self.dtype, name="head", )(x)

Results

The following graph shows the fine tuning performance of a Mixer B16 trained with 224 as the patch size, w.r.t a ViT B32 trained on patch sizes of 224, 128, 64 and 32. As evident from the graph, Mixer performs just as well as ViT on fine-tuning performance !! I mean let's not forget Mixer lacks convolutions or Attention. Being able to provide reasonably similar performance is impressive to say the least.

Some Interesting Results from the Paper 🧐

Affect of Scale 🪜

There are 2 ways to scale a MLP-Mixer model that are outlined in the paper :-
  1. Increase the model size (viz. number of layers, hidden dimensions, MLP widths) when pre-training.
  2. Increase the input image resolution when fine-tuning.
Figure 4: Role of Model Scale
The authors report that when trained on ImageNet from scratch, Mixer lies around 3% behind ViT. But as the pre-training dataset grows in size, Mixer's performance steadily increases. When pre-trained on JFT-300M, Mixer lies around 0.3% behind ViT while being ~2x faster.

PreTraining Dataset Size 📦

The authors report that pre-training on larger datasets improves the mixer's performance. On pre-training on a smaller subset of JFT-300M, all mixer models overfit. BiT overfits less, possibly because of the inductive biases associated with convolutions. But upon increasing the size of the pre-training dataset, the performance of Mixer grows while BiT plateaus.
Figure 5: PreTrained Dataset size effect on performance
On comparing with ViT, the relative improvement is more distinct. Mixer models improve more with dataset size than ViT. The explanation that authors give is the difference in inductive biases.
" Self-Attention layers in ViT lead to certain properties of the learned functions that are less compatible with the true underlying distribution than those discovered with the Mixer architecture "

Visualizations 👀

The first few layers of Convolutional Neural Networks are known to learn detectors that act on pixels in a particular local region of the image. In contrast, Mixer allows for global information exchange in the Token-Mixing blocks (MLP No 1). These Token-Mixing MLPs allow for communication between different spatial locations.
Figure 6: Some Weights of the MLP blocks. It's important to note that in contrast to convolutional kernels where each weight corresponds to a pixel, in the case of MLP, each weight corresponds to particular 16x16 patch.
Figure 6 shows the weights of the first few Token Mixing blocks of Mixer trained on the JFT-300M. Some of the learned features act on the entire image whereas others operate on smaller portions of the image. The first few blocks contain local interactions whereas the "deeper" blocks learn features across larger regions of the image.

Conclusion

In this paper the authors propose a very very simple architecture for vision. Although the model doesn't improve upon the current SOTA performance, but performs comparatively well especially when scaled. Continuing on recent work in the field, this paper is one of many which raises the question of whether " Attention is necessary ? " and hopefully encourages other to think of new interesting architectures beyond the infinite well of Convolutions and Attention.