From fcf215e90ea57a164c13d27e2b2ce9043a5c145d Mon Sep 17 00:00:00 2001 From: Li Date: Sun, 8 Mar 2026 15:57:12 -0700 Subject: [PATCH] [ROCm][Bugfix] Fix MXFP4 MoE emulate fallback logic on MX-capable hardware Fix a Boolean logic regression in QuarkOCP_MX_MoEMethod that prevented fallback to emulation mode on MI350X (gfx950) and other MX-capable hardware, causing gibberish output when AITER CK kernels are incompatible (e.g. ROCm version mismatch). The previous logic: emulate = (not supports_mx() or not scheme.startswith("w_mxfp4")) and (backend is None or not use_aiter_moe) On MI350X with w_mxfp4, the first clause is (False or False) = False, making the entire AND expression always False regardless of whether AITER is available. This silently disabled the emulation fallback and ignored VLLM_ROCM_USE_AITER_MOE=0. The fix restructures the logic to be explicit: can_use_native_ck = supports_mx and w_mxfp4 and aiter_enabled can_use_backend = backend is not None emulate = not (can_use_native_ck or can_use_backend) Also adds: - AITER version logging for easier debugging - Workaround hint in the emulation warning message - Parametrized unit test covering the full dispatch matrix (14 cases) Fixes #36337 Made-with: Cursor Signed-off-by: Li Made-with: Cursor --- tests/quantization/test_quark_moe_emulate.py | 141 ++++++++++++++++++ .../layers/quantization/quark/quark_moe.py | 29 +++- 2 files changed, 165 insertions(+), 5 deletions(-) create mode 100644 tests/quantization/test_quark_moe_emulate.py diff --git a/tests/quantization/test_quark_moe_emulate.py b/tests/quantization/test_quark_moe_emulate.py new file mode 100644 index 000000000000..b4c8d37fe69c --- /dev/null +++ b/tests/quantization/test_quark_moe_emulate.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for the QuarkOCP_MX_MoEMethod emulate dispatch logic. + +The emulate flag determines whether native CK / Triton MXFP4 kernels are +used or whether the computation falls back to high-precision emulation. +A Boolean-logic regression in this flag (PR #29008) caused gibberish +output on MI350X (Issue #36337) because the fallback was silently +disabled on MX-capable hardware. + +These tests verify the flag is set correctly for every relevant +combination of (hardware_support × scheme × aiter_enabled × backend). +No GPU is required — all platform / env-var dependencies are mocked. +""" + +import pytest + + +def _compute_emulate( + supports_mx: bool, + ocp_mx_scheme: str | None, + use_rocm_aiter_moe: bool, + mxfp4_backend_available: bool, +) -> bool: + """Mirror the emulate logic from QuarkOCP_MX_MoEMethod.__init__. + + See vllm/model_executor/layers/quantization/quark/quark_moe.py, + around line 733. + """ + can_use_native_ck = ( + supports_mx + and ocp_mx_scheme is not None + and ocp_mx_scheme.startswith("w_mxfp4") + and use_rocm_aiter_moe + ) + can_use_mxfp4_backend = mxfp4_backend_available + + return not (can_use_native_ck or can_use_mxfp4_backend) + + +# ── Native CK path tests ────────────────────────────────────────────── + + +@pytest.mark.parametrize( + "supports_mx, scheme, aiter_enabled, backend, expected_emulate", + [ + # All conditions met → native CK → no emulation + (True, "w_mxfp4_a_mxfp4", True, False, False), + (True, "w_mxfp4_a_fp8", True, False, False), + (True, "w_mxfp4", True, False, False), + # AITER disabled (VLLM_ROCM_USE_AITER_MOE=0) → must emulate + (True, "w_mxfp4_a_mxfp4", False, False, True), + # Hardware doesn't support MX → must emulate + (False, "w_mxfp4_a_mxfp4", True, False, True), + (False, "w_mxfp4_a_mxfp4", False, False, True), + # Non-mxfp4 scheme → must emulate (no backend either) + (True, "w_mxfp6_e3m2", True, False, True), + (True, "w_mxfp6_e3m2_a_mxfp6_e3m2", True, False, True), + (False, "w_mxfp6_e3m2", True, False, True), + # scheme is None → must emulate + (True, None, True, False, True), + ], + ids=[ + "mi350x-w4a4-aiter_on", + "mi350x-w4afp8-aiter_on", + "mi350x-w4_only-aiter_on", + "mi350x-w4a4-aiter_off", + "no_mx-w4a4-aiter_on", + "no_mx-w4a4-aiter_off", + "mi350x-mxfp6-aiter_on", + "mi350x-mxfp6_sym-aiter_on", + "no_mx-mxfp6-aiter_on", + "mi350x-none_scheme-aiter_on", + ], +) +def test_emulate_native_ck_path( + supports_mx: bool, + scheme: str | None, + aiter_enabled: bool, + backend: bool, + expected_emulate: bool, +): + result = _compute_emulate(supports_mx, scheme, aiter_enabled, backend) + assert result == expected_emulate, ( + f"emulate should be {expected_emulate} for " + f"supports_mx={supports_mx}, scheme={scheme!r}, " + f"aiter_enabled={aiter_enabled}, backend={backend}" + ) + + +# ── Triton mxfp4 backend tests ──────────────────────────────────────── + + +@pytest.mark.parametrize( + "supports_mx, scheme, aiter_enabled, backend, expected_emulate", + [ + # Backend available → no emulation, even without CK + (False, "w_mxfp4", False, True, False), + (True, "w_mxfp4", False, True, False), + # Backend available + CK also available → still no emulation + (True, "w_mxfp4_a_mxfp4", True, True, False), + ], + ids=[ + "no_mx-backend_on-aiter_off", + "mi350x-backend_on-aiter_off", + "mi350x-backend_on-aiter_on", + ], +) +def test_emulate_mxfp4_backend_path( + supports_mx: bool, + scheme: str | None, + aiter_enabled: bool, + backend: bool, + expected_emulate: bool, +): + result = _compute_emulate(supports_mx, scheme, aiter_enabled, backend) + assert result == expected_emulate + + +# ── Regression test for Issue #36337 ────────────────────────────────── + + +def test_regression_issue_36337_aiter_disabled_forces_emulation(): + """On MI350X (supports_mx=True) with w_mxfp4_a_mxfp4 scheme, + setting VLLM_ROCM_USE_AITER_MOE=0 (aiter_enabled=False) MUST + result in emulate=True so the user can fall back to the safe + emulation path when AITER CK kernels are incompatible. + + The old logic (PR #29008) evaluated to emulate=False here because: + (not True or not True) and (...) → (False) and (...) → False + """ + result = _compute_emulate( + supports_mx=True, + ocp_mx_scheme="w_mxfp4_a_mxfp4", + use_rocm_aiter_moe=False, + mxfp4_backend_available=False, + ) + assert result is True, ( + "emulate must be True when AITER is disabled on MI350X — " + "this is the exact regression from Issue #36337" + ) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 0a5db4e71fdb..c05ec66c7a74 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -730,10 +730,19 @@ def __init__( get_current_vllm_config().model_config.hf_config, "model_type", None ) - self.emulate = ( - not current_platform.supports_mx() - or not self.ocp_mx_scheme.startswith("w_mxfp4") - ) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe) + # Native CK path requires MX hardware + w_mxfp4 scheme + AITER MoE. + # The Triton mxfp4 backend is available for weight-only mxfp4. + # If neither path is available, fall back to emulation (dequant to + # high-precision and compute in BF16). + can_use_native_ck = ( + current_platform.supports_mx() + and self.ocp_mx_scheme is not None + and self.ocp_mx_scheme.startswith("w_mxfp4") + and self.use_rocm_aiter_moe + ) + can_use_mxfp4_backend = self.mxfp4_backend is not None + + self.emulate = not (can_use_native_ck or can_use_mxfp4_backend) # CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension # alignment requirements. When violated (e.g. MiniMax-M2.1 with @@ -769,7 +778,9 @@ def __init__( "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision." + "layers computed in high precision. If you see gibberish " + "output with native mode, try VLLM_ROCM_USE_AITER_MOE=0 " + "to force emulation as a workaround." ) else: logger.warning_once( @@ -966,6 +977,14 @@ def process_weights_after_loading(self, layer): from aiter.utility.fp4_utils import e8m0_shuffle + try: + import aiter + + aiter_version = getattr(aiter, "__version__", "unknown") + except ImportError: + aiter_version = "unknown" + logger.info("Using AITER %s for MXFP4 MoE weight processing", aiter_version) + # Pre-shuffle weight scales s0, s1, _ = layer.w13_weight_scale.shape w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)