From 42bc296a60f1d2b7cd8492165b28617a30725c4c Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 8 Feb 2026 18:32:47 +0000 Subject: [PATCH] add Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_5_moe.py | 1527 +++++++++++++++++++++ vllm/model_executor/models/registry.py | 6 +- 2 files changed, 1532 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/qwen3_5_moe.py diff --git a/vllm/model_executor/models/qwen3_5_moe.py b/vllm/model_executor/models/qwen3_5_moe.py new file mode 100644 index 000000000000..dac242a8c159 --- /dev/null +++ b/vllm/model_executor/models/qwen3_5_moe.py @@ -0,0 +1,1527 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3.5-MoE model (VL + hybrid GDN/attention + MoE).""" + +import typing +from collections.abc import Callable, Iterable +from itertools import islice + +import torch +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CacheConfig, + ModelConfig, + SpeculativeConfig, + VllmConfig, + get_current_vllm_config, +) +from vllm.distributed import ( + divide, + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.layernorm import ( + GemmaRMSNorm as Qwen3_5MoeRMSNorm, +) +from vllm.model_executor.layers.layernorm import RMSNormGated +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + mamba_v2_sharded_weight_loader, +) +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, + sharded_weight_loader, +) +from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3_5MoeMLP +from vllm.model_executor.models.qwen3_next import ( + Qwen3NextAttention, + fused_gdn_gating, +) +from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.model_executor.utils import set_weight_attrs +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors + +from .interfaces import ( + HasInnerState, + IsHybrid, + MixtureOfExperts, + SupportsLoRA, + SupportsPP, +) +from .qwen3_vl import ( + Qwen3_VisionTransformer, + Qwen3VLDummyInputsBuilder, + Qwen3VLForConditionalGeneration, + Qwen3VLMultiModalProcessor, + Qwen3VLProcessingInfo, +) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +# --------------------------------------------------------------------------- +# Gated Delta Net (linear attention) with SEPARATE projections +# --------------------------------------------------------------------------- +class Qwen3_5MoeGatedDeltaNet(nn.Module, MambaBase): + """Gated Delta Net for Qwen3.5-MoE. + + Key difference from Qwen3NextGatedDeltaNet: uses 4 separate projection + layers (in_proj_qkv, in_proj_z, in_proj_b, in_proj_a) instead of 2 + combined projections (in_proj_qkvz, in_proj_ba). + """ + + @property + def mamba_type(self) -> str: + return "gdn_attention" + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + self.model_config.dtype, self.cache_config.mamba_cache_dtype + ) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.gated_delta_net_state_shape( + self.tp_size, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + self.conv_kernel_size, + self.num_spec, + ) + + def __init__( + self, + config, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + speculative_config: SpeculativeConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = extract_layer_index(prefix) + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + self.prefix = prefix + + self.config = config + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config + else 0 + ) + + # ---- Convolution ---- + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [query_key_settings, query_key_settings, value_settings], + self.tp_size, + self.tp_rank, + ) + }, + ) + + # ---- Separate projections (Qwen3.5 style) ---- + # in_proj_qkv: projects hidden → Q + K + V + self.qkv_proj_size = self.key_dim * 2 + self.value_dim + self.in_proj_qkv = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.qkv_proj_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_qkv", + ) + # Custom weight loader for non-uniform Q/K/V split + delattr(self.in_proj_qkv.weight, "weight_loader") + set_weight_attrs( + self.in_proj_qkv.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [query_key_settings, query_key_settings, value_settings], + self.tp_size, + self.tp_rank, + ) + }, + ) + + # in_proj_z: projects hidden → Z (gating, same dim as V) + self.in_proj_z = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.value_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_z", + ) + + # in_proj_b: projects hidden → beta (per v-head) + self.in_proj_b = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_b", + ) + + # in_proj_a: projects hidden → alpha (per v-head) + self.in_proj_a = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_a", + ) + + # ---- Selective projection parameters ---- + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), + ) + self.A_log = nn.Parameter( + torch.empty(divide(self.num_v_heads, self.tp_size)), + ) + + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + device=current_platform.current_device(), + dtype=config.dtype, + ) + + self.out_proj = RowParallelLinear( + self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def rearrange_mixed_qkv(self, mixed_qkv): + if mixed_qkv is None: + return None, None, None + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim // self.tp_size, + self.key_dim // self.tp_size, + self.value_dim // self.tp_size, + ], + dim=-1, + ) + query, key = map( + lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), + (query, key), + ) + value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) + return query.contiguous(), key.contiguous(), value.contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + num_tokens = hidden_states.size(0) + + # ---- Input Projection (separate projections) ---- + mixed_qkv, _ = self.in_proj_qkv(hidden_states) + z, _ = self.in_proj_z(hidden_states) + b, _ = self.in_proj_b(hidden_states) + a, _ = self.in_proj_a(hidden_states) + + # Reshape z to [num_tokens, num_v_heads//tp, head_v_dim] + z = z.view(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim) + + # ---- Core Attention (Custom Op) ---- + core_attn_out = torch.zeros( + (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + torch.ops.vllm.gdn_attention_core( + mixed_qkv, + b, + a, + core_attn_out, + self.prefix, + ) + + # ---- Output Projection ---- + z_shape_og = z.shape + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + output[:num_tokens], _ = self.out_proj(core_attn_out) + + def _forward_core( + self, + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + ): + """Core attention computation (called by custom op). + + This is identical to Qwen3NextGatedDeltaNet._forward_core. + """ + from vllm.model_executor.layers.fla.ops import ( + chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule, + ) + from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, + ) + from vllm.v1.attention.backend import AttentionMetadata + from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_indx = attn_metadata.spec_token_indx + non_spec_token_indx = attn_metadata.non_spec_token_indx + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + num_actual_tokens = attn_metadata.num_actual_tokens + num_accepted_tokens = attn_metadata.num_accepted_tokens + + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] + + # 1. Convolution sequence transformation + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) + mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 1.1: Process the multi-query part + if spec_sequence_masks is not None: + mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0][ + : attn_metadata.num_spec_decodes + ], + num_accepted_tokens=num_accepted_tokens, + query_start_loc=spec_query_start_loc, + max_query_len=spec_state_indices_tensor.size(-1), + validate_data=False, + ) + + # 1.2: Process the remaining part + if attn_metadata.num_prefills > 0: + mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) + mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec_T, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=attn_metadata, + ).transpose(0, 1) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[ + : attn_metadata.num_actual_tokens + ], + validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec + ) + + g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) + + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 2. Recurrent attention + + # 2.1: Process the multi-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 2.2: Process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( + ssm_state.dtype + ) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[ + : attn_metadata.num_decodes + 1 + ], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + ) + ) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + # 3. Merge core attention output + if spec_sequence_masks is not None and core_attn_out_non_spec is not None: + merged_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) + merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) + elif spec_sequence_masks is not None: + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) + + +# --------------------------------------------------------------------------- +# Sparse MoE block – uses hf_text_config and handles missing norm_topk_prob +# --------------------------------------------------------------------------- +class Qwen3_5MoeSparseMoeBlock(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_text_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + + self.tp_size = get_tensor_model_parallel_world_size() + + self.ep_group = get_ep_group().device_group + self.ep_rank = get_ep_group().rank_in_group + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}." + ) + + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) + + self.shared_expert_gate = ReplicatedLinear( + config.hidden_size, + 1, + bias=False, + quant_config=None, + prefix=f"{prefix}.shared_expert_gate", + ) + + shared_expert_intermediate_size = getattr( + config, "shared_expert_intermediate_size", 0 + ) + if shared_expert_intermediate_size > 0: + self.shared_expert = Qwen3_5MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + expert_gate=self.shared_expert_gate, + prefix=f"{prefix}.shared_expert", + ) + else: + self.shared_expert = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, + gate=self.gate, + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=getattr(config, "norm_topk_prob", True), + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + if self.experts.is_internal_router: + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) + else: + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self.shared_expert is not None: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + + return final_hidden_states.view(orig_shape) + + +# --------------------------------------------------------------------------- +# Decoder layer – always MoE, dispatch to GDN or full attention +# --------------------------------------------------------------------------- +class Qwen3_5MoeDecoderLayer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + layer_type: str, + prefix: str = "", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_text_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + speculative_config = vllm_config.speculative_config + + self.layer_type = layer_type + self.layer_idx = extract_layer_index(prefix) + + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3_5MoeGatedDeltaNet( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f"{prefix}.linear_attn", + ) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + else: + raise ValueError(f"Invalid layer_type {self.layer_type}") + + # Always MoE (no dense MLP fallback) + self.mlp = Qwen3_5MoeSparseMoeBlock( + vllm_config=vllm_config, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = Qwen3_5MoeRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Qwen3_5MoeRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + positions: torch.Tensor = None, + **kwargs: object, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + self_attention_output = torch.empty_like(hidden_states) + if self.layer_type == "linear_attention": + self.linear_attn( + hidden_states=hidden_states, + output=self_attention_output, + ) + elif self.layer_type == "full_attention": + self.self_attn( + hidden_states=hidden_states, + output=self_attention_output, + positions=positions, + ) + else: + raise ValueError("Invalid layer_type") + hidden_states = self_attention_output + + # Fully Connected (always MoE) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# --------------------------------------------------------------------------- +# Model backbone +# --------------------------------------------------------------------------- +@support_torch_compile +class Qwen3_5MoeModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_text_config + parallel_config = vllm_config.parallel_config + + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + + self.config = config + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + + self.aux_hidden_state_layers: tuple[int, ...] = () + + def get_layer(prefix: str): + return Qwen3_5MoeDecoderLayer( + vllm_config, + layer_type=config.layer_types[extract_layer_index(prefix)], + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + if get_pp_group().is_last_rank: + self.norm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return SharedFusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if name.startswith("mtp."): + continue + + if name.endswith("scale"): + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + logger.warning_once( + "Parameter %s not found in params_dict, skip loading", + name, + ) + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +# --------------------------------------------------------------------------- +# MoE interface for text model +# --------------------------------------------------------------------------- +class Qwen3_5MoeMixtureOfExperts(MixtureOfExperts): + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer.mlp, Qwen3_5MoeSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def set_moe_parameters(self): + self.expert_weights = [] + self.moe_layers = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, Qwen3_5MoeDecoderLayer) and isinstance( + layer.mlp, Qwen3_5MoeSparseMoeBlock + ): + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError("No Qwen3_5Moe MoE layer found in model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + +# --------------------------------------------------------------------------- +# CausalLM (text model) +# --------------------------------------------------------------------------- +class Qwen3_5MoeForCausalLM( + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + Qwen3_5MoeMixtureOfExperts, + IsHybrid, +): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_text_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + scheduler_config = vllm_config.scheduler_config + + if cache_config.mamba_cache_mode == "all": + raise NotImplementedError( + "Qwen3_5Moe currently does not support 'all' prefix caching, " + "please use '--mamba-cache-mode=align' instead" + ) + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = Qwen3_5MoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + self.set_moe_parameters() + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_text_config + tp_size = parallel_config.tensor_parallel_size + num_spec = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ) + return MambaStateShapeCalculator.gated_delta_net_state_shape( + tp_size, + hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, + hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, + hf_config.linear_conv_kernel_dim, + num_spec, + ) + + @classmethod + def get_mamba_state_copy_func( + cls, + ) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func() + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["mtp."]) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + +# =================================================================== +# Vision-Language (VL) wrappers +# =================================================================== + + +class Qwen3_5MoeVLProcessingInfo(Qwen3VLProcessingInfo): + def get_hf_config(self): + from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import ( + Qwen3_5MoeConfig, + ) + + return self.ctx.get_hf_config(Qwen3_5MoeConfig) + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + "deepstack_input_embeds": 0, + } +) +class Qwen3_5MoeLLMModel(Qwen3_5MoeModel): + """Qwen3_5MoeModel with deepstack support for VL usage.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + vision_config = vllm_config.model_config.hf_config.vision_config + if not get_pp_group().is_first_rank and hasattr( + vision_config, "deepstack_visual_indexes" + ): + assert self.start_layer >= len(vision_config.deepstack_visual_indexes), ( + "start_layer should be >= len(deepstack_visual_indexes)" + ) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + deepstack_input_embeds: IntermediateTensors | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + aux_hidden_states = [] + for layer_idx, layer in islice( + enumerate(self.layers), self.start_layer, self.end_layer + ): + if layer_idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if deepstack_input_embeds is not None and layer_idx in range( + 0, len(deepstack_input_embeds) + ): + hidden_states = ( + hidden_states + + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"] + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states + + def load_fused_expert_weights( + self, + name: str, + params_dict: dict, + loaded_weight: torch.Tensor, + shard_id: str, + num_experts: int, + ) -> bool: + param = params_dict[name] + weight_loader = typing.cast(Callable[..., bool], param.weight_loader) + loaded_local_expert = False + for expert_id in range(num_experts): + curr_expert_weight = loaded_weight[expert_id] + success = weight_loader( + param, + curr_expert_weight, + name, + shard_id, + expert_id, + return_success=True, + ) + if success: + loaded_local_expert = True + return loaded_local_expert + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + ignore_suffixes = ( + ".bias", + "_bias", + ".k_scale", + "_k_scale", + ".v_scale", + "_v_scale", + ".weight_scale", + "_weight_scale", + ".input_scale", + "_input_scale", + ) + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + is_fused_expert = False + fused_expert_params_mapping = [ + ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), + ("experts.w2_weight", "experts.down_proj", 0, "w2"), + ] + num_experts = self.config.num_experts + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if "experts.gate_up_proj" in name or "experts.down_proj" in name: + is_fused_expert = True + expert_params_mapping = fused_expert_params_mapping + + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name_mapped, self): + continue + if is_fused_expert: + if "experts.gate_up_proj" in name: + loaded_weight = loaded_weight.chunk(2, dim=-2) + success_w1 = self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight[0], + "w1", + num_experts, + ) + success_w3 = self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight[1], + "w3", + num_experts, + ) + success = success_w1 and success_w3 + else: + success = self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight, + shard_id, + num_experts, + ) + else: + if ( + name_mapped.endswith(ignore_suffixes) + and name_mapped not in params_dict + ): + continue + param = params_dict[name_mapped] + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + break + else: + if is_expert_weight: + continue + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale" + ) + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in checkpoint (e.g. %s), " + "but not found expected name in model " + "(e.g. %s). kv-scale is not loaded.", + name, + remapped_kv_scale_name, + ) + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3_5MoeLLMForCausalLM(Qwen3_5MoeForCausalLM): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3_5MoeForCausalLM, self).__init__() + self.config = vllm_config.model_config.hf_config.text_config + self.quant_config = vllm_config.quant_config + self.model = Qwen3_5MoeLLMModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(self.config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + +# --------------------------------------------------------------------------- +# MoE interface for VL model +# --------------------------------------------------------------------------- +class Qwen3_5MoeVLMixtureOfExperts(MixtureOfExperts): + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.language_model.model.layers: + if isinstance(layer.mlp, Qwen3_5MoeSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def set_moe_parameters(self): + self.expert_weights = [] + self.moe_layers = [] + example_moe = None + for layer in self.language_model.model.layers: + if hasattr(layer, "mlp") and isinstance( + layer.mlp, Qwen3_5MoeSparseMoeBlock + ): + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError("No Qwen3_5Moe layer found in the language_model.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + +# --------------------------------------------------------------------------- +# VL top-level model +# --------------------------------------------------------------------------- +@MULTIMODAL_REGISTRY.register_processor( + Qwen3VLMultiModalProcessor, + info=Qwen3_5MoeVLProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder, +) +class Qwen3_5MoeForConditionalGeneration( + Qwen3VLForConditionalGeneration, + Qwen3_5MoeVLMixtureOfExperts, + HasInnerState, + IsHybrid, +): + is_3d_moe_weight: bool = True + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.visual.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): + super(Qwen3VLForConditionalGeneration, self).__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.video_pruning_rate = multimodal_config.video_pruning_rate + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled() + ) + + self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes") + self.deepstack_num_level = ( + len(config.vision_config.deepstack_visual_indexes) + if self.use_deepstack + else 0 + ) + self.visual_dim = config.vision_config.out_hidden_size + self.multiscale_dim = self.visual_dim * self.deepstack_num_level + + with self._mark_tower_model(vllm_config, {"image", "video"}): + self.visual = Qwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + + if self.use_deepstack: + self.deepstack_input_embeds = [ + torch.zeros( + vllm_config.scheduler_config.max_num_batched_tokens, + config.text_config.hidden_size, + ) + for _ in range(self.deepstack_num_level) + ] + + with self._mark_language_model(vllm_config): + self.language_model = Qwen3_5MoeLLMForCausalLM( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.packed_modules_mapping = ( + self.packed_modules_mapping | self.language_model.packed_modules_mapping + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + self.set_moe_parameters() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["mtp."]) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + return Qwen3_5MoeForCausalLM.get_mamba_state_dtype_from_config(vllm_config) + + @classmethod + def get_mamba_state_shape_from_config( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, int], tuple[int, int]]: + return Qwen3_5MoeForCausalLM.get_mamba_state_shape_from_config(vllm_config) + + @classmethod + def get_mamba_state_copy_func( + cls, + ) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return Qwen3_5MoeForCausalLM.get_mamba_state_copy_func() diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c310f6f177dd..608e77064fbb 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, Any, TypeVar import torch.nn as nn -import transformers +import transformers from vllm import envs from vllm.config import ( ModelConfig, @@ -465,6 +465,10 @@ "qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration", ), + "Qwen3_5MoeForConditionalGeneration": ( + "qwen3_5_moe", + "Qwen3_5MoeForConditionalGeneration", + ), "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501 "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501