Skip to main content

Next Frame Prediction Using Diffusion: The fastai Approach

In this article, we look at how to use diffusion models to predict the next frame on a sequence of images, and we iterate fast over the MovingMNIST dataset.
Created on January 19|Last edited on April 7
In this article, you'll learn how to create the MovingMNIST dataset and build a diffusion model in PyTorch to predict the next frame in a sequence of moving digits.

Foo


This step-by-step guide – inspired by Jeremy Howard's method in the latest fastai course – will show you how to train a diffusion model in an autoregressive manner with very little code.
Each sequence is an item on the dataset.
Show me the code
💡
Here's what we'll be covering:

Table of Contents



Let's get going.

Motivation: The fastai Course Part 2

I've been following the fastai part 2 course, where Jeremy has shown us how to train diffusion models from scratch. The course focuses on iterating fast, so training a big model like Stable Diffusion on LAION or other humongous datasets is not an option. Instead, Jeremy builds intuition by training smaller diffusion models on the FashionMNIST dataset.
This approach of iterating fast is the core of the fastai experience, and I have adopted it in my research. What's really interesting about the course is that he builds all the code on top of this very thin wrapper around PyTorch called miniai, which he builds from scratch during the lessons. In a previous version of the course, he built the actual fastai v2, but this time, the library is simpler and more accessible. So hacking your way to train some custom model is something that I wanted to try.
And if you want to check how to train Denoising Diffusion Probabilistic Model (DDPM), you can check my previous article.

From Clouds to MovingMNIST

At my previous job, I worked on short-term forecasting for solar energy. Without getting into too many details, we tried to use algorithms to predict cloud movement to forecast when the sun would be behind a cloud. The task is basically next-frame prediction or advection/movement prediction.


This task is very challenging and has not yet been solved completely. So when I wanted to try a new paper/algorithm that looked promising, I had to find a smaller dataset that was easy to understand visually.
Enters MovingMNIST. This dataset is basically what you think it is: moving digits from the MNIST dataset.
As a bonus, you can construct this dataset dynamically! Doing so gives you an almost infinite amount of data to train your new shiny model 😎.

Creating MovingMNIST From Scratch

This dataset consists of moving digits over a canvas. To do so, we will:
  • Grab the dataset from torchvision.datasets
  • Learn how to move one single digit on the canvas following some random trajectory
  • Aggregate multiple digits on the canvas and format everything into a PyTorch dataset.
  • Discuss how to make this happen quickly
There are multiple ways of doing this. Initially, I used a hard-coded approach, where the trajectories and overlapping had to be computed by hand directly on the arrays. We'll take a different and more general approach as we want a more diverse dataset where digits can expand, rotate, shear, and translate. If you remember your Linear Algebra course, all these are affine transforms; a matrix and a vector can represent that 🤓, so we can compute them all at once!
The complete code to create MovingMNIST samples is here
💡

Affine Transforms in PyTorch

As our dataset will work with PyTorch models, let's use the available built-in functionality for affine transformations! 
Below, we'll apply an affine transformation on the image keeping the image center invariant. If the image is a torch Tensor, it is expected to have […, H, W] shape, where … means an arbitrary number of leading dimensions. (For more detail, see the docs).
import torchvision.transforms.functional as TF. # I know, I know...

TF.affine(
input_image, # the image to apply transform to
angle = 20 # rotation in degrees
scale = 1.3 # scaling in percentage (1.0 = no scaling)
translate = (2,3) # translation in pixels (x, y)
shear = 15 # deformation on the z-plane
)
This is a deterministic transform. You will obtain the same result each time.
Applying the transform above to the number 5.

Creating a Trajectory

We need to create a sequence of incremental moving digits now. Thankfully, that's as simple as applying the same transform multiple times 😎
from functools import partial

tf = partial(TF.affine, angle=angle, translate=translate, scale=scale, shear=shear)

def apply_n_times(tf, x, n=1):
"Apply `tf` to `x` `n` times, return all values"
sequence = [x]
for n in range(n):
sequence.append(tf(sequence[n]))
return sequence
Here f^5 represents f applied five times to the input image.

Random Trajectories

Now we need to create varied trajectories, so we get a diverse dataset. My solution was sampling different parameters for the transform each time I grabbed a different digit.
from types import SimpleNamespace

affine_params = SimpleNamespace(
angle=(-4, 4),
translate=((-5, 5), (-5, 5)),
scale=(.8, 1.2),
shear=(-3, 3),
)
# then we sample in those intervals
angle = random.uniform(*affine_params.angle)
translate = (random.uniform(*affine_params.translate[0]),
random.uniform(*affine_params.translate[1]))
scale = random.uniform(*affine_params.scale)
shear = random.uniform(*affine_params.shear)

