Skip to main content

How To Train a Conditional Diffusion Model From Scratch

In this article, we look at how to train a conditional diffusion model and find out what you can learn by doing so, using W&B to log and track our experiments.
Created on September 30|Last edited on February 14
From DALL-E to Stable Diffusion, image generation is perhaps the most exciting thing in deep learning right now. But what use cases can we think about other than generating funny images to post on Twitter?
Recently, I found this excellent video on YouTube about programming a conditional diffusion model in PyTorch. I recommend watching this before continuing, as we will dig deeper into this code and train some models of our own!


What you get from this video is that you can train a model to generate data from a supervised dataset.
If you have a labeled dataset, you can generate synthetic data! This can be extremely useful for complex/expensive to-label data.
💡
Here's what we'll be covering:

Table of Contents



Let's get started!

Initial Checkup: CIFAR-10

The original code uses CIFAR-10 to train a conditional diffusion model, so let's do that first. Before we begin, here are a few improvements we made to the code base:
  • We added validation metrics (computed on the test data)
  • We enabled mixed precision training and multithreaded dataloaders
  • OneCycleScheduler
  • Weights & Biases (W&B) logging
You can find the code for this experiments here: https://github.com/tcapelle/Diffusion-Models-pytorch
💡
This telemetry enabled us to follow how the model was training, but it's not enough for these models:

Run set
3

As you can see, the loss (`train_mse`) is not very smooth, so you could think that the models is not learning anything. But if we plot sampled images (we run diffusion inference every 10 epochs and log the images to W&B), we can see how the models keeps improving. Moving the slider below, you can see how the model improves over time. We sample both the regular model and the EMA copy model.
The slider numbers are the batches seen by the model, but it was logged every 10 epochs.


NOTE: In the video, the author claims that the EMA model produces better output, but I am not sure.
💡

Technical Details

There are a bunch of little details that make this training possible. For starters, we did all the training on GCP machines, with the latest PyTorch. We used a V100 16GB for CIFAR initial training before moving on to an A100 (40GB). A few additional points:
  • The UNet model that will denoise the images is big; it has a bunch of self-attention layers and is very computation heavy. So even if the images are small in resolution, the self-attention is computed between every pixel, so it's quadratic on image size.
  • To get sharper results, we used an upscaled version of CIFAR-10 that has 64x64 images (instead of 32x32). You can find this dataset on Kaggle.
  • I could only fit batch size equal to 10 on the V100 up from 4 without mixed precision.
  • I also suppressed one of the deeper bottleneck (512 x 512) convolutional layers on the network to make it a little bit faster.

Code

The full code of this article can be found here. We will limit to the discussion of the relevant parts of the codebase in this article and if you have any questions please contact me (or drop them in the comments below!).

The Model

The default non-conditional diffusion model is composed of a UNet with self-attention layers. We have the classic U structure with downsampling and upsampling paths. The main difference with traditional UNet is that the up and down blocks support an extra timestep argument on their forward pass. This is done by embedding the timestep linearly into the convolutions, for more details, check the modules.py file.
class UNet(nn.Module):
def __init__(self, c_in=3, c_out=3, time_dim=256):
super().__init__()
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 64)
self.down1 = Down(64, 128)
self.sa1 = SelfAttention(128)
self.down2 = Down(128, 256)
self.sa2 = SelfAttention(256)
self.down3 = Down(256, 256)
self.sa3 = SelfAttention(256)

self.bot1 = DoubleConv(256, 256)
self.bot2 = DoubleConv(256, 256)

self.up1 = Up(512, 128)
self.sa4 = SelfAttention(128)
self.up2 = Up(256, 64)
self.sa5 = SelfAttention(64)
self.up3 = Up(128, 64)
self.sa6 = SelfAttention(64)
self.outc = nn.Conv2d(64, c_out, kernel_size=1)
def unet_forwad(self, x, t):
"Classic UNet structure with down and up branches, self attention in between convs"
x1 = self.inc(x)
x2 = self.down1(x1, t)
x2 = self.sa1(x2)
x3 = self.down2(x2, t)
x3 = self.sa2(x3)
x4 = self.down3(x3, t)
x4 = self.sa3(x4)

x4 = self.bot1(x4)
x4 = self.bot2(x4)

