Skip to content

Latest commit

Β 

History

History
496 lines (398 loc) Β· 20.6 KB

File metadata and controls

496 lines (398 loc) Β· 20.6 KB

Evaluating Image Representations for Video Prediction

A comprehensive implementation of video prediction models using Hybrid Transformer-based and CNN architectures for both holistic and object-centric scene representations. This project explores different approaches to learning and predicting future video frames on the MOVi-C dataset.

Reconstruction GIF

πŸ“‹ Table of Contents

🎯 Overview

This project implements a two-stage video prediction pipeline:

  1. Stage 1 - Autoencoder Training: Learn compressed representations of video frames
  2. Stage 2 - Predictor Training: Predict future frame representations in latent space

The framework supports two distinct scene representation approaches:

  • Holistic Representation: Treats the entire scene as a unified entity
  • Object-Centric (OC) Representation: Decomposes scenes into individual objects using masks/bounding boxes

✨ Features

  • πŸ”„ Two-Stage Training Pipeline: Separate autoencoder and predictor training phases
  • 🎭 Dual Scene Representations: Support for both holistic and object-centric approaches
  • 🧠 Transformer-Based Architecture: Modern attention-based encoders and decoders
  • 🎯 Flexible Configuration: Easy-to-modify configuration system
  • πŸ“Š Comprehensive Logging: TensorBoard integration with visualization support
  • ⚑ Mixed Precision Training: Efficient GPU utilization with AMP support
  • πŸ” Early Stopping & Scheduling: Automatic training optimization
  • πŸ’Ύ Checkpoint Management: Automatic model saving and loading

πŸ“ Project Structure

CourseProject_2/
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ base/                    # Base classes
β”‚   β”‚   β”œβ”€β”€ baseTrainer.py       # Base trainer implementation
β”‚   β”‚   └── baseTransformer.py   # Base transformer blocks
β”‚   β”œβ”€β”€ datalib/                 # Data loading and processing
β”‚   β”‚   β”œβ”€β”€ MoviC.py            # MOVi-C dataset class
β”‚   β”‚   β”œβ”€β”€ load_data.py        # Data loading utilities
β”‚   β”‚   └── transforms.py        # Data augmentation
β”‚   β”œβ”€β”€ model/                   # Model architectures
β”‚   β”‚   β”œβ”€β”€ ocvp.py             # Main model definitions (TransformerAutoEncoder, TransformerPredictor, OCVP)
β”‚   β”‚   β”œβ”€β”€ holistic_encoder.py # Holistic encoder (patch-based)
β”‚   β”‚   β”œβ”€β”€ holistic_decoder.py # Holistic decoder
β”‚   β”‚   β”œβ”€β”€ holistic_predictor.py # Holistic predictor
β”‚   β”‚   β”œβ”€β”€ oc_encoder.py       # Object-centric encoder (CNN + Transformer)
β”‚   β”‚   β”œβ”€β”€ oc_decoder.py       # Object-centric decoder (Transformer + CNN)
β”‚   β”‚   β”œβ”€β”€ oc_predictor.py     # Object-centric predictor
β”‚   β”‚   β”œβ”€β”€ predictor_wrapper.py # Autoregressive wrapper with sliding window
β”‚   β”‚   └── model_utils.py      # Model utilities (TransformerBlock, Patchifier, etc.)
β”‚   β”œβ”€β”€ utils/                   # Utility functions
β”‚   β”‚   β”œβ”€β”€ logger.py           # Logging utilities
β”‚   β”‚   β”œβ”€β”€ metrics.py          # Evaluation metrics
β”‚   β”‚   β”œβ”€β”€ utils.py            # General utilities
β”‚   β”‚   └── visualization.py    # Visualization tools
β”‚   β”œβ”€β”€ experiments/             # Experiment outputs
β”‚   β”‚   └── [experiment_name]/
β”‚   β”‚       β”œβ”€β”€ checkpoints/    # Model checkpoints
β”‚   β”‚       β”œβ”€β”€ config/         # Experiment config
β”‚   β”‚       └── tboard_logs/    # TensorBoard logs
β”‚   β”œβ”€β”€ CONFIG.py               # Global configuration
β”‚   β”œβ”€β”€ trainer.py              # Training entry point
β”‚   └── ocvp.ipynb             # Analysis notebook
β”œβ”€β”€ docs/                       # Documentation and reports
β”œβ”€β”€ requirements.txt            # Python dependencies
└── README.md                   # This file

Why Transformer + CNN Hybrid?

The object-centric model uses a hybrid Transformer + CNN architecture for optimal performance:

