From c4734e976ad500623f48885d6e051360e8c3a063 Mon Sep 17 00:00:00 2001 From: wentiange Date: Wed, 10 Dec 2025 14:46:44 +0000 Subject: [PATCH] =?UTF-8?q?[FEATURE]=20=E9=92=88=E5=AF=B9torch=3D2.7.1?= =?UTF-8?q?=E7=9A=84dcp.save=E8=80=97=E6=97=B6=E4=BC=98=E5=8C=96=EF=BC=8C?= =?UTF-8?q?=E4=BD=BF=E8=83=BD=E5=A2=9E=E9=87=8F=E4=BF=9D=E5=AD=98=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- xtuner/v1/engine/train_engine.py | 9 +- xtuner/v1/engine/xtuner_cache_planner.py | 100 +++++++++++++++++ xtuner/v1/engine/xtuner_storage.py | 133 +++++++++++++++++++++++ xtuner/v1/patch/__init__.py | 4 +- xtuner/v1/patch/torch_dcp_planner.py | 92 ++++++++++++++++ xtuner/v1/train/trainer.py | 8 +- 6 files changed, 341 insertions(+), 5 deletions(-) create mode 100644 xtuner/v1/engine/xtuner_cache_planner.py create mode 100644 xtuner/v1/engine/xtuner_storage.py diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 4b3e99873..de1bb2f71 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -32,6 +32,9 @@ from xtuner.v1.utils import get_device, get_logger, get_torch_device_module, profile_time_and_memory from xtuner.v1.utils.grad_norm import cal_grad_norm +from xtuner.v1.engine.xtuner_storage import XtunnerWriter +from xtuner.v1.engine.xtuner_cache_planner import XtunerCacheSavePlanner + logger = get_logger() DEVICE = get_device() @@ -412,7 +415,8 @@ def save_dcp( model_state = get_model_state_dict(self.model, options=_options) dcp.save( model_state, - checkpoint_id=model_dir, + storage_writer=XtunnerWriter(model_dir, enable_write_result_caching=True, cache_key_prefix = "model"), + planner=XtunerCacheSavePlanner(enable_plan_caching=True, cache_key_prefix="model"), ) with profile_time_and_memory(f"[DCP Checkpoint to {optimizer_dir}]"): @@ -420,7 +424,8 @@ def save_dcp( shard_optimizer_state_dict = get_optimizer_state_dict(self.model, self.optimizer, options=_options) dcp.save( shard_optimizer_state_dict, - checkpoint_id=optimizer_dir, + storage_writer=XtunnerWriter(optimizer_dir, enable_write_result_caching=True, cache_key_prefix = "optimizer"), + planner=XtunerCacheSavePlanner(enable_plan_caching=True, cache_key_prefix="optimizer"), ) def load_dcp( diff --git a/xtuner/v1/engine/xtuner_cache_planner.py b/xtuner/v1/engine/xtuner_cache_planner.py new file mode 100644 index 000000000..f8bd9ab59 --- /dev/null +++ b/xtuner/v1/engine/xtuner_cache_planner.py @@ -0,0 +1,100 @@ +from typing import Optional + +from torch.distributed.checkpoint import SavePlanner +from torch.distributed.checkpoint import DefaultSavePlanner, SavePlan, Metadata +from torch.distributed.checkpoint.planner_helpers import ( + _compare_save_plans, + _merge_delta_local_plans, +) + + +# copy from torch 2.8.0 planner_helpers.py +def _contains_usable_plan(delta_plans: list[SavePlan]) -> bool: + """ + Check if any delta plan is usable, indicating the plan has changed. + + Args: + delta_plans (List[SavePlan]): A list of delta plans to check. + Returns: + True if any delta plan is usable, False otherwise. + """ + return any(delta_plan and delta_plan.usable for delta_plan in delta_plans) + + +class XtunerCacheSavePlanner(DefaultSavePlanner): + # Metadata for the global checkpoint plan as computed by `create_global_plan` API. + # Cached on the coordinator rank. + _cached_metadata: dict[str, Metadata] = {} + + def __init__( + self, + flatten_state_dict: bool = True, + flatten_sharded_tensors: bool = True, + dedup_replicated_tensors: Optional[bool] = None, + dedup_save_to_lowest_rank: bool = False, + enable_plan_caching: bool = False, + cache_key_prefix: str = "" + ) -> None: + super().__init__(flatten_state_dict, flatten_sharded_tensors, dedup_replicated_tensors, dedup_save_to_lowest_rank, enable_plan_caching) + self._cached_plans_key: str = cache_key_prefix + self.__class__.__name__ + + def _create_global_plan_with_caching( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], list[SavePlan], Metadata]: + + if hasattr(SavePlanner, "_cached_metadata"): + # adaptor for torch >= 2.8.0 + return super()._create_global_plan_with_caching(all_plans) + + # ONLY cache ``_cached_metadata`` in XtunerCacheSavePlanner + global_plan_delta: list[SavePlan] = [] + + if self._cached_plans_key not in SavePlanner._cached_all_plans: + # Case 1: If the plans are not cached, the cache will be hydrated with the + # all_plans, global_plans (Deduped), and metadata. + + # Cache the original all_plans + SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans + global_plan, metadata = self._create_global_plan(all_plans) + # Cache the deduped and validated global_plan + SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan + # Cache the metadata + XtunerCacheSavePlanner._cached_metadata[self._cached_plans_key] = metadata + # If plans are not cached, global_plan delta will be the same as global plan. + return global_plan, global_plan, metadata + + # Case 2: Plans are cached + if not _contains_usable_plan(all_plans): + # Case 2.1: Plans are cached and the local plans have NOT changed (No usable plans). + # Global plan delta will be empty plans to avoid the collective overhead. + # We can reuse the deduped global plan and metadata from the cache directly. + global_plan_delta = [SavePlan([], usable=False)] * len(all_plans) + global_plan = SavePlanner._cached_global_plan[self._cached_plans_key] + metadata = XtunerCacheSavePlanner._cached_metadata[self._cached_plans_key] + else: + # Case 2.2: Plans are cached but the local plans have changed. + # We will merge the changed local plans with the cached local plans. + # Updated plans will overwrite the cached plans. New global plan and metadata will be created and cached. + # Global plan delta will be created by comparing the new global plan with the cached global plan. + # Only the global plan delta (updated ones) will be sent to the coordinator to avoid the collective overhead. + merged_plans = _merge_delta_local_plans( + SavePlanner._cached_all_plans[self._cached_plans_key], all_plans + ) + # Cache the updated local plans + SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans + global_plan, metadata = self._create_global_plan(merged_plans) + + if self._cached_plans_key in self._cached_global_plan: + for cached_plan, new_plan in zip( + SavePlanner._cached_global_plan[self._cached_plans_key], global_plan + ): + if _compare_save_plans(cached_plan, new_plan): + global_plan_delta.append(SavePlan([], usable=False)) + else: + global_plan_delta.append(new_plan) + + # Cache the new global plan and the metadata + SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan + XtunerCacheSavePlanner._cached_metadata[self._cached_plans_key] = metadata + + return global_plan_delta, global_plan, metadata \ No newline at end of file diff --git a/xtuner/v1/engine/xtuner_storage.py b/xtuner/v1/engine/xtuner_storage.py new file mode 100644 index 000000000..d01829c60 --- /dev/null +++ b/xtuner/v1/engine/xtuner_storage.py @@ -0,0 +1,133 @@ +import os +from typing import Optional, Union +from collections.abc import Sequence + +from torch.distributed.checkpoint import SavePlan, SavePlanner +from torch.distributed.checkpoint import FileSystemWriter, Metadata +from torch.distributed.checkpoint._extension import ( + StreamTransformExtension, +) +from torch.distributed.checkpoint.storage import ( + WriteResult, +) +from torch.futures import Future + +import torch.distributed as dist + + + +def _compare_write_results(write_results: list[WriteResult], other_write_results: list[WriteResult]) -> bool: + """ + Compare the two WriteResults and return True if they are equal. + + Args: + plan (SavePlan): First SavePlan to compare. + other_plan (SavePlan): Second SavePlan to compare. + + Returns: + True if the two plans are equal, False otherwise. + """ + + # Both the plans should have the same number of items + if len(write_results) != len(other_write_results): + return False + + # Both the plans should have the same write items. + for write_item, other_write_item in zip(write_results, other_write_results): + # Write item type should be same + if write_item != other_write_item: + return False + + return True + + +def _contains_new_write_results(results: list[list[WriteResult]]) -> bool: + return any(delta_result for delta_result in results) + + +class XtunnerWriter(FileSystemWriter): + # Save write results for the current rank as computed by `write_data` API + # Cached on the local rank. + _cache_write_results: dict[str, list[WriteResult]] = {} + + # Collection of all the write results from all the ranks. + # This is the ``results`` input to the `finish` API. + # Cached on the coordinator rank. + _cached_all_write_results: dict[str, list[list[WriteResult]]] = {} + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + cache_staged_state_dict: bool = False, + overwrite: bool = True, + _extensions: Optional[Sequence[StreamTransformExtension]] = None, + enable_write_result_caching: bool = False, + cache_key_prefix: str = "", + ) -> None: + super().__init__( + path, + single_file_per_rank=single_file_per_rank, + sync_files=sync_files, + thread_count=thread_count, + per_thread_copy_ahead=per_thread_copy_ahead, + cache_staged_state_dict=cache_staged_state_dict, + overwrite=overwrite, + _extensions=_extensions, + ) + self._enable_write_result_caching = enable_write_result_caching + self._cached_write_results_key = cache_key_prefix + self.__class__.__name__ + + + def write_data( + self, + plan: SavePlan, + planner: SavePlanner, + ) -> Future[list[WriteResult]]: + all_writes_fut = super().write_data(plan, planner) + + if self._enable_write_result_caching: + all_writes_fut = self._get_write_future_with_caching(all_writes_fut) + return all_writes_fut + + def _get_write_future_with_caching(self, all_writes_fut): + new_fut: Future[list[WriteResult]] = Future() + all_writes_fut.wait() + + if self._cached_write_results_key not in XtunnerWriter._cache_write_results: + # Case 1: If the write results are not cached,............. + XtunnerWriter._cache_write_results[self._cached_write_results_key] = all_writes_fut.value() + new_fut.set_result(all_writes_fut.value()) + elif _compare_write_results(all_writes_fut.value(), XtunnerWriter._cache_write_results[self._cached_write_results_key]): + # Case 2: equal + new_fut.set_result([]) + else: + # Case 3: not equal + XtunnerWriter._cache_write_results[self._cached_write_results_key] = all_writes_fut.value() + new_fut.set_result(all_writes_fut.value()) + + return new_fut + + + def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: + if self._enable_write_result_caching: + results = self._get_results_from_caching(results) + + super().finish(metadata, results) + + + def _get_results_from_caching(self, results: list[list[WriteResult]]): + if self._cached_write_results_key not in XtunnerWriter._cached_all_write_results: + # Case 1: + XtunnerWriter._cached_all_write_results[self._cached_write_results_key] = results + elif not _contains_new_write_results(results): + # Case 2: no new + results = XtunnerWriter._cached_all_write_results[self._cached_write_results_key] + else: + # Case 3: not equal TODO: merge + XtunnerWriter._cached_all_write_results[self._cached_write_results_key] = results + + return results \ No newline at end of file diff --git a/xtuner/v1/patch/__init__.py b/xtuner/v1/patch/__init__.py index 8458df843..703d06ff1 100644 --- a/xtuner/v1/patch/__init__.py +++ b/xtuner/v1/patch/__init__.py @@ -1,5 +1,5 @@ from . import torch_shape_env_simplify_pt28 -from .torch_dcp_planner import patch_default_save_plan +from .torch_dcp_planner import patch_default_save_plan, patch_dcp_save_state_dict -__all__ = ["patch_default_save_plan", "torch_shape_env_simplify_pt28"] +__all__ = ["patch_default_save_plan", "torch_shape_env_simplify_pt28","patch_dcp_save_state_dict"] diff --git a/xtuner/v1/patch/torch_dcp_planner.py b/xtuner/v1/patch/torch_dcp_planner.py index e03e63d05..f9d08818e 100644 --- a/xtuner/v1/patch/torch_dcp_planner.py +++ b/xtuner/v1/patch/torch_dcp_planner.py @@ -1,3 +1,17 @@ +import inspect +import warnings +from typing import Optional + + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint import state_dict_saver +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner +from torch.distributed.checkpoint.logger import _dcp_method_logger +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlan, SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter +from torch.distributed.checkpoint.utils import _DistWrapper import torch.distributed.checkpoint.default_planner as torch_default_runner @@ -7,3 +21,81 @@ def fake_validate_global_plan(*args, **kwargs): def patch_default_save_plan(): torch_default_runner._validate_global_plan = fake_validate_global_plan + +def _xtunner_save_state_dict( + state_dict: STATE_DICT_TYPE, + storage_writer: StorageWriter, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, +) -> Metadata: + torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict") + + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if planner is None: + planner = DefaultSavePlanner() + assert planner is not None + + global_metadata = None + + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = distW.group + + @_dcp_method_logger(**ckpt_kwargs) + def local_step(): + assert planner is not None + storage_meta = storage_writer.storage_meta() + if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters: + warnings.warn( + "The function definition for SavePlanner.set_up_planner has been updated" + " to include the storage_meta argument. Please update your implementation" + " to include this parameter." + ) + planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type] + else: + planner.set_up_planner( + state_dict=state_dict, + storage_meta=storage_meta, + is_coordinator=distW.is_coordinator, + ) + storage_writer.set_up_storage_writer(distW.is_coordinator) + + local_plan = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan) + return local_plan + + @_dcp_method_logger(**ckpt_kwargs) + def global_step(all_local_plans): + nonlocal global_metadata + + assert planner is not None + all_local_plans, global_metadata = planner.create_global_plan(all_local_plans) + all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + return all_local_plans + + central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step) + + @_dcp_method_logger(**ckpt_kwargs) + def write_data(): + assert planner is not None + final_local_plan = planner.finish_plan(central_plan) + all_writes = storage_writer.write_data(final_local_plan, planner) + + all_writes.wait() + return all_writes.value() + + @_dcp_method_logger(**ckpt_kwargs) + def finish_checkpoint(all_results): + assert global_metadata is not None + storage_writer.finish(metadata=global_metadata, results=all_results) + # return global_metadata + return Metadata(state_dict_metadata={}) # This is a patch to avoid broadcast overhead. + + return distW.all_reduce("write", write_data, finish_checkpoint) + + +def patch_dcp_save_state_dict(): + state_dict_saver._save_state_dict = _xtunner_save_state_dict \ No newline at end of file diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 5d3ba020b..005c8fa36 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -35,7 +35,7 @@ from xtuner.v1.loss.ce_loss import CELossContextInputItem from xtuner.v1.model.base import ModelItem, TransformerConfig from xtuner.v1.model.utils import ModelForwardExtraLogInfo -from xtuner.v1.patch import patch_default_save_plan +from xtuner.v1.patch import patch_default_save_plan, patch_dcp_save_state_dict from xtuner.v1.profiler import profiling_memory, profiling_time from xtuner.v1.profiler.prober import ProberList from xtuner.v1.profiler.prober_utils import setup_prober_list @@ -171,6 +171,7 @@ class TrainerConfig(BaseModel): checkpoint_interval: int | None = -1 checkpoint_maxkeep: int | None = -1 skip_checkpoint_validation: bool = False # Suggest enabled if fsdp_size is larger than 512 + patch_for_dcp_finish: bool = False snapshot_interval: int | None = None hf_interval: int | None = None hf_max_keep: int | None = None @@ -288,6 +289,7 @@ def __init__( checkpoint_interval: int | None = -1, checkpoint_maxkeep: int | None = -1, skip_checkpoint_validation: bool = False, # Suggest enabled if fsdp_size is larger than 512 + patch_for_dcp_finish: bool = False, snapshot_interval: int | None = None, hf_interval: int | None = None, hf_max_keep: int | None = None, @@ -320,6 +322,9 @@ def __init__( if skip_checkpoint_validation: patch_default_save_plan() + if patch_for_dcp_finish: + patch_dcp_save_state_dict() + if isinstance(profile_step, int): profile_step = [profile_step] self._profile_step = profile_step @@ -480,6 +485,7 @@ def from_config(cls, config: TrainerConfig) -> Self: checkpoint_interval=config.checkpoint_interval, checkpoint_maxkeep=config.checkpoint_maxkeep, skip_checkpoint_validation=config.skip_checkpoint_validation, + patch_for_dcp_finish=config.patch_for_dcp_finish, snapshot_interval=config.snapshot_interval, hf_interval=config.hf_interval, hf_max_keep=config.hf_max_keep,