x = self.up1(x4, x3, t)
x = self.sa4(x)
x = self.up2(x, x2, t)
x = self.sa5(x)
x = self.up3(x, x1, t)
x = self.sa6(x)
output = self.outc(x)
return output
def forward(self, x, t):
"Positional encoding of the timestep before the blocks"
t = t.unsqueeze(-1)
t = self.pos_encoding(t, self.time_dim)
return self.unet_forwad(x, t)
The conditional model is almost identical but adds the encoding of the class label into the timestep by passing the label through an Embedding layer. It is a very simple and elegant solution.
class UNet_conditional(UNet):
def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None):
super().__init__(c_in, c_out, time_dim)
if num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_dim)

def forward(self, x, t, y=None):
t = t.unsqueeze(-1)
t = self.pos_encoding(t, self.time_dim)

if y is not None:
t += self.label_emb(y)

return self.unet_forwad(x, t)

EMA Code (Click to Expand)

Training

We have refactored the code to make it functional. The training step happens on the one_epoch function:
def train_step(self):
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.ema.step_ema(self.ema_model, self.model)
self.scheduler.step()

def one_epoch(self, train=True, use_wandb=False):
avg_loss = 0.
if train: self.model.train()
else: self.model.eval()
pbar = progress_bar(self.train_dataloader, leave=False)
for i, (images, labels) in enumerate(pbar):
with torch.autocast("cuda") and (torch.inference_mode() if not train else torch.enable_grad()):
images = images.to(self.device)
labels = labels.to(self.device)
t = self.sample_timesteps(images.shape[0]).to(self.device)
x_t, noise = self.noise_images(images, t)
if np.random.random() < 0.1:
labels = None
predicted_noise = self.model(x_t, t, labels)
loss = self.mse(noise, predicted_noise)
avg_loss += loss
if train:
self.train_step()
if use_wandb:
wandb.log({"train_mse": loss.item(),
"learning_rate": self.scheduler.get_last_lr()[0]})
pbar.comment = f"MSE={loss.item():2.3f}"
return avg_loss.mean().item()
Here, you can see in the first part of our W&B instrumentation we log the training loss and the learning rate value. This way we can follow the scheduler we are using. To actually log the samples, we define a custom function to perform model inference:
@torch.inference_mode()
def log_images(self):
"Log images to wandb and save them to disk"
labels = torch.arange(self.num_classes).long().to(self.device)
sampled_images = self.sample(use_ema=False, n=len(labels), labels=labels)
ema_sampled_images = self.sample(use_ema=True, n=len(labels), labels=labels)
plot_images(sampled_images) #to display on jupyter if available
# log images to wandb
wandb.log({"sampled_images": [wandb.Image(img.permute(1,2,0).squeeze().cpu().numpy()) for img in sampled_images]})
wandb.log({"ema_sampled_images": [wandb.Image(img.permute(1,2,0).squeeze().cpu().numpy()) for img in ema_sampled_images]})
And also a function to save the model checkpoints:
def save_model(self, run_name, epoch=-1):
"Save model locally and to wandb"
torch.save(self.model.state_dict(), os.path.join("models", run_name, f"ckpt.pt"))
torch.save(self.ema_model.state_dict(), os.path.join("models", run_name, f"ema_ckpt.pt"))
torch.save(self.optimizer.state_dict(), os.path.join("models", run_name, f"optim.pt"))
at = wandb.Artifact("model", type="model", description="Model weights for DDPM conditional", metadata={"epoch": epoch})
at.add_dir(os.path.join("models", run_name))
wandb.log_artifact(at)
Everything fits nicely into the fit function 😎
def prepare(self, args):
"Prepare the model for training"
setup_logging(args.run_name)
device = args.device
self.train_dataloader, self.val_dataloader = get_data(args)
self.optimizer = optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=0.001)
self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=args.lr,
steps_per_epoch=len(self.train_dataloader), epochs=args.epochs)
self.mse = nn.MSELoss()
self.ema = EMA(0.995)
self.scaler = torch.cuda.amp.GradScaler()

def fit(self, args):
self.prepare(args)
for epoch in range(args.epochs):
logging.info(f"Starting epoch {epoch}:")
self.one_epoch(train=True)
## validation
if args.do_validation:
self.one_epoch(train=False)
# log predicitons
if epoch % args.log_every_epoch == 0:
self.log_images(use_wandb=args.use_wandb)

