From 5cb5778a49620a9cbac8feb442ce209d76f1cde8 Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Mon, 1 Dec 2025 16:42:49 +0000 Subject: [PATCH 1/2] Fuse kv_b_proj with cat --- vllm/envs.py | 5 ++ vllm/v1/attention/backends/mla/common.py | 59 ++++++++++++++++++++++-- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index ce719cd3d2d9..6c231e5bd326 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -182,6 +182,7 @@ VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT: bool = True VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD: bool = True VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM: bool = True + VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT: bool = True ROCM_TRITON_MOE_PRESHUFFLE_SCALES: bool = True VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: bool = False VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS: bool = True @@ -1289,6 +1290,10 @@ def get_vllm_port() -> Optional[int]: # Use AITER Triton MLA "VLLM_ROCM_USE_AITER_TRITON_MLA": lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_MLA", "0"))), + + # Use AITER Triton fused FP8 GEMM + split + cat + "VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT": + lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT", "1"))), # If set, enables CK fp4 MoE "VLLM_ROCM_USE_CK_MXFP4_MOE": lambda: ( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 943208ec7b1f..c2ffb1aa8024 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -255,10 +255,22 @@ def dynamic_per_batched_tensor_quant( return x_quant, x_quant_scale from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant + + VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT + if VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT: + from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_split_cat + from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod + from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 + if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()): + import aiter as rocm_aiter + from aiter import get_hip_quant + aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) else: VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False VLLM_ROCM_USE_AITER_TRITON_FP8_BMM = False VLLM_ROCM_USE_AITER_TRITON_FP8_BMM_MAX_BATCH_SIZE = 0 + VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT = False logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FP8_BMM=} {VLLM_ROCM_USE_AITER_TRITON_FP8_BMM_MAX_BATCH_SIZE=}") logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}") @@ -1168,12 +1180,49 @@ def _forward_prefill( assert attn_metadata.prefill is not None has_context = attn_metadata.prefill.chunked_context is not None - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + if ( + VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT + and (self.kv_b_proj.bias is None + or self.kv_b_proj.skip_bias_add) + and self.kv_b_proj.quant_method is not None + and isinstance(self.kv_b_proj.quant_method, Fp8LinearMethod) + and not self.kv_b_proj.gather_output + and self.kv_b_proj.quant_method.block_quant + ): + assert self.kv_b_proj.quant_method.quant_config.weight_block_size is not None + + input = kv_c_normed + weight = self.kv_b_proj.weight + block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size + weight_scale = self.kv_b_proj.weight_scale_inv + input_scale = self.kv_b_proj.input_scale + use_aiter_and_is_supported = (current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and (envs.VLLM_ROCM_USE_AITER_LINEAR or envs.VLLM_ROCM_USE_AITER_TRITON_LINEAR)) + assert input_scale is None + + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + if use_aiter_and_is_supported and current_platform.is_fp8_fnuz(): + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False) + + k, v = fused_gemm_a8w8_blockscale_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) output = self._run_prefill_new_tokens( prefill=attn_metadata.prefill, From 4db9794548de128dffb38789854f0eadf95a14ad Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Wed, 3 Dec 2025 21:04:34 +0000 Subject: [PATCH 2/2] Fused kv_b_proj with cat in _compute_prefill_context --- vllm/v1/attention/backends/mla/common.py | 48 +++++++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index c2ffb1aa8024..87cd65fd4729 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1133,13 +1133,49 @@ def _compute_prefill_context( k_pe = workspace[:toks]\ [..., self.kv_lora_rank:].unsqueeze(1) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + if ( + VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT + and (self.kv_b_proj.bias is None + or self.kv_b_proj.skip_bias_add) + and self.kv_b_proj.quant_method is not None + and isinstance(self.kv_b_proj.quant_method, Fp8LinearMethod) + and not self.kv_b_proj.gather_output + and self.kv_b_proj.quant_method.block_quant + ): + assert self.kv_b_proj.quant_method.quant_config.weight_block_size is not None + + input = kv_c_normed + weight = self.kv_b_proj.weight + block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size + weight_scale = self.kv_b_proj.weight_scale_inv + input_scale = self.kv_b_proj.input_scale + use_aiter_and_is_supported = (current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and (envs.VLLM_ROCM_USE_AITER_LINEAR or envs.VLLM_ROCM_USE_AITER_TRITON_LINEAR)) + assert input_scale is None + + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + if use_aiter_and_is_supported and current_platform.is_fp8_fnuz(): + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False) + + k, v = fused_gemm_a8w8_blockscale_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata,