11# Copyright (c) ModelScope Contributors. All rights reserved.
2+ import os
23from functools import partial
34from typing import Any , Dict , Optional , Tuple , Union
45
@@ -921,6 +922,43 @@ def __init__(
921922 self ._tokenizer_id = tokenizer_id
922923 self ._tokenizer = None
923924 self ._initialized = False
925+ debug_flag = os .getenv ("TWINKLE_DEBUG_SP_LOSS" , "" ).strip ().lower ()
926+ self ._debug_sp_loss = debug_flag not in ("" , "0" , "false" , "off" , "no" )
927+ try :
928+ self ._debug_sp_loss_max_steps = max (1 , int (os .getenv ("TWINKLE_DEBUG_SP_LOSS_STEPS" , "8" )))
929+ except ValueError :
930+ self ._debug_sp_loss_max_steps = 8
931+ self ._debug_sp_loss_seen = 0
932+
933+ def _maybe_debug_reduce_loss (
934+ self ,
935+ reduction : str ,
936+ loss_in : torch .Tensor ,
937+ loss_out : torch .Tensor ,
938+ num_valid_tokens : Optional [torch .Tensor ],
939+ compensate_factor : float ,
940+ ) -> None :
941+ if not self ._debug_sp_loss or self ._debug_sp_loss_seen >= self ._debug_sp_loss_max_steps :
942+ return
943+ rank = dist .get_rank () if dist .is_available () and dist .is_initialized () else 0
944+ sp_rank = (
945+ dist .get_rank (sequence_parallel ._sp_group )
946+ if sequence_parallel ._sp_group is not None and dist .is_initialized ()
947+ else 0
948+ )
949+ token_str = "None"
950+ if num_valid_tokens is not None :
951+ token_str = str (int (num_valid_tokens .detach ().item ()))
952+ in_val = float (loss_in .detach ().item ())
953+ out_val = float (loss_out .detach ().item ())
954+ print (
955+ "[SP-LOSS-DEBUG] "
956+ f"rank={ rank } sp_rank={ sp_rank } reduction={ reduction } "
957+ f"loss_in={ in_val :.6f} loss_out={ out_val :.6f} "
958+ f"local_valid_tokens={ token_str } compensate_factor={ compensate_factor :.4f} " ,
959+ flush = True ,
960+ )
961+ self ._debug_sp_loss_seen += 1
924962
925963 def _get_tokenizer (self ) -> Optional [PreTrainedTokenizer ]:
926964 if self ._tokenizer is not None :
@@ -1003,13 +1041,15 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
10031041 )
10041042 compensate_fsdp_avg = bool (self .sp_config .get ("compensate_fsdp_avg" , False ))
10051043 compensate_factor = float (self .ulysses_size if compensate_fsdp_avg else 1.0 )
1044+ loss_in = loss .detach ()
10061045
10071046 class _ReduceSequenceParallelLoss (torch .autograd .Function ):
10081047 @staticmethod
1009- def forward (ctx , local_sum : torch .Tensor , num_valid_tokens : torch .Tensor ) -> torch .Tensor :
1010- if num_valid_tokens .item () == 0 :
1011- local_sum = torch .nan_to_num (local_sum )
1048+ def forward (ctx , local_mean : torch .Tensor , num_valid_tokens : torch .Tensor ) -> torch .Tensor :
10121049 local_tokens = num_valid_tokens .detach ().clone ()
1050+ local_sum = local_mean * local_tokens
1051+ if local_tokens .item () == 0 :
1052+ local_sum = torch .nan_to_num (local_sum )
10131053 global_sum = local_sum .detach ().clone ()
10141054 dist .all_reduce (global_sum , group = sequence_parallel ._sp_group )
10151055 global_tokens = num_valid_tokens .detach ().clone ()
@@ -1023,9 +1063,10 @@ def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> tor
10231063 def backward (ctx , grad_output : torch .Tensor ):
10241064 local_tokens , global_tokens = ctx .saved_tensors
10251065 if global_tokens .item () == 0 :
1026- return grad_output , None
1027- grad_local_sum = grad_output * (local_tokens / global_tokens ) * compensate_factor
1028- return grad_local_sum , None
1066+ return torch .zeros_like (grad_output ), None
1067+ # d(global_mean)/d(local_mean) = local_tokens / global_tokens.
1068+ grad_local_mean = grad_output * (local_tokens / global_tokens ) * compensate_factor
1069+ return grad_local_mean , None
10291070
10301071 class _ReduceSequenceParallelSum (torch .autograd .Function ):
10311072 @staticmethod
@@ -1039,12 +1080,30 @@ def backward(ctx, grad_output: torch.Tensor):
10391080 return grad_output
10401081
10411082 if reduction == "sum" :
1042- return _ReduceSequenceParallelSum .apply (loss )
1083+ out = _ReduceSequenceParallelSum .apply (loss )
1084+ num_valid_tokens = None
1085+ if self ._debug_sp_loss :
1086+ num_valid_tokens = (labels != ignore_index ).sum ().to (loss .device ).detach ()
1087+ self ._maybe_debug_reduce_loss (
1088+ reduction ,
1089+ loss_in = loss_in ,
1090+ loss_out = out .detach (),
1091+ num_valid_tokens = num_valid_tokens ,
1092+ compensate_factor = compensate_factor ,
1093+ )
1094+ return out
10431095
1044- # Default to mean reduction: assume `loss` is local mean, convert to local sum .
1096+ # Default to mean reduction: `loss` is local mean.
10451097 num_valid_tokens = (labels != ignore_index ).sum ().to (loss .device )
1046- local_sum = loss * num_valid_tokens
1047- return _ReduceSequenceParallelLoss .apply (local_sum , num_valid_tokens )
1098+ out = _ReduceSequenceParallelLoss .apply (loss , num_valid_tokens )
1099+ self ._maybe_debug_reduce_loss (
1100+ reduction ,
1101+ loss_in = loss_in ,
1102+ loss_out = out .detach (),
1103+ num_valid_tokens = num_valid_tokens .detach (),
1104+ compensate_factor = compensate_factor ,
1105+ )
1106+ return out
10481107
10491108 def wrap_model (self , model , optimizer = None ):
10501109 self .initialize ()
0 commit comments