Skip to main content

How Do Vision Transformers Work ?

An in-depth breakdown of 'How Do Vision Transformers Work?' by Namuk Park and Songkuk Kim.
Created on February 28|Last edited on March 1

Table of Contents

ο»Ώ

πŸ”‘ Key Takeaways (TL;DR)

  1. Multihead Self-Attention(MSA) improve not only accuracy but also generalization by flattening the loss landscapes. This is primarily attributable to their data specificity and, not long-range dependency.
  2. Vision Transformers (ViTs) suffer from non-convex losses. But, large datasets and loss landscape smoothing methods alleviate this problem.
  3. Multihead Self-Attention and Convolution exhibit opposite behaviours. Multihead Self Attention (operations) are low-pass filters, but Convolution(operations) are high-pass filters.
  4. Multi-Stage Neural Networks behave like a series-connection of small individual models.
  5. The authors propose a new Conv+MSA Hybrid model where the Convolution blocks at the end of a stage are replaced with MSA blocks.

πŸ‘‹ Motivation

Although Multihead Self Attention (MSA) has become ubiquitous in vision, there is a lack of understanding about the reasons behind their success. The widely accepted reasons are their weak inductive bias and the ability to capture long-range dependencies because of the flexibility granted by Attention. However, it has been observed that to make Vision Transformers work well there is a need for larger datasets and meticulous training paradigms which as a consequence leads to poor performance in small data regimes. ()
Self Attention is a simple operation in which each token (could be a patch of image, cube of a video, spectrograms or some word representation) is able to attend to all other tokens in the section. Most Attention variants can be viewed as a specialization of the following form:
\huge
ο»Ώ
Ξ±j=softmax(x1β‹…xj ,x2β‹…xj ,... ,xnβ‹…xj)yj=βˆ‘i=1nΞ±jixi\huge \begin{array}{ll} \alpha_j &= \textrm{softmax} (x_1 \cdot x_j\, , x_2 \cdot x_j \, , ...\, , x _n \cdot x_j) \\ y_j &= \displaystyle \sum_{i=1}^{n} \alpha_{ji} x_i \end{array}
ο»Ώ
The most common form of attention is of the form:-
Ξ±j=softmax(Kx1β‹…QxjdK ,... ,Kxnβ‹…QxjdK)yj=βˆ‘i=1nΞ±ji V xi\huge \begin{array}{ll} \alpha_j &= \textrm{softmax} \left (\frac{Kx_1 \cdot Qx_j}{\sqrt{d_K}}\, , ...\, , \frac{Kx _n \cdot Qx_j}{\sqrt{d_K}}\right) \\ y_j &= \displaystyle \sum_{i=1}^{n} \alpha_{ji} \,V\,x_i \end{array}
ο»Ώ
where K,Q and VK, Q \, \textrm{and} \, Vο»Ώ are learned vectors. Some key things to notice here:-
  • Attention is a set operation (permutation-equivariant), thus there is not context of global structure and therefore most transformers use some sort of Positional Embeddings to throw in some canonical information into the mix. This is in contrast to Convolutions which convolve the input and pass its result to the next layer. This hard-codes a certain sense of translational-invariance i.e. each patch in an image is processed by the same weights.
  • Attention is a global operation i.e. it aggregates information from all tokens as opposed to convolutions which are local operations. This local connectivity in CNN's can lead to a loss of global context; for instance, encouraging a bias towards classifying on the basis of texture rather than shape. (Tuli et al. 2021)

ο»Ώ
Thus Vision Transformers don't have the inductive bias of convolutions, which can be extremely helpful in the small-date regime. To help counteract this several local MSA solutions were proposed in CSAN, Swin and Twins which calculate self-attention only within certain small windows.
Now let's discuss the key questions raised in this paper

1️⃣ What properties of MSAs do we need to improve optimization?

It's well known that ViTs relaxes the translation-invariance constraint of CNN's and therefore represents a model with a weaker set of inductive biases. But does this weak inductive bias help overfit datasets ?
To validate this claim, the authors conducted experiments comparing the error of the test dataset and the negative log-likelihood on the training dataset. As evident from the experiments highlighted below, surprisingly the stronger the inductive bias, the lower is the test error and the training NLL.
<Insert W&B Panel>
Moreover upon analysing the test error and training NLL on subsampled datasets it seems like ViT's poor performance in small data regimes is not due to overfitting.
Another perspective one can take is the loss landscape perspective. The landscape for a ViT is non-convex whereas that of ResNets is strongly (arguably near-) convex. This disrupts training, especially in the early phase of training. The authors show the ViT has a number of negative Hessian eigenvalues, while ResNet only has a few.
Friendly Math Reminder: Hessian Eigenvalue  α  Sharpness of Loss Landscape\large\textrm{Hessian Eigenvalue} \,\, \alpha \,\, \textrm{Sharpness of Loss Landscape}ο»Ώο»Ώ
πŸ’‘
The authors also show that large datasets suppress negative Hessian eigenvalues in the early phase of training. Therefore, large datasets tend to help ViT learn strong representations by convexifying the loss. ResNet enjoys little benefit from large datasets because its loss is convex even on small datasets. Landscape flattening techniques like Global Average Pooling (GAP), Global Mean Pooling (GMP) and Multihead Attention Pooling (MAP) smoothen the landscape by ensembling feature map points and therefore convexifying the loss.
The authors claim that contrary to popular belief, the long-range dependency hinders NN optimization. To demonstrate this, the authors compared ConvViTs, which calculate self-attention only between feature map points in convolutional receptive fields after unfolding the feature maps in the same way as convolutions and show how 5x5 kernels outperform 3x3 and 8x8 in terms of performance thereby showing that the strong locality inductive bias not only reduces computational complexity but also aid in optimization by convexifying the loss landscape. The following charts reinforce this claim
<Insert W&B Panel>

