From dccdac8679f61fd38f6d174c063a7d091cb6122e Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 9 Feb 2026 10:53:40 -0500 Subject: [PATCH 1/2] solving formatting issues --- vllm/config/__init__.py | 4 ++ vllm/config/moe.py | 69 +++++++++++++++++++ vllm/config/vllm.py | 7 ++ vllm/engine/arg_utils.py | 14 +++- .../layers/quantization/mxfp4.py | 56 ++++++++++----- 5 files changed, 129 insertions(+), 21 deletions(-) create mode 100644 vllm/config/moe.py diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 5bcf9865c279..f7ee20a32fc8 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -22,6 +22,7 @@ str_dtype_to_torch_dtype, try_match_architecture_defaults, ) +from vllm.config.moe import MoeConfig, Mxfp4Backend from vllm.config.multimodal import MultiModalConfig from vllm.config.observability import ObservabilityConfig from vllm.config.parallel import EPLBConfig, ParallelConfig @@ -76,6 +77,9 @@ "LoadConfig", # From vllm.config.lora "LoRAConfig", + # From vllm.config.moe + "MoeConfig", + "Mxfp4Backend", # From vllm.config.model "ModelConfig", "iter_architecture_defaults", diff --git a/vllm/config/moe.py b/vllm/config/moe.py new file mode 100644 index 000000000000..ea425760c085 --- /dev/null +++ b/vllm/config/moe.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Configuration for MoE backends.""" + +from enum import Enum +from typing import Any + +from pydantic import field_validator +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +class Mxfp4Backend(Enum): + NONE = 0 + + # FlashInfer Backends (NVIDIA) + SM100_FI_MXFP4_MXFP8_TRTLLM = 1 + SM100_FI_MXFP4_MXFP8_CUTLASS = 2 + SM100_FI_MXFP4_BF16 = 3 + SM90_FI_MXFP4_BF16 = 4 + + # Marlin Backend + MARLIN = 5 + + # Triton Backend + TRITON = 6 + + # ROCm CK (Composable Kernel) Backend + CK = 7 + + +@config +@dataclass +class MoeConfig: + """ + If not specified via --moe_config.backend, the backend will be + auto-selected based on platform and libraries + + usage: + vllm serve model_name --moe_config.backend=TRITON + vllm serve model_name --moe_config.backend=CK + vllm serve model_name --moe_config.backend=MARLIN + """ + + backend: Mxfp4Backend | None = None + """MoE backend to use. If None, will be selected automatically.""" + + def compute_hash(self) -> str: + """ + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + from vllm.config.utils import get_hash_factors, hash_factors + + ignored_factors: list[str] = [] + factors = get_hash_factors(self, ignored_factors) + return hash_factors(factors) + + @field_validator("backend", mode="before") + @classmethod + def validate_backend_before(cls, value: Any) -> Any: + """Enable parsing of the `backend` enum type from string.""" + if isinstance(value, str): + return Mxfp4Backend[value.upper()] + return value diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index c1ef8e6aae39..720fb934bb2f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -36,6 +36,7 @@ from .load import LoadConfig from .lora import LoRAConfig from .model import ModelConfig +from .moe import MoeConfig from .observability import ObservabilityConfig from .parallel import ParallelConfig from .profiler import ProfilerConfig @@ -224,6 +225,8 @@ class VllmConfig: """Load configuration.""" attention_config: AttentionConfig = Field(default_factory=AttentionConfig) """Attention configuration.""" + moe_config: MoeConfig = Field(default_factory=MoeConfig) + """MoE (Mixture of Experts) configuration.""" kernel_config: KernelConfig = Field(default_factory=KernelConfig) """Kernel configuration.""" lora_config: LoRAConfig | None = None @@ -328,6 +331,10 @@ def compute_hash(self) -> str: vllm_factors.append(self.attention_config.compute_hash()) else: vllm_factors.append("None") + if self.moe_config: + vllm_factors.append(self.moe_config.compute_hash()) + else: + vllm_factors.append("None") if self.lora_config: vllm_factors.append(self.lora_config.compute_hash()) else: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c7c78ffd8e00..11dd10409221 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -46,6 +46,7 @@ LoadConfig, LoRAConfig, ModelConfig, + MoeConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, @@ -538,6 +539,7 @@ class EngineArgs: pooler_config: PoolerConfig | None = ModelConfig.pooler_config compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config") attention_config: AttentionConfig = get_field(VllmConfig, "attention_config") + moe_config: MoeConfig = get_field(VllmConfig, "moe_config") kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config") enable_flashinfer_autotune: bool = get_field( KernelConfig, "enable_flashinfer_autotune" @@ -601,6 +603,8 @@ def __post_init__(self): self.compilation_config = CompilationConfig(**self.compilation_config) if isinstance(self.attention_config, dict): self.attention_config = AttentionConfig(**self.attention_config) + if isinstance(self.moe_config, dict): + self.moe_config = MoeConfig(**self.moe_config) if isinstance(self.kernel_config, dict): self.kernel_config = KernelConfig(**self.kernel_config) if isinstance(self.eplb_config, dict): @@ -1211,6 +1215,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: vllm_group.add_argument( "--attention-config", "-ac", **vllm_kwargs["attention_config"] ) + vllm_group.add_argument("--moe-config", "-mc", **vllm_kwargs["moe_config"]) vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"]) vllm_group.add_argument( "--additional-config", **vllm_kwargs["additional_config"] @@ -1699,9 +1704,11 @@ def create_engine_config( lora_dtype=self.lora_dtype, enable_tower_connector_lora=self.enable_tower_connector_lora, specialize_active_lora=self.specialize_active_lora, - max_cpu_loras=self.max_cpu_loras - if self.max_cpu_loras and self.max_cpu_loras > 0 - else None, + max_cpu_loras=( + self.max_cpu_loras + if self.max_cpu_loras and self.max_cpu_loras > 0 + else None + ), ) if self.enable_lora else None @@ -1802,6 +1809,7 @@ def create_engine_config( device_config=device_config, load_config=load_config, attention_config=attention_config, + moe_config=self.moe_config, kernel_config=kernel_config, lora_config=lora_config, speculative_config=speculative_config, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index dc1e7d88a91c..35e8afac281f 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from enum import Enum import torch from torch.nn.parameter import Parameter from vllm import envs from vllm._aiter_ops import rocm_aiter_ops -from vllm.config import get_current_vllm_config +from vllm.config import ( + Mxfp4Backend, + get_current_vllm_config, + get_current_vllm_config_or_none, +) from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import ( @@ -60,23 +63,22 @@ logger = init_logger(__name__) -# enum for mxfp4 backend -class Mxfp4Backend(Enum): - NONE = 0 - - # FlashInfer Backend - SM100_FI_MXFP4_MXFP8_TRTLLM = 1 - SM100_FI_MXFP4_MXFP8_CUTLASS = 2 - SM100_FI_MXFP4_BF16 = 3 - SM90_FI_MXFP4_BF16 = 4 - - # Marlin Backend - MARLIN = 5 - - # Triton Backend - TRITON = 6 +def _get_user_specified_moe_backend() -> Mxfp4Backend | None: + """ + Check if the user has explicitly specified a MoE backend. + Returns None if not specified or if unavailable + """ + vllm_config = get_current_vllm_config_or_none() + if vllm_config is None: + return None - CK = 7 + backend = vllm_config.moe_config.backend + if backend is not None: + logger.info_once( + "Using user-specified MoE backend: %s (via --moe_config.backend)", + backend.name, + ) + return backend def get_mxfp4_backend_with_lora() -> Mxfp4Backend: @@ -106,6 +108,12 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend: def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: # Backend Selection + # check if --moe_config.backend was used + user_backend = _get_user_specified_moe_backend() + if user_backend is not None: + return user_backend + + # Fall back to auto-detection if with_lora_support: return get_mxfp4_backend_with_lora() @@ -783,6 +791,18 @@ def _interleave_mxfp4_cutlass_sm90(w): .contiguous() .view(e, n, -1) ) + w13_aiter_weight = ( + w13_aiter_weight.view(e, n // 2, 2, k) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, k) + ) + w13_aiter_scale = ( + w13_aiter_scale.view(e, n // 2, 2, -1) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, -1) + ) w13_aiter_weight = w13_aiter_weight.view(torch.float4_e2m1fn_x2) w13_aiter_scale = w13_aiter_scale.view(-1, w13_aiter_scale.shape[-1]) From 0011725ab0e0464d495fb8467b9300fa5fcfd476 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 12 Feb 2026 11:05:18 -0500 Subject: [PATCH 2/2] console logs for moe usage --- vllm/engine/arg_utils.py | 7 ++++++ vllm/model_executor/layers/fused_moe/layer.py | 25 +++++++++++-------- .../layers/quantization/mxfp4.py | 9 +++---- 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 11dd10409221..f7f06771be98 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1825,6 +1825,13 @@ def create_engine_config( weight_transfer_config=self.weight_transfer_config, ) + # Log MoE config being set + if config.moe_config.backend is not None: + logger.info_once( + "MoE backend set to %s (from --moe_config.backend).", + config.moe_config.backend.name, + ) + return config def _check_feature_supported(self, model_config: ModelConfig): diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f35ec87aac42..fa2d8300621c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -369,6 +369,13 @@ def __init__( vllm_config = get_current_vllm_config() self.vllm_config = vllm_config + # Log MoE backend confirmation + if vllm_config.moe_config.backend is not None: + logger.info_once( + "MoE backend confirmed: %s (from --moe_config.backend).", + vllm_config.moe_config.backend.name, + ) + # FIXME (varun): We should have a better way of inferring the activation # datatype. This works for now as the tensor datatype entering the MoE # operation is typically unquantized (i.e. float16/bfloat16). @@ -1725,14 +1732,8 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): batched_hidden_states = self.batched_hidden_states batched_router_logits = self.batched_router_logits - assert ( - batched_hidden_states.size(0) # type: ignore - >= chunk_size - ) - assert ( - batched_router_logits.size(0) # type: ignore - >= chunk_size - ) + assert batched_hidden_states.size(0) >= chunk_size # type: ignore + assert batched_router_logits.size(0) >= chunk_size # type: ignore staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore staged_hidden_states.copy_(hidden_states, non_blocking=True) @@ -2022,9 +2023,11 @@ def make_expert_params_mapping( return [ # (param_name, weight_name, expert_id, shard_id) ( - f"experts.{base_layer}w13_" - if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] - else f"experts.{base_layer}w2_", + ( + f"experts.{base_layer}w13_" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else f"experts.{base_layer}w2_" + ), f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.{base_layer}", expert_id, shard_id, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 35e8afac281f..5fd224c511e2 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -73,11 +73,6 @@ def _get_user_specified_moe_backend() -> Mxfp4Backend | None: return None backend = vllm_config.moe_config.backend - if backend is not None: - logger.info_once( - "Using user-specified MoE backend: %s (via --moe_config.backend)", - backend.name, - ) return backend @@ -111,6 +106,10 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: # check if --moe_config.backend was used user_backend = _get_user_specified_moe_backend() if user_backend is not None: + logger.info_once( + "MoE backend being used: %s (from --moe_config.backend).", + user_backend.name, + ) return user_backend # Fall back to auto-detection