@@ -800,7 +800,8 @@ def pad_and_split_inputs(self,
800800 # - In next-token-aligned labels, this appears at labels[b-1]
801801 boundary_starts = (real_position_ids == 0 )
802802 prev = torch .zeros_like (boundary_starts , dtype = torch .bool )
803- prev [..., 1 :] = boundary_starts [..., :- 1 ]
803+ # Mask token b-1 when boundary starts at b.
804+ prev [..., :- 1 ] = boundary_starts [..., 1 :]
804805 labels = labels .clone ()
805806 labels [prev ] = - 100
806807 # Also avoid any potential wrap-around supervision at the end of the concatenated stream.
@@ -1003,13 +1004,15 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
10031004 )
10041005 compensate_fsdp_avg = bool (self .sp_config .get ("compensate_fsdp_avg" , False ))
10051006 compensate_factor = float (self .ulysses_size if compensate_fsdp_avg else 1.0 )
1007+ sum_metric_scale = float (self .ulysses_size )
10061008
10071009 class _ReduceSequenceParallelLoss (torch .autograd .Function ):
10081010 @staticmethod
1009- def forward (ctx , local_sum : torch .Tensor , num_valid_tokens : torch .Tensor ) -> torch .Tensor :
1010- if num_valid_tokens .item () == 0 :
1011- local_sum = torch .nan_to_num (local_sum )
1011+ def forward (ctx , local_mean : torch .Tensor , num_valid_tokens : torch .Tensor ) -> torch .Tensor :
10121012 local_tokens = num_valid_tokens .detach ().clone ()
1013+ local_sum = local_mean * local_tokens
1014+ if local_tokens .item () == 0 :
1015+ local_sum = torch .nan_to_num (local_sum )
10131016 global_sum = local_sum .detach ().clone ()
10141017 dist .all_reduce (global_sum , group = sequence_parallel ._sp_group )
10151018 global_tokens = num_valid_tokens .detach ().clone ()
@@ -1023,28 +1026,33 @@ def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> tor
10231026 def backward (ctx , grad_output : torch .Tensor ):
10241027 local_tokens , global_tokens = ctx .saved_tensors
10251028 if global_tokens .item () == 0 :
1026- return grad_output , None
1027- grad_local_sum = grad_output * (local_tokens / global_tokens ) * compensate_factor
1028- return grad_local_sum , None
1029+ return torch .zeros_like (grad_output ), None
1030+ # d(global_mean)/d(local_mean) = local_tokens / global_tokens.
1031+ grad_local_mean = grad_output * (local_tokens / global_tokens ) * compensate_factor
1032+ return grad_local_mean , None
10291033
10301034 class _ReduceSequenceParallelSum (torch .autograd .Function ):
10311035 @staticmethod
10321036 def forward (ctx , local_sum : torch .Tensor ) -> torch .Tensor :
1037+ ctx .sum_metric_scale = sum_metric_scale
10331038 global_sum = local_sum .detach ().clone ()
10341039 dist .all_reduce (global_sum , group = sequence_parallel ._sp_group )
1035- return global_sum
1040+ # Keep logging/metric value aligned with non-SP sum semantics under
1041+ # outer collect='mean' by removing one SP replication factor.
1042+ return global_sum / ctx .sum_metric_scale
10361043
10371044 @staticmethod
10381045 def backward (ctx , grad_output : torch .Tensor ):
1046+ # Keep training gradient scale unchanged; forward-side scaling is for
1047+ # logging/metric alignment under outer collect='mean'.
10391048 return grad_output
10401049
10411050 if reduction == "sum" :
10421051 return _ReduceSequenceParallelSum .apply (loss )
10431052
1044- # Default to mean reduction: assume `loss` is local mean, convert to local sum .
1053+ # Default to mean reduction: `loss` is local mean.
10451054 num_valid_tokens = (labels != ignore_index ).sum ().to (loss .device )
1046- local_sum = loss * num_valid_tokens
1047- return _ReduceSequenceParallelLoss .apply (local_sum , num_valid_tokens )
1055+ return _ReduceSequenceParallelLoss .apply (loss , num_valid_tokens )
10481056
10491057 def wrap_model (self , model , optimizer = None ):
10501058 self .initialize ()
0 commit comments