Skip to content

Commit 3982a74

Browse files
Fix megatron loss (#90)
1 parent f682b5d commit 3982a74

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

src/twinkle/model/megatron/args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ def _get_base_model(m):
608608
# Critical: Set finalize_model_grads_func for DP gradient synchronization
609609
# Uses custom wrapper that handles both DDP and PEFT/LoRA models
610610
finalize_model_grads_func=finalize_model_grads_for_lora,
611+
calculate_per_token_loss=True,
611612
# MoE configuration
612613
**moe_kwargs,
613614
)

src/twinkle/model/megatron/megatron.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -527,17 +527,16 @@ def forward_step_func(data_iterator, model):
527527
if isinstance(loss_dict, dict):
528528
if 'loss' in loss_dict:
529529
loss += loss_dict['loss']
530-
count += 1
531530
if 'logits' in loss_dict:
532531
logits.append(loss_dict['logits'])
533532
if 'logps' in loss_dict:
534533
logps.append(loss_dict['logps'])
534+
if 'num_tokens' in loss_dict:
535+
count += loss_dict['num_tokens']
535536
elif isinstance(loss_dict, torch.Tensor):
536-
loss += loss_dict
537-
count += 1
537+
raise ValueError('Expected loss dict, got tensor')
538538

539-
if count > 0:
540-
loss /= count
539+
loss = loss / (count or 1)
541540

542541
# For PP > 1, broadcast loss from last PP stage to all ranks
543542
# Note: mpu is imported at module level, no need to reimport

src/twinkle/model/megatron/strategy/megatron.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,13 @@ def _wrap_with_megatron_ddp(
134134
return wrapped_models
135135

136136
def reduce_loss(self, local_loss, local_count, logits, logps):
137-
loss = local_loss / local_count.clamp(min=1)
138-
return loss, {'loss': loss.detach(), 'logits': logits.detach(), 'logps': logps.detach()}
137+
count = local_count.clamp(min=1).to(torch.int64)
138+
return local_loss, count, {
139+
'loss': local_loss.detach(),
140+
'logits': logits.detach(),
141+
'logps': logps.detach(),
142+
'num_tokens': count
143+
}
139144

140145
def get_model_config(
141146
self,

0 commit comments

Comments
 (0)