tf = partial(TF.affine, angle=angle, translate=translate, scale=scale, shear=shear)
The parameters are small because once you have applied a rotation in 4 degrees 5 times, it is equivalent to 20 degrees once, so if you want the image to stay on the canvas, it is a good idea to keep them bounded.
Jupyter is excellent for adjusting the parameters, as you can visualize the output of the transform in real time by running the cell repeatedly.
Hitting Ctrl+Return repeatedly to re-run the same cell.
We put this new trajectory inside a class:
class RandomTrajectory:
def __init__(self, affine_params, n=5, **kwargs):
self.angle = random.uniform(*affine_params.angle)
self.translate = (random.uniform(*affine_params.translate[0]),
random.uniform(*affine_params.translate[1]))
self.scale = random.uniform(*affine_params.scale)
self.shear = random.uniform(*affine_params.shear)
self.n = n
self.tf = partial(TF.affine,
angle=self.angle,
translate=self.translate,
scale=self.scale,
shear=self.shear, **kwargs)
def __call__(self, img):
return apply_n_times(self.tf, img, n=self.n)
And we can now create random trajectories using this object!

Putting Everything Together

We refactor everything together into a class, and we are ready to use this and feed our neural network:
class MovingMNIST:
def __init__(self, path=".", # path to store the MNIST dataset
affine_params: dict=affine_params, # affine transform parameters, refer to torchvision.transforms.functional.affine
num_digits: list[int]=[1,2], # how many digits to move, random choice between the value provided
num_frames: int=4, # how many frames to create
img_size=64, # the canvas size, the actual digits are always 28x28
concat=True, # if we concat the final results (frames, 1, 28, 28) or a list of frames.
normalize=False # scale images in [0,1] and normalize them with MNIST stats. Applied at batch level. Have to take care of the canvas size that messes up the stats!
):
self.mnist = MNIST(path, download=True).data
self.affine_params = affine_params
self.num_digits = num_digits
self.num_frames = num_frames
self.img_size = img_size
self.pad = padding(img_size)
self.concat = concat
# we could add normalization here 👇
self.batch_tfms = [T.ConvertImageDtype(torch.float32)]
def random_place(self, img):
"Randomly place the digit inside the canvas"
x = random.uniform(-self.pad, self.pad)
y = random.uniform(-self.pad, self.pad)
return TF.affine(img, translate=(x,y), angle=0, scale=1, shear=(0,0))
def random_digit(self):
"Get a random MNIST digit randomly placed on the canvas"
img = self.mnist[[random.randrange(0, len(self.mnist))]]
pimg = TF.pad(img, padding=self.pad)
return self.random_place(pimg)
def _one_moving_digit(self):
digit = self.random_digit()
traj = RandomTrajectory(self.affine_params, n=self.num_frames-1)
return torch.stack(traj(digit))
def __getitem__(self, i):
moving_digits = [self._one_moving_digit() for _ in range(random.choice(self.num_digits))]
moving_digits = torch.stack(moving_digits)
combined_digits = moving_digits.max(dim=0)[0]
return combined_digits if self.concat else [t.squeeze(dim=0) for t in combined_digits.split(1)]
def get_batch(self, bs=32):
"Grab a batch of data"
batch = torch.stack([self[0] for _ in range(bs)])
return self.batch_tfms(batch) if self.batch_tfms is not None else batch
Some things you'll notice:
  • The items are randomly sampled each time, so when you call __getitem__ you are not indexing on the dataset but sampling new random digits and trajectories.
  • This dataset has no length __len__ as we can generate as many items as we want.
  • We can trivially create a Dataloader from this infinite source of samples by just sampling bs items and stacking them together.

Usability

This dataset, as it is, is usable, but it is pretty slow. Thankfully, there's a simple solution. You can generate, say, 100K samples and dump them to disk. Then, you load the dataset as a PyTorch tensor and iterate over that instead of generating samples on the fly.



The complete code to create MovingMNIST samples is here
💡

Training a Diffusion Model On MovingMNIST

Foo


We have the data. Now we need a model.
We will use the same model we used in my previous article. But instead of feeding one image at a time, we'll feed a stack of frames.
But how can a model built to process one image at a time be used to process multiple frames? There are multiple solutions to this issue; we will opt for the simplest one that works well.
The neural network we'll train is a UNet with Self Attention layers between blocks, the same network we used to generate Cifar10 and Fonts in the previous article.
This network can process images of an arbitrary number of channels (RGB, Grayscale, or more!). We'll play this feature to feed the network an image with num_frames channels.
We stack the images into a fat tensor that has as many channels as frames, in this case 5.

MiniAI for DDPM

