Skip to content

Commit 5cc3b2d

Browse files
committed
remove debug log
1 parent c0bfaef commit 5cc3b2d

File tree

1 file changed

+7
-63
lines changed

1 file changed

+7
-63
lines changed

src/twinkle/model/transformers/strategy/sequence_parallel.py

Lines changed: 7 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2-
import os
32
from functools import partial
43
from typing import Any, Dict, Optional, Tuple, Union
54

@@ -801,7 +800,8 @@ def pad_and_split_inputs(self,
801800
# - In next-token-aligned labels, this appears at labels[b-1]
802801
boundary_starts = (real_position_ids == 0)
803802
prev = torch.zeros_like(boundary_starts, dtype=torch.bool)
804-
prev[..., 1:] = boundary_starts[..., :-1]
803+
# Mask token b-1 when boundary starts at b.
804+
prev[..., :-1] = boundary_starts[..., 1:]
805805
labels = labels.clone()
806806
labels[prev] = -100
807807
# Also avoid any potential wrap-around supervision at the end of the concatenated stream.
@@ -922,43 +922,6 @@ def __init__(
922922
self._tokenizer_id = tokenizer_id
923923
self._tokenizer = None
924924
self._initialized = False
925-
debug_flag = os.getenv("TWINKLE_DEBUG_SP_LOSS", "").strip().lower()
926-
self._debug_sp_loss = debug_flag not in ("", "0", "false", "off", "no")
927-
try:
928-
self._debug_sp_loss_max_steps = max(1, int(os.getenv("TWINKLE_DEBUG_SP_LOSS_STEPS", "8")))
929-
except ValueError:
930-
self._debug_sp_loss_max_steps = 8
931-
self._debug_sp_loss_seen = 0
932-
933-
def _maybe_debug_reduce_loss(
934-
self,
935-
reduction: str,
936-
loss_in: torch.Tensor,
937-
loss_out: torch.Tensor,
938-
num_valid_tokens: Optional[torch.Tensor],
939-
compensate_factor: float,
940-
) -> None:
941-
if not self._debug_sp_loss or self._debug_sp_loss_seen >= self._debug_sp_loss_max_steps:
942-
return
943-
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
944-
sp_rank = (
945-
dist.get_rank(sequence_parallel._sp_group)
946-
if sequence_parallel._sp_group is not None and dist.is_initialized()
947-
else 0
948-
)
949-
token_str = "None"
950-
if num_valid_tokens is not None:
951-
token_str = str(int(num_valid_tokens.detach().item()))
952-
in_val = float(loss_in.detach().item())
953-
out_val = float(loss_out.detach().item())
954-
print(
955-
"[SP-LOSS-DEBUG] "
956-
f"rank={rank} sp_rank={sp_rank} reduction={reduction} "
957-
f"loss_in={in_val:.6f} loss_out={out_val:.6f} "
958-
f"local_valid_tokens={token_str} compensate_factor={compensate_factor:.4f}",
959-
flush=True,
960-
)
961-
self._debug_sp_loss_seen += 1
962925

963926
def _get_tokenizer(self) -> Optional[PreTrainedTokenizer]:
964927
if self._tokenizer is not None:
@@ -1042,7 +1005,6 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
10421005
compensate_fsdp_avg = bool(self.sp_config.get("compensate_fsdp_avg", False))
10431006
compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0)
10441007
sum_metric_scale = float(self.ulysses_size)
1045-
loss_in = loss.detach()
10461008

10471009
class _ReduceSequenceParallelLoss(torch.autograd.Function):
10481010
@staticmethod
@@ -1081,34 +1043,16 @@ def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor:
10811043

10821044
@staticmethod
10831045
def backward(ctx, grad_output: torch.Tensor):
1084-
# Preserve original gradient scale (before forward-side metric scaling).
1085-
return grad_output * ctx.sum_metric_scale
1046+
# Keep training gradient scale unchanged; forward-side scaling is for
1047+
# logging/metric alignment under outer collect='mean'.
1048+
return grad_output
10861049

10871050
if reduction == "sum":
1088-
out = _ReduceSequenceParallelSum.apply(loss)
1089-
num_valid_tokens = None
1090-
if self._debug_sp_loss:
1091-
num_valid_tokens = (labels != ignore_index).sum().to(loss.device).detach()
1092-
self._maybe_debug_reduce_loss(
1093-
reduction,
1094-
loss_in=loss_in,
1095-
loss_out=out.detach(),
1096-
num_valid_tokens=num_valid_tokens,
1097-
compensate_factor=compensate_factor,
1098-
)
1099-
return out
1051+
return _ReduceSequenceParallelSum.apply(loss)
11001052

11011053
# Default to mean reduction: `loss` is local mean.
11021054
num_valid_tokens = (labels != ignore_index).sum().to(loss.device)
1103-
out = _ReduceSequenceParallelLoss.apply(loss, num_valid_tokens)
1104-
self._maybe_debug_reduce_loss(
1105-
reduction,
1106-
loss_in=loss_in,
1107-
loss_out=out.detach(),
1108-
num_valid_tokens=num_valid_tokens.detach(),
1109-
compensate_factor=compensate_factor,
1110-
)
1111-
return out
1055+
return _ReduceSequenceParallelLoss.apply(loss, num_valid_tokens)
11121056

11131057
def wrap_model(self, model, optimizer=None):
11141058
self.initialize()

0 commit comments

Comments
 (0)