Fourier Transform in Neural Networks ??!!
Continuing on the recent series of reports analyzing newly proposed pure MLP based architectures. In this report I breakdown "FNet: Mixing Tokens with Fourier Transforms" by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein and Santiago Ontañón.
Created on July 7|Last edited on July 10
Comment
This report is a part of a series of reports covering pure MLP based architectures. Other reports in this series are:-
Paper | Github Repository + Python Package | Model Checkpoints
👋🏻 Introduction
Let's be honest the past few years have been all about Attention. As the authors say, Self-Attention is a :-
" inductive bias that connects each token in the input through a relevance weighted basis of every other token "
Using Attention, each hidden unit is represented in the basis of the hidden units of the previous layer. There have been tons of variants of Attention proposed in the last 4 years, each one allowing for the capture of diverse syntactic and semantic relationships.
But we still have no answers to the question that if inductive bias in self-attention is essential to the effectiveness of Transformers? It is still unclear what empowers the success of Transformers: is it the feedforward nature of Transformers or is it the multi-head self-attention layers in Transformers ?
Also, the quadratic time complexity of attention make it very hard for their successful application to various domains. This was one of the primary reasons why transformers couldn't be applied to vision. Eventually many papers attempted to chip off this time complexity like linformer and longformer. But at-last some people have started to find a way out of this infinite well. Recently a lot of research has been published around MLPs without attention or convolutions. (Collection of recent MLP based work). With interesting new methods like RepMLP, ResMLP (CVPR 2021), Can Attention Enable MLPs To Catch Up With CNNs? (CVM 2021) and Are Pre-trained Convolutions Better than Pre-trained Transformers? (ACL 2021). This new awoken interest seems to be a widely accepted one with giants like Google, Facebook and universities like Tsinghua and Monash submitting multiple papers on this topic to top-tier conferences.
Model Architecture

Figure 1: An overview of the FNet architecture.
The FNet architecture is still a standard Transformer based architecture but with a Fourier sublayer instead of Attention. Let's breakdown the architecture into it's various subparts.
1️⃣ The FeedForward Module

Figure 2: Comparison between FeedForward blocks from MLP-Mixer and FNet
The FeedForward block in this paper is pretty much similar to that of MLP-Mixer, it consists of 2 fully connected layers with a non-linearity (GELU). The only difference is that in the case of FNet, a dropout layer is added towards the end.
Below is an implementation of FeedForward Module in Flax. (For the entire code see the associated github repository)
class FeedForward(nn.Module):dim: intexpansion_factor: intdropout_rate: jnp.float32def setup(self):self.fc1 = nn.Dense(features=self.expansion_factor * self.dim)self.fc2 = nn.Dense(features=self.dim)self.drop = nn.Dropout(rate=self.dropout_rate)@nn.compactdef __call__(self, x, deterministic=False) -> Array:out = self.fc1(x)out = nn.gelu(out)out = self.fc2(out)output = self.drop(out, deterministic=deterministic)return output
2️⃣ The Encoder Block
This is the main part of the FNet architecture. This is also similar to the Mixer block from the MLP-Mixer architecture. Here, the authors replace the Token-Mixing Block with Fourier transforms.
A Fourier Transform is a wonderful and quite honestly beautiful concept from mathematics which essentially allows us to decompose a function into its constituent frequencies. For a brief introduction to Fourier transform check out this amazing video from 3blue1brown.
These Fourier sublayers apply a 2D Discrete Fourier Transform (DFT) along its (sequence length, hidden dimension) embedding input. To avoid make structural changes to the layers to compensate for complex numbers, we just keep the real part from the transformation.
What's the intuition, you ask ?
Well according to the authors, Fourier transforms are effective mechanisms for mixing tokens, providing the feed-forward sublayers access to all tokens. Also, because of the duality of the Fourier transform, each alternating encoder and decoder block can apply alternative Fourier and inverse Fourier transform to transform the input to and fro from the time and frequency domain.
Multiplying the feedforward layer coefficients in the frequency domain is equivalent to convolving in the time domain. Thus, FNet can be thought of as alternating between matrix multiplications and (large kernel) convolutions.
Below is an implementation of the Encoder Block in Flax. (For the entire code see the associated github repository)
class EncoderBlock(nn.Module):d_hidden: int = 512def setup(self):self.ff = FeedForward(dim=self.d_hidden, expansion_factor=4, dropout_rate=0.1)@nn.compactdef __call__(self, x):x_fft = lax.real(jnp.fft.fft2(x, axes=(-1, -2)))x = nn.LayerNorm(name="LN1")(x + x_fft)x_ff = self.ff(x)x = nn.LayerNorm(name="LN2")(x + x_ff)return x
🏠 Complete Model
For pre-processing the tokens, the architecture consists of an Embedding Block (identical to BERT) and an Output Head.
Below is an implementation of the FNet architecture in Flax. (For the entire code see the associated github repository)
class FNet(nn.Module):depth: intdim: intdef setup(self):self.layers = [EncoderBlock(d_hidden=self.dim) for _ in range(self.depth)]self.dense = nn.Dense(features=self.dim)@nn.compactdef __call__(self, x) -> Array:for layer in self.layers:x = layer(x)output = self.dense(x)return output
🧐 Some Interesting Results
The authors performed quite an extensive comparative study comparing FNet with :-
- Linear encoder: where the self-attention sublayers are replaced with a two learnable, dense, linear sublayers, one applied to the hidden dimension and one applied to the sequence dimension.
- Random encoder: where the self-attention sublayers are replaced with two constant random matrices, one applied to the hidden dimension and one applied to the sequence dimension.
- FF-only encoder: where the self-attention sublayers are removed from the Transformer layers [ No Token Mixing !! ]
GLUE Metrics
| Model | Average GLUE score |
|---|---|
| BERT-Base | 83.3 |
| Linear-Base | 77.0 |
| FNet-Base | 76.7 |
| Random-Base | 56.6 |
| FF-only-Base | 49.3 |
Let's put FNet aside for a while, and appreciate the performance of the "linear" encoder based model. Based on the authors experiments, no major accuracy benefits from the softmax or multiple head projections are observed (such as those in Synthesizer), although they did result in slightly decreased training speeds.
They further show that the Linear model can outperform FNet on certain tasks but has several drawbacks relative to FNet: it is slower on GPUs, has a much larger memory footprint, and is prone to instability during training.

Figure 3: Speed-accuracy trade-offs for GPU pre-training. The dotted line shows the Pareto efficiency frontier.
The speed vs MLM accuracy curve for GPU pre-training is shown in Figure 3. For larger, slower models, BERT defines the Pareto efficiency frontier. But for smaller, faster models, FNet and the Linear model define the efficiency frontier, indicating a better trade off between speed and accuracy than BERT. So although a “Mini” FNet is larger than a “Micro” BERT, the FNet model is both more accurate and faster to train than the smaller “Micro” BERT.
Conclusion
In this paper, the authors introduce FNet, a new architecture inspired by Transformers where the learnable Self-Attention layer has been replaced by a unlearnable Fourier Transform. This work highlights the potential of linear units as a drop-in replacement for the attention mechanism. They particularly found Fourier Transforms to be an effective mixing mechanism, in part due to the highly efficient FFT. Remarkably, this unparameterized mixing mechanism can yield relatively competitive models.
This work among many is particularly important for deploying transformer like models to low-compute scenarios.
Add a comment