Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 42 additions & 86 deletions megatron/core/dist_checkpointing/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from . import ShardedTensor
from .core import CheckpointingConfig, save_config
from .dict_utils import extract_matching_values, merge
from .dict_utils import merge
from .mapping import (
CheckpointingException,
CommonStateDict,
Expand All @@ -30,23 +30,16 @@
)
from .state_dict_utils import load_preprocess, save_preprocess
from .strategies.async_utils import AsyncRequest
from .strategies.base import (
AsyncSaveShardedStrategy,
LoadCommonStrategy,
LoadShardedStrategy,
SaveCommonStrategy,
SaveShardedStrategy,
StrategyAction,
get_default_strategy,
)
from .strategies.base import AsyncSaveShardedStrategy
from .strategies.common import load_common, save_common
from .strategies.torch import TorchDistLoadShardedStrategy, TorchDistSaveShardedStrategy
from .utils import extract_sharded_base, force_all_tensors_to_non_fp8
from .validation import (
StrictHandling,
determine_global_metadata,
parse_strict_flag,
validate_integrity_and_strict_load,
validate_sharded_objects_handling,
verify_checkpoint_and_load_strategy,
verify_checkpoint,
)

logger = logging.getLogger(__name__)
Expand All @@ -61,8 +54,8 @@
def load(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
sharded_strategy: TorchDistLoadShardedStrategy = None,
common_strategy: None = None,
validate_access_integrity: bool = True,
strict: Union[str, StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED,
) -> Union[StateDict, Tuple[StateDict, Set[str], Set[str]]]:
Expand Down Expand Up @@ -103,9 +96,11 @@ def load(
StateDict or Tuple[StateDict, Set[str], Set[str]]: in most cases only
the loaded state dict is returned. If `strict` flag was set to
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy, common_strategy
)
assert common_strategy is None

verify_checkpoint(checkpoint_dir)
if sharded_strategy is None:
sharded_strategy = TorchDistLoadShardedStrategy()

# Dequantize all FP8 tensors in the state dict into their corresponding high-precision tensors.
# Retaining FP8 tensors in the state dict can cause issues in the following two cases:
Expand All @@ -118,7 +113,7 @@ def load(
# amax_history buffer of Transformer Engine, which is undesirable.
force_all_tensors_to_non_fp8(sharded_state_dict)

common_state_dict = common_strategy.load_common(checkpoint_dir)
common_state_dict = load_common(checkpoint_dir)

sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict
Expand All @@ -133,9 +128,7 @@ def load(
local_metadata, global_metadata = None, None
strict = parse_strict_flag(strict)
if StrictHandling.requires_explicit_ckpt_mismatch_check(strict):
ckpt_sharded_metadata = load_sharded_metadata(
checkpoint_dir, sharded_strategy, common_strategy # type: ignore[arg-type]
)
ckpt_sharded_metadata = load_sharded_metadata(str(checkpoint_dir), sharded_strategy)
if validate_access_integrity or StrictHandling.requires_global_app_metadata(strict):
local_metadata, global_metadata = determine_global_metadata(sharded_state_dict)

Expand All @@ -148,17 +141,6 @@ def load(
ckpt_sharded_metadata,
)

# ShardedBase loading
if not sharded_strategy.can_handle_sharded_objects:
validate_sharded_objects_handling(sharded_strategy, common_strategy)
sharded_objects_state_dict, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedObject)
)
sharded_objects = common_strategy.load_sharded_objects(
sharded_objects_state_dict, checkpoint_dir
)
merge(common_state_dict, sharded_objects)

loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)

merge(common_state_dict, loaded_state_dict)
Expand Down Expand Up @@ -189,12 +171,12 @@ def load_common_state_dict(checkpoint_dir: Union[str, Path]) -> StateDict:
"load_common_state_dict will no longer be supported in a future release. "
"Please pass it as a string instead.",
)
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(checkpoint_dir)
return common_strategy.load_common(checkpoint_dir)
verify_checkpoint(str(checkpoint_dir))
return load_common(checkpoint_dir)


