From f3d0a44ded030e992d0244bc203bc3ce4392d7a5 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 2 Dec 2025 15:32:41 +0000 Subject: [PATCH 01/11] add 2rms1fp8_group_quant fusion pass Signed-off-by: ShaoChunLee --- vllm/_aiter_ops.py | 54 +++++++++++++++ vllm/compilation/rocm_aiter_fusion.py | 95 +++++++++++++++++++++++++++ 2 files changed, 149 insertions(+) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 20d56e76462e..25a8f329958b 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -511,6 +511,52 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_fake( ) +def _rocm_aiter_2rmsnorm_1fp8_group_quant_impl( + x1: torch.Tensor, + x2: torch.Tensor, + weight1: torch.Tensor, + variance_epsilon1: float, + weight2: torch.Tensor, + variance_epsilon2: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + (x_quant, x_quant_scales), _, x2_out, _ = fused_rms_fp8_group_quant( + x1, + weight1, + variance_epsilon1, + x2, + weight2, + variance_epsilon2, + group_size=group_size, + dtype_quant=AITER_FP8_DTYPE, + res1=None, + ) + return (x_quant, x_quant_scales, x2_out) + + +def _rocm_aiter_2rmsnorm_1fp8_group_quant_fake( + x1: torch.Tensor, + x2: torch.Tensor, + weight1: torch.Tensor, + variance_epsilon1: float, + weight2: torch.Tensor, + variance_epsilon2: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + M, N = x1.shape + scale_shape = ( + M, + (N + group_size - 1) // group_size, + ) + return ( + torch.empty_like(x1, dtype=AITER_FP8_DTYPE, device=x1.device), + torch.empty(scale_shape, dtype=torch.float32, device=x1.device), + torch.empty_like(x2, dtype=x2.dtype, device=x1.device), + ) + + def _rocm_aiter_group_fp8_quant_impl( x: torch.Tensor, group_size: int, @@ -768,6 +814,14 @@ def register_ops_once() -> None: fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake, ) + direct_register_custom_op( + op_name="rocm_aiter_2rmsnorm_1fp8_group_quant", + op_func=_rocm_aiter_2rmsnorm_1fp8_group_quant_impl, + mutates_args=[], + fake_impl=_rocm_aiter_2rmsnorm_1fp8_group_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant", op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl, diff --git a/vllm/compilation/rocm_aiter_fusion.py b/vllm/compilation/rocm_aiter_fusion.py index 8b5db9de3818..cd14014c6828 100644 --- a/vllm/compilation/rocm_aiter_fusion.py +++ b/vllm/compilation/rocm_aiter_fusion.py @@ -28,6 +28,7 @@ ) AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default +AITER_2RMS_1GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_2rmsnorm_1fp8_group_quant.default AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default @@ -35,6 +36,8 @@ FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default +SPLIT_WITH_SIZES_OP = torch.ops.aten.split_with_sizes.default + class AiterRMSFp8GroupQuantPattern: """ @@ -79,6 +82,87 @@ def replacement( pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) +class Aiter2RMS1GroupQuantFP8Pattern: + """ + This pattern fuses aiter rms_norm & group fp8 quant custom for input1 and + rms_norm for input2 + ops into an aiter rms_norm_group_fp8_quant op. + """ + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + quant_op: OpOverload, + hidden_size1: int, + hidden_size2: int, + hidden_size3: int, + ): + self.epsilon = epsilon + self.quant_dtype = quant_dtype + self.quant_op = quant_op + self.hidden_size1 = hidden_size1 + self.hidden_size2 = hidden_size2 + self.hidden_size3 = hidden_size3 + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + weight1: torch.Tensor, + weight2: torch.Tensor, + ): + input1, input_split_0 = SPLIT_WITH_SIZES_OP( + input, + [self.hidden_size1, self.hidden_size2 + self.hidden_size3], + dim=-1, + ) + input2, at4 = SPLIT_WITH_SIZES_OP( + input_split_0, [self.hidden_size2, self.hidden_size3], dim=-1 + ) + + at1 = AITER_RMS_OP(x=input1, weight=weight1, variance_epsilon=self.epsilon) + at2 = self.quant_op(at1, 128) + at3 = AITER_RMS_OP(x=input2, weight=weight2, variance_epsilon=self.epsilon) + + return at2[0], at2[1], at3, at4 + + def replacement( + input: torch.Tensor, + weight1: torch.Tensor, + weight2: torch.Tensor, + ): + input1, input_split_0 = SPLIT_WITH_SIZES_OP( + input, + [self.hidden_size1, self.hidden_size2 + self.hidden_size3], + dim=-1, + ) + input2, at4 = SPLIT_WITH_SIZES_OP( + input_split_0, [self.hidden_size2, self.hidden_size3], dim=-1 + ) + + at = AITER_2RMS_1GROUP_QUANT_OP( + x1=input1, + x2=input2, + weight1=weight1, + variance_epsilon1=self.epsilon, + weight2=weight2, + variance_epsilon2=self.epsilon, + group_size=128, + ) + + return at[0], at[1], at[2], at4 + + inputs = [ + empty_bf16( + 5, self.hidden_size1 + self.hidden_size2 + self.hidden_size3 + ), # input + empty_bf16(1, self.hidden_size1), # weight1 + empty_bf16(1, self.hidden_size2), # weight2 + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + class AiterFusedAddRMSFp8GroupQuantPattern: """ This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops @@ -152,6 +236,16 @@ def __init__(self, config: VllmConfig): for epsilon in [1e-5, 1e-6]: # Fuse rms_norm + dynamic group fp8 quant for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]: + for hidden_size1, hidden_size2, hidden_size3 in [(1536, 512, 64)]: + Aiter2RMS1GroupQuantFP8Pattern( + epsilon, + FP8_DTYPE, + quant_op, + hidden_size1, + hidden_size2, + hidden_size3, + ).register(self.patterns) + AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register( self.patterns ) @@ -169,6 +263,7 @@ def __call__(self, graph: fx.Graph): def uuid(self) -> Any: fusion_patterns = [ + Aiter2RMS1GroupQuantFP8Pattern, AiterRMSFp8GroupQuantPattern, AiterFusedAddRMSFp8GroupQuantPattern, ] From f06b18457cae5d889778d0ab7eace4194f4b0e22 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Thu, 4 Dec 2025 22:36:05 +0000 Subject: [PATCH 02/11] moe fix Signed-off-by: ShaoChunLee --- .../layers/quantization/quark/quark_moe.py | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 8be0299eaa66..b71f10eb9000 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, - ocp_mx_moe_quant_config, + mxfp4_w4a16_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -27,7 +27,6 @@ ) from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, - OCP_MX_Scheme, ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -435,13 +434,9 @@ def __init__( self.static_input_scales = not self.input_quant.get("is_dynamic") self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp") - self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp") + self.input_dtype = self.input_quant["dtype"] self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None) - self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( - self.input_dtype, self.weight_dtype - ) - if self.static_input_scales: raise NotImplementedError( "QuarkOCP_MX_MoEMethod with static input scales is currently " @@ -450,14 +445,11 @@ def __init__( self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled() - self.emulate = not current_platform.supports_mx() or not ( - self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" - ) + self.emulate = not current_platform.supports_mx() or not self.use_rocm_aiter_moe if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, " - f"ocp_mx_scheme={self.ocp_mx_scheme}) " "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " @@ -578,14 +570,10 @@ def process_weights_after_loading(self, layer): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - return ocp_mx_moe_quant_config( - quant_dtype=self.input_dtype, - weight_dtype=self.weight_dtype, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=None, - a2_scale=None, - block_shape=None, + # The default mxfp4 recipe is with activation fp16/bf16 dynamic quantzied + # and weight mxfp4 offline quantized. + return mxfp4_w4a16_moe_quant_config( + layer.w13_weight_scale, layer.w2_weight_scale ) @property @@ -625,15 +613,34 @@ def apply( rocm_aiter_fused_experts, ) + if hasattr(torch, "float4_e2m1fn_x2"): + w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) + w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) + else: + w13_weight = layer.w13_weight + w2_weight = layer.w2_weight + # print(f'>>> quant_config: {self.get_fused_moe_quant_config(layer)} {w13_weight.dtype=} {w2_weight.dtype=}') + out = rocm_aiter_fused_experts( x, - layer.w13_weight, - layer.w2_weight, + w13_weight, + w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, - quant_config=self.moe_quant_config, + quant_config=self.get_fused_moe_quant_config(layer), + # expert_map=expert_map, ) + + # out = rocm_aiter_fused_experts( + # x, + # layer.w13_weight, + # layer.w2_weight, + # topk_weights=topk_weights, + # topk_ids=topk_ids, + # activation=activation, + # quant_config=self.get_fused_moe_quant_config(layer), + # ) else: from vllm.model_executor.layers.fused_moe import fused_experts From 2ba9889eb618683d8ee554f06cf34c3231856364 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Thu, 4 Dec 2025 22:36:19 +0000 Subject: [PATCH 03/11] moe fix Signed-off-by: ShaoChunLee --- .../layers/quantization/quark/quark_moe.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index b71f10eb9000..0000bb227b77 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -619,7 +619,6 @@ def apply( else: w13_weight = layer.w13_weight w2_weight = layer.w2_weight - # print(f'>>> quant_config: {self.get_fused_moe_quant_config(layer)} {w13_weight.dtype=} {w2_weight.dtype=}') out = rocm_aiter_fused_experts( x, @@ -629,18 +628,8 @@ def apply( topk_ids=topk_ids, activation=activation, quant_config=self.get_fused_moe_quant_config(layer), - # expert_map=expert_map, + expert_map=expert_map, ) - - # out = rocm_aiter_fused_experts( - # x, - # layer.w13_weight, - # layer.w2_weight, - # topk_weights=topk_weights, - # topk_ids=topk_ids, - # activation=activation, - # quant_config=self.get_fused_moe_quant_config(layer), - # ) else: from vllm.model_executor.layers.fused_moe import fused_experts From 9b08275b357c5647842e13e122f74c2fd9368b94 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Sat, 6 Dec 2025 04:22:50 +0000 Subject: [PATCH 04/11] fp8 fusion Signed-off-by: ShaoChunLee --- vllm/_aiter_ops.py | 197 ++++++++++++- vllm/attention/layer.py | 42 ++- vllm/envs.py | 13 + vllm/model_executor/layers/mla.py | 8 +- .../layers/quantization/quark/quark_moe.py | 47 +-- vllm/model_executor/models/deepseek_v2.py | 248 ++++++++++++---- .../attention/backends/mla/rocm_aiter_mla.py | 269 +++++++++++++++++- 7 files changed, 725 insertions(+), 99 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 5a8bb10c0c9c..598662f694cd 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -652,7 +652,154 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake( ) return x_fp8, out_bs +def _rocm_aiter_triton_fused_shared_expert_fp8_impl( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_gate_up: torch.Tensor, + weight_scale_gate_up: torch.Tensor, + hidden_states_moe_gate: torch.Tensor, + weight_moe_gate: torch.Tensor, + bias_shared: torch.Tensor, + bias_moe_gate: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_gemm_a8w8_blockscale_a16w16 import fused_gemm_a8w8_blockscale_a16w16 + from aiter.ops.triton.fused_fp8_quant import fused_reduce_act_mul_fp8_group_quant + + shared_output, router_logits = fused_gemm_a8w8_blockscale_a16w16(hidden_states_shared, weight_gate_up, hidden_states_shared_scale, weight_scale_gate_up, hidden_states_moe_gate, weight_moe_gate, + bias_fp8=bias_shared, bias_bf16=bias_moe_gate, dtype=hidden_states_moe_gate.dtype, skip_reduce=True) + if shared_output.dim() == 3: + (shared_output_q, shared_output_s), router_logits = fused_reduce_act_mul_fp8_group_quant(shared_output, activation="silu", x2=router_logits, group_size=128, dtype_quant=AITER_FP8_DTYPE) + else: + (shared_output_q, shared_output_s), _ = fused_reduce_act_mul_fp8_group_quant(shared_output, activation="silu", x2=None, group_size=128, dtype_quant=AITER_FP8_DTYPE) + return shared_output_q, shared_output_s, router_logits + +def _rocm_aiter_triton_fused_shared_expert_fp8_fake( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_gate_up: torch.Tensor, + weight_scale_gate_up: torch.Tensor, + hidden_states_moe_gate: torch.Tensor, + weight_moe_gate: torch.Tensor, + bias_shared: torch.Tensor, + bias_moe_gate: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + M = hidden_states_shared.shape[0] + N = weight_gate_up.shape[0] + N_moe = weight_moe_gate.shape[0] + device = hidden_states_shared.device + group_size = 128 + assert N % 2 == 0 + N_half = N // 2 + assert N_half == N_moe, f"{weight_moe_gate.shape}" + shared_output_q = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=device) + shared_output_s = torch.empty((M, (N_half + group_size - 1) // group_size), dtype=torch.float32, device=device) + router_logits = torch.empty((M, N_moe), dtype=hidden_states_moe_gate.dtype, device=device) + return shared_output_q, shared_output_s, router_logits + + +def _rocm_aiter_triton_fused_down_proj_mul_add_fp8_impl( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_down_proj: torch.Tensor, + weight_scale_down_proj: torch.Tensor, + routed_scaling_factor: float, + final_hidden_states: torch.Tensor, +) -> torch.Tensor: + from aiter.ops.triton.fused_gemm_a8w8_blockscale_mul_add import fused_gemm_a8w8_blockscale_mul_add + + out = fused_gemm_a8w8_blockscale_mul_add(hidden_states_shared, weight_down_proj, hidden_states_shared_scale, weight_scale_down_proj, routed_scaling_factor, final_hidden_states, fuse_type=1) + return out + +def _rocm_aiter_triton_fused_down_proj_mul_add_fp8_fake( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_down_proj: torch.Tensor, + weight_scale_down_proj: torch.Tensor, + routed_scaling_factor: float, + final_hidden_states: torch.Tensor, +) -> torch.Tensor: + out = torch.empty_like(final_hidden_states) + return out + +def _rocm_aiter_triton_fused_shared_expert_fp4_impl( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_gate_up: torch.Tensor, + weight_scale_gate_up: torch.Tensor, + hidden_states_moe_gate: torch.Tensor, + weight_moe_gate: torch.Tensor, + bias_shared: torch.Tensor, + bias_moe_gate: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_gemm_afp4wfp4_a16w16 import fused_gemm_afp4wfp4_a16w16 + from aiter.ops.triton.fused_mxfp4_quant import fused_reduce_act_mul_and_mxfp4_quant + + shared_output, router_logits = fused_gemm_afp4wfp4_a16w16(hidden_states_shared, weight_gate_up, hidden_states_shared_scale, weight_scale_gate_up.T, hidden_states_moe_gate, weight_moe_gate, + is_fp4_preshuffled=False, bias_fp4=bias_shared, bias_bf16=bias_moe_gate, dtype=hidden_states_moe_gate.dtype, skip_reduce=True) + if shared_output.dim() == 3: + (shared_output_q, shared_output_s), router_logits = fused_reduce_act_mul_and_mxfp4_quant(shared_output, activation="silu", x2=router_logits, shuffle=False, scale_shuffle_padding=False, dtype=hidden_states_moe_gate.dtype) + else: + (shared_output_q, shared_output_s), _ = fused_reduce_act_mul_and_mxfp4_quant(shared_output, activation="silu", x2=None, shuffle=False, scale_shuffle_padding=False, dtype=hidden_states_moe_gate.dtype) + + # assert bias_shared is None + # shared_output = gemm_afp4wfp4(hidden_states_shared, weight_gate_up, hidden_states_shared_scale, weight_scale_gate_up.T) + # router_logits = gemm_a16w16(hidden_states_moe_gate, weight_moe_gate, bias=bias_moe_gate) + # shared_output_q, shared_output_s = act_mul_and_mxfp4_quant(shared_output, activation="silu", shuffle=False, scale_shuffle_padding=False) + + return shared_output_q, shared_output_s, router_logits + + +def _rocm_aiter_triton_fused_shared_expert_fp4_fake( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_gate_up: torch.Tensor, + weight_scale_gate_up: torch.Tensor, + hidden_states_moe_gate: torch.Tensor, + weight_moe_gate: torch.Tensor, + bias_shared: torch.Tensor, + bias_moe_gate: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + M = hidden_states_shared.shape[0] + N = weight_gate_up.shape[0] + N_moe = weight_moe_gate.shape[0] + device = hidden_states_shared.device + group_size = 32 + assert N % 4 == 0 + N_half = N // 2 + assert N_half == 256, f"{weight_gate_up.shape}" + assert N_half == N_moe, f"{weight_moe_gate.shape}" + shared_output_q = torch.empty((M, N_half // 2), dtype=torch.uint8, device=device) + shared_output_s = torch.empty((M, (N_half + group_size - 1) // group_size), dtype=torch.uint8, device=device) + router_logits = torch.empty((M, N_moe), dtype=hidden_states_moe_gate.dtype, device=device) + return shared_output_q, shared_output_s, router_logits + + +def _rocm_aiter_triton_fused_down_proj_mul_add_fp4_impl( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_down_proj: torch.Tensor, + weight_scale_down_proj: torch.Tensor, + routed_scaling_factor: float, + final_hidden_states: torch.Tensor, +) -> torch.Tensor: + from aiter.ops.triton.fused_gemm_afp4wfp4_mul_add import fused_gemm_afp4wfp4_mul_add + + out = fused_gemm_afp4wfp4_mul_add(hidden_states_shared, weight_down_proj, hidden_states_shared_scale, weight_scale_down_proj.T, routed_scaling_factor, final_hidden_states, fuse_type=1) + return out + + +def _rocm_aiter_triton_fused_down_proj_mul_add_fp4_fake( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_down_proj: torch.Tensor, + weight_scale_down_proj: torch.Tensor, + routed_scaling_factor: float, + final_hidden_states: torch.Tensor, +) -> torch.Tensor: + out = torch.empty_like(final_hidden_states) + return out + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -667,9 +814,11 @@ class rocm_aiter_ops: _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + _FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + _TRITON_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM @classmethod @@ -706,6 +855,11 @@ def is_fused_moe_enabled(cls) -> bool: @if_aiter_supported def is_fusion_moe_shared_experts_enabled(cls) -> bool: return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED + + @classmethod + @if_aiter_supported + def is_fusion_triton_shared_experts_enabled(cls) -> bool: + return cls.is_fused_moe_enabled() and cls._TRITON_SHARED_EXPERTS_ENABLED @classmethod @if_aiter_supported @@ -735,6 +889,11 @@ def is_triton_unified_attn_enabled(cls) -> bool: @if_aiter_supported def is_fp8bmm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FP8BMM_ENABLED + + @classmethod + @if_aiter_supported + def is_fp4bmm_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._FP4BMM_ENABLED @classmethod @if_aiter_supported @@ -874,6 +1033,40 @@ def register_ops_once() -> None: op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl, fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake, ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_fused_shared_expert_fp8", + op_func=_rocm_aiter_triton_fused_shared_expert_fp8_impl, + mutates_args=[], + fake_impl=_rocm_aiter_triton_fused_shared_expert_fp8_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_fused_down_proj_mul_add_fp8", + op_func=_rocm_aiter_triton_fused_down_proj_mul_add_fp8_impl, + mutates_args=[], + fake_impl=_rocm_aiter_triton_fused_down_proj_mul_add_fp8_fake, + dispatch_key=current_platform.dispatch_key, + ) + + + direct_register_custom_op( + op_name="rocm_aiter_triton_fused_shared_expert_fp4", + op_func=_rocm_aiter_triton_fused_shared_expert_fp4_impl, + mutates_args=[], + fake_impl=_rocm_aiter_triton_fused_shared_expert_fp4_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_fused_down_proj_mul_add_fp4", + op_func=_rocm_aiter_triton_fused_down_proj_mul_add_fp4_impl, + mutates_args=[], + fake_impl=_rocm_aiter_triton_fused_down_proj_mul_add_fp4_fake, + dispatch_key=current_platform.dispatch_key, + ) + _OPS_REGISTERED = True @staticmethod @@ -1250,7 +1443,5 @@ def shuffle_weights( """ from aiter.ops.shuffle import shuffle_weight - return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) - - + return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) rocm_aiter_ops.register_ops_once() diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index da5a62617129..bfffc7dda7f0 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -587,6 +587,7 @@ def __init__( prefix: str = "", use_sparse: bool = False, indexer: object | None = None, + rotary_emb: nn.Module | None = None, **extra_impl_args, ): super().__init__() @@ -646,6 +647,7 @@ def __init__( indexer=indexer, **extra_impl_args, ) + self.impl.rotary_emb = rotary_emb self.use_direct_call = not current_platform.opaque_attention_op() @@ -674,6 +676,7 @@ def forward( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output_shape: torch.Size | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) @@ -710,6 +713,7 @@ def forward( k_pe, output, self.layer_name, + positions, ) return output else: @@ -937,21 +941,37 @@ def unified_mla_attention_with_output( k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) - self.impl.forward( - self, - q, - kv_c_normed, - k_pe, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - output_block_scale=output_block_scale, - ) + from vllm.v1.attention.backends.mla.rocm_aiter_mla import AiterMLAImpl + if isinstance(self.impl, AiterMLAImpl): + self.impl.forward( + self, + q, + kv_c_normed, + k_pe, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + positions=positions, + ) + else: + self.impl.forward( + self, + q, + kv_c_normed, + k_pe, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + ) def unified_mla_attention_with_output_fake( diff --git a/vllm/envs.py b/vllm/envs.py index 37711dece9ab..aa49fa2edf58 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -117,8 +117,10 @@ VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True + VLLM_ROCM_USE_AITER_FP4BMM: bool = False VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False + VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS: bool = False VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True @@ -981,6 +983,11 @@ def get_vllm_port() -> int | None: "VLLM_ROCM_USE_AITER_FP8BMM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1") ), + # Whether to use aiter triton fp4 bmm kernel + # By default is disabled. + "VLLM_ROCM_USE_AITER_FP4BMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "False").lower() in ("true", "1") + ), # Use AITER triton unified attention for V1 attention "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower() @@ -992,6 +999,12 @@ def get_vllm_port() -> int | None: os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "False").lower() in ("true", "1") ), + # Whether to use aiter fusion triton shared experts ops. + # By default is disabled. + "VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS", "False").lower() + in ("true", "1") + ), # Whether to use aiter triton kernels for gemm ops. # By default is enabled. "VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: ( diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index dad960160f2a..18f43126baeb 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -8,7 +8,7 @@ from vllm.config import CacheConfig from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization import QuantizationConfig - +from vllm.platforms import current_platform @dataclass class MLAModules: @@ -103,6 +103,7 @@ def __init__( kv_b_proj=self.kv_b_proj, use_sparse=self.is_sparse, indexer=self.indexer, + rotary_emb = self.rotary_emb if current_platform.is_rocm() else None ) self.prefix = prefix @@ -150,7 +151,7 @@ def forward_native( # Add head dim of 1 to k_pe k_pe = k_pe.unsqueeze(1) - if self.rotary_emb is not None: + if self.rotary_emb is not None and not current_platform.is_rocm(): q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( positions, q[..., self.qk_nope_head_dim :], k_pe ) @@ -163,12 +164,13 @@ def forward_native( if llama_4_scaling is not None: q *= llama_4_scaling + positions_rocm = None if not current_platform.is_rocm() else positions attn_out = self.mla_attn( q, kv_c_normed, k_pe, output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), - ) + positions=positions_rocm) return self.o_proj(attn_out)[0] diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 0000bb227b77..b808acaa7862 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, - mxfp4_w4a16_moe_quant_config, + ocp_mx_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -27,6 +27,7 @@ ) from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, + OCP_MX_Scheme, ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -434,9 +435,13 @@ def __init__( self.static_input_scales = not self.input_quant.get("is_dynamic") self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp") - self.input_dtype = self.input_quant["dtype"] + self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp") self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None) + self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( + self.input_dtype, self.weight_dtype + ) + if self.static_input_scales: raise NotImplementedError( "QuarkOCP_MX_MoEMethod with static input scales is currently " @@ -445,11 +450,14 @@ def __init__( self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled() - self.emulate = not current_platform.supports_mx() or not self.use_rocm_aiter_moe + self.emulate = not current_platform.supports_mx() or not ( + self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" + ) if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, " + f"ocp_mx_scheme={self.ocp_mx_scheme}) " "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " @@ -570,10 +578,14 @@ def process_weights_after_loading(self, layer): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - # The default mxfp4 recipe is with activation fp16/bf16 dynamic quantzied - # and weight mxfp4 offline quantized. - return mxfp4_w4a16_moe_quant_config( - layer.w13_weight_scale, layer.w2_weight_scale + return ocp_mx_moe_quant_config( + quant_dtype=self.input_dtype, + weight_dtype=self.weight_dtype, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None, + block_shape=None, ) @property @@ -613,22 +625,21 @@ def apply( rocm_aiter_fused_experts, ) - if hasattr(torch, "float4_e2m1fn_x2"): - w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) - w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) - else: - w13_weight = layer.w13_weight - w2_weight = layer.w2_weight + # if hasattr(torch, "float4_e2m1fn_x2"): + # w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) + # w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) + # else: + # w13_weight = layer.w13_weight + # w2_weight = layer.w2_weight out = rocm_aiter_fused_experts( x, - w13_weight, - w2_weight, + layer.w13_weight, + layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, - quant_config=self.get_fused_moe_quant_config(layer), - expert_map=expert_map, + quant_config=self.moe_quant_config, ) else: from vllm.model_executor.layers.fused_moe import fused_experts @@ -646,4 +657,4 @@ def apply( expert_map=expert_map, quant_config=self.moe_quant_config, ) - return out + return out \ No newline at end of file diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a8eb4a69b6f2..8ef55c787e18 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -50,6 +50,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -265,11 +266,15 @@ def __init__( quant_config=None, prefix=f"{prefix}.gate", ) - if getattr(config, "topk_method", None) == "noaux_tc": + self.is_fusion_triton_shared_experts_enabled = rocm_aiter_ops.is_fusion_triton_shared_experts_enabled() + if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts, dtype=torch.float32) - ) + torch.empty(config.n_routed_experts, dtype=torch.float32)) + e_score_correction_bias = self.gate.e_score_correction_bias + if self.is_fusion_triton_shared_experts_enabled: + e_score_correction_bias = self.gate.e_score_correction_bias.to(torch.bfloat16) else: + e_score_correction_bias = None self.gate.e_score_correction_bias = None # Load balancing settings. @@ -290,9 +295,45 @@ def __init__( self.is_fusion_moe_shared_experts_enabled = ( rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() ) - if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled: - self.shared_experts = None - else: + if self.is_fusion_triton_shared_experts_enabled: + self.use_triton_fused_shared_expert_fp8 = False + self.use_triton_fused_shared_expert_fp4 = False + self.rocm_aiter_triton_fused_shared_expert_func = None + self.rocm_aiter_triton_fused_down_proj_mul_add_func = None + if self.is_fusion_triton_shared_experts_enabled: + assert config.n_shared_experts is not None, f"config.n_shared_experts == None is detected in {self.__class__.__name__} please turn off VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS" + if quant_config.get_name() == 'fp8': + self.use_triton_fused_shared_expert_fp8 = True + self.rocm_aiter_triton_fused_shared_expert_func = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert_fp8 + self.rocm_aiter_triton_fused_down_proj_mul_add_func = torch.ops.vllm.rocm_aiter_triton_fused_down_proj_mul_add_fp8 + elif quant_config.get_name() == 'quark': + self.use_triton_fused_shared_expert_fp4 = True + self.rocm_aiter_triton_fused_shared_expert_func = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert_fp4 + self.rocm_aiter_triton_fused_down_proj_mul_add_func = torch.ops.vllm.rocm_aiter_triton_fused_down_proj_mul_add_fp4 + else: + raise NotImplementedError(f"{quant_config.get_name()=} which is not supported for VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS") + logger.info(f"[Aiter] {self.__class__.__name__} is registered with {self.rocm_aiter_triton_fused_shared_expert_func.__name__} and {self.rocm_aiter_triton_fused_down_proj_mul_add_func.__name__}") + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=getattr(config, "n_group", 1), + topk_group=getattr(config, "topk_group", 1), + prefix=f"{prefix}.experts", + scoring_func=getattr(config, "scoring_func", "softmax"), + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP( @@ -304,37 +345,58 @@ def __init__( reduce_results=False, prefix=f"{prefix}.shared_experts", ) - - self.experts = SharedFusedMoE( - shared_experts=self.shared_experts, - gate=self.gate, - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=getattr(config, "n_group", 1), - topk_group=getattr(config, "topk_group", 1), - prefix=f"{prefix}.experts", - scoring_func=getattr(config, "scoring_func", "softmax"), - # we do scaling outside, set factor to 1.0 to avoid double mul - # aiter applies routed_scaling_factor internally - routed_scaling_factor=1.0 - if not self.is_rocm_aiter_moe_enabled - else self.routed_scaling_factor, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel, - n_shared_experts=config.n_shared_experts - if self.is_fusion_moe_shared_experts_enabled - else None, + else: + if config.n_shared_experts is None: + self.shared_experts = None + else: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + is_sequence_parallel=self.is_sequence_parallel, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + gate=self.gate, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=getattr(config, "n_group", 1), + topk_group=getattr(config, "topk_group", 1), + prefix=f"{prefix}.experts", + scoring_func=getattr(config, "scoring_func", "softmax"), + # we do scaling outside, set factor to 1.0 to avoid double mul + # aiter applies routed_scaling_factor internally + routed_scaling_factor=1.0 + if not self.is_rocm_aiter_moe_enabled + else self.routed_scaling_factor, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + n_shared_experts=config.n_shared_experts + if self.is_fusion_moe_shared_experts_enabled + else None, + skip_shared_experts = self.is_fusion_triton_shared_experts_enabled, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: tuple | torch.Tensor) -> torch.Tensor: + if isinstance(hidden_states, tuple): + hidden_states_shared, hidden_states = hidden_states + else: + hidden_states_shared = hidden_states + num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -345,34 +407,70 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) - if self.experts.is_internal_router: - # In this case, the gate/router runs inside the FusedMoE class - fused_moe_out = self.experts( - hidden_states=hidden_states, router_logits=hidden_states + if self.is_fusion_triton_shared_experts_enabled: + # assert isinstance(hidden_states_shared, tuple), f"hidden_states_shared must be a tuple of quantized acitvation and scales" + shared_output = None + shared_output_q, shared_output_s = None, None + hidden_states_shared, hidden_states_shared_scale = hidden_states_shared + shared_output_q, shared_output_s, router_logits = ( + self.rocm_aiter_triton_fused_shared_expert_func( + hidden_states_shared=hidden_states_shared, + hidden_states_shared_scale=hidden_states_shared_scale, + weight_gate_up=self.shared_experts.gate_up_proj.weight, + weight_scale_gate_up=self.shared_experts.gate_up_proj.weight_scale, + hidden_states_moe_gate=hidden_states, + weight_moe_gate=self.gate.weight, + bias_shared=( + self.shared_experts.gate_up_proj.bias + if not self.shared_experts.gate_up_proj.skip_bias_add + else None + ), + bias_moe_gate=( + self.gate.bias if not self.gate.skip_bias_add else None + ), + ) ) - else: - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - fused_moe_out = self.experts( + # shared_output = self.shared_experts(hidden_states) + # router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) + else: + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) + else: + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) - shared_output, final_hidden_states = fused_moe_out - if self.shared_experts is None: - assert shared_output is None + shared_output, final_hidden_states = fused_moe_out + if self.shared_experts is None: + assert shared_output is None # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. - if hidden_states.dtype != torch.float16: - if not self.is_rocm_aiter_moe_enabled: - final_hidden_states *= self.routed_scaling_factor - elif self.shared_experts is not None: - assert shared_output is not None - shared_output *= 1.0 / self.routed_scaling_factor + if self.is_fusion_triton_shared_experts_enabled and hidden_states.dtype != torch.float16: + assert shared_output is None + final_hidden_states = self.rocm_aiter_triton_fused_down_proj_mul_add_func(shared_output_q, shared_output_s, self.shared_experts.down_proj.weight, self.shared_experts.down_proj.weight_scale, self.routed_scaling_factor, final_hidden_states) + # assert shared_output is not None + # final_hidden_states *= self.routed_scaling_factor + # final_hidden_states += shared_output + else: + if hidden_states.dtype != torch.float16: + if not self.is_rocm_aiter_moe_enabled: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= 1.0 / self.routed_scaling_factor - if self.shared_experts is not None: - assert shared_output is not None - final_hidden_states += shared_output + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( @@ -1125,6 +1223,17 @@ def __init__( # with the layer's index. layer_idx = int(prefix.split(sep=".")[-1]) self.layer_idx = layer_idx + self.is_fusion_triton_shared_experts_enabled = rocm_aiter_ops.is_fusion_triton_shared_experts_enabled() + self.use_triton_fused_rmsnorm_fp8_quant = False + self.use_triton_fused_rmsnorm_fp4_quant = False + if self.is_fusion_triton_shared_experts_enabled: + if quant_config.get_name() == 'fp8': + self.use_triton_fused_rmsnorm_fp8_quant = True + elif quant_config.get_name() == 'quark': + self.use_triton_fused_rmsnorm_fp4_quant = True + else: + raise NotImplementedError(f"{quant_config.get_name()=} which is not supported for VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS") + logger.info(f"[Aiter] {self.__class__.__name__} has {quant_config.get_name()=}") # verify MLA attention specific fields qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) @@ -1221,7 +1330,34 @@ def forward( residual *= 1.0 / self.routed_scaling_factor # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + if self.is_fusion_triton_shared_experts_enabled and isinstance(self.mlp, DeepseekV2MoE): + weight = self.post_attention_layernorm.weight + eps = self.post_attention_layernorm.variance_epsilon + if self.use_triton_fused_rmsnorm_fp8_quant: + from vllm._aiter_ops import AITER_FP8_DTYPE + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + (hidden_states_quant, hidden_states_quant_scales), hidden_states_unquant, _, residual = fused_rms_fp8_group_quant(hidden_states, weight, eps, + None, None, eps, + group_size=128, + dtype_quant=AITER_FP8_DTYPE, + res1=residual, + output_unquantized_inp1=isinstance(self.mlp, DeepseekV2MoE)) + else: + from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant + (hidden_states_quant, hidden_states_quant_scales), hidden_states_unquant, _, residual = fused_rms_mxfp4_quant(hidden_states, weight, eps, + None, None, eps, + res1=residual, + shuffle=False, + scale_shuffle_padding=False, + output_unquantized_inp1=isinstance(self.mlp, DeepseekV2MoE)) + + if isinstance(self.mlp, DeepseekV2MoE): + hidden_states = ((hidden_states_quant, hidden_states_quant_scales), hidden_states_unquant) + else: + hidden_states = (hidden_states_quant, hidden_states_quant_scales) + else: + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 00a0a77a1c2f..9c688c3f88e8 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -6,6 +6,7 @@ import torch +from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import AttentionLayer, MultipleOf from vllm.config import VllmConfig @@ -18,7 +19,12 @@ ) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank +from vllm.attention.ops.common import cp_lse_ag_out_rs +from vllm.platforms import current_platform +from typing import ClassVar, Generic, TypeVar +M = TypeVar("M", bound=MLACommonMetadata) class AiterMLABackend(MLACommonBackend): @staticmethod @@ -200,6 +206,7 @@ def __init__( kv_sharing_target_layer_name, **mla_args, ) + self.is_aiter_triton_fp4_bmm_enabled = rocm_aiter_ops.is_fp4bmm_enabled() assert num_heads == 16 or num_heads == 128, ( f"Aiter MLA only supports 16 or 128 number of heads.\n" f"Provided {num_heads} number of heads.\n" @@ -236,22 +243,32 @@ def _forward_decode( kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AiterMLAMetadata, layer: AttentionLayer, + mla_output_zeros: torch.Tensor | None = None, + decode_q_cat: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - if type(q) is tuple: + if decode_q_cat is not None: + q = decode_q_cat + elif type(q) is tuple: q = torch.cat(q, dim=-1) assert isinstance(q, torch.Tensor) B = q.shape[0] - o = torch.zeros( - B, - self.num_heads, - self.kv_lora_rank, - dtype=attn_metadata.decode.attn_out_dtype, - device=q.device, - ) + if mla_output_zeros is not None: + o = mla_output_zeros + assert o.shape[0] == B, f"{o.shape[0]=} {B=}" + assert o.shape[1] == self.num_heads, f"{o.shape[1]=} {self.num_heads=}" + assert o.shape[2] == self.kv_lora_rank, f"{o.shape[2]=} {self.kv_lora_rank=}" + else: + o = torch.zeros( + B, + self.num_heads, + self.kv_lora_rank, + dtype=attn_metadata.decode.attn_out_dtype, + device=q.device, + ) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) @@ -273,3 +290,239 @@ def _forward_decode( ) return o, None + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: M, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for MLACommonImpl" + ) + + if attn_metadata is None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + ( + self.chunked_prefill_workspace_size, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + if self.dcp_world_size is None: + self.dcp_world_size = get_dcp_group().world_size + + fp8_attention = self.kv_cache_dtype.startswith("fp8") + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + mla_output_zeros = None + decode_q_cat = None + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + if positions is not None: + # positions is not None entails that Q and K are not RoPE embedded yet, therefore, fused_qk_rope_cat_and_cache_mla is called + assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}" + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla + cos, sin = self.rotary_emb.cos_sin_cache.chunk(2, dim = -1) + is_neox = self.rotary_emb.is_neox_style + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_out_dtype = current_platform.fp8_dtype() if fp8_attention else q.dtype + if self.is_aiter_triton_fp4_bmm_enabled or self.is_aiter_triton_fp8_bmm_enabled: + decode_q_cat = torch.empty((num_decode_tokens, self.num_heads, self.W_K.shape[1] + self.qk_rope_head_dim), dtype = q_out_dtype, device=q.device) + if fp8_attention: + kv_cache_og_dtype = kv_cache.dtype + kv_cache = kv_cache.view(q_out_dtype) + fused_output = fused_qk_rope_cat_and_cache_mla( + q_nope, + q_pe, + k_c_normed.unsqueeze(1), + k_pe, + kv_cache, + attn_metadata.slot_mapping.flatten(), + positions, + cos, + sin, + layer._k_scale, + is_neox, + num_decode_toks_for_zeros=num_decode_tokens, + apply_scale=(k_pe.dtype != kv_cache.dtype), + q_out=None, + decode_q_pe_out = decode_q_cat[... , -self.qk_rope_head_dim:] if self.is_aiter_triton_fp4_bmm_enabled or self.is_aiter_triton_fp8_bmm_enabled else None, + k_pe_out=k_pe, + ) + if num_decode_tokens > 0: + q, _, k_pe, kv_cache, mla_output_zeros = fused_output + else: + q, _, k_pe, kv_cache = fused_output + if fp8_attention: + kv_cache = kv_cache.view(kv_cache_og_dtype) + else: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + decode_q = q[:num_decode_tokens] + + prefill_q = q[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + + if has_prefill: + self._forward_prefill( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + kv_cache, + attn_metadata, + layer._k_scale, + output=output[num_decode_tokens:], + ) + + if has_decode: + assert attn_metadata.decode is not None + + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + + if self.q_pad_num_heads is not None: + B, N, L = decode_q_pe.shape + decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) + decode_pe_padded.resize_((B, N, L)) + decode_pe_padded.copy_(decode_q_pe) + decode_q_pe = decode_pe_padded + + if self.is_aiter_triton_fp4_bmm_enabled: + #x = x.view(-1, self.num_heads, self.kv_lora_rank) + decode_ql_nope = decode_q_cat[... , :self.W_K.shape[1]] if (kv_cache.numel() > 0 and positions is not None) else None + # decode_ql_nope = batched_gemm_a16wfp4( + # decode_q_nope, + # self.W_K, + # self.W_K_scale, + # y=decode_ql_nope, + # transpose_bm=True, + # prequant=True, + # y_scale=layer._q_scale if fp8_attention else None, + # ) + # decode_ql_nope = decode_ql_nope.transpose(0, 1) + elif self.is_aiter_triton_fp8_bmm_enabled: + # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) + decode_ql_nope = decode_q_cat[... , :self.W_K.shape[1]] if (kv_cache.numel() > 0 and positions is not None) else None + decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( + decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + YQ=decode_ql_nope, + transpose_bm=True, + ) + else: + # Pads the head_dim if necessary (for the underlying kernel) + N, B, P = decode_q_nope.shape + _, _, L = self.W_UK_T.shape + + if self.q_pad_num_heads is not None: + decode_ql_nope = decode_q_nope.new_empty( + (self.q_pad_num_heads, B, L) + ) + decode_ql_nope.resize_((N, B, L)) + else: + decode_ql_nope = decode_q_nope.new_empty((N, B, L)) + + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) + + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + + if fp8_attention and not (self.is_aiter_triton_fp4_bmm_enabled or self.is_aiter_triton_fp8_bmm_enabled): + ql_nope_shape = decode_ql_nope.shape + decode_ql_nope, _ = ops.scaled_fp8_quant( + decode_ql_nope.reshape( + [ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]] + ), + layer._q_scale, + ) + decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) + q_pe_shape = decode_q_pe.shape + decode_q_pe, _ = ops.scaled_fp8_quant( + decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale, + ) + decode_q_pe = decode_q_pe.reshape(q_pe_shape) + + decode_q = (decode_ql_nope, decode_q_pe) + if self.dcp_world_size > 1: + assert not fp8_attention, "DCP not support fp8 kvcache now." + # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) + decode_q = torch.cat(decode_q, dim=-1) + # decode_q do allgather in head dim. + decode_q = get_dcp_group().all_gather(decode_q, dim=1) + + # call decode attn + attn_out, lse = self._forward_decode( + decode_q, kv_cache, attn_metadata, layer, mla_output_zeros=mla_output_zeros, decode_q_cat=decode_q_cat + ) + + # correct dcp attn_out with lse. + if self.dcp_world_size > 1: + attn_out = cp_lse_ag_out_rs( + attn_out, + lse, + get_dcp_group(), + is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), + ) + + # v_up projection + self._v_up_proj(attn_out, out=output[:num_decode_tokens]) + return output_padded From 690f8becac636fa27c4e401c3848e0d67eaa7711 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Mon, 8 Dec 2025 17:49:04 +0000 Subject: [PATCH 05/11] add fp4 bmm and qkv_a_proj_layernorm --- vllm/_aiter_ops.py | 236 +++++++++++++++++- vllm/compilation/rocm_aiter_fusion.py | 42 ++++ vllm/model_executor/layers/mla.py | 102 +++++--- .../quark/schemes/quark_ocp_mx.py | 8 + .../layers/quantization/quark/utils.py | 99 ++++++++ .../layers/quantization/utils/fp8_utils.py | 1 - vllm/model_executor/models/deepseek_v2.py | 51 +++- vllm/v1/attention/backends/mla/common.py | 66 ++++- .../attention/backends/mla/rocm_aiter_mla.py | 17 +- 9 files changed, 569 insertions(+), 53 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 598662f694cd..5a9b323f8a5c 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -421,6 +421,54 @@ def _rocm_aiter_gemm_a8w8_blockscale_fake( return Y +def _rocm_aiter_triton_gemm_afp4wfp4_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 + + return gemm_afp4wfp4(A, B, As, Bs.T, dtype=output_dtype) + + +def _rocm_aiter_triton_gemm_afp4wfp4_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +def _rocm_aiter_triton_gemm_a16w8_blockscale_impl( + A: torch.Tensor, + B: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter.ops.triton.gemm_a16w8_blockscale import gemm_a16w8_blockscale + + return gemm_a16w8_blockscale(A, B, Bs, dtype=output_dtype) + + +def _rocm_aiter_triton_gemm_a16w8_blockscale_fake( + A: torch.Tensor, + B: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + def _rocm_aiter_rms_norm_impl( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: @@ -799,7 +847,126 @@ def _rocm_aiter_triton_fused_down_proj_mul_add_fp4_fake( ) -> torch.Tensor: out = torch.empty_like(final_hidden_states) return out + +def _rocm_aiter_triton_qkv_a_proj_layernorm_fp8_impl( + hidden_states_quant: torch.Tensor, + hidden_states_quant_scale: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_fp8_quant import fused_reduce_rms_fp8_group_quant + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale + import aiter as rocm_aiter + qkv_lora = gemm_a8w8_blockscale(hidden_states_quant, weight_qkv_a_proj, hidden_states_quant_scale, weight_scale_qkv_a_proj, skip_reduce=True) + q_c, kv_c, k_pe = qkv_lora.split([q_lora_rank, kv_lora_rank, qk_rope_head_dim], + dim=-1, + ) + k_pe_reduced = None + k_pe_reduced_out = None + if k_pe.dim() == 3: + M = hidden_states_quant.shape[0] + device = hidden_states_quant.device + k_pe_reduced = k_pe + k_pe_reduced_out = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + (q_c, q_c_scale), _, kv_c_normed, _, k_pe_reduced_out = fused_reduce_rms_fp8_group_quant(q_c, q_a_layernorm_weight, q_a_layernorm_variance_epsilon, + kv_c, kv_a_layernorm_weight, kv_a_layernorm_variance_epsilon, k_pe_reduced, + group_size=128, + dtype_quant=AITER_FP8_DTYPE, + dtype=torch.bfloat16, + res1=None, + out3=k_pe_reduced_out) + if k_pe_reduced_out is not None: + k_pe = k_pe_reduced_out + return q_c, q_c_scale, kv_c_normed, k_pe + +def _rocm_aiter_triton_qkv_a_proj_layernorm_fp8_fake( + hidden_states_quant: torch.Tensor, + hidden_states_quant_scale: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + M = hidden_states_quant.shape[0] + device = hidden_states_quant.device + q_c = torch.empty((M, q_lora_rank), dtype=AITER_FP8_DTYPE, device=device) + q_c_scale = torch.empty((M, (q_lora_rank + 128 - 1) // 128), dtype=torch.float32, device=device) + kv_c_normed = torch.empty((M, kv_lora_rank), dtype=torch.bfloat16, device=device) + k_pe = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + return q_c, q_c_scale, kv_c_normed, k_pe + +def _rocm_aiter_triton_qkv_a_proj_layernorm_fp4_impl( + hidden_states_quant: torch.Tensor, + hidden_states_quant_scale: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.fused_mxfp4_quant import fused_reduce_rms_mxfp4_quant + + qkv_lora = gemm_afp4wfp4(hidden_states_quant, weight_qkv_a_proj, hidden_states_quant_scale, weight_scale_qkv_a_proj.T, skip_reduce=True) + q_c, kv_c, k_pe = qkv_lora.split([q_lora_rank, kv_lora_rank, qk_rope_head_dim], + dim=-1, + ) + k_pe_reduced = None + k_pe_reduced_out = None + if k_pe.dim() == 3: + M = hidden_states_quant.shape[0] + device = hidden_states_quant.device + k_pe_reduced = k_pe + k_pe_reduced_out = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + (q_c, q_c_scale), _, kv_c_normed, _, k_pe_reduced_out = fused_reduce_rms_mxfp4_quant(q_c, q_a_layernorm_weight, q_a_layernorm_variance_epsilon, + kv_c, kv_a_layernorm_weight, kv_a_layernorm_variance_epsilon, k_pe_reduced, + res1=None, + shuffle=False, + scale_shuffle_padding=False, + dtype=torch.bfloat16, + out3=k_pe_reduced_out) + if k_pe_reduced_out is not None: + k_pe = k_pe_reduced_out + return q_c, q_c_scale, kv_c_normed, k_pe + +def _rocm_aiter_triton_qkv_a_proj_layernorm_fp4_fake( + hidden_states_quant: torch.Tensor, + hidden_states_quant_scale: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + M = hidden_states_quant.shape[0] + device = hidden_states_quant.device + q_c = torch.empty((M, q_lora_rank // 2), dtype=torch.uint8, device=device) + q_c_scale = torch.empty((M, (q_lora_rank + 32 - 1) // 32), dtype=torch.float32, device=device) + kv_c_normed = torch.empty((M, kv_lora_rank), dtype=torch.bfloat16, device=device) + k_pe = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + return q_c, q_c_scale, kv_c_normed, k_pe + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -893,7 +1060,8 @@ def is_fp8bmm_enabled(cls) -> bool: @classmethod @if_aiter_supported def is_fp4bmm_enabled(cls) -> bool: - return cls._AITER_ENABLED and cls._FP4BMM_ENABLED + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and current_platform.supports_mx() @classmethod @if_aiter_supported @@ -989,6 +1157,18 @@ def register_ops_once() -> None: op_func=_rocm_aiter_gemm_a8w8_blockscale_impl, fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake, ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_gemm_afp4wfp4", + op_func=_rocm_aiter_triton_gemm_afp4wfp4_impl, + fake_impl=_rocm_aiter_triton_gemm_afp4wfp4_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_gemm_a16w8_blockscale", + op_func=_rocm_aiter_triton_gemm_a16w8_blockscale_impl, + fake_impl=_rocm_aiter_triton_gemm_a16w8_blockscale_fake, + ) direct_register_custom_op( op_name="rocm_aiter_rms_norm", @@ -1067,6 +1247,21 @@ def register_ops_once() -> None: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_triton_qkv_a_proj_layernorm_fp8", + op_func=_rocm_aiter_triton_qkv_a_proj_layernorm_fp8_impl, + mutates_args=[], + fake_impl=_rocm_aiter_triton_qkv_a_proj_layernorm_fp8_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_qkv_a_proj_layernorm_fp4", + op_func=_rocm_aiter_triton_qkv_a_proj_layernorm_fp4_impl, + mutates_args=[], + fake_impl=_rocm_aiter_triton_qkv_a_proj_layernorm_fp4_fake, + dispatch_key=current_platform.dispatch_key, + ) _OPS_REGISTERED = True @staticmethod @@ -1321,6 +1516,34 @@ def triton_rotary_embed( query = query.view(query_shape) key = key.view(key_shape) + @staticmethod + def triton_fp4_bmm( + X: torch.Tensor, + WQ: torch.Tensor, + w_scale: torch.Tensor, + dtype: torch.dtype | None = torch.bfloat16, + YQ: torch.Tensor | None = None, + transpose_bm: bool | None = False, + config: dict | None = None, + y_scale: dict | None = None, + ) -> torch.Tensor: + # ruff: noqa: E501 # isort: skip + from aiter.ops.triton.batched_gemm_a16wfp4 import ( + batched_gemm_a16wfp4 as aiter_triton_fp4_bmm, + ) + + return aiter_triton_fp4_bmm( + X, + WQ, + w_scale, + dtype=dtype, + y=YQ, + config=config, + transpose_bm=transpose_bm, + prequant=True, + y_scale=y_scale, + ) + @staticmethod def triton_fp8_bmm( X: torch.Tensor, @@ -1364,6 +1587,17 @@ def triton_gemm_a8w8_blockscale( return torch.ops.vllm.rocm_aiter_triton_gemm_a8w8_blockscale( A, B, As, Bs, output_dtype ) + + @staticmethod + def triton_gemm_a16w8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_triton_gemm_a16w8_blockscale( + A, B, Bs, output_dtype + ) @staticmethod def group_fp8_quant( diff --git a/vllm/compilation/rocm_aiter_fusion.py b/vllm/compilation/rocm_aiter_fusion.py index cd14014c6828..8b22effd2889 100644 --- a/vllm/compilation/rocm_aiter_fusion.py +++ b/vllm/compilation/rocm_aiter_fusion.py @@ -39,6 +39,48 @@ SPLIT_WITH_SIZES_OP = torch.ops.aten.split_with_sizes.default +class AiterRMSFp8GroupQuantPattern: + """ + This pattern fuses aiter rms_norm & group fp8 quant custom + ops into an aiter rms_norm_group_fp8_quant op. + """ + + def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload): + self.epsilon = epsilon + self.quant_dtype = quant_dtype + self.quant_op = quant_op + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + ): + at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon) + + at2 = self.quant_op(at1, 128) + + return at2[0], at2[1] + + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + ): + at = AITER_RMS_GROUP_QUANT_OP( + x=input, + weight=weight, + variance_epsilon=self.epsilon, + group_size=128, + ) + + return at[0], at[1] + + inputs = [ + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + class AiterRMSFp8GroupQuantPattern: """ This pattern fuses aiter rms_norm & group fp8 quant custom diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 18f43126baeb..173562d1aa08 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -9,6 +9,7 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.platforms import current_platform +from vllm._aiter_ops import rocm_aiter_ops @dataclass class MLAModules: @@ -105,53 +106,92 @@ def __init__( indexer=self.indexer, rotary_emb = self.rotary_emb if current_platform.is_rocm() else None ) + self.use_aiter_triton = rocm_aiter_ops.is_enabled() + self.use_triton_qkv_a_proj_layernrom_fp8 = rocm_aiter_ops.is_enabled() and quant_config.get_name() == 'fp8' + self.use_triton_qkv_a_proj_layernrom_fp4 = rocm_aiter_ops.is_enabled() and quant_config.get_name() == 'quark' self.prefix = prefix def forward_native( self, positions: torch.Tensor, - hidden_states: torch.Tensor, + hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor], llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: q_c = None kv_lora = None - if self.q_lora_rank is not None: - assert self.fused_qkv_a_proj is not None, ( - "fused_qkv_a_proj is required when q_lora_rank is not None" - ) - assert self.q_a_layernorm is not None, ( - "q_a_layernorm is required when q_lora_rank is not None" - ) - assert self.q_b_proj is not None, ( - "q_b_proj is required when q_lora_rank is not None" - ) - qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] - q_c, kv_lora = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - q_c = self.q_a_layernorm(q_c) - q = self.q_b_proj(q_c)[0] + if self.use_triton_qkv_a_proj_layernrom_fp8: + assert self.q_lora_rank is not None + assert isinstance(hidden_states, tuple) + hidden_states, hidden_states_scales = hidden_states + q_c, q_c_scale, kv_c_normed, k_pe = torch.ops.vllm.rocm_aiter_triton_qkv_a_proj_layernorm_fp8( + hidden_states_quant=hidden_states, + hidden_states_quant_scale=hidden_states_scales, + weight_qkv_a_proj=self.fused_qkv_a_proj.weight, + weight_scale_qkv_a_proj=self.fused_qkv_a_proj.weight_scale, + q_a_layernorm_weight=self.q_a_layernorm.weight, + q_a_layernorm_variance_epsilon=self.q_a_layernorm.variance_epsilon, + kv_a_layernorm_weight=self.kv_a_layernorm.weight, + kv_a_layernorm_variance_epsilon=self.kv_a_layernorm.variance_epsilon, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim) + q = torch.ops.vllm.rocm_aiter_triton_gemm_a8w8_blockscale(q_c, self.q_b_proj.weight, q_c_scale, self.q_b_proj.weight_scale, output_dtype=torch.bfloat16) + if self.use_triton_qkv_a_proj_layernrom_fp4: + assert self.q_lora_rank is not None + assert isinstance(hidden_states, tuple) + hidden_states, hidden_states_scales = hidden_states + q_c, q_c_scale, kv_c_normed, k_pe = torch.ops.vllm.rocm_aiter_triton_qkv_a_proj_layernorm_fp4( + hidden_states_quant=hidden_states, + hidden_states_quant_scale=hidden_states_scales, + weight_qkv_a_proj=self.fused_qkv_a_proj.weight, + weight_scale_qkv_a_proj=self.fused_qkv_a_proj.weight_scale, + q_a_layernorm_weight=self.q_a_layernorm.weight, + q_a_layernorm_variance_epsilon=self.q_a_layernorm.variance_epsilon, + kv_a_layernorm_weight=self.kv_a_layernorm.weight, + kv_a_layernorm_variance_epsilon=self.kv_a_layernorm.variance_epsilon, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim) + q = torch.ops.vllm.rocm_aiter_triton_gemm_afp4wfp4(q_c, self.q_b_proj.weight, q_c_scale, self.q_b_proj.weight_scale, output_dtype=torch.bfloat16) else: - assert self.kv_a_proj_with_mqa is not None, ( - "kv_a_proj_with_mqa is required when q_lora_rank is None" - ) - assert self.q_proj is not None, ( - "q_proj is required when q_lora_rank is None" - ) - kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] - q = self.q_proj(hidden_states)[0] - - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c) + assert isinstance(hidden_states, torch.Tensor) + if self.q_lora_rank is not None: + assert self.fused_qkv_a_proj is not None, ( + "fused_qkv_a_proj is required when q_lora_rank is not None" + ) + assert self.q_a_layernorm is not None, ( + "q_a_layernorm is required when q_lora_rank is not None" + ) + assert self.q_b_proj is not None, ( + "q_b_proj is required when q_lora_rank is not None" + ) + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_lora = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + q_c = self.q_a_layernorm(q_c) + q = self.q_b_proj(q_c)[0] + else: + assert self.kv_a_proj_with_mqa is not None, ( + "kv_a_proj_with_mqa is required when q_lora_rank is None" + ) + assert self.q_proj is not None, ( + "q_proj is required when q_lora_rank is None" + ) + kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] + q = self.q_proj(hidden_states)[0] + + kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c) q = q.view(-1, self.num_heads, self.qk_head_dim) # Add head dim of 1 to k_pe k_pe = k_pe.unsqueeze(1) - if self.rotary_emb is not None and not current_platform.is_rocm(): + if self.rotary_emb is not None and not self.use_aiter_triton: q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( positions, q[..., self.qk_nope_head_dim :], k_pe ) @@ -164,7 +204,7 @@ def forward_native( if llama_4_scaling is not None: q *= llama_4_scaling - positions_rocm = None if not current_platform.is_rocm() else positions + positions_rocm = None if not self.use_aiter_triton else positions attn_out = self.mla_attn( q, kv_c_normed, diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index eeb60023dc0e..f9553dcff744 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -54,6 +54,7 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: gemm_afp4wfp4, gemm_afp4wfp4_preshuffled_weight_scales, ) + from aiter.ops.triton.gemm_a16wfp4 import gemm_a16wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant from vllm.utils.torch_utils import direct_register_custom_op @@ -123,6 +124,13 @@ def gemm_with_dynamic_quant( return y[:M] else: if x_scales is None: + if M <= 256 and weight.shape[0] == 7168 and x.shape[-1] == 2048: + y = torch.empty(M, + weight.shape[0], + device=x.device, + dtype=out_dtype) + gemm_a16wfp4(x, weight, weight_scale.T, dtype=out_dtype, y=y) + return y x_q, x_s = dynamic_mxfp4_quant(x) else: x_q = x diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index dc82f94ebbbf..ca20eccc262f 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -5,7 +5,10 @@ from types import MappingProxyType from typing import Any +import torch import regex as re +from torch import nn +from aiter.ops.triton.quant import dynamic_mxfp4_quant def deep_compare(dict1: Any, dict2: Any) -> bool: @@ -21,6 +24,102 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: return dict1 == dict2 +# utility for tensor dims > 2 cases +def b_dynamic_mxfp4_quant(x): + h, b, d = x.shape + x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d)) + return x.view(h, b, d // 2), x_scales.view(h, b, d // 32) + #return x.view(h, b, d // 2), x_scales.view(h, b, d // 32) + + +def mxfp4_to_f32(x, is_threed): + # 2 because we pack fp4 in uint8. + x = x.repeat_interleave(2, dim=-1) + if is_threed: + x[..., ::2] = x[..., ::2] & 0xF + x[..., 1::2] = x[..., 1::2] >> 4 + else: + x[:, ::2] = x[:, ::2] & 0xF + x[:, 1::2] = x[:, 1::2] >> 4 + + mxfp4_list = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ] + mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda") + return mxfp4_in_f32[x.long()] + + +def e8m0_to_f32(x): + # Convert the input tensor `x` (assumed to be in e8m0 format) to float32. + # e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa. + # This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats. + + # Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127). + x_f32 = 2 ** ((x.to(torch.float32)) - 127) + + # If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf. + # Since this custom format has no mantissa, treat 2^128 as NaN. + x_f32[x_f32 == 128] = float("nan") + return x_f32 + + +def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str): + if "mxfp4" in quant_format: + # when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor + # do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8) + # and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8) + if w.dtype == torch.bfloat16: + # w_kc, w_vc = w.split( + # [self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1)) + w_kc = w_kc.transpose(-2, -1) + w_s_kc = w_s_kc.transpose(-2, -1) + w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc) + w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2) + w_s_vc = w_s_vc.contiguous().transpose(1, 2) + elif w.dtype == torch.uint8: # static quant for mxfp4 + # when dtype is uint8, it means the w has been quantized to mxfp4 format + # but we must separate it to w_kc and w_vc. + # The quantized tensor size is only half of original tensor size + # and the scaling factor is 1/32, the transpose behavior will be not correct + # need to upcast it to fp32 to separate w to w_kc and w_vc + # to ensure the following transpose behavior is correct + # and then do mxfp4 quant again + w = mxfp4_to_f32(w, True).to(torch.bfloat16) + w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1) + w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16) + w = w * w_scales + w_kc, w_vc = w.unflatten( + 0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim)) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1)) + w_kc = w_kc.transpose(-2, -1) + w_s_kc = w_s_kc.transpose(-2, -1) + w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc) + w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2) + w_s_vc = w_s_vc.contiguous().transpose(1, 2) + + return w_kc, w_s_kc, w_vc, w_s_vc + + def should_ignore_layer( layer_name: str | None, ignore: Iterable[str], diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 18126a6e3391..8c48de62d54f 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -357,7 +357,6 @@ def _run_aiter( not current_platform.is_fp8_fnuz() and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k) ) - if use_triton: gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale else: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 8ef55c787e18..1f5e57e8e4ca 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -388,7 +388,6 @@ def __init__( n_shared_experts=config.n_shared_experts if self.is_fusion_moe_shared_experts_enabled else None, - skip_shared_experts = self.is_fusion_triton_shared_experts_enabled, ) def forward(self, hidden_states: tuple | torch.Tensor) -> torch.Tensor: @@ -1226,7 +1225,7 @@ def __init__( self.is_fusion_triton_shared_experts_enabled = rocm_aiter_ops.is_fusion_triton_shared_experts_enabled() self.use_triton_fused_rmsnorm_fp8_quant = False self.use_triton_fused_rmsnorm_fp4_quant = False - if self.is_fusion_triton_shared_experts_enabled: + if rocm_aiter_ops.is_enabled(): if quant_config.get_name() == 'fp8': self.use_triton_fused_rmsnorm_fp8_quant = True elif quant_config.get_name() == 'quark': @@ -1302,11 +1301,51 @@ def forward( llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: # Self Attention - if residual is None: - residual = hidden_states.clone() - hidden_states = self.input_layernorm(hidden_states) + if self.use_triton_fused_rmsnorm_fp8_quant: + weight = self.input_layernorm.weight + eps = self.input_layernorm.variance_epsilon + from vllm._aiter_ops import AITER_FP8_DTYPE + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + if residual is None: + residual = hidden_states + (hidden_states_quant, hidden_states_quant_scales), _, _, _ = fused_rms_fp8_group_quant(hidden_states, weight, eps, + None, None, eps, + group_size=128, + dtype_quant=AITER_FP8_DTYPE, + res1=None) + else: + (hidden_states_quant, hidden_states_quant_scales), _, _, residual = fused_rms_fp8_group_quant(hidden_states, weight, eps, + None, None, eps, + group_size=128, + dtype_quant=AITER_FP8_DTYPE, + res1=residual) + hidden_states = (hidden_states_quant, hidden_states_quant_scales) + elif self.use_triton_fused_rmsnorm_fp4_quant: + weight = self.input_layernorm.weight + eps = self.input_layernorm.variance_epsilon + from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant + if residual is None: + residual = hidden_states + (hidden_states_quant, hidden_states_quant_scales), _, _, _ = fused_rms_mxfp4_quant(hidden_states, weight, eps, + None, None, eps, + res1=None, + shuffle=False, + scale_shuffle_padding=False, + output_unquantized_inp1=False) + else: + (hidden_states_quant, hidden_states_quant_scales), _, _, residual = fused_rms_mxfp4_quant(hidden_states, weight, eps, + None, None, eps, + res1=residual, + shuffle=False, + scale_shuffle_padding=False, + output_unquantized_inp1=False) + hidden_states = (hidden_states_quant, hidden_states_quant_scales) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + if residual is None: + residual = hidden_states.clone() + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) attn_kwargs = { "positions": positions, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 180625b6ce89..839486dc4275 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1139,6 +1139,7 @@ def __init__( self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() + self.is_aiter_triton_fp4_bmm_enabled = rocm_aiter_ops.is_fp4bmm_enabled() def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): @@ -1167,7 +1168,23 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj) + + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + if self.is_aiter_triton_fp4_bmm_enabled: + from vllm.model_executor.layers.quantization.quark.utils import quark_post_load_weights + + self.W_K, self.W_K_scale, W_V, self.W_V_scale = ( + quark_post_load_weights(self, kv_b_proj_weight, "mxfp4")) + self.W_V = W_V.contiguous().transpose(1, 2) + + self.W_K = self.W_K.transpose(-2, -1).contiguous() + self.W_K_scale = self.W_K_scale.transpose(-2, -1).contiguous() + self.W_V = self.W_V.transpose(-2, -1).contiguous() + self.W_V_scale = self.W_V_scale.transpose(-2, -1).contiguous() + return + + kv_b_proj_weight = kv_b_proj_weight.T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1236,16 +1253,39 @@ def get_and_maybe_dequant_weights(layer: LinearBase): self.W_UK_T = W_UK.permute(1, 2, 0) def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - if self.is_aiter_triton_fp8_bmm_enabled: + if self.is_aiter_triton_fp4_bmm_enabled: + #print(f'>>> x pre (up_proj) {x.shape}') + out = out.view(-1, self.num_heads, self.v_head_dim) + x = x.view(-1, self.num_heads, self.kv_lora_rank) + x = x.transpose(0, 1) + + #print(f'>>> x {x.shape}, attn_bmm_output {attn_bmm_output.shape}, self.W_V {self.W_V.shape}') + out = rocm_aiter_ops.triton_fp4_bmm( + x, + self.W_V, + self.W_V_scale, + YQ=out, + transpose_bm=True, + y_scale=None, + ) + #print(f'>>> x before transpose {x.shape}') + out = out.view(-1, self.num_heads * self.v_head_dim) + # x = out + + elif self.is_aiter_triton_fp8_bmm_enabled: + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + out = out.view(-1, self.num_heads, self.v_head_dim) # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) x = rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out ) else: + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Convert from (B, N * V) to (N, B, V) out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) @@ -1577,7 +1617,23 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj) + + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + if self.is_aiter_triton_fp4_bmm_enabled: + from vllm.model_executor.layers.quantization.quark.utils import quark_post_load_weights + + self.W_K, self.W_K_scale, W_V, self.W_V_scale = ( + quark_post_load_weights(self, kv_b_proj_weight, "mxfp4")) + self.W_V = W_V.contiguous().transpose(1, 2) + + self.W_K = self.W_K.transpose(-2, -1).contiguous() + self.W_K_scale = self.W_K_scale.transpose(-2, -1).contiguous() + self.W_V = self.W_V.transpose(-2, -1).contiguous() + self.W_V_scale = self.W_V_scale.transpose(-2, -1).contiguous() + return + + kv_b_proj_weight = kv_b_proj_weight.T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 9c688c3f88e8..c24ba2e4c5c1 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -445,15 +445,14 @@ def forward( if self.is_aiter_triton_fp4_bmm_enabled: #x = x.view(-1, self.num_heads, self.kv_lora_rank) decode_ql_nope = decode_q_cat[... , :self.W_K.shape[1]] if (kv_cache.numel() > 0 and positions is not None) else None - # decode_ql_nope = batched_gemm_a16wfp4( - # decode_q_nope, - # self.W_K, - # self.W_K_scale, - # y=decode_ql_nope, - # transpose_bm=True, - # prequant=True, - # y_scale=layer._q_scale if fp8_attention else None, - # ) + decode_ql_nope = rocm_aiter_ops.triton_fp4_bmm( + decode_q_nope, + self.W_K, + self.W_K_scale, + YQ=decode_ql_nope, + transpose_bm=True, + y_scale=layer._q_scale if fp8_attention else None, + ) # decode_ql_nope = decode_ql_nope.transpose(0, 1) elif self.is_aiter_triton_fp8_bmm_enabled: # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) From ed00f65c2cbeeb64f82da2c0285400111e8e88ad Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Mon, 8 Dec 2025 18:40:37 +0000 Subject: [PATCH 06/11] fix --- vllm/compilation/rocm_aiter_fusion.py | 43 ----------------------- vllm/model_executor/layers/mla.py | 12 +++---- vllm/model_executor/models/deepseek_v2.py | 2 +- 3 files changed, 7 insertions(+), 50 deletions(-) diff --git a/vllm/compilation/rocm_aiter_fusion.py b/vllm/compilation/rocm_aiter_fusion.py index 8b22effd2889..fd5bf8fb3de2 100644 --- a/vllm/compilation/rocm_aiter_fusion.py +++ b/vllm/compilation/rocm_aiter_fusion.py @@ -37,49 +37,6 @@ FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default SPLIT_WITH_SIZES_OP = torch.ops.aten.split_with_sizes.default - - -class AiterRMSFp8GroupQuantPattern: - """ - This pattern fuses aiter rms_norm & group fp8 quant custom - ops into an aiter rms_norm_group_fp8_quant op. - """ - - def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload): - self.epsilon = epsilon - self.quant_dtype = quant_dtype - self.quant_op = quant_op - - def register(self, pm_pass: PatternMatcherPass): - def pattern( - input: torch.Tensor, - weight: torch.Tensor, - ): - at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon) - - at2 = self.quant_op(at1, 128) - - return at2[0], at2[1] - - def replacement( - input: torch.Tensor, - weight: torch.Tensor, - ): - at = AITER_RMS_GROUP_QUANT_OP( - x=input, - weight=weight, - variance_epsilon=self.epsilon, - group_size=128, - ) - - return at[0], at[1] - - inputs = [ - empty_bf16(5, 4), # input - empty_bf16(1, 5), # weight - ] - - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) class AiterRMSFp8GroupQuantPattern: """ diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 173562d1aa08..8998a249fa43 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -121,11 +121,11 @@ def forward_native( q_c = None kv_lora = None - if self.use_triton_qkv_a_proj_layernrom_fp8: + if self.use_triton_qkv_a_proj_layernrom_fp4: assert self.q_lora_rank is not None assert isinstance(hidden_states, tuple) hidden_states, hidden_states_scales = hidden_states - q_c, q_c_scale, kv_c_normed, k_pe = torch.ops.vllm.rocm_aiter_triton_qkv_a_proj_layernorm_fp8( + q_c, q_c_scale, kv_c_normed, k_pe = torch.ops.vllm.rocm_aiter_triton_qkv_a_proj_layernorm_fp4( hidden_states_quant=hidden_states, hidden_states_quant_scale=hidden_states_scales, weight_qkv_a_proj=self.fused_qkv_a_proj.weight, @@ -137,12 +137,12 @@ def forward_native( q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, qk_rope_head_dim=self.qk_rope_head_dim) - q = torch.ops.vllm.rocm_aiter_triton_gemm_a8w8_blockscale(q_c, self.q_b_proj.weight, q_c_scale, self.q_b_proj.weight_scale, output_dtype=torch.bfloat16) - if self.use_triton_qkv_a_proj_layernrom_fp4: + q = torch.ops.vllm.rocm_aiter_triton_gemm_afp4wfp4(q_c, self.q_b_proj.weight, q_c_scale, self.q_b_proj.weight_scale, output_dtype=torch.bfloat16) + elif self.use_triton_qkv_a_proj_layernrom_fp8: assert self.q_lora_rank is not None assert isinstance(hidden_states, tuple) hidden_states, hidden_states_scales = hidden_states - q_c, q_c_scale, kv_c_normed, k_pe = torch.ops.vllm.rocm_aiter_triton_qkv_a_proj_layernorm_fp4( + q_c, q_c_scale, kv_c_normed, k_pe = torch.ops.vllm.rocm_aiter_triton_qkv_a_proj_layernorm_fp8( hidden_states_quant=hidden_states, hidden_states_quant_scale=hidden_states_scales, weight_qkv_a_proj=self.fused_qkv_a_proj.weight, @@ -154,7 +154,7 @@ def forward_native( q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, qk_rope_head_dim=self.qk_rope_head_dim) - q = torch.ops.vllm.rocm_aiter_triton_gemm_afp4wfp4(q_c, self.q_b_proj.weight, q_c_scale, self.q_b_proj.weight_scale, output_dtype=torch.bfloat16) + q = torch.ops.vllm.rocm_aiter_triton_gemm_a8w8_blockscale(q_c, self.q_b_proj.weight, q_c_scale, self.q_b_proj.weight_scale, output_dtype=torch.bfloat16) else: assert isinstance(hidden_states, torch.Tensor) if self.q_lora_rank is not None: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 1f5e57e8e4ca..77e28a3ea5b8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1231,7 +1231,7 @@ def __init__( elif quant_config.get_name() == 'quark': self.use_triton_fused_rmsnorm_fp4_quant = True else: - raise NotImplementedError(f"{quant_config.get_name()=} which is not supported for VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS") + raise NotImplementedError(f"{quant_config.get_name()=} which is not supported with the current version of AITER") logger.info(f"[Aiter] {self.__class__.__name__} has {quant_config.get_name()=}") # verify MLA attention specific fields From 03f3472143a27f3d89cb4ff3cb0a874b9bf1d8e3 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Mon, 8 Dec 2025 19:29:27 +0000 Subject: [PATCH 07/11] deliberatly disable fp4 fusions --- vllm/model_executor/layers/mla.py | 1 + vllm/model_executor/models/deepseek_v2.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 8998a249fa43..e9e79a21147f 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -109,6 +109,7 @@ def __init__( self.use_aiter_triton = rocm_aiter_ops.is_enabled() self.use_triton_qkv_a_proj_layernrom_fp8 = rocm_aiter_ops.is_enabled() and quant_config.get_name() == 'fp8' self.use_triton_qkv_a_proj_layernrom_fp4 = rocm_aiter_ops.is_enabled() and quant_config.get_name() == 'quark' + self.use_triton_qkv_a_proj_layernrom_fp4 = False self.prefix = prefix diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 77e28a3ea5b8..537d9b7777ac 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -310,10 +310,12 @@ def __init__( self.use_triton_fused_shared_expert_fp4 = True self.rocm_aiter_triton_fused_shared_expert_func = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert_fp4 self.rocm_aiter_triton_fused_down_proj_mul_add_func = torch.ops.vllm.rocm_aiter_triton_fused_down_proj_mul_add_fp4 + self.is_fusion_triton_shared_experts_enabled = False else: raise NotImplementedError(f"{quant_config.get_name()=} which is not supported for VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS") logger.info(f"[Aiter] {self.__class__.__name__} is registered with {self.rocm_aiter_triton_fused_shared_expert_func.__name__} and {self.rocm_aiter_triton_fused_down_proj_mul_add_func.__name__}") + if self.is_fusion_triton_shared_experts_enabled: self.experts = FusedMoE( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, @@ -1230,6 +1232,7 @@ def __init__( self.use_triton_fused_rmsnorm_fp8_quant = True elif quant_config.get_name() == 'quark': self.use_triton_fused_rmsnorm_fp4_quant = True + self.use_triton_fused_rmsnorm_fp4_quant = False else: raise NotImplementedError(f"{quant_config.get_name()=} which is not supported with the current version of AITER") logger.info(f"[Aiter] {self.__class__.__name__} has {quant_config.get_name()=}") From a0389b633eb69cb0fda0b1547481cbfc3909cc50 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Mon, 8 Dec 2025 19:31:50 +0000 Subject: [PATCH 08/11] add gemm split cat in prefill --- vllm/v1/attention/backends/mla/common.py | 186 +++++++++++++++++++---- 1 file changed, 157 insertions(+), 29 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 839486dc4275..eebd9d9a5b93 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -204,6 +204,8 @@ AttentionLayer, MLAAttentionImpl, ) +from vllm.model_executor.layers.quantization.quark.quark import QuarkLinearMethod +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states @@ -1140,6 +1142,7 @@ def __init__( self.q_pad_num_heads = q_pad_num_heads self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() self.is_aiter_triton_fp4_bmm_enabled = rocm_aiter_ops.is_fp4bmm_enabled() + self.is_aiter_enabled = rocm_aiter_ops.is_enabled() def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): @@ -1732,20 +1735,61 @@ def _compute_prefill_context( kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim - ) - k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + if self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod) and False: + from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_split_cat + from aiter.ops.triton.quant import dynamic_mxfp4_quant + input = kv_c_normed + weight = self.kv_b_proj.weight + weight_scale = self.kv_b_proj.weight_scale - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype - attn_output, attn_softmax_lse = self._run_prefill_context_chunk( - prefill=prefill_metadata, - chunk_idx=i, - q=q, - k=k, - v=v, - ) + q_input, x_scale = dynamic_mxfp4_quant(input_2d) + + k, v = fused_gemm_afp4wfp4_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale.T, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + elif self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, Fp8LinearMethod): + from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_split_cat + import aiter as rocm_aiter + from aiter import get_hip_quant + aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) + from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 + + input = kv_c_normed + weight = self.kv_b_proj.weight + block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size + weight_scale = self.kv_b_proj.weight_scale + + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + if current_platform.is_fp8_fnuz(): + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False) + + k, v = fused_gemm_a8w8_blockscale_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) if output is None: output = attn_output @@ -1837,19 +1881,60 @@ def _context_parallel_compute_prefill_context( toks=toks, ) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim - ) - k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + if self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod) and False: + from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_split_cat + from aiter.ops.triton.quant import dynamic_mxfp4_quant + input = kv_c_normed + weight = self.kv_b_proj.weight + weight_scale = self.kv_b_proj.weight_scale - attn_output, attn_softmax_lse = self._run_prefill_context_chunk( - prefill=prefill_metadata, - chunk_idx=i, - q=q, - k=k, - v=v, - ) + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + q_input, x_scale = dynamic_mxfp4_quant(input_2d) + + k, v = fused_gemm_afp4wfp4_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale.T, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + elif self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, Fp8LinearMethod): + from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_split_cat + import aiter as rocm_aiter + from aiter import get_hip_quant + aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) + from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 + + input = kv_c_normed + weight = self.kv_b_proj.weight + block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size + weight_scale = self.kv_b_proj.weight_scale + + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + if current_platform.is_fp8_fnuz(): + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False) + + k, v = fused_gemm_a8w8_blockscale_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) if output is None: output = attn_output @@ -1885,12 +1970,55 @@ def _forward_prefill( assert self.dcp_world_size is not None has_context = attn_metadata.prefill.chunked_context is not None - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim - ) - k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + if self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod) and False: + from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_split_cat + from aiter.ops.triton.quant import dynamic_mxfp4_quant + input = kv_c_normed + weight = self.kv_b_proj.weight + weight_scale = self.kv_b_proj.weight_scale + + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + q_input, x_scale = dynamic_mxfp4_quant(input_2d) + + k, v = fused_gemm_afp4wfp4_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale.T, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + elif self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, Fp8LinearMethod): + from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_split_cat + import aiter as rocm_aiter + from aiter import get_hip_quant + aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) + from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 + + input = kv_c_normed + weight = self.kv_b_proj.weight + block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size + weight_scale = self.kv_b_proj.weight_scale + + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + if current_platform.is_fp8_fnuz(): + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False) + + k, v = fused_gemm_a8w8_blockscale_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) output_prefill = self._run_prefill_new_tokens( prefill=attn_metadata.prefill, From 0edba99b16a3f78a0d40be5ce393f0887539248b Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Wed, 10 Dec 2025 17:09:47 +0000 Subject: [PATCH 09/11] add fused_rope_kv_cahce to Llama --- vllm/attention/layer.py | 23 +++++- vllm/model_executor/models/llama.py | 18 +++-- vllm/v1/attention/backends/rocm_aiter_fa.py | 87 +++++++++++++++------ vllm/v1/attention/backends/rocm_attn.py | 69 +++++++++++++--- 4 files changed, 156 insertions(+), 41 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index bfffc7dda7f0..edac42d4da82 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -185,6 +185,7 @@ def __init__( attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, attn_backend: type[AttentionBackend] | None = None, + rotary_emb: nn.Module | None = None, **extra_impl_args, ) -> None: """ @@ -260,6 +261,7 @@ def __init__( kv_sharing_target_layer_name, **extra_impl_args, ) + self.impl.rotary_emb = rotary_emb backend_name = self.attn_backend.get_name() self.backend = AttentionBackendEnum.__members__.get(backend_name) self.dtype = dtype @@ -316,6 +318,7 @@ def forward( # shape does not match the query shape, so we optionally let the model # definition specify the output tensor shape. output_shape: torch.Size | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: """ The KV cache is stored inside this class and is accessed via @@ -365,7 +368,7 @@ def forward( ) else: torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name + query, key, value, output, self.layer_name, positions=positions ) return output.view(-1, hidden_size) else: @@ -868,8 +871,25 @@ def unified_attention_with_output( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) + if positions is not None: + assert hasattr(self.impl, "rotary_emb") and self.impl.rotary_emb is not None + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + positions=positions, + ) + return + self.impl.forward( self, query, @@ -891,6 +911,7 @@ def unified_attention_with_output_fake( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> None: return diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 167dfbca248c..843c40b7500a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -67,7 +67,8 @@ make_layers, maybe_prefix, ) - +from vllm.platforms import current_platform +from vllm._aiter_ops import rocm_aiter_ops class LlamaMLP(nn.Module): def __init__( @@ -219,6 +220,7 @@ def __init__( per_layer_sliding_window=sliding_window, attn_type=attn_type, prefix=f"{prefix}.attn", + rotary_emb = self.rotary_emb if current_platform.is_rocm() and rocm_aiter_ops.is_enabled() else None ) def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: @@ -239,11 +241,15 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - if self.do_llama_4_scaling: - attn_scale = self._get_llama_4_attn_scale(positions) - q = (q * attn_scale).to(q.dtype) - attn_output = self.attn(q, k, v) + if current_platform.is_rocm() and rocm_aiter_ops.is_enabled(): + assert not self.do_llama_4_scaling + attn_output = self.attn(q, k, v, positions=positions) + else: + q, k = self.rotary_emb(positions, q, k) + if self.do_llama_4_scaling: + attn_scale = self._get_llama_4_attn_scale(positions) + q = (q * attn_scale).to(q.dtype) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index b6aa0ae2be48..acbefc3410e3 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -41,6 +41,10 @@ def block_size(x, head_dim): def num_programs(total_tokens): return min(total_tokens, get_cu_count()) + + from vllm._aiter_ops import rocm_aiter_ops + if rocm_aiter_ops.is_enabled(): + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache @triton.jit def cp_mha_gather_cache_kernel( @@ -782,6 +786,7 @@ def forward( output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with AiterFlashAttention. @@ -823,29 +828,67 @@ def forward( # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached # in KV cache. - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping - # is not padded. However, we don't need to do - # key[:num_actual_tokens] and value[:num_actual_tokens] because - # the reshape_and_cache_flash op uses the slot_mapping's shape - # to determine the number of actual tokens. - - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + if positions is not None and query.shape[0] <= 256 and rocm_aiter_ops.is_enabled(): + assert self.kv_sharing_target_layer_name is None + cos_sin_cache = self.rotary_emb.cos_sin_cache + is_neox = self.rotary_emb.is_neox_style + cos, sin = cos_sin_cache.chunk(2, dim=-1) + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) + query, key, key_cache, value_cache, output = ( + fused_qk_rope_reshape_and_cache( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + positions, + cos, + sin, + layer._k_scale, + layer._v_scale, + is_neox, + flash_layout=True, + apply_scale=is_fp8_kv_cache, + offs=None, + q_out=query, + k_out=key, + output_zeros=True, + zeros_out=output, + ) ) + else: + if positions is not None: + if current_platform.is_rocm(): + query, key = self.rotary_emb.forward_cuda(positions, query, key) + else: + query, key = self.rotary_emb(positions, query, key) + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping + # is not padded. However, we don't need to do + # key[:num_actual_tokens] and value[:num_actual_tokens] because + # the reshape_and_cache_flash op uses the slot_mapping's shape + # to determine the number of actual tokens. + + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(current_platform.fp8_dtype()) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 868143cc192e..6611a49e6caa 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -32,6 +32,12 @@ logger = init_logger(__name__) +if current_platform.is_rocm(): + + from vllm._aiter_ops import rocm_aiter_ops + if rocm_aiter_ops.is_enabled(): + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache + @dataclass class RocmAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. @@ -264,6 +270,7 @@ def forward( output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -306,19 +313,57 @@ def forward( kv_cache, self.num_kv_heads, self.head_size ) - if self.kv_sharing_target_layer_name is None: - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + if positions is not None and query.shape[0] <= 256 and rocm_aiter_ops.is_enabled(): + assert self.kv_sharing_target_layer_name is None + cos_sin_cache = self.rotary_emb.cos_sin_cache + is_neox = self.rotary_emb.is_neox_style + cos, sin = cos_sin_cache.chunk(2, dim=-1) + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + query, key, key_cache, value_cache, output = ( + fused_qk_rope_reshape_and_cache( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + positions, + cos, + sin, + layer._k_scale, + layer._v_scale, + is_neox, + flash_layout=False, + apply_scale=is_fp8_kv_cache, + offs=None, + q_out=query, + k_out=key, + output_zeros=True, + zeros_out=output, + ) ) + else: + if positions is not None: + if current_platform.is_rocm(): + query, key = self.rotary_emb.forward_cuda(positions, query, key) + else: + query, key = self.rotary_emb(positions, query, key) + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype) From 85d44b14a05319979560497885cc250ee99566ea Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Fri, 12 Dec 2025 17:04:10 +0000 Subject: [PATCH 10/11] fp4 proj gemm enablement --- vllm/model_executor/layers/mla.py | 1 - .../layers/quantization/quark/quark.py | 24 +++++++-- .../quantization/quark/schemes/__init__.py | 4 +- .../quark/schemes/quark_ocp_mx.py | 54 +++++++++++++++++++ vllm/model_executor/models/deepseek_v2.py | 5 -- vllm/v1/attention/backends/mla/common.py | 34 ++++++------ 6 files changed, 94 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index e9e79a21147f..8998a249fa43 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -109,7 +109,6 @@ def __init__( self.use_aiter_triton = rocm_aiter_ops.is_enabled() self.use_triton_qkv_a_proj_layernrom_fp8 = rocm_aiter_ops.is_enabled() and quant_config.get_name() == 'fp8' self.use_triton_qkv_a_proj_layernrom_fp4 = rocm_aiter_ops.is_enabled() and quant_config.get_name() == 'quark' - self.use_triton_qkv_a_proj_layernrom_fp4 = False self.prefix = prefix diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 3640e5c45278..6da060315bdf 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -25,6 +25,7 @@ ) from vllm.model_executor.layers.quantization.quark.schemes import ( QuarkOCP_MX, + QuarkW16OCP_MX, QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8, @@ -108,7 +109,16 @@ def get_quant_method( if should_ignore_layer( prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping ): - return UnquantizedLinearMethod() + if current_platform.is_rocm(): + if prefix == "lm_head": + return UnquantizedLinearMethod() + + scheme = self.get_scheme(layer=layer, layer_name=prefix) + layer.scheme = scheme + return QuarkLinearMethod(self) + else: + return UnquantizedLinearMethod() + if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) layer.scheme = scheme @@ -376,7 +386,7 @@ def _matches_pattern(layer_name, pattern): ) return global_quant_config - def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": + def _get_scheme_from_config(self, config: dict[str, Any], layer_name: str) -> "QuarkScheme": if config.get("output_tensors") or config.get("bias"): raise NotImplementedError( "Currently, Quark models with output_tensors " @@ -399,6 +409,14 @@ def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": input_symmetric=input_config.get("symmetric"), ) elif self._is_ocp_mx(weight_config, input_config): + if current_platform.is_rocm(): + exclude_layers = cast(list[str], self.quant_config.get("exclude")) + if should_ignore_layer( + layer_name, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping + ): + return QuarkW16OCP_MX(weight_config, input_config) + return QuarkOCP_MX(weight_config, input_config) + return QuarkOCP_MX(weight_config, input_config) raise NotImplementedError( @@ -411,7 +429,7 @@ def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme": layer_quant_config = self._find_matched_config(layer_name, layer) # Find the quant_scheme - scheme = self._get_scheme_from_config(layer_quant_config) + scheme = self._get_scheme_from_config(layer_quant_config, layer_name) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) self._check_scheme_supported(scheme.get_min_capability()) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py index 7620d6e41b58..7043467a1ec4 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .quark_ocp_mx import QuarkOCP_MX +from .quark_ocp_mx import QuarkOCP_MX, QuarkW16OCP_MX from .quark_scheme import QuarkScheme from .quark_w8a8_fp8 import QuarkW8A8Fp8 from .quark_w8a8_int8 import QuarkW8A8Int8 -__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkOCP_MX"] +__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkOCP_MX", "QuarkW16OCP_MX"] diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index f9553dcff744..f76d8b6a3c99 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -26,6 +26,8 @@ ) from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter from vllm.platforms import current_platform +from vllm.model_executor.parameter import ModelWeightParameter +from vllm.model_executor.utils import set_weight_attrs from .quark_scheme import QuarkScheme @@ -349,3 +351,55 @@ def apply_weights( self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype, ) + + +class QuarkW16OCP_MX(QuarkOCP_MX): + def __init__(self, weight_quant_spec: dict[str, Any], + input_quant_spec: dict[str, Any]): + self.out_dtype = torch.get_default_dtype() + self.qscheme = "per_group" + self.weight_quant_spec = weight_quant_spec + self.input_quant_spec = input_quant_spec + self.emulate = not current_platform.supports_mx() + self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled() + if not self.emulate and (dynamic_mxfp4_quant is None + or gemm_afp4wfp4 is None): + # Currently need these kernels if not emulating + raise NotImplementedError( + f"{self.__class__.__name__} requires AITER to be installed " + "for non-emulation mode! Please refer to " + "https://github.com/ROCm/aiter for installation details.") + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + # This method creates unquantized linear weights. + # The weights are not quantized, and they are not sharded. + # The amount of memory allocated for the weights is + # sum(output_partition_sizes) * input_size_per_partition. + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + # set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0, "weight_loader": weight_loader}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, kwargs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) + w_q, w_s = dynamic_mxfp4_quant(layer.weight) + layer.weight_scale = torch.nn.Parameter( + w_s.T.contiguous(), + requires_grad=False) + layer.weight = torch.nn.Parameter(w_q, + requires_grad=False) + #super().process_weights_after_loading(layer) \ No newline at end of file diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 537d9b7777ac..19717dacb776 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -310,7 +310,6 @@ def __init__( self.use_triton_fused_shared_expert_fp4 = True self.rocm_aiter_triton_fused_shared_expert_func = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert_fp4 self.rocm_aiter_triton_fused_down_proj_mul_add_func = torch.ops.vllm.rocm_aiter_triton_fused_down_proj_mul_add_fp4 - self.is_fusion_triton_shared_experts_enabled = False else: raise NotImplementedError(f"{quant_config.get_name()=} which is not supported for VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS") logger.info(f"[Aiter] {self.__class__.__name__} is registered with {self.rocm_aiter_triton_fused_shared_expert_func.__name__} and {self.rocm_aiter_triton_fused_down_proj_mul_add_func.__name__}") @@ -458,9 +457,6 @@ def forward(self, hidden_states: tuple | torch.Tensor) -> torch.Tensor: if self.is_fusion_triton_shared_experts_enabled and hidden_states.dtype != torch.float16: assert shared_output is None final_hidden_states = self.rocm_aiter_triton_fused_down_proj_mul_add_func(shared_output_q, shared_output_s, self.shared_experts.down_proj.weight, self.shared_experts.down_proj.weight_scale, self.routed_scaling_factor, final_hidden_states) - # assert shared_output is not None - # final_hidden_states *= self.routed_scaling_factor - # final_hidden_states += shared_output else: if hidden_states.dtype != torch.float16: if not self.is_rocm_aiter_moe_enabled: @@ -1232,7 +1228,6 @@ def __init__( self.use_triton_fused_rmsnorm_fp8_quant = True elif quant_config.get_name() == 'quark': self.use_triton_fused_rmsnorm_fp4_quant = True - self.use_triton_fused_rmsnorm_fp4_quant = False else: raise NotImplementedError(f"{quant_config.get_name()=} which is not supported with the current version of AITER") logger.info(f"[Aiter] {self.__class__.__name__} has {quant_config.get_name()=}") diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index eebd9d9a5b93..0a8f4229a2bb 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1735,7 +1735,7 @@ def _compute_prefill_context( kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) - if self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod) and False: + if self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod): from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_split_cat from aiter.ops.triton.quant import dynamic_mxfp4_quant input = kv_c_normed @@ -1783,13 +1783,13 @@ def _compute_prefill_context( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - attn_output, attn_softmax_lse = self._run_prefill_context_chunk( - prefill=prefill_metadata, - chunk_idx=i, - q=q, - k=k, - v=v, - ) + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) if output is None: output = attn_output @@ -1881,7 +1881,7 @@ def _context_parallel_compute_prefill_context( toks=toks, ) - if self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod) and False: + if self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod): from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_split_cat from aiter.ops.triton.quant import dynamic_mxfp4_quant input = kv_c_normed @@ -1928,13 +1928,13 @@ def _context_parallel_compute_prefill_context( k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - attn_output, attn_softmax_lse = self._run_prefill_context_chunk( - prefill=prefill_metadata, - chunk_idx=i, - q=q, - k=k, - v=v, - ) + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) if output is None: output = attn_output @@ -1972,7 +1972,7 @@ def _forward_prefill( has_context = attn_metadata.prefill.chunked_context is not None - if self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod) and False: + if self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod): from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_split_cat from aiter.ops.triton.quant import dynamic_mxfp4_quant input = kv_c_normed From 07759f58e2f72e0b90fc96d5154b3187f59595c9 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Fri, 9 Jan 2026 14:57:21 +0000 Subject: [PATCH 11/11] update import and fused_qk_rope_cat_and_cache_mla interface --- vllm/_aiter_ops.py | 40 +++++++++---------- .../attention/backends/mla/rocm_aiter_mla.py | 8 +--- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 5a9b323f8a5c..344f0e76bc44 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -378,7 +378,7 @@ def _rocm_aiter_triton_gemm_a8w8_blockscale_impl( Bs: torch.Tensor, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale + from aiter.ops.triton.gemm.basic.gemm_a8w8_blockscale import gemm_a8w8_blockscale return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) @@ -428,7 +428,7 @@ def _rocm_aiter_triton_gemm_afp4wfp4_impl( Bs: torch.Tensor, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.gemm.basic.gemm_afp4wfp4 import gemm_afp4wfp4 return gemm_afp4wfp4(A, B, As, Bs.T, dtype=output_dtype) @@ -452,7 +452,7 @@ def _rocm_aiter_triton_gemm_a16w8_blockscale_impl( Bs: torch.Tensor, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - from aiter.ops.triton.gemm_a16w8_blockscale import gemm_a16w8_blockscale + from aiter.ops.triton.gemm.basic.gemm_a16w8_blockscale import gemm_a16w8_blockscale return gemm_a16w8_blockscale(A, B, Bs, dtype=output_dtype) @@ -526,7 +526,7 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl( variance_epsilon: float, group_size: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + from aiter.ops.triton.quant.fused_fp8_quant import fused_rms_fp8_group_quant (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant( x, @@ -564,7 +564,7 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_impl( variance_epsilon: float, group_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: - from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + from aiter.ops.triton.quant.fused_fp8_quant import fused_rms_fp8_group_quant (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant( x, @@ -603,7 +603,7 @@ def _rocm_aiter_2rmsnorm_1fp8_group_quant_impl( variance_epsilon2: float, group_size: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + from aiter.ops.triton.quant.fused_fp8_quant import fused_rms_fp8_group_quant (x_quant, x_quant_scales), _, x2_out, _ = fused_rms_fp8_group_quant( x1, @@ -710,8 +710,8 @@ def _rocm_aiter_triton_fused_shared_expert_fp8_impl( bias_shared: torch.Tensor, bias_moe_gate: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - from aiter.ops.triton.fused_gemm_a8w8_blockscale_a16w16 import fused_gemm_a8w8_blockscale_a16w16 - from aiter.ops.triton.fused_fp8_quant import fused_reduce_act_mul_fp8_group_quant + from aiter.ops.triton.gemm.fused.fused_gemm_a8w8_blockscale_a16w16 import fused_gemm_a8w8_blockscale_a16w16 + from aiter.ops.triton.quant.fused_fp8_quant import fused_reduce_act_mul_fp8_group_quant shared_output, router_logits = fused_gemm_a8w8_blockscale_a16w16(hidden_states_shared, weight_gate_up, hidden_states_shared_scale, weight_scale_gate_up, hidden_states_moe_gate, weight_moe_gate, bias_fp8=bias_shared, bias_bf16=bias_moe_gate, dtype=hidden_states_moe_gate.dtype, skip_reduce=True) @@ -753,7 +753,7 @@ def _rocm_aiter_triton_fused_down_proj_mul_add_fp8_impl( routed_scaling_factor: float, final_hidden_states: torch.Tensor, ) -> torch.Tensor: - from aiter.ops.triton.fused_gemm_a8w8_blockscale_mul_add import fused_gemm_a8w8_blockscale_mul_add + from aiter.ops.triton.gemm.fused.fused_gemm_a8w8_blockscale_mul_add import fused_gemm_a8w8_blockscale_mul_add out = fused_gemm_a8w8_blockscale_mul_add(hidden_states_shared, weight_down_proj, hidden_states_shared_scale, weight_scale_down_proj, routed_scaling_factor, final_hidden_states, fuse_type=1) return out @@ -780,8 +780,8 @@ def _rocm_aiter_triton_fused_shared_expert_fp4_impl( bias_shared: torch.Tensor, bias_moe_gate: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - from aiter.ops.triton.fused_gemm_afp4wfp4_a16w16 import fused_gemm_afp4wfp4_a16w16 - from aiter.ops.triton.fused_mxfp4_quant import fused_reduce_act_mul_and_mxfp4_quant + from aiter.ops.triton.gemm.fused.fused_gemm_afp4wfp4_a16w16 import fused_gemm_afp4wfp4_a16w16 + from aiter.ops.triton.quant.fused_mxfp4_quant import fused_reduce_act_mul_and_mxfp4_quant shared_output, router_logits = fused_gemm_afp4wfp4_a16w16(hidden_states_shared, weight_gate_up, hidden_states_shared_scale, weight_scale_gate_up.T, hidden_states_moe_gate, weight_moe_gate, is_fp4_preshuffled=False, bias_fp4=bias_shared, bias_bf16=bias_moe_gate, dtype=hidden_states_moe_gate.dtype, skip_reduce=True) @@ -831,7 +831,7 @@ def _rocm_aiter_triton_fused_down_proj_mul_add_fp4_impl( routed_scaling_factor: float, final_hidden_states: torch.Tensor, ) -> torch.Tensor: - from aiter.ops.triton.fused_gemm_afp4wfp4_mul_add import fused_gemm_afp4wfp4_mul_add + from aiter.ops.triton.gemm.fused.fused_gemm_afp4wfp4_mul_add import fused_gemm_afp4wfp4_mul_add out = fused_gemm_afp4wfp4_mul_add(hidden_states_shared, weight_down_proj, hidden_states_shared_scale, weight_scale_down_proj.T, routed_scaling_factor, final_hidden_states, fuse_type=1) return out @@ -861,8 +861,8 @@ def _rocm_aiter_triton_qkv_a_proj_layernorm_fp8_impl( kv_lora_rank: int, qk_rope_head_dim: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - from aiter.ops.triton.fused_fp8_quant import fused_reduce_rms_fp8_group_quant - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale + from aiter.ops.triton.quant.fused_fp8_quant import fused_reduce_rms_fp8_group_quant + from aiter.ops.triton.gemm.basic.gemm_a8w8_blockscale import gemm_a8w8_blockscale import aiter as rocm_aiter qkv_lora = gemm_a8w8_blockscale(hidden_states_quant, weight_qkv_a_proj, hidden_states_quant_scale, weight_scale_qkv_a_proj, skip_reduce=True) q_c, kv_c, k_pe = qkv_lora.split([q_lora_rank, kv_lora_rank, qk_rope_head_dim], @@ -920,8 +920,8 @@ def _rocm_aiter_triton_qkv_a_proj_layernorm_fp4_impl( kv_lora_rank: int, qk_rope_head_dim: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 - from aiter.ops.triton.fused_mxfp4_quant import fused_reduce_rms_mxfp4_quant + from aiter.ops.triton.gemm.basic.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.quant.fused_mxfp4_quant import fused_reduce_rms_mxfp4_quant qkv_lora = gemm_afp4wfp4(hidden_states_quant, weight_qkv_a_proj, hidden_states_quant_scale, weight_scale_qkv_a_proj.T, skip_reduce=True) q_c, kv_c, k_pe = qkv_lora.split([q_lora_rank, kv_lora_rank, qk_rope_head_dim], @@ -1464,7 +1464,7 @@ def triton_fp4_gemm_dynamic_qaunt( out_dtype: torch.dtype | None = torch.bfloat16, x_scales: torch.Tensor | None = None, ) -> torch.Tensor: - from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.gemm.basic.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant if x_scales is None: @@ -1490,7 +1490,7 @@ def triton_rotary_embed( rotary_dim: int, is_neox_style: bool, ): - from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace + from aiter.ops.triton.rope.rope import rope_cached_thd_positions_2c_fwd_inplace num_tokens = positions.numel() cos, sin = cos_sin_cache.chunk(2, dim=-1) @@ -1528,7 +1528,7 @@ def triton_fp4_bmm( y_scale: dict | None = None, ) -> torch.Tensor: # ruff: noqa: E501 # isort: skip - from aiter.ops.triton.batched_gemm_a16wfp4 import ( + from aiter.ops.triton.gemm.batched.batched_gemm_a16wfp4 import ( batched_gemm_a16wfp4 as aiter_triton_fp4_bmm, ) @@ -1558,7 +1558,7 @@ def triton_fp8_bmm( config: dict | None = None, ) -> torch.Tensor: # ruff: noqa: E501 # isort: skip - from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + from aiter.ops.triton.gemm.batched.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, ) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index c24ba2e4c5c1..3b255354ac6b 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -361,7 +361,7 @@ def forward( if positions is not None: # positions is not None entails that Q and K are not RoPE embedded yet, therefore, fused_qk_rope_cat_and_cache_mla is called assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}" - from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla + from aiter.ops.triton.fusions.fused_kv_cache import fused_qk_rope_cat_and_cache_mla cos, sin = self.rotary_emb.cos_sin_cache.chunk(2, dim = -1) is_neox = self.rotary_emb.is_neox_style q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) @@ -371,7 +371,7 @@ def forward( if fp8_attention: kv_cache_og_dtype = kv_cache.dtype kv_cache = kv_cache.view(q_out_dtype) - fused_output = fused_qk_rope_cat_and_cache_mla( + q, _, k_pe, mla_output_zeros = fused_qk_rope_cat_and_cache_mla( q_nope, q_pe, k_c_normed.unsqueeze(1), @@ -389,10 +389,6 @@ def forward( decode_q_pe_out = decode_q_cat[... , -self.qk_rope_head_dim:] if self.is_aiter_triton_fp4_bmm_enabled or self.is_aiter_triton_fp8_bmm_enabled else None, k_pe_out=k_pe, ) - if num_decode_tokens > 0: - q, _, k_pe, kv_cache, mla_output_zeros = fused_output - else: - q, _, k_pe, kv_cache = fused_output if fp8_attention: kv_cache = kv_cache.view(kv_cache_og_dtype) else: