Skip to content

Commit 5daba3e

Browse files
authored
Qwen3 MoE Adapter (#1245)
* Initila Qwen3MoE adapter setup * MLP handling in verify models * verification of models * Comment cleanup * Fixed missing closing brace
1 parent 24e6b79 commit 5daba3e

File tree

9 files changed

+5051
-12
lines changed

9 files changed

+5051
-12
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""Integration tests for the Qwen3MoE TransformerBridge.
2+
3+
Uses a tiny programmatic config on the meta device — no network access or
4+
weight downloads. Tensor ops can't execute on meta, so forward-pass tests are
5+
skipped and run manually during verification. Fixture pattern mirrors
6+
tests/unit/model_bridge/test_gpt_oss_moe.py.
7+
"""
8+
9+
import pytest
10+
import torch
11+
from transformers import AutoConfig, AutoModelForCausalLM
12+
13+
from transformer_lens.config import TransformerBridgeConfig
14+
from transformer_lens.model_bridge.bridge import TransformerBridge
15+
from transformer_lens.model_bridge.generalized_components import MoEBridge
16+
from transformer_lens.model_bridge.sources.transformers import (
17+
map_default_transformer_lens_config,
18+
)
19+
from transformer_lens.model_bridge.supported_architectures.qwen3_moe import (
20+
Qwen3MoeArchitectureAdapter,
21+
)
22+
23+
24+
class _MockTokenizer:
25+
"""Stand-in to satisfy TransformerBridge(tokenizer=...)."""
26+
27+
pass
28+
29+
30+
@pytest.fixture(scope="module")
31+
def tiny_qwen3moe_config():
32+
"""Small Qwen3MoeConfig: 2 layers, 4 heads, 4 experts."""
33+
return AutoConfig.for_model(
34+
"qwen3_moe",
35+
hidden_size=64,
36+
num_hidden_layers=2,
37+
num_attention_heads=4,
38+
num_key_value_heads=2,
39+
head_dim=16,
40+
moe_intermediate_size=32,
41+
num_experts=4,
42+
num_experts_per_tok=2,
43+
vocab_size=256,
44+
max_position_embeddings=128,
45+
decoder_sparse_step=1,
46+
mlp_only_layers=[],
47+
)
48+
49+
50+
@pytest.fixture(scope="module")
51+
def tiny_qwen3moe_model_meta(tiny_qwen3moe_config):
52+
"""Qwen3MoE model on meta device (no weights loaded)."""
53+
with torch.device("meta"):
54+
model = AutoModelForCausalLM.from_config(tiny_qwen3moe_config)
55+
return model
56+
57+
58+
@pytest.fixture(scope="module")
59+
def tiny_qwen3moe_bridge(tiny_qwen3moe_config, tiny_qwen3moe_model_meta):
60+
"""TransformerBridge wrapping the tiny meta-device Qwen3MoE model."""
61+
tl_config = map_default_transformer_lens_config(tiny_qwen3moe_config)
62+
63+
bridge_config = TransformerBridgeConfig(
64+
d_model=tl_config.d_model,
65+
d_head=tl_config.d_head,
66+
n_layers=tl_config.n_layers,
67+
n_ctx=tl_config.n_ctx,
68+
n_heads=tl_config.n_heads,
69+
n_key_value_heads=tl_config.n_key_value_heads,
70+
d_vocab=tl_config.d_vocab,
71+
architecture="Qwen3MoeForCausalLM",
72+
)
73+
74+
adapter = Qwen3MoeArchitectureAdapter(bridge_config)
75+
76+
return TransformerBridge(
77+
model=tiny_qwen3moe_model_meta,
78+
adapter=adapter,
79+
tokenizer=_MockTokenizer(),
80+
)
81+
82+
83+
class TestQwen3MoeModelStructure:
84+
def test_model_has_layers(self, tiny_qwen3moe_model_meta) -> None:
85+
assert hasattr(tiny_qwen3moe_model_meta, "model")
86+
assert hasattr(tiny_qwen3moe_model_meta.model, "layers")
87+
assert len(tiny_qwen3moe_model_meta.model.layers) == 2
88+
89+
def test_layer_has_sparse_moe_block(self, tiny_qwen3moe_model_meta) -> None:
90+
# Qwen3MoeSparseMoeBlock stores experts as batched 3D tensors, not a ModuleList
91+
layer0_mlp = tiny_qwen3moe_model_meta.model.layers[0].mlp
92+
assert hasattr(layer0_mlp, "experts")
93+
experts = layer0_mlp.experts
94+
assert hasattr(experts, "gate_up_proj")
95+
assert hasattr(experts, "down_proj")
96+
assert not hasattr(experts, "__iter__")
97+
98+
def test_layer_has_gate_router(self, tiny_qwen3moe_model_meta) -> None:
99+
layer0_mlp = tiny_qwen3moe_model_meta.model.layers[0].mlp
100+
assert hasattr(layer0_mlp, "gate")
101+
102+
def test_attention_has_q_norm_k_norm(self, tiny_qwen3moe_model_meta) -> None:
103+
attn = tiny_qwen3moe_model_meta.model.layers[0].self_attn
104+
assert hasattr(attn, "q_norm")
105+
assert hasattr(attn, "k_norm")
106+
107+
108+
class TestQwen3MoeBridgeStructure:
109+
def test_block_count(self, tiny_qwen3moe_bridge) -> None:
110+
assert len(tiny_qwen3moe_bridge.blocks) == 2
111+
112+
def test_has_core_components(self, tiny_qwen3moe_bridge) -> None:
113+
assert hasattr(tiny_qwen3moe_bridge, "embed")
114+
assert hasattr(tiny_qwen3moe_bridge, "unembed")
115+
assert hasattr(tiny_qwen3moe_bridge, "ln_final")
116+
117+
def test_cfg_final_rms_is_true(self, tiny_qwen3moe_bridge) -> None:
118+
"""Qwen3MoE uses final_rms=True; OLMoE uses False."""
119+
assert tiny_qwen3moe_bridge.cfg.final_rms is True
120+
121+
def test_cfg_n_kv_heads(self, tiny_qwen3moe_bridge) -> None:
122+
assert tiny_qwen3moe_bridge.cfg.n_key_value_heads == 2
123+
124+
def test_cfg_positional_embedding_type(self, tiny_qwen3moe_bridge) -> None:
125+
assert tiny_qwen3moe_bridge.cfg.positional_embedding_type == "rotary"
126+
127+
def test_cfg_normalization_type(self, tiny_qwen3moe_bridge) -> None:
128+
assert tiny_qwen3moe_bridge.cfg.normalization_type == "RMS"
129+
130+
def test_mlp_blocks_are_moe_bridge(self, tiny_qwen3moe_bridge) -> None:
131+
for i, block in enumerate(tiny_qwen3moe_bridge.blocks):
132+
assert isinstance(
133+
block.mlp, MoEBridge
134+
), f"Block {i} mlp is {type(block.mlp).__name__}, expected MoEBridge"
135+
136+
def test_moe_bridge_has_router_scores_hook(self, tiny_qwen3moe_bridge) -> None:
137+
mlp = tiny_qwen3moe_bridge.blocks[0].mlp
138+
assert hasattr(mlp, "hook_router_scores")
139+
140+
def test_block_has_ln1_and_ln2(self, tiny_qwen3moe_bridge) -> None:
141+
block = tiny_qwen3moe_bridge.blocks[0]
142+
assert hasattr(block, "ln1")
143+
assert hasattr(block, "ln2")
144+
145+
def test_block_attn_has_q_norm_k_norm(self, tiny_qwen3moe_bridge) -> None:
146+
attn = tiny_qwen3moe_bridge.blocks[0].attn
147+
assert hasattr(attn, "q_norm")
148+
assert hasattr(attn, "k_norm")
149+
150+
151+
# Forward-pass tests require real weights — meta-device tensor ops raise
152+
# NotImplementedError. Run these manually during Step 3 verification.
153+
154+
155+
@pytest.mark.skip(reason="Requires real weights — run manually during verification")
156+
def test_forward_pass_matches_hf(tiny_qwen3moe_bridge) -> None:
157+
"""Bridge logits match the HF model."""
158+
tokens = torch.tensor([[1, 2, 3, 4]])
159+
with torch.no_grad():
160+
bridge_out = tiny_qwen3moe_bridge(tokens)
161+
hf_out = tiny_qwen3moe_bridge.original_model(tokens).logits
162+
max_diff = (bridge_out - hf_out).abs().max().item()
163+
assert max_diff < 1e-4, f"Bridge vs HF max diff = {max_diff}"
164+
165+
166+
@pytest.mark.skip(reason="Requires real weights — run manually during verification")
167+
def test_run_with_cache_captures_moe_router_scores(tiny_qwen3moe_bridge) -> None:
168+
"""MoEBridge captures router scores in the activation cache."""
169+
tiny_qwen3moe_bridge.enable_compatibility_mode(no_processing=True)
170+
tokens = torch.tensor([[1, 2, 3, 4]])
171+
_, cache = tiny_qwen3moe_bridge.run_with_cache(tokens)
172+
for i in range(len(tiny_qwen3moe_bridge.blocks)):
173+
assert f"blocks.{i}.mlp.hook_router_scores" in cache, f"Missing router scores for block {i}"
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""Unit tests for the Qwen3MoeArchitectureAdapter.
2+
3+
All tests use programmatic TransformerBridgeConfig instances — no network access
4+
or model downloads.
5+
"""
6+
7+
import pytest
8+
9+
from transformer_lens.config import TransformerBridgeConfig
10+
from transformer_lens.conversion_utils.conversion_steps.rearrange_tensor_conversion import (
11+
RearrangeTensorConversion,
12+
)
13+
from transformer_lens.conversion_utils.param_processing_conversion import (
14+
ParamProcessingConversion,
15+
)
16+
from transformer_lens.factories.architecture_adapter_factory import (
17+
SUPPORTED_ARCHITECTURES,
18+
)
19+
from transformer_lens.model_bridge.generalized_components import (
20+
MoEBridge,
21+
RMSNormalizationBridge,
22+
)
23+
from transformer_lens.model_bridge.supported_architectures.qwen3_moe import (
24+
Qwen3MoeArchitectureAdapter,
25+
)
26+
27+
28+
@pytest.fixture
29+
def cfg() -> TransformerBridgeConfig:
30+
return TransformerBridgeConfig(
31+
d_model=64,
32+
d_head=16,
33+
n_layers=2,
34+
n_ctx=128,
35+
n_heads=4,
36+
n_key_value_heads=2,
37+
d_vocab=256,
38+
architecture="Qwen3MoeForCausalLM",
39+
)
40+
41+
42+
@pytest.fixture
43+
def adapter(cfg: TransformerBridgeConfig) -> Qwen3MoeArchitectureAdapter:
44+
return Qwen3MoeArchitectureAdapter(cfg)
45+
46+
47+
class TestQwen3MoeAdapterConfig:
48+
def test_normalization_type_is_rms(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
49+
assert adapter.cfg.normalization_type == "RMS"
50+
51+
def test_positional_embedding_type_is_rotary(
52+
self, adapter: Qwen3MoeArchitectureAdapter
53+
) -> None:
54+
assert adapter.cfg.positional_embedding_type == "rotary"
55+
56+
def test_final_rms_is_true(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
57+
"""Qwen3MoE uses final_rms=True; OLMoE uses False."""
58+
assert adapter.cfg.final_rms is True
59+
60+
def test_gated_mlp_is_true(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
61+
assert adapter.cfg.gated_mlp is True
62+
63+
def test_uses_rms_norm_is_true(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
64+
assert adapter.cfg.uses_rms_norm is True
65+
66+
def test_attn_implementation_is_eager(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
67+
assert adapter.cfg.attn_implementation == "eager"
68+
69+
def test_default_prepend_bos_is_false(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
70+
assert adapter.cfg.default_prepend_bos is False
71+
72+
def test_n_kv_heads_propagated(self) -> None:
73+
"""n_key_value_heads from the loaded config is preserved."""
74+
cfg = TransformerBridgeConfig(
75+
d_model=64,
76+
d_head=16,
77+
n_layers=2,
78+
n_ctx=128,
79+
n_heads=4,
80+
n_key_value_heads=2,
81+
d_vocab=256,
82+
architecture="Qwen3MoeForCausalLM",
83+
)
84+
adapter = Qwen3MoeArchitectureAdapter(cfg)
85+
assert adapter.cfg.n_key_value_heads == 2
86+
87+
88+
class TestQwen3MoeWeightConversions:
89+
def test_has_qkvo_keys(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
90+
convs = adapter.weight_processing_conversions
91+
assert convs is not None
92+
assert "blocks.{i}.attn.q.weight" in convs
93+
assert "blocks.{i}.attn.k.weight" in convs
94+
assert "blocks.{i}.attn.v.weight" in convs
95+
assert "blocks.{i}.attn.o.weight" in convs
96+
97+
def test_q_rearrange_uses_n_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
98+
"""Q rearrange uses n_heads (4)."""
99+
convs = adapter.weight_processing_conversions
100+
assert convs is not None
101+
q_conv = convs["blocks.{i}.attn.q.weight"]
102+
assert isinstance(q_conv, ParamProcessingConversion)
103+
assert isinstance(q_conv.tensor_conversion, RearrangeTensorConversion)
104+
axes = q_conv.tensor_conversion.axes_lengths
105+
assert axes.get("n") == 4
106+
107+
def test_kv_rearrange_uses_n_kv_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
108+
"""K/V rearrange uses n_key_value_heads (2) for GQA."""
109+
convs = adapter.weight_processing_conversions
110+
assert convs is not None
111+
k_conv = convs["blocks.{i}.attn.k.weight"]
112+
v_conv = convs["blocks.{i}.attn.v.weight"]
113+
assert isinstance(k_conv, ParamProcessingConversion)
114+
assert isinstance(v_conv, ParamProcessingConversion)
115+
assert isinstance(k_conv.tensor_conversion, RearrangeTensorConversion)
116+
assert isinstance(v_conv.tensor_conversion, RearrangeTensorConversion)
117+
assert k_conv.tensor_conversion.axes_lengths.get("n") == 2
118+
assert v_conv.tensor_conversion.axes_lengths.get("n") == 2
119+
120+
def test_o_rearrange_uses_n_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
121+
"""O rearrange uses n_heads (4)."""
122+
convs = adapter.weight_processing_conversions
123+
assert convs is not None
124+
o_conv = convs["blocks.{i}.attn.o.weight"]
125+
assert isinstance(o_conv, ParamProcessingConversion)
126+
assert isinstance(o_conv.tensor_conversion, RearrangeTensorConversion)
127+
assert o_conv.tensor_conversion.axes_lengths.get("n") == 4
128+
129+
130+
class TestQwen3MoeComponentMapping:
131+
def test_has_required_top_level_keys(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
132+
mapping = adapter.component_mapping
133+
assert mapping is not None
134+
for key in ("embed", "rotary_emb", "blocks", "ln_final", "unembed"):
135+
assert key in mapping, f"Missing top-level key: {key!r}"
136+
137+
def test_blocks_has_required_submodules(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
138+
mapping = adapter.component_mapping
139+
assert mapping is not None
140+
blocks = mapping["blocks"]
141+
for key in ("ln1", "ln2", "attn", "mlp"):
142+
assert key in blocks.submodules, f"Missing blocks submodule: {key!r}"
143+
144+
def test_attn_has_all_submodules(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
145+
mapping = adapter.component_mapping
146+
assert mapping is not None
147+
attn = mapping["blocks"].submodules["attn"]
148+
for key in ("q", "k", "v", "o", "q_norm", "k_norm"):
149+
assert key in attn.submodules, f"Missing attn submodule: {key!r}"
150+
151+
def test_ln1_ln2_are_rms_norm_bridges(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
152+
mapping = adapter.component_mapping
153+
assert mapping is not None
154+
subs = mapping["blocks"].submodules
155+
assert isinstance(subs["ln1"], RMSNormalizationBridge)
156+
assert isinstance(subs["ln2"], RMSNormalizationBridge)
157+
158+
def test_mlp_is_moe_bridge(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
159+
mapping = adapter.component_mapping
160+
assert mapping is not None
161+
mlp = mapping["blocks"].submodules["mlp"]
162+
assert isinstance(mlp, MoEBridge)
163+
164+
def test_mlp_has_gate_submodule(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
165+
mapping = adapter.component_mapping
166+
assert mapping is not None
167+
mlp = mapping["blocks"].submodules["mlp"]
168+
assert "gate" in mlp.submodules
169+
170+
def test_q_norm_k_norm_are_rms_norm_bridges(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
171+
mapping = adapter.component_mapping
172+
assert mapping is not None
173+
attn_subs = mapping["blocks"].submodules["attn"].submodules
174+
assert isinstance(attn_subs["q_norm"], RMSNormalizationBridge)
175+
assert isinstance(attn_subs["k_norm"], RMSNormalizationBridge)
176+
177+
def test_hf_module_paths(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
178+
"""HF module path names are mapped correctly."""
179+
mapping = adapter.component_mapping
180+
assert mapping is not None
181+
assert mapping["embed"].name == "model.embed_tokens"
182+
assert mapping["ln_final"].name == "model.norm"
183+
assert mapping["unembed"].name == "lm_head"
184+
assert mapping["blocks"].name == "model.layers"
185+
subs = mapping["blocks"].submodules
186+
assert subs["ln1"].name == "input_layernorm"
187+
assert subs["ln2"].name == "post_attention_layernorm"
188+
assert subs["attn"].name == "self_attn"
189+
assert subs["mlp"].name == "mlp"
190+
191+
192+
class TestQwen3MoeFactoryRegistration:
193+
def test_factory_lookup_returns_adapter_class(self) -> None:
194+
assert SUPPORTED_ARCHITECTURES["Qwen3MoeForCausalLM"] is Qwen3MoeArchitectureAdapter

transformer_lens/factories/architecture_adapter_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
PhiArchitectureAdapter,
4545
Qwen2ArchitectureAdapter,
4646
Qwen3ArchitectureAdapter,
47+
Qwen3MoeArchitectureAdapter,
4748
Qwen3NextArchitectureAdapter,
4849
QwenArchitectureAdapter,
4950
StableLmArchitectureAdapter,
@@ -91,6 +92,7 @@
9192
"QwenForCausalLM": QwenArchitectureAdapter,
9293
"Qwen2ForCausalLM": Qwen2ArchitectureAdapter,
9394
"Qwen3ForCausalLM": Qwen3ArchitectureAdapter,
95+
"Qwen3MoeForCausalLM": Qwen3MoeArchitectureAdapter,
9496
"Qwen3NextForCausalLM": Qwen3NextArchitectureAdapter,
9597
"StableLmForCausalLM": StableLmArchitectureAdapter,
9698
"T5ForConditionalGeneration": T5ArchitectureAdapter,

0 commit comments

Comments
 (0)