Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 86 additions & 2 deletions megatron/core/dist_checkpointing/exchange_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def distribute_shards_to_ranks(
shard_to_size: Dict[T, int],
num_ranks: int,
cross_parallelization_group_loads: Set[T],
shard_to_source_rank: Optional[Dict[T, int]] = None,
) -> Dict[T, int]:
"""Computes uniform distribution of workload across ranks, based on sizes.

Expand All @@ -135,11 +136,22 @@ def distribute_shards_to_ranks(
Last step is added because we rely on the fact that
the assignment is deterministic on all ranks.

When ``shard_to_source_rank`` is provided (load path), shards are first grouped
by their source rank (the local rank whose checkpoint file they originate from).
Each file-group is assigned as a unit to the source rank itself (if it is a valid
candidate), which ensures each rank loads only from its own file. Only shards that
cannot be grouped (no source info, or the source rank is not valid) fall through
to the per-shard greedy algorithm.

Args:
shard_to_ranks (Dict[T, List[int]]): mapping of rank access to shards
shard_to_size (Dict[T, int]): sizes of each shard
num_ranks (int): number of ranks in the parallelization group
cross_parallelization_group_loads (Set[T]): Shards to load that are not in the main replica
shard_to_source_rank (Dict[T, int], optional): source-rank hint for each shard
(local rank in the parallelization group whose checkpoint file contains
this shard). When provided, the algorithm groups shards by source rank
and preferentially assigns each group to that rank.

Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work
to achieve maximal uniformity)
Expand All @@ -148,9 +160,53 @@ def distribute_shards_to_ranks(
shard_to_saving_rank = {}
rank_sizes = [(0, rank) for rank in range(num_ranks)]

# --- File-locality-aware grouping (load path) ---
# When source-rank hints are available, assign whole file-groups at once so
# each loading rank opens as few checkpoint files as possible.
ungrouped_shards: Dict[T, tuple] = {}
if shard_to_source_rank:
source_groups: Dict[int, List[T]] = defaultdict(list)
for shard_id in shard_to_ranks:
source = shard_to_source_rank.get(shard_id)
if source is not None:
source_groups[source].append(shard_id)
else:
ungrouped_shards[shard_id] = shard_to_ranks[shard_id]

# Process largest file-groups first for better balance.
for source, shards in sorted(
source_groups.items(),
key=lambda kv: (-sum(shard_to_size.get(s, 0) for s in kv[1]), kv[0]),
):
group_total = sum(shard_to_size.get(s, 0) for s in shards)
# Valid ranks = intersection of all shards' rank lists in this group.
valid_ranks: Optional[set] = None
for shard_id in shards:
ranks_set = set(shard_to_ranks[shard_id])
valid_ranks = ranks_set if valid_ranks is None else valid_ranks & ranks_set
if valid_ranks:
# Prefer the source rank (so this rank reads from its own file).
if source in valid_ranks:
rank = source
else:
# Fall back to least-loaded valid rank.
_, rank = min(
(sz, r) for sz, r in rank_sizes if r in valid_ranks
)
for shard_id in shards:
shard_to_saving_rank[shard_id] = rank
rank_sizes[rank] = (rank_sizes[rank][0] + group_total, rank)
else:
# Cannot assign whole group to one rank -- fall through to per-shard greedy.
for shard_id in shards:
ungrouped_shards[shard_id] = shard_to_ranks[shard_id]
else:
ungrouped_shards = dict(shard_to_ranks)

# --- Per-shard greedy assignment (original algorithm) for ungrouped shards ---
# start from tensors of lowest coverage, then go by tensor size from largest (hence minus size)
for shard_id, shard_ranks in sorted(
shard_to_ranks.items(),
ungrouped_shards.items(),
key=lambda sh_id_ranks: (
# 0 if rank is not in cross_parallelization_group_loads
# which means it has higher priority
Expand All @@ -175,6 +231,7 @@ def determine_main_replica_uniform_distribution(
sharded_state_dict: ShardedStateDict,
parallelization_group: torch.distributed.ProcessGroup,
ignore_groups: bool = False,
key_to_source_ranks: Optional[Dict[str, List[int]]] = None,
) -> Optional[ShardDistribution]:
"""Computes the save distribution.

