Skip to main content

Feedforward Networks for Image Classification

In the past few months there have been various papers proposing MLP based architectures without Attention or Convolutions. This report analyses the paper 'ResMLP: Feedforward networks for image classification with data-efficient training' by Touvron et al.
Created on July 17|Last edited on January 28
This report is a part of a series of reports covering pure MLP based architectures. Other reports in this series are:


Paper | Flax Implementation | Model Checkpoints

Introduction

The strong performance of recent vision architectures is often attributed to Attention or Convolutions. But Multi Layer Perceptrons (MLP) have always been better at capturing long-range dependencies and positional patterns, though, admittedly, they fall behind when it comes to learning local features, which is where CNNs shine.
An interesting new perspective where we view convolutions as a "sparse FC with shared parameters" was proposed in Ding, et al. This opens up a new way of looking at architectures. And it's catching on. This newly awoken interest seems to be a widely accepted one with giants like Google, Facebook and universities like Tsinghua and Monash submitting multiple papers on this topic to top-tier conferences. (here's a collection of recent MLP based work).
In this report we'll look at one such paper (ResMLP: Feedforward networks for image classification with data-efficient training) which explores a simple residual architecture, whose residual blocks consist of a one-hidden layer feed-forward network and a linear patch interaction layer, and achieves surprising results on various benchmarks.

Model Architecture

Figure 1: An overview of the ResMLP Architecture
The main component of the ResMLP architectures are the ResMLP layers. These consist of a linear sublayer applied across patches followed by a feedforward sublayer applied across channels. Similar to Transformers, each sublayer (crosspatch and crosschannel) is paralleled with a skip-connection to achieve training stability.
The authors report that the lack of self-attention layers made it possible to replace the Layer Normalization by a much simpler Affine transformation:
Affα,β(x)=Diag(α)x+βAff_{\alpha, \beta} (x) = Diag(\alpha)x \,\, + \beta

Here, α\alpha and β\beta are learnable weight vectors. Like Normalization, this operation rescales the input and shifts the input element-wise. As opposed to other normalization:
  • Affine Transformations have no cost at inference time.
  • Affine Transformation doesn't depend on batch statistics.
This Affine operation is applied at the beginning and the end of each residual sublayer. Below is an implementation of the Affine Layer in Flax. (For the entire code see the associated github repository)
class Affine(nn.Module):
"""
A Flax linen Module to perform a Affine Transformation
References:
- https://arxiv.org/abs/2105.03404v2
Attributes:
dim (int): Needed to generate matrices of the appropriate shape
"""

dim: int = 512

def setup(self) -> None:
"""Setup the Affine layer based on the input shape"""
self.alpha = self.param("alpha", ones, (1, 1, self.dim))
self.beta = self.param("beta", zeros, (1, 1, self.dim))

@nn.compact
def __call__(self, x, *args, **kwargs):
"""Compute a forward pass through the Affine Transformation Layer"""
return x * self.alpha + self.beta

1️⃣ CrossPatch SubLayer

Figure 2: Visual Representation of the "CrossPatch SubLayer"
Similar to the Token Mixing Block of the MLP-Mixer, The CrossPatchSubLayer of ResMLP replaces the MLP block with a singular Dense Layer and the Layer Normalization with Affine Transformation blocks.
Below is an implementation of the CrossPatch SubLayer in Flax. (For the entire code see the associated github repository)
class CrossPatchSubLayer(nn.Module):
"""
A Flax linen Module consisting of two Affine element-wise transformations,
Linear Layer and Skip Connection
References:
- https://arxiv.org/abs/2105.03404v2
Attributes:
dim (int): no of dimensions for the Affine and MLP layers
num_patches (int): number of patches
layerscale (float): value for the layerscale
"""

dim: int = 512
num_patches: int = 16
layerscale: float = 0.1

def setup(self) -> None:
# Affine Layers
self.affine_1 = Affine(dim=self.dim)
self.affine_2 = Affine(dim=self.dim)

# Linear Layer
self.linear = nn.Dense(features=self.num_patches)

