@@ -990,36 +990,55 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
990990 return loss
991991 if labels is None or sequence_parallel ._sp_group is None :
992992 return loss
993- # Compute full-sequence loss in forward, but keep backward local to this rank .
993+ # Compute global loss via autograd-aware all-reduce .
994994 reduction = str (self .sp_config .get ("loss_reduction" , "mean" )).lower ()
995995 if reduction == "none" :
996996 raise ValueError (
997997 "SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. "
998998 "Please aggregate per-token losses before calling reduce_loss."
999999 )
1000- num_valid_tokens = (labels != ignore_index ).sum ().to (loss .device )
1000+
1001+ class _ReduceSequenceParallelLoss (torch .autograd .Function ):
1002+ @staticmethod
1003+ def forward (ctx , local_sum : torch .Tensor , num_valid_tokens : torch .Tensor ) -> torch .Tensor :
1004+ if num_valid_tokens .item () == 0 :
1005+ local_sum = torch .nan_to_num (local_sum )
1006+ local_tokens = num_valid_tokens .detach ().clone ()
1007+ global_sum = local_sum .detach ().clone ()
1008+ dist .all_reduce (global_sum , group = sequence_parallel ._sp_group )
1009+ global_tokens = num_valid_tokens .detach ().clone ()
1010+ dist .all_reduce (global_tokens , group = sequence_parallel ._sp_group )
1011+ ctx .save_for_backward (local_tokens , global_tokens )
1012+ if global_tokens .item () == 0 :
1013+ return local_sum
1014+ return global_sum / global_tokens
1015+
1016+ @staticmethod
1017+ def backward (ctx , grad_output : torch .Tensor ):
1018+ local_tokens , global_tokens = ctx .saved_tensors
1019+ if global_tokens .item () == 0 :
1020+ return grad_output , None
1021+ grad_local_sum = grad_output * (local_tokens / global_tokens )
1022+ return grad_local_sum , None
1023+
1024+ class _ReduceSequenceParallelSum (torch .autograd .Function ):
1025+ @staticmethod
1026+ def forward (ctx , local_sum : torch .Tensor ) -> torch .Tensor :
1027+ global_sum = local_sum .detach ().clone ()
1028+ dist .all_reduce (global_sum , group = sequence_parallel ._sp_group )
1029+ return global_sum
1030+
1031+ @staticmethod
1032+ def backward (ctx , grad_output : torch .Tensor ):
1033+ return grad_output
1034+
10011035 if reduction == "sum" :
1002- local_sum = loss
1003- global_sum = local_sum .detach ().clone ()
1004- dist .all_reduce (global_sum , group = sequence_parallel ._sp_group )
1005- out = global_sum + (local_sum - local_sum .detach ())
1006- if sequence_parallel .world_size > 1 :
1007- out_metric = out .detach () / sequence_parallel .world_size
1008- return out_metric + (out - out .detach ())
1009- return out
1010- # Default to mean reduction.
1036+ return _ReduceSequenceParallelSum .apply (loss )
1037+
1038+ # Default to mean reduction: assume `loss` is local mean, convert to local sum.
1039+ num_valid_tokens = (labels != ignore_index ).sum ().to (loss .device )
10111040 local_sum = loss * num_valid_tokens
1012- global_sum = local_sum .detach ().clone ()
1013- dist .all_reduce (global_sum , group = sequence_parallel ._sp_group )
1014- global_tokens = num_valid_tokens .detach ().clone ()
1015- dist .all_reduce (global_tokens , group = sequence_parallel ._sp_group )
1016- if global_tokens .item () == 0 :
1017- return loss
1018- out = (global_sum + (local_sum - local_sum .detach ())) / global_tokens
1019- if sequence_parallel .world_size > 1 :
1020- out_metric = out .detach () / sequence_parallel .world_size
1021- return out_metric + (out - out .detach ())
1022- return out
1041+ return _ReduceSequenceParallelLoss .apply (local_sum , num_valid_tokens )
10231042
10241043 def wrap_model (self , model , optimizer = None ):
10251044 self .initialize ()
0 commit comments