Feedforward Networks for Image Classification
In the past few months there have been various papers proposing MLP based architectures without Attention or Convolutions. This report analyses the paper 'ResMLP: Feedforward networks for image classification with data-efficient training' by Touvron et al.
Created on July 17|Last edited on January 28
Comment
This report is a part of a series of reports covering pure MLP based architectures. Other reports in this series are:
MLP-Mixer: An all-MLP Architecture for Vision
In the past few months there have been various papers proposing MLP based architectures without Attention or Convolutions. This report analyses the paper 'MLP-Mixer: An all-MLP Architecture for Vision' by Ilya Tolstikhin, Neil Houlsby, Alexander Kolesnikov, Lucas Beyer and others.
Bringing Back MLPs
It's no news that transformers have dominated the field of deep learning ever since 2017. But in their recent work, titled 'Pay Attention to MLPs,' Hanxiao Liu et al. propose a new architecture that performs as well as Transformers in key language and vision applications. Let's dig in.
Fourier Transform in Neural Networks ??!!
Continuing on the recent series of reports analyzing newly proposed pure MLP based architectures. In this report I breakdown "FNet: Mixing Tokens with Fourier Transforms" by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein and Santiago Ontañón.
Paper | Flax Implementation | Model Checkpoints
Introduction
The strong performance of recent vision architectures is often attributed to Attention or Convolutions. But Multi Layer Perceptrons (MLP) have always been better at capturing long-range dependencies and positional patterns, though, admittedly, they fall behind when it comes to learning local features, which is where CNNs shine.
An interesting new perspective where we view convolutions as a "sparse FC with shared parameters" was proposed in Ding, et al. This opens up a new way of looking at architectures. And it's catching on. This newly awoken interest seems to be a widely accepted one with giants like Google, Facebook and universities like Tsinghua and Monash submitting multiple papers on this topic to top-tier conferences. (here's a collection of recent MLP based work).
In this report we'll look at one such paper (ResMLP: Feedforward networks for image classification with data-efficient training) which explores a simple residual architecture, whose residual blocks consist of a one-hidden layer feed-forward network and a linear patch interaction layer, and achieves surprising results on various benchmarks.
Model Architecture

Figure 1: An overview of the ResMLP Architecture
The main component of the ResMLP architectures are the ResMLP layers. These consist of a linear sublayer applied across patches followed by a feedforward sublayer applied across channels. Similar to Transformers, each sublayer (crosspatch and crosschannel) is paralleled with a skip-connection to achieve training stability.
The authors report that the lack of self-attention layers made it possible to replace the Layer Normalization by a much simpler Affine transformation:
Here, and are learnable weight vectors. Like Normalization, this operation rescales the input and shifts the input element-wise. As opposed to other normalization:
- Affine Transformations have no cost at inference time.
- Affine Transformation doesn't depend on batch statistics.
This Affine operation is applied at the beginning and the end of each residual sublayer. Below is an implementation of the Affine Layer in Flax. (For the entire code see the associated github repository)
class Affine(nn.Module):"""A Flax linen Module to perform a Affine TransformationReferences:- https://arxiv.org/abs/2105.03404v2Attributes:dim (int): Needed to generate matrices of the appropriate shape"""dim: int = 512def setup(self) -> None:"""Setup the Affine layer based on the input shape"""self.alpha = self.param("alpha", ones, (1, 1, self.dim))self.beta = self.param("beta", zeros, (1, 1, self.dim))@nn.compactdef __call__(self, x, *args, **kwargs):"""Compute a forward pass through the Affine Transformation Layer"""return x * self.alpha + self.beta
1️⃣ CrossPatch SubLayer

