From 2b9e811e30e9e786c055fd1488802a3c3ce61ee7 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 4 Mar 2026 13:42:06 -0700 Subject: [PATCH] add stop+go tests to llama3 recipe, turn off async checkpointing for fp8 Signed-off-by: Peter St. John --- .../recipes/llama3_native_te/checkpoint.py | 27 +- .../recipes/llama3_native_te/perf_logger.py | 4 +- .../llama3_native_te/tests/conftest.py | 51 + .../tests/test_distributed_checkpointing.py | 1240 +++++------------ .../tests/test_perf_logger.py | 2 +- .../recipes/llama3_native_te/train_ddp.py | 3 +- .../recipes/llama3_native_te/train_fsdp2.py | 2 +- .../llama3_native_te/train_fsdp2_cp.py | 2 +- 8 files changed, 459 insertions(+), 872 deletions(-) diff --git a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py index bd63390350..2dc5d10dcf 100644 --- a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py +++ b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py @@ -34,7 +34,9 @@ from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save from torch.distributed.checkpoint.state_dict_saver import save as dcp_save from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.tensor import DTensor from torchdata.stateful_dataloader import StatefulDataLoader +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from distributed_config import DistributedConfig @@ -115,8 +117,20 @@ def load_checkpoint_ddp( ckpt_path: str | os.PathLike, dist_config: DistributedConfig, dataloader: StatefulDataLoader | None = None, + weights_only: bool = True, ) -> CheckpointOutput: - """Load DDP checkpoint.""" + """Load DDP checkpoint. + + Args: + model: The model to load. + optimizer: The optimizer to load. + scheduler: The LR scheduler to load. + ckpt_path: The path to the checkpoint. + dist_config: The distributed configuration. + dataloader: The dataloader to load. + weights_only: Whether to load the checkpoint weights only. We have to set this to True when loading FP8 + checkpoints. + """ checkpoint_path, _ = get_latest_checkpoint(ckpt_path) if not checkpoint_path: @@ -126,7 +140,7 @@ def load_checkpoint_ddp( checkpoint = torch.load( checkpoint_path / "checkpoint.pt", map_location=f"cuda:{dist_config.local_rank}", - weights_only=True, + weights_only=weights_only, ) model.load_state_dict(checkpoint["model"]) @@ -221,6 +235,7 @@ class AppState(Stateful): def state_dict(self): """Get the state dict for the model, optimizer, scheduler, and step.""" model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) + model_state_dict = {k: v for k, v in model_state_dict.items() if not k.endswith("_extra_state")} return { "model": model_state_dict, "optim": optimizer_state_dict, @@ -236,6 +251,7 @@ def load_state_dict(self, state_dict: dict): self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"], + options=StateDictOptions(strict=False), ) self.scheduler.load_state_dict(state_dict["scheduler"]) self.step = state_dict["step"] @@ -322,6 +338,13 @@ def save_checkpoint_fsdp2( checkpoint_path = ckpt_path / f"step_{step}" checkpoint_path.mkdir(parents=True, exist_ok=True) + model_params = (p.to_local() if isinstance(p, DTensor) else p for p in model.parameters()) + if async_save and any((isinstance(p, QuantizedTensor) for p in model_params)): + logger.warning( + "Async checkpointing is not supported for FP8 models, falling back to synchronous checkpointing." + ) + async_save = False + if dataloader is not None: save_dataloader( dataloader=dataloader, diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 3d6f63f256..726eb19e8e 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -44,7 +44,7 @@ class PerfLogger: min_loss: The minimum loss seen so far. """ - def __init__(self, dist_config: DistributedConfig, args: DictConfig): + def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: int): """Initialize the logger.""" self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) @@ -75,7 +75,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): if self._dist_config.is_main_process(): # Log the entire args object to wandb for experiment tracking and reproducibility. self._wandb_run = wandb.init(**args.wandb, config=self._run_config) - self._progress_bar = tqdm(total=args.num_train_steps, desc="Training") + self._progress_bar = tqdm(initial=start_step, total=args.num_train_steps, desc="Training") if args.profiler.enabled: self._profiler = NsightProfiler( diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py index 87ba309ad7..08330b12f7 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py @@ -19,6 +19,7 @@ import pytest import torch +from transformer_engine.pytorch import fp8 as te_fp8 sys.path.append(Path(__file__).parent.parent.as_posix()) @@ -61,6 +62,56 @@ def pytest_collection_modifyitems(items): items[:] = stats_tests + other_tests +# --------------------------------------------------------------------------- +# FP8 recipe parametrization +# --------------------------------------------------------------------------- + +# Each entry: (recipe_class_name, hydra_overrides, check_fn) +_FP8_RECIPE_CONFIGS = [ + ( + "DelayedScaling", + ["fp8_config.fp8_recipe=transformer_engine.common.recipe.DelayedScaling"], + te_fp8.check_fp8_support, + ), + ( + "Float8CurrentScaling", + ["fp8_config.fp8_recipe=transformer_engine.common.recipe.Float8CurrentScaling"], + te_fp8.check_fp8_support, + ), + ( + "Float8BlockScaling", + ["fp8_config.fp8_recipe=transformer_engine.common.recipe.Float8BlockScaling"], + te_fp8.check_fp8_block_scaling_support, + ), + ( + "MXFP8BlockScaling", + ["fp8_config.fp8_recipe=transformer_engine.common.recipe.MXFP8BlockScaling"], + te_fp8.check_mxfp8_support, + ), +] + + +def _parametrize_fp8_recipes(): + """Generate pytest.param objects with xfail marks for unsupported FP8 recipes.""" + params = [] + for name, overrides, check_fn in _FP8_RECIPE_CONFIGS: + supported, reason = check_fn() + params.append( + pytest.param( + overrides, + id=name, + marks=pytest.mark.xfail(condition=not supported, reason=reason), + ) + ) + return params + + +@pytest.fixture(params=_parametrize_fp8_recipes()) +def fp_recipe(request): + """Parametrized fixture providing FP8 recipe Hydra overrides for each supported TE recipe.""" + return request.param + + @pytest.fixture(scope="session", autouse=True) def device_mesh(): """Create a re-usable torch process group for testing. diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py index 098223291d..94504095ee 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py @@ -19,11 +19,13 @@ distributed training configurations: - DDP (Distributed Data Parallel) with 1 and 2 processes - FSDP2 (PyTorch native Fully Sharded Data Parallel v2) with 1 and 2 processes +- FSDP2 with context parallelism +- FP8 quantized model init with checkpoint save/resume Test Strategy: 1. Phase 1: Train for N steps and save checkpoint 2. Phase 2: Resume training from checkpoint and continue -3. Validate: Checkpoints created, resuming works, training continues seamlessly +3. Validate: Checkpoints created, resuming works, training continues seamlessly, losses are valid Each test uses temporary directories and disables wandb logging for isolation. """ @@ -50,965 +52,390 @@ ) -def test_checkpoint_save_and_load_single_process_ddp(recipe_path, tmp_path): - """Test checkpoint save/resume functionality for DDP with single process. - - This test validates: - - DDP creates single-file checkpoints (checkpoint.pt files) - - Standard PyTorch checkpoint format (model + optimizer state) - - Single-process DDP training and resuming works correctly - - Checkpoint files contain complete model state - - Process: - 1. Train 10 steps (0-9), save checkpoint file at step 5 - 2. Resume training from checkpoint, continue to step 15 - 3. Verify checkpoint files exist at steps 5 and 10 - """ - temp_dir = str(tmp_path / "test_ckpt_ddp") - - # Phase 1: Train for 10 steps, saving a checkpoint at step 5 - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase1_config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=10", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=false", # Start fresh - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ], - ) +# --------------------------------------------------------------------------- +# Test Utilities +# --------------------------------------------------------------------------- - main_ddp(phase1_config) - gc.collect() - torch.cuda.empty_cache() - # Phase 1 creates this directory structure: - # ckpt_subdir/ - # └── step_5/ - # ├── checkpoint.pt - # └── dataloader_step_5_rank_0_num_workers_1.pt +def _compose_config(recipe_path, tmp_path, config_name, overrides): + """Compose a Hydra config with standard checkpoint-test settings. - # Checkpoints are saved in a subdirectory named after the script - ckpt_subdir = os.path.join(temp_dir, "train_ddp") - assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" + Every config gets ``checkpoint.ckpt_dir``, ``+wandb.dir``, and + ``dataset.use_stateful_dataloader`` set automatically so that callers + only need to supply test-specific overrides. + """ + ckpt_dir = str(tmp_path / "ckpt") + base = [ + f"checkpoint.ckpt_dir={ckpt_dir}", + f"+wandb.dir={tmp_path}", + "dataset.use_stateful_dataloader=true", + ] + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + return compose(config_name=config_name, overrides=base + list(overrides or [])) - # Verify step_5 checkpoint was created - step_5_dir = os.path.join(ckpt_subdir, "step_5") - # Check step_5 directory exists and contains expected files - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - assert len(step_5_files) == 2, f"Expected 2 files in step_5 directory, found {len(step_5_files)}: {step_5_files}" - assert "checkpoint.pt" in step_5_files, f"checkpoint.pt not found in step_5 directory. Files found: {step_5_files}" - assert any("dataloader" in f for f in step_5_files), ( - f"No dataloader file found in step_5 directory. Files found: {step_5_files}" - ) +def _assert_loss_valid(loss, label=""): + """Assert that a training loss is finite and not NaN.""" + tag = f" ({label})" if label else "" + assert loss is not None, f"Loss is None{tag}" + loss_val = float(loss) + assert not torch.isnan(torch.tensor(loss_val)), f"Loss is NaN{tag}" + assert torch.isfinite(torch.tensor(loss_val)), f"Loss is not finite: {loss_val}{tag}" - # Verify the actual checkpoint files are valid files - assert os.path.isfile(os.path.join(step_5_dir, "checkpoint.pt")), "step_5/checkpoint.pt is not a valid file" - # Check that only step_5 exists at this point (no step_10 yet) - all_step_dirs = [d for d in os.listdir(ckpt_subdir) if d.startswith("step_")] - assert len(all_step_dirs) == 1, ( - f"Expected only 1 checkpoint directory after phase 1, found {len(all_step_dirs)}: {all_step_dirs}" - ) - assert all_step_dirs[0] == "step_5", f"Expected only step_5 after phase 1, found: {all_step_dirs}" +def _assert_checkpoint_step(ckpt_subdir, step, num_ranks=1, is_ddp=True): + """Assert that a checkpoint step directory has the expected files. - # Phase 2: Resume training (should start from step 5, continue to step 15) - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase2_config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ], + For DDP checks for ``checkpoint.pt`` and exact file counts. + For FSDP2 (DCP format) only checks for per-rank dataloader files. + """ + step_dir = os.path.join(ckpt_subdir, f"step_{step}") + assert os.path.isdir(step_dir), f"Step {step} directory not found: {step_dir}" + files = os.listdir(step_dir) + + if is_ddp: + expected_count = 1 + num_ranks # checkpoint.pt + one dataloader per rank + assert len(files) == expected_count, ( + f"Expected {expected_count} files in step_{step}, found {len(files)}: {files}" ) + assert "checkpoint.pt" in files, f"checkpoint.pt not in step_{step}: {files}" + assert os.path.isfile(os.path.join(step_dir, "checkpoint.pt")) - main_ddp(phase2_config) - gc.collect() - torch.cuda.empty_cache() - - # Phase 2 adds to the directory structure: - # ckpt_subdir/ - # ├── step_5/ - # │ ├── checkpoint.pt - # │ └── dataloader_step_5_rank_0_num_workers_1.pt - # └── step_10/ - # ├── checkpoint.pt - # └── dataloader_step_10_rank_0_num_workers_1.pt - - # Verify the checkpoint files exist in the correct directories - step_5_dir = os.path.join(ckpt_subdir, "step_5") - step_10_dir = os.path.join(ckpt_subdir, "step_10") - - # Check step_5 directory and files - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - assert "checkpoint.pt" in step_5_files, f"checkpoint.pt not found in step_5 directory. Files found: {step_5_files}" - assert any("dataloader" in f for f in step_5_files), ( - f"No dataloader file found in step_5 directory. Files found: {step_5_files}" - ) - - # Check step_10 directory and files - assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" - step_10_files = os.listdir(step_10_dir) - assert "checkpoint.pt" in step_10_files, ( - f"checkpoint.pt not found in step_10 directory. Files found: {step_10_files}" - ) - assert any("dataloader" in f for f in step_10_files), ( - f"No dataloader file found in step_10 directory. Files found: {step_10_files}" + dataloader_files = [f for f in files if "dataloader" in f] + assert len(dataloader_files) >= num_ranks, ( + f"Expected >= {num_ranks} dataloader files in step_{step}, found {len(dataloader_files)}: {dataloader_files}" ) + for rank in range(num_ranks): + assert any(f"rank_{rank}" in f for f in dataloader_files), ( + f"No dataloader file for rank {rank} in step_{step}: {dataloader_files}" + ) - # Verify the actual checkpoint files are valid files - assert os.path.isfile(os.path.join(step_5_dir, "checkpoint.pt")), "step_5/checkpoint.pt is not a valid file" - assert os.path.isfile(os.path.join(step_10_dir, "checkpoint.pt")), "step_10/checkpoint.pt is not a valid file" - # Final check: we should have exactly 2 checkpoint directories (step_5 and step_10) - all_step_dirs = [d for d in os.listdir(ckpt_subdir) if d.startswith("step_")] - assert len(all_step_dirs) == 2, f"Expected 2 checkpoint directories, found {len(all_step_dirs)}: {all_step_dirs}" - assert set(all_step_dirs) == {"step_5", "step_10"}, f"Expected step_5 and step_10, found: {all_step_dirs}" +def _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fn, + ckpt_subdir_name, + config_name="L0_sanity", + extra_overrides=None, + is_ddp=True, +): + """Run a two-phase checkpoint save/resume test in a single process. + Phase 1 trains for 10 steps (saving at step 5), phase 2 resumes and + continues to step 15 (saving at step 10). Both phases validate that + checkpoints are created correctly and that losses are finite. -@requires_multi_gpu -def test_checkpoint_save_and_load_two_processes_ddp(recipe_path, tmp_path): - """Test checkpoint save/resume functionality for DDP with two processes. - - This test validates: - - Multi-process DDP checkpoint behavior (main process saves only) - - Checkpoint files can be loaded by all DDP processes - - Process synchronization during resume (all processes load same checkpoint) - - DDP training continues correctly after resume across processes - - Process: - 1. Train 10 steps (0-9) across 2 processes, main process saves checkpoint at step 5 - 2. Resume training with 2 processes, all load same checkpoint file, continue to step 15 - 3. Verify checkpoint files exist at steps 5 and 10 + Returns: + Tuple of (phase1_loss, phase2_loss). """ - temp_dir = str(tmp_path / "test_ckpt_ddp_2p") - - # Set environment for subprocess - env = os.environ.copy() - env["WANDB_MODE"] = "disabled" - - # Get the full path to train_ddp.py - train_script = recipe_path / "train_ddp.py" - - # Phase 1: Train for 10 steps with 2 processes - cmd_phase1 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=10", + ckpt_dir = str(tmp_path / "ckpt") + common = [ "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=false", # Start fresh - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing + "checkpoint.async_save=false", + *(extra_overrides or []), ] - result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env) - assert result1.returncode == 0, f"Phase 1 failed: {result1.stderr}" - - # Phase 1 creates this directory structure with 2 processes: - # ckpt_subdir/ - # └── step_5/ - # ├── checkpoint.pt - # ├── dataloader_step_5_rank_0_num_workers_1.pt - # └── dataloader_step_5_rank_1_num_workers_1.pt - - # Checkpoints are saved in a subdirectory named after the script - ckpt_subdir = os.path.join(temp_dir, "train_ddp") - assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" - - # Verify step_5 checkpoint was created - step_5_dir = os.path.join(ckpt_subdir, "step_5") - - # Check step_5 directory exists and contains expected files - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - - # With 2 processes, we expect: 1 checkpoint.pt + 2 dataloader files (one per rank) - assert len(step_5_files) == 3, ( - f"Expected 3 files in step_5 directory (1 checkpoint + 2 dataloaders), found {len(step_5_files)}: {step_5_files}" + # Phase 1: train 10 steps, checkpoint at step 5 + cfg1 = _compose_config( + recipe_path, + tmp_path, + config_name, + [ + "num_train_steps=10", + "checkpoint.resume_from_checkpoint=false", + *common, + ], ) - assert "checkpoint.pt" in step_5_files, f"checkpoint.pt not found in step_5 directory. Files found: {step_5_files}" - # Check for dataloader files for both ranks - dataloader_files = [f for f in step_5_files if "dataloader" in f] - assert len(dataloader_files) == 2, ( - f"Expected 2 dataloader files (rank 0 and 1), found {len(dataloader_files)}: {dataloader_files}" - ) - - # Verify we have dataloader files for both rank 0 and rank 1 - assert any("rank_0" in f for f in dataloader_files), ( - f"No dataloader file for rank 0 found. Files: {dataloader_files}" - ) - assert any("rank_1" in f for f in dataloader_files), ( - f"No dataloader file for rank 1 found. Files: {dataloader_files}" - ) - - # Verify the actual checkpoint file is valid - assert os.path.isfile(os.path.join(step_5_dir, "checkpoint.pt")), "step_5/checkpoint.pt is not a valid file" - - # Check that only step_5 exists at this point (no step_10 yet) - all_step_dirs = [d for d in os.listdir(ckpt_subdir) if d.startswith("step_")] - assert len(all_step_dirs) == 1, ( - f"Expected only 1 checkpoint directory after phase 1, found {len(all_step_dirs)}: {all_step_dirs}" - ) - assert all_step_dirs[0] == "step_5", f"Expected only step_5 after phase 1, found: {all_step_dirs}" - - # Phase 2: Resume training with 2 processes - cmd_phase2 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ] - - result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env) - assert result2.returncode == 0, f"Phase 2 failed: {result2.stderr}" + loss1 = main_fn(cfg1) + gc.collect() + torch.cuda.empty_cache() - # Phase 2 adds to the directory structure: - # ckpt_subdir/ - # ├── step_5/ - # │ ├── checkpoint.pt - # │ ├── dataloader_step_5_rank_0_num_workers_1.pt - # │ └── dataloader_step_5_rank_1_num_workers_1.pt - # └── step_10/ - # ├── checkpoint.pt - # ├── dataloader_step_10_rank_0_num_workers_1.pt - # └── dataloader_step_10_rank_1_num_workers_1.pt - - # Verify step_10 checkpoint was created - step_10_dir = os.path.join(ckpt_subdir, "step_10") - - # Check step_10 directory exists and contains expected files - assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" - step_10_files = os.listdir(step_10_dir) - - # With 2 processes, we expect: 1 checkpoint.pt + 2 dataloader files (one per rank) - assert len(step_10_files) == 3, ( - f"Expected 3 files in step_10 directory (1 checkpoint + 2 dataloaders), found {len(step_10_files)}: {step_10_files}" - ) - assert "checkpoint.pt" in step_10_files, ( - f"checkpoint.pt not found in step_10 directory. Files found: {step_10_files}" - ) + ckpt_subdir = os.path.join(ckpt_dir, ckpt_subdir_name) + assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" + _assert_checkpoint_step(ckpt_subdir, 5, num_ranks=1, is_ddp=is_ddp) - # Check for dataloader files for both ranks - dataloader_files_10 = [f for f in step_10_files if "dataloader" in f] - assert len(dataloader_files_10) == 2, ( - f"Expected 2 dataloader files (rank 0 and 1), found {len(dataloader_files_10)}: {dataloader_files_10}" - ) + step_dirs = sorted(d for d in os.listdir(ckpt_subdir) if d.startswith("step_")) + assert step_dirs == ["step_5"], f"Expected only step_5 after phase 1, found: {step_dirs}" - # Verify we have dataloader files for both rank 0 and rank 1 - assert any("rank_0" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 0 found in step_10. Files: {dataloader_files_10}" + # Phase 2: resume and continue to step 15, checkpoint at step 10 + cfg2 = _compose_config( + recipe_path, + tmp_path, + config_name, + [ + "num_train_steps=15", + "checkpoint.resume_from_checkpoint=true", + *common, + ], ) - assert any("rank_1" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 1 found in step_10. Files: {dataloader_files_10}" - ) - - # Verify the actual checkpoint file is valid - assert os.path.isfile(os.path.join(step_10_dir, "checkpoint.pt")), "step_10/checkpoint.pt is not a valid file" - # Final check: we should have exactly 2 checkpoint directories (step_5 and step_10) - all_step_dirs = [d for d in os.listdir(ckpt_subdir) if d.startswith("step_")] - assert len(all_step_dirs) == 2, f"Expected 2 checkpoint directories, found {len(all_step_dirs)}: {all_step_dirs}" - assert set(all_step_dirs) == {"step_5", "step_10"}, f"Expected step_5 and step_10, found: {all_step_dirs}" - - -def test_checkpoint_save_and_load_single_process_fsdp2(recipe_path, tmp_path): - """Test checkpoint save/resume functionality for FSDP2 with single process. - - This test validates: - - FSDP2 creates distributed checkpoints (step_X directories by default) - - Each rank saves its shard (even with single process) - - Dataloader state is saved alongside model checkpoint - - Training can resume from latest checkpoint and continue - - Resume starts from correct step count - - Process: - 1. Train 10 steps (0-9), save checkpoint at step 5 - 2. Resume training from step 5, continue to step 15 - 3. Verify checkpoints exist at steps 5 and 10 - """ - temp_dir = str(tmp_path / "test_ckpt_fsdp2") - - # Phase 1: Train for 10 steps (using distributed checkpoint by default) - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase1_config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=10", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=false", # Start fresh - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - "checkpoint.async_save=false", - ], - ) - - main_fsdp2(phase1_config) + loss2 = main_fn(cfg2) gc.collect() torch.cuda.empty_cache() - # Checkpoints are saved in a subdirectory named after the script - ckpt_subdir = os.path.join(temp_dir, "train_fsdp2") - assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" + _assert_checkpoint_step(ckpt_subdir, 5, num_ranks=1, is_ddp=is_ddp) + _assert_checkpoint_step(ckpt_subdir, 10, num_ranks=1, is_ddp=is_ddp) - # Verify checkpoint was created (FSDP2 creates directories by default) - checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - assert len(checkpoint_dirs) > 0, "No checkpoint directories created in phase 1" + step_dirs = sorted(d for d in os.listdir(ckpt_subdir) if d.startswith("step_")) + assert set(step_dirs) == {"step_5", "step_10"}, f"Expected step_5 and step_10, found: {step_dirs}" - # Check that checkpoint at step 5 exists - expected_checkpoint = "step_5" - assert expected_checkpoint in checkpoint_dirs, f"Expected {expected_checkpoint} not found" + # Validate losses are finite and not NaN + _assert_loss_valid(loss1, "phase 1") + _assert_loss_valid(loss2, "phase 2") - # Check dataloader file exists in step_5 directory - step_5_dir = os.path.join(ckpt_subdir, "step_5") - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) + return loss1, loss2 - # With single process, we expect dataloader file for rank 0 - dataloader_files_5 = [f for f in step_5_files if "dataloader" in f] - assert len(dataloader_files_5) >= 1, ( - f"Expected at least 1 dataloader file, found {len(dataloader_files_5)}: {dataloader_files_5}" - ) - assert any("rank_0" in f for f in dataloader_files_5), ( - f"No dataloader file for rank 0 found in step_5. Files: {dataloader_files_5}" - ) - # Phase 2: Resume training - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase2_config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - # Sometimes the checkpoint hasn't finished saving by the time we resume training, so we disable async - # save for this test. - "checkpoint.async_save=false", - ], - ) +def _run_multi_process_checkpoint_test( + recipe_path, + tmp_path, + train_script_name, + ckpt_subdir_name, + nproc=2, + extra_overrides=None, + is_ddp=True, +): + """Run a two-phase checkpoint save/resume test using ``torchrun``. - main_fsdp2(phase2_config) - gc.collect() - torch.cuda.empty_cache() - - # Verify phase 2 completed and created additional checkpoints - final_checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - expected_checkpoints = ["step_5", "step_10"] - for expected in expected_checkpoints: - assert expected in final_checkpoint_dirs, f"Missing checkpoint: {expected}" - - # Check dataloader file exists in step_10 directory - step_10_dir = os.path.join(ckpt_subdir, "step_10") - assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" - step_10_files = os.listdir(step_10_dir) - - # With single process, we expect dataloader file for rank 0 - dataloader_files_10 = [f for f in step_10_files if "dataloader" in f] - assert len(dataloader_files_10) >= 1, ( - f"Expected at least 1 dataloader file in step_10, found {len(dataloader_files_10)}: {dataloader_files_10}" - ) - assert any("rank_0" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 0 found in step_10. Files: {dataloader_files_10}" - ) - - -@requires_multi_gpu -def test_checkpoint_save_and_load_two_processes_fsdp2(recipe_path, tmp_path): - """Test checkpoint save/resume functionality for FSDP2 with two processes. - - This test validates: - - Multi-process FSDP2 distributed checkpointing (each rank saves its shard) - - Dataloader state is saved for each rank alongside model checkpoint - - All ranks participate in saving and loading - - Training resumes correctly with proper process synchronization - - Process: - 1. Train 10 steps (0-9) across 2 processes, save checkpoint at step 5 - 2. Resume training with 2 processes from step 5, continue to step 15 - 3. Verify checkpoints exist at steps 5 and 10 with dataloader files for both ranks + Same two-phase strategy as :func:`_run_single_process_checkpoint_test` + but spawns *nproc* processes via ``torchrun``. """ - temp_dir = str(tmp_path / "test_ckpt_fsdp2_2p") - - # Set environment for subprocess + ckpt_dir = str(tmp_path / "ckpt") env = os.environ.copy() env["WANDB_MODE"] = "disabled" - # Get the full path to train_fsdp2.py - train_script = recipe_path / "train_fsdp2.py" - - # Phase 1: Train for 10 steps with 2 processes - cmd_phase1 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=10", + train_script = recipe_path / train_script_name + common = [ + f"checkpoint.ckpt_dir={ckpt_dir}", "checkpoint.save_every_n_steps=5", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing + "dataset.use_stateful_dataloader=true", + *(extra_overrides or []), ] - result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env) + base_cmd = ["torchrun", "--standalone", f"--nproc_per_node={nproc}", str(train_script)] + + # Phase 1 + result1 = subprocess.run( + [*base_cmd, "num_train_steps=10", "checkpoint.resume_from_checkpoint=false", *common], + check=False, + capture_output=True, + text=True, + env=env, + ) assert result1.returncode == 0, f"Phase 1 failed: {result1.stderr}" - # Checkpoints are saved in a subdirectory named after the script - ckpt_subdir = os.path.join(temp_dir, "train_fsdp2") + ckpt_subdir = os.path.join(ckpt_dir, ckpt_subdir_name) assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" + _assert_checkpoint_step(ckpt_subdir, 5, num_ranks=nproc, is_ddp=is_ddp) - # Verify checkpoint was created (FSDP2 creates directories by default) - checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - assert len(checkpoint_dirs) > 0, "No checkpoint directories created in phase 1" - - # Check that checkpoint at step 5 exists - expected_checkpoint = "step_5" - assert expected_checkpoint in checkpoint_dirs, f"Expected {expected_checkpoint} not found" + step_dirs = [d for d in os.listdir(ckpt_subdir) if d.startswith("step_")] + assert len(step_dirs) == 1, f"Expected 1 checkpoint dir after phase 1, found: {step_dirs}" - # Check dataloader files exist in step_5 directory for both ranks - step_5_dir = os.path.join(ckpt_subdir, "step_5") - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - - # With 2 processes, we expect dataloader files for rank 0 and rank 1 - dataloader_files_5 = [f for f in step_5_files if "dataloader" in f] - assert len(dataloader_files_5) == 2, ( - f"Expected 2 dataloader files (rank 0 and 1), found {len(dataloader_files_5)}: {dataloader_files_5}" - ) - assert any("rank_0" in f for f in dataloader_files_5), ( - f"No dataloader file for rank 0 found in step_5. Files: {dataloader_files_5}" - ) - assert any("rank_1" in f for f in dataloader_files_5), ( - f"No dataloader file for rank 1 found in step_5. Files: {dataloader_files_5}" + # Phase 2 + result2 = subprocess.run( + [*base_cmd, "num_train_steps=15", "checkpoint.resume_from_checkpoint=true", *common], + check=False, + capture_output=True, + text=True, + env=env, ) - - # Phase 2: Resume training with 2 processes - cmd_phase2 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ] - - result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env) assert result2.returncode == 0, f"Phase 2 failed: {result2.stderr}" - # Verify phase 2 completed and created additional checkpoints - final_checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - expected_checkpoints = ["step_5", "step_10"] - for expected in expected_checkpoints: - assert expected in final_checkpoint_dirs, f"Missing checkpoint: {expected}" - - # Check dataloader files exist in step_10 directory for both ranks - step_10_dir = os.path.join(ckpt_subdir, "step_10") - assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" - step_10_files = os.listdir(step_10_dir) - - # With 2 processes, we expect dataloader files for rank 0 and rank 1 - dataloader_files_10 = [f for f in step_10_files if "dataloader" in f] - assert len(dataloader_files_10) == 2, ( - f"Expected 2 dataloader files (rank 0 and 1) in step_10, found {len(dataloader_files_10)}: {dataloader_files_10}" - ) - assert any("rank_0" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 0 found in step_10. Files: {dataloader_files_10}" - ) - assert any("rank_1" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 1 found in step_10. Files: {dataloader_files_10}" - ) - - -def test_checkpoint_save_and_load_single_process_fsdp2_with_context_parallelism(recipe_path, tmp_path): - """Test checkpoint save/resume functionality for FSDP2 with single process and context parallelism. - - This test validates: - - FSDP2 creates distributed checkpoints (step_X directories by default) - - Each rank saves its shard (even with single process) - - Dataloader state is saved alongside model checkpoint - - Training can resume from latest checkpoint and continue - - Resume starts from correct step count - - Process: - 1. Train 10 steps (0-9), save checkpoint at step 5 - 2. Resume training from step 5, continue to step 15 - 3. Verify checkpoints exist at steps 5 and 10 - """ - temp_dir = str(tmp_path / "test_ckpt_fsdp2_cp") - - # Phase 1: Train for 10 steps (using distributed checkpoint by default) - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase1_config = compose( - config_name="L0_sanity_cp", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=10", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=false", # Start fresh - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - "checkpoint.async_save=false", - ], - ) - - main_fsdp2_cp(phase1_config) - gc.collect() - torch.cuda.empty_cache() + _assert_checkpoint_step(ckpt_subdir, 5, num_ranks=nproc, is_ddp=is_ddp) + _assert_checkpoint_step(ckpt_subdir, 10, num_ranks=nproc, is_ddp=is_ddp) - # Checkpoints are saved in a subdirectory named after the script - ckpt_subdir = os.path.join(temp_dir, "train_fsdp2") - assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" + step_dirs = sorted(d for d in os.listdir(ckpt_subdir) if d.startswith("step_")) + assert set(step_dirs) == {"step_5", "step_10"}, f"Expected step_5 and step_10, found: {step_dirs}" - # Verify checkpoint was created (FSDP2 creates directories by default) - checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - assert len(checkpoint_dirs) > 0, "No checkpoint directories created in phase 1" - # Check that checkpoint at step 5 exists - expected_checkpoint = "step_5" - assert expected_checkpoint in checkpoint_dirs, f"Expected {expected_checkpoint} not found" +# --------------------------------------------------------------------------- +# DDP Checkpoint Tests +# --------------------------------------------------------------------------- - # Check dataloader file exists in step_5 directory - step_5_dir = os.path.join(ckpt_subdir, "step_5") - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - # With single process, we expect dataloader file for rank 0 - dataloader_files_5 = [f for f in step_5_files if "dataloader" in f] - assert len(dataloader_files_5) >= 1, ( - f"Expected at least 1 dataloader file, found {len(dataloader_files_5)}: {dataloader_files_5}" - ) - assert any("rank_0" in f for f in dataloader_files_5), ( - f"No dataloader file for rank 0 found in step_5. Files: {dataloader_files_5}" +def test_checkpoint_save_and_load_single_process_ddp(recipe_path, tmp_path): + """Test checkpoint save/resume for DDP with a single process.""" + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_ddp, + ckpt_subdir_name="train_ddp", + is_ddp=True, ) - # Phase 2: Resume training - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase2_config = compose( - config_name="L0_sanity_cp", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - # Sometimes the checkpoint hasn't finished saving by the time we resume training, so we disable async - # save for this test. - "checkpoint.async_save=false", - ], - ) - main_fsdp2_cp(phase2_config) - gc.collect() - torch.cuda.empty_cache() - - # Verify phase 2 completed and created additional checkpoints - final_checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - expected_checkpoints = ["step_5", "step_10"] - for expected in expected_checkpoints: - assert expected in final_checkpoint_dirs, f"Missing checkpoint: {expected}" - - # Check dataloader file exists in step_10 directory - step_10_dir = os.path.join(ckpt_subdir, "step_10") - assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" - step_10_files = os.listdir(step_10_dir) - - # With single process, we expect dataloader file for rank 0 - dataloader_files_10 = [f for f in step_10_files if "dataloader" in f] - assert len(dataloader_files_10) >= 1, ( - f"Expected at least 1 dataloader file in step_10, found {len(dataloader_files_10)}: {dataloader_files_10}" - ) - assert any("rank_0" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 0 found in step_10. Files: {dataloader_files_10}" +@requires_multi_gpu +def test_checkpoint_save_and_load_two_processes_ddp(recipe_path, tmp_path): + """Test checkpoint save/resume for DDP with two processes.""" + _run_multi_process_checkpoint_test( + recipe_path, + tmp_path, + "train_ddp.py", + ckpt_subdir_name="train_ddp", + is_ddp=True, ) -@requires_multi_gpu -def test_checkpoint_save_and_load_two_processes_fsdp2_with_context_parallelism(recipe_path, tmp_path): - """Test checkpoint save/resume functionality for FSDP2 with two processes. - - This test validates: - - Multi-process FSDP2 distributed checkpointing (each rank saves its shard) - - Dataloader state is saved for each rank alongside model checkpoint - - All ranks participate in saving and loading - - Training resumes correctly with proper process synchronization - - Process: - 1. Train 10 steps (0-9) across 2 processes with context parallelism, save checkpoint at step 5 - 2. Resume training with 2 processes from step 5, continue to step 15 - 3. Verify checkpoints exist at steps 5 and 10 with dataloader files for both ranks - """ - temp_dir = str(tmp_path / "test_ckpt_fsdp2_cp_2p") +# --------------------------------------------------------------------------- +# FSDP2 Checkpoint Tests +# --------------------------------------------------------------------------- - # Set environment for subprocess - env = os.environ.copy() - env["WANDB_MODE"] = "disabled" - # Get the full path to train_fsdp2.py - train_script = recipe_path / "train_fsdp2_cp.py" - - # Phase 1: Train for 10 steps with 2 processes - cmd_phase1 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=10", - "checkpoint.save_every_n_steps=5", - "checkpoint.async_save=false", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - "cp_size=2", - ] +def test_checkpoint_save_and_load_single_process_fsdp2(recipe_path, tmp_path): + """Test checkpoint save/resume for FSDP2 with a single process.""" + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fsdp2, + ckpt_subdir_name="train_fsdp2", + is_ddp=False, + ) - result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env) - assert result1.returncode == 0, f"Phase 1 failed: {result1.stderr}" - # Checkpoints are saved in a subdirectory named after the script - ckpt_subdir = os.path.join(temp_dir, "train_fsdp2") - assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" +@requires_multi_gpu +def test_checkpoint_save_and_load_two_processes_fsdp2(recipe_path, tmp_path): + """Test checkpoint save/resume for FSDP2 with two processes.""" + _run_multi_process_checkpoint_test( + recipe_path, + tmp_path, + "train_fsdp2.py", + ckpt_subdir_name="train_fsdp2", + is_ddp=False, + ) - # Verify checkpoint was created (FSDP2 creates directories by default) - checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - assert len(checkpoint_dirs) > 0, "No checkpoint directories created in phase 1" - # Check that checkpoint at step 5 exists - expected_checkpoint = "step_5" - assert expected_checkpoint in checkpoint_dirs, f"Expected {expected_checkpoint} not found" +# --------------------------------------------------------------------------- +# FSDP2 + Context Parallelism Checkpoint Tests +# --------------------------------------------------------------------------- - # Check dataloader files exist in step_5 directory for both ranks - step_5_dir = os.path.join(ckpt_subdir, "step_5") - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - # With 2 processes, we expect dataloader files for rank 0 and rank 1 - dataloader_files_5 = [f for f in step_5_files if "dataloader" in f] - assert len(dataloader_files_5) == 2, ( - f"Expected 2 dataloader files (rank 0 and 1), found {len(dataloader_files_5)}: {dataloader_files_5}" - ) - assert any("rank_0" in f for f in dataloader_files_5), ( - f"No dataloader file for rank 0 found in step_5. Files: {dataloader_files_5}" - ) - assert any("rank_1" in f for f in dataloader_files_5), ( - f"No dataloader file for rank 1 found in step_5. Files: {dataloader_files_5}" +def test_checkpoint_save_and_load_single_process_fsdp2_with_context_parallelism(recipe_path, tmp_path): + """Test checkpoint save/resume for FSDP2 with context parallelism (single process).""" + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fsdp2_cp, + ckpt_subdir_name="train_fsdp2", + config_name="L0_sanity_cp", + is_ddp=False, ) - # Phase 2: Resume training with 2 processes - cmd_phase2 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "checkpoint.async_save=false", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - "cp_size=2", - ] - result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env) - assert result2.returncode == 0, f"Phase 2 failed: {result2.stderr}" - - # Verify phase 2 completed and created additional checkpoints - final_checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - expected_checkpoints = ["step_5", "step_10"] - for expected in expected_checkpoints: - assert expected in final_checkpoint_dirs, f"Missing checkpoint: {expected}" - - # Check dataloader files exist in step_10 directory for both ranks - step_10_dir = os.path.join(ckpt_subdir, "step_10") - assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" - step_10_files = os.listdir(step_10_dir) - - # With 2 processes, we expect dataloader files for rank 0 and rank 1 - dataloader_files_10 = [f for f in step_10_files if "dataloader" in f] - assert len(dataloader_files_10) == 2, ( - f"Expected 2 dataloader files (rank 0 and 1) in step_10, found {len(dataloader_files_10)}: {dataloader_files_10}" - ) - assert any("rank_0" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 0 found in step_10. Files: {dataloader_files_10}" - ) - assert any("rank_1" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 1 found in step_10. Files: {dataloader_files_10}" +@requires_multi_gpu +def test_checkpoint_save_and_load_two_processes_fsdp2_with_context_parallelism(recipe_path, tmp_path): + """Test checkpoint save/resume for FSDP2 with context parallelism (two processes).""" + _run_multi_process_checkpoint_test( + recipe_path, + tmp_path, + "train_fsdp2_cp.py", + ckpt_subdir_name="train_fsdp2", + extra_overrides=["checkpoint.async_save=false", "cp_size=2"], + is_ddp=False, ) -def test_scheduler_resume_single_gpu(recipe_path, tmp_path): - """Test that learning rate scheduler resumes from correct state after checkpoint load. - - This test validates: - - Scheduler state is saved in checkpoint - - Scheduler resumes with correct step count - - Learning rate continues from where it left off (not reset) - - Warmup and decay continue correctly after resume - - Process: - 1. Train for 10 steps, save checkpoint with scheduler state at step 5 - 2. Resume training, verify scheduler continues from step 6 (not step 0) - 3. Check that learning rate progression is continuous across resume - """ - temp_dir = str(tmp_path / "test_scheduler_resume") +# --------------------------------------------------------------------------- +# Scheduler Resume Tests +# --------------------------------------------------------------------------- - # Phase 1: Train for 10 steps with warmup - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase1_config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=10", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=false", # Start fresh - "lr_scheduler_kwargs.num_warmup_steps=20", - "lr_scheduler_kwargs.num_decay_steps=100", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ], - ) - main_ddp(phase1_config) - gc.collect() - torch.cuda.empty_cache() +def test_scheduler_resume_single_gpu(recipe_path, tmp_path): + """Test that the LR scheduler resumes from the correct state after checkpoint load.""" + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_ddp, + ckpt_subdir_name="train_ddp", + extra_overrides=[ + "lr_scheduler_kwargs.num_warmup_steps=20", + "lr_scheduler_kwargs.num_decay_steps=100", + ], + is_ddp=True, + ) - # Phase 2: Resume training for 5 more steps - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase2_config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "lr_scheduler_kwargs.num_warmup_steps=20", - "lr_scheduler_kwargs.num_decay_steps=100", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ], - ) - main_ddp(phase2_config) - gc.collect() - torch.cuda.empty_cache() - - # Verify checkpoints were created - ckpt_subdir = os.path.join(temp_dir, "train_ddp") - assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" +@requires_multi_gpu +def test_scheduler_resume_two_gpu(recipe_path, tmp_path): + """Test that the LR scheduler resumes correctly with multi-GPU FSDP2 training.""" + _run_multi_process_checkpoint_test( + recipe_path, + tmp_path, + "train_fsdp2.py", + ckpt_subdir_name="train_fsdp2", + extra_overrides=[ + "lr_scheduler_kwargs.num_warmup_steps=20", + "lr_scheduler_kwargs.num_decay_steps=100", + ], + is_ddp=False, + ) - # Check that checkpoint directories exist - checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - expected_checkpoint_dirs = ["step_5", "step_10"] - for expected_dir in expected_checkpoint_dirs: - assert expected_dir in checkpoint_dirs, f"Missing checkpoint directory: {expected_dir}" - # Verify each checkpoint directory contains the checkpoint file - checkpoint_file = os.path.join(ckpt_subdir, expected_dir, "checkpoint.pt") - assert os.path.isfile(checkpoint_file), f"Missing checkpoint file: {checkpoint_file}" +# --------------------------------------------------------------------------- +# Final Model Save Tests +# --------------------------------------------------------------------------- def test_final_model_save_ddp(recipe_path, tmp_path): - """Test final model saving for DDP. - - Validates that DDP saves the final model correctly with: - - model.safetensors containing weights - - config.json with model configuration - - Can be loaded for inference - - This is important for: - - Exporting trained models - - HuggingFace model hub compatibility - - Inference deployment - """ - temp_dir = str(tmp_path / "test_final_ddp") - - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "checkpoint.save_final_model=true", - "num_train_steps=3", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ], - ) - - main_ddp(config) + """Test that DDP saves a final model with model.safetensors and config.json.""" + cfg = _compose_config( + recipe_path, + tmp_path, + "L0_sanity", + [ + "checkpoint.save_final_model=true", + "num_train_steps=3", + ], + ) + + loss = main_ddp(cfg) gc.collect() torch.cuda.empty_cache() - # Check final model directory - final_model_dir = os.path.join(temp_dir, "train_ddp", "final_model") - assert os.path.exists(final_model_dir), "Final model directory not created" + _assert_loss_valid(loss, "final model ddp") - # Check required files - required_files = ["model.safetensors", "config.json"] - for file in required_files: - file_path = os.path.join(final_model_dir, file) - assert os.path.exists(file_path), f"Missing required file: {file}" - assert os.path.getsize(file_path) > 0, f"File {file} is empty" + final_model_dir = os.path.join(str(tmp_path / "ckpt"), "train_ddp", "final_model") + assert os.path.exists(final_model_dir), "Final model directory not created" + for fname in ("model.safetensors", "config.json"): + fpath = os.path.join(final_model_dir, fname) + assert os.path.exists(fpath), f"Missing: {fname}" + assert os.path.getsize(fpath) > 0, f"{fname} is empty" def test_final_model_save_fsdp2(recipe_path, tmp_path): - """Test final model saving for FSDP2. - - Validates that FSDP2 gathers full state dict and saves the final model with: - - model.safetensors containing gathered weights - - config.json with model configuration - - This tests that FSDP2's parameter gathering works correctly: - - All shards are gathered - - Full model state is consolidated - - Model can be loaded for inference - """ - temp_dir = str(tmp_path / "test_final_fsdp2") - - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "checkpoint.save_final_model=true", - "num_train_steps=3", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ], - ) - - main_fsdp2(config) + """Test that FSDP2 gathers weights and saves a final model.""" + cfg = _compose_config( + recipe_path, + tmp_path, + "L0_sanity", + [ + "checkpoint.save_final_model=true", + "num_train_steps=3", + ], + ) + + loss = main_fsdp2(cfg) gc.collect() torch.cuda.empty_cache() - # Check final model directory - final_model_dir = os.path.join(temp_dir, "train_fsdp2", "final_model") - assert os.path.exists(final_model_dir), "Final model directory not created" - - # Check required files - required_files = ["model.safetensors", "config.json"] - for file in required_files: - file_path = os.path.join(final_model_dir, file) - assert os.path.exists(file_path), f"Missing required file: {file}" - assert os.path.getsize(file_path) > 0, f"File {file} is empty" - - -@requires_multi_gpu -def test_scheduler_resume_two_gpu(recipe_path, tmp_path): - """Test that learning rate scheduler resumes correctly with multi-GPU training. - - This test validates: - - Scheduler state is synchronized across GPUs during save - - All GPUs resume with same scheduler state - - Learning rate is consistent across all processes after resume - - No divergence in LR between ranks - - Process: - 1. Train for 10 steps across 2 GPUs, save checkpoint at step 5 - 2. Resume training on 2 GPUs, verify scheduler continues correctly - 3. Ensure both GPUs have same learning rate progression - """ - temp_dir = str(tmp_path / "test_scheduler_resume_2gpu") + _assert_loss_valid(loss, "final model fsdp2") - env = os.environ.copy() - env["WANDB_MODE"] = "disabled" - - # Test with FSDP2 as it's most complex for scheduler state - train_script = recipe_path / "train_fsdp2.py" - - # Phase 1: Train for 10 steps with 2 GPUs - cmd_phase1 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=10", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=false", # Start fresh - "lr_scheduler_kwargs.num_warmup_steps=20", - "lr_scheduler_kwargs.num_decay_steps=100", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ] - - result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env) - assert result1.returncode == 0, f"Phase 1 failed: {result1.stderr}" - - # Check that checkpoint was created (FSDP2 uses distributed format by default) - ckpt_subdir = os.path.join(temp_dir, "train_fsdp2") - checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - assert "step_5" in checkpoint_dirs, "Checkpoint at step 5 not found" - - # Phase 2: Resume training with 2 GPUs - cmd_phase2 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "lr_scheduler_kwargs.num_warmup_steps=20", - "lr_scheduler_kwargs.num_decay_steps=100", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ] + final_model_dir = os.path.join(str(tmp_path / "ckpt"), "train_fsdp2", "final_model") + assert os.path.exists(final_model_dir), "Final model directory not created" + for fname in ("model.safetensors", "config.json"): + fpath = os.path.join(final_model_dir, fname) + assert os.path.exists(fpath), f"Missing: {fname}" + assert os.path.getsize(fpath) > 0, f"{fname} is empty" - result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env) - assert result2.returncode == 0, f"Phase 2 failed: {result2.stderr}" - # Verify training continued successfully - # The fact that it completed without errors means scheduler state was properly synchronized - - # Check that final checkpoint was created (distributed format) - final_checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - assert "step_10" in final_checkpoint_dirs, "Checkpoint at step 10 not found" +# --------------------------------------------------------------------------- +# Checkpoint Pruning Tests +# --------------------------------------------------------------------------- def test_checkpoint_pruning(tmp_path): - """Test checkpoint pruning functionality.""" - + """Test checkpoint pruning keeps only the latest N checkpoints.""" from checkpoint import prune_checkpoints temp_dir = str(tmp_path / "test_checkpoint_pruning") @@ -1026,8 +453,7 @@ def test_checkpoint_pruning(tmp_path): def test_checkpoint_pruning_not_enough_checkpoints(tmp_path): - """Test checkpoint pruning functionality.""" - + """Test checkpoint pruning when fewer checkpoints than max exist.""" from checkpoint import prune_checkpoints temp_dir = str(tmp_path / "test_checkpoint_pruning") @@ -1040,8 +466,7 @@ def test_checkpoint_pruning_not_enough_checkpoints(tmp_path): def test_checkpoint_pruning_with_files(tmp_path): - """Test checkpoint pruning functionality.""" - + """Test checkpoint pruning with file-based checkpoints.""" from checkpoint import prune_checkpoints for i in range(11): @@ -1054,3 +479,90 @@ def test_checkpoint_pruning_with_files(tmp_path): assert (tmp_path / "step_8.pt").exists() assert (tmp_path / "step_9.pt").exists() assert (tmp_path / "step_10.pt").exists() + + +# --------------------------------------------------------------------------- +# FP8 Checkpoint Tests (with quantized_model_init) +# --------------------------------------------------------------------------- + +_FP8_QUANTIZED_OVERRIDES = [ + "fp8_config.enabled=true", + "fp8_config.quantized_model_init_kwargs.enabled=true", + "+dataset.pad_sequences_to_be_divisible_by=16", +] + + +def test_checkpoint_save_and_load_single_process_ddp_fp8_quantized(recipe_path, tmp_path, fp_recipe): + """Test checkpoint save/resume for DDP with FP8 quantized model init.""" + + if fp_recipe[0].endswith("Float8BlockScaling"): + pytest.xfail(reason="Float8BlockScaling currently does not support quantized model init + dcp") + + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_ddp, + ckpt_subdir_name="train_ddp", + config_name="L0_sanity_cp", + extra_overrides=[*_FP8_QUANTIZED_OVERRIDES, *fp_recipe], + is_ddp=True, + ) + + +def test_checkpoint_save_and_load_single_process_fsdp2_fp8_quantized(recipe_path, tmp_path, fp_recipe): + """Test checkpoint save/resume for FSDP2 with FP8 quantized model init.""" + + if fp_recipe[0].endswith("Float8BlockScaling"): + pytest.xfail(reason="Float8BlockScaling currently does not support quantized model init + dcp") + + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fsdp2, + ckpt_subdir_name="train_fsdp2", + config_name="L0_sanity_cp", + extra_overrides=[*_FP8_QUANTIZED_OVERRIDES, *fp_recipe], + is_ddp=False, + ) + + +def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized(recipe_path, tmp_path, fp_recipe): + """Test checkpoint save/resume for FSDP2 with context parallelism and FP8 quantized model init.""" + + if fp_recipe[0].endswith("Float8BlockScaling"): + pytest.xfail(reason="Float8BlockScaling currently does not support quantized model init + dcp") + + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fsdp2_cp, + ckpt_subdir_name="train_fsdp2", + config_name="L0_sanity_cp", + extra_overrides=[*_FP8_QUANTIZED_OVERRIDES, *fp_recipe], + is_ddp=False, + ) + + +def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized_async(recipe_path, tmp_path, fp_recipe): + """Test checkpoint save/resume for FSDP2+CP with FP8 quantized model init and async save. + + This reproduces the corys_config scenario where async_save=true (the default) + is used with FP8 quantized model init. + """ + + if fp_recipe[0].endswith("Float8BlockScaling"): + pytest.xfail(reason="Float8BlockScaling currently does not support quantized model init + dcp") + + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fsdp2_cp, + ckpt_subdir_name="train_fsdp2", + config_name="L0_sanity_cp", + extra_overrides=[ + *_FP8_QUANTIZED_OVERRIDES, + *fp_recipe, + "checkpoint.async_save=true", + ], + is_ddp=False, + ) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py index 370e174c6f..aebdfe17ef 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py @@ -71,7 +71,7 @@ def _create_perf_logger(logging_frequency, mock_wandb, mock_tqdm): """Create a PerfLogger with the given logging_frequency.""" dist_config = DistributedConfig() args = _make_args(logging_frequency=logging_frequency) - return PerfLogger(dist_config, args) + return PerfLogger(dist_config, args, start_step=0) def _run_steps(perf_logger, losses, grad_acc_steps=1): diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py index 7aed3ff6f5..0a25c02940 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py @@ -137,6 +137,7 @@ def main(args: DictConfig) -> float | None: ckpt_path=ckpt_path, dist_config=dist_config, dataloader=train_dataloader, + weights_only=not args.fp8_config.quantized_model_init_kwargs.enabled, ) logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch) else: @@ -144,7 +145,7 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger(dist_config, args, start_step=start_step) gc.collect() torch.cuda.empty_cache() diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 558d27366d..4d88f2e0c0 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -160,7 +160,7 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger(dist_config, args, start_step=start_step) gc.collect() torch.cuda.empty_cache() diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index 9ad3d0e297..06fb6630ba 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -178,7 +178,7 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger(dist_config, args, start_step=start_step) gc.collect() torch.cuda.empty_cache()