Skip to content

Commit 75d006a

Browse files
committed
feat: standardize import formatting and fix attention implementation string
- Change double quotes to single quotes for consistency in `attn_implementation` parameter - Reformat multi-line imports to single line for better readability - Remove unnecessary import error message in linear attention validation - Maintain code style consistency across the codebase
1 parent 22a71f5 commit 75d006a

File tree

8 files changed

+109
-107
lines changed

8 files changed

+109
-107
lines changed

cookbook/transformers/sp_fsdp_dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def train():
7070
model_cls=TwinkleQwen3_5ForCausalLM,
7171
device_mesh=device_mesh,
7272
strategy='native_fsdp',
73-
attn_implementation="flash_attention_2"
73+
attn_implementation='flash_attention_2'
7474
)
7575

7676
lora_config = LoraConfig(target_modules='all-linear', lora_dropout=0.0)

src/twinkle/model/transformers/__init__.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,8 @@
44
from twinkle.utils.import_utils import _LazyModule
55

66
if TYPE_CHECKING:
7-
from .models import (
8-
TwinkleQwen3_5DecoderLayer,
9-
TwinkleQwen3_5ForCausalLM,
10-
TwinkleQwen3_5GatedDeltaNet,
11-
TwinkleQwen3_5PreTrainedModel,
12-
TwinkleQwen3_5TextModel,
13-
)
7+
from .models import (TwinkleQwen3_5DecoderLayer, TwinkleQwen3_5ForCausalLM, TwinkleQwen3_5GatedDeltaNet,
8+
TwinkleQwen3_5PreTrainedModel, TwinkleQwen3_5TextModel)
149
from .multi_lora_transformers import MultiLoraTransformersModel
1510
from .transformers import TransformersModel
1611
else:

src/twinkle/model/transformers/models/__init__.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2-
from .qwen3_5 import (
3-
TwinkleQwen3_5DecoderLayer,
4-
TwinkleQwen3_5ForCausalLM,
5-
TwinkleQwen3_5GatedDeltaNet,
6-
TwinkleQwen3_5PreTrainedModel,
7-
TwinkleQwen3_5TextModel,
8-
)
2+
from .qwen3_5 import (TwinkleQwen3_5DecoderLayer, TwinkleQwen3_5ForCausalLM, TwinkleQwen3_5GatedDeltaNet,
3+
TwinkleQwen3_5PreTrainedModel, TwinkleQwen3_5TextModel)
94

