Rough in some places codebase but was our CS8803 VLM project.
A Fused Vision-Language-World Model for Temporal Reasoning
TheWorld combines the power of Google Gemma 3 (vision-language understanding) with NVIDIA Cosmos (world dynamics modeling). Experiments with early fusion of Cosmos VAE World Model embeddings with a language model.
| Traditional VLMs | TheWorld |
|---|---|
| "What is this?" | "What is this and what happens next?" |
| Static visual understanding | Static + temporal dynamics |
| Sees current frame | Predicts future states |
Example: Given an image of a ball mid-air, TheWorld can reason about:
- Static: "A red ball in the air"
- Temporal: "The ball is falling and will hit the ground"
# 1. Clone the repository
git clone https://github.com/yourusername/theworld.git
cd theworld
# 2. One-command setup (installs dependencies + creates directories)
bash scripts/setup.sh
# 3. Set your HuggingFace token (for model downloads)
export HF_TOKEN=hf_your_token_herefrom theworld import TheWorld
from PIL import Image
import torch
# Load model with world reasoning enabled
model = TheWorld.from_pretrained(
"google/gemma-3-4b-it",
enable_world=True,
dtype=torch.bfloat16,
device_map="auto"
)
# Prepare your image and question
image = Image.open("example.jpg")
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "What will happen next in this scene?"}
]
}]
# Generate response (standard HuggingFace interface)
inputs = model.processor.apply_chat_template(
messages, tokenize=True, return_dict=True, return_tensors="pt"
).to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100)
response = model.processor.decode(outputs[0], skip_special_tokens=True)
print(response)That's it! TheWorld uses the standard HuggingFace API—if you know how to use Gemma3, you know how to use TheWorld.
Inference example:
export HF_TOKEN=hf_your_token_here
PYTHONPATH=python:$PYTHONPATH uv run python examples/inference.pyTest baseline equivalence (TheWorld with enable_world=False == pure Gemma3):
PYTHONPATH=python:$PYTHONPATH uv run pytest tests/test_baseline_equivalence.py -vAutomated Setup (Recommended)
bash scripts/setup.sh # Full setup with dev tools
bash scripts/setup.sh --skip-dev # Skip dev dependencies (faster)The setup script:
- ✅ Installs core dependencies with
uv - ✅ Installs Cosmos safety checker
- ✅ Creates necessary directories
- ✅ Verifies installation
Manual Setup
# Install with uv (recommended package manager)
uv sync --dev
# Install Cosmos safety checker
uv pip install cosmos_guardrail
# Verify installation
make smoke-test- Fused Design: Gemma 3 (4B) + Cosmos (2B) connected via learnable projection layers
- Standard HF Interface: Drop-in replacement for Gemma3 - same
from_pretrained()andgenerate()API - Validated: Logits are numerically identical to pure Gemma3 when
enable_world=False
- Parameter Efficient: Train only 0.07% of parameters by default (2.9M out of 4.3B)
- Flexible Unfreezing: Choose which components to train (projection, vision, language, world)
- Multi-Stage Training: Start fast with projections, progressively unfreeze as needed
- Enable/Disable World Model: Compare with and without temporal reasoning
- Component Freezing: Fine-grained control over which parts train
- Baseline Comparison: Perfect Gemma3 baseline for ablation studies
from theworld import TheWorld
import torch
# Load with world model enabled
model = TheWorld.from_pretrained(
"google/gemma-3-4b-it",
enable_world=True,
dtype=torch.bfloat16,
device_map="auto"
)# Perfect Gemma3 baseline for comparison
model = TheWorld.from_pretrained(
"google/gemma-3-4b-it",
enable_world=False, # No world model
dtype=torch.bfloat16,
device_map="auto"
)
# This produces identical outputs to pure Gemma3# Train only projection layers (fastest, 0.07% params)
model = TheWorld.from_pretrained(
"google/gemma-3-4b-it",
enable_world=True,
freeze_gemma_vision=True, # Freeze SigLIP (346M)
freeze_gemma_language=True, # Freeze Gemma LLM (3.95B)
freeze_cosmos_vae=True, # Freeze Cosmos VAE
dtype=torch.bfloat16,
device_map="auto"
)
# Train vision + projection (domain adaptation)
model = TheWorld.from_pretrained(
"google/gemma-3-4b-it",
enable_world=True,
freeze_gemma_vision=False, # Train SigLIP
freeze_gemma_language=True,
freeze_cosmos_vae=True,
dtype=torch.bfloat16,
device_map="auto"
)
# Train language + projection (task-specific generation)
model = TheWorld.from_pretrained(
"google/gemma-3-4b-it",
enable_world=True,
freeze_gemma_vision=True,
freeze_gemma_language=False, # Train Gemma LLM
freeze_cosmos_vae=True,
dtype=torch.bfloat16,
device_map="auto"
)# Use different Cosmos variant
model = TheWorld.from_pretrained(
"google/gemma-3-4b-it",
enable_world=True,
cosmos_model_name="nvidia/Cosmos-Predict2-2B-Video2World",
dtype=torch.bfloat16,
device_map="auto"
)Local/Interactive:
# Multi-GPU with Accelerate
accelerate launch --config_file configs/accelerate/multi_gpu_ddp.yaml \
scripts/train_hf.py --config configs/spatial_rgpt_training.jsonSLURM (HPC Clusters):
See detailed guides:
Input Image � [Gemma Vision (SigLIP)] � Vision Tokens (256)
�
� [Cosmos VAE Encoder] � World Latents (16-dim)
�
[Projection 16�2304] � World Tokens (784)
�
Combined: [BOS, SOW, WORLD�784, EOW, TEXT, IMAGE�256]
�
[Gemma Language Model] � Output Logits
- CLAUDE.md - Complete development guide and API reference
- Training Guide - Training configuration and best practices
- Logit Validation - Initialization investigation and solution
- Hub Upload Guide - Publishing models to HuggingFace Hub
- Multi-Stage Training - Progressive training workflow
- Evaluation Guide - Evaluation on BLINK benchmark
- Google Gemma 3: Vision-language foundation
- NVIDIA Cosmos: World model foundation
- HuggingFace Transformers: Model infrastructure