DALL·E mini

Generate images from a text prompt in this interactive report: DALL·E on a smaller architecture. Made by Boris Dayma using Weights & Biases
Boris Dayma
Generated with DALL·E mini as "logo of an armchair in the shape of an avocado"

Introduction

As part of the FLAX/JAX community week organized by 🤗 Hugging Face and the Google Cloud team, we worked on reproducing the results of OpenAI's DALL·E with a smaller architecture. DALL·E can generate new images from any text prompt.
We show we can achieve impressive results (albeit of a lower quality) while being limited to much smaller hardware resources. Our model is 27 times smaller than the original DALL·E and was trained on a single TPU v3-8 for only 3 days.
By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models available, we were able to satisfy a tight timeline.
DALL·E mini project timeline
Our code and our interactive demo are available for experimenting with any prompt!
Demo of DALL·E mini

Datasets

We used 3 datasets for our model:
For fine-tuning our image encoder, we only used a subset of 2 million images.
We used all the images we had (about 15 million) for training our Seq2Seq model.

Model Architecture

Overview

During training, images and descriptions are both available and pass through the system as follows:
Training pipeline of DALL·E mini
At inference time, we only have captions available and want to generate images:
Inference pipeline of DALL·E mini

Image Encoder/Decoder

For encoding & decoding images, we use a VQGAN.
The goal of the VQGAN is to encode an image into a sequence of discrete tokens that can be used in transformers model which have proved to be very efficient in NLP.
Source: Taming Transformers for High-Resolution Image Synthesis
Using a sequence of pixel values, the embedded space of discrete values would be too large, making it extremely difficult to train a model and satisfy memory requirements for self attention layers.
The VQGAN learns a codebook of pixels by using a combination of a perceptual loss and a GAN discriminator loss. The encoder outputs the indexes corresponding to the codebook.
Once the image is encoded into a sequence of tokens, it can then be used in any transformer model.
In our model, we encode images to 16 x 16 = 256 discrete tokens from a vocabulary of size 16384, using a reduction factor f=16 (4 blocks dividing width & height by 2 each). Decoded images are then 256 x 256 (16 x 16 for each side).
For more details and better understanding of the VQGAN, please refer to Taming Transformers for High-Resolution Image Synthesis.

Seq2Seq model

A seq2seq model transforms a sequence of tokens into another sequence of tokens and is typically used in NLP for tasks such as translation, summarization or conversational modeling.
The same idea can be transferred to computer vision once images have been encoded into discrete tokens.
Source: BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
Our model uses BART, where the input corresponds to the description and the output is the corresponding image encoded by the VQGAN.
We only had to make a few adjustments to the original architecture:
For more understanding of BART, refer to BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension.

CLIP

CLIP is a neural network able to create correlation between images and text.
It is trained using contrastive learning, which consists of maximizing the product between a pair of image and text embeddings (also called cosine similarity) and minimizing it between non-associated pairs.
Source: Learning Transferable Visual Models From Natural Language Supervision
When generating images, we perform random sampling of image tokens based on the model logits distribution, which leads to diverse samples but of unequal quality.
CLIP lets us select the best generated samples by giving a score to the generated images against their input description. We directly use the pre-trained version from OpenAI in our inference pipeline.
For more understanding of CLIP, refer to Learning Transferable Visual Models From Natural Language Supervision.

How does it compare to OpenAI DALL·E?

We are grateful for the research and pre-trained models published by OpenAI which were essential in building our model.
Not all the details on DALL·E are public knowledge but here are what we consider to be the main differences:
Those differences have led to an efficient training that can be performed on a single TPU v3-8 in 3 days.
Since we automatically checkpoint and log our model every hour, we could use preemptible TPU instances ($2.40/h at the time of this report), meaning a training cost of less than $200 for our model. This does not include our experimentation on TPU's and hyperparameter search which would add about $1,000 in our case (TPU resources were actually provided for free as part of this project).
Images generated by DALL·E are still of a much higher quality than our model's but it's interesting to observe we can train a reasonably good model with few resources.

Training the model

Training the VQGAN

