Skip to content

Commit 6bdaaca

Browse files
committed
wip
1 parent aa86099 commit 6bdaaca

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

src/twinkle/model/megatron/megatron.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,12 @@ def post_loss_function(output_tensor, inputs, logps):
479479
losses = result['loss']
480480
counts = result['num_tokens']
481481
if not counts:
482+
# Later will gather this value, so it becomes:
483+
# 1. SUM loss: gather_sum(local_num_tokens) = global_num_tokens
484+
# 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps ) = gradient_accumulation_steps * world_size
485+
# Then, grad will divided by this value:
486+
# 1. SUM loss: (global_sum_grad) / (global_num_tokens) = global_sum_grad/global_num_tokens
487+
# 2. PER TOKEN MEAN loss: (gather_sum(per_token_grad * gradient_accumulation_steps)) / (gradient_accumulation_steps * world_size ) = avg_per_token_grad
482488
counts = torch.tensor(1, device=losses.device)
483489
return self.strategy.reduce_loss(losses, counts, output_tensor, logps)
484490

src/twinkle/model/transformers/transformers.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,18 @@ def calculate_loss(self, **kwargs):
500500
loss_value = result['loss']
501501
counts = result['num_tokens']
502502
if not counts:
503-
counts = torch.tensor(0, device=loss_value.device)
503+
counts = torch.tensor(1, device=loss_value.device)
504+
# Later will gather this value, so it becomes:
505+
# 1. SUM loss: gather_sum(local_num_tokens / dp_world_size) = global_num_tokens / dp_world_size
506+
# 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps / dp_world_size ) = gradient_accumulation_steps
507+
# Then, grad will divided by this value:
508+
# 1. SUM loss: gather_mean(local_sum_grad) / (global_num_tokens / dp_world_size)
509+
# = (global_sum_grad / dp_world_size) / (global_num_tokens / dp_world_size)
510+
# = global_sum_grad/global_num_tokens
511+
# 2. PER TOKEN MEAN loss: gather_mean(per_token_grad * gradient_accumulation_steps) / gradient_accumulation_steps
512+
# = (global_per_token_grad * gradient_accumulation_steps / dp_world_size ) / gradient_accumulation_steps
513+
# = global_per_token_grad / dp_world_size = avg_per_token_grad
514+
counts = counts / self.device_mesh.data_world_size
504515
optimizer_config = self.optimizer_group[adapter_name]
505516
optimizer_config.num_tokens += counts.item()
506517
if self.sp_strategy is not None and 'labels' in inputs:

0 commit comments

Comments
 (0)