Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cookbook/legacy/sft/ep_fsdp_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
PROCESSOR_ID = os.environ.get('QWEN3_PROCESSOR_ID', 'AlpacaProcessor')
NUM_LAYERS = int(os.environ.get('QWEN3_NUM_LAYERS', '1'))

# 4 GPUs: dp=2, ep=2
# 4 GPUs: ep=2, ep_fsdp=2
device_mesh = DeviceMesh.from_sizes(
device_type=Platform.get_platform().device_prefix(),
dp_size=2,
dp_size=None,
ep_size=2,
ep_fsdp_size=2,
)

twinkle.initialize(
Expand Down Expand Up @@ -80,6 +81,7 @@ def train():
"router_dtype": "fp32",
"all_to_all": "torch",
"keep_router_logits": False,
"ep_fsdp": True,
}
},
)
Expand Down
65 changes: 63 additions & 2 deletions src/twinkle/model/transformers/moe/expert_parallel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from __future__ import annotations

import os
from dataclasses import dataclass
from typing import Any, Dict, Iterable, Optional, Tuple

Expand All @@ -21,6 +22,8 @@ class ExpertParallelConfig:
keep_router_logits: bool = True
pad_to_max: bool = False
ignore_shared_experts: bool = False
# Deprecated: EP_FSDP is inferred implicitly from dp/ep mesh topology.
ep_fsdp: bool = False


def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: Optional[Dict[str, Any]] = None):
Expand All @@ -32,6 +35,8 @@ def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: Opt
if ep_world_size <= 1:
return model

ep_fsdp_enabled = device_mesh.is_implicit_ep_fsdp_enabled()

if cfg.pad_to_max:
raise NotImplementedError("pad_to_max is not implemented.")
if cfg.all_to_all != "torch":
Expand All @@ -45,7 +50,7 @@ def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: Opt
raise RuntimeError("EP process group is not available in device_mesh.")

for block in find_moe_blocks(model):
shard_experts(block, device_mesh, cfg)
shard_experts(block, device_mesh, cfg, ep_fsdp_enabled=ep_fsdp_enabled)
patch_forward(block, device_mesh, cfg)

return model
Expand Down Expand Up @@ -76,7 +81,7 @@ def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]:
return blocks


def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None:
def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig, *, ep_fsdp_enabled: bool) -> None:
num_experts = _get_num_experts(block)
ep_world_size = device_mesh.ep_world_size
ep_rank = device_mesh.ep_rank
Expand All @@ -91,6 +96,11 @@ def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel
local_end = local_start + experts_per_rank

if isinstance(block.experts, nn.ModuleList):
if ep_fsdp_enabled:
raise NotImplementedError(
"EP+EP_FSDP currently does not support MoE experts stored as nn.ModuleList. "
"Only tensor experts (gate_up_proj/down_proj) are supported."
)
local_experts = nn.ModuleList(block.experts[local_start:local_end])
block.experts = local_experts
block._ep_tensor_experts = False
Expand All @@ -105,6 +115,7 @@ def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel
block._ep_rank = ep_rank
block._ep_world_size = ep_world_size
block._ep_ignore_shared_experts = cfg.ignore_shared_experts
block._ep_fsdp_enabled = ep_fsdp_enabled


