Skip to content

Commit d6dd5c3

Browse files
committed
wip
1 parent f47dbc7 commit d6dd5c3

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

tests/moe/test_expert_parallel_qwen3_fsdp_sp.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -255,19 +255,22 @@ def _compare_grad_dicts(
255255
a = baseline.get(k)
256256
b = sp.get(k)
257257
if a is None or b is None:
258-
raise AssertionError(f'[rank{rank}] Missing grad key={k} baseline={a is not None} sp={b is not None}')
258+
raise AssertionError(
259+
f'[rank{rank}] Missing grad key={k} baseline={a is not None} sp={b is not None} '
260+
f'baseline_keys={len(baseline)} sp_keys={len(sp)}')
259261
a32 = a.to(dtype=torch.float32)
260262
b32 = b.to(dtype=torch.float32)
261263
diff = b32 - a32
262264
rel = diff.norm() / (a32.norm() + 1e-12)
263265
if rel.item() > rel_tol:
264266
abs_diff = diff.abs()
265-
print(
266-
f'[rank{rank}] {k} grad diff mean={abs_diff.mean().item():.6e} '
267-
f'max={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={rel_tol:.1e}',
268-
flush=True,
267+
max_idx = int(abs_diff.reshape(-1).argmax().item())
268+
raise AssertionError(
269+
f'[rank{rank}] {k} grad not close: shape={tuple(a32.shape)} '
270+
f'base_norm={a32.norm().item():.6e} sp_norm={b32.norm().item():.6e} '
271+
f'mean_abs={abs_diff.mean().item():.6e} max_abs={abs_diff.max().item():.6e} '
272+
f'max_flat_idx={max_idx} rel_norm={rel.item():.6e} tol={rel_tol:.1e}',
269273
)
270-
assert rel.item() <= rel_tol
271274

272275

273276
def _run_worker_ep_fsdp_sp_align(
@@ -450,8 +453,9 @@ def _run_worker_ep_fsdp_sp_align(
450453
sp_sel = sp_sel.view(batch_size, end - start, -1)
451454
if not torch.equal(base_sel, sp_sel):
452455
mismatch = (base_sel != sp_sel).sum().item()
453-
print(f'[rank{rank}] block[{idx}] selected_experts mismatch count={mismatch}', flush=True)
454-
assert torch.equal(base_sel, sp_sel)
456+
raise AssertionError(
457+
f'[rank{rank}] block[{idx}] selected_experts mismatch count={mismatch} '
458+
f'base_shape={tuple(base_sel.shape)} sp_shape={tuple(sp_sel.shape)}')
455459

456460
# Backward alignment (expert grads on active local experts for this slice).
457461
sp_loss_sum = F.cross_entropy(
@@ -627,11 +631,10 @@ def _run_worker_fsdp_sp_align(
627631
# Forward alignment (full-seq logits reconstructed by SP gather).
628632
if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4):
629633
diff = (sp_logits - base_logits).abs()
630-
print(
631-
f'[rank{rank}] logits diff mean={diff.mean().item():.6e} max={diff.max().item():.6e}',
632-
flush=True,
634+
raise AssertionError(
635+
f'[rank{rank}] logits not close: mean_abs={diff.mean().item():.6e} '
636+
f'max_abs={diff.max().item():.6e} (rtol=1e-3, atol=1e-4)',
633637
)
634-
assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4)
635638

636639
# Backward alignment: local CE(sum) on SP, compare gathered full-seq inputs_embeds grads.
637640
sp_loss_sum = F.cross_entropy(
@@ -656,7 +659,6 @@ def _run_worker_fsdp_sp_align(
656659
raise AssertionError(
657660
f'[rank{rank}] inputs_embeds.grad(full) not close: mean_abs={abs_diff.mean().item():.6e} '
658661
f'max_abs={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={grad_rel_tol:.1e}')
659-
assert rel.item() <= grad_rel_tol
660662
finally:
661663
dist.destroy_process_group()
662664

0 commit comments

Comments
 (0)