Skip to content

kaushkay/DeepFutureFrames

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

28 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

DeepFutureFrames: Hierarchical Video Prediction with RSSMs and Temporal Processors

This repository contains the implementation of hierarchical video prediction models using Recurrent State Space Models (RSSMs) and various temporal processors including Mamba, xLSTM, and LSTM architectures. The work extends Simon's MSVA implementation for video prediction tasks.

πŸ“„ Thesis
The full MSc thesis is included in this repository:
thesis/Thesis_DeepFutureFrames.pdf

Architecture Overview

The codebase implements a modular architecture for video prediction with the following key components:

Core Components

  • Encoders: Video frame encoders
  • Decoders: Video frame decoders
  • Temporal Processors: Various recurrent architectures for temporal modeling
  • Hierarchical Modules: Multi-level state representations and POMDP-based modeling

Model Variants

  • Single Level RSSM: Basic RSSM with single latent state
  • Two Level RSSM: Hierarchical RSSM with two levels of abstraction
  • Deterministic Models: Mamba1, Mamba2, xLSTM, LSTM implementations
  • Multi-step Prediction: Multi-step future frame generation

Quick Start

1. Environment Setup

Create the conda environment from the provided specification:

conda env create -f future-frames.yaml
conda activate future-frames

2. Install Additional Dependencies

After activating the environment, install the remaining packages in sequence:

# Install Mamba SSM
pip install mamba-ssm==2.2.4 --no-build-isolation --no-cache-dir

# Install Causal Conv1D
pip install causal-conv1d==1.5.0.post8

# Install additional dependencies
pip install opencv-python imageio tqdm requests

3. Data Download and Processing

The codebase supports the OGBench dataset collection. To download and process data:

cd data
python load_dataset.py

Available Datasets:

  • visual-antmaze-large-navigate-v0: Action dim 8, 1000 episodes, 1000 steps each
  • visual-scene-play-v0: Action dim 5, 1000 episodes, 1000 steps each
  • visual-puzzle-4x4-play-v0: Action dim 5, 1000 episodes, 1000 steps each
  • other visual datasets from website: "https://rail.eecs.berkeley.edu/datasets/ogbench/"

Data Processing:

  • Downloads from Berkeley RAIL dataset repository
  • Automatically resizes frames to 64x64 resolution
  • Splits validation data into 80% val / 20% test
  • Saves processed episodes as compressed NPZ files

4. Running Experiments

Basic Training (Fixed Beta)

python -m experiments.train

KL Annealing Training (Variable Beta)

python -m experiments.train_KL_Annealing

Configuration

The system uses Hydra for configuration management. Main configuration files are located in experiments/conf/:

Model Configuration

  • Single Level RSSM: single_level_RSSM.yaml
  • Two Level RSSM: two_level_RSSM.yaml
  • Deterministic: Various temporal processor options: Mamba1, Mamba2, xLSTM, LSTM
  • Training: Learning rate, batch size, beta values
  • Data: Dataset selection and preprocessing

Key Configuration Options

model:
  type: single_level_RSSM  # or two_level_RSSM or self_supervised or multistep
  encoder_type: "dreamerv2" #cnn, resnet, dreamerv2, clip, 3d or dreamerv2
  decoder_type: "dreamerv2" #cnn, pixelshuffle, dreamerv2, clip, 3d
  latent_dim: 256 #Latent Dimention of the processor.
  temporal_block:
    type: 'Mamba1'  # Mamba1, Mamba2, xLSTM, LSTM
    d_state: 256        # For Mamba1, Mamba2, xLSTM
    hidden_size: 512    # For LSTM only
    d_conv: 4         # For Mamba1, Mamba2, xLSTM
    expand: 2         # For Mamba1, Mamba2, xLSTM
    chunk_size: 512    # Mamba2 and xLSTM
    use_mem_eff_path: True # For Mamba2
    dropout: 0.1  # For Mamba1, xLSTM, LSTM
    bidirectional: False # For LSTM

training:
  lr: 0.001
  batch_size: 32
  rssm_beta: 1.0  # KL divergence weight
  epochs: 100