CNN Advantages:

  • βœ… Inductive Bias: Built-in understanding of spatial locality and translation invariance
  • βœ… Efficient Downsampling: Reduces 64Γ—64 images to compact 256D vectors
  • βœ… Parameter Efficiency: Fewer parameters than fully linear projections
  • βœ… Better Image Reconstruction: ConvTranspose layers naturally upsample spatial features

Transformer Advantages:

  • βœ… Temporal Modeling: Captures long-range dependencies across time
  • βœ… Object Relationships: Models interactions between multiple objects
  • βœ… Attention Mechanism: Learns which objects/features are important
  • βœ… Flexible Context: Handles variable number of objects and temporal sequences

Combined Benefits:

  • 🎯 CNNs handle spatial features (what objects look like)
  • 🎯 Transformers handle temporal dynamics (how objects move and interact)
  • 🎯 Best of both worlds: local spatial structure + global temporal reasoning

Key Components

  1. Encoder (HolisticEncoder / ObjectCentricEncoder)

    • Holistic: Patchifies input images (16Γ—16 patches) β†’ Linear projection β†’ Transformer
    • Object-Centric: CNN encoder + Transformer hybrid architecture
      • CNN Feature Extraction: 3-layer ConvNet downsampler
        • Conv2d(3β†’64): 64Γ—64 β†’ 32Γ—32
        • Conv2d(64β†’128): 32Γ—32 β†’ 16Γ—16
        • Conv2d(128β†’256): 16Γ—16 β†’ 8Γ—8
        • Linear: Flatten β†’ 256D embedding
      • Extracts per-object features from masks/bboxes (up to 11 objects)
      • Transformer processes object tokens across time
    • Configurable depth (12 layers default)
    • Embedding dimension: 256
    • Multi-head attention (8 heads)
    • MLP size: 1024
  2. Decoder (HolisticDecoder / ObjectCentricDecoder)

    • Holistic: Transformer β†’ Linear projection β†’ Unpatchify to image
    • Object-Centric: Transformer + CNN hybrid architecture
      • Transformer processes latent object representations
      • CNN Upsampling Decoder: 3-layer ConvTranspose
        • Linear: 192D β†’ 128Γ—8Γ—8 feature map
        • ConvTranspose2d(128β†’64): 8Γ—8 β†’ 16Γ—16
        • ConvTranspose2d(64β†’32): 16Γ—16 β†’ 32Γ—32
        • ConvTranspose2d(32β†’3): 32Γ—32 β†’ 64Γ—64 RGB
        • Tanh activation for [-1, 1] output range
      • Combines per-object frames back to full scene
    • Configurable depth (8 layers default)
    • Embedding dimension: 192
    • Mixed loss: MSE (0.8) + L1 (0.2)
  3. Predictor (HolisticTransformerPredictor / ObjectCentricTransformerPredictor)

    • Predicts future latent representations autoregressively
    • Transformer-based temporal modeling
    • Configurable depth (8 layers default)
    • Embedding dimension: 192
    • Optional residual connections
  4. Predictor Wrapper (PredictorWrapper)

    • Autoregressive Prediction: Iteratively predicts future frames
    • Sliding Window Mechanism: Maintains a buffer of size 5
      • Concatenates new predictions to input buffer
      • Drops oldest frames when buffer exceeds window size
    • Training Strategy:
      • Random temporal slicing for data augmentation
      • Per-step loss computation with temporal consistency
    • Advanced Loss Function:
      • MSE loss (0.6): Overall structure
      • L1 loss (0.2): Sharpness and sparsity
      • Cosine similarity loss (0.2): Feature alignment
    • Generates 5 future frame predictions per forward pass

πŸ—οΈ Architecture

Overall Pipeline

Input Video Frames β†’ Encoder β†’ Latent Representation β†’ Predictor β†’ Future Latent β†’ Decoder β†’ Predicted Frames

