Transformer Deep Dive

Diving into the breakthroughs, scientific basis, formulas and code for the transformer architecture. Made by Carlo Lepelaars using Weights & Biases
Carlo Lepelaars

Introduction

After language model breakthroughs like ELMo and ULMFit, transformers have taken the Natural Language Processing (NLP) field by storm. They are the basis of popular language models like GPT-3 and DALL-E. Also, tools like the HuggingFace Transformers library have made it easy for machine learning engineers to solve a wide range of NLP tasks and facilitated a lot of subsequent breakthroughs in NLP.

In this report, we will take a deep dive into the transformer architecture as described in the paper "Attention Is All You Need" (2017) and code it up using PyTorch.

<br><br>

The transformer is a sequence to sequence (seq2seq) model. This means that it is suited for every problem where there is some ordering in the data and the output is itself a sequence. Example applications are machine translation, abstractive summarization, and speech recognition. Recently, Vision Transformers (ViT) even improved the state-of-the-art in computer vision.

Below you see a visualization of the complete transformer architecture. We will explain what every component does, why it is there and how everything fits together. For now, recognize that there is an encoder (on the left) and a decoder (on the right) which both have several neural network layers inside them.

transform.png

Code examples are refactored from The Annotated Transformer by Harvard NLP Group and PyTorch documentation on transformers.

Tokenization

First of all, we need a way to represent text numerically in order to perform calculations on it. Tokenization is the process of parsing a string of text into a compressed sequence of symbols. This process results in a vector of integers where each integer represents part of the text. The transformer paper uses Byte-Pair Encoding (BPE) as the tokenization method. BPE is a form of compression where the most common consecutive bytes (i.e. characters) are compressed into a single byte (i.e. integer).

Recent research suggests BPE is suboptimal and recent language models like BERT use a WordPiece tokenizer instead. WordPiece seems easier to decode and more intuitive as it often tokenizes entire words into one token. In contrast, BPE often tokenizes slices of words. This can lead to weird tokenization symbols and individual characters (i.e. letters from the alphabet) being tokenized as unknown tokens. Therefore, WordPiece is the tokenizer we use in an example. We could train our own tokenizer on a corpus of text, but in reality practitioners almost always use a pre-trained tokenizer. The HuggingFace transformers library makes it easy to work with pre-trained tokenizers.

>>> from transformers import BertTokenizer
>>> tok = BertTokenizer.from_pretrained("bert-base-uncased")
Downloading: 100%|█████████████████████████| 1.04M/1.04M [00:00<00:00, 1.55MB/s]
Downloading: 100%|████████████████████████████| 456k/456k [00:00<00:00, 848kB/s]

>>> tok("Hello, how are you doing?")['inputs_ids']
{'input_ids': [101, 7592, 2129, 2024, 2017, 2725, 1029, 102]}

>>> tok("The Frenchman spoke in the [MASK] language and ate 🥖")['input_ids']
{'input_ids': [101, 1996, 26529, 3764, 1999, 1996, 103, 2653, 1998, 8823, 100, 102]}

>>> tok("[CLS] [SEP] [MASK] [UNK]")['input_ids']
{'input_ids': [101, 101, 102, 103, 100, 102]}

Note that the tokenizer automatically includes a token for the start of an encoding ([CLS] == 101) and the end of an encoding (i.e. SEParation) ([SEP] == 102). Other special tokens include masking ([MASK] == 103) and an unknown symbol ([UNK] = 100, e.g. for the 🥖 emoji).

Embeddings

In order to learn proper representations of text, each individual token in the sequence is converted to a vector through an embedding. It can be seen as a type of neural network layer, because the weights for the embeddings are learned along with the rest of the transformer model. It contains a vector for each word in the vocabulary and these weights are initialized from a normal distribution \mathcal{N}(0, 1). Note that it requires us to specify the size of the vocabulary (|\text{vocab}|) and the dimension of the model (d_\text{model} = 512) when initializing the model (E \isin \mathbb{R}^{|\text{vocab}| \times d_\text{model}}). Lastly, the weights are multiplied by \sqrt{{d_\text{model}}} as a normalization step.
import torchfrom torch import nnclass Embed(nn.Module): def __init__(self, vocab: int, d_model: int = 512): super(Embed, self).__init__() self.d_model = d_model self.vocab = vocab self.emb = nn.Embedding(self.vocab, self.d_model) self.scaling = torch.sqrt(self.d_model) def forward(self, x): return self.emb(x) * self.scaling

