Skip to content
32 changes: 5 additions & 27 deletions tests/moe/test_expert_parallel_qwen3_fsdp_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,6 @@ def _compare_grad_dicts(
b32 = b.to(dtype=torch.float32)
diff = b32 - a32
rel = diff.norm() / (a32.norm() + 1e-12)
if rel.item() > rel_tol:
abs_diff = diff.abs()
print(
f'[rank{rank}] {k} grad diff mean={abs_diff.mean().item():.6e} '
f'max={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={rel_tol:.1e}',
flush=True,
)
assert rel.item() <= rel_tol
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While removing the redundant if block is a good simplification, the detailed debugging information from the print statement is lost. When this assertion fails, it will be hard to know the actual and expected values. Please add a descriptive message to the assertion.

Suggested change
assert rel.item() <= rel_tol
assert rel.item() <= rel_tol, f'[rank{rank}] {k} grad diff mean={diff.abs().mean().item():.6e} max={diff.abs().max().item():.6e} rel_norm={rel.item():.6e} tol={rel_tol:.1e}'



Expand Down Expand Up @@ -414,7 +407,8 @@ def _run_worker_ep_fsdp_sp_align(
model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None)

# Preprocess labels through SP strategy so they are shifted + split consistently.
sp_label_inputs = {'labels': labels_raw, 'position_ids': position_ids}
# Keep label semantics consistent with the baseline path: next-token aligned labels.
sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids}
sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs)
sp_local_labels = sp_label_inputs['labels']

Expand All @@ -433,10 +427,7 @@ def _run_worker_ep_fsdp_sp_align(
sp_logits = sp_out.logits.detach()

# Forward alignment (full-seq logits reconstructed by SP gather).
if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4):
diff = (sp_logits - base_logits).abs()
raise AssertionError(f'[rank{rank}] logits not close: mean_abs={diff.mean().item():.6e} '
f'max_abs={diff.max().item():.6e} (rtol=1e-3, atol=1e-4)')
assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Replacing the if block with a simple assert loses the detailed error message, which is very helpful for debugging. Please include the debugging information in the assertion message. The original implementation with if not torch.allclose(...) was also more efficient as it only computed the difference when the check failed.

if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4):
    diff = (sp_logits - base_logits).abs()
    raise AssertionError(f'[rank{rank}] logits not close: mean_abs={diff.mean().item():.6e} '
                         f'max_abs={diff.max().item():.6e} (rtol=1e-3, atol=1e-4)')


# Router alignment on this rank's slice: compare selected experts exactly.
# SP captures only local tokens; baseline captures full tokens (we slice it).
Expand All @@ -448,9 +439,6 @@ def _run_worker_ep_fsdp_sp_align(
sp_sel = sp_state['selected_experts']
if sp_sel.dim() == 2:
sp_sel = sp_sel.view(batch_size, end - start, -1)
if not torch.equal(base_sel, sp_sel):
mismatch = (base_sel != sp_sel).sum().item()
print(f'[rank{rank}] block[{idx}] selected_experts mismatch count={mismatch}', flush=True)
assert torch.equal(base_sel, sp_sel)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change removes a helpful print statement that shows the mismatch count when the assertion fails. This information is valuable for debugging. Please include it in the assertion message.

assert torch.equal(base_sel, sp_sel), f'[rank{rank}] block[{idx}] selected_experts mismatch count={(base_sel != sp_sel).sum().item()}'


# Backward alignment (expert grads on active local experts for this slice).
Expand Down Expand Up @@ -609,7 +597,8 @@ def _run_worker_fsdp_sp_align(
sp_embeds = embed_sp(input_ids).detach().requires_grad_(True)
model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None)

sp_label_inputs = {'labels': labels_raw, 'position_ids': position_ids}
# Keep label semantics consistent with the baseline path: next-token aligned labels.
sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids}
sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs)
sp_local_labels = sp_label_inputs['labels']

Expand All @@ -625,12 +614,6 @@ def _run_worker_fsdp_sp_align(
sp_logits = sp_out.logits.detach()

# Forward alignment (full-seq logits reconstructed by SP gather).
if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4):
diff = (sp_logits - base_logits).abs()
print(
f'[rank{rank}] logits diff mean={diff.mean().item():.6e} max={diff.max().item():.6e}',
flush=True,
)
assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The removed print statement contained useful debugging information (mean and max difference) for when the allclose check fails. It's better to include this information in the assertion message to aid debugging.

assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4), f'[rank{rank}] logits diff mean={(sp_logits - base_logits).abs().mean().item():.6e} max={(sp_logits - base_logits).abs().max().item():.6e}'


# Backward alignment: local CE(sum) on SP, compare gathered full-seq inputs_embeds grads.
Expand All @@ -651,11 +634,6 @@ def _run_worker_fsdp_sp_align(
diff = sp_full - base_full
rel = diff.norm() / (base_full.norm() + 1e-12)
grad_rel_tol = float(os.environ.get('TWINKLE_INPUT_GRAD_REL_TOL', '1e-2'))
if rel.item() > grad_rel_tol:
abs_diff = diff.abs()
raise AssertionError(
f'[rank{rank}] inputs_embeds.grad(full) not close: mean_abs={abs_diff.mean().item():.6e} '
f'max_abs={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={grad_rel_tol:.1e}')
assert rel.item() <= grad_rel_tol
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Replacing the if block with a simple assert loses the detailed error message, which is very helpful for debugging. Please include the debugging information in the assertion message. The diff and rel tensors are already computed, so they can be used in the message.

assert rel.item() <= grad_rel_tol, f'[rank{rank}] inputs_embeds.grad(full) not close: mean_abs={diff.abs().mean().item():.6e} max_abs={diff.abs().max().item():.6e} rel_norm={rel.item():.6e} tol={grad_rel_tol:.1e}'

finally:
dist.destroy_process_group()
Expand Down
Loading