Skip to content

Commit a0b9ece

Browse files
committed
feat(sequence_parallel): refactor loss reduction using custom autograd functions
Replace manual gradient handling with `torch.autograd.Function` subclasses `_ReduceSequenceParallelLoss` and `_ReduceSequenceParallelSum` to compute global loss via autograd-aware all-reduce. This simplifies the logic for both sum and mean reductions, improves gradient correctness, and removes the need for separate metric scaling when `world_size > 1`.
1 parent bcfb465 commit a0b9ece

File tree

1 file changed

+41
-22
lines changed

1 file changed

+41
-22
lines changed

src/twinkle/model/transformers/strategy/sequence_parallel.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -990,36 +990,55 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
990990
return loss
991991
if labels is None or sequence_parallel._sp_group is None:
992992
return loss
993-
# Compute full-sequence loss in forward, but keep backward local to this rank.
993+
# Compute global loss via autograd-aware all-reduce.
994994
reduction = str(self.sp_config.get("loss_reduction", "mean")).lower()
995995
if reduction == "none":
996996
raise ValueError(
997997
"SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. "
998998
"Please aggregate per-token losses before calling reduce_loss."
999999
)
1000-
num_valid_tokens = (labels != ignore_index).sum().to(loss.device)
1000+
1001+
class _ReduceSequenceParallelLoss(torch.autograd.Function):
1002+
@staticmethod
1003+
def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor:
1004+
if num_valid_tokens.item() == 0:
1005+
local_sum = torch.nan_to_num(local_sum)
1006+
local_tokens = num_valid_tokens.detach().clone()
1007+
global_sum = local_sum.detach().clone()
1008+
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
1009+
global_tokens = num_valid_tokens.detach().clone()
1010+
dist.all_reduce(global_tokens, group=sequence_parallel._sp_group)
1011+
ctx.save_for_backward(local_tokens, global_tokens)
1012+
if global_tokens.item() == 0:
1013+
return local_sum
1014+
return global_sum / global_tokens
1015+
1016+
@staticmethod
1017+
def backward(ctx, grad_output: torch.Tensor):
1018+
local_tokens, global_tokens = ctx.saved_tensors
1019+
if global_tokens.item() == 0:
1020+
return grad_output, None
1021+
grad_local_sum = grad_output * (local_tokens / global_tokens)
1022+
return grad_local_sum, None
1023+
1024+
class _ReduceSequenceParallelSum(torch.autograd.Function):
1025+
@staticmethod
1026+
def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor:
1027+
global_sum = local_sum.detach().clone()
1028+
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
1029+
return global_sum
1030+
1031+
@staticmethod
1032+
def backward(ctx, grad_output: torch.Tensor):
1033+
return grad_output
1034+
10011035
if reduction == "sum":
1002-
local_sum = loss
1003-
global_sum = local_sum.detach().clone()
1004-
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
1005-
out = global_sum + (local_sum - local_sum.detach())
1006-
if sequence_parallel.world_size > 1:
1007-
out_metric = out.detach() / sequence_parallel.world_size
1008-
return out_metric + (out - out.detach())
1009-
return out
1010-
# Default to mean reduction.
1036+
return _ReduceSequenceParallelSum.apply(loss)
1037+
1038+
# Default to mean reduction: assume `loss` is local mean, convert to local sum.
1039+
num_valid_tokens = (labels != ignore_index).sum().to(loss.device)
10111040
local_sum = loss * num_valid_tokens
1012-
global_sum = local_sum.detach().clone()
1013-
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
1014-
global_tokens = num_valid_tokens.detach().clone()
1015-
dist.all_reduce(global_tokens, group=sequence_parallel._sp_group)
1016-
if global_tokens.item() == 0:
1017-
return loss
1018-
out = (global_sum + (local_sum - local_sum.detach())) / global_tokens
1019-
if sequence_parallel.world_size > 1:
1020-
out_metric = out.detach() / sequence_parallel.world_size
1021-
return out_metric + (out - out.detach())
1022-
return out
1041+
return _ReduceSequenceParallelLoss.apply(local_sum, num_valid_tokens)
10231042

10241043
def wrap_model(self, model, optimizer=None):
10251044
self.initialize()

0 commit comments

Comments
 (0)