Skip to content

Commit 9d239da

Browse files
committed
delete unused unit test
1 parent 9823afb commit 9d239da

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

src/twinkle/model/transformers/strategy/sequence_parallel.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,8 @@ def pad_and_split_inputs(self,
788788
# - In next-token-aligned labels, this appears at labels[b-1]
789789
boundary_starts = (real_position_ids == 0)
790790
prev = torch.zeros_like(boundary_starts, dtype=torch.bool)
791-
prev[..., 1:] = boundary_starts[..., :-1]
791+
# Mask token b-1 when boundary starts at b.
792+
prev[..., :-1] = boundary_starts[..., 1:]
792793
labels = labels.clone()
793794
labels[prev] = -100
794795
# Also avoid any potential wrap-around supervision at the end of the concatenated stream.
@@ -982,13 +983,15 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
982983
)
983984
compensate_fsdp_avg = bool(self.sp_config.get("compensate_fsdp_avg", False))
984985
compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0)
986+
sum_metric_scale = float(self.ulysses_size)
985987

986988
class _ReduceSequenceParallelLoss(torch.autograd.Function):
987989
@staticmethod
988-
def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor:
989-
if num_valid_tokens.item() == 0:
990-
local_sum = torch.nan_to_num(local_sum)
990+
def forward(ctx, local_mean: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor:
991991
local_tokens = num_valid_tokens.detach().clone()
992+
local_sum = local_mean * local_tokens
993+
if local_tokens.item() == 0:
994+
local_sum = torch.nan_to_num(local_sum)
992995
global_sum = local_sum.detach().clone()
993996
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
994997
global_tokens = num_valid_tokens.detach().clone()
@@ -1002,28 +1005,33 @@ def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> tor
10021005
def backward(ctx, grad_output: torch.Tensor):
10031006
local_tokens, global_tokens = ctx.saved_tensors
10041007
if global_tokens.item() == 0:
1005-
return grad_output, None
1006-
grad_local_sum = grad_output * (local_tokens / global_tokens) * compensate_factor
1007-
return grad_local_sum, None
1008+
return torch.zeros_like(grad_output), None
1009+
# d(global_mean)/d(local_mean) = local_tokens / global_tokens.
1010+
grad_local_mean = grad_output * (local_tokens / global_tokens) * compensate_factor
1011+
return grad_local_mean, None
10081012

10091013
class _ReduceSequenceParallelSum(torch.autograd.Function):
10101014
@staticmethod
10111015
def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor:
1016+
ctx.sum_metric_scale = sum_metric_scale
10121017
global_sum = local_sum.detach().clone()
10131018
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
1014-
return global_sum
1019+
# Keep logging/metric value aligned with non-SP sum semantics under
1020+
# outer collect='mean' by removing one SP replication factor.
1021+
return global_sum / ctx.sum_metric_scale
10151022

10161023
@staticmethod
10171024
def backward(ctx, grad_output: torch.Tensor):
1025+
# Keep training gradient scale unchanged; forward-side scaling is for
1026+
# logging/metric alignment under outer collect='mean'.
10181027
return grad_output
10191028

10201029
if reduction == "sum":
10211030
return _ReduceSequenceParallelSum.apply(loss)
10221031

1023-
# Default to mean reduction: assume `loss` is local mean, convert to local sum.
1032+
# Default to mean reduction: `loss` is local mean.
10241033
num_valid_tokens = (labels != ignore_index).sum().to(loss.device)
1025-
local_sum = loss * num_valid_tokens
1026-
return _ReduceSequenceParallelLoss.apply(local_sum, num_valid_tokens)
1034+
return _ReduceSequenceParallelLoss.apply(loss, num_valid_tokens)
10271035

10281036
def wrap_model(self, model, optimizer=None):
10291037
self.initialize()

src/twinkle/model/transformers/transformers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -440,10 +440,9 @@ def calculate_loss(self, **kwargs):
440440
optimizer_config = self.optimizer_group[adapter_name]
441441
optimizer_config.num_tokens += counts.item()
442442
if self.sp_strategy is not None and 'labels' in inputs:
443-
if 'loss_reduction' not in self.sp_strategy.sp_config:
444-
reduction = getattr(loss_instance, 'reduction', None)
445-
if reduction is not None:
446-
self.sp_strategy.sp_config['loss_reduction'] = str(reduction)
443+
reduction = getattr(loss_instance, "reduction", None)
444+
if reduction is not None:
445+
self.sp_strategy.sp_config["loss_reduction"] = str(reduction)
447446
loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels'])
448447
optimizer_config.loss_value += loss_value
449448
outputs['loss'] = optimizer_config.loss_value

0 commit comments

Comments
 (0)