Without spoiling the course (that I encourage you to take ASAP!), miniai is a thin wrapper around PyTorch. This means that it gives you some building blocks to create a customizable training loop and hook your optimizers, metrics, and callbacks.
To train a DDPM model, you train a UNet neural network to denoise images. If you do this sufficiently long, your model can generate new images from pure noise. We want exactly that for our next frame.
We'll need to create a custom noisify function to only add noise to the last frame – the one we want to predict. For each sequence of frames, we will split the batch into two parts, the past frames (all but the last) and the last frame. We'll randomly add noise following the DDPM noise schedule.
We are going to the same noisify function as the one from the fastai course, but we will apply the noise only to the last frame of the sequence.
NOTE: We are using the same defaults as the course notebook.
def noisify_ddpm(x0):
"Noise by ddpm"
device = x0.device
n = len(x0)
t = torch.randint(0, n_steps, (n,), dtype=torch.long)
ε = torch.randn(x0.shape, device=device)
ᾱ_t = alphabar[t].reshape(-1, 1, 1, 1).to(device)
xt = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε
return xt, t.to(device), ε

def noisify_last_frame(frames, noise_func):
"Noisify the last frame of a sequence"
past_frames = frames[:,:-1]
last_frame = frames[:,-1:]
noise, t, e = noise_func(last_frame)
return torch.cat([past_frames, noise], dim=1), t, e
If we consider sequences of four frames and we look at a batch of size 2, we get the following:
We plot the time step used to noise the last frame (values in [0,1000])
Constructing a DataLoader is as simple as modifying the collate function to noise the inputs, as shown above:
def noisify_collate(noise_func):
def _inner(b):
"Collate function that noisifies the last frame"
return noisify_last_frame(default_collate(b), noise_func)
return _inner

class NoisifyDataloader(DataLoader):
"Noisify the last frame of a dataloader by applying a noise function,
after collating the batch"
def __init__(self, dataset, *args, noise_func=noisify, **kwargs):
super().__init__(dataset, *args, collate_fn=noisify_collate(noise_func), **kwargs)
And then creating the Learner object in miniminiai
from diffusers import UNet2DModel

# optimization params
tmax = config.epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=config.lr, total_steps=tmax)
opt_func = partial(optim.Adam, eps=1e-5)

# create the model
model = UNet2DModel(in_channels=4, out_channels=1, block_out_channels=(16, 32, 64, 128), norm_num_groups=8)
init_ddpm(model)
cbs = [DDPMCB2(), ProgressCB(plot=True), MetricsCB(),
BatchSchedCB(sched), AccelerateCB(n_inp=2)]
learn = Learner(model, dls, nn.MSELoss(reduction="sum"),
lr=config.lr, cbs=cbs, opt_func=opt_func)
We'll modify the training loop by passing a list of very thin and modular callbacks. I won't go into all the details, as Jeremy explains them way better than me but the only callback that is relevant to us is the DDPMCB2 callback:
class DDPMCB(Callback):
"Get model samples from output object"
def after_predict(self, learn): learn.preds = learn.preds.sample

Training and monitoring the model

Of course, we need to monitor how our models are training. To do so, we will log the metrics to Weights & Biases and sample model predictions regularly.
I created two callbacks to do this: one to monitor and the other to sample from the model. The monitoring callback is very hacky, as miniai is still under active dev. It overrides the ProgressCB that himself overrides the MetricsCB, but Johno gave me and idea on how to streamline mine.
def to_wandb_image(img):
"Stack the images horizontally"
return wandb.Image(torch.cat(img.split(1), dim=-1).cpu().numpy())

def log_images(model, xt):
"Sample and log images to W&B"
samples = ddim_sampler()(model, xt)
frames = torch.cat([xt.to(samples[-1].device), samples], dim=1)
wandb.log({"sampled_images": [to_wandb_image(img) for img in frames]})

class LogPreds(Callback):
"Log samples to W&B"
def __init__(self, n_preds=10, log_every=1):
self.n_preds=n_preds
self.log_every=log_every
def before_fit(self, learn):
dt = learn.dls.valid
xt, t, ε = next(iter(dt))
self.xt = xt[:self.n_preds,:-1,...]
def after_epoch(self, learn):
if not learn.training and (learn.epoch%self.log_every==0):
log_images(learn.model, self.xt)

Foo


This callback will log the same batch of images every self.log_every number of epochs. As the model trains, the generated frames will be less noisy. You will see metrics and the logged images in the Weights & Biases workspace.


As the model is autoregressive, we can sample as far as we want into the future. Here we generate five future frames by feeding the model the last four frames iteratively.

The trained model is not perfect, and the images degrade over time. We should probably train longer and generate more data, but this is an impressive starting point, proving that diffusion models can do image generation remarkably well!
Most of the code I used to train a DDPM model for next-frame prediction comes from this notebook of the latest version of the course.

Conclusions

Since the release of Stable Diffusion, I have been interested in applying the generative capabilities of these models to real-life problems. But generating future video frames is a complex problem. Recent papers such as MCVD: Masked Conditional Video Diffusion for Prediction, Generation, and Interpolation push diffusion models even further and produce insanely good results in datasets way more complex than MovingMNIST. But as you can see in this article, you can get very far with a simple DDPM model trained for just a couple of hours!
We are currently working with my colleague Soumik on a real weather dataset to predict cloud movement! Come to GTC to see what we have cooked up!

Iterate on AI agents and models faster. Try Weights & Biases today.