Skip to content

Commit 24e6b79

Browse files
authored
Qwen 3.5 Architecture Adapter (#1244)
* Qwen 3.5 architecture adapter * Qwen 3.5 architecture adapter complete * Cleaning up tests
1 parent f0dd689 commit 24e6b79

File tree

8 files changed

+1028
-6
lines changed

8 files changed

+1028
-6
lines changed

tests/unit/test_qwen3_next_adapter.py

Lines changed: 585 additions & 0 deletions
Large diffs are not rendered by default.

transformer_lens/factories/architecture_adapter_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
PhiArchitectureAdapter,
4545
Qwen2ArchitectureAdapter,
4646
Qwen3ArchitectureAdapter,
47+
Qwen3NextArchitectureAdapter,
4748
QwenArchitectureAdapter,
4849
StableLmArchitectureAdapter,
4950
T5ArchitectureAdapter,
@@ -90,6 +91,7 @@
9091
"QwenForCausalLM": QwenArchitectureAdapter,
9192
"Qwen2ForCausalLM": Qwen2ArchitectureAdapter,
9293
"Qwen3ForCausalLM": Qwen3ArchitectureAdapter,
94+
"Qwen3NextForCausalLM": Qwen3NextArchitectureAdapter,
9395
"StableLmForCausalLM": StableLmArchitectureAdapter,
9496
"T5ForConditionalGeneration": T5ArchitectureAdapter,
9597
"NanoGPTForCausalLM": NanogptArchitectureAdapter,

transformer_lens/model_bridge/supported_architectures/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@
123123
from transformer_lens.model_bridge.supported_architectures.qwen3 import (
124124
Qwen3ArchitectureAdapter,
125125
)
126+
from transformer_lens.model_bridge.supported_architectures.qwen3_next import (
127+
Qwen3NextArchitectureAdapter,
128+
)
126129
from transformer_lens.model_bridge.supported_architectures.stablelm import (
127130
StableLmArchitectureAdapter,
128131
)
@@ -171,6 +174,7 @@
171174
"QwenArchitectureAdapter",
172175
"Qwen2ArchitectureAdapter",
173176
"Qwen3ArchitectureAdapter",
177+
"Qwen3NextArchitectureAdapter",
174178
"StableLmArchitectureAdapter",
175179
"T5ArchitectureAdapter",
176180
]
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""Qwen3Next architecture adapter.
2+
3+
Qwen3NextForCausalLM is a hybrid linear-attention + full-attention architecture
4+
with a sparse Mixture-of-Experts MLP on every layer. Layers alternate between
5+
GatedDeltaNet (linear attention) and standard full attention blocks, while the
6+
MLP is always a Qwen3NextSparseMoeBlock (gate router + batched experts +
7+
shared expert).
8+
9+
Since self_attn is absent on linear-attention layers, we only map submodules
10+
that exist on ALL layers (norms, MLP). The HF native forward handles
11+
linear/full attention dispatch internally, and MoEBridge delegates the entire
12+
MoE forward (including router, experts, and shared expert) to the native
13+
implementation.
14+
15+
Hook coverage:
16+
- Block-level: hook_resid_pre, hook_resid_post on every layer
17+
- Normalization: ln1 (input_layernorm), ln2 (post_attention_layernorm)
18+
- MLP: hook_in, hook_out on the MoE block (MoEBridge)
19+
- Attention internals are NOT individually hooked (self_attn absent on
20+
linear-attention layers; mapping it would crash on those layers)
21+
- Expert-level internals are NOT individually hooked (batched expert params
22+
live inside Qwen3NextExperts; MoEBridge delegates to HF forward)
23+
24+
Optional parameters:
25+
- n_key_value_heads: only set when using GQA (num_key_value_heads != num_attention_heads)
26+
"""
27+
28+
from typing import Any
29+
30+
import torch
31+
32+
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
33+
from transformer_lens.model_bridge.generalized_components import (
34+
BlockBridge,
35+
EmbeddingBridge,
36+
MoEBridge,
37+
RMSNormalizationBridge,
38+
RotaryEmbeddingBridge,
39+
UnembeddingBridge,
40+
)
41+
42+
43+
class Qwen3NextArchitectureAdapter(ArchitectureAdapter):
44+
"""Architecture adapter for Qwen3Next models.
45+
46+
Qwen3NextForCausalLM is a hybrid linear-attention + full-attention
47+
architecture with sparse MoE MLPs, sharing the same design as Qwen3.5:
48+
- Uses RMSNorm for all normalizations
49+
- Uses rotary position embeddings (RoPE) with partial rotation
50+
- Every 4th layer is a full-attention layer (self_attn); the rest are
51+
GatedDeltaNet linear-attention layers (linear_attn)
52+
- Uses Qwen3NextSparseMoeBlock on ALL layers (decoder_sparse_step=1 and
53+
mlp_only_layers=[] on every real checkpoint). The MoE block contains a
54+
top-K router, batched Qwen3NextExperts (experts.gate_up_proj /
55+
experts.down_proj as 3D tensors), plus a shared_expert (gated MLP) and
56+
shared_expert_gate. Each expert is internally a gated MLP.
57+
- No biases on any linear layers
58+
- Full-attention layers have Q/K normalization (q_norm, k_norm)
59+
- Full-attention q_proj outputs n_heads * head_dim * 2 (interleaved
60+
query+gate layout); the preprocess_weights method slices the query half
61+
62+
Since self_attn is absent on linear-attention layers, only universally
63+
present submodules (norms, MLP) are mapped as block submodules. The HF
64+
native forward handles per-layer attention dispatch internally, and
65+
MoEBridge delegates the MoE forward pass (including router + experts +
66+
shared expert) to the native Qwen3NextSparseMoeBlock implementation.
67+
68+
Optional parameters:
69+
- n_key_value_heads: set when num_key_value_heads != num_attention_heads (GQA)
70+
"""
71+
72+
def __init__(self, cfg: Any) -> None:
73+
"""Initialize the Qwen3Next architecture adapter."""
74+
super().__init__(cfg)
75+
76+
# Core config attributes
77+
self.cfg.normalization_type = "RMS"
78+
self.cfg.positional_embedding_type = "rotary"
79+
self.cfg.final_rms = True
80+
self.cfg.gated_mlp = True
81+
self.cfg.attn_only = False
82+
self.cfg.uses_rms_norm = True
83+
self.cfg.default_prepend_bos = False
84+
85+
# Disable fold_ln: ln1 is followed by self_attn on full-attention
86+
# layers and by linear_attn (GatedDeltaNet) on linear-attention layers,
87+
# but neither is mapped as a bridge submodule (see class docstring for
88+
# why). With no bridge-mapped target to fold into, the standard fold_ln
89+
# pass leaves LN weights in an inconsistent state and the processed
90+
# bridge output diverges from the unprocessed / HF output. Skipping
91+
# fold_ln keeps processed-mode forward passes numerically equivalent.
92+
self.supports_fold_ln = False
93+
94+
# Use eager attention to support output_attentions for hook_attn_scores
95+
# and hook_pattern. SDPA doesn't support output_attentions.
96+
self.cfg.attn_implementation = "eager"
97+
98+
# GQA: only set n_key_value_heads when using grouped-query attention
99+
if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
100+
self.cfg.n_key_value_heads = cfg.n_key_value_heads
101+
102+
self.weight_processing_conversions: dict = {}
103+
self.component_mapping: dict = {
104+
"embed": EmbeddingBridge(name="model.embed_tokens"),
105+
"rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
106+
"blocks": BlockBridge(
107+
name="model.layers",
108+
submodules={
109+
"ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
110+
"ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
111+
# Qwen3NextSparseMoeBlock has a custom Qwen3NextTopKRouter
112+
# (not an nn.Linear) as `gate`, plus batched experts and a
113+
# shared expert. MoEBridge wraps the whole MoE module and
114+
# delegates to HF's native forward, so we don't enumerate
115+
# the internal structure here.
116+
"mlp": MoEBridge(name="mlp", config=self.cfg),
117+
},
118+
),
119+
"ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
120+
"unembed": UnembeddingBridge(name="lm_head"),
121+
}
122+
123+
def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
124+
"""No-op for hybrid models.
125+
126+
Hybrid models don't map attention as a block submodule (self_attn is
127+
absent on linear-attention layers), so there are no rotary embedding
128+
references to set up.
129+
130+
Note: to find which layers are full_attention at runtime, use:
131+
layer_types = getattr(hf_model.config, "layer_types", [])
132+
first_full_attn_idx = next(
133+
i for i, t in enumerate(layer_types) if t == "full_attention"
134+
)
135+
Do NOT use hf_model.config.full_attention_interval -- it is not stored
136+
on the config object (consumed during __init__ to build layer_types).
137+
"""
138+
139+
def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
140+
"""Slice query half from q_proj.weight (interleaved per-head layout).
141+
142+
In Qwen3Next, q_proj.weight has shape (n_heads * head_dim * 2, hidden_size).
143+
Rows are organized as per-head interleaved:
144+
head_0_query (d_head rows), head_0_gate (d_head rows),
145+
head_1_query (d_head rows), head_1_gate (d_head rows), ...
146+
147+
A naive first-half slice would be wrong. We must reshape by head, then
148+
take the first d_head rows of each head (the query half).
149+
150+
Note: since self_attn is NOT currently mapped as a bridge submodule,
151+
these weights will not be loaded by the bridge. This method is included
152+
for correctness and forward-compatibility.
153+
"""
154+
n_heads = self.cfg.n_heads
155+
d_head = self.cfg.d_head
156+
keys_to_update = [k for k in state_dict if k.endswith(".self_attn.q_proj.weight")]
157+
for key in keys_to_update:
158+
w = state_dict[key] # shape: (n_heads * d_head * 2, hidden_size)
159+
# Reshape to expose per-head layout
160+
w = w.view(n_heads, d_head * 2, -1)
161+
# Take only the first d_head rows of each head (query half)
162+
state_dict[key] = w[:, :d_head, :].reshape(n_heads * d_head, -1)
163+
return state_dict

transformer_lens/tools/model_registry/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
"QwenForCausalLM",
7979
"Qwen2ForCausalLM",
8080
"Qwen3ForCausalLM",
81+
"Qwen3NextForCausalLM",
8182
"StableLmForCausalLM",
8283
"T5ForConditionalGeneration",
8384
}

transformer_lens/tools/model_registry/data/supported_models.json

Lines changed: 156 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,166 @@
11
{
2-
"generated_at": "2026-04-09",
2+
"generated_at": "2026-04-10",
33
"scan_info": {
44
"total_scanned": 6354,
55
"task_filter": "text-generation",
66
"min_downloads": 500,
7-
"scan_duration_seconds": 12.1
7+
"scan_duration_seconds": 0.0
88
},
9-
"total_architectures": 36,
10-
"total_models": 6686,
11-
"total_verified": 690,
9+
"total_architectures": 37,
10+
"total_models": 5563,
11+
"total_verified": 693,
1212
"models": [
13+
{
14+
"architecture_id": "Qwen3NextForCausalLM",
15+
"model_id": "Qwen/Qwen3-Next-80B-A3B-Instruct",
16+
"status": 2,
17+
"verified_date": "2026-04-10",
18+
"metadata": null,
19+
"note": "Estimated 6929.6 GB exceeds 96.0 GB limit",
20+
"phase1_score": null,
21+
"phase2_score": null,
22+
"phase3_score": null,
23+
"phase4_score": null,
24+
"phase7_score": null,
25+
"phase8_score": null
26+
},
27+
{
28+
"architecture_id": "Qwen3NextForCausalLM",
29+
"model_id": "unsloth/Qwen3-Coder-Next",
30+
"status": 2,
31+
"verified_date": "2026-04-10",
32+
"metadata": null,
33+
"note": "Estimated 6929.6 GB exceeds 96.0 GB limit",
34+
"phase1_score": null,
35+
"phase2_score": null,
36+
"phase3_score": null,
37+
"phase4_score": null,
38+
"phase7_score": null,
39+
"phase8_score": null
40+
},
41+
{
42+
"architecture_id": "Qwen3NextForCausalLM",
43+
"model_id": "Qwen/Qwen3-Next-80B-A3B-Thinking",
44+
"status": 2,
45+
"verified_date": "2026-04-10",
46+
"metadata": null,
47+
"note": "Estimated 6929.6 GB exceeds 96.0 GB limit",
48+
"phase1_score": null,
49+
"phase2_score": null,
50+
"phase3_score": null,
51+
"phase4_score": null,
52+
"phase7_score": null,
53+
"phase8_score": null
54+
},
55+
{
56+
"architecture_id": "Qwen3NextForCausalLM",
57+
"model_id": "tiny-random/qwen3-next-moe",
58+
"status": 1,
59+
"verified_date": "2026-04-10",
60+
"metadata": null,
61+
"note": "Full verification completed",
62+
"phase1_score": 100.0,
63+
"phase2_score": 100.0,
64+
"phase3_score": 100.0,
65+
"phase4_score": 75.7,
66+
"phase7_score": null,
67+
"phase8_score": null
68+
},
69+
{
70+
"architecture_id": "Qwen3NextForCausalLM",
71+
"model_id": "optimum-intel-internal-testing/tiny-random-qwen3-next",
72+
"status": 1,
73+
"verified_date": "2026-04-10",
74+
"metadata": null,
75+
"note": "Full verification completed",
76+
"phase1_score": 100.0,
77+
"phase2_score": 100.0,
78+
"phase3_score": 100.0,
79+
"phase4_score": 55.9,
80+
"phase7_score": null,
81+
"phase8_score": null
82+
},
83+
{
84+
"architecture_id": "Qwen3NextForCausalLM",
85+
"model_id": "yujiepan/qwen3-next-moe-tiny-random",
86+
"status": 1,
87+
"verified_date": "2026-04-10",
88+
"metadata": null,
89+
"note": "Full verification completed",
90+
"phase1_score": 100.0,
91+
"phase2_score": 100.0,
92+
"phase3_score": 100.0,
93+
"phase4_score": 75.7,
94+
"phase7_score": null,
95+
"phase8_score": null
96+
},
97+
{
98+
"architecture_id": "Qwen3NextForCausalLM",
99+
"model_id": "huihui-ai/Huihui-Qwen3-Coder-Next-abliterated",
100+
"status": 2,
101+
"verified_date": "2026-04-10",
102+
"metadata": null,
103+
"note": "Estimated 6929.6 GB exceeds 96.0 GB limit",
104+
"phase1_score": null,
105+
"phase2_score": null,
106+
"phase3_score": null,
107+
"phase4_score": null,
108+
"phase7_score": null,
109+
"phase8_score": null
110+
},
111+
{
112+
"architecture_id": "Qwen3NextForCausalLM",
113+
"model_id": "Qwen/Qwen3-Coder-Next-Base",
114+
"status": 2,
115+
"verified_date": "2026-04-10",
116+
"metadata": null,
117+
"note": "Estimated 6929.6 GB exceeds 96.0 GB limit",
118+
"phase1_score": null,
119+
"phase2_score": null,
120+
"phase3_score": null,
121+
"phase4_score": null,
122+
"phase7_score": null,
123+
"phase8_score": null
124+
},
125+
{
126+
"architecture_id": "Qwen3NextForCausalLM",
127+
"model_id": "bknyaz/Qwen3-Coder-Next-REAM",
128+
"status": 2,
129+
"verified_date": "2026-04-10",
130+
"metadata": null,
131+
"note": "Estimated 5201.5 GB exceeds 96.0 GB limit",
132+
"phase1_score": null,
133+
"phase2_score": null,
134+
"phase3_score": null,
135+
"phase4_score": null,
136+
"phase7_score": null,
137+
"phase8_score": null
138+
},
139+
{
140+
"architecture_id": "Qwen3NextForCausalLM",
141+
"model_id": "Qwen/Qwen3-Coder-Next",
142+
"status": 2,
143+
"verified_date": "2026-04-10",
144+
"metadata": {
145+
"downloads": 664116,
146+
"likes": 0,
147+
"last_modified": null,
148+
"tags": [
149+
"transformers",
150+
"safetensors",
151+
"qwen3_next",
152+
"text-generation"
153+
],
154+
"parameter_count": 79674391296
155+
},
156+
"note": "Estimated 6929.6 GB exceeds 96.0 GB limit",
157+
"phase1_score": null,
158+
"phase2_score": null,
159+
"phase3_score": null,
160+
"phase4_score": null,
161+
"phase7_score": null,
162+
"phase8_score": null
163+
},
13164
{
14165
"architecture_id": "Qwen3ForCausalLM",
15166
"model_id": "Qwen/Qwen3-0.6B",

0 commit comments

Comments
 (0)