Detailed Architecture: Object-Centric Model (Transformer + CNN)

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                          OBJECT-CENTRIC ENCODER                             β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Input: Video [B, T, 3, 64, 64] + Masks [B, T, 64, 64]                       β”‚
β”‚   ↓                                                                         β”‚
β”‚ Object Extraction (11 objects max)                                          β”‚
β”‚   β†’ Object Frames: [B, T, 11, 3, 64, 64]                                    β”‚
β”‚   ↓                                                                         β”‚
β”‚ CNN Feature Extractor (Per Object):                                         β”‚
β”‚   β€’ Conv2d(3β†’64, k=4, s=2) + BatchNorm + ReLU    [64x64 β†’ 32x32]            β”‚
β”‚   β€’ Conv2d(64β†’128, k=4, s=2) + BatchNorm + ReLU  [32x32 β†’ 16x16]            β”‚
β”‚   β€’ Conv2d(128β†’256, k=4, s=2) + BatchNorm + ReLU [16x16 β†’ 8x8]              β”‚
β”‚   β€’ Flatten + Linear(256Β·8Β·8 β†’ 256)                                         β”‚
β”‚   β†’ Object Tokens: [B, T, 11, 256]                                          β”‚
β”‚   ↓                                                                         β”‚
β”‚ Transformer Encoder (12 layers):                                            β”‚
β”‚   β€’ Positional Encoding                                                     β”‚
β”‚   β€’ Multi-Head Attention (8 heads, dim=128)                                 β”‚
β”‚   β€’ MLP (dim=1024)                                                          β”‚
β”‚   β€’ Layer Normalization                                                     β”‚
β”‚   β†’ Latent: [B, T, 11, 256]                                                 β”‚ 
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                            PREDICTOR + WRAPPER                              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Input Latent: [B, T=24, 11, 256]                                            β”‚
β”‚   ↓                                                                         β”‚
β”‚ PredictorWrapper (Autoregressive):                                          β”‚
β”‚   β€’ Random temporal slice (5 frames)                                        β”‚
β”‚   β€’ Sliding window buffer (size=5)                                          β”‚
β”‚   ↓                                                                         β”‚
β”‚ Transformer Predictor (8 layers):                                           β”‚
β”‚   β€’ Linear(256 β†’ 192)                                                       β”‚
β”‚   β€’ Transformer blocks (depth=8)                                            β”‚
β”‚   β€’ Linear(192 β†’ 256)                                                       β”‚
β”‚   β€’ Optional residual connections                                           β”‚
β”‚   ↓                                                                         β”‚
β”‚ Autoregressive Loop (5 predictions):                                        β”‚
β”‚   For t in 1..5:                                                            β”‚
β”‚     β€’ Predict next frame                                                    β”‚
β”‚     β€’ Append to buffer, shift window                                        β”‚
β”‚     β€’ Compute loss (MSE + L1 + Cosine)                                      β”‚
β”‚   β†’ Future Latent: [B, 5, 11, 256]                                          β”‚ 
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                          OBJECT-CENTRIC DECODER                             β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Input Latent: [B, T, 11, 256]                                               β”‚
β”‚   ↓                                                                         β”‚
β”‚ Transformer Decoder (8 layers):                                             β”‚
β”‚   β€’ Linear(256 β†’ 192)                                                       β”‚
β”‚   β€’ Positional Encoding                                                     β”‚
β”‚   β€’ Transformer blocks (depth=8)                                            β”‚
β”‚   β€’ Layer Normalization                                                     β”‚
β”‚   β†’ [B, T, 11, 192]                                                         β”‚
β”‚   ↓                                                                         β”‚
β”‚ CNN Upsampling Decoder (Per Object):                                        β”‚
β”‚   β€’ Linear(192 β†’ 128Β·8Β·8) + Reshape to [128, 8, 8]                          β”‚
β”‚   β€’ ConvTranspose2d(128β†’64, k=4, s=2) + BatchNorm + ReLU [8x8 β†’ 16x16]      β”‚
β”‚   β€’ ConvTranspose2d(64β†’32, k=4, s=2) + BatchNorm + ReLU [16x16 β†’ 32x32]     β”‚
β”‚   β€’ ConvTranspose2d(32β†’3, k=4, s=2) + Tanh        [32x32 β†’ 64x64]           β”‚
β”‚   β†’ Per-Object Frames: [B, T, 11, 3, 64, 64]                                β”‚
β”‚   ↓                                                                         β”‚
β”‚ Object Composition:                                                         β”‚
β”‚   β€’ Sum all object frames: Ξ£(objects)                                       β”‚
β”‚   β€’ Normalize: (x + 1) / 2  (from [-1,1] to [0,1])                          β”‚
β”‚   β€’ Clamp to [0, 1]                                                         β”‚
β”‚   β†’ Reconstructed Video: [B, T, 3, 64, 64]                                  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Setup

  1. Clone the repository:
git clone <repository-url>
cd CourseProject_2
  1. Create and activate virtual environment:
python -m venv venv
source venv/bin/activate  # On Linux/Mac
# or
venv\Scripts\activate  # On Windows
  1. Install dependencies:
pip install -r requirements.txt

πŸ“¦ Dataset

This project uses the MOVi-C dataset (Multi-Object Video Dataset).

Dataset Setup

  1. Download MOVi-C dataset from the official source
  2. Extract to your preferred location
  3. Update the dataset path in src/CONFIG.py:
config = {
    'data': {
        'dataset_path': '/path/to/movi_c/',
        ...
    }
}

Dataset Structure

The MOVi-C dataset should have the following structure:

movi_c/
β”œβ”€β”€ train/
β”œβ”€β”€ validation/
└── test/

πŸ’» Usage

Training Autoencoder