# LayerScale Parameter
self.layerscale_val = self.param(
"layerscale_crosspatch", full, self.dim, self.layerscale
)

@nn.compact
def __call__(self, x, *args, **kwargs):
"""Forward pass for CrossPatchSubLayer"""

# Output from Affine Layer 1
transform = self.affine_1(x)

# Transpose the Affine Layer 1
transposed_transform = jnp.transpose(transform, axes=(0, 2, 1))

# Feed into Linear Layer
linear_transform = self.linear(transposed_transform)

# Tranpose the output from Linear Layer
transposed_linear = jnp.transpose(linear_transform, axes=(0, 2, 1))

# Feed into Affine Layer 2
affine_output = self.affine_2(transposed_linear)

# Skip-Connection with LayerScale
return x + affine_output * self.layerscale_val

2️⃣ CrossChannel SubLayer

Figure 3: Visual Representation of the "CrossChannel SubLayer"
This paper proposes a variant of the class-attention tokens initially introduced in the Class-Attention in Image Transformers (CaiT) architecture. Instead of having two layers in which only the class token is updated on the basis of the patch embeddings, in ResMLP the authors introduce "class-MLP" a pooling variant where the interactions between the class and patch embeddedings are caught by a simple linear layer, still keeping the patch embeddings frozen.
Below is an implementation of the CrossChannel SubLayer in Flax. (For the entire code see the associated github repository)
class CrossChannelSubLayer(nn.Module):
"""
A Flax linen Module consisting of two Affine element-wise transformations,
MLP and Skip Connection
References:
- https://arxiv.org/abs/2105.03404v2
Attributes:
dim (int): no of dimensions for the Affine and MLP layers
layerscale (float): value for the layerscale
expansion_factor (int): expansion factor of the MLP block
"""

dim: int = 512
layerscale: float = 0.1
expansion_factor: int = 4

def setup(self) -> None:
# Affine Layers
self.affine_1 = Affine(dim=self.dim)
self.affine_2 = Affine(dim=self.dim)

# MLP Block
self.mlp = nn.Sequential(
[
nn.Dense(features=self.expansion_factor * self.dim),
GeLU(),
nn.Dense(features=self.dim),
]
)

# LayerScale Parameter
self.layerscale_val = self.param(
"layerscale_crosschannel", full, self.dim, self.layerscale
)

@nn.compact
def __call__(self, x, *args, **kwargs):
"""Forward pass for CrossChannelSubLayer"""
# Output from Affine Layer 1
transform = self.affine_1(x)

# Feed into the MLP Block
mlp_output = self.mlp(transform)

# Output from Affine Layer 2
affine_output = self.affine_2(mlp_output)

# Skip-Connection with LayerScale
return x + affine_output * self.layerscale_val

3️⃣ ResMLP Layer

Overall, this Multi-layer perceptron layer takes a set of N2dN^2 \, d-dimensional input features stacked in a d×N2d \times N^2 matrix XX, and outputs a set of N2dN^2 \, d-dimensional output features, stacked in a matrix YY. The transformations can be summarized as:
Z=X+Aff((AAff(X)T)T)Z \, \, = X + Aff((A \, \, Aff(X)^T)^T)

Y=Z+Aff(CGELU(BAff(Z)))Y \, \, = Z + Aff(C \, \, GELU(B \, \, Aff(Z)))

Here, A(N2×N2)A \, (N^2 \times N^2) is the learnable weight of the linear layer in the CrossPatch Sublayer and B(4d×d)B \, (4d \times d) and C(d×4d)C \, (d \times 4d) are the learnable weights of the linear layers in the CrossChannel Sublayers.
Below is an implementation of the ResMLPLayer in Flax. (For the entire code see the associated github repository)
class ResMLPLayer(nn.Module):
"""
A Flax linen Module consisting of the CrossPatchSubLayer and CrossChannelSubLayer
References:
- https://arxiv.org/abs/2105.03404v2
Attributes:
num_patches (int): no of patches
dim (int): dimensionality for the Affine and MLP layers
depth (int): number of blocks of the ResMLP Layer
expansion_factor (int): expansion factor of the MLP block
"""

