Bringing Back MLPs

It's no news that transformers have dominated the field of deep learning ever since 2017. But in their recent work, titled 'Pay Attention to MLPs,' Hanxiao Liu et al. propose a new architecture that performs as well as Transformers in key language and vision applications. Let's dig in.
Saurav Maheshkar

Link to the Paper \longrightarrow

Github Repository / Python Package with Flax Implementation \longrightarrow

Introduction

Ever since "Attention is All You Need" (Vaswani et al. [1]) came out, Transformers have replaced LSTMs and RNNs as the default architecture and, with architectures like Vision Transformer (Dosovitskiy et al. [2]) and DeiT (Touvron et al. [3]), Transformers have even become an alternative to Convolutional Neural Networks.
The question though is whether Transformers really are the way forward. And that's what we're going to dig into in this post.
Let's first jump back a bit to the good old days of 2014 when Sutskever et al. [4] proposed the concept of Sequence to Sequence Learning. Back then, Encoder-Decoder networks were all the rage. This type of architecture evolved to handle variable-length sequences for machine translation. The general end-to-end approach for 'seq2seq' introduced by Sutskever et al. allowed for a RNN Encoder to encode the input sequence to a fixed length hidden state, then this fixed length state could be used to generate sequences using a Decoder.
Figure 1: Schematic Representation of Sequence to Sequence Learning from (Sutskever et al. [4]). The relation between distanced tokens gets lost as we increase the sequence length.
As shown in Figure 1 above, a part of that last token gets passed on while processing the next token. As you can imagine upon increasing the sequence length, only remnants of the initial tokens remain. This affects the long range memory of such systems. We can model words which should be together fairly efficiently but when it comes to long range semantic relationships it's fair to say Seq2Seq fails. That said, Gated Recurrent Units were able to solve some of these problems and others like vanishing gradients.
But with the advent of self-attention in Transformers, instead of having a fixed length hidden state, the world of weighted context vector emerged, which was capable of maintaining the attention mechanism while processing sequences in parallel thus allowing for parallelism. But let's not forget Transformers are classical sequence to sequence models at their core.
Figure 2: A Encoder-Decoder view of Transformers from Jay Alammar's blogpost
As shown in Figure 2, a Transformer consists of a stack of encoders and decoders, but the brilliance of Transformers lie in the structure of such Encoders: they consist of a Self-Attention Module and a Feed-Forward Neural Network. This "Attention Module" allows for the model to understand the spatial relationship between the tokens as the sequence passes through the encoder stack and really is the core of Transformers. Now, I won't try to be the million-and-oneth guy to attempt to explain attention, keys, query all over again, so for the sake of simplicity let's borrow Ashish Vaswani's definition of self-attention:
“Self-attention, sometimes called intra-attention, is an attention mechanism relating different positions of a single sequence in order to compute a representation of the sequence.”
This attention mechanism allows learning long-range dependencies between any two positions in the input data in the form of an attention map. However, this additional freedom and reduced inductive bias means that effectively training Transformer-based architectures requires huge amounts of data.
There are a few key noticeable differences between convolutions/recurrent networks: (From notes of NYU DL Course)
Thus concluding there are 2 main aspects of Transformers:-
1. A Recurrent free architecture which computes the representations for each token in parallel, and
2. Multi-head self-attention blocks which aggregate spatial information across tokens.

The Newly Awoken Interest in MLPs

2021 so far seems to be the year for MLPs with interesting new methods like RepMLP, ResMLP (CVPR 2021), Can Attention Enable MLPs To Catch Up With CNNs? (CVM 2021) and Are Pre-trained Convolutions Better than Pre-trained Transformers? (ACL 2021). This new awoken interest seems to be a widely accepted one with giants like Google, Facebook and universities like Tsinghua and Monash submitting multiple papers on this topic to top-tier conferences.
Figure 3: A comparison of the recent MLP based Architectures
Although most of these papers, don't necessarily overpower transformers their aim seems to be clear enough \longrightarrow To take full advantage of linear layers. But how do these architectures account for long-term memory that Attention is responsible for ??
All these architectures use linear layers along the token dimension, to allow for different patches to communicate with each other. These layers process patches, thereby extracting local features and introducing inductive bias.

Do linear layers seem too weak to you ? 🔫

Well, recent work (by Tay et al. [5]) shows that CNN-based pre-trained models are competitive and outperform their Transformer counterpart in certain scenarios, albeit with caveats. Now there are a number of reasons why one should pick convolutions over attention.

Key Takeaways 🔑 from the Paper (tl;dr)

The Model Architecture

