Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT: bool = True
VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD: bool = True
VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM: bool = True
VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT: bool = True
ROCM_TRITON_MOE_PRESHUFFLE_SCALES: bool = True
VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: bool = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS: bool = True
Expand Down Expand Up @@ -1241,15 +1242,15 @@
# Use AITER Triton fused RMSNORM + Quantization
"VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT", "1"))),

Check failure on line 1245 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1245:81: E501 Line too long (92 > 80)
# Use AITER Triton fused elementwise multiply + elementwise addtion
"VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD", "1"))),

Check failure on line 1249 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1249:81: E501 Line too long (82 > 80)
# Use AITER Triton fused rope + zeros + reshape_and_cache
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE", "1"))),

Check failure on line 1253 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1253:81: E501 Line too long (94 > 80)
# Use AITER Triton fused FP8 per-token group quant + FP8 batched GEMM
"VLLM_ROCM_USE_AITER_TRITON_FP8_BMM":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FP8_BMM", "1"))),
Expand Down Expand Up @@ -1289,6 +1290,10 @@
# Use AITER Triton MLA
"VLLM_ROCM_USE_AITER_TRITON_MLA":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_MLA", "0"))),

# Use AITER Triton fused FP8 GEMM + split + cat
"VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT", "1"))),

# If set, enables CK fp4 MoE
"VLLM_ROCM_USE_CK_MXFP4_MOE": lambda: (
Expand Down
107 changes: 96 additions & 11 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,22 @@ def dynamic_per_batched_tensor_quant(
return x_quant, x_quant_scale

from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant

VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT
if VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT:
from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_split_cat
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()):
import aiter as rocm_aiter
from aiter import get_hip_quant
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
VLLM_ROCM_USE_AITER_TRITON_FP8_BMM = False
VLLM_ROCM_USE_AITER_TRITON_FP8_BMM_MAX_BATCH_SIZE = 0
VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT = False

logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FP8_BMM=} {VLLM_ROCM_USE_AITER_TRITON_FP8_BMM_MAX_BATCH_SIZE=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}")
Expand Down Expand Up @@ -1121,13 +1133,49 @@ def _compute_prefill_context(
k_pe = workspace[:toks]\
[..., self.kv_lora_rank:].unsqueeze(1)

kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if (
VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT
and (self.kv_b_proj.bias is None
or self.kv_b_proj.skip_bias_add)
and self.kv_b_proj.quant_method is not None
and isinstance(self.kv_b_proj.quant_method, Fp8LinearMethod)
and not self.kv_b_proj.gather_output
and self.kv_b_proj.quant_method.block_quant
):
assert self.kv_b_proj.quant_method.quant_config.weight_block_size is not None

input = kv_c_normed
weight = self.kv_b_proj.weight
block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size
weight_scale = self.kv_b_proj.weight_scale_inv
input_scale = self.kv_b_proj.input_scale
use_aiter_and_is_supported = (current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and (envs.VLLM_ROCM_USE_AITER_LINEAR or envs.VLLM_ROCM_USE_AITER_TRITON_LINEAR))
assert input_scale is None

# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_dtype = input.dtype

if use_aiter_and_is_supported and current_platform.is_fp8_fnuz():
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False)

k, v = fused_gemm_a8w8_blockscale_split_cat(
q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype
)
else:
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)

attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
Expand Down Expand Up @@ -1168,12 +1216,49 @@ def _forward_prefill(
assert attn_metadata.prefill is not None

has_context = attn_metadata.prefill.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
if (
VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP8_SPLIT_CAT
and (self.kv_b_proj.bias is None
or self.kv_b_proj.skip_bias_add)
and self.kv_b_proj.quant_method is not None
and isinstance(self.kv_b_proj.quant_method, Fp8LinearMethod)
and not self.kv_b_proj.gather_output
and self.kv_b_proj.quant_method.block_quant
):
assert self.kv_b_proj.quant_method.quant_config.weight_block_size is not None

input = kv_c_normed
weight = self.kv_b_proj.weight
block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size
weight_scale = self.kv_b_proj.weight_scale_inv
input_scale = self.kv_b_proj.input_scale
use_aiter_and_is_supported = (current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and (envs.VLLM_ROCM_USE_AITER_LINEAR or envs.VLLM_ROCM_USE_AITER_TRITON_LINEAR))
assert input_scale is None

# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_dtype = input.dtype

if use_aiter_and_is_supported and current_platform.is_fp8_fnuz():
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False)

k, v = fused_gemm_a8w8_blockscale_split_cat(
q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype
)
else:
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

output = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
Expand Down
Loading