From 1f5f0ec1b4e5dc66d41106aedb03e041f3b702d3 Mon Sep 17 00:00:00 2001 From: zhengchenyu Date: Sat, 6 Dec 2025 10:12:16 +0800 Subject: [PATCH] Keep the training data continuous and the total batch size constant regardless of changes in the replica world size. --- torchft/data.py | 261 +++++++++++++++++++++++- torchft/data_test.py | 422 ++++++++++++++++++++++++++++++++++++++- torchft/manager.py | 142 ++++++++++++- torchft/manager_test.py | 154 +++++++++++++- torchft/optim.py | 4 +- train_ddp_fix_batch.py | 259 ++++++++++++++++++++++++ train_fsdp2_fix_batch.py | 317 +++++++++++++++++++++++++++++ 7 files changed, 1548 insertions(+), 11 deletions(-) create mode 100644 train_ddp_fix_batch.py create mode 100644 train_fsdp2_fix_batch.py diff --git a/torchft/data.py b/torchft/data.py index 02e5b3be..4f5cb90b 100644 --- a/torchft/data.py +++ b/torchft/data.py @@ -14,10 +14,269 @@ dataloader frequently to avoid duplicate batches. """ -from typing import Optional +import math +from collections.abc import Iterator +from typing import Iterable, Optional, TypeVar, Union +import torch import torch.distributed as dist from torch.utils import data +from torch.utils.data.dataset import Dataset +from torch.utils.data.sampler import BatchSampler, Sampler + +_T_co = TypeVar("_T_co", covariant=True) + + +class SkipDistributedSampler(Sampler[_T_co]): + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + skip_samples: int = 0, + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" + ) + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + self.skip_samples = skip_samples + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.skip_samples - self.num_replicas) + / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil( + (len(self.dataset) - self.skip_samples) / self.num_replicas + ) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator[_T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + indices = indices[self.skip_samples : len(indices)] + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[self.skip_samples : self.skip_samples + self.total_size] + if len(indices) != self.total_size: + raise AssertionError( + f"Number of indices ({len(indices)}) does not match total_size ({self.total_size})" + ) + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + if len(indices) != self.num_samples: + raise AssertionError( + f"Number of subsampled indices ({len(indices)}) does not match num_samples ({self.num_samples})" + ) + + # pyrefly: ignore # bad-return + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r""" + Set the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + + +class DistributedBatchSampler(Sampler[list[int]]): + r"""Wraps a BatchSampler to distribute batches across multiple processes in distributed training. + + Each process gets a subset of batches based on its rank and the total number of replicas. + This is useful for distributed training where each process should work on different batches + to avoid data duplication. + + Args: + sampler (Sampler or Iterable): Base sampler. Can be any iterable object + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + num_replicas (int): Number of processes participating in distributed training. + rank (int): Rank of the current process within num_replicas. + Should be in range [0, num_replicas - 1]. + even_batches (bool): If ``True``, ensures all ranks get exactly the same number + of batches by potentially dropping some batches. If ``False``, some ranks + may get one extra batch. Default: ``True``. + + Example: + >>> # For a dataset with indices 0-20, batch_size=2, num_replicas=2 + >>> # All batches would be: [[0,1], [2,3], [4,5], [6,7], [8,9], [10,11], ...] + >>> + >>> # With even_batches=False (original behavior): + >>> # rank=0 gets batches: [[0,1], [4,5], [8,9], [12,13], [16,17], [20]] (6 batches) + >>> # rank=1 gets batches: [[2,3], [6,7], [10,11], [14,15], [18,19]] (5 batches) + >>> sampler_rank0 = DistributedBatchSampler( + ... SequentialSampler(range(21)), batch_size=2, drop_last=False, + ... num_replicas=2, rank=0, even_batches=False + ... ) + >>> list(sampler_rank0) + [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17], [20]] + >>> + >>> # With even_batches=True (default behavior): + >>> # Both ranks get exactly 5 batches (drops the last batch [20]) + >>> # rank=0 gets batches: [[0,1], [4,5], [8,9], [12,13], [16,17]] (5 batches) + >>> # rank=1 gets batches: [[2,3], [6,7], [10,11], [14,15], [18,19]] (5 batches) + >>> sampler_rank0_even = DistributedBatchSampler( + ... SequentialSampler(range(21)), batch_size=2, drop_last=False, + ... num_replicas=2, rank=0, even_batches=True + ... ) + >>> list(sampler_rank0_even) + [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17]] + """ + + def __init__( + self, + sampler: Union[Sampler[int], Iterable[int]], + batch_size: int, + drop_last: bool, + num_replicas: int = 1, + rank: int = 0, + even_batches: bool = True, + ) -> None: + # Validate batch_size + if ( + not isinstance(batch_size, int) + or isinstance(batch_size, bool) + or batch_size <= 0 + ): + raise ValueError( + f"batch_size should be a positive integer value, but got batch_size={batch_size}" + ) + + # Validate drop_last + if not isinstance(drop_last, bool): + raise ValueError( + f"drop_last should be a boolean value, but got drop_last={drop_last}" + ) + + # Validate num_replicas + if not isinstance(num_replicas, int) or num_replicas <= 0: + raise ValueError( + f"num_replicas should be a positive integer value, but got num_replicas={num_replicas}" + ) + + # Validate rank + if not isinstance(rank, int) or rank < 0 or rank >= num_replicas: + raise ValueError( + f"rank should be an integer in range [0, {num_replicas - 1}], but got rank={rank}" + ) + + # Validate even_batches + if not isinstance(even_batches, bool): + raise ValueError( + f"even_batches should be a boolean value, but got even_batches={even_batches}" + ) + + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + self.num_replicas = num_replicas + self.rank = rank + self.even_batches = even_batches + + # Create a BatchSampler to generate all batches + self.batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + def __iter__(self) -> Iterator[list[int]]: + if self.even_batches: + # When even_batches=True, ensure all ranks get the same number of batches + # by potentially dropping some batches + all_batches = list(self.batch_sampler) + total_batches = len(all_batches) + + # Calculate how many batches each rank should get to make them even + batches_per_rank = total_batches // self.num_replicas + + # Only consider the first batches_per_rank * num_replicas batches + # This ensures even distribution + total_even_batches = batches_per_rank * self.num_replicas + + batch_idx = 0 + for batch in all_batches: + if batch_idx >= total_even_batches: + # Stop yielding once we've exhausted the even batches + break + # Only yield batches that belong to current rank + if batch_idx % self.num_replicas == self.rank: + yield batch + batch_idx += 1 + else: + # Original behavior when even_batches=False + batch_idx = 0 + for batch in self.batch_sampler: + # Only yield batches that belong to current rank + if batch_idx % self.num_replicas == self.rank: + yield batch + batch_idx += 1 + + def __len__(self) -> int: + # Calculate total number of batches from BatchSampler + total_batches = len(self.batch_sampler) # type: ignore[arg-type] + + if self.even_batches: + # When even_batches=True, all ranks get exactly the same number of batches + return total_batches // self.num_replicas + else: + # Original behavior when even_batches=False + # Each rank gets approximately total_batches // num_replicas batches + # The remaining batches are distributed among the first few ranks + batches_per_rank = total_batches // self.num_replicas + remaining_batches = total_batches % self.num_replicas + + # Current rank gets one extra batch if it's among the first 'remaining_batches' ranks + if self.rank < remaining_batches: + return batches_per_rank + 1 + else: + return batches_per_rank # pyre-fixme[24]: expected generic parameter diff --git a/torchft/data_test.py b/torchft/data_test.py index 8dae190e..0cf1ffc8 100644 --- a/torchft/data_test.py +++ b/torchft/data_test.py @@ -7,8 +7,13 @@ from unittest import TestCase from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler, SequentialSampler -from torchft.data import DistributedSampler +from torchft.data import ( + DistributedBatchSampler, + DistributedSampler, + SkipDistributedSampler, +) class DummyDataset(Dataset): @@ -37,3 +42,418 @@ def test_distributed_sampler(self) -> None: sampler_iter = iter(sampler) self.assertEqual(next(sampler_iter), 500) + + def test_skip_distributed_sampler(self): + dataset_length = 100 + dataset = DummyDataset(dataset_length) + + # Case 1: sample is not skipped + for drop_last in [True, False]: + num_replicas = 7 + for rank in range(num_replicas): + # print(f"---- sample is not skipped, drop_last={drop_last}, rank={rank} ----") + sampler = SkipDistributedSampler( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=False, + drop_last=drop_last, + ) + cur = rank + for idx in sampler: + # print("idx = ", idx) + self.assertEqual( + idx, (cur % dataset_length), f"idx={idx}, cur={cur}" + ) + cur += num_replicas + # If drop_last is True, read ceil((100-7)/7)*7=98 samples totally. + # If drop_last is False, read ceil(100/7)*7=105 samples totally. + if drop_last: + self.assertEqual(cur, 98 + rank, f"rank={rank}, cur={cur}") + else: + self.assertEqual(cur, 105 + rank, f"rank={rank}, cur={cur}") + + # Case 2: sample is skipped + for drop_last in [True, False]: + num_replicas = 7 + skip_samples = 10 + for rank in range(num_replicas): + # print(f"---- sample is skipped, drop_last={drop_last}, rank={rank} ----") + sampler = SkipDistributedSampler( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=False, + drop_last=drop_last, + skip_samples=skip_samples, + ) + cur = rank + for idx in sampler: + # print("idx = ", idx) + expected = ( + ((cur + skip_samples) % dataset_length + skip_samples) + if (cur + skip_samples) >= dataset_length + else (cur + skip_samples) + ) + self.assertEqual(idx, expected, f"idx={idx}, expected={expected}") + cur += num_replicas + # If drop_last is True, read ceil((100-10-7)/7)*7=84 samples totally. + # If drop_last is False, read ceil((100-10)/7)*7=91 samples totally. + if drop_last: + self.assertEqual(cur, 84 + rank, f"rank={rank}, cur={cur}") + else: + self.assertEqual(cur, 91 + rank, f"rank={rank}, cur={cur}") + + # Case 3: drop last is False and padding size is larger than number of indices + # If skip_samples is 90, and num_replicas is 31, then the indices is [90, 92, ..., 99]. + # It means only 10 samples are left, so padding size is 21 which is larger than 10. + num_replicas = 31 + skip_samples = 90 + expected = list(range(90, 100)) + expected = (expected * 4)[:31] + for rank in range(num_replicas): + # print(f"---- sample is skipped, drop_last={drop_last}, rank={rank} ----") + sampler = SkipDistributedSampler( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=False, + drop_last=False, + skip_samples=skip_samples, + ) + cnt = 0 + for idx in sampler: + # print("idx = ", idx) + self.assertEqual( + idx, expected[rank], f"idx={idx}, rank={rank}, expected={expected}" + ) + cnt += 1 + self.assertTrue(cnt, 1) + + def test_distributed_batch_sampler(self): + # Test 1: Basic functionality with dataset 0-20, batch_size=2, num_replicas=2 + dataset_size = 21 + batch_size = 2 + num_replicas = 2 + + # Test with even_batches=True (default behavior) - all ranks get same number of batches + sampler_rank0 = DistributedBatchSampler( + SequentialSampler(DummyDataset(dataset_size)), + batch_size=batch_size, + drop_last=False, + num_replicas=num_replicas, + rank=0, + even_batches=True, + ) + + sampler_rank1 = DistributedBatchSampler( + SequentialSampler(DummyDataset(dataset_size)), + batch_size=batch_size, + drop_last=False, + num_replicas=num_replicas, + rank=1, + even_batches=True, + ) + + batches_rank0 = list(sampler_rank0) + batches_rank1 = list(sampler_rank1) + + # With even_batches=True, both ranks get exactly 5 batches (drops the last batch [20]) + expected_rank0_even = [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17]] + expected_rank1_even = [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19]] + + assert ( + batches_rank0 == expected_rank0_even + ), f"Expected {expected_rank0_even}, got {batches_rank0}" + assert ( + batches_rank1 == expected_rank1_even + ), f"Expected {expected_rank1_even}, got {batches_rank1}" + assert len(sampler_rank0) == 5, f"Expected length 5, got {len(sampler_rank0)}" + assert len(sampler_rank1) == 5, f"Expected length 5, got {len(sampler_rank1)}" + + # Test with even_batches=False - some ranks may get extra batches + sampler_rank0_uneven = DistributedBatchSampler( + SequentialSampler(DummyDataset(dataset_size)), + batch_size=batch_size, + drop_last=False, + num_replicas=num_replicas, + rank=0, + even_batches=False, + ) + + sampler_rank1_uneven = DistributedBatchSampler( + SequentialSampler(DummyDataset(dataset_size)), + batch_size=batch_size, + drop_last=False, + num_replicas=num_replicas, + rank=1, + even_batches=False, + ) + + batches_rank0_uneven = list(sampler_rank0_uneven) + batches_rank1_uneven = list(sampler_rank1_uneven) + + # With even_batches=False, rank0 gets 6 batches, rank1 gets 5 batches + expected_rank0_uneven = [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17], [20]] + expected_rank1_uneven = [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19]] + + assert ( + batches_rank0_uneven == expected_rank0_uneven + ), f"Expected {expected_rank0_uneven}, got {batches_rank0_uneven}" + assert ( + batches_rank1_uneven == expected_rank1_uneven + ), f"Expected {expected_rank1_uneven}, got {batches_rank1_uneven}" + assert ( + len(sampler_rank0_uneven) == 6 + ), f"Expected length 6, got {len(sampler_rank0_uneven)}" + assert ( + len(sampler_rank1_uneven) == 5 + ), f"Expected length 5, got {len(sampler_rank1_uneven)}" + + # Test 2: Verify no data loss and no overlap (using even_batches=False for completeness) + all_indices_distributed = [] + for batch in batches_rank0_uneven + batches_rank1_uneven: + all_indices_distributed.extend(batch) + + normal_sampler = BatchSampler( + SequentialSampler(DummyDataset(dataset_size)), batch_size, False + ) + all_indices_normal = [] + for batch in normal_sampler: + all_indices_normal.extend(batch) + + assert sorted(all_indices_distributed) == sorted( + all_indices_normal + ), "Data completeness check failed" + assert len(set(all_indices_distributed)) == len( + all_indices_distributed + ), "Overlap detected" + + # Test 3: drop_last=True + sampler_rank0_drop = DistributedBatchSampler( + SequentialSampler(DummyDataset(dataset_size)), + batch_size=batch_size, + drop_last=True, + num_replicas=num_replicas, + rank=0, + ) + + sampler_rank1_drop = DistributedBatchSampler( + SequentialSampler(DummyDataset(dataset_size)), + batch_size=batch_size, + drop_last=True, + num_replicas=num_replicas, + rank=1, + ) + + batches_rank0_drop = list(sampler_rank0_drop) + batches_rank1_drop = list(sampler_rank1_drop) + + # With drop_last=True, we should get 10 total batches (dropping the last incomplete batch) + # rank0 should get batches 0,2,4,6,8 -> [[0,1], [4,5], [8,9], [12,13], [16,17]] + # rank1 should get batches 1,3,5,7,9 -> [[2,3], [6,7], [10,11], [14,15], [18,19]] + expected_rank0_drop = [[0, 1], [4, 5], [8, 9], [12, 13], [16, 17]] + expected_rank1_drop = [[2, 3], [6, 7], [10, 11], [14, 15], [18, 19]] + + assert ( + batches_rank0_drop == expected_rank0_drop + ), f"Expected {expected_rank0_drop}, got {batches_rank0_drop}" + assert ( + batches_rank1_drop == expected_rank1_drop + ), f"Expected {expected_rank1_drop}, got {batches_rank1_drop}" + + # Test 4: num_replicas=3 + dataset_size = 20 + num_replicas = 3 + + samplers = [] + batches = [] + for rank in range(num_replicas): + sampler = DistributedBatchSampler( + SequentialSampler(DummyDataset(dataset_size)), + batch_size=2, + drop_last=False, + num_replicas=num_replicas, + rank=rank, + even_batches=False, + ) + samplers.append(sampler) + batches.append(list(sampler)) + + # Total batches should be 10: [[0,1], [2,3], ..., [18,19]] + # rank0 gets: [0,3,6,9] -> [[0,1], [6,7], [12,13], [18,19]] (4 batches) + # rank1 gets: [1,4,7] -> [[2,3], [8,9], [14,15]] (3 batches) + # rank2 gets: [2,5,8] -> [[4,5], [10,11], [16,17]] (3 batches) + expected_batches = [ + [[0, 1], [6, 7], [12, 13], [18, 19]], # rank0 + [[2, 3], [8, 9], [14, 15]], # rank1 + [[4, 5], [10, 11], [16, 17]], # rank2 + ] + + for rank, (expected, actual) in enumerate(zip(expected_batches, batches)): + assert actual == expected, f"Rank {rank}: Expected {expected}, got {actual}" + + # Verify lengths + assert ( + len(samplers[0]) == 4 + ), f"Rank 0 length: expected 4, got {len(samplers[0])}" + assert ( + len(samplers[1]) == 3 + ), f"Rank 1 length: expected 3, got {len(samplers[1])}" + assert ( + len(samplers[2]) == 3 + ), f"Rank 2 length: expected 3, got {len(samplers[2])}" + + # Test 5: even_batches functionality + # Test even_batches=True with different scenarios + dataset_size = 23 # This will create 12 total batches with batch_size=2 + batch_size = 2 + num_replicas = 3 + + samplers_even = [] + batches_even = [] + for rank in range(num_replicas): + sampler = DistributedBatchSampler( + SequentialSampler(DummyDataset(dataset_size)), + batch_size=batch_size, + drop_last=False, + num_replicas=num_replicas, + rank=rank, + even_batches=True, + ) + samplers_even.append(sampler) + batches_even.append(list(sampler)) + + # With 12 total batches and 3 ranks, each rank should get exactly 4 batches + for rank in range(num_replicas): + assert ( + len(batches_even[rank]) == 4 + ), f"Rank {rank} should get 4 batches, got {len(batches_even[rank])}" + assert ( + len(samplers_even[rank]) == 4 + ), f"Rank {rank} __len__ should return 4, got {len(samplers_even[rank])}" + + # Test even_batches=False with same scenario + samplers_uneven = [] + batches_uneven = [] + for rank in range(num_replicas): + sampler = DistributedBatchSampler( + SequentialSampler(DummyDataset(dataset_size)), + batch_size=batch_size, + drop_last=False, + num_replicas=num_replicas, + rank=rank, + even_batches=False, + ) + samplers_uneven.append(sampler) + batches_uneven.append(list(sampler)) + + # With 12 total batches and 3 ranks, each rank gets exactly 4 batches (evenly divisible) + for rank in range(num_replicas): + assert ( + len(batches_uneven[rank]) == 4 + ), f"Rank {rank} should get 4 batches, got {len(batches_uneven[rank])}" + + # Test with 13 total batches (not evenly divisible) + dataset_size = 25 # This will create 13 total batches with batch_size=2 + + samplers_even_13 = [] + batches_even_13 = [] + for rank in range(num_replicas): + sampler = DistributedBatchSampler( + SequentialSampler(DummyDataset(dataset_size)), + batch_size=batch_size, + drop_last=False, + num_replicas=num_replicas, + rank=rank, + even_batches=True, + ) + samplers_even_13.append(sampler) + batches_even_13.append(list(sampler)) + + # With 13 total batches and 3 ranks, even_batches=True gives each rank 4 batches (drops 1 batch) + for rank in range(num_replicas): + assert ( + len(batches_even_13[rank]) == 4 + ), f"Rank {rank} should get 4 batches with even_batches=True, got {len(batches_even_13[rank])}" + + samplers_uneven_13 = [] + batches_uneven_13 = [] + for rank in range(num_replicas): + sampler = DistributedBatchSampler( + SequentialSampler(DummyDataset(dataset_size)), + batch_size=batch_size, + drop_last=False, + num_replicas=num_replicas, + rank=rank, + even_batches=False, + ) + samplers_uneven_13.append(sampler) + batches_uneven_13.append(list(sampler)) + + # With 13 total batches and 3 ranks, even_batches=False: rank0 gets 5, rank1 gets 4, rank2 gets 4 + assert ( + len(batches_uneven_13[0]) == 5 + ), f"Rank 0 should get 5 batches with even_batches=False, got {len(batches_uneven_13[0])}" + assert ( + len(batches_uneven_13[1]) == 4 + ), f"Rank 1 should get 4 batches with even_batches=False, got {len(batches_uneven_13[1])}" + assert ( + len(batches_uneven_13[2]) == 4 + ), f"Rank 2 should get 4 batches with even_batches=False, got {len(batches_uneven_13[2])}" + + # Test 6: Parameter validation + base_sampler = SequentialSampler(DummyDataset(10)) + + # Test invalid batch_size + try: + DistributedBatchSampler(base_sampler, -1, False, 2, 0) + assert False, "Should raise ValueError for negative batch_size" + except ValueError: + pass + + try: + DistributedBatchSampler(base_sampler, 0, False, 2, 0) + assert False, "Should raise ValueError for zero batch_size" + except ValueError: + pass + + # Test invalid drop_last + try: + DistributedBatchSampler(base_sampler, 2, "false", 2, 0) + assert False, "Should raise ValueError for non-bool drop_last" + except ValueError: + pass + + # Test invalid num_replicas + try: + DistributedBatchSampler(base_sampler, 2, False, 0, 0) + assert False, "Should raise ValueError for zero num_replicas" + except ValueError: + pass + + try: + DistributedBatchSampler(base_sampler, 2, False, -1, 0) + assert False, "Should raise ValueError for negative num_replicas" + except ValueError: + pass + + # Test invalid rank + try: + DistributedBatchSampler(base_sampler, 2, False, 2, -1) + assert False, "Should raise ValueError for negative rank" + except ValueError: + pass + + try: + DistributedBatchSampler(base_sampler, 2, False, 2, 2) + assert False, "Should raise ValueError for rank >= num_replicas" + except ValueError: + pass + + # Test invalid even_batches + try: + DistributedBatchSampler(base_sampler, 2, False, 2, 0, "true") + assert False, "Should raise ValueError for non-bool even_batches" + except ValueError: + pass diff --git a/torchft/manager.py b/torchft/manager.py index 7e785846..baeda8b3 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -26,6 +26,7 @@ """ import concurrent.futures +import gc import logging import os import socket @@ -33,7 +34,7 @@ import uuid import weakref from concurrent.futures import ThreadPoolExecutor -from contextlib import nullcontext +from contextlib import contextmanager from datetime import timedelta from enum import Enum from typing import ( @@ -53,6 +54,7 @@ import torch.distributed as dist from torch.distributed import ReduceOp, TCPStore from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work +from torch.futures import Future from torchft._torchft import ManagerClient, ManagerServer from torchft.checkpointing import CheckpointTransport, HTTPTransport @@ -185,6 +187,8 @@ def __init__( init_sync: bool = True, max_retries: Optional[int] = None, quorum_retries: int = 0, + dataloader_fn: Optional[Callable[[int, int, int], None]] = None, + accumulation_grad: bool = False, ) -> None: """ Args: @@ -365,6 +369,19 @@ def __init__( self._update_fr_path() + # The number of batches committed in the current epoch.Compare to _batches_committed, + # _current_batches_committed will reset to 0 when next epoch starts. + self._current_batches_committed = 0 + self._epoch = 0 + self._loaded_epoch = 0 + self._loaded_current_batches_committed = 0 + self.require_backward_grad_sync = True + self._dataloader_fn = dataloader_fn + self._dataloader_dirty = False + self._dataloader_iter = None + self._accumulation_steps = 1 + self._accumulation_grad = accumulation_grad + def allow_state_dict_read(self) -> None: if self._is_state_dict_read_allowed: return @@ -438,10 +455,21 @@ def allreduce( return _DummyWork(tensor) self.wait_quorum() + + # If dirty, the result will not be committed, so return empty tensor. + if self._dataloader_dirty: + tensor.zero_() + return _ManagedWork(self, _DummyWork(tensor), tensor) + + if not self.require_backward_grad_sync: + return _ManagedWork(self, _DummyWork(tensor), tensor) + num_participants: int = self.num_participants() if not self.is_participating(): tensor.zero_() + elif self._accumulation_grad: + tensor /= self._accumulation_steps # special logic for average pg_reduce_op = reduce_op @@ -494,6 +522,15 @@ def callback( return _DummyWork(tensor) + @contextmanager + def no_sync(self): + old_require_backward_grad_sync = self.require_backward_grad_sync + self.require_backward_grad_sync = False + try: + yield + finally: + self.require_backward_grad_sync = old_require_backward_grad_sync + def report_error(self, e: Exception) -> None: """ Report an error to the manager. @@ -678,6 +715,8 @@ def _async_quorum( if self._use_async_quorum or not allow_heal else (replica_rank, replica_world_size) ) + self._replica_rank = replica_rank + self._replica_world_size = replica_world_size # For fixed with spares we need to ensure that we don't have more # participating replicas than the min replica size. @@ -691,6 +730,7 @@ def _async_quorum( ): self._participating_replica_rank = None + quorum_changed = False if quorum_id != self._quorum_id: self.quorum_logger.info( "", @@ -706,7 +746,11 @@ def _async_quorum( f"{store_address}/torchft/{quorum_id}/{self._group_rank}" ) - self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}") + self._logger.info( + f"reconfiguring for {quorum_id=} {store_prefixed_addr=}, " + f"{self._participating_replica_world_size=}, " + f"{self._participating_replica_rank=}" + ) # We use the replica rank and world as we want all replicas in the PG. try: self._quorum_id = quorum_id @@ -737,6 +781,7 @@ def _async_quorum( self._logger.exception(f"got exception in pg configure: {e}") self.report_error(e) return + quorum_changed = True if allow_heal: # run recovery on the recovery stream if available @@ -807,6 +852,42 @@ def _async_quorum( else None ) + # reconfigure dataloader after healing so that we can get offset from other replica group + if quorum_changed and self._dataloader_fn: + self.reconfigure_dataloader() + self._dataloader_dirty = True + + def get_batch_samples( + self, epoch=0, num_batches=None, batch_size=None, total_batch_size=None + ): + # In general, `start_quorum` might not have been called during the first loop, + # and the dataloader might not have been initialized yet. In this case, we should + # return immediately and set the dirty flag to avoid computation and commit. + if not self._dataloader_iter: + self._dataloader_dirty = True + return [] + # If the recovery worker is behind the current epoch, we should skip computation and commit. + if epoch < self._loaded_epoch: + return None + + if total_batch_size != None and batch_size != None: + num_batches = total_batch_size // (batch_size * self._replica_world_size) + + assert num_batches is not None, ( + "num_batches must be specified or " + "total_batch_size and batch_size must be specified" + ) + + batch_samples = [] + for _ in range(num_batches): + try: + batch_samples.append(next(self._dataloader_iter)) + except StopIteration: + break + self._dataloader_dirty = False + self._accumulation_steps = len(batch_samples) + return batch_samples if batch_samples else None + def _update_fr_path(self) -> None: """ Update the path that flight recorder will dump the traces to. @@ -872,6 +953,10 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: Raises: RuntimeError: if should_commit fails max_retries times in a row and max_retries is set """ + + # Sometime allreduce is not called before should_commit, we need to wait quorum + self.wait_quorum() + # make sure recovery is complete before committing with torch.profiler.record_function( "torchft::manager::should_commmit::recovery_stream::synchronize" @@ -921,9 +1006,18 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: # decide whether we're in a healthy state to increase the step count if should_commit: - self._step += 1 - self._batches_committed += self.num_participants() self._commit_failures = 0 # Reset failure counter on success + if not self._dataloader_dirty: + self._step += 1 + self._batches_committed += ( + self.num_participants() * self._accumulation_steps + ) + self._current_batches_committed += ( + self.num_participants() * self._accumulation_steps + ) + return True + else: + return False else: self._commit_failures += 1 # Check if we've hit max retries @@ -934,8 +1028,7 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: msg = f"should_commit failed {self._commit_failures} times consecutively, exceeding max_retries={self._max_retries}" self._logger.exception(msg) raise RuntimeError(msg) - - return should_commit + return False def load_state_dict(self, state_dict: Dict[str, int]) -> None: """ @@ -948,6 +1041,11 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: """ self._step = state_dict["step"] self._batches_committed = state_dict["batches_committed"] + self._loaded_epoch = state_dict["epoch"] + self._loaded_current_batches_committed = state_dict["current_batches_committed"] + if self._loaded_epoch == 0: + self._epoch = 0 + self._current_batches_committed = self._loaded_current_batches_committed def _manager_state_dict(self) -> Dict[str, object]: with self._state_dict_lock.r_lock(): @@ -969,7 +1067,13 @@ def state_dict(self) -> Dict[str, int]: Returns: the state dict for this manager """ - return {"step": self._step, "batches_committed": self._batches_committed} + print("Getting manager state dict, ") + return { + "step": self._step, + "batches_committed": self._batches_committed, + "epoch": self._epoch, + "current_batches_committed": self._current_batches_committed, + } def current_step(self) -> int: """ @@ -1047,6 +1151,30 @@ def is_participating(self) -> bool: return False return True + def reconfigure_dataloader(self): + dataloader = self._dataloader_fn( + self._replica_world_size, + self._replica_rank, + self._current_batches_committed, + ) + if hasattr(dataloader, "batch_sampler"): + dataloader.batch_sampler.sampler.set_epoch(self._epoch) + else: + dataloader.sampler.set_epoch(self._epoch) + self._dataloader_iter = iter(dataloader) + # cleanup for old dataloader + gc.collect() + + def next_epoch(self): + self._epoch += 1 + if self._loaded_epoch == self._epoch: + self._current_batches_committed = self._loaded_current_batches_committed + else: + self._current_batches_committed = 0 + if self._dataloader_fn: + self.reconfigure_dataloader() + self._dataloader_dirty = False + class _ManagerLogger: def __init__(self, manager: Manager, replica_id: str, group_rank: int) -> None: diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 4d2dc42c..5728a204 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -8,16 +8,19 @@ import threading import time from datetime import timedelta -from typing import Optional +from typing import Callable, Optional from unittest import TestCase from unittest.mock import create_autospec, MagicMock, patch import torch from torch.distributed import ReduceOp, TCPStore +from torch.utils.data import DataLoader from torchft._torchft import QuorumResult from torchft.checkpointing._rwlock import RWLock from torchft.checkpointing.transport import CheckpointTransport +from torchft.data import SkipDistributedSampler +from torchft.data_test import DummyDataset from torchft.manager import Manager, MANAGER_ADDR_KEY, REPLICA_ID_KEY, WorldSizeMode from torchft.process_group import ProcessGroup from torchft.work import _DummyWork @@ -47,6 +50,7 @@ def _create_manager( timeout: timedelta = timedelta(seconds=10), init_sync: bool = True, max_retries: Optional[int] = None, + dataloader_fn: Optional[Callable[[int, int, int], None]] = None, ) -> Manager: pg = create_autospec(ProcessGroup) pg.errored.return_value = None @@ -76,6 +80,7 @@ def _create_manager( timeout=timeout, init_sync=init_sync, max_retries=max_retries, + dataloader_fn=dataloader_fn, ) self.manager = manager return manager @@ -909,3 +914,150 @@ def test_manager_state_dict_with_lock(self, client_mock: MagicMock) -> None: # Restore the original lock manager._state_dict_lock = original_lock + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_dataloader_after_quorum(self, client_mock: MagicMock) -> None: + # 1 Initial + dataset_len = 1000 + batch_size = 4 + dataset = DummyDataset(dataset_len) + committed_batches = 0 + store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + + def dataloader_fn(replica_world_size, replica_rank, batches_committed): + sampler = SkipDistributedSampler( + dataset=dataset, + num_replicas=replica_world_size, + rank=replica_rank, + shuffle=False, + seed=0, + drop_last=False, + skip_samples=batches_committed * batch_size, + ) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=replica_world_size, + sampler=sampler, + ) + return dataloader + + def exptected_samples(world_size, rank, committed_batches, expected_len=None): + expected = [] + expected_len = expected_len if expected_len is not None else batch_size + for i in range(expected_len): + expected.append(committed_batches * batch_size + rank + i * world_size) + expected = [x % dataset_len for x in expected] + return expected + + # Create manager + manager = self._create_manager(dataloader_fn=dataloader_fn) + manager.set_epoch(0) + + # mock for should_commit + client_mock().should_commit = mock_should_commit + + # mock for quorum + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{store.port}" + quorum.max_step = 1 + quorum.max_replica_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + # 2 The initial state has 2 replicas + quorum.replica_world_size = 2 + client_mock()._quorum.return_value = quorum + + # 2.1 Get sampler first time without quorum, then will got empty batches + batches = manager.get_batch_samples(1) + self.assertNotEqual(batches, None) + self.assertEqual(len(batches), 0) + + # 2.2 Start quorum, then reinit dataloader + manager.start_quorum() + manager.wait_quorum() + batches = manager.get_batch_samples(1) + self.assertEqual(len(batches), 1) + for inputs in batches: + self.assertTrue(len(inputs), 4) + self.assertEqual( + inputs.tolist(), + exptected_samples( + quorum.replica_world_size, quorum.replica_rank, committed_batches + ), + ) + + # 2.3 Call should commit to increment committed batches, then get samples + manager.should_commit() + committed_batches += quorum.replica_world_size + batches = manager.get_batch_samples(1) + self.assertEqual(len(batches), 1) + for inputs in batches: + self.assertTrue(len(inputs), 4) + self.assertEqual( + inputs.tolist(), + exptected_samples( + quorum.replica_world_size, quorum.replica_rank, committed_batches + ), + ) + + # 3 Start quorum to increment step and replica world size to 3 + quorum.quorum_id = 124 + quorum.replica_world_size = 3 + client_mock()._quorum.return_value = quorum + manager.start_quorum() + manager.wait_quorum() + + # 3.1 Get sample after quorum with 3 replicas, and set dirty flag to mock dataloader is reinit. + batches = manager.get_batch_samples(1) + manager._dataloader_dirty = True + self.assertEqual(len(batches), 1) + for inputs in batches: + self.assertTrue(len(inputs), 4) + self.assertEqual( + inputs.tolist(), + exptected_samples( + quorum.replica_world_size, quorum.replica_rank, committed_batches + ), + ) + # When the dataloader is dirty, should not commit + self.assertFalse(manager.should_commit()) + # reset the dirty flag + manager._dataloader_dirty = False + + # 3.2 Call should commit to increment committed batches + manager.should_commit() + committed_batches += quorum.replica_world_size + batches = manager.get_batch_samples(1) + self.assertEqual(len(batches), 1) + for inputs in batches: + self.assertTrue(len(inputs), 4) + self.assertEqual( + inputs.tolist(), + exptected_samples( + quorum.replica_world_size, quorum.replica_rank, committed_batches + ), + ) + + # 3.3 Continue to get samples until the dataloader is exhausted + while (batches := manager.get_batch_samples()) != None: + committed_batches += quorum.replica_world_size + self.assertEqual(len(batches), 1) + for inputs in batches: + self.assertTrue(len(inputs), 4) + self.assertEqual( + inputs.tolist(), + exptected_samples( + quorum.replica_world_size, + quorum.replica_rank, + committed_batches, + expected_len=len(inputs.tolist()), + ), + ) diff --git a/torchft/optim.py b/torchft/optim.py index a2884392..4b68e338 100644 --- a/torchft/optim.py +++ b/torchft/optim.py @@ -49,10 +49,12 @@ def zero_grad(self, set_to_none: bool = True) -> None: self.manager.start_quorum() self.optim.zero_grad(set_to_none) - def step(self, closure: Optional[object] = None) -> None: + def step(self, closure: Optional[object] = None) -> bool: assert closure is None, "optimizers that use closures are not supported" if self.manager.should_commit(): self.optim.step() + return True + return False @property def param_groups(self) -> List[Dict[str, Any]]: diff --git a/train_ddp_fix_batch.py b/train_ddp_fix_batch.py new file mode 100644 index 00000000..896ff051 --- /dev/null +++ b/train_ddp_fix_batch.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +from datetime import timedelta + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) +os.environ["NCCL_HOSTID"] = str(REPLICA_GROUP_ID) + +import torch +import torchvision +import torchvision.transforms as transforms +from torch import nn, optim +from torch.utils.data import DataLoader + +from torchft import ( + DistributedDataParallel, + Manager, + Optimizer, + ProcessGroupGloo, + ProcessGroupNCCL, + ProcessGroupXCCL, +) +from torchft.checkpointing.pg_transport import PGTransport +from torchft.data import DistributedBatchSampler, SkipDistributedSampler + +logging.basicConfig(level=logging.INFO) + +NUM_EPOCHS = 1 +BATCH_SIZE = 4 +TOTAL_BATCH_SIZE = BATCH_SIZE * 6 +CHECKPOINT_ENABLED = False +INIT_CHECKPOINT_PATH = "./tmp/train_ddp_fix_batch/ckpt-init" + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.cnn = nn.Sequential( + nn.Conv2d(3, 6, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + nn.Conv2d(6, 16, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + ) + final_dim = 10 + self.classifier = nn.Sequential( + nn.Linear(16 * 5 * 5, 120), + nn.ReLU(), + nn.Linear(120, 84), + nn.ReLU(), + nn.Linear(84, final_dim), + ) + + def forward(self, x): + x = self.cnn(x) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = self.classifier(x) + return x + + +def setup_logger(): + # Use UnbufferedFileHandler to avoid losing logs in case of failure. + class UnbufferedFileHandler(logging.FileHandler): + def emit(self, record): + super().emit(record) + self.flush() + os.fsync(self.stream.fileno()) + + loss_logger = logging.getLogger("loss") + loss_logger.setLevel(logging.INFO) + loss_logger.propagate = False + file_handler = UnbufferedFileHandler( + "./tmp/train_ddp_fix_batch/loss.txt", encoding="utf-8" + ) + loss_logger.addHandler(file_handler) + return loss_logger + + +def main() -> None: + loss_logger = setup_logger() + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + trainset = torchvision.datasets.CIFAR10( + root="./cifar", train=True, download=True, transform=transform + ) + + def load_state_dict(state_dict): + print("Received checkpoint!") + m.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optim"]) + + def state_dict(): + ret = { + "model": m.state_dict(), + "optim": optimizer.state_dict(), + } + print("Setup checkpoint to send!") + return ret + + if torch.cuda.is_available(): + device = "cuda" + pg = ProcessGroupNCCL(timeout=timedelta(seconds=30)) + elif torch.xpu.is_available(): + device = "xpu" + pg = ProcessGroupXCCL(timeout=timedelta(seconds=30)) + else: + device = "cpu" + pg = ProcessGroupGloo(timeout=timedelta(seconds=5)) + + transport = PGTransport( + pg, + timeout=timedelta(seconds=10), + device=( + "cuda" + if torch.cuda.is_available() + else "xpu" + if torch.xpu.is_available() + else "cpu" + ), + ) + + def dataloader_fn(replica_world_size, replica_rank, current_batches_committed): + sampler = SkipDistributedSampler( + dataset=trainset, + num_replicas=1, + rank=0, + shuffle=True, + seed=0, + drop_last=True, + skip_samples=current_batches_committed * BATCH_SIZE, + ) + batch_sampler = DistributedBatchSampler( + sampler=sampler, + batch_size=BATCH_SIZE, + drop_last=True, + num_replicas=replica_world_size, + rank=replica_rank, + even_batches=True, + ) + + dataloader = DataLoader(trainset, num_workers=0, batch_sampler=batch_sampler) + print( + f"num_batches remaining: {len(dataloader)}, dataset length: {len(trainset)}," + f"sampler length: {len(sampler)}, replica_world_size: {replica_world_size}," + f"replica_rank: {replica_rank}, batches_committed: {current_batches_committed}" + ) + + return dataloader + + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=f"train_ddp_{REPLICA_GROUP_ID}", + timeout=timedelta(seconds=30), + checkpoint_transport=transport, + dataloader_fn=dataloader_fn, + accumulation_grad=True, + ) + + m = Net().to(device) + if os.path.exists(INIT_CHECKPOINT_PATH): + # Load from the same model to ensure that each experiment has the same initial state. + print(f"Loading initial model from {INIT_CHECKPOINT_PATH}") + init_state_dict = torch.load(INIT_CHECKPOINT_PATH) + m.load_state_dict(init_state_dict["model"]) + else: + print("No initial model found, training from random.") + + m = DistributedDataParallel(manager, m) + optimizer = Optimizer(manager, optim.AdamW(m.parameters())) + criterion = nn.CrossEntropyLoss() + + print(m) + num_params = sum(p.numel() for p in m.parameters()) + print(f"Total number of parameters: {num_params}") + + for epoch in range(NUM_EPOCHS): + while ( + batches := manager.get_batch_samples( + epoch=epoch, batch_size=BATCH_SIZE, total_batch_size=TOTAL_BATCH_SIZE + ) + ) is not None: + optimizer.zero_grad() + total_loss = 0.0 + accumulation_steps = len(batches) + for i in range(accumulation_steps): + inputs, labels = batches[i] + inputs = inputs.to(device) + labels = labels.to(device) + out = m(inputs) + loss = criterion(out, labels) + if i == accumulation_steps - 1: + loss.backward() + else: + with manager.no_sync(): + loss.backward() + total_loss += loss.item() + + # If errored, the optimizer step will be a no-op, and the parameter will not be updated. + # Although it is possible to use new pg to compute old batches, it is still safe. + if not optimizer.step(): + # The first call to `get_batch_samples` will return empty and mark the dataloader as dirty. + # The manager server will force synchronization for `_step` being 0. If `_step` doesn't + # increment here, it will cause synchronization checkpoints twice because `_step` was 0 in + # the first two rounds. The second checkpoint will run in parallel with the computation, + # leading to pollution. Therefore, it's necessary to avoid having `_step` 0 in two + # consecutive training rounds. + if manager._step == 0: + manager._step += 1 + continue + + # allreduce the loss across all replicas for logging + loss_tensor = torch.tensor(total_loss, device=device) + # manager all reduce will divide by replica world size * accumulation steps + manager.allreduce(loss_tensor).wait() + avg_loss = loss_tensor.item() + if manager.participating_rank() == 0: + loss_logger.info(f"{manager.current_step()} {avg_loss}") + if manager.current_step() % 10 == 0: + print( + f"Epoch {epoch + 1}, step = {manager.current_step()}, " + f"batch_committed: {manager.batches_committed()}, Loss: {avg_loss:.4f}" + ) + print( + f"Epoch {epoch + 1} completed, batches_committed {manager.batches_committed()}." + ) + manager.next_epoch() + print("Training completed.") + + +def save_init_model(): + m = Net() + state_dict_to_save = { + "model": m.state_dict(), + } + if not os.path.exists(INIT_CHECKPOINT_PATH): + torch.save(state_dict_to_save, INIT_CHECKPOINT_PATH) + print("Initial model saved.") + else: + print("Init model already exists.") + + +if __name__ == "__main__": + if len(sys.argv) > 1 and sys.argv[1] == "save_init_model": + save_init_model() + else: + main() diff --git a/train_fsdp2_fix_batch.py b/train_fsdp2_fix_batch.py new file mode 100644 index 00000000..013a7680 --- /dev/null +++ b/train_fsdp2_fix_batch.py @@ -0,0 +1,317 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import hashlib +import logging +import os +import sys +from datetime import timedelta +from itertools import chain + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) +os.environ["NCCL_HOSTID"] = str(REPLICA_GROUP_ID) + +import torch +import torch.distributed as dist +import torchvision +import torchvision.transforms as transforms +from torch import nn, optim +from torch.distributed.checkpoint.state_dict import ( + get_model_state_dict, + get_optimizer_state_dict, + set_model_state_dict, + set_optimizer_state_dict, +) +from torch.distributed.distributed_c10d import ReduceOp +from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.fsdp import FSDPModule, fully_shard +from torch.distributed.tensor import DTensor +from torch.utils.data import DataLoader + +from torchft import ( + Manager, + Optimizer, + process_group, + ProcessGroupGloo, + ProcessGroupNCCL, + ProcessGroupXCCL, +) +from torchft.checkpointing.pg_transport import PGTransport +from torchft.data import DistributedBatchSampler, SkipDistributedSampler + +logging.basicConfig(level=logging.INFO) + +NUM_EPOCHS = 1 +BATCH_SIZE = 4 +MODEL_SHARDING_SIZE = 2 +TOTAL_BATCH_SIZE = BATCH_SIZE * 6 * MODEL_SHARDING_SIZE +INIT_CHECKPOINT_PATH = "./tmp/train_fsdp2_fix_batch/ckpt-init" + + +def maybe_set_all_reduce_hook(model_parts: list[torch.nn.Module], replicate_pg) -> None: + def all_reduce_hook(output): + dist.all_reduce(output, group=replicate_pg, op=ReduceOp.AVG) + + def apply_set_all_reduce_hook(m): + if isinstance(m, FSDPModule): + m.set_all_reduce_hook(all_reduce_hook) + + for model_part in model_parts: + model_part.apply(apply_set_all_reduce_hook) + + +def is_first_dp(manager): + return manager.participating_rank() == 0 and dist.get_rank() == 0 + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.cnn = nn.Sequential( + nn.Conv2d(3, 6, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + nn.Conv2d(6, 16, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + ) + final_dim = 10 + self.classifier = nn.Sequential( + nn.Linear(16 * 5 * 5, 120), + nn.ReLU(), + nn.Linear(120, 84), + nn.ReLU(), + nn.Linear(84, final_dim), + ) + + def forward(self, x): + x = self.cnn(x) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = self.classifier(x) + return x + + +def setup_logger(): + # Use UnbufferedFileHandler to avoid losing logs in case of failure. + class UnbufferedFileHandler(logging.FileHandler): + def emit(self, record): + super().emit(record) + self.flush() + os.fsync(self.stream.fileno()) + + loss_logger = logging.getLogger("loss") + loss_logger.setLevel(logging.INFO) + loss_logger.propagate = False + file_handler = UnbufferedFileHandler( + "./tmp/train_fsdp2_fix_batch/loss.txt", encoding="utf-8" + ) + loss_logger.addHandler(file_handler) + return loss_logger + + +@record +def main() -> None: + loss_logger = setup_logger() + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + trainset = torchvision.datasets.CIFAR10( + root="./cifar", train=True, download=True, transform=transform + ) + + if torch.cuda.is_available(): + local_rank = os.environ.get("LOCAL_RANK") + device = torch.device( + f"cuda:{local_rank}" if local_rank is not None else "cuda" + ) + print(f"Using CUDA device: {device}") + pg = ProcessGroupNCCL(timeout=timedelta(seconds=30)) + elif torch.xpu.is_available(): + device = "xpu" + pg = ProcessGroupXCCL(timeout=timedelta(seconds=30)) + else: + device = "cpu" + pg = ProcessGroupGloo(timeout=timedelta(seconds=5)) + + transport = PGTransport( + pg, + timeout=timedelta(seconds=10), + device=device, + ) + + def dataloader_fn(replica_world_size, replica_rank, current_batches_committed): + sampler = SkipDistributedSampler( + dataset=trainset, + num_replicas=1, + rank=0, + shuffle=True, + seed=0, + drop_last=True, + skip_samples=current_batches_committed * BATCH_SIZE * MODEL_SHARDING_SIZE, + ) + batch_sampler = DistributedBatchSampler( + sampler=sampler, + batch_size=BATCH_SIZE, + drop_last=True, + num_replicas=replica_world_size * MODEL_SHARDING_SIZE, + rank=replica_rank * MODEL_SHARDING_SIZE + + dist.get_rank() % MODEL_SHARDING_SIZE, + even_batches=True, + ) + + dataloader = DataLoader(trainset, num_workers=0, batch_sampler=batch_sampler) + print( + f"num_batches remaining: {len(dataloader)}, dataset length: {len(trainset)}," + f"sampler length: {len(sampler)}, replica_world_size: {replica_world_size}," + f"replica_rank: {replica_rank}, batches_committed: {current_batches_committed}" + ) + + return dataloader + + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=None, + state_dict=None, + replica_id=f"train_ddp_{REPLICA_GROUP_ID}", + timeout=timedelta(seconds=30), + checkpoint_transport=transport, + dataloader_fn=dataloader_fn, + ) + + m = Net() + criterion = nn.CrossEntropyLoss() + if os.path.exists(INIT_CHECKPOINT_PATH): + print(f"Loading initial model from {INIT_CHECKPOINT_PATH}") + init_state_dict = torch.load(INIT_CHECKPOINT_PATH) + m.load_state_dict(init_state_dict["model"]) + else: + print("No initial model found, training from random.") + torch.cuda.set_device(int(local_rank)) + + # Apply FSDP sharding + for layer in chain(m.cnn, m.classifier): + fully_shard(layer, reshard_after_forward=True) + m = fully_shard(m, reshard_after_forward=True) + + # Create optimizer by sharding model parameters + base_optimizer = optim.AdamW(m.parameters()) + optimizer = Optimizer(manager, base_optimizer) + + replicate_pg = process_group.ManagedProcessGroup(manager) + maybe_set_all_reduce_hook(model_parts=[m], replicate_pg=replicate_pg) + + def load_state_dict(state_dict): + # It's necessary to ensure that `set_model_state_dict` does not trigger `optim.step`, + # as this operation may occur concurrently with both forward and backward iterations. + print("Received checkpoint!") + set_model_state_dict(m, state_dict["model"]) + set_optimizer_state_dict(m, base_optimizer, state_dict["optim"]) + + def state_dict(): + # It's necessary to ensure that `get_model_state_dict` does not trigger `optim.step`, + # as this operation may occur concurrently with both forward and backward iterations. + ret = { + "model": get_model_state_dict(m), + "optim": get_optimizer_state_dict(m, base_optimizer), + } + print("Setup checkpoint to send!") + return ret + + manager.register_state_dict_fn("set_state_dict_fns", load_state_dict, state_dict) + + print(m) + num_params = sum(p.numel() for p in m.parameters()) + print(f"Total number of parameters: {num_params}") + + for epoch in range(NUM_EPOCHS): + while ( + batches := manager.get_batch_samples( + epoch=epoch, + batch_size=BATCH_SIZE, + total_batch_size=TOTAL_BATCH_SIZE // MODEL_SHARDING_SIZE, + ) + ) is not None: + optimizer.zero_grad() + total_loss = 0.0 + accumulation_steps = len(batches) + for i in range(accumulation_steps): + inputs, labels = batches[i] + inputs = inputs.to(device) + labels = labels.to(device) + out = m(inputs) + loss = criterion(out, labels) + # For fsdp2, synchronization must be performed every time; `no_sync` cannot be used. + # This is because if `all_reduce_hook` is executed in `_fsdp_collectives.py`, it applies + # to the temporary `reduce_output` variable. If synchronization is only performed in the + # last step, each shard will lose the gradients of the previous `accumulation_steps - 1` + # steps from other shards. + loss.backward() + total_loss += loss.item() + + if accumulation_steps > 1: + for group in base_optimizer.param_groups: + for param in group["params"]: + if param.grad is not None: + if isinstance(param.grad, DTensor): + param.grad.data._local_tensor.div_(accumulation_steps) + else: + param.grad.data.div_(accumulation_steps) + + # If errored, the optimizer step will be a no-op, and the parameter will not be updated. + # Although it is possible to use new pg to compute old batches, it is still safe. + if not optimizer.step(): + # For fsdp2, the model may be updated in should_commit. We must wait for all model shard + # to finish loading before proceeding; otherwise, inconsistencies may occur. + dist.barrier() + # The first call to `get_batch_samples` will return empty and mark the dataloader as dirty. + # The manager server will force synchronization for `_step` being 0. If `_step` doesn't + # increment here, it will cause synchronization checkpoints twice because `_step` was 0 in + # the first two rounds. The second checkpoint will run in parallel with the computation, + # leading to pollution. Therefore, it's necessary to avoid having `_step` 0 in two + # consecutive training rounds. + if manager._step == 0: + manager._step += 1 + continue + + loss_tensor = torch.tensor(total_loss, device=device) + # Perform allreduce within replica group. Then perform allreduce across replica groups. + dist.all_reduce(loss_tensor, op=ReduceOp.AVG) + manager.allreduce(loss_tensor).wait() + avg_loss = loss_tensor.item() + avg_loss /= accumulation_steps + if is_first_dp(manager): + loss_logger.info(f"{manager.current_step() - 1} {avg_loss}") + if manager.current_step() % 10 == 0: + print( + f"Epoch {epoch + 1}, step = {manager.current_step() - 1}, " + f"batch_committed: {manager.batches_committed()}, Loss: {avg_loss:.4f}" + ) + print( + f"Epoch {epoch + 1} completed, batches_committed {manager.batches_committed()}." + ) + manager.next_epoch() + print("Training completed.") + + +def save_init_model(): + m = Net() + state_dict_to_save = { + "model": m.state_dict(), + } + if not os.path.exists(INIT_CHECKPOINT_PATH): + torch.save(state_dict_to_save, INIT_CHECKPOINT_PATH) + print("Initial model saved.") + else: + print("Init model already exists.") + + +if __name__ == "__main__": + if len(sys.argv) > 1 and sys.argv[1] == "save_init_model": + save_init_model() + else: + main()