Skip to content

Commit c0bfaef

Browse files
committed
loss metric fix
1 parent 4937e53 commit c0bfaef

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

src/twinkle/model/transformers/strategy/sequence_parallel.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,7 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
10411041
)
10421042
compensate_fsdp_avg = bool(self.sp_config.get("compensate_fsdp_avg", False))
10431043
compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0)
1044+
sum_metric_scale = float(self.ulysses_size)
10441045
loss_in = loss.detach()
10451046

10461047
class _ReduceSequenceParallelLoss(torch.autograd.Function):
@@ -1071,13 +1072,17 @@ def backward(ctx, grad_output: torch.Tensor):
10711072
class _ReduceSequenceParallelSum(torch.autograd.Function):
10721073
@staticmethod
10731074
def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor:
1075+
ctx.sum_metric_scale = sum_metric_scale
10741076
global_sum = local_sum.detach().clone()
10751077
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
1076-
return global_sum
1078+
# Keep logging/metric value aligned with non-SP sum semantics under
1079+
# outer collect='mean' by removing one SP replication factor.
1080+
return global_sum / ctx.sum_metric_scale
10771081

10781082
@staticmethod
10791083
def backward(ctx, grad_output: torch.Tensor):
1080-
return grad_output
1084+
# Preserve original gradient scale (before forward-side metric scaling).
1085+
return grad_output * ctx.sum_metric_scale
10811086

10821087
if reduction == "sum":
10831088
out = _ReduceSequenceParallelSum.apply(loss)

src/twinkle/model/transformers/transformers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ def __post_init__(self):
8080

8181
def _build_metrics(self):
8282
self.train_metrics = [
83-
LossMetric(self._device_mesh, self._dp_group, loss_reduction='mean'),
83+
LossMetric(self._device_mesh, self._dp_group, loss_reduction='sum'),
8484
Accuracy(self._device_mesh, self._dp_group),
8585
TrainMetric(self._device_mesh, self._dp_group),
8686
]
8787

8888
self.eval_metrics = [
89-
LossMetric(self._device_mesh, self._dp_group, loss_reduction='mean'),
89+
LossMetric(self._device_mesh, self._dp_group, loss_reduction='sum'),
9090
Accuracy(self._device_mesh, self._dp_group),
9191
TrainMetric(self._device_mesh, self._dp_group),
9292
]
@@ -317,7 +317,7 @@ def _ensure_optimizer_dp_groups(self):
317317

318318
def _construct_default_optimizer_group(self):
319319
return OptimizerGroup(
320-
loss_instance=CrossEntropyLoss(reduction='mean'),
320+
loss_instance=CrossEntropyLoss(reduction='sum'),
321321
template=Template(self.tokenizer_id),
322322
processor=InputProcessor(self.device_mesh),
323323
_device_mesh=self.device_mesh,
@@ -431,10 +431,9 @@ def calculate_loss(self, **kwargs):
431431
optimizer_config = self.optimizer_group[adapter_name]
432432
optimizer_config.num_tokens += counts.item()
433433
if self.sp_strategy is not None and 'labels' in inputs:
434-
if "loss_reduction" not in self.sp_strategy.sp_config:
435-
reduction = getattr(loss_instance, "reduction", None)
436-
if reduction is not None:
437-
self.sp_strategy.sp_config["loss_reduction"] = str(reduction)
434+
reduction = getattr(loss_instance, "reduction", None)
435+
if reduction is not None:
436+
self.sp_strategy.sp_config["loss_reduction"] = str(reduction)
438437
loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels'])
439438
optimizer_config.loss_value += loss_value
440439
outputs['loss'] = optimizer_config.loss_value

0 commit comments

Comments
 (0)