Skip to content

Commit 4937e53

Browse files
committed
loss debug
1 parent 39c40b8 commit 4937e53

File tree

2 files changed

+72
-13
lines changed

2 files changed

+72
-13
lines changed

src/twinkle/model/transformers/strategy/sequence_parallel.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import os
23
from functools import partial
34
from typing import Any, Dict, Optional, Tuple, Union
45

@@ -921,6 +922,43 @@ def __init__(
921922
self._tokenizer_id = tokenizer_id
922923
self._tokenizer = None
923924
self._initialized = False
925+
debug_flag = os.getenv("TWINKLE_DEBUG_SP_LOSS", "").strip().lower()
926+
self._debug_sp_loss = debug_flag not in ("", "0", "false", "off", "no")
927+
try:
928+
self._debug_sp_loss_max_steps = max(1, int(os.getenv("TWINKLE_DEBUG_SP_LOSS_STEPS", "8")))
929+
except ValueError:
930+
self._debug_sp_loss_max_steps = 8
931+
self._debug_sp_loss_seen = 0
932+
933+
def _maybe_debug_reduce_loss(
934+
self,
935+
reduction: str,
936+
loss_in: torch.Tensor,
937+
loss_out: torch.Tensor,
938+
num_valid_tokens: Optional[torch.Tensor],
939+
compensate_factor: float,
940+
) -> None:
941+
if not self._debug_sp_loss or self._debug_sp_loss_seen >= self._debug_sp_loss_max_steps:
942+
return
943+
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
944+
sp_rank = (
945+
dist.get_rank(sequence_parallel._sp_group)
946+
if sequence_parallel._sp_group is not None and dist.is_initialized()
947+
else 0
948+
)
949+
token_str = "None"
950+
if num_valid_tokens is not None:
951+
token_str = str(int(num_valid_tokens.detach().item()))
952+
in_val = float(loss_in.detach().item())
953+
out_val = float(loss_out.detach().item())
954+
print(
955+
"[SP-LOSS-DEBUG] "
956+
f"rank={rank} sp_rank={sp_rank} reduction={reduction} "
957+
f"loss_in={in_val:.6f} loss_out={out_val:.6f} "
958+
f"local_valid_tokens={token_str} compensate_factor={compensate_factor:.4f}",
959+
flush=True,
960+
)
961+
self._debug_sp_loss_seen += 1
924962

925963
def _get_tokenizer(self) -> Optional[PreTrainedTokenizer]:
926964
if self._tokenizer is not None:
@@ -1003,13 +1041,15 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
10031041
)
10041042
compensate_fsdp_avg = bool(self.sp_config.get("compensate_fsdp_avg", False))
10051043
compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0)
1044+
loss_in = loss.detach()
10061045

10071046
class _ReduceSequenceParallelLoss(torch.autograd.Function):
10081047
@staticmethod
1009-
def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor:
1010-
if num_valid_tokens.item() == 0:
1011-
local_sum = torch.nan_to_num(local_sum)
1048+
def forward(ctx, local_mean: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor:
10121049
local_tokens = num_valid_tokens.detach().clone()
1050+
local_sum = local_mean * local_tokens
1051+
if local_tokens.item() == 0:
1052+
local_sum = torch.nan_to_num(local_sum)
10131053
global_sum = local_sum.detach().clone()
10141054
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
10151055
global_tokens = num_valid_tokens.detach().clone()
@@ -1023,9 +1063,10 @@ def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> tor
10231063
def backward(ctx, grad_output: torch.Tensor):
10241064
local_tokens, global_tokens = ctx.saved_tensors
10251065
if global_tokens.item() == 0:
1026-
return grad_output, None
1027-
grad_local_sum = grad_output * (local_tokens / global_tokens) * compensate_factor
1028-
return grad_local_sum, None
1066+
return torch.zeros_like(grad_output), None
1067+
# d(global_mean)/d(local_mean) = local_tokens / global_tokens.
1068+
grad_local_mean = grad_output * (local_tokens / global_tokens) * compensate_factor
1069+
return grad_local_mean, None
10291070

10301071
class _ReduceSequenceParallelSum(torch.autograd.Function):
10311072
@staticmethod
@@ -1039,12 +1080,30 @@ def backward(ctx, grad_output: torch.Tensor):
10391080
return grad_output
10401081

10411082
if reduction == "sum":
1042-
return _ReduceSequenceParallelSum.apply(loss)
1083+
out = _ReduceSequenceParallelSum.apply(loss)
1084+
num_valid_tokens = None
1085+
if self._debug_sp_loss:
1086+
num_valid_tokens = (labels != ignore_index).sum().to(loss.device).detach()
1087+
self._maybe_debug_reduce_loss(
1088+
reduction,
1089+
loss_in=loss_in,
1090+
loss_out=out.detach(),
1091+
num_valid_tokens=num_valid_tokens,
1092+
compensate_factor=compensate_factor,
1093+
)
1094+
return out
10431095

1044-
# Default to mean reduction: assume `loss` is local mean, convert to local sum.
1096+
# Default to mean reduction: `loss` is local mean.
10451097
num_valid_tokens = (labels != ignore_index).sum().to(loss.device)
1046-
local_sum = loss * num_valid_tokens
1047-
return _ReduceSequenceParallelLoss.apply(local_sum, num_valid_tokens)
1098+
out = _ReduceSequenceParallelLoss.apply(loss, num_valid_tokens)
1099+
self._maybe_debug_reduce_loss(
1100+
reduction,
1101+
loss_in=loss_in,
1102+
loss_out=out.detach(),
1103+
num_valid_tokens=num_valid_tokens.detach(),
1104+
compensate_factor=compensate_factor,
1105+
)
1106+
return out
10481107

10491108
def wrap_model(self, model, optimizer=None):
10501109
self.initialize()

src/twinkle/model/transformers/transformers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ def __post_init__(self):
8080

8181
def _build_metrics(self):
8282
self.train_metrics = [
83-
LossMetric(self._device_mesh, self._dp_group, loss_reduction='sum'),
83+
LossMetric(self._device_mesh, self._dp_group, loss_reduction='mean'),
8484
Accuracy(self._device_mesh, self._dp_group),
8585
TrainMetric(self._device_mesh, self._dp_group),
8686
]
8787

8888
self.eval_metrics = [
89-
LossMetric(self._device_mesh, self._dp_group, loss_reduction='sum'),
89+
LossMetric(self._device_mesh, self._dp_group, loss_reduction='mean'),
9090
Accuracy(self._device_mesh, self._dp_group),
9191
TrainMetric(self._device_mesh, self._dp_group),
9292
]
@@ -317,7 +317,7 @@ def _ensure_optimizer_dp_groups(self):
317317

318318
def _construct_default_optimizer_group(self):
319319
return OptimizerGroup(
320-
loss_instance=CrossEntropyLoss(reduction='sum'),
320+
loss_instance=CrossEntropyLoss(reduction='mean'),
321321
template=Template(self.tokenizer_id),
322322
processor=InputProcessor(self.device_mesh),
323323
_device_mesh=self.device_mesh,

0 commit comments

Comments
 (0)