diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index 868b61c0..81eaf60d 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -1,12 +1,15 @@ -import numpy as np +import math from functools import partial + +import numpy as np from peft import LoraConfig import twinkle -from twinkle import DeviceGroup, DeviceMesh, Platform, get_logger +from twinkle import DeviceGroup, DeviceMesh, get_logger,Platform from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel +from twinkle.model.transformers.models import TwinkleQwen3_5ForCausalLM from twinkle.preprocessor import SelfCognitionProcessor logger = get_logger() @@ -64,29 +67,41 @@ def train(): model = TransformersModel( model_id=MODEL_ID, + model_cls=TwinkleQwen3_5ForCausalLM, device_mesh=device_mesh, strategy='native_fsdp', + attn_implementation='flash_attention_2' ) - lora_config = LoraConfig(target_modules='all-linear') - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1) + lora_config = LoraConfig(target_modules='all-linear', lora_dropout=0.0) + model.add_adapter_to_model('default', lora_config) + grad_accumulation_steps = model.optimizer_group['default'].gradient_accumulation_steps + num_optimizer_steps = math.ceil(len(dataloader) / grad_accumulation_steps) + log_every_optimizer_steps = 20 model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') model.set_lr_scheduler( scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, - num_training_steps=len(dataloader), + num_training_steps=num_optimizer_steps, adapter_name='default', ) logger.info(model.get_train_configs(adapter_name='default')) - logger.info(f'Total steps: {len(dataloader)}') + logger.info( + f'Total micro steps: {len(dataloader)}, optimizer steps: {num_optimizer_steps}, ' + f'gradient_accumulation_steps: {grad_accumulation_steps}') for step, batch in enumerate(dataloader): model.forward_backward(inputs=batch, adapter_name='default') model.clip_grad_and_step(adapter_name='default') - if step % 20 == 0: + optimizer_step = step // grad_accumulation_steps + is_optimizer_boundary = (step + 1) % grad_accumulation_steps == 0 + if is_optimizer_boundary and optimizer_step % log_every_optimizer_steps == 0: metric = model.calculate_metric(is_training=True, adapter_name='default') - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') + optimizer_step = metric.get('iters', optimizer_step) + logger.info( + f'Current is optimizer step {optimizer_step} of {num_optimizer_steps} ' + f'(micro step {step} of {len(dataloader)}), metric: {metric}') model.save('last-checkpoint', interval=1) diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index b3ce4f0f..0a591098 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -45,14 +45,31 @@ def __init__(self, self.max_retries = kwargs.pop('max_retries', 20) self.min_batch_size = min_batch_size if device_mesh is not None: - assert batch_size >= device_mesh.data_world_size and batch_size % device_mesh.data_world_size == 0 - self.batch_size = batch_size + required_world_size = self._required_data_world_size(device_mesh) + assert batch_size >= required_world_size and batch_size % required_world_size == 0 + self.batch_size = self._resolve_runtime_batch_size(batch_size, device_mesh) self.dataloader_params = kwargs - self.dataloader_params['batch_size'] = batch_size + self.dataloader_params['batch_size'] = self.batch_size self.device_mesh = device_mesh self.processor: Optional[InputProcessor] = None self._set_work_init_fn() + @staticmethod + def _required_data_world_size(device_mesh: Optional[DeviceMesh]) -> int: + if device_mesh is None: + return 1 + if (device_mesh.ulysses_size or 1) > 1: + return device_mesh.raw_data_world_size + return device_mesh.data_world_size + + def _resolve_runtime_batch_size(self, batch_size: int, device_mesh: Optional[DeviceMesh]) -> int: + if device_mesh is None: + return batch_size + ulysses_size = device_mesh.ulysses_size or 1 + if ulysses_size <= 1: + return batch_size + return batch_size // ulysses_size + def _set_work_init_fn(self): num_workers = self.dataloader_params.get('num_workers', 2) self.dataloader_params['worker_init_fn'] = partial( diff --git a/src/twinkle/model/transformers/__init__.py b/src/twinkle/model/transformers/__init__.py index 9ffe9866..afd16934 100644 --- a/src/twinkle/model/transformers/__init__.py +++ b/src/twinkle/model/transformers/__init__.py @@ -1,3 +1,32 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from .multi_lora_transformers import MultiLoraTransformersModel -from .transformers import TransformersModel +from typing import TYPE_CHECKING + +from twinkle.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .models import (TwinkleQwen3_5DecoderLayer, TwinkleQwen3_5ForCausalLM, TwinkleQwen3_5GatedDeltaNet, + TwinkleQwen3_5PreTrainedModel, TwinkleQwen3_5TextModel) + from .multi_lora_transformers import MultiLoraTransformersModel + from .transformers import TransformersModel +else: + _import_structure = { + 'transformers': ['TransformersModel'], + 'multi_lora_transformers': ['MultiLoraTransformersModel'], + 'models': [ + 'TwinkleQwen3_5PreTrainedModel', + 'TwinkleQwen3_5TextModel', + 'TwinkleQwen3_5DecoderLayer', + 'TwinkleQwen3_5GatedDeltaNet', + 'TwinkleQwen3_5ForCausalLM', + ], + } + + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, # noqa + extra_objects={}, + ) diff --git a/src/twinkle/model/transformers/models/__init__.py b/src/twinkle/model/transformers/models/__init__.py new file mode 100644 index 00000000..8c84298c --- /dev/null +++ b/src/twinkle/model/transformers/models/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .qwen3_5 import (TwinkleQwen3_5DecoderLayer, TwinkleQwen3_5ForCausalLM, TwinkleQwen3_5GatedDeltaNet, + TwinkleQwen3_5PreTrainedModel, TwinkleQwen3_5TextModel) + +__all__ = [ + 'TwinkleQwen3_5PreTrainedModel', + 'TwinkleQwen3_5TextModel', + 'TwinkleQwen3_5DecoderLayer', + 'TwinkleQwen3_5GatedDeltaNet', + 'TwinkleQwen3_5ForCausalLM', +] diff --git a/src/twinkle/model/transformers/models/qwen3_5/__init__.py b/src/twinkle/model/transformers/models/qwen3_5/__init__.py new file mode 100644 index 00000000..60c8a808 --- /dev/null +++ b/src/twinkle/model/transformers/models/qwen3_5/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .modeling_qwen3_5 import (TwinkleQwen3_5DecoderLayer, TwinkleQwen3_5ForCausalLM, TwinkleQwen3_5GatedDeltaNet, + TwinkleQwen3_5PreTrainedModel, TwinkleQwen3_5TextModel) + +__all__ = [ + 'TwinkleQwen3_5PreTrainedModel', + 'TwinkleQwen3_5TextModel', + 'TwinkleQwen3_5DecoderLayer', + 'TwinkleQwen3_5GatedDeltaNet', + 'TwinkleQwen3_5ForCausalLM', +] diff --git a/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py new file mode 100644 index 00000000..4f7ea33e --- /dev/null +++ b/src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -0,0 +1,674 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +import importlib.util +import torch +import torch.nn.functional as F +from torch import nn +from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.qwen3_5 import modeling_qwen3_5 as hf_qwen35 +from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, can_return_tuple +from transformers.utils.generic import merge_with_config_defaults +from transformers.utils.output_capturing import capture_outputs +from typing import Any, Callable, Optional + +try: + from fla.modules import FusedRMSNormGated as _FLA_FUSED_RMS_NORM_GATED + from fla.modules.convolution import causal_conv1d as _FLA_CAUSAL_CONV1D_FN + from fla.modules.convolution import causal_conv1d_update as _FLA_CAUSAL_CONV1D_UPDATE + from fla.ops.gated_delta_rule import chunk_gated_delta_rule as _FLA_CHUNK_GATED_DELTA_RULE + from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule as _FLA_FUSED_RECURRENT_GATED_DELTA_RULE +except ImportError: + _FLA_FUSED_RMS_NORM_GATED = None + _FLA_CAUSAL_CONV1D_FN = None + _FLA_CAUSAL_CONV1D_UPDATE = None + _FLA_CHUNK_GATED_DELTA_RULE = None + _FLA_FUSED_RECURRENT_GATED_DELTA_RULE = None + +_HAS_CAUSAL_CONV1D = importlib.util.find_spec('causal_conv1d') is not None + + +def _ensure_text_config(config: Qwen3_5TextConfig) -> Qwen3_5TextConfig: + if isinstance(config, Qwen3_5TextConfig): + return config + raise TypeError('TwinkleQwen3_5 text-only models require transformers.models.qwen3_5.Qwen3_5TextConfig. ' + f'Got {type(config).__name__}.') + + +def _ensure_linear_attention_fast_path() -> None: + missing = [] + if _FLA_CAUSAL_CONV1D_FN is None: + missing.append('fla.modules.convolution.causal_conv1d') + if _FLA_CHUNK_GATED_DELTA_RULE is None or _FLA_FUSED_RECURRENT_GATED_DELTA_RULE is None: + missing.append('fla.ops.gated_delta_rule') + if not _HAS_CAUSAL_CONV1D: + missing.append('causal-conv1d') + if missing: + raise ImportError('TwinkleQwen3_5 linear attention requires flash-linear-attention and causal-conv1d. ' + f'Missing: {", ".join(missing)}') + + +def _maybe_slice_tensor_output(output: Any) -> torch.Tensor: + if isinstance(output, tuple): + return output[0] + return output + + +def _sp_is_enabled(sequence_parallel_context: Any | None) -> bool: + return bool( + sequence_parallel_context is not None and getattr(sequence_parallel_context, 'sp_world_size', 1) > 1 + and getattr(sequence_parallel_context, 'sp_group', None) is not None) + + +def _get_sp_rank(sequence_parallel_context: Any | None) -> int: + if not _sp_is_enabled(sequence_parallel_context): + return 0 + rank = getattr(sequence_parallel_context, 'rank', None) + if rank is not None: + return int(rank) + import torch.distributed as dist + + return dist.get_rank(sequence_parallel_context.sp_group) + + +def _seq_to_head_shard(tensor: torch.Tensor, sequence_parallel_context: Any | None) -> torch.Tensor: + if not _sp_is_enabled(sequence_parallel_context): + return tensor + from twinkle.model.transformers.strategy.sequence_parallel import _SeqAllToAll + + if tensor.dim() == 3: + return _SeqAllToAll.apply(sequence_parallel_context.sp_group, tensor.unsqueeze(-1), 2, 1).squeeze(-1) + return _SeqAllToAll.apply(sequence_parallel_context.sp_group, tensor, 2, 1) + + +def _head_to_seq_shard(tensor: torch.Tensor, sequence_parallel_context: Any | None) -> torch.Tensor: + if not _sp_is_enabled(sequence_parallel_context): + return tensor + from twinkle.model.transformers.strategy.sequence_parallel import _SeqAllToAll + + if tensor.dim() == 3: + return _SeqAllToAll.apply(sequence_parallel_context.sp_group, tensor.unsqueeze(-1), 1, 2).squeeze(-1) + return _SeqAllToAll.apply(sequence_parallel_context.sp_group, tensor, 1, 2) + + +def _resolve_local_padding_mask( + attention_mask: torch.Tensor | None, + seq_len: int, + sequence_parallel_context: Any | None = None, +) -> torch.Tensor | None: + if attention_mask is None or attention_mask.dim() != 2: + return attention_mask + if attention_mask.shape[-1] == seq_len: + return attention_mask + if _sp_is_enabled(sequence_parallel_context): + sp_world_size = int(sequence_parallel_context.sp_world_size) + full_seq_len = attention_mask.shape[-1] + if full_seq_len % sp_world_size == 0: + local_seq_len = full_seq_len // sp_world_size + if local_seq_len == seq_len: + sp_rank = _get_sp_rank(sequence_parallel_context) + start = sp_rank * local_seq_len + end = start + local_seq_len + return attention_mask[:, start:end].contiguous() + if attention_mask.shape[-1] >= seq_len: + return attention_mask[:, :seq_len].contiguous() + return attention_mask + + +def _flatten_varlen_batch(tensor: torch.Tensor) -> torch.Tensor: + return tensor.reshape(1, tensor.shape[0] * tensor.shape[1], *tensor.shape[2:]) + + +def _pad_or_trim_2d_tensor(tensor: torch.Tensor | None, target_len: int, pad_value: int) -> torch.Tensor | None: + if tensor is None: + return None + if tensor.dim() == 3: + tensor = tensor[0] + if tensor.shape[-1] == target_len: + return tensor + if tensor.shape[-1] > target_len: + return tensor[..., :target_len].contiguous() + pad_shape = (*tensor.shape[:-1], target_len - tensor.shape[-1]) + pad_tensor = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device) + return torch.cat((tensor, pad_tensor), dim=-1) + + +def _build_varlen_metadata( + *, + position_ids: torch.Tensor | None, + attention_mask: torch.Tensor | None, + full_seq_len: int, +) -> tuple[torch.Tensor, torch.Tensor]: + position_ids = _pad_or_trim_2d_tensor(position_ids, full_seq_len, pad_value=-1) + + if position_ids is not None: + valid_mask = position_ids != -1 + elif attention_mask is not None: + attention_mask = _pad_or_trim_2d_tensor(attention_mask, full_seq_len, pad_value=0) + valid_mask = attention_mask != 0 + else: + raise ValueError('Varlen metadata requires at least one of position_ids or attention_mask.') + + cu_seqlens = [0] + total = 0 + for row_idx in range(valid_mask.shape[0]): + if position_ids is not None: + valid_positions = position_ids[row_idx][valid_mask[row_idx]] + if valid_positions.numel() == 0: + continue + seq_start_indices = torch.where(valid_positions == 0)[0] + if seq_start_indices.numel() == 0 or seq_start_indices[0].item() != 0: + seq_start_indices = torch.cat([ + torch.tensor([0], device=valid_positions.device, dtype=seq_start_indices.dtype), + seq_start_indices, + ]) + seq_end_indices = torch.cat([ + seq_start_indices[1:], + torch.tensor([valid_positions.numel()], device=valid_positions.device, dtype=seq_start_indices.dtype), + ]) + seq_lengths = (seq_end_indices - seq_start_indices).tolist() + else: + seq_lengths = [int(valid_mask[row_idx].sum().item())] + for seq_length in seq_lengths: + if seq_length <= 0: + continue + total += int(seq_length) + cu_seqlens.append(total) + return valid_mask, torch.tensor(cu_seqlens, device=valid_mask.device, dtype=torch.int32) + + +def _pack_varlen_tensor(tensor: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor: + return tensor[valid_mask].unsqueeze(0) + + +def _unpack_varlen_tensor( + packed_tensor: torch.Tensor, + valid_mask: torch.Tensor, + batch_size: int, + seq_len: int, +) -> torch.Tensor: + output = packed_tensor.new_zeros((batch_size, seq_len, *packed_tensor.shape[2:])) + output[valid_mask] = packed_tensor.squeeze(0) + return output + + +class TwinkleQwen3_5GatedDeltaNet(hf_qwen35.Qwen3_5GatedDeltaNet): + + def __init__(self, config: Qwen3_5TextConfig, layer_idx: int): + _ensure_linear_attention_fast_path() + super().__init__(config, layer_idx) + self.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN + self.causal_conv1d_update = _FLA_CAUSAL_CONV1D_UPDATE or hf_qwen35.causal_conv1d_update + self.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE + self.recurrent_gated_delta_rule = _FLA_FUSED_RECURRENT_GATED_DELTA_RULE + if _FLA_FUSED_RMS_NORM_GATED is not None and torch.cuda.is_available(): + self.norm = _FLA_FUSED_RMS_NORM_GATED( + self.head_v_dim, + eps=self.layer_norm_epsilon, + activation=self.activation, + device=torch.cuda.current_device(), + dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), + ) + + def _get_local_conv1d_weight(self, sp_rank: int, local_key_dim: int, local_value_dim: int) -> torch.Tensor: + w_full = self.conv1d.weight.squeeze(1) + key_offset = sp_rank * local_key_dim + value_offset = sp_rank * local_value_dim + w_q = w_full[key_offset:key_offset + local_key_dim] + w_k = w_full[self.key_dim + key_offset:self.key_dim + key_offset + local_key_dim] + w_v = w_full[2 * self.key_dim + value_offset:2 * self.key_dim + value_offset + local_value_dim] + return torch.cat((w_q, w_k, w_v), dim=0) + + def _apply_varlen_conv( + self, + mixed_qkv: torch.Tensor, + conv_weight: torch.Tensor, + cu_seq_lens_q: torch.Tensor | None, + ) -> torch.Tensor: + if self.causal_conv1d_fn is None: + raise ImportError( + 'TwinkleQwen3_5 linear attention requires fla.modules.convolution.causal_conv1d for prefill/train.') + output = self.causal_conv1d_fn( + x=mixed_qkv, + weight=conv_weight, + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + backend='triton', + cu_seqlens=cu_seq_lens_q, + ) + return _maybe_slice_tensor_output(output) + + def _apply_decode_conv( + self, + mixed_qkv: torch.Tensor, + conv_state: torch.Tensor, + conv_weight: torch.Tensor, + ) -> torch.Tensor: + if self.causal_conv1d_update is None: + raise ImportError( + 'TwinkleQwen3_5 decode requires a causal_conv1d_update implementation from flash-linear-attention ' + 'or causal-conv1d.') + mixed_qkv_t = mixed_qkv.transpose(1, 2).contiguous() + output = self.causal_conv1d_update( + mixed_qkv_t, + conv_state, + conv_weight, + self.conv1d.bias, + self.activation, + ) + output = _maybe_slice_tensor_output(output) + if output.dim() == 2: + output = output.unsqueeze(1) + elif output.dim() == 3 and output.shape[1] == conv_weight.shape[0]: + output = output.transpose(1, 2).contiguous() + return output + + def forward( + self, + hidden_states: torch.Tensor, + cache_params: hf_qwen35.Qwen3_5DynamicCache | None = None, + cache_position: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + cu_seq_lens_q: torch.Tensor | None = None, + sequence_parallel_context: Any | None = None, + ): + attention_mask = _resolve_local_padding_mask(attention_mask, hidden_states.shape[1], sequence_parallel_context) + hidden_states = hf_qwen35.apply_mask_to_padding_states(hidden_states, attention_mask) + batch_size, seq_len, _ = hidden_states.shape + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state and seq_len == 1 + and cache_position is not None) + + if cache_params is not None: + conv_state = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + else: + conv_state = None + recurrent_state = None + + mixed_qkv = self.in_proj_qkv(hidden_states) + z = self.in_proj_z(hidden_states).reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + full_attention_mask = attention_mask + + sp_enabled = _sp_is_enabled(sequence_parallel_context) + if sp_enabled: + sp_world_size = int(sequence_parallel_context.sp_world_size) + if self.num_k_heads % sp_world_size != 0 or self.num_v_heads % sp_world_size != 0: + raise RuntimeError( + 'TwinkleQwen3_5 linear attention requires sp_world_size to divide both ' + f'linear_num_key_heads ({self.num_k_heads}) and linear_num_value_heads ({self.num_v_heads}).') + local_num_k_heads = self.num_k_heads // sp_world_size + local_num_v_heads = self.num_v_heads // sp_world_size + local_key_dim = local_num_k_heads * self.head_k_dim + local_value_dim = local_num_v_heads * self.head_v_dim + + q_proj, k_proj, v_proj = torch.split(mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1) + q_proj = q_proj.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + k_proj = k_proj.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + v_proj = v_proj.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + q_proj = _seq_to_head_shard(q_proj, sequence_parallel_context) + k_proj = _seq_to_head_shard(k_proj, sequence_parallel_context) + v_proj = _seq_to_head_shard(v_proj, sequence_parallel_context) + b = _seq_to_head_shard(b.reshape(batch_size, seq_len, self.num_v_heads), sequence_parallel_context) + a = _seq_to_head_shard(a.reshape(batch_size, seq_len, self.num_v_heads), sequence_parallel_context) + + mixed_qkv = torch.cat( + ( + q_proj.reshape(batch_size, q_proj.shape[1], local_key_dim), + k_proj.reshape(batch_size, k_proj.shape[1], local_key_dim), + v_proj.reshape(batch_size, v_proj.shape[1], local_value_dim), + ), + dim=-1, + ) + conv_weight = self._get_local_conv1d_weight( + _get_sp_rank(sequence_parallel_context), local_key_dim, local_value_dim) + else: + local_num_k_heads = self.num_k_heads + local_num_v_heads = self.num_v_heads + local_key_dim = self.key_dim + local_value_dim = self.value_dim + b = b.reshape(batch_size, seq_len, self.num_v_heads) + a = a.reshape(batch_size, seq_len, self.num_v_heads) + conv_weight = self.conv1d.weight.squeeze(1) + + packed_valid_mask = None + packed_cu_seqlens = cu_seq_lens_q + packed_seq_len = mixed_qkv.shape[1] + use_varlen_pack = cu_seq_lens_q is not None and not use_precomputed_states + if use_varlen_pack: + full_position_ids = getattr(sequence_parallel_context, 'real_position_ids', None) + packed_valid_mask, packed_cu_seqlens = _build_varlen_metadata( + position_ids=full_position_ids, + attention_mask=full_attention_mask, + full_seq_len=packed_seq_len, + ) + mixed_qkv = _pack_varlen_tensor(mixed_qkv, packed_valid_mask) + b = _pack_varlen_tensor(b, packed_valid_mask) + a = _pack_varlen_tensor(a, packed_valid_mask) + + if use_precomputed_states: + if conv_state is None: + raise RuntimeError('Qwen3.5 decode requires initialized convolution state.') + mixed_qkv = self._apply_decode_conv(mixed_qkv, conv_state, conv_weight) + else: + if cache_params is not None: + cache_params.conv_states[self.layer_idx] = F.pad( + mixed_qkv.transpose(1, 2).contiguous(), + (self.conv_kernel_size - mixed_qkv.shape[1], 0), + ) + mixed_qkv = self._apply_varlen_conv(mixed_qkv, conv_weight, packed_cu_seqlens) + + query, key, value = torch.split(mixed_qkv, [local_key_dim, local_key_dim, local_value_dim], dim=-1) + qkv_batch_size = 1 if use_varlen_pack else batch_size + query = query.reshape(qkv_batch_size, query.shape[1], local_num_k_heads, self.head_k_dim) + key = key.reshape(qkv_batch_size, key.shape[1], local_num_k_heads, self.head_k_dim) + value = value.reshape(qkv_batch_size, value.shape[1], local_num_v_heads, self.head_v_dim) + + beta = b.sigmoid() + if sp_enabled: + head_offset = _get_sp_rank(sequence_parallel_context) * local_num_v_heads + head_slice = slice(head_offset, head_offset + local_num_v_heads) + g = -self.A_log[head_slice].float().exp() * F.softplus(a.float() + self.dt_bias[head_slice]) + else: + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + + if self.num_v_heads // self.num_k_heads > 1: + repeat = self.num_v_heads // self.num_k_heads + query = query.repeat_interleave(repeat, dim=2) + key = key.repeat_interleave(repeat, dim=2) + + if use_precomputed_states: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + cu_seqlens=packed_cu_seqlens, + ) + + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + + if use_varlen_pack: + core_attn_out = _unpack_varlen_tensor(core_attn_out, packed_valid_mask, batch_size, packed_seq_len) + core_attn_out = _head_to_seq_shard(core_attn_out, sequence_parallel_context) + core_attn_out = self.norm(core_attn_out.reshape(-1, self.head_v_dim), z.reshape(-1, self.head_v_dim)) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, self.value_dim) + return self.out_proj(core_attn_out) + + +class TwinkleQwen3_5DecoderLayer(hf_qwen35.Qwen3_5DecoderLayer): + + def __init__(self, config: Qwen3_5TextConfig, layer_idx: int): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == 'linear_attention': + self.linear_attn = TwinkleQwen3_5GatedDeltaNet(config, layer_idx) + elif self.layer_type == 'full_attention': + self.self_attn = hf_qwen35.Qwen3_5Attention(config, layer_idx) + else: + raise ValueError(f'Unsupported Qwen3.5 layer_type={self.layer_type!r}') + self.mlp = hf_qwen35.Qwen3_5MLP(config, config.intermediate_size) + self.input_layernorm = hf_qwen35.Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = hf_qwen35.Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + cu_seq_lens_q: torch.Tensor | None = None, + sequence_parallel_context: Any | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == 'linear_attention': + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + sequence_parallel_context=sequence_parallel_context, + ) + else: + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + return residual + hidden_states + + +class TwinkleQwen3_5PreTrainedModel(hf_qwen35.Qwen3_5PreTrainedModel): + config_class = Qwen3_5TextConfig + _no_split_modules = ['TwinkleQwen3_5DecoderLayer'] + _can_record_outputs = { + 'hidden_states': TwinkleQwen3_5DecoderLayer, + 'attentions': hf_qwen35.Qwen3_5Attention, + } + + +class TwinkleQwen3_5TextModel(TwinkleQwen3_5PreTrainedModel): + + def __init__(self, config: Qwen3_5TextConfig): + config = _ensure_text_config(config) + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.layers = nn.ModuleList( + [TwinkleQwen3_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = hf_qwen35.Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = hf_qwen35.Qwen3_5TextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self._sequence_parallel_context = None + self.requires_cu_seq_lens_q = any(layer_type == 'linear_attention' for layer_type in config.layer_types) + self.post_init() + + def set_sequence_parallel_context(self, context: Any | None) -> None: + self._sequence_parallel_context = context + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _update_linear_attn_mask(self, attention_mask, cache_position): + linear_attn_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + linear_attn_mask = None + return linear_attn_mask + + def _resolve_linear_attn_mask( + self, + attention_mask: torch.Tensor | None, + text_position_ids: torch.LongTensor | None, + seq_len: int, + ) -> torch.Tensor | None: + if attention_mask is not None and attention_mask.dim() == 2 and attention_mask.shape[-1] == seq_len: + return attention_mask + if text_position_ids is None: + return attention_mask if attention_mask is not None and attention_mask.dim() == 2 else None + dtype = attention_mask.dtype if attention_mask is not None else torch.int64 + return (text_position_ids != -1).to(dtype=dtype) + + @merge_with_config_defaults + @capture_outputs + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + cu_seq_lens_q: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError('You must specify exactly one of input_ids or inputs_embeds') + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = hf_qwen35.Qwen3_5DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + causal_mask = hf_qwen35.create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + sp_context = self._sequence_parallel_context + linear_attn_mask = self._update_linear_attn_mask( + self._resolve_linear_attn_mask(attention_mask, text_position_ids, hidden_states.shape[1]), + cache_position, + ) + if _sp_is_enabled(sp_context) and self.requires_cu_seq_lens_q and cu_seq_lens_q is None: + raise ValueError('TwinkleQwen3_5TextModel requires cu_seq_lens_q when sequence parallel is enabled.') + + for decoder_layer in self.layers[:self.config.num_hidden_layers]: + layer_mask = linear_attn_mask if decoder_layer.layer_type == 'linear_attention' else causal_mask + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + cu_seq_lens_q=cu_seq_lens_q if decoder_layer.layer_type == 'linear_attention' else None, + sequence_parallel_context=sp_context if decoder_layer.layer_type == 'linear_attention' else None, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return hf_qwen35.Qwen3_5ModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class TwinkleQwen3_5ForCausalLM(TwinkleQwen3_5PreTrainedModel, GenerationMixin): + _tied_weights_keys = {'lm_head.weight': 'model.embed_tokens.weight'} + _tp_plan = {'lm_head': 'colwise_gather_output'} + _pp_plan = {'lm_head': (['hidden_states'], ['logits'])} + _keys_to_ignore_on_load_unexpected = [r'^mtp.*', r'^model\.visual.*'] + + def __init__(self, config: Qwen3_5TextConfig): + config = _ensure_text_config(config) + super().__init__(config) + self.model = TwinkleQwen3_5TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_sequence_parallel_context(self, context: Any | None) -> None: + self.model.set_sequence_parallel_context(context) + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + cu_seq_lens_q: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + cu_seq_lens_q=cu_seq_lens_q, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 6033d943..a0a35562 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -14,7 +14,7 @@ from twinkle.processor import InputProcessor from ..multi_lora import MultiLora from .strategy import AccelerateStrategy -from .transformers import OptimizerGroup, TransformersModel +from .transformers import OptimizerGroup, TransformersModel, _default_gradient_accumulation_steps_for_device_mesh @remote_class() @@ -184,7 +184,9 @@ def add_adapter_to_model(self, adapter_name: str, config_or_dir: Union[PeftConfi self.optimizer_group[adapter_name] = self._construct_default_optimizer_group() self.optimizer_group[adapter_name].adapter_name = adapter_name self.optimizer_group[adapter_name].adapter_config = config_or_dir - _gas_default = kwargs.get('gradient_accumulation_steps', 1) + _gas_default = kwargs.get('gradient_accumulation_steps') + if _gas_default is None: + _gas_default = _default_gradient_accumulation_steps_for_device_mesh(self.device_mesh) self.optimizer_group[adapter_name].gradient_accumulation_steps = _gas_default self._default_tokenizer = self.optimizer_group[adapter_name].template.processor self.multi_adapter.acquire_lora(tenant_adapter_name=adapter_name, config=config_or_dir) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 64ea34f3..4f2720a9 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -23,6 +23,40 @@ def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor): return cu_seqlens +def get_flattened_cu_seqlens_from_position_ids(position_ids: torch.LongTensor): + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + if position_ids.dim() != 2: + raise ValueError(f'Expected 1D or 2D position_ids, got shape={tuple(position_ids.shape)}') + + device = position_ids.device + cu_seqlens = [0] + total = 0 + for row in position_ids: + row = row.clone() + row[row < 0] = 0 + seq_start_indices = torch.where(row == 0)[0] + if seq_start_indices.numel() == 0 or seq_start_indices[0].item() != 0: + seq_start_indices = torch.cat( + [torch.tensor([0], device=device, dtype=seq_start_indices.dtype), seq_start_indices]) + seq_end_indices = torch.cat([seq_start_indices[1:], torch.tensor([len(row)], device=device)]) + seq_lengths = (seq_end_indices - seq_start_indices).tolist() + for seq_length in seq_lengths: + total += int(seq_length) + cu_seqlens.append(total) + return torch.tensor(cu_seqlens, device=device, dtype=torch.long) + + +@dataclass(frozen=True) +class SequenceParallelContext: + sp_group: Optional[dist.ProcessGroup] + sp_world_size: int + rank: int + world_size: int + real_position_ids: Optional[torch.Tensor] + is_packed: bool + + def _get_raw_data_world_size(device_mesh: DeviceMesh) -> int: dp_world_size = device_mesh.dp_world_size or 1 fsdp_world_size = device_mesh.fsdp_world_size or 1 @@ -307,14 +341,17 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, a else: query_layer, key_layer, value_layer = query, key, value - position_ids = kwargs.pop('position_ids') + position_ids = kwargs.pop('position_ids', None) if position_ids is not None: - shape0 = position_ids.shape[0] - position_ids_output = torch.empty((shape0 * self.sequence_parallel.sp_world_size, position_ids.shape[1]), - dtype=position_ids.dtype, - device=position_ids.device) + position_ids = position_ids.contiguous() + gathered_shape = (self.sequence_parallel.sp_world_size, *position_ids.shape) + position_ids_output = torch.empty( + gathered_shape, + dtype=position_ids.dtype, + device=position_ids.device, + ) dist.all_gather_into_tensor(position_ids_output, position_ids, group=self.sequence_parallel._sp_group) - position_ids = torch.cat(position_ids_output.split(shape0, dim=0), dim=1) + position_ids = torch.cat(tuple(position_ids_output.unbind(dim=0)), dim=-1).contiguous() context_layer = self.local_attn( query_layer, key_layer, value_layer, attention_mask, *args, position_ids=position_ids, **kwargs) @@ -343,13 +380,27 @@ def __init__(self): self._sp_group = None self.num_heads = None self.causal_mask_func = None + self.attn_implementation = None self.extra_kwargs = {} + self.requires_cu_seq_lens_q = False + self._bound_llm_model = None @property def real_position_ids(self) -> torch.Tensor: """The real position ids, this is different from the position_ids in mrope""" return self.extra_kwargs.get('position_ids') + def _build_context(self) -> SequenceParallelContext: + rank = dist.get_rank(self._sp_group) if self._sp_group is not None and dist.is_initialized() else 0 + return SequenceParallelContext( + sp_group=self._sp_group, + sp_world_size=int(self.sp_world_size or 1), + rank=rank, + world_size=int(self.world_size or 1), + real_position_ids=self.real_position_ids, + is_packed=bool(self.extra_kwargs.get('is_packed', False)), + ) + def _prepare_flash_attn(self, base_model: torch.nn.Module): try: from transformers import masking_utils @@ -491,6 +542,12 @@ def _attention(query, key, value, *args, **kwargs): kwargs['cu_seq_lens_k'] = cu_seqlens kwargs['max_length_q'] = max_seqlen kwargs['max_length_k'] = max_seqlen + else: + # Dense, non-packed SP path should not forward position_ids into FA2. + # Qwen3.5 has already applied RoPE before entering the attention interface, and keeping + # position_ids here can make HF's FA2 helper mis-detect packed/varlen mode, especially when + # batch_size == 1 and position_ids carries mRoPE-style leading dimensions. + kwargs.pop('position_ids', None) return ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'](module, query, key, value, *args, **kwargs)[0] @@ -629,6 +686,9 @@ def prepare( else: if hasattr(llm_model, '_update_causal_mask'): self.causal_mask_func = llm_model._update_causal_mask + self.attn_implementation = ( + get_config_attr(model.config, '_attn_implementation') + or get_config_attr(model.config, '_attn_implementation_internal')) if not SequenceParallel._global_inited: # these operations are global initializations and patches @@ -637,6 +697,10 @@ def prepare( SequenceParallel._global_inited = True self._prepare_forward_hook(llm_model) + self.requires_cu_seq_lens_q = bool(getattr(llm_model, 'requires_cu_seq_lens_q', False)) + self._bound_llm_model = llm_model + if hasattr(llm_model, 'set_sequence_parallel_context'): + llm_model.set_sequence_parallel_context(self._build_context()) if SequenceParallel._is_moe_model(getattr(model, 'config', None)): self._prepare_moe_aux_loss(llm_model) @@ -748,26 +812,28 @@ def pad_and_split_inputs(self, loss_scale = self.pad(loss_scale, padding_value=0., position_ids=real_position_ids) if real_position_ids is not None: real_position_ids = self.pad(real_position_ids, padding_value=-1, position_ids=real_position_ids) - # Build a 2D attention_mask whenever we padded for SP alignment so FlashAttention2 can unpad correctly. - # For packed batches (batch_size==1 with multiple position_id resets), relying on position_ids alone is - # unsafe if we also appended SP-alignment padding (position_ids=-1), because HF's FA2 varlen path will - # include the padded tail in the last segment when attention_mask is None. - if (input_ids is not None or input_embeds is not None) and batch_size > 1: - # not padding_free, so not ring-attention + # Preserve a 2D attention mask only when there is real padding to describe. + # For dense batches, FA2/FA3 should keep `attention_mask=None`; otherwise HF routes the kernel through + # its varlen/unpadding path and can introduce unnecessary overhead or invalid accesses in SP mode. + if input_ids is not None or input_embeds is not None: inputs = input_ids if input_ids is not None else input_embeds - attn_shape = inputs.shape[1] # The sequence length + attn_shape = inputs.shape[1] if attention_mask is None: - # Mask out padded positions introduced by sequence-parallel padding. - # `real_position_ids` is padded with `-1` (see above), so use it to build a valid-token mask. - attention_mask = (real_position_ids != -1).to(dtype=torch.int64) - # no need position_ids here, because padding_free does not need attention_mask, - # so this is not ring-attention - attention_mask = self.pad(attention_mask, padding_value=0) - cache_position = torch.arange(0, attn_shape, device=inputs.device) - # pad attention mask to 4d to avoid calculation errors - if hasattr(self, 'causal_mask_func') and self.causal_mask_func is not None: - attention_mask = self.causal_mask_func(attention_mask, inputs.to(self.model_dtype), cache_position, - None, None) + has_padding = bool(real_position_ids is not None and torch.any(real_position_ids == -1)) + if has_padding: + attention_mask = (real_position_ids != -1).to(dtype=torch.int64) + else: + has_padding = not bool(torch.all(attention_mask != 0)) + if not has_padding: + attention_mask = None + if attention_mask is not None: + attention_mask = self.pad(attention_mask, padding_value=0) + if self.attn_implementation not in ('flash_attention_2', 'flash_attention_3'): + cache_position = torch.arange(0, attn_shape, device=inputs.device) + # SDPA/eager-style paths still expect a fully materialized causal mask here. + if hasattr(self, 'causal_mask_func') and self.causal_mask_func is not None: + attention_mask = self.causal_mask_func(attention_mask, inputs.to(self.model_dtype), + cache_position, None, None) if extra_split_values is not None: for (tensor, pad_value, split_dim) in extra_split_values: extra_values.append( @@ -848,12 +914,19 @@ def prepare_inputs(self, inputs): """ position_ids = None input_ids = inputs.get('input_ids') + inputs_embeds = inputs.get('inputs_embeds') position_ids = inputs.get('position_ids') - if position_ids is not None and input_ids is not None and position_ids.shape[0] == input_ids.shape[0]: + batch_source = input_ids if input_ids is not None else inputs_embeds + if position_ids is not None and batch_source is not None and position_ids.shape[0] == batch_source.shape[0]: self.extra_kwargs['position_ids'] = position_ids.clone() self.extra_kwargs['is_packed'] = self._is_packed_position_ids(position_ids) if input_ids is not None: self.extra_kwargs['input_ids'] = input_ids.clone() + if self._bound_llm_model is not None and hasattr(self._bound_llm_model, 'set_sequence_parallel_context'): + self._bound_llm_model.set_sequence_parallel_context(self._build_context()) + if self.requires_cu_seq_lens_q and position_ids is not None: + padded_position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids, dim=-1) + inputs['cu_seq_lens_q'] = get_flattened_cu_seqlens_from_position_ids(padded_position_ids).to(torch.int32) if 'labels' in inputs: labels = inputs['labels'] _, _, labels, _, _, _, _ = self.pad_and_split_inputs( diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index d00e80ed..c6eb569a 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -151,6 +151,13 @@ def calculate_metrics(self, is_training): DEFAULT_WEIGHT_DECAY = 0.01 +def _default_gradient_accumulation_steps_for_device_mesh(device_mesh: Optional[DeviceMesh]) -> int: + if device_mesh is None: + return 1 + ulysses_size = getattr(device_mesh, 'ulysses_size', None) or 1 + return ulysses_size if ulysses_size > 1 else 1 + + @remote_class() class TransformersModel(TwinkleModel, PreTrainedModel, CheckpointEngineMixin): """The transformers model wrapper. @@ -352,6 +359,7 @@ def _construct_default_optimizer_group(self): loss_instance=CrossEntropyLoss(reduction='sum'), template=Template(self.tokenizer_id), processor=InputProcessor(self.device_mesh), + gradient_accumulation_steps=_default_gradient_accumulation_steps_for_device_mesh(self.device_mesh), _device_mesh=self.device_mesh, ) @@ -1010,7 +1018,9 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str self._construct_default_optimizer_group()) self.optimizer_group[adapter_name].adapter_name = adapter_name self.optimizer_group[adapter_name].adapter_config = config - _gas_default = kwargs.get('gradient_accumulation_steps', 1) + _gas_default = kwargs.get('gradient_accumulation_steps') + if _gas_default is None: + _gas_default = _default_gradient_accumulation_steps_for_device_mesh(self.device_mesh) self.optimizer_group[adapter_name].gradient_accumulation_steps = _gas_default self._default_tokenizer = self.optimizer_group[adapter_name].template.processor self.active_group = adapter_name diff --git a/src/twinkle/utils/device_mesh.py b/src/twinkle/utils/device_mesh.py index 9f5aa9e7..d8bbfc2e 100644 --- a/src/twinkle/utils/device_mesh.py +++ b/src/twinkle/utils/device_mesh.py @@ -352,18 +352,22 @@ def get_data_rank_from_global_rank(self, global_rank: int) -> int: @property def data_world_size(self) -> int: """Consider all dp/fsdp ranks, uses to determine how to distribute the data""" - dp_world_size = self.dp_world_size - fsdp_world_size = self.fsdp_world_size + data_world_size = self.raw_data_world_size ulysses_size = self.ulysses_size or 1 - if fsdp_world_size is not None and fsdp_world_size > 1: - data_world_size = dp_world_size * fsdp_world_size if dp_world_size is not None else fsdp_world_size - else: - data_world_size = dp_world_size if dp_world_size is not None else 1 assert data_world_size % ulysses_size == 0, ( f'data_world_size: {data_world_size} cannot be divided by ulysses_size: {ulysses_size}.') return data_world_size // ulysses_size + @property + def raw_data_world_size(self) -> int: + """The data world size before applying ulysses sequence parallel grouping.""" + dp_world_size = self.dp_world_size or 1 + fsdp_world_size = self.fsdp_world_size + if fsdp_world_size is not None and fsdp_world_size > 1: + return dp_world_size * fsdp_world_size + return dp_world_size + def get_slice(self, total_length: int, rank: Optional[int] = None) -> slice: world_size = self.data_world_size if world_size == 1: diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 79bf78ad..778ea171 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -135,6 +135,21 @@ def test_device_mesh_sampler_with_encode(self): assert 'input_ids' in batch assert batch['input_ids'].shape[0] == 2 + def test_device_mesh_sampler_auto_adjusts_batch_for_ulysses(self): + csv_path = str(TEST_DATA_DIR / 'test.csv') + dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path)) + + device_mesh = DeviceMesh( + device_type='cpu', + mesh=np.arange(4).reshape(2, 2), + mesh_dim_names=('dp', 'fsdp'), + ulysses_size=2, + ) + + dataloader = DataLoader(dataset=dataset, batch_size=8, device_mesh=device_mesh) + + assert dataloader.batch_size == 4 + class TestRetrySampler: diff --git a/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py new file mode 100644 index 00000000..ad8c8108 --- /dev/null +++ b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py @@ -0,0 +1,481 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import tempfile +import torch +import unittest +from contextlib import ExitStack +from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig +from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM +from types import SimpleNamespace +from unittest.mock import patch + +from twinkle.model.transformers.models.qwen3_5 import modeling_qwen3_5 as tw_qwen35 +from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallel, SequenceParallelContext + + +def _build_text_config(layer_types=None) -> Qwen3_5TextConfig: + layer_types = layer_types or ['full_attention'] + return Qwen3_5TextConfig( + vocab_size=64, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=len(layer_types), + num_attention_heads=4, + num_key_value_heads=2, + head_dim=4, + hidden_act='silu', + max_position_embeddings=128, + rms_norm_eps=1e-6, + attention_dropout=0.0, + linear_conv_kernel_dim=3, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + layer_types=layer_types, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + + +def _linear_attention_runtime_available() -> bool: + return bool(torch.cuda.is_available() and tw_qwen35._FLA_CAUSAL_CONV1D_FN is not None + and tw_qwen35._FLA_CHUNK_GATED_DELTA_RULE is not None + and tw_qwen35._FLA_FUSED_RECURRENT_GATED_DELTA_RULE is not None and tw_qwen35._HAS_CAUSAL_CONV1D) + + +class _ContextReceiver: + + def __init__(self): + self.context = None + + def set_sequence_parallel_context(self, context): + self.context = context + + +class TestTwinkleQwen35TextModel(unittest.TestCase): + + def test_rejects_non_text_config(self): + with self.assertRaises(TypeError): + tw_qwen35.TwinkleQwen3_5ForCausalLM(Qwen3_5Config()) + + def test_text_model_accepts_sequence_parallel_context(self): + model = tw_qwen35.TwinkleQwen3_5TextModel(_build_text_config(['full_attention'])) + context = SequenceParallelContext( + sp_group=None, + sp_world_size=2, + rank=0, + world_size=2, + real_position_ids=torch.tensor([[0, 1, 2]], dtype=torch.long), + is_packed=False, + ) + model.set_sequence_parallel_context(context) + self.assertIs(model._sequence_parallel_context, context) + + def test_from_pretrained_loads_text_only_weights(self): + config = _build_text_config(['full_attention']) + hf_model = Qwen3_5ForCausalLM(config).eval() + with tempfile.TemporaryDirectory() as temp_dir: + hf_model.save_pretrained(temp_dir) + tw_model = tw_qwen35.TwinkleQwen3_5ForCausalLM.from_pretrained(temp_dir).eval() + + input_ids = torch.tensor([[1, 2, 3, 4]], dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + position_ids = torch.arange(input_ids.shape[1], dtype=torch.long).unsqueeze(0) + with torch.no_grad(): + hf_outputs = hf_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + return_dict=True, + ) + tw_outputs = tw_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + return_dict=True, + ) + + torch.testing.assert_close(tw_outputs.logits, hf_outputs.logits, rtol=0, atol=0) + + def test_from_pretrained_loads_mixed_linear_weights(self): + config = _build_text_config(['full_attention', 'linear_attention']) + hf_model = Qwen3_5ForCausalLM(config).eval() + linear_param_names = [ + 'model.layers.1.linear_attn.in_proj_qkv.weight', + 'model.layers.1.linear_attn.in_proj_z.weight', + 'model.layers.1.linear_attn.out_proj.weight', + ] + + with tempfile.TemporaryDirectory() as temp_dir: + hf_model.save_pretrained(temp_dir) + with ExitStack() as stack: + stack.enter_context(patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_FN', lambda *args, **kwargs: args[0])) + stack.enter_context( + patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_UPDATE', lambda *args, **kwargs: args[0])) + stack.enter_context( + patch.object( + tw_qwen35, + '_FLA_CHUNK_GATED_DELTA_RULE', + lambda *args, **kwargs: (args[2], None), + )) + stack.enter_context( + patch.object( + tw_qwen35, + '_FLA_FUSED_RECURRENT_GATED_DELTA_RULE', + lambda *args, **kwargs: (args[2], None), + )) + stack.enter_context(patch.object(tw_qwen35, '_HAS_CAUSAL_CONV1D', True)) + tw_model = tw_qwen35.TwinkleQwen3_5ForCausalLM.from_pretrained(temp_dir).eval() + + hf_state_dict = hf_model.state_dict() + tw_state_dict = tw_model.state_dict() + for param_name in linear_param_names: + self.assertIn(param_name, tw_state_dict) + torch.testing.assert_close(tw_state_dict[param_name], hf_state_dict[param_name], rtol=0, atol=0) + + if _linear_attention_runtime_available(): + device = torch.device('cuda:0') + tw_model = tw_qwen35.TwinkleQwen3_5ForCausalLM.from_pretrained(temp_dir).to(device).eval() + input_ids = torch.tensor([[1, 2, 3, 4]], dtype=torch.long, device=device) + position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=device).unsqueeze(0) + with torch.no_grad(): + outputs = tw_model( + input_ids=input_ids, + position_ids=position_ids, + use_cache=False, + return_dict=True, + ) + self.assertEqual(tuple(outputs.logits.shape), (1, 4, config.vocab_size)) + + def test_sequence_parallel_prepare_inputs_injects_cu_seq_lens(self): + sp = SequenceParallel() + sp.world_size = 2 + sp.sp_world_size = 2 + sp.requires_cu_seq_lens_q = True + receiver = _ContextReceiver() + sp._bound_llm_model = receiver + inputs = { + 'input_ids': torch.tensor([[1, 2, 3, 4]], dtype=torch.long), + 'position_ids': torch.tensor([[0, 1, 2, 3]], dtype=torch.long), + } + + outputs = sp.prepare_inputs(inputs) + + self.assertIn('cu_seq_lens_q', outputs) + self.assertTrue(torch.equal(outputs['cu_seq_lens_q'], torch.tensor([0, 4], dtype=torch.int32))) + self.assertIsNotNone(receiver.context) + self.assertFalse(receiver.context.is_packed) + self.assertTrue(torch.equal(receiver.context.real_position_ids, inputs['position_ids'])) + + def test_sequence_parallel_prepare_inputs_injects_flattened_cu_seq_lens_for_batched_rows(self): + sp = SequenceParallel() + sp.world_size = 2 + sp.sp_world_size = 2 + sp.requires_cu_seq_lens_q = True + receiver = _ContextReceiver() + sp._bound_llm_model = receiver + inputs = { + 'input_ids': torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.long), + 'position_ids': torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.long), + } + + outputs = sp.prepare_inputs(inputs) + + self.assertIn('cu_seq_lens_q', outputs) + self.assertTrue(torch.equal(outputs['cu_seq_lens_q'], torch.tensor([0, 4, 8], dtype=torch.int32))) + + def test_sequence_parallel_prepare_inputs_tracks_position_ids_for_inputs_embeds(self): + sp = SequenceParallel() + sp.world_size = 2 + sp.sp_world_size = 2 + sp.requires_cu_seq_lens_q = True + receiver = _ContextReceiver() + sp._bound_llm_model = receiver + inputs = { + 'inputs_embeds': torch.randn(2, 4, 8), + 'position_ids': torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.long), + } + + outputs = sp.prepare_inputs(inputs) + + self.assertIn('cu_seq_lens_q', outputs) + self.assertTrue(torch.equal(outputs['cu_seq_lens_q'], torch.tensor([0, 4, 8], dtype=torch.int32))) + self.assertIsNotNone(receiver.context) + self.assertTrue(torch.equal(receiver.context.real_position_ids, inputs['position_ids'])) + + def test_linear_attention_requires_fast_path_dependencies(self): + with patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_FN', None), \ + patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_UPDATE', None), \ + patch.object(tw_qwen35, '_FLA_CHUNK_GATED_DELTA_RULE', None), \ + patch.object(tw_qwen35, '_FLA_FUSED_RECURRENT_GATED_DELTA_RULE', None), \ + patch.object(tw_qwen35, '_HAS_CAUSAL_CONV1D', False): + with self.assertRaises(ImportError): + tw_qwen35.TwinkleQwen3_5TextModel(_build_text_config(['linear_attention'])) + + def test_linear_attention_sp_passes_cu_seq_lens_and_keeps_z_local(self): + captured = { + 'cu_seqlens': None, + 'seq_to_head_calls': 0, + 'head_to_seq_calls': 0, + 'norm_z_shape': None, + } + + def fake_conv(x, weight, bias, activation, seq_idx=None, backend=None, cu_seqlens=None): + del weight, bias, activation, seq_idx, backend + captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None + return x + + def fake_chunk_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, + cu_seqlens=None): + del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel + captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None + return value, None + + def fake_recurrent_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False): + del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel + return value, None + + def fake_seq_to_head(tensor, context): + captured['seq_to_head_calls'] += 1 + sp_world_size = context.sp_world_size + rank = context.rank + if tensor.dim() == 4: + local_heads = tensor.shape[2] // sp_world_size + start = rank * local_heads + end = start + local_heads + return tensor[:, :, start:end, :].contiguous() + if tensor.dim() == 3: + local_heads = tensor.shape[2] // sp_world_size + start = rank * local_heads + end = start + local_heads + return tensor[:, :, start:end].contiguous() + return tensor + + def fake_head_to_seq(tensor, context): + captured['head_to_seq_calls'] += 1 + return tensor.repeat_interleave(context.sp_world_size, dim=2) + + class DummyNorm(torch.nn.Module): + + def forward(self, x, z): + captured['norm_z_shape'] = tuple(z.shape) + return x + z + + with patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_FN', fake_conv), \ + patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_UPDATE', lambda *args, **kwargs: args[0]), \ + patch.object(tw_qwen35, '_FLA_CHUNK_GATED_DELTA_RULE', fake_chunk_rule), \ + patch.object(tw_qwen35, '_FLA_FUSED_RECURRENT_GATED_DELTA_RULE', fake_recurrent_rule), \ + patch.object(tw_qwen35, '_FLA_FUSED_RMS_NORM_GATED', None), \ + patch.object(tw_qwen35, '_HAS_CAUSAL_CONV1D', True), \ + patch.object(tw_qwen35, '_seq_to_head_shard', side_effect=fake_seq_to_head), \ + patch.object(tw_qwen35, '_head_to_seq_shard', side_effect=fake_head_to_seq): + config = _build_text_config(['linear_attention']) + module = tw_qwen35.TwinkleQwen3_5GatedDeltaNet(config, layer_idx=0) + module.norm = DummyNorm() + hidden_states = torch.randn(1, 2, config.hidden_size) + attention_mask = torch.ones(1, 2, dtype=torch.int64) + cu_seq_lens_q = torch.tensor([0, 2], dtype=torch.int32) + context = SequenceParallelContext( + sp_group='dummy_group', + sp_world_size=2, + rank=0, + world_size=2, + real_position_ids=torch.tensor([[0, 1]], dtype=torch.long), + is_packed=False, + ) + + output = module( + hidden_states=hidden_states, + attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + sequence_parallel_context=context, + ) + + self.assertGreater(captured['seq_to_head_calls'], 0) + self.assertGreater(captured['head_to_seq_calls'], 0) + self.assertTrue(torch.equal(captured['cu_seqlens'], cu_seq_lens_q)) + self.assertEqual( + captured['norm_z_shape'], + (hidden_states.shape[0] * hidden_states.shape[1] * config.linear_num_value_heads, + config.linear_value_head_dim), + ) + self.assertEqual(tuple(output.shape), (1, 2, config.hidden_size)) + + def test_linear_attention_sp_flattens_batched_varlen_inputs(self): + captured = { + 'query_shape': None, + 'cu_seqlens': None, + } + + def fake_conv(x, weight, bias, activation, seq_idx=None, backend=None, cu_seqlens=None): + del weight, bias, activation, seq_idx, backend + captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None + return x + + def fake_chunk_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, + cu_seqlens=None): + del key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel + captured['query_shape'] = tuple(query.shape) + captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None + return query.new_zeros(query.shape[0], query.shape[1], 4, 4), None + + def fake_recurrent_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False): + del query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel + raise AssertionError('recurrent path should not be used') + + class DummyNorm(torch.nn.Module): + + def forward(self, x, z): + return x + z + + with patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_FN', fake_conv), \ + patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_UPDATE', lambda *args, **kwargs: args[0]), \ + patch.object(tw_qwen35, '_FLA_CHUNK_GATED_DELTA_RULE', fake_chunk_rule), \ + patch.object(tw_qwen35, '_FLA_FUSED_RECURRENT_GATED_DELTA_RULE', fake_recurrent_rule), \ + patch.object(tw_qwen35, '_FLA_FUSED_RMS_NORM_GATED', None), \ + patch.object(tw_qwen35, '_HAS_CAUSAL_CONV1D', True): + config = _build_text_config(['linear_attention']) + module = tw_qwen35.TwinkleQwen3_5GatedDeltaNet(config, layer_idx=0) + module.norm = DummyNorm() + hidden_states = torch.randn(2, 3, config.hidden_size) + attention_mask = torch.ones(2, 3, dtype=torch.int64) + cu_seq_lens_q = torch.tensor([0, 3, 6], dtype=torch.int32) + + output = module( + hidden_states=hidden_states, + attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + ) + + self.assertEqual(captured['query_shape'], (1, 6, 4, 4)) + self.assertTrue(torch.equal(captured['cu_seqlens'], cu_seq_lens_q)) + self.assertEqual(tuple(output.shape), (2, 3, config.hidden_size)) + + def test_linear_attention_sp_uses_local_attention_mask(self): + captured = {'mask': None} + + def fake_conv(x, weight, bias, activation, seq_idx=None, backend=None, cu_seqlens=None): + del weight, bias, activation, seq_idx, backend, cu_seqlens + return x + + def fake_chunk_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, + cu_seqlens=None): + del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel, cu_seqlens + return value, None + + def fake_recurrent_rule(query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False): + del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel + return value, None + + with patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_FN', fake_conv), \ + patch.object(tw_qwen35, '_FLA_CAUSAL_CONV1D_UPDATE', lambda *args, **kwargs: args[0]), \ + patch.object(tw_qwen35, '_FLA_CHUNK_GATED_DELTA_RULE', fake_chunk_rule), \ + patch.object(tw_qwen35, '_FLA_FUSED_RECURRENT_GATED_DELTA_RULE', fake_recurrent_rule), \ + patch.object(tw_qwen35, '_HAS_CAUSAL_CONV1D', True): + config = _build_text_config(['linear_attention']) + model = tw_qwen35.TwinkleQwen3_5TextModel(config) + model.set_sequence_parallel_context( + SequenceParallelContext( + sp_group='dummy_group', + sp_world_size=2, + rank=0, + world_size=2, + real_position_ids=torch.tensor([[0, 1, 2, -1]], dtype=torch.long), + is_packed=False, + )) + + def fake_linear_forward(hidden_states, + cache_params=None, + cache_position=None, + attention_mask=None, + cu_seq_lens_q=None, + sequence_parallel_context=None): + del hidden_states, cache_params, cache_position, cu_seq_lens_q, sequence_parallel_context + captured['mask'] = attention_mask.clone() if attention_mask is not None else None + return torch.zeros(1, 2, config.hidden_size) + + with patch.object(model.layers[0].linear_attn, 'forward', side_effect=fake_linear_forward): + _ = model( + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + attention_mask=torch.tensor([[1, 1, 1, 0]], dtype=torch.int64), + position_ids=torch.tensor([[0, -1]], dtype=torch.long), + cache_position=torch.tensor([0, 1], dtype=torch.long), + cu_seq_lens_q=torch.tensor([0, 2], dtype=torch.int32), + use_cache=False, + ) + + self.assertIsNotNone(captured['mask']) + self.assertTrue(torch.equal(captured['mask'], torch.tensor([[1, 0]], dtype=torch.int64))) + + def test_sequence_parallel_drops_dense_attention_mask_for_flash_attention_2(self): + sp = SequenceParallel() + sp.world_size = 2 + sp.sp_world_size = 2 + sp.tokenizer = SimpleNamespace(pad_token_id=0) + sp.model_dtype = torch.bfloat16 + sp.attn_implementation = 'flash_attention_2' + sp.causal_mask_func = lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError('should not build 4d mask')) + + input_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.long) + position_ids = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.long) + + for attention_mask in (None, torch.ones(2, 4, dtype=torch.int64)): + _, _, _, _, resolved_attention_mask, _, _ = sp.pad_and_split_inputs( + input_ids=input_ids, + input_embeds=None, + labels=None, + position_ids=position_ids, + attention_mask=attention_mask, + loss_scale=None, + real_position_ids=position_ids, + ) + self.assertIsNone(resolved_attention_mask) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/sequence_parallel/test_twinkle_qwen3_5_text_model_parity.py b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model_parity.py new file mode 100644 index 00000000..6b3fc1d1 --- /dev/null +++ b/tests/sequence_parallel/test_twinkle_qwen3_5_text_model_parity.py @@ -0,0 +1,364 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import copy +import os +import socket +import tempfile +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import traceback +import unittest +from datetime import timedelta +from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig +from transformers.utils.import_utils import is_flash_attn_2_available +from types import SimpleNamespace + +from twinkle.model.transformers.models.qwen3_5 import modeling_qwen3_5 as tw_qwen35 +from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallel, SequenceParallelContext +from twinkle.utils import DeviceMesh + + +def _seed_everything(seed: int) -> None: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + + +def _build_linear_parity_config() -> Qwen3_5TextConfig: + return Qwen3_5TextConfig( + vocab_size=64, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + hidden_act='silu', + max_position_embeddings=128, + rms_norm_eps=1e-6, + attention_dropout=0.0, + linear_conv_kernel_dim=3, + linear_key_head_dim=16, + linear_value_head_dim=16, + linear_num_key_heads=2, + linear_num_value_heads=4, + layer_types=['linear_attention'], + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + + +def _build_mixed_parity_config() -> Qwen3_5TextConfig: + config = Qwen3_5TextConfig( + vocab_size=64, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + hidden_act='silu', + max_position_embeddings=128, + rms_norm_eps=1e-6, + attention_dropout=0.0, + linear_conv_kernel_dim=3, + linear_key_head_dim=16, + linear_value_head_dim=16, + linear_num_key_heads=2, + linear_num_value_heads=4, + layer_types=['full_attention', 'linear_attention'], + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + attn_implementation = 'flash_attention_2' if is_flash_attn_2_available() else 'sdpa' + config._attn_implementation = attn_implementation + config._attn_implementation_internal = attn_implementation + return config + + +def _linear_attention_runtime_available() -> bool: + return bool(torch.cuda.is_available() and tw_qwen35._FLA_CAUSAL_CONV1D_FN is not None + and tw_qwen35._FLA_CHUNK_GATED_DELTA_RULE is not None + and tw_qwen35._FLA_FUSED_RECURRENT_GATED_DELTA_RULE is not None and tw_qwen35._HAS_CAUSAL_CONV1D) + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(('127.0.0.1', 0)) + return sock.getsockname()[1] + + +def _all_gather_seq(local_tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: + world_size = dist.get_world_size(group) + chunks = [torch.empty_like(local_tensor) for _ in range(world_size)] + dist.all_gather(chunks, local_tensor.contiguous(), group=group) + return torch.cat(chunks, dim=1).contiguous() + + +def _all_reduce_grads(module: torch.nn.Module, group: dist.ProcessGroup) -> None: + for param in module.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad, group=group) + + +def _relative_error(actual: torch.Tensor, expected: torch.Tensor) -> float: + actual_fp32 = actual.detach().to(dtype=torch.float32) + expected_fp32 = expected.detach().to(dtype=torch.float32) + return float((actual_fp32 - expected_fp32).norm() / (expected_fp32.norm() + 1e-12)) + + +def _assert_relative_error(actual: torch.Tensor, expected: torch.Tensor, rel_tol: float, name: str) -> None: + rel = _relative_error(actual, expected) + if rel > rel_tol: + raise AssertionError(f'{name} relative error {rel:.4e} exceeds tolerance {rel_tol:.4e}') + + +def _write_error(error_prefix: str, rank: int) -> None: + with open(f'{error_prefix}.rank{rank}.err', 'w', encoding='utf-8') as f: + f.write(traceback.format_exc()) + + +def _run_linear_attention_parity_worker(rank: int, world_size: int, port: int, error_prefix: str) -> None: + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(port) + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_WORLD_SIZE'] = str(world_size) + torch.cuda.set_device(rank) + + try: + dist.init_process_group( + backend='nccl', + rank=rank, + world_size=world_size, + timeout=timedelta(minutes=10), + ) + + device = torch.device(f'cuda:{rank}') + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + seed = 1234 + batch_size = 2 + seq_len = 8 + local_seq_len = seq_len // world_size + start = rank * local_seq_len + end = start + local_seq_len + + _seed_everything(seed) + config = _build_linear_parity_config() + baseline_module = tw_qwen35.TwinkleQwen3_5GatedDeltaNet( + config, layer_idx=0).to( + device=device, dtype=dtype).eval() + sp_module = copy.deepcopy(baseline_module).to(device=device, dtype=dtype).eval() + + full_hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device=device, dtype=dtype) + dist.broadcast(full_hidden_states, src=0) + full_position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64, device=device) + cu_seq_lens_q = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=device) + + baseline_hidden_states = full_hidden_states.detach().clone().requires_grad_(True) + baseline_output = baseline_module( + hidden_states=baseline_hidden_states, + attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + ) + baseline_loss = baseline_output.float().square().sum() + baseline_loss.backward() + baseline_input_grad = baseline_hidden_states.grad.detach() + baseline_param_grads = { + name: param.grad.detach().clone() + for name, param in baseline_module.named_parameters() if param.grad is not None + } + + sp_hidden_states = full_hidden_states[:, start:end].detach().clone().requires_grad_(True) + sp_output = sp_module( + hidden_states=sp_hidden_states, + attention_mask=attention_mask[:, start:end].contiguous(), + cu_seq_lens_q=cu_seq_lens_q, + sequence_parallel_context=SequenceParallelContext( + sp_group=dist.group.WORLD, + sp_world_size=world_size, + rank=rank, + world_size=world_size, + real_position_ids=full_position_ids, + is_packed=False, + ), + ) + sp_loss = sp_output.float().square().sum() + sp_loss.backward() + _all_reduce_grads(sp_module, dist.group.WORLD) + + sp_output_full = _all_gather_seq(sp_output.detach(), dist.group.WORLD) + sp_input_grad_full = _all_gather_seq(sp_hidden_states.grad.detach(), dist.group.WORLD) + + torch.testing.assert_close( + sp_output_full.to(dtype=torch.float32), + baseline_output.detach().to(dtype=torch.float32), + rtol=1e-3, + atol=1e-3, + ) + _assert_relative_error(sp_input_grad_full, baseline_input_grad, 1e-2, 'linear_attention.input_grad') + + for name in ( + 'in_proj_qkv.weight', + 'in_proj_z.weight', + 'out_proj.weight', + ): + _assert_relative_error( + sp_module.get_parameter(name).grad, + baseline_param_grads[name], + 2e-2, + f'linear_attention.{name}', + ) + except Exception: + _write_error(error_prefix, rank) + raise + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _run_mixed_text_model_parity_worker(rank: int, world_size: int, port: int, error_prefix: str) -> None: + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(port) + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_WORLD_SIZE'] = str(world_size) + torch.cuda.set_device(rank) + + try: + dist.init_process_group( + backend='nccl', + rank=rank, + world_size=world_size, + timeout=timedelta(minutes=10), + ) + + device = torch.device(f'cuda:{rank}') + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + seed = 5678 + batch_size = 2 + seq_len = 8 + + _seed_everything(seed) + config = _build_mixed_parity_config() + baseline_model = tw_qwen35.TwinkleQwen3_5TextModel(config).to(device=device, dtype=dtype).eval() + sp_model = copy.deepcopy(baseline_model).to(device=device, dtype=dtype).eval() + + full_inputs_embeds = torch.randn(batch_size, seq_len, config.hidden_size, device=device, dtype=dtype) + dist.broadcast(full_inputs_embeds, src=0) + full_position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + + baseline_inputs_embeds = full_inputs_embeds.detach().clone().requires_grad_(True) + baseline_outputs = baseline_model( + inputs_embeds=baseline_inputs_embeds, + position_ids=full_position_ids, + use_cache=False, + ) + baseline_hidden = baseline_outputs.last_hidden_state + baseline_loss = baseline_hidden.float().square().sum() + baseline_loss.backward() + baseline_input_grad = baseline_inputs_embeds.grad.detach() + baseline_param_grads = { + name: param.grad.detach().clone() + for name, param in baseline_model.named_parameters() if param.grad is not None + } + + device_mesh = DeviceMesh.from_sizes( + world_size=world_size, + dp_size=world_size, + ulysses_size=world_size, + device_type='cuda', + ) + sp = SequenceParallel() + sp.prepare(world_size, sp_model, SimpleNamespace(pad_token_id=0), device_mesh=device_mesh) + + sp_inputs_embeds = full_inputs_embeds.detach().clone().requires_grad_(True) + sp_inputs = sp.prepare_inputs({ + 'inputs_embeds': sp_inputs_embeds, + 'position_ids': full_position_ids.clone(), + 'use_cache': False, + }) + sp_outputs = sp_model(**sp_inputs) + sp_hidden_local = sp_outputs.last_hidden_state + sp_loss = sp_hidden_local.float().square().sum() + sp_loss.backward() + dist.all_reduce(sp_inputs_embeds.grad, group=dist.group.WORLD) + _all_reduce_grads(sp_model, dist.group.WORLD) + + sp_hidden_full = _all_gather_seq(sp_hidden_local.detach(), dist.group.WORLD) + torch.testing.assert_close( + sp_hidden_full.to(dtype=torch.float32), + baseline_hidden.detach().to(dtype=torch.float32), + rtol=5e-3, + atol=5e-3, + ) + _assert_relative_error(sp_inputs_embeds.grad, baseline_input_grad, 1e-2, 'mixed_text_model.input_grad') + + for name in ( + 'layers.0.self_attn.q_proj.weight', + 'layers.1.linear_attn.in_proj_qkv.weight', + 'layers.1.mlp.gate_proj.weight', + ): + _assert_relative_error( + sp_model.get_parameter(name).grad, + baseline_param_grads[name], + 2e-2, + f'mixed_text_model.{name}', + ) + except Exception: + _write_error(error_prefix, rank) + raise + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +class TestTwinkleQwen35TextModelParity(unittest.TestCase): + + WORLD_SIZE = 2 + + def _run_spawned_parity_test(self, worker) -> None: + port = _find_free_port() + with tempfile.TemporaryDirectory() as temp_dir: + error_prefix = os.path.join(temp_dir, 'parity') + try: + mp.spawn( + worker, + args=(self.WORLD_SIZE, port, error_prefix), + nprocs=self.WORLD_SIZE, + join=True, + ) + except Exception: + error_logs = [] + for rank in range(self.WORLD_SIZE): + error_path = f'{error_prefix}.rank{rank}.err' + if os.path.exists(error_path): + with open(error_path, encoding='utf-8') as f: + error_logs.append(f'Rank {rank}:\n{f.read()}') + if error_logs: + self.fail('\n\n'.join(error_logs)) + raise + + def test_linear_attention_sp_parity(self): + if not _linear_attention_runtime_available(): + self.skipTest('CUDA + flash-linear-attention + causal-conv1d are required for linear attention parity.') + if torch.cuda.device_count() < self.WORLD_SIZE: + self.skipTest(f'Need at least {self.WORLD_SIZE} CUDA devices for SP parity.') + self._run_spawned_parity_test(_run_linear_attention_parity_worker) + + def test_mixed_text_model_sp_parity(self): + if not _linear_attention_runtime_available(): + self.skipTest('CUDA + flash-linear-attention + causal-conv1d are required for mixed model parity.') + if torch.cuda.device_count() < self.WORLD_SIZE: + self.skipTest(f'Need at least {self.WORLD_SIZE} CUDA devices for SP parity.') + self._run_spawned_parity_test(_run_mixed_text_model_parity_worker) + + +if __name__ == '__main__': + unittest.main()