Figure 2: Visual Representation of the "CrossPatch SubLayer"
Similar to the Token Mixing Block of the MLP-Mixer, The CrossPatchSubLayer of ResMLP replaces the MLP block with a singular Dense Layer and the Layer Normalization with Affine Transformation blocks.
Below is an implementation of the CrossPatch SubLayer in Flax. (For the entire code see the associated github repository)
class CrossPatchSubLayer(nn.Module):"""A Flax linen Module consisting of two Affine element-wise transformations,Linear Layer and Skip ConnectionReferences:- https://arxiv.org/abs/2105.03404v2Attributes:dim (int): no of dimensions for the Affine and MLP layersnum_patches (int): number of patcheslayerscale (float): value for the layerscale"""dim: int = 512num_patches: int = 16layerscale: float = 0.1def setup(self) -> None:# Affine Layersself.affine_1 = Affine(dim=self.dim)self.affine_2 = Affine(dim=self.dim)# Linear Layerself.linear = nn.Dense(features=self.num_patches)# LayerScale Parameterself.layerscale_val = self.param("layerscale_crosspatch", full, self.dim, self.layerscale)@nn.compactdef __call__(self, x, *args, **kwargs):"""Forward pass for CrossPatchSubLayer"""# Output from Affine Layer 1transform = self.affine_1(x)# Transpose the Affine Layer 1transposed_transform = jnp.transpose(transform, axes=(0, 2, 1))# Feed into Linear Layerlinear_transform = self.linear(transposed_transform)# Tranpose the output from Linear Layertransposed_linear = jnp.transpose(linear_transform, axes=(0, 2, 1))# Feed into Affine Layer 2affine_output = self.affine_2(transposed_linear)# Skip-Connection with LayerScalereturn x + affine_output * self.layerscale_val
2️⃣ CrossChannel SubLayer

