File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -417,7 +417,8 @@ def _run_worker_ep_fsdp_sp_align(
417417 model_sp , _ = fsdp_strategy .wrap_model (model_sp , optimizer = None )
418418
419419 # Preprocess labels through SP strategy so they are shifted + split consistently.
420- sp_label_inputs = {'labels' : labels_raw , 'position_ids' : position_ids }
420+ # Keep label semantics consistent with the baseline path: next-token aligned labels.
421+ sp_label_inputs = {'labels' : labels_shifted , 'position_ids' : position_ids }
421422 sp_label_inputs = sp_strategy .preprocess_inputs (sp_label_inputs )
422423 sp_local_labels = sp_label_inputs ['labels' ]
423424
@@ -613,7 +614,8 @@ def _run_worker_fsdp_sp_align(
613614 sp_embeds = embed_sp (input_ids ).detach ().requires_grad_ (True )
614615 model_sp , _ = fsdp_strategy .wrap_model (model_sp , optimizer = None )
615616
616- sp_label_inputs = {'labels' : labels_raw , 'position_ids' : position_ids }
617+ # Keep label semantics consistent with the baseline path: next-token aligned labels.
618+ sp_label_inputs = {'labels' : labels_shifted , 'position_ids' : position_ids }
617619 sp_label_inputs = sp_strategy .preprocess_inputs (sp_label_inputs )
618620 sp_local_labels = sp_label_inputs ['labels' ]
619621
You can’t perform that action at this time.
0 commit comments