Positional Encoding

In contrast to recurrent and convolutional networks, the model in itself has no information about relative position of the embedded tokens in a sequence. Therefore, we have to inject this information by adding an encoding to the input embeddings for the encoder and decoder. This information can be added in many different ways and can be static or learned. The transformer uses sine and cosine transformations for each position (\text{pos}). Sine is used for the even dimensions (2_i) and cosine for the odd dimensions (2_{i+1}).
PE_{\text{pos}, 2i} = \sin(\frac{\text{pos}}{10000^{2i / d_\text{model}}})
PE_{\text{pos}, 2i+1} = \cos(\frac{\text{pos}}{10000^{2i / d_\text{model}}})
In the code, positional encodings are computed in log space to avoid numerical overflow.
import torchfrom torch import nnfrom torch.autograd import Variableclass PositionalEncoding(nn.Module): def __init__(self, d_model: int = 512, dropout: float = .1, max_len: int = 5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(dropout) # Compute the positional encodings in log space pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.Tensor([10000.0])) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) return self.dropout(x)

Multi-Head Attention

Before transformers, the paradigm in AI research to learn from sequences was to use either convolutions (WaveNet, ByteNet) or recurrence (RNN, LSTM). Attention already allowed for some NLP breakthroughs before transformers (Luong et al., 2015), but it was not obvious back then that you could build effective models without convolutions or recurrence. Therefore, proposing that "Attention is All You Need" was quite a bold statement.
The attention layer can learn a mapping between a query (Q) and a set of key (K) value (V) pairs. The meaning of these names can be confusing, because they depend on the particular NLP application. For developing our transformers you can just think of them as linear projections of the input. The names query, key and value come from traditional information retrieval theory and we will briefly explain where these terms come from with an example.
When you are searching for a video on Youtube, you will type in a phrase in the search bar (i.e. a query Q). The search engine will use this to map against Youtube video titles, descriptions etc. (i.e. keys K). Using this mapping it will suggest the most relevant videos to you (i.e. values V). (Example source)
One innovation which boosted the performance of attention in NLP is what the authors call "Scaled dot-product attention". It is the same as multiplicative attention, but additionally the mapping of Q and K is scaled by the key dimension d_k. This makes multiplicative attention perform better with larger dimensions. The result is put through a softmax activation (\text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_{j} exp(x_j)}) and multiplied by V.
\text{Attention}(Q, K, V) = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}} )V
import torchfrom torch import nnclass Attention: def __init__(self, dropout: float = 0.): super(Attention, self).__init__() self.dropout = nn.Dropout(dropout) self.softmax = nn.Softmax(dim=-1) def forward(self, query, key, value, mask=None): d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = self.dropout(self.softmax(scores)) return torch.matmul(p_attn, value) def __call__(self, query, key, value, mask=None): return self.forward(query, key, value, mask)
In the decoder, an attention sub-layer is masked by filling certain positions with a very large negative number (-1\text{e}9 or -\inf). This is to prevent the model from cheating by attending to subsequent positions. This ensures that the model can only attend to words at previous positions when it tries to predict the next token.
The mechanism of attention in itself is already very powerful and can be calculated efficiently on modern hardware, like GPUs and TPUs which are optimized for matrix multiplication. However, a single attention layer only allows for one representation. Therefore, in the transformer multiple attention heads are used. This allows the model to learn multiple patterns and representations. The paper uses h = 8 attention layers which are concatenated. The final formula becomes:
\text{MultiHead}(Q, K, V) = \text{concat}(\text{head}_1, \dots, \text{head}_n) W^O
where \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
The projection weights (W^O \isin \mathbb{R}^{hd_{v} \times d_\text{model}}, W_i^Q \isin \mathbb{R}^{d_\text{model} \times d_k}, W_i^K \isin \mathbb{R}^{d_\text{model} \times d_k}, W_i^V \isin \mathbb{R}^{d_\text{model} \times d_v}) are outputs of a fully-connected (Linear) layer. The authors of the transformer paper use d_k = d_v = \frac{d_\text{model}}{h} = 64.
from torch import nnfrom copy import deepcopyclass MultiHeadAttention(nn.Module): def __init__(self, h: int = 8, d_model: int = 512, dropout: float = 0.1): super(MultiHeadAttention, self).__init__() self.d_k = d_model // h self.h = h self.attn = Attention(dropout) self.lindim = (d_model, d_model) self.linears = nn.ModuleList([deepcopy(nn.Linear(*self.lindim)) for _ in range(4)]) self.final_linear = nn.Linear(*self.lindim, bias=False) self.dropout = nn.Dropout(p=dropout) def forward(self, query, key, value, mask=None): if mask is not None: mask = mask.unsqueeze(1) query, key, value = [l(x).view(query.size(0), -1, self.h, self.d_k).transpose(1, 2) \ for l, x in zip(self.linears, (query, key, value))] nbatches = query.size(0) x = self.attn(query, key, value, mask=mask) # Concatenate and multiply by W^O x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) return self.final_linear(x)
Technical note: The .contiguous method is added after .transpose, because .transpose shares its underlying memory storage with original tensor. Calling .view after that requires a contiguous tensor (documentation). The .view method allows for efficient reshaping, slicing and element-wise operations (documentation).
Because the dimension of each head is divided by h, the total computation is similar to using one attention head with full dimensionality (d_\text{model}). However, with this approach the calculation can be parallelized along heads which leads to massive speedups on modern hardware. This is one of the innovations that allowed for training effective language models without convolutions or recurrence.