# save model
self.save_model(run_name=args.run_name, use_wandb=args.use_wandb, epoch=epoch)

Sampling Images

To sample images, we need to start from random noise and iteratively denoise to obtain the final image. The procedure is very well described in "The Illustrated Stable Diffusion." Our case is much simpler but the sampling is the same. It's a good read anyway for people interested in understanding the bytes and bits of the Stable Diffusion architecture.
In our case, we don't need a decoder, as our images are already full resolution after the UNet's output.
Denoising steps one at a time. We are doing exactly this but without the Image Decoder part.
The sampling code progressively removes noise from the image following a scheduler of noise. We start from random pure noise and end up with a sampled image. The code is kind of confusing because the params are named after the equations on the DDPM paper.
We are using the second algorithm from the paper:
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise

Training a Model to Generate Fonts

Let's try another labeled dataset. For this simple example, we want datasets that have small images. After a search on Kaggle, the Alphabet character fonts dataset seems a good candidate. Let's take a look at what this dataset looks like.
Each row of the dataset is a font family with letters from A to Z rendered as 32x32 pixels BW images. We can quickly log this to W&B and explore:

Run set
10

The idea here is to train a diffusion model to generate new fonts! We can apply the same idea as with CIFAR-10 and condition generation to the actual letter we want to produce. As before, let's take a look at the training process by sampling outputs regularly during training:


The downside of the model here is that there is no way of generating font families (or styles), as we trained the model by taking each letter separately, so we can generate letters one by one only. So even if we pass labels A-Z (or W&B 🤣) we get independent random letters each time.


You can check how I created this wandb.Tables in this notebook.

Using fastai to Train your Diffusion Model

We saw how little code is needed to train a conditional diffusion model from scratch, but if one uses a library as fastai, it can be reduced even further! I encourage you to follow "Deep Learning for Coders Part 2" which is going on right now and learn more about generative models with Jeremy.
In the meantime: the same CIFAR code I showed you before is just a callback in fastai's language. Check fastdiffusion repository where a bunch of fastai devs are adding their implementations of diffusion models.
To train CIFAR-10 you only need to pass this callback to an Image 2 Image fastai pipeline:
class ConditionalDDPMCallback(Callback):
def __init__(self, n_steps, beta_min, beta_max, tensor_type=TensorImage):
store_attr()

def before_fit(self):
self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(self.dls.device) # variance schedule, linearly increased with timestep
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.sigma = torch.sqrt(self.beta)

def before_batch_training(self):
x0 = self.xb[0] # original images, x_0
eps = self.tensor_type(torch.randn(x0.shape, device=x0.device)) # noise, x_T
batch_size = x0.shape[0]
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long) # select random timesteps
alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps #noisify the image
self.learn.xb = (xt, t, self.yb[0]) # input to our model is noisy image and timestep
self.learn.yb = (eps,) # ground truth is the noise


def before_batch_sampling(self):
xt = self.tensor_type(self.xb[0]) # a full batch at once!
batch_size = xt.shape[0]
label = torch.arange(10, dtype=torch.long, device=xt.device).repeat(batch_size//10 + 1).flatten()[0:batch_size]
for t in progress_bar(reversed(range(self.n_steps)), total=self.n_steps, leave=False):
t_batch = torch.full((batch_size,), t, device=xt.device, dtype=torch.long)
z = torch.randn(xt.shape, device=xt.device) if t > 0 else torch.zeros(xt.shape, device=xt.device)
alpha_t = self.alpha[t] # get noise level at current timestep
alpha_bar_t = self.alpha_bar[t]
sigma_t = self.sigma[t]
xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * self.model(xt, t_batch, label=label)) + sigma_t*z # predict x_(t-1) in accordance to Algorithm 2 in paper
self.learn.pred = (xt,)
raise CancelBatchException

def before_batch(self):
if not hasattr(self, 'gather_preds'): self.before_batch_training()
else: self.before_batch_sampling()
and then pass this callback to the corresponding Learner:
model = ConditionalUnet(dim=32, channels=1, num_classes=10).cuda()
ddpm_learner = Learner(dls, model,
cbs=[ConditionalDDPMCallback(n_steps=1000, beta_min=0.0001, beta_max=0.02, tensor_type=TensorImageBW)],
loss_func=nn.MSELoss()).to_fp16()
As fastai supports W&B out of the box, it is as simple as passing the WandbCallback :
with wandb.init(project="sd_from_scratch"):
ddpm_learner.fit_one_cycle(10, 1e-4, cbs=WandbCallback(log_preds=False))
for a more comprehensive example check this notebook.

