Skip to main content

Taming Transformers for High-Resolution Image Synthesis

The efficiency of convolutional approaches with the expressivity of transformers.
Created on February 10|Last edited on September 14

This article presents an overview of the paper "Taming Transformers for High-Resolution Image Synthesis." This paper addresses the fundamental challenges of using the expressivity of transformers for high-resolution image synthesis.
The proposed approach represents images as a composition of perpetually rich image constituents and, in turn, utilizes transformers to efficiently model their composition within high-resolution images.
Here's what we'll be covering:

Table of Contents



Here's a nice Twitter thread by Sander Dieleman summarizing this paper.

Paper | Project Page | GitHub

Check out my Colab Notebook to generate stunning images like this:

Run set
1


Introduction

Convolutional Neural Networks (CNNs) are the go-to models for vision tasks. This is because CNNs exhibit a strong locality bias (due to the use of kernels) and a bias towards spatial invariance through the use of shared weights (a kernel sweeps the entire image).
Compared to CNNs, transformers contain no inductive bias that prioritizes local interactions, allowing them to learn complex relationships between inputs, in other words, making them expressive. However, this is computationally expensive for long-term sequences. The increased expressivity of transformers comes with quadratically increasing computational costs.
Esser et al. (2021) demonstrate high-resolution image synthesis by combining the effectiveness of CNNs with the expressivity of the transformers.

Two Words on Generative Models

Given xpdata(x)x \approx p_{data}(x), where pdata(x)p_{data}(x) is the true distribution describing the dataset, X,X,the dataset consists of finite samples from this distribution,
X=xxpdata(x)X = {x|x \approx p_{data}(x)}
The generative task is to find a model such that pmodel(x;θ)pdata(x)p_{model}(x; θ) \approx p_{data}(x). Here θθ is the model parameters.
Figure 1: Generative Model Taxonomy (Source)
We can distinguish two main types of generative models:
  • Likelihood-based (explicit) models: These models here provide an "explicit" parametric specification of the data distribution and have a tractable likelihood function. For example, Variational Autoencoders (VAEs) and autoregressive models.
  • Implicit models: These models do not specify the distribution of the data itself, but instead define a stochastic process that, after training, aims to draw samples from the underlying data distribution. For example, Generative Adversarial Models (GANs).
Here are the key ideas for a better understanding of this paper:
  • GANs, being implicit, are hard to evaluate and normally fail to cover all the "mode" of the data and are thus more susceptible to model collapse.
  • Larger-scale GANs can now generate high-quality and high-resolution images. However, it is well known that samples from these models do not fully capture the diversity of the true distribution. (Source)
  • Likelihood-based methods optimize negative log-likelihood (NLL) of the training data that allows for easier model comparison and better generalization to unseen data. However, maximizing likelihood in the pixel space is challenging and computationally expensive.
  • In an autoregressive model, we assume that an example xXx \in X can be represented as sequences xix_i. The distribution is factorized into a product of conditionals, using chain rule of probability given as: pdata(x)=q(x1)q(x2x1)..q(xdxd1,...,x1).p_{data}(x) = q(x_1)q(x_2|x_1)..q(x_d|x_{d-1},..., x_1). The model predicts the next xix_i based on the previous x<ix_{<i}. Image generation has been successfully cast as an autoregressive sequence generation or transformational problem. (Source)
  • Conditional GANs (cGANs) are a simple yet effective modification to a regular GAN that allows for image generation based on conditions like the class label or other image (partial image, segmentation map, etc). You can learn more about cGANs from this excellent blog post, and use this Colab to try out a simple cGAN.

Run set
1




Overview of the Proposed Method

Previous works that applied transformers to image generation demonstrated promising results for images up to a size of 64x64 pixels but couldn't be scaled to a higher resolution due to quadratically increasing cost with sequence length.
To use transformers to synthesize higher-resolution images, we need to represent the semantics of an image cleverly. Using pixel representation is not going to work as the number of pixels increases quadratically with a 2x increase in image resolution.

Vector Quantized Variational Autoencoder (VQ-VAE)

This work is Inspired by Vectored Quantized Variational Autoencoder that differs from traditional VAEs in two key ways:
  • The encoder network outputs discrete rather than continuous codes(latent representation of the image).
  • The prior is learned instead of static multivariate normal distribution.
You can learn more about VAEs in my report Towards Deep Generative Modeling with W&B. But this work is a more natural evolution of VQ-VAE2 that showed how powerful representation learning can be in the context of autoregressive generative modeling.
Figure 2: Overview of VQ-VAE architecture. (Source)
VQ-VAE consists of an encoder(E(.)E(.)the) that maps observations(images) onto a sequence of discrete latent variables and a decoder(G(.)G(.)) that reconstructs the observations from these discrete variables. They use a shared codebook. The codebook is given by eRK×De \in R^{K \times D}. Here KK is the size of the discrete code vectors in the codebook and DD is the dimensionality of each code eie_i, i1,2,..,Ki \in 1, 2,..,K.
As shown in figure 2, an image xx is passed through EE producing E(x)E(x). This is then quantized based on its distance to the code vectors eie_i such that each vector E(x)E(x) is replaced by the index of the nearest code vector in the codebook. The same is used by the decoder for reconstruction. The quantization is given by,
Quantization(E(x))=ekQuantization(E(x)) = e_k where k=argminjE(x)ejk = \underset{j}{\operatorname{argmin}} || E(x) - e_j ||
Two quick last pointers,
  • Quantization is a non-differentiable step; thus, to enable training end-to-end, the gradient of the reconstruction error is back-propagated through the decoder and to the encoder using the straight-through gradient estimator(copy the gradient from the decoder to the encoder).
  • Besides reconstruction loss, codebook loss and commitment loss are used. Codebook loss brings the selected code ee close to the output of the encoder. The commitment loss ensures that the output of the encoder stays close to the chosen ee.

