Skip to content

Commit f5ca919

Browse files
committed
delete unused unit test
1 parent 39c40b8 commit f5ca919

File tree

3 files changed

+22
-402
lines changed

3 files changed

+22
-402
lines changed

src/twinkle/model/transformers/strategy/sequence_parallel.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,8 @@ def pad_and_split_inputs(self,
800800
# - In next-token-aligned labels, this appears at labels[b-1]
801801
boundary_starts = (real_position_ids == 0)
802802
prev = torch.zeros_like(boundary_starts, dtype=torch.bool)
803-
prev[..., 1:] = boundary_starts[..., :-1]
803+
# Mask token b-1 when boundary starts at b.
804+
prev[..., :-1] = boundary_starts[..., 1:]
804805
labels = labels.clone()
805806
labels[prev] = -100
806807
# Also avoid any potential wrap-around supervision at the end of the concatenated stream.
@@ -1003,13 +1004,15 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
10031004
)
10041005
compensate_fsdp_avg = bool(self.sp_config.get("compensate_fsdp_avg", False))
10051006
compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0)
1007+
sum_metric_scale = float(self.ulysses_size)
10061008

10071009
class _ReduceSequenceParallelLoss(torch.autograd.Function):
10081010
@staticmethod
1009-
def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor:
1010-
if num_valid_tokens.item() == 0:
1011-
local_sum = torch.nan_to_num(local_sum)
1011+
def forward(ctx, local_mean: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor:
10121012
local_tokens = num_valid_tokens.detach().clone()
1013+
local_sum = local_mean * local_tokens
1014+
if local_tokens.item() == 0:
1015+
local_sum = torch.nan_to_num(local_sum)
10131016
global_sum = local_sum.detach().clone()
10141017
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
10151018
global_tokens = num_valid_tokens.detach().clone()
@@ -1023,28 +1026,33 @@ def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> tor
10231026
def backward(ctx, grad_output: torch.Tensor):
10241027
local_tokens, global_tokens = ctx.saved_tensors
10251028
if global_tokens.item() == 0:
1026-
return grad_output, None
1027-
grad_local_sum = grad_output * (local_tokens / global_tokens) * compensate_factor
1028-
return grad_local_sum, None
1029+
return torch.zeros_like(grad_output), None
1030+
# d(global_mean)/d(local_mean) = local_tokens / global_tokens.
1031+
grad_local_mean = grad_output * (local_tokens / global_tokens) * compensate_factor
1032+
return grad_local_mean, None
10291033

10301034
class _ReduceSequenceParallelSum(torch.autograd.Function):
10311035
@staticmethod
10321036
def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor:
1037+
ctx.sum_metric_scale = sum_metric_scale
10331038
global_sum = local_sum.detach().clone()
10341039
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
1035-
return global_sum
1040+
# Keep logging/metric value aligned with non-SP sum semantics under
1041+
# outer collect='mean' by removing one SP replication factor.
1042+
return global_sum / ctx.sum_metric_scale
10361043

10371044
@staticmethod
10381045
def backward(ctx, grad_output: torch.Tensor):
1046+
# Keep training gradient scale unchanged; forward-side scaling is for
1047+
# logging/metric alignment under outer collect='mean'.
10391048
return grad_output
10401049

10411050
if reduction == "sum":
10421051
return _ReduceSequenceParallelSum.apply(loss)
10431052

1044-
# Default to mean reduction: assume `loss` is local mean, convert to local sum.
1053+
# Default to mean reduction: `loss` is local mean.
10451054
num_valid_tokens = (labels != ignore_index).sum().to(loss.device)
1046-
local_sum = loss * num_valid_tokens
1047-
return _ReduceSequenceParallelLoss.apply(local_sum, num_valid_tokens)
1055+
return _ReduceSequenceParallelLoss.apply(loss, num_valid_tokens)
10481056

10491057
def wrap_model(self, model, optimizer=None):
10501058
self.initialize()

src/twinkle/model/transformers/transformers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,9 @@ def calculate_loss(self, **kwargs):
431431
optimizer_config = self.optimizer_group[adapter_name]
432432
optimizer_config.num_tokens += counts.item()
433433
if self.sp_strategy is not None and 'labels' in inputs:
434-
if "loss_reduction" not in self.sp_strategy.sp_config:
435-
reduction = getattr(loss_instance, "reduction", None)
436-
if reduction is not None:
437-
self.sp_strategy.sp_config["loss_reduction"] = str(reduction)
434+
reduction = getattr(loss_instance, "reduction", None)
435+
if reduction is not None:
436+
self.sp_strategy.sp_config["loss_reduction"] = str(reduction)
438437
loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels'])
439438
optimizer_config.loss_value += loss_value
440439
outputs['loss'] = optimizer_config.loss_value

0 commit comments

Comments
 (0)