diff --git a/src/twinkle/model/megatron/args.py b/src/twinkle/model/megatron/args.py index 80759d82..eacc10db 100644 --- a/src/twinkle/model/megatron/args.py +++ b/src/twinkle/model/megatron/args.py @@ -608,6 +608,7 @@ def _get_base_model(m): # Critical: Set finalize_model_grads_func for DP gradient synchronization # Uses custom wrapper that handles both DDP and PEFT/LoRA models finalize_model_grads_func=finalize_model_grads_for_lora, + calculate_per_token_loss=True, # MoE configuration **moe_kwargs, ) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 419e03bd..68e68f5a 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -527,17 +527,16 @@ def forward_step_func(data_iterator, model): if isinstance(loss_dict, dict): if 'loss' in loss_dict: loss += loss_dict['loss'] - count += 1 if 'logits' in loss_dict: logits.append(loss_dict['logits']) if 'logps' in loss_dict: logps.append(loss_dict['logps']) + if 'num_tokens' in loss_dict: + count += loss_dict['num_tokens'] elif isinstance(loss_dict, torch.Tensor): - loss += loss_dict - count += 1 + raise ValueError('Expected loss dict, got tensor') - if count > 0: - loss /= count + loss = loss / (count or 1) # For PP > 1, broadcast loss from last PP stage to all ranks # Note: mpu is imported at module level, no need to reimport diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 758afb4e..362e65ae 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -134,8 +134,13 @@ def _wrap_with_megatron_ddp( return wrapped_models def reduce_loss(self, local_loss, local_count, logits, logps): - loss = local_loss / local_count.clamp(min=1) - return loss, {'loss': loss.detach(), 'logits': logits.detach(), 'logps': logps.detach()} + count = local_count.clamp(min=1).to(torch.int64) + return local_loss, count, { + 'loss': local_loss.detach(), + 'logits': logits.detach(), + 'logps': logps.detach(), + 'num_tokens': count + } def get_model_config( self,