A comprehensive implementation of video prediction models using Hybrid Transformer-based and CNN architectures for both holistic and object-centric scene representations. This project explores different approaches to learning and predicting future video frames on the MOVi-C dataset.
- Overview
- Features
- Project Structure
- Architecture
- Dataset
- Usage
- Experiments
- Model Checkpoints
- Configuration
- Results
- Citation
This project implements a two-stage video prediction pipeline:
- Stage 1 - Autoencoder Training: Learn compressed representations of video frames
- Stage 2 - Predictor Training: Predict future frame representations in latent space
The framework supports two distinct scene representation approaches:
- Holistic Representation: Treats the entire scene as a unified entity
- Object-Centric (OC) Representation: Decomposes scenes into individual objects using masks/bounding boxes
- π Two-Stage Training Pipeline: Separate autoencoder and predictor training phases
- π Dual Scene Representations: Support for both holistic and object-centric approaches
- π§ Transformer-Based Architecture: Modern attention-based encoders and decoders
- π― Flexible Configuration: Easy-to-modify configuration system
- π Comprehensive Logging: TensorBoard integration with visualization support
- β‘ Mixed Precision Training: Efficient GPU utilization with AMP support
- π Early Stopping & Scheduling: Automatic training optimization
- πΎ Checkpoint Management: Automatic model saving and loading
CourseProject_2/
βββ src/
β βββ base/ # Base classes
β β βββ baseTrainer.py # Base trainer implementation
β β βββ baseTransformer.py # Base transformer blocks
β βββ datalib/ # Data loading and processing
β β βββ MoviC.py # MOVi-C dataset class
β β βββ load_data.py # Data loading utilities
β β βββ transforms.py # Data augmentation
β βββ model/ # Model architectures
β β βββ ocvp.py # Main model definitions (TransformerAutoEncoder, TransformerPredictor, OCVP)
β β βββ holistic_encoder.py # Holistic encoder (patch-based)
β β βββ holistic_decoder.py # Holistic decoder
β β βββ holistic_predictor.py # Holistic predictor
β β βββ oc_encoder.py # Object-centric encoder (CNN + Transformer)
β β βββ oc_decoder.py # Object-centric decoder (Transformer + CNN)
β β βββ oc_predictor.py # Object-centric predictor
β β βββ predictor_wrapper.py # Autoregressive wrapper with sliding window
β β βββ model_utils.py # Model utilities (TransformerBlock, Patchifier, etc.)
β βββ utils/ # Utility functions
β β βββ logger.py # Logging utilities
β β βββ metrics.py # Evaluation metrics
β β βββ utils.py # General utilities
β β βββ visualization.py # Visualization tools
β βββ experiments/ # Experiment outputs
β β βββ [experiment_name]/
β β βββ checkpoints/ # Model checkpoints
β β βββ config/ # Experiment config
β β βββ tboard_logs/ # TensorBoard logs
β βββ CONFIG.py # Global configuration
β βββ trainer.py # Training entry point
β βββ ocvp.ipynb # Analysis notebook
βββ docs/ # Documentation and reports
βββ requirements.txt # Python dependencies
βββ README.md # This file
The object-centric model uses a hybrid Transformer + CNN architecture for optimal performance:
CNN Advantages:
- β Inductive Bias: Built-in understanding of spatial locality and translation invariance
- β Efficient Downsampling: Reduces 64Γ64 images to compact 256D vectors
- β Parameter Efficiency: Fewer parameters than fully linear projections
- β Better Image Reconstruction: ConvTranspose layers naturally upsample spatial features
Transformer Advantages:
- β Temporal Modeling: Captures long-range dependencies across time
- β Object Relationships: Models interactions between multiple objects
- β Attention Mechanism: Learns which objects/features are important
- β Flexible Context: Handles variable number of objects and temporal sequences
Combined Benefits:
- π― CNNs handle spatial features (what objects look like)
- π― Transformers handle temporal dynamics (how objects move and interact)
- π― Best of both worlds: local spatial structure + global temporal reasoning
-
Encoder (
HolisticEncoder/ObjectCentricEncoder)- Holistic: Patchifies input images (16Γ16 patches) β Linear projection β Transformer
- Object-Centric: CNN encoder + Transformer hybrid architecture
- CNN Feature Extraction: 3-layer ConvNet downsampler
- Conv2d(3β64): 64Γ64 β 32Γ32
- Conv2d(64β128): 32Γ32 β 16Γ16
- Conv2d(128β256): 16Γ16 β 8Γ8
- Linear: Flatten β 256D embedding
- Extracts per-object features from masks/bboxes (up to 11 objects)
- Transformer processes object tokens across time
- CNN Feature Extraction: 3-layer ConvNet downsampler
- Configurable depth (12 layers default)
- Embedding dimension: 256
- Multi-head attention (8 heads)
- MLP size: 1024
-
Decoder (
HolisticDecoder/ObjectCentricDecoder)- Holistic: Transformer β Linear projection β Unpatchify to image
- Object-Centric: Transformer + CNN hybrid architecture
- Transformer processes latent object representations
- CNN Upsampling Decoder: 3-layer ConvTranspose
- Linear: 192D β 128Γ8Γ8 feature map
- ConvTranspose2d(128β64): 8Γ8 β 16Γ16
- ConvTranspose2d(64β32): 16Γ16 β 32Γ32
- ConvTranspose2d(32β3): 32Γ32 β 64Γ64 RGB
- Tanh activation for [-1, 1] output range
- Combines per-object frames back to full scene
- Configurable depth (8 layers default)
- Embedding dimension: 192
- Mixed loss: MSE (0.8) + L1 (0.2)
-
Predictor (
HolisticTransformerPredictor/ObjectCentricTransformerPredictor)- Predicts future latent representations autoregressively
- Transformer-based temporal modeling
- Configurable depth (8 layers default)
- Embedding dimension: 192
- Optional residual connections
-
Predictor Wrapper (
PredictorWrapper)- Autoregressive Prediction: Iteratively predicts future frames
- Sliding Window Mechanism: Maintains a buffer of size 5
- Concatenates new predictions to input buffer
- Drops oldest frames when buffer exceeds window size
- Training Strategy:
- Random temporal slicing for data augmentation
- Per-step loss computation with temporal consistency
- Advanced Loss Function:
- MSE loss (0.6): Overall structure
- L1 loss (0.2): Sharpness and sparsity
- Cosine similarity loss (0.2): Feature alignment
- Generates 5 future frame predictions per forward pass
Input Video Frames β Encoder β Latent Representation β Predictor β Future Latent β Decoder β Predicted Frames
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β OBJECT-CENTRIC ENCODER β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β Input: Video [B, T, 3, 64, 64] + Masks [B, T, 64, 64] β
β β β
β Object Extraction (11 objects max) β
β β Object Frames: [B, T, 11, 3, 64, 64] β
β β β
β CNN Feature Extractor (Per Object): β
β β’ Conv2d(3β64, k=4, s=2) + BatchNorm + ReLU [64x64 β 32x32] β
β β’ Conv2d(64β128, k=4, s=2) + BatchNorm + ReLU [32x32 β 16x16] β
β β’ Conv2d(128β256, k=4, s=2) + BatchNorm + ReLU [16x16 β 8x8] β
β β’ Flatten + Linear(256Β·8Β·8 β 256) β
β β Object Tokens: [B, T, 11, 256] β
β β β
β Transformer Encoder (12 layers): β
β β’ Positional Encoding β
β β’ Multi-Head Attention (8 heads, dim=128) β
β β’ MLP (dim=1024) β
β β’ Layer Normalization β
β β Latent: [B, T, 11, 256] β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β PREDICTOR + WRAPPER β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β Input Latent: [B, T=24, 11, 256] β
β β β
β PredictorWrapper (Autoregressive): β
β β’ Random temporal slice (5 frames) β
β β’ Sliding window buffer (size=5) β
β β β
β Transformer Predictor (8 layers): β
β β’ Linear(256 β 192) β
β β’ Transformer blocks (depth=8) β
β β’ Linear(192 β 256) β
β β’ Optional residual connections β
β β β
β Autoregressive Loop (5 predictions): β
β For t in 1..5: β
β β’ Predict next frame β
β β’ Append to buffer, shift window β
β β’ Compute loss (MSE + L1 + Cosine) β
β β Future Latent: [B, 5, 11, 256] β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β OBJECT-CENTRIC DECODER β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β Input Latent: [B, T, 11, 256] β
β β β
β Transformer Decoder (8 layers): β
β β’ Linear(256 β 192) β
β β’ Positional Encoding β
β β’ Transformer blocks (depth=8) β
β β’ Layer Normalization β
β β [B, T, 11, 192] β
β β β
β CNN Upsampling Decoder (Per Object): β
β β’ Linear(192 β 128Β·8Β·8) + Reshape to [128, 8, 8] β
β β’ ConvTranspose2d(128β64, k=4, s=2) + BatchNorm + ReLU [8x8 β 16x16] β
β β’ ConvTranspose2d(64β32, k=4, s=2) + BatchNorm + ReLU [16x16 β 32x32] β
β β’ ConvTranspose2d(32β3, k=4, s=2) + Tanh [32x32 β 64x64] β
β β Per-Object Frames: [B, T, 11, 3, 64, 64] β
β β β
β Object Composition: β
β β’ Sum all object frames: Ξ£(objects) β
β β’ Normalize: (x + 1) / 2 (from [-1,1] to [0,1]) β
β β’ Clamp to [0, 1] β
β β Reconstructed Video: [B, T, 3, 64, 64] β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
- Clone the repository:
git clone <repository-url>
cd CourseProject_2- Create and activate virtual environment:
python -m venv venv
source venv/bin/activate # On Linux/Mac
# or
venv\Scripts\activate # On Windows- Install dependencies:
pip install -r requirements.txtThis project uses the MOVi-C dataset (Multi-Object Video Dataset).
- Download MOVi-C dataset from the official source
- Extract to your preferred location
- Update the dataset path in
src/CONFIG.py:
config = {
'data': {
'dataset_path': '/path/to/movi_c/',
...
}
}The MOVi-C dataset should have the following structure:
movi_c/
βββ train/
βββ validation/
βββ test/
Train the autoencoder with holistic representation:
cd src
python trainer.py --ae --scene_rep holisticTrain with object-centric representation:
python trainer.py --ae --scene_rep ocAfter training the autoencoder, train the predictor:
python trainer.py --predictor --scene_rep holistic \
--ackpt experiments/01_Holistic_AE_XL/checkpoints/best_01_Holistic_AE_XL.pthFor object-centric:
python trainer.py --predictor --scene_rep oc \
--ackpt experiments/01_OC_AE_XL_64_Full_CNN/checkpoints/best_01_OC_AE_XL_64_Full_CNN.pthRun end-to-end video prediction:
python trainer.py --inference --scene_rep holistic \
--ackpt path/to/autoencoder.pth \
--pckpt path/to/predictor.pth| Argument | Short | Description |
|---|---|---|
--ae |
-a |
Enable autoencoder training mode |
--predictor |
-p |
Enable predictor training mode |
--inference |
-i |
Enable end-to-end inference mode |
--ackpt |
-ac |
Path to pretrained autoencoder checkpoint |
--pckpt |
-pc |
Path to pretrained predictor checkpoint |
--scene_rep |
-s |
Scene representation type: holistic or oc |
The project includes several experimental configurations:
-
Holistic Autoencoders:
01_Holistic_AE_Base: Baseline holistic autoencoder02_Holistic_AE_XL: Extra-large holistic autoencoder
-
Object-Centric Autoencoders:
01_OC_AE_XL_64_Full_CNN: Full CNN-based OC autoencoder01_OC_AE_XL_64_Mixed_CNN_Decoder_Linear_ENCODER: Mixed architecture- Various linear and advanced configurations
-
Holistic Predictors:
02_Holistic_Predictor_XL: Standard predictor03_Holistic_Predictor_XL: Improved version05_Holistic_Predictor_XL_NoResidual: Without residual connections
-
Object-Centric Predictors:
01_OC_Predictor_XL: Standard OC predictor
Each experiment generates:
- Checkpoints: Best and periodic model saves
- TensorBoard Logs: Training curves, visualizations
- Configuration Snapshots: Reproducible experiment configs
Pre-trained model checkpoints are available for download:
π Download Model Checkpoints
- Holistic Autoencoder (Base & XL)
- Object-Centric Autoencoder (Various configurations)
- Holistic Predictor (Multiple versions)
- Object-Centric Predictor
The main configuration file is src/CONFIG.py. Key parameters:
'data': {
'dataset_path': '/path/to/movi_c/',
'batch_size': 32,
'patch_size': 16,
'max_objects': 11,
'num_workers': 8,
'image_height': 64,
'image_width': 64,
}'training': {
'num_epochs': 300,
'warmup_epochs': 15,
'early_stopping_patience': 15,
'model_name': '01_OC_AE_XL_64_Full_CNN',
'lr': 4e-4,
'save_frequency': 25,
'use_scheduler': True,
'use_early_stopping': True,
'use_transforms': False,
'use_amp': True, # Mixed precision training
}'vit_cfg': {
'encoder_embed_dim': 256,
'decoder_embed_dim': 192,
'num_heads': 8,
'mlp_size': 1024,
'encoder_depth': 12,
'decoder_depth': 8,
'predictor_depth': 8,
'num_preds': 5,
'predictor_window_size': 5,
'use_masks': True,
'use_bboxes': False,
'residual': True,
}The models achieve high-quality video frame reconstruction:
- Holistic Models: Capture global scene structure effectively
- Object-Centric Models: Better at preserving individual object details
View results in the Jupyter notebook:
cd src
jupyter lab ocvp.ipynbThe notebook includes:
- Training/validation loss curves
- Reconstruction visualizations
- Prediction quality analysis
- Comparison between holistic and object-centric approaches
Monitor training progress:
tensorboard --logdir src/experiments/[experiment_name]/tboard_logsIf you use this code in your research, please cite:
@misc{video_prediction_ocvp,
title={Evaluating Image Representations for Video Prediction},
author={Your Name},
year={2025},
howpublished={\url{https://github.com/your-repo}}
}Note: This is a course project Video Prediction with Object Representations. See the docs/ folder for project reports and lab notebook examples.