Why is Discrete Latent Representation Working?

The work Generating Diverse High-Fidelity Images with VQ-VAE-2 that resembles Taming Transformers to an extent uses VQ-VAE2 (modification to VQ-VAE), but why is discrete representation working?
This is inspired by JPEG lossy compression of the image. JPEG encoding removes more than 80% of the data without noticeably changing the perceived image quality. Secondly, training a generative model with less noise tends to work better.

Taming Transformer

Figure 3: Architecture design of Taming a Transformer.
To keep the sequence length small and harness the expressivity of the transformer, the authors of the taming transformer use a discrete codebook of learned representations(inspired by VQ-VAE), such that an image xRH×W×3x \in R^{H \times W \times 3} can be represented by a spatial collection of codebook entries zqRh×w×nzz_q \in R^{h \times w \times n_z} where nzn_z is the dimensionality of codes.
As shown in Figure 3, the authors have used VQ-GAN, a variant of the original VQ-VAE, thus using a discriminator and perpetual loss to keep good perceptual quality at an increased compression rate.
This is a two-step training architecture design:
  • Training the VQ-GAN and learning the quantized codebook.
  • Training an autoregressive transformer using the quantized codebook as sequential input to the transformer.

Training VQ-GAN

The VQ-GAN is trained using a known adversarial training procedure with a patch-based discriminator DD. The complete objective for finding the optimal compression model Q={E,G,Z}Q^* = \{E^*, G^*, Z^*\} is given by,
Q=argminmaxE,G,Z Exp(x)[LVQ(E,G,Z)+λLGAN({E,G,Z},D)]Q^* = \underset{E, G, Z}{\operatorname{argminmax}} \space E_{x \approx p(x)} [L_{VQ}(E, G, Z) + \lambda L_{GAN}(\{E, G, Z\}, D)] 
where λ\lambda is an adaptive weight.
This training procedure significantly reduces the sequence length when unrolling the latent code and thereby enables the application of powerful transformer models.
In order to train your own VQ-GAN, you can clone the official GitHub repo for this paper and install the dependencies. This repository is already instrumented with Weights & Biases thus, you can automatically get all the necessary metrics in your W&B dashboard. The data preparation step is clearly mentioned by the authors. To train the VQ-GAN:
python main.py --base configs/faceshq_vqgan.yaml -t True --gpus 0,
You can select a different .yaml file depending on the dataset of your choice.

Training Transformer

With the trained encoder and decoder, we can now represent an image in terms of the codebook-indices of their encodings. This quantized encoding of image xx  is given by zq=q(E(x))Rh×w×nzz_q = q(E(x)) \in R^{h \times w \times n_z}. By expressing this encoding as a sequence s{0,....,Z1}h×ws \in \{0,...., |Z|-1\}^{h \times w} of indices from the codebook, which is obtained by replacing each code by its index in the codebook ZZ, sij=ks_{ij} = k such that (zq)ij=zk(z_q)_{ij} = z_k. One can now map the sequence ss back to the corresponding codebook entry and can get the decoded image x~=G(zq)\tilde{x} = G(z_q).
With this sequence, image generation can be formulated as an autoregressive next-index prediction. Given indices s<is_{<i}, the transformer learns to predict the distribution of possible next indices, i.e, p(sis<i)p(s_i | s_{<i}) to compute the likelihood of the full representation as p(s)=ip(sis<i)p(s) = \prod_i p(s_i | s_{<i}). Thus the objective function is to maximize the log-likelihood of the data representations. Here's a really nice article on likelihood.
Any image generation system is useful if the user can control the generation process. Image generation can be conditioned on additional information like class labels(as shown above) or partial images. The task then is to learn the likelihood of the sequence given this information cc:
p(sc)=ip(sis<i,c)p(s|c) = \prod_i p(s_i | s_{<i}, c)
With the trained VQ-GAN, adjust the checkpoint path of the config key model.params.first_stage_config.params.ckpt_path in configs/faceshq_transformer.yaml, then run:
python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0,



Generating High-Resolution Images

The quadratically increasing computational cost of training a transformer with the increase in sequence length puts limits on the sequence length. Thus to generate images in the megapixel regime, the authors have to work patch-wise and crop images to restrict the length of ss to a maximally feasible size during training.
To sample images, the transformer is used in a sliding-window manner.
Figure 4: The sliding attention window
The VQ-GAN ensures that the available context is still sufficient to faithfully model images, as long as either the statistics of the dataset are approximately spatially invariant or spatial conditioning information is available.
In the media panel shown below, click on the ⚙️ and use the slider to visualize how a high-resolution image was synthesized using the sliding window.

Run set
1


Results


Run set
1



Run set
1


Further Reading and Conclusion

The goal of this report is to summarize the paper, making it more accessible for the readers. I have used lines from the paper at places because that was the best way to convey the information.
Here are some of the further reads that you might find exciting.
24wowlive
24wowlive •  
Am i right, that it is rengineered Dall e?
1 reply
Iterate on AI agents and models faster. Try Weights & Biases today.