From 8e81de33ad52229da79b6f1d95e7fbf62fd06c67 Mon Sep 17 00:00:00 2001 From: kliuae Date: Thu, 11 Dec 2025 08:30:49 +0000 Subject: [PATCH] add aiter qknorm and rope fusion Signed-off-by: kliuae --- .../kernels/benchmark_qk_norm_rope_fusion.py | 346 ++++++++++++++++++ tests/compile/test_qk_norm_rope_fusion.py | 20 +- tests/kernels/core/test_fused_qk_norm_rope.py | 89 +++++ vllm/_aiter_ops.py | 109 +++++- vllm/compilation/matcher_utils.py | 4 +- vllm/compilation/qk_norm_rope_fusion.py | 19 +- vllm/config/compilation.py | 2 +- vllm/envs.py | 11 +- 8 files changed, 580 insertions(+), 20 deletions(-) create mode 100644 benchmarks/kernels/benchmark_qk_norm_rope_fusion.py diff --git a/benchmarks/kernels/benchmark_qk_norm_rope_fusion.py b/benchmarks/kernels/benchmark_qk_norm_rope_fusion.py new file mode 100644 index 000000000000..a1ac0ba80600 --- /dev/null +++ b/benchmarks/kernels/benchmark_qk_norm_rope_fusion.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import itertools + +import torch + +from vllm._aiter_ops import rocm_aiter_ops +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform +from vllm.triton_utils import triton + + +def apply_qk_norm_rope_unfused( + qkv: torch.Tensor, + positions: torch.Tensor, + q_norm: RMSNorm, + k_norm: RMSNorm, + rope: RotaryEmbedding, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, +) -> torch.Tensor: + q_size = num_heads_q * head_dim + kv_size = num_heads_kv * head_dim + + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim) + q_by_head = q_norm.forward_native(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim) + k_by_head = k_norm.forward_native(k_by_head) + k = k_by_head.view(k.shape) + + q, k = rope.forward_native(positions, q, k) + return torch.cat([q, k, v], dim=-1) + + +def apply_qk_norm_rope_vllm_cuda( + qkv: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, + eps: float, + is_neox: bool, +) -> torch.Tensor: + torch.ops._C.fused_qk_norm_rope( + qkv, + num_heads_q, + num_heads_kv, + num_heads_kv, + head_dim, + eps, + q_weight, + k_weight, + cos_sin_cache, + is_neox, + positions.view(-1), + ) + return qkv + + +def apply_qk_norm_rope_aiter( + qkv: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, + eps: float, + is_neox: bool, +) -> torch.Tensor: + rocm_aiter_ops.fused_qk_norm_rope( + qkv=qkv, + num_heads_q=num_heads_q, + num_heads_k=num_heads_kv, + num_heads_v=num_heads_kv, + head_dim=head_dim, + eps=eps, + q_weight=q_weight, + k_weight=k_weight, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox, + position_ids=positions.view(-1), + ) + return qkv + + +def calculate_diff(num_tokens, num_heads, num_kv_heads, head_dim, dtype, is_neox, eps): + device = "cuda" + total_dim = (num_heads + 2 * num_kv_heads) * head_dim + + qkv_base = torch.randn(num_tokens, total_dim, dtype=dtype, device=device) + positions = torch.arange(num_tokens, dtype=torch.long, device=device) + + q_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + k_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + q_norm.weight.data.normal_(mean=1.0, std=0.1) + k_norm.weight.data.normal_(mean=1.0, std=0.1) + q_weight = q_norm.weight.data + k_weight = k_norm.weight.data + + rope = RotaryEmbedding( + head_size=head_dim, + rotary_dim=head_dim, + max_position_embeddings=4096, + base=10000.0, + is_neox_style=is_neox, + dtype=dtype, + ).to(device) + + # Unfused reference + output_unfused = apply_qk_norm_rope_unfused( + qkv_base.clone(), + positions, + q_norm, + k_norm, + rope, + num_heads, + num_kv_heads, + head_dim, + ) + + # vLLM CUDA kernel + if hasattr(torch.ops._C, "fused_qk_norm_rope"): + qkv_vllm = qkv_base.clone() + output_vllm = apply_qk_norm_rope_vllm_cuda( + qkv_vllm, + positions, + q_weight, + k_weight, + rope.cos_sin_cache, + num_heads, + num_kv_heads, + head_dim, + eps, + is_neox, + ) + vllm_matches = torch.allclose(output_unfused, output_vllm, atol=5e-2, rtol=1e-2) + print(f"vLLM CUDA kernel: {'Matches' if vllm_matches else 'Differs'}") + else: + print("vLLM CUDA kernel: Not available") + + # AITER kernel + if ( + current_platform.is_rocm() + and rocm_aiter_ops.is_enabled() + and rocm_aiter_ops.is_fused_qk_norm_rope_enabled() + ): + qkv_aiter = qkv_base.clone() + output_aiter = apply_qk_norm_rope_aiter( + qkv_aiter, + positions, + q_weight, + k_weight, + rope.cos_sin_cache, + num_heads, + num_kv_heads, + head_dim, + eps, + is_neox, + ) + aiter_matches = torch.allclose( + output_unfused, output_aiter, atol=5e-2, rtol=1e-2 + ) + print(f"AITER kernel: {'Matches' if aiter_matches else 'Differs'}") + else: + print( + "AITER kernel: Not available " + "(requires ROCm and VLLM_ROCM_USE_AITER_FUSED_QK_NORM_ROPE=1)" + ) + + +num_tokens_range = [64, 256, 1024, 4096] +num_heads_range = [32, 64] +num_kv_heads_range = [8, 16] +head_dim_range = [64, 128] +configs = list( + itertools.product( + num_tokens_range, num_heads_range, num_kv_heads_range, head_dim_range + ) +) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens", "num_heads", "num_kv_heads", "head_dim"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["unfused", "vllm_cuda", "aiter"], + line_names=["Unfused", "vLLM CUDA", "AITER"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name="qk-norm-rope-fusion-perf", + args={}, + ) +) +def benchmark(num_tokens, num_heads, num_kv_heads, head_dim, provider): + dtype = torch.bfloat16 + device = "cuda" + eps = 1e-6 + is_neox = True + + total_dim = (num_heads + 2 * num_kv_heads) * head_dim + qkv = torch.randn(num_tokens, total_dim, dtype=dtype, device=device) + positions = torch.arange(num_tokens, dtype=torch.long, device=device) + + q_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + k_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + q_weight = q_norm.weight.data + k_weight = k_norm.weight.data + + rope = RotaryEmbedding( + head_size=head_dim, + rotary_dim=head_dim, + max_position_embeddings=4096, + base=10000.0, + is_neox_style=is_neox, + dtype=dtype, + ).to(device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "unfused": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: apply_qk_norm_rope_unfused( + qkv.clone(), + positions, + q_norm, + k_norm, + rope, + num_heads, + num_kv_heads, + head_dim, + ), + quantiles=quantiles, + ) + elif provider == "vllm_cuda": + if not hasattr(torch.ops._C, "fused_qk_norm_rope"): + return float("nan"), float("nan"), float("nan") + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: apply_qk_norm_rope_vllm_cuda( + qkv.clone(), + positions, + q_weight, + k_weight, + rope.cos_sin_cache, + num_heads, + num_kv_heads, + head_dim, + eps, + is_neox, + ), + quantiles=quantiles, + ) + elif provider == "aiter": + if not ( + current_platform.is_rocm() + and rocm_aiter_ops.is_enabled() + and rocm_aiter_ops.is_fused_qk_norm_rope_enabled() + ): + return float("nan"), float("nan"), float("nan") + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: apply_qk_norm_rope_aiter( + qkv.clone(), + positions, + q_weight, + k_weight, + rope.cos_sin_cache, + num_heads, + num_kv_heads, + head_dim, + eps, + is_neox, + ), + quantiles=quantiles, + ) + else: + raise ValueError(f"Unknown provider: {provider}") + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + import os + + parser = argparse.ArgumentParser( + description="Benchmark QK norm + RoPE fusion kernels" + ) + parser.add_argument("--num-tokens", type=int, default=256, help="Number of tokens") + parser.add_argument( + "--num-heads", type=int, default=32, help="Number of query heads" + ) + parser.add_argument( + "--num-kv-heads", type=int, default=8, help="Number of key/value heads" + ) + parser.add_argument("--head-dim", type=int, default=128, help="Head dimension") + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["float16", "bfloat16"], + help="Data type", + ) + parser.add_argument("--is-neox", action="store_true", help="Use Neox-style RoPE") + parser.add_argument("--eps", type=float, default=1e-6, help="RMSNorm epsilon") + parser.add_argument( + "--save-path", + type=str, + default="./configs/qk_norm_rope/", + help="Path to save benchmark results", + ) + + args = parser.parse_args() + + os.makedirs(args.save_path, exist_ok=True) + + dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 + + print("=" * 80) + print("Correctness Test") + print("=" * 80) + calculate_diff( + args.num_tokens, + args.num_heads, + args.num_kv_heads, + args.head_dim, + dtype, + args.is_neox, + args.eps, + ) + + print("\n" + "=" * 80) + print("Performance Benchmark") + print("=" * 80) + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/tests/compile/test_qk_norm_rope_fusion.py b/tests/compile/test_qk_norm_rope_fusion.py index 511e50f5fdc2..4a5a4bf98892 100644 --- a/tests/compile/test_qk_norm_rope_fusion.py +++ b/tests/compile/test_qk_norm_rope_fusion.py @@ -10,8 +10,8 @@ from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.qk_norm_rope_fusion import ( - FUSED_QK_ROPE_OP, QKNormRoPEFusionPass, + get_fused_qknorm_rope_op, ) from vllm.config import ( CompilationConfig, @@ -104,7 +104,7 @@ def ops_in_model_before(self) -> list[torch._ops.OpOverload]: return ops def ops_in_model_after(self) -> list[torch._ops.OpOverload]: - return [FUSED_QK_ROPE_OP] + return [get_fused_qknorm_rope_op()] @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @@ -119,8 +119,14 @@ def ops_in_model_after(self) -> list[torch._ops.OpOverload]: def test_qk_norm_rope_fusion( eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype ): - if not hasattr(torch.ops._C, "fused_qk_norm_rope"): - pytest.skip("fused_qk_norm_rope custom op not available") + has_vllm_cuda_kernel = hasattr(torch.ops._C, "fused_qk_norm_rope") + has_aiter_kernel = hasattr(torch.ops.vllm, "rocm_aiter_fused_qk_norm_rope") + + if not has_vllm_cuda_kernel and not has_aiter_kernel: + pytest.skip( + "Neither fused_qk_norm_rope (CUDA) nor rocm_aiter_fused_qk_norm_rope " + "(AITER) custom op is available" + ) torch.set_default_device("cuda") torch.set_default_dtype(dtype) @@ -180,10 +186,12 @@ def test_qk_norm_rope_fusion( model_unfused = torch.compile(model, backend=backend_baseline) q_unfused, k_unfused, v_unfused = model_unfused(qkv_unfused, pos_unfused) + # AITER kernel may have slightly different numerical behavior + # Use the tolerances from the AITER test suite if dtype == torch.float16: - ATOL, RTOL = (2e-3, 2e-3) + ATOL, RTOL = (5e-2, 1e-2) else: - ATOL, RTOL = (1e-2, 1e-2) + ATOL, RTOL = (5e-2, 1e-2) torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL) torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL) diff --git a/tests/kernels/core/test_fused_qk_norm_rope.py b/tests/kernels/core/test_fused_qk_norm_rope.py index a23959e353da..66cc929cd25a 100644 --- a/tests/kernels/core/test_fused_qk_norm_rope.py +++ b/tests/kernels/core/test_fused_qk_norm_rope.py @@ -5,6 +5,7 @@ import torch from tests.kernels.utils import opcheck +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform @@ -139,3 +140,91 @@ def test_fused_qk_norm_rope_matches_reference( atol=ATOL, rtol=RTOL, ) + + +@pytest.mark.skipif( + not ( + current_platform.is_rocm() + and rocm_aiter_ops.is_enabled() + and rocm_aiter_ops.is_fused_qk_norm_rope_enabled() + ), + reason="aiter fused_qk_norm_rope requires rocm platform and " + "VLLM_ROCM_USE_AITER_FUSED_QK_NORM_ROPE=1", +) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("is_neox", IS_NEOX) +@pytest.mark.parametrize("eps", EPS_VALUES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_aiter_fused_qk_norm_rope_matches_reference( + device: str, + dtype: torch.dtype, + is_neox: bool, + eps: float, + seed: int, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + num_heads, num_kv_heads, head_dim = 16, 4, 128 + num_tokens = 4 + + total_dim = (num_heads + 2 * num_kv_heads) * head_dim + qkv_base = torch.randn(num_tokens, total_dim, dtype=dtype, device=device) + qkv_aiter = qkv_base.clone() + positions = torch.arange(num_tokens, dtype=torch.long, device=device) + + q_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + k_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + q_norm.weight.data.normal_(mean=1.0, std=0.1) + k_norm.weight.data.normal_(mean=1.0, std=0.1) + q_weight = q_norm.weight.data + k_weight = k_norm.weight.data + + rope = RotaryEmbedding( + head_size=head_dim, + rotary_dim=head_dim, + max_position_embeddings=4096, + base=10000.0, + is_neox_style=is_neox, + dtype=dtype, + ).to(device) + + ref_result = _apply_qk_norm_rope( + qkv=qkv_base, + positions=positions, + q_norm=q_norm, + k_norm=k_norm, + rope=rope, + num_heads_q=num_heads, + num_heads_kv=num_kv_heads, + head_dim=head_dim, + ) + + # Test aiter kernel + rocm_aiter_ops.fused_qk_norm_rope( + qkv=qkv_aiter, + num_heads_q=num_heads, + num_heads_k=num_kv_heads, + num_heads_v=num_kv_heads, + head_dim=head_dim, + eps=eps, + q_weight=q_weight, + k_weight=k_weight, + cos_sin_cache=rope.cos_sin_cache, + is_neox=is_neox, + position_ids=positions.view(-1), + ) + + # Use relaxed tolerances similar to aiter's test suite + if dtype == torch.float16: + ATOL, RTOL = (5e-2, 1e-2) + else: + ATOL, RTOL = (5e-2, 1e-2) + + torch.testing.assert_close( + qkv_aiter, + ref_result, + atol=ATOL, + rtol=RTOL, + ) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index a25f1a9d0c4f..90af9c99015f 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -462,6 +462,54 @@ def _rocm_aiter_rms_norm_fake( return torch.empty_like(x) +def _rocm_aiter_fused_qk_norm_rope_impl( + qkv: torch.Tensor, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + eps: float, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + position_ids: torch.Tensor, +) -> None: + from aiter import fused_rope_rms + + num_tokens = position_ids.numel() + fused_rope_rms( + qkv=qkv, + qw=q_weight, + kw=k_weight, + cos_sin=cos_sin_cache, + positions=position_ids, + num_tokens=num_tokens, + num_heads_q=num_heads_q, + num_heads_k=num_heads_k, + num_heads_v=num_heads_v, + head_size=head_dim, + is_neox_style=is_neox, + eps=eps, + ) + + +def _rocm_aiter_fused_qk_norm_rope_fake( + qkv: torch.Tensor, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + eps: float, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + position_ids: torch.Tensor, +) -> None: + pass + + def _rocm_aiter_rmsnorm2d_fwd_with_add_impl( x: torch.Tensor, residual: torch.Tensor, @@ -510,6 +558,7 @@ class rocm_aiter_ops: _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + _FUSED_QK_NORM_ROPE = envs.VLLM_ROCM_USE_AITER_FUSED_QK_NORM_ROPE @classmethod @if_aiter_supported @@ -583,6 +632,10 @@ def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: @classmethod @if_aiter_supported def is_triton_rotary_embed_enabled(cls) -> bool: + # Disable Triton RoPE when QK fusion is enabled, as the fusion needs + # the RoPE custom op to be present in the graph for pattern matching + if cls._AITER_ENABLED and cls._FUSED_QK_NORM_ROPE: + return False return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED @classmethod @@ -590,6 +643,12 @@ def is_triton_rotary_embed_enabled(cls) -> bool: def is_triton_gemm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM + @classmethod + @if_aiter_supported + def is_fused_qk_norm_rope_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._FUSED_QK_NORM_ROPE + @staticmethod @if_aiter_supported def register_ops_once() -> None: @@ -703,6 +762,14 @@ def register_ops_once() -> None: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_fused_qk_norm_rope", + op_func=_rocm_aiter_fused_qk_norm_rope_impl, + mutates_args=["qkv"], + fake_impl=_rocm_aiter_fused_qk_norm_rope_fake, + dispatch_key=current_platform.dispatch_key, + ) + _OPS_REGISTERED = True @staticmethod @@ -722,6 +789,34 @@ def rms_norm( ) -> torch.Tensor: return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon) + @staticmethod + def fused_qk_norm_rope( + qkv: torch.Tensor, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + eps: float, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + position_ids: torch.Tensor, + ) -> None: + torch.ops.vllm.rocm_aiter_fused_qk_norm_rope( + qkv, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + eps, + q_weight, + k_weight, + cos_sin_cache, + is_neox, + position_ids, + ) + @staticmethod def gemm_a8w8( A: torch.Tensor, @@ -949,14 +1044,14 @@ def triton_rotary_embed( key_ = key[..., :rotary_dim] positions = positions.view(*query.shape[:1]) rope_cached_thd_positions_2c_fwd_inplace( - positions, - sin, - cos, - query_, - key_, - rotate_style, + positions=positions, + sin=sin, + cos=cos, + x=query_, + y=key_, + rotate_style=rotate_style, reuse_freqs_front_part=True, - is_nope_first=False, + nope_first=False, ) query = query.view(query_shape) key = key.view(key_shape) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 904a7ca39272..23c60d162254 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -27,8 +27,8 @@ RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default if _USE_AITER_RMS_NORM: - RMS_OP = rocm_aiter_ops.rms_norm - RMS_ADD_OP = rocm_aiter_ops.rms_norm2d_with_add + RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default + RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default ROTARY_OP = torch.ops._C.rotary_embedding.default FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default diff --git a/vllm/compilation/qk_norm_rope_fusion.py b/vllm/compilation/qk_norm_rope_fusion.py index e3c399e07906..db85702a967a 100644 --- a/vllm/compilation/qk_norm_rope_fusion.py +++ b/vllm/compilation/qk_norm_rope_fusion.py @@ -9,10 +9,12 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform from .fusion import empty_bf16, empty_fp32, empty_i64 from .inductor_pass import enable_fake_mode @@ -21,7 +23,18 @@ logger = init_logger(__name__) -FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default + +def get_fused_qknorm_rope_op(): + use_aiter = ( + current_platform.is_rocm() + and rocm_aiter_ops.is_fused_qk_norm_rope_enabled() + and hasattr(torch.ops.vllm, "rocm_aiter_fused_qk_norm_rope") + ) + + if use_aiter: + return torch.ops.vllm.rocm_aiter_fused_qk_norm_rope.default + else: + return torch.ops._C.fused_qk_norm_rope.default class QkNormRopePattern: @@ -72,6 +85,8 @@ def __init__( use_flashinfer=self.rope_flashinfer, ) + self.fused_qk_norm_rope_op = get_fused_qknorm_rope_op() + def get_inputs(self): # Sample inputs to help pattern tracing T = 5 @@ -146,7 +161,7 @@ def replacement( ): # Run fused qk_norm_rope op result = auto_functionalized( - FUSED_QK_ROPE_OP, + self.fused_qk_norm_rope_op, qkv=qkv, num_heads_q=self.num_heads, num_heads_k=self.num_kv_heads, diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index a4344e41bf14..59d4bf4a6391 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -195,7 +195,7 @@ def __post_init__(self) -> None: if ( self.enable_aiter_allreduce_rmsnorm_fusion - and not current_platform.is_rcom() + and not current_platform.is_rocm() ): logger.warning_once( "AITER all-reduce + RMSNorm fusion enabled but the current platform" diff --git a/vllm/envs.py b/vllm/envs.py index 7e71a70ab1f1..b1a4033d83a7 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -120,6 +120,7 @@ VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True + VLLM_ROCM_USE_AITER_FUSED_QK_NORM_ROPE: bool = False VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -960,9 +961,9 @@ def get_vllm_port() -> int | None: os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in ("true", "1") ), # Whether to use aiter rope. - # By default is enabled. + # By default is disabled. "VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "True").lower() in ("true", "1") + os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1") ), # Whether to use aiter triton fp8 bmm kernel # By default is enabled. @@ -985,6 +986,12 @@ def get_vllm_port() -> int | None: "VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_TRITON_GEMM", "True").lower() in ("true", "1") ), + # Whether to use aiter fused qk norm rope kernel. + # By default is disabled. + "VLLM_ROCM_USE_AITER_FUSED_QK_NORM_ROPE": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FUSED_QK_NORM_ROPE", "False").lower() + in ("true", "1") + ), # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: ( os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")