def patch_forward(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None:
Expand Down Expand Up @@ -138,6 +149,8 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
else:
raise ValueError(f"Unsupported hidden_states ndim: {hidden_states.ndim}")

_debug_log_ep_fsdp_runtime_once(block, gate, hidden_states_2d)

router_logits, routing_weights, selected_experts, cast_weights = _run_router(
gate=gate,
hidden_states=hidden_states_2d,
Expand Down Expand Up @@ -333,6 +346,25 @@ def _run_expert(block: nn.Module, expert_id: int, expert_in: torch.Tensor) -> to
expert = block.experts[expert_id]
return _run_module_with_casting(expert, expert_in)
experts = block.experts
if getattr(block, "_ep_fsdp_enabled", False):
# In EP+EP_FSDP mode, execute experts.forward so FSDP hooks can
# manage unshard/reshard around forward/backward safely.
top_k_index = torch.full(
(expert_in.shape[0], 1),
int(expert_id),
dtype=torch.long,
device=expert_in.device,
)
top_k_weights = torch.ones(
(expert_in.shape[0], 1),
dtype=expert_in.dtype,
device=expert_in.device,
)
out = experts(expert_in, top_k_index, top_k_weights)
if out.dtype != input_dtype:
out = out.to(input_dtype)
return out

gate_up = experts.gate_up_proj[expert_id]
down = experts.down_proj[expert_id]
compute_dtype = gate_up.dtype
Expand Down Expand Up @@ -383,3 +415,32 @@ def _run_router(
if norm_topk_prob:
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
return router_logits, routing_weights, selected_experts, True


def _debug_log_ep_fsdp_runtime_once(block: nn.Module, gate: nn.Module, hidden_states: torch.Tensor) -> None:
if os.environ.get("TWINKLE_EP_FSDP_DEBUG", "1") != "1":
return
if not getattr(block, "_ep_fsdp_enabled", False):
return
if getattr(block, "_ep_fsdp_runtime_logged", False):
return
rank = dist.get_rank() if dist.is_initialized() else -1
experts = getattr(block, "experts", None)
expert_param = next(experts.parameters(), None) if experts is not None else None
gate_param = next(gate.parameters(), None) if gate is not None else None
print(
f"[EP_FSDP][rank{rank}] runtime input={hidden_states.device} "
f"expert_param={_describe_param_mesh(expert_param)} "
f"gate_param={_describe_param_mesh(gate_param)}",
flush=True,
)
block._ep_fsdp_runtime_logged = True


def _describe_param_mesh(param: Optional[nn.Parameter]) -> str:
if param is None:
return "none"
mesh = getattr(param, "device_mesh", None)
if mesh is None:
return f"local:{getattr(param, 'device', 'unknown')}"
return f"dtensor:{tuple(mesh.mesh_dim_names or ())}:{mesh.mesh.flatten().tolist()}"
87 changes: 87 additions & 0 deletions src/twinkle/model/transformers/strategy/native_fsdp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
from typing import Dict, Any, Optional, Literal, Set, TYPE_CHECKING

import torch
Expand Down Expand Up @@ -30,13 +31,32 @@ def wrap_model(self, model, optimizer=None):
from torch.distributed.fsdp import fully_shard
fsdp_mesh = _build_fsdp_mesh(self.device_mesh)
if fsdp_mesh is not None:
ep_fsdp_mode = _is_ep_fsdp_mode_enabled(
self.device_mesh,
self.enable_ep,
)
if self.enable_ep:
_ensure_moe_patched_if_needed(model, self.device_mesh)
_place_ep_experts_on_local_device(model, self.device_mesh)
mp_policy = _build_mp_policy(self.mixed_precision)
reshard_after_forward = self.fsdp_config.get("reshard_after_forward", True)
ignored_params = _collect_expert_params(model) if self.enable_ep else None

if ep_fsdp_mode:
_ensure_ep_fsdp_supported(model)
ep_fsdp_mesh = _build_ep_fsdp_mesh(self.device_mesh)
if ep_fsdp_mesh is None:
raise RuntimeError(
"Implicit EP_FSDP requires dp dim with size > 1, but could not build an ep_fsdp mesh from dp."
)
sharded_blocks = _maybe_shard_ep_expert_blocks(
model,
mesh=ep_fsdp_mesh,
reshard_after_forward=reshard_after_forward,
mp_policy=mp_policy,
)
_debug_log_ep_fsdp_sharding(model, self.device_mesh, sharded_blocks)

_maybe_shard_layers(
model,
mesh=fsdp_mesh,
Expand Down Expand Up @@ -86,6 +106,21 @@ def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]:
return TorchDeviceMesh(device_mesh.device_type, flat_mesh, mesh_dim_names=("fsdp",))


def _build_ep_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]:
if device_mesh is None or not device_mesh.has_dim("dp"):
return None
ranks = device_mesh.get_ranks_for_dims("dp")
if len(ranks) <= 1:
return None
return TorchDeviceMesh(device_mesh.device_type, ranks, mesh_dim_names=("ep_fsdp",))


def _is_ep_fsdp_mode_enabled(device_mesh: Optional[DeviceMesh], enable_ep: bool) -> bool:
if not enable_ep or device_mesh is None:
return False
return device_mesh.is_implicit_ep_fsdp_enabled()


def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]:
ignored: Set[nn.Parameter] = set()
ep_patched = False
Expand Down Expand Up @@ -140,6 +175,58 @@ def _ensure_moe_patched_if_needed(model: nn.Module, device_mesh: DeviceMesh) ->
)


