Skip to main content

A Deep Dive Into the Transformer Architecture

This article takes a deep dive into the breakthroughs, scientific basis, formulas, and code for the transformer architecture, as outlined in "Attention Is All You Need".
Created on December 29|Last edited on November 30
After language model breakthroughs like ELMo and ULMFit, transformers have taken the Natural Language Processing (NLP) field by storm. They are the basis of popular language models like GPT-3 and DALL-E. Also, tools like the HuggingFace Transformers library have made it easy for machine learning engineers to solve a wide range of NLP tasks and facilitated a lot of subsequent breakthroughs in NLP.
In this article, we'll take a deep dive into the transformer architecture as described in the paper "Attention Is All You Need" (2017) and code it up using PyTorch.

Table of Contents



Tokenization

First of all, we need a way to represent text numerically in order to perform calculations on it. Tokenization is the process of parsing a string of text into a compressed sequence of symbols. This process results in a vector of integers where each integer represents part of the text. The transformer paper uses Byte-Pair Encoding (BPE) as the tokenization method. BPE is a form of compression where the most common consecutive bytes (i.e. characters) are compressed into a single byte (i.e. integer).
Recent research suggests BPE is suboptimal and recent language models like BERT use a WordPiece tokenizer instead. WordPiece seems easier to decode and more intuitive as it often tokenizes entire words into one token. In contrast, BPE often tokenizes slices of words. This can lead to weird tokenization symbols and individual characters (i.e. letters from the alphabet) being tokenized as unknown tokens. Therefore, WordPiece is the tokenizer we use in an example. We could train our own tokenizer on a corpus of text, but in reality practitioners almost always use a pre-trained tokenizer. The HuggingFace transformers library makes it easy to work with pre-trained tokenizers.

>>> from transformers import BertTokenizer
>>> tok = BertTokenizer.from_pretrained("bert-base-uncased")
Downloading: 100%|█████████████████████████| 1.04M/1.04M [00:00<00:00, 1.55MB/s]
Downloading: 100%|████████████████████████████| 456k/456k [00:00<00:00, 848kB/s]

>>> tok("Hello, how are you doing?")['inputs_ids']
{'input_ids': [101, 7592, 2129, 2024, 2017, 2725, 1029, 102]}

>>> tok("The Frenchman spoke in the [MASK] language and ate 🥖")['input_ids']
{'input_ids': [101, 1996, 26529, 3764, 1999, 1996, 103, 2653, 1998, 8823, 100, 102]}

>>> tok("[CLS] [SEP] [MASK] [UNK]")['input_ids']
{'input_ids': [101, 101, 102, 103, 100, 102]}
Note that the tokenizer automatically includes a token for the start of an encoding ([CLS] == 101) and the end of an encoding (i.e. SEParation) ([SEP] == 102). Other special tokens include masking ([MASK] == 103) and an unknown symbol ([UNK] = 100, e.g. for the 🥖 emoji).

Embeddings

In order to learn proper representations of text, each individual token in the sequence is converted to a vector through an embedding. It can be seen as a type of neural network layer, because the weights for the embeddings are learned along with the rest of the transformer model. It contains a vector for each word in the vocabulary and these weights are initialized from a normal distribution N(0,1)\mathcal{N}(0, 1). Note that it requires us to specify the size of the vocabulary (vocab|\text{vocab}|) and the dimension of the model (dmodel=512d_\text{model} = 512) when initializing the model (ERvocab×dmodelE \isin \mathbb{R}^{|\text{vocab}| \times d_\text{model}}). Lastly, the weights are multiplied by dmodel\sqrt{{d_\text{model}}} as a normalization step.
import torch
from torch import nn
class Embed(nn.Module):
def __init__(self, vocab: int, d_model: int = 512):
super(Embed, self).__init__()
self.d_model = d_model
self.vocab = vocab
self.emb = nn.Embedding(self.vocab, self.d_model)
self.scaling = torch.sqrt(self.d_model)

def forward(self, x):
return self.emb(x) * self.scaling

Positional Encoding

