[Overview] Taming Transformers for High-Resolution Image Synthesis

The efficiency of convolutional approaches with the expressivity of transformers. . Made by Ayush Thakur using Weights & Biases
Ayush Thakur
This report presents an overview of the paper, "Taming Transformers for High-Resolution Image Synthesis". This paper addresses the fundamental challenges of using the expressivity of transformers for high-resolution image synthesis. The proposed approach represents images as a composition of perpetually rich image constituents, and in turn, utilizes transformers to efficiently model their composition within high-resolution images.
Here's a nice Twitter thread by Sander Dieleman summarizing this paper.

Paper | Project Page | GitHub

Check out my Colab Notebook to generate stunning images like this:

Introduction

Convolutional Neural Networks (CNNs) are the go-to models for vision tasks. This is because CNNs exhibit a strong locality bias (due to the use of kernels) and a bias towards spatial invariance through the use of shared weights (a kernel sweeps the entire image).
Compared to CNNs, transformers contain no inductive bias that prioritizes local interactions, allowing them to learn complex relationships between inputs, in other words making them expressive. However, this is computationally expensive for long-term sequences. The increased expressivity of transformers comes with quadratically increasing computational costs.
Esser et al. (2021) demonstrate high-resolution image synthesis by combining the effectiveness of CNNs with the expressivity of the transformers.

Two Words on Generative Models

Given x \approx p_{data}(x), where p_{data}(x) is the true distribution describing the dataset X,the dataset consist of finite samples from this distribution,
X = {x|x \approx p_{data}(x)}
The generative task is to find a model such that p_{model}(x; θ) \approx p_{data}(x). Here θ is the model parameters.
Figure 1: Generative Model Taxonomy (Source)
We can distinguish two main types of generative models:
Here are the key ideas for a better understanding of this paper:

Overview of the Proposed Method

Previous works that applied transformers to image generation demonstrated promising results for images up to a size of 64x64 pixels but couldn't be scaled to a higher resolution due to quadratically increasing cost with sequence length. Thus to use transformers to synthesize higher resolution images we need to represent the semantics of an image cleverly. Using pixel representation is not going to work as the number of pixels increases quadratically with a 2x increase in image resolution.

Vector Quantized Variational Autoencoder(VQ-VAE)

This work is Inspired by Vectored Quantized Variational Autoencoder that differs from traditional VAE's in two key ways:
You can learn more about VAE's in my report Towards Deep Generative Modeling with W&B. But this work is a more natural evolution of VQ-VAE2 that showed how powerful representation learning can be, in the context of autoregressive generative modeling.
Figure 2: Overview of VQ-VAE architecture. (Source)
VQ-VAE consists of an encoder(E(.)the) that maps observations(images) onto a sequence of discrete latent variables, and a decoder(G(.)) that reconstructs the observations from these discrete variables. They use a shared codebook. The codebook is given by e \in R^{K \times D}. Here K is the size of the discrete code vectors in the codebook and D is the dimensionality of each code e_i, i \in 1, 2,..,K.
As shown in figure 2, an image x is passed through E producing E(x). This is then quantized based on its distance to the code vectors e_i such that each vector E(x) is replaced by the index of the nearest code vector in the codebook. The same is used by the decoder for reconstruction. The quantization is given by,
Quantization(E(x)) = e_k where k = \underset{j}{\operatorname{argmin}} || E(x) - e_j ||
Two quick last pointers,

Why is Discrete Latent Representation Working?

The work Generating Diverse High-Fidelity Images with VQ-VAE-2 that resembles Taming Transformers to an extent uses VQ-VAE2(modification to VQ-VAE), but why is discrete representation working?
This is inspired by JPEG lossy compression of the image. JPEG encoding removes more than 80% of the data without noticeably changing the perceived image quality. Secondly, training a generative model with less noise tends to work better.

Taming Transformer

Figure 3: Architecture design of Taming a Transformer.
To keep the sequence length small and harness the expressivity of the transformer the authors of taming transformer uses a discrete codebook of learned representations(inspired from VQ-VAE), such that an image x \in R^{H \times W \times 3} can be represented by a spatial collection of codebook entries z_q \in R^{h \times w \times n_z} where n_z is the dimensionality of codes.
As shown in figure 3, the authors have used VQ-GAN, a variant of the original VQ-VAE thus uses a discriminator and perpetual loss to keep good perceptual quality at increased compression rate.
This is a two-step training architecture design:

Training VQ-GAN

The VQ-GAN is trained using known adversarial training procedure with a patch-based discriminator D. The complete objective for finding the optimal compression model Q^* = \{E^*, G^*, Z^*\} is given by,
Q^* = \underset{E, G, Z}{\operatorname{argminmax}} \space E_{x \approx p(x)} [L_{VQ}(E, G, Z) + \lambda L_{GAN}(\{E, G, Z\}, D)]
where \lambda is an adaptive weight.
This training procedure significantly reduces the sequence length when unrolling the latent code and thereby enables the application of powerful transformer models.
In order to train your own VQ-GAN, you can clone the official GitHub repo for this paper and install the dependencies. This repository is already instrumented with Weights and Biases thus you can automatically get all the necessary metrics in your W&B dashboard. The data preparation step is clearly mentioned by the authors. To train the VQ-GAN:
python main.py --base configs/faceshq_vqgan.yaml -t True --gpus 0,
You can select a different .yaml file depending on the dataset of your choice.

Training Transformer

With the trained encoder and decoder, we can now represent an image in terms of the codebook-indices of their encodings. This quantized encoding on image x is given by z_q = q(E(x)) \in R^{h \times w \times n_z}. By expressing this encoding as a sequence s \in \{0,...., |Z|-1\}^{h \times w} of indices from the codebook, which is obtained by replacing each code by its index in the codebook Z, s_{ij} = k such that (z_q)_{ij} = z_k. One can now map the sequence s back to the corresponding codebook entry and can get the decoded image \tilde{x} = G(z_q).
With this sequence, image-generation can be formulated as an autoregressive next-index prediction. Given indices s_{, the transformer learns to predict the distribution of possible next indices, i.e, p(s_i | s_{ to compute the likelihood of the full representation as p(s) = \prod_i p(s_i | s_{. Thus the objective function is to maximize the log-likelihood of the data representations. Here's a really nice article on likelihood.
Any image generation system is useful if the user can control the generation process. Image generation can be conditioned on additional information like class labels(as shown above) or partial images. The task then is to learn the likelihood of the sequence given this information c:
p(s|c) = \prod_i p(s_i | s_{
With the trained VQ-GAN, adjust the checkpoint path of the config key model.params.first_stage_config.params.ckpt_path in configs/faceshq_transformer.yaml, then run:
python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0,

Generating High-Resolution Images

The quadratically increasing computational cost of training a transformer with the increase in sequence length puts limits on the sequence length. Thus to generate images in the megapixel regime, the authors have to work patch-wise and crop images to restrict the length of s to a maximally feasible size during training.
To sample images, the transformer is used in a sliding-window manner.
Figure 4: The sliding attention window
The VQ-GAN ensures that the available context is still sufficient to faithfully model images, as long as either the statistics of the dataset are approximately spatially invariant or spatial conditioning information is available.
In the media panel shown below click on the ⚙️and use the slider to visualize how a high-resolution image was synthesized using the sliding window.

Results

Further Reading and Conclusion

The goal of this report is to summarize the paper, making it more accessible for the readers. I have used lines from the paper at places because that was the best way to convey the information.
Here are some of the further reads that you might find exciting.