Skip to main content

Mini-Sora on Flying MNIST

Created on April 15|Last edited on October 8
Training Sora-like models appears to demand large mount of data and compute out of reach for most organizations. We make the experimentation of training such models more accessible.
To scale down Sora training, we aim to train models that can generate videos as if they are collected from a tiny world: flying hand-written digits that follow simple physics.
As a baseline, 250 A10 hours ($200 on Lambda Cloud) would give us a model that can generate a 8-second video on a toy 2D world (flying mnist hand-written digits). The quality is decent enough to full me for 1 second. This gives me sufficient confidence to apply the same recipe to real world videos.
Two training recipes are used to establish this baseline: GPT (adapted from nanoGPT) with a small custom spatial temporal VAE and Diffusion Transformer (specifically ST-DiT) with the pre-trained Stable Diffusion VAE.
The good: without any labels of what digits are in the video, and not even knowing those are videos about digits at all, the model was able to 1) generate localized patterns that look like hand written digits, 2) simulate object permanence for 5+ seconds, 3) simulate simple constant velocity and bouncing at the boundary.
The bad: 1) temporal consistency is weak in GPT generated videos, and digits like to reproduce themselves, likely due to limited context window size; 2) for STDiT, while physics/movements look more smooth with longer temporal consistency, the digits are not stable, possibly due to weaker attention layers that do not attend jointly at space&time.
The next focus is cost reduction from this baseline: aiming to get 5x to 10x before running medium sized experiments. If real world videos require 100x more compute, then we need 1000 A10 days or about 60 days on a 8xA100. But if we increase training speed by 10x, then that means 1 week is enough to train a decent model on real data.
There are 11 cost reduction recipes in my backlog. 2 of them (using my own implementation) failed. 1 is on-going and looking promising. 7 of them are rather new research papers and require more experiments.
Here are the details of baseline runs.

Baselines

  • GPT: 1.5B ~ 2.5B tokens, 100~200 A10 hours, 61M parameters, 6000-token context window, 110k unique video scenes, quality 3 out of 5 (digits look good, 8 frames of temporal consistency for motion, tend to duplicate digits)
  • STDiT (with SD VAE): 25B tokens, 250 A10 hours, 160M parameters, 32K-token context window, 110 unique video scenes, quality 3 out of 5 (32 frames of consistency for motion, digits are less recognizable and morph often)

Run set
2


Cost reduction experiment ideas

The following are directions to improve the video generation quality and reduce training cost. I will add experiment runs to them.

1. Joint image+video training or fine-tune from a single-frame (image) generator

It appears the PixArt training pipeline got to decent image generation quality in about 10 hours (trained on 1B tokens during this time).

Run set
1

Using STDiT (from OpenSora), the 1-frame training run also looks good. Sub-sequent training then resume from this checkpoint.

Run set
3


2. MAMBA

3. Masked transformer

4. Infini-Transformer

5. Auto-regressive next scale prediction

6. MagViT

7. Flow matching

8. Attention mask

9. HLB and llm.c GPT training recipe

Add additional positional encoding for the 3D position. It should hopefully improve quality at the same cost.

10. Unclip/Kandinsky

This is an idea from Ethan.
The basic premise is that diffusion models typically go:
text -> image/video
meanwhile unclip goes
text -> image embedding -> decode to image
it's not too common these days, I think because there wasn't exactly a huge benefit to train in stages like that. But for long video, this could be very powerful,
by first modeling in token space, achieving context lengths similar to what language models achieve is very computationally reasonable.
as shown in the graphic as well, the image embeddings, even those like CLIP (not trained for reconstruction) at relatively small dims (1280 compared to common 4096 for LLMs) they manage to maintain a lot of the important information about the image they encode.
hence perhaps we can generate a very long video in embedding space, where the embeddings act as a coarse representation or the "skeleton" of the final video. Then the decoder can fill-in/hallucinate the remainder of the details in a chunk-wise approach.
In other words, the embedding content should explain maybe 95% of the variance in the outputs while maybe the random seed and biases of the decoder accounts for the remaining variance in outputs
the naive chunk-wise approach with temporal inpainting models typically does not perform very well, and at best you kinda get looping of the same content. this is likely because the model has not seen extended content or scene changs and other things that happen over long time frames
but if the model can learn to rely on the embeddings very strongly, perhaps its sufficient for only the embedding generator model to have been trained on these long sequences
My interpretation: train prior.unet, decoder_pipe.unet, while freezing prior.image_encoder, decoder_pipe.movq.decoder, decoder_pipe.movq.encoder? During generation, use prior.unet to sample image embeddings per frame, and then use decoder_pipe.unet sample latents per image conditioning on the image embedding, and finally use decoder_pipe.movq.decoder to translate to the pixel space.
Here is the prior model and decoder respectively

11. Sparse context window

Nano-GPT packs block_size (C=512 for example) examples into one example, where the context length increases from 1 to C, and auto-regressively predicts the next token. This works on 1D sequences. For video data supported on a 3D grid (num_frames, height, width), I am interested to use neighborhood structure to make the context window sparse. Instead of each token attending all previous tokens up to C steps back, only attend to tokens that are close to the target token.
The model signature becomes: given PL×3(X),XL×1,P1×3(Y)P^{(X)}_{L\times 3}, X_{L\times 1}, P^{(Y)}_{1\times 3}, predict Y1×1Y_{1\times 1}.
While the idea is intuitive, I have not yet made this work. The example packing no longer works. As a result, the training becomes much less efficient. The generation quality is worse.

Run set
1


12. Cross attention on the same sequence

Another idea is to break the context window into 2 parts: one part handled by self-attention with a time complexity of O(M), and the other part handled by cross attention with a time complexity of O(MN). We can keep M relatively small (e.g. 512) to encourage spatial consistency on the actively generated tokens by the "cursor". The more stale tokens can be handled by cross attention, to have weaker, but longer-range consistency. N can be quite large, such as 50,000.
My current implementation of cross attention is making the generation quality worse (likely a bug in generation). The loss is not better.

Run set
2