From bc3eea5754f1549fdec89542b5193758914dcd8c Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 6 Feb 2026 10:03:13 +0800 Subject: [PATCH 01/14] feat(tests): add sequence parallel single attention test Add a new test file `test_sequence_parallel_single_attention.py` to verify the correctness of the sequence parallel attention implementation. The test includes a distributed setup using torch.distributed and compares outputs between sequence parallel and local attention modes. Also adds an empty `__init__.py` to the transformers test directory for proper module imports. --- tests/transformers/__init__.py | 2 + ...test_sequence_parallel_single_attention.py | 318 ++++++++++++++++++ 2 files changed, 320 insertions(+) create mode 100644 tests/transformers/__init__.py create mode 100644 tests/transformers/test_sequence_parallel_single_attention.py diff --git a/tests/transformers/__init__.py b/tests/transformers/__init__.py new file mode 100644 index 00000000..475cff64 --- /dev/null +++ b/tests/transformers/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. + diff --git a/tests/transformers/test_sequence_parallel_single_attention.py b/tests/transformers/test_sequence_parallel_single_attention.py new file mode 100644 index 00000000..15a85e94 --- /dev/null +++ b/tests/transformers/test_sequence_parallel_single_attention.py @@ -0,0 +1,318 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import os +import socket +import sys +import contextlib +from datetime import timedelta +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +SRC_PATH = str(REPO_ROOT / "src") +if SRC_PATH not in sys.path: + sys.path.insert(0, SRC_PATH) + +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from twinkle.model.transformers.strategy.sequence_parallel import ( + DistributedAttention, + _get_sp_group_from_device_mesh, + sequence_parallel, +) +from twinkle.model.transformers.strategy import NativeFSDPStrategy +from twinkle.utils import DeviceMesh + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + +def _broadcast_params(module: torch.nn.Module) -> None: + for p in module.parameters(): + dist.broadcast(p.data, src=0) + + +@contextlib.contextmanager +def _force_sdpa_math(): + """Force SDPA to use the math backend for stricter (more deterministic) alignment.""" + try: + from torch.nn.attention import SDPBackend, sdpa_kernel + + with sdpa_kernel(SDPBackend.MATH): + yield + return + except Exception: # noqa: BLE001 + pass + + # Fallback for older torch versions. + if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "sdp_kernel"): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, + enable_math=True, + enable_mem_efficient=False, + enable_cudnn=False, + ): + yield + else: + yield + + +class _SingleAttention(torch.nn.Module): + def __init__(self, hidden_dim: int, num_heads: int, sp_enabled: bool): + super().__init__() + assert hidden_dim % num_heads == 0 + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.sp_enabled = sp_enabled + + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.k_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.v_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.out_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + + self._dist_attn = DistributedAttention(self._local_attn, sequence_parallel) + + @staticmethod + def _local_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + *, + position_ids=None, + is_causal: bool = True, + **_kwargs, + ) -> torch.Tensor: + # query/key/value: [B, S, H, D] + q = query.permute(0, 2, 1, 3).contiguous() + k = key.permute(0, 2, 1, 3).contiguous() + v = value.permute(0, 2, 1, 3).contiguous() + with _force_sdpa_math(): + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal + ) + return out.permute(0, 2, 1, 3).contiguous() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bsz, seqlen, _ = x.shape + q = self.q_proj(x).view(bsz, seqlen, self.num_heads, self.head_dim) + k = self.k_proj(x).view(bsz, seqlen, self.num_heads, self.head_dim) + v = self.v_proj(x).view(bsz, seqlen, self.num_heads, self.head_dim) + + if self.sp_enabled and sequence_parallel.world_size and sequence_parallel.world_size > 1: + ctx = self._dist_attn(q, k, v, None, position_ids=None, is_causal=True) + else: + ctx = self._local_attn(q, k, v, None, position_ids=None, is_causal=True) + + out = self.out_proj(ctx.reshape(bsz, seqlen, self.hidden_dim)) + return out + + +def _init_dist(rank: int, world_size: int, port: int) -> torch.device: + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["LOCAL_WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this test.") + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" + dist.init_process_group( + backend="nccl", + rank=rank, + world_size=world_size, + init_method=f"tcp://127.0.0.1:{port}", + device_id=device, + timeout=timedelta(minutes=15), + ) + dist.barrier() + return device + + +def _setup_sp(device_mesh: DeviceMesh, sp_size: int) -> None: + sequence_parallel.world_size = sp_size + sequence_parallel._init_device_mesh(device_mesh) + + +def _run_worker_single_attn(rank: int, world_size: int, port: int, padding: bool): + device = _init_dist(rank, world_size, port) + try: + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + + sp_size = 2 + device_mesh = DeviceMesh.from_sizes(dp_size=world_size, ulysses_size=sp_size, device_type="cuda") + _setup_sp(device_mesh, sp_size) + + batch_size = 2 + unpad_seq_len = 127 if padding else 128 + hidden_dim = 256 + num_heads = 8 # must be divisible by sp_size + assert num_heads % sp_size == 0 + + full_x = torch.randn(batch_size, unpad_seq_len, hidden_dim, device=device, dtype=torch.bfloat16) + dist.broadcast(full_x, src=0) + + if padding: + x = sequence_parallel.pad(full_x, padding_value=0.0, position_ids=None, dim=1) + else: + x = full_x + + dp_x = x.detach().requires_grad_(True) + sp_x_local = sequence_parallel.split(x, dim=1, position_ids=None).detach().requires_grad_(True) + + attn_sp = _SingleAttention(hidden_dim, num_heads, sp_enabled=True).to(device=device, dtype=torch.bfloat16) + attn_dp = _SingleAttention(hidden_dim, num_heads, sp_enabled=False).to(device=device, dtype=torch.bfloat16) + _broadcast_params(attn_sp) + attn_dp.load_state_dict(attn_sp.state_dict()) + + # forward + sp_out_local = attn_sp(sp_x_local) + sp_out_full = sequence_parallel.gather(sp_out_local, dim=1, position_ids=None)[:, :unpad_seq_len] + + dp_out_full = attn_dp(dp_x)[:, :unpad_seq_len] + + torch.testing.assert_close(dp_out_full, sp_out_full, atol=2e-5, rtol=1e-5) + + # backward (use local loss; sum grads across SP group) + sp_loss = sp_out_local.sum() * 2.0 + sp_loss.backward() + sp_q_grad = attn_sp.q_proj.weight.grad.detach().clone() + sp_o_grad = attn_sp.out_proj.weight.grad.detach().clone() + dist.all_reduce(sp_q_grad, group=sequence_parallel._sp_group) + dist.all_reduce(sp_o_grad, group=sequence_parallel._sp_group) + sp_x_grad_full = sequence_parallel.gather(sp_x_local.grad.detach(), dim=1, position_ids=None)[:, :unpad_seq_len] + + dp_loss = dp_out_full.sum() * 2.0 + dp_loss.backward() + dp_q_grad = attn_dp.q_proj.weight.grad.detach().clone() + dp_o_grad = attn_dp.out_proj.weight.grad.detach().clone() + dp_x_grad_full = dp_x.grad.detach()[:, :unpad_seq_len] + + torch.testing.assert_close(dp_o_grad, sp_o_grad, atol=2e-3, rtol=1e-4) + torch.testing.assert_close(dp_q_grad, sp_q_grad, atol=2e-3, rtol=1e-4) + torch.testing.assert_close(dp_x_grad_full, sp_x_grad_full, atol=2e-5, rtol=1e-5) + finally: + dist.destroy_process_group() + + +def _run_worker_single_attn_fsdp(rank: int, world_size: int, port: int): + device = _init_dist(rank, world_size, port) + try: + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + + sp_size = 2 + # For FSDP+SP, SP is derived from dp/fsdp ranks. Use fsdp=world, dp=1. + device_mesh = DeviceMesh.from_sizes(fsdp_size=world_size, dp_size=1, ulysses_size=sp_size, device_type="cuda") + _setup_sp(device_mesh, sp_size) + sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size) + + batch_size = 2 + unpad_seq_len = 128 + hidden_dim = 256 + num_heads = 8 + assert num_heads % sp_size == 0 + + full_x = torch.randn(batch_size, unpad_seq_len, hidden_dim, device=device, dtype=torch.bfloat16) + dist.broadcast(full_x, src=0) + + # Each SP rank uses its local slice loss; across SP ranks this equals the full loss. + # For comparing input grads, compare the local slice grad against the corresponding slice of baseline. + sp_rank = dist.get_rank(sp_group) if sp_group is not None else 0 + local = unpad_seq_len // sp_size + start = sp_rank * local + end = start + local + + dp_x = full_x.detach().requires_grad_(True) + sp_x_local = full_x[:, start:end].detach().requires_grad_(True) + + attn_sp = _SingleAttention(hidden_dim, num_heads, sp_enabled=True).to(device=device, dtype=torch.bfloat16) + attn_dp = _SingleAttention(hidden_dim, num_heads, sp_enabled=False).to(device=device, dtype=torch.bfloat16) + _broadcast_params(attn_sp) + attn_dp.load_state_dict(attn_sp.state_dict()) + + fsdp = NativeFSDPStrategy(device_mesh=device_mesh, mixed_precision="bf16", fsdp_config={}) + attn_sp, _ = fsdp.wrap_model(attn_sp, optimizer=None) + attn_dp, _ = fsdp.wrap_model(attn_dp, optimizer=None) + + sp_out_local = attn_sp(sp_x_local) + sp_out_full = sequence_parallel.gather(sp_out_local, dim=1, position_ids=None)[:, :unpad_seq_len] + dp_out_full = attn_dp(dp_x)[:, :unpad_seq_len] + + torch.testing.assert_close(dp_out_full, sp_out_full, atol=2e-5, rtol=1e-5) + + sp_loss = sp_out_local.sum() * 2.0 + dp_loss = dp_out_full[:, start:end].sum() * 2.0 + sp_loss.backward() + dp_loss.backward() + + # Under FSDP2, grads are sharded; compare local shards directly (same mesh, same wrapping). + torch.testing.assert_close( + attn_dp.out_proj.weight.grad.detach(), attn_sp.out_proj.weight.grad.detach(), atol=2e-3, rtol=1e-4 + ) + torch.testing.assert_close( + attn_dp.q_proj.weight.grad.detach(), attn_sp.q_proj.weight.grad.detach(), atol=2e-3, rtol=1e-4 + ) + torch.testing.assert_close(dp_x.grad.detach()[:, start:end], sp_x_local.grad.detach(), atol=2e-5, rtol=1e-5) + finally: + dist.destroy_process_group() + + +class TestSequenceParallelSingleAttention(unittest.TestCase): + def test_single_attention(self): + if not dist.is_available(): + self.skipTest("torch.distributed is not available") + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for this test.") + world_size = 4 + if torch.cuda.device_count() < world_size: + self.skipTest("Requires at least 4 GPUs for sequence-parallel attention test.") + port = _find_free_port() + mp.spawn( + _run_worker_single_attn, + args=(world_size, port, False), + nprocs=world_size, + join=True, + ) + + def test_single_attention_padding(self): + if not dist.is_available(): + self.skipTest("torch.distributed is not available") + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for this test.") + world_size = 4 + if torch.cuda.device_count() < world_size: + self.skipTest("Requires at least 4 GPUs for sequence-parallel attention test.") + port = _find_free_port() + mp.spawn( + _run_worker_single_attn, + args=(world_size, port, True), + nprocs=world_size, + join=True, + ) + + def test_single_attention_fsdp(self): + if not dist.is_available(): + self.skipTest("torch.distributed is not available") + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for this test.") + world_size = 4 + if torch.cuda.device_count() < world_size: + self.skipTest("Requires at least 4 GPUs for sequence-parallel attention test.") + port = _find_free_port() + mp.spawn( + _run_worker_single_attn_fsdp, + args=(world_size, port), + nprocs=world_size, + join=True, + ) From 4858e37dafda226480ba78286cdca1b29af05c1b Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 6 Feb 2026 10:22:16 +0800 Subject: [PATCH 02/14] wip --- ...test_sequence_parallel_single_attention.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/transformers/test_sequence_parallel_single_attention.py b/tests/transformers/test_sequence_parallel_single_attention.py index 15a85e94..568cd041 100644 --- a/tests/transformers/test_sequence_parallel_single_attention.py +++ b/tests/transformers/test_sequence_parallel_single_attention.py @@ -152,6 +152,7 @@ def _run_worker_single_attn(rank: int, world_size: int, port: int, padding: bool sp_size = 2 device_mesh = DeviceMesh.from_sizes(dp_size=world_size, ulysses_size=sp_size, device_type="cuda") _setup_sp(device_mesh, sp_size) + sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size) batch_size = 2 unpad_seq_len = 127 if padding else 128 @@ -166,9 +167,17 @@ def _run_worker_single_attn(rank: int, world_size: int, port: int, padding: bool x = sequence_parallel.pad(full_x, padding_value=0.0, position_ids=None, dim=1) else: x = full_x + pad_seq_len = x.size(1) + assert pad_seq_len % sp_size == 0 dp_x = x.detach().requires_grad_(True) sp_x_local = sequence_parallel.split(x, dim=1, position_ids=None).detach().requires_grad_(True) + sp_rank = dist.get_rank(sp_group) if sp_group is not None else 0 + local = pad_seq_len // sp_size + start = sp_rank * local + end = start + local + valid_end = min(end, unpad_seq_len) + local_valid = max(valid_end - start, 0) attn_sp = _SingleAttention(hidden_dim, num_heads, sp_enabled=True).to(device=device, dtype=torch.bfloat16) attn_dp = _SingleAttention(hidden_dim, num_heads, sp_enabled=False).to(device=device, dtype=torch.bfloat16) @@ -183,24 +192,22 @@ def _run_worker_single_attn(rank: int, world_size: int, port: int, padding: bool torch.testing.assert_close(dp_out_full, sp_out_full, atol=2e-5, rtol=1e-5) - # backward (use local loss; sum grads across SP group) - sp_loss = sp_out_local.sum() * 2.0 + # backward: compare *local-slice* gradients directly (avoid any extra scaling assumptions). + sp_loss = sp_out_local[:, :local_valid].sum() * 2.0 sp_loss.backward() - sp_q_grad = attn_sp.q_proj.weight.grad.detach().clone() - sp_o_grad = attn_sp.out_proj.weight.grad.detach().clone() - dist.all_reduce(sp_q_grad, group=sequence_parallel._sp_group) - dist.all_reduce(sp_o_grad, group=sequence_parallel._sp_group) - sp_x_grad_full = sequence_parallel.gather(sp_x_local.grad.detach(), dim=1, position_ids=None)[:, :unpad_seq_len] + sp_q_grad = attn_sp.q_proj.weight.grad.detach() + sp_o_grad = attn_sp.out_proj.weight.grad.detach() + sp_x_grad_local = sp_x_local.grad.detach()[:, :local_valid] - dp_loss = dp_out_full.sum() * 2.0 + dp_loss = dp_out_full[:, start:valid_end].sum() * 2.0 dp_loss.backward() - dp_q_grad = attn_dp.q_proj.weight.grad.detach().clone() - dp_o_grad = attn_dp.out_proj.weight.grad.detach().clone() - dp_x_grad_full = dp_x.grad.detach()[:, :unpad_seq_len] + dp_q_grad = attn_dp.q_proj.weight.grad.detach() + dp_o_grad = attn_dp.out_proj.weight.grad.detach() + dp_x_grad_local = dp_x.grad.detach()[:, start:valid_end] torch.testing.assert_close(dp_o_grad, sp_o_grad, atol=2e-3, rtol=1e-4) torch.testing.assert_close(dp_q_grad, sp_q_grad, atol=2e-3, rtol=1e-4) - torch.testing.assert_close(dp_x_grad_full, sp_x_grad_full, atol=2e-5, rtol=1e-5) + torch.testing.assert_close(dp_x_grad_local, sp_x_grad_local, atol=2e-5, rtol=1e-5) finally: dist.destroy_process_group() From 1855ed74eaf5c107cdf6eaa4c54e1ce8d2201e32 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 6 Feb 2026 17:07:32 +0800 Subject: [PATCH 03/14] feat(tests): enhance sequence parallel attention test determinism - Add `_enable_strict_determinism` helper to disable TF32 and enable deterministic algorithms - Add `_to_local` helper to unwrap DTensors for gradient comparison - Update test to use full world size for sequence parallel group and increase head count - Switch to float32 dtype for stricter numerical alignment - Improve gradient comparison by cloning and unwrapping tensors --- ...test_sequence_parallel_single_attention.py | 143 +++++++++++++----- 1 file changed, 101 insertions(+), 42 deletions(-) diff --git a/tests/transformers/test_sequence_parallel_single_attention.py b/tests/transformers/test_sequence_parallel_single_attention.py index 568cd041..4ff489a4 100644 --- a/tests/transformers/test_sequence_parallel_single_attention.py +++ b/tests/transformers/test_sequence_parallel_single_attention.py @@ -37,6 +37,33 @@ def _broadcast_params(module: torch.nn.Module) -> None: dist.broadcast(p.data, src=0) +def _enable_strict_determinism() -> None: + # Reduce kernel variability: avoid TF32 and prefer deterministic algorithms. + # Note: some collectives/kernels may not have deterministic variants; keep warn_only. + if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"): + torch.backends.cuda.matmul.allow_tf32 = False + if hasattr(torch.backends, "cudnn"): + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.benchmark = False + if hasattr(torch, "set_float32_matmul_precision"): + torch.set_float32_matmul_precision("highest") + torch.use_deterministic_algorithms(True, warn_only=True) + +def _to_local(x: torch.Tensor) -> torch.Tensor: + # FSDP2 grads can be DTensors; unwrap to local for comparison. + if hasattr(x, "to_local"): + try: + return x.to_local() + except Exception: # noqa: BLE001 + pass + if hasattr(x, "local_tensor"): + try: + return x.local_tensor + except Exception: # noqa: BLE001 + pass + return x + + @contextlib.contextmanager def _force_sdpa_math(): """Force SDPA to use the math backend for stricter (more deterministic) alignment.""" @@ -96,7 +123,7 @@ def _local_attn( with _force_sdpa_math(): out = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal - ) + ) return out.permute(0, 2, 1, 3).contiguous() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -148,8 +175,10 @@ def _run_worker_single_attn(rank: int, world_size: int, port: int, padding: bool try: torch.manual_seed(0) torch.cuda.manual_seed_all(0) + _enable_strict_determinism() - sp_size = 2 + # Align with VeOmni's test semantics: one SP group across the whole world. + sp_size = world_size device_mesh = DeviceMesh.from_sizes(dp_size=world_size, ulysses_size=sp_size, device_type="cuda") _setup_sp(device_mesh, sp_size) sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size) @@ -157,10 +186,11 @@ def _run_worker_single_attn(rank: int, world_size: int, port: int, padding: bool batch_size = 2 unpad_seq_len = 127 if padding else 128 hidden_dim = 256 - num_heads = 8 # must be divisible by sp_size + num_heads = 16 # must be divisible by sp_size assert num_heads % sp_size == 0 + dtype = torch.float32 # maximize determinism for strict alignment - full_x = torch.randn(batch_size, unpad_seq_len, hidden_dim, device=device, dtype=torch.bfloat16) + full_x = torch.randn(batch_size, unpad_seq_len, hidden_dim, device=device, dtype=dtype) dist.broadcast(full_x, src=0) if padding: @@ -179,35 +209,52 @@ def _run_worker_single_attn(rank: int, world_size: int, port: int, padding: bool valid_end = min(end, unpad_seq_len) local_valid = max(valid_end - start, 0) - attn_sp = _SingleAttention(hidden_dim, num_heads, sp_enabled=True).to(device=device, dtype=torch.bfloat16) - attn_dp = _SingleAttention(hidden_dim, num_heads, sp_enabled=False).to(device=device, dtype=torch.bfloat16) + attn_sp = _SingleAttention(hidden_dim, num_heads, sp_enabled=True).to(device=device, dtype=dtype) + attn_dp = _SingleAttention(hidden_dim, num_heads, sp_enabled=False).to(device=device, dtype=dtype) _broadcast_params(attn_sp) attn_dp.load_state_dict(attn_sp.state_dict()) - # forward + # forward (SP) sp_out_local = attn_sp(sp_x_local) sp_out_full = sequence_parallel.gather(sp_out_local, dim=1, position_ids=None)[:, :unpad_seq_len] - dp_out_full = attn_dp(dp_x)[:, :unpad_seq_len] - - torch.testing.assert_close(dp_out_full, sp_out_full, atol=2e-5, rtol=1e-5) - - # backward: compare *local-slice* gradients directly (avoid any extra scaling assumptions). + # backward (SP): VeOmni uses overlapping grad on local output, then all-reduces param grads. sp_loss = sp_out_local[:, :local_valid].sum() * 2.0 sp_loss.backward() - sp_q_grad = attn_sp.q_proj.weight.grad.detach() - sp_o_grad = attn_sp.out_proj.weight.grad.detach() - sp_x_grad_local = sp_x_local.grad.detach()[:, :local_valid] + sp_q_grad = attn_sp.q_proj.weight.grad.detach().clone() + sp_o_grad = attn_sp.out_proj.weight.grad.detach().clone() + sp_x_grad_full = sequence_parallel.gather(sp_x_local.grad.detach(), dim=1, position_ids=None)[:, :unpad_seq_len] + + if sp_group is not None: + dist.all_reduce(sp_q_grad, op=dist.ReduceOp.SUM, group=sp_group) + dist.all_reduce(sp_o_grad, op=dist.ReduceOp.SUM, group=sp_group) + + # Disable SP for the baseline DP run (like set_ulysses_sequence_parallel_group(None)). + saved_world = sequence_parallel.world_size + saved_sp_world = sequence_parallel.sp_world_size + saved_group = sequence_parallel._sp_group + sequence_parallel.world_size = 1 + sequence_parallel.sp_world_size = 1 + sequence_parallel._sp_group = None + + # forward/backward (DP full sequence) + dp_out_full = attn_dp(dp_x)[:, :unpad_seq_len] + torch.testing.assert_close(dp_out_full, sp_out_full, atol=1e-6, rtol=1e-5) - dp_loss = dp_out_full[:, start:valid_end].sum() * 2.0 + dp_loss = dp_out_full.sum() * 2.0 dp_loss.backward() dp_q_grad = attn_dp.q_proj.weight.grad.detach() dp_o_grad = attn_dp.out_proj.weight.grad.detach() - dp_x_grad_local = dp_x.grad.detach()[:, start:valid_end] + dp_x_grad_full = dp_x.grad.detach()[:, :unpad_seq_len] - torch.testing.assert_close(dp_o_grad, sp_o_grad, atol=2e-3, rtol=1e-4) - torch.testing.assert_close(dp_q_grad, sp_q_grad, atol=2e-3, rtol=1e-4) - torch.testing.assert_close(dp_x_grad_local, sp_x_grad_local, atol=2e-5, rtol=1e-5) + # Restore SP globals (not strictly needed, but keeps teardown clean). + sequence_parallel.world_size = saved_world + sequence_parallel.sp_world_size = saved_sp_world + sequence_parallel._sp_group = saved_group + + torch.testing.assert_close(dp_o_grad, sp_o_grad, atol=1e-3, rtol=1e-4) + torch.testing.assert_close(dp_q_grad, sp_q_grad, atol=1e-3, rtol=1e-4) + torch.testing.assert_close(dp_x_grad_full, sp_x_grad_full, atol=1e-5, rtol=1e-5) finally: dist.destroy_process_group() @@ -217,8 +264,10 @@ def _run_worker_single_attn_fsdp(rank: int, world_size: int, port: int): try: torch.manual_seed(0) torch.cuda.manual_seed_all(0) + _enable_strict_determinism() - sp_size = 2 + # Use sp_size=world_size to avoid multiple independent SP groups while FSDP shards globally. + sp_size = world_size # For FSDP+SP, SP is derived from dp/fsdp ranks. Use fsdp=world, dp=1. device_mesh = DeviceMesh.from_sizes(fsdp_size=world_size, dp_size=1, ulysses_size=sp_size, device_type="cuda") _setup_sp(device_mesh, sp_size) @@ -227,50 +276,60 @@ def _run_worker_single_attn_fsdp(rank: int, world_size: int, port: int): batch_size = 2 unpad_seq_len = 128 hidden_dim = 256 - num_heads = 8 + num_heads = 16 assert num_heads % sp_size == 0 + dtype = torch.float32 - full_x = torch.randn(batch_size, unpad_seq_len, hidden_dim, device=device, dtype=torch.bfloat16) + full_x = torch.randn(batch_size, unpad_seq_len, hidden_dim, device=device, dtype=dtype) dist.broadcast(full_x, src=0) - # Each SP rank uses its local slice loss; across SP ranks this equals the full loss. - # For comparing input grads, compare the local slice grad against the corresponding slice of baseline. - sp_rank = dist.get_rank(sp_group) if sp_group is not None else 0 - local = unpad_seq_len // sp_size - start = sp_rank * local - end = start + local - dp_x = full_x.detach().requires_grad_(True) - sp_x_local = full_x[:, start:end].detach().requires_grad_(True) + sp_x_local = sequence_parallel.split(full_x, dim=1, position_ids=None).detach().requires_grad_(True) - attn_sp = _SingleAttention(hidden_dim, num_heads, sp_enabled=True).to(device=device, dtype=torch.bfloat16) - attn_dp = _SingleAttention(hidden_dim, num_heads, sp_enabled=False).to(device=device, dtype=torch.bfloat16) + attn_sp = _SingleAttention(hidden_dim, num_heads, sp_enabled=True).to(device=device, dtype=dtype) + attn_dp = _SingleAttention(hidden_dim, num_heads, sp_enabled=False).to(device=device, dtype=dtype) _broadcast_params(attn_sp) attn_dp.load_state_dict(attn_sp.state_dict()) - fsdp = NativeFSDPStrategy(device_mesh=device_mesh, mixed_precision="bf16", fsdp_config={}) + fsdp = NativeFSDPStrategy(device_mesh=device_mesh, mixed_precision="no", fsdp_config={}) attn_sp, _ = fsdp.wrap_model(attn_sp, optimizer=None) attn_dp, _ = fsdp.wrap_model(attn_dp, optimizer=None) + # SP forward/backward: local loss; across ranks sum(loss_local) == DP full loss. sp_out_local = attn_sp(sp_x_local) sp_out_full = sequence_parallel.gather(sp_out_local, dim=1, position_ids=None)[:, :unpad_seq_len] - dp_out_full = attn_dp(dp_x)[:, :unpad_seq_len] - - torch.testing.assert_close(dp_out_full, sp_out_full, atol=2e-5, rtol=1e-5) - sp_loss = sp_out_local.sum() * 2.0 - dp_loss = dp_out_full[:, start:end].sum() * 2.0 sp_loss.backward() + sp_x_grad_local = sp_x_local.grad.detach() + sp_x_grad_full = sequence_parallel.gather(sp_x_grad_local, dim=1, position_ids=None)[:, :unpad_seq_len] + + # DP forward/backward: full loss on full sequence. + dp_out_full = attn_dp(dp_x)[:, :unpad_seq_len] + torch.testing.assert_close(dp_out_full, sp_out_full, atol=1e-6, rtol=1e-5) + # In FSDP, per-rank losses are effectively summed across ranks when forming param grads. + # DP uses identical full-seq inputs on every rank, so scale by world_size to match the + # global objective of SP (which partitions the sequence across ranks). + dp_loss = dp_out_full.sum() * (2.0 / float(world_size)) dp_loss.backward() # Under FSDP2, grads are sharded; compare local shards directly (same mesh, same wrapping). torch.testing.assert_close( - attn_dp.out_proj.weight.grad.detach(), attn_sp.out_proj.weight.grad.detach(), atol=2e-3, rtol=1e-4 + _to_local(attn_dp.out_proj.weight.grad.detach()), + _to_local(attn_sp.out_proj.weight.grad.detach()), + atol=1e-3, + rtol=1e-4, ) torch.testing.assert_close( - attn_dp.q_proj.weight.grad.detach(), attn_sp.q_proj.weight.grad.detach(), atol=2e-3, rtol=1e-4 + _to_local(attn_dp.q_proj.weight.grad.detach()), + _to_local(attn_sp.q_proj.weight.grad.detach()), + atol=1e-3, + rtol=1e-4, ) - torch.testing.assert_close(dp_x.grad.detach()[:, start:end], sp_x_local.grad.detach(), atol=2e-5, rtol=1e-5) + # dp_x.grad is not reduced across ranks; rescale to match the unscaled SP full loss. + dp_x_grad_full = dp_x.grad.detach()[:, :unpad_seq_len] * float(world_size) + torch.testing.assert_close(dp_x_grad_full, sp_x_grad_full, atol=1e-5, rtol=1e-5) + + # Only validate forward + parameter grads for the FSDP+SP case. finally: dist.destroy_process_group() From cb6e343497133d089c8073893ee5a2e978dec065 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Fri, 6 Feb 2026 17:09:25 +0800 Subject: [PATCH 04/14] remove __init__ --- tests/transformers/__init__.py | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 tests/transformers/__init__.py diff --git a/tests/transformers/__init__.py b/tests/transformers/__init__.py deleted file mode 100644 index 475cff64..00000000 --- a/tests/transformers/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. - From 37774fb6a373a9ab1611d602aab8c9542d9b6fea Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Sat, 7 Feb 2026 19:44:29 +0800 Subject: [PATCH 05/14] feat(sequence_parallel): refactor config handling and remove padding-free logic - Replace HfConfigFactory utility with direct get_config_attr function - Move get_llm_model to shared transformers utilities - Remove padding_free parameter and related conditional logic - Simplify attention mask construction for padded tokens - Update SequenceParallelConfig to drop padding_free field --- .../strategy/sequence_parallel.py | 84 ++++++------------- src/twinkle/utils/transformers_utils.py | 57 ++++++++++++- 2 files changed, 82 insertions(+), 59 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 8e2aa749..e970cbd5 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -10,16 +10,11 @@ from dataclasses import asdict, dataclass, is_dataclass from twinkle.utils import DeviceMesh +from twinkle.utils.transformers_utils import get_llm_model -def get_llm_model(model): # type: ignore - return getattr(model, "language_model", model) - - -class HfConfigFactory: # type: ignore - @staticmethod - def get_config_attr(config, attr_name: str, include_vit: bool = False): - return getattr(config, attr_name, None) +def get_config_attr(config, key, default=None): + return getattr(config, key, default) def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor): @@ -582,7 +577,7 @@ def _is_moe_model(config) -> bool: if 'Moe' in config.__class__.__name__: return True for key in ['num_experts', 'num_experts_per_tok', 'moe_intermediate_size']: - if HfConfigFactory.get_config_attr(config, key): + if get_config_attr(config, key): return True return False @@ -591,18 +586,16 @@ def prepare( sp_size: int, model: torch.nn.Module, tokenizer: PreTrainedTokenizer, - padding_free: bool, device_mesh: Optional[DeviceMesh] = None, ): - self.num_heads = HfConfigFactory.get_config_attr(model.config, 'num_key_value_heads') + self.num_heads = get_config_attr(model.config, 'num_key_value_heads') if self.num_heads is None: - self.num_heads = HfConfigFactory.get_config_attr(model.config, 'num_attention_heads') + self.num_heads = get_config_attr(model.config, 'num_attention_heads') assert self.num_heads is not None, 'Cannot find num_heads config in config.json' if sp_size > 1 and self.num_heads % sp_size != 0: raise ValueError( f'sp_size ({sp_size}) must divide num_heads ({self.num_heads}) for ulysses sequence parallel.' ) - self.padding_free = padding_free self.world_size = sp_size llm_model = get_llm_model(model) @@ -627,8 +620,6 @@ def prepare( self.model_dtype = next(model.parameters()).dtype self.tokenizer = tokenizer - if not self.padding_free: - pass def pad(self, tensor, padding_value, position_ids=None, dim=1): """Pad tensor for sequence parallel""" @@ -683,26 +674,6 @@ def split(self, input, dim: int, position_ids=None): output = tensor_list[rank].contiguous() return output - def pad_and_split_mm_tokens(self, visual_mask, mm_embeds): - input_ids = self.extra_kwargs['input_ids'] - empty_embeds = torch.empty( - (input_ids.shape[0], input_ids.shape[1], mm_embeds.shape[-1])).to(mm_embeds.device).to(mm_embeds.dtype) - empty_embeds[visual_mask] = mm_embeds - - embeds = SimpleNamespace(weight=mm_embeds) - - _, split_input_embeds, _, _, _, _, extra_values = self.pad_and_split_inputs( - None, - empty_embeds, - None, - None, - None, - None, - embeds, - self.real_position_ids, - extra_split_values=[(visual_mask, 0, -1)]) - visual_mask = extra_values[0] - return visual_mask, split_input_embeds[visual_mask] def pad_and_split_inputs(self, input_ids, @@ -754,12 +725,16 @@ def pad_and_split_inputs(self, loss_scale = self.pad(loss_scale, padding_value=0., position_ids=real_position_ids) if real_position_ids is not None: real_position_ids = self.pad(real_position_ids, padding_value=-1, position_ids=real_position_ids) - if (input_ids is not None or input_embeds is not None) and batch_size > 1: + if input_ids is not None or input_embeds is not None: inputs = input_ids if input_ids is not None else input_embeds attn_shape = inputs.shape[1] # The sequence length if attention_mask is None: + # Build an attention mask from the (unpadded) real_position_ids, then pad it to the + # communication-aligned length. This keeps padded tokens from affecting attention, + # including for packed/padding-free style batches with batch_size==1. attention_mask = torch.ones_like(real_position_ids) - # no need position_ids here, because padding_free does not need attention_mask, + # We don't need position_ids here. When attention_mask is used, it's only to + # keep padded tokens from affecting attention computation. attention_mask = self.pad(attention_mask, padding_value=0) cache_position = torch.arange(0, attn_shape, device=inputs.device) # pad attention mask to 4d to avoid calculation errors @@ -831,7 +806,6 @@ def prepare_inputs(self, inputs): class SequenceParallelConfig: enabled: bool = True ulysses_size: Optional[int] = None - padding_free: bool = False gather_logits: bool = True @@ -866,7 +840,6 @@ def __init__( self.sp_config = sp_config or {} self.enabled = bool(self.sp_config.get("enabled", True)) self.ulysses_size = _get_ulysses_size(device_mesh, self.sp_config) - self.padding_free = bool(self.sp_config.get("padding_free", False)) self._model_ref = model self._tokenizer_id = tokenizer_id self._tokenizer = None @@ -901,7 +874,6 @@ def initialize(self) -> bool: self.ulysses_size, self._model_ref, tokenizer, - self.padding_free, device_mesh=self.device_mesh, ) self._initialized = True @@ -919,29 +891,25 @@ def postprocess_outputs(self, outputs: Any) -> Any: or not self.sp_config.get("gather_logits", True) ): return outputs - # Optionally reassemble full-seq logits for downstream consumers. - logits = None - if hasattr(outputs, "logits"): - logits = outputs.logits - elif isinstance(outputs, dict): - logits = outputs.get("logits") - elif isinstance(outputs, (list, tuple)) and len(outputs) > 0 and torch.is_tensor(outputs[0]): - logits = outputs[0] + # Twinkle expects dict-like ModelOutput containers in the main training path + # (uses `.get(...)` and `outputs[...] = ...`). Keep SP postprocess consistent. + if outputs is None or not hasattr(outputs, "get") or not hasattr(outputs, "__setitem__"): + raise TypeError( + "SequenceParallelStrategy.postprocess_outputs expects a dict-like ModelOutput. " + f"Got type={type(outputs)}" + ) + logits = outputs.get("logits", None) if logits is None or not torch.is_tensor(logits) or logits.dim() < 2: return outputs gathered = sequence_parallel.gather( logits, dim=1, position_ids=sequence_parallel.real_position_ids ) - if hasattr(outputs, "logits"): - outputs.logits = gathered - return outputs - if isinstance(outputs, dict): - outputs["logits"] = gathered - return outputs - if isinstance(outputs, (list, tuple)) and len(outputs) > 0: - new = list(outputs) - new[0] = gathered - return type(outputs)(new) + # Scheme A: SP pads to make seq_len divisible by sp_size. Trim back to the original + # (unpadded) length using the cached real_position_ids. + real_pos = sequence_parallel.real_position_ids + if real_pos is not None and torch.is_tensor(real_pos) and real_pos.dim() >= 2: + gathered = gathered[:, : real_pos.shape[1]].contiguous() + outputs["logits"] = gathered return outputs def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore_index: int = -100) -> torch.Tensor: diff --git a/src/twinkle/utils/transformers_utils.py b/src/twinkle/utils/transformers_utils.py index 3a5ff344..f6fccd3c 100644 --- a/src/twinkle/utils/transformers_utils.py +++ b/src/twinkle/utils/transformers_utils.py @@ -133,4 +133,59 @@ def get_modules_to_not_convert(model): if 'linear' in m.__class__.__name__.lower() and (any(n.endswith(suffix) for suffix in suffix_list) or any(n.startswith(prefix) for prefix in prefix_list)): res.append(n) - return res if res else None \ No newline at end of file + return res if res else None + + +def get_llm_model(model, *, model_meta=None, inner_backbone: bool = True): + """Best-effort extraction of the LLM module from a (possibly wrapped) model. + + This mirrors the common pattern used by Swift/PEFT/Accelerate stacks: + - unwrap parallel wrappers (DDP/FSDP/Accelerate) + - unwrap PEFT/Swift wrappers (if present) + - use `model_meta.model_arch.language_model` to locate the LLM in multimodal models + - optionally return the inner backbone (e.g. `QwenModel`/`LlamaModel`) via `.model` + """ + # 1) Unwrap parallel wrappers (Accelerate). + try: + from accelerate.utils import extract_model_from_parallel # type: ignore + + model = extract_model_from_parallel(model) + except Exception: + pass + + # 2) Unwrap PEFT/Swift wrappers. + try: + from peft import PeftModel # type: ignore + + if isinstance(model, PeftModel): + model = model.model + except Exception: + pass + + try: + from swift.tuners import SwiftModel # type: ignore + + if isinstance(model, SwiftModel): + model = model.model + except Exception: + pass + + # 3) Locate the language model module in multimodal containers via model_meta. + if model_meta is None: + model_meta = getattr(model, "model_meta", None) + llm_model = model + model_arch = getattr(model_meta, "model_arch", None) if model_meta is not None else None + llm_prefix = getattr(model_arch, "language_model", None) if model_arch is not None else None + if llm_prefix: + # Convention: `language_model` is a list of candidate prefixes. + llm_model = deep_getattr(model, llm_prefix[0]) + else: + llm_model = getattr(model, "language_model", model) + + # 4) Return the inner backbone if requested. + if inner_backbone: + if hasattr(llm_model, "thinker"): + llm_model = llm_model.thinker.model + elif hasattr(llm_model, "model"): + llm_model = llm_model.model + return llm_model From 4b2215ed3de703670e31766fcf6907a83b90d110 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Sun, 8 Feb 2026 21:50:37 +0800 Subject: [PATCH 06/14] feat(sequence_parallel): enforce flash_attention_2 for packed batches - Add detection of packed batches via `_is_packed_position_ids` heuristic - Raise RuntimeError when SDPA backend is used with packed batches, as SDPA lacks native packed/varlen support - Build 2D attention_mask for padded sequences to ensure correct FlashAttention2 unpad behavior - Avoid unnecessary 4D causal mask generation for packed/padding-free batches --- .../strategy/sequence_parallel.py | 125 +++++++++++++++--- .../model/transformers/transformers.py | 4 + 2 files changed, 112 insertions(+), 17 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index e970cbd5..4f7ccaea 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -191,8 +191,8 @@ def forward(ctx, loss, labels, gather_idx=None, position_ids=None): @staticmethod def backward(ctx, *grad_output): - # Split grads back to local sequence chunk; scale for all-gather semantics. - _grad = grad_output[0] * sequence_parallel.world_size + # Split grads back to local sequence chunk. + _grad = grad_output[0] if sequence_parallel.world_size > 1 and sequence_parallel._sp_group is not None: _grad = sequence_parallel.split(_grad, dim=ctx.gather_idx, position_ids=ctx.position_ids).contiguous() return _grad, None, None, None @@ -463,7 +463,31 @@ def _attention(query, key, value, *args, **kwargs): query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) - if 'cu_seq_lens_q' in kwargs: + # Packed batches (produced by PackingDataset + padding_free collate) require FA2 varlen + # semantics to avoid cross-subsequence attention. We derive cu_seqlens from position_ids + # resets (0,1,...) and pass cu_seq_lens_* to FA2. + if self.extra_kwargs.get('is_packed', False): + position_ids = kwargs.get('position_ids') + if position_ids is None: + position_ids = self.real_position_ids + # Treat SP-alignment padding (-1) as separate 1-token sequences by mapping -1 -> 0. + pos = position_ids + if pos.dim() == 1: + pos = pos.unsqueeze(0) + pos = pos.clone() + pos[pos < 0] = 0 + + cu_seqlens = get_cu_seqlens_from_position_ids(pos).to(torch.int32) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + assert query.shape[2] == cu_seqlens[-1] + kwargs['cu_seq_lens_q'] = cu_seqlens + kwargs['cu_seq_lens_k'] = cu_seqlens + kwargs['max_length_q'] = max_seqlen + kwargs['max_length_k'] = max_seqlen + # Do not use attention_mask-based unpadding when using explicit cu_seqlens. + if len(args) > 0: + args = (None, *args[1:]) + elif 'cu_seq_lens_q' in kwargs: position_ids = kwargs.get('position_ids') if position_ids is None: position_ids = self.real_position_ids @@ -490,6 +514,14 @@ def local_sdpa_attn(module: torch.nn.Module, query_states, key_states, value_sta if self.world_size == 1 or module.__class__ not in [m.__class__ for m in text_model.modules()]: return ALL_ATTENTION_FUNCTIONS['sdpa_origin'](module, query_states, key_states, value_states, attention_mask, *args, **kwargs) + # Policy: packed (PackingDataset/padding-free) batches require FlashAttention2 varlen/packed semantics. + # SDPA does not have a native packed/varlen interface; supporting packed batches would require building a + # large block-diagonal causal mask (slow / memory heavy). + if self.extra_kwargs.get('is_packed', False): + raise RuntimeError( + 'SequenceParallel: detected packed batch (position_ids contains multiple sequences). ' + 'SDPA backend is not supported for packed batches; please use flash_attention_2.' + ) if dist_attn.local_attn is None: def _attention(query, key, value, *args, **kwargs): @@ -702,6 +734,8 @@ def pad_and_split_inputs(self, """ tokenizer = self.tokenizer real_position_ids = real_position_ids if real_position_ids is not None else position_ids + # Track packed batches to drive attention backend behavior (packed => require flash_attention_2 varlen). + self.extra_kwargs['is_packed'] = self._is_packed_position_ids(real_position_ids) extra_values = [] batch_size = input_ids.shape[ 0] if input_ids is not None else input_embeds.shape[0] if input_embeds is not None else None @@ -725,16 +759,20 @@ def pad_and_split_inputs(self, loss_scale = self.pad(loss_scale, padding_value=0., position_ids=real_position_ids) if real_position_ids is not None: real_position_ids = self.pad(real_position_ids, padding_value=-1, position_ids=real_position_ids) - if input_ids is not None or input_embeds is not None: + # Build a 2D attention_mask whenever we padded for SP alignment so FlashAttention2 can unpad correctly. + # For packed batches (batch_size==1 with multiple position_id resets), relying on position_ids alone is + # unsafe if we also appended SP-alignment padding (position_ids=-1), because HF's FA2 varlen path will + # include the padded tail in the last segment when attention_mask is None. + if (input_ids is not None or input_embeds is not None) and batch_size > 1: + # not padding_free, so not ring-attention inputs = input_ids if input_ids is not None else input_embeds attn_shape = inputs.shape[1] # The sequence length if attention_mask is None: - # Build an attention mask from the (unpadded) real_position_ids, then pad it to the - # communication-aligned length. This keeps padded tokens from affecting attention, - # including for packed/padding-free style batches with batch_size==1. - attention_mask = torch.ones_like(real_position_ids) - # We don't need position_ids here. When attention_mask is used, it's only to - # keep padded tokens from affecting attention computation. + # Mask out padded positions introduced by sequence-parallel padding. + # `real_position_ids` is padded with `-1` (see above), so use it to build a valid-token mask. + attention_mask = (real_position_ids != -1).to(dtype=torch.int64) + # no need position_ids here, because padding_free does not need attention_mask, + # so this is not ring-attention attention_mask = self.pad(attention_mask, padding_value=0) cache_position = torch.arange(0, attn_shape, device=inputs.device) # pad attention mask to 4d to avoid calculation errors @@ -750,8 +788,22 @@ def pad_and_split_inputs(self, if input_embeds is not None: input_embeds = self.split(input_embeds, dim=1, position_ids=real_position_ids) if labels is not None: - # Align next-token labels before splitting so each rank sees local targets. - labels = torch.roll(labels, shifts=-1, dims=-1) + if self.extra_kwargs.get('is_packed', False) and real_position_ids is not None: + # PackingDataset + padding_free collate concatenates multiple sequences into a single token stream. + # `position_ids` resets to 0 at each boundary, but our labels are already next-token aligned by + # Template._roll_labels(). Therefore the cross-subsequence supervision term lives at the *previous* + # token index (the token right before a boundary start). + # + # Example (boundary at index b where position_ids[b] == 0): + # - Bad term is: token[b-1] predicting token[b] + # - In next-token-aligned labels, this appears at labels[b-1] + boundary_starts = (real_position_ids == 0) + prev = torch.zeros_like(boundary_starts, dtype=torch.bool) + prev[..., 1:] = boundary_starts[..., :-1] + labels = labels.clone() + labels[prev] = -100 + # Also avoid any potential wrap-around supervision at the end of the concatenated stream. + labels[..., -1] = -100 labels = self.split(labels, dim=-1, position_ids=real_position_ids) if loss_scale is not None: loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1) @@ -777,6 +829,27 @@ def _init_device_mesh(self, device_mesh: Optional[DeviceMesh] = None): if self._sp_group is None and self.sp_world_size > 1: raise RuntimeError("Failed to create sequence-parallel group from DeviceMesh.") + @staticmethod + def _is_packed_position_ids(position_ids: Optional[torch.Tensor]) -> bool: + """Heuristic: detect packed samples by multiple (0,1,...) resets in position_ids. + + PackingDataset packs multiple sequences into one row by resetting position_ids to 0/1/... at each boundary. + """ + if position_ids is None or not torch.is_tensor(position_ids): + return False + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + if position_ids.dim() != 2: + return False + # A batch may contain multiple packed samples; consider it "packed" if any row is packed. + for i in range(position_ids.size(0)): + row = position_ids[i] + zero_count = int((row == 0).sum().item()) + one_count = int((row == 1).sum().item()) + if zero_count > 1 and one_count > 1: + return True + return False + def prepare_inputs(self, inputs): """Prepare inputs @@ -789,6 +862,7 @@ def prepare_inputs(self, inputs): position_ids = inputs.get('position_ids') if position_ids is not None and input_ids is not None and position_ids.shape[0] == input_ids.shape[0]: self.extra_kwargs['position_ids'] = position_ids.clone() + self.extra_kwargs['is_packed'] = self._is_packed_position_ids(position_ids) if input_ids is not None: self.extra_kwargs['input_ids'] = input_ids.clone() if 'labels' in inputs: @@ -807,6 +881,7 @@ class SequenceParallelConfig: enabled: bool = True ulysses_size: Optional[int] = None gather_logits: bool = True + loss_reduction: str = "mean" def _get_ulysses_size(device_mesh, sp_config: Optional[Dict[str, Any]] = None) -> int: @@ -917,12 +992,28 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore return loss if labels is None or sequence_parallel._sp_group is None: return loss - # Reduce loss with token-count normalization across SP ranks. + # Compute full-sequence loss in forward, but keep backward local to this rank. + reduction = str(self.sp_config.get("loss_reduction", "mean")).lower() + if reduction == "none": + raise ValueError( + "SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. " + "Please aggregate per-token losses before calling reduce_loss." + ) num_valid_tokens = (labels != ignore_index).sum().to(loss.device) - reduced_loss = loss * num_valid_tokens - dist.all_reduce(reduced_loss, group=sequence_parallel._sp_group) - dist.all_reduce(num_valid_tokens, group=sequence_parallel._sp_group) - return reduced_loss / num_valid_tokens + if reduction == "sum": + local_sum = loss + global_sum = local_sum.detach().clone() + dist.all_reduce(global_sum, group=sequence_parallel._sp_group) + return global_sum + (local_sum - local_sum.detach()) + # Default to mean reduction. + local_sum = loss * num_valid_tokens + global_sum = local_sum.detach().clone() + dist.all_reduce(global_sum, group=sequence_parallel._sp_group) + global_tokens = num_valid_tokens.detach().clone() + dist.all_reduce(global_tokens, group=sequence_parallel._sp_group) + if global_tokens.item() == 0: + return loss + return (global_sum + (local_sum - local_sum.detach())) / global_tokens def wrap_model(self, model, optimizer=None): self.initialize() diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 2df9b5bf..eea51435 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -425,6 +425,10 @@ def calculate_loss(self, **kwargs): optimizer_config = self.optimizer_group[adapter_name] optimizer_config.num_tokens += counts.item() if self.sp_strategy is not None and 'labels' in inputs: + if "loss_reduction" not in self.sp_strategy.sp_config: + reduction = getattr(loss_instance, "reduction", None) + if reduction is not None: + self.sp_strategy.sp_config["loss_reduction"] = str(reduction) loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels']) optimizer_config.loss_value += loss_value outputs['loss'] = optimizer_config.loss_value From f9a8f783824ce9e7b918b499a6e9d3f2946fa955 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 9 Feb 2026 14:56:31 +0800 Subject: [PATCH 07/14] feat(sft): add single controller SP packing example for Qwen2.5-7B Introduce a new cookbook script demonstrating supervised fine-tuning with a single controller using sequence parallelism (SP) and FSDP across 4 GPUs. The example includes: - Device mesh configuration with dp=2 and fsdp=2 dimensions - PackingDataset setup with self-cognition data and left truncation - Training loop with LoRA adapter, AdamW optimizer, and periodic evaluation - Checkpoint saving based on loss improvement - Validation of FSDP + SP input slicing across multiple GPUs --- .../sft/single_controller_sp_packing.py | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 cookbook/legacy/sft/single_controller_sp_packing.py diff --git a/cookbook/legacy/sft/single_controller_sp_packing.py b/cookbook/legacy/sft/single_controller_sp_packing.py new file mode 100644 index 00000000..3197aa84 --- /dev/null +++ b/cookbook/legacy/sft/single_controller_sp_packing.py @@ -0,0 +1,108 @@ +from functools import partial +import numpy as np +from peft import LoraConfig + +import twinkle +from twinkle import get_logger, DeviceGroup, Platform, DeviceMesh +from twinkle.dataloader import DataLoader +from twinkle.dataset import PackingDataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor + +logger = get_logger() +MODEL_ID = 'ms://Qwen/Qwen2.5-7B-Instruct' + +device_group = [ + DeviceGroup( + name="default", + ranks=[0, 1, 2, 3], + device_type=Platform.get_platform().device_prefix(), + ) +] + +# FSDP + SP validation over 4 GPUs: dp=2, fsdp=2 (SP only affects input slicing) +device_mesh = DeviceMesh( + device_type="cuda", + mesh=np.arange(4).reshape(2, 2), + mesh_dim_names=("dp", "fsdp"), + ulysses_size=2, +) + +twinkle.initialize( + mode="ray", + nproc_per_node=4, + groups=device_group, + global_device_mesh=device_mesh, + lazy_collect=False, +) + + +def create_dataset(data_slice=None): + dataset = PackingDataset( + dataset_meta=DatasetMeta("ms://swift/self-cognition", data_slice=data_slice) + ) + dataset.set_template( + "Template", + model_id=MODEL_ID, + truncation_strategy="left", + max_length=64, + ) + dataset.map(SelfCognitionProcessor("twinkle模型", "twinkle团队")) + dataset.encode(batched=True) + dataset.pack_dataset() + return dataset + + +def eval(model: TransformersModel): + dataloader = DataLoader( + dataset=partial(create_dataset, data_slice=range(20)), + batch_size=4, + drop_last=True, + device_mesh=device_mesh, + remote_group="default", + ) + for step, batch in enumerate(dataloader): + model.forward_only(inputs=batch, adapter_name="default") + model.calculate_loss(adapter_name="default") + metrics = model.calculate_metric(is_training=False, adapter_name="default") + return metrics() + + +def train(): + dataloader = DataLoader( + dataset=partial(create_dataset, data_slice=None), + batch_size=4, + device_mesh=device_mesh, + remote_group="default", + ) + + model = TransformersModel( + model_id=MODEL_ID, + device_mesh=device_mesh, + strategy="native_fsdp", + remote_group="default", + ) + + lora_config = LoraConfig(target_modules="all-linear") + model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1) + model.set_optimizer("AdamW", lr=1e-4, adapter_name="default") + + loss_metric = 99.0 + for step, batch in enumerate(dataloader): + if isinstance(batch, list) and len(batch) == 0: + continue + output = model.forward_backward(inputs=batch, adapter_name="default") + loss_value = output() if callable(output) else output + logger.info(f"step {step}, loss: {loss_value}") + model.clip_grad_and_step(adapter_name="default") + if step % 50 == 0 and step > 0: + metrics = eval(model) + logger.info(f"Current is step {step} of {len(dataloader)}, metric: {metrics}") + metrics["step"] = step + if loss_metric > metrics["loss"]: + model.save(f"checkpoint-{step}") + loss_metric = metrics["loss"] + + +if __name__ == "__main__": + train() From 2086e874c9322916ba8a1d2ff2ebde143a6c7907 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 9 Feb 2026 17:35:23 +0800 Subject: [PATCH 08/14] fix loss computation bug --- .../sft/single_controller_sp_packing.py | 4 +- cookbook/legacy/single_controller_sp.py | 43 ++++++------------- .../strategy/sequence_parallel.py | 13 +++++- .../model/transformers/transformers.py | 1 + 4 files changed, 29 insertions(+), 32 deletions(-) diff --git a/cookbook/legacy/sft/single_controller_sp_packing.py b/cookbook/legacy/sft/single_controller_sp_packing.py index 3197aa84..24282757 100644 --- a/cookbook/legacy/sft/single_controller_sp_packing.py +++ b/cookbook/legacy/sft/single_controller_sp_packing.py @@ -8,6 +8,7 @@ from twinkle.dataset import PackingDataset, DatasetMeta from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.processor import InputProcessor logger = get_logger() MODEL_ID = 'ms://Qwen/Qwen2.5-7B-Instruct' @@ -82,9 +83,10 @@ def train(): strategy="native_fsdp", remote_group="default", ) - lora_config = LoraConfig(target_modules="all-linear") model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1) + model.set_processor(InputProcessor, padding_free=True, adapter_name="default") + model.set_loss("CrossEntropyLoss", reduction="mean", adapter_name="default") model.set_optimizer("AdamW", lr=1e-4, adapter_name="default") loss_metric = 99.0 diff --git a/cookbook/legacy/single_controller_sp.py b/cookbook/legacy/single_controller_sp.py index 995d59d9..63ed61c9 100644 --- a/cookbook/legacy/single_controller_sp.py +++ b/cookbook/legacy/single_controller_sp.py @@ -46,27 +46,13 @@ def create_dataset(data_slice=None): "Template", model_id=MODEL_ID, truncation_strategy="left", - max_length=64, + max_length=256, ) dataset.map(SelfCognitionProcessor("twinkle模型", "twinkle团队")) dataset.encode(batched=True) return dataset -def eval(model: TransformersModel): - dataloader = DataLoader( - dataset=partial(create_dataset, data_slice=range(20)), - batch_size=4, - drop_last=True, - device_mesh=device_mesh, - remote_group="default", - ) - for step, batch in enumerate(dataloader): - model.forward_only(inputs=batch, adapter_name="default") - model.calculate_loss(adapter_name="default") - metrics = model.calculate_metric(is_training=False, adapter_name="default") - return metrics() - def train(): dataloader = DataLoader( @@ -87,21 +73,20 @@ def train(): model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1) model.set_optimizer("AdamW", lr=1e-4, adapter_name="default") - loss_metric = 99.0 for step, batch in enumerate(dataloader): - if isinstance(batch, list) and len(batch) == 0: - continue - output = model.forward_backward(inputs=batch, adapter_name="default") - loss_value = output() if callable(output) else output - logger.info(f"step {step}, loss: {loss_value}") - model.clip_grad_and_step(adapter_name="default") - if step % 50 == 0 and step > 0: - metrics = eval(model) - logger.info(f"Current is step {step} of {len(dataloader)}, metric: {metrics}") - metrics["step"] = step - if loss_metric > metrics["loss"]: - model.save(f"checkpoint-{step}") - loss_metric = metrics["loss"] + model.forward_backward(inputs=batch, adapter_name='default') + model.clip_grad_and_step(adapter_name='default') + if step % 1 == 0: + metric = model.calculate_metric(is_training=True, adapter_name='default') + _metrics = {} + for key, value in metric.items(): + try: + value = float(value) + _metrics[key] = value + except: + pass + logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') + model.save(f'last-checkpoint', interval=1) if __name__ == "__main__": diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 4f7ccaea..c931ff23 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import math +import os from functools import partial from types import SimpleNamespace from typing import Any, Dict, Optional, Tuple, Union @@ -1004,7 +1005,11 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore local_sum = loss global_sum = local_sum.detach().clone() dist.all_reduce(global_sum, group=sequence_parallel._sp_group) - return global_sum + (local_sum - local_sum.detach()) + out = global_sum + (local_sum - local_sum.detach()) + if sequence_parallel.world_size > 1: + out_metric = out.detach() / sequence_parallel.world_size + return out_metric + (out - out.detach()) + return out # Default to mean reduction. local_sum = loss * num_valid_tokens global_sum = local_sum.detach().clone() @@ -1013,7 +1018,11 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore dist.all_reduce(global_tokens, group=sequence_parallel._sp_group) if global_tokens.item() == 0: return loss - return (global_sum + (local_sum - local_sum.detach())) / global_tokens + out = (global_sum + (local_sum - local_sum.detach())) / global_tokens + if sequence_parallel.world_size > 1: + out_metric = out.detach() / sequence_parallel.world_size + return out_metric + (out - out.detach()) + return out def wrap_model(self, model, optimizer=None): self.initialize() diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index eea51435..432d1ca3 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import contextlib +import os import json import os import re From 2547659bba6846480989faf4c2120efe645e2a32 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 9 Feb 2026 19:43:48 +0800 Subject: [PATCH 09/14] feat(cookbook): add single controller SP example and reorganize transformers cookbook - Add new single_controller_sp.py example demonstrating FSDP + SP validation over 4 GPUs - Move legacy single_controller_sp.py to transformers/sp_fsdp_dense.py - Add shell script sp_fsdp_dense.sh for running the example - Update imports and structure to use twinkle framework components --- cookbook/transformers/single_controller_sp.py | 90 +++++++++++++++++++ .../sp_fsdp_dense.py} | 0 cookbook/transformers/sp_fsdp_dense.sh | 1 + 3 files changed, 91 insertions(+) create mode 100644 cookbook/transformers/single_controller_sp.py rename cookbook/{legacy/single_controller_sp.py => transformers/sp_fsdp_dense.py} (100%) create mode 100644 cookbook/transformers/sp_fsdp_dense.sh diff --git a/cookbook/transformers/single_controller_sp.py b/cookbook/transformers/single_controller_sp.py new file mode 100644 index 00000000..2c5914a1 --- /dev/null +++ b/cookbook/transformers/single_controller_sp.py @@ -0,0 +1,90 @@ +from functools import partial +import numpy as np +import torch +from peft import LoraConfig + +import twinkle +from twinkle import get_logger, DeviceGroup, Platform, DeviceMesh +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor + +logger = get_logger() +MODEL_ID = 'ms://Qwen/Qwen2.5-7B-Instruct' + +device_group = [ + DeviceGroup( + name="default", + ranks=[0, 1, 2, 3], + device_type=Platform.get_platform().device_prefix(), + ) +] + +# FSDP + SP validation over 4 GPUs: dp=2, fsdp=2 (SP only affects input slicing) +device_mesh = DeviceMesh( + device_type="cuda", + mesh=np.arange(4).reshape(2, 2), + mesh_dim_names=("dp", "fsdp"), + ulysses_size=2, +) + +twinkle.initialize( + mode="local", + nproc_per_node=4, + global_device_mesh=device_mesh, + lazy_collect=False, +) + + +def create_dataset(data_slice=None): + dataset = Dataset( + dataset_meta=DatasetMeta("ms://swift/self-cognition", data_slice=data_slice) + ) + dataset.set_template( + "Template", + model_id=MODEL_ID, + truncation_strategy="left", + max_length=256, + ) + dataset.map(SelfCognitionProcessor("twinkle模型", "twinkle团队")) + dataset.encode(batched=True) + return dataset + + + +def train(): + dataloader = DataLoader( + dataset=partial(create_dataset, data_slice=None), + batch_size=4, + device_mesh=device_mesh, + ) + + model = TransformersModel( + model_id=MODEL_ID, + device_mesh=device_mesh, + strategy="native_fsdp", + ) + + lora_config = LoraConfig(target_modules="all-linear") + model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1) + model.set_optimizer("AdamW", lr=1e-4, adapter_name="default") + + for step, batch in enumerate(dataloader): + model.forward_backward(inputs=batch, adapter_name='default') + model.clip_grad_and_step(adapter_name='default') + if step % 1 == 0: + metric = model.calculate_metric(is_training=True, adapter_name='default') + _metrics = {} + for key, value in metric.items(): + try: + value = float(value) + _metrics[key] = value + except: + pass + logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') + model.save(f'last-checkpoint', interval=1) + + +if __name__ == "__main__": + train() diff --git a/cookbook/legacy/single_controller_sp.py b/cookbook/transformers/sp_fsdp_dense.py similarity index 100% rename from cookbook/legacy/single_controller_sp.py rename to cookbook/transformers/sp_fsdp_dense.py diff --git a/cookbook/transformers/sp_fsdp_dense.sh b/cookbook/transformers/sp_fsdp_dense.sh new file mode 100644 index 00000000..b9eabc29 --- /dev/null +++ b/cookbook/transformers/sp_fsdp_dense.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 sp_fsdp_dense.py \ No newline at end of file From db97bb2db24a26031f6ce4003588f653c9f1dfe8 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 9 Feb 2026 19:52:42 +0800 Subject: [PATCH 10/14] refactor(tests): move sequence parallel attention test to dedicated directory Relocate test_sequence_parallel_single_attention.py from tests/transformers/ to tests/sequence_parallel/ to better organize test files by feature area. This improves maintainability and aligns with the project's test structure conventions. --- .../sft/single_controller_sp_packing.py | 110 ------------------ cookbook/transformers/single_controller_sp.py | 90 -------------- ...test_sequence_parallel_single_attention.py | 0 3 files changed, 200 deletions(-) delete mode 100644 cookbook/legacy/sft/single_controller_sp_packing.py delete mode 100644 cookbook/transformers/single_controller_sp.py rename tests/{transformers => sequence_parallel}/test_sequence_parallel_single_attention.py (100%) diff --git a/cookbook/legacy/sft/single_controller_sp_packing.py b/cookbook/legacy/sft/single_controller_sp_packing.py deleted file mode 100644 index 24282757..00000000 --- a/cookbook/legacy/sft/single_controller_sp_packing.py +++ /dev/null @@ -1,110 +0,0 @@ -from functools import partial -import numpy as np -from peft import LoraConfig - -import twinkle -from twinkle import get_logger, DeviceGroup, Platform, DeviceMesh -from twinkle.dataloader import DataLoader -from twinkle.dataset import PackingDataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.processor import InputProcessor - -logger = get_logger() -MODEL_ID = 'ms://Qwen/Qwen2.5-7B-Instruct' - -device_group = [ - DeviceGroup( - name="default", - ranks=[0, 1, 2, 3], - device_type=Platform.get_platform().device_prefix(), - ) -] - -# FSDP + SP validation over 4 GPUs: dp=2, fsdp=2 (SP only affects input slicing) -device_mesh = DeviceMesh( - device_type="cuda", - mesh=np.arange(4).reshape(2, 2), - mesh_dim_names=("dp", "fsdp"), - ulysses_size=2, -) - -twinkle.initialize( - mode="ray", - nproc_per_node=4, - groups=device_group, - global_device_mesh=device_mesh, - lazy_collect=False, -) - - -def create_dataset(data_slice=None): - dataset = PackingDataset( - dataset_meta=DatasetMeta("ms://swift/self-cognition", data_slice=data_slice) - ) - dataset.set_template( - "Template", - model_id=MODEL_ID, - truncation_strategy="left", - max_length=64, - ) - dataset.map(SelfCognitionProcessor("twinkle模型", "twinkle团队")) - dataset.encode(batched=True) - dataset.pack_dataset() - return dataset - - -def eval(model: TransformersModel): - dataloader = DataLoader( - dataset=partial(create_dataset, data_slice=range(20)), - batch_size=4, - drop_last=True, - device_mesh=device_mesh, - remote_group="default", - ) - for step, batch in enumerate(dataloader): - model.forward_only(inputs=batch, adapter_name="default") - model.calculate_loss(adapter_name="default") - metrics = model.calculate_metric(is_training=False, adapter_name="default") - return metrics() - - -def train(): - dataloader = DataLoader( - dataset=partial(create_dataset, data_slice=None), - batch_size=4, - device_mesh=device_mesh, - remote_group="default", - ) - - model = TransformersModel( - model_id=MODEL_ID, - device_mesh=device_mesh, - strategy="native_fsdp", - remote_group="default", - ) - lora_config = LoraConfig(target_modules="all-linear") - model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1) - model.set_processor(InputProcessor, padding_free=True, adapter_name="default") - model.set_loss("CrossEntropyLoss", reduction="mean", adapter_name="default") - model.set_optimizer("AdamW", lr=1e-4, adapter_name="default") - - loss_metric = 99.0 - for step, batch in enumerate(dataloader): - if isinstance(batch, list) and len(batch) == 0: - continue - output = model.forward_backward(inputs=batch, adapter_name="default") - loss_value = output() if callable(output) else output - logger.info(f"step {step}, loss: {loss_value}") - model.clip_grad_and_step(adapter_name="default") - if step % 50 == 0 and step > 0: - metrics = eval(model) - logger.info(f"Current is step {step} of {len(dataloader)}, metric: {metrics}") - metrics["step"] = step - if loss_metric > metrics["loss"]: - model.save(f"checkpoint-{step}") - loss_metric = metrics["loss"] - - -if __name__ == "__main__": - train() diff --git a/cookbook/transformers/single_controller_sp.py b/cookbook/transformers/single_controller_sp.py deleted file mode 100644 index 2c5914a1..00000000 --- a/cookbook/transformers/single_controller_sp.py +++ /dev/null @@ -1,90 +0,0 @@ -from functools import partial -import numpy as np -import torch -from peft import LoraConfig - -import twinkle -from twinkle import get_logger, DeviceGroup, Platform, DeviceMesh -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor - -logger = get_logger() -MODEL_ID = 'ms://Qwen/Qwen2.5-7B-Instruct' - -device_group = [ - DeviceGroup( - name="default", - ranks=[0, 1, 2, 3], - device_type=Platform.get_platform().device_prefix(), - ) -] - -# FSDP + SP validation over 4 GPUs: dp=2, fsdp=2 (SP only affects input slicing) -device_mesh = DeviceMesh( - device_type="cuda", - mesh=np.arange(4).reshape(2, 2), - mesh_dim_names=("dp", "fsdp"), - ulysses_size=2, -) - -twinkle.initialize( - mode="local", - nproc_per_node=4, - global_device_mesh=device_mesh, - lazy_collect=False, -) - - -def create_dataset(data_slice=None): - dataset = Dataset( - dataset_meta=DatasetMeta("ms://swift/self-cognition", data_slice=data_slice) - ) - dataset.set_template( - "Template", - model_id=MODEL_ID, - truncation_strategy="left", - max_length=256, - ) - dataset.map(SelfCognitionProcessor("twinkle模型", "twinkle团队")) - dataset.encode(batched=True) - return dataset - - - -def train(): - dataloader = DataLoader( - dataset=partial(create_dataset, data_slice=None), - batch_size=4, - device_mesh=device_mesh, - ) - - model = TransformersModel( - model_id=MODEL_ID, - device_mesh=device_mesh, - strategy="native_fsdp", - ) - - lora_config = LoraConfig(target_modules="all-linear") - model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1) - model.set_optimizer("AdamW", lr=1e-4, adapter_name="default") - - for step, batch in enumerate(dataloader): - model.forward_backward(inputs=batch, adapter_name='default') - model.clip_grad_and_step(adapter_name='default') - if step % 1 == 0: - metric = model.calculate_metric(is_training=True, adapter_name='default') - _metrics = {} - for key, value in metric.items(): - try: - value = float(value) - _metrics[key] = value - except: - pass - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - model.save(f'last-checkpoint', interval=1) - - -if __name__ == "__main__": - train() diff --git a/tests/transformers/test_sequence_parallel_single_attention.py b/tests/sequence_parallel/test_sequence_parallel_single_attention.py similarity index 100% rename from tests/transformers/test_sequence_parallel_single_attention.py rename to tests/sequence_parallel/test_sequence_parallel_single_attention.py From 74993e3b05f0f081c0d6b7f17f810956cc912e57 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 9 Feb 2026 20:09:21 +0800 Subject: [PATCH 11/14] feat: add sequence parallelism instructions and clean up imports - Add bash script header and comments to `sp_fsdp_dense.sh` explaining how to enable sequence parallelism with ulysses_size - Remove duplicate `import os` statement in transformers.py for cleaner code - Fix minor formatting by removing extra blank line in transformers_utils.py --- cookbook/transformers/sp_fsdp_dense.sh | 9 +++++++++ src/twinkle/model/transformers/transformers.py | 1 - src/twinkle/utils/transformers_utils.py | 1 - 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/cookbook/transformers/sp_fsdp_dense.sh b/cookbook/transformers/sp_fsdp_dense.sh index b9eabc29..9603780e 100644 --- a/cookbook/transformers/sp_fsdp_dense.sh +++ b/cookbook/transformers/sp_fsdp_dense.sh @@ -1 +1,10 @@ +#!/bin/bash +# To enabele sequence parallelism, please set ulysses_size > 1 +# device_mesh = DeviceMesh( +# device_type="cuda", +# mesh=np.arange(4).reshape(2, 2), +# mesh_dim_names=("dp", "fsdp"), +# ulysses_size=2, +# ) +# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 sp_fsdp_dense.py \ No newline at end of file diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 432d1ca3..c6a0ce8d 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -2,7 +2,6 @@ import contextlib import os import json -import os import re from dataclasses import dataclass, field from typing import Dict, Any, List, Literal, Callable diff --git a/src/twinkle/utils/transformers_utils.py b/src/twinkle/utils/transformers_utils.py index f6fccd3c..8b0f59e1 100644 --- a/src/twinkle/utils/transformers_utils.py +++ b/src/twinkle/utils/transformers_utils.py @@ -135,7 +135,6 @@ def get_modules_to_not_convert(model): res.append(n) return res if res else None - def get_llm_model(model, *, model_meta=None, inner_backbone: bool = True): """Best-effort extraction of the LLM module from a (possibly wrapped) model. From c385e8b3c87c34d8966d177eff35dd3fff2513c4 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 9 Feb 2026 20:18:04 +0800 Subject: [PATCH 12/14] refactor --- src/twinkle/model/transformers/transformers.py | 2 +- src/twinkle/utils/transformers_utils.py | 10 +--------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index c6a0ce8d..eea51435 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import contextlib -import os import json +import os import re from dataclasses import dataclass, field from typing import Dict, Any, List, Literal, Callable diff --git a/src/twinkle/utils/transformers_utils.py b/src/twinkle/utils/transformers_utils.py index 8b0f59e1..2623b867 100644 --- a/src/twinkle/utils/transformers_utils.py +++ b/src/twinkle/utils/transformers_utils.py @@ -152,7 +152,7 @@ def get_llm_model(model, *, model_meta=None, inner_backbone: bool = True): except Exception: pass - # 2) Unwrap PEFT/Swift wrappers. + # 2) Unwrap PEFT wrappers. try: from peft import PeftModel # type: ignore @@ -161,14 +161,6 @@ def get_llm_model(model, *, model_meta=None, inner_backbone: bool = True): except Exception: pass - try: - from swift.tuners import SwiftModel # type: ignore - - if isinstance(model, SwiftModel): - model = model.model - except Exception: - pass - # 3) Locate the language model module in multimodal containers via model_meta. if model_meta is None: model_meta = getattr(model, "model_meta", None) From 488d648ae46d4a700d5eba72879ddbe5c24f3c00 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 9 Feb 2026 21:03:48 +0800 Subject: [PATCH 13/14] feat: update training script with local mode and evaluation - Switch from `ray` to `local` mode for twinkle initialization - Add evaluation function with separate dataset slice - Increase dataset size from 100 to 500 samples - Add cosine warmup learning rate scheduler - Remove unused torch import and remote_group parameters - Adjust batch size from 4 to 8 and logging frequency to every 20 steps - Improve logging with train configs and total steps information --- cookbook/transformers/sp_fsdp_dense.py | 60 ++++++++++++++------------ 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index 63ed61c9..99c2e4ec 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -1,6 +1,5 @@ from functools import partial import numpy as np -import torch from peft import LoraConfig import twinkle @@ -12,6 +11,7 @@ logger = get_logger() MODEL_ID = 'ms://Qwen/Qwen2.5-7B-Instruct' +DATASETS='ms://swift/self-cognition' device_group = [ DeviceGroup( @@ -30,64 +30,70 @@ ) twinkle.initialize( - mode="ray", + mode="local", nproc_per_node=4, - groups=device_group, global_device_mesh=device_mesh, lazy_collect=False, ) +def eval(model): + dataloader = DataLoader( + dataset=partial(create_dataset, data_slice=range(100)), + batch_size=4, + device_mesh=device_mesh, + ) + for _, batch in enumerate(dataloader): + model.forward_only(inputs=batch, adapter_name="default") + model.calculate_loss(adapter_name="default") + return model.calculate_metric(is_training=False, adapter_name="default") + def create_dataset(data_slice=None): dataset = Dataset( - dataset_meta=DatasetMeta("ms://swift/self-cognition", data_slice=data_slice) + dataset_meta=DatasetMeta(DATASETS, data_slice=range(500)) ) dataset.set_template( "Template", - model_id=MODEL_ID, - truncation_strategy="left", - max_length=256, + model_id=MODEL_ID ) dataset.map(SelfCognitionProcessor("twinkle模型", "twinkle团队")) dataset.encode(batched=True) return dataset - - - def train(): dataloader = DataLoader( dataset=partial(create_dataset, data_slice=None), - batch_size=4, + batch_size=8, device_mesh=device_mesh, - remote_group="default", ) model = TransformersModel( model_id=MODEL_ID, device_mesh=device_mesh, strategy="native_fsdp", - remote_group="default", ) lora_config = LoraConfig(target_modules="all-linear") model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1) model.set_optimizer("AdamW", lr=1e-4, adapter_name="default") + model.set_lr_scheduler( + scheduler_cls="CosineWarmupScheduler", + num_warmup_steps=5, + num_training_steps=len(dataloader), + adapter_name="default", + ) + + logger.info(model.get_train_configs(adapter_name="default")) + logger.info(f"Total steps: {len(dataloader)}") + for step, batch in enumerate(dataloader): - model.forward_backward(inputs=batch, adapter_name='default') - model.clip_grad_and_step(adapter_name='default') - if step % 1 == 0: - metric = model.calculate_metric(is_training=True, adapter_name='default') - _metrics = {} - for key, value in metric.items(): - try: - value = float(value) - _metrics[key] = value - except: - pass - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - model.save(f'last-checkpoint', interval=1) + model.forward_backward(inputs=batch, adapter_name="default") + model.clip_grad_and_step(adapter_name="default") + if step % 20 == 0: + metric = model.calculate_metric(is_training=True, adapter_name="default") + logger.info(f"Current is step {step} of {len(dataloader)}, metric: {metric}") + model.save("last-checkpoint", interval=1) if __name__ == "__main__": - train() + train() \ No newline at end of file From 6783a9a5ab245bc7a1f918ca8d2cc301e214d465 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 9 Feb 2026 21:06:04 +0800 Subject: [PATCH 14/14] feat(transformers): remove unused imports in sequence_parallel module Removed unnecessary imports (`math`, `os`, `SimpleNamespace`) from the sequence_parallel strategy file to clean up the codebase and improve maintainability. --- src/twinkle/model/transformers/strategy/sequence_parallel.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index c931ff23..e3a95a79 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -1,8 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import math -import os from functools import partial -from types import SimpleNamespace from typing import Any, Dict, Optional, Tuple, Union import torch