Conclusions

I mostly did this to help myself understand diffusion models from scratch. I was never very fond of GANs as they are complex and need a lot of engineering tricks to make them train, so when this novel technique of generating images using the old trusty UNet came to conquer the open source world, I realized that I had to try it!
This new and powerful tool opens the door to a new way of generating labeled data. It could become a new way of pretraining or augmenting datasets to get more powerful pre-trained models. I have a couple of ideas on how to test this hypothesis, so stay tuned.
Shinichi Tanaka
Shinichi Tanaka •  
Thank you for sharing your very useful insights on diffusion model training!! You say "As you can see, the loss (`train_mse`) is not very smooth, so you could think that the models is not learning anything." However, in my experience of training diffusion models, if you look at the moving average line, the loss gradually decreases on a small scale. So I would like to check the moving average line of your loss curve!!!
1 reply
Steve Guo
Steve Guo •  
In your code example, Diffusion_models_with_fastai_conditional_cifart_EMA.ipynb, where is the file for "from data import *"? thank you. Very good example.
Reply
Rabeya Tus Sadia
Rabeya Tus Sadia •  
Thanks for the article. Would you please tell how to find out the FID score using this code ?
Reply
zhaoguangxiang
zhaoguangxiang •  
What hyperparameters should be changed if I apply this code to generate Cifar 32*32*3 images.
6 replies
Brown Dwarf
Brown Dwarf •  
There is a bug in "03_fonts.ipynb" on the lines calling the sample method "samples = adiff.sample(True, len(labels), labels)". On this function call, len(labels) should be replaced by labels. len(labels) is an int, while at this index the function is looking for a list of labels. On the next argument, sample is waiting for the cfg_scale argument which is set by default at 3. I don't know, if you want to set it to the number of labels or to leave it to its default value. I'd recommend it to keep it to its default value
1 reply
Brown Dwarf
Brown Dwarf •  *
In the code of the notebook titled "02_generate_samples.ipynb", it seems that the models stored within the artifacts folder resulting from the training undertaken using "01_train_cifar.ipynb", can't be loaded through the following line: diff.load(model_cpkt_path) There are a difference in the architecture between the trained model and the empty UNet. Did you come across the same issue? I had to replace the artifact_dir by the path to "models/DDPM_conditional"
1 reply
Brown Dwarf
Brown Dwarf •  
```with wandb.init(project="train_sd", group="train", config=config): wandb.use_artifact(config.dataset_artifact, type='dataset') diff.prepare(config) diff.fit(config)``` In the code of your notebook titled "01_train_cifar", do we need to add diff.prepare(config) before diff.fit(config). Without diff.prepare(config), it seems that we are getting some bugs
2 replies
Saurabh Joshi
Saurabh Joshi •  
Can I ask how you overcame the error: Torch not compiled with CUDA enabled? It seems to be a common error in other codes I am working with as well. Would be super curious to know which version you used?
2 replies
Bob
Bob •  
I have a 300 x 300 xray images with 14 classes, since it's xray image i have to train a model from scratch. should i downsample the images to 64x64 or modify the code for 300x300?
5 replies
MLalex
MLalex •  
Amazing Blog! Thanks I have a question ... do you guys think its better to fine-tune a hugging face diffusers model or train the model from scratch for a dataset with 3000 (special kind of images I want to generate) labeled with 3000 texdescriptions -> no prompt but in the style of an email) ... so the email text would be the conditioning. Thanks for your time and this amazing blog post! Best, Alex
3 replies
Kaiwen Wu
Kaiwen Wu •  
Nice blog post! How long does it take to train the diffusion model on the 64 x 64 CIFAR-10 dataset (V100 vs. A100)?
3 replies
Usman
Usman •  
Really cool work :) But I skipped over the code when I read that v100 is needed for training. Do you think it's possible to fine-tune these models on commercial GPUs such as 3090?
4 replies
Iterate on AI agents and models faster. Try Weights & Biases today.