diff --git a/.dev_scripts/ci_container_test.sh b/.dev_scripts/ci_container_test.sh index dd868b5b..66b4c7e1 100644 --- a/.dev_scripts/ci_container_test.sh +++ b/.dev_scripts/ci_container_test.sh @@ -1,3 +1,7 @@ +install_twinkle_with_kernels() { + pip install ".[kernels]" -i https://mirrors.aliyun.com/pypi/simple/ || pip install ".[kernels]" +} + if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then # pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple git config --global --add safe.directory /twinkle @@ -25,8 +29,9 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then pip install optimum # test with install - pip install . + install_twinkle_with_kernels else + install_twinkle_with_kernels echo "Running case in release image, run case directly!" fi # remove torch_extensions folder to avoid ci hang. diff --git a/pyproject.toml b/pyproject.toml index 6f1bd429..76ca660d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ transformers = [ "torch>=2.6.0,<3.0.0", "torchvision", ] +kernels = ["kernels"] megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]"] vllm = ["vllm>=0.11"] ray = ["ray[serve]"] diff --git a/src/twinkle/kernel/function.py b/src/twinkle/kernel/function.py index 5f2bc130..94a2d817 100644 --- a/src/twinkle/kernel/function.py +++ b/src/twinkle/kernel/function.py @@ -35,8 +35,16 @@ def impl(*args, **kwargs): from kernels._versions import select_revision_or_version from kernels.utils import get_kernel assert repo_id is not None - resolved = select_revision_or_version(repo_id, revision, version) - kernel = get_kernel(repo_id, revision=resolved) + # kernels API changed across versions; use keyword args for modern API + # and fall back to repo_id-only for older variants. + try: + resolved = select_revision_or_version(repo_id, revision=revision, version=version) + except TypeError: + resolved = select_revision_or_version(repo_id) + try: + kernel = get_kernel(repo_id, revision=resolved) + except TypeError: + kernel = get_kernel(repo_id, resolved) func = getattr(kernel, func_name, None) if func is None: raise AttributeError(f'Kernel repo {repo_id} does not export {func_name}.') diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 28285070..64ea34f3 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -183,6 +183,9 @@ def backward(ctx, *grad_output): # Split grads back to local sequence chunk. _grad = grad_output[0] if sequence_parallel.world_size > 1 and sequence_parallel._sp_group is not None: + # Gather replicates the sequence dimension across SP ranks. Scale once here + # so downstream FSDP avg does not shrink this path by an extra SP factor. + _grad = _grad * sequence_parallel.world_size _grad = sequence_parallel.split(_grad, dim=ctx.gather_idx, position_ids=ctx.position_ids).contiguous() return _grad, None, None, None @@ -785,7 +788,8 @@ def pad_and_split_inputs(self, # - In next-token-aligned labels, this appears at labels[b-1] boundary_starts = (real_position_ids == 0) prev = torch.zeros_like(boundary_starts, dtype=torch.bool) - prev[..., 1:] = boundary_starts[..., :-1] + # Mask token b-1 when boundary starts at b. + prev[..., :-1] = boundary_starts[..., 1:] labels = labels.clone() labels[prev] = -100 # Also avoid any potential wrap-around supervision at the end of the concatenated stream. @@ -867,6 +871,7 @@ class SequenceParallelConfig: ulysses_size: Optional[int] = None gather_logits: bool = True loss_reduction: str = 'mean' + compensate_fsdp_avg: bool = False def _get_ulysses_size(device_mesh, sp_config: Optional[Dict[str, Any]] = None) -> int: @@ -969,34 +974,64 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore return loss if labels is None or sequence_parallel._sp_group is None: return loss - # Compute full-sequence loss in forward, but keep backward local to this rank. + # Compute global loss via autograd-aware all-reduce. reduction = str(self.sp_config.get('loss_reduction', 'mean')).lower() if reduction == 'none': raise ValueError("SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. " 'Please aggregate per-token losses before calling reduce_loss.') - num_valid_tokens = (labels != ignore_index).sum().to(loss.device) + compensate_fsdp_avg = bool(self.sp_config.get('compensate_fsdp_avg', False)) + compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0) + sum_metric_scale = float(self.ulysses_size) + + class _ReduceSequenceParallelLoss(torch.autograd.Function): + + @staticmethod + def forward(ctx, local_mean: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor: + local_tokens = num_valid_tokens.detach().clone() + local_sum = local_mean * local_tokens + if local_tokens.item() == 0: + local_sum = torch.nan_to_num(local_sum) + global_sum = local_sum.detach().clone() + dist.all_reduce(global_sum, group=sequence_parallel._sp_group) + global_tokens = num_valid_tokens.detach().clone() + dist.all_reduce(global_tokens, group=sequence_parallel._sp_group) + ctx.save_for_backward(local_tokens, global_tokens) + if global_tokens.item() == 0: + return local_sum + return global_sum / global_tokens + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + local_tokens, global_tokens = ctx.saved_tensors + if global_tokens.item() == 0: + return torch.zeros_like(grad_output), None + # d(global_mean)/d(local_mean) = local_tokens / global_tokens. + grad_local_mean = grad_output * (local_tokens / global_tokens) * compensate_factor + return grad_local_mean, None + + class _ReduceSequenceParallelSum(torch.autograd.Function): + + @staticmethod + def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor: + ctx.sum_metric_scale = sum_metric_scale + global_sum = local_sum.detach().clone() + dist.all_reduce(global_sum, group=sequence_parallel._sp_group) + # Keep logging/metric value aligned with non-SP sum semantics under + # outer collect='mean' by removing one SP replication factor. + return global_sum / ctx.sum_metric_scale + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Keep training gradient scale unchanged; forward-side scaling is for + # logging/metric alignment under outer collect='mean'. + return grad_output + if reduction == 'sum': - local_sum = loss - global_sum = local_sum.detach().clone() - dist.all_reduce(global_sum, group=sequence_parallel._sp_group) - out = global_sum + (local_sum - local_sum.detach()) - if sequence_parallel.world_size > 1: - out_metric = out.detach() / sequence_parallel.world_size - return out_metric + (out - out.detach()) - return out - # Default to mean reduction. - local_sum = loss * num_valid_tokens - global_sum = local_sum.detach().clone() - dist.all_reduce(global_sum, group=sequence_parallel._sp_group) - global_tokens = num_valid_tokens.detach().clone() - dist.all_reduce(global_tokens, group=sequence_parallel._sp_group) - if global_tokens.item() == 0: - return loss - out = (global_sum + (local_sum - local_sum.detach())) / global_tokens - if sequence_parallel.world_size > 1: - out_metric = out.detach() / sequence_parallel.world_size - return out_metric + (out - out.detach()) - return out + return _ReduceSequenceParallelSum.apply(loss) + + # Default to mean reduction: `loss` is local mean. + num_valid_tokens = (labels != ignore_index).sum().to(loss.device) + return _ReduceSequenceParallelLoss.apply(loss, num_valid_tokens) def wrap_model(self, model, optimizer=None): self.initialize() diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 75aa8a1d..6f80699b 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -247,9 +247,15 @@ def _ensure_sp_strategy(self) -> None: return from .strategy.sequence_parallel import SequenceParallelStrategy + sp_config = {} + # When data-parallel gradient averaging runs across SP shards (native FSDP or + # accelerate DDP/FSDP paths), compensate SP loss backward to keep gradient scale. + if isinstance(self.strategy, (NativeFSDPStrategy, AccelerateStrategy)) and self.device_mesh is not None: + if (self.device_mesh.ulysses_size or 1) > 1 and (self.device_mesh.data_world_size or 1) > 1: + sp_config['compensate_fsdp_avg'] = True self.sp_strategy = SequenceParallelStrategy( self.device_mesh, - {}, + sp_config, model=self.model, tokenizer_id=self.tokenizer_id, ) @@ -434,10 +440,9 @@ def calculate_loss(self, **kwargs): optimizer_config = self.optimizer_group[adapter_name] optimizer_config.num_tokens += counts.item() if self.sp_strategy is not None and 'labels' in inputs: - if 'loss_reduction' not in self.sp_strategy.sp_config: - reduction = getattr(loss_instance, 'reduction', None) - if reduction is not None: - self.sp_strategy.sp_config['loss_reduction'] = str(reduction) + reduction = getattr(loss_instance, 'reduction', None) + if reduction is not None: + self.sp_strategy.sp_config['loss_reduction'] = str(reduction) loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels']) optimizer_config.loss_value += loss_value outputs['loss'] = optimizer_config.loss_value