Residuals and Layer Normalization

The AI research community discovered that concepts like residual connections and (batch) normalization improve performance, reduce training time and allow training of deeper networks. Therefore, the transformer is equipped with residual connections and normalization after every attention layer and every feed forward layer. Additionally, dropout is added in each layer for better generalization.

Normalization

Modern deep learning based computer vision models often feature batch normalization. However, this type of normalization is dependent on a large batch size and does not lend itself naturally to recurrence. The traditional transformer architecture has layer normalization instead. Layer normalization is stable even with small batch sizes (\text{batch size} < 8).
In order to calculate layer normalization, we first calculate the mean \mu_i and standard deviation \sigma_i separately for each sample in the minibatch.
\mu_i = \frac{1}{K} \sum_{k=1}^{k} x_{i, k}
\sigma_i = \text{sqrt}\left(\frac{1}{K} \sum_{k=1}^{k} (x_{i, k} - \mu_i)^2\right)
Then, the normalization step is defined as:
LN_{\gamma, \beta}(x_i) \equiv \gamma \frac{x - \mu_i}{\sigma_i + \epsilon} + \beta
where \gamma and \beta are learnable parameters. A small number \epsilon is added for numerical stability in case the standard deviation \sigma_i is 0.
from torch import nnclass LayerNorm(nn.Module): def __init__(self, features: int, eps: float = 1e-6): super(LayerNorm, self).__init__() self.gamma = nn.Parameter(torch.ones(features)) self.beta = nn.Parameter(torch.zeros(features)) self.eps = eps def forward(self, x): mean = x.mean(-1, keepdim=True) std = x.std(-1, keepdim=True) return self.gamma * (x - mean) / (std + self.eps) + self.beta

Residual

A residual connection means that you add the output of a previous layer in the network (i.e sublayer) to the output of the current layer. This allows for very deep networks because the network can essentially 'skip' certain layers.
The final output of each layer will then be \text{ResidualConnection}(x) = x + \text{Dropout}(\text{SubLayer}(\text{LayerNorm}(x)))
from torch import nnclass ResidualConnection(nn.Module): def __init__(self, size: int = 512, dropout: float = .1): super(ResidualConnection, self).__init__() self.norm = LayerNorm(size) self.dropout = nn.Dropout(dropout) def forward(self, x, sublayer): return x + self.dropout(sublayer(self.norm(x)))

Feed Forward

