Skip to main content

The Infrastructure Behind Serving DALL·E Mini

In this article, we explore the technology underpinning Dall-E mini and set up a high-load backend infrastructure on Google TPUs.
Created on June 7|Last edited on January 25
Please note: This is not a model training post.
We started working on DALL·E Mini almost a year ago, as part of the HuggingFace JAX/Flax community event. It was our attempt to replicate OpenAI's DALL·E project but using more modest resources. We won first place and wrote a detailed post about the process.
Boris has kept working on it since then, vastly improving our initial effort, getting better data, training larger models (DALL·E-Mega), and discovering all sorts of tweaks, workarounds, and practical tips to deal with large models. For all the ongoing training details and everything he's learning along the way, you should really follow him now and read his reports.
One of the great things about his work is that he's doing everything in the open. Not only is the model open source and the weights available as W&B Artifacts, but he also shares a good amount of stuff in Twitter threads like the one linked above and engages with other researchers to discuss ideas that could or could not work for the project. This includes the ability for anyone to try out the model at its current training stage and see for themselves what it can do or discover any limitations it has.
As part of our participation in that HuggingFace event, we created a live demo that you can use to create images using any text prompt you like. We've kept the demo alive and improved it to cope with increasingly higher load levels.
This article is about how we are managing to serve a sustained load of more than 10 requests per second (on a large, complicated model!), as the following plot attests. These are the figures I got when I started writing this post:


The blue line represents the total number of requests received per minute, and the orange/red lines show the portion of requests that errored for whatever reason (high load, usually). As you can see, our demo is holding up despite high load levels!

Table of Contents



Resources

We are using Google Cloud TPU VMs for several reasons:
  • Google sponsored the community event and allowed us to use a TPU v3-8 for training. After the event, they provided us with free access to several TPU instances to keep working on the project. We are immensely grateful for their support.
  • The project was initially coded in JAX, and we found that JAX is great for parallelism (more on that later).
  • TPUs are inherently parallel machines. The use of JAX on TPUs is a very powerful combination.
