@@ -870,7 +870,7 @@ class SequenceParallelConfig:
870870 enabled : bool = True
871871 ulysses_size : Optional [int ] = None
872872 gather_logits : bool = True
873- loss_reduction : str = " mean"
873+ loss_reduction : str = ' mean'
874874 compensate_fsdp_avg : bool = False
875875
876876
@@ -975,17 +975,16 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
975975 if labels is None or sequence_parallel ._sp_group is None :
976976 return loss
977977 # Compute global loss via autograd-aware all-reduce.
978- reduction = str (self .sp_config .get ("loss_reduction" , "mean" )).lower ()
979- if reduction == "none" :
980- raise ValueError (
981- "SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. "
982- "Please aggregate per-token losses before calling reduce_loss."
983- )
984- compensate_fsdp_avg = bool (self .sp_config .get ("compensate_fsdp_avg" , False ))
978+ reduction = str (self .sp_config .get ('loss_reduction' , 'mean' )).lower ()
979+ if reduction == 'none' :
980+ raise ValueError ("SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. "
981+ 'Please aggregate per-token losses before calling reduce_loss.' )
982+ compensate_fsdp_avg = bool (self .sp_config .get ('compensate_fsdp_avg' , False ))
985983 compensate_factor = float (self .ulysses_size if compensate_fsdp_avg else 1.0 )
986984 sum_metric_scale = float (self .ulysses_size )
987985
988986 class _ReduceSequenceParallelLoss (torch .autograd .Function ):
987+
989988 @staticmethod
990989 def forward (ctx , local_mean : torch .Tensor , num_valid_tokens : torch .Tensor ) -> torch .Tensor :
991990 local_tokens = num_valid_tokens .detach ().clone ()
@@ -1011,6 +1010,7 @@ def backward(ctx, grad_output: torch.Tensor):
10111010 return grad_local_mean , None
10121011
10131012 class _ReduceSequenceParallelSum (torch .autograd .Function ):
1013+
10141014 @staticmethod
10151015 def forward (ctx , local_sum : torch .Tensor ) -> torch .Tensor :
10161016 ctx .sum_metric_scale = sum_metric_scale
@@ -1026,7 +1026,7 @@ def backward(ctx, grad_output: torch.Tensor):
10261026 # logging/metric alignment under outer collect='mean'.
10271027 return grad_output
10281028
1029- if reduction == " sum" :
1029+ if reduction == ' sum' :
10301030 return _ReduceSequenceParallelSum .apply (loss )
10311031
10321032 # Default to mean reduction: `loss` is local mean.
0 commit comments