|
1 | 1 | # Copyright (c) ModelScope Contributors. All rights reserved. |
2 | | -import os |
3 | 2 | from functools import partial |
4 | 3 | from typing import Any, Dict, Optional, Tuple, Union |
5 | 4 |
|
@@ -801,7 +800,8 @@ def pad_and_split_inputs(self, |
801 | 800 | # - In next-token-aligned labels, this appears at labels[b-1] |
802 | 801 | boundary_starts = (real_position_ids == 0) |
803 | 802 | prev = torch.zeros_like(boundary_starts, dtype=torch.bool) |
804 | | - prev[..., 1:] = boundary_starts[..., :-1] |
| 803 | + # Mask token b-1 when boundary starts at b. |
| 804 | + prev[..., :-1] = boundary_starts[..., 1:] |
805 | 805 | labels = labels.clone() |
806 | 806 | labels[prev] = -100 |
807 | 807 | # Also avoid any potential wrap-around supervision at the end of the concatenated stream. |
@@ -922,43 +922,6 @@ def __init__( |
922 | 922 | self._tokenizer_id = tokenizer_id |
923 | 923 | self._tokenizer = None |
924 | 924 | self._initialized = False |
925 | | - debug_flag = os.getenv("TWINKLE_DEBUG_SP_LOSS", "").strip().lower() |
926 | | - self._debug_sp_loss = debug_flag not in ("", "0", "false", "off", "no") |
927 | | - try: |
928 | | - self._debug_sp_loss_max_steps = max(1, int(os.getenv("TWINKLE_DEBUG_SP_LOSS_STEPS", "8"))) |
929 | | - except ValueError: |
930 | | - self._debug_sp_loss_max_steps = 8 |
931 | | - self._debug_sp_loss_seen = 0 |
932 | | - |
933 | | - def _maybe_debug_reduce_loss( |
934 | | - self, |
935 | | - reduction: str, |
936 | | - loss_in: torch.Tensor, |
937 | | - loss_out: torch.Tensor, |
938 | | - num_valid_tokens: Optional[torch.Tensor], |
939 | | - compensate_factor: float, |
940 | | - ) -> None: |
941 | | - if not self._debug_sp_loss or self._debug_sp_loss_seen >= self._debug_sp_loss_max_steps: |
942 | | - return |
943 | | - rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 |
944 | | - sp_rank = ( |
945 | | - dist.get_rank(sequence_parallel._sp_group) |
946 | | - if sequence_parallel._sp_group is not None and dist.is_initialized() |
947 | | - else 0 |
948 | | - ) |
949 | | - token_str = "None" |
950 | | - if num_valid_tokens is not None: |
951 | | - token_str = str(int(num_valid_tokens.detach().item())) |
952 | | - in_val = float(loss_in.detach().item()) |
953 | | - out_val = float(loss_out.detach().item()) |
954 | | - print( |
955 | | - "[SP-LOSS-DEBUG] " |
956 | | - f"rank={rank} sp_rank={sp_rank} reduction={reduction} " |
957 | | - f"loss_in={in_val:.6f} loss_out={out_val:.6f} " |
958 | | - f"local_valid_tokens={token_str} compensate_factor={compensate_factor:.4f}", |
959 | | - flush=True, |
960 | | - ) |
961 | | - self._debug_sp_loss_seen += 1 |
962 | 925 |
|
963 | 926 | def _get_tokenizer(self) -> Optional[PreTrainedTokenizer]: |
964 | 927 | if self._tokenizer is not None: |
@@ -1042,7 +1005,6 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore |
1042 | 1005 | compensate_fsdp_avg = bool(self.sp_config.get("compensate_fsdp_avg", False)) |
1043 | 1006 | compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0) |
1044 | 1007 | sum_metric_scale = float(self.ulysses_size) |
1045 | | - loss_in = loss.detach() |
1046 | 1008 |
|
1047 | 1009 | class _ReduceSequenceParallelLoss(torch.autograd.Function): |
1048 | 1010 | @staticmethod |
@@ -1081,34 +1043,16 @@ def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor: |
1081 | 1043 |
|
1082 | 1044 | @staticmethod |
1083 | 1045 | def backward(ctx, grad_output: torch.Tensor): |
1084 | | - # Preserve original gradient scale (before forward-side metric scaling). |
1085 | | - return grad_output * ctx.sum_metric_scale |
| 1046 | + # Keep training gradient scale unchanged; forward-side scaling is for |
| 1047 | + # logging/metric alignment under outer collect='mean'. |
| 1048 | + return grad_output |
1086 | 1049 |
|
1087 | 1050 | if reduction == "sum": |
1088 | | - out = _ReduceSequenceParallelSum.apply(loss) |
1089 | | - num_valid_tokens = None |
1090 | | - if self._debug_sp_loss: |
1091 | | - num_valid_tokens = (labels != ignore_index).sum().to(loss.device).detach() |
1092 | | - self._maybe_debug_reduce_loss( |
1093 | | - reduction, |
1094 | | - loss_in=loss_in, |
1095 | | - loss_out=out.detach(), |
1096 | | - num_valid_tokens=num_valid_tokens, |
1097 | | - compensate_factor=compensate_factor, |
1098 | | - ) |
1099 | | - return out |
| 1051 | + return _ReduceSequenceParallelSum.apply(loss) |
1100 | 1052 |
|
1101 | 1053 | # Default to mean reduction: `loss` is local mean. |
1102 | 1054 | num_valid_tokens = (labels != ignore_index).sum().to(loss.device) |
1103 | | - out = _ReduceSequenceParallelLoss.apply(loss, num_valid_tokens) |
1104 | | - self._maybe_debug_reduce_loss( |
1105 | | - reduction, |
1106 | | - loss_in=loss_in, |
1107 | | - loss_out=out.detach(), |
1108 | | - num_valid_tokens=num_valid_tokens.detach(), |
1109 | | - compensate_factor=compensate_factor, |
1110 | | - ) |
1111 | | - return out |
| 1055 | + return _ReduceSequenceParallelLoss.apply(loss, num_valid_tokens) |
1112 | 1056 |
|
1113 | 1057 | def wrap_model(self, model, optimizer=None): |
1114 | 1058 | self.initialize() |
|
0 commit comments