On top of every attention layer a feed forward network is added. This consists of two fully-connected layers with a \text{ReLU} activation (\text{ReLU}(x) = \max(0, x)) and dropout for the inner layer. The standard dimensions used in the transformer paper are d_\text{model} = 512 for the input layer and d_{ff} = 2048 for the inner layer.
The full calculation becomes \text{FeedForward}(x) = W_2 max(0, xW_1 + B_1) + B_2.
Note that PyTorch Linear already includes the biases (B_1 and B_2) by default.
from torch import nnclass FeedForward(nn.Module): def __init__(self, d_model: int = 512, d_ff: int = 2048, dropout: float = .1): super(FeedForward, self).__init__() self.l1 = nn.Linear(d_model, d_ff) self.l2 = nn.Linear(d_ff, d_model) self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) def forward(self, x): return self.l2(self.dropout(self.relu(self.l1(x))))

Encoder - Decoder

Encoding

Now we have all the components to build the model encoder and decoder. A single encoder layer consists of a multi-head attention layer followed by a feed-forward network. As mentioned earlier, we also include residual connections and layer normalization.
\text{Encoding}(x, \text{mask}) = \text{FeedForward}(\text{MultiHeadAttention}(x))
from torch import nnfrom copy import deepcopyclass EncoderLayer(nn.Module): def __init__(self, size: int, self_attn: MultiHeadAttention, feed_forward: FeedForward, dropout: float = .1): super(EncoderLayer, self).__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.sub1 = ResidualConnection(size, dropout) self.sub2 = ResidualConnection(size, dropout) self.size = size def forward(self, x, mask): x = self.sub1(x, lambda x: self.self_attn(x, x, x, mask)) return self.sub2(x, self.feed_forward)
The final transformer encoder from the paper consists of 6 identical encoder layers followed by layer normalization.
class Encoder(nn.Module): def __init__(self, layer, n: int = 6): super(Encoder, self).__init__() self.layers = nn.ModuleList([deepcopy(layer) for _ in range(n)]) self.norm = LayerNorm(layer.size) def forward(self, x, mask): for layer in self.layers: x = layer(x, mask) return self.norm(x)

Decoding

The decoding layer is a masked multi-head attention layer followed by multi-head attention layer that includes memory. Memory is an output from the encoder. Lastly, it goes through a feed-forward network. Again, all these components include residual connections and layer normalization.
\text{Decoding}(x, \text{memory}, \text{mask}_1, \text{mask}_2) = \text{FeedForward}(\text{MultiHeadAttention}(\text{MultiHeadAttention}(x, \text{mask}_1), \text{memory}, \text{mask}_2))
from torch import nnfrom copy import deepcopyclass DecoderLayer(nn.Module): def __init__(self, size: int, self_attn: MultiHeadAttention, src_attn: MultiHeadAttention, feed_forward: FeedForward, dropout: float = .1): super(DecoderLayer, self).__init__() self.size = size self.self_attn = self_attn self.src_attn = src_attn self.feed_forward = feed_forward self.sub1 = ResidualConnection(size, dropout) self.sub2 = ResidualConnection(size, dropout) self.sub3 = ResidualConnection(size, dropout) def forward(self, x, memory, src_mask, tgt_mask): x = self.sub1(x, lambda x: self.self_attn(x, x, x, tgt_mask)) x = self.sub2(x, lambda x: self.src_attn(x, memory, memory, src_mask)) return self.sub3(x, self.feed_forward)
As with the final encoder, the decoder in the paper also has 6 identical layers followed by layer normalization.
class Decoder(nn.Module): def __init__(self, layer: DecoderLayer, n: int = 6): super(Decoder, self).__init__() self.layers = nn.ModuleList([deepcopy(layer) for _ in range(n)]) self.norm = LayerNorm(layer.size) def forward(self, x, memory, src_mask, tgt_mask): for layer in self.layers: x = layer(x, memory, src_mask, tgt_mask) return self.norm(x)

With this higher level representation of the encoder and decoder we can easily formulate the final encoder-decoder block.

from torch import nn
class EncoderDecoder(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, 
                 src_embed: Embed, tgt_embed: Embed, final_layer: Output):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.final_layer = final_layer
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        return self.final_layer(self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask))
    
    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

Final Output

