diff --git a/cookbook/legacy/sft/ep_fsdp_qwen3_moe.py b/cookbook/legacy/sft/ep_fsdp_qwen3_moe.py index 68c6d15a..468bfd5d 100644 --- a/cookbook/legacy/sft/ep_fsdp_qwen3_moe.py +++ b/cookbook/legacy/sft/ep_fsdp_qwen3_moe.py @@ -20,11 +20,12 @@ PROCESSOR_ID = os.environ.get('QWEN3_PROCESSOR_ID', 'AlpacaProcessor') NUM_LAYERS = int(os.environ.get('QWEN3_NUM_LAYERS', '1')) -# 4 GPUs: dp=2, ep=2 +# 4 GPUs: ep=2, ep_fsdp=2 device_mesh = DeviceMesh.from_sizes( device_type=Platform.get_platform().device_prefix(), - dp_size=2, + dp_size=None, ep_size=2, + ep_fsdp_size=2, ) twinkle.initialize( @@ -80,6 +81,7 @@ def train(): "router_dtype": "fp32", "all_to_all": "torch", "keep_router_logits": False, + "ep_fsdp": True, } }, ) diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index 99108bb7..55134133 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations +import os from dataclasses import dataclass from typing import Any, Dict, Iterable, Optional, Tuple @@ -21,6 +22,8 @@ class ExpertParallelConfig: keep_router_logits: bool = True pad_to_max: bool = False ignore_shared_experts: bool = False + # Deprecated: EP_FSDP is inferred implicitly from dp/ep mesh topology. + ep_fsdp: bool = False def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: Optional[Dict[str, Any]] = None): @@ -32,6 +35,8 @@ def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: Opt if ep_world_size <= 1: return model + ep_fsdp_enabled = device_mesh.is_implicit_ep_fsdp_enabled() + if cfg.pad_to_max: raise NotImplementedError("pad_to_max is not implemented.") if cfg.all_to_all != "torch": @@ -45,7 +50,7 @@ def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: Opt raise RuntimeError("EP process group is not available in device_mesh.") for block in find_moe_blocks(model): - shard_experts(block, device_mesh, cfg) + shard_experts(block, device_mesh, cfg, ep_fsdp_enabled=ep_fsdp_enabled) patch_forward(block, device_mesh, cfg) return model @@ -76,7 +81,7 @@ def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]: return blocks -def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None: +def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig, *, ep_fsdp_enabled: bool) -> None: num_experts = _get_num_experts(block) ep_world_size = device_mesh.ep_world_size ep_rank = device_mesh.ep_rank @@ -91,6 +96,11 @@ def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel local_end = local_start + experts_per_rank if isinstance(block.experts, nn.ModuleList): + if ep_fsdp_enabled: + raise NotImplementedError( + "EP+EP_FSDP currently does not support MoE experts stored as nn.ModuleList. " + "Only tensor experts (gate_up_proj/down_proj) are supported." + ) local_experts = nn.ModuleList(block.experts[local_start:local_end]) block.experts = local_experts block._ep_tensor_experts = False @@ -105,6 +115,7 @@ def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel block._ep_rank = ep_rank block._ep_world_size = ep_world_size block._ep_ignore_shared_experts = cfg.ignore_shared_experts + block._ep_fsdp_enabled = ep_fsdp_enabled def patch_forward(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None: @@ -138,6 +149,8 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): else: raise ValueError(f"Unsupported hidden_states ndim: {hidden_states.ndim}") + _debug_log_ep_fsdp_runtime_once(block, gate, hidden_states_2d) + router_logits, routing_weights, selected_experts, cast_weights = _run_router( gate=gate, hidden_states=hidden_states_2d, @@ -333,6 +346,25 @@ def _run_expert(block: nn.Module, expert_id: int, expert_in: torch.Tensor) -> to expert = block.experts[expert_id] return _run_module_with_casting(expert, expert_in) experts = block.experts + if getattr(block, "_ep_fsdp_enabled", False): + # In EP+EP_FSDP mode, execute experts.forward so FSDP hooks can + # manage unshard/reshard around forward/backward safely. + top_k_index = torch.full( + (expert_in.shape[0], 1), + int(expert_id), + dtype=torch.long, + device=expert_in.device, + ) + top_k_weights = torch.ones( + (expert_in.shape[0], 1), + dtype=expert_in.dtype, + device=expert_in.device, + ) + out = experts(expert_in, top_k_index, top_k_weights) + if out.dtype != input_dtype: + out = out.to(input_dtype) + return out + gate_up = experts.gate_up_proj[expert_id] down = experts.down_proj[expert_id] compute_dtype = gate_up.dtype @@ -383,3 +415,32 @@ def _run_router( if norm_topk_prob: routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) return router_logits, routing_weights, selected_experts, True + + +def _debug_log_ep_fsdp_runtime_once(block: nn.Module, gate: nn.Module, hidden_states: torch.Tensor) -> None: + if os.environ.get("TWINKLE_EP_FSDP_DEBUG", "1") != "1": + return + if not getattr(block, "_ep_fsdp_enabled", False): + return + if getattr(block, "_ep_fsdp_runtime_logged", False): + return + rank = dist.get_rank() if dist.is_initialized() else -1 + experts = getattr(block, "experts", None) + expert_param = next(experts.parameters(), None) if experts is not None else None + gate_param = next(gate.parameters(), None) if gate is not None else None + print( + f"[EP_FSDP][rank{rank}] runtime input={hidden_states.device} " + f"expert_param={_describe_param_mesh(expert_param)} " + f"gate_param={_describe_param_mesh(gate_param)}", + flush=True, + ) + block._ep_fsdp_runtime_logged = True + + +def _describe_param_mesh(param: Optional[nn.Parameter]) -> str: + if param is None: + return "none" + mesh = getattr(param, "device_mesh", None) + if mesh is None: + return f"local:{getattr(param, 'device', 'unknown')}" + return f"dtensor:{tuple(mesh.mesh_dim_names or ())}:{mesh.mesh.flatten().tolist()}" diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index b2a94faa..1a4d69c5 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import os from typing import Dict, Any, Optional, Literal, Set, TYPE_CHECKING import torch @@ -30,6 +31,10 @@ def wrap_model(self, model, optimizer=None): from torch.distributed.fsdp import fully_shard fsdp_mesh = _build_fsdp_mesh(self.device_mesh) if fsdp_mesh is not None: + ep_fsdp_mode = _is_ep_fsdp_mode_enabled( + self.device_mesh, + self.enable_ep, + ) if self.enable_ep: _ensure_moe_patched_if_needed(model, self.device_mesh) _place_ep_experts_on_local_device(model, self.device_mesh) @@ -37,6 +42,21 @@ def wrap_model(self, model, optimizer=None): reshard_after_forward = self.fsdp_config.get("reshard_after_forward", True) ignored_params = _collect_expert_params(model) if self.enable_ep else None + if ep_fsdp_mode: + _ensure_ep_fsdp_supported(model) + ep_fsdp_mesh = _build_ep_fsdp_mesh(self.device_mesh) + if ep_fsdp_mesh is None: + raise RuntimeError( + "Implicit EP_FSDP requires dp dim with size > 1, but could not build an ep_fsdp mesh from dp." + ) + sharded_blocks = _maybe_shard_ep_expert_blocks( + model, + mesh=ep_fsdp_mesh, + reshard_after_forward=reshard_after_forward, + mp_policy=mp_policy, + ) + _debug_log_ep_fsdp_sharding(model, self.device_mesh, sharded_blocks) + _maybe_shard_layers( model, mesh=fsdp_mesh, @@ -86,6 +106,21 @@ def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]: return TorchDeviceMesh(device_mesh.device_type, flat_mesh, mesh_dim_names=("fsdp",)) +def _build_ep_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]: + if device_mesh is None or not device_mesh.has_dim("dp"): + return None + ranks = device_mesh.get_ranks_for_dims("dp") + if len(ranks) <= 1: + return None + return TorchDeviceMesh(device_mesh.device_type, ranks, mesh_dim_names=("ep_fsdp",)) + + +def _is_ep_fsdp_mode_enabled(device_mesh: Optional[DeviceMesh], enable_ep: bool) -> bool: + if not enable_ep or device_mesh is None: + return False + return device_mesh.is_implicit_ep_fsdp_enabled() + + def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]: ignored: Set[nn.Parameter] = set() ep_patched = False @@ -140,6 +175,58 @@ def _ensure_moe_patched_if_needed(model: nn.Module, device_mesh: DeviceMesh) -> ) +def _ensure_ep_fsdp_supported(model: nn.Module) -> None: + for module in model.modules(): + if not getattr(module, "_ep_patched", False): + continue + experts = getattr(module, "experts", None) + if isinstance(experts, nn.ModuleList): + raise NotImplementedError( + "EP+EP_FSDP currently does not support MoE experts stored as nn.ModuleList. " + "Only tensor experts (gate_up_proj/down_proj) are supported." + ) + + +def _maybe_shard_ep_expert_blocks(model: nn.Module, + *, + mesh: TorchDeviceMesh, + reshard_after_forward: Optional[bool], + mp_policy: 'MixedPrecisionPolicy') -> int: + from torch.distributed.fsdp import fully_shard + from torch.distributed.tensor import Shard + sharded_blocks = 0 + for module in model.modules(): + if not getattr(module, "_ep_patched", False): + continue + experts = getattr(module, "experts", None) + if experts is None: + continue + # Correct EP+EP_FSDP behavior: only experts are sharded on ep_fsdp mesh. + # Non-expert params (router/gate etc.) are left to global FSDP wrapping. + fully_shard( + experts, + mesh=mesh, + reshard_after_forward=reshard_after_forward, + mp_policy=mp_policy, + shard_placement_fn=lambda param: Shard(1), + ) + sharded_blocks += 1 + return sharded_blocks + + +def _debug_log_ep_fsdp_sharding(model: nn.Module, device_mesh: DeviceMesh, sharded_blocks: int) -> None: + if os.environ.get("TWINKLE_EP_FSDP_DEBUG", "0") != "1": + return + + rank = Platform.get_rank() + ep_fsdp_ranks = device_mesh.get_ranks_for_dims("dp") + print( + f"[EP_FSDP][rank{rank}] enabled=1 sharded_blocks={sharded_blocks} " + f"ep_fsdp_group={ep_fsdp_ranks}", + flush=True, + ) + + def _maybe_shard_layers(model: nn.Module, *, mesh: TorchDeviceMesh, diff --git a/src/twinkle/utils/grad_clip.py b/src/twinkle/utils/grad_clip.py index 97134d3c..b9d7fb70 100644 --- a/src/twinkle/utils/grad_clip.py +++ b/src/twinkle/utils/grad_clip.py @@ -32,7 +32,23 @@ def normalize_and_clip_grad_norm(parameters: Iterable['torch.nn.Parameter'], has_dtensor_grad = any(hasattr(grad, "to_local") for grad in grads) has_local_tensor_grad = any(not hasattr(grad, "to_local") for grad in grads) - if not (has_dtensor_grad and has_local_tensor_grad): + dtensor_mesh_keys = set() + for grad in grads: + if not hasattr(grad, "to_local"): + continue + mesh = getattr(grad, "device_mesh", None) + if mesh is None: + dtensor_mesh_keys.add("dtensor:unknown") + continue + try: + mesh_key = (tuple(mesh.mesh.flatten().tolist()), tuple(mesh.mesh_dim_names or ())) + except Exception: + mesh_key = repr(mesh) + dtensor_mesh_keys.add(mesh_key) + + has_mixed_dtensor_mesh = len(dtensor_mesh_keys) > 1 + + if not (has_dtensor_grad and has_local_tensor_grad) and not has_mixed_dtensor_mesh: grad_norm = torch.nn.utils.clip_grad_norm_( parameters, max_grad_norm, @@ -62,6 +78,11 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor: reduce_device = torch.device(Platform.get_local_device()) else: reduce_device = torch.device("cpu") + reduce_group = group + if has_mixed_dtensor_mesh: + # Different DTensor meshes cannot be reduced by DTensor op propagation (e.g. aten.stack). + # Fall back to world reduction over local shards. + reduce_group = None if norm_type == float("inf"): local_norm = 0.0 @@ -72,7 +93,7 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor: local_norm = max(local_norm, local_grad.detach().abs().max().item()) total_norm_tensor = torch.tensor(local_norm, device=reduce_device, dtype=torch.float32) if dist.is_initialized(): - dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=group) + dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=reduce_group) total_norm = float(total_norm_tensor.item()) else: local_sq = 0.0 @@ -83,7 +104,7 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor: local_sq += local_grad.detach().float().pow(2).sum().item() total_sq_tensor = torch.tensor(local_sq, device=reduce_device, dtype=torch.float32) if dist.is_initialized(): - dist.all_reduce(total_sq_tensor, op=dist.ReduceOp.SUM, group=group) + dist.all_reduce(total_sq_tensor, op=dist.ReduceOp.SUM, group=reduce_group) total_norm = float(total_sq_tensor.sqrt().item()) clip_coef = float(max_grad_norm) / (total_norm + 1e-6) diff --git a/src/twinkle/utils/platform.py b/src/twinkle/utils/platform.py index 79e30c9e..c958b748 100644 --- a/src/twinkle/utils/platform.py +++ b/src/twinkle/utils/platform.py @@ -8,10 +8,10 @@ import subprocess from abc import ABC from dataclasses import dataclass, field -from functools import lru_cache from itertools import product -from typing import Dict, List, Optional, Type, Union - +from functools import lru_cache +from typing import Any, Dict, List, Optional, Type, Union +import torch.distributed as dist import numpy as np @@ -50,7 +50,7 @@ class DeviceMesh: @staticmethod def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, tp_size: int = None, pp_size: int = None, ulysses_size: int = None, cp_size: int = None, ep_size: int = None, - etp_size: int = 1, vpp_size: int = None, device_type: str = 'cuda', sequence_parallel: bool = False) -> "DeviceMesh": + etp_size: int = None,vpp_size: int = None, device_type: str = 'cuda', sequence_parallel: bool = False) -> "DeviceMesh": """Create a default device mesh from the given sizes. Args: @@ -180,6 +180,50 @@ def get_dim_group(self, dims): key = tuple(c for i, c in enumerate(coord) if i != dim_idx) return group_map[key] + def get_ranks_for_dims(self, dims): + if self.mesh_dim_names is None: + raise ValueError("mesh_dim_names is not set.") + if isinstance(dims, str): + dims = (dims,) + for dim_name in dims: + if dim_name not in self.mesh_dim_names: + raise ValueError( + f"Dimension '{dim_name}' not found in mesh. Available: {self.mesh_dim_names}" + ) + + coord = self._get_coord() + if coord is None: + raise RuntimeError("Current rank is not found in mesh.") + + slices = [] + for i, dim_name in enumerate(self.mesh_dim_names): + if dim_name in dims: + slices.append(slice(None)) + else: + slices.append(coord[i]) + return sorted(self.mesh[tuple(slices)].flatten().tolist()) + + def is_implicit_ep_fsdp_enabled(self) -> bool: + ep_world_size = self.ep_world_size or 1 + dp_world_size = self.dp_world_size or 1 + if ep_world_size <= 1 or dp_world_size <= 1: + return False + + world_size = self.world_size or 1 + if world_size % ep_world_size != 0: + raise ValueError( + f"world_size ({world_size}) must be divisible by ep_world_size ({ep_world_size}) " + "to infer implicit EP_FSDP from dp." + ) + expected_dp_size = world_size // ep_world_size + if dp_world_size != expected_dp_size: + raise ValueError( + f"Implicit EP_FSDP requires dp_world_size == world_size // ep_world_size, " + f"but got dp_world_size={dp_world_size}, world_size={world_size}, " + f"ep_world_size={ep_world_size}." + ) + return True + @property def order(self): """The order of the dimensions for megatron""" @@ -351,16 +395,16 @@ def data_world_size(self) -> int: """Consider all dp/fsdp ranks, uses to determine how to distribute the data""" dp_world_size = self.dp_world_size fsdp_world_size = self.fsdp_world_size - ulysses_size = self.ulysses_size or 1 if fsdp_world_size is not None and fsdp_world_size > 1: - data_world_size = dp_world_size * fsdp_world_size if dp_world_size is not None else fsdp_world_size - else: - data_world_size = dp_world_size if dp_world_size is not None else 1 + if dp_world_size is not None: + return dp_world_size * fsdp_world_size + else: + return fsdp_world_size + + ulysses_size = self.ulysses_size or 1 + assert dp_world_size % ulysses_size == 0, f'dp_world_size: {dp_world_size} cannot be divided by ulysses_size: {ulysses_size}.' + return dp_world_size // ulysses_size - assert data_world_size % ulysses_size == 0, ( - f'data_world_size: {data_world_size} cannot be divided by ulysses_size: {ulysses_size}.' - ) - return data_world_size // ulysses_size def get_slice(self, total_length: int, rank: Optional[int] = None) -> slice: world_size = self.data_world_size if world_size == 1: @@ -376,57 +420,6 @@ def get_slice(self, total_length: int, rank: Optional[int] = None) -> slice: end = (rank + 1) * k + min(rank + 1, m) return slice(start, end) - def get_tp_ranks(self) -> List[int]: - """Get all ranks in the same TP group as the current rank.""" - rank = Platform.get_rank() - if not self._has_dim("tp"): - return [rank] - - tp_idx = self._get_dim_index("tp") - coords = self._get_coord_for_rank(rank) - - if coords is None: - return [] - - slices = [] - for i, dim_val in enumerate(coords): - if i == tp_idx: - slices.append(slice(None)) - else: - slices.append(dim_val) - - return sorted(self.mesh[tuple(slices)].flatten().tolist()) - - def get_tp_last_ranks(self) -> List[int]: - """Get a list of all ranks that are the last rank in their respective TP group.""" - if not self._has_dim("tp"): - return self.mesh.flatten().tolist() - - tp_idx = self._get_dim_index("tp") - tp_size = self.mesh.shape[tp_idx] - - slices = [slice(None)] * self.mesh.ndim - slices[tp_idx] = tp_size - 1 - - return sorted(self.mesh[tuple(slices)].flatten().tolist()) - - def is_tp_last_rank(self, rank: Optional[int] = None) -> bool: - """Check if the given rank is the last rank in its TP group.""" - if rank is None: - rank = Platform.get_rank() - - if not self._has_dim("tp"): - return True - - tp_idx = self._get_dim_index("tp") - coords = self._get_coord_for_rank(rank) - - if coords is None: - return False - - tp_size = self.mesh.shape[tp_idx] - return coords[tp_idx] == tp_size - 1 - def is_pp_first_rank(self) -> bool: pp_ranks = self.get_pp_first_ranks() if pp_ranks is None: @@ -486,6 +479,7 @@ class DeviceGroup: name: str ranks: Union[List[int], int] device_type: str + visible_devices: Optional[str] = None # Optional: explicitly set visible devices (e.g., "8,9") gpus_per_worker: int = 1 _device_mesh: Dict[str, DeviceMesh] = field(default_factory=dict) @@ -502,17 +496,63 @@ def _ensure_npu_backend() -> None: ) from exc @staticmethod - def visible_device_env(platform: str = None) -> str: - return Platform.get_platform(platform).visible_device_env() + def visible_device_env() -> str: + return Platform.get_platform().visible_device_env() @staticmethod - def device_prefix(platform: str = None) -> str: - return Platform.get_platform(platform).device_prefix() + def device_prefix() -> str: + return Platform.get_platform().device_prefix() @staticmethod def get_platform_names() -> List[str]: return ['GPU', 'NPU', 'MPS'] + @staticmethod + def resolve_visible_devices( + device_type: str, + *, + explicit: Any = None, + env_values: Optional[List[Any]] = None, + include_os_env: bool = True, + ) -> Optional[str]: + def _normalize(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, (list, tuple)): + return ','.join(str(v) for v in value) + if isinstance(value, int): + return str(value) + if isinstance(value, str): + return value + return None + + if not device_type: + return None + if device_type.upper() == "CPU": + return None + + try: + visible_env = Platform.get_platform(device_type.upper()).visible_device_env() + except Exception: + visible_env = None + if not visible_env: + return None + + normalized = _normalize(explicit) + if normalized: + return normalized + + if env_values: + for value in env_values: + normalized = _normalize(value) + if normalized: + return normalized + + if include_os_env: + return _normalize(os.environ.get(visible_env)) + + return None + @staticmethod def get_platform(platform: str = None) -> Type['Platform']: if platform is None: diff --git a/tests/DeviceMesh/test_device_mesh.py b/tests/DeviceMesh/test_device_mesh.py index 51a12d22..05c2a3fc 100644 --- a/tests/DeviceMesh/test_device_mesh.py +++ b/tests/DeviceMesh/test_device_mesh.py @@ -183,6 +183,24 @@ def test_world_sizes(self): assert mesh.pp_world_size == 5 assert mesh.world_size == 2 * 3 * 4 * 5 + def test_ep_fsdp_rank_and_world_size(self): + mesh = DeviceMesh.from_sizes(dp_size=1, ep_size=2, ep_fsdp_size=3) + mesh_array = mesh.mesh.reshape(1, 2, 3) + + for ep_idx in range(2): + for ep_fsdp_idx in range(3): + global_rank = int(mesh_array[0, ep_idx, ep_fsdp_idx]) + with patch.object(Platform, 'get_rank', return_value=global_rank): + assert mesh.ep_rank == ep_idx + assert mesh.ep_fsdp_rank == ep_fsdp_idx + assert mesh.ep_world_size == 2 + assert mesh.ep_fsdp_world_size == 3 + + def test_without_dp_dimension(self): + mesh = DeviceMesh.from_sizes(dp_size=None, ep_size=2, ep_fsdp_size=2) + assert mesh.mesh.shape == (2, 2) + assert mesh.mesh_dim_names == ('ep', 'ep_fsdp') + def test_data_rank_with_dp_only(self): mesh = DeviceMesh.from_sizes(dp_size=4)