In contrast to recurrent and convolutional networks, the model in itself has no information about relative position of the embedded tokens in a sequence. Therefore, we have to inject this information by adding an encoding to the input embeddings for the encoder and decoder. This information can be added in many different ways and can be static or learned. The transformer uses sine and cosine transformations for each position (pos\text{pos}). Sine is used for the even dimensions (2i2_i) and cosine for the odd dimensions (2i+12_{i+1}).
PEpos,2i=sin(pos100002i/dmodel)PE_{\text{pos}, 2i} = \sin(\frac{\text{pos}}{10000^{2i / d_\text{model}}})
PEpos,2i+1=cos(pos100002i/dmodel)PE_{\text{pos}, 2i+1} = \cos(\frac{\text{pos}}{10000^{2i / d_\text{model}}})
In the code, positional encodings are computed in log space to avoid numerical overflow.
import torch
from torch import nn
from torch.autograd import Variable

class PositionalEncoding(nn.Module):
def __init__(self, d_model: int = 512, dropout: float = .1, max_len: int = 5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
# Compute the positional encodings in log space
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.Tensor([10000.0])) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
return self.dropout(x)

Multi-Head Attention

Before transformers, the paradigm in AI research to learn from sequences was to use either convolutions (WaveNet, ByteNet) or recurrence (RNN, LSTM). Attention already allowed for some NLP breakthroughs before transformers (Luong et al., 2015), but it was not obvious back then that you could build effective models without convolutions or recurrence. Therefore, proposing that "Attention is All You Need" was quite a bold statement.
The attention layer can learn a mapping between a query (QQ) and a set of key (KK) value (VV) pairs. The meaning of these names can be confusing, because they depend on the particular NLP application. For developing our transformers you can just think of them as linear projections of the input. The names query, key and value come from traditional information retrieval theory and we will briefly explain where these terms come from with an example.
When you are searching for a video on Youtube, you will type in a phrase in the search bar (i.e. a query QQ). The search engine will use this to map against Youtube video titles, descriptions etc. (i.e. keys KK). Using this mapping it will suggest the most relevant videos to you (i.e. values VV). (Example source)
One innovation which boosted the performance of attention in NLP is what the authors call "Scaled dot-product attention". It is the same as multiplicative attention, but additionally the mapping of QQ and KK is scaled by the key dimension dkd_k. This makes multiplicative attention perform better with larger dimensions. The result is put through a softmax activation (softmax(xi)=exp(xi)jexp(xj)\text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_{j} exp(x_j)}) and multiplied by VV.
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}} )V
import torch
from torch import nn
class Attention:
def __init__(self, dropout: float = 0.):
super(Attention, self).__init__()
self.dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)

def forward(self, query, key, value, mask=None):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = self.dropout(self.softmax(scores))
return torch.matmul(p_attn, value)
def __call__(self, query, key, value, mask=None):
return self.forward(query, key, value, mask)
In the decoder, an attention sub-layer is masked by filling certain positions with a very large negative number (1e9-1\text{e}9 or inf-\inf). This is to prevent the model from cheating by attending to subsequent positions. This ensures that the model can only attend to words at previous positions when it tries to predict the next token.
The mechanism of attention in itself is already very powerful and can be calculated efficiently on modern hardware, like GPUs and TPUs which are optimized for matrix multiplication. However, a single attention layer only allows for one representation. Therefore, in the transformer multiple attention heads are used. This allows the model to learn multiple patterns and representations. The paper uses h=8h = 8 attention layers which are concatenated. The final formula becomes:
MultiHead(Q,K,V)=concat(head1,,headn)WO\text{MultiHead}(Q, K, V) = \text{concat}(\text{head}_1, \dots, \text{head}_n) W^O
where headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
The projection weights (WORhdv×dmodel,WiQRdmodel×dk,WiKRdmodel×dk,WiVRdmodel×dvW^O \isin \mathbb{R}^{hd_{v} \times d_\text{model}}, W_i^Q \isin \mathbb{R}^{d_\text{model} \times d_k}, W_i^K \isin \mathbb{R}^{d_\text{model} \times d_k}, W_i^V \isin \mathbb{R}^{d_\text{model} \times d_v}) are outputs of a fully-connected (Linear) layer. The authors of the transformer paper use dk=dv=dmodelh=64d_k = d_v = \frac{d_\text{model}}{h} = 64.
from torch import nn
from copy import deepcopy
class MultiHeadAttention(nn.Module):
def __init__(self, h: int = 8, d_model: int = 512, dropout: float = 0.1):
super(MultiHeadAttention, self).__init__()
self.d_k = d_model // h
self.h = h
self.attn = Attention(dropout)
self.lindim = (d_model, d_model)
self.linears = nn.ModuleList([deepcopy(nn.Linear(*self.lindim)) for _ in range(4)])
self.final_linear = nn.Linear(*self.lindim, bias=False)
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
if mask is not None:
mask = mask.unsqueeze(1)
query, key, value = [l(x).view(query.size(0), -1, self.h, self.d_k).transpose(1, 2) \
for l, x in zip(self.linears, (query, key, value))]
nbatches = query.size(0)
x = self.attn(query, key, value, mask=mask)
# Concatenate and multiply by W^O
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
return self.final_linear(x)
Technical note: The .contiguous method is added after .transpose, because .transpose shares its underlying memory storage with original tensor. Calling .view after that requires a contiguous tensor (documentation). The .view method allows for efficient reshaping, slicing and element-wise operations (documentation).
Because the dimension of each head is divided by hh, the total computation is similar to using one attention head with full dimensionality (dmodeld_\text{model}). However, with this approach the calculation can be parallelized along heads which leads to massive speedups on modern hardware. This is one of the innovations that allowed for training effective language models without convolutions or recurrence.