(As a side note, I think it's worth mentioning that Google offers the TPU Research Cloud program in which any researcher can pitch their project and maybe get free TPU credits for their training!)
TPUs come in two families: v3 and v2. The difference between v3 and v2 is that v3's have 16 GB of RAM per tensor device (think: GPU card), and v2's have 8 GB of RAM per device. The instances we use have eight tensor devices, hence Google names them v3-8 and v2-8. Tensor devices inside a TPU instance are connected using very high-speed and low-latency connectivity.

Architecture

Our app is hosted in the HuggingFace hub as a Space. It was initially written using Streamlit, but we recently migrated it to Gradio because the new Blocks API provides a lot of flexibility to configure the UI the way you like. We also found Gradio to be very stable and resilient under heavy loads.
Many Spaces in the Hub run inside virtual computing environments provided by HuggingFace, and they perform inference using the resources the VM provides. Our situation is slightly different: our app runs in Space, but we offload all requests to our backend infrastructure. This allows us to take advantage of the TPUs Google lent us and to add or remove instances to adapt to the load we have. For example, we've noticed huge spikes when OpenAI released their impressive next iteration of DALL·E (version 2), when Google released Imagen shortly after, or when Vinesauce tested our demo.
In order to be able to add or remove instances on the fly, we use a small webserver that runs nginx in a load-balancing configuration. This is a small and cheap machine with just 4 GB of system RAM, whose mission is to forward requests to the TPUs doing the actual inference. In order to add a new instance, we need to update the configuration file and reload the service – no downtime is incurred at all.

Dashboard

The load-balancing server is also in charge of gathering stats for the backend and log them as W&B runs using a custom Python script. This is probably a little outside the way W&B was meant to be used, but all the results from our training and evaluation code are already logged to W&B. Why wouldn't we do the same with server metrics?
The data we log comes from two sources:
  • nginx log file. This is where the plot shown in the introduction is coming from.
  • A special endpoint runs in the TPUs. Every few seconds, we poll all the load-balanced TPU servers and gather information about response times, queue sizes, and the number of prompts being served per batch.
The following plot, for example, shows the evolution of response time in some of our TPU instances. The solid blue line is the average, and the shadowed area is an indication of dispersion across instances:


Correspondingly, this is the accumulated queue size in those machines:


Interestingly, the queue size for the less powerful v2-8 instances is about half of their v3-8 counterparts, but the response time is about the same. nginx provides lots of knobs and options to configure load balancing in a way that makes sense for your service.


This simple dashboard allows us to react to events such as changes in the workload or TPUs becoming unavailable. It also helps uncover potential mistakes, like the following plot shows. Instead of showing the aggregated response times for v2-8 instances, I chose to show them all and saw that one of the computers is responding faster than the others.
They are all identical and should take the same time to respond, so what is going on? Initially, I thought that I must have surely made some mistake in the configuration of sr1, so it was generating fewer images per batch, but I verified that it was ok. I then noticed that the version of the jax library was older than the one installed in their siblings, so we could be using some inefficient or deprecated functions that got replaced in newer versions. This is something we need to test and verify!



JAX Parallelism: Optimizing Capacity

Our first inference implementation was a straightforward loop, similar to the one that is demonstrated in the project's inference notebook (it can run in Colab, by the way). We took a maximum of 8 prompts from different users and sent one prompt to each TPU device (remember that each TPU instance has 8 TPU devices).
We looped a number of times to get several predictions per prompt, scored them with CLIP, and returned the best 9 of the lot. If the load was low (which was the usual thing when we started, but not anymore), we just processed a single prompt in one of the devices.
There are three models involved in inference. The details are not important (you can read about them in the project report), but some idea of what they do is useful:
  • A BART-like model to perform the sequence-to-sequence transformations.
  • A VQGAN model to generate images from sequences.
  • A CLIP model to score the best images, measuring how similar the image is to the prompt supplied.

Data Parallelization

The next step was to parallelize requests as much as we could. Using JAX's essential pmap function to distribute data across devices and leveraging the einops package to expressively describe our data transformations, we prepared our inference to work like this:
  • We found that each TPU v3-8 was capable of generating 64 images at a time (8 per device) using half-precision bfloat16.
  • We accepted a maximum of 4 parallel prompts at a time, so we generate 16 images per prompt.
  • The first prompt is replicated in the first two devices, the second prompt takes the next two, and so on.
  • When the batch runs, each device is in charge of generating images for the prompt it was assigned.
  • After the images are generated, we extract them to a list, unpack them by prompt and score results with CLIP.
  • A final trick was to optimize for low-load cases. If we only receive a single prompt at a time, we generate 64 images for that prompt so that the quality would be higher! Similarly, we would generate 32 images if just two requests were received in parallel.
By this time, we had decided to replace the smaller version of the model (DALL·E Mini) with the larger one being trained (DALL·E Mega), so people would get the best predictions possible as model training progressed. Unfortunately, this model no longer fits in the smaller v2-8 instances.
Note, also, that each TPU device is doing exactly the same, just on different prompts. Every device receives its own replica of the three models involved in the system.

Model parallelization

The problem with the approach we just described is that models are very large, and there is relatively little free TPU RAM to host data. This happens in the 8 devices each instance has because each has to keep its own copy of the models.
The next trick Boris came up with was to apply Model Parallelization (MP), in addition to the Data Parallelization we had already put in place!
The basic idea is to split model weights across devices so the forward pass is distributed among them. Instead of each device running the same model on different data, different devices contribute to a different portion of the computation. As we don't need to replicate all the weights in all the devices, there's more room for data.
I had never used MP before. I understand there are specialized and very complex libraries to achieve the same on GPUs, so I was really amazed to see that JAX native libraries already support this use case. In fact, JAX was designed with DP and MP in mind. Using pjit we could distribute model weights across devices in a similar way that using pmap allows you to distribute data.
After applying this technique, a single v3-8 instance was now able to generate 128 images at once, using full precision. We had essentially doubled the capacity and fixed a problem where the VQGAN portion of our model sometimes generates black images when using half precision.

Making it work efficiently in v2-8 instances

The Data Parallel solution was not able to run in v2-8 instances because there was not enough RAM. Using MP, however, we could run batches of 32 images at a time. Our final trick was to:
  • Use half-precision for BART and CLIP.
  • Keep using full precision for VQGAN so we prevent the black image generation problem.
This allowed us to generate 128 images per TPU v2-8 instance so that we could add a few to our backend!
This is what happened when we added the first TPU v2-8 instances on Saturday evening:
Introducing the first MP-enabled TPU v2-8 at about 7:40pm.
We were already at max capacity – a queue size of 100 is the maximum configured to allow for response times of about 2 minutes max. When we reach that queue size, we start rejecting requests. The new smaller instances allowed us to go through a very heavy period that went on for a few more hours.
These gains could also be applied to our existing v3-8 instances: we could essentially double their capacity for free. Before we migrated them, however, Boris set up a test to examine the difference between running all models in full precision and moving to half-precision (except for VQGAN).
The following table shows the results: rows with a red number were generated by a full precision model, while rows with a purple number were generated by bfloat16 + float32 VQGAN. If you click the arrows below the table, you can examine all outputs yourself. In our opinion, the bfloat16 version yields perfectly fine results, so we migrated all our TPUs to take advantage of the increased capacity.



Beware of System Limits

The first figure in this post shows the live traffic we were getting when I started writing this post yesterday. Today, before wrapping up, traffic was already higher. But, more suspiciously, it was flat at about 1200 requests per minute:
We get anxious whenever we see a flat pattern. It usually means we've reached a limit somewhere.
Even though I had already raised nginx's connections limit beyond the default value of 768, I did not think about the total number of open files allowed per Unix process. As it turns out, we were hitting a default limit of 1024. The file abstraction in Unix of course represents network connections too, so nginx was being prevented from even considering any excess requests. After raising the limit, we immediately saw this:
Releasing the kraken. Increasing the open files limit.
The huge and sudden spike is possibly caused by a large number of pending requests that suddenly made it through to the server. It appears to be stabilizing, but it's clear we are currently at about 2,000 requests per minute or more than 30 per second. This is huge for a model that takes several seconds to run!

Next Steps

  • We have a single load balancer, which is a single point of failure. In addition, we have servers in the Europe and USA regions of Google Cloud. Retrieving data from a different region incurs monetary costs (Google provides the TPUs, but we pay for everything else). I had the idea to set up a load balancer per region, but Suvash recommended me to look into Google Cloud Load Balancing service, which is designed exactly for that. Thanks, Suvash; I'll do that!
  • We could create automations to assign more TPUs as load increases and remove them when they are not necessary. We still have a few v2-8 that we could use!
  • We want to evaluate performance on GPU devices to see how they compare.

Final Thoughts

This is a project that grew organically from a single inference server and a simple demo to support thousands of requests every day. I don't know if there are professional solutions to achieve the same, but it seems hard to cover all the aspects of the problem, including model tuning and optimization. Even so, I think this is an area with great potential for growth. I liked to craft particular solutions, custom code, and visualizations, but this should be doable at scale. Who's up for the challenge?
Andrew Cox
Andrew Cox •  
Very interesting! thanks for sharing
1 reply
Iterate on AI agents and models faster. Try Weights & Biases today.