diff --git a/.github/workflows/config/.secrets.baseline b/.github/workflows/config/.secrets.baseline index 83fa8afc..cbd1e1e8 100644 --- a/.github/workflows/config/.secrets.baseline +++ b/.github/workflows/config/.secrets.baseline @@ -90,10 +90,6 @@ { "path": "detect_secrets.filters.allowlist.is_line_allowlisted" }, - { - "path": "detect_secrets.filters.common.is_baseline_file", - "filename": ".github/workflows/config/.secrets.baseline" - }, { "path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies", "min_level": 2 @@ -139,10 +135,10 @@ "filename": "examples/automodel/pretrain/cicd/wan21_cicd_nightly_video.yaml", "hashed_secret": "c70f071570ba65f9c4079d6051e955ff4f802eea", "is_verified": false, - "line_number": 67, + "line_number": 72, "is_secret": false } ] }, - "generated_at": "2026-01-30T18:50:34Z" + "generated_at": "2026-02-12T07:45:24Z" } diff --git a/dfm/src/automodel/recipes/train.py b/dfm/src/automodel/recipes/train.py index c67032b2..71818f5f 100644 --- a/dfm/src/automodel/recipes/train.py +++ b/dfm/src/automodel/recipes/train.py @@ -25,6 +25,7 @@ from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig from nemo_automodel.components.loggers.log_utils import setup_logging from nemo_automodel.components.loggers.wandb_utils import suppress_wandb_log_messages +from nemo_automodel.components.optim.scheduler import OptimizerParamScheduler from nemo_automodel.components.training.rng import StatefulRNG from nemo_automodel.components.training.step_scheduler import StepScheduler from nemo_automodel.recipes.base_recipe import BaseRecipe @@ -195,20 +196,93 @@ def build_model_and_optimizer( def build_lr_scheduler( + cfg, optimizer: torch.optim.Optimizer, - *, - num_epochs: int, - steps_per_epoch: int, - eta_min: float = 1e-6, -) -> torch.optim.lr_scheduler.CosineAnnealingLR: - """Build the cosine annealing learning rate scheduler.""" - - total_steps = max(1, num_epochs * max(1, steps_per_epoch)) - logging.info(f"[INFO] Scheduler configured for {total_steps} total steps") - return torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, - T_max=total_steps, - eta_min=eta_min, + total_steps: int, +) -> Optional[OptimizerParamScheduler]: + """Build the learning rate scheduler. + + Args: + cfg: Configuration for the OptimizerParamScheduler from YAML. If None, no scheduler + is created and constant LR is used. Supports: + - lr_decay_style: constant, linear, cosine, inverse-square-root, WSD + - lr_warmup_steps: Number of warmup steps (or fraction < 1 for percentage) + - min_lr: Minimum LR after decay + - init_lr: Initial LR for warmup (defaults to 10% of max_lr if warmup enabled) + - wd_incr_style: constant, linear, cosine (for weight decay scheduling) + - wsd_decay_steps: WSD-specific decay steps + - lr_wsd_decay_style: WSD-specific decay style (cosine, linear, exponential, minus_sqrt) + optimizer: The optimizer to be scheduled. + total_steps: Total number of optimizer steps for the training run. + + Returns: + OptimizerParamScheduler instance, or None if cfg is None. + """ + if cfg is None: + return None + + user_cfg = cfg.to_dict() if hasattr(cfg, "to_dict") else dict(cfg) + + base_lr = optimizer.param_groups[0]["lr"] + base_wd = optimizer.param_groups[0].get("weight_decay", 0.0) + + # Compute defaults from runtime values + default_cfg: Dict[str, Any] = { + "optimizer": optimizer, + "lr_warmup_steps": min(1000, total_steps // 10), + "lr_decay_steps": total_steps, + "lr_decay_style": "cosine", + "init_lr": base_lr * 0.1, + "max_lr": base_lr, + "min_lr": base_lr * 0.01, + "start_wd": base_wd, + "end_wd": base_wd, + "wd_incr_steps": total_steps, + "wd_incr_style": "constant", + } + + # Handle warmup as fraction before merging + if "lr_warmup_steps" in user_cfg: + warmup = user_cfg["lr_warmup_steps"] + if isinstance(warmup, float) and 0 < warmup < 1: + user_cfg["lr_warmup_steps"] = int(warmup * total_steps) + + # WSD defaults if user specifies WSD style + if user_cfg.get("lr_decay_style") == "WSD": + default_cfg["wsd_decay_steps"] = max(1, total_steps // 10) + default_cfg["lr_wsd_decay_style"] = "cosine" + + # User config overrides defaults + default_cfg.update(user_cfg) + + # If user disabled warmup, set init_lr = max_lr + if default_cfg["lr_warmup_steps"] == 0: + default_cfg["init_lr"] = default_cfg["max_lr"] + + # Ensure warmup < decay steps + if default_cfg["lr_warmup_steps"] >= default_cfg["lr_decay_steps"]: + default_cfg["lr_warmup_steps"] = max(0, default_cfg["lr_decay_steps"] - 1) + + logging.info( + f"[INFO] LR Scheduler: style={default_cfg['lr_decay_style']}, " + f"warmup={default_cfg['lr_warmup_steps']}, total={default_cfg['lr_decay_steps']}, " + f"max_lr={default_cfg['max_lr']}, min_lr={default_cfg['min_lr']}" + ) + + return OptimizerParamScheduler( + optimizer=default_cfg["optimizer"], + init_lr=default_cfg["init_lr"], + max_lr=default_cfg["max_lr"], + min_lr=default_cfg["min_lr"], + lr_warmup_steps=default_cfg["lr_warmup_steps"], + lr_decay_steps=default_cfg["lr_decay_steps"], + lr_decay_style=default_cfg["lr_decay_style"], + start_wd=default_cfg["start_wd"], + end_wd=default_cfg["end_wd"], + wd_incr_steps=default_cfg["wd_incr_steps"], + wd_incr_style=default_cfg["wd_incr_style"], + wsd_decay_steps=default_cfg.get("wsd_decay_steps"), + lr_wsd_decay_style=default_cfg.get("lr_wsd_decay_style"), ) @@ -390,11 +464,17 @@ def setup(self): grad_acc_steps = max(1, self.global_batch_size // max(1, self.local_batch_size * self.dp_size)) self.steps_per_epoch = ceil(self.raw_steps_per_epoch / grad_acc_steps) - self.lr_scheduler = build_lr_scheduler( + # Calculate total optimizer steps for LR scheduler + total_steps = self.num_epochs * self.steps_per_epoch + + # Build LR scheduler (returns None if lr_scheduler not in config) + # Wrap in list for compatibility with checkpointing (OptimizerState expects list) + lr_scheduler = build_lr_scheduler( + self.cfg.get("lr_scheduler", None), self.optimizer, - num_epochs=self.num_epochs, - steps_per_epoch=self.steps_per_epoch, + total_steps, ) + self.lr_scheduler = [lr_scheduler] if lr_scheduler is not None else None self.global_step = 0 self.start_epoch = 0 @@ -490,7 +570,8 @@ def run_train_validation_loop(self): grad_norm = float(grad_norm) if torch.is_tensor(grad_norm) else grad_norm self.optimizer.step() - self.lr_scheduler.step() + if self.lr_scheduler is not None: + self.lr_scheduler[0].step(1) group_loss_mean = float(sum(micro_losses) / len(micro_losses)) epoch_loss += group_loss_mean diff --git a/dfm/src/automodel/utils/preprocessing_multiprocess.py b/dfm/src/automodel/utils/preprocessing_multiprocess.py index 6b7ef2bf..c3a98cb6 100644 --- a/dfm/src/automodel/utils/preprocessing_multiprocess.py +++ b/dfm/src/automodel/utils/preprocessing_multiprocess.py @@ -12,229 +12,177 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Unified preprocessing tool for images and videos. + +Supports: +- Images: FLUX (and other image models) +- Videos: Wan2.1, HunyuanVideo-1.5 + +Usage: + # Image preprocessing + python -m dfm.src.automodel.utils.preprocessing_multiprocess image \\ + --image_dir /path/to/images \\ + --output_dir /path/to/cache \\ + --processor flux + + # Video preprocessing + python -m dfm.src.automodel.utils.preprocessing_multiprocess video \\ + --video_dir /path/to/videos \\ + --output_dir /path/to/cache \\ + --processor wan \\ + --resolution_preset 512p + + # List available processors + python -m dfm.src.automodel.utils.preprocessing_multiprocess --list_processors +""" + import argparse import hashlib import json +import logging import os +import pickle import traceback from multiprocessing import Pool from pathlib import Path from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +import cv2 +import numpy as np import torch from PIL import Image from tqdm import tqdm from dfm.src.automodel.datasets.multiresolutionDataloader.multi_tier_bucketing import MultiTierBucketCalculator -from dfm.src.automodel.utils.processors import BaseModelProcessor, ProcessorRegistry +from dfm.src.automodel.utils.processors import ( + BaseModelProcessor, + BaseVideoProcessor, + ProcessorRegistry, + get_caption_loader, +) +# ============================================================================= +# Constants +# ============================================================================= +IMAGE_EXTENSIONS = {"jpg", "jpeg", "png", "webp", "bmp"} +VIDEO_EXTENSIONS = {"mp4", "avi", "mov", "mkv", "webm"} + +# ============================================================================= # Global worker state (initialized once per process) +# ============================================================================= _worker_models: Optional[Dict[str, Any]] = None _worker_processor: Optional[BaseModelProcessor] = None _worker_calculator: Optional[MultiTierBucketCalculator] = None _worker_device: Optional[str] = None +_worker_config: Optional[Dict[str, Any]] = None -def _init_worker(processor_name: str, model_name: str, gpu_id: int, max_pixels: int): - """Initialize worker process with models on assigned GPU.""" - global _worker_models, _worker_processor, _worker_calculator, _worker_device - - # Set CUDA_VISIBLE_DEVICES to isolate this GPU for the worker process. - # After this, the selected GPU becomes cuda:0 (not cuda:{gpu_id}). - os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) - _worker_device = "cuda:0" - - _worker_processor = ProcessorRegistry.get(processor_name) - _worker_models = _worker_processor.load_models(model_name, _worker_device) - _worker_calculator = MultiTierBucketCalculator(quantization=64, max_pixels=max_pixels) - - print(f"Worker initialized on GPU {gpu_id}") - +# ============================================================================= +# Common Utility Functions +# ============================================================================= -def _load_caption(image_path: Path, caption_field: str = "internvl") -> Optional[str]: - """ - Load caption from JSON file for an image. - DEPRECATED: Use _load_all_captions() instead for better performance. - This function is kept for backward compatibility only. - """ - image_name = image_path.name +def _get_media_files(media_dir: Path, extensions: set) -> List[Path]: + """Recursively get all media files with given extensions using os.walk().""" + media_files = [] + for root, dirs, files in os.walk(media_dir): + root_path = Path(root) + for file in files: + if "." in file: + ext = file.lower().rsplit(".", 1)[-1] + if ext in extensions: + media_files.append(root_path / file) + return sorted(media_files) - # Extract prefix: everything before '_sample' - if "_sample" in image_name: - prefix = image_name.rsplit("_sample", 1)[0] - else: - prefix = image_path.stem - json_path = image_path.parent / f"{prefix}_internvl.json" +def _save_metadata_shards( + all_metadata: List[Dict], + output_dir: Path, + processor_name: str, + model_name: str, + model_type: str, + shard_size: int, + extra_fields: Dict[str, Any], +) -> None: + """Save metadata in shards and write config file.""" + shard_files = [] + for shard_idx in range(0, len(all_metadata), shard_size): + shard_data = all_metadata[shard_idx : shard_idx + shard_size] + shard_file = output_dir / f"metadata_shard_{shard_idx // shard_size:04d}.json" + with open(shard_file, "w") as f: + json.dump(shard_data, f, indent=2) + shard_files.append(shard_file.name) - if not json_path.exists(): - return None + metadata = { + "processor": processor_name, + "model_name": model_name, + "model_type": model_type, + "total_items": len(all_metadata), + "num_shards": len(shard_files), + "shard_size": shard_size, + "shards": shard_files, + **extra_fields, + } - try: - with open(json_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - try: - entry = json.loads(line) - if entry.get("file_name") == image_name: - return entry.get(caption_field, "") - except json.JSONDecodeError: - continue - except Exception: - pass - - return None + with open(output_dir / "metadata.json", "w") as f: + json.dump(metadata, f, indent=2) -def _load_all_captions( - image_files: List[Path], caption_field: str = "internvl", verbose: bool = True -) -> Dict[str, str]: - """ - Pre-load all captions from JSONL files into memory. +def _print_bucket_distribution(all_metadata: List[Dict]) -> None: + """Print bucket resolution distribution.""" + bucket_counts: Dict[str, int] = {} + for item in all_metadata: + res = f"{item['bucket_resolution'][0]}x{item['bucket_resolution'][1]}" + bucket_counts[res] = bucket_counts.get(res, 0) + 1 - This function eliminates the performance bottleneck of repeatedly opening - and parsing the same JSONL files by loading all captions once upfront. + logger.info("Bucket distribution:") + for res in sorted(bucket_counts.keys()): + logger.info(" %s: %d", res, bucket_counts[res]) - Args: - image_files: List of image file paths - caption_field: Field name in JSONL to use ('internvl' or 'usr') - verbose: Print progress information - Returns: - Dictionary mapping image filename to caption text - """ - from collections import defaultdict +# ============================================================================= +# Image Preprocessing Functions +# ============================================================================= - if verbose: - print("\nPre-loading captions from JSONL files...") - # Group images by their JSONL file - jsonl_to_images = defaultdict(list) +def _init_worker(processor_name: str, model_name: str, gpu_id: int, max_pixels: int): + """Initialize worker process with models on assigned GPU.""" + global _worker_models, _worker_processor, _worker_calculator, _worker_device - for image_path in image_files: - image_name = image_path.name + # Set CUDA_VISIBLE_DEVICES to isolate this GPU for the worker process. + # After this, the selected GPU becomes cuda:0 (not cuda:{gpu_id}). + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + _worker_device = "cuda:0" - # Extract prefix: everything before '_sample' - if "_sample" in image_name: - prefix = image_name.rsplit("_sample", 1)[0] - else: - prefix = image_path.stem + _worker_processor = ProcessorRegistry.get(processor_name) + _worker_models = _worker_processor.load_models(model_name, _worker_device) + _worker_calculator = MultiTierBucketCalculator(quantization=64, max_pixels=max_pixels) - json_path = image_path.parent / f"{prefix}_internvl.json" - jsonl_to_images[json_path].append(image_name) + logger.info("Worker initialized on GPU %d", gpu_id) - # Load each JSONL file once and build caption dictionary - caption_cache = {} - loaded_files = 0 - missing_files = 0 - total_captions = 0 - for json_path, image_names in tqdm(jsonl_to_images.items(), desc="Loading JSONL files", disable=not verbose): - if not json_path.exists(): - missing_files += 1 - # Images with missing JSONL will use filename fallback - continue +def _load_all_captions( + image_files: List[Path], caption_field: str = "internvl", verbose: bool = True +) -> Dict[str, str]: + """Pre-load all captions from JSONL files. Returns filename->caption dict.""" + from dfm.src.automodel.utils.processors.caption_loaders import JSONLCaptionLoader - try: - with open(json_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - try: - entry = json.loads(line) - file_name = entry.get("file_name") - if file_name and file_name in image_names: - caption = entry.get(caption_field, "") - if caption: - caption_cache[file_name] = caption - total_captions += 1 - except json.JSONDecodeError: - continue - loaded_files += 1 - except Exception as e: - if verbose: - print(f"Warning: Failed to load {json_path}: {e}") - continue + loader = JSONLCaptionLoader(jsonl_suffix="_internvl.json") + captions, stats = loader.load_captions_with_stats(image_files, caption_field, verbose=verbose) if verbose: - print(f"Loaded {total_captions} captions from {loaded_files} JSONL files") - if missing_files > 0: - print(f" {missing_files} JSONL files not found (will use filename fallback)") - missing_captions = len(image_files) - total_captions - if missing_captions > 0: - print(f" {missing_captions} images will use filename as caption") + logger.info("Loaded %d captions from %d JSONL files", stats.loaded_count, stats.files_parsed) + if stats.files_missing > 0: + logger.info(" %d JSONL files not found (will use filename fallback)", stats.files_missing) + if stats.captions_missing > 0: + logger.info(" %d images will use filename as caption", stats.captions_missing) - return caption_cache - - -def _validate_caption_files(image_files: List[Path], caption_field: str) -> Tuple[int, int, List[str]]: - """ - Validate that caption files exist and are parseable. - - Args: - image_files: List of image file paths - caption_field: Field name to check in JSONL files - - Returns: - (num_valid_files, num_missing_files, error_messages) - """ - - # Group images by their JSONL file - jsonl_files = set() - - for image_path in image_files: - image_name = image_path.name - - # Extract prefix: everything before '_sample' - if "_sample" in image_name: - prefix = image_name.rsplit("_sample", 1)[0] - else: - prefix = image_path.stem - - json_path = image_path.parent / f"{prefix}_internvl.json" - jsonl_files.add(json_path) - - # Validate each JSONL file - valid_files = 0 - missing_files = 0 - errors = [] - - for json_path in jsonl_files: - if not json_path.exists(): - missing_files += 1 - errors.append(f"Missing: {json_path}") - continue - - try: - with open(json_path, "r", encoding="utf-8") as f: - line_count = 0 - for line in f: - line = line.strip() - if not line: - continue - line_count += 1 - try: - entry = json.loads(line) - # Basic validation: check structure - if "file_name" not in entry: - errors.append(f"Invalid format in {json_path}: missing 'file_name' field") - break - except json.JSONDecodeError as e: - errors.append(f"JSON error in {json_path} line {line_count}: {e}") - break - else: - # File parsed successfully - valid_files += 1 - except Exception as e: - errors.append(f"Failed to read {json_path}: {e}") - continue - - return valid_files, missing_files, errors + return captions def _process_image(args: Tuple) -> Optional[Dict]: @@ -256,7 +204,7 @@ def _process_image(args: Tuple) -> Optional[Dict]: latent = _worker_processor.encode_image(image_tensor, _worker_models, _worker_device) if verify and not _worker_processor.verify_latent(latent, _worker_models, _worker_device): - print(f"Verification failed: {image_path}") + logger.warning("Verification failed: %s", image_path) return None # Use pre-loaded caption with fallback to filename @@ -275,7 +223,7 @@ def _process_image(args: Tuple) -> Optional[Dict]: metadata = { "original_resolution": (orig_width, orig_height), - "crop_resolution": (target_width, target_height), + "bucket_resolution": (target_width, target_height), "crop_offset": crop_offset, "prompt": caption, "image_path": str(Path(image_path).absolute()), @@ -289,7 +237,7 @@ def _process_image(args: Tuple) -> Optional[Dict]: return { "cache_file": str(cache_file), "image_path": str(Path(image_path).absolute()), - "crop_resolution": [target_width, target_height], + "bucket_resolution": [target_width, target_height], "original_resolution": [orig_width, orig_height], "prompt": caption, "bucket_id": bucket["id"], @@ -299,33 +247,11 @@ def _process_image(args: Tuple) -> Optional[Dict]: } except Exception as e: - print(f"Error processing {image_path}: {e}") - traceback.print_exc() + logger.error("Error processing %s: %s", image_path, e) + logger.debug(traceback.format_exc()) return None -def _get_image_files(image_dir: Path) -> List[Path]: - """ - Recursively get all image files efficiently. - - Uses os.walk() for better performance on large directories compared to rglob(). - """ - image_files = [] - valid_extensions = {"jpg", "jpeg", "png", "webp", "bmp"} - - # Use os.walk for better performance on large directories - for root, dirs, files in os.walk(image_dir): - root_path = Path(root) - for file in files: - # Extract extension and check if it's a valid image file - if "." in file: - ext = file.lower().rsplit(".", 1)[-1] - if ext in valid_extensions: - image_files.append(root_path / file) - - return sorted(image_files) - - def _process_shard_on_gpu( gpu_id: int, image_files: List[Path], @@ -362,7 +288,7 @@ def preprocess_dataset( max_pixels: int = 256 * 256, ): """ - Preprocess dataset with one process per GPU. + Preprocess image dataset with one process per GPU. Args: image_dir: Directory containing images @@ -388,39 +314,23 @@ def preprocess_dataset( if num_gpus == 0: raise RuntimeError("No GPUs available") - print(f"Processor: {processor_name} ({processor.model_type})") - print(f"Model: {model_name}") - print(f"GPUs: {num_gpus}") - print(f"Max pixels: {max_pixels}") + logger.info("Processor: %s (%s)", processor_name, processor.model_type) + logger.info("Model: %s", model_name) + logger.info("GPUs: %d", num_gpus) + logger.info("Max pixels: %d", max_pixels) # Get all image files - print("\nScanning for images...") - image_files = _get_image_files(image_dir) + logger.info("Scanning for images...") + image_files = _get_media_files(image_dir, IMAGE_EXTENSIONS) if max_images is not None: image_files = image_files[:max_images] - print(f"Processing {len(image_files)} images") + logger.info("Processing %d images", len(image_files)) if not image_files: return - # Validate caption files before processing - print("\nValidating caption files...") - num_valid, num_missing, errors = _validate_caption_files(image_files, caption_field) - print(f" Valid JSONL files: {num_valid}") - print(f" Missing JSONL files: {num_missing}") - - if errors and num_missing > len(set([img.parent / f"{img.stem}_internvl.json" for img in image_files])) * 0.5: - print("\nWARNING: Many caption files missing or invalid. First 10 errors:") - for err in errors[:10]: - print(f" {err}") - elif errors and len(errors) <= 5: - print("\nCaption file issues:") - for err in errors: - print(f" {err}") - - # Pre-load all captions (PERFORMANCE OPTIMIZATION) caption_cache = _load_all_captions(image_files, caption_field, verbose=True) # Split images across GPUs @@ -440,99 +350,831 @@ def preprocess_dataset( for gpu_results in results: all_metadata.extend(gpu_results) - # Save metadata in shards - shard_files = [] - for shard_idx in range(0, len(all_metadata), shard_size): - shard_data = all_metadata[shard_idx : shard_idx + shard_size] - shard_file = output_dir / f"metadata_shard_{shard_idx // shard_size:04d}.json" - with open(shard_file, "w") as f: - json.dump(shard_data, f, indent=2) - shard_files.append(shard_file.name) + # Save metadata + _save_metadata_shards( + all_metadata, + output_dir, + processor_name, + model_name, + processor.model_type, + shard_size, + {"caption_field": caption_field, "max_pixels": max_pixels}, + ) + + # Print summary + logger.info("=" * 50) + logger.info("COMPLETE: %d/%d images", len(all_metadata), len(image_files)) + logger.info("Output: %s", output_dir) + _print_bucket_distribution(all_metadata) + + +# ============================================================================= +# Video Preprocessing Functions +# ============================================================================= + + +def _init_video_worker( + processor_name: str, + model_name: str, + gpu_id: int, + max_pixels: int, + video_config: Dict[str, Any], +): + """Initialize video worker process with models on assigned GPU.""" + global _worker_models, _worker_processor, _worker_calculator, _worker_device, _worker_config + + # Set CUDA_VISIBLE_DEVICES to isolate this GPU for the worker process. + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + _worker_device = "cuda:0" + _worker_config = video_config + + _worker_processor = ProcessorRegistry.get(processor_name) + _worker_models = _worker_processor.load_models(model_name, _worker_device) + + # Create bucket calculator with processor's quantization (8 for video, 64 for image) + quantization = getattr(_worker_processor, "quantization", 8) + _worker_calculator = MultiTierBucketCalculator(quantization=quantization, max_pixels=max_pixels) + + logger.info("Video worker initialized on GPU %d (quantization=%d)", gpu_id, quantization) + + +def _get_video_dimensions(video_path: str) -> Tuple[int, int, int]: + """Get video dimensions and frame count using OpenCV.""" + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise ValueError(f"Failed to open video: {video_path}") + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + cap.release() + + return width, height, frame_count + + +def _extract_evenly_spaced_frames( + video_path: str, + num_frames: int, + target_size: Tuple[int, int], + resize_mode: str = "bilinear", + center_crop: bool = True, +) -> Tuple[List[np.ndarray], List[int]]: + """Extract evenly-spaced frames. Returns (frames, source_indices).""" + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise ValueError(f"Failed to open video: {video_path}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + # Calculate evenly-spaced frame indices + if num_frames >= total_frames: + frame_indices = list(range(total_frames)) + else: + frame_indices = np.linspace(0, total_frames - 1, num_frames).astype(int).tolist() + + target_height, target_width = target_size + + # Map resize modes to OpenCV interpolation + interp_map = { + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "nearest": cv2.INTER_NEAREST, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4, + } + interpolation = interp_map.get(resize_mode, cv2.INTER_LINEAR) + + frames = [] + actual_indices = [] + + for target_idx in frame_indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, target_idx) + ret, frame = cap.read() + if not ret: + continue + + # Convert BGR to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Resize and optionally center crop + if center_crop: + # Calculate scale to cover target area + scale = max(target_width / orig_width, target_height / orig_height) + new_width = int(orig_width * scale) + new_height = int(orig_height * scale) + + frame = cv2.resize(frame, (new_width, new_height), interpolation=interpolation) + + # Center crop + start_x = (new_width - target_width) // 2 + start_y = (new_height - target_height) // 2 + frame = frame[start_y : start_y + target_height, start_x : start_x + target_width] + else: + # Direct resize (may change aspect ratio) + frame = cv2.resize(frame, (target_width, target_height), interpolation=interpolation) + + frames.append(frame) + actual_indices.append(target_idx) + + cap.release() + return frames, actual_indices + + +def _frame_to_video_tensor(frame: np.ndarray, dtype: torch.dtype = torch.float16) -> torch.Tensor: + """Convert frame (H,W,C) to video tensor (1,C,1,H,W) normalized to [-1,1].""" + # (H, W, C) -> (C, H, W) + tensor = torch.from_numpy(frame).float().permute(2, 0, 1) + + # Normalize to [-1, 1] + tensor = tensor / 255.0 + tensor = (tensor - 0.5) / 0.5 + + # Add batch and temporal dimensions: (C, H, W) -> (1, C, 1, H, W) + tensor = tensor.unsqueeze(0).unsqueeze(2) + + return tensor.to(dtype) + - # Save config metadata (references shards instead of duplicating items) - metadata_file = output_dir / "metadata.json" - with open(metadata_file, "w") as f: - json.dump( - { - "processor": processor_name, - "model_name": model_name, - "model_type": processor.model_type, - "caption_field": caption_field, - "max_pixels": max_pixels, - "total_images": len(all_metadata), - "num_shards": len(shard_files), - "shard_size": shard_size, - "shards": shard_files, - }, - f, - indent=2, +# ============================================================================= +# Video Processing Helper Functions +# ============================================================================= + + +def _resolve_video_resolution( + orig_width: int, + orig_height: int, + config: Dict[str, Any], +) -> Tuple[int, int, Optional[str], float]: + """Resolve target resolution. Returns (width, height, bucket_id, aspect_ratio).""" + target_height = config.get("target_height") + target_width = config.get("target_width") + + if target_height is not None and target_width is not None: + # Explicit size: no bucketing + return target_width, target_height, None, target_width / target_height + else: + # Use bucket calculator to find best resolution + bucket = _worker_calculator.get_bucket_for_image(orig_width, orig_height) + return bucket["resolution"][0], bucket["resolution"][1], bucket["id"], bucket["aspect_ratio"] + + +def _save_cache_file( + cache_data: Dict[str, Any], + output_dir: str, + resolution: str, + cache_hash: str, + output_format: str, +) -> Path: + """Save cache data to file. Returns path to saved file.""" + cache_subdir = Path(output_dir) / resolution + cache_subdir.mkdir(parents=True, exist_ok=True) + + if output_format == "meta": + cache_file = cache_subdir / f"{cache_hash}.meta" + with open(cache_file, "wb") as f: + pickle.dump(cache_data, f) + else: # pt format + cache_file = cache_subdir / f"{cache_hash}.pt" + torch.save(cache_data, cache_file) + + return cache_file + + +def _build_result_dict( + cache_file: Path, + video_path: str, + target_width: int, + target_height: int, + orig_width: int, + orig_height: int, + caption: str, + bucket_id: Optional[str], + aspect_ratio: float, + num_frames: int = 1, + frame_index: Optional[int] = None, + total_frames_extracted: Optional[int] = None, + source_frame_index: Optional[int] = None, +) -> Dict[str, Any]: + """Build a result dictionary for a processed video/frame.""" + result = { + "cache_file": str(cache_file), + "video_path": str(Path(video_path).absolute()), + "bucket_resolution": [target_width, target_height], + "original_resolution": [orig_width, orig_height], + "num_frames": num_frames, + "prompt": caption, + "bucket_id": bucket_id, + "aspect_ratio": aspect_ratio, + "pixels": target_width * target_height, + "model_type": _worker_processor.model_type, + } + + # Add frame-specific fields if provided + if frame_index is not None: + result["frame_index"] = frame_index + if total_frames_extracted is not None: + result["total_frames_extracted"] = total_frames_extracted + if source_frame_index is not None: + result["source_frame_index"] = source_frame_index + + return result + + +def _process_video_frames_mode(args: Tuple) -> List[Dict]: + """Process video in frames mode - each frame becomes a separate sample.""" + video_path, output_dir, caption, config = args + + try: + # Get video dimensions + orig_width, orig_height, total_frames = _get_video_dimensions(video_path) + + # Resolve target resolution (handles bucketing vs explicit size) + target_width, target_height, bucket_id, aspect_ratio = _resolve_video_resolution( + orig_width, orig_height, config + ) + + # Extract evenly-spaced frames + num_frames = config.get("num_frames", 10) + frames, source_frame_indices = _extract_evenly_spaced_frames( + video_path, + num_frames=num_frames, + target_size=(target_height, target_width), + resize_mode=config.get("resize_mode", "bilinear"), + center_crop=config.get("center_crop", True), + ) + + if not frames: + logger.warning("No frames extracted from %s", video_path) + return [] + + total_frames_extracted = len(frames) + + # Use caption with fallback to filename + if not caption: + caption = Path(video_path).stem.replace("_", " ") + + # Encode text ONCE (reuse for all frames) + text_encodings = _worker_processor.encode_text(caption, _worker_models, _worker_device) + + # Process each frame individually + results = [] + deterministic = config.get("deterministic", True) + output_format = config.get("output_format", "meta") + resolution = f"{target_width}x{target_height}" + + for frame_idx, (frame, source_idx) in enumerate(zip(frames, source_frame_indices)): + # Convert single frame to 1-frame video tensor + video_tensor = _frame_to_video_tensor(frame) + + # Encode with VAE + latent = _worker_processor.encode_video( + video_tensor, + _worker_models, + _worker_device, + deterministic=deterministic, + ) + + # Prepare metadata for this frame + # Note: first_frame and image_embeds are omitted in frames mode + # (frames mode is intended for t2v training, not i2v conditioning) + metadata = { + "original_resolution": (orig_width, orig_height), + "bucket_resolution": (target_width, target_height), + "bucket_id": bucket_id, + "aspect_ratio": aspect_ratio, + "num_frames": 1, # Always 1 for frame mode + "total_original_frames": total_frames, + "prompt": caption, + "video_path": str(Path(video_path).absolute()), + "deterministic": deterministic, + "mode": "frames", + "frame_index": frame_idx + 1, # 1-based index + "total_frames_extracted": total_frames_extracted, + "source_frame_index": source_idx, # 0-based index in source video + } + + # Get cache data from processor + cache_data = _worker_processor.get_cache_data(latent, text_encodings, metadata) + + # Include frame index in hash to ensure unique filenames + cache_hash = hashlib.md5( + f"{Path(video_path).absolute()}_{resolution}_frame{frame_idx}".encode() + ).hexdigest() + + # Save cache file using helper + cache_file = _save_cache_file(cache_data, output_dir, resolution, cache_hash, output_format) + + # Build result dict using helper + results.append( + _build_result_dict( + cache_file=cache_file, + video_path=video_path, + target_width=target_width, + target_height=target_height, + orig_width=orig_width, + orig_height=orig_height, + caption=caption, + bucket_id=bucket_id, + aspect_ratio=aspect_ratio, + num_frames=1, + frame_index=frame_idx + 1, + total_frames_extracted=total_frames_extracted, + source_frame_index=source_idx, + ) + ) + + return results + + except Exception as e: + logger.error("Error processing %s in frames mode: %s", video_path, e) + logger.debug(traceback.format_exc()) + return [] + + +def _process_video_video_mode(args: Tuple) -> Optional[Dict]: + """Process video in video mode - multi-frame encoding as single sample.""" + video_path, output_dir, caption, config = args + + try: + # Get video dimensions + orig_width, orig_height, total_frames = _get_video_dimensions(video_path) + + # Resolve target resolution (handles bucketing vs explicit size) + target_width, target_height, bucket_id, aspect_ratio = _resolve_video_resolution( + orig_width, orig_height, config + ) + + # Load video with target resolution + num_frames = config.get("num_frames") + target_frames = config.get("target_frames") + + video_tensor, first_frame = _worker_processor.load_video( + video_path, + target_size=(target_height, target_width), + num_frames=target_frames or num_frames, + resize_mode=config.get("resize_mode", "bilinear"), + center_crop=config.get("center_crop", True), ) + actual_frames = video_tensor.shape[2] # (1, C, T, H, W) + + # Use caption with fallback to filename + if not caption: + caption = Path(video_path).stem.replace("_", " ") + + # Encode video + deterministic = config.get("deterministic", True) + latent = _worker_processor.encode_video( + video_tensor, + _worker_models, + _worker_device, + deterministic=deterministic, + ) + + # Encode text + text_encodings = _worker_processor.encode_text(caption, _worker_models, _worker_device) + + # Encode first frame for i2v (if processor supports it) + image_embeds = None + if hasattr(_worker_processor, "encode_first_frame"): + image_embeds = _worker_processor.encode_first_frame(first_frame, _worker_models, _worker_device) + + # Prepare metadata + metadata = { + "original_resolution": (orig_width, orig_height), + "bucket_resolution": (target_width, target_height), + "bucket_id": bucket_id, + "aspect_ratio": aspect_ratio, + "num_frames": actual_frames, + "total_original_frames": total_frames, + "prompt": caption, + "video_path": str(Path(video_path).absolute()), + "first_frame": first_frame, + "image_embeds": image_embeds, + "deterministic": deterministic, + "mode": config.get("mode", "video"), + } + + # Get cache data from processor + cache_data = _worker_processor.get_cache_data(latent, text_encodings, metadata) + + # Save cache file using helper + output_format = config.get("output_format", "meta") + resolution = f"{target_width}x{target_height}" + cache_hash = hashlib.md5(f"{Path(video_path).absolute()}_{resolution}_{actual_frames}".encode()).hexdigest() + cache_file = _save_cache_file(cache_data, output_dir, resolution, cache_hash, output_format) + + # Build result dict using helper + return _build_result_dict( + cache_file=cache_file, + video_path=video_path, + target_width=target_width, + target_height=target_height, + orig_width=orig_width, + orig_height=orig_height, + caption=caption, + bucket_id=bucket_id, + aspect_ratio=aspect_ratio, + num_frames=actual_frames, + ) + + except Exception as e: + logger.error("Error processing %s: %s", video_path, e) + logger.debug(traceback.format_exc()) + return None + + +def _process_video(args: Tuple) -> List[Dict]: + """Process a single video. Dispatches to frames or video mode based on config.""" + video_path, output_dir, caption, config = args + mode = config.get("mode", "video") + + if mode == "frames": + return _process_video_frames_mode(args) + else: + # Wrap single result in a list for consistent return type + result = _process_video_video_mode(args) + return [result] if result is not None else [] + + +def _process_video_shard_on_gpu( + gpu_id: int, + video_files: List[Path], + output_dir: str, + processor_name: str, + model_name: str, + caption_cache: Dict[str, str], + max_pixels: int, + video_config: Dict[str, Any], +) -> List[Dict]: + """Process a shard of videos on a specific GPU.""" + _init_video_worker(processor_name, model_name, gpu_id, max_pixels, video_config) + + results = [] + for video_path in tqdm(video_files, desc=f"GPU {gpu_id}", position=gpu_id): + caption = caption_cache.get(video_path.name) + # _process_video now always returns List[Dict] for consistent handling + results.extend(_process_video((str(video_path), output_dir, caption, video_config))) + + return results + + +def preprocess_video_dataset( + video_dir: str, + output_dir: str, + processor_name: str, + model_name: Optional[str] = None, + mode: str = "video", + num_frames: int = 10, + target_frames: Optional[int] = None, + resolution_preset: Optional[str] = None, + max_pixels: Optional[int] = None, + target_height: Optional[int] = None, + target_width: Optional[int] = None, + resize_mode: str = "bilinear", + center_crop: bool = True, + deterministic: bool = True, + output_format: str = "meta", + caption_format: str = "sidecar", + caption_field: str = "caption", + shard_size: int = 10000, + max_videos: Optional[int] = None, +): + """ + Preprocess video dataset with one process per GPU. + + Args: + video_dir: Directory containing videos + output_dir: Output directory for cache + processor_name: Name of processor ('wan', 'hunyuan') + model_name: HuggingFace model name (uses processor default if None) + mode: Processing mode ('video' or 'frames') + num_frames: Number of frames for 'frames' mode + target_frames: Target frame count (for HunyuanVideo 4n+1) + resolution_preset: Resolution preset ('256p', '512p', '768p', '1024p', '1536p') + max_pixels: Custom pixel budget (mutually exclusive with resolution_preset) + target_height: Explicit target height (disables bucketing) + target_width: Explicit target width (disables bucketing) + resize_mode: Interpolation mode for resizing + center_crop: Whether to center crop + deterministic: Use deterministic latent encoding + output_format: Output format ('meta' or 'pt') + caption_format: Caption format ('sidecar', 'meta_json', 'jsonl') + caption_field: Field name for captions + shard_size: Number of videos per metadata shard + max_videos: Maximum number of videos to process + """ + video_dir = Path(video_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Get processor and resolve model name + processor = ProcessorRegistry.get(processor_name) + if model_name is None: + model_name = processor.default_model_name + + # Determine max_pixels + if resolution_preset: + if resolution_preset not in MultiTierBucketCalculator.RESOLUTION_PRESETS: + raise ValueError( + f"Unknown preset '{resolution_preset}'. " + f"Available: {list(MultiTierBucketCalculator.RESOLUTION_PRESETS.keys())}" + ) + max_pixels = MultiTierBucketCalculator.RESOLUTION_PRESETS[resolution_preset] + elif max_pixels is None and target_height is None: + # Default to 512p for videos + max_pixels = 512 * 512 + + # If explicit size given, disable bucketing + use_bucketing = target_height is None or target_width is None + if not use_bucketing and max_pixels is None: + max_pixels = target_height * target_width # Use explicit size as pixel budget + + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + raise RuntimeError("No GPUs available") + + logger.info("Processor: %s (%s)", processor_name, processor.model_type) + logger.info("Model: %s", model_name) + logger.info("GPUs: %d", num_gpus) + logger.info("Mode: %s", mode) + if use_bucketing: + logger.info("Max pixels: %d (bucketing enabled)", max_pixels) + logger.info("Quantization: %d", getattr(processor, "quantization", 8)) + else: + logger.info("Target size: %dx%d (bucketing disabled)", target_width, target_height) + + if hasattr(processor, "frame_constraint") and processor.frame_constraint: + logger.info("Frame constraint: %s", processor.frame_constraint) + + # Get all video files + logger.info("Scanning for videos...") + video_files = _get_media_files(video_dir, VIDEO_EXTENSIONS) + + if max_videos is not None: + video_files = video_files[:max_videos] + + logger.info("Found %d videos", len(video_files)) + + if not video_files: + return + + # Load captions using appropriate loader + logger.info("Loading captions (format: %s, field: %s)...", caption_format, caption_field) + caption_loader = get_caption_loader(caption_format) + caption_cache = caption_loader.load_captions(video_files, caption_field) + logger.info(" Loaded %d captions", len(caption_cache)) + + # Video config for workers + video_config = { + "mode": mode, + "num_frames": num_frames, + "target_frames": target_frames, + "target_height": target_height if not use_bucketing else None, + "target_width": target_width if not use_bucketing else None, + "resize_mode": resize_mode, + "center_crop": center_crop, + "deterministic": deterministic, + "output_format": output_format, + } + + # Split videos across GPUs + chunks = [video_files[i::num_gpus] for i in range(num_gpus)] + + # Process with one worker per GPU + all_metadata = [] + + with Pool(processes=num_gpus) as pool: + args = [ + ( + gpu_id, + chunks[gpu_id], + str(output_dir), + processor_name, + model_name, + caption_cache, + max_pixels, + video_config, + ) + for gpu_id in range(num_gpus) + ] + + results = pool.starmap(_process_video_shard_on_gpu, args) + + for gpu_results in results: + all_metadata.extend(gpu_results) + + # Save metadata + _save_metadata_shards( + all_metadata, + output_dir, + processor_name, + model_name, + processor.model_type, + shard_size, + { + "caption_format": caption_format, + "caption_field": caption_field, + "max_pixels": max_pixels, + "mode": mode, + "target_frames": target_frames, + }, + ) + # Print summary - print(f"\n{'=' * 50}") - print(f"COMPLETE: {len(all_metadata)}/{len(image_files)} images") - print(f"Output: {output_dir}") + logger.info("=" * 50) + logger.info("COMPLETE: %d/%d videos", len(all_metadata), len(video_files)) + logger.info("Output: %s", output_dir) + _print_bucket_distribution(all_metadata) - bucket_counts: Dict[str, int] = {} - for item in all_metadata: - res = f"{item['crop_resolution'][0]}x{item['crop_resolution'][1]}" - bucket_counts[res] = bucket_counts.get(res, 0) + 1 - print("\nBucket distribution:") - for res in sorted(bucket_counts.keys()): - print(f" {res}: {bucket_counts[res]}") +# ============================================================================= +# CLI Entry Point +# ============================================================================= def main(): - parser = argparse.ArgumentParser(description="Preprocess images (one process per GPU)") - - parser.add_argument("--list_processors", action="store_true", help="List available processors") - parser.add_argument("--image_dir", type=str, help="Input image directory") - parser.add_argument("--output_dir", type=str, help="Output cache directory") - parser.add_argument("--processor", type=str, default="flux", help="Processor name") - parser.add_argument("--model_name", type=str, default=None, help="Model name") - parser.add_argument("--shard_size", type=int, default=10000, help="Metadata shard size") - parser.add_argument("--verify", action="store_true", help="Verify latents") - parser.add_argument("--caption_field", type=str, default="internvl", choices=["internvl", "usr"]) - parser.add_argument("--max_images", type=int, default=None, help="Max images to process") - parser.add_argument("--max_pixels", type=int, default=None, help="Max pixels per image") - parser.add_argument( - "--resolution_preset", type=str, default=None, choices=["256p", "512p", "768p", "1024p", "1536p"] + parser = argparse.ArgumentParser( + description="Unified preprocessing tool for images and videos", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Image preprocessing with FLUX + python -m dfm.src.automodel.utils.preprocessing_multiprocess image \\ + --image_dir /data/images --output_dir /cache --processor flux + + # Video preprocessing with Wan2.1 + python -m dfm.src.automodel.utils.preprocessing_multiprocess video \\ + --video_dir /data/videos --output_dir /cache --processor wan \\ + --resolution_preset 512p --caption_format sidecar + + # Video preprocessing with HunyuanVideo + python -m dfm.src.automodel.utils.preprocessing_multiprocess video \\ + --video_dir /data/videos --output_dir /cache --processor hunyuan \\ + --target_frames 121 --caption_format meta_json + """, ) + parser.add_argument("--list_processors", action="store_true", help="List available processors and exit") + + subparsers = parser.add_subparsers(dest="command", help="Preprocessing type") + + # =================== + # Image subcommand + # =================== + image_parser = subparsers.add_parser("image", help="Preprocess images") + image_parser.add_argument("--image_dir", type=str, required=True, help="Input image directory") + image_parser.add_argument("--output_dir", type=str, required=True, help="Output cache directory") + image_parser.add_argument("--processor", type=str, default="flux", help="Processor name (default: flux)") + image_parser.add_argument("--model_name", type=str, default=None, help="Model name (uses processor default)") + image_parser.add_argument("--shard_size", type=int, default=10000, help="Metadata shard size") + image_parser.add_argument("--verify", action="store_true", help="Verify latents can be decoded") + image_parser.add_argument( + "--caption_field", type=str, default="internvl", choices=["internvl", "usr"], help="Caption field in JSONL" + ) + image_parser.add_argument("--max_images", type=int, default=None, help="Max images to process") + + # Resolution options (mutually exclusive) + image_res_group = image_parser.add_mutually_exclusive_group() + image_res_group.add_argument( + "--resolution_preset", + type=str, + choices=["256p", "512p", "768p", "1024p", "1536p"], + help="Resolution preset for bucketing", + ) + image_res_group.add_argument("--max_pixels", type=int, help="Custom max pixel budget") + + # =================== + # Video subcommand + # =================== + video_parser = subparsers.add_parser("video", help="Preprocess videos") + video_parser.add_argument("--video_dir", type=str, required=True, help="Input video directory") + video_parser.add_argument("--output_dir", type=str, required=True, help="Output cache directory") + video_parser.add_argument( + "--processor", + type=str, + required=True, + choices=["wan", "wan2.1", "hunyuan", "hunyuanvideo", "hunyuanvideo-1.5"], + ) + video_parser.add_argument("--model_name", type=str, default=None, help="Model name (uses processor default)") + video_parser.add_argument("--mode", type=str, default="video", choices=["video", "frames"], help="Processing mode") + video_parser.add_argument("--num_frames", type=int, default=10, help="Frames to extract in 'frames' mode") + video_parser.add_argument( + "--target_frames", type=int, default=None, help="Target frame count (e.g., 121 for HunyuanVideo)" + ) + + # Resolution options + video_res_group = video_parser.add_mutually_exclusive_group() + video_res_group.add_argument( + "--resolution_preset", + type=str, + choices=["256p", "512p", "768p", "1024p", "1536p"], + help="Resolution preset (videos bucketed by aspect ratio)", + ) + video_res_group.add_argument("--max_pixels", type=int, help="Custom pixel budget for bucketing") + + # Explicit size options (disables bucketing) + video_parser.add_argument("--height", type=int, default=None, help="Explicit height (disables bucketing)") + video_parser.add_argument("--width", type=int, default=None, help="Explicit width (disables bucketing)") + + video_parser.add_argument( + "--resize_mode", + type=str, + default="bilinear", + choices=["bilinear", "bicubic", "nearest", "area", "lanczos"], + help="Interpolation mode", + ) + video_parser.add_argument("--center_crop", action="store_true", default=True, help="Center crop (default: True)") + video_parser.add_argument("--no_center_crop", dest="center_crop", action="store_false", help="Disable center crop") + video_parser.add_argument( + "--deterministic", action="store_true", default=True, help="Use deterministic encoding (default: True)" + ) + video_parser.add_argument( + "--stochastic", dest="deterministic", action="store_false", help="Use stochastic (sampled) encoding" + ) + video_parser.add_argument( + "--caption_format", + type=str, + default="sidecar", + choices=["sidecar", "meta_json", "jsonl"], + help="Caption format", + ) + video_parser.add_argument("--caption_field", type=str, default="caption", help="Caption field name") + video_parser.add_argument( + "--output_format", type=str, default="meta", choices=["meta", "pt"], help="Output file format" + ) + video_parser.add_argument("--shard_size", type=int, default=10000, help="Metadata shard size") + video_parser.add_argument("--max_videos", type=int, default=None, help="Max videos to process") + args = parser.parse_args() + # Handle --list_processors if args.list_processors: - print("Available processors:") + logger.info("Available processors:") for name in ProcessorRegistry.list_available(): proc = ProcessorRegistry.get(name) - print(f" {name}: {proc.model_type}") + media_type = "video" if isinstance(proc, BaseVideoProcessor) else "image" + quantization = getattr(proc, "quantization", 64) + frame_constraint = getattr(proc, "frame_constraint", None) or "none" + logger.info(" %s:", name) + logger.info(" type: %s", proc.model_type) + logger.info(" media: %s", media_type) + logger.info(" quantization: %d", quantization) + if media_type == "video": + logger.info(" frame_constraint: %s", frame_constraint) return - if not args.image_dir or not args.output_dir: - parser.error("--image_dir and --output_dir are required") - - if args.resolution_preset and args.max_pixels: - parser.error("Cannot specify both --resolution_preset and --max_pixels") + # Handle subcommands + if args.command == "image": + if args.resolution_preset: + max_pixels = MultiTierBucketCalculator.RESOLUTION_PRESETS[args.resolution_preset] + elif args.max_pixels: + max_pixels = args.max_pixels + else: + max_pixels = 256 * 256 + + preprocess_dataset( + args.image_dir, + args.output_dir, + args.processor, + args.model_name, + args.shard_size, + args.verify, + args.caption_field, + args.max_images, + max_pixels, + ) - if args.resolution_preset: - max_pixels = MultiTierBucketCalculator.RESOLUTION_PRESETS[args.resolution_preset] - elif args.max_pixels: - max_pixels = args.max_pixels + elif args.command == "video": + # Validate explicit size args + if (args.height is None) != (args.width is None): + parser.error("Both --height and --width must be specified together") + + preprocess_video_dataset( + video_dir=args.video_dir, + output_dir=args.output_dir, + processor_name=args.processor, + model_name=args.model_name, + mode=args.mode, + num_frames=args.num_frames, + target_frames=args.target_frames, + resolution_preset=args.resolution_preset, + max_pixels=args.max_pixels, + target_height=args.height, + target_width=args.width, + resize_mode=args.resize_mode, + center_crop=args.center_crop, + deterministic=args.deterministic, + output_format=args.output_format, + caption_format=args.caption_format, + caption_field=args.caption_field, + shard_size=args.shard_size, + max_videos=args.max_videos, + ) else: - max_pixels = 256 * 256 - - preprocess_dataset( - args.image_dir, - args.output_dir, - args.processor, - args.model_name, - args.shard_size, - args.verify, - args.caption_field, - args.max_images, - max_pixels, - ) + parser.print_help() if __name__ == "__main__": diff --git a/dfm/src/automodel/utils/processors/__init__.py b/dfm/src/automodel/utils/processors/__init__.py index 5991f3fc..50d37467 100644 --- a/dfm/src/automodel/utils/processors/__init__.py +++ b/dfm/src/automodel/utils/processors/__init__.py @@ -13,12 +13,37 @@ # limitations under the License. from .base import BaseModelProcessor +from .base_video import BaseVideoProcessor +from .caption_loaders import ( + CaptionLoader, + CaptionLoadingStats, + JSONLCaptionLoader, + JSONSidecarCaptionLoader, + MetaJSONCaptionLoader, + get_caption_loader, +) from .flux import FluxProcessor +from .hunyuan import HunyuanVideoProcessor from .registry import ProcessorRegistry +from .wan import WanProcessor __all__ = [ + # Base classes "BaseModelProcessor", + "BaseVideoProcessor", + # Registry "ProcessorRegistry", + # Image processors "FluxProcessor", + # Video processors + "WanProcessor", + "HunyuanVideoProcessor", + # Caption loaders + "CaptionLoader", + "CaptionLoadingStats", + "JSONSidecarCaptionLoader", + "MetaJSONCaptionLoader", + "JSONLCaptionLoader", + "get_caption_loader", ] diff --git a/dfm/src/automodel/utils/processors/base.py b/dfm/src/automodel/utils/processors/base.py index 29ea979d..55a17c3f 100644 --- a/dfm/src/automodel/utils/processors/base.py +++ b/dfm/src/automodel/utils/processors/base.py @@ -142,7 +142,7 @@ def get_cache_data( text_encodings: Dict of text embeddings from encode_text() metadata: Dict containing: - original_resolution: Tuple[int, int] - - crop_resolution: Tuple[int, int] + - bucket_resolution: Tuple[int, int] - crop_offset: Tuple[int, int] - prompt: str - image_path: str diff --git a/dfm/src/automodel/utils/processors/base_video.py b/dfm/src/automodel/utils/processors/base_video.py new file mode 100644 index 00000000..f9fb03d5 --- /dev/null +++ b/dfm/src/automodel/utils/processors/base_video.py @@ -0,0 +1,382 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base class for video model preprocessing. + +Extends BaseModelProcessor with video-specific functionality for models like +Wan2.1 and HunyuanVideo. +""" + +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch + +from .base import BaseModelProcessor + + +class BaseVideoProcessor(BaseModelProcessor): + """ + Abstract base class for video model preprocessing. + + Extends BaseModelProcessor with video-specific methods for: + - Video loading and frame extraction + - Video VAE encoding + - Frame count constraints (e.g., 4n+1 for HunyuanVideo) + - First frame handling for image-to-video models + """ + + @property + @abstractmethod + def supported_modes(self) -> List[str]: + """ + Return supported input modes. + + Returns: + List of supported modes: 'video' for video files, 'frames' for image sequences + """ + pass + + @property + def frame_constraint(self) -> Optional[str]: + """ + Return frame count constraint. + + Returns: + Frame constraint string (e.g., '4n+1') or None if no constraint + """ + return None + + @property + def quantization(self) -> int: + """ + VAE quantization requirement. + + Video models typically use 8 due to 3D VAE temporal compression. + Override in subclasses if different. + + Returns: + Resolution quantization factor (default 8 for video models) + """ + return 8 + + @abstractmethod + def encode_video( + self, + video_tensor: torch.Tensor, + models: Dict[str, Any], + device: str, + deterministic: bool = True, + **kwargs, + ) -> torch.Tensor: + """ + Encode video tensor to latent space. + + Args: + video_tensor: Video tensor of shape (1, C, T, H, W), normalized to [-1, 1] + models: Dict of loaded models from load_models() + device: Device to use for encoding + deterministic: If True, use mean instead of sampling from latent distribution + **kwargs: Additional model-specific arguments + + Returns: + Latent tensor (shape varies by model, typically (1, C, T', H', W')) + """ + pass + + @abstractmethod + def load_video( + self, + video_path: str, + target_size: Tuple[int, int], + num_frames: Optional[int] = None, + **kwargs, + ) -> Tuple[torch.Tensor, np.ndarray]: + """ + Load video from file and preprocess. + + Args: + video_path: Path to video file + target_size: Target (height, width) + num_frames: Number of frames to extract (None = all frames) + **kwargs: Additional loading options + + Returns: + Tuple of: + - video_tensor: Tensor of shape (1, C, T, H, W), normalized to [-1, 1] + - first_frame: First frame as numpy array (H, W, C) in uint8 for caching + """ + pass + + def adjust_frame_count(self, frames: np.ndarray, target_frames: int) -> np.ndarray: + """ + Adjust frame count to meet model constraints. + + Override in subclasses that have specific frame count requirements + (e.g., HunyuanVideo requires 4n+1 frames). + + Args: + frames: Array of frames (T, H, W, C) + target_frames: Target number of frames + + Returns: + Adjusted frames array with target_frames frames + """ + current_frames = len(frames) + if current_frames == target_frames: + return frames + + # Default: uniform sampling to reach target frame count + indices = np.linspace(0, current_frames - 1, target_frames).astype(int) + return frames[indices] + + def encode_image( + self, + image_tensor: torch.Tensor, + models: Dict[str, Any], + device: str, + ) -> torch.Tensor: + """ + Encode single image by treating it as a 1-frame video. + + Default implementation wraps image as video and delegates to encode_video. + + Args: + image_tensor: Image tensor of shape (1, C, H, W), normalized to [-1, 1] + models: Dict of loaded models from load_models() + device: Device to use for encoding + + Returns: + Latent tensor + """ + # Add temporal dimension: (1, C, H, W) -> (1, C, 1, H, W) + video_tensor = image_tensor.unsqueeze(2) + return self.encode_video(video_tensor, models, device) + + def verify_latent( + self, + latent: torch.Tensor, + models: Dict[str, Any], + device: str, + ) -> bool: + """ + Verify that a latent can be decoded. + + Default implementation checks for NaN/Inf values. + Override for model-specific verification. + + Args: + latent: Encoded latent tensor + models: Dict of loaded models from load_models() + device: Device to use for verification + + Returns: + True if verification passes, False otherwise + """ + try: + # Basic sanity checks + if torch.isnan(latent).any(): + return False + if torch.isinf(latent).any(): + return False + return True + except Exception: + return False + + def validate_latent_shape( + self, + latent: torch.Tensor, + expected_channels: int, + spatial_downscale: int = 8, + temporal_downscale: int = 4, + input_shape: Optional[Tuple[int, int, int, int, int]] = None, + ) -> Tuple[bool, Optional[str]]: + """ + Validate latent tensor shape based on expected dimensions. + + This helper validates that the encoded latent has the expected shape + given the input dimensions and model-specific downscale factors. + + Args: + latent: Encoded latent tensor, expected shape (B, C, T', H', W') + expected_channels: Expected number of latent channels + spatial_downscale: Spatial downscale factor (default 8) + temporal_downscale: Temporal downscale factor (default 4) + input_shape: Optional input shape (B, C, T, H, W) for dimension validation + + Returns: + Tuple of (is_valid, error_message) + - is_valid: True if shape is valid + - error_message: Description of issue if invalid, None if valid + """ + # Check basic tensor properties + if latent.ndim != 5: + return False, f"Expected 5D tensor (B, C, T, H, W), got {latent.ndim}D" + + B, C, T, H, W = latent.shape + + # Check channel count + if C != expected_channels: + return False, f"Expected {expected_channels} channels, got {C}" + + # Check for invalid values + if torch.isnan(latent).any(): + return False, "Latent contains NaN values" + if torch.isinf(latent).any(): + return False, "Latent contains Inf values" + + # Validate dimensions against input shape if provided + if input_shape is not None: + _, _, in_T, in_H, in_W = input_shape + expected_T = max(1, (in_T + temporal_downscale - 1) // temporal_downscale) + expected_H = in_H // spatial_downscale + expected_W = in_W // spatial_downscale + + if T != expected_T: + return False, f"Expected temporal dim {expected_T}, got {T}" + if H != expected_H: + return False, f"Expected height {expected_H}, got {H}" + if W != expected_W: + return False, f"Expected width {expected_W}, got {W}" + + return True, None + + def load_video_frames( + self, + video_path: str, + target_size: Tuple[int, int], + num_frames: Optional[int] = None, + resize_mode: str = "bilinear", + center_crop: bool = True, + ) -> Tuple[np.ndarray, Dict[str, Any]]: + """ + Load video frames using OpenCV with resizing and optional center crop. + + This is a utility method that can be used by subclass implementations. + + Args: + video_path: Path to video file + target_size: Target (height, width) + num_frames: Number of frames to extract (None = all) + resize_mode: Interpolation mode for resizing + center_crop: Whether to center crop to target aspect ratio + + Returns: + Tuple of: + - frames: numpy array (T, H, W, C) in uint8 + - info: Dict with video metadata (fps, original_size, etc.) + """ + import cv2 + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise ValueError(f"Failed to open video: {video_path}") + + # Get video properties + fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + # Determine which frames to extract + if num_frames is not None and num_frames < total_frames: + # Uniform sampling + frame_indices = np.linspace(0, total_frames - 1, num_frames).astype(int) + else: + frame_indices = np.arange(total_frames) + + target_height, target_width = target_size + + # Map resize modes to OpenCV interpolation + interp_map = { + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "nearest": cv2.INTER_NEAREST, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4, + } + interpolation = interp_map.get(resize_mode, cv2.INTER_LINEAR) + + frames = [] + current_idx = 0 + + for target_idx in frame_indices: + # Seek to frame if needed + if current_idx != target_idx: + cap.set(cv2.CAP_PROP_POS_FRAMES, target_idx) + + ret, frame = cap.read() + if not ret: + break + + # Convert BGR to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Resize and optionally center crop + if center_crop: + # Calculate scale to cover target area + scale = max(target_width / orig_width, target_height / orig_height) + new_width = int(orig_width * scale) + new_height = int(orig_height * scale) + + frame = cv2.resize(frame, (new_width, new_height), interpolation=interpolation) + + # Center crop + start_x = (new_width - target_width) // 2 + start_y = (new_height - target_height) // 2 + frame = frame[start_y : start_y + target_height, start_x : start_x + target_width] + else: + # Direct resize (may change aspect ratio) + frame = cv2.resize(frame, (target_width, target_height), interpolation=interpolation) + + frames.append(frame) + current_idx = target_idx + 1 + + cap.release() + + frames = np.array(frames, dtype=np.uint8) + + info = { + "fps": fps, + "total_frames": total_frames, + "extracted_frames": len(frames), + "original_size": (orig_width, orig_height), + "target_size": (target_width, target_height), + } + + return frames, info + + def frames_to_tensor(self, frames: np.ndarray) -> torch.Tensor: + """ + Convert numpy frames array to normalized tensor. + + Args: + frames: numpy array (T, H, W, C) in uint8 + + Returns: + Tensor of shape (1, C, T, H, W) normalized to [-1, 1] + """ + # (T, H, W, C) -> (T, C, H, W) + tensor = torch.from_numpy(frames).float().permute(0, 3, 1, 2) + + # Normalize to [-1, 1] + tensor = tensor / 255.0 + tensor = (tensor - 0.5) / 0.5 + + # Add batch dimension: (T, C, H, W) -> (1, C, T, H, W) + tensor = tensor.permute(1, 0, 2, 3).unsqueeze(0) + + return tensor diff --git a/dfm/src/automodel/utils/processors/caption_loaders.py b/dfm/src/automodel/utils/processors/caption_loaders.py new file mode 100644 index 00000000..d8ecfe7d --- /dev/null +++ b/dfm/src/automodel/utils/processors/caption_loaders.py @@ -0,0 +1,485 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Caption loading strategies for preprocessing. + +Provides multiple ways to load captions for media files: +- JSONSidecarCaptionLoader: video.mp4 -> video.json with {"caption": "..."} +- MetaJSONCaptionLoader: meta.json with [{"file_name": "...", "caption": "..."}] +- JSONLCaptionLoader: Existing JSONL format for images +""" + +import json +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Tuple + + +logger = logging.getLogger(__name__) + + +@dataclass +class CaptionLoadingStats: + """ + Statistics from caption loading operations. + + Provides detailed information about caption loading for debugging + and progress reporting. + """ + + # Number of captions successfully loaded + loaded_count: int = 0 + + # Number of caption files found and parsed + files_parsed: int = 0 + + # Number of expected caption files that were missing + files_missing: int = 0 + + # Number of media files without captions (will use fallback) + captions_missing: int = 0 + + # Error messages encountered during loading + errors: List[str] = field(default_factory=list) + + def __str__(self) -> str: + """Human-readable summary.""" + return ( + f"Loaded {self.loaded_count} captions from {self.files_parsed} files " + f"({self.files_missing} missing, {self.captions_missing} using fallback)" + ) + + +class CaptionLoader(ABC): + """ + Abstract base class for caption loading strategies. + + Different datasets organize captions in different ways: + - Sidecar files (one JSON per media file) + - Single metadata file (meta.json with all captions) + - JSONL files (line-delimited JSON entries) + """ + + @abstractmethod + def load_captions( + self, + media_files: List[Path], + caption_field: str = "caption", + verbose: bool = False, + ) -> Dict[str, str]: + """ + Load captions for a list of media files. + + Args: + media_files: List of media file paths + caption_field: Field name containing the caption text + verbose: If True, print progress information + + Returns: + Dict mapping filename (not full path) to caption text + """ + pass + + def load_captions_with_stats( + self, + media_files: List[Path], + caption_field: str = "caption", + verbose: bool = False, + ) -> Tuple[Dict[str, str], CaptionLoadingStats]: + """ + Load captions and return statistics. + + Args: + media_files: List of media file paths + caption_field: Field name containing the caption text + verbose: If True, print progress information + + Returns: + Tuple of (captions dict, loading statistics) + """ + # Default implementation - subclasses can override for efficiency + captions = self.load_captions(media_files, caption_field, verbose) + stats = CaptionLoadingStats( + loaded_count=len(captions), + captions_missing=len(media_files) - len(captions), + ) + return captions, stats + + @staticmethod + def get_loader(format_name: str) -> "CaptionLoader": + """ + Factory method to get the appropriate caption loader. + + Args: + format_name: One of 'sidecar', 'meta_json', 'jsonl' + + Returns: + CaptionLoader instance + + Raises: + ValueError: If format_name is unknown + """ + loaders = { + "sidecar": JSONSidecarCaptionLoader, + "meta_json": MetaJSONCaptionLoader, + "jsonl": JSONLCaptionLoader, + } + if format_name not in loaders: + available = ", ".join(sorted(loaders.keys())) + raise ValueError(f"Unknown caption format: '{format_name}'. Available: {available}") + return loaders[format_name]() + + +class JSONSidecarCaptionLoader(CaptionLoader): + """ + Load captions from JSON sidecar files. + + Expects: video.mp4 -> video.json with content like: + {"caption": "A video of..."} + + This is common for video datasets where each video has its own metadata file. + """ + + def load_captions( + self, + media_files: List[Path], + caption_field: str = "caption", + verbose: bool = False, + ) -> Dict[str, str]: + """ + Load captions from sidecar JSON files. + + For each media file (e.g., video.mp4), looks for a corresponding + JSON file (video.json) in the same directory. + + Args: + media_files: List of media file paths + caption_field: Field name containing the caption text + verbose: If True, print progress information + + Returns: + Dict mapping filename to caption text + """ + captions, _ = self.load_captions_with_stats(media_files, caption_field, verbose) + return captions + + def load_captions_with_stats( + self, + media_files: List[Path], + caption_field: str = "caption", + verbose: bool = False, + ) -> Tuple[Dict[str, str], CaptionLoadingStats]: + """ + Load captions from sidecar JSON files with statistics. + + Args: + media_files: List of media file paths + caption_field: Field name containing the caption text + verbose: If True, print progress information + + Returns: + Tuple of (captions dict, loading statistics) + """ + captions = {} + stats = CaptionLoadingStats() + + if verbose: + logger.info("Loading captions from sidecar JSON files...") + + for media_path in media_files: + # Look for sidecar JSON: video.mp4 -> video.json + json_path = media_path.with_suffix(".json") + + if not json_path.exists(): + stats.files_missing += 1 + continue + + try: + with open(json_path, "r", encoding="utf-8") as f: + data = json.load(f) + + stats.files_parsed += 1 + caption = data.get(caption_field) + if caption: + captions[media_path.name] = caption + stats.loaded_count += 1 + + except json.JSONDecodeError as e: + stats.errors.append(f"JSON error in {json_path}: {e}") + except IOError as e: + stats.errors.append(f"IO error reading {json_path}: {e}") + + stats.captions_missing = len(media_files) - stats.loaded_count + + if verbose: + logger.info(" %s", stats) + if stats.errors and len(stats.errors) <= 5: + for err in stats.errors: + logger.warning(" %s", err) + + return captions, stats + + +class MetaJSONCaptionLoader(CaptionLoader): + """ + Load captions from a centralized meta.json file. + + Expects: meta.json with content like: + [ + {"file_name": "video1.mp4", "caption": "..."}, + {"file_name": "video2.mp4", "caption": "..."} + ] + or: + { + "items": [ + {"file_name": "video1.mp4", "caption": "..."}, + ... + ] + } + + This is common for curated datasets with a single metadata file. + """ + + def load_captions( + self, + media_files: List[Path], + caption_field: str = "caption", + verbose: bool = False, + ) -> Dict[str, str]: + """ + Load captions from meta.json files. + + Looks for meta.json in each unique directory containing media files. + + Args: + media_files: List of media file paths + caption_field: Field name containing the caption text + verbose: If True, print progress information + + Returns: + Dict mapping filename to caption text + """ + captions, _ = self.load_captions_with_stats(media_files, caption_field, verbose) + return captions + + def load_captions_with_stats( + self, + media_files: List[Path], + caption_field: str = "caption", + verbose: bool = False, + ) -> Tuple[Dict[str, str], CaptionLoadingStats]: + """ + Load captions from meta.json files with statistics. + + Args: + media_files: List of media file paths + caption_field: Field name containing the caption text + verbose: If True, print progress information + + Returns: + Tuple of (captions dict, loading statistics) + """ + captions = {} + stats = CaptionLoadingStats() + + if verbose: + logger.info("Loading captions from meta.json files...") + + # Group media files by directory to find meta.json files + dirs = set(p.parent for p in media_files) + + for directory in dirs: + meta_path = directory / "meta.json" + if not meta_path.exists(): + stats.files_missing += 1 + continue + + try: + with open(meta_path, "r", encoding="utf-8") as f: + data = json.load(f) + + stats.files_parsed += 1 + + # Handle both list format and dict with 'items' key + if isinstance(data, dict): + items = data.get("items", data.get("data", [])) + else: + items = data + + for item in items: + if not isinstance(item, dict): + continue + + file_name = item.get("file_name") or item.get("filename") + caption = item.get(caption_field) + + if file_name and caption: + captions[file_name] = caption + stats.loaded_count += 1 + + except json.JSONDecodeError as e: + stats.errors.append(f"JSON error in {meta_path}: {e}") + except IOError as e: + stats.errors.append(f"IO error reading {meta_path}: {e}") + + stats.captions_missing = len(media_files) - stats.loaded_count + + if verbose: + logger.info(" %s", stats) + if stats.errors and len(stats.errors) <= 5: + for err in stats.errors: + logger.warning(" %s", err) + + return captions, stats + + +class JSONLCaptionLoader(CaptionLoader): + """ + Load captions from JSONL files. + + Expects: _internvl.json (JSONL format) with content like: + {"file_name": "image1.jpg", "internvl": "..."} + {"file_name": "image2.jpg", "internvl": "..."} + + This is the existing format used for image preprocessing. + """ + + def __init__(self, jsonl_suffix: str = "_internvl.json"): + """ + Args: + jsonl_suffix: Suffix for JSONL files (default: '_internvl.json') + """ + self.jsonl_suffix = jsonl_suffix + + def load_captions( + self, + media_files: List[Path], + caption_field: str = "internvl", + verbose: bool = False, + ) -> Dict[str, str]: + """ + Load captions from JSONL files. + + For each media file, determines the associated JSONL file based on + the filename pattern (prefix before '_sample' + suffix). + + Args: + media_files: List of media file paths + caption_field: Field name containing the caption text + verbose: If True, print progress information + + Returns: + Dict mapping filename to caption text + """ + captions, _ = self.load_captions_with_stats(media_files, caption_field, verbose) + return captions + + def load_captions_with_stats( + self, + media_files: List[Path], + caption_field: str = "internvl", + verbose: bool = False, + ) -> Tuple[Dict[str, str], CaptionLoadingStats]: + """ + Load captions from JSONL files with statistics. + + Args: + media_files: List of media file paths + caption_field: Field name containing the caption text + verbose: If True, print progress information + + Returns: + Tuple of (captions dict, loading statistics) + """ + from collections import defaultdict + + captions = {} + stats = CaptionLoadingStats() + + if verbose: + logger.info("Loading captions from JSONL files...") + + # Group files by their JSONL file + jsonl_to_files: Dict[Path, List[str]] = defaultdict(list) + + for media_path in media_files: + media_name = media_path.name + + # Extract prefix: everything before '_sample' + if "_sample" in media_name: + prefix = media_name.rsplit("_sample", 1)[0] + else: + prefix = media_path.stem + + json_path = media_path.parent / f"{prefix}{self.jsonl_suffix}" + jsonl_to_files[json_path].append(media_name) + + # Load each JSONL file once + for json_path, file_names in jsonl_to_files.items(): + if not json_path.exists(): + stats.files_missing += 1 + continue + + try: + with open(json_path, "r", encoding="utf-8") as f: + stats.files_parsed += 1 + line_num = 0 + for line in f: + line_num += 1 + line = line.strip() + if not line: + continue + + try: + entry = json.loads(line) + file_name = entry.get("file_name") + caption = entry.get(caption_field) + + if file_name and caption and file_name in file_names: + captions[file_name] = caption + stats.loaded_count += 1 + + except json.JSONDecodeError as e: + stats.errors.append(f"JSON error in {json_path} line {line_num}: {e}") + + except IOError as e: + stats.errors.append(f"IO error reading {json_path}: {e}") + + stats.captions_missing = len(media_files) - stats.loaded_count + + if verbose: + logger.info(" %s", stats) + if stats.files_missing > 0: + logger.info(" %d JSONL files not found (will use filename fallback)", stats.files_missing) + if stats.errors and len(stats.errors) <= 5: + for err in stats.errors: + logger.warning(" %s", err) + + return captions, stats + + +def get_caption_loader(format_name: str) -> CaptionLoader: + """ + Convenience function to get a caption loader by format name. + + Args: + format_name: One of 'sidecar', 'meta_json', 'jsonl' + + Returns: + CaptionLoader instance + """ + return CaptionLoader.get_loader(format_name) diff --git a/dfm/src/automodel/utils/processors/flux.py b/dfm/src/automodel/utils/processors/flux.py index c189f9d9..3117da81 100644 --- a/dfm/src/automodel/utils/processors/flux.py +++ b/dfm/src/automodel/utils/processors/flux.py @@ -21,6 +21,7 @@ - T5 text encoder """ +import logging from typing import Any, Dict import torch @@ -30,6 +31,9 @@ from .registry import ProcessorRegistry +logger = logging.getLogger(__name__) + + @ProcessorRegistry.register("flux") class FluxProcessor(BaseModelProcessor): """ @@ -65,7 +69,7 @@ def load_models(self, model_name: str, device: str) -> Dict[str, Any]: """ from diffusers import FluxPipeline - print(f"[FLUX] Loading models from {model_name} via FluxPipeline...") + logger.info("[FLUX] Loading models from %s via FluxPipeline...", model_name) # Load pipeline without transformer (not needed for preprocessing) pipeline = FluxPipeline.from_pretrained( @@ -76,21 +80,21 @@ def load_models(self, model_name: str, device: str) -> Dict[str, Any]: models = {} - print(" Configuring VAE...") + logger.info(" Configuring VAE...") models["vae"] = pipeline.vae.to(device=device, dtype=torch.bfloat16) models["vae"].eval() - print(f"!!! VAE config: {models['vae'].config}") - print(f"!!! VAE shift_factor: {models['vae'].config.shift_factor}") - print(f"!!! VAE scaling_factor: {models['vae'].config.scaling_factor}") + logger.debug("VAE config: %s", models["vae"].config) + logger.debug("VAE shift_factor: %s", models["vae"].config.shift_factor) + logger.debug("VAE scaling_factor: %s", models["vae"].config.scaling_factor) # Extract CLIP components - print(" Configuring CLIP...") + logger.info(" Configuring CLIP...") models["clip_tokenizer"] = pipeline.tokenizer models["clip_encoder"] = pipeline.text_encoder.to(device) models["clip_encoder"].eval() # Extract T5 components - print(" Configuring T5...") + logger.info(" Configuring T5...") models["t5_tokenizer"] = pipeline.tokenizer_2 models["t5_encoder"] = pipeline.text_encoder_2.to(device) models["t5_encoder"].eval() @@ -99,7 +103,7 @@ def load_models(self, model_name: str, device: str) -> Dict[str, Any]: del pipeline torch.cuda.empty_cache() - print("[FLUX] Models loaded successfully!") + logger.info("[FLUX] Models loaded successfully!") return models def encode_image( @@ -231,7 +235,7 @@ def verify_latent( return True except Exception as e: - print(f"[FLUX] Verification failed: {e}") + logger.warning("[FLUX] Verification failed: %s", e) return False def get_cache_data( @@ -263,7 +267,7 @@ def get_cache_data( "prompt_embeds": text_encodings["prompt_embeds"], # Metadata "original_resolution": metadata["original_resolution"], - "crop_resolution": metadata["crop_resolution"], + "bucket_resolution": metadata["bucket_resolution"], "crop_offset": metadata["crop_offset"], "prompt": metadata["prompt"], "image_path": metadata["image_path"], diff --git a/dfm/src/automodel/utils/processors/hunyuan.py b/dfm/src/automodel/utils/processors/hunyuan.py new file mode 100644 index 00000000..63c73117 --- /dev/null +++ b/dfm/src/automodel/utils/processors/hunyuan.py @@ -0,0 +1,419 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +HunyuanVideo-1.5 model processor for preprocessing. + +Handles HunyuanVideo-1.5 video models with: +- HunyuanVideo VAE for video encoding +- Dual text encoders (CLIP-like + LLaMA) +- Image encoder for first frame (i2v conditioning) +- 4n+1 frame constraint +""" + +import logging +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from PIL import Image + +from .base_video import BaseVideoProcessor +from .registry import ProcessorRegistry + + +logger = logging.getLogger(__name__) + + +@ProcessorRegistry.register("hunyuan") +@ProcessorRegistry.register("hunyuanvideo") +@ProcessorRegistry.register("hunyuanvideo-1.5") +class HunyuanVideoProcessor(BaseVideoProcessor): + """ + Processor for HunyuanVideo-1.5 video models. + + HunyuanVideo uses: + - HunyuanVideo VAE with shift_factor/scaling_factor normalization + - Dual text encoders (CLIP-like + LLaMA) via pipeline.encode_prompt() + - Image encoder for first frame embeddings (i2v conditioning) + - 4n+1 frame constraint (1, 5, 9, 13, 17, ... 121) + + Default image embedding shape is (729, 1152). + """ + + # Default image embedding shape for HunyuanVideo + DEFAULT_IMAGE_EMBED_SHAPE = (729, 1152) + + @property + def model_type(self) -> str: + return "hunyuanvideo" + + @property + def default_model_name(self) -> str: + return "hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_i2v" + + @property + def supported_modes(self) -> List[str]: + return ["video"] + + @property + def frame_constraint(self) -> str: + return "4n+1" + + @property + def quantization(self) -> int: + # HunyuanVideo VAE requires 8-pixel aligned dimensions + return 8 + + def load_models(self, model_name: str, device: str) -> Dict[str, Any]: + """ + Load HunyuanVideo-1.5 models via pipeline. + + Args: + model_name: HuggingFace model path + device: Device to load models on + + Returns: + Dict containing pipeline and individual components + """ + from diffusers import HunyuanVideo15ImageToVideoPipeline + + dtype = torch.float16 if "cuda" in device else torch.float32 + + logger.info("[HunyuanVideo] Loading pipeline from %s...", model_name) + + # Load pipeline without transformer to save memory + # cpu_offload=True helps manage VRAM + pipeline = HunyuanVideo15ImageToVideoPipeline.from_pretrained( + model_name, + torch_dtype=dtype, + transformer=None, # Don't load transformer for preprocessing + ) + + logger.info(" Configuring VAE...") + vae = pipeline.vae + vae.to(device) + vae.eval() + + # Enable memory optimizations + if hasattr(vae, "enable_tiling"): + vae.enable_tiling( + tile_sample_min_height=64, + tile_sample_min_width=64, + tile_overlap_factor=0.25, + ) + logger.info(" VAE tiling enabled") + + if hasattr(vae, "enable_slicing"): + vae.enable_slicing() + logger.info(" VAE slicing enabled") + + logger.info("[HunyuanVideo] Models loaded successfully!") + + return { + "pipeline": pipeline, + "vae": vae, + "text_encoder": pipeline.text_encoder, + "tokenizer": pipeline.tokenizer, + "image_encoder": pipeline.image_encoder, + "dtype": dtype, + "device": device, + } + + def adjust_frame_count(self, frames: np.ndarray, target_frames: int) -> np.ndarray: + """ + Adjust frame count to meet 4n+1 constraint. + + Args: + frames: Array of frames (T, H, W, C) + target_frames: Target number of frames (must be 4n+1) + + Returns: + Adjusted frames array with target_frames frames + """ + # Validate target_frames is 4n+1 + if (target_frames - 1) % 4 != 0: + raise ValueError(f"target_frames must be 4n+1 (e.g., 1, 5, 9, 13, ..., 121), got {target_frames}") + + num_frames = len(frames) + + if num_frames == target_frames: + return frames + + # Sample frames uniformly to reach target + indices = np.linspace(0, num_frames - 1, target_frames).astype(int) + return frames[indices] + + def validate_frame_count(self, num_frames: int) -> bool: + """ + Check if frame count satisfies 4n+1 constraint. + + Args: + num_frames: Number of frames + + Returns: + True if valid, False otherwise + """ + return (num_frames - 1) % 4 == 0 + + def get_closest_valid_frame_count(self, num_frames: int) -> int: + """ + Get the closest valid 4n+1 frame count. + + Args: + num_frames: Current number of frames + + Returns: + Closest 4n+1 value + """ + n = (num_frames - 1) // 4 + lower = 4 * n + 1 + upper = 4 * (n + 1) + 1 + + if num_frames - lower <= upper - num_frames: + return max(1, lower) + else: + return upper + + def load_video( + self, + video_path: str, + target_size: Tuple[int, int], + num_frames: Optional[int] = None, + resize_mode: str = "bilinear", + center_crop: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, np.ndarray]: + """ + Load video from file and preprocess with 4n+1 frame handling. + + Args: + video_path: Path to video file + target_size: Target (height, width) + num_frames: Target number of frames (should be 4n+1) + resize_mode: Interpolation mode for resizing + center_crop: Whether to center crop + + Returns: + Tuple of: + - video_tensor: Tensor of shape (1, C, T, H, W), normalized to [-1, 1] + - first_frame: First frame as numpy array (H, W, C) in uint8 + """ + # Use base class utility to load frames + frames, info = self.load_video_frames( + video_path, + target_size, + num_frames=None, # Load all frames first + resize_mode=resize_mode, + center_crop=center_crop, + ) + + # Adjust to 4n+1 if target specified + if num_frames is not None: + frames = self.adjust_frame_count(frames, num_frames) + else: + # Auto-adjust to closest 4n+1 + target = self.get_closest_valid_frame_count(len(frames)) + if target != len(frames): + frames = self.adjust_frame_count(frames, target) + + # Save first frame before converting to tensor + first_frame = frames[0].copy() + + # Convert to tensor + video_tensor = self.frames_to_tensor(frames) + + return video_tensor, first_frame + + def encode_video( + self, + video_tensor: torch.Tensor, + models: Dict[str, Any], + device: str, + deterministic: bool = True, + **kwargs, + ) -> torch.Tensor: + """ + Encode video tensor to latent space using HunyuanVideo VAE. + + Uses shift_factor and scaling_factor normalization. + + Args: + video_tensor: Video tensor (1, C, T, H, W), normalized to [-1, 1] + models: Dict containing 'vae' + device: Device to use + deterministic: If True, use mean instead of sampling from latent distribution + + Returns: + Latent tensor (1, C, T', H', W'), FP16 + """ + vae = models["vae"] + dtype = models.get("dtype", torch.float16) + + video_tensor = video_tensor.to(device=device, dtype=dtype) + + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=dtype, enabled=(device != "cpu")): + latent_dist = vae.encode(video_tensor) + + # Use mean for deterministic encoding, sample otherwise + if deterministic: + latents = latent_dist.latent_dist.mean + else: + latents = latent_dist.latent_dist.sample() + + # Apply HunyuanVideo-specific latent normalization + if hasattr(vae.config, "shift_factor") and vae.config.shift_factor: + latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor + else: + latents = latents * vae.config.scaling_factor + + return latents.detach().cpu().to(torch.float16) + + def encode_text( + self, + prompt: str, + models: Dict[str, Any], + device: str, + ) -> Dict[str, torch.Tensor]: + """ + Encode text using dual encoders via pipeline.encode_prompt(). + + Args: + prompt: Text prompt + models: Dict containing pipeline + device: Device to use + + Returns: + Dict containing: + - text_embeddings: Primary text encoder output + - text_mask: Primary attention mask + - text_embeddings_2: Secondary text encoder output + - text_mask_2: Secondary attention mask + """ + pipeline = models["pipeline"] + dtype = models.get("dtype", torch.float16) + + # Move text encoder to device + pipeline.text_encoder.to(device) + pipeline.text_encoder.eval() + + with torch.no_grad(): + ( + prompt_embeds, + prompt_embeds_mask, + prompt_embeds_2, + prompt_embeds_mask_2, + ) = pipeline.encode_prompt( + prompt=prompt, + device=device, + dtype=dtype, + batch_size=1, + num_videos_per_prompt=1, + ) + + # Move back to CPU to free VRAM + pipeline.text_encoder.to("cpu") + + return { + "text_embeddings": prompt_embeds.detach().cpu(), + "text_mask": prompt_embeds_mask.detach().cpu(), + "text_embeddings_2": prompt_embeds_2.detach().cpu(), + "text_mask_2": prompt_embeds_mask_2.detach().cpu(), + } + + def encode_first_frame( + self, + first_frame: np.ndarray, + models: Dict[str, Any], + device: str, + ) -> torch.Tensor: + """ + Encode first frame using image encoder for i2v conditioning. + + Args: + first_frame: First frame as numpy array (H, W, C) in uint8 + models: Dict containing pipeline with image_encoder + device: Device to use + + Returns: + Image embeddings tensor (1, 729, 1152) + """ + pipeline = models["pipeline"] + dtype = models.get("dtype", torch.float16) + + # Move image encoder to device + pipeline.image_encoder.to(device) + + # Convert numpy to PIL Image if needed + if isinstance(first_frame, np.ndarray): + first_frame_pil = Image.fromarray(first_frame) + else: + first_frame_pil = first_frame + + with torch.no_grad(): + image_embeds = pipeline.encode_image( + image=first_frame_pil, + batch_size=1, + device=device, + dtype=dtype, + ) + + # Move back to CPU + pipeline.image_encoder.to("cpu") + + return image_embeds.detach().cpu() + + def get_cache_data( + self, + latent: torch.Tensor, + text_encodings: Dict[str, torch.Tensor], + metadata: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Construct cache dictionary for HunyuanVideo. + + Args: + latent: Encoded latent tensor (1, C, T, H, W) + text_encodings: Dict from encode_text() + metadata: Additional metadata including image_embeds + + Returns: + Dict to save with torch.save() or pickle + """ + return { + # Video latent + "video_latents": latent, + # Dual text embeddings + "text_embeddings": text_encodings["text_embeddings"], + "text_mask": text_encodings["text_mask"], + "text_embeddings_2": text_encodings["text_embeddings_2"], + "text_mask_2": text_encodings["text_mask_2"], + # Image embeddings for i2v + "image_embeds": metadata.get("image_embeds"), + # Resolution and bucketing info + "original_resolution": metadata.get("original_resolution"), + "bucket_resolution": metadata.get("bucket_resolution"), + "bucket_id": metadata.get("bucket_id"), + "aspect_ratio": metadata.get("aspect_ratio"), + # Video info + "num_frames": metadata.get("num_frames"), + "prompt": metadata.get("prompt"), + "video_path": metadata.get("video_path"), + # Processing settings + "deterministic_latents": metadata.get("deterministic", True), + "model_version": "hunyuanvideo-1.5", + "processing_mode": metadata.get("mode", "video"), + "model_type": self.model_type, + } diff --git a/dfm/src/automodel/utils/processors/wan.py b/dfm/src/automodel/utils/processors/wan.py new file mode 100644 index 00000000..87619538 --- /dev/null +++ b/dfm/src/automodel/utils/processors/wan.py @@ -0,0 +1,344 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wan2.1 video model processor for preprocessing. + +Handles Wan2.1-T2V models (1.3B and 14B variants) with: +- AutoencoderKLWan for video encoding +- UMT5 text encoder for text conditioning +- Latent normalization using latents_mean and latents_std +""" + +import html +import logging +import re +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch + +from .base_video import BaseVideoProcessor +from .registry import ProcessorRegistry + + +logger = logging.getLogger(__name__) + + +def _basic_clean(text: str) -> str: + """Fix text encoding issues and unescape HTML entities.""" + try: + from diffusers.utils import is_ftfy_available + + if is_ftfy_available(): + import ftfy + + text = ftfy.fix_text(text) + except ImportError: + pass + text = html.unescape(html.unescape(text)) + return text.strip() + + +def _whitespace_clean(text: str) -> str: + """Normalize whitespace by replacing multiple spaces with single space.""" + text = re.sub(r"\s+", " ", text) + return text.strip() + + +def _prompt_clean(text: str) -> str: + """Clean prompt text exactly as done in WanPipeline.""" + return _whitespace_clean(_basic_clean(text)) + + +@ProcessorRegistry.register("wan") +@ProcessorRegistry.register("wan2.1") +class WanProcessor(BaseVideoProcessor): + """ + Processor for Wan2.1 T2V video models. + + Wan2.1 uses: + - AutoencoderKLWan for video/image encoding with latents_mean/latents_std normalization + - UMT5 text encoder with specific padding behavior (trim and re-pad to 226 tokens) + """ + + # Maximum sequence length for UMT5 text encoder + MAX_SEQUENCE_LENGTH = 226 + + @property + def model_type(self) -> str: + return "wan" + + @property + def default_model_name(self) -> str: + return "Wan-AI/Wan2.1-T2V-14B-Diffusers" + + @property + def supported_modes(self) -> List[str]: + return ["video", "frames"] + + @property + def quantization(self) -> int: + # Wan VAE downsamples by 8x and transformer has patch_size=2 in latent space + # Therefore, pixel dimensions must be divisible by 8 * 2 = 16 + return 16 + + def load_models(self, model_name: str, device: str) -> Dict[str, Any]: + """ + Load Wan2.1 models. + + Args: + model_name: HuggingFace model path (e.g., 'Wan-AI/Wan2.1-T2V-14B-Diffusers') + device: Device to load models on + + Returns: + Dict containing: + - vae: AutoencoderKLWan + - text_encoder: UMT5EncoderModel + - tokenizer: AutoTokenizer + """ + from diffusers import AutoencoderKLWan + from transformers import AutoTokenizer, UMT5EncoderModel + + dtype = torch.float16 if "cuda" in device else torch.float32 + + logger.info("[Wan] Loading models from %s...", model_name) + + # Load text encoder + logger.info(" Loading UMT5 text encoder...") + text_encoder = UMT5EncoderModel.from_pretrained( + model_name, + subfolder="text_encoder", + torch_dtype=dtype, + ) + text_encoder.to(device) + text_encoder.eval() + + # Load VAE + logger.info(" Loading AutoencoderKLWan...") + vae = AutoencoderKLWan.from_pretrained( + model_name, + subfolder="vae", + torch_dtype=dtype, + ) + vae.to(device) + vae.eval() + + # Enable memory optimizations + vae.enable_slicing() + vae.enable_tiling() + + # Load tokenizer + logger.info(" Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="tokenizer") + + logger.info("[Wan] Models loaded successfully!") + logger.debug(" VAE latents_mean: %s", vae.config.latents_mean) + logger.debug(" VAE latents_std: %s", vae.config.latents_std) + + return { + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "dtype": dtype, + } + + def load_video( + self, + video_path: str, + target_size: Tuple[int, int], + num_frames: Optional[int] = None, + resize_mode: str = "bilinear", + center_crop: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, np.ndarray]: + """ + Load video from file and preprocess. + + Args: + video_path: Path to video file + target_size: Target (height, width) + num_frames: Number of frames to extract (None = all frames) + resize_mode: Interpolation mode for resizing + center_crop: Whether to center crop + + Returns: + Tuple of: + - video_tensor: Tensor of shape (1, C, T, H, W), normalized to [-1, 1] + - first_frame: First frame as numpy array (H, W, C) in uint8 + """ + # Use base class utility to load frames + frames, info = self.load_video_frames( + video_path, + target_size, + num_frames=num_frames, + resize_mode=resize_mode, + center_crop=center_crop, + ) + + # Save first frame before converting to tensor + first_frame = frames[0].copy() + + # Convert to tensor + video_tensor = self.frames_to_tensor(frames) + + return video_tensor, first_frame + + def encode_video( + self, + video_tensor: torch.Tensor, + models: Dict[str, Any], + device: str, + deterministic: bool = True, + **kwargs, + ) -> torch.Tensor: + """ + Encode video tensor to latent space using Wan VAE. + + Uses latents_mean and latents_std normalization as per Wan2.1 specification. + + Args: + video_tensor: Video tensor (1, C, T, H, W), normalized to [-1, 1] + models: Dict containing 'vae' + device: Device to use + deterministic: If True, use mean instead of sampling + + Returns: + Latent tensor (1, C, T', H', W'), FP16 + """ + vae = models["vae"] + dtype = models.get("dtype", torch.float16) + + video_tensor = video_tensor.to(device=device, dtype=dtype) + + with torch.no_grad(): + latent_dist = vae.encode(video_tensor) + + if deterministic: + video_latents = latent_dist.latent_dist.mean + else: + video_latents = latent_dist.latent_dist.sample() + + # Apply Wan-specific latent normalization + if not hasattr(vae.config, "latents_mean") or not hasattr(vae.config, "latents_std"): + raise ValueError("Wan2.1 VAE requires latents_mean and latents_std in config") + + latents_mean = torch.tensor(vae.config.latents_mean, device=device, dtype=dtype).view(1, -1, 1, 1, 1) + latents_std = torch.tensor(vae.config.latents_std, device=device, dtype=dtype).view(1, -1, 1, 1, 1) + + latents = (video_latents - latents_mean) / latents_std + + return latents.detach().cpu().to(torch.float16) + + def encode_text( + self, + prompt: str, + models: Dict[str, Any], + device: str, + ) -> Dict[str, torch.Tensor]: + """ + Encode text using UMT5. + + Implements the specific padding behavior for Wan: + 1. Tokenize with padding to max_length + 2. Encode with attention mask + 3. Trim embeddings to actual sequence length + 4. Re-pad with zeros to max_sequence_length (226) + + Args: + prompt: Text prompt + models: Dict containing tokenizer and text_encoder + device: Device to use + + Returns: + Dict containing: + - text_embeddings: UMT5 embeddings (1, 226, hidden_dim) + """ + tokenizer = models["tokenizer"] + text_encoder = models["text_encoder"] + + # Clean prompt + prompt = _prompt_clean(prompt) + + # Tokenize + inputs = tokenizer( + prompt, + max_length=self.MAX_SEQUENCE_LENGTH, + padding="max_length", + truncation=True, + return_tensors="pt", + return_attention_mask=True, + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + + # Calculate actual sequence length (excluding padding) + seq_lens = inputs["attention_mask"].gt(0).sum(dim=1).long() + + with torch.no_grad(): + prompt_embeds = text_encoder( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + ).last_hidden_state + + # CRITICAL: Trim to actual length and re-pad with zeros + # This matches the exact behavior in WanPipeline + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(self.MAX_SEQUENCE_LENGTH - u.size(0), u.size(1))]) for u in prompt_embeds], + dim=0, + ) + + return { + "text_embeddings": prompt_embeds.detach().cpu(), + } + + def get_cache_data( + self, + latent: torch.Tensor, + text_encodings: Dict[str, torch.Tensor], + metadata: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Construct cache dictionary for Wan. + + Args: + latent: Encoded latent tensor (1, C, T, H, W) + text_encodings: Dict from encode_text() + metadata: Additional metadata including first_frame + + Returns: + Dict to save with torch.save() or pickle + """ + return { + # Video latent + "video_latents": latent, + # Text embeddings + "text_embeddings": text_encodings["text_embeddings"], + # First frame for image-to-video conditioning + "first_frame": metadata.get("first_frame"), + # Resolution and bucketing info + "original_resolution": metadata.get("original_resolution"), + "bucket_resolution": metadata.get("bucket_resolution"), + "bucket_id": metadata.get("bucket_id"), + "aspect_ratio": metadata.get("aspect_ratio"), + # Video info + "num_frames": metadata.get("num_frames"), + "prompt": metadata.get("prompt"), + "video_path": metadata.get("video_path"), + # Processing settings + "deterministic_latents": metadata.get("deterministic", True), + "model_version": "wan2.1", + "processing_mode": metadata.get("mode", "video"), + "model_type": self.model_type, + } diff --git a/examples/automodel/finetune/hunyuan_t2v_flow.yaml b/examples/automodel/finetune/hunyuan_t2v_flow.yaml index 70da74fc..8382d683 100644 --- a/examples/automodel/finetune/hunyuan_t2v_flow.yaml +++ b/examples/automodel/finetune/hunyuan_t2v_flow.yaml @@ -11,6 +11,11 @@ optim: weight_decay: 0.01 betas: [0.9, 0.999] +lr_scheduler: + lr_decay_style: cosine + lr_warmup_steps: 0 + min_lr: 1e-6 + fsdp: dp_size: 8 dp_replicate_size: 1 diff --git a/examples/automodel/finetune/wan2_1_t2v_flow.yaml b/examples/automodel/finetune/wan2_1_t2v_flow.yaml index fa3ca082..b8856658 100644 --- a/examples/automodel/finetune/wan2_1_t2v_flow.yaml +++ b/examples/automodel/finetune/wan2_1_t2v_flow.yaml @@ -33,6 +33,11 @@ optim: weight_decay: 0.01 betas: [0.9, 0.999] +lr_scheduler: + lr_decay_style: cosine + lr_warmup_steps: 0 + min_lr: 1e-6 + # Flow matching V2 configuration flow_matching: adapter_type: "simple" # Options: "hunyuan", "simple" diff --git a/examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml b/examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml index 76c88bfd..866ce354 100644 --- a/examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml +++ b/examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml @@ -34,6 +34,11 @@ optim: weight_decay: 0.01 betas: [0.9, 0.999] +lr_scheduler: + lr_decay_style: cosine + lr_warmup_steps: 0 + min_lr: 1e-6 + # Flow matching V2 configuration flow_matching: adapter_type: "simple" # Options: "hunyuan", "simple" diff --git a/examples/automodel/pretrain/cicd/wan21_cicd_nightly_image.yaml b/examples/automodel/pretrain/cicd/wan21_cicd_nightly_image.yaml index 80c66fcf..8c1fb34a 100644 --- a/examples/automodel/pretrain/cicd/wan21_cicd_nightly_image.yaml +++ b/examples/automodel/pretrain/cicd/wan21_cicd_nightly_image.yaml @@ -40,6 +40,11 @@ optim: weight_decay: 0.1 betas: [0.9, 0.95] +lr_scheduler: + lr_decay_style: cosine + lr_warmup_steps: 0 + min_lr: 1e-6 + fsdp: tp_size: 1 cp_size: 1 diff --git a/examples/automodel/pretrain/cicd/wan21_cicd_nightly_video.yaml b/examples/automodel/pretrain/cicd/wan21_cicd_nightly_video.yaml index 0238752d..248fdbd4 100644 --- a/examples/automodel/pretrain/cicd/wan21_cicd_nightly_video.yaml +++ b/examples/automodel/pretrain/cicd/wan21_cicd_nightly_video.yaml @@ -40,6 +40,11 @@ optim: weight_decay: 0.1 betas: [0.9, 0.95] +lr_scheduler: + lr_decay_style: cosine + lr_warmup_steps: 0 + min_lr: 1e-6 + fsdp: tp_size: 1 cp_size: 1 diff --git a/examples/automodel/pretrain/flux_t2i_flow.yaml b/examples/automodel/pretrain/flux_t2i_flow.yaml index f444b7a2..ea5be3f9 100644 --- a/examples/automodel/pretrain/flux_t2i_flow.yaml +++ b/examples/automodel/pretrain/flux_t2i_flow.yaml @@ -17,6 +17,11 @@ optim: weight_decay: 0.01 betas: [0.9, 0.999] +lr_scheduler: + lr_decay_style: cosine + lr_warmup_steps: 0 + min_lr: 1e-6 + fsdp: dp_size: 8 tp_size: 1 diff --git a/examples/automodel/pretrain/wan2_1_t2v_flow.yaml b/examples/automodel/pretrain/wan2_1_t2v_flow.yaml index 2f9ff18c..1d9923f3 100644 --- a/examples/automodel/pretrain/wan2_1_t2v_flow.yaml +++ b/examples/automodel/pretrain/wan2_1_t2v_flow.yaml @@ -37,8 +37,11 @@ optim: optimizer: weight_decay: 0.1 betas: [0.9, 0.95] - # "warmup_steps": 1000, - # "lr_min": 1e-5, + +lr_scheduler: + lr_decay_style: cosine + lr_warmup_steps: 0 + min_lr: 1e-6 flow_matching: