@@ -260,13 +260,6 @@ def _compare_grad_dicts(
260260 b32 = b .to (dtype = torch .float32 )
261261 diff = b32 - a32
262262 rel = diff .norm () / (a32 .norm () + 1e-12 )
263- if rel .item () > rel_tol :
264- abs_diff = diff .abs ()
265- print (
266- f'[rank{ rank } ] { k } grad diff mean={ abs_diff .mean ().item ():.6e} '
267- f'max={ abs_diff .max ().item ():.6e} rel_norm={ rel .item ():.6e} tol={ rel_tol :.1e} ' ,
268- flush = True ,
269- )
270263 assert rel .item () <= rel_tol
271264
272265
@@ -414,7 +407,8 @@ def _run_worker_ep_fsdp_sp_align(
414407 model_sp , _ = fsdp_strategy .wrap_model (model_sp , optimizer = None )
415408
416409 # Preprocess labels through SP strategy so they are shifted + split consistently.
417- sp_label_inputs = {'labels' : labels_raw , 'position_ids' : position_ids }
410+ # Keep label semantics consistent with the baseline path: next-token aligned labels.
411+ sp_label_inputs = {'labels' : labels_shifted , 'position_ids' : position_ids }
418412 sp_label_inputs = sp_strategy .preprocess_inputs (sp_label_inputs )
419413 sp_local_labels = sp_label_inputs ['labels' ]
420414
@@ -433,10 +427,7 @@ def _run_worker_ep_fsdp_sp_align(
433427 sp_logits = sp_out .logits .detach ()
434428
435429 # Forward alignment (full-seq logits reconstructed by SP gather).
436- if not torch .allclose (sp_logits , base_logits , rtol = 1e-3 , atol = 1e-4 ):
437- diff = (sp_logits - base_logits ).abs ()
438- raise AssertionError (f'[rank{ rank } ] logits not close: mean_abs={ diff .mean ().item ():.6e} '
439- f'max_abs={ diff .max ().item ():.6e} (rtol=1e-3, atol=1e-4)' )
430+ assert torch .allclose (sp_logits , base_logits , rtol = 1e-3 , atol = 1e-4 )
440431
441432 # Router alignment on this rank's slice: compare selected experts exactly.
442433 # SP captures only local tokens; baseline captures full tokens (we slice it).
@@ -448,9 +439,6 @@ def _run_worker_ep_fsdp_sp_align(
448439 sp_sel = sp_state ['selected_experts' ]
449440 if sp_sel .dim () == 2 :
450441 sp_sel = sp_sel .view (batch_size , end - start , - 1 )
451- if not torch .equal (base_sel , sp_sel ):
452- mismatch = (base_sel != sp_sel ).sum ().item ()
453- print (f'[rank{ rank } ] block[{ idx } ] selected_experts mismatch count={ mismatch } ' , flush = True )
454442 assert torch .equal (base_sel , sp_sel )
455443
456444 # Backward alignment (expert grads on active local experts for this slice).
@@ -609,7 +597,8 @@ def _run_worker_fsdp_sp_align(
609597 sp_embeds = embed_sp (input_ids ).detach ().requires_grad_ (True )
610598 model_sp , _ = fsdp_strategy .wrap_model (model_sp , optimizer = None )
611599
612- sp_label_inputs = {'labels' : labels_raw , 'position_ids' : position_ids }
600+ # Keep label semantics consistent with the baseline path: next-token aligned labels.
601+ sp_label_inputs = {'labels' : labels_shifted , 'position_ids' : position_ids }
613602 sp_label_inputs = sp_strategy .preprocess_inputs (sp_label_inputs )
614603 sp_local_labels = sp_label_inputs ['labels' ]
615604
@@ -625,12 +614,6 @@ def _run_worker_fsdp_sp_align(
625614 sp_logits = sp_out .logits .detach ()
626615
627616 # Forward alignment (full-seq logits reconstructed by SP gather).
628- if not torch .allclose (sp_logits , base_logits , rtol = 1e-3 , atol = 1e-4 ):
629- diff = (sp_logits - base_logits ).abs ()
630- print (
631- f'[rank{ rank } ] logits diff mean={ diff .mean ().item ():.6e} max={ diff .max ().item ():.6e} ' ,
632- flush = True ,
633- )
634617 assert torch .allclose (sp_logits , base_logits , rtol = 1e-3 , atol = 1e-4 )
635618
636619 # Backward alignment: local CE(sum) on SP, compare gathered full-seq inputs_embeds grads.
@@ -651,11 +634,6 @@ def _run_worker_fsdp_sp_align(
651634 diff = sp_full - base_full
652635 rel = diff .norm () / (base_full .norm () + 1e-12 )
653636 grad_rel_tol = float (os .environ .get ('TWINKLE_INPUT_GRAD_REL_TOL' , '1e-2' ))
654- if rel .item () > grad_rel_tol :
655- abs_diff = diff .abs ()
656- raise AssertionError (
657- f'[rank{ rank } ] inputs_embeds.grad(full) not close: mean_abs={ abs_diff .mean ().item ():.6e} '
658- f'max_abs={ abs_diff .max ().item ():.6e} rel_norm={ rel .item ():.6e} tol={ grad_rel_tol :.1e} ' )
659637 assert rel .item () <= grad_rel_tol
660638 finally :
661639 dist .destroy_process_group ()
0 commit comments