From 4170cf08e4c6db3d4691c34fa9500dd974026727 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 1 Mar 2026 21:28:47 +0800 Subject: [PATCH 1/3] fix --- src/twinkle/model/megatron/args.py | 1 + src/twinkle/model/megatron/megatron.py | 11 +++++------ src/twinkle/model/megatron/strategy/megatron.py | 3 +-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/twinkle/model/megatron/args.py b/src/twinkle/model/megatron/args.py index 80759d82..eacc10db 100644 --- a/src/twinkle/model/megatron/args.py +++ b/src/twinkle/model/megatron/args.py @@ -608,6 +608,7 @@ def _get_base_model(m): # Critical: Set finalize_model_grads_func for DP gradient synchronization # Uses custom wrapper that handles both DDP and PEFT/LoRA models finalize_model_grads_func=finalize_model_grads_for_lora, + calculate_per_token_loss=True, # MoE configuration **moe_kwargs, ) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 419e03bd..787a4bb4 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -527,17 +527,16 @@ def forward_step_func(data_iterator, model): if isinstance(loss_dict, dict): if 'loss' in loss_dict: loss += loss_dict['loss'] - count += 1 if 'logits' in loss_dict: logits.append(loss_dict['logits']) if 'logps' in loss_dict: logps.append(loss_dict['logps']) + if 'num_tokens' in loss_dict: + count += loss_dict['num_tokens'] elif isinstance(loss_dict, torch.Tensor): - loss += loss_dict - count += 1 - - if count > 0: - loss /= count + raise ValueError('Expected loss dict, got tensor') + + loss = loss / (count or 1) # For PP > 1, broadcast loss from last PP stage to all ranks # Note: mpu is imported at module level, no need to reimport diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 758afb4e..f58b25c8 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -134,8 +134,7 @@ def _wrap_with_megatron_ddp( return wrapped_models def reduce_loss(self, local_loss, local_count, logits, logps): - loss = local_loss / local_count.clamp(min=1) - return loss, {'loss': loss.detach(), 'logits': logits.detach(), 'logps': logps.detach()} + return local_loss, local_count.clamp(min=1).to(torch.int64), {'loss': local_loss.detach(), 'logits': logits.detach(), 'logps': logps.detach(), 'num_tokens': local_count.clamp(min=1).to(torch.int64)} def get_model_config( self, From 95686088eaaafdf59e34d74e3bbd1c9fbb37b5f7 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 1 Mar 2026 21:29:58 +0800 Subject: [PATCH 2/3] fix --- src/twinkle/model/megatron/megatron.py | 2 +- src/twinkle/model/megatron/strategy/megatron.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 787a4bb4..68e68f5a 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -535,7 +535,7 @@ def forward_step_func(data_iterator, model): count += loss_dict['num_tokens'] elif isinstance(loss_dict, torch.Tensor): raise ValueError('Expected loss dict, got tensor') - + loss = loss / (count or 1) # For PP > 1, broadcast loss from last PP stage to all ranks diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index f58b25c8..d96d4b41 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -134,7 +134,12 @@ def _wrap_with_megatron_ddp( return wrapped_models def reduce_loss(self, local_loss, local_count, logits, logps): - return local_loss, local_count.clamp(min=1).to(torch.int64), {'loss': local_loss.detach(), 'logits': logits.detach(), 'logps': logps.detach(), 'num_tokens': local_count.clamp(min=1).to(torch.int64)} + return local_loss, local_count.clamp(min=1).to(torch.int64), { + 'loss': local_loss.detach(), + 'logits': logits.detach(), + 'logps': logps.detach(), + 'num_tokens': local_count.clamp(min=1).to(torch.int64) + } def get_model_config( self, From 2d10a9f5947d45f78afb856c40533bf415d033ac Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 1 Mar 2026 21:30:49 +0800 Subject: [PATCH 3/3] fix --- src/twinkle/model/megatron/strategy/megatron.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index d96d4b41..362e65ae 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -134,11 +134,12 @@ def _wrap_with_megatron_ddp( return wrapped_models def reduce_loss(self, local_loss, local_count, logits, logps): - return local_loss, local_count.clamp(min=1).to(torch.int64), { + count = local_count.clamp(min=1).to(torch.int64) + return local_loss, count, { 'loss': local_loss.detach(), 'logits': logits.detach(), 'logps': logps.detach(), - 'num_tokens': local_count.clamp(min=1).to(torch.int64) + 'num_tokens': count } def get_model_config(