From c6ace7fd88556e945c3bce61f8e982dff0a891ef Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Wed, 11 Feb 2026 15:41:47 +0800 Subject: [PATCH 1/6] feat(tests): replace manual sp_group retrieval with module attribute Replace calls to `_get_sp_group_from_device_mesh` with direct access to `sequence_parallel._sp_group` in sequence parallel attention tests. This simplifies the test setup by using the already initialized group stored in the module, improving code clarity and reducing redundancy. --- .../test_sequence_parallel_single_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/sequence_parallel/test_sequence_parallel_single_attention.py b/tests/sequence_parallel/test_sequence_parallel_single_attention.py index dde6b387..32e01aaa 100644 --- a/tests/sequence_parallel/test_sequence_parallel_single_attention.py +++ b/tests/sequence_parallel/test_sequence_parallel_single_attention.py @@ -181,7 +181,7 @@ def _run_worker_single_attn(rank: int, world_size: int, port: int, padding: bool sp_size = world_size device_mesh = DeviceMesh.from_sizes(dp_size=world_size, ulysses_size=sp_size, device_type="cuda") _setup_sp(device_mesh, sp_size) - sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size) + sp_group = sequence_parallel._sp_group batch_size = 2 unpad_seq_len = 127 if padding else 128 @@ -271,7 +271,7 @@ def _run_worker_single_attn_fsdp(rank: int, world_size: int, port: int): # For FSDP+SP, SP is derived from dp/fsdp ranks. Use fsdp=world, dp=1. device_mesh = DeviceMesh.from_sizes(fsdp_size=world_size, dp_size=1, ulysses_size=sp_size, device_type="cuda") _setup_sp(device_mesh, sp_size) - sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size) + sp_group = sequence_parallel._sp_group batch_size = 2 unpad_seq_len = 128 From 939466f1aafe8d5a213db8f2edeecdb353459f68 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 13 Feb 2026 15:43:14 +0800 Subject: [PATCH 2/6] feat(tests): improve kernel availability check in test_function_kernel Add additional imports and a try-except block to verify that the 'kernels-test/flattened-build' kernel can be successfully loaded in the current environment before proceeding with the test. This prevents test failures due to environment-specific loading issues and provides a more informative skip message. --- tests/kernel/test_function_kernel.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/kernel/test_function_kernel.py b/tests/kernel/test_function_kernel.py index 66f375ee..fe95bafa 100644 --- a/tests/kernel/test_function_kernel.py +++ b/tests/kernel/test_function_kernel.py @@ -54,10 +54,21 @@ def test_flattened_build_replaces_function(self): self.skipTest(f'HuggingFace unreachable: {e}') try: from kernels import has_kernel + from kernels._versions import select_revision_or_version + from kernels.utils import get_kernel except Exception: self.skipTest('kernels package missing has_kernel.') if not has_kernel('kernels-test/flattened-build'): self.skipTest('kernels-test/flattened-build not available.') + try: + revision = select_revision_or_version( + 'kernels-test/flattened-build', + revision=None, + version=None, + ) + get_kernel('kernels-test/flattened-build', revision=revision) + except Exception as exc: + self.skipTest(f'kernels-test/flattened-build cannot be loaded in this env: {exc}') _ensure_test_packages() module_name = 'tests.kernel._tmp_flattened_build_module' From d6dd5c384f60119771c379b31af320794eaa2b9d Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 13 Feb 2026 16:03:43 +0800 Subject: [PATCH 3/6] wip --- .../moe/test_expert_parallel_qwen3_fsdp_sp.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py b/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py index e46573ff..83b876c9 100644 --- a/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py +++ b/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py @@ -255,19 +255,22 @@ def _compare_grad_dicts( a = baseline.get(k) b = sp.get(k) if a is None or b is None: - raise AssertionError(f'[rank{rank}] Missing grad key={k} baseline={a is not None} sp={b is not None}') + raise AssertionError( + f'[rank{rank}] Missing grad key={k} baseline={a is not None} sp={b is not None} ' + f'baseline_keys={len(baseline)} sp_keys={len(sp)}') a32 = a.to(dtype=torch.float32) b32 = b.to(dtype=torch.float32) diff = b32 - a32 rel = diff.norm() / (a32.norm() + 1e-12) if rel.item() > rel_tol: abs_diff = diff.abs() - print( - f'[rank{rank}] {k} grad diff mean={abs_diff.mean().item():.6e} ' - f'max={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={rel_tol:.1e}', - flush=True, + max_idx = int(abs_diff.reshape(-1).argmax().item()) + raise AssertionError( + f'[rank{rank}] {k} grad not close: shape={tuple(a32.shape)} ' + f'base_norm={a32.norm().item():.6e} sp_norm={b32.norm().item():.6e} ' + f'mean_abs={abs_diff.mean().item():.6e} max_abs={abs_diff.max().item():.6e} ' + f'max_flat_idx={max_idx} rel_norm={rel.item():.6e} tol={rel_tol:.1e}', ) - assert rel.item() <= rel_tol def _run_worker_ep_fsdp_sp_align( @@ -450,8 +453,9 @@ def _run_worker_ep_fsdp_sp_align( sp_sel = sp_sel.view(batch_size, end - start, -1) if not torch.equal(base_sel, sp_sel): mismatch = (base_sel != sp_sel).sum().item() - print(f'[rank{rank}] block[{idx}] selected_experts mismatch count={mismatch}', flush=True) - assert torch.equal(base_sel, sp_sel) + raise AssertionError( + f'[rank{rank}] block[{idx}] selected_experts mismatch count={mismatch} ' + f'base_shape={tuple(base_sel.shape)} sp_shape={tuple(sp_sel.shape)}') # Backward alignment (expert grads on active local experts for this slice). sp_loss_sum = F.cross_entropy( @@ -627,11 +631,10 @@ def _run_worker_fsdp_sp_align( # Forward alignment (full-seq logits reconstructed by SP gather). if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4): diff = (sp_logits - base_logits).abs() - print( - f'[rank{rank}] logits diff mean={diff.mean().item():.6e} max={diff.max().item():.6e}', - flush=True, + raise AssertionError( + f'[rank{rank}] logits not close: mean_abs={diff.mean().item():.6e} ' + f'max_abs={diff.max().item():.6e} (rtol=1e-3, atol=1e-4)', ) - assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4) # Backward alignment: local CE(sum) on SP, compare gathered full-seq inputs_embeds grads. sp_loss_sum = F.cross_entropy( @@ -656,7 +659,6 @@ def _run_worker_fsdp_sp_align( raise AssertionError( f'[rank{rank}] inputs_embeds.grad(full) not close: mean_abs={abs_diff.mean().item():.6e} ' f'max_abs={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={grad_rel_tol:.1e}') - assert rel.item() <= grad_rel_tol finally: dist.destroy_process_group() From 7a783d4485b74b1c4d9607e3832676fe8a6b3c15 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 13 Feb 2026 16:18:12 +0800 Subject: [PATCH 4/6] wip --- tests/moe/test_expert_parallel_qwen3_fsdp_sp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py b/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py index 83b876c9..dc933fac 100644 --- a/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py +++ b/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py @@ -417,7 +417,8 @@ def _run_worker_ep_fsdp_sp_align( model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None) # Preprocess labels through SP strategy so they are shifted + split consistently. - sp_label_inputs = {'labels': labels_raw, 'position_ids': position_ids} + # Keep label semantics consistent with the baseline path: next-token aligned labels. + sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids} sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs) sp_local_labels = sp_label_inputs['labels'] @@ -613,7 +614,8 @@ def _run_worker_fsdp_sp_align( sp_embeds = embed_sp(input_ids).detach().requires_grad_(True) model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None) - sp_label_inputs = {'labels': labels_raw, 'position_ids': position_ids} + # Keep label semantics consistent with the baseline path: next-token aligned labels. + sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids} sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs) sp_local_labels = sp_label_inputs['labels'] From 2d4a19f963bbcb0fbe83a1cc0c3dd4f9d17f09a7 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 13 Feb 2026 16:40:25 +0800 Subject: [PATCH 5/6] remove debug info --- .../moe/test_expert_parallel_qwen3_fsdp_sp.py | 38 +++---------------- 1 file changed, 6 insertions(+), 32 deletions(-) diff --git a/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py b/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py index dc933fac..a2031d14 100644 --- a/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py +++ b/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py @@ -255,22 +255,12 @@ def _compare_grad_dicts( a = baseline.get(k) b = sp.get(k) if a is None or b is None: - raise AssertionError( - f'[rank{rank}] Missing grad key={k} baseline={a is not None} sp={b is not None} ' - f'baseline_keys={len(baseline)} sp_keys={len(sp)}') + raise AssertionError(f'[rank{rank}] Missing grad key={k} baseline={a is not None} sp={b is not None}') a32 = a.to(dtype=torch.float32) b32 = b.to(dtype=torch.float32) diff = b32 - a32 rel = diff.norm() / (a32.norm() + 1e-12) - if rel.item() > rel_tol: - abs_diff = diff.abs() - max_idx = int(abs_diff.reshape(-1).argmax().item()) - raise AssertionError( - f'[rank{rank}] {k} grad not close: shape={tuple(a32.shape)} ' - f'base_norm={a32.norm().item():.6e} sp_norm={b32.norm().item():.6e} ' - f'mean_abs={abs_diff.mean().item():.6e} max_abs={abs_diff.max().item():.6e} ' - f'max_flat_idx={max_idx} rel_norm={rel.item():.6e} tol={rel_tol:.1e}', - ) + assert rel.item() <= rel_tol def _run_worker_ep_fsdp_sp_align( @@ -437,10 +427,7 @@ def _run_worker_ep_fsdp_sp_align( sp_logits = sp_out.logits.detach() # Forward alignment (full-seq logits reconstructed by SP gather). - if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4): - diff = (sp_logits - base_logits).abs() - raise AssertionError(f'[rank{rank}] logits not close: mean_abs={diff.mean().item():.6e} ' - f'max_abs={diff.max().item():.6e} (rtol=1e-3, atol=1e-4)') + assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4) # Router alignment on this rank's slice: compare selected experts exactly. # SP captures only local tokens; baseline captures full tokens (we slice it). @@ -452,11 +439,7 @@ def _run_worker_ep_fsdp_sp_align( sp_sel = sp_state['selected_experts'] if sp_sel.dim() == 2: sp_sel = sp_sel.view(batch_size, end - start, -1) - if not torch.equal(base_sel, sp_sel): - mismatch = (base_sel != sp_sel).sum().item() - raise AssertionError( - f'[rank{rank}] block[{idx}] selected_experts mismatch count={mismatch} ' - f'base_shape={tuple(base_sel.shape)} sp_shape={tuple(sp_sel.shape)}') + assert torch.equal(base_sel, sp_sel) # Backward alignment (expert grads on active local experts for this slice). sp_loss_sum = F.cross_entropy( @@ -631,12 +614,7 @@ def _run_worker_fsdp_sp_align( sp_logits = sp_out.logits.detach() # Forward alignment (full-seq logits reconstructed by SP gather). - if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4): - diff = (sp_logits - base_logits).abs() - raise AssertionError( - f'[rank{rank}] logits not close: mean_abs={diff.mean().item():.6e} ' - f'max_abs={diff.max().item():.6e} (rtol=1e-3, atol=1e-4)', - ) + assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4) # Backward alignment: local CE(sum) on SP, compare gathered full-seq inputs_embeds grads. sp_loss_sum = F.cross_entropy( @@ -656,11 +634,7 @@ def _run_worker_fsdp_sp_align( diff = sp_full - base_full rel = diff.norm() / (base_full.norm() + 1e-12) grad_rel_tol = float(os.environ.get('TWINKLE_INPUT_GRAD_REL_TOL', '1e-2')) - if rel.item() > grad_rel_tol: - abs_diff = diff.abs() - raise AssertionError( - f'[rank{rank}] inputs_embeds.grad(full) not close: mean_abs={abs_diff.mean().item():.6e} ' - f'max_abs={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={grad_rel_tol:.1e}') + assert rel.item() <= grad_rel_tol finally: dist.destroy_process_group() From 2b350078b62181af134214303c43558c3558bc59 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 13 Feb 2026 17:30:55 +0800 Subject: [PATCH 6/6] feat: add ep/sp FSDP MoE finetuning entry and update script - Add new entry for ep/sp FSDP MoE finetuning in README table - Update ep_fsdp_qwen3_moe.py script to include ulysses_size parameter for enhanced parallelism configuration --- README.md | 1 + cookbook/transformers/ep_fsdp_qwen3_moe.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/README.md b/README.md index 447ebf87..7ddd0070 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ pip install -e . | --------------------------------- | --------------- | ------------------------------------------------- | | FSDP finetuning | transformers | [Script](cookbook/transformers/fsdp2.py) | | FSDP MoE finetuning | transformers | [Script](cookbook/transformers/fsdp2_moe.py) | +| ep/sp FSDP MoE finetuning | transformers | [Script](cookbook/transformers/ep_fsdp_qwen3_moe.py) | | EP MoE finetuning | transformers | [Script](cookbook/transformers/ep_fsdp_qwen3_moe.py) | | pp/tp/cp finetuning | megatron | [Script](cookbook/megatron/tp.py) | | pp/tp/cp MoE finetuning | megatron | [Script](cookbook/megatron/tp_moe.py) | diff --git a/cookbook/transformers/ep_fsdp_qwen3_moe.py b/cookbook/transformers/ep_fsdp_qwen3_moe.py index 16706eae..6473dc63 100644 --- a/cookbook/transformers/ep_fsdp_qwen3_moe.py +++ b/cookbook/transformers/ep_fsdp_qwen3_moe.py @@ -21,11 +21,13 @@ # 4 gpus, dp=2, ep=2 dp_size = 2 ep_size = 2 +ulysses_size = 2 device_mesh = DeviceMesh( device_type=Platform.get_platform().device_prefix(), mesh=np.arange(dp_size * ep_size).reshape(dp_size, ep_size), mesh_dim_names=('dp', 'ep'), + ulysses_size=ulysses_size, # enable sp ) twinkle.initialize(