def _ensure_ep_fsdp_supported(model: nn.Module) -> None:
for module in model.modules():
if not getattr(module, "_ep_patched", False):
continue
experts = getattr(module, "experts", None)
if isinstance(experts, nn.ModuleList):
raise NotImplementedError(
"EP+EP_FSDP currently does not support MoE experts stored as nn.ModuleList. "
"Only tensor experts (gate_up_proj/down_proj) are supported."
)


def _maybe_shard_ep_expert_blocks(model: nn.Module,
*,
mesh: TorchDeviceMesh,
reshard_after_forward: Optional[bool],
mp_policy: 'MixedPrecisionPolicy') -> int:
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import Shard
sharded_blocks = 0
for module in model.modules():
if not getattr(module, "_ep_patched", False):
continue
experts = getattr(module, "experts", None)
if experts is None:
continue
# Correct EP+EP_FSDP behavior: only experts are sharded on ep_fsdp mesh.
# Non-expert params (router/gate etc.) are left to global FSDP wrapping.
fully_shard(
experts,
mesh=mesh,
reshard_after_forward=reshard_after_forward,
mp_policy=mp_policy,
shard_placement_fn=lambda param: Shard(1),
)
sharded_blocks += 1
return sharded_blocks


def _debug_log_ep_fsdp_sharding(model: nn.Module, device_mesh: DeviceMesh, sharded_blocks: int) -> None:
if os.environ.get("TWINKLE_EP_FSDP_DEBUG", "0") != "1":
return

rank = Platform.get_rank()
ep_fsdp_ranks = device_mesh.get_ranks_for_dims("dp")
print(
f"[EP_FSDP][rank{rank}] enabled=1 sharded_blocks={sharded_blocks} "
f"ep_fsdp_group={ep_fsdp_ranks}",
flush=True,
)


def _maybe_shard_layers(model: nn.Module,
*,
mesh: TorchDeviceMesh,
Expand Down
27 changes: 24 additions & 3 deletions src/twinkle/utils/grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,23 @@ def normalize_and_clip_grad_norm(parameters: Iterable['torch.nn.Parameter'],

has_dtensor_grad = any(hasattr(grad, "to_local") for grad in grads)
has_local_tensor_grad = any(not hasattr(grad, "to_local") for grad in grads)
if not (has_dtensor_grad and has_local_tensor_grad):
dtensor_mesh_keys = set()
for grad in grads:
if not hasattr(grad, "to_local"):
continue
mesh = getattr(grad, "device_mesh", None)
if mesh is None:
dtensor_mesh_keys.add("dtensor:unknown")
continue
try:
mesh_key = (tuple(mesh.mesh.flatten().tolist()), tuple(mesh.mesh_dim_names or ()))
except Exception:
mesh_key = repr(mesh)
dtensor_mesh_keys.add(mesh_key)

has_mixed_dtensor_mesh = len(dtensor_mesh_keys) > 1

if not (has_dtensor_grad and has_local_tensor_grad) and not has_mixed_dtensor_mesh:
grad_norm = torch.nn.utils.clip_grad_norm_(
parameters,
max_grad_norm,
Expand Down Expand Up @@ -62,6 +78,11 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor:
reduce_device = torch.device(Platform.get_local_device())
else:
reduce_device = torch.device("cpu")
reduce_group = group
if has_mixed_dtensor_mesh:
# Different DTensor meshes cannot be reduced by DTensor op propagation (e.g. aten.stack).
# Fall back to world reduction over local shards.
reduce_group = None

if norm_type == float("inf"):
local_norm = 0.0
Expand All @@ -72,7 +93,7 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor:
local_norm = max(local_norm, local_grad.detach().abs().max().item())
total_norm_tensor = torch.tensor(local_norm, device=reduce_device, dtype=torch.float32)
if dist.is_initialized():
dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=reduce_group)
total_norm = float(total_norm_tensor.item())
else:
local_sq = 0.0
Expand All @@ -83,7 +104,7 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor:
local_sq += local_grad.detach().float().pow(2).sum().item()
total_sq_tensor = torch.tensor(local_sq, device=reduce_device, dtype=torch.float32)
if dist.is_initialized():
dist.all_reduce(total_sq_tensor, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(total_sq_tensor, op=dist.ReduceOp.SUM, group=reduce_group)
total_norm = float(total_sq_tensor.sqrt().item())

clip_coef = float(max_grad_norm) / (total_norm + 1e-6)
Expand Down
Loading
Loading