Lastly, the vector output from the decoder has to be transformed to a final output. For sequence to sequence problems like language translation this is a probability distribution over the total vocabulary for each position. One fully-connected layer transforms the output of the decoder to a matrix of logits, which have the dimension of the target vocabulary. These numbers are transformed to a probability distribution over the vocabulary through a softmax activation function. In the code we use \text{LogSoftmax}(x_i) = \log(\frac{\exp(x_i)}{\sum_{j} \exp(x_j)}), because it is faster and more numerically stable.
For example, let's say the translated sentence has 20 tokens and the total vocabulary is 30000 tokens. The resulting output will then be a matrix M \isin \mathbb{R}^{20 \times 30000}. We can then take the \arg\max over the last dimension to get a vector of output tokens T \isin \mathbb{N}^{20} that can be decoded to a string of text through a tokenizer.
Output(x) = LogSoftmax(max(0, xW_1 + B_1))
from torch import nnclass Output(nn.Module): def __init__(self, input_dim: int, output_dim: int): super(Output, self).__init__() self.l1 = nn.Linear(input_dim, output_dim) self.log_softmax = nn.LogSoftmax(dim=-1) def forward(self, x: torch.Tensor) -> torch.Tensor: logits = self.l1(x) return self.log_softmax(logits)

Model Initialization

We build the transformer model with the same dimensions as in the paper. The initialization strategy is Xavier/Glorot initialization which consists of picking from a uniform distribution in the range of [-\frac{1}{\sqrt{n}}, \frac{1}{\sqrt{n}}]. All biases are initialized with 0.
\text{Xavier}(W) \sim U[-\frac{1}{\sqrt{n}}, \frac{1}{\sqrt{n}}], B = 0
from torch import nndef make_model(input_vocab: int, output_vocab: int, d_model: int = 512): encoder = Encoder(EncoderLayer(d_model, MultiHeadAttention(), FeedForward())) decoder = Decoder(DecoderLayer(d_model, MultiHeadAttention(), MultiHeadAttention(), FeedForward())) input_embed= nn.Sequential(Embed(vocab=input_vocab), PositionalEncoding()) output_embed = nn.Sequential(Embed(vocab=output_vocab), PositionalEncoding()) output = Output(input_dim=d_model, output_dim=output_vocab) model = EncoderDecoder(encoder, decoder, input_embed, output_embed, output) # Initialize parameters with Xavier uniform for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) return model
This function return a PyTorch model that can be trained for sequence to sequence problems. Below you find a dummy example on how to use it with a tokenized input and output. Let's say we have a vocabulary of just 10 words for the input and output.
# Tokenized symbols for source and target.>>> src = torch.tensor([[1, 2, 3, 4, 5]])>>> src_mask = torch.tensor([[1, 1, 1, 1, 1]])>>> tgt = torch.tensor([[6, 7, 8, 0, 0]])>>> tgt_mask = torch.tensor([[1, 1, 1, 0, 0]])# Create PyTorch model>>> model = make_model(input_vocab=10, output_vocab=10)# Do inference and take tokens with highest probability through argmax along the vocabulary axis (-1)>>> result = model(src, tgt, src_mask, tgt_mask)>>> result.argmax(dim=-1)tensor([[6, 6, 4, 3, 6]])
The output is very far off from the target, because the model has uniformly initialized weights at this point. Training these transformer models from scratch requires quite some computation. To train the base model, the authors from the original paper trained for 12 hours on 8 NVIDIA P100 GPUs. Their larger models took 3.5 days to train on 8 GPUs! I would advice using pre-trained transformer models and fine-tune them for your application. The HuggingFace Transformers library already has many pre-trained models for fine-tuning.
If you do want to learn more about coding up the training procedure from scratch I would suggest checking out the training section of The Annotated Transformer.
That's all! Hope you enjoyed this deep dive into the transformer architecture!
If you have any questions or feedback, feel free to comment below. You can also contact me on Twitter @carlolepelaars.

Sources and More Learning Resources

Suggested Papers

Attention Is All You Need (2017)
BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (2018)
Reformer: The Efficient Transformer (2020)
Linformer: Self-Attention with Linear Complexity (2020)
Longformer: The Long-Document Transformer (2020)
Language Models are Few-Shot Learners (GPT-3 paper)(2020)

Online Resources

The Annotated transformer with PyTorch Code
The Illustrated transformer
The Narrated transformer video
Łukasz Kaiser's masterclass on transformers
HuggingFace Transformers course