From 3a4203b50b39754e08bc2d7da02b17b1126ac608 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 1 Mar 2026 19:03:54 +0800 Subject: [PATCH] fix logps pp --- cookbook/megatron/tp.py | 3 --- src/twinkle/model/megatron/megatron.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/cookbook/megatron/tp.py b/cookbook/megatron/tp.py index 8bf4525c..b09d1a60 100644 --- a/cookbook/megatron/tp.py +++ b/cookbook/megatron/tp.py @@ -8,9 +8,6 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import MegatronModel from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum -from twinkle.server.tinker.common.compat_base import TwinkleCompatModelBase - # Construct a device_mesh, tp=pp=cp=2, dp=1 device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, cp_size=2) # use torchrun mode diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 22977721..419e03bd 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -481,7 +481,7 @@ def forward_step_func(data_iterator, model): output_tensor = model(**batch) batch['labels'] = labels logps = None - if labels is not None: + if labels is not None and mpu.is_pipeline_last_stage(): loss_mask = (labels != -100).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0