Generate images from a text prompt in this interactive report: DALL·E on a smaller architecture. Made by Boris Dayma using Weights & Biases
Generated with DALL·E mini as "logo of an armchair in the shape of an avocado"
As part of the FLAX/JAX community week
organized by 🤗 Hugging Face and the Google Cloud team, we worked on reproducing the results of OpenAI's DALL·E
with a smaller architecture. DALL·E can generate new images from any text prompt.
We show we can achieve impressive results (albeit of a lower quality) while being limited to much smaller hardware resources. Our model is 27 times smaller than the original DALL·E and was trained on a single TPU v3-8 for only 3 days.
By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models available, we were able to satisfy a tight timeline.
DALL·E mini project timeline
Demo of DALL·E mini
We used 3 datasets for our model:
The OpenAI subset
which contains about 15 million images and that we further sub-sampled to 2 million images due to limitations in storage space. We used both title and description as caption and removed html tags, new lines and extra spaces.
For fine-tuning our image encoder, we only used a subset of 2 million images.
We used all the images we had (about 15 million) for training our Seq2Seq model.
During training, images and descriptions are both available and pass through the system as follows:
- Images are encoded through a VQGAN encoder, which turns images into a sequence of tokens.
- Descriptions are encoded through a BART encoder.
The output of the BART encoder and encoded images are fed through the BART decoder, which is an auto-regressive model whose goal is to predict the next token.
Loss is the softmax cross-entropy between the model prediction logits and the actual image encodings from the VQGAN.
Training pipeline of DALL·E mini
At inference time, we only have captions available and want to generate images:
The caption is encoded through the BART encoder.
A token (special token identifying the "Beginning Of Sequence") is fed through the BART decoder.
Image tokens are sampled sequentially based on the decoder's predicted distribution over the next token.
Sequences of image tokens are decoded through the VQGAN decoder.
is used to select the best generated images.
Inference pipeline of DALL·E mini
For encoding & decoding images, we use a VQGAN.
The goal of the VQGAN is to encode an image into a sequence of discrete tokens that can be used in transformers model which have proved to be very efficient in NLP.
Source: Taming Transformers for High-Resolution Image Synthesis
Using a sequence of pixel values, the embedded space of discrete values would be too large, making it extremely difficult to train a model and satisfy memory requirements for self attention layers.
The VQGAN learns a codebook of pixels by using a combination of a perceptual loss and a GAN discriminator loss. The encoder outputs the indexes corresponding to the codebook.
Once the image is encoded into a sequence of tokens, it can then be used in any transformer model.
In our model, we encode images to 16 x 16 = 256 discrete tokens from a vocabulary of size 16384, using a reduction factor f=16 (4 blocks dividing width & height by 2 each). Decoded images are then 256 x 256 (16 x 16 for each side).
A seq2seq model transforms a sequence of tokens into another sequence of tokens and is typically used in NLP for tasks such as translation, summarization or conversational modeling.
The same idea can be transferred to computer vision once images have been encoded into discrete tokens.
Source: BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
Our model uses BART, where the input corresponds to the description and the output is the corresponding image encoded by the VQGAN.
We only had to make a few adjustments to the original architecture:
Create independent embedding layers for the decoder and encoder (they often can be shared when having the same type of inputs & outputs)
Adjust decoder inputs and outputs shape to VQGAN vocabulary size (not needed for the intermediate embedding layers)
Force the generated sequence to 256 tokens (without including special tokens and which identify beginning and end of sequences)
CLIP is a neural network able to create correlation between images and text.
It is trained using contrastive learning, which consists of maximizing the product between a pair of image and text embeddings (also called cosine similarity) and minimizing it between non-associated pairs.
Source: Learning Transferable Visual Models From Natural Language Supervision
When generating images, we perform random sampling of image tokens based on the model logits distribution, which leads to diverse samples but of unequal quality.
CLIP lets us select the best generated samples by giving a score to the generated images against their input description. We directly use the pre-trained version from OpenAI in our inference pipeline.
How does it compare to OpenAI DALL·E?
We are grateful for the research and pre-trained models published by OpenAI which were essential in building our model.
Not all the details on DALL·E are public knowledge but here are what we consider to be the main differences:
DALL·E uses a 12 billion parameter version of GPT-3. In comparison our model is 27 times smaller with about 0.4 billion parameters.
We heavily leverage pre-trained models (VQGAN, BART encoder and CLIP) while OpenAI had to train all their models from scratch. Our model architecture takes into account pre-trained models available and their efficiency.
DALL·E encodes images using a larger number of tokens (1024 vs 256) from a smaller vocabulary (8192 vs 16384). DALL·E uses a VQVAE while we use a VQGAN.
DALL·E encodes text using fewer tokens (at most 256 vs 1024) and a smaller vocabulary (16,384 vs 50,264).
DALL·E reads text and images as a single stream of data while we split them between the Seq2Seq Encoder and Decoder. This also let us use independent vocabulary for text and images.
DALL·E reads the text through an auto-regressive model while we use a bidirectional encoder.
DALL·E was trained on 250 million pairs of image and text while we used only 15 million pairs.
Those differences have led to an efficient training that can be performed on a single TPU v3-8 in 3 days.
Since we automatically checkpoint and log our model every hour, we could use preemptible TPU instances ($2.40/h at the time of this report), meaning a training cost of less than $200 for our model. This does not include our experimentation on TPU's and hyperparameter search which would add about $1,000 in our case (TPU resources were actually provided for free as part of this project).
Images generated by DALL·E are still of a much higher quality than our model's but it's interesting to observe we can train a reasonably good model with few resources.
Training the model
Training the VQGAN
While being extremely efficient at encoding a large range of images, the pre-trained checkpoint was not good at encoding people and faces (they are not frequent in ImageNet) so we decided to fine-tune it for about 20h on a cloud instance of 2 x RTX A6000.
The quality of generated images didn't improve a lot on faces, probably due to mode collapse
. It would be worthwhile to retrain it from scratch in the future.
Once the model was trained, we converted our Pytorch model to JAX for the next phase.
Training DALL·E mini
The model is programmed in JAX to take full advantage of the TPU's.
We pre-encoded all our images with the image encoder for faster data loading.
We quickly settled on few parameters that seemed to work well:
Batch size per TPU per step: 56 to max the the memory available per TPU
Gradient accumulation: 8 steps for an effective batch size of 56 x 8 TPU chips x 8 steps = 3584 images per update.
for its memory efficiency which let us use a higher batch size.
Learning rate with 2,000 warmup steps and a linear decay.
We dedicated half a day to finding a good learning rate for our model by launching a hyper-parameter search.
After our preliminary search, we experimented with a few different learning rates for a longer period until we finally settled with 0.005.
Training could have continued longer as the evaluation loss was still improving well but the project was ending (as was the availability of the TPU VM).
For each prompt, we generate 128 images and select the best 8 images with CLIP.
Note: the Unreal Engine trick
does not seem to affect our model. It is possible that our dataset did not have such image-text pairs and that including them would affect the predictions.
Evolution of predictions during training
We can clearly see how the quality of generated images improved as the model trained.
Visualize different examples by clicking on ⚙️ at the top left of the panel and changing index.
How do our results compare with OpenAI's DALL·E
It is interesting to note that OpenAI often uses very long repeating prompts such as:
a storefront that has the word 'openai' written on it. a storefront that has the word 'openai' written on it. a storefront that has the word 'openai' written on it. openai storefront.
This may be due to having their prompts defined as the concatenation between image titles and descriptions.
We did not notice any significant impact in using longer prompts in our model.
How do our results compare to DALLE-pytorch
It offers many different models, allows for plenty of customization (model size, image encoder, custom attention heads, etc), and seems to have been trained on similar datasets as ours. It has shown impressive results especially when trained on smaller datasets.
Both models can generate impressive results, especially on landscapes.
Overall, DALL·E mini seems to be able to produce more relevant images, of a slightly better quality, and with more details.
Visualize different examples by clicking on ⚙️ at the top left of the panel and changing index.
How do our results compare to "Generator + CLIP"
There are several models available which consist of a generator coupled with CLIP to create images (such as "VQGAN + CLIP").
These models have a completely different approach. Each image prediction is actually the result of an optimization process where we iterate over the latent space of the generator (image encoding space) to directly maximize the CLIP score between generated image and description.
An interesting aspect of this method is that we can iterate either from a random image or from a pre-selected image. Also it can be used with any image resolution, constrained only by GPU RAM and time to train.
Sample predictions using "VQGAN + CLIP"
This technique is slower and mostly used for generating artistic images which could be unrealistic but of a higher resolution.
Limitations and biases
During our experiments, we observed several limitations:
Watermarks are often present on generated samples.
Faces and people in general are not generated properly.
Animals are usually unrealistic.
It is hard to predict where the model excels or falls short. For example the model is great at generating "a logo of an armchair in the shape of an avocado" but cannot produce anything relevant for "a logo of a computer" (in this case we need to adjust to "an illustration of a computer"). Reformulating matters! The goal is to write a description similar to what could have been seen during training. Good prompt engineering will lead to the best results.
The model has only been trained with English descriptions and will not perform well in other languages. This can potentially be fixed using a translation service or model in our inference pipeline, but needs to be evaluated in more details.
Overall it is difficult to investigate in much detail the model biases due to the low quality of generated people and faces, but it is nevertheless clear that biases are present:
Occupations demonstrating higher levels of education (such as engineers, doctors or scientists) or high physical labor (such as in the construction industry) are mostly represented by white men. In contrast, nurses, secretaries or assistants are typically women, often white as well.
Most of the people generated are white. It's only on specific examples such as athletes that we will see different races, though most of them still under-represented.
The dataset is limited to pictures with English descriptions, preventing text and images from non-English speaking cultures to be represented.
We study the context in which several sensitive terms related to gender, age, race, ethnicity appear such as “black”, “white”, “asian”, “african”, “american”, “indian”, “man/men”, “woman/women”, “boy”, “girl”, “young”, “old”, etc. We do not observe any large biases in the distribution of these terms, either in terms of co-occurrence between sensitive term pairs or co-occurrence with other tokens. Furthermore, we check the distribution of web domains and, similar to visual concepts, we find this to be diverse and long-tail: >100K with >40K contributing >10 samples. We take our preliminary study as a positive indication of no severe biases stemming from particular domains or communities.
Since this dataset represents only 70% of all the data we used, it is possible that bias was introduced by:
the other datasets
the model itself
our training pipeline
our inference pipeline
the pre-trained models we used (mainly BART encoder or CLIP during scoring)
a combination of all the above, including potentially undetected bias from the preliminary study done with Conceptual 12M
Since we are releasing a public demo, we will be able to collect feedback from users and get more understanding of our model's limitations and biases. The next step will be to find ways to mitigate them.
Some improvements can be made on the model:
We can use a larger dataset, we didn't use all the images we had available.
We need to better filter the dataset: duplicates, low quality images, watermarks, bad descriptions, etc. The use of Neural Networks can be helpful for these tasks.
We can improve how we pre-process title and description of images and concatenate them based on their quality.
We can test different types of tokenizers & encoders.
We can try to normalize the text: all lower case (though it probably helps identify names and places), no punctuation, filter allowed characters.
Our model is limited by the quality of the Image Encoder/Decoder.
We can train the VQGAN from scratch which could limit some of the mode collapse acquired by a pre-trained model (though not necessarily avoid it completely)
We can scale up the model.
We can train longer and leverage more hardware resources.
We can try to generate the image in a different sequence (for example starting from the center).
Limitations & Biases
Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. "Learning Transferable Visual Models From Natural Language Supervision"
Tom B. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel M. Ziegler, Jeffrey Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, Dario Amodei. "Language Models are Few-Shot Learners"