Residuals and Layer Normalization

The AI research community discovered that concepts like residual connections and (batch) normalization improve performance, reduce training time and allow training of deeper networks. Therefore, the transformer is equipped with residual connections and normalization after every attention layer and every feed forward layer. Additionally, dropout is added in each layer for better generalization.

Normalization

Modern deep learning based computer vision models often feature batch normalization. However, this type of normalization is dependent on a large batch size and does not lend itself naturally to recurrence. The traditional transformer architecture has layer normalization instead. Layer normalization is stable even with small batch sizes (batch size<8\text{batch size} < 8).
In order to calculate layer normalization, we first calculate the mean μi\mu_i and standard deviation σi\sigma_i separately for each sample in the minibatch.
μi=1Kk=1kxi,k\mu_i = \frac{1}{K} \sum_{k=1}^{k} x_{i, k}
σi=sqrt(1Kk=1k(xi,kμi)2)\sigma_i = \text{sqrt}\left(\frac{1}{K} \sum_{k=1}^{k} (x_{i, k} - \mu_i)^2\right)
Then, the normalization step is defined as:
LNγ,β(xi)γxμiσi+ϵ+βLN_{\gamma, \beta}(x_i) \equiv \gamma \frac{x - \mu_i}{\sigma_i + \epsilon} + \beta
where γ\gamma and β\beta are learnable parameters. A small number ϵ\epsilon is added for numerical stability in case the standard deviation σi\sigma_i is 00.
from torch import nn
class LayerNorm(nn.Module):
def __init__(self, features: int, eps: float = 1e-6):
super(LayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
self.eps = eps

def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta

Residual

A residual connection means that you add the output of a previous layer in the network (i.e sublayer) to the output of the current layer. This allows for very deep networks because the network can essentially 'skip' certain layers.
The final output of each layer will then be ResidualConnection(x)=x+Dropout(SubLayer(LayerNorm(x)))\text{ResidualConnection}(x) = x + \text{Dropout}(\text{SubLayer}(\text{LayerNorm}(x)))
from torch import nn
class ResidualConnection(nn.Module):
def __init__(self, size: int = 512, dropout: float = .1):
super(ResidualConnection, self).__init__()
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)

def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))


Feed Forward

On top of every attention layer a feed forward network is added. This consists of two fully-connected layers with a ReLU\text{ReLU} activation (ReLU(x)=max(0,x)\text{ReLU}(x) = \max(0, x)) and dropout for the inner layer. The standard dimensions used in the transformer paper are dmodel=512d_\text{model} = 512 for the input layer and dff=2048d_{ff} = 2048 for the inner layer.
The full calculation becomes FeedForward(x)=W2max(0,xW1+B1)+B2\text{FeedForward}(x) = W_2 max(0, xW_1 + B_1) + B_2.
Note that PyTorch Linear already includes the biases (B1B_1 and B2B_2) by default.
from torch import nn
class FeedForward(nn.Module):
def __init__(self, d_model: int = 512, d_ff: int = 2048, dropout: float = .1):
super(FeedForward, self).__init__()
self.l1 = nn.Linear(d_model, d_ff)
self.l2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)