num_patches: int
dim: int = 512
depth: int = 12
expansion_factor: int = 4

def setup(self) -> None:

# Determine Value of LayerScale based on the depth
if self.depth <= 18:
self.layerscale = 0.1
elif self.depth > 18 and self.depth <= 24:
self.layerscale = 1e-5
else:
self.layerscale = 1e-6

self.crosspatch = CrossPatchSubLayer(
dim=self.dim, num_patches=self.num_patches, layerscale=self.layerscale
)
self.crosschannel = CrossChannelSubLayer(
dim=self.dim,
layerscale=self.layerscale,
expansion_factor=self.expansion_factor,
)

@nn.compact
def __call__(self, x, *args, **kwargs):
"""Forward pass for ResMLPLayer"""

# Cross-Patch Sublayer
crosspatch_ouptput = self.crosspatch(x)

# Cross-Channel Sublayer
return self.crosschannel(crosspatch_ouptput)

🏠 The Complete Model Architecture

Similar to the MLP-Mixer architecture, the ResMLP model takes as input a grid of N×NN \times N non-overlapping patches as input. These patches are then independently passed through a linear layer to form a set of N2dN^2 \, d - dimensional embeddings. After feeding these embeddings through a sequence of ResMLP layers, they're averaged and fed into a linear classifier layer to get the output.
Below is an implementation of the ResMLP in Flax. (For the entire code see the associated github repository)
class ResMLP(nn.Module):
"""
A Flax linen Module for creating the ResMLP architecture
References:
- https://arxiv.org/abs/2105.03404v2
Attributes:
dim: dimensionality for the Affine and MLP layers
depth: Number of ResMLP layers
patch_size: dimensionality of the patches
num_classes: No of classes
"""

dim: int = 512
depth: int = 12
in_channels: int = 3
patch_size: int = 16
num_classes: int = 10
image_size: int = 224
expansion_factor: int = 4

def setup(self) -> None:

# Attributes
assert (
self.image_size % self.patch_size == 0
), "Image dimensions must be divisible by the patch size."
self.num_patches = (self.image_size // self.patch_size) ** 2

# Patch Projector
self.patch_projector = nn.Conv(
features=self.dim,
kernel_size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
)

# ResMLP Layers
self.blocks = nn.Sequential(
[
ResMLPLayer(
dim=self.dim,
depth=self.depth,
num_patches=self.num_patches,
expansion_factor=self.expansion_factor,
)
for _ in range(self.depth)
]
)

# Fully Connected Layer
self.fully_connected = nn.Dense(features=self.num_classes)

@nn.compact
def __call__(self, x, *args, **kwargs):
"""Forward pass for ResMLP"""

# Get the Patch Embeddings
x = self.patch_projector(x)
x = x.reshape(-1, self.num_patches, self.dim)

# Feed into ResMLP Layers
x = self.blocks(x)

# Get the mean of the patches
output = jnp.mean(x, axis=1)

# Feed into Classification Head
return self.fully_connected(output)
The key differences as listed in the paper between ResMLP and ViT are :-
  • Lack of Self-Attention Blocks: Linear layers are instead used
  • No Positional Embeddings: Linear Layers encode patch information
  • No extra 'Class' Tokens: Instead Average Pooling is used on patch embeddings
  • No Normalization based on batch statistics: A learnable Affine Operator is used

Demo + Results

The following panel shows a comparison of various ResMLP variants trained on CIFAR-10 from scratch.

Run set
3



Conclusion

The authors introduce a new architecture (ResMLP) for image classification built only by using multi-layer perceptrons. The entire architecture is nothing but a residual network that alternates between a linear network for cross-patch interactions and two-layer feedforward network for cross-channel interactions. This simple architecture when trained using modern training techniques performs surprisingly well on common benchmarks.
This work among many is particularly important for deploying transformer like models to low-compute scenarios and brings into question the need for Attention or Convolutions.
idor
idor •  *
test
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.