Figure 4: An overview of the gMLP architecture
The architecture consists of a stack of L blocks with identical size and structure. Let X \in \mathbb{R}^{n \times d}, be the token representations with sequence length n and dimension d. Each block is then defined as
Z = \sigma(XU)
\tilde{Z} = s(Z)
Y = \tilde{Z}V
where \sigma is the activation function, U and V define linear projections along the channel dimension. s(.), is a layer which captures the spatial interactions. When s is an identity mapping, the above transformation degenerates to a regular Feed Forward Network, where individual tokens are processed independently without any cross-token communication. Here, s(.) is defined as a spatial depth-wise convolution. NOTE: Unlike Transformers gMLP does not require position embeddings.
Here's a Flax implementation of a gMLP block. (For the full code refer to the github repository)
class gMLPBlock(nn.Module): """ Flax Module to create a gMLP Block with optional Attention Module Attributes: dim: No of output dimensions for the block dim_ff: No of output dimensions for the input projection attn_dim: No of dimensions for the Attention head (default: None) dtype: the dtype of the computation (default: float32) """ dim: int dim_ff: int attn_dim: Any = None dtype: Dtype = jnp.float32 def setup(self): self.proj_in = nn.Dense(features=self.dim_ff, dtype=self.dtype) self.attn = ( Attention( dim_head=self.attn_dim, dim_out=self.dim_ff // 2, dtype=self.dtype ) if self.attn_dim is not None else None ) self.sgu = SpatialGatingUnit(dim_out=self.dim_ff // 2, dtype=self.dtype) self.proj_out = nn.Dense(features=self.dim, dtype=self.dtype) @nn.compact def __call__(self, x) -> Array: gate_res = self.attn(x) if self.attn is not None else None x = self.proj_in(x) x = nn.gelu(x) x = self.sgu(x, gate_res=gate_res) x = self.proj_out(x) return x

The Spatial Gating Unit

To enable cross-token interactions, it's necessary for the layer s(.) to contain some sort of a contraction operation over the spatial dimension. The simplistic option would be a simple linear projection:
f_{W, b} (Z) = WZ + b
and the spatial interaction unit is then defined as the multiplication of its input and the spatially transformed input:-
s(Z) = Z \odot f_{W, b} (Z) \,\,\,\,\,\,\,\,\, -(I)
NOTE: For training stability, the authors suggest to initialize W as near-zero values and b as ones. (Equation (I) therefore acts as an Identity Mapping). This initialization ensures each gMLP block behaves like a regular Feed Forward Network at the early stage of training, where each token is processed independently, and only gradually 'injects' spatial information across tokens.
This multiplicative gating can be viewed as a mechanism to "modulate" individual token representations using the spatial signal. The authors also found it effective to split Z into two independent parts (Z_1, Z_2) along the channel dimension, Thus equation (I) is transformed to :-
s(Z) = Z_1 \odot f_{W, b} (Z_2)
The authors also normalize the input to improve stability. This variant is closely related to Gated Linear Units. But, the key difference lies in the fact that SGU is based on projections along the spatial (cross-token) dimension rather than the channel (hidden) dimension.
In terms of computational cost, SGU has n^2e/2 multiply-adds which is comparable to the 2n^2d of dot-product self-attention. Both are linear over the input channel size and quadratic over the sequence length n.

Results 🧐

Image 🌄 Classification

The authors compare gMLP by applying it to image classification task on ImageNet without using extra data. The input and output protocols follow ViT/B16 where the raw image is converted into 16×16 patches at the stem. The depth and width are chosen so that the models are comparable with ViT/DeiT in capacity. Like Transformers, they find that gMLPs tend to drastically overfit the training data. They therefore apply a similar regularization recipe as the one used in DeiT and to avoid extensive tuning, they adjust only the strengths of stochastic depth as they move from smaller to larger models
Model ImageNet Top-1 (%) Params(M)
DeiT-Ti (ViT+reg) 72.2 5
DeiT-S (ViT+reg) 79.8 22
DeiT-B (ViT+reg) 81.8 86
gMLP-Ti 72.3 6
gMLP-S 79.6 20
gMLP-B 81.6 73
As per the results shown in the paper, gMLPs are comparable with DeiT, namely ViT trained using improved regularization. More importantly as we scale the model we get comparable metrics with less parameters (~13 M) in the case of DeiT-B and gMLP-B.

Masked Language Modeling (MLM) 💬

Best of both worlds❔

Upon experimenting with MLM, the authors identified NLP finetuning tasks where gMLPs transfer less well than Transformers. Infact the MLP-like model is advantageous on SST-2 but worse on MNLI. This is particularly informative—the former is a single-sentence task whereas the latter involves sentence pairs. They suspect the role of self-attention during finetuning is related to cross-sentence alignment. To isolate the effect of self-attention, they experiment with a hybrid model where a tiny self-attention block is attached to the gating function of gMLP. This hybrid model is referred to as aMLP (“a” for attention).
The authors use the full English C4 dataset and adopt a common MLM setup with batch size 256, max length 512 and 1M training steps. Like with vision they adjust the depth and width of gMLPs to ensure comparable model capacity with the Transformer baselines.
Model Perplexity SST-2 MNLI (m/mm) Params (M)
BERT - base 4.17 93.8 85.6/85.7 110
gMLP - base 4.28 94.2 83.7/84.1 130
aMLP - base 3.95 93.4 85.9/85.8 109
BERT - large 3.35 94.3 87.0/87.4 336
gMLP - large 3.32 94.8 86.2/86.5 365
aMLP - large 3.19 94.8 88.4/88.4 316
Like vision, gMLP's are comparable to Transformers in terms of perplexity, especially as we scale up. The authors show that blending in a tiny single-head self-attention is sufficient to make gMLPs outperform Transformers of similar capacity, sometimes by a significant margin. The authors suggest that :-
" the capacity in the multi-head self-attention of Transformers can be largely redundant, and that the majority of its functionalities can be captured by the spatial gating unit in gMLPs "

Conclusion

Now calm down!! I'm not a non-believer of Transformers or Attention, but let's slow down and think about if it's necessary in the first place. Nobody can deny the sheer impact that Transformers have had since June 2017. From overtaking RNNs for Machine Translation to almost replacing convolutions in 2020, Transformers have deeply impacted Biological Sequence Analysis (facebookresearch/esm) and Video Understanding (facebookresearch/TimeSformer).
But it still remains an open question if inductive bias in self-attention is essential to the effectiveness of Transformers? It is still unclear what empowers the success of Transformers: is it the feedforward nature of Transformers or is it the multi-head self-attention layers in Transformers?
Hopefully such questions are answered as we take a more careful look at Attention and if it really behind the success of Transformers.