@@ -788,7 +788,8 @@ def pad_and_split_inputs(self,
788788 # - In next-token-aligned labels, this appears at labels[b-1]
789789 boundary_starts = (real_position_ids == 0 )
790790 prev = torch .zeros_like (boundary_starts , dtype = torch .bool )
791- prev [..., 1 :] = boundary_starts [..., :- 1 ]
791+ # Mask token b-1 when boundary starts at b.
792+ prev [..., :- 1 ] = boundary_starts [..., 1 :]
792793 labels = labels .clone ()
793794 labels [prev ] = - 100
794795 # Also avoid any potential wrap-around supervision at the end of the concatenated stream.
@@ -982,13 +983,15 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
982983 )
983984 compensate_fsdp_avg = bool (self .sp_config .get ("compensate_fsdp_avg" , False ))
984985 compensate_factor = float (self .ulysses_size if compensate_fsdp_avg else 1.0 )
986+ sum_metric_scale = float (self .ulysses_size )
985987
986988 class _ReduceSequenceParallelLoss (torch .autograd .Function ):
987989 @staticmethod
988- def forward (ctx , local_sum : torch .Tensor , num_valid_tokens : torch .Tensor ) -> torch .Tensor :
989- if num_valid_tokens .item () == 0 :
990- local_sum = torch .nan_to_num (local_sum )
990+ def forward (ctx , local_mean : torch .Tensor , num_valid_tokens : torch .Tensor ) -> torch .Tensor :
991991 local_tokens = num_valid_tokens .detach ().clone ()
992+ local_sum = local_mean * local_tokens
993+ if local_tokens .item () == 0 :
994+ local_sum = torch .nan_to_num (local_sum )
992995 global_sum = local_sum .detach ().clone ()
993996 dist .all_reduce (global_sum , group = sequence_parallel ._sp_group )
994997 global_tokens = num_valid_tokens .detach ().clone ()
@@ -1002,28 +1005,33 @@ def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> tor
10021005 def backward (ctx , grad_output : torch .Tensor ):
10031006 local_tokens , global_tokens = ctx .saved_tensors
10041007 if global_tokens .item () == 0 :
1005- return grad_output , None
1006- grad_local_sum = grad_output * (local_tokens / global_tokens ) * compensate_factor
1007- return grad_local_sum , None
1008+ return torch .zeros_like (grad_output ), None
1009+ # d(global_mean)/d(local_mean) = local_tokens / global_tokens.
1010+ grad_local_mean = grad_output * (local_tokens / global_tokens ) * compensate_factor
1011+ return grad_local_mean , None
10081012
10091013 class _ReduceSequenceParallelSum (torch .autograd .Function ):
10101014 @staticmethod
10111015 def forward (ctx , local_sum : torch .Tensor ) -> torch .Tensor :
1016+ ctx .sum_metric_scale = sum_metric_scale
10121017 global_sum = local_sum .detach ().clone ()
10131018 dist .all_reduce (global_sum , group = sequence_parallel ._sp_group )
1014- return global_sum
1019+ # Keep logging/metric value aligned with non-SP sum semantics under
1020+ # outer collect='mean' by removing one SP replication factor.
1021+ return global_sum / ctx .sum_metric_scale
10151022
10161023 @staticmethod
10171024 def backward (ctx , grad_output : torch .Tensor ):
1025+ # Keep training gradient scale unchanged; forward-side scaling is for
1026+ # logging/metric alignment under outer collect='mean'.
10181027 return grad_output
10191028
10201029 if reduction == "sum" :
10211030 return _ReduceSequenceParallelSum .apply (loss )
10221031
1023- # Default to mean reduction: assume `loss` is local mean, convert to local sum .
1032+ # Default to mean reduction: `loss` is local mean.
10241033 num_valid_tokens = (labels != ignore_index ).sum ().to (loss .device )
1025- local_sum = loss * num_valid_tokens
1026- return _ReduceSequenceParallelLoss .apply (local_sum , num_valid_tokens )
1034+ return _ReduceSequenceParallelLoss .apply (loss , num_valid_tokens )
10271035
10281036 def wrap_model (self , model , optimizer = None ):
10291037 self .initialize ()
0 commit comments