105
__all__ = [
116
'TwinkleQwen3_5PreTrainedModel',

src/twinkle/model/transformers/models/qwen3_5/__init__.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2-
from .modeling_qwen3_5 import (
3-
TwinkleQwen3_5DecoderLayer,
4-
TwinkleQwen3_5ForCausalLM,
5-
TwinkleQwen3_5GatedDeltaNet,
6-
TwinkleQwen3_5PreTrainedModel,
7-
TwinkleQwen3_5TextModel,
8-
)
2+
from .modeling_qwen3_5 import (TwinkleQwen3_5DecoderLayer, TwinkleQwen3_5ForCausalLM, TwinkleQwen3_5GatedDeltaNet,
3+
TwinkleQwen3_5PreTrainedModel, TwinkleQwen3_5TextModel)
94

105
__all__ = [
116
'TwinkleQwen3_5PreTrainedModel',

src/twinkle/model/transformers/models/qwen3_5/modeling_qwen3_5.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,19 @@
22
from __future__ import annotations
33

44
import importlib.util
5-
from typing import Any, Callable, Optional
6-
75
import torch
86
import torch.nn.functional as F
97
from torch import nn
108
from transformers.cache_utils import Cache
119
from transformers.generation import GenerationMixin
1210
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13-
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig
1411
from transformers.models.qwen3_5 import modeling_qwen3_5 as hf_qwen35
12+
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig
1513
from transformers.processing_utils import Unpack
1614
from transformers.utils import TransformersKwargs, can_return_tuple
1715
from transformers.utils.generic import merge_with_config_defaults
1816
from transformers.utils.output_capturing import capture_outputs
19-
17+
from typing import Any, Callable, Optional
2018

2119
try:
2220
from fla.modules import FusedRMSNormGated as _FLA_FUSED_RMS_NORM_GATED
@@ -37,10 +35,8 @@
3735
def _ensure_text_config(config: Qwen3_5TextConfig) -> Qwen3_5TextConfig:
3836
if isinstance(config, Qwen3_5TextConfig):
3937
return config
40-
raise TypeError(
41-
'TwinkleQwen3_5 text-only models require transformers.models.qwen3_5.Qwen3_5TextConfig. '
42-
f'Got {type(config).__name__}.'
43-
)
38+
raise TypeError('TwinkleQwen3_5 text-only models require transformers.models.qwen3_5.Qwen3_5TextConfig. '
39+
f'Got {type(config).__name__}.')
4440

4541

4642
def _ensure_linear_attention_fast_path() -> None:
@@ -52,10 +48,8 @@ def _ensure_linear_attention_fast_path() -> None:
5248
if not _HAS_CAUSAL_CONV1D:
5349
missing.append('causal-conv1d')
5450
if missing:
55-
raise ImportError(
56-
'TwinkleQwen3_5 linear attention requires flash-linear-attention and causal-conv1d. '
57-
f'Missing: {", ".join(missing)}'
58-
)
51+
raise ImportError('TwinkleQwen3_5 linear attention requires flash-linear-attention and causal-conv1d. '
52+
f'Missing: {", ".join(missing)}')
5953

6054

6155
def _maybe_slice_tensor_output(output: Any) -> torch.Tensor:
@@ -66,10 +60,8 @@ def _maybe_slice_tensor_output(output: Any) -> torch.Tensor:
6660

6761
def _sp_is_enabled(sequence_parallel_context: Any | None) -> bool:
6862
return bool(
69-
sequence_parallel_context is not None
70-
and getattr(sequence_parallel_context, 'sp_world_size', 1) > 1
71-
and getattr(sequence_parallel_context, 'sp_group', None) is not None
72-
)
63+
sequence_parallel_context is not None and getattr(sequence_parallel_context, 'sp_world_size', 1) > 1
64+
and getattr(sequence_parallel_context, 'sp_group', None) is not None)
7365

7466

7567
def _get_sp_rank(sequence_parallel_context: Any | None) -> int:
@@ -239,8 +231,7 @@ def _apply_varlen_conv(
239231
) -> torch.Tensor:
240232
if self.causal_conv1d_fn is None:
241233
raise ImportError(
242-
'TwinkleQwen3_5 linear attention requires fla.modules.convolution.causal_conv1d for prefill/train.'
243-
)
234+
'TwinkleQwen3_5 linear attention requires fla.modules.convolution.causal_conv1d for prefill/train.')
244235
output = self.causal_conv1d_fn(
245236
x=mixed_qkv,
246237
weight=conv_weight,
@@ -261,8 +252,7 @@ def _apply_decode_conv(
261252
if self.causal_conv1d_update is None:
262253
raise ImportError(
263254
'TwinkleQwen3_5 decode requires a causal_conv1d_update implementation from flash-linear-attention '
264-
'or causal-conv1d.'
265-
)
255+
'or causal-conv1d.')
266256
mixed_qkv_t = mixed_qkv.transpose(1, 2).contiguous()
267257
output = self.causal_conv1d_update(
268258
mixed_qkv_t,
@@ -291,11 +281,8 @@ def forward(
291281
hidden_states = hf_qwen35.apply_mask_to_padding_states(hidden_states, attention_mask)
292282
batch_size, seq_len, _ = hidden_states.shape
293283
use_precomputed_states = (
294-
cache_params is not None
295-
and cache_params.has_previous_state
296-
and seq_len == 1
297-
and cache_position is not None
298-
)
284+
cache_params is not None and cache_params.has_previous_state and seq_len == 1
285+
and cache_position is not None)
299286

300287
if cache_params is not None:
301288
conv_state = cache_params.conv_states[self.layer_idx]
@@ -316,8 +303,7 @@ def forward(
316303
if self.num_k_heads % sp_world_size != 0 or self.num_v_heads % sp_world_size != 0:
317304
raise RuntimeError(
318305
'TwinkleQwen3_5 linear attention requires sp_world_size to divide both '
319-
f'linear_num_key_heads ({self.num_k_heads}) and linear_num_value_heads ({self.num_v_heads}).'
320-
)
306+
f'linear_num_key_heads ({self.num_k_heads}) and linear_num_value_heads ({self.num_v_heads}).')
321307
local_num_k_heads = self.num_k_heads // sp_world_size
322308
local_num_v_heads = self.num_v_heads // sp_world_size
323309
local_key_dim = local_num_k_heads * self.head_k_dim
@@ -341,7 +327,8 @@ def forward(
341327
),
342328
dim=-1,
343329
)
344-
conv_weight = self._get_local_conv1d_weight(_get_sp_rank(sequence_parallel_context), local_key_dim, local_value_dim)
330+
conv_weight = self._get_local_conv1d_weight(
331+
_get_sp_rank(sequence_parallel_context), local_key_dim, local_value_dim)
345332
else:
346333
local_num_k_heads = self.num_k_heads
347334
local_num_v_heads = self.num_v_heads
@@ -506,8 +493,7 @@ def __init__(self, config: Qwen3_5TextConfig):
506493
super().__init__(config)
507494
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
508495
self.layers = nn.ModuleList(
509-
[TwinkleQwen3_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
510-
)
496+
[TwinkleQwen3_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
511497
self.norm = hf_qwen35.Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
512498
self.rotary_emb = hf_qwen35.Qwen3_5TextRotaryEmbedding(config=config)
513499
self.gradient_checkpointing = False
@@ -569,8 +555,7 @@ def forward(
569555
if cache_position is None:
570556
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
571557
cache_position = torch.arange(
572-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
573-
)
558+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
574559

575560
if position_ids is None:
576561
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)

src/twinkle/model/transformers/strategy/sequence_parallel.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def get_flattened_cu_seqlens_from_position_ids(position_ids: torch.LongTensor):
3737
row[row < 0] = 0
3838
seq_start_indices = torch.where(row == 0)[0]
3939
if seq_start_indices.numel() == 0 or seq_start_indices[0].item() != 0:
40-
seq_start_indices = torch.cat([torch.tensor([0], device=device, dtype=seq_start_indices.dtype), seq_start_indices])
40+
seq_start_indices = torch.cat(
41+
[torch.tensor([0], device=device, dtype=seq_start_indices.dtype), seq_start_indices])
4142
seq_end_indices = torch.cat([seq_start_indices[1:], torch.tensor([len(row)], device=device)])
4243
seq_lengths = (seq_end_indices - seq_start_indices).tolist()
4344
for seq_length in seq_lengths:
@@ -687,8 +688,7 @@ def prepare(
687688
self.causal_mask_func = llm_model._update_causal_mask
688689
self.attn_implementation = (
689690
get_config_attr(model.config, '_attn_implementation')
690-
or get_config_attr(model.config, '_attn_implementation_internal')
691-
)
691+
or get_config_attr(model.config, '_attn_implementation_internal'))
692692

693693
if not SequenceParallel._global_inited:
694694
# these operations are global initializations and patches
@@ -832,8 +832,8 @@ def pad_and_split_inputs(self,
832832
cache_position = torch.arange(0, attn_shape, device=inputs.device)
833833
# SDPA/eager-style paths still expect a fully materialized causal mask here.
834834
if hasattr(self, 'causal_mask_func') and self.causal_mask_func is not None:
835-
attention_mask = self.causal_mask_func(
836-
attention_mask, inputs.to(self.model_dtype), cache_position, None, None)
835+
attention_mask = self.causal_mask_func(attention_mask, inputs.to(self.model_dtype),
836+
cache_position, None, None)
837837
if extra_split_values is not None:
838838
for (tensor, pad_value, split_dim) in extra_split_values:
839839
extra_values.append(

tests/sequence_parallel/test_twinkle_qwen3_5_text_model.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
import tempfile
3+
import torch
34
import unittest
45
from contextlib import ExitStack
5-
from types import SimpleNamespace
6-
from unittest.mock import patch
7-
8-
import torch
96
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig
107
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
8+
from types import SimpleNamespace
9+
from unittest.mock import patch
1110

1211
from twinkle.model.transformers.models.qwen3_5 import modeling_qwen3_5 as tw_qwen35
1312
from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallel, SequenceParallelContext
@@ -40,13 +39,9 @@ def _build_text_config(layer_types=None) -> Qwen3_5TextConfig:
4039

4140

4241
def _linear_attention_runtime_available() -> bool:
43-
return bool(
44-
torch.cuda.is_available()
45-
and tw_qwen35._FLA_CAUSAL_CONV1D_FN is not None
46-
and tw_qwen35._FLA_CHUNK_GATED_DELTA_RULE is not None
47-
and tw_qwen35._FLA_FUSED_RECURRENT_GATED_DELTA_RULE is not None
48-
and tw_qwen35._HAS_CAUSAL_CONV1D
49-
)
42+
return bool(torch.cuda.is_available() and tw_qwen35._FLA_CAUSAL_CONV1D_FN is not None
43+
and tw_qwen35._FLA_CHUNK_GATED_DELTA_RULE is not None
44+
and tw_qwen35._FLA_FUSED_RECURRENT_GATED_DELTA_RULE is not None and tw_qwen35._HAS_CAUSAL_CONV1D)
5045

5146

5247
class _ContextReceiver:
@@ -233,13 +228,26 @@ def fake_conv(x, weight, bias, activation, seq_idx=None, backend=None, cu_seqlen
233228
captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None
234229
return x
235230

236-
def fake_chunk_rule(query, key, value, g, beta, initial_state=None, output_final_state=False,
237-
use_qk_l2norm_in_kernel=False, cu_seqlens=None):
231+
def fake_chunk_rule(query,
232+
key,
233+
value,
234+
g,
235+
beta,
236+
initial_state=None,
237+
output_final_state=False,
238+
use_qk_l2norm_in_kernel=False,
239+
cu_seqlens=None):
238240
del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel
239241
captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None
240242
return value, None
241243

242-
def fake_recurrent_rule(query, key, value, g, beta, initial_state=None, output_final_state=False,
244+
def fake_recurrent_rule(query,
245+
key,
246+
value,
247+
g,
248+
beta,
249+
initial_state=None,
250+
output_final_state=False,
243251
use_qk_l2norm_in_kernel=False):
244252
del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel
245253
return value, None
@@ -321,14 +329,27 @@ def fake_conv(x, weight, bias, activation, seq_idx=None, backend=None, cu_seqlen
321329
captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None
322330
return x
323331

324-
def fake_chunk_rule(query, key, value, g, beta, initial_state=None, output_final_state=False,
325-
use_qk_l2norm_in_kernel=False, cu_seqlens=None):
332+
def fake_chunk_rule(query,
333+
key,
334+
value,
335+
g,
336+
beta,
337+
initial_state=None,
338+
output_final_state=False,
339+
use_qk_l2norm_in_kernel=False,
340+
cu_seqlens=None):
326341
del key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel
327342
captured['query_shape'] = tuple(query.shape)
328343
captured['cu_seqlens'] = cu_seqlens.clone() if cu_seqlens is not None else None
329344
return query.new_zeros(query.shape[0], query.shape[1], 4, 4), None
330345

331-
def fake_recurrent_rule(query, key, value, g, beta, initial_state=None, output_final_state=False,
346+
def fake_recurrent_rule(query,
347+
key,
348+
value,
349+
g,
350+
beta,
351+
initial_state=None,
352+
output_final_state=False,
332353
use_qk_l2norm_in_kernel=False):
333354
del query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel
334355
raise AssertionError('recurrent path should not be used')
@@ -368,12 +389,25 @@ def fake_conv(x, weight, bias, activation, seq_idx=None, backend=None, cu_seqlen
368389
del weight, bias, activation, seq_idx, backend, cu_seqlens
369390
return x
370391

371-
def fake_chunk_rule(query, key, value, g, beta, initial_state=None, output_final_state=False,
372-
use_qk_l2norm_in_kernel=False, cu_seqlens=None):
392+
def fake_chunk_rule(query,
393+
key,
394+
value,
395+
g,
396+
beta,
397+
initial_state=None,
398+
output_final_state=False,
399+
use_qk_l2norm_in_kernel=False,
400+
cu_seqlens=None):
373401
del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel, cu_seqlens
374402
return value, None
375403

376-
def fake_recurrent_rule(query, key, value, g, beta, initial_state=None, output_final_state=False,
404+
def fake_recurrent_rule(query,
405+
key,
406+
value,
407+
g,
408+
beta,
409+
initial_state=None,
410+
output_final_state=False,
377411
use_qk_l2norm_in_kernel=False):
378412
del query, key, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel
379413
return value, None
@@ -395,8 +429,12 @@ def fake_recurrent_rule(query, key, value, g, beta, initial_state=None, output_f
395429
is_packed=False,
396430
))
397431

398-
def fake_linear_forward(hidden_states, cache_params=None, cache_position=None, attention_mask=None,
399-
cu_seq_lens_q=None, sequence_parallel_context=None):
432+
def fake_linear_forward(hidden_states,
433+
cache_params=None,
434+
cache_position=None,
435+
attention_mask=None,
436+
cu_seq_lens_q=None,
437+
sequence_parallel_context=None):
400438
del hidden_states, cache_params, cache_position, cu_seq_lens_q, sequence_parallel_context
401439
captured['mask'] = attention_mask.clone() if attention_mask is not None else None
402440
return torch.zeros(1, 2, config.hidden_size)
@@ -421,8 +459,7 @@ def test_sequence_parallel_drops_dense_attention_mask_for_flash_attention_2(self
421459
sp.tokenizer = SimpleNamespace(pad_token_id=0)
422460
sp.model_dtype = torch.bfloat16
423461
sp.attn_implementation = 'flash_attention_2'
424-
sp.causal_mask_func = lambda *args, **kwargs: (_ for _ in ()).throw(
425-
AssertionError('should not build 4d mask'))
462+
sp.causal_mask_func = lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError('should not build 4d mask'))
426463

427464
input_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.long)
428465
position_ids = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.long)

0 commit comments

Comments
 (0)