Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from xtuner.v1.utils import get_device, get_logger, get_torch_device_module, profile_time_and_memory
from xtuner.v1.utils.grad_norm import cal_grad_norm

from xtuner.v1.engine.xtuner_storage import XtunnerWriter
from xtuner.v1.engine.xtuner_cache_planner import XtunerCacheSavePlanner


logger = get_logger()
DEVICE = get_device()
Expand Down Expand Up @@ -412,15 +415,17 @@ def save_dcp(
model_state = get_model_state_dict(self.model, options=_options)
dcp.save(
model_state,
checkpoint_id=model_dir,
storage_writer=XtunnerWriter(model_dir, enable_write_result_caching=True, cache_key_prefix = "model"),
planner=XtunerCacheSavePlanner(enable_plan_caching=True, cache_key_prefix="model"),
)

with profile_time_and_memory(f"[DCP Checkpoint to {optimizer_dir}]"):
if optimizer_dir is not None:
shard_optimizer_state_dict = get_optimizer_state_dict(self.model, self.optimizer, options=_options)
dcp.save(
shard_optimizer_state_dict,
checkpoint_id=optimizer_dir,
storage_writer=XtunnerWriter(optimizer_dir, enable_write_result_caching=True, cache_key_prefix = "optimizer"),
planner=XtunerCacheSavePlanner(enable_plan_caching=True, cache_key_prefix="optimizer"),
)

def load_dcp(
Expand Down
100 changes: 100 additions & 0 deletions xtuner/v1/engine/xtuner_cache_planner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Optional

from torch.distributed.checkpoint import SavePlanner
from torch.distributed.checkpoint import DefaultSavePlanner, SavePlan, Metadata
from torch.distributed.checkpoint.planner_helpers import (
_compare_save_plans,
_merge_delta_local_plans,
)


# copy from torch 2.8.0 planner_helpers.py
def _contains_usable_plan(delta_plans: list[SavePlan]) -> bool:
"""
Check if any delta plan is usable, indicating the plan has changed.

Args:
delta_plans (List[SavePlan]): A list of delta plans to check.
Returns:
True if any delta plan is usable, False otherwise.
"""
return any(delta_plan and delta_plan.usable for delta_plan in delta_plans)


class XtunerCacheSavePlanner(DefaultSavePlanner):
# Metadata for the global checkpoint plan as computed by `create_global_plan` API.
# Cached on the coordinator rank.
_cached_metadata: dict[str, Metadata] = {}

def __init__(
self,
flatten_state_dict: bool = True,
flatten_sharded_tensors: bool = True,
dedup_replicated_tensors: Optional[bool] = None,
dedup_save_to_lowest_rank: bool = False,
enable_plan_caching: bool = False,
cache_key_prefix: str = ""
) -> None:
super().__init__(flatten_state_dict, flatten_sharded_tensors, dedup_replicated_tensors, dedup_save_to_lowest_rank, enable_plan_caching)
self._cached_plans_key: str = cache_key_prefix + self.__class__.__name__

def _create_global_plan_with_caching(
self, all_plans: list[SavePlan]
) -> tuple[list[SavePlan], list[SavePlan], Metadata]:

if hasattr(SavePlanner, "_cached_metadata"):
# adaptor for torch >= 2.8.0
return super()._create_global_plan_with_caching(all_plans)

# ONLY cache ``_cached_metadata`` in XtunerCacheSavePlanner
global_plan_delta: list[SavePlan] = []

if self._cached_plans_key not in SavePlanner._cached_all_plans:
# Case 1: If the plans are not cached, the cache will be hydrated with the
# all_plans, global_plans (Deduped), and metadata.

# Cache the original all_plans
SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans
global_plan, metadata = self._create_global_plan(all_plans)
# Cache the deduped and validated global_plan
SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
# Cache the metadata
XtunerCacheSavePlanner._cached_metadata[self._cached_plans_key] = metadata
# If plans are not cached, global_plan delta will be the same as global plan.
return global_plan, global_plan, metadata

# Case 2: Plans are cached
if not _contains_usable_plan(all_plans):
# Case 2.1: Plans are cached and the local plans have NOT changed (No usable plans).
# Global plan delta will be empty plans to avoid the collective overhead.
# We can reuse the deduped global plan and metadata from the cache directly.
global_plan_delta = [SavePlan([], usable=False)] * len(all_plans)
global_plan = SavePlanner._cached_global_plan[self._cached_plans_key]
metadata = XtunerCacheSavePlanner._cached_metadata[self._cached_plans_key]
else:
# Case 2.2: Plans are cached but the local plans have changed.
# We will merge the changed local plans with the cached local plans.
# Updated plans will overwrite the cached plans. New global plan and metadata will be created and cached.
# Global plan delta will be created by comparing the new global plan with the cached global plan.
# Only the global plan delta (updated ones) will be sent to the coordinator to avoid the collective overhead.
merged_plans = _merge_delta_local_plans(
SavePlanner._cached_all_plans[self._cached_plans_key], all_plans
)
# Cache the updated local plans
SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans
global_plan, metadata = self._create_global_plan(merged_plans)

if self._cached_plans_key in self._cached_global_plan:
for cached_plan, new_plan in zip(
SavePlanner._cached_global_plan[self._cached_plans_key], global_plan
):
if _compare_save_plans(cached_plan, new_plan):
global_plan_delta.append(SavePlan([], usable=False))
else:
global_plan_delta.append(new_plan)

# Cache the new global plan and the metadata
SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
XtunerCacheSavePlanner._cached_metadata[self._cached_plans_key] = metadata

return global_plan_delta, global_plan, metadata
133 changes: 133 additions & 0 deletions xtuner/v1/engine/xtuner_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os
from typing import Optional, Union
from collections.abc import Sequence

from torch.distributed.checkpoint import SavePlan, SavePlanner
from torch.distributed.checkpoint import FileSystemWriter, Metadata
from torch.distributed.checkpoint._extension import (
StreamTransformExtension,
)
from torch.distributed.checkpoint.storage import (
WriteResult,
)
from torch.futures import Future

import torch.distributed as dist



def _compare_write_results(write_results: list[WriteResult], other_write_results: list[WriteResult]) -> bool:
"""
Compare the two WriteResults and return True if they are equal.

Args:
plan (SavePlan): First SavePlan to compare.
other_plan (SavePlan): Second SavePlan to compare.

Returns:
True if the two plans are equal, False otherwise.
"""

# Both the plans should have the same number of items
if len(write_results) != len(other_write_results):
return False

# Both the plans should have the same write items.
for write_item, other_write_item in zip(write_results, other_write_results):
# Write item type should be same
if write_item != other_write_item:
return False

return True


def _contains_new_write_results(results: list[list[WriteResult]]) -> bool:
return any(delta_result for delta_result in results)


class XtunnerWriter(FileSystemWriter):
# Save write results for the current rank as computed by `write_data` API
# Cached on the local rank.
_cache_write_results: dict[str, list[WriteResult]] = {}

# Collection of all the write results from all the ranks.
# This is the ``results`` input to the `finish` API.
# Cached on the coordinator rank.
_cached_all_write_results: dict[str, list[list[WriteResult]]] = {}

def __init__(
self,
path: Union[str, os.PathLike],
single_file_per_rank: bool = True,
sync_files: bool = True,
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
cache_staged_state_dict: bool = False,
overwrite: bool = True,
_extensions: Optional[Sequence[StreamTransformExtension]] = None,
enable_write_result_caching: bool = False,
cache_key_prefix: str = "",
) -> None:
super().__init__(
path,
single_file_per_rank=single_file_per_rank,
sync_files=sync_files,
thread_count=thread_count,
per_thread_copy_ahead=per_thread_copy_ahead,
cache_staged_state_dict=cache_staged_state_dict,
overwrite=overwrite,
_extensions=_extensions,
)
self._enable_write_result_caching = enable_write_result_caching
self._cached_write_results_key = cache_key_prefix + self.__class__.__name__


def write_data(
self,
plan: SavePlan,
planner: SavePlanner,
) -> Future[list[WriteResult]]:
all_writes_fut = super().write_data(plan, planner)

if self._enable_write_result_caching:
all_writes_fut = self._get_write_future_with_caching(all_writes_fut)
return all_writes_fut

def _get_write_future_with_caching(self, all_writes_fut):
new_fut: Future[list[WriteResult]] = Future()
all_writes_fut.wait()

if self._cached_write_results_key not in XtunnerWriter._cache_write_results:
# Case 1: If the write results are not cached,.............
XtunnerWriter._cache_write_results[self._cached_write_results_key] = all_writes_fut.value()
new_fut.set_result(all_writes_fut.value())
elif _compare_write_results(all_writes_fut.value(), XtunnerWriter._cache_write_results[self._cached_write_results_key]):
# Case 2: equal
new_fut.set_result([])
else:
# Case 3: not equal
XtunnerWriter._cache_write_results[self._cached_write_results_key] = all_writes_fut.value()
new_fut.set_result(all_writes_fut.value())

return new_fut


def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
if self._enable_write_result_caching:
results = self._get_results_from_caching(results)

super().finish(metadata, results)


def _get_results_from_caching(self, results: list[list[WriteResult]]):
if self._cached_write_results_key not in XtunnerWriter._cached_all_write_results:
# Case 1:
XtunnerWriter._cached_all_write_results[self._cached_write_results_key] = results
elif not _contains_new_write_results(results):
# Case 2: no new
results = XtunnerWriter._cached_all_write_results[self._cached_write_results_key]
else:
# Case 3: not equal TODO: merge
XtunnerWriter._cached_all_write_results[self._cached_write_results_key] = results

return results
4 changes: 2 additions & 2 deletions xtuner/v1/patch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import torch_shape_env_simplify_pt28
from .torch_dcp_planner import patch_default_save_plan
from .torch_dcp_planner import patch_default_save_plan, patch_dcp_save_state_dict


__all__ = ["patch_default_save_plan", "torch_shape_env_simplify_pt28"]
__all__ = ["patch_default_save_plan", "torch_shape_env_simplify_pt28","patch_dcp_save_state_dict"]
92 changes: 92 additions & 0 deletions xtuner/v1/patch/torch_dcp_planner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
import inspect
import warnings
from typing import Optional


import torch
import torch.distributed as dist
from torch.distributed.checkpoint import state_dict_saver
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.logger import _dcp_method_logger
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner
from torch.distributed.checkpoint.storage import StorageWriter
from torch.distributed.checkpoint.utils import _DistWrapper
import torch.distributed.checkpoint.default_planner as torch_default_runner


Expand All @@ -7,3 +21,81 @@ def fake_validate_global_plan(*args, **kwargs):

def patch_default_save_plan():
torch_default_runner._validate_global_plan = fake_validate_global_plan

def _xtunner_save_state_dict(
state_dict: STATE_DICT_TYPE,
storage_writer: StorageWriter,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
planner: Optional[SavePlanner] = None,
) -> Metadata:
torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict")

distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
if planner is None:
planner = DefaultSavePlanner()
assert planner is not None

global_metadata = None

ckpt_kwargs = {}
if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None:
ckpt_kwargs["checkpoint_id"] = ckpt_id
ckpt_kwargs["process_group"] = distW.group

@_dcp_method_logger(**ckpt_kwargs)
def local_step():
assert planner is not None
storage_meta = storage_writer.storage_meta()
if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters:
warnings.warn(
"The function definition for SavePlanner.set_up_planner has been updated"
" to include the storage_meta argument. Please update your implementation"
" to include this parameter."
)
planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type]
else:
planner.set_up_planner(
state_dict=state_dict,
storage_meta=storage_meta,
is_coordinator=distW.is_coordinator,
)
storage_writer.set_up_storage_writer(distW.is_coordinator)

local_plan = planner.create_local_plan()
local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan

@_dcp_method_logger(**ckpt_kwargs)
def global_step(all_local_plans):
nonlocal global_metadata

assert planner is not None
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans

central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step)

@_dcp_method_logger(**ckpt_kwargs)
def write_data():
assert planner is not None
final_local_plan = planner.finish_plan(central_plan)
all_writes = storage_writer.write_data(final_local_plan, planner)

all_writes.wait()
return all_writes.value()

@_dcp_method_logger(**ckpt_kwargs)
def finish_checkpoint(all_results):
assert global_metadata is not None
storage_writer.finish(metadata=global_metadata, results=all_results)
# return global_metadata
return Metadata(state_dict_metadata={}) # This is a patch to avoid broadcast overhead.

return distW.all_reduce("write", write_data, finish_checkpoint)


def patch_dcp_save_state_dict():
state_dict_saver._save_state_dict = _xtunner_save_state_dict
Loading