-
Notifications
You must be signed in to change notification settings - Fork 22
unittest fix ljl #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
unittest fix ljl #62
Changes from all commits
c6ace7f
939466f
e344ba2
f47dbc7
d6dd5c3
7a783d4
2d4a19f
1d8094d
1a6d0ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -260,13 +260,6 @@ def _compare_grad_dicts( | |
| b32 = b.to(dtype=torch.float32) | ||
| diff = b32 - a32 | ||
| rel = diff.norm() / (a32.norm() + 1e-12) | ||
| if rel.item() > rel_tol: | ||
| abs_diff = diff.abs() | ||
| print( | ||
| f'[rank{rank}] {k} grad diff mean={abs_diff.mean().item():.6e} ' | ||
| f'max={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={rel_tol:.1e}', | ||
| flush=True, | ||
| ) | ||
| assert rel.item() <= rel_tol | ||
|
|
||
|
|
||
|
|
@@ -414,7 +407,8 @@ def _run_worker_ep_fsdp_sp_align( | |
| model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None) | ||
|
|
||
| # Preprocess labels through SP strategy so they are shifted + split consistently. | ||
| sp_label_inputs = {'labels': labels_raw, 'position_ids': position_ids} | ||
| # Keep label semantics consistent with the baseline path: next-token aligned labels. | ||
| sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids} | ||
| sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs) | ||
| sp_local_labels = sp_label_inputs['labels'] | ||
|
|
||
|
|
@@ -433,10 +427,7 @@ def _run_worker_ep_fsdp_sp_align( | |
| sp_logits = sp_out.logits.detach() | ||
|
|
||
| # Forward alignment (full-seq logits reconstructed by SP gather). | ||
| if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4): | ||
| diff = (sp_logits - base_logits).abs() | ||
| raise AssertionError(f'[rank{rank}] logits not close: mean_abs={diff.mean().item():.6e} ' | ||
| f'max_abs={diff.max().item():.6e} (rtol=1e-3, atol=1e-4)') | ||
| assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replacing the if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4):
diff = (sp_logits - base_logits).abs()
raise AssertionError(f'[rank{rank}] logits not close: mean_abs={diff.mean().item():.6e} '
f'max_abs={diff.max().item():.6e} (rtol=1e-3, atol=1e-4)') |
||
|
|
||
| # Router alignment on this rank's slice: compare selected experts exactly. | ||
| # SP captures only local tokens; baseline captures full tokens (we slice it). | ||
|
|
@@ -448,9 +439,6 @@ def _run_worker_ep_fsdp_sp_align( | |
| sp_sel = sp_state['selected_experts'] | ||
| if sp_sel.dim() == 2: | ||
| sp_sel = sp_sel.view(batch_size, end - start, -1) | ||
| if not torch.equal(base_sel, sp_sel): | ||
| mismatch = (base_sel != sp_sel).sum().item() | ||
| print(f'[rank{rank}] block[{idx}] selected_experts mismatch count={mismatch}', flush=True) | ||
| assert torch.equal(base_sel, sp_sel) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change removes a helpful assert torch.equal(base_sel, sp_sel), f'[rank{rank}] block[{idx}] selected_experts mismatch count={(base_sel != sp_sel).sum().item()}' |
||
|
|
||
| # Backward alignment (expert grads on active local experts for this slice). | ||
|
|
@@ -609,7 +597,8 @@ def _run_worker_fsdp_sp_align( | |
| sp_embeds = embed_sp(input_ids).detach().requires_grad_(True) | ||
| model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None) | ||
|
|
||
| sp_label_inputs = {'labels': labels_raw, 'position_ids': position_ids} | ||
| # Keep label semantics consistent with the baseline path: next-token aligned labels. | ||
| sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids} | ||
| sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs) | ||
| sp_local_labels = sp_label_inputs['labels'] | ||
|
|
||
|
|
@@ -625,12 +614,6 @@ def _run_worker_fsdp_sp_align( | |
| sp_logits = sp_out.logits.detach() | ||
|
|
||
| # Forward alignment (full-seq logits reconstructed by SP gather). | ||
| if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4): | ||
| diff = (sp_logits - base_logits).abs() | ||
| print( | ||
| f'[rank{rank}] logits diff mean={diff.mean().item():.6e} max={diff.max().item():.6e}', | ||
| flush=True, | ||
| ) | ||
| assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The removed assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4), f'[rank{rank}] logits diff mean={(sp_logits - base_logits).abs().mean().item():.6e} max={(sp_logits - base_logits).abs().max().item():.6e}' |
||
|
|
||
| # Backward alignment: local CE(sum) on SP, compare gathered full-seq inputs_embeds grads. | ||
|
|
@@ -651,11 +634,6 @@ def _run_worker_fsdp_sp_align( | |
| diff = sp_full - base_full | ||
| rel = diff.norm() / (base_full.norm() + 1e-12) | ||
| grad_rel_tol = float(os.environ.get('TWINKLE_INPUT_GRAD_REL_TOL', '1e-2')) | ||
| if rel.item() > grad_rel_tol: | ||
| abs_diff = diff.abs() | ||
| raise AssertionError( | ||
| f'[rank{rank}] inputs_embeds.grad(full) not close: mean_abs={abs_diff.mean().item():.6e} ' | ||
| f'max_abs={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={grad_rel_tol:.1e}') | ||
| assert rel.item() <= grad_rel_tol | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replacing the assert rel.item() <= grad_rel_tol, f'[rank{rank}] inputs_embeds.grad(full) not close: mean_abs={diff.abs().mean().item():.6e} max_abs={diff.abs().max().item():.6e} rel_norm={rel.item():.6e} tol={grad_rel_tol:.1e}' |
||
| finally: | ||
| dist.destroy_process_group() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While removing the redundant
ifblock is a good simplification, the detailed debugging information from theprintstatement is lost. When this assertion fails, it will be hard to know the actual and expected values. Please add a descriptive message to the assertion.