Skip to main content

How To Sample From a Generative AI Model During Training

In this article, we explore how to follow your generative image model during training by sampling from the model continuously, using W&B to track our results.
Created on September 8|Last edited on November 12
Whenever training a generative model, the premise is that the more you train, the better the model is. In our case: the model samples should get better in quality over time. How can we follow this when training an image model?
In this short article, we'll explore how to follow your generative image model during training by sampling from the model continuously.
Here's what we'll cover:

Table of Contents



Getting Started

Ideally, you have implemented a sampling method every X/epochs or step. Something that looks like this:
for epoch in range(EPOCHS):
# train steps and shenanigans...
# During validations 👇
if epoch % 10 == 0:
samples = model.generate_samples(num_samples=16)
save_samples(samples, name=f"samples_{epoch}")
As your training progresses, your samples improve in quality, as shown in the image below:
A sprite generation model, from our deeplearning.ai course on generative AI

Save Your Samples Along With Your Training Metrics

The simplest and easiest way of doing this is by saving the samples during training by logging them as a media panel in Weights & Biases. As an added bonus, this gives you a nice slider!
if epoch % 4 == 0:
samples = sample_ddpm_model(num_samples=30)
wandb.log({"train_samples": [wandb.Image(img) for img in samples]}). # cast the samples to wandb.Image # #a


Drag the "Step" slider above to see how our generated sprites improved over time
💡

Sampling with W&B Tables

Let's create a dummy pipeline that "trains" a model that generates images. We will call this model the bright-pikachu; this model adds light to the Pikachu images progressively. On epoch 0 we add zero brightness and on the last epoch we add 1 and recover the original image.
Making Pikachu chair look bright, courtesy of Dall-E-mini

What you will need to do during training, is log the same Table over and over again, exactly as you were doing with the media panel.



There Is a Better Way

What if we want to visualize all the Tables logged during training at the same time?
Simple! First, go to the Artifacts tab. Here you'll find all versions of the Table with the Pikachus, the workspace is only showing you the latest one (v6).
We can visualize any individual table by clicking on it and opening the Files tab:


How Do We Want To Stack the Tables Together?

This will depend on your training type and what you want to visualize


We can use a weave expression to grab all the Tables inside the Artifacts. By default, the concatenation happens vertically, by rows.
runs[0].loggedArtifactVersions.map((row, index) => row.file("table_pikachu.table.json"))



The type of concatenation can be changed by clicking the cogwheel, selecting Joining Rows instead, and using the key name.
Try it on the panel above ☝️



Conclusion

Having good telemetry during model training is very important to assess model performance. If you are already logging a bunch of metrics, saving your checkpoints and optimizers, launching a big cluster full of GPUs storing your model samples in the same place is a good idea.
Also, during training, the model is already loaded on GPU memory, so it is the best moment to sample from it! If this is not possible, maybe you could set up an external process that grabs the latest checkpoint and samples from the model once they become available; you could use W&B Automation to orchestrate this automatically.

Iterate on AI agents and models faster. Try Weights & Biases today.