This repository contains the implementation of hierarchical video prediction models using Recurrent State Space Models (RSSMs) and various temporal processors including Mamba, xLSTM, and LSTM architectures. The work extends Simon's MSVA implementation for video prediction tasks.
π Thesis
The full MSc thesis is included in this repository:
thesis/Thesis_DeepFutureFrames.pdf
The codebase implements a modular architecture for video prediction with the following key components:
- Encoders: Video frame encoders
- Decoders: Video frame decoders
- Temporal Processors: Various recurrent architectures for temporal modeling
- Hierarchical Modules: Multi-level state representations and POMDP-based modeling
- Single Level RSSM: Basic RSSM with single latent state
- Two Level RSSM: Hierarchical RSSM with two levels of abstraction
- Deterministic Models: Mamba1, Mamba2, xLSTM, LSTM implementations
- Multi-step Prediction: Multi-step future frame generation
Create the conda environment from the provided specification:
conda env create -f future-frames.yaml
conda activate future-framesAfter activating the environment, install the remaining packages in sequence:
# Install Mamba SSM
pip install mamba-ssm==2.2.4 --no-build-isolation --no-cache-dir
# Install Causal Conv1D
pip install causal-conv1d==1.5.0.post8
# Install additional dependencies
pip install opencv-python imageio tqdm requestsThe codebase supports the OGBench dataset collection. To download and process data:
cd data
python load_dataset.pyAvailable Datasets:
visual-antmaze-large-navigate-v0: Action dim 8, 1000 episodes, 1000 steps eachvisual-scene-play-v0: Action dim 5, 1000 episodes, 1000 steps eachvisual-puzzle-4x4-play-v0: Action dim 5, 1000 episodes, 1000 steps each- other visual datasets from website: "https://rail.eecs.berkeley.edu/datasets/ogbench/"
Data Processing:
- Downloads from Berkeley RAIL dataset repository
- Automatically resizes frames to 64x64 resolution
- Splits validation data into 80% val / 20% test
- Saves processed episodes as compressed NPZ files
python -m experiments.trainpython -m experiments.train_KL_AnnealingThe system uses Hydra for configuration management. Main configuration files are located in experiments/conf/:
- Single Level RSSM:
single_level_RSSM.yaml - Two Level RSSM:
two_level_RSSM.yaml - Deterministic: Various temporal processor options: Mamba1, Mamba2, xLSTM, LSTM
- Training: Learning rate, batch size, beta values
- Data: Dataset selection and preprocessing
model:
type: single_level_RSSM # or two_level_RSSM or self_supervised or multistep
encoder_type: "dreamerv2" #cnn, resnet, dreamerv2, clip, 3d or dreamerv2
decoder_type: "dreamerv2" #cnn, pixelshuffle, dreamerv2, clip, 3d
latent_dim: 256 #Latent Dimention of the processor.
temporal_block:
type: 'Mamba1' # Mamba1, Mamba2, xLSTM, LSTM
d_state: 256 # For Mamba1, Mamba2, xLSTM
hidden_size: 512 # For LSTM only
d_conv: 4 # For Mamba1, Mamba2, xLSTM
expand: 2 # For Mamba1, Mamba2, xLSTM
chunk_size: 512 # Mamba2 and xLSTM
use_mem_eff_path: True # For Mamba2
dropout: 0.1 # For Mamba1, xLSTM, LSTM
bidirectional: False # For LSTM
training:
lr: 0.001
batch_size: 32
rssm_beta: 1.0 # KL divergence weight
epochs: 100
data:
dataset_name: visual-antmaze-large-navigate-v0
frame_size: 64
sequence_length: 50The core/components/hierarchical_modules/ directory contains MSVA implementation.
- POMDP Modules: Base classes for Partially Observable Markov Decision Processes
- RSSM Implementation: Recurrent State Space Models with stochastic and deterministic components
- Recurrent Modules: Various temporal processing architectures
- Distribution Modules: Probabilistic state representations
- Forward Pass: Encode frames β Process temporally β Decode predictions
- Loss Computation: Reconstruction + KL divergence + optional action prediction
- Optimization: AdamW optimizer with optional learning rate scheduling
- Validation: Regular evaluation on held-out data
- Reconstruction Loss: Pixel-wise reconstruction quality
- KL Divergence: Latent space regularization
- PSNR/SSIM: Image quality metrics
- Action Prediction: Accuracy of predicted actions (if applicable)
- Checkpointing: Automatic model saving based on validation metrics
- Early Stopping: Prevents overfitting with patience-based stopping
- WandB Integration: Experiment tracking and visualization
- Implement in
core/components/hierarchical_modules/recurrent.py - Register with the
@registerdecorator - Add configuration in appropriate YAML files
- Implement in
core/components/temporal_blocks.py - Add condiguration in appropriate YAML Files.
- Extend
BaseEncoder/BaseDecoderclasses - Implement forward methods
- Update model factory functions
- Modify
data/load_dataset.pyfor new data sources - Implement appropriate preprocessing
- Update dataset configuration
DeepFutureFrames/
βββ core/ # Core model implementations
β βββ components/ # Modular components
β β βββ encoders.py # Video frame encoders
β β βββ decoders.py # Video frame decoders
β β βββ hierarchical_modules/ # Simon's MSVA extensions
β β β βββ pomdp/ # POMDP-based modeling
β β β βββ recurrent.py # Temporal processors
β β β βββ distributions.py # Probabilistic modules
β β βββ temporal_blocks.py # Temporal processing blocks
β βββ models/ # Complete model implementations
βββ data/ # Data loading and processing
β βββ datasets.py # Dataset classes
β βββ load_dataset.py # Data download and preprocessing
βββ experiments/ # Training and evaluation
β βββ conf/ # Configuration files
β βββ train.py # Fixed beta training
β βββ train_KL_Annealing.py # KL annealing training
βββ utils/ # Utility functions
β βββ loss.py # Loss functions
β βββ evaluation.py # Evaluation metrics
β βββ trajectory.py # Trajectory handling
βββ future-frames.yaml # Conda environment specification
- CUDA Memory: Reduce batch size or sequence length
- Installation Errors: Ensure correct Python version (3.11) and CUDA compatibility
- Data Loading: Check internet connection for dataset downloads
- Model Convergence: Adjust learning rate and beta values
This project is part of a thesis submission. Please respect academic usage guidelines and cite appropriately when using this code.