diff --git a/megatron/core/dist_checkpointing/serialization.py b/megatron/core/dist_checkpointing/serialization.py index 94c7a6cf663..02d436e1dde 100644 --- a/megatron/core/dist_checkpointing/serialization.py +++ b/megatron/core/dist_checkpointing/serialization.py @@ -19,7 +19,7 @@ from . import ShardedTensor from .core import CheckpointingConfig, save_config -from .dict_utils import extract_matching_values, merge +from .dict_utils import merge from .mapping import ( CheckpointingException, CommonStateDict, @@ -30,23 +30,16 @@ ) from .state_dict_utils import load_preprocess, save_preprocess from .strategies.async_utils import AsyncRequest -from .strategies.base import ( - AsyncSaveShardedStrategy, - LoadCommonStrategy, - LoadShardedStrategy, - SaveCommonStrategy, - SaveShardedStrategy, - StrategyAction, - get_default_strategy, -) +from .strategies.base import AsyncSaveShardedStrategy +from .strategies.common import load_common, save_common +from .strategies.torch import TorchDistLoadShardedStrategy, TorchDistSaveShardedStrategy from .utils import extract_sharded_base, force_all_tensors_to_non_fp8 from .validation import ( StrictHandling, determine_global_metadata, parse_strict_flag, validate_integrity_and_strict_load, - validate_sharded_objects_handling, - verify_checkpoint_and_load_strategy, + verify_checkpoint, ) logger = logging.getLogger(__name__) @@ -61,8 +54,8 @@ def load( sharded_state_dict: ShardedStateDict, checkpoint_dir: str, - sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None, - common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None, + sharded_strategy: TorchDistLoadShardedStrategy = None, + common_strategy: None = None, validate_access_integrity: bool = True, strict: Union[str, StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED, ) -> Union[StateDict, Tuple[StateDict, Set[str], Set[str]]]: @@ -103,9 +96,11 @@ def load( StateDict or Tuple[StateDict, Set[str], Set[str]]: in most cases only the loaded state dict is returned. If `strict` flag was set to """ - sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy( - checkpoint_dir, sharded_strategy, common_strategy - ) + assert common_strategy is None + + verify_checkpoint(checkpoint_dir) + if sharded_strategy is None: + sharded_strategy = TorchDistLoadShardedStrategy() # Dequantize all FP8 tensors in the state dict into their corresponding high-precision tensors. # Retaining FP8 tensors in the state dict can cause issues in the following two cases: @@ -118,7 +113,7 @@ def load( # amax_history buffer of Transformer Engine, which is undesirable. force_all_tensors_to_non_fp8(sharded_state_dict) - common_state_dict = common_strategy.load_common(checkpoint_dir) + common_state_dict = load_common(checkpoint_dir) sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess( sharded_state_dict @@ -133,9 +128,7 @@ def load( local_metadata, global_metadata = None, None strict = parse_strict_flag(strict) if StrictHandling.requires_explicit_ckpt_mismatch_check(strict): - ckpt_sharded_metadata = load_sharded_metadata( - checkpoint_dir, sharded_strategy, common_strategy # type: ignore[arg-type] - ) + ckpt_sharded_metadata = load_sharded_metadata(str(checkpoint_dir), sharded_strategy) if validate_access_integrity or StrictHandling.requires_global_app_metadata(strict): local_metadata, global_metadata = determine_global_metadata(sharded_state_dict) @@ -148,17 +141,6 @@ def load( ckpt_sharded_metadata, ) - # ShardedBase loading - if not sharded_strategy.can_handle_sharded_objects: - validate_sharded_objects_handling(sharded_strategy, common_strategy) - sharded_objects_state_dict, sharded_state_dict = extract_matching_values( - sharded_state_dict, lambda v: isinstance(v, ShardedObject) - ) - sharded_objects = common_strategy.load_sharded_objects( - sharded_objects_state_dict, checkpoint_dir - ) - merge(common_state_dict, sharded_objects) - loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir) merge(common_state_dict, loaded_state_dict) @@ -189,12 +171,12 @@ def load_common_state_dict(checkpoint_dir: Union[str, Path]) -> StateDict: "load_common_state_dict will no longer be supported in a future release. " "Please pass it as a string instead.", ) - sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(checkpoint_dir) - return common_strategy.load_common(checkpoint_dir) + verify_checkpoint(str(checkpoint_dir)) + return load_common(checkpoint_dir) def load_tensors_metadata( - checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None + checkpoint_dir: str, sharded_strategy: TorchDistLoadShardedStrategy = None ) -> CkptShardedMetadata: """Load tensors metadata from the checkpoint. @@ -218,16 +200,14 @@ def load_tensors_metadata( CkptShardedMetadata: flat state dict without data describing ShardedTensors in the checkpoint """ - sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy( - checkpoint_dir, sharded_strategy - ) + verify_checkpoint(checkpoint_dir) + if sharded_strategy is None: + sharded_strategy = TorchDistLoadShardedStrategy() return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir)) def load_sharded_metadata( - checkpoint_dir: str, - sharded_strategy: Union[LoadShardedStrategy, None] = None, - common_strategy: Union[LoadCommonStrategy, None] = None, + checkpoint_dir: str, sharded_strategy: TorchDistLoadShardedStrategy = None ) -> CkptShardedMetadata: """Load sharded metadata from the checkpoint. @@ -248,22 +228,15 @@ def load_sharded_metadata( sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used. - common_strategy (LoadCommonStrategy, optional): common strategy to load metadata. - Defaults to None - in this case a default load strategy for a given checkpoint type is - used. This strategy won't be used unless `sharded_strategy` can't handle ShardedObjects Returns: CkptShardedMetadata: flat state dict without data describing ShardedTensors and ShardedObjects in the checkpoint """ - sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy( - checkpoint_dir, sharded_strategy, common_strategy - ) + verify_checkpoint(checkpoint_dir) + if sharded_strategy is None: + sharded_strategy = TorchDistLoadShardedStrategy() sharded_metadata = sharded_strategy.load_sharded_metadata(checkpoint_dir) - if not sharded_strategy.can_handle_sharded_objects: - validate_sharded_objects_handling(sharded_strategy, common_strategy) - common_metadata = common_strategy.load_sharded_metadata(checkpoint_dir) - sharded_metadata = merge(sharded_metadata, common_metadata) return sharded_metadata @@ -307,15 +280,15 @@ def load_content_metadata( def remove_sharded_tensors(checkpoint_dir: str, key_prefix: str): """determine the appropriate sharding strategy and delegate removal to the sharded strategy""" - sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(checkpoint_dir) - sharded_strategy.remove_sharded_tensors(checkpoint_dir, key_prefix) + verify_checkpoint(checkpoint_dir) + TorchDistSaveShardedStrategy.remove_sharded_tensors(checkpoint_dir, key_prefix) def save( sharded_state_dict: ShardedStateDict, checkpoint_dir: str, - sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None, - common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None, + sharded_strategy: TorchDistSaveShardedStrategy = None, + common_strategy: None = None, validate_access_integrity: bool = True, async_sharded_save: bool = False, preprocess_common_before_consistancy_check: Optional[ @@ -373,6 +346,8 @@ def save( async request that should be scheduled by the caller of this function. None otherwise. """ + from .strategies.fully_parallel import FullyParallelSaveStrategyWrapper + if torch.distributed.get_rank() == 0: if MultiStorageClientFeature.is_enabled(): msc = MultiStorageClientFeature.import_package() @@ -386,20 +361,13 @@ def save( if torch.distributed.get_rank() == 0: logger.warning("Overwriting old incomplete / corrupted checkpoint...") - if common_strategy is not None: - raise NotImplementedError('The only supported common strategy is torch') + assert common_strategy is None if sharded_strategy is None: - sharded_strategy = get_default_save_sharded_strategy() - if not isinstance(sharded_strategy, SaveShardedStrategy): - assert isinstance(sharded_strategy, tuple), type(sharded_strategy) - sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, *sharded_strategy) - - if common_strategy is None: - common_strategy = get_default_save_common_strategy() - if not isinstance(common_strategy, SaveCommonStrategy): - assert isinstance(common_strategy, tuple), type(common_strategy) - common_strategy = get_default_strategy(StrategyAction.SAVE_COMMON, *common_strategy) + sharded_strategy = TorchDistSaveShardedStrategy() + assert isinstance(sharded_strategy, TorchDistSaveShardedStrategy) or isinstance( + sharded_strategy, FullyParallelSaveStrategyWrapper + ), f"Unknown sharded strategy type: {type(sharded_strategy)}" if content_metadata is not None: sharded_state_dict[_CONTENT_METADATA_KEY] = content_metadata @@ -408,14 +376,7 @@ def save( sharded_state_dict, validate_access_integrity, preprocess_common_before_consistancy_check ) - common_strategy.save_common(state_dict, checkpoint_dir) - - if not sharded_strategy.can_handle_sharded_objects: - validate_sharded_objects_handling(sharded_strategy, common_strategy) - sharded_objects_state_dict, sharded_state_dict = extract_matching_values( - sharded_state_dict, lambda v: isinstance(v, ShardedObject) - ) - common_strategy.save_sharded_objects(sharded_objects_state_dict, checkpoint_dir) + save_common(state_dict, checkpoint_dir) def metadata_finalize_fn(): if torch.distributed.get_rank() == 0: @@ -441,18 +402,13 @@ def metadata_finalize_fn(): def get_default_save_sharded_strategy( backend: str = 'torch_dist', version: int = 1 -) -> SaveShardedStrategy: +) -> TorchDistSaveShardedStrategy: """Get default save sharded strategy.""" - return get_default_strategy(StrategyAction.SAVE_SHARDED, backend, version) - - -def get_default_save_common_strategy( - backend: str = 'torch', version: int = 1 -) -> SaveCommonStrategy: - """Get default save common strategy.""" - return get_default_strategy(StrategyAction.SAVE_COMMON, backend, version) + if backend != 'torch_dist' or version != 1: + raise ValueError(f"Unsupported backend: {backend} or version: {version}") + return TorchDistSaveShardedStrategy() -def get_default_load_sharded_strategy(checkpoint_dir: str) -> LoadShardedStrategy: +def get_default_load_sharded_strategy(checkpoint_dir: str) -> TorchDistLoadShardedStrategy: """Get default load sharded strategy.""" - return verify_checkpoint_and_load_strategy(checkpoint_dir)[0] + return TorchDistLoadShardedStrategy() diff --git a/megatron/core/dist_checkpointing/strategies/__init__.py b/megatron/core/dist_checkpointing/strategies/__init__.py index a786b8e84a6..d0a42ae92f8 100644 --- a/megatron/core/dist_checkpointing/strategies/__init__.py +++ b/megatron/core/dist_checkpointing/strategies/__init__.py @@ -1,7 +1 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Various loading and saving strategies """ -from megatron.core.dist_checkpointing.strategies.common import register_default_common_strategies - -# We load "common" strategies by default to be always available -register_default_common_strategies() diff --git a/megatron/core/dist_checkpointing/strategies/base.py b/megatron/core/dist_checkpointing/strategies/base.py index 53422b362f6..a7fe29bd618 100644 --- a/megatron/core/dist_checkpointing/strategies/base.py +++ b/megatron/core/dist_checkpointing/strategies/base.py @@ -8,16 +8,14 @@ from pathlib import Path from typing import Any, DefaultDict, Union -from ..mapping import CheckpointingException, ShardedStateDict, StateDict +from ..mapping import CheckpointingException, ShardedStateDict from .async_utils import AsyncCallsQueue, AsyncRequest class StrategyAction(Enum): - """Specifies save vs load and sharded vs common action.""" + """Specifies save vs load action.""" - LOAD_COMMON = 'load_common' LOAD_SHARDED = 'load_sharded' - SAVE_COMMON = 'save_common' SAVE_SHARDED = 'save_sharded' @@ -56,7 +54,7 @@ def register_default_strategy( """Adds a given strategy to the registry of default strategies. Args: - action (StrategyAction): specifies save/load and sharded/common + action (StrategyAction): specifies save/load and sharded backend (str): backend that the strategy becomes a default for version (int): version that the strategy becomes a default for strategy (SaveStrategyBase, LoadStrategyBase): strategy to register @@ -101,28 +99,6 @@ def __str__(self): return f'{self.__class__.__name__}({self.backend}, {self.version})' -class LoadCommonStrategy(LoadStrategyBase): - """Load strategy for common (non-sharded) objects""" - - @abstractmethod - def load_common(self, checkpoint_dir: Union[str, Path]): - """Load common part of the checkpoint.""" - raise NotImplementedError - - @abstractmethod - def load_sharded_objects( - self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Union[str, Path] - ): - """Load sharded objects from the checkpoint.""" - raise NotImplementedError - - def load_sharded_metadata(self, checkpoint_dir: Union[str, Path]) -> ShardedStateDict: - """Load just the metadata from the checkpoint.""" - if not self.can_handle_sharded_objects: - return {} - raise NotImplementedError - - class LoadShardedStrategy(LoadStrategyBase): """Load strategy for sharded tensors""" @@ -166,21 +142,6 @@ def remove_sharded_tensors(self, checkpoint_dir: Union[str, Path], key_prefix: s raise NotImplementedError -class SaveCommonStrategy(SaveStrategyBase): - """Save strategy for common (non-sharded) objects""" - - @abstractmethod - def save_common(self, common_state_dict: StateDict, checkpoint_dir: Union[str, Path]): - """Save common part of the state dict.""" - raise NotImplementedError - - def save_sharded_objects( - self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Union[str, Path] - ): - """Save sharded objects from the state dict.""" - raise NotImplementedError - - class SaveShardedStrategy(SaveStrategyBase): """Save strategy for sharded tensors""" diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py index 41c21d93d7d..7ba0b498267 100644 --- a/megatron/core/dist_checkpointing/strategies/common.py +++ b/megatron/core/dist_checkpointing/strategies/common.py @@ -3,191 +3,39 @@ """ Common strategies. """ import logging -import os from pathlib import Path -from typing import Union import torch -from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict -from megatron.core.dist_checkpointing.strategies.base import ( - SaveCommonStrategy, - StrategyAction, - register_default_strategy, -) -from megatron.core.msc_utils import MultiStorageClientFeature +from megatron.core.dist_checkpointing.mapping import StateDict -from ..dict_utils import dict_list_map_inplace, nested_values -from ..mapping import CheckpointingException, ShardedObject, is_main_replica -from ..strategies.base import LoadCommonStrategy +from ..mapping import CheckpointingException COMMON_STATE_FNAME = 'common.pt' logger = logging.getLogger(__name__) -def register_default_common_strategies(): - """Register default common strategies.""" - register_default_strategy(StrategyAction.LOAD_COMMON, 'torch', 1, TorchCommonLoadStrategy()) - register_default_strategy( - StrategyAction.SAVE_COMMON, 'torch', 1, TorchCommonSaveStrategy('torch', 1) - ) +def save_common(common_state_dict: StateDict, checkpoint_dir: Path): + """Save common part of the state dict.""" + if torch.distributed.get_rank() == 0: + torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME) -class TorchCommonSaveStrategy(SaveCommonStrategy): - """Common save strategy leveraging native torch save/load.""" +def load_common(checkpoint_dir: Path): + """Load common (non-sharded) objects state dict from the checkpoint. - def save_common(self, common_state_dict: StateDict, checkpoint_dir: Union[str, Path]): - """Save common part of the state dict.""" - if torch.distributed.get_rank() == 0: - path = os.path.join(checkpoint_dir, COMMON_STATE_FNAME) - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - msc.torch.save(common_state_dict, path) - else: - torch.save(common_state_dict, path) + Args: + checkpoint_dir (Path): checkpoint directory - def save_sharded_objects( - self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Union[str, Path] - ): - """Save sharded objects from the state dict.""" - for sh_obj in nested_values(sharded_objects_state_dict): - if is_main_replica(sh_obj.replica_id): - save_path = os.path.join(checkpoint_dir, f"{sh_obj.unique_key}.pt") - parent_dir = os.path.dirname(save_path) - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - msc.os.makedirs(parent_dir, exist_ok=True) - msc.torch.save(sh_obj.data, save_path) - else: - os.makedirs(parent_dir, exist_ok=True) - torch.save(sh_obj.data, save_path) - - def can_handle_sharded_objects(self): - """This strategy can handle ShardedObjects.""" - return True - - -class TorchCommonLoadStrategy(LoadCommonStrategy): - """Common load strategy leveraging native torch save/load.""" - - def load_common(self, checkpoint_dir: Union[str, Path]): - """Load common (non-sharded) objects state dict from the checkpoint. - - Args: - checkpoint_dir (Union[str, Path]): checkpoint directory - - Returns: - StateDict: state dict with non-sharded objects from the checkpoint - """ - load_path = os.path.join(checkpoint_dir, COMMON_STATE_FNAME) - try: - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - return msc.torch.load(load_path, map_location='cpu') - else: - return torch.load(load_path, map_location='cpu') - except FileNotFoundError as e: - err_msg = f'Common file {load_path} does not exist' - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - ckpt_files = [f.name for f in msc.Path(checkpoint_dir).iterdir()] - else: - ckpt_files = [f.name for f in checkpoint_dir.iterdir()] - logger.debug(f'{err_msg}. Checkpoint directory content: {ckpt_files}') - raise CheckpointingException(err_msg) from e - - def load_sharded_objects( - self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Union[str, Path] - ): - """Replaces all ShardedObject from a given state dict with values loaded from the - checkpoint. - - Args: - sharded_objects_state_dict (ShardedStateDict): - sharded state dict defining what objects should be loaded. - checkpoint_dir (Union[str, Path]): checkpoint directory - - Returns: - None: sharded state dict is modified in place - """ - - def load_sharded_object(sh_obj: ShardedObject): - sh_obj.data = None - load_path = os.path.join(checkpoint_dir, f'{sh_obj.unique_key}.pt') - try: - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - loaded_obj = msc.torch.load(load_path) - else: - loaded_obj = torch.load(load_path) - except FileNotFoundError as e: - # Backward compatible logic: previously the save format was incorrect - base, _ = os.path.splitext(sh_obj.unique_key) - old_load_path = os.path.join(checkpoint_dir, f"{base}.pt") - try: - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - loaded_obj = msc.torch.load(old_load_path) - else: - loaded_obj = torch.load(old_load_path) - except FileNotFoundError: - err_msg = f'Object shard {load_path} not found' - obj_subdir = os.path.join(checkpoint_dir, sh_obj.key) - if os.path.exists(obj_subdir): - obj_files = os.listdir(obj_subdir) - logger.debug( - f'{err_msg}. Object {sh_obj.key} directory content: {obj_files}' - ) - else: - ckpt_files = os.listdir(checkpoint_dir) - logger.debug( - f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint' - f' directory content: {ckpt_files}' - ) - raise CheckpointingException(err_msg) from e - return loaded_obj - - return dict_list_map_inplace(load_sharded_object, sharded_objects_state_dict) - - def load_sharded_metadata(self, checkpoint_dir: Union[str, Path]) -> ShardedStateDict: - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - checkpoint_dir = msc.Path(checkpoint_dir) - else: - checkpoint_dir = Path(checkpoint_dir) - - sharded_metadata = {} - for subdir in checkpoint_dir.iterdir(): - if not subdir.is_dir(): - continue - shard_files = list(subdir.glob('shard_*.pt')) - if not shard_files: - continue - sh_objs = [] - for shard_file in shard_files: - full_key = f'{subdir.name}/{shard_file.stem}' - sh_objs.append(ShardedObject.empty_from_unique_key(full_key)) - - # This is a backward-compatibility fix, where the last global shape is missing in the - # name - if sh_objs[0].global_shape[-1] < 0: - max_last_offset = max(map(lambda sh_obj: sh_obj.global_offset[-1], sh_objs)) - for sh_obj in sh_objs: - sh_obj.global_shape = (*sh_obj.global_shape[:-1], max_last_offset + 1) - - # Update the sharded state dict - for sh_obj in sh_objs: - sharded_metadata[sh_obj.unique_key] = sh_obj - return sharded_metadata - - @property - def can_handle_sharded_objects(self): - """This strategy can handle ShardedObjects.""" - return True - - def check_backend_compatibility(self, loaded_version): - pass - - def check_version_compatibility(self, loaded_version): - pass + Returns: + StateDict: state dict with non-sharded objects from the checkpoint + """ + load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME + try: + return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + ckpt_files = [f.name for f in checkpoint_dir.iterdir()] + logger.debug(f'{err_msg}. Checkpoint directory content: {ckpt_files}') + raise CheckpointingException(err_msg) from e diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index a5b6c009ba4..4230d233278 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -37,7 +37,6 @@ from torch.distributed.checkpoint.metadata import Metadata from torch.distributed.checkpoint.planner_helpers import _create_write_items -from ...utils import get_torch_version, is_torch_min_version from ..core import CheckpointingException from ..dict_utils import nested_values from ..mapping import ( @@ -278,6 +277,8 @@ def _mcore_to_dcp_compatible_tensor(sh_tens: List[ShardedTensor]) -> TorchSharde - if `allow_shape_mismatch` is True, the data is initialized with zeros prior to loading (not all parts of the tensor will be read from the checkpoint) """ + from ...utils import is_torch_min_version + assert all(isinstance(sh_ten, ShardedTensor) for sh_ten in sh_tens), sh_tens for sh_ten in sh_tens: if sh_ten.data is None: @@ -427,6 +428,8 @@ def __init__( ) -> None: # `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings # during saving. + from ...utils import get_torch_version + if get_torch_version() <= PkgVersion("2.2"): kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors super().__init__(*args, **kwargs) @@ -596,8 +599,8 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy): def __init__( self, - backend: str, - version: int, + backend: str = "torch_dist", + version: int = 1, keep_only_main_replica: bool = True, thread_count: int = 2, cached_metadata: bool = False, @@ -863,6 +866,7 @@ def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str): 4. resaves the new metadata and removes the old metadata 5. removes the relevant files """ + from ...utils import is_torch_min_version assert is_torch_min_version( "2.3.0" diff --git a/megatron/core/dist_checkpointing/validation.py b/megatron/core/dist_checkpointing/validation.py index 0e5d6a011a8..89ecba1a968 100644 --- a/megatron/core/dist_checkpointing/validation.py +++ b/megatron/core/dist_checkpointing/validation.py @@ -1,16 +1,19 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import logging -import os from collections import Counter, defaultdict from enum import Enum +from pathlib import Path from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union import numpy as np import torch from megatron.core.dist_checkpointing import ShardedTensor -from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config +from megatron.core.dist_checkpointing.core import ( + CheckpointingException, + check_is_distributed_checkpoint, +) from megatron.core.dist_checkpointing.dict_utils import diff, extract_matching_values, nested_values from megatron.core.dist_checkpointing.mapping import ( CommonStateDict, @@ -19,15 +22,6 @@ ShardedStateDict, is_main_replica, ) -from megatron.core.dist_checkpointing.strategies.base import ( - LoadCommonStrategy, - LoadShardedStrategy, - SaveCommonStrategy, - SaveShardedStrategy, - StrategyAction, - get_default_strategy, -) -from megatron.core.msc_utils import MultiStorageClientFeature if TYPE_CHECKING: from megatron.core.dist_checkpointing.serialization import CkptShardedMetadata @@ -199,60 +193,17 @@ def validate_integrity_and_strict_load( return sharded_state_dict, missing_keys, unexpected_keys -def verify_checkpoint_and_load_strategy( - checkpoint_dir: str, - sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None, - common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None, -) -> Tuple[LoadShardedStrategy, LoadCommonStrategy]: - """Verifies if checkpoint metadata exists and matches given strategies. - - If no strategies are passed, they are determined based on the checkpoint metadata. +def verify_checkpoint(checkpoint_dir: str): + """Verifies if checkpoint exists. Args: checkpoint_dir (str): checkpoint directory - sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): sharded load strategy to be verified - if compatible with the checkpoint content. If None, the default sharded load strategy - for the checkpoint backend will be returned. - common_strategy (LoadCommonStrategy, Tuple[str, int], optional): common load strategy to be verified - if compatible with the checkpoint content. If None, the default common load strategy - for the checkpoint backend will be returned. """ - isdir = True - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - isdir = msc.os.path.isdir(str(checkpoint_dir), strict=False) - else: - isdir = os.path.isdir(checkpoint_dir) - if not isdir: - raise CheckpointingException(f"Checkpoint directory {checkpoint_dir} does not exist") - - saved_config = maybe_load_config(checkpoint_dir) - if saved_config is None: - raise CheckpointingException(f"{checkpoint_dir} is not a distributed checkpoint") - - if sharded_strategy is None: - sharded_strategy = get_default_strategy( - StrategyAction.LOAD_SHARDED, - saved_config.sharded_backend, - saved_config.sharded_backend_version, - ) - elif isinstance(sharded_strategy, tuple): - sharded_strategy = get_default_strategy(StrategyAction.LOAD_SHARDED, *sharded_strategy) - - if common_strategy is None: - common_strategy = get_default_strategy( - StrategyAction.LOAD_COMMON, - saved_config.common_backend, - saved_config.common_backend_version, - ) - elif isinstance(common_strategy, tuple): - sharded_strategy = get_default_strategy(StrategyAction.LOAD_COMMON, *common_strategy) + if not Path(checkpoint_dir).exists(): + raise CheckpointingException(f'Checkpoint directory {checkpoint_dir} does not exist') - sharded_strategy.check_backend_compatibility(saved_config.sharded_backend) - sharded_strategy.check_version_compatibility(saved_config.sharded_backend_version) - common_strategy.check_backend_compatibility(saved_config.common_backend) - common_strategy.check_version_compatibility(saved_config.common_backend_version) - return sharded_strategy, common_strategy + if not check_is_distributed_checkpoint(checkpoint_dir): + raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint') def adjust_non_strict_load( @@ -532,29 +483,3 @@ def determine_global_metadata( global_metadata = [None] * torch.distributed.get_world_size() torch.distributed.all_gather_object(global_metadata, local_metadata) return local_metadata, global_metadata # type: ignore[return-value] - - -def validate_sharded_objects_handling( - sharded_strategy: Union[SaveShardedStrategy, LoadShardedStrategy], - common_strategy: Union[SaveCommonStrategy, LoadCommonStrategy], -) -> None: - """Checks if either of the passed strategies can handle sharded objects. - - Args: - sharded_strategy (Union[SaveShardedStrategy, LoadShardedStrategy]): sharded strategy used for saving/loading - common_strategy (Union[SaveCommonStrategy, LoadCommonStrategy]): common strategy used for saving/loading - - Returns: - None - - Raises: - CheckpointingException: if both strategies can't handle ShardedObjects - """ - if ( - not sharded_strategy.can_handle_sharded_objects - and not common_strategy.can_handle_sharded_objects - ): - raise CheckpointingException( - f"Either sharded strategy or common strategy must implement ShardedObjects handling." - f" Both {sharded_strategy} and {common_strategy} specify can_handle_sharded_objects=False" - ) diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index f964b8dd32e..8b116bc1140 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -23,7 +23,7 @@ from megatron.core import dist_checkpointing, mpu, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedObject -from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy +from megatron.core.dist_checkpointing.strategies.torch import TorchDistLoadShardedStrategy from megatron.core.dist_checkpointing.strategies.fully_parallel import ( FullyParallelLoadStrategyWrapper, FullyParallelSaveStrategyWrapper, @@ -1119,7 +1119,7 @@ def _load_global_dist_base_checkpoint( ) checkpoint_name = get_checkpoint_name(load_dir, iteration, release, return_base_dir=True) - load_strategy = get_default_load_sharded_strategy(checkpoint_name) + load_strategy = TorchDistLoadShardedStrategy() # NOTE: `args.ckpt_fully_parallel_load` applies to both persistent and non-persistent checkpoints. if args.ckpt_fully_parallel_load: load_strategy = FullyParallelLoadStrategyWrapper( diff --git a/tests/unit_tests/dist_checkpointing/models/common.py b/tests/unit_tests/dist_checkpointing/models/common.py index 8cb1dc4df65..0e97b5f230a 100644 --- a/tests/unit_tests/dist_checkpointing/models/common.py +++ b/tests/unit_tests/dist_checkpointing/models/common.py @@ -6,14 +6,15 @@ from megatron.core import parallel_state from megatron.core.dist_checkpointing import load, load_plain_tensors, save from megatron.core.dist_checkpointing.dict_utils import diff -from megatron.core.dist_checkpointing.serialization import ( - get_default_load_sharded_strategy, - get_default_save_sharded_strategy, -) +from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy from megatron.core.dist_checkpointing.strategies.fully_parallel import ( FullyParallelLoadStrategyWrapper, FullyParallelSaveStrategyWrapper, ) +from megatron.core.dist_checkpointing.strategies.torch import ( + TorchDistLoadShardedStrategy, + TorchDistSaveShardedStrategy, +) from megatron.core.dist_checkpointing.validation import StrictHandling from tests.unit_tests.dist_checkpointing import TempNamedDir from tests.unit_tests.test_utilities import Utils @@ -81,7 +82,7 @@ def common_test_parallel_reconfiguration_e2e( pipeline_model_parallel_size=src_tp_pp[1], **src_model_init_kwargs, ) - save_strategy = get_default_save_sharded_strategy() + save_strategy = TorchDistSaveShardedStrategy() if use_fpsl: save_strategy = FullyParallelSaveStrategyWrapper( save_strategy, @@ -104,7 +105,7 @@ def common_test_parallel_reconfiguration_e2e( **dst_model_init_kwargs, ) if use_fpsl: - load_strategy = get_default_load_sharded_strategy(ckpt_dir_A) + load_strategy = TorchDistLoadShardedStrategy() load_strategy = FullyParallelLoadStrategyWrapper(load_strategy) else: load_strategy = None diff --git a/tests/unit_tests/dist_checkpointing/models/test_mamba.py b/tests/unit_tests/dist_checkpointing/models/test_mamba.py index 85fbe5dd045..eb493d43655 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_mamba.py +++ b/tests/unit_tests/dist_checkpointing/models/test_mamba.py @@ -6,14 +6,12 @@ from megatron.core import parallel_state from megatron.core.dist_checkpointing import load, load_plain_tensors, save from megatron.core.dist_checkpointing.dict_utils import diff -from megatron.core.dist_checkpointing.serialization import ( - get_default_load_sharded_strategy, - get_default_save_sharded_strategy, -) +from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy from megatron.core.dist_checkpointing.strategies.fully_parallel import ( FullyParallelLoadStrategyWrapper, FullyParallelSaveStrategyWrapper, ) +from megatron.core.dist_checkpointing.strategies.torch import TorchDistLoadShardedStrategy from megatron.core.extensions.transformer_engine import ( TELayerNormColumnParallelLinear, TERowParallelLinear, @@ -149,7 +147,7 @@ def test_parallel_reconfiguration_e2e( sequence_parallel=(dest_exp > 1 and dest_pp > 1), ) if use_fpsl: - load_strategy = get_default_load_sharded_strategy(ckpt_dir_A) + load_strategy = TorchDistLoadShardedStrategy() load_strategy = FullyParallelLoadStrategyWrapper( load_strategy, parallel_state.get_data_parallel_group(with_context_parallel=True), diff --git a/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py b/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py index ca546d746af..903d45608b8 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py +++ b/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py @@ -8,14 +8,12 @@ from megatron.core import parallel_state from megatron.core.dist_checkpointing import load, load_plain_tensors, save from megatron.core.dist_checkpointing.dict_utils import diff -from megatron.core.dist_checkpointing.serialization import ( - get_default_load_sharded_strategy, - get_default_save_sharded_strategy, -) +from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy from megatron.core.dist_checkpointing.strategies.fully_parallel import ( FullyParallelLoadStrategyWrapper, FullyParallelSaveStrategyWrapper, ) +from megatron.core.dist_checkpointing.strategies.torch import TorchDistLoadShardedStrategy from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, @@ -204,7 +202,7 @@ def test_parallel_reconfiguration_e2e( ) model_B = initialize_expert_layer(1, use_glu, expert_type) if use_fpsl: - load_strategy = get_default_load_sharded_strategy(ckpt_dir_A) + load_strategy = TorchDistLoadShardedStrategy() load_strategy = FullyParallelLoadStrategyWrapper( load_strategy, parallel_state.get_data_parallel_group(with_context_parallel=True), diff --git a/tests/unit_tests/dist_checkpointing/test_fp8.py b/tests/unit_tests/dist_checkpointing/test_fp8.py index 4fb89a8265a..a8a4d467822 100644 --- a/tests/unit_tests/dist_checkpointing/test_fp8.py +++ b/tests/unit_tests/dist_checkpointing/test_fp8.py @@ -5,14 +5,12 @@ from transformer_engine.pytorch.float8_tensor import Float8Tensor from megatron.core.dist_checkpointing import ShardedTensor, load, save -from megatron.core.dist_checkpointing.serialization import ( - get_default_load_sharded_strategy, - get_default_save_sharded_strategy, -) +from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy from megatron.core.dist_checkpointing.strategies.fully_parallel import ( FullyParallelLoadStrategyWrapper, FullyParallelSaveStrategyWrapper, ) +from megatron.core.dist_checkpointing.strategies.torch import TorchDistLoadShardedStrategy from tests.unit_tests.dist_checkpointing import TempNamedDir from tests.unit_tests.test_utilities import Utils @@ -104,7 +102,7 @@ def get_state_dict(fill_val=1): Utils.initialize_model_parallel(*dest_tp_pp) if use_fpsl: - load_strategy = get_default_load_sharded_strategy(ckpt_dir) + load_strategy = TorchDistLoadShardedStrategy() load_strategy = FullyParallelLoadStrategyWrapper( load_strategy, None, False, load_exchange_algo ) diff --git a/tests/unit_tests/dist_checkpointing/test_fully_parallel.py b/tests/unit_tests/dist_checkpointing/test_fully_parallel.py index 494eaefb44c..69bebd9f08c 100644 --- a/tests/unit_tests/dist_checkpointing/test_fully_parallel.py +++ b/tests/unit_tests/dist_checkpointing/test_fully_parallel.py @@ -27,10 +27,7 @@ ShardedTensorFactory, is_main_replica, ) -from megatron.core.dist_checkpointing.serialization import ( - get_default_load_sharded_strategy, - get_default_save_sharded_strategy, -) +from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy from megatron.core.dist_checkpointing.strategies.base import ( LoadShardedStrategy, SaveShardedStrategy, diff --git a/tests/unit_tests/dist_checkpointing/test_serialization.py b/tests/unit_tests/dist_checkpointing/test_serialization.py index 0815633f9b5..7dff08f3b3d 100644 --- a/tests/unit_tests/dist_checkpointing/test_serialization.py +++ b/tests/unit_tests/dist_checkpointing/test_serialization.py @@ -831,16 +831,15 @@ def _get_base_state_dict(self): ), } - @pytest.mark.parametrize('save_format', ['torch_dist']) @pytest.mark.parametrize('validate_integrity', [True, False]) def test_unexpected_keys_handling_during_validation( - self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format + self, caplog, tmp_path_dist_ckpt, validate_integrity ): sharded_state_dict = self._get_base_state_dict() with TempNamedDir( tmp_path_dist_ckpt / 'test_unexpected_keys_raises_error_during_validation' ) as ckpt_dir: - save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1) + save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, 'torch_dist', 1) save(sharded_state_dict, ckpt_dir, save_strategy) def load_with_flag(strict): @@ -865,9 +864,7 @@ def test_error(error_msg): assert 'Missing keys' not in error_msg # ASSUME_OK_UNEXPECTED results in an exception raised by the underlying strategy - with pytest.raises( - PyTCheckpointingException if save_format == 'torch_dist' else CheckpointingException - ) as exc_info: + with pytest.raises(PyTCheckpointingException) as exc_info: load_with_flag(StrictHandling.ASSUME_OK_UNEXPECTED) # Informative exceptions with `RAISE_*` options: with pytest.raises(CheckpointingException) as exc_info: @@ -905,16 +902,15 @@ def test_error(error_msg): loaded_state_dict = load_with_flag(StrictHandling.IGNORE_ALL) assert 'TenA' in loaded_state_dict - @pytest.mark.parametrize('save_format', ['torch_dist']) @pytest.mark.parametrize('validate_integrity', [True, False]) def test_missing_keys_raises_error_during_validation( - self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format + self, caplog, tmp_path_dist_ckpt, validate_integrity ): sharded_state_dict = self._get_base_state_dict() with TempNamedDir( tmp_path_dist_ckpt / 'test_missing_keys_raises_error_during_validation' ) as ckpt_dir: - save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1) + save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, 'torch_dist', 1) save(sharded_state_dict, ckpt_dir, save_strategy) def load_with_flag(strict): @@ -975,12 +971,11 @@ def test_error(error_msg): assert unexpected_keys == set() assert missing_keys == {'TenA', 'ObjB'} - @pytest.mark.parametrize('save_format', ['torch_dist']) @pytest.mark.parametrize('validate_integrity', [True, False]) - def test_exact_load_handling(self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format): + def test_exact_load_handling(self, caplog, tmp_path_dist_ckpt, validate_integrity): sharded_state_dict = self._get_base_state_dict() with TempNamedDir(tmp_path_dist_ckpt / 'test_exact_load_handling') as ckpt_dir: - save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1) + save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, 'torch_dist', 1) save(sharded_state_dict, ckpt_dir, save_strategy) def load_with_flag(strict): @@ -1015,12 +1010,11 @@ def load_with_flag(strict): assert missing_keys == set() assert unexpected_keys == set() - @pytest.mark.parametrize('save_format', ['torch_dist']) - def test_sharded_metadata(self, tmp_path_dist_ckpt, save_format): + def test_sharded_metadata(self, tmp_path_dist_ckpt): sharded_state_dict = self._get_base_state_dict() with TempNamedDir(tmp_path_dist_ckpt / 'test_exact_load_handling') as ckpt_dir: - save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1) + save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, 'torch_dist', 1) save(sharded_state_dict, ckpt_dir, save_strategy) torch.distributed.barrier() sharded_metadata = load_sharded_metadata(ckpt_dir)