Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
str_dtype_to_torch_dtype,
try_match_architecture_defaults,
)
from vllm.config.moe import MoeConfig, Mxfp4Backend
from vllm.config.multimodal import MultiModalConfig
from vllm.config.observability import ObservabilityConfig
from vllm.config.parallel import EPLBConfig, ParallelConfig
Expand Down Expand Up @@ -76,6 +77,9 @@
"LoadConfig",
# From vllm.config.lora
"LoRAConfig",
# From vllm.config.moe
"MoeConfig",
"Mxfp4Backend",
# From vllm.config.model
"ModelConfig",
"iter_architecture_defaults",
Expand Down
69 changes: 69 additions & 0 deletions vllm/config/moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Configuration for MoE backends."""

from enum import Enum
from typing import Any

from pydantic import field_validator
from pydantic.dataclasses import dataclass

from vllm.config.utils import config


class Mxfp4Backend(Enum):
NONE = 0

# FlashInfer Backends (NVIDIA)
SM100_FI_MXFP4_MXFP8_TRTLLM = 1
SM100_FI_MXFP4_MXFP8_CUTLASS = 2
SM100_FI_MXFP4_BF16 = 3
SM90_FI_MXFP4_BF16 = 4

# Marlin Backend
MARLIN = 5

# Triton Backend
TRITON = 6

# ROCm CK (Composable Kernel) Backend
CK = 7


@config
@dataclass
class MoeConfig:
"""
If not specified via --moe_config.backend, the backend will be
auto-selected based on platform and libraries

usage:
vllm serve model_name --moe_config.backend=TRITON
vllm serve model_name --moe_config.backend=CK
vllm serve model_name --moe_config.backend=MARLIN
"""

backend: Mxfp4Backend | None = None
"""MoE backend to use. If None, will be selected automatically."""

def compute_hash(self) -> str:
"""
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
from vllm.config.utils import get_hash_factors, hash_factors

ignored_factors: list[str] = []
factors = get_hash_factors(self, ignored_factors)
return hash_factors(factors)

