|
| 1 | +"""Qwen3Next architecture adapter. |
| 2 | +
|
| 3 | +Qwen3NextForCausalLM is a hybrid linear-attention + full-attention architecture |
| 4 | +with a sparse Mixture-of-Experts MLP on every layer. Layers alternate between |
| 5 | +GatedDeltaNet (linear attention) and standard full attention blocks, while the |
| 6 | +MLP is always a Qwen3NextSparseMoeBlock (gate router + batched experts + |
| 7 | +shared expert). |
| 8 | +
|
| 9 | +Since self_attn is absent on linear-attention layers, we only map submodules |
| 10 | +that exist on ALL layers (norms, MLP). The HF native forward handles |
| 11 | +linear/full attention dispatch internally, and MoEBridge delegates the entire |
| 12 | +MoE forward (including router, experts, and shared expert) to the native |
| 13 | +implementation. |
| 14 | +
|
| 15 | +Hook coverage: |
| 16 | +- Block-level: hook_resid_pre, hook_resid_post on every layer |
| 17 | +- Normalization: ln1 (input_layernorm), ln2 (post_attention_layernorm) |
| 18 | +- MLP: hook_in, hook_out on the MoE block (MoEBridge) |
| 19 | +- Attention internals are NOT individually hooked (self_attn absent on |
| 20 | + linear-attention layers; mapping it would crash on those layers) |
| 21 | +- Expert-level internals are NOT individually hooked (batched expert params |
| 22 | + live inside Qwen3NextExperts; MoEBridge delegates to HF forward) |
| 23 | +
|
| 24 | +Optional parameters: |
| 25 | +- n_key_value_heads: only set when using GQA (num_key_value_heads != num_attention_heads) |
| 26 | +""" |
| 27 | + |
| 28 | +from typing import Any |
| 29 | + |
| 30 | +import torch |
| 31 | + |
| 32 | +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter |
| 33 | +from transformer_lens.model_bridge.generalized_components import ( |
| 34 | + BlockBridge, |
| 35 | + EmbeddingBridge, |
| 36 | + MoEBridge, |
| 37 | + RMSNormalizationBridge, |
| 38 | + RotaryEmbeddingBridge, |
| 39 | + UnembeddingBridge, |
| 40 | +) |
| 41 | + |
| 42 | + |
| 43 | +class Qwen3NextArchitectureAdapter(ArchitectureAdapter): |
| 44 | + """Architecture adapter for Qwen3Next models. |
| 45 | +
|
| 46 | + Qwen3NextForCausalLM is a hybrid linear-attention + full-attention |
| 47 | + architecture with sparse MoE MLPs, sharing the same design as Qwen3.5: |
| 48 | + - Uses RMSNorm for all normalizations |
| 49 | + - Uses rotary position embeddings (RoPE) with partial rotation |
| 50 | + - Every 4th layer is a full-attention layer (self_attn); the rest are |
| 51 | + GatedDeltaNet linear-attention layers (linear_attn) |
| 52 | + - Uses Qwen3NextSparseMoeBlock on ALL layers (decoder_sparse_step=1 and |
| 53 | + mlp_only_layers=[] on every real checkpoint). The MoE block contains a |
| 54 | + top-K router, batched Qwen3NextExperts (experts.gate_up_proj / |
| 55 | + experts.down_proj as 3D tensors), plus a shared_expert (gated MLP) and |
| 56 | + shared_expert_gate. Each expert is internally a gated MLP. |
| 57 | + - No biases on any linear layers |
| 58 | + - Full-attention layers have Q/K normalization (q_norm, k_norm) |
| 59 | + - Full-attention q_proj outputs n_heads * head_dim * 2 (interleaved |
| 60 | + query+gate layout); the preprocess_weights method slices the query half |
| 61 | +
|
| 62 | + Since self_attn is absent on linear-attention layers, only universally |
| 63 | + present submodules (norms, MLP) are mapped as block submodules. The HF |
| 64 | + native forward handles per-layer attention dispatch internally, and |
| 65 | + MoEBridge delegates the MoE forward pass (including router + experts + |
| 66 | + shared expert) to the native Qwen3NextSparseMoeBlock implementation. |
| 67 | +
|
| 68 | + Optional parameters: |
| 69 | + - n_key_value_heads: set when num_key_value_heads != num_attention_heads (GQA) |
| 70 | + """ |
| 71 | + |
| 72 | + def __init__(self, cfg: Any) -> None: |
| 73 | + """Initialize the Qwen3Next architecture adapter.""" |
| 74 | + super().__init__(cfg) |
| 75 | + |
| 76 | + # Core config attributes |
| 77 | + self.cfg.normalization_type = "RMS" |
| 78 | + self.cfg.positional_embedding_type = "rotary" |
| 79 | + self.cfg.final_rms = True |
| 80 | + self.cfg.gated_mlp = True |
| 81 | + self.cfg.attn_only = False |
| 82 | + self.cfg.uses_rms_norm = True |
| 83 | + self.cfg.default_prepend_bos = False |
| 84 | + |
| 85 | + # Disable fold_ln: ln1 is followed by self_attn on full-attention |
| 86 | + # layers and by linear_attn (GatedDeltaNet) on linear-attention layers, |
| 87 | + # but neither is mapped as a bridge submodule (see class docstring for |
| 88 | + # why). With no bridge-mapped target to fold into, the standard fold_ln |
| 89 | + # pass leaves LN weights in an inconsistent state and the processed |
| 90 | + # bridge output diverges from the unprocessed / HF output. Skipping |
| 91 | + # fold_ln keeps processed-mode forward passes numerically equivalent. |
| 92 | + self.supports_fold_ln = False |
| 93 | + |
| 94 | + # Use eager attention to support output_attentions for hook_attn_scores |
| 95 | + # and hook_pattern. SDPA doesn't support output_attentions. |
| 96 | + self.cfg.attn_implementation = "eager" |
| 97 | + |
| 98 | + # GQA: only set n_key_value_heads when using grouped-query attention |
| 99 | + if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: |
| 100 | + self.cfg.n_key_value_heads = cfg.n_key_value_heads |
| 101 | + |
| 102 | + self.weight_processing_conversions: dict = {} |
| 103 | + self.component_mapping: dict = { |
| 104 | + "embed": EmbeddingBridge(name="model.embed_tokens"), |
| 105 | + "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), |
| 106 | + "blocks": BlockBridge( |
| 107 | + name="model.layers", |
| 108 | + submodules={ |
| 109 | + "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), |
| 110 | + "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), |
| 111 | + # Qwen3NextSparseMoeBlock has a custom Qwen3NextTopKRouter |
| 112 | + # (not an nn.Linear) as `gate`, plus batched experts and a |
| 113 | + # shared expert. MoEBridge wraps the whole MoE module and |
| 114 | + # delegates to HF's native forward, so we don't enumerate |
| 115 | + # the internal structure here. |
| 116 | + "mlp": MoEBridge(name="mlp", config=self.cfg), |
| 117 | + }, |
| 118 | + ), |
| 119 | + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), |
| 120 | + "unembed": UnembeddingBridge(name="lm_head"), |
| 121 | + } |
| 122 | + |
| 123 | + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: |
| 124 | + """No-op for hybrid models. |
| 125 | +
|
| 126 | + Hybrid models don't map attention as a block submodule (self_attn is |
| 127 | + absent on linear-attention layers), so there are no rotary embedding |
| 128 | + references to set up. |
| 129 | +
|
| 130 | + Note: to find which layers are full_attention at runtime, use: |
| 131 | + layer_types = getattr(hf_model.config, "layer_types", []) |
| 132 | + first_full_attn_idx = next( |
| 133 | + i for i, t in enumerate(layer_types) if t == "full_attention" |
| 134 | + ) |
| 135 | + Do NOT use hf_model.config.full_attention_interval -- it is not stored |
| 136 | + on the config object (consumed during __init__ to build layer_types). |
| 137 | + """ |
| 138 | + |
| 139 | + def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| 140 | + """Slice query half from q_proj.weight (interleaved per-head layout). |
| 141 | +
|
| 142 | + In Qwen3Next, q_proj.weight has shape (n_heads * head_dim * 2, hidden_size). |
| 143 | + Rows are organized as per-head interleaved: |
| 144 | + head_0_query (d_head rows), head_0_gate (d_head rows), |
| 145 | + head_1_query (d_head rows), head_1_gate (d_head rows), ... |
| 146 | +
|
| 147 | + A naive first-half slice would be wrong. We must reshape by head, then |
| 148 | + take the first d_head rows of each head (the query half). |
| 149 | +
|
| 150 | + Note: since self_attn is NOT currently mapped as a bridge submodule, |
| 151 | + these weights will not be loaded by the bridge. This method is included |
| 152 | + for correctness and forward-compatibility. |
| 153 | + """ |
| 154 | + n_heads = self.cfg.n_heads |
| 155 | + d_head = self.cfg.d_head |
| 156 | + keys_to_update = [k for k in state_dict if k.endswith(".self_attn.q_proj.weight")] |
| 157 | + for key in keys_to_update: |
| 158 | + w = state_dict[key] # shape: (n_heads * d_head * 2, hidden_size) |
| 159 | + # Reshape to expose per-head layout |
| 160 | + w = w.view(n_heads, d_head * 2, -1) |
| 161 | + # Take only the first d_head rows of each head (query half) |
| 162 | + state_dict[key] = w[:, :d_head, :].reshape(n_heads * d_head, -1) |
| 163 | + return state_dict |
0 commit comments