def load_tensors_metadata(
checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None
checkpoint_dir: str, sharded_strategy: TorchDistLoadShardedStrategy = None
) -> CkptShardedMetadata:
"""Load tensors metadata from the checkpoint.

Expand All @@ -218,16 +200,14 @@ def load_tensors_metadata(
CkptShardedMetadata: flat state dict without data describing ShardedTensors
in the checkpoint
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy
)
verify_checkpoint(checkpoint_dir)
if sharded_strategy is None:
sharded_strategy = TorchDistLoadShardedStrategy()
return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir))


def load_sharded_metadata(
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, None] = None,
common_strategy: Union[LoadCommonStrategy, None] = None,
checkpoint_dir: str, sharded_strategy: TorchDistLoadShardedStrategy = None
) -> CkptShardedMetadata:
"""Load sharded metadata from the checkpoint.

Expand All @@ -248,22 +228,15 @@ def load_sharded_metadata(
sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type
is used.
common_strategy (LoadCommonStrategy, optional): common strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type is
used. This strategy won't be used unless `sharded_strategy` can't handle ShardedObjects

Returns:
CkptShardedMetadata: flat state dict without data describing ShardedTensors
and ShardedObjects in the checkpoint
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy, common_strategy
)
verify_checkpoint(checkpoint_dir)
if sharded_strategy is None:
sharded_strategy = TorchDistLoadShardedStrategy()
sharded_metadata = sharded_strategy.load_sharded_metadata(checkpoint_dir)
if not sharded_strategy.can_handle_sharded_objects:
validate_sharded_objects_handling(sharded_strategy, common_strategy)
common_metadata = common_strategy.load_sharded_metadata(checkpoint_dir)
sharded_metadata = merge(sharded_metadata, common_metadata)
return sharded_metadata


Expand Down Expand Up @@ -307,15 +280,15 @@ def load_content_metadata(

def remove_sharded_tensors(checkpoint_dir: str, key_prefix: str):
"""determine the appropriate sharding strategy and delegate removal to the sharded strategy"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(checkpoint_dir)
sharded_strategy.remove_sharded_tensors(checkpoint_dir, key_prefix)
verify_checkpoint(checkpoint_dir)
TorchDistSaveShardedStrategy.remove_sharded_tensors(checkpoint_dir, key_prefix)


def save(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None,
sharded_strategy: TorchDistSaveShardedStrategy = None,
common_strategy: None = None,
validate_access_integrity: bool = True,
async_sharded_save: bool = False,
preprocess_common_before_consistancy_check: Optional[
Expand Down Expand Up @@ -373,6 +346,8 @@ def save(
async request that should be scheduled by the caller of this function.
None otherwise.
"""
from .strategies.fully_parallel import FullyParallelSaveStrategyWrapper

if torch.distributed.get_rank() == 0:
if MultiStorageClientFeature.is_enabled():
msc = MultiStorageClientFeature.import_package()
Expand All @@ -386,20 +361,13 @@ def save(
if torch.distributed.get_rank() == 0:
logger.warning("Overwriting old incomplete / corrupted checkpoint...")

if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
assert common_strategy is None

if sharded_strategy is None:
sharded_strategy = get_default_save_sharded_strategy()
if not isinstance(sharded_strategy, SaveShardedStrategy):
assert isinstance(sharded_strategy, tuple), type(sharded_strategy)
sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, *sharded_strategy)

if common_strategy is None:
common_strategy = get_default_save_common_strategy()
if not isinstance(common_strategy, SaveCommonStrategy):
assert isinstance(common_strategy, tuple), type(common_strategy)
common_strategy = get_default_strategy(StrategyAction.SAVE_COMMON, *common_strategy)
sharded_strategy = TorchDistSaveShardedStrategy()
assert isinstance(sharded_strategy, TorchDistSaveShardedStrategy) or isinstance(
sharded_strategy, FullyParallelSaveStrategyWrapper
), f"Unknown sharded strategy type: {type(sharded_strategy)}"

if content_metadata is not None:
sharded_state_dict[_CONTENT_METADATA_KEY] = content_metadata
Expand All @@ -408,14 +376,7 @@ def save(
sharded_state_dict, validate_access_integrity, preprocess_common_before_consistancy_check
)

common_strategy.save_common(state_dict, checkpoint_dir)

if not sharded_strategy.can_handle_sharded_objects:
validate_sharded_objects_handling(sharded_strategy, common_strategy)
sharded_objects_state_dict, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedObject)
)
common_strategy.save_sharded_objects(sharded_objects_state_dict, checkpoint_dir)
save_common(state_dict, checkpoint_dir)

