Next Frame Prediction Using Diffusion: The fastai Approach
In this article, we look at how to use diffusion models to predict the next frame on a sequence of images, and we iterate fast over the MovingMNIST dataset.
Created on January 19|Last edited on April 7
Comment
In this article, you'll learn how to create the MovingMNIST dataset and build a diffusion model in PyTorch to predict the next frame in a sequence of moving digits.
This step-by-step guide – inspired by Jeremy Howard's method in the latest fastai course – will show you how to train a diffusion model in an autoregressive manner with very little code.

Each sequence is an item on the dataset.
💡
Here's what we'll be covering:
Table of Contents
Motivation: The fastai Course Part 2From Clouds to MovingMNISTCreating MovingMNIST From ScratchTraining a Diffusion Model On MovingMNISTMiniAI for DDPMTraining and monitoring the modelConclusions
Let's get going.
Motivation: The fastai Course Part 2
I've been following the fastai part 2 course, where Jeremy has shown us how to train diffusion models from scratch. The course focuses on iterating fast, so training a big model like Stable Diffusion on LAION or other humongous datasets is not an option. Instead, Jeremy builds intuition by training smaller diffusion models on the FashionMNIST dataset.
This approach of iterating fast is the core of the fastai experience, and I have adopted it in my research. What's really interesting about the course is that he builds all the code on top of this very thin wrapper around PyTorch called miniai, which he builds from scratch during the lessons. In a previous version of the course, he built the actual fastai v2, but this time, the library is simpler and more accessible. So hacking your way to train some custom model is something that I wanted to try.
And if you want to check how to train Denoising Diffusion Probabilistic Model (DDPM), you can check my previous article.
From Clouds to MovingMNIST
At my previous job, I worked on short-term forecasting for solar energy. Without getting into too many details, we tried to use algorithms to predict cloud movement to forecast when the sun would be behind a cloud. The task is basically next-frame prediction or advection/movement prediction.
This task is very challenging and has not yet been solved completely. So when I wanted to try a new paper/algorithm that looked promising, I had to find a smaller dataset that was easy to understand visually.
Enters MovingMNIST. This dataset is basically what you think it is: moving digits from the MNIST dataset.
As a bonus, you can construct this dataset dynamically! Doing so gives you an almost infinite amount of data to train your new shiny model 😎.
Creating MovingMNIST From Scratch
This dataset consists of moving digits over a canvas. To do so, we will:
- Grab the dataset from torchvision.datasets
- Learn how to move one single digit on the canvas following some random trajectory
- Aggregate multiple digits on the canvas and format everything into a PyTorch dataset.
- Discuss how to make this happen quickly
There are multiple ways of doing this. Initially, I used a hard-coded approach, where the trajectories and overlapping had to be computed by hand directly on the arrays. We'll take a different and more general approach as we want a more diverse dataset where digits can expand, rotate, shear, and translate. If you remember your Linear Algebra course, all these are affine transforms; a matrix and a vector can represent that 🤓, so we can compute them all at once!
💡
Affine Transforms in PyTorch
As our dataset will work with PyTorch models, let's use the available built-in functionality for affine transformations!
Below, we'll apply an affine transformation on the image keeping the image center invariant. If the image is a torch Tensor, it is expected to have […, H, W] shape, where … means an arbitrary number of leading dimensions. (For more detail, see the docs).
import torchvision.transforms.functional as TF. # I know, I know...TF.affine(input_image, # the image to apply transform toangle = 20 # rotation in degreesscale = 1.3 # scaling in percentage (1.0 = no scaling)translate = (2,3) # translation in pixels (x, y)shear = 15 # deformation on the z-plane)
This is a deterministic transform. You will obtain the same result each time.

