Skip to main content

The Problem with Quadratic Attention in Transformer Architectures

This report provides a brief overview of the problem with vanilla self-attention and explains its quadratic nature.
Created on March 4|Last edited on March 14
Attention forms the basis of modern day NLP and is the key ingredient behind the success of transformer architectures. Not only is the notion of attention simple and intuitive to understand but it also provides great interpretability, something perhaps ignored in this day and age. However when we try to use these models for inference, there is a barrier that arises: the quadratic nature of the all important attention layers.
In this article, we'll look into why this quadratic nature emerges and look into the various ways in which academics and companies have tried to solve this problem.

Table of Contents





Why Quadratic ?

NOTE: I won't reinvent the wheel and explain attention again, instead we shall refer to other articles which cover the paradigm in detail
💡

In the vanilla ("standard") attention method, we learn mappings from each token to every other token in the sequence. That means there are nn such weights, where nn is the number of tokens in the sequence, for every token. In other words, in total, we'll have to learn n×nn \times n such weights leading to O(n2)\mathcal{O}(n^2) space and time complexity.
From basic complexity theory, we can understand how this can be troublesome, especially when token lengths have surpassed 4 and 5 digits in length. So how do we fix this quadratic complexity ?

Proposed Solutions

In the recent years, many "linear attention" solutions have been proposed let's look into some of these in some detail:
  • Reformer (The Efficient Transformer): This method aims to use locality-sensitive hashing to reduce complexity. Reformer uses a hash function to bucket/chunk related tokens together and uses it to match similar vector together thereby avoiding a redundant search of the entire sequence. Attention is then applied within these much smaller chunks reducing the quadratic attention to almost linear!
Figure 1: Locality Sensitive Hashing as used in Reformer.
  • Sparse Transformer: Perhaps the simplest method to reduce quadratic attention, this implementation by OpenAI limits the possible tokens that a particular token can attend to, this reduces the complexity to O(nn)\mathcal{O}(n \sqrt{n}).
Figure 2: Comparison of Vanilla Attention, Strided (Sparse Attention) and Fixed Attention.
  • BigBird: another formulation by Google Research which aims to provide a linear implementation of attention by splitting attention into various parts with some tokens learning representations globally while some learn representations in local neighborhoods.
Figure 3: Comparison of Random Attention, Windowed Attention, Global Attention and BigBird Attention.
  • LongFormer: This formulation uses the classic sliding window technique to make the attention linear in terms of window width. Every token only has visibility to some fixed width of a window thereby reducing the complexity, other variants such as dilated sliding window attention have also been proposed.
Figure 4: Comparison of Vanilla, Sliding, Dilated Sliding and Global + Sliding Attention.

Summary

In this article we learnt about why the vanilla attention is quadratic in nature and how various papers throughout the years have tried to fix this problem.
If you want more reports covering the math and attention, let us know in the comments down below or on our forum ✨!
Check out these other reports on Fully Connected covering other fundamental topics like Attention and Quantisation.

Iterate on AI agents and models faster. Try Weights & Biases today.