def forward(self, x):
return self.l2(self.dropout(self.relu(self.l1(x))))

Encoder - Decoder

Encoding

Now we have all the components to build the model encoder and decoder. A single encoder layer consists of a multi-head attention layer followed by a feed-forward network. As mentioned earlier, we also include residual connections and layer normalization.
Encoding(x,mask)=FeedForward(MultiHeadAttention(x))\text{Encoding}(x, \text{mask}) = \text{FeedForward}(\text{MultiHeadAttention}(x))
from torch import nn
from copy import deepcopy
class EncoderLayer(nn.Module):
def __init__(self, size: int, self_attn: MultiHeadAttention, feed_forward: FeedForward, dropout: float = .1):
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.sub1 = ResidualConnection(size, dropout)
self.sub2 = ResidualConnection(size, dropout)
self.size = size

def forward(self, x, mask):
x = self.sub1(x, lambda x: self.self_attn(x, x, x, mask))
return self.sub2(x, self.feed_forward)
The final transformer encoder from the paper consists of 6 identical encoder layers followed by layer normalization.
class Encoder(nn.Module):
def __init__(self, layer, n: int = 6):
super(Encoder, self).__init__()
self.layers = nn.ModuleList([deepcopy(layer) for _ in range(n)])
self.norm = LayerNorm(layer.size)
def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)

Decoding

The decoding layer is a masked multi-head attention layer followed by multi-head attention layer that includes memory. Memory is an output from the encoder. Lastly, it goes through a feed-forward network. Again, all these components include residual connections and layer normalization.
Decoding(x,memory,mask1,mask2)=FeedForward(MultiHeadAttention(MultiHeadAttention(x,mask1),memory,mask2))\text{Decoding}(x, \text{memory}, \text{mask}_1, \text{mask}_2) = \text{FeedForward}(\text{MultiHeadAttention}(\text{MultiHeadAttention}(x, \text{mask}_1), \text{memory}, \text{mask}_2))
from torch import nn
from copy import deepcopy
class DecoderLayer(nn.Module):
def __init__(self, size: int, self_attn: MultiHeadAttention, src_attn: MultiHeadAttention,
feed_forward: FeedForward, dropout: float = .1):
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.sub1 = ResidualConnection(size, dropout)
self.sub2 = ResidualConnection(size, dropout)
self.sub3 = ResidualConnection(size, dropout)
def forward(self, x, memory, src_mask, tgt_mask):
x = self.sub1(x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sub2(x, lambda x: self.src_attn(x, memory, memory, src_mask))
return self.sub3(x, self.feed_forward)
As with the final encoder, the decoder in the paper also has 6 identical layers followed by layer normalization.
class Decoder(nn.Module):
def __init__(self, layer: DecoderLayer, n: int = 6):
super(Decoder, self).__init__()
self.layers = nn.ModuleList([deepcopy(layer) for _ in range(n)])
self.norm = LayerNorm(layer.size)
def forward(self, x, memory, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, memory, src_mask, tgt_mask)
return self.norm(x)

With this higher level representation of the encoder and decoder we can easily formulate the final encoder-decoder block.
from torch import nn
class EncoderDecoder(nn.Module):
def __init__(self, encoder: Encoder, decoder: Decoder,
src_embed: Embed, tgt_embed: Embed, final_layer: Output):
super(EncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.final_layer = final_layer
def forward(self, src, tgt, src_mask, tgt_mask):
return self.final_layer(self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask))
def encode(self, src, src_mask):
return self.encoder(self.src_embed(src), src_mask)
def decode(self, memory, src_mask, tgt, tgt_mask):
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)


Final Output

