Feedforward Networks for Image Classification

In the past few months there have been various papers proposing MLP based architectures without Attention or Convolutions. This report analyses the paper 'ResMLP: Feedforward networks for image classification with data-efficient training' by Touvron et al. .
Saurav Maheshkar
This report is a part of a series of reports covering pure MLP based architectures. Other reports in this series are:
Report Gallery

Paper | Github Repository + Python Package | Model Checkpoints

Introduction

The strong performance of recent vision architectures is often attributed to Attention or Convolutions. But Multi Layer Perceptrons (MLP) have always been better at capturing long-range dependencies and positional patterns, though, admittedly, they fall behind when it comes to learning local features, which is where CNNs shine.
An interesting new perspective where we view convolutions as a "sparse FC with shared parameters" was proposed in Ding, et al. This opens up a new way of looking at architectures. And it's catching on. This newly awoken interest seems to be a widely accepted one with giants like Google, Facebook and universities like Tsinghua and Monash submitting multiple papers on this topic to top-tier conferences. (here's a collection of recent MLP based work).
In this report we'll look at one such paper (ResMLP: Feedforward networks for image classification with data-efficient training) which explores a simple residual architecture, whose residual blocks consist of a one-hidden layer feed-forward network and a linear patch interaction layer, and achieves surprising results on various benchmarks.

Model Architecture

Figure 1: An overview of the ResMLP Architecture
The main component of the ResMLP architectures are the ResMLP layers. These consist of a linear sublayer applied across patches followed by a feedforward sublayer applied across channels. Similar to Transformers, each sublayer (crosspatch and crosschannel) is paralleled with a skip-connection to achieve training stability.
The authors report that the lack of self-attention layers made it possible to replace the Layer Normalization by a much simpler Affine transformation:
Aff_{\alpha, \beta} (x) = Diag(\alpha)x \,\, + \beta
Here, \alpha and \beta are learnable weight vectors. Like Normalization, this operation rescales the input and shifts the input element-wise. As opposed to other normalization:
This Affine operation is applied at the beginning and the end of each residual sublayer. Below is an implementation of the Affine Layer in Flax. (For the entire code see the associated github repository)
class Affine(nn.Module): """ A Flax linen Module to perform a Affine Transformation Attributes: dim: Needed to generate matrices of the appropriate shape """ dim: int = 512 def setup(self): self.alpha = self.param("alpha", ones, (1, 1, self.dim)) self.beta = self.param("beta", zeros, (1, 1, self.dim)) @nn.compact def __call__(self, x) -> Array: return x * self.alpha + self.beta

1️⃣ CrossPatch SubLayer

Figure 2: Visual Representation of the "CrossPatch SubLayer"
Similar to the Token Mixing Block of the MLP-Mixer, The CrossPatchSubLayer of ResMLP replaces the MLP block with a singular Dense Layer and the Layer Normalization with Affine Transformation blocks.
Below is an implementation of the CrossPatch SubLayer in Flax. (For the entire code see the associated github repository)
class CrossPatchSublayer(nn.Module): """ A Flax linen Module consisting of two Affine element-wise transformations, Linear Layer and Skip Connection. Attributes: dim: dimensionality for the Affine Layer patch_size: dimensionality for the Linear Layer layerscale: float value for scaling the output """ dim: int = 512 patch_size: int = 16 layerscale: float = 0.1 def setup(self): self.aff1 = Affine(dim=self.dim) self.linear = nn.Dense(features=self.patch_size) self.aff2 = Affine(self.dim) self.layerscale = self.param( "layerscale_crosspatch", full, self.dim, self.layer_scale ) @nn.compact def __call__(self, x) -> Array: # Output from Affine Layer 1 transform = self.aff1(x) # Transpose the Affine Layer 1 transposed_transform = jnp.transpose(transform, axes=(1, 2)) # Feed into Linear Layer linear_transform = self.linear(transposed_transform) # Tranpose the output from Linear Layer transposed_linear = jnp.transpose(linear_transform, axes=(1, 2)) # Feed into Affine Layer 2 affine_output = self.aff2(transposed_linear) # Skip-Connection with LayerScale output = x + affine_output * self.layerscale return output

2️⃣ CrossChannel SubLayer