data:
  dataset_name: visual-antmaze-large-navigate-v0
  frame_size: 64
  sequence_length: 50

Model Architecture Details

Hierarchical Modules

The core/components/hierarchical_modules/ directory contains MSVA implementation.

  • POMDP Modules: Base classes for Partially Observable Markov Decision Processes
  • RSSM Implementation: Recurrent State Space Models with stochastic and deterministic components
  • Recurrent Modules: Various temporal processing architectures
  • Distribution Modules: Probabilistic state representations

Training and Evaluation

Training Process

  1. Forward Pass: Encode frames β†’ Process temporally β†’ Decode predictions
  2. Loss Computation: Reconstruction + KL divergence + optional action prediction
  3. Optimization: AdamW optimizer with optional learning rate scheduling
  4. Validation: Regular evaluation on held-out data

Evaluation Metrics

  • Reconstruction Loss: Pixel-wise reconstruction quality
  • KL Divergence: Latent space regularization
  • PSNR/SSIM: Image quality metrics
  • Action Prediction: Accuracy of predicted actions (if applicable)

Model Selection

  • Checkpointing: Automatic model saving based on validation metrics
  • Early Stopping: Prevents overfitting with patience-based stopping
  • WandB Integration: Experiment tracking and visualization

Customization

Adding New Temporal Processors

  1. Implement in core/components/hierarchical_modules/recurrent.py
  2. Register with the @register decorator
  3. Add configuration in appropriate YAML files

Adding New Deterministic Temporal Blocks

  1. Implement in core/components/temporal_blocks.py
  2. Add condiguration in appropriate YAML Files.

New Encoder/Decoder Architectures

  1. Extend BaseEncoder/BaseDecoder classes
  2. Implement forward methods
  3. Update model factory functions

Custom Datasets

  1. Modify data/load_dataset.py for new data sources
  2. Implement appropriate preprocessing
  3. Update dataset configuration

Project Structure

DeepFutureFrames/
β”œβ”€β”€ core/                           # Core model implementations
β”‚   β”œβ”€β”€ components/                 # Modular components
β”‚   β”‚   β”œβ”€β”€ encoders.py            # Video frame encoders
β”‚   β”‚   β”œβ”€β”€ decoders.py            # Video frame decoders
β”‚   β”‚   β”œβ”€β”€ hierarchical_modules/  # Simon's MSVA extensions
β”‚   β”‚   β”‚   β”œβ”€β”€ pomdp/            # POMDP-based modeling
β”‚   β”‚   β”‚   β”œβ”€β”€ recurrent.py      # Temporal processors
β”‚   β”‚   β”‚   └── distributions.py  # Probabilistic modules
β”‚   β”‚   └── temporal_blocks.py    # Temporal processing blocks
β”‚   └── models/                    # Complete model implementations
β”œβ”€β”€ data/                          # Data loading and processing
β”‚   β”œβ”€β”€ datasets.py               # Dataset classes
β”‚   └── load_dataset.py           # Data download and preprocessing
β”œβ”€β”€ experiments/                   # Training and evaluation
β”‚   β”œβ”€β”€ conf/                     # Configuration files
β”‚   β”œβ”€β”€ train.py                  # Fixed beta training
β”‚   └── train_KL_Annealing.py    # KL annealing training
β”œβ”€β”€ utils/                         # Utility functions
β”‚   β”œβ”€β”€ loss.py                   # Loss functions
β”‚   β”œβ”€β”€ evaluation.py             # Evaluation metrics
β”‚   └── trajectory.py             # Trajectory handling
└── future-frames.yaml            # Conda environment specification

Troubleshooting

Common Issues

  1. CUDA Memory: Reduce batch size or sequence length
  2. Installation Errors: Ensure correct Python version (3.11) and CUDA compatibility
  3. Data Loading: Check internet connection for dataset downloads
  4. Model Convergence: Adjust learning rate and beta values

License

This project is part of a thesis submission. Please respect academic usage guidelines and cite appropriately when using this code.

About

Long Horizon Video Predictions with Hierarchical Deep Probabilistic Models

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages