From abf8497f057c73ef9261decffcaca1276555c165 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 23 Mar 2026 08:41:58 +0800 Subject: [PATCH 1/6] support qwen35 sp --- cookbook/transformers/sp_fsdp_dense.py | 2 + .../model/transformers/models/__init__.py | 1 + .../transformers/models/qwen3_5/__init__.py | 16 + .../models/qwen3_5/modeling_qwen3_5.py | 549 ++++++++++++++++++ .../strategy/sequence_parallel.py | 34 ++ .../test_twinkle_qwen3_5_text_model.py | 197 +++++++ 6 files changed, 799 insertions(+) create mode 100644 src/twinkle/model/transformers/models/__init__.py create mode 100644 src/twinkle/model/transformers/models/qwen3_5/__init__.py create mode 100644 src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py create mode 100644 tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index 868b61c0..66ad0efc 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -7,6 +7,7 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel +from twinkle.model.transformers.models.qwen3_5 import TwinkleQwen3_5ForCausalLM from twinkle.preprocessor import SelfCognitionProcessor logger = get_logger() @@ -64,6 +65,7 @@ def train(): model = TransformersModel( model_id=MODEL_ID, + model_cls=TwinkleQwen3_5ForCausalLM, device_mesh=device_mesh, strategy='native_fsdp', ) diff --git a/src/twinkle/model/transformers/models/__init__.py b/src/twinkle/model/transformers/models/__init__.py new file mode 100644 index 00000000..85b3e739 --- /dev/null +++ b/src/twinkle/model/transformers/models/__init__.py @@ -0,0 +1 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. diff --git a/src/twinkle/model/transformers/models/qwen3_5/__init__.py b/src/twinkle/model/transformers/models/qwen3_5/__init__.py new file mode 100644 index 00000000..1b3cb561 --- /dev/null +++ b/src/twinkle/model/transformers/models/qwen3_5/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .modeling_qwen3_5 import ( + TwinkleQwen3_5DecoderLayer, + TwinkleQwen3_5ForCausalLM, + TwinkleQwen3_5GatedDeltaNet, + TwinkleQwen3_5PreTrainedModel, + TwinkleQwen3_5TextModel, +) + +__all__ = [ + 'TwinkleQwen3_5PreTrainedModel', + 'TwinkleQwen3_5TextModel', + 'TwinkleQwen3_5DecoderLayer', + 'TwinkleQwen3_5GatedDeltaNet', + 'TwinkleQwen3_5ForCausalLM', +] diff --git a/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py new file mode 100644 index 00000000..c44f21e0 --- /dev/null +++ b/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -0,0 +1,549 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +import importlib.util +from typing import Any, Callable, Optional + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig +from transformers.models.qwen3_5 import modeling_qwen3_5 as hf_qwen35 +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, can_return_tuple +from transformers.utils.generic import merge_with_config_defaults +from transformers.utils.output_capturing import capture_outputs + + +try: + from fla.modules import FusedRMSNormGated as _FLA_FUSED_RMS_NORM_GATED + from fla.modules.convolution import causal_conv1d as _FLA_CAUSAL_CONV1D_FN + from fla.modules.convolution import causal_conv1d_update as _FLA_CAUSAL_CONV1D_UPDATE + from fla.ops.gated_delta_rule import chunk_gated_delta_rule as _FLA_CHUNK_GATED_DELTA_RULE + from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule as _FLA_FUSED_RECURRENT_GATED_DELTA_RULE +except ImportError: + _FLA_FUSED_RMS_NORM_GATED = None + _FLA_CAUSAL_CONV1D_FN = None + _FLA_CAUSAL_CONV1D_UPDATE = None + _FLA_CHUNK_GATED_DELTA_RULE = None + _FLA_FUSED_RECURRENT_GATED_DELTA_RULE = None + +_HAS_CAUSAL_CONV1D = importlib.util.find_spec('causal_conv1d') is not None + + +def _ensure_text_config(config: Qwen3_5TextConfig) -> Qwen3_5TextConfig: + if isinstance(config, Qwen3_5TextConfig): + return config + raise TypeError( + 'TwinkleQwen3_5 text-only models require transformers.models.qwen3_5.Qwen3_5TextConfig. ' + f'Got {type(config).__name__}.' + ) + + +def _ensure_linear_attention_fast_path() -> None: + missing = [] + if _FLA_CAUSAL_CONV1D_FN is None: + missing.append('fla.modules.convolution.causal_conv1d') + if _FLA_CHUNK_GATED_DELTA_RULE is None or _FLA_FUSED_RECURRENT_GATED_DELTA_RULE is None: + missing.append('fla.ops.gated_delta_rule') + if not _HAS_CAUSAL_CONV1D: + missing.append('causal-conv1d') + if missing: + raise ImportError( + 'TwinkleQwen3_5 linear attention requires flash-linear-attention and causal-conv1d. ' + f'Missing: {", ".join(missing)}' + ) + + +def _maybe_slice_tensor_output(output: Any) -> torch.Tensor: + if isinstance(output, tuple): + return output[0] + return output + + +def _sp_is_enabled(sequence_parallel_context: Any | None) -> bool: + return bool( + sequence_parallel_context is not None + and getattr(sequence_parallel_context, 'sp_world_size', 1) > 1 + and getattr(sequence_parallel_context, 'sp_group', None) is not None + ) + + +def _get_sp_rank(sequence_parallel_context: Any | None) -> int: + if not _sp_is_enabled(sequence_parallel_context): + return 0 + import torch.distributed as dist + + return dist.get_rank(sequence_parallel_context.sp_group) + + +def _seq_to_head_shard(tensor: torch.Tensor, sequence_parallel_context: Any | None) -> torch.Tensor: + if not _sp_is_enabled(sequence_parallel_context): + return tensor + from twinkle.model.transformers.strategy.sequence_parallel import _SeqAllToAll + + if tensor.dim() == 3: + return _SeqAllToAll.apply(sequence_parallel_context.sp_group, tensor.unsqueeze(-1), 2, 1).squeeze(-1) + return _SeqAllToAll.apply(sequence_parallel_context.sp_group, tensor, 2, 1) + + +def _head_to_seq_shard(tensor: torch.Tensor, sequence_parallel_context: Any | None) -> torch.Tensor: + if not _sp_is_enabled(sequence_parallel_context): + return tensor + from twinkle.model.transformers.strategy.sequence_parallel import _SeqAllToAll + + if tensor.dim() == 3: + return _SeqAllToAll.apply(sequence_parallel_context.sp_group, tensor.unsqueeze(-1), 1, 2).squeeze(-1) + return _SeqAllToAll.apply(sequence_parallel_context.sp_group, tensor, 1, 2) + + +class TwinkleQwen3_5GatedDeltaNet(hf_qwen35.Qwen3_5GatedDeltaNet): + + def __init__(self, config: Qwen3_5TextConfig, layer_idx: int): + _ensure_linear_attention_fast_path() + super().__init__(config, layer_idx) + self.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN + self.causal_conv1d_update = _FLA_CAUSAL_CONV1D_UPDATE or hf_qwen35.causal_conv1d_update + self.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE + self.recurrent_gated_delta_rule = _FLA_FUSED_RECURRENT_GATED_DELTA_RULE + if _FLA_FUSED_RMS_NORM_GATED is not None and torch.cuda.is_available(): + self.norm = _FLA_FUSED_RMS_NORM_GATED( + self.head_v_dim, + eps=self.layer_norm_epsilon, + activation=self.activation, + device=torch.cuda.current_device(), + dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), + ) + + def _get_local_conv1d_weight(self, sp_rank: int, local_key_dim: int, local_value_dim: int) -> torch.Tensor: + w_full = self.conv1d.weight.squeeze(1) + key_offset = sp_rank * local_key_dim + value_offset = sp_rank * local_value_dim + w_q = w_full[key_offset:key_offset + local_key_dim] + w_k = w_full[self.key_dim + key_offset:self.key_dim + key_offset + local_key_dim] + w_v = w_full[2 * self.key_dim + value_offset:2 * self.key_dim + value_offset + local_value_dim] + return torch.cat((w_q, w_k, w_v), dim=0) + + def _apply_varlen_conv( + self, + mixed_qkv: torch.Tensor, + conv_weight: torch.Tensor, + cu_seq_lens_q: torch.Tensor | None, + ) -> torch.Tensor: + if self.causal_conv1d_fn is None: + raise ImportError( + 'TwinkleQwen3_5 linear attention requires fla.modules.convolution.causal_conv1d for prefill/train.' + ) + output = self.causal_conv1d_fn( + x=mixed_qkv, + weight=conv_weight, + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + backend='triton', + cu_seqlens=cu_seq_lens_q, + ) + return _maybe_slice_tensor_output(output) + + def _apply_decode_conv( + self, + mixed_qkv: torch.Tensor, + conv_state: torch.Tensor, + conv_weight: torch.Tensor, + ) -> torch.Tensor: + if self.causal_conv1d_update is None: + raise ImportError( + 'TwinkleQwen3_5 decode requires a causal_conv1d_update implementation from flash-linear-attention ' + 'or causal-conv1d.' + ) + mixed_qkv_t = mixed_qkv.transpose(1, 2).contiguous() + output = self.causal_conv1d_update( + mixed_qkv_t, + conv_state, + conv_weight, + self.conv1d.bias, + self.activation, + ) + output = _maybe_slice_tensor_output(output) + if output.dim() == 2: + output = output.unsqueeze(1) + elif output.dim() == 3 and output.shape[1] == conv_weight.shape[0]: + output = output.transpose(1, 2).contiguous() + return output + + def forward( + self, + hidden_states: torch.Tensor, + cache_params: hf_qwen35.Qwen3_5DynamicCache | None = None, + cache_position: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + cu_seq_lens_q: torch.Tensor | None = None, + sequence_parallel_context: Any | None = None, + ): + hidden_states = hf_qwen35.apply_mask_to_padding_states(hidden_states, attention_mask) + batch_size, seq_len, _ = hidden_states.shape + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_position is not None + ) + + if cache_params is not None: + conv_state = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + else: + conv_state = None + recurrent_state = None + + mixed_qkv = self.in_proj_qkv(hidden_states) + z = self.in_proj_z(hidden_states).reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + sp_enabled = _sp_is_enabled(sequence_parallel_context) + if sp_enabled: + sp_world_size = int(sequence_parallel_context.sp_world_size) + if self.num_k_heads % sp_world_size != 0 or self.num_v_heads % sp_world_size != 0: + raise RuntimeError( + 'TwinkleQwen3_5 linear attention requires sp_world_size to divide both ' + f'linear_num_key_heads ({self.num_k_heads}) and linear_num_value_heads ({self.num_v_heads}).' + ) + local_num_k_heads = self.num_k_heads // sp_world_size + local_num_v_heads = self.num_v_heads // sp_world_size + local_key_dim = local_num_k_heads * self.head_k_dim + local_value_dim = local_num_v_heads * self.head_v_dim + + q_proj, k_proj, v_proj = torch.split(mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1) + q_proj = q_proj.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + k_proj = k_proj.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + v_proj = v_proj.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + q_proj = _seq_to_head_shard(q_proj, sequence_parallel_context) + k_proj = _seq_to_head_shard(k_proj, sequence_parallel_context) + v_proj = _seq_to_head_shard(v_proj, sequence_parallel_context) + b = _seq_to_head_shard(b.reshape(batch_size, seq_len, self.num_v_heads), sequence_parallel_context) + a = _seq_to_head_shard(a.reshape(batch_size, seq_len, self.num_v_heads), sequence_parallel_context) + + mixed_qkv = torch.cat( + ( + q_proj.reshape(batch_size, q_proj.shape[1], local_key_dim), + k_proj.reshape(batch_size, k_proj.shape[1], local_key_dim), + v_proj.reshape(batch_size, v_proj.shape[1], local_value_dim), + ), + dim=-1, + ) + conv_weight = self._get_local_conv1d_weight(_get_sp_rank(sequence_parallel_context), local_key_dim, local_value_dim) + else: + local_num_k_heads = self.num_k_heads + local_num_v_heads = self.num_v_heads + local_key_dim = self.key_dim + local_value_dim = self.value_dim + b = b.reshape(batch_size, seq_len, self.num_v_heads) + a = a.reshape(batch_size, seq_len, self.num_v_heads) + conv_weight = self.conv1d.weight.squeeze(1) + + if use_precomputed_states: + if conv_state is None: + raise RuntimeError('Qwen3.5 decode requires initialized convolution state.') + mixed_qkv = self._apply_decode_conv(mixed_qkv, conv_state, conv_weight) + else: + if cache_params is not None: + cache_params.conv_states[self.layer_idx] = F.pad( + mixed_qkv.transpose(1, 2).contiguous(), + (self.conv_kernel_size - mixed_qkv.shape[1], 0), + ) + mixed_qkv = self._apply_varlen_conv(mixed_qkv, conv_weight, cu_seq_lens_q) + + query, key, value = torch.split(mixed_qkv, [local_key_dim, local_key_dim, local_value_dim], dim=-1) + query = query.reshape(batch_size, query.shape[1], local_num_k_heads, self.head_k_dim) + key = key.reshape(batch_size, key.shape[1], local_num_k_heads, self.head_k_dim) + value = value.reshape(batch_size, value.shape[1], local_num_v_heads, self.head_v_dim) + + beta = b.sigmoid() + if sp_enabled: + head_offset = _get_sp_rank(sequence_parallel_context) * local_num_v_heads + head_slice = slice(head_offset, head_offset + local_num_v_heads) + g = -self.A_log[head_slice].float().exp() * F.softplus(a.float() + self.dt_bias[head_slice]) + else: + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + + if self.num_v_heads // self.num_k_heads > 1: + repeat = self.num_v_heads // self.num_k_heads + query = query.repeat_interleave(repeat, dim=2) + key = key.repeat_interleave(repeat, dim=2) + + if use_precomputed_states: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seq_lens_q, + ) + + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + + core_attn_out = _head_to_seq_shard(core_attn_out, sequence_parallel_context) + core_attn_out = self.norm(core_attn_out.reshape(-1, self.head_v_dim), z.reshape(-1, self.head_v_dim)) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, self.value_dim) + return self.out_proj(core_attn_out) + + +class TwinkleQwen3_5DecoderLayer(hf_qwen35.Qwen3_5DecoderLayer): + + def __init__(self, config: Qwen3_5TextConfig, layer_idx: int): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == 'linear_attention': + self.linear_attn = TwinkleQwen3_5GatedDeltaNet(config, layer_idx) + elif self.layer_type == 'full_attention': + self.self_attn = hf_qwen35.Qwen3_5Attention(config, layer_idx) + else: + raise ValueError(f'Unsupported Qwen3.5 layer_type={self.layer_type!r}') + self.mlp = hf_qwen35.Qwen3_5MLP(config, config.intermediate_size) + self.input_layernorm = hf_qwen35.Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = hf_qwen35.Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + cu_seq_lens_q: torch.Tensor | None = None, + sequence_parallel_context: Any | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == 'linear_attention': + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + sequence_parallel_context=sequence_parallel_context, + ) + else: + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + return residual + hidden_states + + +class TwinkleQwen3_5PreTrainedModel(hf_qwen35.Qwen3_5PreTrainedModel): + config_class = Qwen3_5TextConfig + _no_split_modules = ['TwinkleQwen3_5DecoderLayer'] + _can_record_outputs = { + 'hidden_states': TwinkleQwen3_5DecoderLayer, + 'attentions': hf_qwen35.Qwen3_5Attention, + } + + +class TwinkleQwen3_5TextModel(TwinkleQwen3_5PreTrainedModel): + + def __init__(self, config: Qwen3_5TextConfig): + config = _ensure_text_config(config) + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.layers = nn.ModuleList( + [TwinkleQwen3_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = hf_qwen35.Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = hf_qwen35.Qwen3_5TextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self._sequence_parallel_context = None + self.requires_cu_seq_lens_q = any(layer_type == 'linear_attention' for layer_type in config.layer_types) + self.post_init() + + def set_sequence_parallel_context(self, context: Any | None) -> None: + self._sequence_parallel_context = context + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _update_linear_attn_mask(self, attention_mask, cache_position): + linear_attn_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + linear_attn_mask = None + return linear_attn_mask + + @merge_with_config_defaults + @capture_outputs + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + cu_seq_lens_q: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError('You must specify exactly one of input_ids or inputs_embeds') + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = hf_qwen35.Qwen3_5DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + causal_mask = hf_qwen35.create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + sp_context = self._sequence_parallel_context + if _sp_is_enabled(sp_context) and self.requires_cu_seq_lens_q and cu_seq_lens_q is None: + raise ValueError('TwinkleQwen3_5TextModel requires cu_seq_lens_q when sequence parallel is enabled.') + + for decoder_layer in self.layers[:self.config.num_hidden_layers]: + layer_mask = linear_attn_mask if decoder_layer.layer_type == 'linear_attention' else causal_mask + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + cu_seq_lens_q=cu_seq_lens_q if decoder_layer.layer_type == 'linear_attention' else None, + sequence_parallel_context=sp_context if decoder_layer.layer_type == 'linear_attention' else None, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return hf_qwen35.Qwen3_5ModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class TwinkleQwen3_5ForCausalLM(TwinkleQwen3_5PreTrainedModel, GenerationMixin): + _tied_weights_keys = {'lm_head.weight': 'model.embed_tokens.weight'} + _tp_plan = {'lm_head': 'colwise_gather_output'} + _pp_plan = {'lm_head': (['hidden_states'], ['logits'])} + _keys_to_ignore_on_load_unexpected = [r'^mtp.*', r'^model\.visual.*'] + + def __init__(self, config: Qwen3_5TextConfig): + config = _ensure_text_config(config) + super().__init__(config) + self.model = TwinkleQwen3_5TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_sequence_parallel_context(self, context: Any | None) -> None: + self.model.set_sequence_parallel_context(context) + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + cu_seq_lens_q: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + cu_seq_lens_q=cu_seq_lens_q, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 64ea34f3..40bb843b 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -23,6 +23,16 @@ def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor): return cu_seqlens +@dataclass(frozen=True) +class SequenceParallelContext: + sp_group: Optional[dist.ProcessGroup] + sp_world_size: int + rank: int + world_size: int + real_position_ids: Optional[torch.Tensor] + is_packed: bool + + def _get_raw_data_world_size(device_mesh: DeviceMesh) -> int: dp_world_size = device_mesh.dp_world_size or 1 fsdp_world_size = device_mesh.fsdp_world_size or 1 @@ -344,12 +354,25 @@ def __init__(self): self.num_heads = None self.causal_mask_func = None self.extra_kwargs = {} + self.requires_cu_seq_lens_q = False + self._bound_llm_model = None @property def real_position_ids(self) -> torch.Tensor: """The real position ids, this is different from the position_ids in mrope""" return self.extra_kwargs.get('position_ids') + def _build_context(self) -> SequenceParallelContext: + rank = dist.get_rank(self._sp_group) if self._sp_group is not None and dist.is_initialized() else 0 + return SequenceParallelContext( + sp_group=self._sp_group, + sp_world_size=int(self.sp_world_size or 1), + rank=rank, + world_size=int(self.world_size or 1), + real_position_ids=self.real_position_ids, + is_packed=bool(self.extra_kwargs.get('is_packed', False)), + ) + def _prepare_flash_attn(self, base_model: torch.nn.Module): try: from transformers import masking_utils @@ -637,6 +660,10 @@ def prepare( SequenceParallel._global_inited = True self._prepare_forward_hook(llm_model) + self.requires_cu_seq_lens_q = bool(getattr(llm_model, 'requires_cu_seq_lens_q', False)) + self._bound_llm_model = llm_model + if hasattr(llm_model, 'set_sequence_parallel_context'): + llm_model.set_sequence_parallel_context(self._build_context()) if SequenceParallel._is_moe_model(getattr(model, 'config', None)): self._prepare_moe_aux_loss(llm_model) @@ -854,6 +881,13 @@ def prepare_inputs(self, inputs): self.extra_kwargs['is_packed'] = self._is_packed_position_ids(position_ids) if input_ids is not None: self.extra_kwargs['input_ids'] = input_ids.clone() + if self._bound_llm_model is not None and hasattr(self._bound_llm_model, 'set_sequence_parallel_context'): + self._bound_llm_model.set_sequence_parallel_context(self._build_context()) + if self.requires_cu_seq_lens_q and position_ids is not None: + padded_position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids, dim=-1) + padded_position_ids = padded_position_ids.clone() + padded_position_ids[padded_position_ids < 0] = 0 + inputs['cu_seq_lens_q'] = get_cu_seqlens_from_position_ids(padded_position_ids).to(torch.int32) if 'labels' in inputs: labels = inputs['labels'] _, _, labels, _, _, _, _ = self.pad_and_split_inputs( diff --git a/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py new file mode 100644 index 00000000..ed67b4e3 --- /dev/null +++ b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py @@ -0,0 +1,197 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import tempfile +import unittest +from types import SimpleNamespace +from unittest.mock import patch + +import torch +from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig +from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + +from twinkle.model.transformers.models.qwen3_5 import modeling_qwen3_5 as tw_qwen35 +from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallel, SequenceParallelContext + + +def _build_text_config(layer_types=None) -> Qwen3_5TextConfig: + layer_types = layer_types or ['full_attention'] + return Qwen3_5TextConfig( + vocab_size=64, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=len(layer_types), + num_attention_heads=4, + num_key_value_heads=2, + head_dim=4, + hidden_act='silu', + max_position_embeddings=128, + rms_norm_eps=1e-6, + linear_conv_kernel_dim=3, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + layer_types=layer_types, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + + +class _ContextReceiver: + + def __init__(self): + self.context = None + + def set_sequence_parallel_context(self, context): + self.context = context + + +class TestTwinkleQwen35TextModel(unittest.TestCase): + + def test_rejects_non_text_config(self): + with self.assertRaises(TypeError): + tw_qwen35.TwinkleQwen3_5ForCausalLM(Qwen3_5Config()) + + def test_text_model_accepts_sequence_parallel_context(self): + model = tw_qwen35.TwinkleQwen3_5TextModel(_build_text_config(['full_attention'])) + context = SequenceParallelContext( + sp_group=None, + sp_world_size=2, + rank=0, + world_size=2, + real_position_ids=torch.tensor([[0, 1, 2]], dtype=torch.long), + is_packed=False, + ) + model.set_sequence_parallel_context(context) + self.assertIs(model._sequence_parallel_context, context) + + def test_from_pretrained_loads_text_only_weights(self): + config = _build_text_config(['full_attention']) + hf_model = Qwen3_5ForCausalLM(config).eval() + with tempfile.TemporaryDirectory() as temp_dir: + hf_model.save_pretrained(temp_dir) + tw_model = tw_qwen35.TwinkleQwen3_5ForCausalLM.from_pretrained(temp_dir).eval() + + input_ids = torch.tensor([[1, 2, 3, 4]], dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + position_ids = torch.arange(input_ids.shape[1], dtype=torch.long).unsqueeze(0) + with torch.no_grad(): + hf_outputs = hf_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + return_dict=True, + ) + tw_outputs = tw_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + return_dict=True, + ) + + torch.testing.assert_close(tw_outputs.logits, hf_outputs.logits, rtol=0, atol=0) + + def test_sequence_parallel_prepare_inputs_injects_cu_seq_lens(self): + sp = SequenceParallel() + sp.world_size = 2 + sp.sp_world_size = 2 + sp.requires_cu_seq_lens_q = True + receiver = _ContextReceiver() + sp._bound_llm_model = receiver + inputs = { + 'input_ids': torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long), + 'position_ids': torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.long), + } + + outputs = sp.prepare_inputs(inputs) + + self.assertIn('cu_seq_lens_q', outputs) + self.assertTrue(torch.equal(outputs['cu_seq_lens_q'], torch.tensor([0, 5, 6], dtype=torch.int32))) + self.assertIsNotNone(receiver.context) + self.assertFalse(receiver.context.is_packed) + self.assertTrue(torch.equal(receiver.context.real_position_ids, inputs['position_ids'])) + + def test_linear_attention_requires_fast_path_dependencies(self): + with self.assertRaises(ImportError): + tw_qwen35.TwinkleQwen3_5TextModel(_build_text_config(['linear_attention'])) + + def test_linear_attention_sp_uses_cu_seq_lens_and_keeps_z_local(self): + captured = { + 'cu_seqlens': None, + 'seq_to_head_calls': 0, + 'head_to_seq_calls': 0, + 'norm_z_shape': None, + } + + def fake_conv(x, weight, bias, activation, seq_idx=None, backend=None, cu_seqlens=None): + del weight, bias, activation, seq_idx, backend + captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None + return x + + def fake_chunk_rule(query, key, value, g, beta, initial_state=None, output_final_state=False, + use_qk_l2norm_in_kernel=False, cu_seqlens=None): + del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel + captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None + return value, None + + def fake_recurrent_rule(query, key, value, g, beta, initial_state=None, output_final_state=False, + use_qk_l2norm_in_kernel=False): + del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel + return value, None + + def fake_seq_to_head(tensor, context): + del context + captured['seq_to_head_calls'] += 1 + return tensor + + def fake_head_to_seq(tensor, context): + del context + captured['head_to_seq_calls'] += 1 + return tensor + + class DummyNorm(torch.nn.Module): + + def forward(self, x, z): + captured['norm_z_shape'] = tuple(z.shape) + return x + z + + with patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_FN', fake_conv), \ + patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_UPDATE', lambda *args, **kwargs: args[0]), \ + patch.object(tw_qwen35, '_FLA_CHUNK_GATED_DELTA_RULE', fake_chunk_rule), \ + patch.object(tw_qwen35, '_FLA_FUSED_RECURRENT_GATED_DELTA_RULE', fake_recurrent_rule), \ + patch.object(tw_qwen35, '_HAS_CAUSAL_CONV1D', True), \ + patch.object(tw_qwen35, '_seq_to_head_shard', side_effect=fake_seq_to_head), \ + patch.object(tw_qwen35, '_head_to_seq_shard', side_effect=fake_head_to_seq): + config = _build_text_config(['linear_attention']) + module = tw_qwen35.TwinkleQwen3_5GatedDeltaNet(config, layer_idx=0) + module.norm = DummyNorm() + hidden_states = torch.randn(1, 2, config.hidden_size) + attention_mask = torch.ones(1, 2, dtype=torch.int64) + cu_seq_lens_q = torch.tensor([0, 2], dtype=torch.int32) + context = SequenceParallelContext( + sp_group='dummy_group', + sp_world_size=2, + rank=0, + world_size=2, + real_position_ids=torch.tensor([[0, 1]], dtype=torch.long), + is_packed=False, + ) + + output = module( + hidden_states=hidden_states, + attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + sequence_parallel_context=context, + ) + + self.assertEqual(captured['seq_to_head_calls'], 5) + self.assertEqual(captured['head_to_seq_calls'], 1) + self.assertTrue(torch.equal(captured['cu_seqlens'], cu_seq_lens_q)) + self.assertEqual(captured['norm_z_shape'], (hidden_states.shape[0] * hidden_states.shape[1] * config.linear_num_value_heads, config.linear_value_head_dim)) + self.assertEqual(tuple(output.shape), (1, 2, config.hidden_size)) + + +if __name__ == '__main__': + unittest.main() From b331436af8aac2f9d3be471f77d9feb8dd4e0126 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Mon, 23 Mar 2026 09:12:03 +0800 Subject: [PATCH 2/6] fix qwen3.5 sp by using ulysses --- .../transformers/qwen3_5_sp_memory_bench.py | 430 ++++++++++++++++++ cookbook/transformers/sp_fsdp_dense.py | 264 ++++++++++- src/twinkle/dataloader/dataloader.py | 23 +- src/twinkle/model/transformers/__init__.py | 38 +- .../model/transformers/models/__init__.py | 15 + .../models/qwen3_5/modeling_qwen3_5.py | 152 ++++++- .../transformers/multi_lora_transformers.py | 6 +- .../strategy/sequence_parallel.py | 95 ++-- .../model/transformers/transformers.py | 12 +- src/twinkle/utils/device_mesh.py | 16 +- tests/dataloader/test_dataloader.py | 15 + .../test_twinkle_qwen3_5_text_model.py | 273 ++++++++++- .../test_twinkle_qwen3_5_text_model_parity.py | 369 +++++++++++++++ .../test_ulysses_auto_batch_policy.py | 18 + 14 files changed, 1657 insertions(+), 69 deletions(-) create mode 100644 cookbook/transformers/qwen3_5_sp_memory_bench.py create mode 100644 tests/sequence_parallel/test_twinkle_qwen3_5_text_model_parity.py create mode 100644 tests/transformers/test_ulysses_auto_batch_policy.py diff --git a/cookbook/transformers/qwen3_5_sp_memory_bench.py b/cookbook/transformers/qwen3_5_sp_memory_bench.py new file mode 100644 index 00000000..2a47b2f4 --- /dev/null +++ b/cookbook/transformers/qwen3_5_sp_memory_bench.py @@ -0,0 +1,430 @@ +import json +import os +import socket +import tempfile +import traceback +from datetime import timedelta +from types import SimpleNamespace + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig +from transformers.utils.import_utils import is_flash_attn_2_available + +from twinkle.model.transformers.models.qwen3_5 import modeling_qwen3_5 as tw_qwen35 +from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallel, SequenceParallelContext +from twinkle.utils import DeviceMesh + +# Examples: +# CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=src python cookbook/transformers/qwen3_5_sp_memory_bench.py +# CUDA_VISIBLE_DEVICES=0,1 QWEN35_SP_MEMORY_MODE=linear PYTHONPATH=src \ +# python cookbook/transformers/qwen3_5_sp_memory_bench.py + + +def _build_linear_bench_config() -> Qwen3_5TextConfig: + hidden_size = int(os.environ.get('QWEN35_SP_MEMORY_HIDDEN_SIZE', '1024')) + head_dim = int(os.environ.get('QWEN35_SP_MEMORY_HEAD_DIM', '64')) + num_attention_heads = hidden_size // head_dim + return Qwen3_5TextConfig( + vocab_size=64, + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_hidden_layers=1, + num_attention_heads=num_attention_heads, + num_key_value_heads=max(1, num_attention_heads // 2), + head_dim=head_dim, + hidden_act='silu', + max_position_embeddings=16384, + rms_norm_eps=1e-6, + attention_dropout=0.0, + linear_conv_kernel_dim=3, + linear_key_head_dim=head_dim, + linear_value_head_dim=head_dim, + linear_num_key_heads=max(2, num_attention_heads // 2), + linear_num_value_heads=num_attention_heads, + layer_types=['linear_attention'], + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + + +def _build_mixed_bench_config() -> Qwen3_5TextConfig: + hidden_size = int(os.environ.get('QWEN35_SP_MEMORY_HIDDEN_SIZE', '1024')) + head_dim = int(os.environ.get('QWEN35_SP_MEMORY_HEAD_DIM', '64')) + num_attention_heads = hidden_size // head_dim + config = Qwen3_5TextConfig( + vocab_size=64, + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_hidden_layers=2, + num_attention_heads=num_attention_heads, + num_key_value_heads=max(1, num_attention_heads // 2), + head_dim=head_dim, + hidden_act='silu', + max_position_embeddings=16384, + rms_norm_eps=1e-6, + attention_dropout=0.0, + linear_conv_kernel_dim=3, + linear_key_head_dim=head_dim, + linear_value_head_dim=head_dim, + linear_num_key_heads=max(2, num_attention_heads // 2), + linear_num_value_heads=num_attention_heads, + layer_types=['full_attention', 'linear_attention'], + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + attn_implementation = os.environ.get( + 'QWEN35_SP_MEMORY_ATTN_IMPLEMENTATION', + 'flash_attention_2' if is_flash_attn_2_available() else 'sdpa', + ) + config._attn_implementation = attn_implementation + config._attn_implementation_internal = attn_implementation + return config + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(('127.0.0.1', 0)) + return sock.getsockname()[1] + + +def _parse_cases(): + spec = os.environ.get('QWEN35_SP_MEMORY_CASES', '1x1024,1x2048,2x2048') + cases = [] + for item in spec.split(','): + item = item.strip() + if not item: + continue + batch_size, seq_len = item.lower().split('x', 1) + cases.append((int(batch_size), int(seq_len))) + return cases + + +def _measure_cuda_peak_stats(run_step): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + run_step() + torch.cuda.synchronize() + return { + 'peak_allocated_mib': torch.cuda.max_memory_allocated() / (1024 ** 2), + 'peak_reserved_mib': torch.cuda.max_memory_reserved() / (1024 ** 2), + } + + +def _run_linear_attention_memory_step( + module: torch.nn.Module, + hidden_states: torch.Tensor, + *, + attention_mask: torch.Tensor, + cu_seq_lens_q: torch.Tensor, + sequence_parallel_context: SequenceParallelContext | None = None, +) -> dict[str, float]: + + def _step(): + module.zero_grad(set_to_none=True) + local_hidden_states = hidden_states.detach().clone().requires_grad_(True) + output = module( + hidden_states=local_hidden_states, + attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + sequence_parallel_context=sequence_parallel_context, + ) + loss = output.float().square().mean() + loss.backward() + + return _measure_cuda_peak_stats(_step) + + +def _run_text_model_memory_step( + model: torch.nn.Module, + model_inputs: dict[str, torch.Tensor | bool], +) -> dict[str, float]: + + def _step(): + model.zero_grad(set_to_none=True) + local_inputs = {} + for key, value in model_inputs.items(): + if torch.is_tensor(value): + local_inputs[key] = value.detach().clone() + else: + local_inputs[key] = value + outputs = model(**local_inputs) + loss = outputs.last_hidden_state.float().square().mean() + loss.backward() + + return _measure_cuda_peak_stats(_step) + + +def _write_error(error_prefix: str, rank: int) -> None: + with open(f'{error_prefix}.rank{rank}.err', 'w', encoding='utf-8') as f: + f.write(traceback.format_exc()) + + +def _run_linear_attention_memory_worker(rank: int, world_size: int, port: int, result_path: str, cases): + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(port) + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_WORLD_SIZE'] = str(world_size) + torch.cuda.set_device(rank) + error_prefix = result_path + try: + dist.init_process_group( + backend='nccl', + rank=rank, + world_size=world_size, + timeout=timedelta(minutes=10), + ) + + device = torch.device(f'cuda:{rank}') + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + config = _build_linear_bench_config() + results = [] + + for batch_size, seq_len in cases: + if seq_len % world_size != 0: + raise ValueError(f'seq_len ({seq_len}) must be divisible by world_size ({world_size})') + + full_attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64, device=device) + full_position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + cu_seq_lens_q = torch.arange( + 0, + (batch_size + 1) * seq_len, + step=seq_len, + dtype=torch.int32, + device=device, + ) + + baseline_module = tw_qwen35.TwinkleQwen3_5GatedDeltaNet(config, layer_idx=0).to(device=device, dtype=dtype) + baseline_module.train() + baseline_hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device=device, dtype=dtype) + baseline_stats = _run_linear_attention_memory_step( + baseline_module, + baseline_hidden_states, + attention_mask=full_attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + sequence_parallel_context=None, + ) + del baseline_module, baseline_hidden_states + torch.cuda.empty_cache() + + local_seq_len = seq_len // world_size + start = rank * local_seq_len + end = start + local_seq_len + sp_attention_mask = full_attention_mask[:, start:end].contiguous() + sp_hidden_states = torch.randn(batch_size, local_seq_len, config.hidden_size, device=device, dtype=dtype) + sp_context = SequenceParallelContext( + sp_group=dist.group.WORLD, + sp_world_size=world_size, + rank=rank, + world_size=world_size, + real_position_ids=full_position_ids, + is_packed=False, + ) + sp_module = tw_qwen35.TwinkleQwen3_5GatedDeltaNet(config, layer_idx=0).to(device=device, dtype=dtype) + sp_module.train() + sp_stats = _run_linear_attention_memory_step( + sp_module, + sp_hidden_states, + attention_mask=sp_attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + sequence_parallel_context=sp_context, + ) + del sp_module, sp_hidden_states + torch.cuda.empty_cache() + + payload = torch.tensor([ + baseline_stats['peak_allocated_mib'], + baseline_stats['peak_reserved_mib'], + sp_stats['peak_allocated_mib'], + sp_stats['peak_reserved_mib'], + ], device=device) + gathered = [torch.zeros_like(payload) for _ in range(world_size)] + dist.all_gather(gathered, payload) + + if rank == 0: + gathered_cpu = [tensor.cpu().tolist() for tensor in gathered] + results.append({ + 'batch_size': batch_size, + 'seq_len': seq_len, + 'baseline_peak_allocated_mib_per_rank': [row[0] for row in gathered_cpu], + 'baseline_peak_reserved_mib_per_rank': [row[1] for row in gathered_cpu], + 'sp_peak_allocated_mib_per_rank': [row[2] for row in gathered_cpu], + 'sp_peak_reserved_mib_per_rank': [row[3] for row in gathered_cpu], + 'baseline_peak_allocated_mib_max': max(row[0] for row in gathered_cpu), + 'sp_peak_allocated_mib_max': max(row[2] for row in gathered_cpu), + }) + + if rank == 0: + torch.save(results, result_path) + except Exception: + _write_error(error_prefix, rank) + raise + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _run_mixed_text_model_memory_worker(rank: int, world_size: int, port: int, result_path: str, cases): + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(port) + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_WORLD_SIZE'] = str(world_size) + torch.cuda.set_device(rank) + error_prefix = result_path + try: + dist.init_process_group( + backend='nccl', + rank=rank, + world_size=world_size, + timeout=timedelta(minutes=10), + ) + + device = torch.device(f'cuda:{rank}') + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + baseline_results = [] + sp_results = [] + + for batch_size, seq_len in cases: + config = _build_mixed_bench_config() + full_input_ids = torch.randint(1, config.vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) + full_position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + + baseline_model = tw_qwen35.TwinkleQwen3_5TextModel(config).to(device=device, dtype=dtype) + baseline_model.train() + baseline_stats = _run_text_model_memory_step( + baseline_model, + { + 'input_ids': full_input_ids, + 'position_ids': full_position_ids, + 'use_cache': False, + }, + ) + baseline_results.append(baseline_stats) + del baseline_model + torch.cuda.empty_cache() + + device_mesh = DeviceMesh.from_sizes( + world_size=world_size, + dp_size=world_size, + ulysses_size=world_size, + device_type='cuda', + ) + tokenizer = SimpleNamespace(pad_token_id=0) + sp = SequenceParallel() + + for batch_size, seq_len in cases: + if seq_len % world_size != 0: + raise ValueError(f'seq_len ({seq_len}) must be divisible by world_size ({world_size})') + + config = _build_mixed_bench_config() + full_input_ids = torch.randint(1, config.vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) + full_position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + + sp_model = tw_qwen35.TwinkleQwen3_5TextModel(config).to(device=device, dtype=dtype) + sp_model.train() + sp.prepare(world_size, sp_model, tokenizer, device_mesh=device_mesh) + sp_inputs = sp.prepare_inputs({ + 'input_ids': full_input_ids, + 'position_ids': full_position_ids, + 'use_cache': False, + }) + sp_stats = _run_text_model_memory_step(sp_model, sp_inputs) + sp_results.append(sp_stats) + del sp_model + torch.cuda.empty_cache() + + gathered_results = [] + for (batch_size, seq_len), baseline_stats, sp_stats in zip(cases, baseline_results, sp_results, strict=False): + payload = torch.tensor([ + baseline_stats['peak_allocated_mib'], + baseline_stats['peak_reserved_mib'], + sp_stats['peak_allocated_mib'], + sp_stats['peak_reserved_mib'], + ], device=device) + gathered = [torch.zeros_like(payload) for _ in range(world_size)] + dist.all_gather(gathered, payload) + + if rank == 0: + gathered_cpu = [tensor.cpu().tolist() for tensor in gathered] + gathered_results.append({ + 'batch_size': batch_size, + 'seq_len': seq_len, + 'attn_implementation': getattr(config, '_attn_implementation', None), + 'baseline_peak_allocated_mib_per_rank': [row[0] for row in gathered_cpu], + 'baseline_peak_reserved_mib_per_rank': [row[1] for row in gathered_cpu], + 'sp_peak_allocated_mib_per_rank': [row[2] for row in gathered_cpu], + 'sp_peak_reserved_mib_per_rank': [row[3] for row in gathered_cpu], + 'baseline_peak_allocated_mib_max': max(row[0] for row in gathered_cpu), + 'sp_peak_allocated_mib_max': max(row[2] for row in gathered_cpu), + }) + + if rank == 0: + torch.save(gathered_results, result_path) + except Exception: + _write_error(error_prefix, rank) + raise + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _run_spawned(worker, world_size: int, cases): + port = _find_free_port() + with tempfile.TemporaryDirectory() as temp_dir: + result_path = os.path.join(temp_dir, 'memory_results.pt') + try: + mp.spawn( + worker, + args=(world_size, port, result_path, cases), + nprocs=world_size, + join=True, + ) + except Exception: + error_logs = [] + for rank in range(world_size): + error_path = f'{result_path}.rank{rank}.err' + if os.path.exists(error_path): + with open(error_path, 'r', encoding='utf-8') as f: + error_logs.append(f'Rank {rank}:\n{f.read()}') + if error_logs: + raise RuntimeError('\n\n'.join(error_logs)) + raise + return torch.load(result_path, weights_only=False) + + +def main(): + if not torch.cuda.is_available(): + raise SystemExit('CUDA is required for the Qwen3.5 SP memory benchmark.') + + world_size = int(os.environ.get('QWEN35_SP_MEMORY_WORLD_SIZE', '2')) + if torch.cuda.device_count() < world_size: + raise SystemExit(f'Need at least {world_size} CUDA devices for the Qwen3.5 SP memory benchmark.') + + cases = _parse_cases() + mode = os.environ.get('QWEN35_SP_MEMORY_MODE', 'both') + results = {} + + if mode in ('linear', 'both'): + results['linear_attention'] = _run_spawned(_run_linear_attention_memory_worker, world_size, cases) + + if mode in ('mixed', 'both'): + results['mixed_text_model'] = _run_spawned(_run_mixed_text_model_memory_worker, world_size, cases) + + output = json.dumps(results, indent=2) + print(output) + + result_path = os.environ.get('QWEN35_SP_MEMORY_RESULT_PATH') + if result_path: + with open(result_path, 'w', encoding='utf-8') as f: + f.write(output) + + +if __name__ == '__main__': + main() diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index 66ad0efc..8dede7bd 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -1,5 +1,10 @@ -import numpy as np +import json +import math +import os from functools import partial + +import numpy as np +import torch from peft import LoraConfig import twinkle @@ -7,12 +12,22 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel -from twinkle.model.transformers.models.qwen3_5 import TwinkleQwen3_5ForCausalLM +from twinkle.model.transformers.models import TwinkleQwen3_5ForCausalLM from twinkle.preprocessor import SelfCognitionProcessor +# TWINKLE_PROFILE_MODULE_MEMORY=1 \ +# TWINKLE_PROFILE_MODULE_MEMORY_STEP=0 \ +# TWINKLE_PROFILE_MODULE_MEMORY_RANK=0 \ +# python cookbook/transformers/sp_fsdp_dense.py logger = get_logger() MODEL_ID = 'ms://Qwen/Qwen3.5-4B' DATASETS = 'ms://swift/self-cognition' +TRAIN_SEED = int(os.environ.get('TRAIN_SEED', '1234')) +TRAIN_DETERMINISTIC = os.environ.get('TRAIN_DETERMINISTIC', '1') == '1' +TRAIN_NUM_WORKERS = int(os.environ.get('TRAIN_NUM_WORKERS', '0')) +TRAIN_SHUFFLE = os.environ.get('TRAIN_SHUFFLE', '0') == '1' +TRAIN_ATTENTION_DROPOUT = float(os.environ.get('TRAIN_ATTENTION_DROPOUT', '0.0')) +TRAIN_LORA_DROPOUT = float(os.environ.get('TRAIN_LORA_DROPOUT', '0.0')) device_group = [DeviceGroup( name='default', @@ -34,6 +49,205 @@ global_device_mesh=device_mesh, lazy_collect=False, ) +twinkle.framework_util.seed_everything(TRAIN_SEED, TRAIN_DETERMINISTIC) + + +def _memory_api(): + device_type = Platform.get_platform().device_prefix() + device_api = getattr(torch, device_type, None) + if device_api is None or not hasattr(device_api, 'is_available') or not device_api.is_available(): + return None, None + return device_type, device_api + + +def _format_mib(num_bytes): + return f'{num_bytes / (1024 ** 2):.1f} MiB' + + +def _get_memory_stats(): + device_type, device_api = _memory_api() + if device_api is None: + return {} + + if hasattr(device_api, 'synchronize'): + device_api.synchronize() + + current_device = device_api.current_device() if hasattr(device_api, 'current_device') else 0 + return { + 'rank': Platform.get_rank(), + 'local_rank': Platform.get_local_rank(), + 'device': f'{device_type}:{current_device}', + 'mem_allocated': _format_mib(device_api.memory_allocated()), + 'mem_reserved': _format_mib(device_api.memory_reserved()), + 'mem_peak_allocated': _format_mib(device_api.max_memory_allocated()), + 'mem_peak_reserved': _format_mib(device_api.max_memory_reserved()), + } + + +def _reset_peak_memory_stats(): + _, device_api = _memory_api() + if device_api is not None and hasattr(device_api, 'reset_peak_memory_stats'): + device_api.reset_peak_memory_stats() + + +def _memory_mib_value(num_bytes): + return round(float(num_bytes) / (1024 ** 2), 3) + + +def _shape_of(value): + if torch.is_tensor(value): + return tuple(value.shape) + if isinstance(value, (list, tuple)): + for item in value: + shape = _shape_of(item) + if shape is not None: + return shape + return None + + +class _ModuleMemoryProfiler: + + TARGET_CLASS_NAMES = { + 'TwinkleQwen3_5DecoderLayer', + 'TwinkleQwen3_5GatedDeltaNet', + 'Qwen3_5Attention', + 'Qwen3_5MLP', + } + + def __init__(self, model: TransformersModel): + self.model = model + self.enabled = os.environ.get('TWINKLE_PROFILE_MODULE_MEMORY') == '1' + self.target_step = int(os.environ.get('TWINKLE_PROFILE_MODULE_MEMORY_STEP', '0')) + self.target_rank = int(os.environ.get('TWINKLE_PROFILE_MODULE_MEMORY_RANK', '0')) + self.max_records = int(os.environ.get('TWINKLE_PROFILE_MODULE_MEMORY_LIMIT', '16')) + self.active = False + self.handles = [] + self.entries = {} + self.records = [] + + def attach(self): + if not self.enabled or Platform.get_rank() != self.target_rank: + return + base_model = getattr(self.model, 'model', None) + if base_model is None: + return + for name, module in base_model.named_modules(): + if module.__class__.__name__ not in self.TARGET_CLASS_NAMES: + continue + self.handles.append(module.register_forward_pre_hook(self._make_pre_hook(name), with_kwargs=True)) + self.handles.append(module.register_forward_hook(self._make_post_hook(name), with_kwargs=True)) + + def _make_pre_hook(self, name): + + def _hook(module, args, kwargs): + if not self.active: + return + _, device_api = _memory_api() + if device_api is None: + return + if hasattr(device_api, 'synchronize'): + device_api.synchronize() + self.entries[id(module)] = { + 'name': name, + 'class_name': module.__class__.__name__, + 'input_shape': _shape_of(args) or _shape_of(kwargs), + 'pre_allocated_mib': _memory_mib_value(device_api.memory_allocated()), + 'pre_reserved_mib': _memory_mib_value(device_api.memory_reserved()), + 'pre_peak_allocated_mib': _memory_mib_value(device_api.max_memory_allocated()), + 'pre_peak_reserved_mib': _memory_mib_value(device_api.max_memory_reserved()), + } + + return _hook + + def _make_post_hook(self, name): + + def _hook(module, args, kwargs, output): + del args, kwargs + if not self.active: + return + _, device_api = _memory_api() + if device_api is None: + return + if hasattr(device_api, 'synchronize'): + device_api.synchronize() + entry = self.entries.pop(id(module), None) + if entry is None: + return + post_allocated_mib = _memory_mib_value(device_api.memory_allocated()) + post_reserved_mib = _memory_mib_value(device_api.memory_reserved()) + post_peak_allocated_mib = _memory_mib_value(device_api.max_memory_allocated()) + post_peak_reserved_mib = _memory_mib_value(device_api.max_memory_reserved()) + entry.update({ + 'output_shape': _shape_of(output), + 'post_allocated_mib': post_allocated_mib, + 'post_reserved_mib': post_reserved_mib, + 'post_peak_allocated_mib': post_peak_allocated_mib, + 'post_peak_reserved_mib': post_peak_reserved_mib, + 'delta_allocated_mib': round(post_allocated_mib - entry['pre_allocated_mib'], 3), + 'delta_reserved_mib': round(post_reserved_mib - entry['pre_reserved_mib'], 3), + 'delta_peak_allocated_mib': round(post_peak_allocated_mib - entry['pre_peak_allocated_mib'], 3), + 'delta_peak_reserved_mib': round(post_peak_reserved_mib - entry['pre_peak_reserved_mib'], 3), + }) + self.records.append(entry) + + return _hook + + def start_step(self, step: int): + self.active = self.enabled and Platform.get_rank() == self.target_rank and step == self.target_step + self.entries.clear() + self.records.clear() + if self.active: + _reset_peak_memory_stats() + + def finish_step(self, step: int): + if not self.active: + return + step_memory = _get_memory_stats() + sorted_records = sorted(self.records, key=lambda item: item['delta_peak_allocated_mib'], reverse=True) + logger.info( + 'Module memory profile summary: ' + + json.dumps( + { + 'step': step, + 'rank': Platform.get_rank(), + 'total_step_peak_allocated_mib': step_memory.get('mem_peak_allocated'), + 'total_step_peak_reserved_mib': step_memory.get('mem_peak_reserved'), + 'top_forward_modules_by_peak_allocated': sorted_records[:self.max_records], + }, + ensure_ascii=False, + )) + self.active = False + + def close(self): + for handle in self.handles: + handle.remove() + self.handles.clear() + + +def _get_runtime_backend_info(model: TransformersModel): + model._ensure_sp_strategy() + + underlying_model = getattr(model, 'model', None) + llm_model = getattr(underlying_model, 'model', underlying_model) + config = getattr(underlying_model, 'config', None) + + attn_implementation = None + attn_implementation_internal = None + if config is not None: + attn_implementation = getattr(config, '_attn_implementation', None) + attn_implementation_internal = getattr(config, '_attn_implementation_internal', None) + + return { + 'model_cls': type(underlying_model).__name__ if underlying_model is not None else None, + 'llm_model_cls': type(llm_model).__name__ if llm_model is not None else None, + 'attn_implementation': attn_implementation, + 'attn_implementation_internal': attn_implementation_internal, + 'requires_cu_seq_lens_q': bool(getattr(llm_model, 'requires_cu_seq_lens_q', False)), + 'sp_enabled': bool(getattr(model, '_enable_sp', False)), + 'ulysses_size': getattr(getattr(model, 'device_mesh', None), 'ulysses_size', None), + 'sp_strategy_enabled': bool(getattr(getattr(model, 'sp_strategy', None), 'enabled', False)), + 'sp_strategy_ulysses_size': getattr(getattr(model, 'sp_strategy', None), 'ulysses_size', None), + } def eval(model): @@ -41,6 +255,8 @@ def eval(model): dataset=partial(create_dataset, data_slice=range(100)), batch_size=4, device_mesh=device_mesh, + num_workers=TRAIN_NUM_WORKERS, + shuffle=False, ) for _, batch in enumerate(dataloader): model.forward_only(inputs=batch, adapter_name='default') @@ -61,6 +277,8 @@ def train(): dataset=partial(create_dataset, data_slice=None), batch_size=8, device_mesh=device_mesh, + num_workers=TRAIN_NUM_WORKERS, + shuffle=TRAIN_SHUFFLE, ) model = TransformersModel( @@ -68,28 +286,58 @@ def train(): model_cls=TwinkleQwen3_5ForCausalLM, device_mesh=device_mesh, strategy='native_fsdp', + attention_dropout=TRAIN_ATTENTION_DROPOUT, ) - lora_config = LoraConfig(target_modules='all-linear') - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1) + lora_config = LoraConfig(target_modules='all-linear', lora_dropout=TRAIN_LORA_DROPOUT) + model.add_adapter_to_model('default', lora_config) + grad_accumulation_steps = model.optimizer_group['default'].gradient_accumulation_steps + num_optimizer_steps = math.ceil(len(dataloader) / grad_accumulation_steps) + log_every_optimizer_steps = 20 model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') model.set_lr_scheduler( scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, - num_training_steps=len(dataloader), + num_training_steps=num_optimizer_steps, adapter_name='default', ) logger.info(model.get_train_configs(adapter_name='default')) - logger.info(f'Total steps: {len(dataloader)}') + logger.info( + f'Total micro steps: {len(dataloader)}, optimizer steps: {num_optimizer_steps}, ' + f'gradient_accumulation_steps: {grad_accumulation_steps}') + logger.info( + 'Reproducibility config: ' + + str({ + 'seed': TRAIN_SEED, + 'deterministic': TRAIN_DETERMINISTIC, + 'dataloader_shuffle': TRAIN_SHUFFLE, + 'dataloader_num_workers': TRAIN_NUM_WORKERS, + 'attention_dropout': TRAIN_ATTENTION_DROPOUT, + 'lora_dropout': TRAIN_LORA_DROPOUT, + })) + logger.info(f'Backend info: {_get_runtime_backend_info(model)}') + logger.info(f'Initial memory: {_get_memory_stats()}') + _reset_peak_memory_stats() + module_memory_profiler = _ModuleMemoryProfiler(model) + module_memory_profiler.attach() for step, batch in enumerate(dataloader): + module_memory_profiler.start_step(step) model.forward_backward(inputs=batch, adapter_name='default') + module_memory_profiler.finish_step(step) model.clip_grad_and_step(adapter_name='default') - if step % 20 == 0: + optimizer_step = step // grad_accumulation_steps + is_optimizer_boundary = (step + 1) % grad_accumulation_steps == 0 + if is_optimizer_boundary and optimizer_step % log_every_optimizer_steps == 0: metric = model.calculate_metric(is_training=True, adapter_name='default') - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') + metric.update(_get_memory_stats()) + optimizer_step = metric.get('iters', optimizer_step) + logger.info( + f'Current is optimizer step {optimizer_step} of {num_optimizer_steps} ' + f'(micro step {step} of {len(dataloader)}), metric: {metric}') model.save('last-checkpoint', interval=1) + module_memory_profiler.close() if __name__ == '__main__': diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index b3ce4f0f..0a591098 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -45,14 +45,31 @@ def __init__(self, self.max_retries = kwargs.pop('max_retries', 20) self.min_batch_size = min_batch_size if device_mesh is not None: - assert batch_size >= device_mesh.data_world_size and batch_size % device_mesh.data_world_size == 0 - self.batch_size = batch_size + required_world_size = self._required_data_world_size(device_mesh) + assert batch_size >= required_world_size and batch_size % required_world_size == 0 + self.batch_size = self._resolve_runtime_batch_size(batch_size, device_mesh) self.dataloader_params = kwargs - self.dataloader_params['batch_size'] = batch_size + self.dataloader_params['batch_size'] = self.batch_size self.device_mesh = device_mesh self.processor: Optional[InputProcessor] = None self._set_work_init_fn() + @staticmethod + def _required_data_world_size(device_mesh: Optional[DeviceMesh]) -> int: + if device_mesh is None: + return 1 + if (device_mesh.ulysses_size or 1) > 1: + return device_mesh.raw_data_world_size + return device_mesh.data_world_size + + def _resolve_runtime_batch_size(self, batch_size: int, device_mesh: Optional[DeviceMesh]) -> int: + if device_mesh is None: + return batch_size + ulysses_size = device_mesh.ulysses_size or 1 + if ulysses_size <= 1: + return batch_size + return batch_size // ulysses_size + def _set_work_init_fn(self): num_workers = self.dataloader_params.get('num_workers', 2) self.dataloader_params['worker_init_fn'] = partial( diff --git a/src/twinkle/model/transformers/__init__.py b/src/twinkle/model/transformers/__init__.py index 9ffe9866..2ce17822 100644 --- a/src/twinkle/model/transformers/__init__.py +++ b/src/twinkle/model/transformers/__init__.py @@ -1,3 +1,37 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from .multi_lora_transformers import MultiLoraTransformersModel -from .transformers import TransformersModel +from typing import TYPE_CHECKING + +from twinkle.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .models import ( + TwinkleQwen3_5DecoderLayer, + TwinkleQwen3_5ForCausalLM, + TwinkleQwen3_5GatedDeltaNet, + TwinkleQwen3_5PreTrainedModel, + TwinkleQwen3_5TextModel, + ) + from .multi_lora_transformers import MultiLoraTransformersModel + from .transformers import TransformersModel +else: + _import_structure = { + 'transformers': ['TransformersModel'], + 'multi_lora_transformers': ['MultiLoraTransformersModel'], + 'models': [ + 'TwinkleQwen3_5PreTrainedModel', + 'TwinkleQwen3_5TextModel', + 'TwinkleQwen3_5DecoderLayer', + 'TwinkleQwen3_5GatedDeltaNet', + 'TwinkleQwen3_5ForCausalLM', + ], + } + + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, # noqa + extra_objects={}, + ) diff --git a/src/twinkle/model/transformers/models/__init__.py b/src/twinkle/model/transformers/models/__init__.py index 85b3e739..68d7103b 100644 --- a/src/twinkle/model/transformers/models/__init__.py +++ b/src/twinkle/model/transformers/models/__init__.py @@ -1 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from .qwen3_5 import ( + TwinkleQwen3_5DecoderLayer, + TwinkleQwen3_5ForCausalLM, + TwinkleQwen3_5GatedDeltaNet, + TwinkleQwen3_5PreTrainedModel, + TwinkleQwen3_5TextModel, +) + +__all__ = [ + 'TwinkleQwen3_5PreTrainedModel', + 'TwinkleQwen3_5TextModel', + 'TwinkleQwen3_5DecoderLayer', + 'TwinkleQwen3_5GatedDeltaNet', + 'TwinkleQwen3_5ForCausalLM', +] diff --git a/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py index c44f21e0..369cae73 100644 --- a/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -75,6 +75,9 @@ def _sp_is_enabled(sequence_parallel_context: Any | None) -> bool: def _get_sp_rank(sequence_parallel_context: Any | None) -> int: if not _sp_is_enabled(sequence_parallel_context): return 0 + rank = getattr(sequence_parallel_context, 'rank', None) + if rank is not None: + return int(rank) import torch.distributed as dist return dist.get_rank(sequence_parallel_context.sp_group) @@ -100,6 +103,107 @@ def _head_to_seq_shard(tensor: torch.Tensor, sequence_parallel_context: Any | No return _SeqAllToAll.apply(sequence_parallel_context.sp_group, tensor, 1, 2) +def _resolve_local_padding_mask( + attention_mask: torch.Tensor | None, + seq_len: int, + sequence_parallel_context: Any | None = None, +) -> torch.Tensor | None: + if attention_mask is None or attention_mask.dim() != 2: + return attention_mask + if attention_mask.shape[-1] == seq_len: + return attention_mask + if _sp_is_enabled(sequence_parallel_context): + sp_world_size = int(sequence_parallel_context.sp_world_size) + full_seq_len = attention_mask.shape[-1] + if full_seq_len % sp_world_size == 0: + local_seq_len = full_seq_len // sp_world_size + if local_seq_len == seq_len: + sp_rank = _get_sp_rank(sequence_parallel_context) + start = sp_rank * local_seq_len + end = start + local_seq_len + return attention_mask[:, start:end].contiguous() + if attention_mask.shape[-1] >= seq_len: + return attention_mask[:, :seq_len].contiguous() + return attention_mask + + +def _flatten_varlen_batch(tensor: torch.Tensor) -> torch.Tensor: + return tensor.reshape(1, tensor.shape[0] * tensor.shape[1], *tensor.shape[2:]) + + +def _pad_or_trim_2d_tensor(tensor: torch.Tensor | None, target_len: int, pad_value: int) -> torch.Tensor | None: + if tensor is None: + return None + if tensor.dim() == 3: + tensor = tensor[0] + if tensor.shape[-1] == target_len: + return tensor + if tensor.shape[-1] > target_len: + return tensor[..., :target_len].contiguous() + pad_shape = (*tensor.shape[:-1], target_len - tensor.shape[-1]) + pad_tensor = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device) + return torch.cat((tensor, pad_tensor), dim=-1) + + +def _build_varlen_metadata( + *, + position_ids: torch.Tensor | None, + attention_mask: torch.Tensor | None, + full_seq_len: int, +) -> tuple[torch.Tensor, torch.Tensor]: + position_ids = _pad_or_trim_2d_tensor(position_ids, full_seq_len, pad_value=-1) + + if position_ids is not None: + valid_mask = position_ids != -1 + elif attention_mask is not None: + attention_mask = _pad_or_trim_2d_tensor(attention_mask, full_seq_len, pad_value=0) + valid_mask = attention_mask != 0 + else: + raise ValueError('Varlen metadata requires at least one of position_ids or attention_mask.') + + cu_seqlens = [0] + total = 0 + for row_idx in range(valid_mask.shape[0]): + if position_ids is not None: + valid_positions = position_ids[row_idx][valid_mask[row_idx]] + if valid_positions.numel() == 0: + continue + seq_start_indices = torch.where(valid_positions == 0)[0] + if seq_start_indices.numel() == 0 or seq_start_indices[0].item() != 0: + seq_start_indices = torch.cat([ + torch.tensor([0], device=valid_positions.device, dtype=seq_start_indices.dtype), + seq_start_indices, + ]) + seq_end_indices = torch.cat([ + seq_start_indices[1:], + torch.tensor([valid_positions.numel()], device=valid_positions.device, dtype=seq_start_indices.dtype), + ]) + seq_lengths = (seq_end_indices - seq_start_indices).tolist() + else: + seq_lengths = [int(valid_mask[row_idx].sum().item())] + for seq_length in seq_lengths: + if seq_length <= 0: + continue + total += int(seq_length) + cu_seqlens.append(total) + return valid_mask, torch.tensor(cu_seqlens, device=valid_mask.device, dtype=torch.int32) + + +def _pack_varlen_tensor(tensor: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor: + return tensor[valid_mask].unsqueeze(0) + + +def _unpack_varlen_tensor( + packed_tensor: torch.Tensor, + valid_mask: torch.Tensor, + batch_size: int, + seq_len: int, +) -> torch.Tensor: + output = packed_tensor.new_zeros((batch_size, seq_len, *packed_tensor.shape[2:])) + output[valid_mask] = packed_tensor.squeeze(0) + return output + + class TwinkleQwen3_5GatedDeltaNet(hf_qwen35.Qwen3_5GatedDeltaNet): def __init__(self, config: Qwen3_5TextConfig, layer_idx: int): @@ -183,6 +287,7 @@ def forward( cu_seq_lens_q: torch.Tensor | None = None, sequence_parallel_context: Any | None = None, ): + attention_mask = _resolve_local_padding_mask(attention_mask, hidden_states.shape[1], sequence_parallel_context) hidden_states = hf_qwen35.apply_mask_to_padding_states(hidden_states, attention_mask) batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = ( @@ -203,6 +308,7 @@ def forward( z = self.in_proj_z(hidden_states).reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) + full_attention_mask = attention_mask sp_enabled = _sp_is_enabled(sequence_parallel_context) if sp_enabled: @@ -245,6 +351,21 @@ def forward( a = a.reshape(batch_size, seq_len, self.num_v_heads) conv_weight = self.conv1d.weight.squeeze(1) + packed_valid_mask = None + packed_cu_seqlens = cu_seq_lens_q + packed_seq_len = mixed_qkv.shape[1] + use_varlen_pack = cu_seq_lens_q is not None and not use_precomputed_states + if use_varlen_pack: + full_position_ids = getattr(sequence_parallel_context, 'real_position_ids', None) + packed_valid_mask, packed_cu_seqlens = _build_varlen_metadata( + position_ids=full_position_ids, + attention_mask=full_attention_mask, + full_seq_len=packed_seq_len, + ) + mixed_qkv = _pack_varlen_tensor(mixed_qkv, packed_valid_mask) + b = _pack_varlen_tensor(b, packed_valid_mask) + a = _pack_varlen_tensor(a, packed_valid_mask) + if use_precomputed_states: if conv_state is None: raise RuntimeError('Qwen3.5 decode requires initialized convolution state.') @@ -255,12 +376,13 @@ def forward( mixed_qkv.transpose(1, 2).contiguous(), (self.conv_kernel_size - mixed_qkv.shape[1], 0), ) - mixed_qkv = self._apply_varlen_conv(mixed_qkv, conv_weight, cu_seq_lens_q) + mixed_qkv = self._apply_varlen_conv(mixed_qkv, conv_weight, packed_cu_seqlens) query, key, value = torch.split(mixed_qkv, [local_key_dim, local_key_dim, local_value_dim], dim=-1) - query = query.reshape(batch_size, query.shape[1], local_num_k_heads, self.head_k_dim) - key = key.reshape(batch_size, key.shape[1], local_num_k_heads, self.head_k_dim) - value = value.reshape(batch_size, value.shape[1], local_num_v_heads, self.head_v_dim) + qkv_batch_size = 1 if use_varlen_pack else batch_size + query = query.reshape(qkv_batch_size, query.shape[1], local_num_k_heads, self.head_k_dim) + key = key.reshape(qkv_batch_size, key.shape[1], local_num_k_heads, self.head_k_dim) + value = value.reshape(qkv_batch_size, value.shape[1], local_num_v_heads, self.head_v_dim) beta = b.sigmoid() if sp_enabled: @@ -296,12 +418,14 @@ def forward( initial_state=None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seq_lens_q, + cu_seqlens=packed_cu_seqlens, ) if cache_params is not None: cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + if use_varlen_pack: + core_attn_out = _unpack_varlen_tensor(core_attn_out, packed_valid_mask, batch_size, packed_seq_len) core_attn_out = _head_to_seq_shard(core_attn_out, sequence_parallel_context) core_attn_out = self.norm(core_attn_out.reshape(-1, self.head_v_dim), z.reshape(-1, self.head_v_dim)) core_attn_out = core_attn_out.reshape(batch_size, seq_len, self.value_dim) @@ -406,6 +530,19 @@ def _update_linear_attn_mask(self, attention_mask, cache_position): linear_attn_mask = None return linear_attn_mask + def _resolve_linear_attn_mask( + self, + attention_mask: torch.Tensor | None, + text_position_ids: torch.LongTensor | None, + seq_len: int, + ) -> torch.Tensor | None: + if attention_mask is not None and attention_mask.dim() == 2 and attention_mask.shape[-1] == seq_len: + return attention_mask + if text_position_ids is None: + return attention_mask if attention_mask is not None and attention_mask.dim() == 2 else None + dtype = attention_mask.dtype if attention_mask is not None else torch.int64 + return (text_position_ids != -1).to(dtype=dtype) + @merge_with_config_defaults @capture_outputs def forward( @@ -454,10 +591,13 @@ def forward( past_key_values=past_key_values, position_ids=text_position_ids, ) - linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) sp_context = self._sequence_parallel_context + linear_attn_mask = self._update_linear_attn_mask( + self._resolve_linear_attn_mask(attention_mask, text_position_ids, hidden_states.shape[1]), + cache_position, + ) if _sp_is_enabled(sp_context) and self.requires_cu_seq_lens_q and cu_seq_lens_q is None: raise ValueError('TwinkleQwen3_5TextModel requires cu_seq_lens_q when sequence parallel is enabled.') diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 6033d943..a0a35562 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -14,7 +14,7 @@ from twinkle.processor import InputProcessor from ..multi_lora import MultiLora from .strategy import AccelerateStrategy -from .transformers import OptimizerGroup, TransformersModel +from .transformers import OptimizerGroup, TransformersModel, _default_gradient_accumulation_steps_for_device_mesh @remote_class() @@ -184,7 +184,9 @@ def add_adapter_to_model(self, adapter_name: str, config_or_dir: Union[PeftConfi self.optimizer_group[adapter_name] = self._construct_default_optimizer_group() self.optimizer_group[adapter_name].adapter_name = adapter_name self.optimizer_group[adapter_name].adapter_config = config_or_dir - _gas_default = kwargs.get('gradient_accumulation_steps', 1) + _gas_default = kwargs.get('gradient_accumulation_steps') + if _gas_default is None: + _gas_default = _default_gradient_accumulation_steps_for_device_mesh(self.device_mesh) self.optimizer_group[adapter_name].gradient_accumulation_steps = _gas_default self._default_tokenizer = self.optimizer_group[adapter_name].template.processor self.multi_adapter.acquire_lora(tenant_adapter_name=adapter_name, config=config_or_dir) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 40bb843b..35b95f40 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -23,6 +23,29 @@ def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor): return cu_seqlens +def get_flattened_cu_seqlens_from_position_ids(position_ids: torch.LongTensor): + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + if position_ids.dim() != 2: + raise ValueError(f'Expected 1D or 2D position_ids, got shape={tuple(position_ids.shape)}') + + device = position_ids.device + cu_seqlens = [0] + total = 0 + for row in position_ids: + row = row.clone() + row[row < 0] = 0 + seq_start_indices = torch.where(row == 0)[0] + if seq_start_indices.numel() == 0 or seq_start_indices[0].item() != 0: + seq_start_indices = torch.cat([torch.tensor([0], device=device, dtype=seq_start_indices.dtype), seq_start_indices]) + seq_end_indices = torch.cat([seq_start_indices[1:], torch.tensor([len(row)], device=device)]) + seq_lengths = (seq_end_indices - seq_start_indices).tolist() + for seq_length in seq_lengths: + total += int(seq_length) + cu_seqlens.append(total) + return torch.tensor(cu_seqlens, device=device, dtype=torch.long) + + @dataclass(frozen=True) class SequenceParallelContext: sp_group: Optional[dist.ProcessGroup] @@ -317,14 +340,17 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, a else: query_layer, key_layer, value_layer = query, key, value - position_ids = kwargs.pop('position_ids') + position_ids = kwargs.pop('position_ids', None) if position_ids is not None: - shape0 = position_ids.shape[0] - position_ids_output = torch.empty((shape0 * self.sequence_parallel.sp_world_size, position_ids.shape[1]), - dtype=position_ids.dtype, - device=position_ids.device) + position_ids = position_ids.contiguous() + gathered_shape = (self.sequence_parallel.sp_world_size, *position_ids.shape) + position_ids_output = torch.empty( + gathered_shape, + dtype=position_ids.dtype, + device=position_ids.device, + ) dist.all_gather_into_tensor(position_ids_output, position_ids, group=self.sequence_parallel._sp_group) - position_ids = torch.cat(position_ids_output.split(shape0, dim=0), dim=1) + position_ids = torch.cat(tuple(position_ids_output.unbind(dim=0)), dim=-1).contiguous() context_layer = self.local_attn( query_layer, key_layer, value_layer, attention_mask, *args, position_ids=position_ids, **kwargs) @@ -353,6 +379,7 @@ def __init__(self): self._sp_group = None self.num_heads = None self.causal_mask_func = None + self.attn_implementation = None self.extra_kwargs = {} self.requires_cu_seq_lens_q = False self._bound_llm_model = None @@ -514,6 +541,12 @@ def _attention(query, key, value, *args, **kwargs): kwargs['cu_seq_lens_k'] = cu_seqlens kwargs['max_length_q'] = max_seqlen kwargs['max_length_k'] = max_seqlen + else: + # Dense, non-packed SP path should not forward position_ids into FA2. + # Qwen3.5 has already applied RoPE before entering the attention interface, and keeping + # position_ids here can make HF's FA2 helper mis-detect packed/varlen mode, especially when + # batch_size == 1 and position_ids carries mRoPE-style leading dimensions. + kwargs.pop('position_ids', None) return ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'](module, query, key, value, *args, **kwargs)[0] @@ -652,6 +685,10 @@ def prepare( else: if hasattr(llm_model, '_update_causal_mask'): self.causal_mask_func = llm_model._update_causal_mask + self.attn_implementation = ( + get_config_attr(model.config, '_attn_implementation') + or get_config_attr(model.config, '_attn_implementation_internal') + ) if not SequenceParallel._global_inited: # these operations are global initializations and patches @@ -775,26 +812,28 @@ def pad_and_split_inputs(self, loss_scale = self.pad(loss_scale, padding_value=0., position_ids=real_position_ids) if real_position_ids is not None: real_position_ids = self.pad(real_position_ids, padding_value=-1, position_ids=real_position_ids) - # Build a 2D attention_mask whenever we padded for SP alignment so FlashAttention2 can unpad correctly. - # For packed batches (batch_size==1 with multiple position_id resets), relying on position_ids alone is - # unsafe if we also appended SP-alignment padding (position_ids=-1), because HF's FA2 varlen path will - # include the padded tail in the last segment when attention_mask is None. - if (input_ids is not None or input_embeds is not None) and batch_size > 1: - # not padding_free, so not ring-attention + # Preserve a 2D attention mask only when there is real padding to describe. + # For dense batches, FA2/FA3 should keep `attention_mask=None`; otherwise HF routes the kernel through + # its varlen/unpadding path and can introduce unnecessary overhead or invalid accesses in SP mode. + if input_ids is not None or input_embeds is not None: inputs = input_ids if input_ids is not None else input_embeds - attn_shape = inputs.shape[1] # The sequence length + attn_shape = inputs.shape[1] if attention_mask is None: - # Mask out padded positions introduced by sequence-parallel padding. - # `real_position_ids` is padded with `-1` (see above), so use it to build a valid-token mask. - attention_mask = (real_position_ids != -1).to(dtype=torch.int64) - # no need position_ids here, because padding_free does not need attention_mask, - # so this is not ring-attention - attention_mask = self.pad(attention_mask, padding_value=0) - cache_position = torch.arange(0, attn_shape, device=inputs.device) - # pad attention mask to 4d to avoid calculation errors - if hasattr(self, 'causal_mask_func') and self.causal_mask_func is not None: - attention_mask = self.causal_mask_func(attention_mask, inputs.to(self.model_dtype), cache_position, - None, None) + has_padding = bool(real_position_ids is not None and torch.any(real_position_ids == -1)) + if has_padding: + attention_mask = (real_position_ids != -1).to(dtype=torch.int64) + else: + has_padding = not bool(torch.all(attention_mask != 0)) + if not has_padding: + attention_mask = None + if attention_mask is not None: + attention_mask = self.pad(attention_mask, padding_value=0) + if self.attn_implementation not in ('flash_attention_2', 'flash_attention_3'): + cache_position = torch.arange(0, attn_shape, device=inputs.device) + # SDPA/eager-style paths still expect a fully materialized causal mask here. + if hasattr(self, 'causal_mask_func') and self.causal_mask_func is not None: + attention_mask = self.causal_mask_func( + attention_mask, inputs.to(self.model_dtype), cache_position, None, None) if extra_split_values is not None: for (tensor, pad_value, split_dim) in extra_split_values: extra_values.append( @@ -875,8 +914,10 @@ def prepare_inputs(self, inputs): """ position_ids = None input_ids = inputs.get('input_ids') + inputs_embeds = inputs.get('inputs_embeds') position_ids = inputs.get('position_ids') - if position_ids is not None and input_ids is not None and position_ids.shape[0] == input_ids.shape[0]: + batch_source = input_ids if input_ids is not None else inputs_embeds + if position_ids is not None and batch_source is not None and position_ids.shape[0] == batch_source.shape[0]: self.extra_kwargs['position_ids'] = position_ids.clone() self.extra_kwargs['is_packed'] = self._is_packed_position_ids(position_ids) if input_ids is not None: @@ -885,9 +926,7 @@ def prepare_inputs(self, inputs): self._bound_llm_model.set_sequence_parallel_context(self._build_context()) if self.requires_cu_seq_lens_q and position_ids is not None: padded_position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids, dim=-1) - padded_position_ids = padded_position_ids.clone() - padded_position_ids[padded_position_ids < 0] = 0 - inputs['cu_seq_lens_q'] = get_cu_seqlens_from_position_ids(padded_position_ids).to(torch.int32) + inputs['cu_seq_lens_q'] = get_flattened_cu_seqlens_from_position_ids(padded_position_ids).to(torch.int32) if 'labels' in inputs: labels = inputs['labels'] _, _, labels, _, _, _, _ = self.pad_and_split_inputs( diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index d00e80ed..c6eb569a 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -151,6 +151,13 @@ def calculate_metrics(self, is_training): DEFAULT_WEIGHT_DECAY = 0.01 +def _default_gradient_accumulation_steps_for_device_mesh(device_mesh: Optional[DeviceMesh]) -> int: + if device_mesh is None: + return 1 + ulysses_size = getattr(device_mesh, 'ulysses_size', None) or 1 + return ulysses_size if ulysses_size > 1 else 1 + + @remote_class() class TransformersModel(TwinkleModel, PreTrainedModel, CheckpointEngineMixin): """The transformers model wrapper. @@ -352,6 +359,7 @@ def _construct_default_optimizer_group(self): loss_instance=CrossEntropyLoss(reduction='sum'), template=Template(self.tokenizer_id), processor=InputProcessor(self.device_mesh), + gradient_accumulation_steps=_default_gradient_accumulation_steps_for_device_mesh(self.device_mesh), _device_mesh=self.device_mesh, ) @@ -1010,7 +1018,9 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str self._construct_default_optimizer_group()) self.optimizer_group[adapter_name].adapter_name = adapter_name self.optimizer_group[adapter_name].adapter_config = config - _gas_default = kwargs.get('gradient_accumulation_steps', 1) + _gas_default = kwargs.get('gradient_accumulation_steps') + if _gas_default is None: + _gas_default = _default_gradient_accumulation_steps_for_device_mesh(self.device_mesh) self.optimizer_group[adapter_name].gradient_accumulation_steps = _gas_default self._default_tokenizer = self.optimizer_group[adapter_name].template.processor self.active_group = adapter_name diff --git a/src/twinkle/utils/device_mesh.py b/src/twinkle/utils/device_mesh.py index 9f5aa9e7..d8bbfc2e 100644 --- a/src/twinkle/utils/device_mesh.py +++ b/src/twinkle/utils/device_mesh.py @@ -352,18 +352,22 @@ def get_data_rank_from_global_rank(self, global_rank: int) -> int: @property def data_world_size(self) -> int: """Consider all dp/fsdp ranks, uses to determine how to distribute the data""" - dp_world_size = self.dp_world_size - fsdp_world_size = self.fsdp_world_size + data_world_size = self.raw_data_world_size ulysses_size = self.ulysses_size or 1 - if fsdp_world_size is not None and fsdp_world_size > 1: - data_world_size = dp_world_size * fsdp_world_size if dp_world_size is not None else fsdp_world_size - else: - data_world_size = dp_world_size if dp_world_size is not None else 1 assert data_world_size % ulysses_size == 0, ( f'data_world_size: {data_world_size} cannot be divided by ulysses_size: {ulysses_size}.') return data_world_size // ulysses_size + @property + def raw_data_world_size(self) -> int: + """The data world size before applying ulysses sequence parallel grouping.""" + dp_world_size = self.dp_world_size or 1 + fsdp_world_size = self.fsdp_world_size + if fsdp_world_size is not None and fsdp_world_size > 1: + return dp_world_size * fsdp_world_size + return dp_world_size + def get_slice(self, total_length: int, rank: Optional[int] = None) -> slice: world_size = self.data_world_size if world_size == 1: diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 79bf78ad..778ea171 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -135,6 +135,21 @@ def test_device_mesh_sampler_with_encode(self): assert 'input_ids' in batch assert batch['input_ids'].shape[0] == 2 + def test_device_mesh_sampler_auto_adjusts_batch_for_ulysses(self): + csv_path = str(TEST_DATA_DIR / 'test.csv') + dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path)) + + device_mesh = DeviceMesh( + device_type='cpu', + mesh=np.arange(4).reshape(2, 2), + mesh_dim_names=('dp', 'fsdp'), + ulysses_size=2, + ) + + dataloader = DataLoader(dataset=dataset, batch_size=8, device_mesh=device_mesh) + + assert dataloader.batch_size == 4 + class TestRetrySampler: diff --git a/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py index ed67b4e3..b638687e 100644 --- a/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py +++ b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import tempfile import unittest +from contextlib import ExitStack from types import SimpleNamespace from unittest.mock import patch @@ -25,6 +26,7 @@ def _build_text_config(layer_types=None) -> Qwen3_5TextConfig: hidden_act='silu', max_position_embeddings=128, rms_norm_eps=1e-6, + attention_dropout=0.0, linear_conv_kernel_dim=3, linear_key_head_dim=4, linear_value_head_dim=4, @@ -37,6 +39,16 @@ def _build_text_config(layer_types=None) -> Qwen3_5TextConfig: ) +def _linear_attention_runtime_available() -> bool: + return bool( + torch.cuda.is_available() + and tw_qwen35._FLA_CAUSAL_CONV1D_FN is not None + and tw_qwen35._FLA_CHUNK_GATED_DELTA_RULE is not None + and tw_qwen35._FLA_FUSED_RECURRENT_GATED_DELTA_RULE is not None + and tw_qwen35._HAS_CAUSAL_CONV1D + ) + + class _ContextReceiver: def __init__(self): @@ -93,6 +105,56 @@ def test_from_pretrained_loads_text_only_weights(self): torch.testing.assert_close(tw_outputs.logits, hf_outputs.logits, rtol=0, atol=0) + def test_from_pretrained_loads_mixed_linear_weights(self): + config = _build_text_config(['full_attention', 'linear_attention']) + hf_model = Qwen3_5ForCausalLM(config).eval() + linear_param_names = [ + 'model.layers.1.linear_attn.in_proj_qkv.weight', + 'model.layers.1.linear_attn.in_proj_z.weight', + 'model.layers.1.linear_attn.out_proj.weight', + ] + + with tempfile.TemporaryDirectory() as temp_dir: + hf_model.save_pretrained(temp_dir) + with ExitStack() as stack: + stack.enter_context(patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_FN', lambda *args, **kwargs: args[0])) + stack.enter_context( + patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_UPDATE', lambda *args, **kwargs: args[0])) + stack.enter_context( + patch.object( + tw_qwen35, + '_FLA_CHUNK_GATED_DELTA_RULE', + lambda *args, **kwargs: (args[2], None), + )) + stack.enter_context( + patch.object( + tw_qwen35, + '_FLA_FUSED_RECURRENT_GATED_DELTA_RULE', + lambda *args, **kwargs: (args[2], None), + )) + stack.enter_context(patch.object(tw_qwen35, '_HAS_CAUSAL_CONV1D', True)) + tw_model = tw_qwen35.TwinkleQwen3_5ForCausalLM.from_pretrained(temp_dir).eval() + + hf_state_dict = hf_model.state_dict() + tw_state_dict = tw_model.state_dict() + for param_name in linear_param_names: + self.assertIn(param_name, tw_state_dict) + torch.testing.assert_close(tw_state_dict[param_name], hf_state_dict[param_name], rtol=0, atol=0) + + if _linear_attention_runtime_available(): + device = torch.device('cuda:0') + tw_model = tw_qwen35.TwinkleQwen3_5ForCausalLM.from_pretrained(temp_dir).to(device).eval() + input_ids = torch.tensor([[1, 2, 3, 4]], dtype=torch.long, device=device) + position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=device).unsqueeze(0) + with torch.no_grad(): + outputs = tw_model( + input_ids=input_ids, + position_ids=position_ids, + use_cache=False, + return_dict=True, + ) + self.assertEqual(tuple(outputs.logits.shape), (1, 4, config.vocab_size)) + def test_sequence_parallel_prepare_inputs_injects_cu_seq_lens(self): sp = SequenceParallel() sp.world_size = 2 @@ -101,23 +163,64 @@ def test_sequence_parallel_prepare_inputs_injects_cu_seq_lens(self): receiver = _ContextReceiver() sp._bound_llm_model = receiver inputs = { - 'input_ids': torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long), - 'position_ids': torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.long), + 'input_ids': torch.tensor([[1, 2, 3, 4]], dtype=torch.long), + 'position_ids': torch.tensor([[0, 1, 2, 3]], dtype=torch.long), } outputs = sp.prepare_inputs(inputs) self.assertIn('cu_seq_lens_q', outputs) - self.assertTrue(torch.equal(outputs['cu_seq_lens_q'], torch.tensor([0, 5, 6], dtype=torch.int32))) + self.assertTrue(torch.equal(outputs['cu_seq_lens_q'], torch.tensor([0, 4], dtype=torch.int32))) self.assertIsNotNone(receiver.context) self.assertFalse(receiver.context.is_packed) self.assertTrue(torch.equal(receiver.context.real_position_ids, inputs['position_ids'])) - def test_linear_attention_requires_fast_path_dependencies(self): - with self.assertRaises(ImportError): - tw_qwen35.TwinkleQwen3_5TextModel(_build_text_config(['linear_attention'])) + def test_sequence_parallel_prepare_inputs_injects_flattened_cu_seq_lens_for_batched_rows(self): + sp = SequenceParallel() + sp.world_size = 2 + sp.sp_world_size = 2 + sp.requires_cu_seq_lens_q = True + receiver = _ContextReceiver() + sp._bound_llm_model = receiver + inputs = { + 'input_ids': torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.long), + 'position_ids': torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.long), + } + + outputs = sp.prepare_inputs(inputs) + + self.assertIn('cu_seq_lens_q', outputs) + self.assertTrue(torch.equal(outputs['cu_seq_lens_q'], torch.tensor([0, 4, 8], dtype=torch.int32))) + + def test_sequence_parallel_prepare_inputs_tracks_position_ids_for_inputs_embeds(self): + sp = SequenceParallel() + sp.world_size = 2 + sp.sp_world_size = 2 + sp.requires_cu_seq_lens_q = True + receiver = _ContextReceiver() + sp._bound_llm_model = receiver + inputs = { + 'inputs_embeds': torch.randn(2, 4, 8), + 'position_ids': torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.long), + } - def test_linear_attention_sp_uses_cu_seq_lens_and_keeps_z_local(self): + outputs = sp.prepare_inputs(inputs) + + self.assertIn('cu_seq_lens_q', outputs) + self.assertTrue(torch.equal(outputs['cu_seq_lens_q'], torch.tensor([0, 4, 8], dtype=torch.int32))) + self.assertIsNotNone(receiver.context) + self.assertTrue(torch.equal(receiver.context.real_position_ids, inputs['position_ids'])) + + def test_linear_attention_requires_fast_path_dependencies(self): + with patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_FN', None), \ + patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_UPDATE', None), \ + patch.object(tw_qwen35, '_FLA_CHUNK_GATED_DELTA_RULE', None), \ + patch.object(tw_qwen35, '_FLA_FUSED_RECURRENT_GATED_DELTA_RULE', None), \ + patch.object(tw_qwen35, '_HAS_CAUSAL_CONV1D', False): + with self.assertRaises(ImportError): + tw_qwen35.TwinkleQwen3_5TextModel(_build_text_config(['linear_attention'])) + + def test_linear_attention_sp_passes_cu_seq_lens_and_keeps_z_local(self): captured = { 'cu_seqlens': None, 'seq_to_head_calls': 0, @@ -142,14 +245,24 @@ def fake_recurrent_rule(query, key, value, g, beta, initial_state=None, output_f return value, None def fake_seq_to_head(tensor, context): - del context captured['seq_to_head_calls'] += 1 + sp_world_size = context.sp_world_size + rank = context.rank + if tensor.dim() == 4: + local_heads = tensor.shape[2] // sp_world_size + start = rank * local_heads + end = start + local_heads + return tensor[:, :, start:end, :].contiguous() + if tensor.dim() == 3: + local_heads = tensor.shape[2] // sp_world_size + start = rank * local_heads + end = start + local_heads + return tensor[:, :, start:end].contiguous() return tensor def fake_head_to_seq(tensor, context): - del context captured['head_to_seq_calls'] += 1 - return tensor + return tensor.repeat_interleave(context.sp_world_size, dim=2) class DummyNorm(torch.nn.Module): @@ -161,6 +274,7 @@ def forward(self, x, z): patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_UPDATE', lambda *args, **kwargs: args[0]), \ patch.object(tw_qwen35, '_FLA_CHUNK_GATED_DELTA_RULE', fake_chunk_rule), \ patch.object(tw_qwen35, '_FLA_FUSED_RECURRENT_GATED_DELTA_RULE', fake_recurrent_rule), \ + patch.object(tw_qwen35, '_FLA_FUSED_RMS_NORM_GATED', None), \ patch.object(tw_qwen35, '_HAS_CAUSAL_CONV1D', True), \ patch.object(tw_qwen35, '_seq_to_head_shard', side_effect=fake_seq_to_head), \ patch.object(tw_qwen35, '_head_to_seq_shard', side_effect=fake_head_to_seq): @@ -186,12 +300,145 @@ def forward(self, x, z): sequence_parallel_context=context, ) - self.assertEqual(captured['seq_to_head_calls'], 5) - self.assertEqual(captured['head_to_seq_calls'], 1) + self.assertGreater(captured['seq_to_head_calls'], 0) + self.assertGreater(captured['head_to_seq_calls'], 0) self.assertTrue(torch.equal(captured['cu_seqlens'], cu_seq_lens_q)) - self.assertEqual(captured['norm_z_shape'], (hidden_states.shape[0] * hidden_states.shape[1] * config.linear_num_value_heads, config.linear_value_head_dim)) + self.assertEqual( + captured['norm_z_shape'], + (hidden_states.shape[0] * hidden_states.shape[1] * config.linear_num_value_heads, + config.linear_value_head_dim), + ) self.assertEqual(tuple(output.shape), (1, 2, config.hidden_size)) + def test_linear_attention_sp_flattens_batched_varlen_inputs(self): + captured = { + 'query_shape': None, + 'cu_seqlens': None, + } + + def fake_conv(x, weight, bias, activation, seq_idx=None, backend=None, cu_seqlens=None): + del weight, bias, activation, seq_idx, backend + captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None + return x + + def fake_chunk_rule(query, key, value, g, beta, initial_state=None, output_final_state=False, + use_qk_l2norm_in_kernel=False, cu_seqlens=None): + del key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel + captured['query_shape'] = tuple(query.shape) + captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None + return query.new_zeros(query.shape[0], query.shape[1], 4, 4), None + + def fake_recurrent_rule(query, key, value, g, beta, initial_state=None, output_final_state=False, + use_qk_l2norm_in_kernel=False): + del query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel + raise AssertionError('recurrent path should not be used') + + class DummyNorm(torch.nn.Module): + + def forward(self, x, z): + return x + z + + with patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_FN', fake_conv), \ + patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_UPDATE', lambda *args, **kwargs: args[0]), \ + patch.object(tw_qwen35, '_FLA_CHUNK_GATED_DELTA_RULE', fake_chunk_rule), \ + patch.object(tw_qwen35, '_FLA_FUSED_RECURRENT_GATED_DELTA_RULE', fake_recurrent_rule), \ + patch.object(tw_qwen35, '_FLA_FUSED_RMS_NORM_GATED', None), \ + patch.object(tw_qwen35, '_HAS_CAUSAL_CONV1D', True): + config = _build_text_config(['linear_attention']) + module = tw_qwen35.TwinkleQwen3_5GatedDeltaNet(config, layer_idx=0) + module.norm = DummyNorm() + hidden_states = torch.randn(2, 3, config.hidden_size) + attention_mask = torch.ones(2, 3, dtype=torch.int64) + cu_seq_lens_q = torch.tensor([0, 3, 6], dtype=torch.int32) + + output = module( + hidden_states=hidden_states, + attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + ) + + self.assertEqual(captured['query_shape'], (1, 6, 4, 4)) + self.assertTrue(torch.equal(captured['cu_seqlens'], cu_seq_lens_q)) + self.assertEqual(tuple(output.shape), (2, 3, config.hidden_size)) + + def test_linear_attention_sp_uses_local_attention_mask(self): + captured = {'mask': None} + + def fake_conv(x, weight, bias, activation, seq_idx=None, backend=None, cu_seqlens=None): + del weight, bias, activation, seq_idx, backend, cu_seqlens + return x + + def fake_chunk_rule(query, key, value, g, beta, initial_state=None, output_final_state=False, + use_qk_l2norm_in_kernel=False, cu_seqlens=None): + del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel, cu_seqlens + return value, None + + def fake_recurrent_rule(query, key, value, g, beta, initial_state=None, output_final_state=False, + use_qk_l2norm_in_kernel=False): + del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel + return value, None + + with patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_FN', fake_conv), \ + patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_UPDATE', lambda *args, **kwargs: args[0]), \ + patch.object(tw_qwen35, '_FLA_CHUNK_GATED_DELTA_RULE', fake_chunk_rule), \ + patch.object(tw_qwen35, '_FLA_FUSED_RECURRENT_GATED_DELTA_RULE', fake_recurrent_rule), \ + patch.object(tw_qwen35, '_HAS_CAUSAL_CONV1D', True): + config = _build_text_config(['linear_attention']) + model = tw_qwen35.TwinkleQwen3_5TextModel(config) + model.set_sequence_parallel_context( + SequenceParallelContext( + sp_group='dummy_group', + sp_world_size=2, + rank=0, + world_size=2, + real_position_ids=torch.tensor([[0, 1, 2, -1]], dtype=torch.long), + is_packed=False, + )) + + def fake_linear_forward(hidden_states, cache_params=None, cache_position=None, attention_mask=None, + cu_seq_lens_q=None, sequence_parallel_context=None): + del hidden_states, cache_params, cache_position, cu_seq_lens_q, sequence_parallel_context + captured['mask'] = attention_mask.clone() if attention_mask is not None else None + return torch.zeros(1, 2, config.hidden_size) + + with patch.object(model.layers[0].linear_attn, 'forward', side_effect=fake_linear_forward): + _ = model( + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + attention_mask=torch.tensor([[1, 1, 1, 0]], dtype=torch.int64), + position_ids=torch.tensor([[0, -1]], dtype=torch.long), + cache_position=torch.tensor([0, 1], dtype=torch.long), + cu_seq_lens_q=torch.tensor([0, 2], dtype=torch.int32), + use_cache=False, + ) + + self.assertIsNotNone(captured['mask']) + self.assertTrue(torch.equal(captured['mask'], torch.tensor([[1, 0]], dtype=torch.int64))) + + def test_sequence_parallel_drops_dense_attention_mask_for_flash_attention_2(self): + sp = SequenceParallel() + sp.world_size = 2 + sp.sp_world_size = 2 + sp.tokenizer = SimpleNamespace(pad_token_id=0) + sp.model_dtype = torch.bfloat16 + sp.attn_implementation = 'flash_attention_2' + sp.causal_mask_func = lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError('should not build 4d mask')) + + input_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.long) + position_ids = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.long) + + for attention_mask in (None, torch.ones(2, 4, dtype=torch.int64)): + _, _, _, _, resolved_attention_mask, _, _ = sp.pad_and_split_inputs( + input_ids=input_ids, + input_embeds=None, + labels=None, + position_ids=position_ids, + attention_mask=attention_mask, + loss_scale=None, + real_position_ids=position_ids, + ) + self.assertIsNone(resolved_attention_mask) + if __name__ == '__main__': unittest.main() diff --git a/tests/sequence_parallel/test_twinkle_qwen3_5_text_model_parity.py b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model_parity.py new file mode 100644 index 00000000..5e765adb --- /dev/null +++ b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model_parity.py @@ -0,0 +1,369 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import copy +import os +import socket +import tempfile +import traceback +import unittest +from datetime import timedelta +from types import SimpleNamespace + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig +from transformers.utils.import_utils import is_flash_attn_2_available + +from twinkle.model.transformers.models.qwen3_5 import modeling_qwen3_5 as tw_qwen35 +from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallel, SequenceParallelContext +from twinkle.utils import DeviceMesh + + +def _seed_everything(seed: int) -> None: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + + +def _build_linear_parity_config() -> Qwen3_5TextConfig: + return Qwen3_5TextConfig( + vocab_size=64, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + hidden_act='silu', + max_position_embeddings=128, + rms_norm_eps=1e-6, + attention_dropout=0.0, + linear_conv_kernel_dim=3, + linear_key_head_dim=16, + linear_value_head_dim=16, + linear_num_key_heads=2, + linear_num_value_heads=4, + layer_types=['linear_attention'], + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + + +def _build_mixed_parity_config() -> Qwen3_5TextConfig: + config = Qwen3_5TextConfig( + vocab_size=64, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + hidden_act='silu', + max_position_embeddings=128, + rms_norm_eps=1e-6, + attention_dropout=0.0, + linear_conv_kernel_dim=3, + linear_key_head_dim=16, + linear_value_head_dim=16, + linear_num_key_heads=2, + linear_num_value_heads=4, + layer_types=['full_attention', 'linear_attention'], + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + attn_implementation = 'flash_attention_2' if is_flash_attn_2_available() else 'sdpa' + config._attn_implementation = attn_implementation + config._attn_implementation_internal = attn_implementation + return config + + +def _linear_attention_runtime_available() -> bool: + return bool( + torch.cuda.is_available() + and tw_qwen35._FLA_CAUSAL_CONV1D_FN is not None + and tw_qwen35._FLA_CHUNK_GATED_DELTA_RULE is not None + and tw_qwen35._FLA_FUSED_RECURRENT_GATED_DELTA_RULE is not None + and tw_qwen35._HAS_CAUSAL_CONV1D + ) + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(('127.0.0.1', 0)) + return sock.getsockname()[1] + + +def _all_gather_seq(local_tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: + world_size = dist.get_world_size(group) + chunks = [torch.empty_like(local_tensor) for _ in range(world_size)] + dist.all_gather(chunks, local_tensor.contiguous(), group=group) + return torch.cat(chunks, dim=1).contiguous() + + +def _all_reduce_grads(module: torch.nn.Module, group: dist.ProcessGroup) -> None: + for param in module.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad, group=group) + + +def _relative_error(actual: torch.Tensor, expected: torch.Tensor) -> float: + actual_fp32 = actual.detach().to(dtype=torch.float32) + expected_fp32 = expected.detach().to(dtype=torch.float32) + return float((actual_fp32 - expected_fp32).norm() / (expected_fp32.norm() + 1e-12)) + + +def _assert_relative_error(actual: torch.Tensor, expected: torch.Tensor, rel_tol: float, name: str) -> None: + rel = _relative_error(actual, expected) + if rel > rel_tol: + raise AssertionError(f'{name} relative error {rel:.4e} exceeds tolerance {rel_tol:.4e}') + + +def _write_error(error_prefix: str, rank: int) -> None: + with open(f'{error_prefix}.rank{rank}.err', 'w', encoding='utf-8') as f: + f.write(traceback.format_exc()) + + +def _run_linear_attention_parity_worker(rank: int, world_size: int, port: int, error_prefix: str) -> None: + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(port) + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_WORLD_SIZE'] = str(world_size) + torch.cuda.set_device(rank) + + try: + dist.init_process_group( + backend='nccl', + rank=rank, + world_size=world_size, + timeout=timedelta(minutes=10), + ) + + device = torch.device(f'cuda:{rank}') + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + seed = 1234 + batch_size = 2 + seq_len = 8 + local_seq_len = seq_len // world_size + start = rank * local_seq_len + end = start + local_seq_len + + _seed_everything(seed) + config = _build_linear_parity_config() + baseline_module = tw_qwen35.TwinkleQwen3_5GatedDeltaNet(config, layer_idx=0).to(device=device, dtype=dtype).eval() + sp_module = copy.deepcopy(baseline_module).to(device=device, dtype=dtype).eval() + + full_hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device=device, dtype=dtype) + dist.broadcast(full_hidden_states, src=0) + full_position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64, device=device) + cu_seq_lens_q = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=device) + + baseline_hidden_states = full_hidden_states.detach().clone().requires_grad_(True) + baseline_output = baseline_module( + hidden_states=baseline_hidden_states, + attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + ) + baseline_loss = baseline_output.float().square().sum() + baseline_loss.backward() + baseline_input_grad = baseline_hidden_states.grad.detach() + baseline_param_grads = { + name: param.grad.detach().clone() + for name, param in baseline_module.named_parameters() + if param.grad is not None + } + + sp_hidden_states = full_hidden_states[:, start:end].detach().clone().requires_grad_(True) + sp_output = sp_module( + hidden_states=sp_hidden_states, + attention_mask=attention_mask[:, start:end].contiguous(), + cu_seq_lens_q=cu_seq_lens_q, + sequence_parallel_context=SequenceParallelContext( + sp_group=dist.group.WORLD, + sp_world_size=world_size, + rank=rank, + world_size=world_size, + real_position_ids=full_position_ids, + is_packed=False, + ), + ) + sp_loss = sp_output.float().square().sum() + sp_loss.backward() + _all_reduce_grads(sp_module, dist.group.WORLD) + + sp_output_full = _all_gather_seq(sp_output.detach(), dist.group.WORLD) + sp_input_grad_full = _all_gather_seq(sp_hidden_states.grad.detach(), dist.group.WORLD) + + torch.testing.assert_close( + sp_output_full.to(dtype=torch.float32), + baseline_output.detach().to(dtype=torch.float32), + rtol=1e-3, + atol=1e-3, + ) + _assert_relative_error(sp_input_grad_full, baseline_input_grad, 1e-2, 'linear_attention.input_grad') + + for name in ( + 'in_proj_qkv.weight', + 'in_proj_z.weight', + 'out_proj.weight', + ): + _assert_relative_error( + sp_module.get_parameter(name).grad, + baseline_param_grads[name], + 2e-2, + f'linear_attention.{name}', + ) + except Exception: + _write_error(error_prefix, rank) + raise + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _run_mixed_text_model_parity_worker(rank: int, world_size: int, port: int, error_prefix: str) -> None: + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(port) + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_WORLD_SIZE'] = str(world_size) + torch.cuda.set_device(rank) + + try: + dist.init_process_group( + backend='nccl', + rank=rank, + world_size=world_size, + timeout=timedelta(minutes=10), + ) + + device = torch.device(f'cuda:{rank}') + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + seed = 5678 + batch_size = 2 + seq_len = 8 + + _seed_everything(seed) + config = _build_mixed_parity_config() + baseline_model = tw_qwen35.TwinkleQwen3_5TextModel(config).to(device=device, dtype=dtype).eval() + sp_model = copy.deepcopy(baseline_model).to(device=device, dtype=dtype).eval() + + full_inputs_embeds = torch.randn(batch_size, seq_len, config.hidden_size, device=device, dtype=dtype) + dist.broadcast(full_inputs_embeds, src=0) + full_position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + + baseline_inputs_embeds = full_inputs_embeds.detach().clone().requires_grad_(True) + baseline_outputs = baseline_model( + inputs_embeds=baseline_inputs_embeds, + position_ids=full_position_ids, + use_cache=False, + ) + baseline_hidden = baseline_outputs.last_hidden_state + baseline_loss = baseline_hidden.float().square().sum() + baseline_loss.backward() + baseline_input_grad = baseline_inputs_embeds.grad.detach() + baseline_param_grads = { + name: param.grad.detach().clone() + for name, param in baseline_model.named_parameters() + if param.grad is not None + } + + device_mesh = DeviceMesh.from_sizes( + world_size=world_size, + dp_size=world_size, + ulysses_size=world_size, + device_type='cuda', + ) + sp = SequenceParallel() + sp.prepare(world_size, sp_model, SimpleNamespace(pad_token_id=0), device_mesh=device_mesh) + + sp_inputs_embeds = full_inputs_embeds.detach().clone().requires_grad_(True) + sp_inputs = sp.prepare_inputs({ + 'inputs_embeds': sp_inputs_embeds, + 'position_ids': full_position_ids.clone(), + 'use_cache': False, + }) + sp_outputs = sp_model(**sp_inputs) + sp_hidden_local = sp_outputs.last_hidden_state + sp_loss = sp_hidden_local.float().square().sum() + sp_loss.backward() + dist.all_reduce(sp_inputs_embeds.grad, group=dist.group.WORLD) + _all_reduce_grads(sp_model, dist.group.WORLD) + + sp_hidden_full = _all_gather_seq(sp_hidden_local.detach(), dist.group.WORLD) + torch.testing.assert_close( + sp_hidden_full.to(dtype=torch.float32), + baseline_hidden.detach().to(dtype=torch.float32), + rtol=5e-3, + atol=5e-3, + ) + _assert_relative_error(sp_inputs_embeds.grad, baseline_input_grad, 1e-2, 'mixed_text_model.input_grad') + + for name in ( + 'layers.0.self_attn.q_proj.weight', + 'layers.1.linear_attn.in_proj_qkv.weight', + 'layers.1.mlp.gate_proj.weight', + ): + _assert_relative_error( + sp_model.get_parameter(name).grad, + baseline_param_grads[name], + 2e-2, + f'mixed_text_model.{name}', + ) + except Exception: + _write_error(error_prefix, rank) + raise + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +class TestTwinkleQwen35TextModelParity(unittest.TestCase): + + WORLD_SIZE = 2 + + def _run_spawned_parity_test(self, worker) -> None: + port = _find_free_port() + with tempfile.TemporaryDirectory() as temp_dir: + error_prefix = os.path.join(temp_dir, 'parity') + try: + mp.spawn( + worker, + args=(self.WORLD_SIZE, port, error_prefix), + nprocs=self.WORLD_SIZE, + join=True, + ) + except Exception: + error_logs = [] + for rank in range(self.WORLD_SIZE): + error_path = f'{error_prefix}.rank{rank}.err' + if os.path.exists(error_path): + with open(error_path, 'r', encoding='utf-8') as f: + error_logs.append(f'Rank {rank}:\n{f.read()}') + if error_logs: + self.fail('\n\n'.join(error_logs)) + raise + + def test_linear_attention_sp_parity(self): + if not _linear_attention_runtime_available(): + self.skipTest('CUDA + flash-linear-attention + causal-conv1d are required for linear attention parity.') + if torch.cuda.device_count() < self.WORLD_SIZE: + self.skipTest(f'Need at least {self.WORLD_SIZE} CUDA devices for SP parity.') + self._run_spawned_parity_test(_run_linear_attention_parity_worker) + + def test_mixed_text_model_sp_parity(self): + if not _linear_attention_runtime_available(): + self.skipTest('CUDA + flash-linear-attention + causal-conv1d are required for mixed model parity.') + if torch.cuda.device_count() < self.WORLD_SIZE: + self.skipTest(f'Need at least {self.WORLD_SIZE} CUDA devices for SP parity.') + self._run_spawned_parity_test(_run_mixed_text_model_parity_worker) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transformers/test_ulysses_auto_batch_policy.py b/tests/transformers/test_ulysses_auto_batch_policy.py new file mode 100644 index 00000000..e66c398a --- /dev/null +++ b/tests/transformers/test_ulysses_auto_batch_policy.py @@ -0,0 +1,18 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import numpy as np + +from twinkle import DeviceMesh +from twinkle.model.transformers.transformers import _default_gradient_accumulation_steps_for_device_mesh + + +class TestUlyssesAutoBatchPolicy: + + def test_default_gradient_accumulation_steps_scales_with_ulysses(self): + device_mesh = DeviceMesh( + device_type='cpu', + mesh=np.arange(4).reshape(2, 2), + mesh_dim_names=('dp', 'fsdp'), + ulysses_size=2, + ) + + assert _default_gradient_accumulation_steps_for_device_mesh(device_mesh) == 2 From bc739b192cd8abb50f1aee6769284aaf0d04b265 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Tue, 24 Mar 2026 16:45:48 +0800 Subject: [PATCH 3/6] delete unused files --- cookbook/transformers/sp_fsdp_dense.py | 241 +----------------- .../test_ulysses_auto_batch_policy.py | 18 -- 2 files changed, 2 insertions(+), 257 deletions(-) delete mode 100644 tests/transformers/test_ulysses_auto_batch_policy.py diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index 8dede7bd..7928f277 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -1,38 +1,24 @@ -import json import math -import os from functools import partial import numpy as np -import torch from peft import LoraConfig import twinkle -from twinkle import DeviceGroup, DeviceMesh, Platform, get_logger +from twinkle import DeviceGroup, DeviceMesh, get_logger from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel from twinkle.model.transformers.models import TwinkleQwen3_5ForCausalLM from twinkle.preprocessor import SelfCognitionProcessor -# TWINKLE_PROFILE_MODULE_MEMORY=1 \ -# TWINKLE_PROFILE_MODULE_MEMORY_STEP=0 \ -# TWINKLE_PROFILE_MODULE_MEMORY_RANK=0 \ -# python cookbook/transformers/sp_fsdp_dense.py logger = get_logger() MODEL_ID = 'ms://Qwen/Qwen3.5-4B' DATASETS = 'ms://swift/self-cognition' -TRAIN_SEED = int(os.environ.get('TRAIN_SEED', '1234')) -TRAIN_DETERMINISTIC = os.environ.get('TRAIN_DETERMINISTIC', '1') == '1' -TRAIN_NUM_WORKERS = int(os.environ.get('TRAIN_NUM_WORKERS', '0')) -TRAIN_SHUFFLE = os.environ.get('TRAIN_SHUFFLE', '0') == '1' -TRAIN_ATTENTION_DROPOUT = float(os.environ.get('TRAIN_ATTENTION_DROPOUT', '0.0')) -TRAIN_LORA_DROPOUT = float(os.environ.get('TRAIN_LORA_DROPOUT', '0.0')) device_group = [DeviceGroup( name='default', ranks=[0, 1, 2, 3], - device_type=Platform.get_platform().device_prefix(), )] # FSDP + SP validation over 4 GPUs: dp=2, fsdp=2 (SP only affects input slicing) @@ -49,205 +35,6 @@ global_device_mesh=device_mesh, lazy_collect=False, ) -twinkle.framework_util.seed_everything(TRAIN_SEED, TRAIN_DETERMINISTIC) - - -def _memory_api(): - device_type = Platform.get_platform().device_prefix() - device_api = getattr(torch, device_type, None) - if device_api is None or not hasattr(device_api, 'is_available') or not device_api.is_available(): - return None, None - return device_type, device_api - - -def _format_mib(num_bytes): - return f'{num_bytes / (1024 ** 2):.1f} MiB' - - -def _get_memory_stats(): - device_type, device_api = _memory_api() - if device_api is None: - return {} - - if hasattr(device_api, 'synchronize'): - device_api.synchronize() - - current_device = device_api.current_device() if hasattr(device_api, 'current_device') else 0 - return { - 'rank': Platform.get_rank(), - 'local_rank': Platform.get_local_rank(), - 'device': f'{device_type}:{current_device}', - 'mem_allocated': _format_mib(device_api.memory_allocated()), - 'mem_reserved': _format_mib(device_api.memory_reserved()), - 'mem_peak_allocated': _format_mib(device_api.max_memory_allocated()), - 'mem_peak_reserved': _format_mib(device_api.max_memory_reserved()), - } - - -def _reset_peak_memory_stats(): - _, device_api = _memory_api() - if device_api is not None and hasattr(device_api, 'reset_peak_memory_stats'): - device_api.reset_peak_memory_stats() - - -def _memory_mib_value(num_bytes): - return round(float(num_bytes) / (1024 ** 2), 3) - - -def _shape_of(value): - if torch.is_tensor(value): - return tuple(value.shape) - if isinstance(value, (list, tuple)): - for item in value: - shape = _shape_of(item) - if shape is not None: - return shape - return None - - -class _ModuleMemoryProfiler: - - TARGET_CLASS_NAMES = { - 'TwinkleQwen3_5DecoderLayer', - 'TwinkleQwen3_5GatedDeltaNet', - 'Qwen3_5Attention', - 'Qwen3_5MLP', - } - - def __init__(self, model: TransformersModel): - self.model = model - self.enabled = os.environ.get('TWINKLE_PROFILE_MODULE_MEMORY') == '1' - self.target_step = int(os.environ.get('TWINKLE_PROFILE_MODULE_MEMORY_STEP', '0')) - self.target_rank = int(os.environ.get('TWINKLE_PROFILE_MODULE_MEMORY_RANK', '0')) - self.max_records = int(os.environ.get('TWINKLE_PROFILE_MODULE_MEMORY_LIMIT', '16')) - self.active = False - self.handles = [] - self.entries = {} - self.records = [] - - def attach(self): - if not self.enabled or Platform.get_rank() != self.target_rank: - return - base_model = getattr(self.model, 'model', None) - if base_model is None: - return - for name, module in base_model.named_modules(): - if module.__class__.__name__ not in self.TARGET_CLASS_NAMES: - continue - self.handles.append(module.register_forward_pre_hook(self._make_pre_hook(name), with_kwargs=True)) - self.handles.append(module.register_forward_hook(self._make_post_hook(name), with_kwargs=True)) - - def _make_pre_hook(self, name): - - def _hook(module, args, kwargs): - if not self.active: - return - _, device_api = _memory_api() - if device_api is None: - return - if hasattr(device_api, 'synchronize'): - device_api.synchronize() - self.entries[id(module)] = { - 'name': name, - 'class_name': module.__class__.__name__, - 'input_shape': _shape_of(args) or _shape_of(kwargs), - 'pre_allocated_mib': _memory_mib_value(device_api.memory_allocated()), - 'pre_reserved_mib': _memory_mib_value(device_api.memory_reserved()), - 'pre_peak_allocated_mib': _memory_mib_value(device_api.max_memory_allocated()), - 'pre_peak_reserved_mib': _memory_mib_value(device_api.max_memory_reserved()), - } - - return _hook - - def _make_post_hook(self, name): - - def _hook(module, args, kwargs, output): - del args, kwargs - if not self.active: - return - _, device_api = _memory_api() - if device_api is None: - return - if hasattr(device_api, 'synchronize'): - device_api.synchronize() - entry = self.entries.pop(id(module), None) - if entry is None: - return - post_allocated_mib = _memory_mib_value(device_api.memory_allocated()) - post_reserved_mib = _memory_mib_value(device_api.memory_reserved()) - post_peak_allocated_mib = _memory_mib_value(device_api.max_memory_allocated()) - post_peak_reserved_mib = _memory_mib_value(device_api.max_memory_reserved()) - entry.update({ - 'output_shape': _shape_of(output), - 'post_allocated_mib': post_allocated_mib, - 'post_reserved_mib': post_reserved_mib, - 'post_peak_allocated_mib': post_peak_allocated_mib, - 'post_peak_reserved_mib': post_peak_reserved_mib, - 'delta_allocated_mib': round(post_allocated_mib - entry['pre_allocated_mib'], 3), - 'delta_reserved_mib': round(post_reserved_mib - entry['pre_reserved_mib'], 3), - 'delta_peak_allocated_mib': round(post_peak_allocated_mib - entry['pre_peak_allocated_mib'], 3), - 'delta_peak_reserved_mib': round(post_peak_reserved_mib - entry['pre_peak_reserved_mib'], 3), - }) - self.records.append(entry) - - return _hook - - def start_step(self, step: int): - self.active = self.enabled and Platform.get_rank() == self.target_rank and step == self.target_step - self.entries.clear() - self.records.clear() - if self.active: - _reset_peak_memory_stats() - - def finish_step(self, step: int): - if not self.active: - return - step_memory = _get_memory_stats() - sorted_records = sorted(self.records, key=lambda item: item['delta_peak_allocated_mib'], reverse=True) - logger.info( - 'Module memory profile summary: ' - + json.dumps( - { - 'step': step, - 'rank': Platform.get_rank(), - 'total_step_peak_allocated_mib': step_memory.get('mem_peak_allocated'), - 'total_step_peak_reserved_mib': step_memory.get('mem_peak_reserved'), - 'top_forward_modules_by_peak_allocated': sorted_records[:self.max_records], - }, - ensure_ascii=False, - )) - self.active = False - - def close(self): - for handle in self.handles: - handle.remove() - self.handles.clear() - - -def _get_runtime_backend_info(model: TransformersModel): - model._ensure_sp_strategy() - - underlying_model = getattr(model, 'model', None) - llm_model = getattr(underlying_model, 'model', underlying_model) - config = getattr(underlying_model, 'config', None) - - attn_implementation = None - attn_implementation_internal = None - if config is not None: - attn_implementation = getattr(config, '_attn_implementation', None) - attn_implementation_internal = getattr(config, '_attn_implementation_internal', None) - - return { - 'model_cls': type(underlying_model).__name__ if underlying_model is not None else None, - 'llm_model_cls': type(llm_model).__name__ if llm_model is not None else None, - 'attn_implementation': attn_implementation, - 'attn_implementation_internal': attn_implementation_internal, - 'requires_cu_seq_lens_q': bool(getattr(llm_model, 'requires_cu_seq_lens_q', False)), - 'sp_enabled': bool(getattr(model, '_enable_sp', False)), - 'ulysses_size': getattr(getattr(model, 'device_mesh', None), 'ulysses_size', None), - 'sp_strategy_enabled': bool(getattr(getattr(model, 'sp_strategy', None), 'enabled', False)), - 'sp_strategy_ulysses_size': getattr(getattr(model, 'sp_strategy', None), 'ulysses_size', None), - } def eval(model): @@ -255,8 +42,6 @@ def eval(model): dataset=partial(create_dataset, data_slice=range(100)), batch_size=4, device_mesh=device_mesh, - num_workers=TRAIN_NUM_WORKERS, - shuffle=False, ) for _, batch in enumerate(dataloader): model.forward_only(inputs=batch, adapter_name='default') @@ -277,8 +62,6 @@ def train(): dataset=partial(create_dataset, data_slice=None), batch_size=8, device_mesh=device_mesh, - num_workers=TRAIN_NUM_WORKERS, - shuffle=TRAIN_SHUFFLE, ) model = TransformersModel( @@ -286,10 +69,9 @@ def train(): model_cls=TwinkleQwen3_5ForCausalLM, device_mesh=device_mesh, strategy='native_fsdp', - attention_dropout=TRAIN_ATTENTION_DROPOUT, ) - lora_config = LoraConfig(target_modules='all-linear', lora_dropout=TRAIN_LORA_DROPOUT) + lora_config = LoraConfig(target_modules='all-linear', lora_dropout=0.0) model.add_adapter_to_model('default', lora_config) grad_accumulation_steps = model.optimizer_group['default'].gradient_accumulation_steps num_optimizer_steps = math.ceil(len(dataloader) / grad_accumulation_steps) @@ -306,38 +88,19 @@ def train(): logger.info( f'Total micro steps: {len(dataloader)}, optimizer steps: {num_optimizer_steps}, ' f'gradient_accumulation_steps: {grad_accumulation_steps}') - logger.info( - 'Reproducibility config: ' - + str({ - 'seed': TRAIN_SEED, - 'deterministic': TRAIN_DETERMINISTIC, - 'dataloader_shuffle': TRAIN_SHUFFLE, - 'dataloader_num_workers': TRAIN_NUM_WORKERS, - 'attention_dropout': TRAIN_ATTENTION_DROPOUT, - 'lora_dropout': TRAIN_LORA_DROPOUT, - })) - logger.info(f'Backend info: {_get_runtime_backend_info(model)}') - logger.info(f'Initial memory: {_get_memory_stats()}') - _reset_peak_memory_stats() - module_memory_profiler = _ModuleMemoryProfiler(model) - module_memory_profiler.attach() for step, batch in enumerate(dataloader): - module_memory_profiler.start_step(step) model.forward_backward(inputs=batch, adapter_name='default') - module_memory_profiler.finish_step(step) model.clip_grad_and_step(adapter_name='default') optimizer_step = step // grad_accumulation_steps is_optimizer_boundary = (step + 1) % grad_accumulation_steps == 0 if is_optimizer_boundary and optimizer_step % log_every_optimizer_steps == 0: metric = model.calculate_metric(is_training=True, adapter_name='default') - metric.update(_get_memory_stats()) optimizer_step = metric.get('iters', optimizer_step) logger.info( f'Current is optimizer step {optimizer_step} of {num_optimizer_steps} ' f'(micro step {step} of {len(dataloader)}), metric: {metric}') model.save('last-checkpoint', interval=1) - module_memory_profiler.close() if __name__ == '__main__': diff --git a/tests/transformers/test_ulysses_auto_batch_policy.py b/tests/transformers/test_ulysses_auto_batch_policy.py deleted file mode 100644 index e66c398a..00000000 --- a/tests/transformers/test_ulysses_auto_batch_policy.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import numpy as np - -from twinkle import DeviceMesh -from twinkle.model.transformers.transformers import _default_gradient_accumulation_steps_for_device_mesh - - -class TestUlyssesAutoBatchPolicy: - - def test_default_gradient_accumulation_steps_scales_with_ulysses(self): - device_mesh = DeviceMesh( - device_type='cpu', - mesh=np.arange(4).reshape(2, 2), - mesh_dim_names=('dp', 'fsdp'), - ulysses_size=2, - ) - - assert _default_gradient_accumulation_steps_for_device_mesh(device_mesh) == 2 From 1cf60eedc4f88be1b3ca443ad5b2ccb66aefcc07 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Tue, 24 Mar 2026 17:19:56 +0800 Subject: [PATCH 4/6] delete unusesd files --- .../transformers/qwen3_5_sp_memory_bench.py | 430 ------------------ 1 file changed, 430 deletions(-) delete mode 100644 cookbook/transformers/qwen3_5_sp_memory_bench.py diff --git a/cookbook/transformers/qwen3_5_sp_memory_bench.py b/cookbook/transformers/qwen3_5_sp_memory_bench.py deleted file mode 100644 index 2a47b2f4..00000000 --- a/cookbook/transformers/qwen3_5_sp_memory_bench.py +++ /dev/null @@ -1,430 +0,0 @@ -import json -import os -import socket -import tempfile -import traceback -from datetime import timedelta -from types import SimpleNamespace - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig -from transformers.utils.import_utils import is_flash_attn_2_available - -from twinkle.model.transformers.models.qwen3_5 import modeling_qwen3_5 as tw_qwen35 -from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallel, SequenceParallelContext -from twinkle.utils import DeviceMesh - -# Examples: -# CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=src python cookbook/transformers/qwen3_5_sp_memory_bench.py -# CUDA_VISIBLE_DEVICES=0,1 QWEN35_SP_MEMORY_MODE=linear PYTHONPATH=src \ -# python cookbook/transformers/qwen3_5_sp_memory_bench.py - - -def _build_linear_bench_config() -> Qwen3_5TextConfig: - hidden_size = int(os.environ.get('QWEN35_SP_MEMORY_HIDDEN_SIZE', '1024')) - head_dim = int(os.environ.get('QWEN35_SP_MEMORY_HEAD_DIM', '64')) - num_attention_heads = hidden_size // head_dim - return Qwen3_5TextConfig( - vocab_size=64, - hidden_size=hidden_size, - intermediate_size=hidden_size * 2, - num_hidden_layers=1, - num_attention_heads=num_attention_heads, - num_key_value_heads=max(1, num_attention_heads // 2), - head_dim=head_dim, - hidden_act='silu', - max_position_embeddings=16384, - rms_norm_eps=1e-6, - attention_dropout=0.0, - linear_conv_kernel_dim=3, - linear_key_head_dim=head_dim, - linear_value_head_dim=head_dim, - linear_num_key_heads=max(2, num_attention_heads // 2), - linear_num_value_heads=num_attention_heads, - layer_types=['linear_attention'], - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - ) - - -def _build_mixed_bench_config() -> Qwen3_5TextConfig: - hidden_size = int(os.environ.get('QWEN35_SP_MEMORY_HIDDEN_SIZE', '1024')) - head_dim = int(os.environ.get('QWEN35_SP_MEMORY_HEAD_DIM', '64')) - num_attention_heads = hidden_size // head_dim - config = Qwen3_5TextConfig( - vocab_size=64, - hidden_size=hidden_size, - intermediate_size=hidden_size * 2, - num_hidden_layers=2, - num_attention_heads=num_attention_heads, - num_key_value_heads=max(1, num_attention_heads // 2), - head_dim=head_dim, - hidden_act='silu', - max_position_embeddings=16384, - rms_norm_eps=1e-6, - attention_dropout=0.0, - linear_conv_kernel_dim=3, - linear_key_head_dim=head_dim, - linear_value_head_dim=head_dim, - linear_num_key_heads=max(2, num_attention_heads // 2), - linear_num_value_heads=num_attention_heads, - layer_types=['full_attention', 'linear_attention'], - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - ) - attn_implementation = os.environ.get( - 'QWEN35_SP_MEMORY_ATTN_IMPLEMENTATION', - 'flash_attention_2' if is_flash_attn_2_available() else 'sdpa', - ) - config._attn_implementation = attn_implementation - config._attn_implementation_internal = attn_implementation - return config - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(('127.0.0.1', 0)) - return sock.getsockname()[1] - - -def _parse_cases(): - spec = os.environ.get('QWEN35_SP_MEMORY_CASES', '1x1024,1x2048,2x2048') - cases = [] - for item in spec.split(','): - item = item.strip() - if not item: - continue - batch_size, seq_len = item.lower().split('x', 1) - cases.append((int(batch_size), int(seq_len))) - return cases - - -def _measure_cuda_peak_stats(run_step): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - run_step() - torch.cuda.synchronize() - return { - 'peak_allocated_mib': torch.cuda.max_memory_allocated() / (1024 ** 2), - 'peak_reserved_mib': torch.cuda.max_memory_reserved() / (1024 ** 2), - } - - -def _run_linear_attention_memory_step( - module: torch.nn.Module, - hidden_states: torch.Tensor, - *, - attention_mask: torch.Tensor, - cu_seq_lens_q: torch.Tensor, - sequence_parallel_context: SequenceParallelContext | None = None, -) -> dict[str, float]: - - def _step(): - module.zero_grad(set_to_none=True) - local_hidden_states = hidden_states.detach().clone().requires_grad_(True) - output = module( - hidden_states=local_hidden_states, - attention_mask=attention_mask, - cu_seq_lens_q=cu_seq_lens_q, - sequence_parallel_context=sequence_parallel_context, - ) - loss = output.float().square().mean() - loss.backward() - - return _measure_cuda_peak_stats(_step) - - -def _run_text_model_memory_step( - model: torch.nn.Module, - model_inputs: dict[str, torch.Tensor | bool], -) -> dict[str, float]: - - def _step(): - model.zero_grad(set_to_none=True) - local_inputs = {} - for key, value in model_inputs.items(): - if torch.is_tensor(value): - local_inputs[key] = value.detach().clone() - else: - local_inputs[key] = value - outputs = model(**local_inputs) - loss = outputs.last_hidden_state.float().square().mean() - loss.backward() - - return _measure_cuda_peak_stats(_step) - - -def _write_error(error_prefix: str, rank: int) -> None: - with open(f'{error_prefix}.rank{rank}.err', 'w', encoding='utf-8') as f: - f.write(traceback.format_exc()) - - -def _run_linear_attention_memory_worker(rank: int, world_size: int, port: int, result_path: str, cases): - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(port) - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['LOCAL_WORLD_SIZE'] = str(world_size) - torch.cuda.set_device(rank) - error_prefix = result_path - try: - dist.init_process_group( - backend='nccl', - rank=rank, - world_size=world_size, - timeout=timedelta(minutes=10), - ) - - device = torch.device(f'cuda:{rank}') - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - config = _build_linear_bench_config() - results = [] - - for batch_size, seq_len in cases: - if seq_len % world_size != 0: - raise ValueError(f'seq_len ({seq_len}) must be divisible by world_size ({world_size})') - - full_attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64, device=device) - full_position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) - cu_seq_lens_q = torch.arange( - 0, - (batch_size + 1) * seq_len, - step=seq_len, - dtype=torch.int32, - device=device, - ) - - baseline_module = tw_qwen35.TwinkleQwen3_5GatedDeltaNet(config, layer_idx=0).to(device=device, dtype=dtype) - baseline_module.train() - baseline_hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device=device, dtype=dtype) - baseline_stats = _run_linear_attention_memory_step( - baseline_module, - baseline_hidden_states, - attention_mask=full_attention_mask, - cu_seq_lens_q=cu_seq_lens_q, - sequence_parallel_context=None, - ) - del baseline_module, baseline_hidden_states - torch.cuda.empty_cache() - - local_seq_len = seq_len // world_size - start = rank * local_seq_len - end = start + local_seq_len - sp_attention_mask = full_attention_mask[:, start:end].contiguous() - sp_hidden_states = torch.randn(batch_size, local_seq_len, config.hidden_size, device=device, dtype=dtype) - sp_context = SequenceParallelContext( - sp_group=dist.group.WORLD, - sp_world_size=world_size, - rank=rank, - world_size=world_size, - real_position_ids=full_position_ids, - is_packed=False, - ) - sp_module = tw_qwen35.TwinkleQwen3_5GatedDeltaNet(config, layer_idx=0).to(device=device, dtype=dtype) - sp_module.train() - sp_stats = _run_linear_attention_memory_step( - sp_module, - sp_hidden_states, - attention_mask=sp_attention_mask, - cu_seq_lens_q=cu_seq_lens_q, - sequence_parallel_context=sp_context, - ) - del sp_module, sp_hidden_states - torch.cuda.empty_cache() - - payload = torch.tensor([ - baseline_stats['peak_allocated_mib'], - baseline_stats['peak_reserved_mib'], - sp_stats['peak_allocated_mib'], - sp_stats['peak_reserved_mib'], - ], device=device) - gathered = [torch.zeros_like(payload) for _ in range(world_size)] - dist.all_gather(gathered, payload) - - if rank == 0: - gathered_cpu = [tensor.cpu().tolist() for tensor in gathered] - results.append({ - 'batch_size': batch_size, - 'seq_len': seq_len, - 'baseline_peak_allocated_mib_per_rank': [row[0] for row in gathered_cpu], - 'baseline_peak_reserved_mib_per_rank': [row[1] for row in gathered_cpu], - 'sp_peak_allocated_mib_per_rank': [row[2] for row in gathered_cpu], - 'sp_peak_reserved_mib_per_rank': [row[3] for row in gathered_cpu], - 'baseline_peak_allocated_mib_max': max(row[0] for row in gathered_cpu), - 'sp_peak_allocated_mib_max': max(row[2] for row in gathered_cpu), - }) - - if rank == 0: - torch.save(results, result_path) - except Exception: - _write_error(error_prefix, rank) - raise - finally: - if dist.is_initialized(): - dist.destroy_process_group() - - -def _run_mixed_text_model_memory_worker(rank: int, world_size: int, port: int, result_path: str, cases): - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(port) - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['LOCAL_WORLD_SIZE'] = str(world_size) - torch.cuda.set_device(rank) - error_prefix = result_path - try: - dist.init_process_group( - backend='nccl', - rank=rank, - world_size=world_size, - timeout=timedelta(minutes=10), - ) - - device = torch.device(f'cuda:{rank}') - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - baseline_results = [] - sp_results = [] - - for batch_size, seq_len in cases: - config = _build_mixed_bench_config() - full_input_ids = torch.randint(1, config.vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) - full_position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) - - baseline_model = tw_qwen35.TwinkleQwen3_5TextModel(config).to(device=device, dtype=dtype) - baseline_model.train() - baseline_stats = _run_text_model_memory_step( - baseline_model, - { - 'input_ids': full_input_ids, - 'position_ids': full_position_ids, - 'use_cache': False, - }, - ) - baseline_results.append(baseline_stats) - del baseline_model - torch.cuda.empty_cache() - - device_mesh = DeviceMesh.from_sizes( - world_size=world_size, - dp_size=world_size, - ulysses_size=world_size, - device_type='cuda', - ) - tokenizer = SimpleNamespace(pad_token_id=0) - sp = SequenceParallel() - - for batch_size, seq_len in cases: - if seq_len % world_size != 0: - raise ValueError(f'seq_len ({seq_len}) must be divisible by world_size ({world_size})') - - config = _build_mixed_bench_config() - full_input_ids = torch.randint(1, config.vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) - full_position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) - - sp_model = tw_qwen35.TwinkleQwen3_5TextModel(config).to(device=device, dtype=dtype) - sp_model.train() - sp.prepare(world_size, sp_model, tokenizer, device_mesh=device_mesh) - sp_inputs = sp.prepare_inputs({ - 'input_ids': full_input_ids, - 'position_ids': full_position_ids, - 'use_cache': False, - }) - sp_stats = _run_text_model_memory_step(sp_model, sp_inputs) - sp_results.append(sp_stats) - del sp_model - torch.cuda.empty_cache() - - gathered_results = [] - for (batch_size, seq_len), baseline_stats, sp_stats in zip(cases, baseline_results, sp_results, strict=False): - payload = torch.tensor([ - baseline_stats['peak_allocated_mib'], - baseline_stats['peak_reserved_mib'], - sp_stats['peak_allocated_mib'], - sp_stats['peak_reserved_mib'], - ], device=device) - gathered = [torch.zeros_like(payload) for _ in range(world_size)] - dist.all_gather(gathered, payload) - - if rank == 0: - gathered_cpu = [tensor.cpu().tolist() for tensor in gathered] - gathered_results.append({ - 'batch_size': batch_size, - 'seq_len': seq_len, - 'attn_implementation': getattr(config, '_attn_implementation', None), - 'baseline_peak_allocated_mib_per_rank': [row[0] for row in gathered_cpu], - 'baseline_peak_reserved_mib_per_rank': [row[1] for row in gathered_cpu], - 'sp_peak_allocated_mib_per_rank': [row[2] for row in gathered_cpu], - 'sp_peak_reserved_mib_per_rank': [row[3] for row in gathered_cpu], - 'baseline_peak_allocated_mib_max': max(row[0] for row in gathered_cpu), - 'sp_peak_allocated_mib_max': max(row[2] for row in gathered_cpu), - }) - - if rank == 0: - torch.save(gathered_results, result_path) - except Exception: - _write_error(error_prefix, rank) - raise - finally: - if dist.is_initialized(): - dist.destroy_process_group() - - -def _run_spawned(worker, world_size: int, cases): - port = _find_free_port() - with tempfile.TemporaryDirectory() as temp_dir: - result_path = os.path.join(temp_dir, 'memory_results.pt') - try: - mp.spawn( - worker, - args=(world_size, port, result_path, cases), - nprocs=world_size, - join=True, - ) - except Exception: - error_logs = [] - for rank in range(world_size): - error_path = f'{result_path}.rank{rank}.err' - if os.path.exists(error_path): - with open(error_path, 'r', encoding='utf-8') as f: - error_logs.append(f'Rank {rank}:\n{f.read()}') - if error_logs: - raise RuntimeError('\n\n'.join(error_logs)) - raise - return torch.load(result_path, weights_only=False) - - -def main(): - if not torch.cuda.is_available(): - raise SystemExit('CUDA is required for the Qwen3.5 SP memory benchmark.') - - world_size = int(os.environ.get('QWEN35_SP_MEMORY_WORLD_SIZE', '2')) - if torch.cuda.device_count() < world_size: - raise SystemExit(f'Need at least {world_size} CUDA devices for the Qwen3.5 SP memory benchmark.') - - cases = _parse_cases() - mode = os.environ.get('QWEN35_SP_MEMORY_MODE', 'both') - results = {} - - if mode in ('linear', 'both'): - results['linear_attention'] = _run_spawned(_run_linear_attention_memory_worker, world_size, cases) - - if mode in ('mixed', 'both'): - results['mixed_text_model'] = _run_spawned(_run_mixed_text_model_memory_worker, world_size, cases) - - output = json.dumps(results, indent=2) - print(output) - - result_path = os.environ.get('QWEN35_SP_MEMORY_RESULT_PATH') - if result_path: - with open(result_path, 'w', encoding='utf-8') as f: - f.write(output) - - -if __name__ == '__main__': - main() From 22a71f5aa020cbfbfe300c5225c28b26baff9e17 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Wed, 25 Mar 2026 00:03:17 +0800 Subject: [PATCH 5/6] fix bug --- cookbook/transformers/sp_fsdp_dense.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index 7928f277..d99de2e3 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -5,7 +5,7 @@ from peft import LoraConfig import twinkle -from twinkle import DeviceGroup, DeviceMesh, get_logger +from twinkle import DeviceGroup, DeviceMesh, get_logger,Platform from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel @@ -19,6 +19,7 @@ device_group = [DeviceGroup( name='default', ranks=[0, 1, 2, 3], + device_type=Platform.get_platform().device_prefix(), )] # FSDP + SP validation over 4 GPUs: dp=2, fsdp=2 (SP only affects input slicing) @@ -69,6 +70,7 @@ def train(): model_cls=TwinkleQwen3_5ForCausalLM, device_mesh=device_mesh, strategy='native_fsdp', + attn_implementation="flash_attention_2" ) lora_config = LoraConfig(target_modules='all-linear', lora_dropout=0.0) From 75d006ad9f56276103c8347fccb882a44befafd8 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Wed, 25 Mar 2026 09:27:07 +0800 Subject: [PATCH 6/6] feat: standardize import formatting and fix attention implementation string - Change double quotes to single quotes for consistency in `attn_implementation` parameter - Reformat multi-line imports to single line for better readability - Remove unnecessary import error message in linear attention validation - Maintain code style consistency across the codebase --- cookbook/transformers/sp_fsdp_dense.py | 2 +- src/twinkle/model/transformers/__init__.py | 9 +- .../model/transformers/models/__init__.py | 9 +- .../transformers/models/qwen3_5/__init__.py | 9 +- .../models/qwen3_5/modeling_qwen3_5.py | 49 ++++------- .../strategy/sequence_parallel.py | 10 +-- .../test_twinkle_qwen3_5_text_model.py | 85 +++++++++++++------ .../test_twinkle_qwen3_5_text_model_parity.py | 43 +++++----- 8 files changed, 109 insertions(+), 107 deletions(-) diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index d99de2e3..81eaf60d 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -70,7 +70,7 @@ def train(): model_cls=TwinkleQwen3_5ForCausalLM, device_mesh=device_mesh, strategy='native_fsdp', - attn_implementation="flash_attention_2" + attn_implementation='flash_attention_2' ) lora_config = LoraConfig(target_modules='all-linear', lora_dropout=0.0) diff --git a/src/twinkle/model/transformers/__init__.py b/src/twinkle/model/transformers/__init__.py index 2ce17822..afd16934 100644 --- a/src/twinkle/model/transformers/__init__.py +++ b/src/twinkle/model/transformers/__init__.py @@ -4,13 +4,8 @@ from twinkle.utils.import_utils import _LazyModule if TYPE_CHECKING: - from .models import ( - TwinkleQwen3_5DecoderLayer, - TwinkleQwen3_5ForCausalLM, - TwinkleQwen3_5GatedDeltaNet, - TwinkleQwen3_5PreTrainedModel, - TwinkleQwen3_5TextModel, - ) + from .models import (TwinkleQwen3_5DecoderLayer, TwinkleQwen3_5ForCausalLM, TwinkleQwen3_5GatedDeltaNet, + TwinkleQwen3_5PreTrainedModel, TwinkleQwen3_5TextModel) from .multi_lora_transformers import MultiLoraTransformersModel from .transformers import TransformersModel else: diff --git a/src/twinkle/model/transformers/models/__init__.py b/src/twinkle/model/transformers/models/__init__.py index 68d7103b..8c84298c 100644 --- a/src/twinkle/model/transformers/models/__init__.py +++ b/src/twinkle/model/transformers/models/__init__.py @@ -1,11 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from .qwen3_5 import ( - TwinkleQwen3_5DecoderLayer, - TwinkleQwen3_5ForCausalLM, - TwinkleQwen3_5GatedDeltaNet, - TwinkleQwen3_5PreTrainedModel, - TwinkleQwen3_5TextModel, -) +from .qwen3_5 import (TwinkleQwen3_5DecoderLayer, TwinkleQwen3_5ForCausalLM, TwinkleQwen3_5GatedDeltaNet, + TwinkleQwen3_5PreTrainedModel, TwinkleQwen3_5TextModel) __all__ = [ 'TwinkleQwen3_5PreTrainedModel', diff --git a/src/twinkle/model/transformers/models/qwen3_5/__init__.py b/src/twinkle/model/transformers/models/qwen3_5/__init__.py index 1b3cb561..60c8a808 100644 --- a/src/twinkle/model/transformers/models/qwen3_5/__init__.py +++ b/src/twinkle/model/transformers/models/qwen3_5/__init__.py @@ -1,11 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from .modeling_qwen3_5 import ( - TwinkleQwen3_5DecoderLayer, - TwinkleQwen3_5ForCausalLM, - TwinkleQwen3_5GatedDeltaNet, - TwinkleQwen3_5PreTrainedModel, - TwinkleQwen3_5TextModel, -) +from .modeling_qwen3_5 import (TwinkleQwen3_5DecoderLayer, TwinkleQwen3_5ForCausalLM, TwinkleQwen3_5GatedDeltaNet, + TwinkleQwen3_5PreTrainedModel, TwinkleQwen3_5TextModel) __all__ = [ 'TwinkleQwen3_5PreTrainedModel', diff --git a/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py index 369cae73..4f7ea33e 100644 --- a/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -2,21 +2,19 @@ from __future__ import annotations import importlib.util -from typing import Any, Callable, Optional - import torch import torch.nn.functional as F from torch import nn from transformers.cache_utils import Cache from transformers.generation import GenerationMixin from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig from transformers.models.qwen3_5 import modeling_qwen3_5 as hf_qwen35 +from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, can_return_tuple from transformers.utils.generic import merge_with_config_defaults from transformers.utils.output_capturing import capture_outputs - +from typing import Any, Callable, Optional try: from fla.modules import FusedRMSNormGated as _FLA_FUSED_RMS_NORM_GATED @@ -37,10 +35,8 @@ def _ensure_text_config(config: Qwen3_5TextConfig) -> Qwen3_5TextConfig: if isinstance(config, Qwen3_5TextConfig): return config - raise TypeError( - 'TwinkleQwen3_5 text-only models require transformers.models.qwen3_5.Qwen3_5TextConfig. ' - f'Got {type(config).__name__}.' - ) + raise TypeError('TwinkleQwen3_5 text-only models require transformers.models.qwen3_5.Qwen3_5TextConfig. ' + f'Got {type(config).__name__}.') def _ensure_linear_attention_fast_path() -> None: @@ -52,10 +48,8 @@ def _ensure_linear_attention_fast_path() -> None: if not _HAS_CAUSAL_CONV1D: missing.append('causal-conv1d') if missing: - raise ImportError( - 'TwinkleQwen3_5 linear attention requires flash-linear-attention and causal-conv1d. ' - f'Missing: {", ".join(missing)}' - ) + raise ImportError('TwinkleQwen3_5 linear attention requires flash-linear-attention and causal-conv1d. ' + f'Missing: {", ".join(missing)}') def _maybe_slice_tensor_output(output: Any) -> torch.Tensor: @@ -66,10 +60,8 @@ def _maybe_slice_tensor_output(output: Any) -> torch.Tensor: def _sp_is_enabled(sequence_parallel_context: Any | None) -> bool: return bool( - sequence_parallel_context is not None - and getattr(sequence_parallel_context, 'sp_world_size', 1) > 1 - and getattr(sequence_parallel_context, 'sp_group', None) is not None - ) + sequence_parallel_context is not None and getattr(sequence_parallel_context, 'sp_world_size', 1) > 1 + and getattr(sequence_parallel_context, 'sp_group', None) is not None) def _get_sp_rank(sequence_parallel_context: Any | None) -> int: @@ -239,8 +231,7 @@ def _apply_varlen_conv( ) -> torch.Tensor: if self.causal_conv1d_fn is None: raise ImportError( - 'TwinkleQwen3_5 linear attention requires fla.modules.convolution.causal_conv1d for prefill/train.' - ) + 'TwinkleQwen3_5 linear attention requires fla.modules.convolution.causal_conv1d for prefill/train.') output = self.causal_conv1d_fn( x=mixed_qkv, weight=conv_weight, @@ -261,8 +252,7 @@ def _apply_decode_conv( if self.causal_conv1d_update is None: raise ImportError( 'TwinkleQwen3_5 decode requires a causal_conv1d_update implementation from flash-linear-attention ' - 'or causal-conv1d.' - ) + 'or causal-conv1d.') mixed_qkv_t = mixed_qkv.transpose(1, 2).contiguous() output = self.causal_conv1d_update( mixed_qkv_t, @@ -291,11 +281,8 @@ def forward( hidden_states = hf_qwen35.apply_mask_to_padding_states(hidden_states, attention_mask) batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_position is not None - ) + cache_params is not None and cache_params.has_previous_state and seq_len == 1 + and cache_position is not None) if cache_params is not None: conv_state = cache_params.conv_states[self.layer_idx] @@ -316,8 +303,7 @@ def forward( if self.num_k_heads % sp_world_size != 0 or self.num_v_heads % sp_world_size != 0: raise RuntimeError( 'TwinkleQwen3_5 linear attention requires sp_world_size to divide both ' - f'linear_num_key_heads ({self.num_k_heads}) and linear_num_value_heads ({self.num_v_heads}).' - ) + f'linear_num_key_heads ({self.num_k_heads}) and linear_num_value_heads ({self.num_v_heads}).') local_num_k_heads = self.num_k_heads // sp_world_size local_num_v_heads = self.num_v_heads // sp_world_size local_key_dim = local_num_k_heads * self.head_k_dim @@ -341,7 +327,8 @@ def forward( ), dim=-1, ) - conv_weight = self._get_local_conv1d_weight(_get_sp_rank(sequence_parallel_context), local_key_dim, local_value_dim) + conv_weight = self._get_local_conv1d_weight( + _get_sp_rank(sequence_parallel_context), local_key_dim, local_value_dim) else: local_num_k_heads = self.num_k_heads local_num_v_heads = self.num_v_heads @@ -506,8 +493,7 @@ def __init__(self, config: Qwen3_5TextConfig): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) self.layers = nn.ModuleList( - [TwinkleQwen3_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) + [TwinkleQwen3_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.norm = hf_qwen35.Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = hf_qwen35.Qwen3_5TextRotaryEmbedding(config=config) self.gradient_checkpointing = False @@ -569,8 +555,7 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 35b95f40..4f2720a9 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -37,7 +37,8 @@ def get_flattened_cu_seqlens_from_position_ids(position_ids: torch.LongTensor): row[row < 0] = 0 seq_start_indices = torch.where(row == 0)[0] if seq_start_indices.numel() == 0 or seq_start_indices[0].item() != 0: - seq_start_indices = torch.cat([torch.tensor([0], device=device, dtype=seq_start_indices.dtype), seq_start_indices]) + seq_start_indices = torch.cat( + [torch.tensor([0], device=device, dtype=seq_start_indices.dtype), seq_start_indices]) seq_end_indices = torch.cat([seq_start_indices[1:], torch.tensor([len(row)], device=device)]) seq_lengths = (seq_end_indices - seq_start_indices).tolist() for seq_length in seq_lengths: @@ -687,8 +688,7 @@ def prepare( self.causal_mask_func = llm_model._update_causal_mask self.attn_implementation = ( get_config_attr(model.config, '_attn_implementation') - or get_config_attr(model.config, '_attn_implementation_internal') - ) + or get_config_attr(model.config, '_attn_implementation_internal')) if not SequenceParallel._global_inited: # these operations are global initializations and patches @@ -832,8 +832,8 @@ def pad_and_split_inputs(self, cache_position = torch.arange(0, attn_shape, device=inputs.device) # SDPA/eager-style paths still expect a fully materialized causal mask here. if hasattr(self, 'causal_mask_func') and self.causal_mask_func is not None: - attention_mask = self.causal_mask_func( - attention_mask, inputs.to(self.model_dtype), cache_position, None, None) + attention_mask = self.causal_mask_func(attention_mask, inputs.to(self.model_dtype), + cache_position, None, None) if extra_split_values is not None: for (tensor, pad_value, split_dim) in extra_split_values: extra_values.append( diff --git a/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py index b638687e..ad8c8108 100644 --- a/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py +++ b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py @@ -1,13 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import tempfile +import torch import unittest from contextlib import ExitStack -from types import SimpleNamespace -from unittest.mock import patch - -import torch from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM +from types import SimpleNamespace +from unittest.mock import patch from twinkle.model.transformers.models.qwen3_5 import modeling_qwen3_5 as tw_qwen35 from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallel, SequenceParallelContext @@ -40,13 +39,9 @@ def _build_text_config(layer_types=None) -> Qwen3_5TextConfig: def _linear_attention_runtime_available() -> bool: - return bool( - torch.cuda.is_available() - and tw_qwen35._FLA_CAUSAL_CONV1D_FN is not None - and tw_qwen35._FLA_CHUNK_GATED_DELTA_RULE is not None - and tw_qwen35._FLA_FUSED_RECURRENT_GATED_DELTA_RULE is not None - and tw_qwen35._HAS_CAUSAL_CONV1D - ) + return bool(torch.cuda.is_available() and tw_qwen35._FLA_CAUSAL_CONV1D_FN is not None + and tw_qwen35._FLA_CHUNK_GATED_DELTA_RULE is not None + and tw_qwen35._FLA_FUSED_RECURRENT_GATED_DELTA_RULE is not None and tw_qwen35._HAS_CAUSAL_CONV1D) class _ContextReceiver: @@ -233,13 +228,26 @@ def fake_conv(x, weight, bias, activation, seq_idx=None, backend=None, cu_seqlen captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None return x - def fake_chunk_rule(query, key, value, g, beta, initial_state=None, output_final_state=False, - use_qk_l2norm_in_kernel=False, cu_seqlens=None): + def fake_chunk_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, + cu_seqlens=None): del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None return value, None - def fake_recurrent_rule(query, key, value, g, beta, initial_state=None, output_final_state=False, + def fake_recurrent_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False): del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel return value, None @@ -321,14 +329,27 @@ def fake_conv(x, weight, bias, activation, seq_idx=None, backend=None, cu_seqlen captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None return x - def fake_chunk_rule(query, key, value, g, beta, initial_state=None, output_final_state=False, - use_qk_l2norm_in_kernel=False, cu_seqlens=None): + def fake_chunk_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, + cu_seqlens=None): del key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel captured['query_shape'] = tuple(query.shape) captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None return query.new_zeros(query.shape[0], query.shape[1], 4, 4), None - def fake_recurrent_rule(query, key, value, g, beta, initial_state=None, output_final_state=False, + def fake_recurrent_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False): del query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel raise AssertionError('recurrent path should not be used') @@ -368,12 +389,25 @@ def fake_conv(x, weight, bias, activation, seq_idx=None, backend=None, cu_seqlen del weight, bias, activation, seq_idx, backend, cu_seqlens return x - def fake_chunk_rule(query, key, value, g, beta, initial_state=None, output_final_state=False, - use_qk_l2norm_in_kernel=False, cu_seqlens=None): + def fake_chunk_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, + cu_seqlens=None): del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel, cu_seqlens return value, None - def fake_recurrent_rule(query, key, value, g, beta, initial_state=None, output_final_state=False, + def fake_recurrent_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False): del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel return value, None @@ -395,8 +429,12 @@ def fake_recurrent_rule(query, key, value, g, beta, initial_state=None, output_f is_packed=False, )) - def fake_linear_forward(hidden_states, cache_params=None, cache_position=None, attention_mask=None, - cu_seq_lens_q=None, sequence_parallel_context=None): + def fake_linear_forward(hidden_states, + cache_params=None, + cache_position=None, + attention_mask=None, + cu_seq_lens_q=None, + sequence_parallel_context=None): del hidden_states, cache_params, cache_position, cu_seq_lens_q, sequence_parallel_context captured['mask'] = attention_mask.clone() if attention_mask is not None else None return torch.zeros(1, 2, config.hidden_size) @@ -421,8 +459,7 @@ def test_sequence_parallel_drops_dense_attention_mask_for_flash_attention_2(self sp.tokenizer = SimpleNamespace(pad_token_id=0) sp.model_dtype = torch.bfloat16 sp.attn_implementation = 'flash_attention_2' - sp.causal_mask_func = lambda *args, **kwargs: (_ for _ in ()).throw( - AssertionError('should not build 4d mask')) + sp.causal_mask_func = lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError('should not build 4d mask')) input_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.long) position_ids = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.long) diff --git a/tests/sequence_parallel/test_twinkle_qwen3_5_text_model_parity.py b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model_parity.py index 5e765adb..6b3fc1d1 100644 --- a/tests/sequence_parallel/test_twinkle_qwen3_5_text_model_parity.py +++ b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model_parity.py @@ -3,16 +3,15 @@ import os import socket import tempfile -import traceback -import unittest -from datetime import timedelta -from types import SimpleNamespace - import torch import torch.distributed as dist import torch.multiprocessing as mp +import traceback +import unittest +from datetime import timedelta from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig from transformers.utils.import_utils import is_flash_attn_2_available +from types import SimpleNamespace from twinkle.model.transformers.models.qwen3_5 import modeling_qwen3_5 as tw_qwen35 from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallel, SequenceParallelContext @@ -81,13 +80,9 @@ def _build_mixed_parity_config() -> Qwen3_5TextConfig: def _linear_attention_runtime_available() -> bool: - return bool( - torch.cuda.is_available() - and tw_qwen35._FLA_CAUSAL_CONV1D_FN is not None - and tw_qwen35._FLA_CHUNK_GATED_DELTA_RULE is not None - and tw_qwen35._FLA_FUSED_RECURRENT_GATED_DELTA_RULE is not None - and tw_qwen35._HAS_CAUSAL_CONV1D - ) + return bool(torch.cuda.is_available() and tw_qwen35._FLA_CAUSAL_CONV1D_FN is not None + and tw_qwen35._FLA_CHUNK_GATED_DELTA_RULE is not None + and tw_qwen35._FLA_FUSED_RECURRENT_GATED_DELTA_RULE is not None and tw_qwen35._HAS_CAUSAL_CONV1D) def _find_free_port() -> int: @@ -154,7 +149,9 @@ def _run_linear_attention_parity_worker(rank: int, world_size: int, port: int, e _seed_everything(seed) config = _build_linear_parity_config() - baseline_module = tw_qwen35.TwinkleQwen3_5GatedDeltaNet(config, layer_idx=0).to(device=device, dtype=dtype).eval() + baseline_module = tw_qwen35.TwinkleQwen3_5GatedDeltaNet( + config, layer_idx=0).to( + device=device, dtype=dtype).eval() sp_module = copy.deepcopy(baseline_module).to(device=device, dtype=dtype).eval() full_hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device=device, dtype=dtype) @@ -174,8 +171,7 @@ def _run_linear_attention_parity_worker(rank: int, world_size: int, port: int, e baseline_input_grad = baseline_hidden_states.grad.detach() baseline_param_grads = { name: param.grad.detach().clone() - for name, param in baseline_module.named_parameters() - if param.grad is not None + for name, param in baseline_module.named_parameters() if param.grad is not None } sp_hidden_states = full_hidden_states[:, start:end].detach().clone().requires_grad_(True) @@ -208,9 +204,9 @@ def _run_linear_attention_parity_worker(rank: int, world_size: int, port: int, e _assert_relative_error(sp_input_grad_full, baseline_input_grad, 1e-2, 'linear_attention.input_grad') for name in ( - 'in_proj_qkv.weight', - 'in_proj_z.weight', - 'out_proj.weight', + 'in_proj_qkv.weight', + 'in_proj_z.weight', + 'out_proj.weight', ): _assert_relative_error( sp_module.get_parameter(name).grad, @@ -270,8 +266,7 @@ def _run_mixed_text_model_parity_worker(rank: int, world_size: int, port: int, e baseline_input_grad = baseline_inputs_embeds.grad.detach() baseline_param_grads = { name: param.grad.detach().clone() - for name, param in baseline_model.named_parameters() - if param.grad is not None + for name, param in baseline_model.named_parameters() if param.grad is not None } device_mesh = DeviceMesh.from_sizes( @@ -306,9 +301,9 @@ def _run_mixed_text_model_parity_worker(rank: int, world_size: int, port: int, e _assert_relative_error(sp_inputs_embeds.grad, baseline_input_grad, 1e-2, 'mixed_text_model.input_grad') for name in ( - 'layers.0.self_attn.q_proj.weight', - 'layers.1.linear_attn.in_proj_qkv.weight', - 'layers.1.mlp.gate_proj.weight', + 'layers.0.self_attn.q_proj.weight', + 'layers.1.linear_attn.in_proj_qkv.weight', + 'layers.1.mlp.gate_proj.weight', ): _assert_relative_error( sp_model.get_parameter(name).grad, @@ -344,7 +339,7 @@ def _run_spawned_parity_test(self, worker) -> None: for rank in range(self.WORLD_SIZE): error_path = f'{error_prefix}.rank{rank}.err' if os.path.exists(error_path): - with open(error_path, 'r', encoding='utf-8') as f: + with open(error_path, encoding='utf-8') as f: error_logs.append(f'Rank {rank}:\n{f.read()}') if error_logs: self.fail('\n\n'.join(error_logs))