diff --git a/README.md b/README.md index da52d8f..374e012 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,48 @@ For GPUs with more than 40GB of GPU memory, **e.g., H100, please use the unquant Interactive inference via the terminal is available at `turbodiffusion/serve/`. This allows multi-turn video generation without reloading the model. +### Memory Optimization: Pre-caching T5 Embeddings + +The umT5-XXL text encoder requires ~11GB VRAM, which can cause OOM on 32GB GPUs when combined with the DiT models. To avoid this, you can pre-cache text embeddings in a separate pass: + +**Memory Comparison:** +| Approach | Peak VRAM | Notes | +|----------|-----------|-------| +| Standard (T5 + DiT) | ~30GB+ | May OOM on 32GB GPUs | +| Cached embeddings | ~18GB | T5 never loaded during inference | + +**Step 1: Cache the embedding (loads T5, encodes prompt, saves to file, unloads T5)** +```bash +python scripts/cache_t5.py \ + --prompt "slow head turn, cinematic" \ + --output cached_embeddings.pt +``` + +**Step 2: Run inference with cached embedding (T5 never loaded)** +```bash +python turbodiffusion/inference/wan2.2_i2v_infer.py \ + --cached_embedding cached_embeddings.pt \ + --skip_t5 \ + --model Wan2.2-A14B \ + --low_noise_model_path checkpoints/TurboWan2.2-I2V-A14B-low-720P-quant.pth \ + --high_noise_model_path checkpoints/TurboWan2.2-I2V-A14B-high-720P-quant.pth \ + --image_path your_image.jpg \ + --prompt "slow head turn, cinematic" \ + --quant_linear --attention_type sagesla --ode +``` + +You can cache multiple prompts at once: +```bash +# Create a prompts file +echo "slow head turn, cinematic" > prompts.txt +echo "walking forward, dramatic lighting" >> prompts.txt + +# Cache all prompts +python scripts/cache_t5.py --prompts_file prompts.txt --output my_prompts.pt +``` + +The cached file is only ~4MB per prompt, compared to the 11GB T5 model. + ## Evaluation diff --git a/install.sh b/install.sh new file mode 100755 index 0000000..dfaac0f --- /dev/null +++ b/install.sh @@ -0,0 +1,310 @@ +#!/bin/bash +# TurboDiffusion Installation Script +# For RTX 5090 (Blackwell) with CUDA 13.0 + +set -e # Exit on error + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +echo "==============================================" +echo "TurboDiffusion Installation Script" +echo "==============================================" +echo "" + +# ============================================================================= +# Check for Miniconda +# ============================================================================= +check_conda() { + if command -v conda &> /dev/null; then + echo "✅ Conda found: $(conda --version)" + return 0 + fi + + # Check common install locations + for conda_path in ~/miniconda3/bin/conda ~/anaconda3/bin/conda /opt/conda/bin/conda; do + if [ -f "$conda_path" ]; then + echo "✅ Found conda at: $conda_path" + eval "$($conda_path shell.bash hook)" + return 0 + fi + done + + return 1 +} + +install_miniconda() { + echo "" + echo "❌ Conda/Miniconda not found!" + echo "" + echo "Please install Miniconda first:" + echo "" + echo " # Download Miniconda (Linux x86_64)" + echo " wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" + echo "" + echo " # Install (follow prompts)" + echo " bash Miniconda3-latest-Linux-x86_64.sh" + echo "" + echo " # Restart shell or run:" + echo " source ~/.bashrc" + echo "" + echo " # Then re-run this script" + echo "" + + read -p "Would you like to download and install Miniconda now? [y/N] " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + echo "Downloading Miniconda..." + wget -q --show-progress https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh + + echo "Installing Miniconda to ~/miniconda3..." + bash /tmp/miniconda.sh -b -p ~/miniconda3 + + echo "Initializing conda..." + ~/miniconda3/bin/conda init bash + eval "$(~/miniconda3/bin/conda shell.bash hook)" + + rm /tmp/miniconda.sh + echo "✅ Miniconda installed!" + return 0 + else + exit 1 + fi +} + +if ! check_conda; then + install_miniconda +fi + +# Source conda for current shell +if [ -f ~/miniconda3/etc/profile.d/conda.sh ]; then + source ~/miniconda3/etc/profile.d/conda.sh +elif [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then + source ~/anaconda3/etc/profile.d/conda.sh +fi + +# ============================================================================= +# Check for CUDA +# ============================================================================= +echo "" +echo "Checking CUDA..." + +if ! command -v nvcc &> /dev/null; then + echo "⚠️ nvcc not found in PATH" + # Check common locations + for cuda_path in /usr/local/cuda-13.0 /usr/local/cuda-12.9 /usr/local/cuda; do + if [ -f "$cuda_path/bin/nvcc" ]; then + echo " Found CUDA at: $cuda_path" + export PATH="$cuda_path/bin:$PATH" + export LD_LIBRARY_PATH="$cuda_path/lib64:$LD_LIBRARY_PATH" + break + fi + done +fi + +if command -v nvcc &> /dev/null; then + CUDA_VERSION=$(nvcc --version | grep "release" | sed 's/.*release \([0-9]*\.[0-9]*\).*/\1/') + echo "✅ CUDA version: $CUDA_VERSION" +else + echo "❌ CUDA not found. Please install CUDA 13.0 for RTX 5090 support." + exit 1 +fi + +# Check GPU +if command -v nvidia-smi &> /dev/null; then + GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null | head -1) + GPU_MEMORY=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader 2>/dev/null | head -1) + echo "✅ GPU: $GPU_NAME ($GPU_MEMORY)" +fi + +# ============================================================================= +# Create/Activate Conda Environment +# ============================================================================= +ENV_NAME="turbodiffusion" + +echo "" +echo "Setting up conda environment: $ENV_NAME" + +if conda env list | grep -q "^$ENV_NAME "; then + echo " Environment '$ENV_NAME' already exists" + read -p " Recreate environment? [y/N] " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + echo " Removing existing environment..." + conda env remove -n $ENV_NAME -y + echo " Creating fresh environment..." + conda create -n $ENV_NAME python=3.12 -y + fi +else + echo " Creating new environment with Python 3.12..." + conda create -n $ENV_NAME python=3.12 -y +fi + +echo " Activating environment..." +conda activate $ENV_NAME + +echo "✅ Python: $(python --version)" + +# ============================================================================= +# Install PyTorch with CUDA 13.0 (Nightly for Blackwell support) +# ============================================================================= +echo "" +echo "Installing PyTorch with CUDA 13.0 support..." +echo " (Nightly build required for RTX 5090/Blackwell)" + +pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu130 + +# Verify PyTorch installation +python -c "import torch; print(f'✅ PyTorch {torch.__version__}'); print(f' CUDA available: {torch.cuda.is_available()}'); print(f' CUDA version: {torch.version.cuda}')" || { + echo "❌ PyTorch installation failed" + exit 1 +} + +# ============================================================================= +# Install Dependencies +# ============================================================================= +echo "" +echo "Installing dependencies..." + +pip install psutil + +# ============================================================================= +# Initialize Git Submodules (CUTLASS) +# ============================================================================= +echo "" +echo "Initializing git submodules (CUTLASS)..." + +if [ -d ".git" ]; then + git submodule update --init --recursive + echo "✅ Submodules initialized" +else + echo "⚠️ Not a git repository, checking if CUTLASS exists..." + if [ ! -f "turbodiffusion/ops/cutlass/include/cutlass/cutlass.h" ]; then + echo "❌ CUTLASS not found. Please clone with: git clone --recursive " + exit 1 + fi +fi + +# Verify CUTLASS headers +if [ ! -f "turbodiffusion/ops/cutlass/include/cutlass/cutlass.h" ]; then + echo "❌ CUTLASS headers not found after submodule init" + exit 1 +fi +echo "✅ CUTLASS headers verified" + +# ============================================================================= +# Build and Install TurboDiffusion +# ============================================================================= +echo "" +echo "Building TurboDiffusion..." +echo " Compiling CUDA kernels for: sm_80, sm_89, sm_90, sm_120a (Blackwell)" +echo " This may take several minutes..." +echo "" + +# Clean previous builds if requested +if [ "$1" == "--clean" ]; then + echo "Cleaning previous builds..." + rm -rf build/ dist/ *.egg-info/ + find . -name "*.so" -path "*/turbodiffusion/*" -delete 2>/dev/null || true +fi + +pip install -e . --no-build-isolation 2>&1 | tee build.log + +# ============================================================================= +# Create Module Symlinks (for inference scripts) +# ============================================================================= +echo "" +echo "Creating module symlinks..." + +# The inference scripts import from top-level (e.g., 'from imaginaire.utils.io') +# but modules are inside turbodiffusion/. Create symlinks at repo root. +cd "$SCRIPT_DIR" + +for module in imaginaire rcm ops SLA; do + if [ -d "turbodiffusion/$module" ]; then + if [ ! -L "$module" ]; then + ln -sf "turbodiffusion/$module" "$module" + echo " Created symlink: $module -> turbodiffusion/$module" + else + echo " Symlink exists: $module" + fi + fi +done + +# Verify symlinks work +python -c " +import sys +sys.path.insert(0, '.') +from imaginaire.utils.io import save_image_or_video +from rcm.datasets.utils import VIDEO_RES_SIZE_INFO +from ops import FastLayerNorm, FastRMSNorm, Int8Linear +from SLA import SparseLinearAttention, SageSparseLinearAttention +print('✅ All module imports working') +" || echo "⚠️ Some imports failed - check symlinks" + +# ============================================================================= +# Install SpargeAttn (Sparse Attention for efficiency) +# ============================================================================= +echo "" +echo "Installing SpargeAttn..." + +# Get GPU compute capability +GPU_ARCH=$(python -c "import torch; cc = torch.cuda.get_device_capability(); print(f'{cc[0]}.{cc[1]}')" 2>/dev/null || echo "8.0") +echo " Detected GPU compute capability: $GPU_ARCH" + +# Clone, patch for Blackwell (sm_120) if needed, and install +SPARGE_TMP="/tmp/SpargeAttn_build_$$" +rm -rf "$SPARGE_TMP" +git clone --depth 1 https://github.com/thu-ml/SpargeAttn.git "$SPARGE_TMP" + +# Add sm_120 (Blackwell) support if not already present +if grep -q '"12.0"' "$SPARGE_TMP/setup.py"; then + echo " SpargeAttn already supports sm_120" +else + echo " Patching SpargeAttn for Blackwell (sm_120) support..." + sed -i 's/SUPPORTED_ARCHS = {"8.0", "8.6", "8.7", "8.9", "9.0"}/SUPPORTED_ARCHS = {"8.0", "8.6", "8.7", "8.9", "9.0", "12.0"}/' "$SPARGE_TMP/setup.py" +fi + +cd "$SPARGE_TMP" +TORCH_CUDA_ARCH_LIST="$GPU_ARCH" pip install -e . --no-build-isolation +cd "$SCRIPT_DIR" +rm -rf "$SPARGE_TMP" + +# ============================================================================= +# Verify Installation +# ============================================================================= +echo "" +echo "Verifying installation..." + +python -c " +import torch +import turbo_diffusion_ops +print('✅ turbo_diffusion_ops loaded') +print(' Available ops:', [x for x in dir(turbo_diffusion_ops) if not x.startswith('_')]) + +try: + import spas_sage_attn + print('✅ SpargeAttn (spas_sage_attn) loaded') +except ImportError: + print('⚠️ SpargeAttn not available (optional)') + +print() +print('GPU Info:') +if torch.cuda.is_available(): + print(f' Device: {torch.cuda.get_device_name(0)}') + print(f' VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB') + print(f' Compute Capability: {torch.cuda.get_device_capability(0)}') +" + +echo "" +echo "==============================================" +echo "✅ Installation complete!" +echo "==============================================" +echo "" +echo "Usage:" +echo " conda activate $ENV_NAME" +echo " python -c 'import turbodiffusion'" +echo "" +echo "To run the TUI server:" +echo " python -m turbodiffusion.tui_serve" +echo "" diff --git a/scripts/cache_t5.py b/scripts/cache_t5.py new file mode 100644 index 0000000..1bf8c2e --- /dev/null +++ b/scripts/cache_t5.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +""" +Pre-cache T5 text embeddings to avoid loading the 11GB model during inference. + +Usage: + # Cache a single prompt + python scripts/cache_t5.py --prompt "slow head turn, cinematic" --output cached_embeddings.pt + + # Cache multiple prompts from file + python scripts/cache_t5.py --prompts_file prompts.txt --output cached_embeddings.pt + +Then use with inference: + python turbodiffusion/inference/wan2.2_i2v_infer.py \ + --cached_embedding cached_embeddings.pt \ + --skip_t5 \ + ... +""" +import os +import sys +import argparse +import torch + +# Add repo root to path for imports +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +REPO_ROOT = os.path.dirname(SCRIPT_DIR) +sys.path.insert(0, REPO_ROOT) + +def main(): + parser = argparse.ArgumentParser(description="Pre-cache T5 text embeddings") + parser.add_argument("--prompt", type=str, default=None, help="Single prompt to cache") + parser.add_argument("--prompts_file", type=str, default=None, help="File with prompts (one per line)") + parser.add_argument("--text_encoder_path", type=str, + default="/media/2TB/ComfyUI/models/text_encoders/models_t5_umt5-xxl-enc-bf16.pth", + help="Path to the umT5 text encoder") + parser.add_argument("--output", type=str, default="cached_t5_embeddings.pt", + help="Output path for cached embeddings") + parser.add_argument("--device", type=str, default="cuda", + help="Device to use for encoding (cuda is faster, memory freed after)") + args = parser.parse_args() + + # Collect prompts + prompts = [] + if args.prompt: + prompts.append(args.prompt) + if args.prompts_file and os.path.exists(args.prompts_file): + with open(args.prompts_file, 'r') as f: + prompts.extend([line.strip() for line in f if line.strip()]) + + if not prompts: + print("Error: Provide --prompt or --prompts_file") + sys.exit(1) + + print(f"Caching embeddings for {len(prompts)} prompt(s)") + print(f"Text encoder: {args.text_encoder_path}") + print(f"Device: {args.device}") + print() + + # Import after path setup + from rcm.utils.umt5 import get_umt5_embedding, clear_umt5_memory + + cache_data = { + 'prompts': prompts, + 'embeddings': [], + 'text_encoder_path': args.text_encoder_path, + } + + with torch.no_grad(): + for i, prompt in enumerate(prompts): + print(f"[{i+1}/{len(prompts)}] Encoding: '{prompt[:60]}...' " if len(prompt) > 60 else f"[{i+1}/{len(prompts)}] Encoding: '{prompt}'") + + # Get embedding (loads T5 if not already loaded) + embedding = get_umt5_embedding( + checkpoint_path=args.text_encoder_path, + prompts=prompt + ) + + # Move to CPU for storage + cache_data['embeddings'].append({ + 'prompt': prompt, + 'embedding': embedding.cpu(), + 'shape': list(embedding.shape), + }) + + print(f" Shape: {embedding.shape}, dtype: {embedding.dtype}") + + # Clear T5 from memory + print("\nClearing T5 from memory...") + clear_umt5_memory() + torch.cuda.empty_cache() + + # Save cache + print(f"\nSaving to: {args.output}") + torch.save(cache_data, args.output) + + # Summary + file_size = os.path.getsize(args.output) / (1024 * 1024) + print(f"Done! Cache file size: {file_size:.2f} MB") + print() + print("Usage:") + print(f" python turbodiffusion/inference/wan2.2_i2v_infer.py \\") + print(f" --cached_embedding {args.output} \\") + print(f" --skip_t5 \\") + print(f" ... (other args)") + + +if __name__ == "__main__": + main() diff --git a/turbodiffusion/inference/wan2.2_i2v_infer.py b/turbodiffusion/inference/wan2.2_i2v_infer.py index e57e509..a1ee28e 100644 --- a/turbodiffusion/inference/wan2.2_i2v_infer.py +++ b/turbodiffusion/inference/wan2.2_i2v_infer.py @@ -15,6 +15,7 @@ import argparse import math +import os import torch from einops import rearrange, repeat @@ -47,6 +48,8 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument("--sigma_max", type=float, default=200, help="Initial sigma for rCM") parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth", help="Path to the Wan2.1 VAE") parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth", help="Path to the umT5 text encoder") + parser.add_argument("--cached_embedding", type=str, default=None, help="Path to pre-cached T5 embeddings (from scripts/cache_t5.py)") + parser.add_argument("--skip_t5", action="store_true", help="Skip loading T5 model (requires --cached_embedding)") parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate") parser.add_argument("--prompt", type=str, default=None, help="Text prompt for video generation (required unless --serve)") parser.add_argument("--resolution", default="720p", type=str, help="Resolution of the generated output") @@ -60,6 +63,7 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument("--quant_linear", action="store_true", help="Whether to replace Linear layers with quantized versions") parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm layers with faster versions") parser.add_argument("--serve", action="store_true", help="Launch interactive TUI server mode (keeps model loaded)") + parser.add_argument("--offload_dit", action="store_true", help="Offload DiT models before VAE decode (saves VRAM for high-res/long videos)") return parser.parse_args() @@ -82,10 +86,34 @@ def parse_arguments() -> argparse.Namespace: log.error("--image_path is required (unless using --serve mode)") exit(1) - log.info(f"Computing embedding for prompt: {args.prompt}") - with torch.no_grad(): - text_emb = get_umt5_embedding(checkpoint_path=args.text_encoder_path, prompts=args.prompt).to(**tensor_kwargs) - clear_umt5_memory() + # Get text embedding - either from cache or by running T5 + if args.cached_embedding and os.path.exists(args.cached_embedding): + log.info(f"Loading cached embedding from: {args.cached_embedding}") + cache_data = torch.load(args.cached_embedding, map_location='cpu') + + # Find matching prompt or use first embedding + text_emb = None + for emb_data in cache_data.get('embeddings', []): + if emb_data['prompt'] == args.prompt: + text_emb = emb_data['embedding'] + log.info(f"Found exact prompt match in cache") + break + + if text_emb is None: + # Use first embedding if no exact match + text_emb = cache_data['embeddings'][0]['embedding'] + log.warning(f"No exact prompt match, using cached embedding for: '{cache_data['embeddings'][0]['prompt'][:50]}...'") + + text_emb = text_emb.to(**tensor_kwargs) + log.success(f"Loaded cached embedding, shape: {text_emb.shape}") + elif args.skip_t5: + log.error("--skip_t5 requires --cached_embedding with a valid path") + exit(1) + else: + log.info(f"Computing embedding for prompt: {args.prompt}") + with torch.no_grad(): + text_emb = get_umt5_embedding(checkpoint_path=args.text_encoder_path, prompts=args.prompt).to(**tensor_kwargs) + clear_umt5_memory() log.info(f"Loading DiT models.") high_noise_model = create_model(dit_path=args.high_noise_model_path, args=args).cpu() @@ -212,6 +240,18 @@ def parse_arguments() -> argparse.Namespace: low_noise_model.cpu() torch.cuda.empty_cache() + # Offload DiT models completely before VAE decode if requested + if args.offload_dit: + log.info("Offloading DiT models to free VRAM for VAE decode...") + del high_noise_model + del low_noise_model + del net + torch.cuda.empty_cache() + import gc + gc.collect() + torch.cuda.empty_cache() + log.success(f"VRAM freed. Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB") + with torch.no_grad(): video = tokenizer.decode(samples)