Expand All @@ -192,6 +249,13 @@ def determine_main_replica_uniform_distribution(
This option is primarily used during loading, as it ensures that all replicas,
including non-main ones, are loaded by this parallelization group
Defaults to False.
key_to_source_ranks (Dict[str, List[int]], optional): mapping from FQN
(ShardedTensor key) to the list of global ranks that saved this tensor.
When provided during loading, shards are grouped by source file and
preferentially assigned to the rank corresponding to the source file
so that each rank reads from a single file. For TP-sharded tensors
the same FQN may have entries from multiple TP groups; only the source
rank belonging to the current parallelization group is used.

Returns (ShardDistribution, optional): distribution that can be used to apply the
parallelization. Returns None if the process_group is trivial (1 rank)
Expand Down Expand Up @@ -243,8 +307,28 @@ def determine_main_replica_uniform_distribution(
# Filter out shards that don't belong to this group
shard_to_ranks = {k: v for k, v in shard_to_ranks.items() if k in shards_in_this_group}

# Build shard-level source-rank mapping (global → local) for file-locality assignment.
shard_to_source_rank: Optional[Dict[_ShardId, int]] = None
if key_to_source_ranks:
group_global_ranks = torch.distributed.get_process_group_ranks(parallelization_group)
group_global_ranks_set = set(group_global_ranks)
global_to_local = {g: l for l, g in enumerate(group_global_ranks)}
shard_to_source_rank = {}
for shard_id in shard_to_ranks:
fqn = shard_id[0] # first element of _ShardId is the key (FQN)
source_ranks = key_to_source_ranks.get(fqn, [])
# Pick the source rank that belongs to the current parallelization group.
for global_src in source_ranks:
if global_src in group_global_ranks_set:
shard_to_source_rank[shard_id] = global_to_local[global_src]
break

shard_to_saving_rank = distribute_shards_to_ranks(
shard_to_ranks, shard_to_size, len(all_shards), cross_parallelization_group_loads
shard_to_ranks,
shard_to_size,
len(all_shards),
cross_parallelization_group_loads,
shard_to_source_rank=shard_to_source_rank,
)

return ShardDistribution(
Expand Down
127 changes: 123 additions & 4 deletions megatron/core/dist_checkpointing/strategies/fully_parallel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
import re
from collections import defaultdict
from pathlib import Path
from time import time
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -44,6 +46,89 @@

T = TypeVar('T', ShardedObject, ShardedTensor)

# Regex to extract the global rank from a distcp file name, e.g. "__17_0.distcp" → 17
_DISTCP_RANK_RE = re.compile(r'^__(\d+)_')


def _build_key_to_source_ranks(checkpoint_dir: Path) -> Dict[str, List[int]]:
"""Build mapping from FQN to the list of global ranks that saved it.

Reads the checkpoint metadata and extracts the source rank from file names
(e.g., ``__17_0.distcp`` → rank 17). For TP-sharded tensors the same FQN
may appear with different offsets from different TP groups; all source ranks
are collected so that the caller can pick the one belonging to its own
parallelization group.

Args:
checkpoint_dir: path to the checkpoint directory.

Returns:
Mapping from FQN to list of global ranks that saved data for it.
"""
from .torch import _get_filesystem_reader

fs_reader = _get_filesystem_reader(checkpoint_dir)
metadata = fs_reader.read_metadata()

key_to_source_ranks: Dict[str, List[int]] = defaultdict(list)
seen: Dict[str, set] = defaultdict(set)

if metadata.storage_data:
for meta_idx, storage_info in metadata.storage_data.items():
fqn = meta_idx.fqn
m = _DISTCP_RANK_RE.match(storage_info.relative_path)
if m:
rank = int(m.group(1))
if rank not in seen[fqn]:
seen[fqn].add(rank)
key_to_source_ranks[fqn].append(rank)

return dict(key_to_source_ranks)


# Regex matching the TP-replica postfix appended by make_sharded_tensor_for_checkpoint.
_TP_REPLICA_POSTFIX_RE = re.compile(r'\.__tp_replica_\d+$')


def _checkpoint_has_tp_replica_keys(checkpoint_dir: Path) -> bool:
"""Return True if the checkpoint was saved with TP-replica postfix keys.

Old checkpoints (saved before the postfix change) store replicated tensors
under their original FQN. New checkpoints append ``.__tp_replica_{tp_rank}``
to the FQN so each TP rank has its own entry.

The result is used for backward compatibility: when loading an old
checkpoint we strip the postfix from the load-side ShardedTensors so that
the keys match what is stored in the checkpoint.
"""
from .torch import _get_filesystem_reader

fs_reader = _get_filesystem_reader(checkpoint_dir)
metadata = fs_reader.read_metadata()
for key in metadata.state_dict_metadata:
if '.__tp_replica_' in key:
return True
return False


def _strip_tp_replica_postfix(sharded_state_dict: ShardedStateDict) -> None:
"""Strip ``.__tp_replica_X`` postfix from all ShardedTensor keys in place.

Used for backward compatibility when loading from a checkpoint that was
saved before the TP-replica postfix change. Modifies the ShardedTensor
objects in *sharded_state_dict* directly (no copy).
"""
count = 0
for item in nested_values(sharded_state_dict):
if isinstance(item, ShardedTensor) and _TP_REPLICA_POSTFIX_RE.search(item.key):
item.key = _TP_REPLICA_POSTFIX_RE.sub('', item.key)
count += 1
if count:
logger.info(
f'Stripped .__tp_replica_* postfix from {count} ShardedTensor keys '
f'(backward compat with old checkpoint format)'
)


class FullyParallelSaveStrategyWrapper(AsyncSaveShardedStrategy):
"""Wraps arbitrary strategy and distributes the save during `save`.
Expand Down Expand Up @@ -217,13 +302,23 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St

loaded_state_dict = {}

# Step 0: Backward compatibility — detect old checkpoints that were saved
# without TP-replica postfix keys and strip the postfix from the load-side
# ShardedTensors so that keys match the checkpoint metadata.
try:
if not _checkpoint_has_tp_replica_keys(checkpoint_dir):
logger.info('Old checkpoint format detected (no TP-replica postfix keys)')
_strip_tp_replica_postfix(sharded_state_dict)
except Exception as e:
logger.warning(f'Failed to check for TP-replica keys, skipping compat: {e}')

if get_pg_size(self.parallelization_group) <= 1:
return self.base_strategy.load(sharded_state_dict, checkpoint_dir)

# Step 1 and 2: exchange load metadata and distribute the load
with debug_time("self.apply_loading_parallelization", logger):
precomputed_distribution: ShardDistribution | None = self.apply_loading_parallelization(
sharded_state_dict
sharded_state_dict, checkpoint_dir=checkpoint_dir
)
assert (
precomputed_distribution is not None
Expand Down Expand Up @@ -352,7 +447,9 @@ def fill_in_deferred_sharded_tensors(
)

def apply_loading_parallelization(
self, sharded_state_dict: ShardedStateDict
self,
sharded_state_dict: ShardedStateDict,
checkpoint_dir: Optional[Path] = None,
) -> Optional[ShardDistribution]:
"""Distributes the load across ranks by exchanging metadata.

Expand All @@ -365,8 +462,14 @@ def apply_loading_parallelization(
the calls and subsequent distributions happen without any inter-rank
communication.

When ``checkpoint_dir`` is provided, reads checkpoint metadata to build a
source-rank mapping so that each rank preferentially loads from its own
checkpoint file, minimising the number of distinct files opened per rank.

Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the loading
checkpoint_dir (Path, optional): checkpoint directory used to extract
file-locality information for the load distribution.

Returns:
ShardDistribution (optional): the computed loading distribution
Expand All @@ -376,8 +479,24 @@ def apply_loading_parallelization(
precomputed_distribution = self.cached_distribution
else:
logger.debug(f'Apply load parallelization')
# Build FQN → source-rank mapping for file-locality-aware assignment.
key_to_source_ranks = None
if checkpoint_dir is not None:
try:
key_to_source_ranks = _build_key_to_source_ranks(checkpoint_dir)
logger.debug(
f'Built key_to_source_ranks with {len(key_to_source_ranks)} entries'
)
except Exception as e:
logger.warning(
f'Failed to build key_to_source_ranks, '
f'falling back to default distribution: {e}'
)
precomputed_distribution = determine_main_replica_uniform_distribution(
sharded_state_dict, self.parallelization_group, True
sharded_state_dict,
self.parallelization_group,
True,
key_to_source_ranks=key_to_source_ranks,
)

distribute_main_replicas_with_precomputed_distribution(
Expand Down
11 changes: 10 additions & 1 deletion megatron/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,8 +1041,17 @@ def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_
tensor = get_full_tensor_if_necessary(tensor)
new_offsets.append((prepend_axis_num, dp_rank, dp_size))

tp_rank = parallel_state.get_tensor_model_parallel_rank()
tp_size = parallel_state.get_tensor_model_parallel_world_size()

# Give each TP rank a unique checkpoint key for replicated tensors.
# This ensures each TP rank saves/loads from its own file, eliminating
# cross-TP-group file reads on distributed filesystems.
if tp_size > 1:
key = f'{key}.__tp_replica_{tp_rank}'

if replica_id is None:
replica_id = (0, get_pg_rank(tp_group), dp_replica_id)
replica_id = (0, 0, dp_replica_id)

return ShardedTensor.from_rank_offsets(
key,
Expand Down