Train the autoencoder with holistic representation:

cd src
python trainer.py --ae --scene_rep holistic

Train with object-centric representation:

python trainer.py --ae --scene_rep oc

Training Predictor

After training the autoencoder, train the predictor:

python trainer.py --predictor --scene_rep holistic \
    --ackpt experiments/01_Holistic_AE_XL/checkpoints/best_01_Holistic_AE_XL.pth

For object-centric:

python trainer.py --predictor --scene_rep oc \
    --ackpt experiments/01_OC_AE_XL_64_Full_CNN/checkpoints/best_01_OC_AE_XL_64_Full_CNN.pth

Inference

Run end-to-end video prediction:

python trainer.py --inference --scene_rep holistic \
    --ackpt path/to/autoencoder.pth \
    --pckpt path/to/predictor.pth

Command-Line Arguments

Argument Short Description
--ae -a Enable autoencoder training mode
--predictor -p Enable predictor training mode
--inference -i Enable end-to-end inference mode
--ackpt -ac Path to pretrained autoencoder checkpoint
--pckpt -pc Path to pretrained predictor checkpoint
--scene_rep -s Scene representation type: holistic or oc

πŸ”¬ Experiments

The project includes several experimental configurations:

Autoencoder Experiments

  1. Holistic Autoencoders:

    • 01_Holistic_AE_Base: Baseline holistic autoencoder
    • 02_Holistic_AE_XL: Extra-large holistic autoencoder
  2. Object-Centric Autoencoders:

    • 01_OC_AE_XL_64_Full_CNN: Full CNN-based OC autoencoder
    • 01_OC_AE_XL_64_Mixed_CNN_Decoder_Linear_ENCODER: Mixed architecture
    • Various linear and advanced configurations

Predictor Experiments

  1. Holistic Predictors:

    • 02_Holistic_Predictor_XL: Standard predictor
    • 03_Holistic_Predictor_XL: Improved version
    • 05_Holistic_Predictor_XL_NoResidual: Without residual connections
  2. Object-Centric Predictors:

    • 01_OC_Predictor_XL: Standard OC predictor

Experiment Outputs

Each experiment generates:

  • Checkpoints: Best and periodic model saves
  • TensorBoard Logs: Training curves, visualizations
  • Configuration Snapshots: Reproducible experiment configs

πŸ’Ύ Model Checkpoints

Pre-trained model checkpoints are available for download:

πŸ”— Download Model Checkpoints

Available Checkpoints

  • Holistic Autoencoder (Base & XL)
  • Object-Centric Autoencoder (Various configurations)
  • Holistic Predictor (Multiple versions)
  • Object-Centric Predictor

βš™οΈ Configuration

The main configuration file is src/CONFIG.py. Key parameters:

Data Configuration

'data': {
    'dataset_path': '/path/to/movi_c/',
    'batch_size': 32,
    'patch_size': 16,
    'max_objects': 11,
    'num_workers': 8,
    'image_height': 64,
    'image_width': 64,
}

Training Configuration

'training': {
    'num_epochs': 300,
    'warmup_epochs': 15,
    'early_stopping_patience': 15,
    'model_name': '01_OC_AE_XL_64_Full_CNN',
    'lr': 4e-4,
    'save_frequency': 25,
    'use_scheduler': True,
    'use_early_stopping': True,
    'use_transforms': False,
    'use_amp': True,  # Mixed precision training
}

Model Configuration

'vit_cfg': {
    'encoder_embed_dim': 256,
    'decoder_embed_dim': 192,
    'num_heads': 8,
    'mlp_size': 1024,
    'encoder_depth': 12,
    'decoder_depth': 8,
    'predictor_depth': 8,
    'num_preds': 5,
    'predictor_window_size': 5,
    'use_masks': True,
    'use_bboxes': False,
    'residual': True,
}

πŸ“Š Results

Reconstruction Quality

The models achieve high-quality video frame reconstruction:

  • Holistic Models: Capture global scene structure effectively
  • Object-Centric Models: Better at preserving individual object details

Visualization

View results in the Jupyter notebook:

cd src
jupyter lab ocvp.ipynb

The notebook includes:

  • Training/validation loss curves
  • Reconstruction visualizations
  • Prediction quality analysis
  • Comparison between holistic and object-centric approaches

TensorBoard

Monitor training progress:

tensorboard --logdir src/experiments/[experiment_name]/tboard_logs

πŸŽ“ Citation

If you use this code in your research, please cite:

@misc{video_prediction_ocvp,
  title={Evaluating Image Representations for Video Prediction},
  author={Your Name},
  year={2025},
  howpublished={\url{https://github.com/your-repo}}
}

Note: This is a course project Video Prediction with Object Representations. See the docs/ folder for project reports and lab notebook examples.