Skip to content

Commit 0857154

Browse files
committed
update
2 parents b4c0e65 + 454ebe8 commit 0857154

File tree

2 files changed

+6
-28
lines changed

2 files changed

+6
-28
lines changed

cookbook/client/tinker/short_math_grpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __call__(self, sample):
6565
return Trajectory(messages=[], user_data=[])
6666

6767
def get_boxed_answer(text):
68-
match = re.search(r'\\boxed\{([^}]+)\}', text)
68+
match = re.search(r'\\boxed{([^}]*)}', text)
6969
return match.group(1) if match else None
7070

7171
ground_truth = get_boxed_answer(sample['solution'])

tests/moe/test_expert_parallel_qwen3_fsdp_sp.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,6 @@ def _compare_grad_dicts(
260260
b32 = b.to(dtype=torch.float32)
261261
diff = b32 - a32
262262
rel = diff.norm() / (a32.norm() + 1e-12)
263-
if rel.item() > rel_tol:
264-
abs_diff = diff.abs()
265-
print(
266-
f'[rank{rank}] {k} grad diff mean={abs_diff.mean().item():.6e} '
267-
f'max={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={rel_tol:.1e}',
268-
flush=True,
269-
)
270263
assert rel.item() <= rel_tol
271264

272265

@@ -414,7 +407,8 @@ def _run_worker_ep_fsdp_sp_align(
414407
model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None)
415408

416409
# Preprocess labels through SP strategy so they are shifted + split consistently.
417-
sp_label_inputs = {'labels': labels_raw, 'position_ids': position_ids}
410+
# Keep label semantics consistent with the baseline path: next-token aligned labels.
411+
sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids}
418412
sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs)
419413
sp_local_labels = sp_label_inputs['labels']
420414

@@ -433,10 +427,7 @@ def _run_worker_ep_fsdp_sp_align(
433427
sp_logits = sp_out.logits.detach()
434428

435429
# Forward alignment (full-seq logits reconstructed by SP gather).
436-
if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4):
437-
diff = (sp_logits - base_logits).abs()
438-
raise AssertionError(f'[rank{rank}] logits not close: mean_abs={diff.mean().item():.6e} '
439-
f'max_abs={diff.max().item():.6e} (rtol=1e-3, atol=1e-4)')
430+
assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4)
440431

441432
# Router alignment on this rank's slice: compare selected experts exactly.
442433
# SP captures only local tokens; baseline captures full tokens (we slice it).
@@ -448,9 +439,6 @@ def _run_worker_ep_fsdp_sp_align(
448439
sp_sel = sp_state['selected_experts']
449440
if sp_sel.dim() == 2:
450441
sp_sel = sp_sel.view(batch_size, end - start, -1)
451-
if not torch.equal(base_sel, sp_sel):
452-
mismatch = (base_sel != sp_sel).sum().item()
453-
print(f'[rank{rank}] block[{idx}] selected_experts mismatch count={mismatch}', flush=True)
454442
assert torch.equal(base_sel, sp_sel)
455443

456444
# Backward alignment (expert grads on active local experts for this slice).
@@ -609,7 +597,8 @@ def _run_worker_fsdp_sp_align(
609597
sp_embeds = embed_sp(input_ids).detach().requires_grad_(True)
610598
model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None)
611599

612-
sp_label_inputs = {'labels': labels_raw, 'position_ids': position_ids}
600+
# Keep label semantics consistent with the baseline path: next-token aligned labels.
601+
sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids}
613602
sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs)
614603
sp_local_labels = sp_label_inputs['labels']
615604

@@ -625,12 +614,6 @@ def _run_worker_fsdp_sp_align(
625614
sp_logits = sp_out.logits.detach()
626615

627616
# Forward alignment (full-seq logits reconstructed by SP gather).
628-
if not torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4):
629-
diff = (sp_logits - base_logits).abs()
630-
print(
631-
f'[rank{rank}] logits diff mean={diff.mean().item():.6e} max={diff.max().item():.6e}',
632-
flush=True,
633-
)
634617
assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4)
635618

636619
# Backward alignment: local CE(sum) on SP, compare gathered full-seq inputs_embeds grads.
@@ -651,11 +634,6 @@ def _run_worker_fsdp_sp_align(
651634
diff = sp_full - base_full
652635
rel = diff.norm() / (base_full.norm() + 1e-12)
653636
grad_rel_tol = float(os.environ.get('TWINKLE_INPUT_GRAD_REL_TOL', '1e-2'))
654-
if rel.item() > grad_rel_tol:
655-
abs_diff = diff.abs()
656-
raise AssertionError(
657-
f'[rank{rank}] inputs_embeds.grad(full) not close: mean_abs={abs_diff.mean().item():.6e} '
658-
f'max_abs={abs_diff.max().item():.6e} rel_norm={rel.item():.6e} tol={grad_rel_tol:.1e}')
659637
assert rel.item() <= grad_rel_tol
660638
finally:
661639
dist.destroy_process_group()

0 commit comments

Comments
 (0)