From 04f4f98e0eac0af91622e7371c2989a934f9c5aa Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Thu, 12 Feb 2026 14:41:44 +0800 Subject: [PATCH 1/6] feat(sequence_parallel): refactor loss reduction using custom autograd functions Replace manual gradient handling with `torch.autograd.Function` subclasses `_ReduceSequenceParallelLoss` and `_ReduceSequenceParallelSum` to compute global loss via autograd-aware all-reduce. This simplifies the logic for both sum and mean reductions, improves gradient correctness, and removes the need for separate metric scaling when `world_size > 1`. --- .../strategy/sequence_parallel.py | 73 ++++++++++++------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 28285070..359b8d79 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -969,34 +969,55 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore return loss if labels is None or sequence_parallel._sp_group is None: return loss - # Compute full-sequence loss in forward, but keep backward local to this rank. - reduction = str(self.sp_config.get('loss_reduction', 'mean')).lower() - if reduction == 'none': - raise ValueError("SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. " - 'Please aggregate per-token losses before calling reduce_loss.') + # Compute global loss via autograd-aware all-reduce. + reduction = str(self.sp_config.get("loss_reduction", "mean")).lower() + if reduction == "none": + raise ValueError( + "SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. " + "Please aggregate per-token losses before calling reduce_loss." + ) + + class _ReduceSequenceParallelLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor: + if num_valid_tokens.item() == 0: + local_sum = torch.nan_to_num(local_sum) + local_tokens = num_valid_tokens.detach().clone() + global_sum = local_sum.detach().clone() + dist.all_reduce(global_sum, group=sequence_parallel._sp_group) + global_tokens = num_valid_tokens.detach().clone() + dist.all_reduce(global_tokens, group=sequence_parallel._sp_group) + ctx.save_for_backward(local_tokens, global_tokens) + if global_tokens.item() == 0: + return local_sum + return global_sum / global_tokens + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + local_tokens, global_tokens = ctx.saved_tensors + if global_tokens.item() == 0: + return grad_output, None + grad_local_sum = grad_output * (local_tokens / global_tokens) + return grad_local_sum, None + + class _ReduceSequenceParallelSum(torch.autograd.Function): + @staticmethod + def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor: + global_sum = local_sum.detach().clone() + dist.all_reduce(global_sum, group=sequence_parallel._sp_group) + return global_sum + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + return grad_output + + if reduction == "sum": + return _ReduceSequenceParallelSum.apply(loss) + + # Default to mean reduction: assume `loss` is local mean, convert to local sum. num_valid_tokens = (labels != ignore_index).sum().to(loss.device) - if reduction == 'sum': - local_sum = loss - global_sum = local_sum.detach().clone() - dist.all_reduce(global_sum, group=sequence_parallel._sp_group) - out = global_sum + (local_sum - local_sum.detach()) - if sequence_parallel.world_size > 1: - out_metric = out.detach() / sequence_parallel.world_size - return out_metric + (out - out.detach()) - return out - # Default to mean reduction. local_sum = loss * num_valid_tokens - global_sum = local_sum.detach().clone() - dist.all_reduce(global_sum, group=sequence_parallel._sp_group) - global_tokens = num_valid_tokens.detach().clone() - dist.all_reduce(global_tokens, group=sequence_parallel._sp_group) - if global_tokens.item() == 0: - return loss - out = (global_sum + (local_sum - local_sum.detach())) / global_tokens - if sequence_parallel.world_size > 1: - out_metric = out.detach() / sequence_parallel.world_size - return out_metric + (out - out.detach()) - return out + return _ReduceSequenceParallelLoss.apply(local_sum, num_valid_tokens) def wrap_model(self, model, optimizer=None): self.initialize() From 9823afb912d50762ae41e23c4b534bbe832e8f4e Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 13 Feb 2026 09:08:48 +0800 Subject: [PATCH 2/6] feat(sequence_parallel): compensate gradient scaling for FSDP averaging Add `compensate_fsdp_avg` config flag to adjust loss reduction when sequence parallel (SP) is combined with FSDP or accelerate DDP/FSDP. This prevents gradient magnitude from being incorrectly scaled down by an extra factor of SP world size during data-parallel averaging. - In `GatherLoss` backward, scale gradients by SP world size before splitting, so downstream FSDP averaging does not shrink this path. - In `SequenceParallelStrategy.reduce_loss`, apply a compensation factor (ulysses_size) when `compensate_fsdp_avg` is enabled. - Automatically set `compensate_fsdp_avg=True` in `TransformersModel` when using NativeFSDPStrategy or AccelerateStrategy with both SP and data parallelism active. --- .../model/transformers/strategy/sequence_parallel.py | 10 ++++++++-- src/twinkle/model/transformers/transformers.py | 8 +++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 359b8d79..6fdac2e9 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -183,6 +183,9 @@ def backward(ctx, *grad_output): # Split grads back to local sequence chunk. _grad = grad_output[0] if sequence_parallel.world_size > 1 and sequence_parallel._sp_group is not None: + # Gather replicates the sequence dimension across SP ranks. Scale once here + # so downstream FSDP avg does not shrink this path by an extra SP factor. + _grad = _grad * sequence_parallel.world_size _grad = sequence_parallel.split(_grad, dim=ctx.gather_idx, position_ids=ctx.position_ids).contiguous() return _grad, None, None, None @@ -866,7 +869,8 @@ class SequenceParallelConfig: enabled: bool = True ulysses_size: Optional[int] = None gather_logits: bool = True - loss_reduction: str = 'mean' + loss_reduction: str = "mean" + compensate_fsdp_avg: bool = False def _get_ulysses_size(device_mesh, sp_config: Optional[Dict[str, Any]] = None) -> int: @@ -976,6 +980,8 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore "SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. " "Please aggregate per-token losses before calling reduce_loss." ) + compensate_fsdp_avg = bool(self.sp_config.get("compensate_fsdp_avg", False)) + compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0) class _ReduceSequenceParallelLoss(torch.autograd.Function): @staticmethod @@ -997,7 +1003,7 @@ def backward(ctx, grad_output: torch.Tensor): local_tokens, global_tokens = ctx.saved_tensors if global_tokens.item() == 0: return grad_output, None - grad_local_sum = grad_output * (local_tokens / global_tokens) + grad_local_sum = grad_output * (local_tokens / global_tokens) * compensate_factor return grad_local_sum, None class _ReduceSequenceParallelSum(torch.autograd.Function): diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 75aa8a1d..6dc74ad9 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -247,9 +247,15 @@ def _ensure_sp_strategy(self) -> None: return from .strategy.sequence_parallel import SequenceParallelStrategy + sp_config = {} + # When data-parallel gradient averaging runs across SP shards (native FSDP or + # accelerate DDP/FSDP paths), compensate SP loss backward to keep gradient scale. + if isinstance(self.strategy, (NativeFSDPStrategy, AccelerateStrategy)) and self.device_mesh is not None: + if (self.device_mesh.ulysses_size or 1) > 1 and (self.device_mesh.data_world_size or 1) > 1: + sp_config["compensate_fsdp_avg"] = True self.sp_strategy = SequenceParallelStrategy( self.device_mesh, - {}, + sp_config, model=self.model, tokenizer_id=self.tokenizer_id, ) From 9d239dad12ff7c9ca5e8955d7b49d32e1c9c460d Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 13 Feb 2026 10:06:12 +0800 Subject: [PATCH 3/6] delete unused unit test --- .../strategy/sequence_parallel.py | 30 ++++++++++++------- .../model/transformers/transformers.py | 7 ++--- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 6fdac2e9..281b4fe6 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -788,7 +788,8 @@ def pad_and_split_inputs(self, # - In next-token-aligned labels, this appears at labels[b-1] boundary_starts = (real_position_ids == 0) prev = torch.zeros_like(boundary_starts, dtype=torch.bool) - prev[..., 1:] = boundary_starts[..., :-1] + # Mask token b-1 when boundary starts at b. + prev[..., :-1] = boundary_starts[..., 1:] labels = labels.clone() labels[prev] = -100 # Also avoid any potential wrap-around supervision at the end of the concatenated stream. @@ -982,13 +983,15 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore ) compensate_fsdp_avg = bool(self.sp_config.get("compensate_fsdp_avg", False)) compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0) + sum_metric_scale = float(self.ulysses_size) class _ReduceSequenceParallelLoss(torch.autograd.Function): @staticmethod - def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor: - if num_valid_tokens.item() == 0: - local_sum = torch.nan_to_num(local_sum) + def forward(ctx, local_mean: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor: local_tokens = num_valid_tokens.detach().clone() + local_sum = local_mean * local_tokens + if local_tokens.item() == 0: + local_sum = torch.nan_to_num(local_sum) global_sum = local_sum.detach().clone() dist.all_reduce(global_sum, group=sequence_parallel._sp_group) global_tokens = num_valid_tokens.detach().clone() @@ -1002,28 +1005,33 @@ def forward(ctx, local_sum: torch.Tensor, num_valid_tokens: torch.Tensor) -> tor def backward(ctx, grad_output: torch.Tensor): local_tokens, global_tokens = ctx.saved_tensors if global_tokens.item() == 0: - return grad_output, None - grad_local_sum = grad_output * (local_tokens / global_tokens) * compensate_factor - return grad_local_sum, None + return torch.zeros_like(grad_output), None + # d(global_mean)/d(local_mean) = local_tokens / global_tokens. + grad_local_mean = grad_output * (local_tokens / global_tokens) * compensate_factor + return grad_local_mean, None class _ReduceSequenceParallelSum(torch.autograd.Function): @staticmethod def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor: + ctx.sum_metric_scale = sum_metric_scale global_sum = local_sum.detach().clone() dist.all_reduce(global_sum, group=sequence_parallel._sp_group) - return global_sum + # Keep logging/metric value aligned with non-SP sum semantics under + # outer collect='mean' by removing one SP replication factor. + return global_sum / ctx.sum_metric_scale @staticmethod def backward(ctx, grad_output: torch.Tensor): + # Keep training gradient scale unchanged; forward-side scaling is for + # logging/metric alignment under outer collect='mean'. return grad_output if reduction == "sum": return _ReduceSequenceParallelSum.apply(loss) - # Default to mean reduction: assume `loss` is local mean, convert to local sum. + # Default to mean reduction: `loss` is local mean. num_valid_tokens = (labels != ignore_index).sum().to(loss.device) - local_sum = loss * num_valid_tokens - return _ReduceSequenceParallelLoss.apply(local_sum, num_valid_tokens) + return _ReduceSequenceParallelLoss.apply(loss, num_valid_tokens) def wrap_model(self, model, optimizer=None): self.initialize() diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 6dc74ad9..5a46eee2 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -440,10 +440,9 @@ def calculate_loss(self, **kwargs): optimizer_config = self.optimizer_group[adapter_name] optimizer_config.num_tokens += counts.item() if self.sp_strategy is not None and 'labels' in inputs: - if 'loss_reduction' not in self.sp_strategy.sp_config: - reduction = getattr(loss_instance, 'reduction', None) - if reduction is not None: - self.sp_strategy.sp_config['loss_reduction'] = str(reduction) + reduction = getattr(loss_instance, "reduction", None) + if reduction is not None: + self.sp_strategy.sp_config["loss_reduction"] = str(reduction) loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels']) optimizer_config.loss_value += loss_value outputs['loss'] = optimizer_config.loss_value From f04c1f8938f5c09ae941946ee13516dc33ba5f65 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 13 Feb 2026 12:58:48 +0800 Subject: [PATCH 4/6] fix lint --- .../transformers/strategy/sequence_parallel.py | 18 +++++++++--------- src/twinkle/model/transformers/transformers.py | 6 +++--- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 281b4fe6..64ea34f3 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -870,7 +870,7 @@ class SequenceParallelConfig: enabled: bool = True ulysses_size: Optional[int] = None gather_logits: bool = True - loss_reduction: str = "mean" + loss_reduction: str = 'mean' compensate_fsdp_avg: bool = False @@ -975,17 +975,16 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore if labels is None or sequence_parallel._sp_group is None: return loss # Compute global loss via autograd-aware all-reduce. - reduction = str(self.sp_config.get("loss_reduction", "mean")).lower() - if reduction == "none": - raise ValueError( - "SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. " - "Please aggregate per-token losses before calling reduce_loss." - ) - compensate_fsdp_avg = bool(self.sp_config.get("compensate_fsdp_avg", False)) + reduction = str(self.sp_config.get('loss_reduction', 'mean')).lower() + if reduction == 'none': + raise ValueError("SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. " + 'Please aggregate per-token losses before calling reduce_loss.') + compensate_fsdp_avg = bool(self.sp_config.get('compensate_fsdp_avg', False)) compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0) sum_metric_scale = float(self.ulysses_size) class _ReduceSequenceParallelLoss(torch.autograd.Function): + @staticmethod def forward(ctx, local_mean: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor: local_tokens = num_valid_tokens.detach().clone() @@ -1011,6 +1010,7 @@ def backward(ctx, grad_output: torch.Tensor): return grad_local_mean, None class _ReduceSequenceParallelSum(torch.autograd.Function): + @staticmethod def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor: ctx.sum_metric_scale = sum_metric_scale @@ -1026,7 +1026,7 @@ def backward(ctx, grad_output: torch.Tensor): # logging/metric alignment under outer collect='mean'. return grad_output - if reduction == "sum": + if reduction == 'sum': return _ReduceSequenceParallelSum.apply(loss) # Default to mean reduction: `loss` is local mean. diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 5a46eee2..6f80699b 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -252,7 +252,7 @@ def _ensure_sp_strategy(self) -> None: # accelerate DDP/FSDP paths), compensate SP loss backward to keep gradient scale. if isinstance(self.strategy, (NativeFSDPStrategy, AccelerateStrategy)) and self.device_mesh is not None: if (self.device_mesh.ulysses_size or 1) > 1 and (self.device_mesh.data_world_size or 1) > 1: - sp_config["compensate_fsdp_avg"] = True + sp_config['compensate_fsdp_avg'] = True self.sp_strategy = SequenceParallelStrategy( self.device_mesh, sp_config, @@ -440,9 +440,9 @@ def calculate_loss(self, **kwargs): optimizer_config = self.optimizer_group[adapter_name] optimizer_config.num_tokens += counts.item() if self.sp_strategy is not None and 'labels' in inputs: - reduction = getattr(loss_instance, "reduction", None) + reduction = getattr(loss_instance, 'reduction', None) if reduction is not None: - self.sp_strategy.sp_config["loss_reduction"] = str(reduction) + self.sp_strategy.sp_config['loss_reduction'] = str(reduction) loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels']) optimizer_config.loss_value += loss_value outputs['loss'] = optimizer_config.loss_value From 95f01ac2a57c3079ee5e6d12896192c17913c5dc Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 13 Feb 2026 14:37:02 +0800 Subject: [PATCH 5/6] feat: add kernels optional dependency and refactor CI installation - Add 'kernels' as an optional dependency group in pyproject.toml - Refactor CI container test script to use a reusable installation function - Install twinkle with kernels in both debug and release modes for consistency - Improve maintainability by centralizing the installation command --- .dev_scripts/ci_container_test.sh | 7 ++++++- pyproject.toml | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.dev_scripts/ci_container_test.sh b/.dev_scripts/ci_container_test.sh index dd868b5b..66b4c7e1 100644 --- a/.dev_scripts/ci_container_test.sh +++ b/.dev_scripts/ci_container_test.sh @@ -1,3 +1,7 @@ +install_twinkle_with_kernels() { + pip install ".[kernels]" -i https://mirrors.aliyun.com/pypi/simple/ || pip install ".[kernels]" +} + if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then # pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple git config --global --add safe.directory /twinkle @@ -25,8 +29,9 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then pip install optimum # test with install - pip install . + install_twinkle_with_kernels else + install_twinkle_with_kernels echo "Running case in release image, run case directly!" fi # remove torch_extensions folder to avoid ci hang. diff --git a/pyproject.toml b/pyproject.toml index 6f1bd429..76ca660d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ transformers = [ "torch>=2.6.0,<3.0.0", "torchvision", ] +kernels = ["kernels"] megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]"] vllm = ["vllm>=0.11"] ray = ["ray[serve]"] From de5bc35996beffe17270d25ba97d0ffbd7bcec8a Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 13 Feb 2026 14:56:49 +0800 Subject: [PATCH 6/6] feat(kernel): add backward compatibility for kernels API changes Update `_load_from_hub` function to handle API changes in `select_revision_or_version` and `get_kernel` calls. The changes introduce try-except blocks to catch `TypeError` exceptions, allowing the function to work with both modern keyword-based APIs and older positional argument variants. This ensures compatibility across different versions of the kernels module without breaking existing functionality. --- src/twinkle/kernel/function.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/twinkle/kernel/function.py b/src/twinkle/kernel/function.py index 5f2bc130..94a2d817 100644 --- a/src/twinkle/kernel/function.py +++ b/src/twinkle/kernel/function.py @@ -35,8 +35,16 @@ def impl(*args, **kwargs): from kernels._versions import select_revision_or_version from kernels.utils import get_kernel assert repo_id is not None - resolved = select_revision_or_version(repo_id, revision, version) - kernel = get_kernel(repo_id, revision=resolved) + # kernels API changed across versions; use keyword args for modern API + # and fall back to repo_id-only for older variants. + try: + resolved = select_revision_or_version(repo_id, revision=revision, version=version) + except TypeError: + resolved = select_revision_or_version(repo_id) + try: + kernel = get_kernel(repo_id, revision=resolved) + except TypeError: + kernel = get_kernel(repo_id, resolved) func = getattr(kernel, func_name, None) if func is None: raise AttributeError(f'Kernel repo {repo_id} does not export {func_name}.')