Applying the transform above to the number 5.
Creating a Trajectory
We need to create a sequence of incremental moving digits now. Thankfully, that's as simple as applying the same transform multiple times 😎
from functools import partialtf = partial(TF.affine, angle=angle, translate=translate, scale=scale, shear=shear)def apply_n_times(tf, x, n=1):"Apply `tf` to `x` `n` times, return all values"sequence = [x]for n in range(n):sequence.append(tf(sequence[n]))return sequence

Here f^5 represents f applied five times to the input image.
Random Trajectories
Now we need to create varied trajectories, so we get a diverse dataset. My solution was sampling different parameters for the transform each time I grabbed a different digit.
from types import SimpleNamespaceaffine_params = SimpleNamespace(angle=(-4, 4),translate=((-5, 5), (-5, 5)),scale=(.8, 1.2),shear=(-3, 3),)# then we sample in those intervalsangle = random.uniform(*affine_params.angle)translate = (random.uniform(*affine_params.translate[0]),random.uniform(*affine_params.translate[1]))scale = random.uniform(*affine_params.scale)shear = random.uniform(*affine_params.shear)tf = partial(TF.affine, angle=angle, translate=translate, scale=scale, shear=shear)
The parameters are small because once you have applied a rotation in 4 degrees 5 times, it is equivalent to 20 degrees once, so if you want the image to stay on the canvas, it is a good idea to keep them bounded.
Jupyter is excellent for adjusting the parameters, as you can visualize the output of the transform in real time by running the cell repeatedly.