def metadata_finalize_fn():
if torch.distributed.get_rank() == 0:
Expand All @@ -441,18 +402,13 @@ def metadata_finalize_fn():

def get_default_save_sharded_strategy(
backend: str = 'torch_dist', version: int = 1
) -> SaveShardedStrategy:
) -> TorchDistSaveShardedStrategy:
"""Get default save sharded strategy."""
return get_default_strategy(StrategyAction.SAVE_SHARDED, backend, version)


def get_default_save_common_strategy(
backend: str = 'torch', version: int = 1
) -> SaveCommonStrategy:
"""Get default save common strategy."""
return get_default_strategy(StrategyAction.SAVE_COMMON, backend, version)
if backend != 'torch_dist' or version != 1:
raise ValueError(f"Unsupported backend: {backend} or version: {version}")
return TorchDistSaveShardedStrategy()


def get_default_load_sharded_strategy(checkpoint_dir: str) -> LoadShardedStrategy:
def get_default_load_sharded_strategy(checkpoint_dir: str) -> TorchDistLoadShardedStrategy:
"""Get default load sharded strategy."""
return verify_checkpoint_and_load_strategy(checkpoint_dir)[0]
return TorchDistLoadShardedStrategy()
6 changes: 0 additions & 6 deletions megatron/core/dist_checkpointing/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.

""" Various loading and saving strategies """
from megatron.core.dist_checkpointing.strategies.common import register_default_common_strategies

# We load "common" strategies by default to be always available
register_default_common_strategies()
45 changes: 3 additions & 42 deletions megatron/core/dist_checkpointing/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
from pathlib import Path
from typing import Any, DefaultDict, Union

from ..mapping import CheckpointingException, ShardedStateDict, StateDict
from ..mapping import CheckpointingException, ShardedStateDict
from .async_utils import AsyncCallsQueue, AsyncRequest


class StrategyAction(Enum):
"""Specifies save vs load and sharded vs common action."""
"""Specifies save vs load action."""

LOAD_COMMON = 'load_common'
LOAD_SHARDED = 'load_sharded'
SAVE_COMMON = 'save_common'
SAVE_SHARDED = 'save_sharded'


Expand Down Expand Up @@ -56,7 +54,7 @@ def register_default_strategy(
"""Adds a given strategy to the registry of default strategies.

Args:
action (StrategyAction): specifies save/load and sharded/common
action (StrategyAction): specifies save/load and sharded
backend (str): backend that the strategy becomes a default for
version (int): version that the strategy becomes a default for
strategy (SaveStrategyBase, LoadStrategyBase): strategy to register
Expand Down Expand Up @@ -101,28 +99,6 @@ def __str__(self):
return f'{self.__class__.__name__}({self.backend}, {self.version})'


class LoadCommonStrategy(LoadStrategyBase):
"""Load strategy for common (non-sharded) objects"""

@abstractmethod
def load_common(self, checkpoint_dir: Union[str, Path]):
"""Load common part of the checkpoint."""
raise NotImplementedError

@abstractmethod
def load_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Union[str, Path]
):
"""Load sharded objects from the checkpoint."""
raise NotImplementedError

def load_sharded_metadata(self, checkpoint_dir: Union[str, Path]) -> ShardedStateDict:
"""Load just the metadata from the checkpoint."""
if not self.can_handle_sharded_objects:
return {}
raise NotImplementedError


class LoadShardedStrategy(LoadStrategyBase):
"""Load strategy for sharded tensors"""

Expand Down Expand Up @@ -166,21 +142,6 @@ def remove_sharded_tensors(self, checkpoint_dir: Union[str, Path], key_prefix: s
raise NotImplementedError


class SaveCommonStrategy(SaveStrategyBase):
"""Save strategy for common (non-sharded) objects"""

@abstractmethod
def save_common(self, common_state_dict: StateDict, checkpoint_dir: Union[str, Path]):
"""Save common part of the state dict."""
raise NotImplementedError

def save_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Union[str, Path]
):
"""Save sharded objects from the state dict."""
raise NotImplementedError


class SaveShardedStrategy(SaveStrategyBase):
"""Save strategy for sharded tensors"""

Expand Down
Loading
Loading