diff --git a/megatron/core/dist_checkpointing/exchange_utils.py b/megatron/core/dist_checkpointing/exchange_utils.py index 79f906b237a..7b0defdfcc6 100644 --- a/megatron/core/dist_checkpointing/exchange_utils.py +++ b/megatron/core/dist_checkpointing/exchange_utils.py @@ -120,6 +120,7 @@ def distribute_shards_to_ranks( shard_to_size: Dict[T, int], num_ranks: int, cross_parallelization_group_loads: Set[T], + shard_to_source_rank: Optional[Dict[T, int]] = None, ) -> Dict[T, int]: """Computes uniform distribution of workload across ranks, based on sizes. @@ -135,11 +136,22 @@ def distribute_shards_to_ranks( Last step is added because we rely on the fact that the assignment is deterministic on all ranks. + When ``shard_to_source_rank`` is provided (load path), shards are first grouped + by their source rank (the local rank whose checkpoint file they originate from). + Each file-group is assigned as a unit to the source rank itself (if it is a valid + candidate), which ensures each rank loads only from its own file. Only shards that + cannot be grouped (no source info, or the source rank is not valid) fall through + to the per-shard greedy algorithm. + Args: shard_to_ranks (Dict[T, List[int]]): mapping of rank access to shards shard_to_size (Dict[T, int]): sizes of each shard num_ranks (int): number of ranks in the parallelization group cross_parallelization_group_loads (Set[T]): Shards to load that are not in the main replica + shard_to_source_rank (Dict[T, int], optional): source-rank hint for each shard + (local rank in the parallelization group whose checkpoint file contains + this shard). When provided, the algorithm groups shards by source rank + and preferentially assigns each group to that rank. Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work to achieve maximal uniformity) @@ -148,9 +160,53 @@ def distribute_shards_to_ranks( shard_to_saving_rank = {} rank_sizes = [(0, rank) for rank in range(num_ranks)] + # --- File-locality-aware grouping (load path) --- + # When source-rank hints are available, assign whole file-groups at once so + # each loading rank opens as few checkpoint files as possible. + ungrouped_shards: Dict[T, tuple] = {} + if shard_to_source_rank: + source_groups: Dict[int, List[T]] = defaultdict(list) + for shard_id in shard_to_ranks: + source = shard_to_source_rank.get(shard_id) + if source is not None: + source_groups[source].append(shard_id) + else: + ungrouped_shards[shard_id] = shard_to_ranks[shard_id] + + # Process largest file-groups first for better balance. + for source, shards in sorted( + source_groups.items(), + key=lambda kv: (-sum(shard_to_size.get(s, 0) for s in kv[1]), kv[0]), + ): + group_total = sum(shard_to_size.get(s, 0) for s in shards) + # Valid ranks = intersection of all shards' rank lists in this group. + valid_ranks: Optional[set] = None + for shard_id in shards: + ranks_set = set(shard_to_ranks[shard_id]) + valid_ranks = ranks_set if valid_ranks is None else valid_ranks & ranks_set + if valid_ranks: + # Prefer the source rank (so this rank reads from its own file). + if source in valid_ranks: + rank = source + else: + # Fall back to least-loaded valid rank. + _, rank = min( + (sz, r) for sz, r in rank_sizes if r in valid_ranks + ) + for shard_id in shards: + shard_to_saving_rank[shard_id] = rank + rank_sizes[rank] = (rank_sizes[rank][0] + group_total, rank) + else: + # Cannot assign whole group to one rank -- fall through to per-shard greedy. + for shard_id in shards: + ungrouped_shards[shard_id] = shard_to_ranks[shard_id] + else: + ungrouped_shards = dict(shard_to_ranks) + + # --- Per-shard greedy assignment (original algorithm) for ungrouped shards --- # start from tensors of lowest coverage, then go by tensor size from largest (hence minus size) for shard_id, shard_ranks in sorted( - shard_to_ranks.items(), + ungrouped_shards.items(), key=lambda sh_id_ranks: ( # 0 if rank is not in cross_parallelization_group_loads # which means it has higher priority @@ -175,6 +231,7 @@ def determine_main_replica_uniform_distribution( sharded_state_dict: ShardedStateDict, parallelization_group: torch.distributed.ProcessGroup, ignore_groups: bool = False, + key_to_source_ranks: Optional[Dict[str, List[int]]] = None, ) -> Optional[ShardDistribution]: """Computes the save distribution. @@ -192,6 +249,13 @@ def determine_main_replica_uniform_distribution( This option is primarily used during loading, as it ensures that all replicas, including non-main ones, are loaded by this parallelization group Defaults to False. + key_to_source_ranks (Dict[str, List[int]], optional): mapping from FQN + (ShardedTensor key) to the list of global ranks that saved this tensor. + When provided during loading, shards are grouped by source file and + preferentially assigned to the rank corresponding to the source file + so that each rank reads from a single file. For TP-sharded tensors + the same FQN may have entries from multiple TP groups; only the source + rank belonging to the current parallelization group is used. Returns (ShardDistribution, optional): distribution that can be used to apply the parallelization. Returns None if the process_group is trivial (1 rank) @@ -243,8 +307,28 @@ def determine_main_replica_uniform_distribution( # Filter out shards that don't belong to this group shard_to_ranks = {k: v for k, v in shard_to_ranks.items() if k in shards_in_this_group} + # Build shard-level source-rank mapping (global → local) for file-locality assignment. + shard_to_source_rank: Optional[Dict[_ShardId, int]] = None + if key_to_source_ranks: + group_global_ranks = torch.distributed.get_process_group_ranks(parallelization_group) + group_global_ranks_set = set(group_global_ranks) + global_to_local = {g: l for l, g in enumerate(group_global_ranks)} + shard_to_source_rank = {} + for shard_id in shard_to_ranks: + fqn = shard_id[0] # first element of _ShardId is the key (FQN) + source_ranks = key_to_source_ranks.get(fqn, []) + # Pick the source rank that belongs to the current parallelization group. + for global_src in source_ranks: + if global_src in group_global_ranks_set: + shard_to_source_rank[shard_id] = global_to_local[global_src] + break + shard_to_saving_rank = distribute_shards_to_ranks( - shard_to_ranks, shard_to_size, len(all_shards), cross_parallelization_group_loads + shard_to_ranks, + shard_to_size, + len(all_shards), + cross_parallelization_group_loads, + shard_to_source_rank=shard_to_source_rank, ) return ShardDistribution( diff --git a/megatron/core/dist_checkpointing/strategies/fully_parallel.py b/megatron/core/dist_checkpointing/strategies/fully_parallel.py index be3b941c07e..499c41242c1 100644 --- a/megatron/core/dist_checkpointing/strategies/fully_parallel.py +++ b/megatron/core/dist_checkpointing/strategies/fully_parallel.py @@ -1,8 +1,10 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import logging +import re +from collections import defaultdict from pathlib import Path from time import time -from typing import Any, Callable, Dict, Optional, Tuple, TypeVar +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar import torch import torch.distributed as dist @@ -44,6 +46,89 @@ T = TypeVar('T', ShardedObject, ShardedTensor) +# Regex to extract the global rank from a distcp file name, e.g. "__17_0.distcp" → 17 +_DISTCP_RANK_RE = re.compile(r'^__(\d+)_') + + +def _build_key_to_source_ranks(checkpoint_dir: Path) -> Dict[str, List[int]]: + """Build mapping from FQN to the list of global ranks that saved it. + + Reads the checkpoint metadata and extracts the source rank from file names + (e.g., ``__17_0.distcp`` → rank 17). For TP-sharded tensors the same FQN + may appear with different offsets from different TP groups; all source ranks + are collected so that the caller can pick the one belonging to its own + parallelization group. + + Args: + checkpoint_dir: path to the checkpoint directory. + + Returns: + Mapping from FQN to list of global ranks that saved data for it. + """ + from .torch import _get_filesystem_reader + + fs_reader = _get_filesystem_reader(checkpoint_dir) + metadata = fs_reader.read_metadata() + + key_to_source_ranks: Dict[str, List[int]] = defaultdict(list) + seen: Dict[str, set] = defaultdict(set) + + if metadata.storage_data: + for meta_idx, storage_info in metadata.storage_data.items(): + fqn = meta_idx.fqn + m = _DISTCP_RANK_RE.match(storage_info.relative_path) + if m: + rank = int(m.group(1)) + if rank not in seen[fqn]: + seen[fqn].add(rank) + key_to_source_ranks[fqn].append(rank) + + return dict(key_to_source_ranks) + + +# Regex matching the TP-replica postfix appended by make_sharded_tensor_for_checkpoint. +_TP_REPLICA_POSTFIX_RE = re.compile(r'\.__tp_replica_\d+$') + + +def _checkpoint_has_tp_replica_keys(checkpoint_dir: Path) -> bool: + """Return True if the checkpoint was saved with TP-replica postfix keys. + + Old checkpoints (saved before the postfix change) store replicated tensors + under their original FQN. New checkpoints append ``.__tp_replica_{tp_rank}`` + to the FQN so each TP rank has its own entry. + + The result is used for backward compatibility: when loading an old + checkpoint we strip the postfix from the load-side ShardedTensors so that + the keys match what is stored in the checkpoint. + """ + from .torch import _get_filesystem_reader + + fs_reader = _get_filesystem_reader(checkpoint_dir) + metadata = fs_reader.read_metadata() + for key in metadata.state_dict_metadata: + if '.__tp_replica_' in key: + return True + return False + + +def _strip_tp_replica_postfix(sharded_state_dict: ShardedStateDict) -> None: + """Strip ``.__tp_replica_X`` postfix from all ShardedTensor keys in place. + + Used for backward compatibility when loading from a checkpoint that was + saved before the TP-replica postfix change. Modifies the ShardedTensor + objects in *sharded_state_dict* directly (no copy). + """ + count = 0 + for item in nested_values(sharded_state_dict): + if isinstance(item, ShardedTensor) and _TP_REPLICA_POSTFIX_RE.search(item.key): + item.key = _TP_REPLICA_POSTFIX_RE.sub('', item.key) + count += 1 + if count: + logger.info( + f'Stripped .__tp_replica_* postfix from {count} ShardedTensor keys ' + f'(backward compat with old checkpoint format)' + ) + class FullyParallelSaveStrategyWrapper(AsyncSaveShardedStrategy): """Wraps arbitrary strategy and distributes the save during `save`. @@ -217,13 +302,23 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St loaded_state_dict = {} + # Step 0: Backward compatibility — detect old checkpoints that were saved + # without TP-replica postfix keys and strip the postfix from the load-side + # ShardedTensors so that keys match the checkpoint metadata. + try: + if not _checkpoint_has_tp_replica_keys(checkpoint_dir): + logger.info('Old checkpoint format detected (no TP-replica postfix keys)') + _strip_tp_replica_postfix(sharded_state_dict) + except Exception as e: + logger.warning(f'Failed to check for TP-replica keys, skipping compat: {e}') + if get_pg_size(self.parallelization_group) <= 1: return self.base_strategy.load(sharded_state_dict, checkpoint_dir) # Step 1 and 2: exchange load metadata and distribute the load with debug_time("self.apply_loading_parallelization", logger): precomputed_distribution: ShardDistribution | None = self.apply_loading_parallelization( - sharded_state_dict + sharded_state_dict, checkpoint_dir=checkpoint_dir ) assert ( precomputed_distribution is not None @@ -352,7 +447,9 @@ def fill_in_deferred_sharded_tensors( ) def apply_loading_parallelization( - self, sharded_state_dict: ShardedStateDict + self, + sharded_state_dict: ShardedStateDict, + checkpoint_dir: Optional[Path] = None, ) -> Optional[ShardDistribution]: """Distributes the load across ranks by exchanging metadata. @@ -365,8 +462,14 @@ def apply_loading_parallelization( the calls and subsequent distributions happen without any inter-rank communication. + When ``checkpoint_dir`` is provided, reads checkpoint metadata to build a + source-rank mapping so that each rank preferentially loads from its own + checkpoint file, minimising the number of distinct files opened per rank. + Args: sharded_state_dict (ShardedStateDict): state dict to distribute the loading + checkpoint_dir (Path, optional): checkpoint directory used to extract + file-locality information for the load distribution. Returns: ShardDistribution (optional): the computed loading distribution @@ -376,8 +479,24 @@ def apply_loading_parallelization( precomputed_distribution = self.cached_distribution else: logger.debug(f'Apply load parallelization') + # Build FQN → source-rank mapping for file-locality-aware assignment. + key_to_source_ranks = None + if checkpoint_dir is not None: + try: + key_to_source_ranks = _build_key_to_source_ranks(checkpoint_dir) + logger.debug( + f'Built key_to_source_ranks with {len(key_to_source_ranks)} entries' + ) + except Exception as e: + logger.warning( + f'Failed to build key_to_source_ranks, ' + f'falling back to default distribution: {e}' + ) precomputed_distribution = determine_main_replica_uniform_distribution( - sharded_state_dict, self.parallelization_group, True + sharded_state_dict, + self.parallelization_group, + True, + key_to_source_ranks=key_to_source_ranks, ) distribute_main_replicas_with_precomputed_distribution( diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 636c76f2a84..751210fb5f7 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1041,8 +1041,17 @@ def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_ tensor = get_full_tensor_if_necessary(tensor) new_offsets.append((prepend_axis_num, dp_rank, dp_size)) + tp_rank = parallel_state.get_tensor_model_parallel_rank() + tp_size = parallel_state.get_tensor_model_parallel_world_size() + + # Give each TP rank a unique checkpoint key for replicated tensors. + # This ensures each TP rank saves/loads from its own file, eliminating + # cross-TP-group file reads on distributed filesystems. + if tp_size > 1: + key = f'{key}.__tp_replica_{tp_rank}' + if replica_id is None: - replica_id = (0, get_pg_rank(tp_group), dp_replica_id) + replica_id = (0, 0, dp_replica_id) return ShardedTensor.from_rank_offsets( key,