@field_validator("backend", mode="before")
@classmethod
def validate_backend_before(cls, value: Any) -> Any:
"""Enable parsing of the `backend` enum type from string."""
if isinstance(value, str):
return Mxfp4Backend[value.upper()]
return value
7 changes: 7 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .load import LoadConfig
from .lora import LoRAConfig
from .model import ModelConfig
from .moe import MoeConfig
from .observability import ObservabilityConfig
from .parallel import ParallelConfig
from .profiler import ProfilerConfig
Expand Down Expand Up @@ -224,6 +225,8 @@ class VllmConfig:
"""Load configuration."""
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
"""Attention configuration."""
moe_config: MoeConfig = Field(default_factory=MoeConfig)
"""MoE (Mixture of Experts) configuration."""
kernel_config: KernelConfig = Field(default_factory=KernelConfig)
"""Kernel configuration."""
lora_config: LoRAConfig | None = None
Expand Down Expand Up @@ -328,6 +331,10 @@ def compute_hash(self) -> str:
vllm_factors.append(self.attention_config.compute_hash())
else:
vllm_factors.append("None")
if self.moe_config:
vllm_factors.append(self.moe_config.compute_hash())
else:
vllm_factors.append("None")
if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash())
else:
Expand Down
21 changes: 18 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
LoadConfig,
LoRAConfig,
ModelConfig,
MoeConfig,
MultiModalConfig,
ObservabilityConfig,
ParallelConfig,
Expand Down Expand Up @@ -538,6 +539,7 @@ class EngineArgs:
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
moe_config: MoeConfig = get_field(VllmConfig, "moe_config")
kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config")
enable_flashinfer_autotune: bool = get_field(
KernelConfig, "enable_flashinfer_autotune"
Expand Down Expand Up @@ -601,6 +603,8 @@ def __post_init__(self):
self.compilation_config = CompilationConfig(**self.compilation_config)
if isinstance(self.attention_config, dict):
self.attention_config = AttentionConfig(**self.attention_config)
if isinstance(self.moe_config, dict):
self.moe_config = MoeConfig(**self.moe_config)
if isinstance(self.kernel_config, dict):
self.kernel_config = KernelConfig(**self.kernel_config)
if isinstance(self.eplb_config, dict):
Expand Down Expand Up @@ -1211,6 +1215,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
vllm_group.add_argument(
"--attention-config", "-ac", **vllm_kwargs["attention_config"]
)
vllm_group.add_argument("--moe-config", "-mc", **vllm_kwargs["moe_config"])
vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"])
vllm_group.add_argument(
"--additional-config", **vllm_kwargs["additional_config"]
Expand Down Expand Up @@ -1699,9 +1704,11 @@ def create_engine_config(
lora_dtype=self.lora_dtype,
enable_tower_connector_lora=self.enable_tower_connector_lora,
specialize_active_lora=self.specialize_active_lora,
max_cpu_loras=self.max_cpu_loras
if self.max_cpu_loras and self.max_cpu_loras > 0
else None,
max_cpu_loras=(
self.max_cpu_loras
if self.max_cpu_loras and self.max_cpu_loras > 0
else None
),
)
if self.enable_lora
else None
Expand Down Expand Up @@ -1802,6 +1809,7 @@ def create_engine_config(
device_config=device_config,
load_config=load_config,
attention_config=attention_config,
moe_config=self.moe_config,
kernel_config=kernel_config,
lora_config=lora_config,
speculative_config=speculative_config,
Expand All @@ -1817,6 +1825,13 @@ def create_engine_config(
weight_transfer_config=self.weight_transfer_config,
)

# Log MoE config being set
if config.moe_config.backend is not None:
logger.info_once(
"MoE backend set to %s (from --moe_config.backend).",
config.moe_config.backend.name,
)

return config

def _check_feature_supported(self, model_config: ModelConfig):
Expand Down
25 changes: 14 additions & 11 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,13 @@ def __init__(
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config

# Log MoE backend confirmation
if vllm_config.moe_config.backend is not None:
logger.info_once(
"MoE backend confirmed: %s (from --moe_config.backend).",
vllm_config.moe_config.backend.name,
)

# FIXME (varun): We should have a better way of inferring the activation
# datatype. This works for now as the tensor datatype entering the MoE
# operation is typically unquantized (i.e. float16/bfloat16).
Expand Down Expand Up @@ -1725,14 +1732,8 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
batched_hidden_states = self.batched_hidden_states
batched_router_logits = self.batched_router_logits

assert (
batched_hidden_states.size(0) # type: ignore
>= chunk_size
)
assert (
batched_router_logits.size(0) # type: ignore
>= chunk_size
)
assert batched_hidden_states.size(0) >= chunk_size # type: ignore
assert batched_router_logits.size(0) >= chunk_size # type: ignore
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True)
Expand Down Expand Up @@ -2022,9 +2023,11 @@ def make_expert_params_mapping(
return [
# (param_name, weight_name, expert_id, shard_id)
(
f"experts.{base_layer}w13_"
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
else f"experts.{base_layer}w2_",
(
f"experts.{base_layer}w13_"
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
else f"experts.{base_layer}w2_"
),
f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.{base_layer}",
expert_id,
shard_id,
Expand Down
55 changes: 37 additions & 18 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum

import torch
from torch.nn.parameter import Parameter

from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.config import (
Mxfp4Backend,
get_current_vllm_config,
get_current_vllm_config_or_none,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import (
Expand Down Expand Up @@ -60,23 +63,17 @@
logger = init_logger(__name__)


# enum for mxfp4 backend
class Mxfp4Backend(Enum):
NONE = 0

# FlashInfer Backend
SM100_FI_MXFP4_MXFP8_TRTLLM = 1
SM100_FI_MXFP4_MXFP8_CUTLASS = 2
SM100_FI_MXFP4_BF16 = 3
SM90_FI_MXFP4_BF16 = 4

# Marlin Backend
MARLIN = 5

# Triton Backend
TRITON = 6
def _get_user_specified_moe_backend() -> Mxfp4Backend | None:
"""
Check if the user has explicitly specified a MoE backend.
Returns None if not specified or if unavailable
"""
vllm_config = get_current_vllm_config_or_none()
if vllm_config is None:
return None

CK = 7
backend = vllm_config.moe_config.backend
return backend


def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
Expand Down Expand Up @@ -106,6 +103,16 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
# Backend Selection

# check if --moe_config.backend was used
user_backend = _get_user_specified_moe_backend()
if user_backend is not None:
logger.info_once(
"MoE backend being used: %s (from --moe_config.backend).",
user_backend.name,
)
return user_backend

# Fall back to auto-detection
if with_lora_support:
return get_mxfp4_backend_with_lora()

Expand Down Expand Up @@ -783,6 +790,18 @@ def _interleave_mxfp4_cutlass_sm90(w):
.contiguous()
.view(e, n, -1)
)
w13_aiter_weight = (
w13_aiter_weight.view(e, n // 2, 2, k)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, k)
)
w13_aiter_scale = (
w13_aiter_scale.view(e, n // 2, 2, -1)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, -1)
)

w13_aiter_weight = w13_aiter_weight.view(torch.float4_e2m1fn_x2)
w13_aiter_scale = w13_aiter_scale.view(-1, w13_aiter_scale.shape[-1])
Expand Down