Figure 3: Visual Representation of the "CrossChannel SubLayer"
This paper proposes a variant of the class-attention tokens initially introduced in the Class-Attention in Image Transformers (CaiT) architecture. Instead of having two layers in which only the class token is updated on the basis of the patch embeddings, in ResMLP the authors introduce "class-MLP" a pooling variant where the interactions between the class and patch embeddedings are caught by a simple linear layer, still keeping the patch embeddings frozen.
Below is an implementation of the CrossChannel SubLayer in Flax. (For the entire code see the associated github repository)
class CrossChannelSubLayer(nn.Module):"""A Flax linen Module consisting of two Affine element-wise transformations,MLP and Skip ConnectionReferences:- https://arxiv.org/abs/2105.03404v2Attributes:dim (int): no of dimensions for the Affine and MLP layerslayerscale (float): value for the layerscaleexpansion_factor (int): expansion factor of the MLP block"""dim: int = 512layerscale: float = 0.1expansion_factor: int = 4def setup(self) -> None:# Affine Layersself.affine_1 = Affine(dim=self.dim)self.affine_2 = Affine(dim=self.dim)# MLP Blockself.mlp = nn.Sequential([nn.Dense(features=self.expansion_factor * self.dim),GeLU(),nn.Dense(features=self.dim),])# LayerScale Parameterself.layerscale_val = self.param("layerscale_crosschannel", full, self.dim, self.layerscale)@nn.compactdef __call__(self, x, *args, **kwargs):"""Forward pass for CrossChannelSubLayer"""# Output from Affine Layer 1transform = self.affine_1(x)# Feed into the MLP Blockmlp_output = self.mlp(transform)# Output from Affine Layer 2affine_output = self.affine_2(mlp_output)# Skip-Connection with LayerScalereturn x + affine_output * self.layerscale_val
3️⃣ ResMLP Layer
Overall, this Multi-layer perceptron layer takes a set of -dimensional input features stacked in a matrix , and outputs a set of -dimensional output features, stacked in a matrix . The transformations can be summarized as:
Here, is the learnable weight of the linear layer in the CrossPatch Sublayer and and are the learnable weights of the linear layers in the CrossChannel Sublayers.
Below is an implementation of the ResMLPLayer in Flax. (For the entire code see the associated github repository)
class ResMLPLayer(nn.Module):"""A Flax linen Module consisting of the CrossPatchSubLayer and CrossChannelSubLayerReferences:- https://arxiv.org/abs/2105.03404v2Attributes:num_patches (int): no of patchesdim (int): dimensionality for the Affine and MLP layersdepth (int): number of blocks of the ResMLP Layerexpansion_factor (int): expansion factor of the MLP block"""num_patches: intdim: int = 512depth: int = 12expansion_factor: int = 4def setup(self) -> None:# Determine Value of LayerScale based on the depthif self.depth <= 18:self.layerscale = 0.1elif self.depth > 18 and self.depth <= 24:self.layerscale = 1e-5else:self.layerscale = 1e-6self.crosspatch = CrossPatchSubLayer(dim=self.dim, num_patches=self.num_patches, layerscale=self.layerscale)self.crosschannel = CrossChannelSubLayer(dim=self.dim,layerscale=self.layerscale,expansion_factor=self.expansion_factor,)@nn.compactdef __call__(self, x, *args, **kwargs):"""Forward pass for ResMLPLayer"""# Cross-Patch Sublayercrosspatch_ouptput = self.crosspatch(x)# Cross-Channel Sublayerreturn self.crosschannel(crosspatch_ouptput)
🏠 The Complete Model Architecture
Similar to the MLP-Mixer architecture, the ResMLP model takes as input a grid of non-overlapping patches as input. These patches are then independently passed through a linear layer to form a set of - dimensional embeddings. After feeding these embeddings through a sequence of ResMLP layers, they're averaged and fed into a linear classifier layer to get the output.
Below is an implementation of the ResMLP in Flax. (For the entire code see the associated github repository)
class ResMLP(nn.Module):"""A Flax linen Module for creating the ResMLP architectureReferences:- https://arxiv.org/abs/2105.03404v2Attributes:dim: dimensionality for the Affine and MLP layersdepth: Number of ResMLP layerspatch_size: dimensionality of the patchesnum_classes: No of classes"""dim: int = 512depth: int = 12in_channels: int = 3patch_size: int = 16num_classes: int = 10image_size: int = 224expansion_factor: int = 4def setup(self) -> None:# Attributesassert (self.image_size % self.patch_size == 0), "Image dimensions must be divisible by the patch size."self.num_patches = (self.image_size // self.patch_size) ** 2# Patch Projectorself.patch_projector = nn.Conv(features=self.dim,kernel_size=(self.patch_size, self.patch_size),strides=(self.patch_size, self.patch_size),)# ResMLP Layersself.blocks = nn.Sequential([ResMLPLayer(dim=self.dim,depth=self.depth,num_patches=self.num_patches,expansion_factor=self.expansion_factor,)for _ in range(self.depth)])# Fully Connected Layerself.fully_connected = nn.Dense(features=self.num_classes)@nn.compactdef __call__(self, x, *args, **kwargs):"""Forward pass for ResMLP"""# Get the Patch Embeddingsx = self.patch_projector(x)x = x.reshape(-1, self.num_patches, self.dim)# Feed into ResMLP Layersx = self.blocks(x)# Get the mean of the patchesoutput = jnp.mean(x, axis=1)# Feed into Classification Headreturn self.fully_connected(output)
The key differences as listed in the paper between ResMLP and ViT are :-
- Lack of Self-Attention Blocks: Linear layers are instead used
- No Positional Embeddings: Linear Layers encode patch information
- No extra 'Class' Tokens: Instead Average Pooling is used on patch embeddings
- No Normalization based on batch statistics: A learnable Affine Operator is used
Demo + Results
The following panel shows a comparison of various ResMLP variants trained on CIFAR-10 from scratch.
Run set
3
Conclusion
The authors introduce a new architecture (ResMLP) for image classification built only by using multi-layer perceptrons. The entire architecture is nothing but a residual network that alternates between a linear network for cross-patch interactions and two-layer feedforward network for cross-channel interactions. This simple architecture when trained using modern training techniques performs surprisingly well on common benchmarks.
This work among many is particularly important for deploying transformer like models to low-compute scenarios and brings into question the need for Attention or Convolutions.
Add a comment
test
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.