Skip to content

Commit 7a783d4

Browse files
committed
wip
1 parent d6dd5c3 commit 7a783d4

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

tests/moe/test_expert_parallel_qwen3_fsdp_sp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,8 @@ def _run_worker_ep_fsdp_sp_align(
417417
model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None)
418418

419419
# Preprocess labels through SP strategy so they are shifted + split consistently.
420-
sp_label_inputs = {'labels': labels_raw, 'position_ids': position_ids}
420+
# Keep label semantics consistent with the baseline path: next-token aligned labels.
421+
sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids}
421422
sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs)
422423
sp_local_labels = sp_label_inputs['labels']
423424

@@ -613,7 +614,8 @@ def _run_worker_fsdp_sp_align(
613614
sp_embeds = embed_sp(input_ids).detach().requires_grad_(True)
614615
model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None)
615616

616-
sp_label_inputs = {'labels': labels_raw, 'position_ids': position_ids}
617+
# Keep label semantics consistent with the baseline path: next-token aligned labels.
618+
sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids}
617619
sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs)
618620
sp_local_labels = sp_label_inputs['labels']
619621

0 commit comments

Comments
 (0)