Lastly, the vector output from the decoder has to be transformed to a final output. For sequence to sequence problems like language translation this is a probability distribution over the total vocabulary for each position. One fully-connected layer transforms the output of the decoder to a matrix of logits, which have the dimension of the target vocabulary. These numbers are transformed to a probability distribution over the vocabulary through a softmax activation function. In the code we use LogSoftmax(xi)=log(exp(xi)jexp(xj))\text{LogSoftmax}(x_i) = \log(\frac{\exp(x_i)}{\sum_{j} \exp(x_j)}), because it is faster and more numerically stable.
For example, let's say the translated sentence has 2020 tokens and the total vocabulary is 3000030000 tokens. The resulting output will then be a matrix MR20×30000M \isin \mathbb{R}^{20 \times 30000}. We can then take the argmax\arg\max over the last dimension to get a vector of output tokens TN20T \isin \mathbb{N}^{20} that can be decoded to a string of text through a tokenizer.
Output(x)=LogSoftmax(max(0,xW1+B1))Output(x) = LogSoftmax(max(0, xW_1 + B_1))
from torch import nn
class Output(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super(Output, self).__init__()
self.l1 = nn.Linear(input_dim, output_dim)
self.log_softmax = nn.LogSoftmax(dim=-1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
logits = self.l1(x)
return self.log_softmax(logits)


Model Initialization

We build the transformer model with the same dimensions as in the paper. The initialization strategy is Xavier/Glorot initialization which consists of picking from a uniform distribution in the range of [1n,1n][-\frac{1}{\sqrt{n}}, \frac{1}{\sqrt{n}}]. All biases are initialized with 00.
Xavier(W)U[1n,1n],B=0\text{Xavier}(W) \sim U[-\frac{1}{\sqrt{n}}, \frac{1}{\sqrt{n}}], B = 0
from torch import nn
def make_model(input_vocab: int, output_vocab: int, d_model: int = 512):
encoder = Encoder(EncoderLayer(d_model, MultiHeadAttention(), FeedForward()))
decoder = Decoder(DecoderLayer(d_model, MultiHeadAttention(), MultiHeadAttention(), FeedForward()))
input_embed= nn.Sequential(Embed(vocab=input_vocab), PositionalEncoding())
output_embed = nn.Sequential(Embed(vocab=output_vocab), PositionalEncoding())
output = Output(input_dim=d_model, output_dim=output_vocab)
model = EncoderDecoder(encoder, decoder, input_embed, output_embed, output)
# Initialize parameters with Xavier uniform
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return model
This function return a PyTorch model that can be trained for sequence to sequence problems. Below you find a dummy example on how to use it with a tokenized input and output. Let's say we have a vocabulary of just 1010 words for the input and output.
# Tokenized symbols for source and target.
>>> src = torch.tensor([[1, 2, 3, 4, 5]])
>>> src_mask = torch.tensor([[1, 1, 1, 1, 1]])
>>> tgt = torch.tensor([[6, 7, 8, 0, 0]])
>>> tgt_mask = torch.tensor([[1, 1, 1, 0, 0]])

# Create PyTorch model
>>> model = make_model(input_vocab=10, output_vocab=10)
# Do inference and take tokens with highest probability through argmax along the vocabulary axis (-1)
>>> result = model(src, tgt, src_mask, tgt_mask)
>>> result.argmax(dim=-1)
tensor([[6, 6, 4, 3, 6]])
The output is very far off from the target, because the model has uniformly initialized weights at this point. Training these transformer models from scratch requires quite some computation. To train the base model, the authors from the original paper trained for 12 hours on 8 NVIDIA P100 GPUs. Their larger models took 3.5 days to train on 8 GPUs! I would advice using pre-trained transformer models and fine-tune them for your application. The HuggingFace Transformers library already has many pre-trained models for fine-tuning.
If you do want to learn more about coding up the training procedure from scratch I would suggest checking out the training section of The Annotated Transformer.

That's all! Hope you enjoyed this deep dive into the transformer architecture!
If you have any questions or feedback, feel free to comment below. You can also contact me on Twitter @carlolepelaars.

Sources and More Learning Resources

Suggested Papers

Online Resources


Charles Frye
Charles Frye •  
Panel
"Lastly, the weights are multiplied by \sqrt{d_{model}as a normalization step." Interestingly, this is exactly the input-layer scaling needed to make an infinite-width network train properly, see: https://arxiv.org/pdf/2011.14522.pdf
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.