@@ -255,19 +255,22 @@ def _compare_grad_dicts(
255255 a = baseline .get (k )
256256 b = sp .get (k )
257257 if a is None or b is None :
258- raise AssertionError (f'[rank{ rank } ] Missing grad key={ k } baseline={ a is not None } sp={ b is not None } ' )
258+ raise AssertionError (
259+ f'[rank{ rank } ] Missing grad key={ k } baseline={ a is not None } sp={ b is not None } '
260+ f'baseline_keys={ len (baseline )} sp_keys={ len (sp )} ' )
259261 a32 = a .to (dtype = torch .float32 )
260262 b32 = b .to (dtype = torch .float32 )
261263 diff = b32 - a32
262264 rel = diff .norm () / (a32 .norm () + 1e-12 )
263265 if rel .item () > rel_tol :
264266 abs_diff = diff .abs ()
265- print (
266- f'[rank{ rank } ] { k } grad diff mean={ abs_diff .mean ().item ():.6e} '
267- f'max={ abs_diff .max ().item ():.6e} rel_norm={ rel .item ():.6e} tol={ rel_tol :.1e} ' ,
268- flush = True ,
267+ max_idx = int (abs_diff .reshape (- 1 ).argmax ().item ())
268+ raise AssertionError (
269+ f'[rank{ rank } ] { k } grad not close: shape={ tuple (a32 .shape )} '
270+ f'base_norm={ a32 .norm ().item ():.6e} sp_norm={ b32 .norm ().item ():.6e} '
271+ f'mean_abs={ abs_diff .mean ().item ():.6e} max_abs={ abs_diff .max ().item ():.6e} '
272+ f'max_flat_idx={ max_idx } rel_norm={ rel .item ():.6e} tol={ rel_tol :.1e} ' ,
269273 )
270- assert rel .item () <= rel_tol
271274
272275
273276def _run_worker_ep_fsdp_sp_align (
@@ -450,8 +453,9 @@ def _run_worker_ep_fsdp_sp_align(
450453 sp_sel = sp_sel .view (batch_size , end - start , - 1 )
451454 if not torch .equal (base_sel , sp_sel ):
452455 mismatch = (base_sel != sp_sel ).sum ().item ()
453- print (f'[rank{ rank } ] block[{ idx } ] selected_experts mismatch count={ mismatch } ' , flush = True )
454- assert torch .equal (base_sel , sp_sel )
456+ raise AssertionError (
457+ f'[rank{ rank } ] block[{ idx } ] selected_experts mismatch count={ mismatch } '
458+ f'base_shape={ tuple (base_sel .shape )} sp_shape={ tuple (sp_sel .shape )} ' )
455459
456460 # Backward alignment (expert grads on active local experts for this slice).
457461 sp_loss_sum = F .cross_entropy (
@@ -627,11 +631,10 @@ def _run_worker_fsdp_sp_align(
627631 # Forward alignment (full-seq logits reconstructed by SP gather).
628632 if not torch .allclose (sp_logits , base_logits , rtol = 1e-3 , atol = 1e-4 ):
629633 diff = (sp_logits - base_logits ).abs ()
630- print (
631- f'[rank{ rank } ] logits diff mean= { diff . mean (). item ():.6e } max ={ diff .max ().item ():.6e} ' ,
632- flush = True ,
634+ raise AssertionError (
635+ f'[rank{ rank } ] logits not close: mean_abs ={ diff .mean ().item ():.6e} '
636+ f'max_abs= { diff . max (). item ():.6e } (rtol=1e-3, atol=1e-4)' ,
633637 )
634- assert torch .allclose (sp_logits , base_logits , rtol = 1e-3 , atol = 1e-4 )
635638
636639 # Backward alignment: local CE(sum) on SP, compare gathered full-seq inputs_embeds grads.
637640 sp_loss_sum = F .cross_entropy (
@@ -656,7 +659,6 @@ def _run_worker_fsdp_sp_align(
656659 raise AssertionError (
657660 f'[rank{ rank } ] inputs_embeds.grad(full) not close: mean_abs={ abs_diff .mean ().item ():.6e} '
658661 f'max_abs={ abs_diff .max ().item ():.6e} rel_norm={ rel .item ():.6e} tol={ grad_rel_tol :.1e} ' )
659- assert rel .item () <= grad_rel_tol
660662 finally :
661663 dist .destroy_process_group ()
662664
0 commit comments