diff --git a/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py b/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py index e46573ff..a2031d14 100644 --- a/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py +++ b/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py @@ -260,13 +260,6 @@ def _compare_grad_dicts( b32 = b.to(dtype=torch.float32) diff = b32 - a32 rel = diff.norm() / (a32.norm() + 1e-12) - if rel.item() > rel_tol: - abs_diff = diff.abs() - print( - f'[rank{rank}] {k} grad diff mean={abs_diff.mean().item():.6e} ' - f'max={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={rel_tol:.1e}', - flush=True, - ) assert rel.item() <= rel_tol @@ -414,7 +407,8 @@ def _run_worker_ep_fsdp_sp_align( model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None) # Preprocess labels through SP strategy so they are shifted + split consistently. - sp_label_inputs = {'labels': labels_raw, 'position_ids': position_ids} + # Keep label semantics consistent with the baseline path: next-token aligned labels. + sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids} sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs) sp_local_labels = sp_label_inputs['labels'] @@ -433,10 +427,7 @@ def _run_worker_ep_fsdp_sp_align( sp_logits = sp_out.logits.detach() # Forward alignment (full-seq logits reconstructed by SP gather). - if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4): - diff = (sp_logits - base_logits).abs() - raise AssertionError(f'[rank{rank}] logits not close: mean_abs={diff.mean().item():.6e} ' - f'max_abs={diff.max().item():.6e} (rtol=1e-3, atol=1e-4)') + assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4) # Router alignment on this rank's slice: compare selected experts exactly. # SP captures only local tokens; baseline captures full tokens (we slice it). @@ -448,9 +439,6 @@ def _run_worker_ep_fsdp_sp_align( sp_sel = sp_state['selected_experts'] if sp_sel.dim() == 2: sp_sel = sp_sel.view(batch_size, end - start, -1) - if not torch.equal(base_sel, sp_sel): - mismatch = (base_sel != sp_sel).sum().item() - print(f'[rank{rank}] block[{idx}] selected_experts mismatch count={mismatch}', flush=True) assert torch.equal(base_sel, sp_sel) # Backward alignment (expert grads on active local experts for this slice). @@ -609,7 +597,8 @@ def _run_worker_fsdp_sp_align( sp_embeds = embed_sp(input_ids).detach().requires_grad_(True) model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None) - sp_label_inputs = {'labels': labels_raw, 'position_ids': position_ids} + # Keep label semantics consistent with the baseline path: next-token aligned labels. + sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids} sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs) sp_local_labels = sp_label_inputs['labels'] @@ -625,12 +614,6 @@ def _run_worker_fsdp_sp_align( sp_logits = sp_out.logits.detach() # Forward alignment (full-seq logits reconstructed by SP gather). - if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4): - diff = (sp_logits - base_logits).abs() - print( - f'[rank{rank}] logits diff mean={diff.mean().item():.6e} max={diff.max().item():.6e}', - flush=True, - ) assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4) # Backward alignment: local CE(sum) on SP, compare gathered full-seq inputs_embeds grads. @@ -651,11 +634,6 @@ def _run_worker_fsdp_sp_align( diff = sp_full - base_full rel = diff.norm() / (base_full.norm() + 1e-12) grad_rel_tol = float(os.environ.get('TWINKLE_INPUT_GRAD_REL_TOL', '1e-2')) - if rel.item() > grad_rel_tol: - abs_diff = diff.abs() - raise AssertionError( - f'[rank{rank}] inputs_embeds.grad(full) not close: mean_abs={abs_diff.mean().item():.6e} ' - f'max_abs={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={grad_rel_tol:.1e}') assert rel.item() <= grad_rel_tol finally: dist.destroy_process_group()