2️⃣ Does Multihead Self Attention act like Convolutions?

Convolutions are data-agnostic and channel-specific whereas Multihead Self Attention is data-specific and channel-agnostic. Multihead Self Attention blocks are low-pass filters, but Convolutions are high- pass filters. Multihead Self Attention spatially smoothen feature maps with self-attention importances and therefore, the authors expect (and later prove) that MSAs will tend to reduce high-frequency signals. The authors show how the relative log amplitude of ViT’s Fourier transformed feature map is almost always decreased by MSA and how MLPs - corresponding to Convolutions - increase it. The only exception is in the early stages of the model. In these stages, MSAs behave like Convolutions and increase the amplitude. This further serves as an evidence for a hybrid model that uses Convolutions in early stages and MSAs in later stages.
Thus we can infer that low-frequency signals are informative to Multihead Self Attention and high-frequency signals are informative for Convolutions.
πŸ’‘
Appropriately ViT and ResNet are vulnerable to low-frequency noise and high-frequency noise, respectively. Low-frequency signals and the high- frequency signals each correspond to the shape and the texture of images, further suggesting that MSAs are shape-biased, whereas Convolutions are texture-biased.
Since MSAs average feature maps, they reduce variance of feature map points. Thus suggesting that MSAs ensemble feature maps. The following figure shows that MSAs in ViT tend to reduce the variance; conversely, Convolutions in ResNet and MLPs in ViT increase it.
In conclusion, MSAs ensemble feature map predictions, but Convolutions do not.
πŸ’‘
As Park & Kim (2021) figured out, reducing the feature map uncertainty helps optimization by ensembling and stabilizing the transformed feature maps. Furthermore, the variance accumulates in every layer (or rather a block) and tends to increase as the depth increases and the feature map variance in ResNet peaks at the ends of each stage. Therefore, suggesting that we can improve the predictive performance of ResNets by inserting MSAs at the end of each stage.

3️⃣ How can we harmonize Multihead Self Attention with Convolutions?

Since, we've established that MSAs and Convolutions are complementary to each other. Now let's look at some design rules to use the best parts of both.
The authors observe that the pattern of feature map variance repeats itself at every stages and that this behavior is also observed in feature map similarities and lesion studies. They observed that the feature map similarities of CNNs and multi-stage ViTs, such as PiT and Swin, have a block structure. Since Vanilla ViT does not have this structure, its conclusive to say that the structure is an intrinsic characteristic of multi-stage architectures. For lesion studies the important note is that removing a layer at the beginning of a stage impairs accuracy more than removing a layer at the end of a stage. Thus one would expect that MSAs closer to the end of a stage to significantly improve the performance.
Following the discussion above discussions the authors propose a new model architecture called AlterNet based on the following rules:-
  1. Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model.
  2. If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA block.
  3. Use more heads and higher hidden dimensions for MSA blocks in late stages.

πŸ‘‹ Conclusion

In this paper the authors demonstrate that Multihead Self Attention is not merely generalized Convolutions, but rather generalized spatial smoothings that complement Convolutions. Multihead Self Attention also help Neural Networks learn strong representations by ensembling feature map points and flattening the loss landscape. They propose AlterNet based on their learnings which preserves the architectures of Convolution and Multihead Self Attention blocks in AlterNet. As Park & Kim (2021) pointed out, global average pooling (GAP) for simple classification tasks has a strong tendency to ensemble feature maps, but Neural Networks for dense prediction do not use GAP. Therefore, they believe that Multihead Self Attention to be able to significantly improve the results in dense prediction tasks by ensembling feature maps. Lastly, strong data augmentation for MSA training harms uncertainty calibration.
The code and all the details of the study were open-sourced in a great Github Repository ⭐️ which we forked and open sourced with a W&B Implementation.
To cite the paper kindly use the following BibTeX :-
@inproceedings{
park2022how,
title={How Do Vision Transformers Work?},
author={Namuk Park and Songkuk Kim},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=D78Go4hVcxO}
}

πŸ“š References

ο»Ώ

Papers

Videos

  1. ο»Ώ