diff --git a/dfm/src/automodel/flow_matching/adapters/flux.py b/dfm/src/automodel/flow_matching/adapters/flux.py index 4d05f464..cdac6afa 100644 --- a/dfm/src/automodel/flow_matching/adapters/flux.py +++ b/dfm/src/automodel/flow_matching/adapters/flux.py @@ -174,7 +174,9 @@ def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: # The pipeline provides timesteps in [0, num_train_timesteps] timesteps = context.timesteps.to(dtype) / 1000.0 - guidance = torch.full((batch_size,), 3.5, device=device, dtype=torch.float32) + # TODO: guidance scale is different across pretraining and finetuning, we need pass it as a hyperparamters. + # needs verify by Pranav + guidance = torch.full((batch_size,), self.guidance_scale, device=device, dtype=torch.float32) inputs = { "hidden_states": packed_latents, diff --git a/examples/automodel/finetune/flux_t2i_flow.yaml b/examples/automodel/finetune/flux_t2i_flow.yaml new file mode 100644 index 00000000..6618eeec --- /dev/null +++ b/examples/automodel/finetune/flux_t2i_flow.yaml @@ -0,0 +1,85 @@ +model: + pretrained_model_name_or_path: "black-forest-labs/FLUX.1-dev" + mode: "finetune" + cache_dir: null + attention_backend: "_flash_3_hub" + + pipeline_spec: + transformer_cls: "FluxTransformer2DModel" + subfolder: "transformer" + load_full_pipeline: false + enable_gradient_checkpointing: false + +optim: + learning_rate: 1e-5 + + optimizer: + weight_decay: 0.01 + betas: [0.9, 0.999] + +#adjust dp_size to the total number of GPUs +fsdp: + dp_size: 8 + tp_size: 1 + cp_size: 1 + pp_size: 1 + activation_checkpointing: false + cpu_offload: false + +flow_matching: + adapter_type: "flux" + adapter_kwargs: + #Critical: use 3.5 guidance scale for finetuning + guidance_scale: 3.5 + use_guidance_embeds: true + timestep_sampling: "logit_normal" + logit_mean: 0.0 + logit_std: 1.0 + flow_shift: 3.0 + mix_uniform_ratio: 0.1 + sigma_min: 0.0 + sigma_max: 1.0 + num_train_timesteps: 1000 + i2v_prob: 0.0 + use_loss_weighting: true + log_interval: 100 + summary_log_interval: 10 + +step_scheduler: + num_epochs: 5000 + local_batch_size: 1 + global_batch_size: 8 + ckpt_every_steps: 2000 + log_every: 1 + +data: + dataloader: + _target_: dfm.src.automodel.datasets.multiresolutionDataloader.build_flux_multiresolution_dataloader + cache_dir: PATH_TO_YOUR_DATA + train_text_encoder: false + num_workers: 10 + # Supported resolutions include [256×256], [512×512], and [1024×1024]. + # While a 1:1 aspect ratio is currently used as a proxy for the closest image size, + # the implementation is designed to support multiple aspect ratios. + base_resolution: [512, 512] + dynamic_batch_size: false + shuffle: true + drop_last: false + +checkpoint: + enabled: true + checkpoint_dir: PATH_TO_YOUR_CKPT_DIR + model_save_format: torch_save + save_consolidated: false + restore_from: null + +wandb: + project: flux-finetuning + mode: online + name: flux_pretrain_ddp_test_run_1 + +dist_env: + backend: "nccl" + init_method: "env://" + +seed: 42 diff --git a/examples/automodel/generate/flux_generate.py b/examples/automodel/generate/flux_generate.py new file mode 100644 index 00000000..187aebdf --- /dev/null +++ b/examples/automodel/generate/flux_generate.py @@ -0,0 +1,277 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FLUX Inference Script with Multi-Resolution Dataloader (Embedding Injection) + +This script loads a FLUX transformer and runs inference by extracting +pre-computed text embeddings directly from the multiresolution dataloader. +""" + +import argparse +import logging +import os +import random +from pathlib import Path + +import numpy as np +import torch +from diffusers import FluxPipeline + +# Import the provided dataloader builder +from dfm.src.automodel.datasets.multiresolutionDataloader import build_flux_multiresolution_dataloader + + +def parse_args(): + parser = argparse.ArgumentParser(description="FLUX Inference with pre-computed embeddings") + + parser.add_argument( + "--model_id", + type=str, + default="black-forest-labs/FLUX.1-dev", + help="Base FLUX model ID from HuggingFace", + ) + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="Path to checkpoint directory containing model/ subfolder or consolidated weights", + ) + parser.add_argument( + "--use-original", + action="store_true", + help="Use original FLUX model without loading custom checkpoint", + ) + parser.add_argument( + "--data_path", + type=str, + required=True, + help="Path to the dataset cache directory", + ) + parser.add_argument( + "--max_samples", + type=int, + default=5, + help="Maximum number of images to generate", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./inference_outputs", + help="Directory to save generated images", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=28, + help="Number of inference steps", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="Guidance scale", + ) + parser.add_argument( + "--height", + type=int, + default=512, + help="Image height", + ) + parser.add_argument( + "--width", + type=int, + default=512, + help="Image width", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["float16", "bfloat16", "float32"], + help="Data type for model", + ) + parser.add_argument( + "--num_workers", + type=int, + default=2, + help="Number of workers for the dataloader", + ) + + return parser.parse_args() + + +def load_sharded_checkpoint(transformer, checkpoint_dir, device="cuda"): + import torch.distributed as dist + from torch.distributed.checkpoint import FileSystemReader + from torch.distributed.checkpoint import load as dist_load + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import StateDictType + from torch.distributed.fsdp.api import ShardedStateDictConfig + + sharded_dir = os.path.join(checkpoint_dir, "model") + if not os.path.isdir(sharded_dir): + raise FileNotFoundError(f"Model directory not found: {sharded_dir}") + + # Initialize a single-process group if needed + init_dist = False + if not dist.is_initialized(): + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", "29500") + dist.init_process_group(backend="gloo", rank=0, world_size=1) + init_dist = True + + try: + transformer = transformer.to(device=device, dtype=torch.bfloat16) + fsdp_transformer = FSDP(transformer, use_orig_params=True, device_id=torch.device(device)) + + FSDP.set_state_dict_type( + fsdp_transformer, + StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(offload_to_cpu=True), + ) + + model_state = fsdp_transformer.state_dict() + dist_load(state_dict=model_state, storage_reader=FileSystemReader(sharded_dir)) + fsdp_transformer.load_state_dict(model_state) + transformer = fsdp_transformer.module + print("[INFO] ✅ Successfully loaded sharded FSDP checkpoint") + finally: + if init_dist: + dist.destroy_process_group() + return transformer + + +def load_consolidated_checkpoint(transformer, checkpoint_path): + print(f"[INFO] Loading consolidated checkpoint from {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location="cpu") + if "model_state_dict" in state_dict: + state_dict = state_dict["model_state_dict"] + transformer.load_state_dict(state_dict, strict=True) + print("[INFO] ✅ Loaded consolidated checkpoint") + return transformer + + +def main(): + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + args = parse_args() + + if args.seed is not None: + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} + torch_dtype = dtype_map[args.dtype] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # --- 1. Initialize Dataloader --- + print("=" * 80) + print(f"Initializing Multiresolution Dataloader: {args.data_path}") + + dataloader, _ = build_flux_multiresolution_dataloader( + cache_dir=args.data_path, batch_size=1, num_workers=args.num_workers, dynamic_batch_size=True, shuffle=False + ) + print(f"[INFO] Dataloader ready. Batches: {len(dataloader)}") + + # --- 2. Initialize Model --- + use_original = args.use_original or args.checkpoint is None + + print(f"\n[Pipeline] Loading FLUX pipeline from: {args.model_id}") + pipe = FluxPipeline.from_pretrained(args.model_id, torch_dtype=torch_dtype) + + if not use_original: + checkpoint_dir = Path(args.checkpoint) + model_name = checkpoint_dir.name + sharded_dir = checkpoint_dir / "model" + consolidated_path = checkpoint_dir / "consolidated_model.bin" + ema_path = checkpoint_dir / "ema_shadow.pt" + + if ema_path.exists(): + print("[INFO] Loading EMA checkpoint...") + pipe.transformer.load_state_dict(torch.load(ema_path, map_location="cpu")) + elif consolidated_path.exists(): + pipe.transformer = load_consolidated_checkpoint(pipe.transformer, str(consolidated_path)) + elif sharded_dir.exists(): + pipe.transformer = load_sharded_checkpoint(pipe.transformer, str(checkpoint_dir), device=device) + else: + model_name = "original" + + pipe.enable_model_cpu_offload() + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # --- 3. Inference Loop (Injecting Embeddings) --- + print(f"\n[Inference] Generating {args.max_samples} samples using pre-computed embeddings...") + generator = torch.Generator(device="cpu").manual_seed(args.seed) + + count = 0 + for batch_idx, batch in enumerate(dataloader): + if count >= args.max_samples: + break + + try: + # Extract metadata for logging/filenames + current_prompt = batch["metadata"]["prompts"][0] + source_path = batch["metadata"]["image_paths"][0] + + # Extract and move embeddings to device/dtype + # batch['text_embeddings'] corresponds to T5 output + # batch['pooled_prompt_embeds'] corresponds to CLIP pooled output + prompt_embeds = batch["text_embeddings"].to(device=device, dtype=torch_dtype) + pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(device=device, dtype=torch_dtype) + + except (KeyError, IndexError) as e: + print(f"[WARN] Batch {batch_idx} missing required data. Skipping. Error: {e}") + continue + + print(f"\n--- Sample {count + 1}/{args.max_samples} ---") + print(f" Source: {os.path.basename(source_path)}") + print(f" Prompt: {current_prompt[:80]}...") + + with torch.no_grad(): + # Pass embeddings directly to bypass internal encoders + output = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + height=args.height, + width=args.width, + generator=generator, + ) + + # Save output + image = output.images[0] + safe_prompt = ( + "".join(c if c.isalnum() or c in " _-" else "" for c in current_prompt)[:50].strip().replace(" ", "_") + ) + output_filename = f"flux_{model_name}_sample{count:03d}_{safe_prompt}.png" + image.save(output_dir / output_filename) + print(f" ✅ Saved to: {output_filename}") + + count += 1 + + print("\n" + "=" * 80 + "\nInference complete!\n" + "=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/automodel/pretrain/flux_t2i_flow.yaml b/examples/automodel/pretrain/flux_t2i_flow.yaml index ea5be3f9..582bfdea 100644 --- a/examples/automodel/pretrain/flux_t2i_flow.yaml +++ b/examples/automodel/pretrain/flux_t2i_flow.yaml @@ -2,7 +2,6 @@ model: pretrained_model_name_or_path: "black-forest-labs/FLUX.1-dev" mode: "pretrain" cache_dir: null - attention_backend: "_flash_3_hub" pipeline_spec: transformer_cls: "FluxTransformer2DModel" @@ -11,7 +10,7 @@ model: enable_gradient_checkpointing: false optim: - learning_rate: 1e-5 + learning_rate: 2e-5 optimizer: weight_decay: 0.01 @@ -23,7 +22,7 @@ lr_scheduler: min_lr: 1e-6 fsdp: - dp_size: 8 + dp_size: 32 tp_size: 1 cp_size: 1 pp_size: 1 @@ -33,45 +32,50 @@ fsdp: flow_matching: adapter_type: "flux" adapter_kwargs: - guidance_scale: 3.5 - use_guidance_embeds: true - timestep_sampling: "logit_normal" + #Critical: use 1 guidance scale for pretraining + guidance_scale: 1 + use_guidance_embeds: false + timestep_sampling: "uniform" logit_mean: 0.0 logit_std: 1.0 - flow_shift: 3.0 + flow_shift: 1.0 mix_uniform_ratio: 0.1 sigma_min: 0.0 sigma_max: 1.0 num_train_timesteps: 1000 i2v_prob: 0.0 - use_loss_weighting: true + #Critical: use_loss_weighting needs to be false for pretraining + use_loss_weighting: false log_interval: 100 summary_log_interval: 10 step_scheduler: - num_epochs: 5000 - local_batch_size: 1 - global_batch_size: 8 - ckpt_every_steps: 2000 + num_epochs: 500000 + local_batch_size: 2 + global_batch_size: 64 + ckpt_every_steps: 1000 log_every: 1 data: dataloader: _target_: dfm.src.automodel.datasets.multiresolutionDataloader.build_flux_multiresolution_dataloader - cache_dir: /lustre/fsw/portfolios/coreai/users/pthombre/Automodel/FluxTraining/DFM/FluxData512Full/ + cache_dir: PATH_TO_YOUR_DATA train_text_encoder: false - num_workers: 10 - base_resolution: [512, 512] + num_workers: 1 + # Supported resolutions include [256×256], [512×512], and [1024×1024]. + # While a 1:1 aspect ratio is currently used as a proxy for the closest image size, + # the implementation is designed to support multiple aspect ratios. + base_resolution: [256, 256] dynamic_batch_size: false shuffle: true drop_last: false checkpoint: enabled: true - checkpoint_dir: /lustre/fsw/portfolios/coreai/users/pthombre/Automodel/FluxTraining/DFM/flux_ddp_test/ + checkpoint_dir: PATH_TO_YOUR_CKPT_DIR model_save_format: torch_save save_consolidated: false - restore_from: null + restore_from: wandb: project: flux-pretraining diff --git a/tests/functional_tests/automodel/flux/test_flux_multiresolution_dataloader.py b/tests/functional_tests/automodel/flux/test_flux_multiresolution_dataloader.py new file mode 100644 index 00000000..2a6474a0 --- /dev/null +++ b/tests/functional_tests/automodel/flux/test_flux_multiresolution_dataloader.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os + +from dfm.src.automodel.datasets.multiresolutionDataloader import build_flux_multiresolution_dataloader + + +def test_real_dataloader(cache_path: str): + # Configure logging to see the initialization details + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + if not os.path.exists(cache_path): + print(f"ERROR: Cache directory not found at {cache_path}") + return + + try: + # 1. Initialize the real dataloader + dataloader, sampler = build_flux_multiresolution_dataloader( + cache_dir=cache_path, + batch_size=2, # Small batch for printing + num_workers=2, # Use a couple of workers to test multi-processing + dynamic_batch_size=True, # Test the bucket logic + shuffle=True, + ) + + print("\n" + "=" * 50) + print("DATALOADER LOADED SUCCESSFULLY") + print(f"Total Batches: {len(dataloader)}") + print("=" * 50 + "\n") + + # 2. Iterate through the first 2 batches + pathes = [] + for batch_idx, batch in enumerate(dataloader): + # if batch_idx >= 2: # Stop after 2 batches to avoid flooding the console + # break + + print(f"--- Batch {batch_idx} ---") + print(f"Keys in batch: {list(batch.keys())}") + + # Print Tensor Shapes + print(f"Image Latents Shape: {batch['image_latents'].shape} (B, C, H, W)") + + if "text_embeddings" in batch: + print(f"Text Embeds Shape: {batch['text_embeddings']}") + print(f"Pooled Embeds Shape: {batch['pooled_prompt_embeds']}") + + # Print Metadata for the first sample in the batch + metadata = batch.get("metadata", {}) + print("\nSample Metadata (First item in batch):") + print(f" - Prompt: {metadata['prompts'][0][:100]}...") # Truncated + print(f" - Path: {metadata['image_paths'][0]}") + print(f" - Res: {metadata['original_resolution'][0]} -> {metadata['crop_resolution'][0]}") + print(f" - Aspect: {metadata['aspect_ratios'][0]}") + print("-" * 30 + "\n") + pathes.append(metadata["image_paths"][0]) + unique_paths = list(set(pathes)) + print(f"Total paths: {len(pathes)}") + print(f"Unique paths: {len(unique_paths)}") + + except Exception as e: + logging.error(f"Failed to run dataloader: {e}", exc_info=True) + + +if __name__ == "__main__": + # SET YOUR ACTUAL PATH HERE + MY_CACHE_DIR = "/linnanw/Diffuser/FLUX/DATA" + + test_real_dataloader(MY_CACHE_DIR)