Figure 3: Visual Representation of the "CrossChannel SubLayer"
This paper proposes a variant of the class-attention tokens initially introduced in the Class-Attention in Image Transformers (CaiT) architecture. Instead of having two layers in which only the class token is updated on the basis of the patch embeddings, in ResMLP the authors introduce "class-MLP" a pooling variant where the interactions between the class and patch embeddedings are caught by a simple linear layer, still keeping the patch embeddings frozen.
Below is an implementation of the CrossChannel SubLayer in Flax. (For the entire code see the associated github repository)
class CrossChannelSublayer(nn.Module): """ A Flax linen Module consisting of two Affine element-wise transformations, MLP and Skip Connection. Attributes: dim: dimensionality for the Affine Layer and MLP fully-connected layers layerscale: float value for scaling the output """ dim: int = 512 layerscale: float = 0.1 expansion_factor: int = 4 def setup(self): self.aff1 = Affine(self.dim) self.mlp = Sequential( [ nn.Dense(features=self.expansion_factor * self.dim), gelu(), nn.Dense(features=self.dim), ] ) self.aff2 = Affine(self.dim) self.layerscale = self.param( "layerscale_crosschannel", full, self.dim, self.layer_scale ) @nn.compact def __call__(self, x) -> Array: # Output from Affine Layer 1 transform = self.aff1(x) # Feed into the MLP Block mlp_output = self.mlp(transform) # Output from Affine Layer 2 affine_output = self.aff2(mlp_output) # Skip-Connection with LayerScale output = x + affine_output * self.layerscale return output

3️⃣ ResMLP Layer

Overall, this Multi-layer perceptron layer takes a set of N^2 \, d-dimensional input features stacked in a d \times N^2 matrix X, and outputs a set of N^2 \, d-dimensional output features, stacked in a matrix Y. The transformations can be summarized as:
Z \, \, = X + Aff((A \, \, Aff(X)^T)^T)
Y \, \, = Z + Aff(C \, \, GELU(B \, \, Aff(Z)))
Here, A \, (N^2 \times N^2) is the learnable weight of the linear layer in the CrossPatch Sublayer and B \, (4d \times d) and C \, (d \times 4d) are the learnable weights of the linear layers in the CrossChannel Sublayers.
Below is an implementation of the ResMLPLayer in Flax. (For the entire code see the associated github repository)
class ResMLPLayer(nn.Module): """ A Flax linen Module consisting of the CrossPatchSublayer and CrossChannelSublayer Attributes: dim: dimensionality for the Affine and MLP layers depth: No of ResMLP Layers, needed for determining the layerscale value patch_size: dimensionality for the Linear Layer """ dim: int = 512 depth: int = 12 patch_size: int = 16 def setup(self): # Determine Value of LayerScale based on the depth if self.depth <= 18: layerscale = 0.1 elif self.depth > 18 and self.depth <= 24: layerscale = 1e-5 else: layerscale = 1e-6 self.crosspatch = CrossPatchSublayer( dim=self.dim, patch_size=self.patch_size, layerscale=layerscale ) self.crosschannel = CrossChannelSublayer(dim=self.dim, layerscale=layerscale) @nn.compact def __call__(self, x) -> Array: crosspatch_ouptput = self.crosspatch(x) crosschannel_output = self.crosschannel(crosspatch_ouptput) return crosschannel_output

🏠 The Complete Model Architecture

Similar to the MLP-Mixer architecture, the ResMLP model takes as input a grid of N \times N non-overlapping patches as input. These patches are then independently passed through a linear layer to form a set of N^2 \, d - dimensional embeddings. After feeding these embeddings through a sequence of ResMLP layers, they're averaged and fed into a linear classifier layer to get the output.
Below is an implementation of the ResMLP in Flax. (For the entire code see the associated github repository)
class ResMLP(nn.Module): """ A Flax linen Module for creating the ResMLP architecture. Attributes: dim: dimensionality for the Affine and MLP layers depth: Number of ResMLP layers patch_size: dimensionality of the patches num_classes: No of classes """ dim: int = 512 depth: int = 12 patch_size: int = 16 num_classes: int = 10 def setup(self): self.patch_projector = nn.Conv( features=self.dim, kernel_size=self.patch_size, strides=self.patch_size ) self.blocks = Sequential( [ ResMLPLayer(dim=self.dim, patch_size=self.patch_size, depth=self.depth) for _ in range(self.depth) ] ) self.fc = nn.Dense(features=self.num_classes) @nn.compact def __call__(self, x) -> Array: x = self.patch_projector(x) x = self.blocks(x) output = jnp.mean(x, axis=1) return self.fc(output)
The key differences as listed in the paper between ResMLP and ViT are :-

Demo + Results

The following panel shows a comparison of various ResMLP variants trained on CIFAR-10 from scratch.
I've also provided an interactive demo of the ResMLP Big variant (depth = 24) trained on ImageNet built using Gradio. I was able to build this demo application by just adding a couple of lines to the DeiT inference colab. For more details you can have a look at the notebook used to upload this app and this report by Abubakar Abid.

Conclusion

The authors introduce a new architecture (ResMLP) for image classification built only by using multi-layer perceptrons. The entire architecture is nothing but a residual network that alternates between a linear network for cross-patch interactions and two-layer feedforward network for cross-channel interactions. This simple architecture when trained using modern training techniques performs surprisingly well on common benchmarks.
This work among many is particularly important for deploying transformer like models to low-compute scenarios and brings into question the need for Attention or Convolutions.