Warning
This repository is still in development. For now there are no instructions on how to run the code because it is subject to change
Oniris is an universal and novel way of training diffusion models for video generation and world-modelling-- It generalizes and improves upon all of the previously known methods of video diffusion models and diffusion-based world-modelling
- Each frame is generated sequencially and can attent to all of its context frames (just like LLMs generate tokens)
- The training is sample-efficient. (just like LLM training)
- Each frame is generated via a reverse-diffusion process. (just like diffusion for image generation)
- The architecture is a generalization of what can be found on the paper "Analyzing and Improving the Training Dynamics of Diffusion Models"
Important
Working on training it on Counter-Strike. Stay tuned, for now here are the results with the lunar lander
This are the results of a ~4 hours training run with a Nvidia RTX 4090. The model was trained on the Lunar-Lander gymnasium environment.
In the image above the first three rows are given as context, the next 3 rows are generated by Oniris.To this day 3 main techniques have been used to generate a sequence of frames (image taken from the DIAMOND paper)

All of them have deal-breaking problems:
-
The Diffusion for video generation can only effectively generate videos of fixed time duration, and it's of no use for world-modelling.
-
The Frame-stacking architecture can't attend in an effective way to previous frames, so it suffers from severe amnesia.
-
The Cross-attention architecture is the one that makes most sense. However, it's extremely inefficient during training because cost per sample increases (super-)linearly with the number of context frames.
-
It is sample-efficient like diffusion video generation. On top of that it can generate videos of any length and can be used for world-modelling. (In the future this model can be expanded to be able to generate multiple frames at the same time)
-
It implicitly employs frame-stacking because it uses 3D convolutional layers-- They can be thought as stacking frames channel-wise and then doing 2D convolutions. On top of that it doesn't suffer from amnesia because it can attend all of the previous frames with the attention mechanism.
-
It can attend to all of the previous frames. On top of that during training the computational cost per-sample is (roughly) constant as the context lenght increases.
Warning
The information provided below is just a brief overview of what's going on under the hood.
One way to train a Diffusion model is to learn to predict the score function
The most used architectures are UNets and Image-Transformers.
Language models work by estimating the probablilty distribuition of the last token given all of the previous ones
The transformer architecture allows to train such a model in a sample-efficient way. This is very important because it multiplies the effictive batch size by the sequence lenght.
To generate a video where each frame if generated autoregressively we need to unite the two paradigms by estimating the score given all of the previous frames.
Last-frame generation can be thought as image generation conditioned on the previous frames.
Where
Here is how you make the training in a way that is sample-efficient.
Let
- The first part is not noised
$x_c=(x_1,\dots,x_n)$ - The second part is noised
$x_\epsilon=(\tilde x_1,\dots,\tilde x_n)$ where each frame is noised as$\tilde x_{\epsilon,i} = x_i +\sigma_i\epsilon$ .
The input sequence
Here is an animation that shows how the training and inference logic are connected
attention_animation.mp4
In this model there are two modules that can transfer information between frames
VideoAttention3DCausalConvolution
During inference the attention module implements causal attention masking in a manner similar to how it's done in autoregressive language modelling
The only difference is that the mask is block-sparseDuring inference the model uses KV-caching to make it fast.
Here is an illustrative image that shows how the information moves
Here is a schematic representation on how the inputs and output interact. TODO: make this better!
This can be archieved by doing block-sparse masking using FlexAttention. Thanks to it no computation is wasted.
Wierdly enough, the convolution layer is the hardest to explain because it relies on a couple of tricks to make sure that the code runs as fast and efficiently as possible during training.
I'll write later how it works exactly. For now you can read the code
During inference the model caches only the activations inside of convolutional filter. This leads to yet another big improvement in speed making the per-frame inference computation ~O(1).
The loss is computed indipendently for each frame. The equations are adapted from [1,2]
Here are the (slightly oversiplified) equations. Let
where
And this is what you see in the graph below
In the image above on the left how the average loss goes down as the training progresses (~12h of a RTX4090). On the right it is shown the relashionship between the loss, the noise applied and the position along the sequence.
However the losses are normalized by their expected value with respect to
This ensures that the loss
We also developed a Group-Causal VAE that is capable of compressing arbitrarly long sequences very cheaply without compromising on performance with just one consumer gpu. We didn't want to do it in house, but we had because we didn't find anything like this.
The architecture is similar to a ResNet, but instead of using standard 2D convolutions we used Group-Causal 3D-Convolutional layers inspired from Improved Video VAE for Latent Video Diffusion Model but we simplified, streamlined and generalized their implementation.
We did not use any form of attention in the VAE.
- Temporal attention makes it impossible to encode arbitrarly long sequences as the RAM usage increases linearly with the context size.
- Spatial attention does not improve significantly the performance of VAEs. Most of the information moves locally.
We also added a group-wise KL-divergence loss (following How to train your VAE paper)
Below you can see the results for the training of the VAE