Hitting Ctrl+Return repeatedly to re-run the same cell.
We put this new trajectory inside a class:
class RandomTrajectory:def __init__(self, affine_params, n=5, **kwargs):self.angle = random.uniform(*affine_params.angle)self.translate = (random.uniform(*affine_params.translate[0]),random.uniform(*affine_params.translate[1]))self.scale = random.uniform(*affine_params.scale)self.shear = random.uniform(*affine_params.shear)self.n = nself.tf = partial(TF.affine,angle=self.angle,translate=self.translate,scale=self.scale,shear=self.shear, **kwargs)def __call__(self, img):return apply_n_times(self.tf, img, n=self.n)
And we can now create random trajectories using this object!
Putting Everything Together
We refactor everything together into a class, and we are ready to use this and feed our neural network:
class MovingMNIST:def __init__(self, path=".", # path to store the MNIST datasetaffine_params: dict=affine_params, # affine transform parameters, refer to torchvision.transforms.functional.affinenum_digits: list[int]=[1,2], # how many digits to move, random choice between the value providednum_frames: int=4, # how many frames to createimg_size=64, # the canvas size, the actual digits are always 28x28concat=True, # if we concat the final results (frames, 1, 28, 28) or a list of frames.normalize=False # scale images in [0,1] and normalize them with MNIST stats. Applied at batch level. Have to take care of the canvas size that messes up the stats!):self.mnist = MNIST(path, download=True).dataself.affine_params = affine_paramsself.num_digits = num_digitsself.num_frames = num_framesself.img_size = img_sizeself.pad = padding(img_size)self.concat = concat# we could add normalization here 👇self.batch_tfms = [T.ConvertImageDtype(torch.float32)]def random_place(self, img):"Randomly place the digit inside the canvas"x = random.uniform(-self.pad, self.pad)y = random.uniform(-self.pad, self.pad)return TF.affine(img, translate=(x,y), angle=0, scale=1, shear=(0,0))def random_digit(self):"Get a random MNIST digit randomly placed on the canvas"img = self.mnist[[random.randrange(0, len(self.mnist))]]pimg = TF.pad(img, padding=self.pad)return self.random_place(pimg)def _one_moving_digit(self):digit = self.random_digit()traj = RandomTrajectory(self.affine_params, n=self.num_frames-1)return torch.stack(traj(digit))def __getitem__(self, i):moving_digits = [self._one_moving_digit() for _ in range(random.choice(self.num_digits))]moving_digits = torch.stack(moving_digits)combined_digits = moving_digits.max(dim=0)[0]return combined_digits if self.concat else [t.squeeze(dim=0) for t in combined_digits.split(1)]def get_batch(self, bs=32):"Grab a batch of data"batch = torch.stack([self[0] for _ in range(bs)])return self.batch_tfms(batch) if self.batch_tfms is not None else batch
Some things you'll notice:
- The items are randomly sampled each time, so when you call __getitem__ you are not indexing on the dataset but sampling new random digits and trajectories.
- This dataset has no length __len__ as we can generate as many items as we want.
- We can trivially create a Dataloader from this infinite source of samples by just sampling bs items and stacking them together.
Usability
This dataset, as it is, is usable, but it is pretty slow. Thankfully, there's a simple solution. You can generate, say, 100K samples and dump them to disk. Then, you load the dataset as a PyTorch tensor and iterate over that instead of generating samples on the fly.
💡
Training a Diffusion Model On MovingMNIST
We have the data. Now we need a model.
We will use the same model we used in my previous article. But instead of feeding one image at a time, we'll feed a stack of frames.
But how can a model built to process one image at a time be used to process multiple frames? There are multiple solutions to this issue; we will opt for the simplest one that works well.
The neural network we'll train is a UNet with Self Attention layers between blocks, the same network we used to generate Cifar10 and Fonts in the previous article.
This network can process images of an arbitrary number of channels (RGB, Grayscale, or more!). We'll play this feature to feed the network an image with num_frames channels.

We stack the images into a fat tensor that has as many channels as frames, in this case 5.
MiniAI for DDPM
Without spoiling the course (that I encourage you to take ASAP!), miniai is a thin wrapper around PyTorch. This means that it gives you some building blocks to create a customizable training loop and hook your optimizers, metrics, and callbacks.
To train a DDPM model, you train a UNet neural network to denoise images. If you do this sufficiently long, your model can generate new images from pure noise. We want exactly that for our next frame.
We'll need to create a custom noisify function to only add noise to the last frame – the one we want to predict. For each sequence of frames, we will split the batch into two parts, the past frames (all but the last) and the last frame. We'll randomly add noise following the DDPM noise schedule.
We are going to the same noisify function as the one from the fastai course, but we will apply the noise only to the last frame of the sequence.
NOTE: We are using the same defaults as the course notebook.
def noisify_ddpm(x0):"Noise by ddpm"device = x0.devicen = len(x0)t = torch.randint(0, n_steps, (n,), dtype=torch.long)ε = torch.randn(x0.shape, device=device)ᾱ_t = alphabar[t].reshape(-1, 1, 1, 1).to(device)xt = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*εreturn xt, t.to(device), εdef noisify_last_frame(frames, noise_func):"Noisify the last frame of a sequence"past_frames = frames[:,:-1]last_frame = frames[:,-1:]noise, t, e = noise_func(last_frame)return torch.cat([past_frames, noise], dim=1), t, e
If we consider sequences of four frames and we look at a batch of size 2, we get the following:

We plot the time step used to noise the last frame (values in [0,1000])
Constructing a DataLoader is as simple as modifying the collate function to noise the inputs, as shown above:
def noisify_collate(noise_func):def _inner(b):"Collate function that noisifies the last frame"return noisify_last_frame(default_collate(b), noise_func)return _innerclass NoisifyDataloader(DataLoader):"Noisify the last frame of a dataloader by applying a noise function,after collating the batch"def __init__(self, dataset, *args, noise_func=noisify, **kwargs):super().__init__(dataset, *args, collate_fn=noisify_collate(noise_func), **kwargs)
And then creating the Learner object in miniminiai
from diffusers import UNet2DModel# optimization paramstmax = config.epochs * len(dls.train)sched = partial(lr_scheduler.OneCycleLR, max_lr=config.lr, total_steps=tmax)opt_func = partial(optim.Adam, eps=1e-5)# create the modelmodel = UNet2DModel(in_channels=4, out_channels=1, block_out_channels=(16, 32, 64, 128), norm_num_groups=8)init_ddpm(model)cbs = [DDPMCB2(), ProgressCB(plot=True), MetricsCB(),BatchSchedCB(sched), AccelerateCB(n_inp=2)]learn = Learner(model, dls, nn.MSELoss(reduction="sum"),lr=config.lr, cbs=cbs, opt_func=opt_func)
We'll modify the training loop by passing a list of very thin and modular callbacks. I won't go into all the details, as Jeremy explains them way better than me but the only callback that is relevant to us is the DDPMCB2 callback:
class DDPMCB(Callback):"Get model samples from output object"def after_predict(self, learn): learn.preds = learn.preds.sample
Training and monitoring the model
Of course, we need to monitor how our models are training. To do so, we will log the metrics to Weights & Biases and sample model predictions regularly.
I created two callbacks to do this: one to monitor and the other to sample from the model. The monitoring callback is very hacky, as miniai is still under active dev. It overrides the ProgressCB that himself overrides the MetricsCB, but Johno gave me and idea on how to streamline mine.
def to_wandb_image(img):"Stack the images horizontally"return wandb.Image(torch.cat(img.split(1), dim=-1).cpu().numpy())def log_images(model, xt):"Sample and log images to W&B"samples = ddim_sampler()(model, xt)frames = torch.cat([xt.to(samples[-1].device), samples], dim=1)wandb.log({"sampled_images": [to_wandb_image(img) for img in frames]})class LogPreds(Callback):"Log samples to W&B"def __init__(self, n_preds=10, log_every=1):self.n_preds=n_predsself.log_every=log_everydef before_fit(self, learn):dt = learn.dls.validxt, t, ε = next(iter(dt))self.xt = xt[:self.n_preds,:-1,...]def after_epoch(self, learn):if not learn.training and (learn.epoch%self.log_every==0):log_images(learn.model, self.xt)
This callback will log the same batch of images every self.log_every number of epochs. As the model trains, the generated frames will be less noisy. You will see metrics and the logged images in the Weights & Biases workspace.
As the model is autoregressive, we can sample as far as we want into the future. Here we generate five future frames by feeding the model the last four frames iteratively.

The trained model is not perfect, and the images degrade over time. We should probably train longer and generate more data, but this is an impressive starting point, proving that diffusion models can do image generation remarkably well!
Most of the code I used to train a DDPM model for next-frame prediction comes from this notebook of the latest version of the course.
Conclusions
Since the release of Stable Diffusion, I have been interested in applying the generative capabilities of these models to real-life problems. But generating future video frames is a complex problem. Recent papers such as MCVD: Masked Conditional Video Diffusion for Prediction, Generation, and Interpolation push diffusion models even further and produce insanely good results in datasets way more complex than MovingMNIST. But as you can see in this article, you can get very far with a simple DDPM model trained for just a couple of hours!
We are currently working with my colleague Soumik on a real weather dataset to predict cloud movement! Come to GTC to see what we have cooked up!
Diffusion on the Clouds: Short-term solar energy forecasting with Diffusion Models
Using diffusion models to predict cloud movement on satellite imagery to forecast solar energy production
How To Train a Conditional Diffusion Model From Scratch
In this article, we look at how to train a conditional diffusion model and find out what you can learn by doing so, using W&B to log and track our experiments.
Stable Diffusion and the Samplers Mystery
This report explores Stability AI's Stable Diffusion model and focuses on the different samplers methods available for image generation and their comparison.
Making My Kid a Jedi Master With Stable Diffusion and Dreambooth
In this article, we'll explore how to teach and fine-tune Stable Diffusion to transform my son into his favorite Star Wars character using Dreambooth.
Add a comment
Tags: Articles, Intermediate, Diffusion, fastai, MNIST, GenAI, Plots, Panels, Tables, Computer Vision
Iterate on AI agents and models faster. Try Weights & Biases today.