diff --git a/cookbook/legacy/single_controller_sp.py b/cookbook/transformers/sp_fsdp_dense.py similarity index 62% rename from cookbook/legacy/single_controller_sp.py rename to cookbook/transformers/sp_fsdp_dense.py index 995d59d9..99c2e4ec 100644 --- a/cookbook/legacy/single_controller_sp.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -1,6 +1,5 @@ from functools import partial import numpy as np -import torch from peft import LoraConfig import twinkle @@ -12,6 +11,7 @@ logger = get_logger() MODEL_ID = 'ms://Qwen/Qwen2.5-7B-Instruct' +DATASETS='ms://swift/self-cognition' device_group = [ DeviceGroup( @@ -30,79 +30,70 @@ ) twinkle.initialize( - mode="ray", + mode="local", nproc_per_node=4, - groups=device_group, global_device_mesh=device_mesh, lazy_collect=False, ) +def eval(model): + dataloader = DataLoader( + dataset=partial(create_dataset, data_slice=range(100)), + batch_size=4, + device_mesh=device_mesh, + ) + for _, batch in enumerate(dataloader): + model.forward_only(inputs=batch, adapter_name="default") + model.calculate_loss(adapter_name="default") + return model.calculate_metric(is_training=False, adapter_name="default") + def create_dataset(data_slice=None): dataset = Dataset( - dataset_meta=DatasetMeta("ms://swift/self-cognition", data_slice=data_slice) + dataset_meta=DatasetMeta(DATASETS, data_slice=range(500)) ) dataset.set_template( "Template", - model_id=MODEL_ID, - truncation_strategy="left", - max_length=64, + model_id=MODEL_ID ) dataset.map(SelfCognitionProcessor("twinkle模型", "twinkle团队")) dataset.encode(batched=True) return dataset - - -def eval(model: TransformersModel): - dataloader = DataLoader( - dataset=partial(create_dataset, data_slice=range(20)), - batch_size=4, - drop_last=True, - device_mesh=device_mesh, - remote_group="default", - ) - for step, batch in enumerate(dataloader): - model.forward_only(inputs=batch, adapter_name="default") - model.calculate_loss(adapter_name="default") - metrics = model.calculate_metric(is_training=False, adapter_name="default") - return metrics() - - def train(): dataloader = DataLoader( dataset=partial(create_dataset, data_slice=None), - batch_size=4, + batch_size=8, device_mesh=device_mesh, - remote_group="default", ) model = TransformersModel( model_id=MODEL_ID, device_mesh=device_mesh, strategy="native_fsdp", - remote_group="default", ) lora_config = LoraConfig(target_modules="all-linear") model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1) model.set_optimizer("AdamW", lr=1e-4, adapter_name="default") + model.set_lr_scheduler( + scheduler_cls="CosineWarmupScheduler", + num_warmup_steps=5, + num_training_steps=len(dataloader), + adapter_name="default", + ) + + logger.info(model.get_train_configs(adapter_name="default")) + logger.info(f"Total steps: {len(dataloader)}") + - loss_metric = 99.0 for step, batch in enumerate(dataloader): - if isinstance(batch, list) and len(batch) == 0: - continue - output = model.forward_backward(inputs=batch, adapter_name="default") - loss_value = output() if callable(output) else output - logger.info(f"step {step}, loss: {loss_value}") + model.forward_backward(inputs=batch, adapter_name="default") model.clip_grad_and_step(adapter_name="default") - if step % 50 == 0 and step > 0: - metrics = eval(model) - logger.info(f"Current is step {step} of {len(dataloader)}, metric: {metrics}") - metrics["step"] = step - if loss_metric > metrics["loss"]: - model.save(f"checkpoint-{step}") - loss_metric = metrics["loss"] + if step % 20 == 0: + metric = model.calculate_metric(is_training=True, adapter_name="default") + logger.info(f"Current is step {step} of {len(dataloader)}, metric: {metric}") + model.save("last-checkpoint", interval=1) if __name__ == "__main__": - train() + train() \ No newline at end of file diff --git a/cookbook/transformers/sp_fsdp_dense.sh b/cookbook/transformers/sp_fsdp_dense.sh new file mode 100644 index 00000000..9603780e --- /dev/null +++ b/cookbook/transformers/sp_fsdp_dense.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# To enabele sequence parallelism, please set ulysses_size > 1 +# device_mesh = DeviceMesh( +# device_type="cuda", +# mesh=np.arange(4).reshape(2, 2), +# mesh_dim_names=("dp", "fsdp"), +# ulysses_size=2, +# ) +# +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 sp_fsdp_dense.py \ No newline at end of file diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 8e2aa749..e3a95a79 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -1,7 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import math from functools import partial -from types import SimpleNamespace from typing import Any, Dict, Optional, Tuple, Union import torch @@ -10,16 +8,11 @@ from dataclasses import asdict, dataclass, is_dataclass from twinkle.utils import DeviceMesh +from twinkle.utils.transformers_utils import get_llm_model -def get_llm_model(model): # type: ignore - return getattr(model, "language_model", model) - - -class HfConfigFactory: # type: ignore - @staticmethod - def get_config_attr(config, attr_name: str, include_vit: bool = False): - return getattr(config, attr_name, None) +def get_config_attr(config, key, default=None): + return getattr(config, key, default) def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor): @@ -196,8 +189,8 @@ def forward(ctx, loss, labels, gather_idx=None, position_ids=None): @staticmethod def backward(ctx, *grad_output): - # Split grads back to local sequence chunk; scale for all-gather semantics. - _grad = grad_output[0] * sequence_parallel.world_size + # Split grads back to local sequence chunk. + _grad = grad_output[0] if sequence_parallel.world_size > 1 and sequence_parallel._sp_group is not None: _grad = sequence_parallel.split(_grad, dim=ctx.gather_idx, position_ids=ctx.position_ids).contiguous() return _grad, None, None, None @@ -468,7 +461,31 @@ def _attention(query, key, value, *args, **kwargs): query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) - if 'cu_seq_lens_q' in kwargs: + # Packed batches (produced by PackingDataset + padding_free collate) require FA2 varlen + # semantics to avoid cross-subsequence attention. We derive cu_seqlens from position_ids + # resets (0,1,...) and pass cu_seq_lens_* to FA2. + if self.extra_kwargs.get('is_packed', False): + position_ids = kwargs.get('position_ids') + if position_ids is None: + position_ids = self.real_position_ids + # Treat SP-alignment padding (-1) as separate 1-token sequences by mapping -1 -> 0. + pos = position_ids + if pos.dim() == 1: + pos = pos.unsqueeze(0) + pos = pos.clone() + pos[pos < 0] = 0 + + cu_seqlens = get_cu_seqlens_from_position_ids(pos).to(torch.int32) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + assert query.shape[2] == cu_seqlens[-1] + kwargs['cu_seq_lens_q'] = cu_seqlens + kwargs['cu_seq_lens_k'] = cu_seqlens + kwargs['max_length_q'] = max_seqlen + kwargs['max_length_k'] = max_seqlen + # Do not use attention_mask-based unpadding when using explicit cu_seqlens. + if len(args) > 0: + args = (None, *args[1:]) + elif 'cu_seq_lens_q' in kwargs: position_ids = kwargs.get('position_ids') if position_ids is None: position_ids = self.real_position_ids @@ -495,6 +512,14 @@ def local_sdpa_attn(module: torch.nn.Module, query_states, key_states, value_sta if self.world_size == 1 or module.__class__ not in [m.__class__ for m in text_model.modules()]: return ALL_ATTENTION_FUNCTIONS['sdpa_origin'](module, query_states, key_states, value_states, attention_mask, *args, **kwargs) + # Policy: packed (PackingDataset/padding-free) batches require FlashAttention2 varlen/packed semantics. + # SDPA does not have a native packed/varlen interface; supporting packed batches would require building a + # large block-diagonal causal mask (slow / memory heavy). + if self.extra_kwargs.get('is_packed', False): + raise RuntimeError( + 'SequenceParallel: detected packed batch (position_ids contains multiple sequences). ' + 'SDPA backend is not supported for packed batches; please use flash_attention_2.' + ) if dist_attn.local_attn is None: def _attention(query, key, value, *args, **kwargs): @@ -582,7 +607,7 @@ def _is_moe_model(config) -> bool: if 'Moe' in config.__class__.__name__: return True for key in ['num_experts', 'num_experts_per_tok', 'moe_intermediate_size']: - if HfConfigFactory.get_config_attr(config, key): + if get_config_attr(config, key): return True return False @@ -591,18 +616,16 @@ def prepare( sp_size: int, model: torch.nn.Module, tokenizer: PreTrainedTokenizer, - padding_free: bool, device_mesh: Optional[DeviceMesh] = None, ): - self.num_heads = HfConfigFactory.get_config_attr(model.config, 'num_key_value_heads') + self.num_heads = get_config_attr(model.config, 'num_key_value_heads') if self.num_heads is None: - self.num_heads = HfConfigFactory.get_config_attr(model.config, 'num_attention_heads') + self.num_heads = get_config_attr(model.config, 'num_attention_heads') assert self.num_heads is not None, 'Cannot find num_heads config in config.json' if sp_size > 1 and self.num_heads % sp_size != 0: raise ValueError( f'sp_size ({sp_size}) must divide num_heads ({self.num_heads}) for ulysses sequence parallel.' ) - self.padding_free = padding_free self.world_size = sp_size llm_model = get_llm_model(model) @@ -627,8 +650,6 @@ def prepare( self.model_dtype = next(model.parameters()).dtype self.tokenizer = tokenizer - if not self.padding_free: - pass def pad(self, tensor, padding_value, position_ids=None, dim=1): """Pad tensor for sequence parallel""" @@ -683,26 +704,6 @@ def split(self, input, dim: int, position_ids=None): output = tensor_list[rank].contiguous() return output - def pad_and_split_mm_tokens(self, visual_mask, mm_embeds): - input_ids = self.extra_kwargs['input_ids'] - empty_embeds = torch.empty( - (input_ids.shape[0], input_ids.shape[1], mm_embeds.shape[-1])).to(mm_embeds.device).to(mm_embeds.dtype) - empty_embeds[visual_mask] = mm_embeds - - embeds = SimpleNamespace(weight=mm_embeds) - - _, split_input_embeds, _, _, _, _, extra_values = self.pad_and_split_inputs( - None, - empty_embeds, - None, - None, - None, - None, - embeds, - self.real_position_ids, - extra_split_values=[(visual_mask, 0, -1)]) - visual_mask = extra_values[0] - return visual_mask, split_input_embeds[visual_mask] def pad_and_split_inputs(self, input_ids, @@ -731,6 +732,8 @@ def pad_and_split_inputs(self, """ tokenizer = self.tokenizer real_position_ids = real_position_ids if real_position_ids is not None else position_ids + # Track packed batches to drive attention backend behavior (packed => require flash_attention_2 varlen). + self.extra_kwargs['is_packed'] = self._is_packed_position_ids(real_position_ids) extra_values = [] batch_size = input_ids.shape[ 0] if input_ids is not None else input_embeds.shape[0] if input_embeds is not None else None @@ -754,12 +757,20 @@ def pad_and_split_inputs(self, loss_scale = self.pad(loss_scale, padding_value=0., position_ids=real_position_ids) if real_position_ids is not None: real_position_ids = self.pad(real_position_ids, padding_value=-1, position_ids=real_position_ids) + # Build a 2D attention_mask whenever we padded for SP alignment so FlashAttention2 can unpad correctly. + # For packed batches (batch_size==1 with multiple position_id resets), relying on position_ids alone is + # unsafe if we also appended SP-alignment padding (position_ids=-1), because HF's FA2 varlen path will + # include the padded tail in the last segment when attention_mask is None. if (input_ids is not None or input_embeds is not None) and batch_size > 1: + # not padding_free, so not ring-attention inputs = input_ids if input_ids is not None else input_embeds attn_shape = inputs.shape[1] # The sequence length if attention_mask is None: - attention_mask = torch.ones_like(real_position_ids) + # Mask out padded positions introduced by sequence-parallel padding. + # `real_position_ids` is padded with `-1` (see above), so use it to build a valid-token mask. + attention_mask = (real_position_ids != -1).to(dtype=torch.int64) # no need position_ids here, because padding_free does not need attention_mask, + # so this is not ring-attention attention_mask = self.pad(attention_mask, padding_value=0) cache_position = torch.arange(0, attn_shape, device=inputs.device) # pad attention mask to 4d to avoid calculation errors @@ -775,8 +786,22 @@ def pad_and_split_inputs(self, if input_embeds is not None: input_embeds = self.split(input_embeds, dim=1, position_ids=real_position_ids) if labels is not None: - # Align next-token labels before splitting so each rank sees local targets. - labels = torch.roll(labels, shifts=-1, dims=-1) + if self.extra_kwargs.get('is_packed', False) and real_position_ids is not None: + # PackingDataset + padding_free collate concatenates multiple sequences into a single token stream. + # `position_ids` resets to 0 at each boundary, but our labels are already next-token aligned by + # Template._roll_labels(). Therefore the cross-subsequence supervision term lives at the *previous* + # token index (the token right before a boundary start). + # + # Example (boundary at index b where position_ids[b] == 0): + # - Bad term is: token[b-1] predicting token[b] + # - In next-token-aligned labels, this appears at labels[b-1] + boundary_starts = (real_position_ids == 0) + prev = torch.zeros_like(boundary_starts, dtype=torch.bool) + prev[..., 1:] = boundary_starts[..., :-1] + labels = labels.clone() + labels[prev] = -100 + # Also avoid any potential wrap-around supervision at the end of the concatenated stream. + labels[..., -1] = -100 labels = self.split(labels, dim=-1, position_ids=real_position_ids) if loss_scale is not None: loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1) @@ -802,6 +827,27 @@ def _init_device_mesh(self, device_mesh: Optional[DeviceMesh] = None): if self._sp_group is None and self.sp_world_size > 1: raise RuntimeError("Failed to create sequence-parallel group from DeviceMesh.") + @staticmethod + def _is_packed_position_ids(position_ids: Optional[torch.Tensor]) -> bool: + """Heuristic: detect packed samples by multiple (0,1,...) resets in position_ids. + + PackingDataset packs multiple sequences into one row by resetting position_ids to 0/1/... at each boundary. + """ + if position_ids is None or not torch.is_tensor(position_ids): + return False + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + if position_ids.dim() != 2: + return False + # A batch may contain multiple packed samples; consider it "packed" if any row is packed. + for i in range(position_ids.size(0)): + row = position_ids[i] + zero_count = int((row == 0).sum().item()) + one_count = int((row == 1).sum().item()) + if zero_count > 1 and one_count > 1: + return True + return False + def prepare_inputs(self, inputs): """Prepare inputs @@ -814,6 +860,7 @@ def prepare_inputs(self, inputs): position_ids = inputs.get('position_ids') if position_ids is not None and input_ids is not None and position_ids.shape[0] == input_ids.shape[0]: self.extra_kwargs['position_ids'] = position_ids.clone() + self.extra_kwargs['is_packed'] = self._is_packed_position_ids(position_ids) if input_ids is not None: self.extra_kwargs['input_ids'] = input_ids.clone() if 'labels' in inputs: @@ -831,8 +878,8 @@ def prepare_inputs(self, inputs): class SequenceParallelConfig: enabled: bool = True ulysses_size: Optional[int] = None - padding_free: bool = False gather_logits: bool = True + loss_reduction: str = "mean" def _get_ulysses_size(device_mesh, sp_config: Optional[Dict[str, Any]] = None) -> int: @@ -866,7 +913,6 @@ def __init__( self.sp_config = sp_config or {} self.enabled = bool(self.sp_config.get("enabled", True)) self.ulysses_size = _get_ulysses_size(device_mesh, self.sp_config) - self.padding_free = bool(self.sp_config.get("padding_free", False)) self._model_ref = model self._tokenizer_id = tokenizer_id self._tokenizer = None @@ -901,7 +947,6 @@ def initialize(self) -> bool: self.ulysses_size, self._model_ref, tokenizer, - self.padding_free, device_mesh=self.device_mesh, ) self._initialized = True @@ -919,29 +964,25 @@ def postprocess_outputs(self, outputs: Any) -> Any: or not self.sp_config.get("gather_logits", True) ): return outputs - # Optionally reassemble full-seq logits for downstream consumers. - logits = None - if hasattr(outputs, "logits"): - logits = outputs.logits - elif isinstance(outputs, dict): - logits = outputs.get("logits") - elif isinstance(outputs, (list, tuple)) and len(outputs) > 0 and torch.is_tensor(outputs[0]): - logits = outputs[0] + # Twinkle expects dict-like ModelOutput containers in the main training path + # (uses `.get(...)` and `outputs[...] = ...`). Keep SP postprocess consistent. + if outputs is None or not hasattr(outputs, "get") or not hasattr(outputs, "__setitem__"): + raise TypeError( + "SequenceParallelStrategy.postprocess_outputs expects a dict-like ModelOutput. " + f"Got type={type(outputs)}" + ) + logits = outputs.get("logits", None) if logits is None or not torch.is_tensor(logits) or logits.dim() < 2: return outputs gathered = sequence_parallel.gather( logits, dim=1, position_ids=sequence_parallel.real_position_ids ) - if hasattr(outputs, "logits"): - outputs.logits = gathered - return outputs - if isinstance(outputs, dict): - outputs["logits"] = gathered - return outputs - if isinstance(outputs, (list, tuple)) and len(outputs) > 0: - new = list(outputs) - new[0] = gathered - return type(outputs)(new) + # Scheme A: SP pads to make seq_len divisible by sp_size. Trim back to the original + # (unpadded) length using the cached real_position_ids. + real_pos = sequence_parallel.real_position_ids + if real_pos is not None and torch.is_tensor(real_pos) and real_pos.dim() >= 2: + gathered = gathered[:, : real_pos.shape[1]].contiguous() + outputs["logits"] = gathered return outputs def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore_index: int = -100) -> torch.Tensor: @@ -949,12 +990,36 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore return loss if labels is None or sequence_parallel._sp_group is None: return loss - # Reduce loss with token-count normalization across SP ranks. + # Compute full-sequence loss in forward, but keep backward local to this rank. + reduction = str(self.sp_config.get("loss_reduction", "mean")).lower() + if reduction == "none": + raise ValueError( + "SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. " + "Please aggregate per-token losses before calling reduce_loss." + ) num_valid_tokens = (labels != ignore_index).sum().to(loss.device) - reduced_loss = loss * num_valid_tokens - dist.all_reduce(reduced_loss, group=sequence_parallel._sp_group) - dist.all_reduce(num_valid_tokens, group=sequence_parallel._sp_group) - return reduced_loss / num_valid_tokens + if reduction == "sum": + local_sum = loss + global_sum = local_sum.detach().clone() + dist.all_reduce(global_sum, group=sequence_parallel._sp_group) + out = global_sum + (local_sum - local_sum.detach()) + if sequence_parallel.world_size > 1: + out_metric = out.detach() / sequence_parallel.world_size + return out_metric + (out - out.detach()) + return out + # Default to mean reduction. + local_sum = loss * num_valid_tokens + global_sum = local_sum.detach().clone() + dist.all_reduce(global_sum, group=sequence_parallel._sp_group) + global_tokens = num_valid_tokens.detach().clone() + dist.all_reduce(global_tokens, group=sequence_parallel._sp_group) + if global_tokens.item() == 0: + return loss + out = (global_sum + (local_sum - local_sum.detach())) / global_tokens + if sequence_parallel.world_size > 1: + out_metric = out.detach() / sequence_parallel.world_size + return out_metric + (out - out.detach()) + return out def wrap_model(self, model, optimizer=None): self.initialize() diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 2df9b5bf..eea51435 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -425,6 +425,10 @@ def calculate_loss(self, **kwargs): optimizer_config = self.optimizer_group[adapter_name] optimizer_config.num_tokens += counts.item() if self.sp_strategy is not None and 'labels' in inputs: + if "loss_reduction" not in self.sp_strategy.sp_config: + reduction = getattr(loss_instance, "reduction", None) + if reduction is not None: + self.sp_strategy.sp_config["loss_reduction"] = str(reduction) loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels']) optimizer_config.loss_value += loss_value outputs['loss'] = optimizer_config.loss_value diff --git a/src/twinkle/utils/transformers_utils.py b/src/twinkle/utils/transformers_utils.py index 3a5ff344..2623b867 100644 --- a/src/twinkle/utils/transformers_utils.py +++ b/src/twinkle/utils/transformers_utils.py @@ -133,4 +133,50 @@ def get_modules_to_not_convert(model): if 'linear' in m.__class__.__name__.lower() and (any(n.endswith(suffix) for suffix in suffix_list) or any(n.startswith(prefix) for prefix in prefix_list)): res.append(n) - return res if res else None \ No newline at end of file + return res if res else None + +def get_llm_model(model, *, model_meta=None, inner_backbone: bool = True): + """Best-effort extraction of the LLM module from a (possibly wrapped) model. + + This mirrors the common pattern used by Swift/PEFT/Accelerate stacks: + - unwrap parallel wrappers (DDP/FSDP/Accelerate) + - unwrap PEFT/Swift wrappers (if present) + - use `model_meta.model_arch.language_model` to locate the LLM in multimodal models + - optionally return the inner backbone (e.g. `QwenModel`/`LlamaModel`) via `.model` + """ + # 1) Unwrap parallel wrappers (Accelerate). + try: + from accelerate.utils import extract_model_from_parallel # type: ignore + + model = extract_model_from_parallel(model) + except Exception: + pass + + # 2) Unwrap PEFT wrappers. + try: + from peft import PeftModel # type: ignore + + if isinstance(model, PeftModel): + model = model.model + except Exception: + pass + + # 3) Locate the language model module in multimodal containers via model_meta. + if model_meta is None: + model_meta = getattr(model, "model_meta", None) + llm_model = model + model_arch = getattr(model_meta, "model_arch", None) if model_meta is not None else None + llm_prefix = getattr(model_arch, "language_model", None) if model_arch is not None else None + if llm_prefix: + # Convention: `language_model` is a list of candidate prefixes. + llm_model = deep_getattr(model, llm_prefix[0]) + else: + llm_model = getattr(model, "language_model", model) + + # 4) Return the inner backbone if requested. + if inner_backbone: + if hasattr(llm_model, "thinker"): + llm_model = llm_model.thinker.model + elif hasattr(llm_model, "model"): + llm_model = llm_model.model + return llm_model diff --git a/tests/sequence_parallel/test_sequence_parallel_single_attention.py b/tests/sequence_parallel/test_sequence_parallel_single_attention.py new file mode 100644 index 00000000..4ff489a4 --- /dev/null +++ b/tests/sequence_parallel/test_sequence_parallel_single_attention.py @@ -0,0 +1,384 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import os +import socket +import sys +import contextlib +from datetime import timedelta +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +SRC_PATH = str(REPO_ROOT / "src") +if SRC_PATH not in sys.path: + sys.path.insert(0, SRC_PATH) + +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from twinkle.model.transformers.strategy.sequence_parallel import ( + DistributedAttention, + _get_sp_group_from_device_mesh, + sequence_parallel, +) +from twinkle.model.transformers.strategy import NativeFSDPStrategy +from twinkle.utils import DeviceMesh + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + +def _broadcast_params(module: torch.nn.Module) -> None: + for p in module.parameters(): + dist.broadcast(p.data, src=0) + + +def _enable_strict_determinism() -> None: + # Reduce kernel variability: avoid TF32 and prefer deterministic algorithms. + # Note: some collectives/kernels may not have deterministic variants; keep warn_only. + if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"): + torch.backends.cuda.matmul.allow_tf32 = False + if hasattr(torch.backends, "cudnn"): + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.benchmark = False + if hasattr(torch, "set_float32_matmul_precision"): + torch.set_float32_matmul_precision("highest") + torch.use_deterministic_algorithms(True, warn_only=True) + +def _to_local(x: torch.Tensor) -> torch.Tensor: + # FSDP2 grads can be DTensors; unwrap to local for comparison. + if hasattr(x, "to_local"): + try: + return x.to_local() + except Exception: # noqa: BLE001 + pass + if hasattr(x, "local_tensor"): + try: + return x.local_tensor + except Exception: # noqa: BLE001 + pass + return x + + +@contextlib.contextmanager +def _force_sdpa_math(): + """Force SDPA to use the math backend for stricter (more deterministic) alignment.""" + try: + from torch.nn.attention import SDPBackend, sdpa_kernel + + with sdpa_kernel(SDPBackend.MATH): + yield + return + except Exception: # noqa: BLE001 + pass + + # Fallback for older torch versions. + if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "sdp_kernel"): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, + enable_math=True, + enable_mem_efficient=False, + enable_cudnn=False, + ): + yield + else: + yield + + +class _SingleAttention(torch.nn.Module): + def __init__(self, hidden_dim: int, num_heads: int, sp_enabled: bool): + super().__init__() + assert hidden_dim % num_heads == 0 + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.sp_enabled = sp_enabled + + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.k_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.v_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.out_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + + self._dist_attn = DistributedAttention(self._local_attn, sequence_parallel) + + @staticmethod + def _local_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + *, + position_ids=None, + is_causal: bool = True, + **_kwargs, + ) -> torch.Tensor: + # query/key/value: [B, S, H, D] + q = query.permute(0, 2, 1, 3).contiguous() + k = key.permute(0, 2, 1, 3).contiguous() + v = value.permute(0, 2, 1, 3).contiguous() + with _force_sdpa_math(): + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal + ) + return out.permute(0, 2, 1, 3).contiguous() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bsz, seqlen, _ = x.shape + q = self.q_proj(x).view(bsz, seqlen, self.num_heads, self.head_dim) + k = self.k_proj(x).view(bsz, seqlen, self.num_heads, self.head_dim) + v = self.v_proj(x).view(bsz, seqlen, self.num_heads, self.head_dim) + + if self.sp_enabled and sequence_parallel.world_size and sequence_parallel.world_size > 1: + ctx = self._dist_attn(q, k, v, None, position_ids=None, is_causal=True) + else: + ctx = self._local_attn(q, k, v, None, position_ids=None, is_causal=True) + + out = self.out_proj(ctx.reshape(bsz, seqlen, self.hidden_dim)) + return out + + +def _init_dist(rank: int, world_size: int, port: int) -> torch.device: + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["LOCAL_WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this test.") + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" + dist.init_process_group( + backend="nccl", + rank=rank, + world_size=world_size, + init_method=f"tcp://127.0.0.1:{port}", + device_id=device, + timeout=timedelta(minutes=15), + ) + dist.barrier() + return device + + +def _setup_sp(device_mesh: DeviceMesh, sp_size: int) -> None: + sequence_parallel.world_size = sp_size + sequence_parallel._init_device_mesh(device_mesh) + + +def _run_worker_single_attn(rank: int, world_size: int, port: int, padding: bool): + device = _init_dist(rank, world_size, port) + try: + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + _enable_strict_determinism() + + # Align with VeOmni's test semantics: one SP group across the whole world. + sp_size = world_size + device_mesh = DeviceMesh.from_sizes(dp_size=world_size, ulysses_size=sp_size, device_type="cuda") + _setup_sp(device_mesh, sp_size) + sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size) + + batch_size = 2 + unpad_seq_len = 127 if padding else 128 + hidden_dim = 256 + num_heads = 16 # must be divisible by sp_size + assert num_heads % sp_size == 0 + dtype = torch.float32 # maximize determinism for strict alignment + + full_x = torch.randn(batch_size, unpad_seq_len, hidden_dim, device=device, dtype=dtype) + dist.broadcast(full_x, src=0) + + if padding: + x = sequence_parallel.pad(full_x, padding_value=0.0, position_ids=None, dim=1) + else: + x = full_x + pad_seq_len = x.size(1) + assert pad_seq_len % sp_size == 0 + + dp_x = x.detach().requires_grad_(True) + sp_x_local = sequence_parallel.split(x, dim=1, position_ids=None).detach().requires_grad_(True) + sp_rank = dist.get_rank(sp_group) if sp_group is not None else 0 + local = pad_seq_len // sp_size + start = sp_rank * local + end = start + local + valid_end = min(end, unpad_seq_len) + local_valid = max(valid_end - start, 0) + + attn_sp = _SingleAttention(hidden_dim, num_heads, sp_enabled=True).to(device=device, dtype=dtype) + attn_dp = _SingleAttention(hidden_dim, num_heads, sp_enabled=False).to(device=device, dtype=dtype) + _broadcast_params(attn_sp) + attn_dp.load_state_dict(attn_sp.state_dict()) + + # forward (SP) + sp_out_local = attn_sp(sp_x_local) + sp_out_full = sequence_parallel.gather(sp_out_local, dim=1, position_ids=None)[:, :unpad_seq_len] + + # backward (SP): VeOmni uses overlapping grad on local output, then all-reduces param grads. + sp_loss = sp_out_local[:, :local_valid].sum() * 2.0 + sp_loss.backward() + sp_q_grad = attn_sp.q_proj.weight.grad.detach().clone() + sp_o_grad = attn_sp.out_proj.weight.grad.detach().clone() + sp_x_grad_full = sequence_parallel.gather(sp_x_local.grad.detach(), dim=1, position_ids=None)[:, :unpad_seq_len] + + if sp_group is not None: + dist.all_reduce(sp_q_grad, op=dist.ReduceOp.SUM, group=sp_group) + dist.all_reduce(sp_o_grad, op=dist.ReduceOp.SUM, group=sp_group) + + # Disable SP for the baseline DP run (like set_ulysses_sequence_parallel_group(None)). + saved_world = sequence_parallel.world_size + saved_sp_world = sequence_parallel.sp_world_size + saved_group = sequence_parallel._sp_group + sequence_parallel.world_size = 1 + sequence_parallel.sp_world_size = 1 + sequence_parallel._sp_group = None + + # forward/backward (DP full sequence) + dp_out_full = attn_dp(dp_x)[:, :unpad_seq_len] + torch.testing.assert_close(dp_out_full, sp_out_full, atol=1e-6, rtol=1e-5) + + dp_loss = dp_out_full.sum() * 2.0 + dp_loss.backward() + dp_q_grad = attn_dp.q_proj.weight.grad.detach() + dp_o_grad = attn_dp.out_proj.weight.grad.detach() + dp_x_grad_full = dp_x.grad.detach()[:, :unpad_seq_len] + + # Restore SP globals (not strictly needed, but keeps teardown clean). + sequence_parallel.world_size = saved_world + sequence_parallel.sp_world_size = saved_sp_world + sequence_parallel._sp_group = saved_group + + torch.testing.assert_close(dp_o_grad, sp_o_grad, atol=1e-3, rtol=1e-4) + torch.testing.assert_close(dp_q_grad, sp_q_grad, atol=1e-3, rtol=1e-4) + torch.testing.assert_close(dp_x_grad_full, sp_x_grad_full, atol=1e-5, rtol=1e-5) + finally: + dist.destroy_process_group() + + +def _run_worker_single_attn_fsdp(rank: int, world_size: int, port: int): + device = _init_dist(rank, world_size, port) + try: + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + _enable_strict_determinism() + + # Use sp_size=world_size to avoid multiple independent SP groups while FSDP shards globally. + sp_size = world_size + # For FSDP+SP, SP is derived from dp/fsdp ranks. Use fsdp=world, dp=1. + device_mesh = DeviceMesh.from_sizes(fsdp_size=world_size, dp_size=1, ulysses_size=sp_size, device_type="cuda") + _setup_sp(device_mesh, sp_size) + sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size) + + batch_size = 2 + unpad_seq_len = 128 + hidden_dim = 256 + num_heads = 16 + assert num_heads % sp_size == 0 + dtype = torch.float32 + + full_x = torch.randn(batch_size, unpad_seq_len, hidden_dim, device=device, dtype=dtype) + dist.broadcast(full_x, src=0) + + dp_x = full_x.detach().requires_grad_(True) + sp_x_local = sequence_parallel.split(full_x, dim=1, position_ids=None).detach().requires_grad_(True) + + attn_sp = _SingleAttention(hidden_dim, num_heads, sp_enabled=True).to(device=device, dtype=dtype) + attn_dp = _SingleAttention(hidden_dim, num_heads, sp_enabled=False).to(device=device, dtype=dtype) + _broadcast_params(attn_sp) + attn_dp.load_state_dict(attn_sp.state_dict()) + + fsdp = NativeFSDPStrategy(device_mesh=device_mesh, mixed_precision="no", fsdp_config={}) + attn_sp, _ = fsdp.wrap_model(attn_sp, optimizer=None) + attn_dp, _ = fsdp.wrap_model(attn_dp, optimizer=None) + + # SP forward/backward: local loss; across ranks sum(loss_local) == DP full loss. + sp_out_local = attn_sp(sp_x_local) + sp_out_full = sequence_parallel.gather(sp_out_local, dim=1, position_ids=None)[:, :unpad_seq_len] + sp_loss = sp_out_local.sum() * 2.0 + sp_loss.backward() + sp_x_grad_local = sp_x_local.grad.detach() + sp_x_grad_full = sequence_parallel.gather(sp_x_grad_local, dim=1, position_ids=None)[:, :unpad_seq_len] + + # DP forward/backward: full loss on full sequence. + dp_out_full = attn_dp(dp_x)[:, :unpad_seq_len] + torch.testing.assert_close(dp_out_full, sp_out_full, atol=1e-6, rtol=1e-5) + # In FSDP, per-rank losses are effectively summed across ranks when forming param grads. + # DP uses identical full-seq inputs on every rank, so scale by world_size to match the + # global objective of SP (which partitions the sequence across ranks). + dp_loss = dp_out_full.sum() * (2.0 / float(world_size)) + dp_loss.backward() + + # Under FSDP2, grads are sharded; compare local shards directly (same mesh, same wrapping). + torch.testing.assert_close( + _to_local(attn_dp.out_proj.weight.grad.detach()), + _to_local(attn_sp.out_proj.weight.grad.detach()), + atol=1e-3, + rtol=1e-4, + ) + torch.testing.assert_close( + _to_local(attn_dp.q_proj.weight.grad.detach()), + _to_local(attn_sp.q_proj.weight.grad.detach()), + atol=1e-3, + rtol=1e-4, + ) + # dp_x.grad is not reduced across ranks; rescale to match the unscaled SP full loss. + dp_x_grad_full = dp_x.grad.detach()[:, :unpad_seq_len] * float(world_size) + torch.testing.assert_close(dp_x_grad_full, sp_x_grad_full, atol=1e-5, rtol=1e-5) + + # Only validate forward + parameter grads for the FSDP+SP case. + finally: + dist.destroy_process_group() + + +class TestSequenceParallelSingleAttention(unittest.TestCase): + def test_single_attention(self): + if not dist.is_available(): + self.skipTest("torch.distributed is not available") + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for this test.") + world_size = 4 + if torch.cuda.device_count() < world_size: + self.skipTest("Requires at least 4 GPUs for sequence-parallel attention test.") + port = _find_free_port() + mp.spawn( + _run_worker_single_attn, + args=(world_size, port, False), + nprocs=world_size, + join=True, + ) + + def test_single_attention_padding(self): + if not dist.is_available(): + self.skipTest("torch.distributed is not available") + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for this test.") + world_size = 4 + if torch.cuda.device_count() < world_size: + self.skipTest("Requires at least 4 GPUs for sequence-parallel attention test.") + port = _find_free_port() + mp.spawn( + _run_worker_single_attn, + args=(world_size, port, True), + nprocs=world_size, + join=True, + ) + + def test_single_attention_fsdp(self): + if not dist.is_available(): + self.skipTest("torch.distributed is not available") + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for this test.") + world_size = 4 + if torch.cuda.device_count() < world_size: + self.skipTest("Requires at least 4 GPUs for sequence-parallel attention test.") + port = _find_free_port() + mp.spawn( + _run_worker_single_attn_fsdp, + args=(world_size, port), + nprocs=world_size, + join=True, + )