We started with a pre-trained checkpoint fine-tuned on ImageNet with a reduction factor f=16 and a vocabulary size of 16,384.
While being extremely efficient at encoding a large range of images, the pre-trained checkpoint was not good at encoding people and faces (they are not frequent in ImageNet) so we decided to fine-tune it for about 20h on a cloud instance of 2 x RTX A6000.
The quality of generated images didn't improve a lot on faces, probably due to mode collapse. It would be worthwhile to retrain it from scratch in the future.
Once the model was trained, we converted our Pytorch model to JAX for the next phase.

Training DALL·E mini

The model is programmed in JAX to take full advantage of the TPU's.
We pre-encoded all our images with the image encoder for faster data loading.
We quickly settled on few parameters that seemed to work well:
We dedicated half a day to finding a good learning rate for our model by launching a hyper-parameter search.
After our preliminary search, we experimented with a few different learning rates for a longer period until we finally settled with 0.005.
Training could have continued longer as the evaluation loss was still improving well but the project was ending (as was the availability of the TPU VM).

Results

Sample predictions

For each prompt, we generate 128 images and select the best 8 images with CLIP.
Note: the Unreal Engine trick does not seem to affect our model. It is possible that our dataset did not have such image-text pairs and that including them would affect the predictions.

Evolution of predictions during training

We can clearly see how the quality of generated images improved as the model trained.
Visualize different examples by clicking on ⚙️ at the top left of the panel and changing index.

How do our results compare with OpenAI's DALL·E

The model fails on several prompts published for OpenAI's DALL·E.
It is interesting to note that OpenAI often uses very long repeating prompts such as:
a storefront that has the word 'openai' written on it. a storefront that has the word 'openai' written on it. a storefront that has the word 'openai' written on it. openai storefront.
This may be due to having their prompts defined as the concatenation between image titles and descriptions.
We did not notice any significant impact in using longer prompts in our model.

How do our results compare to DALLE-pytorch

The best open-source version of DALL·E that we were aware of when developing our model was lucidrains/DALLE-pytorch.
It offers many different models, allows for plenty of customization (model size, image encoder, custom attention heads, etc), and seems to have been trained on similar datasets as ours. It has shown impressive results especially when trained on smaller datasets.
For this comparison, we use checkpoint 16L_64HD_8H_512I_128T_cc12m_cc3m_3E.pt which is the current recommended model and select the top 8 predictions out of 128 according to CLIP to follow our inference pipeline.
Both models can generate impressive results, especially on landscapes.
Overall, DALL·E mini seems to be able to produce more relevant images, of a slightly better quality, and with more details.
Visualize different examples by clicking on ⚙️ at the top left of the panel and changing index.

How do our results compare to "Generator + CLIP"

There are several models available which consist of a generator coupled with CLIP to create images (such as "VQGAN + CLIP").
These models have a completely different approach. Each image prediction is actually the result of an optimization process where we iterate over the latent space of the generator (image encoding space) to directly maximize the CLIP score between generated image and description.
An interesting aspect of this method is that we can iterate either from a random image or from a pre-selected image. Also it can be used with any image resolution, constrained only by GPU RAM and time to train.
Sample predictions using "VQGAN + CLIP"
This technique is slower and mostly used for generating artistic images which could be unrealistic but of a higher resolution.

Limitations and biases

During our experiments, we observed several limitations:
Overall it is difficult to investigate in much detail the model biases due to the low quality of generated people and faces, but it is nevertheless clear that biases are present:
According to Conceptual 12M paper:
We study the context in which several sensitive terms related to gender, age, race, ethnicity appear such as “black”, “white”, “asian”, “african”, “american”, “indian”, “man/men”, “woman/women”, “boy”, “girl”, “young”, “old”, etc. We do not observe any large biases in the distribution of these terms, either in terms of co-occurrence between sensitive term pairs or co-occurrence with other tokens. Furthermore, we check the distribution of web domains and, similar to visual concepts, we find this to be diverse and long-tail: >100K with >40K contributing >10 samples. We take our preliminary study as a positive indication of no severe biases stemming from particular domains or communities.
Since this dataset represents only 70% of all the data we used, it is possible that bias was introduced by:
Since we are releasing a public demo, we will be able to collect feedback from users and get more understanding of our model's limitations and biases. The next step will be to find ways to mitigate them.

Looking forward

Some improvements can be made on the model:

Resources

References

Authors

Acknowledgements