From fd98ff54727cb83deac95eb5101e64e990bf7a95 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 9 Feb 2026 08:46:05 +0800 Subject: [PATCH 1/3] wip --- cookbook/sft/ep_fsdp_qwen3_moe.py | 6 +- .../model/transformers/moe/expert_parallel.py | 18 +++- .../transformers/strategy/native_fsdp.py | 89 ++++++++++++++++++- .../model/transformers/transformers.py | 14 +++ src/twinkle/utils/platform.py | 35 ++++++-- tests/DeviceMesh/test_device_mesh.py | 18 ++++ 6 files changed, 168 insertions(+), 12 deletions(-) diff --git a/cookbook/sft/ep_fsdp_qwen3_moe.py b/cookbook/sft/ep_fsdp_qwen3_moe.py index 68c6d15a..468bfd5d 100644 --- a/cookbook/sft/ep_fsdp_qwen3_moe.py +++ b/cookbook/sft/ep_fsdp_qwen3_moe.py @@ -20,11 +20,12 @@ PROCESSOR_ID = os.environ.get('QWEN3_PROCESSOR_ID', 'AlpacaProcessor') NUM_LAYERS = int(os.environ.get('QWEN3_NUM_LAYERS', '1')) -# 4 GPUs: dp=2, ep=2 +# 4 GPUs: ep=2, ep_fsdp=2 device_mesh = DeviceMesh.from_sizes( device_type=Platform.get_platform().device_prefix(), - dp_size=2, + dp_size=None, ep_size=2, + ep_fsdp_size=2, ) twinkle.initialize( @@ -80,6 +81,7 @@ def train(): "router_dtype": "fp32", "all_to_all": "torch", "keep_router_logits": False, + "ep_fsdp": True, } }, ) diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index 99108bb7..68466b72 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -21,6 +21,7 @@ class ExpertParallelConfig: keep_router_logits: bool = True pad_to_max: bool = False ignore_shared_experts: bool = False + ep_fsdp: bool = False def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: Optional[Dict[str, Any]] = None): @@ -32,6 +33,13 @@ def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: Opt if ep_world_size <= 1: return model + # Only explicit ep_fsdp config enables this mode. + ep_fsdp_enabled = bool( + cfg.ep_fsdp + and device_mesh.has_dim("ep_fsdp") + and (device_mesh.ep_fsdp_world_size or 1) > 1 + ) + if cfg.pad_to_max: raise NotImplementedError("pad_to_max is not implemented.") if cfg.all_to_all != "torch": @@ -45,7 +53,7 @@ def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: Opt raise RuntimeError("EP process group is not available in device_mesh.") for block in find_moe_blocks(model): - shard_experts(block, device_mesh, cfg) + shard_experts(block, device_mesh, cfg, ep_fsdp_enabled=ep_fsdp_enabled) patch_forward(block, device_mesh, cfg) return model @@ -76,7 +84,7 @@ def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]: return blocks -def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None: +def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig, *, ep_fsdp_enabled: bool) -> None: num_experts = _get_num_experts(block) ep_world_size = device_mesh.ep_world_size ep_rank = device_mesh.ep_rank @@ -91,6 +99,11 @@ def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel local_end = local_start + experts_per_rank if isinstance(block.experts, nn.ModuleList): + if ep_fsdp_enabled: + raise NotImplementedError( + "EP+EP_FSDP currently does not support MoE experts stored as nn.ModuleList. " + "Only tensor experts (gate_up_proj/down_proj) are supported." + ) local_experts = nn.ModuleList(block.experts[local_start:local_end]) block.experts = local_experts block._ep_tensor_experts = False @@ -105,6 +118,7 @@ def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel block._ep_rank = ep_rank block._ep_world_size = ep_world_size block._ep_ignore_shared_experts = cfg.ignore_shared_experts + block._ep_fsdp_enabled = ep_fsdp_enabled def patch_forward(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None: diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index f2f34534..6abae266 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -16,11 +16,13 @@ def __init__(self, device_mesh: Optional[DeviceMesh] = None, mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16', fsdp_config: Dict[str, Any] = None, - enable_ep: bool = True): + enable_ep: bool = True, + enable_ep_fsdp: bool = False): self.device_mesh = device_mesh self.mixed_precision = mixed_precision self.fsdp_config = fsdp_config or {} self.enable_ep = enable_ep + self.enable_ep_fsdp = enable_ep_fsdp def wrap_model(self, model, optimizer=None): if self.device_mesh is None: @@ -28,6 +30,11 @@ def wrap_model(self, model, optimizer=None): fsdp_mesh = _build_fsdp_mesh(self.device_mesh) if fsdp_mesh is not None: + ep_fsdp_mode = _is_ep_fsdp_mode_enabled( + self.device_mesh, + self.enable_ep, + self.enable_ep_fsdp, + ) if self.enable_ep: _ensure_moe_patched_if_needed(model, self.device_mesh) _place_ep_experts_on_local_device(model, self.device_mesh) @@ -35,6 +42,20 @@ def wrap_model(self, model, optimizer=None): reshard_after_forward = self.fsdp_config.get("reshard_after_forward", True) ignored_params = _collect_expert_params(model) if self.enable_ep else None + if ep_fsdp_mode: + _ensure_ep_fsdp_supported(model) + ep_fsdp_mesh = _build_ep_fsdp_mesh(self.device_mesh) + if ep_fsdp_mesh is None: + raise RuntimeError( + "expert_parallel.ep_fsdp is enabled but could not build an ep_fsdp device mesh." + ) + _maybe_shard_ep_expert_blocks( + model, + mesh=ep_fsdp_mesh, + reshard_after_forward=reshard_after_forward, + mp_policy=mp_policy, + ) + _maybe_shard_layers( model, mesh=fsdp_mesh, @@ -83,6 +104,23 @@ def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]: return TorchDeviceMesh(device_mesh.device_type, flat_mesh, mesh_dim_names=("fsdp",)) +def _build_ep_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]: + if device_mesh is None or not device_mesh.has_dim("ep_fsdp"): + return None + ranks = device_mesh.get_ranks_for_dims("ep_fsdp") + if len(ranks) <= 1: + return None + return TorchDeviceMesh(device_mesh.device_type, ranks, mesh_dim_names=("ep_fsdp",)) + + +def _is_ep_fsdp_mode_enabled(device_mesh: Optional[DeviceMesh], enable_ep: bool, enable_ep_fsdp: bool) -> bool: + if not enable_ep or not enable_ep_fsdp or device_mesh is None: + return False + if not device_mesh.has_dim("ep_fsdp"): + return False + return (device_mesh.ep_fsdp_world_size or 1) > 1 + + def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]: ignored: Set[nn.Parameter] = set() ep_patched = False @@ -107,6 +145,18 @@ def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]: return ignored or None +def _collect_block_expert_params(block: nn.Module) -> Set[nn.Parameter]: + experts = getattr(block, "experts", None) + if experts is None: + return set() + if isinstance(experts, nn.ModuleList): + params: Set[nn.Parameter] = set() + for expert in experts: + params.update(expert.parameters()) + return params + return set(experts.parameters()) + + def _place_ep_experts_on_local_device(model: nn.Module, device_mesh: DeviceMesh) -> None: ep_world_size = device_mesh.ep_world_size or 1 if ep_world_size <= 1: @@ -137,6 +187,43 @@ def _ensure_moe_patched_if_needed(model: nn.Module, device_mesh: DeviceMesh) -> ) +def _ensure_ep_fsdp_supported(model: nn.Module) -> None: + for module in model.modules(): + if not getattr(module, "_ep_patched", False): + continue + experts = getattr(module, "experts", None) + if isinstance(experts, nn.ModuleList): + raise NotImplementedError( + "EP+EP_FSDP currently does not support MoE experts stored as nn.ModuleList. " + "Only tensor experts (gate_up_proj/down_proj) are supported." + ) + + +def _maybe_shard_ep_expert_blocks(model: nn.Module, + *, + mesh: TorchDeviceMesh, + reshard_after_forward: Optional[bool], + mp_policy: MixedPrecisionPolicy) -> None: + for module in model.modules(): + if not getattr(module, "_ep_patched", False): + continue + experts = getattr(module, "experts", None) + if experts is None: + continue + expert_params = _collect_block_expert_params(module) + if not expert_params: + continue + block_params = set(module.parameters()) + non_expert_params = block_params - expert_params + fully_shard( + module, + mesh=mesh, + reshard_after_forward=reshard_after_forward, + mp_policy=mp_policy, + ignored_params=non_expert_params or None, + ) + + def _maybe_shard_layers(model: nn.Module, *, mesh: TorchDeviceMesh, diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 3960d33d..2462bc72 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -202,6 +202,8 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): self._expert_parallel_config = self._fsdp_config.pop("expert_parallel", None) self._enable_expert_parallel = self._should_enable_expert_parallel( self._expert_parallel_config, self.device_mesh) + self._enable_expert_ep_fsdp = self._should_enable_expert_ep_fsdp( + self._expert_parallel_config, self.device_mesh) self._expert_parallel_applied = False use_native_fsdp = self._enable_expert_parallel or strategy == 'native_fsdp' if use_native_fsdp: @@ -210,6 +212,7 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): fsdp_config=self._fsdp_config, device_mesh=self.device_mesh, enable_ep=self._enable_expert_parallel, + enable_ep_fsdp=self._enable_expert_ep_fsdp, ) else: self.strategy = AccelerateStrategy(mixed_precision=self.mixed_precision, ddp_config=self._ddp_config, @@ -292,6 +295,17 @@ def _maybe_apply_expert_parallel(self): ) self._expert_parallel_applied = True + @staticmethod + def _should_enable_expert_ep_fsdp(expert_parallel_config: Optional[Dict[str, Any]], + device_mesh: Optional[DeviceMesh]) -> bool: + if expert_parallel_config is None or device_mesh is None: + return False + if not expert_parallel_config.get("ep_fsdp", False): + return False + if not device_mesh.has_dim("ep_fsdp"): + return False + return (device_mesh.ep_fsdp_world_size or 1) > 1 + def _ensure_optimizer_dp_groups(self): for optimizer_group in self.optimizer_group.values(): if not isinstance(optimizer_group, OptimizerGroup): diff --git a/src/twinkle/utils/platform.py b/src/twinkle/utils/platform.py index 986c4efc..d16d448d 100644 --- a/src/twinkle/utils/platform.py +++ b/src/twinkle/utils/platform.py @@ -22,6 +22,7 @@ class DeviceMesh: - ulysses: ulysses sequence parallel - cp: Context Parallel - ep: Expert Parallel + - ep_fsdp: Expert FSDP Parallel - vpp: Virtual Pipeline Parallel Examples: @@ -34,6 +35,7 @@ class DeviceMesh: mesh: np.ndarray mesh_dim_names: Optional[tuple[str, ...]] ep_size: Optional[int] = None + ep_fsdp_size: Optional[int] = None etp_size: Optional[int] = None # megatron only vpp_size: Optional[int] = None @@ -46,7 +48,8 @@ class DeviceMesh: @staticmethod def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, tp_size: int = None, pp_size: int = None, ulysses_size: int = None, cp_size: int = None, ep_size: int = None, - etp_size: int = None,vpp_size: int = None, device_type: str = 'cuda', sequence_parallel: bool = False) -> "DeviceMesh": + ep_fsdp_size: int = None, etp_size: int = None, vpp_size: int = None, device_type: str = 'cuda', + sequence_parallel: bool = False) -> "DeviceMesh": """Create a default device mesh from the given sizes. Args: @@ -58,6 +61,7 @@ def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, ulysses_size: The ulysses parallel size cp_size: The context parallel size ep_size: The expert parallel size + ep_fsdp_size: The expert fsdp parallel size etp_size: The expert tensor parallel size vpp_size: The virtual pipeline parallel size device_type: The device type @@ -84,13 +88,16 @@ def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, if origin_world_size == 1: world_size *= dp_size mesh_dim_sizes.append(dp_size) - else: - mesh_dim_sizes.append(-1) if ep_size is not None: mesh_dim_sizes.append(ep_size) mesh_dim_names.append("ep") if origin_world_size == 1: world_size *= ep_size + if ep_fsdp_size is not None: + mesh_dim_sizes.append(ep_fsdp_size) + mesh_dim_names.append("ep_fsdp") + if origin_world_size == 1: + world_size *= ep_fsdp_size if cp_size is not None: mesh_dim_sizes.append(cp_size) mesh_dim_names.append("cp") @@ -107,6 +114,7 @@ def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, mesh_dim_names=tuple(mesh_dim_names), vpp_size=vpp_size, ep_size=ep_size, + ep_fsdp_size=ep_fsdp_size, etp_size=etp_size, ulysses_size=ulysses_size, sequence_parallel=sequence_parallel, @@ -116,7 +124,7 @@ def __post_init__(self): if not isinstance(self.mesh, np.ndarray): self.mesh = np.array(self.mesh) - valid_dim_names = {"dp", "fsdp", "tp", "pp", "cp", "ep"} + valid_dim_names = {"dp", "fsdp", "tp", "pp", "cp", "ep", "ep_fsdp"} if self.mesh_dim_names is not None: if len(self.mesh_dim_names) != len(self.mesh.shape): raise ValueError( @@ -128,6 +136,12 @@ def __post_init__(self): def create_process_group(self, dims): """Create a process group by dims""" import torch.distributed as dist + ranks = self.get_ranks_for_dims(dims) + return dist.new_group(ranks=ranks) + + def get_ranks_for_dims(self, dims): + if isinstance(dims, str): + dims = (dims,) rank = dist.get_rank() coords = np.argwhere(self.mesh == rank)[0] slices = [] @@ -137,8 +151,7 @@ def create_process_group(self, dims): else: slices.append(coords[i]) - ranks = sorted(self.mesh[tuple(slices)].flatten().tolist()) - return dist.new_group(ranks=ranks) + return sorted(self.mesh[tuple(slices)].flatten().tolist()) def get_dim_group(self, dims): import torch.distributed as dist @@ -185,7 +198,7 @@ def get_dim_group(self, dims): def order(self): """The order of the dimensions for megatron""" # TODO hard coded for now - return 'tp-cp-ep-dp-pp' + return 'tp-cp-ep-ep_fsdp-dp-pp' def to_torch_device_mesh(self): import torch @@ -263,6 +276,10 @@ def cp_rank(self) -> Optional[int]: def ep_rank(self) -> Optional[int]: return self._get_rank_for_dim("ep") + @property + def ep_fsdp_rank(self) -> Optional[int]: + return self._get_rank_for_dim("ep_fsdp") + @property def dp_world_size(self) -> int: return self._get_world_size_for_dim("dp") @@ -287,6 +304,10 @@ def cp_world_size(self) -> int: def ep_world_size(self) -> Optional[int]: return self._get_world_size_for_dim("ep") + @property + def ep_fsdp_world_size(self) -> Optional[int]: + return self._get_world_size_for_dim("ep_fsdp") + @property def etp_world_size(self) -> int: if self.etp_size is not None: diff --git a/tests/DeviceMesh/test_device_mesh.py b/tests/DeviceMesh/test_device_mesh.py index 51a12d22..05c2a3fc 100644 --- a/tests/DeviceMesh/test_device_mesh.py +++ b/tests/DeviceMesh/test_device_mesh.py @@ -183,6 +183,24 @@ def test_world_sizes(self): assert mesh.pp_world_size == 5 assert mesh.world_size == 2 * 3 * 4 * 5 + def test_ep_fsdp_rank_and_world_size(self): + mesh = DeviceMesh.from_sizes(dp_size=1, ep_size=2, ep_fsdp_size=3) + mesh_array = mesh.mesh.reshape(1, 2, 3) + + for ep_idx in range(2): + for ep_fsdp_idx in range(3): + global_rank = int(mesh_array[0, ep_idx, ep_fsdp_idx]) + with patch.object(Platform, 'get_rank', return_value=global_rank): + assert mesh.ep_rank == ep_idx + assert mesh.ep_fsdp_rank == ep_fsdp_idx + assert mesh.ep_world_size == 2 + assert mesh.ep_fsdp_world_size == 3 + + def test_without_dp_dimension(self): + mesh = DeviceMesh.from_sizes(dp_size=None, ep_size=2, ep_fsdp_size=2) + assert mesh.mesh.shape == (2, 2) + assert mesh.mesh_dim_names == ('ep', 'ep_fsdp') + def test_data_rank_with_dp_only(self): mesh = DeviceMesh.from_sizes(dp_size=4) From 3c7b18ef882b842619aa9a24efa35ada671a55d6 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 13 Feb 2026 09:51:16 +0800 Subject: [PATCH 2/3] wip --- .../model/transformers/moe/expert_parallel.py | 59 ++++- .../transformers/strategy/native_fsdp.py | 64 +++--- .../model/transformers/transformers.py | 14 -- src/twinkle/utils/grad_clip.py | 27 ++- src/twinkle/utils/platform.py | 205 ++++++++++-------- 5 files changed, 219 insertions(+), 150 deletions(-) diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index 68466b72..55134133 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations +import os from dataclasses import dataclass from typing import Any, Dict, Iterable, Optional, Tuple @@ -21,6 +22,7 @@ class ExpertParallelConfig: keep_router_logits: bool = True pad_to_max: bool = False ignore_shared_experts: bool = False + # Deprecated: EP_FSDP is inferred implicitly from dp/ep mesh topology. ep_fsdp: bool = False @@ -33,12 +35,7 @@ def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: Opt if ep_world_size <= 1: return model - # Only explicit ep_fsdp config enables this mode. - ep_fsdp_enabled = bool( - cfg.ep_fsdp - and device_mesh.has_dim("ep_fsdp") - and (device_mesh.ep_fsdp_world_size or 1) > 1 - ) + ep_fsdp_enabled = device_mesh.is_implicit_ep_fsdp_enabled() if cfg.pad_to_max: raise NotImplementedError("pad_to_max is not implemented.") @@ -152,6 +149,8 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): else: raise ValueError(f"Unsupported hidden_states ndim: {hidden_states.ndim}") + _debug_log_ep_fsdp_runtime_once(block, gate, hidden_states_2d) + router_logits, routing_weights, selected_experts, cast_weights = _run_router( gate=gate, hidden_states=hidden_states_2d, @@ -347,6 +346,25 @@ def _run_expert(block: nn.Module, expert_id: int, expert_in: torch.Tensor) -> to expert = block.experts[expert_id] return _run_module_with_casting(expert, expert_in) experts = block.experts + if getattr(block, "_ep_fsdp_enabled", False): + # In EP+EP_FSDP mode, execute experts.forward so FSDP hooks can + # manage unshard/reshard around forward/backward safely. + top_k_index = torch.full( + (expert_in.shape[0], 1), + int(expert_id), + dtype=torch.long, + device=expert_in.device, + ) + top_k_weights = torch.ones( + (expert_in.shape[0], 1), + dtype=expert_in.dtype, + device=expert_in.device, + ) + out = experts(expert_in, top_k_index, top_k_weights) + if out.dtype != input_dtype: + out = out.to(input_dtype) + return out + gate_up = experts.gate_up_proj[expert_id] down = experts.down_proj[expert_id] compute_dtype = gate_up.dtype @@ -397,3 +415,32 @@ def _run_router( if norm_topk_prob: routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) return router_logits, routing_weights, selected_experts, True + + +def _debug_log_ep_fsdp_runtime_once(block: nn.Module, gate: nn.Module, hidden_states: torch.Tensor) -> None: + if os.environ.get("TWINKLE_EP_FSDP_DEBUG", "1") != "1": + return + if not getattr(block, "_ep_fsdp_enabled", False): + return + if getattr(block, "_ep_fsdp_runtime_logged", False): + return + rank = dist.get_rank() if dist.is_initialized() else -1 + experts = getattr(block, "experts", None) + expert_param = next(experts.parameters(), None) if experts is not None else None + gate_param = next(gate.parameters(), None) if gate is not None else None + print( + f"[EP_FSDP][rank{rank}] runtime input={hidden_states.device} " + f"expert_param={_describe_param_mesh(expert_param)} " + f"gate_param={_describe_param_mesh(gate_param)}", + flush=True, + ) + block._ep_fsdp_runtime_logged = True + + +def _describe_param_mesh(param: Optional[nn.Parameter]) -> str: + if param is None: + return "none" + mesh = getattr(param, "device_mesh", None) + if mesh is None: + return f"local:{getattr(param, 'device', 'unknown')}" + return f"dtensor:{tuple(mesh.mesh_dim_names or ())}:{mesh.mesh.flatten().tolist()}" diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 031f4f2b..051cd8cd 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import os from typing import Dict, Any, Optional, Literal, Set, TYPE_CHECKING import torch @@ -18,13 +19,11 @@ def __init__(self, device_mesh: Optional[DeviceMesh] = None, mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16', fsdp_config: Dict[str, Any] = None, - enable_ep: bool = True, - enable_ep_fsdp: bool = False): + enable_ep: bool = True): self.device_mesh = device_mesh self.mixed_precision = mixed_precision self.fsdp_config = fsdp_config or {} self.enable_ep = enable_ep - self.enable_ep_fsdp = enable_ep_fsdp def wrap_model(self, model, optimizer=None): if self.device_mesh is None: @@ -35,7 +34,6 @@ def wrap_model(self, model, optimizer=None): ep_fsdp_mode = _is_ep_fsdp_mode_enabled( self.device_mesh, self.enable_ep, - self.enable_ep_fsdp, ) if self.enable_ep: _ensure_moe_patched_if_needed(model, self.device_mesh) @@ -49,14 +47,15 @@ def wrap_model(self, model, optimizer=None): ep_fsdp_mesh = _build_ep_fsdp_mesh(self.device_mesh) if ep_fsdp_mesh is None: raise RuntimeError( - "expert_parallel.ep_fsdp is enabled but could not build an ep_fsdp device mesh." + "Implicit EP_FSDP requires dp dim with size > 1, but could not build an ep_fsdp mesh from dp." ) - _maybe_shard_ep_expert_blocks( + sharded_blocks = _maybe_shard_ep_expert_blocks( model, mesh=ep_fsdp_mesh, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy, ) + _debug_log_ep_fsdp_sharding(model, self.device_mesh, sharded_blocks) _maybe_shard_layers( model, @@ -108,20 +107,18 @@ def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]: def _build_ep_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]: - if device_mesh is None or not device_mesh.has_dim("ep_fsdp"): + if device_mesh is None or not device_mesh.has_dim("dp"): return None - ranks = device_mesh.get_ranks_for_dims("ep_fsdp") + ranks = device_mesh.get_ranks_for_dims("dp") if len(ranks) <= 1: return None return TorchDeviceMesh(device_mesh.device_type, ranks, mesh_dim_names=("ep_fsdp",)) -def _is_ep_fsdp_mode_enabled(device_mesh: Optional[DeviceMesh], enable_ep: bool, enable_ep_fsdp: bool) -> bool: - if not enable_ep or not enable_ep_fsdp or device_mesh is None: +def _is_ep_fsdp_mode_enabled(device_mesh: Optional[DeviceMesh], enable_ep: bool) -> bool: + if not enable_ep or device_mesh is None: return False - if not device_mesh.has_dim("ep_fsdp"): - return False - return (device_mesh.ep_fsdp_world_size or 1) > 1 + return device_mesh.is_implicit_ep_fsdp_enabled() def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]: @@ -148,18 +145,6 @@ def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]: return ignored or None -def _collect_block_expert_params(block: nn.Module) -> Set[nn.Parameter]: - experts = getattr(block, "experts", None) - if experts is None: - return set() - if isinstance(experts, nn.ModuleList): - params: Set[nn.Parameter] = set() - for expert in experts: - params.update(expert.parameters()) - return params - return set(experts.parameters()) - - def _place_ep_experts_on_local_device(model: nn.Module, device_mesh: DeviceMesh) -> None: ep_world_size = device_mesh.ep_world_size or 1 if ep_world_size <= 1: @@ -206,25 +191,38 @@ def _maybe_shard_ep_expert_blocks(model: nn.Module, *, mesh: TorchDeviceMesh, reshard_after_forward: Optional[bool], - mp_policy: MixedPrecisionPolicy) -> None: + mp_policy: 'MixedPrecisionPolicy') -> int: + from torch.distributed.fsdp import fully_shard + sharded_blocks = 0 for module in model.modules(): if not getattr(module, "_ep_patched", False): continue experts = getattr(module, "experts", None) if experts is None: continue - expert_params = _collect_block_expert_params(module) - if not expert_params: - continue - block_params = set(module.parameters()) - non_expert_params = block_params - expert_params + # Correct EP+EP_FSDP behavior: only experts are sharded on ep_fsdp mesh. + # Non-expert params (router/gate etc.) are left to global FSDP wrapping. fully_shard( - module, + experts, mesh=mesh, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy, - ignored_params=non_expert_params or None, ) + sharded_blocks += 1 + return sharded_blocks + + +def _debug_log_ep_fsdp_sharding(model: nn.Module, device_mesh: DeviceMesh, sharded_blocks: int) -> None: + if os.environ.get("TWINKLE_EP_FSDP_DEBUG", "0") != "1": + return + + rank = Platform.get_rank() + ep_fsdp_ranks = device_mesh.get_ranks_for_dims("dp") + print( + f"[EP_FSDP][rank{rank}] enabled=1 sharded_blocks={sharded_blocks} " + f"ep_fsdp_group={ep_fsdp_ranks}", + flush=True, + ) def _maybe_shard_layers(model: nn.Module, diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 40576abe..0e32efa8 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -213,8 +213,6 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): self._expert_parallel_config = self._fsdp_config.pop("expert_parallel", None) self._enable_expert_parallel = self._should_enable_expert_parallel( self._expert_parallel_config, self.device_mesh) - self._enable_expert_ep_fsdp = self._should_enable_expert_ep_fsdp( - self._expert_parallel_config, self.device_mesh) self._expert_parallel_applied = False use_native_fsdp = self._enable_expert_parallel or strategy == 'native_fsdp' if use_native_fsdp: @@ -223,7 +221,6 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): fsdp_config=self._fsdp_config, device_mesh=self.device_mesh, enable_ep=self._enable_expert_parallel, - enable_ep_fsdp=self._enable_expert_ep_fsdp, ) else: self.strategy = AccelerateStrategy(mixed_precision=self.mixed_precision, ddp_config=self._ddp_config, @@ -303,17 +300,6 @@ def _maybe_apply_expert_parallel(self): ) self._expert_parallel_applied = True - @staticmethod - def _should_enable_expert_ep_fsdp(expert_parallel_config: Optional[Dict[str, Any]], - device_mesh: Optional[DeviceMesh]) -> bool: - if expert_parallel_config is None or device_mesh is None: - return False - if not expert_parallel_config.get("ep_fsdp", False): - return False - if not device_mesh.has_dim("ep_fsdp"): - return False - return (device_mesh.ep_fsdp_world_size or 1) > 1 - def _ensure_optimizer_dp_groups(self): for optimizer_group in self.optimizer_group.values(): if not isinstance(optimizer_group, OptimizerGroup): diff --git a/src/twinkle/utils/grad_clip.py b/src/twinkle/utils/grad_clip.py index 97134d3c..b9d7fb70 100644 --- a/src/twinkle/utils/grad_clip.py +++ b/src/twinkle/utils/grad_clip.py @@ -32,7 +32,23 @@ def normalize_and_clip_grad_norm(parameters: Iterable['torch.nn.Parameter'], has_dtensor_grad = any(hasattr(grad, "to_local") for grad in grads) has_local_tensor_grad = any(not hasattr(grad, "to_local") for grad in grads) - if not (has_dtensor_grad and has_local_tensor_grad): + dtensor_mesh_keys = set() + for grad in grads: + if not hasattr(grad, "to_local"): + continue + mesh = getattr(grad, "device_mesh", None) + if mesh is None: + dtensor_mesh_keys.add("dtensor:unknown") + continue + try: + mesh_key = (tuple(mesh.mesh.flatten().tolist()), tuple(mesh.mesh_dim_names or ())) + except Exception: + mesh_key = repr(mesh) + dtensor_mesh_keys.add(mesh_key) + + has_mixed_dtensor_mesh = len(dtensor_mesh_keys) > 1 + + if not (has_dtensor_grad and has_local_tensor_grad) and not has_mixed_dtensor_mesh: grad_norm = torch.nn.utils.clip_grad_norm_( parameters, max_grad_norm, @@ -62,6 +78,11 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor: reduce_device = torch.device(Platform.get_local_device()) else: reduce_device = torch.device("cpu") + reduce_group = group + if has_mixed_dtensor_mesh: + # Different DTensor meshes cannot be reduced by DTensor op propagation (e.g. aten.stack). + # Fall back to world reduction over local shards. + reduce_group = None if norm_type == float("inf"): local_norm = 0.0 @@ -72,7 +93,7 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor: local_norm = max(local_norm, local_grad.detach().abs().max().item()) total_norm_tensor = torch.tensor(local_norm, device=reduce_device, dtype=torch.float32) if dist.is_initialized(): - dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=group) + dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=reduce_group) total_norm = float(total_norm_tensor.item()) else: local_sq = 0.0 @@ -83,7 +104,7 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor: local_sq += local_grad.detach().float().pow(2).sum().item() total_sq_tensor = torch.tensor(local_sq, device=reduce_device, dtype=torch.float32) if dist.is_initialized(): - dist.all_reduce(total_sq_tensor, op=dist.ReduceOp.SUM, group=group) + dist.all_reduce(total_sq_tensor, op=dist.ReduceOp.SUM, group=reduce_group) total_norm = float(total_sq_tensor.sqrt().item()) clip_coef = float(max_grad_norm) / (total_norm + 1e-6) diff --git a/src/twinkle/utils/platform.py b/src/twinkle/utils/platform.py index f2641dfe..f5e28188 100644 --- a/src/twinkle/utils/platform.py +++ b/src/twinkle/utils/platform.py @@ -5,10 +5,10 @@ import subprocess from abc import ABC from dataclasses import dataclass, field -from functools import lru_cache from itertools import product -from typing import Dict, List, Optional, Type, Union - +from functools import lru_cache +from typing import Any, Dict, List, Optional, Type, Union +import torch.distributed as dist import numpy as np @@ -23,7 +23,6 @@ class DeviceMesh: - sequence_parallel: megatron sequence parallel - cp: Context Parallel - ep: Expert Parallel - - ep_fsdp: Expert FSDP Parallel - vpp: Virtual Pipeline Parallel Examples: @@ -36,7 +35,6 @@ class DeviceMesh: mesh: np.ndarray mesh_dim_names: Optional[tuple[str, ...]] ep_size: Optional[int] = None - ep_fsdp_size: Optional[int] = None etp_size: Optional[int] = None # megatron only vpp_size: Optional[int] = None @@ -49,8 +47,7 @@ class DeviceMesh: @staticmethod def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, tp_size: int = None, pp_size: int = None, ulysses_size: int = None, cp_size: int = None, ep_size: int = None, - ep_fsdp_size: int = None, etp_size: int = 1, vpp_size: int = None, device_type: str = 'cuda', - sequence_parallel: bool = False) -> "DeviceMesh": + etp_size: int = None,vpp_size: int = None, device_type: str = 'cuda', sequence_parallel: bool = False) -> "DeviceMesh": """Create a default device mesh from the given sizes. Args: @@ -62,7 +59,6 @@ def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, ulysses_size: The ulysses parallel size cp_size: The context parallel size ep_size: The expert parallel size - ep_fsdp_size: The expert fsdp parallel size etp_size: The expert tensor parallel size vpp_size: The virtual pipeline parallel size device_type: The device type @@ -91,11 +87,6 @@ def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, mesh_dim_sizes.append(dp_size) else: mesh_dim_sizes.append(-1) - if ep_size is not None: - mesh_dim_sizes.append(ep_size) - mesh_dim_names.append("ep") - if origin_world_size == 1: - world_size *= ep_size if cp_size is not None: mesh_dim_sizes.append(cp_size) mesh_dim_names.append("cp") @@ -112,7 +103,6 @@ def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, mesh_dim_names=tuple(mesh_dim_names), vpp_size=vpp_size, ep_size=ep_size, - ep_fsdp_size=ep_fsdp_size, etp_size=etp_size, ulysses_size=ulysses_size, sequence_parallel=sequence_parallel, @@ -122,7 +112,7 @@ def __post_init__(self): if not isinstance(self.mesh, np.ndarray): self.mesh = np.array(self.mesh) - valid_dim_names = {"dp", "fsdp", "tp", "pp", "cp", "ep", "ep_fsdp"} + valid_dim_names = {"dp", "fsdp", "tp", "pp", "cp", "ep"} if self.mesh_dim_names is not None: if len(self.mesh_dim_names) != len(self.mesh.shape): raise ValueError( @@ -134,12 +124,6 @@ def __post_init__(self): def create_process_group(self, dims): """Create a process group by dims""" import torch.distributed as dist - ranks = self.get_ranks_for_dims(dims) - return dist.new_group(ranks=ranks) - - def get_ranks_for_dims(self, dims): - if isinstance(dims, str): - dims = (dims,) rank = dist.get_rank() coords = np.argwhere(self.mesh == rank)[0] slices = [] @@ -149,7 +133,8 @@ def get_ranks_for_dims(self, dims): else: slices.append(coords[i]) - return sorted(self.mesh[tuple(slices)].flatten().tolist()) + ranks = sorted(self.mesh[tuple(slices)].flatten().tolist()) + return dist.new_group(ranks=ranks) def get_dim_group(self, dims): import torch.distributed as dist @@ -192,11 +177,55 @@ def get_dim_group(self, dims): key = tuple(c for i, c in enumerate(coord) if i != dim_idx) return group_map[key] + def get_ranks_for_dims(self, dims): + if self.mesh_dim_names is None: + raise ValueError("mesh_dim_names is not set.") + if isinstance(dims, str): + dims = (dims,) + for dim_name in dims: + if dim_name not in self.mesh_dim_names: + raise ValueError( + f"Dimension '{dim_name}' not found in mesh. Available: {self.mesh_dim_names}" + ) + + coord = self._get_coord() + if coord is None: + raise RuntimeError("Current rank is not found in mesh.") + + slices = [] + for i, dim_name in enumerate(self.mesh_dim_names): + if dim_name in dims: + slices.append(slice(None)) + else: + slices.append(coord[i]) + return sorted(self.mesh[tuple(slices)].flatten().tolist()) + + def is_implicit_ep_fsdp_enabled(self) -> bool: + ep_world_size = self.ep_world_size or 1 + dp_world_size = self.dp_world_size or 1 + if ep_world_size <= 1 or dp_world_size <= 1: + return False + + world_size = self.world_size or 1 + if world_size % ep_world_size != 0: + raise ValueError( + f"world_size ({world_size}) must be divisible by ep_world_size ({ep_world_size}) " + "to infer implicit EP_FSDP from dp." + ) + expected_dp_size = world_size // ep_world_size + if dp_world_size != expected_dp_size: + raise ValueError( + f"Implicit EP_FSDP requires dp_world_size == world_size // ep_world_size, " + f"but got dp_world_size={dp_world_size}, world_size={world_size}, " + f"ep_world_size={ep_world_size}." + ) + return True + @property def order(self): """The order of the dimensions for megatron""" # TODO hard coded for now - return 'tp-cp-ep-ep_fsdp-dp-pp' + return 'tp-cp-ep-dp-pp' def to_torch_device_mesh(self): import torch @@ -274,10 +303,6 @@ def cp_rank(self) -> Optional[int]: def ep_rank(self) -> Optional[int]: return self._get_rank_for_dim("ep") - @property - def ep_fsdp_rank(self) -> Optional[int]: - return self._get_rank_for_dim("ep_fsdp") - @property def dp_world_size(self) -> int: return self._get_world_size_for_dim("dp") @@ -302,10 +327,6 @@ def cp_world_size(self) -> int: def ep_world_size(self) -> Optional[int]: return self._get_world_size_for_dim("ep") - @property - def ep_fsdp_world_size(self) -> Optional[int]: - return self._get_world_size_for_dim("ep_fsdp") - @property def etp_world_size(self) -> int: if self.etp_size is not None: @@ -371,16 +392,16 @@ def data_world_size(self) -> int: """Consider all dp/fsdp ranks, uses to determine how to distribute the data""" dp_world_size = self.dp_world_size fsdp_world_size = self.fsdp_world_size - ulysses_size = self.ulysses_size or 1 if fsdp_world_size is not None and fsdp_world_size > 1: - data_world_size = dp_world_size * fsdp_world_size if dp_world_size is not None else fsdp_world_size - else: - data_world_size = dp_world_size if dp_world_size is not None else 1 + if dp_world_size is not None: + return dp_world_size * fsdp_world_size + else: + return fsdp_world_size + + ulysses_size = self.ulysses_size or 1 + assert dp_world_size % ulysses_size == 0, f'dp_world_size: {dp_world_size} cannot be divided by ulysses_size: {ulysses_size}.' + return dp_world_size // ulysses_size - assert data_world_size % ulysses_size == 0, ( - f'data_world_size: {data_world_size} cannot be divided by ulysses_size: {ulysses_size}.' - ) - return data_world_size // ulysses_size def get_slice(self, total_length: int, rank: Optional[int] = None) -> slice: world_size = self.data_world_size if world_size == 1: @@ -396,57 +417,6 @@ def get_slice(self, total_length: int, rank: Optional[int] = None) -> slice: end = (rank + 1) * k + min(rank + 1, m) return slice(start, end) - def get_tp_ranks(self) -> List[int]: - """Get all ranks in the same TP group as the current rank.""" - rank = Platform.get_rank() - if not self._has_dim("tp"): - return [rank] - - tp_idx = self._get_dim_index("tp") - coords = self._get_coord_for_rank(rank) - - if coords is None: - return [] - - slices = [] - for i, dim_val in enumerate(coords): - if i == tp_idx: - slices.append(slice(None)) - else: - slices.append(dim_val) - - return sorted(self.mesh[tuple(slices)].flatten().tolist()) - - def get_tp_last_ranks(self) -> List[int]: - """Get a list of all ranks that are the last rank in their respective TP group.""" - if not self._has_dim("tp"): - return self.mesh.flatten().tolist() - - tp_idx = self._get_dim_index("tp") - tp_size = self.mesh.shape[tp_idx] - - slices = [slice(None)] * self.mesh.ndim - slices[tp_idx] = tp_size - 1 - - return sorted(self.mesh[tuple(slices)].flatten().tolist()) - - def is_tp_last_rank(self, rank: Optional[int] = None) -> bool: - """Check if the given rank is the last rank in its TP group.""" - if rank is None: - rank = Platform.get_rank() - - if not self._has_dim("tp"): - return True - - tp_idx = self._get_dim_index("tp") - coords = self._get_coord_for_rank(rank) - - if coords is None: - return False - - tp_size = self.mesh.shape[tp_idx] - return coords[tp_idx] == tp_size - 1 - def is_pp_first_rank(self) -> bool: pp_ranks = self.get_pp_first_ranks() if pp_ranks is None: @@ -506,6 +476,7 @@ class DeviceGroup: name: str ranks: Union[List[int], int] device_type: str + visible_devices: Optional[str] = None # Optional: explicitly set visible devices (e.g., "8,9") gpus_per_worker: int = 1 _device_mesh: Dict[str, DeviceMesh] = field(default_factory=dict) @@ -522,17 +493,63 @@ def _ensure_npu_backend() -> None: ) from exc @staticmethod - def visible_device_env(platform: str = None) -> str: - return Platform.get_platform(platform).visible_device_env() + def visible_device_env() -> str: + return Platform.get_platform().visible_device_env() @staticmethod - def device_prefix(platform: str = None) -> str: - return Platform.get_platform(platform).device_prefix() + def device_prefix() -> str: + return Platform.get_platform().device_prefix() @staticmethod def get_platform_names() -> List[str]: return ['GPU', 'NPU', 'MPS'] + @staticmethod + def resolve_visible_devices( + device_type: str, + *, + explicit: Any = None, + env_values: Optional[List[Any]] = None, + include_os_env: bool = True, + ) -> Optional[str]: + def _normalize(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, (list, tuple)): + return ','.join(str(v) for v in value) + if isinstance(value, int): + return str(value) + if isinstance(value, str): + return value + return None + + if not device_type: + return None + if device_type.upper() == "CPU": + return None + + try: + visible_env = Platform.get_platform(device_type.upper()).visible_device_env() + except Exception: + visible_env = None + if not visible_env: + return None + + normalized = _normalize(explicit) + if normalized: + return normalized + + if env_values: + for value in env_values: + normalized = _normalize(value) + if normalized: + return normalized + + if include_os_env: + return _normalize(os.environ.get(visible_env)) + + return None + @staticmethod def get_platform(platform: str = None) -> Type['Platform']: if platform is None: From ce711392aa08fbe9f1be7f319c38acb9ea301873 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 13 Feb 2026 10:07:03 +0800 Subject: [PATCH 3/3] wip --- src/twinkle/model/transformers/strategy/native_fsdp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 051cd8cd..1a4d69c5 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -193,6 +193,7 @@ def _maybe_shard_ep_expert_blocks(model: nn.Module, reshard_after_forward: Optional[bool], mp_policy: 'MixedPrecisionPolicy') -> int: from torch.distributed.fsdp import fully_shard + from torch.distributed.tensor import Shard sharded_blocks = 0 for module in model.modules(): if not getattr(module, "_ep_patched", False): @@ -207,6 +208,7 @@ def _maybe_shard_ep_expert_blocks(model: nn.Module, mesh=mesh, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy, + shard_placement_fn=lambda param: Shard(1), ) sharded_blocks += 1 return sharded_blocks