diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index d0ba8385f4a0..c8464df3f88d 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools + import pytest import torch import vllm.plugins +from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.matcher_utils import QUANT_OPS @@ -18,6 +21,7 @@ VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -119,13 +123,79 @@ def ops_in_model_before_partial(self): ) +class TestAiterRmsnormGroupFp8QuantModel(torch.nn.Module): + def __init__(self, hidden_size: int, eps: float, **kwargs): + super().__init__() + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(128, 128), + act_quant_group_shape=GroupShape(1, 128), + cutlass_block_fp8_supported=False, + use_aiter_and_is_supported=True, + ) + self.w = [ + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + for _ in range(3) + ] + + scale_hidden_size = (hidden_size + 128 - 1) // 128 + self.wscale = [ + torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32) + for _ in range(3) + ] + + self.norm_weight = [torch.ones(hidden_size) for _ in range(4)] + self.eps = eps + + def forward(self, x): + # avoid having graph input be an arg to a pattern directly + x = resid = torch.relu(x) + y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps) + + x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0]) + # make sure resid is used for replacement to work + y2, resid = rocm_aiter_ops.rms_norm2d_with_add( + x2, resid, self.norm_weight[1], self.eps + ) + + x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1]) + + y3, resid = rocm_aiter_ops.rms_norm2d_with_add( + x3, resid, self.norm_weight[2], self.eps + ) + + x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2]) + + y4, resid = rocm_aiter_ops.rms_norm2d_with_add( + x4, resid, self.norm_weight[3], self.eps + ) + return y4 + + def ops_in_model_before(self): + return [ + torch.ops.vllm.rocm_aiter_rms_norm, + torch.ops.vllm.rocm_aiter_group_fp8_quant, + ] + + def ops_in_model_before_partial(self): + return [] + + def ops_in_model_after(self): + return [ + torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant, + torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant, + ] + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) -@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) +@pytest.mark.parametrize( + "model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op", + list(itertools.product([TestModel], [True, False], [True, False])) + + [(TestAiterRmsnormGroupFp8QuantModel, False, False)], +) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize( @@ -140,10 +210,14 @@ def test_fusion_rmsnorm_quant( num_tokens, eps, static, + model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, cuda_force_torch, ): + if model_class is TestAiterRmsnormGroupFp8QuantModel and not IS_AITER_FOUND: + pytest.skip("AITER is not supported on this GPU.") + torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) @@ -167,12 +241,24 @@ def test_fusion_rmsnorm_quant( with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = RMSNormQuantFusionPass(vllm_config) + if model_class is TestAiterRmsnormGroupFp8QuantModel: + from vllm.compilation.rocm_aiter_fusion import ( + RocmAiterRMSNormFp8GroupQuantFusionPass, + ) + + fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config) + else: + fusion_pass = RMSNormQuantFusionPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) backend2 = TestBackend(noop_pass, cleanup_pass) - model = TestModel(hidden_size, eps, static, cuda_force_torch) + model = model_class( + hidden_size=hidden_size, + eps=eps, + static=static, + cuda_force_torch=cuda_force_torch, + ) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) @@ -202,7 +288,10 @@ def test_fusion_rmsnorm_quant( # there's a risk that the fused add doesn't get included in the # replacement and only the rms part gets fused with quant. # Hence, we check only 2 add nodes are left (final fused rmsnorm add). - if not enable_rms_norm_custom_op: + if ( + not enable_rms_norm_custom_op + and model_class is not TestAiterRmsnormGroupFp8QuantModel + ): n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) assert n_add_nodes(backend.graph_pre_pass) == 7 diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index c336a45955cb..b414ea9dd6c6 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -7,6 +7,7 @@ import vllm.envs as envs from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor +from vllm._aiter_ops import IS_AITER_FOUND from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.compilation.activation_quant_fusion import ( FUSED_OPS, @@ -24,6 +25,7 @@ set_current_vllm_config, ) from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, @@ -126,6 +128,39 @@ def ops_in_model_after(self): return [FUSED_OPS[kNvfp4Quant]] +class TestAiterSiluMulGroupFp8QuantModel(torch.nn.Module): + def __init__(self, hidden_size: int, **kwargs): + super().__init__() + self.silu_and_mul = SiluAndMul() + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(128, 128), + act_quant_group_shape=GroupShape(1, 128), + cutlass_block_fp8_supported=False, + use_aiter_and_is_supported=True, + ) + self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + + scale_hidden_size = (hidden_size + 128 - 1) // 128 + self.wscale = torch.rand( + (scale_hidden_size, scale_hidden_size), dtype=torch.float32 + ) + + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() + + def forward(self, x): + y = self.silu_and_mul(x) + x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale) + return x2 + + def ops_in_model_before(self): + return [ + SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul, + ] + + def ops_in_model_after(self): + return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant] + + @pytest.mark.parametrize("num_tokens", [32, 64]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @@ -133,7 +168,10 @@ def ops_in_model_after(self): @pytest.mark.parametrize( "model_class, enable_quant_fp8_custom_op, cuda_force_torch", list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])) - + [(TestSiluMulNvfp4QuantModel, False, False)], + + [ + (TestSiluMulNvfp4QuantModel, False, False), + (TestAiterSiluMulGroupFp8QuantModel, False, False), + ], ) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant( num_tokens: int, hidden_size: int, dtype: torch.dtype, - model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel], + model_class: type[ + TestSiluMulFp8QuantModel + | TestSiluMulNvfp4QuantModel + | TestAiterSiluMulGroupFp8QuantModel + ], enable_silu_mul_custom_op: bool, enable_quant_fp8_custom_op: bool, cuda_force_torch: bool, ): if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): pytest.skip("NVFP4 is not supported on this GPU.") + if model_class is TestAiterSiluMulGroupFp8QuantModel and not IS_AITER_FOUND: + pytest.skip("AITER is not supported on this GPU.") torch.set_default_device("cuda") torch.set_default_dtype(dtype) @@ -174,6 +218,12 @@ def test_fusion_silu_and_mul_quant( with set_current_vllm_config(config): fusion_pass = ActivationQuantFusionPass(config) + if model_class == TestAiterSiluMulGroupFp8QuantModel: + from vllm.compilation.rocm_aiter_fusion import ( + RocmAiterSiluMulFp8GroupQuantFusionPass, + ) + + fusion_pass = RocmAiterSiluMulFp8GroupQuantFusionPass(config) passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] backend = TestBackend(*passes) @@ -194,6 +244,8 @@ def test_fusion_silu_and_mul_quant( atol, rtol = 1e-3, 1e-3 elif model_class == TestSiluMulNvfp4QuantModel: atol, rtol = 1e-1, 1e-1 + elif model_class == TestAiterSiluMulGroupFp8QuantModel: + atol, rtol = 5e-2, 5e-2 torch.testing.assert_close( result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 35920d826578..344f0e76bc44 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -22,6 +22,15 @@ def is_aiter_found() -> bool: # we keep this global outside to not cause torch compile breaks. IS_AITER_FOUND = is_aiter_found() +# Can't use dtypes.fp8 directly inside an op +# because it returns wrong result on gfx942. +# This is a workaround to get the correct FP8 dtype. +# This might because that the get_gfx() is wrapped as a custom op. +if IS_AITER_FOUND: + from aiter import dtypes + + AITER_FP8_DTYPE = dtypes.fp8 + def if_aiter_supported(func: Callable) -> Callable: """Decorator that only executes the function if @@ -43,36 +52,6 @@ def wrapper(*args, **kwargs): return wrapper -def _rocm_aiter_group_fp8_quant_impl( - x: torch.Tensor, - group_size: int, -) -> tuple[torch.Tensor, torch.Tensor]: - assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size" - from aiter import QuantType, dtypes, get_hip_quant - - aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) - return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8) - - -def _rocm_aiter_group_fp8_quant_fake( - x: torch.Tensor, - group_size: int, -) -> tuple[torch.Tensor, torch.Tensor]: - from aiter import dtypes - - M, N = x.shape - x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device) - out_bs = torch.empty( - ( - M, - (N + group_size - 1) // group_size, - ), - dtype=torch.float32, - device=x.device, - ) - return x_fp8, out_bs - - def _rocm_aiter_fused_moe_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -392,6 +371,31 @@ def _rocm_aiter_gemm_a8w8_fake( return Y +def _rocm_aiter_triton_gemm_a8w8_blockscale_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter.ops.triton.gemm.basic.gemm_a8w8_blockscale import gemm_a8w8_blockscale + + return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + + +def _rocm_aiter_triton_gemm_a8w8_blockscale_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + def _rocm_aiter_gemm_a8w8_blockscale_impl( A: torch.Tensor, B: torch.Tensor, @@ -417,6 +421,54 @@ def _rocm_aiter_gemm_a8w8_blockscale_fake( return Y +def _rocm_aiter_triton_gemm_afp4wfp4_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter.ops.triton.gemm.basic.gemm_afp4wfp4 import gemm_afp4wfp4 + + return gemm_afp4wfp4(A, B, As, Bs.T, dtype=output_dtype) + + +def _rocm_aiter_triton_gemm_afp4wfp4_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +def _rocm_aiter_triton_gemm_a16w8_blockscale_impl( + A: torch.Tensor, + B: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter.ops.triton.gemm.basic.gemm_a16w8_blockscale import gemm_a16w8_blockscale + + return gemm_a16w8_blockscale(A, B, Bs, dtype=output_dtype) + + +def _rocm_aiter_triton_gemm_a16w8_blockscale_fake( + A: torch.Tensor, + B: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + def _rocm_aiter_rms_norm_impl( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: @@ -467,6 +519,454 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( return torch.empty_like(x), torch.empty_like(residual) +def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.quant.fused_fp8_quant import fused_rms_fp8_group_quant + + (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant( + x, + weight, + variance_epsilon, + None, + None, + None, + group_size=group_size, + dtype_quant=AITER_FP8_DTYPE, + res1=residual, + ) + return (x_quant, x_quant_scales, res) + + +def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + M, N = x.shape + scale_shape = (M, (N + group_size - 1) // group_size) + return ( + torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), + torch.empty(scale_shape, dtype=torch.float32, device=x.device), + torch.empty_like(residual, device=residual.device), + ) + + +def _rocm_aiter_rmsnorm_fp8_group_quant_impl( + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.ops.triton.quant.fused_fp8_quant import fused_rms_fp8_group_quant + + (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant( + x, + weight, + variance_epsilon, + None, + None, + None, + group_size=group_size, + dtype_quant=AITER_FP8_DTYPE, + res1=None, + ) + return (x_quant, x_quant_scales) + + +def _rocm_aiter_rmsnorm_fp8_group_quant_fake( + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + scale_shape = (M, (N + group_size - 1) // group_size) + return ( + torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), + torch.empty(scale_shape, dtype=torch.float32, device=x.device), + ) + + +def _rocm_aiter_2rmsnorm_1fp8_group_quant_impl( + x1: torch.Tensor, + x2: torch.Tensor, + weight1: torch.Tensor, + variance_epsilon1: float, + weight2: torch.Tensor, + variance_epsilon2: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.quant.fused_fp8_quant import fused_rms_fp8_group_quant + + (x_quant, x_quant_scales), _, x2_out, _ = fused_rms_fp8_group_quant( + x1, + weight1, + variance_epsilon1, + x2, + weight2, + variance_epsilon2, + group_size=group_size, + dtype_quant=AITER_FP8_DTYPE, + res1=None, + ) + return (x_quant, x_quant_scales, x2_out) + + +def _rocm_aiter_2rmsnorm_1fp8_group_quant_fake( + x1: torch.Tensor, + x2: torch.Tensor, + weight1: torch.Tensor, + variance_epsilon1: float, + weight2: torch.Tensor, + variance_epsilon2: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + M, N = x1.shape + scale_shape = ( + M, + (N + group_size - 1) // group_size, + ) + return ( + torch.empty_like(x1, dtype=AITER_FP8_DTYPE, device=x1.device), + torch.empty(scale_shape, dtype=torch.float32, device=x1.device), + torch.empty_like(x2, dtype=x2.dtype, device=x1.device), + ) + + +def _rocm_aiter_group_fp8_quant_impl( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size" + from aiter import QuantType, get_hip_quant + + aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) + return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE) + + +def _rocm_aiter_group_fp8_quant_fake( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device) + out_bs = torch.empty( + ( + M, + (N + group_size - 1) // group_size, + ), + dtype=torch.float32, + device=x.device, + ) + return x_fp8, out_bs + + +def _rocm_aiter_act_mul_and_fp8_group_quant_impl( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.ops.triton.activation import act_mul_and_fp8_group_quant + + return act_mul_and_fp8_group_quant( + x, + activation="silu", + group_size=group_size, + dtype_quant=AITER_FP8_DTYPE, + ) + + +def _rocm_aiter_act_mul_and_fp8_group_quant_fake( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + assert N % 2 == 0 + N_half = N // 2 + x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device) + out_bs = torch.empty( + ( + M, + (N_half + group_size - 1) // group_size, + ), + dtype=torch.float32, + device=x.device, + ) + return x_fp8, out_bs + +def _rocm_aiter_triton_fused_shared_expert_fp8_impl( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_gate_up: torch.Tensor, + weight_scale_gate_up: torch.Tensor, + hidden_states_moe_gate: torch.Tensor, + weight_moe_gate: torch.Tensor, + bias_shared: torch.Tensor, + bias_moe_gate: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.gemm.fused.fused_gemm_a8w8_blockscale_a16w16 import fused_gemm_a8w8_blockscale_a16w16 + from aiter.ops.triton.quant.fused_fp8_quant import fused_reduce_act_mul_fp8_group_quant + + shared_output, router_logits = fused_gemm_a8w8_blockscale_a16w16(hidden_states_shared, weight_gate_up, hidden_states_shared_scale, weight_scale_gate_up, hidden_states_moe_gate, weight_moe_gate, + bias_fp8=bias_shared, bias_bf16=bias_moe_gate, dtype=hidden_states_moe_gate.dtype, skip_reduce=True) + if shared_output.dim() == 3: + (shared_output_q, shared_output_s), router_logits = fused_reduce_act_mul_fp8_group_quant(shared_output, activation="silu", x2=router_logits, group_size=128, dtype_quant=AITER_FP8_DTYPE) + else: + (shared_output_q, shared_output_s), _ = fused_reduce_act_mul_fp8_group_quant(shared_output, activation="silu", x2=None, group_size=128, dtype_quant=AITER_FP8_DTYPE) + return shared_output_q, shared_output_s, router_logits + +def _rocm_aiter_triton_fused_shared_expert_fp8_fake( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_gate_up: torch.Tensor, + weight_scale_gate_up: torch.Tensor, + hidden_states_moe_gate: torch.Tensor, + weight_moe_gate: torch.Tensor, + bias_shared: torch.Tensor, + bias_moe_gate: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + M = hidden_states_shared.shape[0] + N = weight_gate_up.shape[0] + N_moe = weight_moe_gate.shape[0] + device = hidden_states_shared.device + group_size = 128 + assert N % 2 == 0 + N_half = N // 2 + assert N_half == N_moe, f"{weight_moe_gate.shape}" + shared_output_q = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=device) + shared_output_s = torch.empty((M, (N_half + group_size - 1) // group_size), dtype=torch.float32, device=device) + router_logits = torch.empty((M, N_moe), dtype=hidden_states_moe_gate.dtype, device=device) + return shared_output_q, shared_output_s, router_logits + + +def _rocm_aiter_triton_fused_down_proj_mul_add_fp8_impl( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_down_proj: torch.Tensor, + weight_scale_down_proj: torch.Tensor, + routed_scaling_factor: float, + final_hidden_states: torch.Tensor, +) -> torch.Tensor: + from aiter.ops.triton.gemm.fused.fused_gemm_a8w8_blockscale_mul_add import fused_gemm_a8w8_blockscale_mul_add + + out = fused_gemm_a8w8_blockscale_mul_add(hidden_states_shared, weight_down_proj, hidden_states_shared_scale, weight_scale_down_proj, routed_scaling_factor, final_hidden_states, fuse_type=1) + return out + + +def _rocm_aiter_triton_fused_down_proj_mul_add_fp8_fake( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_down_proj: torch.Tensor, + weight_scale_down_proj: torch.Tensor, + routed_scaling_factor: float, + final_hidden_states: torch.Tensor, +) -> torch.Tensor: + out = torch.empty_like(final_hidden_states) + return out + +def _rocm_aiter_triton_fused_shared_expert_fp4_impl( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_gate_up: torch.Tensor, + weight_scale_gate_up: torch.Tensor, + hidden_states_moe_gate: torch.Tensor, + weight_moe_gate: torch.Tensor, + bias_shared: torch.Tensor, + bias_moe_gate: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.gemm.fused.fused_gemm_afp4wfp4_a16w16 import fused_gemm_afp4wfp4_a16w16 + from aiter.ops.triton.quant.fused_mxfp4_quant import fused_reduce_act_mul_and_mxfp4_quant + + shared_output, router_logits = fused_gemm_afp4wfp4_a16w16(hidden_states_shared, weight_gate_up, hidden_states_shared_scale, weight_scale_gate_up.T, hidden_states_moe_gate, weight_moe_gate, + is_fp4_preshuffled=False, bias_fp4=bias_shared, bias_bf16=bias_moe_gate, dtype=hidden_states_moe_gate.dtype, skip_reduce=True) + if shared_output.dim() == 3: + (shared_output_q, shared_output_s), router_logits = fused_reduce_act_mul_and_mxfp4_quant(shared_output, activation="silu", x2=router_logits, shuffle=False, scale_shuffle_padding=False, dtype=hidden_states_moe_gate.dtype) + else: + (shared_output_q, shared_output_s), _ = fused_reduce_act_mul_and_mxfp4_quant(shared_output, activation="silu", x2=None, shuffle=False, scale_shuffle_padding=False, dtype=hidden_states_moe_gate.dtype) + + # assert bias_shared is None + # shared_output = gemm_afp4wfp4(hidden_states_shared, weight_gate_up, hidden_states_shared_scale, weight_scale_gate_up.T) + # router_logits = gemm_a16w16(hidden_states_moe_gate, weight_moe_gate, bias=bias_moe_gate) + # shared_output_q, shared_output_s = act_mul_and_mxfp4_quant(shared_output, activation="silu", shuffle=False, scale_shuffle_padding=False) + + return shared_output_q, shared_output_s, router_logits + + +def _rocm_aiter_triton_fused_shared_expert_fp4_fake( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_gate_up: torch.Tensor, + weight_scale_gate_up: torch.Tensor, + hidden_states_moe_gate: torch.Tensor, + weight_moe_gate: torch.Tensor, + bias_shared: torch.Tensor, + bias_moe_gate: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + M = hidden_states_shared.shape[0] + N = weight_gate_up.shape[0] + N_moe = weight_moe_gate.shape[0] + device = hidden_states_shared.device + group_size = 32 + assert N % 4 == 0 + N_half = N // 2 + assert N_half == 256, f"{weight_gate_up.shape}" + assert N_half == N_moe, f"{weight_moe_gate.shape}" + shared_output_q = torch.empty((M, N_half // 2), dtype=torch.uint8, device=device) + shared_output_s = torch.empty((M, (N_half + group_size - 1) // group_size), dtype=torch.uint8, device=device) + router_logits = torch.empty((M, N_moe), dtype=hidden_states_moe_gate.dtype, device=device) + return shared_output_q, shared_output_s, router_logits + + +def _rocm_aiter_triton_fused_down_proj_mul_add_fp4_impl( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_down_proj: torch.Tensor, + weight_scale_down_proj: torch.Tensor, + routed_scaling_factor: float, + final_hidden_states: torch.Tensor, +) -> torch.Tensor: + from aiter.ops.triton.gemm.fused.fused_gemm_afp4wfp4_mul_add import fused_gemm_afp4wfp4_mul_add + + out = fused_gemm_afp4wfp4_mul_add(hidden_states_shared, weight_down_proj, hidden_states_shared_scale, weight_scale_down_proj.T, routed_scaling_factor, final_hidden_states, fuse_type=1) + return out + + +def _rocm_aiter_triton_fused_down_proj_mul_add_fp4_fake( + hidden_states_shared: torch.Tensor, + hidden_states_shared_scale: torch.Tensor, + weight_down_proj: torch.Tensor, + weight_scale_down_proj: torch.Tensor, + routed_scaling_factor: float, + final_hidden_states: torch.Tensor, +) -> torch.Tensor: + out = torch.empty_like(final_hidden_states) + return out + +def _rocm_aiter_triton_qkv_a_proj_layernorm_fp8_impl( + hidden_states_quant: torch.Tensor, + hidden_states_quant_scale: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.quant.fused_fp8_quant import fused_reduce_rms_fp8_group_quant + from aiter.ops.triton.gemm.basic.gemm_a8w8_blockscale import gemm_a8w8_blockscale + import aiter as rocm_aiter + qkv_lora = gemm_a8w8_blockscale(hidden_states_quant, weight_qkv_a_proj, hidden_states_quant_scale, weight_scale_qkv_a_proj, skip_reduce=True) + q_c, kv_c, k_pe = qkv_lora.split([q_lora_rank, kv_lora_rank, qk_rope_head_dim], + dim=-1, + ) + k_pe_reduced = None + k_pe_reduced_out = None + if k_pe.dim() == 3: + M = hidden_states_quant.shape[0] + device = hidden_states_quant.device + k_pe_reduced = k_pe + k_pe_reduced_out = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + (q_c, q_c_scale), _, kv_c_normed, _, k_pe_reduced_out = fused_reduce_rms_fp8_group_quant(q_c, q_a_layernorm_weight, q_a_layernorm_variance_epsilon, + kv_c, kv_a_layernorm_weight, kv_a_layernorm_variance_epsilon, k_pe_reduced, + group_size=128, + dtype_quant=AITER_FP8_DTYPE, + dtype=torch.bfloat16, + res1=None, + out3=k_pe_reduced_out) + if k_pe_reduced_out is not None: + k_pe = k_pe_reduced_out + return q_c, q_c_scale, kv_c_normed, k_pe + +def _rocm_aiter_triton_qkv_a_proj_layernorm_fp8_fake( + hidden_states_quant: torch.Tensor, + hidden_states_quant_scale: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + M = hidden_states_quant.shape[0] + device = hidden_states_quant.device + q_c = torch.empty((M, q_lora_rank), dtype=AITER_FP8_DTYPE, device=device) + q_c_scale = torch.empty((M, (q_lora_rank + 128 - 1) // 128), dtype=torch.float32, device=device) + kv_c_normed = torch.empty((M, kv_lora_rank), dtype=torch.bfloat16, device=device) + k_pe = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + return q_c, q_c_scale, kv_c_normed, k_pe + +def _rocm_aiter_triton_qkv_a_proj_layernorm_fp4_impl( + hidden_states_quant: torch.Tensor, + hidden_states_quant_scale: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.gemm.basic.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.quant.fused_mxfp4_quant import fused_reduce_rms_mxfp4_quant + + qkv_lora = gemm_afp4wfp4(hidden_states_quant, weight_qkv_a_proj, hidden_states_quant_scale, weight_scale_qkv_a_proj.T, skip_reduce=True) + q_c, kv_c, k_pe = qkv_lora.split([q_lora_rank, kv_lora_rank, qk_rope_head_dim], + dim=-1, + ) + k_pe_reduced = None + k_pe_reduced_out = None + if k_pe.dim() == 3: + M = hidden_states_quant.shape[0] + device = hidden_states_quant.device + k_pe_reduced = k_pe + k_pe_reduced_out = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + (q_c, q_c_scale), _, kv_c_normed, _, k_pe_reduced_out = fused_reduce_rms_mxfp4_quant(q_c, q_a_layernorm_weight, q_a_layernorm_variance_epsilon, + kv_c, kv_a_layernorm_weight, kv_a_layernorm_variance_epsilon, k_pe_reduced, + res1=None, + shuffle=False, + scale_shuffle_padding=False, + dtype=torch.bfloat16, + out3=k_pe_reduced_out) + + if k_pe_reduced_out is not None: + k_pe = k_pe_reduced_out + return q_c, q_c_scale, kv_c_normed, k_pe + +def _rocm_aiter_triton_qkv_a_proj_layernorm_fp4_fake( + hidden_states_quant: torch.Tensor, + hidden_states_quant_scale: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + M = hidden_states_quant.shape[0] + device = hidden_states_quant.device + q_c = torch.empty((M, q_lora_rank // 2), dtype=torch.uint8, device=device) + q_c_scale = torch.empty((M, (q_lora_rank + 32 - 1) // 32), dtype=torch.float32, device=device) + kv_c_normed = torch.empty((M, kv_lora_rank), dtype=torch.bfloat16, device=device) + k_pe = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + return q_c, q_c_scale, kv_c_normed, k_pe + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -481,9 +981,11 @@ class rocm_aiter_ops: _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + _FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + _TRITON_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM @classmethod @@ -502,7 +1004,7 @@ def is_linear_enabled(cls) -> bool: @if_aiter_supported def is_linear_fp8_enaled(cls) -> bool: """ "Verifies device specs and availability of env variable.""" - return cls.is_linear_enabled() and current_platform.is_fp8_fnuz() + return cls.is_linear_enabled() @classmethod @if_aiter_supported @@ -520,6 +1022,11 @@ def is_fused_moe_enabled(cls) -> bool: @if_aiter_supported def is_fusion_moe_shared_experts_enabled(cls) -> bool: return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED + + @classmethod + @if_aiter_supported + def is_fusion_triton_shared_experts_enabled(cls) -> bool: + return cls.is_fused_moe_enabled() and cls._TRITON_SHARED_EXPERTS_ENABLED @classmethod @if_aiter_supported @@ -549,6 +1056,12 @@ def is_triton_unified_attn_enabled(cls) -> bool: @if_aiter_supported def is_fp8bmm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FP8BMM_ENABLED + + @classmethod + @if_aiter_supported + def is_fp4bmm_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and current_platform.supports_mx() @classmethod @if_aiter_supported @@ -577,14 +1090,6 @@ def register_ops_once() -> None: ) # register all the custom ops here - direct_register_custom_op( - op_name="rocm_aiter_group_fp8_quant", - op_func=_rocm_aiter_group_fp8_quant_impl, - mutates_args=[], - fake_impl=_rocm_aiter_group_fp8_quant_fake, - dispatch_key=current_platform.dispatch_key, - ) - direct_register_custom_op( op_name="rocm_aiter_asm_moe_tkw1", op_func=_rocm_aiter_asm_moe_tkw1_impl, @@ -641,30 +1146,122 @@ def register_ops_once() -> None: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_triton_gemm_a8w8_blockscale", + op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl, + fake_impl=_rocm_aiter_triton_gemm_a8w8_blockscale_fake, + ) + direct_register_custom_op( op_name="rocm_aiter_gemm_a8w8_blockscale", op_func=_rocm_aiter_gemm_a8w8_blockscale_impl, - mutates_args=[], fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake, - dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_gemm_afp4wfp4", + op_func=_rocm_aiter_triton_gemm_afp4wfp4_impl, + fake_impl=_rocm_aiter_triton_gemm_afp4wfp4_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_gemm_a16w8_blockscale", + op_func=_rocm_aiter_triton_gemm_a16w8_blockscale_impl, + fake_impl=_rocm_aiter_triton_gemm_a16w8_blockscale_fake, ) direct_register_custom_op( op_name="rocm_aiter_rms_norm", op_func=_rocm_aiter_rms_norm_impl, - mutates_args=[], fake_impl=_rocm_aiter_rms_norm_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_rmsnorm2d_fwd_with_add", op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl, - mutates_args=[], fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_fp8_group_quant", + op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl, + fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_2rmsnorm_1fp8_group_quant", + op_func=_rocm_aiter_2rmsnorm_1fp8_group_quant_impl, + mutates_args=[], + fake_impl=_rocm_aiter_2rmsnorm_1fp8_group_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant", + op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl, + fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_group_fp8_quant", + op_func=_rocm_aiter_group_fp8_quant_impl, + fake_impl=_rocm_aiter_group_fp8_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_act_mul_and_fp8_group_quant", + op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl, + fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_fused_shared_expert_fp8", + op_func=_rocm_aiter_triton_fused_shared_expert_fp8_impl, + mutates_args=[], + fake_impl=_rocm_aiter_triton_fused_shared_expert_fp8_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_fused_down_proj_mul_add_fp8", + op_func=_rocm_aiter_triton_fused_down_proj_mul_add_fp8_impl, + mutates_args=[], + fake_impl=_rocm_aiter_triton_fused_down_proj_mul_add_fp8_fake, dispatch_key=current_platform.dispatch_key, ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_fused_shared_expert_fp4", + op_func=_rocm_aiter_triton_fused_shared_expert_fp4_impl, + mutates_args=[], + fake_impl=_rocm_aiter_triton_fused_shared_expert_fp4_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_fused_down_proj_mul_add_fp4", + op_func=_rocm_aiter_triton_fused_down_proj_mul_add_fp4_impl, + mutates_args=[], + fake_impl=_rocm_aiter_triton_fused_down_proj_mul_add_fp4_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_qkv_a_proj_layernorm_fp8", + op_func=_rocm_aiter_triton_qkv_a_proj_layernorm_fp8_impl, + mutates_args=[], + fake_impl=_rocm_aiter_triton_qkv_a_proj_layernorm_fp8_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_triton_qkv_a_proj_layernorm_fp4", + op_func=_rocm_aiter_triton_qkv_a_proj_layernorm_fp4_impl, + mutates_args=[], + fake_impl=_rocm_aiter_triton_qkv_a_proj_layernorm_fp4_fake, + dispatch_key=current_platform.dispatch_key, + ) _OPS_REGISTERED = True @staticmethod @@ -867,7 +1464,7 @@ def triton_fp4_gemm_dynamic_qaunt( out_dtype: torch.dtype | None = torch.bfloat16, x_scales: torch.Tensor | None = None, ) -> torch.Tensor: - from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.gemm.basic.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant if x_scales is None: @@ -893,7 +1490,7 @@ def triton_rotary_embed( rotary_dim: int, is_neox_style: bool, ): - from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace + from aiter.ops.triton.rope.rope import rope_cached_thd_positions_2c_fwd_inplace num_tokens = positions.numel() cos, sin = cos_sin_cache.chunk(2, dim=-1) @@ -919,6 +1516,34 @@ def triton_rotary_embed( query = query.view(query_shape) key = key.view(key_shape) + @staticmethod + def triton_fp4_bmm( + X: torch.Tensor, + WQ: torch.Tensor, + w_scale: torch.Tensor, + dtype: torch.dtype | None = torch.bfloat16, + YQ: torch.Tensor | None = None, + transpose_bm: bool | None = False, + config: dict | None = None, + y_scale: dict | None = None, + ) -> torch.Tensor: + # ruff: noqa: E501 # isort: skip + from aiter.ops.triton.gemm.batched.batched_gemm_a16wfp4 import ( + batched_gemm_a16wfp4 as aiter_triton_fp4_bmm, + ) + + return aiter_triton_fp4_bmm( + X, + WQ, + w_scale, + dtype=dtype, + y=YQ, + config=config, + transpose_bm=transpose_bm, + prequant=True, + y_scale=y_scale, + ) + @staticmethod def triton_fp8_bmm( X: torch.Tensor, @@ -933,7 +1558,7 @@ def triton_fp8_bmm( config: dict | None = None, ) -> torch.Tensor: # ruff: noqa: E501 # isort: skip - from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + from aiter.ops.triton.gemm.batched.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, ) @@ -959,9 +1584,20 @@ def triton_gemm_a8w8_blockscale( block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale - - return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + return torch.ops.vllm.rocm_aiter_triton_gemm_a8w8_blockscale( + A, B, As, Bs, output_dtype + ) + + @staticmethod + def triton_gemm_a16w8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_triton_gemm_a16w8_blockscale( + A, B, Bs, output_dtype + ) @staticmethod def group_fp8_quant( @@ -1041,7 +1677,5 @@ def shuffle_weights( """ from aiter.ops.shuffle import shuffle_weight - return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) - - + return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) rocm_aiter_ops.register_ops_once() diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index da5a62617129..edac42d4da82 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -185,6 +185,7 @@ def __init__( attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, attn_backend: type[AttentionBackend] | None = None, + rotary_emb: nn.Module | None = None, **extra_impl_args, ) -> None: """ @@ -260,6 +261,7 @@ def __init__( kv_sharing_target_layer_name, **extra_impl_args, ) + self.impl.rotary_emb = rotary_emb backend_name = self.attn_backend.get_name() self.backend = AttentionBackendEnum.__members__.get(backend_name) self.dtype = dtype @@ -316,6 +318,7 @@ def forward( # shape does not match the query shape, so we optionally let the model # definition specify the output tensor shape. output_shape: torch.Size | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: """ The KV cache is stored inside this class and is accessed via @@ -365,7 +368,7 @@ def forward( ) else: torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name + query, key, value, output, self.layer_name, positions=positions ) return output.view(-1, hidden_size) else: @@ -587,6 +590,7 @@ def __init__( prefix: str = "", use_sparse: bool = False, indexer: object | None = None, + rotary_emb: nn.Module | None = None, **extra_impl_args, ): super().__init__() @@ -646,6 +650,7 @@ def __init__( indexer=indexer, **extra_impl_args, ) + self.impl.rotary_emb = rotary_emb self.use_direct_call = not current_platform.opaque_attention_op() @@ -674,6 +679,7 @@ def forward( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output_shape: torch.Size | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) @@ -710,6 +716,7 @@ def forward( k_pe, output, self.layer_name, + positions, ) return output else: @@ -864,8 +871,25 @@ def unified_attention_with_output( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) + if positions is not None: + assert hasattr(self.impl, "rotary_emb") and self.impl.rotary_emb is not None + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + positions=positions, + ) + return + self.impl.forward( self, query, @@ -887,6 +911,7 @@ def unified_attention_with_output_fake( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> None: return @@ -937,21 +962,37 @@ def unified_mla_attention_with_output( k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) - self.impl.forward( - self, - q, - kv_c_normed, - k_pe, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - output_block_scale=output_block_scale, - ) + from vllm.v1.attention.backends.mla.rocm_aiter_mla import AiterMLAImpl + if isinstance(self.impl, AiterMLAImpl): + self.impl.forward( + self, + q, + kv_c_normed, + k_pe, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + positions=positions, + ) + else: + self.impl.forward( + self, + q, + kv_c_normed, + k_pe, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + ) def unified_mla_attention_with_output_fake( diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 37f48721ea20..d285478b7218 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -5,6 +5,7 @@ from torch import fx as fx from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform @@ -13,6 +14,12 @@ from .post_cleanup import PostCleanupPass from .vllm_inductor_pass import VllmInductorPass +if rocm_aiter_ops.is_enabled(): + from vllm.compilation.rocm_aiter_fusion import ( + RocmAiterRMSNormFp8GroupQuantFusionPass, + RocmAiterSiluMulFp8GroupQuantFusionPass, + ) + if current_platform.is_cuda_alike(): from .activation_quant_fusion import ActivationQuantFusionPass from .fusion import RMSNormQuantFusionPass @@ -107,6 +114,9 @@ def configure(self, config: VllmConfig): self.passes += [RMSNormQuantFusionPass(config)] if self.pass_config.fuse_act_quant: self.passes += [ActivationQuantFusionPass(config)] + if rocm_aiter_ops.is_enabled(): + self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)] + self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)] if self.pass_config.fuse_attn_quant: self.passes += [AttnFusionPass(config)] diff --git a/vllm/compilation/rocm_aiter_fusion.py b/vllm/compilation/rocm_aiter_fusion.py new file mode 100644 index 000000000000..fd5bf8fb3de2 --- /dev/null +++ b/vllm/compilation/rocm_aiter_fusion.py @@ -0,0 +1,336 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._ops import OpOverload + +import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 +from vllm.compilation.activation_quant_fusion import ActivationQuantPattern +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .fusion import empty_bf16 +from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherSiluAndMul +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass + +logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() + +AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default +AITER_RMS_ADD_GROUP_QUANT_OP = ( + torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default +) + +AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default +AITER_2RMS_1GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_2rmsnorm_1fp8_group_quant.default +AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default + +AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default +TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default + +FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default + +SPLIT_WITH_SIZES_OP = torch.ops.aten.split_with_sizes.default + +class AiterRMSFp8GroupQuantPattern: + """ + This pattern fuses aiter rms_norm & group fp8 quant custom + ops into an aiter rms_norm_group_fp8_quant op. + """ + + def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload): + self.epsilon = epsilon + self.quant_dtype = quant_dtype + self.quant_op = quant_op + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + ): + at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon) + + at2 = self.quant_op(at1, 128) + + return at2[0], at2[1] + + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + ): + at = AITER_RMS_GROUP_QUANT_OP( + x=input, + weight=weight, + variance_epsilon=self.epsilon, + group_size=128, + ) + + return at[0], at[1] + + inputs = [ + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + +class Aiter2RMS1GroupQuantFP8Pattern: + """ + This pattern fuses aiter rms_norm & group fp8 quant custom for input1 and + rms_norm for input2 + ops into an aiter rms_norm_group_fp8_quant op. + """ + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + quant_op: OpOverload, + hidden_size1: int, + hidden_size2: int, + hidden_size3: int, + ): + self.epsilon = epsilon + self.quant_dtype = quant_dtype + self.quant_op = quant_op + self.hidden_size1 = hidden_size1 + self.hidden_size2 = hidden_size2 + self.hidden_size3 = hidden_size3 + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + weight1: torch.Tensor, + weight2: torch.Tensor, + ): + input1, input_split_0 = SPLIT_WITH_SIZES_OP( + input, + [self.hidden_size1, self.hidden_size2 + self.hidden_size3], + dim=-1, + ) + input2, at4 = SPLIT_WITH_SIZES_OP( + input_split_0, [self.hidden_size2, self.hidden_size3], dim=-1 + ) + + at1 = AITER_RMS_OP(x=input1, weight=weight1, variance_epsilon=self.epsilon) + at2 = self.quant_op(at1, 128) + at3 = AITER_RMS_OP(x=input2, weight=weight2, variance_epsilon=self.epsilon) + + return at2[0], at2[1], at3, at4 + + def replacement( + input: torch.Tensor, + weight1: torch.Tensor, + weight2: torch.Tensor, + ): + input1, input_split_0 = SPLIT_WITH_SIZES_OP( + input, + [self.hidden_size1, self.hidden_size2 + self.hidden_size3], + dim=-1, + ) + input2, at4 = SPLIT_WITH_SIZES_OP( + input_split_0, [self.hidden_size2, self.hidden_size3], dim=-1 + ) + + at = AITER_2RMS_1GROUP_QUANT_OP( + x1=input1, + x2=input2, + weight1=weight1, + variance_epsilon1=self.epsilon, + weight2=weight2, + variance_epsilon2=self.epsilon, + group_size=128, + ) + + return at[0], at[1], at[2], at4 + + inputs = [ + empty_bf16( + 5, self.hidden_size1 + self.hidden_size2 + self.hidden_size3 + ), # input + empty_bf16(1, self.hidden_size1), # weight1 + empty_bf16(1, self.hidden_size2), # weight2 + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + +class AiterFusedAddRMSFp8GroupQuantPattern: + """ + This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops + into a aiter rms_norm_with_add_group_fp8_quant op. + """ + + def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload): + self.epsilon = epsilon + self.quant_dtype = quant_dtype + self.quant_op = quant_op + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + ): + at1 = AITER_RMS_ADD_OP( + x=input, + residual=residual, + weight=weight, + variance_epsilon=self.epsilon, + ) + + at2 = self.quant_op(at1[0], 128) + + # result, scale, residual + return at2[0], at2[1], at1[1] + + def replacement( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + ): + at = AITER_RMS_ADD_GROUP_QUANT_OP( + x=input, + residual=residual, + weight=weight, + variance_epsilon=self.epsilon, + group_size=128, + ) + + # result, scale, residual + return at[0], at[1], at[2] + + inputs = [ + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + +class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass): + """ + This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. + It also supports fused_add_rms_norm. + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass" + ) + + # Make sure fused add patterns are before simple rms norm, + # as the latter is a subset of the former in torch ops + for epsilon in [1e-5, 1e-6]: + # Fuse rms_norm + dynamic group fp8 quant + for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]: + for hidden_size1, hidden_size2, hidden_size3 in [(1536, 512, 64)]: + Aiter2RMS1GroupQuantFP8Pattern( + epsilon, + FP8_DTYPE, + quant_op, + hidden_size1, + hidden_size2, + hidden_size3, + ).register(self.patterns) + + AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register( + self.patterns + ) + + AiterFusedAddRMSFp8GroupQuantPattern( + epsilon, FP8_DTYPE, quant_op + ).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph): + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def uuid(self) -> Any: + fusion_patterns = [ + Aiter2RMS1GroupQuantFP8Pattern, + AiterRMSFp8GroupQuantPattern, + AiterFusedAddRMSFp8GroupQuantPattern, + ] + return self.hash_source(self, *fusion_patterns) + + +class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): + """ + This pattern fuses aiter silu_and_mul & group fp8 quant custom + ops into an aiter silu_and_mul_group_fp8_quant op. + """ + + def __init__(self, quant_op: OpOverload): + self.silu_and_mul_matcher = MatcherSiluAndMul() + self.quant_op = quant_op + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + ): + at1 = self.silu_and_mul_matcher(input) + at2 = self.quant_op(at1, 128) + return at2[0], at2[1] + + def replacement( + input: torch.Tensor, + ): + at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128) + return at[0], at[1] + + inputs = [ + self.silu_and_mul_matcher.inputs()[0], + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + +class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass" + ) + + for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]: + AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: torch.fx.Graph): + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def uuid(self): + fusion_patterns = [ + ActivationQuantPattern, + AiterSiluMulFp8GroupQuantPattern, + ] + return VllmInductorPass.hash_source(self, *fusion_patterns) diff --git a/vllm/envs.py b/vllm/envs.py index 37711dece9ab..aa49fa2edf58 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -117,8 +117,10 @@ VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True + VLLM_ROCM_USE_AITER_FP4BMM: bool = False VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False + VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS: bool = False VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True @@ -981,6 +983,11 @@ def get_vllm_port() -> int | None: "VLLM_ROCM_USE_AITER_FP8BMM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1") ), + # Whether to use aiter triton fp4 bmm kernel + # By default is disabled. + "VLLM_ROCM_USE_AITER_FP4BMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "False").lower() in ("true", "1") + ), # Use AITER triton unified attention for V1 attention "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower() @@ -992,6 +999,12 @@ def get_vllm_port() -> int | None: os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "False").lower() in ("true", "1") ), + # Whether to use aiter fusion triton shared experts ops. + # By default is disabled. + "VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS", "False").lower() + in ("true", "1") + ), # Whether to use aiter triton kernels for gemm ops. # By default is enabled. "VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: ( diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index dad960160f2a..8998a249fa43 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -8,7 +8,8 @@ from vllm.config import CacheConfig from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization import QuantizationConfig - +from vllm.platforms import current_platform +from vllm._aiter_ops import rocm_aiter_ops @dataclass class MLAModules: @@ -103,54 +104,94 @@ def __init__( kv_b_proj=self.kv_b_proj, use_sparse=self.is_sparse, indexer=self.indexer, + rotary_emb = self.rotary_emb if current_platform.is_rocm() else None ) + self.use_aiter_triton = rocm_aiter_ops.is_enabled() + self.use_triton_qkv_a_proj_layernrom_fp8 = rocm_aiter_ops.is_enabled() and quant_config.get_name() == 'fp8' + self.use_triton_qkv_a_proj_layernrom_fp4 = rocm_aiter_ops.is_enabled() and quant_config.get_name() == 'quark' self.prefix = prefix def forward_native( self, positions: torch.Tensor, - hidden_states: torch.Tensor, + hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor], llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: q_c = None kv_lora = None - if self.q_lora_rank is not None: - assert self.fused_qkv_a_proj is not None, ( - "fused_qkv_a_proj is required when q_lora_rank is not None" - ) - assert self.q_a_layernorm is not None, ( - "q_a_layernorm is required when q_lora_rank is not None" - ) - assert self.q_b_proj is not None, ( - "q_b_proj is required when q_lora_rank is not None" - ) - qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] - q_c, kv_lora = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - q_c = self.q_a_layernorm(q_c) - q = self.q_b_proj(q_c)[0] + if self.use_triton_qkv_a_proj_layernrom_fp4: + assert self.q_lora_rank is not None + assert isinstance(hidden_states, tuple) + hidden_states, hidden_states_scales = hidden_states + q_c, q_c_scale, kv_c_normed, k_pe = torch.ops.vllm.rocm_aiter_triton_qkv_a_proj_layernorm_fp4( + hidden_states_quant=hidden_states, + hidden_states_quant_scale=hidden_states_scales, + weight_qkv_a_proj=self.fused_qkv_a_proj.weight, + weight_scale_qkv_a_proj=self.fused_qkv_a_proj.weight_scale, + q_a_layernorm_weight=self.q_a_layernorm.weight, + q_a_layernorm_variance_epsilon=self.q_a_layernorm.variance_epsilon, + kv_a_layernorm_weight=self.kv_a_layernorm.weight, + kv_a_layernorm_variance_epsilon=self.kv_a_layernorm.variance_epsilon, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim) + q = torch.ops.vllm.rocm_aiter_triton_gemm_afp4wfp4(q_c, self.q_b_proj.weight, q_c_scale, self.q_b_proj.weight_scale, output_dtype=torch.bfloat16) + elif self.use_triton_qkv_a_proj_layernrom_fp8: + assert self.q_lora_rank is not None + assert isinstance(hidden_states, tuple) + hidden_states, hidden_states_scales = hidden_states + q_c, q_c_scale, kv_c_normed, k_pe = torch.ops.vllm.rocm_aiter_triton_qkv_a_proj_layernorm_fp8( + hidden_states_quant=hidden_states, + hidden_states_quant_scale=hidden_states_scales, + weight_qkv_a_proj=self.fused_qkv_a_proj.weight, + weight_scale_qkv_a_proj=self.fused_qkv_a_proj.weight_scale, + q_a_layernorm_weight=self.q_a_layernorm.weight, + q_a_layernorm_variance_epsilon=self.q_a_layernorm.variance_epsilon, + kv_a_layernorm_weight=self.kv_a_layernorm.weight, + kv_a_layernorm_variance_epsilon=self.kv_a_layernorm.variance_epsilon, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim) + q = torch.ops.vllm.rocm_aiter_triton_gemm_a8w8_blockscale(q_c, self.q_b_proj.weight, q_c_scale, self.q_b_proj.weight_scale, output_dtype=torch.bfloat16) else: - assert self.kv_a_proj_with_mqa is not None, ( - "kv_a_proj_with_mqa is required when q_lora_rank is None" - ) - assert self.q_proj is not None, ( - "q_proj is required when q_lora_rank is None" - ) - kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] - q = self.q_proj(hidden_states)[0] - - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c) + assert isinstance(hidden_states, torch.Tensor) + if self.q_lora_rank is not None: + assert self.fused_qkv_a_proj is not None, ( + "fused_qkv_a_proj is required when q_lora_rank is not None" + ) + assert self.q_a_layernorm is not None, ( + "q_a_layernorm is required when q_lora_rank is not None" + ) + assert self.q_b_proj is not None, ( + "q_b_proj is required when q_lora_rank is not None" + ) + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_lora = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + q_c = self.q_a_layernorm(q_c) + q = self.q_b_proj(q_c)[0] + else: + assert self.kv_a_proj_with_mqa is not None, ( + "kv_a_proj_with_mqa is required when q_lora_rank is None" + ) + assert self.q_proj is not None, ( + "q_proj is required when q_lora_rank is None" + ) + kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] + q = self.q_proj(hidden_states)[0] + + kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c) q = q.view(-1, self.num_heads, self.qk_head_dim) # Add head dim of 1 to k_pe k_pe = k_pe.unsqueeze(1) - if self.rotary_emb is not None: + if self.rotary_emb is not None and not self.use_aiter_triton: q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( positions, q[..., self.qk_nope_head_dim :], k_pe ) @@ -163,12 +204,13 @@ def forward_native( if llama_4_scaling is not None: q *= llama_4_scaling + positions_rocm = None if not self.use_aiter_triton else positions attn_out = self.mla_attn( q, kv_c_normed, k_pe, output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), - ) + positions=positions_rocm) return self.o_proj(attn_out)[0] diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 3640e5c45278..6da060315bdf 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -25,6 +25,7 @@ ) from vllm.model_executor.layers.quantization.quark.schemes import ( QuarkOCP_MX, + QuarkW16OCP_MX, QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8, @@ -108,7 +109,16 @@ def get_quant_method( if should_ignore_layer( prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping ): - return UnquantizedLinearMethod() + if current_platform.is_rocm(): + if prefix == "lm_head": + return UnquantizedLinearMethod() + + scheme = self.get_scheme(layer=layer, layer_name=prefix) + layer.scheme = scheme + return QuarkLinearMethod(self) + else: + return UnquantizedLinearMethod() + if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) layer.scheme = scheme @@ -376,7 +386,7 @@ def _matches_pattern(layer_name, pattern): ) return global_quant_config - def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": + def _get_scheme_from_config(self, config: dict[str, Any], layer_name: str) -> "QuarkScheme": if config.get("output_tensors") or config.get("bias"): raise NotImplementedError( "Currently, Quark models with output_tensors " @@ -399,6 +409,14 @@ def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": input_symmetric=input_config.get("symmetric"), ) elif self._is_ocp_mx(weight_config, input_config): + if current_platform.is_rocm(): + exclude_layers = cast(list[str], self.quant_config.get("exclude")) + if should_ignore_layer( + layer_name, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping + ): + return QuarkW16OCP_MX(weight_config, input_config) + return QuarkOCP_MX(weight_config, input_config) + return QuarkOCP_MX(weight_config, input_config) raise NotImplementedError( @@ -411,7 +429,7 @@ def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme": layer_quant_config = self._find_matched_config(layer_name, layer) # Find the quant_scheme - scheme = self._get_scheme_from_config(layer_quant_config) + scheme = self._get_scheme_from_config(layer_quant_config, layer_name) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) self._check_scheme_supported(scheme.get_min_capability()) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 8be0299eaa66..b808acaa7862 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -625,6 +625,13 @@ def apply( rocm_aiter_fused_experts, ) + # if hasattr(torch, "float4_e2m1fn_x2"): + # w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) + # w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) + # else: + # w13_weight = layer.w13_weight + # w2_weight = layer.w2_weight + out = rocm_aiter_fused_experts( x, layer.w13_weight, @@ -650,4 +657,4 @@ def apply( expert_map=expert_map, quant_config=self.moe_quant_config, ) - return out + return out \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py index 7620d6e41b58..7043467a1ec4 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .quark_ocp_mx import QuarkOCP_MX +from .quark_ocp_mx import QuarkOCP_MX, QuarkW16OCP_MX from .quark_scheme import QuarkScheme from .quark_w8a8_fp8 import QuarkW8A8Fp8 from .quark_w8a8_int8 import QuarkW8A8Int8 -__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkOCP_MX"] +__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkOCP_MX", "QuarkW16OCP_MX"] diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index eeb60023dc0e..f76d8b6a3c99 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -26,6 +26,8 @@ ) from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter from vllm.platforms import current_platform +from vllm.model_executor.parameter import ModelWeightParameter +from vllm.model_executor.utils import set_weight_attrs from .quark_scheme import QuarkScheme @@ -54,6 +56,7 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: gemm_afp4wfp4, gemm_afp4wfp4_preshuffled_weight_scales, ) + from aiter.ops.triton.gemm_a16wfp4 import gemm_a16wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant from vllm.utils.torch_utils import direct_register_custom_op @@ -123,6 +126,13 @@ def gemm_with_dynamic_quant( return y[:M] else: if x_scales is None: + if M <= 256 and weight.shape[0] == 7168 and x.shape[-1] == 2048: + y = torch.empty(M, + weight.shape[0], + device=x.device, + dtype=out_dtype) + gemm_a16wfp4(x, weight, weight_scale.T, dtype=out_dtype, y=y) + return y x_q, x_s = dynamic_mxfp4_quant(x) else: x_q = x @@ -341,3 +351,55 @@ def apply_weights( self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype, ) + + +class QuarkW16OCP_MX(QuarkOCP_MX): + def __init__(self, weight_quant_spec: dict[str, Any], + input_quant_spec: dict[str, Any]): + self.out_dtype = torch.get_default_dtype() + self.qscheme = "per_group" + self.weight_quant_spec = weight_quant_spec + self.input_quant_spec = input_quant_spec + self.emulate = not current_platform.supports_mx() + self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled() + if not self.emulate and (dynamic_mxfp4_quant is None + or gemm_afp4wfp4 is None): + # Currently need these kernels if not emulating + raise NotImplementedError( + f"{self.__class__.__name__} requires AITER to be installed " + "for non-emulation mode! Please refer to " + "https://github.com/ROCm/aiter for installation details.") + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + # This method creates unquantized linear weights. + # The weights are not quantized, and they are not sharded. + # The amount of memory allocated for the weights is + # sum(output_partition_sizes) * input_size_per_partition. + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + # set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0, "weight_loader": weight_loader}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, kwargs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) + w_q, w_s = dynamic_mxfp4_quant(layer.weight) + layer.weight_scale = torch.nn.Parameter( + w_s.T.contiguous(), + requires_grad=False) + layer.weight = torch.nn.Parameter(w_q, + requires_grad=False) + #super().process_weights_after_loading(layer) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index dc82f94ebbbf..ca20eccc262f 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -5,7 +5,10 @@ from types import MappingProxyType from typing import Any +import torch import regex as re +from torch import nn +from aiter.ops.triton.quant import dynamic_mxfp4_quant def deep_compare(dict1: Any, dict2: Any) -> bool: @@ -21,6 +24,102 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: return dict1 == dict2 +# utility for tensor dims > 2 cases +def b_dynamic_mxfp4_quant(x): + h, b, d = x.shape + x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d)) + return x.view(h, b, d // 2), x_scales.view(h, b, d // 32) + #return x.view(h, b, d // 2), x_scales.view(h, b, d // 32) + + +def mxfp4_to_f32(x, is_threed): + # 2 because we pack fp4 in uint8. + x = x.repeat_interleave(2, dim=-1) + if is_threed: + x[..., ::2] = x[..., ::2] & 0xF + x[..., 1::2] = x[..., 1::2] >> 4 + else: + x[:, ::2] = x[:, ::2] & 0xF + x[:, 1::2] = x[:, 1::2] >> 4 + + mxfp4_list = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ] + mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda") + return mxfp4_in_f32[x.long()] + + +def e8m0_to_f32(x): + # Convert the input tensor `x` (assumed to be in e8m0 format) to float32. + # e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa. + # This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats. + + # Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127). + x_f32 = 2 ** ((x.to(torch.float32)) - 127) + + # If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf. + # Since this custom format has no mantissa, treat 2^128 as NaN. + x_f32[x_f32 == 128] = float("nan") + return x_f32 + + +def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str): + if "mxfp4" in quant_format: + # when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor + # do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8) + # and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8) + if w.dtype == torch.bfloat16: + # w_kc, w_vc = w.split( + # [self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1)) + w_kc = w_kc.transpose(-2, -1) + w_s_kc = w_s_kc.transpose(-2, -1) + w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc) + w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2) + w_s_vc = w_s_vc.contiguous().transpose(1, 2) + elif w.dtype == torch.uint8: # static quant for mxfp4 + # when dtype is uint8, it means the w has been quantized to mxfp4 format + # but we must separate it to w_kc and w_vc. + # The quantized tensor size is only half of original tensor size + # and the scaling factor is 1/32, the transpose behavior will be not correct + # need to upcast it to fp32 to separate w to w_kc and w_vc + # to ensure the following transpose behavior is correct + # and then do mxfp4 quant again + w = mxfp4_to_f32(w, True).to(torch.bfloat16) + w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1) + w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16) + w = w * w_scales + w_kc, w_vc = w.unflatten( + 0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim)) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1)) + w_kc = w_kc.transpose(-2, -1) + w_s_kc = w_s_kc.transpose(-2, -1) + w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc) + w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2) + w_s_vc = w_s_vc.contiguous().transpose(1, 2) + + return w_kc, w_s_kc, w_vc, w_s_vc + + def should_ignore_layer( layer_name: str | None, ignore: Iterable[str], diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 6e73833d1ae1..8c48de62d54f 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -194,6 +194,39 @@ def _fp8_gemm_nt_op_fake( ) +def _triton_per_token_group_quant_fp8_impl( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + return per_token_group_quant_fp8( + x, group_size, column_major_scales=False, use_ue8m0=False + ) + + +def _triton_per_token_group_quant_fp8_fake( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + x_fp8 = torch.empty((M, N), dtype=current_platform.fp8_dtype(), device=x.device) + out_bs = torch.empty( + ( + M, + (N + group_size - 1) // group_size, + ), + dtype=torch.float32, + device=x.device, + ) + return x_fp8, out_bs + + +direct_register_custom_op( + "triton_per_token_group_quant_fp8", + _triton_per_token_group_quant_fp8_impl, + fake_impl=_triton_per_token_group_quant_fp8_fake, +) + + # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 class W8A8BlockFp8LinearOp: @@ -324,7 +357,6 @@ def _run_aiter( not current_platform.is_fp8_fnuz() and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k) ) - if use_triton: gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale else: @@ -332,17 +364,15 @@ def _run_aiter( if input_scale is not None: q_input = input_2d - # MI350 case uses triton kernel elif use_triton: - q_input, input_scale = per_token_group_quant_fp8( + q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8( input_2d, self.act_quant_group_shape.col, - column_major_scales=False, - use_ue8m0=False, ) - # MI300 uses tuned AITER ASM/C++ kernel else: - q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d) + q_input, input_scale = rocm_aiter_ops.group_fp8_quant( + input_2d, self.act_quant_group_shape.col + ) return gemm_a8w8_blockscale_op( q_input, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a8eb4a69b6f2..19717dacb776 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -50,6 +50,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -265,11 +266,15 @@ def __init__( quant_config=None, prefix=f"{prefix}.gate", ) - if getattr(config, "topk_method", None) == "noaux_tc": + self.is_fusion_triton_shared_experts_enabled = rocm_aiter_ops.is_fusion_triton_shared_experts_enabled() + if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts, dtype=torch.float32) - ) + torch.empty(config.n_routed_experts, dtype=torch.float32)) + e_score_correction_bias = self.gate.e_score_correction_bias + if self.is_fusion_triton_shared_experts_enabled: + e_score_correction_bias = self.gate.e_score_correction_bias.to(torch.bfloat16) else: + e_score_correction_bias = None self.gate.e_score_correction_bias = None # Load balancing settings. @@ -290,9 +295,46 @@ def __init__( self.is_fusion_moe_shared_experts_enabled = ( rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() ) - if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled: - self.shared_experts = None - else: + if self.is_fusion_triton_shared_experts_enabled: + self.use_triton_fused_shared_expert_fp8 = False + self.use_triton_fused_shared_expert_fp4 = False + self.rocm_aiter_triton_fused_shared_expert_func = None + self.rocm_aiter_triton_fused_down_proj_mul_add_func = None + if self.is_fusion_triton_shared_experts_enabled: + assert config.n_shared_experts is not None, f"config.n_shared_experts == None is detected in {self.__class__.__name__} please turn off VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS" + if quant_config.get_name() == 'fp8': + self.use_triton_fused_shared_expert_fp8 = True + self.rocm_aiter_triton_fused_shared_expert_func = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert_fp8 + self.rocm_aiter_triton_fused_down_proj_mul_add_func = torch.ops.vllm.rocm_aiter_triton_fused_down_proj_mul_add_fp8 + elif quant_config.get_name() == 'quark': + self.use_triton_fused_shared_expert_fp4 = True + self.rocm_aiter_triton_fused_shared_expert_func = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert_fp4 + self.rocm_aiter_triton_fused_down_proj_mul_add_func = torch.ops.vllm.rocm_aiter_triton_fused_down_proj_mul_add_fp4 + else: + raise NotImplementedError(f"{quant_config.get_name()=} which is not supported for VLLM_ROCM_USE_AITER_TRITON_FUSION_SHARED_EXPERTS") + logger.info(f"[Aiter] {self.__class__.__name__} is registered with {self.rocm_aiter_triton_fused_shared_expert_func.__name__} and {self.rocm_aiter_triton_fused_down_proj_mul_add_func.__name__}") + + if self.is_fusion_triton_shared_experts_enabled: + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=getattr(config, "n_group", 1), + topk_group=getattr(config, "topk_group", 1), + prefix=f"{prefix}.experts", + scoring_func=getattr(config, "scoring_func", "softmax"), + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP( @@ -304,37 +346,57 @@ def __init__( reduce_results=False, prefix=f"{prefix}.shared_experts", ) - - self.experts = SharedFusedMoE( - shared_experts=self.shared_experts, - gate=self.gate, - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=getattr(config, "n_group", 1), - topk_group=getattr(config, "topk_group", 1), - prefix=f"{prefix}.experts", - scoring_func=getattr(config, "scoring_func", "softmax"), - # we do scaling outside, set factor to 1.0 to avoid double mul - # aiter applies routed_scaling_factor internally - routed_scaling_factor=1.0 - if not self.is_rocm_aiter_moe_enabled - else self.routed_scaling_factor, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel, - n_shared_experts=config.n_shared_experts - if self.is_fusion_moe_shared_experts_enabled - else None, + else: + if config.n_shared_experts is None: + self.shared_experts = None + else: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + is_sequence_parallel=self.is_sequence_parallel, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + gate=self.gate, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=getattr(config, "n_group", 1), + topk_group=getattr(config, "topk_group", 1), + prefix=f"{prefix}.experts", + scoring_func=getattr(config, "scoring_func", "softmax"), + # we do scaling outside, set factor to 1.0 to avoid double mul + # aiter applies routed_scaling_factor internally + routed_scaling_factor=1.0 + if not self.is_rocm_aiter_moe_enabled + else self.routed_scaling_factor, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + n_shared_experts=config.n_shared_experts + if self.is_fusion_moe_shared_experts_enabled + else None, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: tuple | torch.Tensor) -> torch.Tensor: + if isinstance(hidden_states, tuple): + hidden_states_shared, hidden_states = hidden_states + else: + hidden_states_shared = hidden_states + num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -345,34 +407,67 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) - if self.experts.is_internal_router: - # In this case, the gate/router runs inside the FusedMoE class - fused_moe_out = self.experts( - hidden_states=hidden_states, router_logits=hidden_states + if self.is_fusion_triton_shared_experts_enabled: + # assert isinstance(hidden_states_shared, tuple), f"hidden_states_shared must be a tuple of quantized acitvation and scales" + shared_output = None + shared_output_q, shared_output_s = None, None + hidden_states_shared, hidden_states_shared_scale = hidden_states_shared + shared_output_q, shared_output_s, router_logits = ( + self.rocm_aiter_triton_fused_shared_expert_func( + hidden_states_shared=hidden_states_shared, + hidden_states_shared_scale=hidden_states_shared_scale, + weight_gate_up=self.shared_experts.gate_up_proj.weight, + weight_scale_gate_up=self.shared_experts.gate_up_proj.weight_scale, + hidden_states_moe_gate=hidden_states, + weight_moe_gate=self.gate.weight, + bias_shared=( + self.shared_experts.gate_up_proj.bias + if not self.shared_experts.gate_up_proj.skip_bias_add + else None + ), + bias_moe_gate=( + self.gate.bias if not self.gate.skip_bias_add else None + ), + ) ) - else: - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - fused_moe_out = self.experts( + # shared_output = self.shared_experts(hidden_states) + # router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) + else: + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) + else: + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) - shared_output, final_hidden_states = fused_moe_out - if self.shared_experts is None: - assert shared_output is None + shared_output, final_hidden_states = fused_moe_out + if self.shared_experts is None: + assert shared_output is None # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. - if hidden_states.dtype != torch.float16: - if not self.is_rocm_aiter_moe_enabled: - final_hidden_states *= self.routed_scaling_factor - elif self.shared_experts is not None: - assert shared_output is not None - shared_output *= 1.0 / self.routed_scaling_factor + if self.is_fusion_triton_shared_experts_enabled and hidden_states.dtype != torch.float16: + assert shared_output is None + final_hidden_states = self.rocm_aiter_triton_fused_down_proj_mul_add_func(shared_output_q, shared_output_s, self.shared_experts.down_proj.weight, self.shared_experts.down_proj.weight_scale, self.routed_scaling_factor, final_hidden_states) + else: + if hidden_states.dtype != torch.float16: + if not self.is_rocm_aiter_moe_enabled: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= 1.0 / self.routed_scaling_factor - if self.shared_experts is not None: - assert shared_output is not None - final_hidden_states += shared_output + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( @@ -1125,6 +1220,17 @@ def __init__( # with the layer's index. layer_idx = int(prefix.split(sep=".")[-1]) self.layer_idx = layer_idx + self.is_fusion_triton_shared_experts_enabled = rocm_aiter_ops.is_fusion_triton_shared_experts_enabled() + self.use_triton_fused_rmsnorm_fp8_quant = False + self.use_triton_fused_rmsnorm_fp4_quant = False + if rocm_aiter_ops.is_enabled(): + if quant_config.get_name() == 'fp8': + self.use_triton_fused_rmsnorm_fp8_quant = True + elif quant_config.get_name() == 'quark': + self.use_triton_fused_rmsnorm_fp4_quant = True + else: + raise NotImplementedError(f"{quant_config.get_name()=} which is not supported with the current version of AITER") + logger.info(f"[Aiter] {self.__class__.__name__} has {quant_config.get_name()=}") # verify MLA attention specific fields qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) @@ -1193,11 +1299,51 @@ def forward( llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: # Self Attention - if residual is None: - residual = hidden_states.clone() - hidden_states = self.input_layernorm(hidden_states) + if self.use_triton_fused_rmsnorm_fp8_quant: + weight = self.input_layernorm.weight + eps = self.input_layernorm.variance_epsilon + from vllm._aiter_ops import AITER_FP8_DTYPE + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + if residual is None: + residual = hidden_states + (hidden_states_quant, hidden_states_quant_scales), _, _, _ = fused_rms_fp8_group_quant(hidden_states, weight, eps, + None, None, eps, + group_size=128, + dtype_quant=AITER_FP8_DTYPE, + res1=None) + else: + (hidden_states_quant, hidden_states_quant_scales), _, _, residual = fused_rms_fp8_group_quant(hidden_states, weight, eps, + None, None, eps, + group_size=128, + dtype_quant=AITER_FP8_DTYPE, + res1=residual) + hidden_states = (hidden_states_quant, hidden_states_quant_scales) + elif self.use_triton_fused_rmsnorm_fp4_quant: + weight = self.input_layernorm.weight + eps = self.input_layernorm.variance_epsilon + from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant + if residual is None: + residual = hidden_states + (hidden_states_quant, hidden_states_quant_scales), _, _, _ = fused_rms_mxfp4_quant(hidden_states, weight, eps, + None, None, eps, + res1=None, + shuffle=False, + scale_shuffle_padding=False, + output_unquantized_inp1=False) + else: + (hidden_states_quant, hidden_states_quant_scales), _, _, residual = fused_rms_mxfp4_quant(hidden_states, weight, eps, + None, None, eps, + res1=residual, + shuffle=False, + scale_shuffle_padding=False, + output_unquantized_inp1=False) + hidden_states = (hidden_states_quant, hidden_states_quant_scales) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + if residual is None: + residual = hidden_states.clone() + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) attn_kwargs = { "positions": positions, @@ -1221,7 +1367,34 @@ def forward( residual *= 1.0 / self.routed_scaling_factor # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + if self.is_fusion_triton_shared_experts_enabled and isinstance(self.mlp, DeepseekV2MoE): + weight = self.post_attention_layernorm.weight + eps = self.post_attention_layernorm.variance_epsilon + if self.use_triton_fused_rmsnorm_fp8_quant: + from vllm._aiter_ops import AITER_FP8_DTYPE + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + (hidden_states_quant, hidden_states_quant_scales), hidden_states_unquant, _, residual = fused_rms_fp8_group_quant(hidden_states, weight, eps, + None, None, eps, + group_size=128, + dtype_quant=AITER_FP8_DTYPE, + res1=residual, + output_unquantized_inp1=isinstance(self.mlp, DeepseekV2MoE)) + else: + from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant + (hidden_states_quant, hidden_states_quant_scales), hidden_states_unquant, _, residual = fused_rms_mxfp4_quant(hidden_states, weight, eps, + None, None, eps, + res1=residual, + shuffle=False, + scale_shuffle_padding=False, + output_unquantized_inp1=isinstance(self.mlp, DeepseekV2MoE)) + + if isinstance(self.mlp, DeepseekV2MoE): + hidden_states = ((hidden_states_quant, hidden_states_quant_scales), hidden_states_unquant) + else: + hidden_states = (hidden_states_quant, hidden_states_quant_scales) + else: + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 167dfbca248c..843c40b7500a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -67,7 +67,8 @@ make_layers, maybe_prefix, ) - +from vllm.platforms import current_platform +from vllm._aiter_ops import rocm_aiter_ops class LlamaMLP(nn.Module): def __init__( @@ -219,6 +220,7 @@ def __init__( per_layer_sliding_window=sliding_window, attn_type=attn_type, prefix=f"{prefix}.attn", + rotary_emb = self.rotary_emb if current_platform.is_rocm() and rocm_aiter_ops.is_enabled() else None ) def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: @@ -239,11 +241,15 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - if self.do_llama_4_scaling: - attn_scale = self._get_llama_4_attn_scale(positions) - q = (q * attn_scale).to(q.dtype) - attn_output = self.attn(q, k, v) + if current_platform.is_rocm() and rocm_aiter_ops.is_enabled(): + assert not self.do_llama_4_scaling + attn_output = self.attn(q, k, v, positions=positions) + else: + q, k = self.rotary_emb(positions, q, k) + if self.do_llama_4_scaling: + attn_scale = self._get_llama_4_attn_scale(positions) + q = (q * attn_scale).to(q.dtype) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 180625b6ce89..0a8f4229a2bb 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -204,6 +204,8 @@ AttentionLayer, MLAAttentionImpl, ) +from vllm.model_executor.layers.quantization.quark.quark import QuarkLinearMethod +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states @@ -1139,6 +1141,8 @@ def __init__( self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() + self.is_aiter_triton_fp4_bmm_enabled = rocm_aiter_ops.is_fp4bmm_enabled() + self.is_aiter_enabled = rocm_aiter_ops.is_enabled() def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): @@ -1167,7 +1171,23 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj) + + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + if self.is_aiter_triton_fp4_bmm_enabled: + from vllm.model_executor.layers.quantization.quark.utils import quark_post_load_weights + + self.W_K, self.W_K_scale, W_V, self.W_V_scale = ( + quark_post_load_weights(self, kv_b_proj_weight, "mxfp4")) + self.W_V = W_V.contiguous().transpose(1, 2) + + self.W_K = self.W_K.transpose(-2, -1).contiguous() + self.W_K_scale = self.W_K_scale.transpose(-2, -1).contiguous() + self.W_V = self.W_V.transpose(-2, -1).contiguous() + self.W_V_scale = self.W_V_scale.transpose(-2, -1).contiguous() + return + + kv_b_proj_weight = kv_b_proj_weight.T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1236,16 +1256,39 @@ def get_and_maybe_dequant_weights(layer: LinearBase): self.W_UK_T = W_UK.permute(1, 2, 0) def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - if self.is_aiter_triton_fp8_bmm_enabled: + if self.is_aiter_triton_fp4_bmm_enabled: + #print(f'>>> x pre (up_proj) {x.shape}') + out = out.view(-1, self.num_heads, self.v_head_dim) + x = x.view(-1, self.num_heads, self.kv_lora_rank) + x = x.transpose(0, 1) + + #print(f'>>> x {x.shape}, attn_bmm_output {attn_bmm_output.shape}, self.W_V {self.W_V.shape}') + out = rocm_aiter_ops.triton_fp4_bmm( + x, + self.W_V, + self.W_V_scale, + YQ=out, + transpose_bm=True, + y_scale=None, + ) + #print(f'>>> x before transpose {x.shape}') + out = out.view(-1, self.num_heads * self.v_head_dim) + # x = out + + elif self.is_aiter_triton_fp8_bmm_enabled: + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + out = out.view(-1, self.num_heads, self.v_head_dim) # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) x = rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out ) else: + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Convert from (B, N * V) to (N, B, V) out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) @@ -1577,7 +1620,23 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj) + + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + if self.is_aiter_triton_fp4_bmm_enabled: + from vllm.model_executor.layers.quantization.quark.utils import quark_post_load_weights + + self.W_K, self.W_K_scale, W_V, self.W_V_scale = ( + quark_post_load_weights(self, kv_b_proj_weight, "mxfp4")) + self.W_V = W_V.contiguous().transpose(1, 2) + + self.W_K = self.W_K.transpose(-2, -1).contiguous() + self.W_K_scale = self.W_K_scale.transpose(-2, -1).contiguous() + self.W_V = self.W_V.transpose(-2, -1).contiguous() + self.W_V_scale = self.W_V_scale.transpose(-2, -1).contiguous() + return + + kv_b_proj_weight = kv_b_proj_weight.T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1676,12 +1735,53 @@ def _compute_prefill_context( kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim - ) - k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + if self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod): + from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_split_cat + from aiter.ops.triton.quant import dynamic_mxfp4_quant + input = kv_c_normed + weight = self.kv_b_proj.weight + weight_scale = self.kv_b_proj.weight_scale - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + q_input, x_scale = dynamic_mxfp4_quant(input_2d) + + k, v = fused_gemm_afp4wfp4_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale.T, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + elif self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, Fp8LinearMethod): + from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_split_cat + import aiter as rocm_aiter + from aiter import get_hip_quant + aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) + from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 + + input = kv_c_normed + weight = self.kv_b_proj.weight + block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size + weight_scale = self.kv_b_proj.weight_scale + + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + if current_platform.is_fp8_fnuz(): + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False) + + k, v = fused_gemm_a8w8_blockscale_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1781,11 +1881,52 @@ def _context_parallel_compute_prefill_context( toks=toks, ) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim - ) - k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + if self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod): + from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_split_cat + from aiter.ops.triton.quant import dynamic_mxfp4_quant + input = kv_c_normed + weight = self.kv_b_proj.weight + weight_scale = self.kv_b_proj.weight_scale + + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + q_input, x_scale = dynamic_mxfp4_quant(input_2d) + + k, v = fused_gemm_afp4wfp4_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale.T, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + elif self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, Fp8LinearMethod): + from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_split_cat + import aiter as rocm_aiter + from aiter import get_hip_quant + aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) + from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 + + input = kv_c_normed + weight = self.kv_b_proj.weight + block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size + weight_scale = self.kv_b_proj.weight_scale + + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + if current_platform.is_fp8_fnuz(): + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False) + + k, v = fused_gemm_a8w8_blockscale_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1829,12 +1970,55 @@ def _forward_prefill( assert self.dcp_world_size is not None has_context = attn_metadata.prefill.chunked_context is not None - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim - ) - k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + if self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod): + from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_split_cat + from aiter.ops.triton.quant import dynamic_mxfp4_quant + input = kv_c_normed + weight = self.kv_b_proj.weight + weight_scale = self.kv_b_proj.weight_scale + + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + q_input, x_scale = dynamic_mxfp4_quant(input_2d) + + k, v = fused_gemm_afp4wfp4_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale.T, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + elif self.is_aiter_enabled and isinstance(self.kv_b_proj.quant_method, Fp8LinearMethod): + from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_split_cat + import aiter as rocm_aiter + from aiter import get_hip_quant + aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) + from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 + + input = kv_c_normed + weight = self.kv_b_proj.weight + block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size + weight_scale = self.kv_b_proj.weight_scale + + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + if current_platform.is_fp8_fnuz(): + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False) + + k, v = fused_gemm_a8w8_blockscale_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) output_prefill = self._run_prefill_new_tokens( prefill=attn_metadata.prefill, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 00a0a77a1c2f..3b255354ac6b 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -6,6 +6,7 @@ import torch +from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import AttentionLayer, MultipleOf from vllm.config import VllmConfig @@ -18,7 +19,12 @@ ) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank +from vllm.attention.ops.common import cp_lse_ag_out_rs +from vllm.platforms import current_platform +from typing import ClassVar, Generic, TypeVar +M = TypeVar("M", bound=MLACommonMetadata) class AiterMLABackend(MLACommonBackend): @staticmethod @@ -200,6 +206,7 @@ def __init__( kv_sharing_target_layer_name, **mla_args, ) + self.is_aiter_triton_fp4_bmm_enabled = rocm_aiter_ops.is_fp4bmm_enabled() assert num_heads == 16 or num_heads == 128, ( f"Aiter MLA only supports 16 or 128 number of heads.\n" f"Provided {num_heads} number of heads.\n" @@ -236,22 +243,32 @@ def _forward_decode( kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AiterMLAMetadata, layer: AttentionLayer, + mla_output_zeros: torch.Tensor | None = None, + decode_q_cat: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - if type(q) is tuple: + if decode_q_cat is not None: + q = decode_q_cat + elif type(q) is tuple: q = torch.cat(q, dim=-1) assert isinstance(q, torch.Tensor) B = q.shape[0] - o = torch.zeros( - B, - self.num_heads, - self.kv_lora_rank, - dtype=attn_metadata.decode.attn_out_dtype, - device=q.device, - ) + if mla_output_zeros is not None: + o = mla_output_zeros + assert o.shape[0] == B, f"{o.shape[0]=} {B=}" + assert o.shape[1] == self.num_heads, f"{o.shape[1]=} {self.num_heads=}" + assert o.shape[2] == self.kv_lora_rank, f"{o.shape[2]=} {self.kv_lora_rank=}" + else: + o = torch.zeros( + B, + self.num_heads, + self.kv_lora_rank, + dtype=attn_metadata.decode.attn_out_dtype, + device=q.device, + ) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) @@ -273,3 +290,234 @@ def _forward_decode( ) return o, None + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: M, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for MLACommonImpl" + ) + + if attn_metadata is None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + ( + self.chunked_prefill_workspace_size, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + if self.dcp_world_size is None: + self.dcp_world_size = get_dcp_group().world_size + + fp8_attention = self.kv_cache_dtype.startswith("fp8") + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + mla_output_zeros = None + decode_q_cat = None + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + if positions is not None: + # positions is not None entails that Q and K are not RoPE embedded yet, therefore, fused_qk_rope_cat_and_cache_mla is called + assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}" + from aiter.ops.triton.fusions.fused_kv_cache import fused_qk_rope_cat_and_cache_mla + cos, sin = self.rotary_emb.cos_sin_cache.chunk(2, dim = -1) + is_neox = self.rotary_emb.is_neox_style + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_out_dtype = current_platform.fp8_dtype() if fp8_attention else q.dtype + if self.is_aiter_triton_fp4_bmm_enabled or self.is_aiter_triton_fp8_bmm_enabled: + decode_q_cat = torch.empty((num_decode_tokens, self.num_heads, self.W_K.shape[1] + self.qk_rope_head_dim), dtype = q_out_dtype, device=q.device) + if fp8_attention: + kv_cache_og_dtype = kv_cache.dtype + kv_cache = kv_cache.view(q_out_dtype) + q, _, k_pe, mla_output_zeros = fused_qk_rope_cat_and_cache_mla( + q_nope, + q_pe, + k_c_normed.unsqueeze(1), + k_pe, + kv_cache, + attn_metadata.slot_mapping.flatten(), + positions, + cos, + sin, + layer._k_scale, + is_neox, + num_decode_toks_for_zeros=num_decode_tokens, + apply_scale=(k_pe.dtype != kv_cache.dtype), + q_out=None, + decode_q_pe_out = decode_q_cat[... , -self.qk_rope_head_dim:] if self.is_aiter_triton_fp4_bmm_enabled or self.is_aiter_triton_fp8_bmm_enabled else None, + k_pe_out=k_pe, + ) + if fp8_attention: + kv_cache = kv_cache.view(kv_cache_og_dtype) + else: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + decode_q = q[:num_decode_tokens] + + prefill_q = q[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + + if has_prefill: + self._forward_prefill( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + kv_cache, + attn_metadata, + layer._k_scale, + output=output[num_decode_tokens:], + ) + + if has_decode: + assert attn_metadata.decode is not None + + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + + if self.q_pad_num_heads is not None: + B, N, L = decode_q_pe.shape + decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) + decode_pe_padded.resize_((B, N, L)) + decode_pe_padded.copy_(decode_q_pe) + decode_q_pe = decode_pe_padded + + if self.is_aiter_triton_fp4_bmm_enabled: + #x = x.view(-1, self.num_heads, self.kv_lora_rank) + decode_ql_nope = decode_q_cat[... , :self.W_K.shape[1]] if (kv_cache.numel() > 0 and positions is not None) else None + decode_ql_nope = rocm_aiter_ops.triton_fp4_bmm( + decode_q_nope, + self.W_K, + self.W_K_scale, + YQ=decode_ql_nope, + transpose_bm=True, + y_scale=layer._q_scale if fp8_attention else None, + ) + # decode_ql_nope = decode_ql_nope.transpose(0, 1) + elif self.is_aiter_triton_fp8_bmm_enabled: + # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) + decode_ql_nope = decode_q_cat[... , :self.W_K.shape[1]] if (kv_cache.numel() > 0 and positions is not None) else None + decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( + decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + YQ=decode_ql_nope, + transpose_bm=True, + ) + else: + # Pads the head_dim if necessary (for the underlying kernel) + N, B, P = decode_q_nope.shape + _, _, L = self.W_UK_T.shape + + if self.q_pad_num_heads is not None: + decode_ql_nope = decode_q_nope.new_empty( + (self.q_pad_num_heads, B, L) + ) + decode_ql_nope.resize_((N, B, L)) + else: + decode_ql_nope = decode_q_nope.new_empty((N, B, L)) + + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) + + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + + if fp8_attention and not (self.is_aiter_triton_fp4_bmm_enabled or self.is_aiter_triton_fp8_bmm_enabled): + ql_nope_shape = decode_ql_nope.shape + decode_ql_nope, _ = ops.scaled_fp8_quant( + decode_ql_nope.reshape( + [ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]] + ), + layer._q_scale, + ) + decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) + q_pe_shape = decode_q_pe.shape + decode_q_pe, _ = ops.scaled_fp8_quant( + decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale, + ) + decode_q_pe = decode_q_pe.reshape(q_pe_shape) + + decode_q = (decode_ql_nope, decode_q_pe) + if self.dcp_world_size > 1: + assert not fp8_attention, "DCP not support fp8 kvcache now." + # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) + decode_q = torch.cat(decode_q, dim=-1) + # decode_q do allgather in head dim. + decode_q = get_dcp_group().all_gather(decode_q, dim=1) + + # call decode attn + attn_out, lse = self._forward_decode( + decode_q, kv_cache, attn_metadata, layer, mla_output_zeros=mla_output_zeros, decode_q_cat=decode_q_cat + ) + + # correct dcp attn_out with lse. + if self.dcp_world_size > 1: + attn_out = cp_lse_ag_out_rs( + attn_out, + lse, + get_dcp_group(), + is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), + ) + + # v_up projection + self._v_up_proj(attn_out, out=output[:num_decode_tokens]) + return output_padded diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index b6aa0ae2be48..acbefc3410e3 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -41,6 +41,10 @@ def block_size(x, head_dim): def num_programs(total_tokens): return min(total_tokens, get_cu_count()) + + from vllm._aiter_ops import rocm_aiter_ops + if rocm_aiter_ops.is_enabled(): + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache @triton.jit def cp_mha_gather_cache_kernel( @@ -782,6 +786,7 @@ def forward( output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with AiterFlashAttention. @@ -823,29 +828,67 @@ def forward( # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached # in KV cache. - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping - # is not padded. However, we don't need to do - # key[:num_actual_tokens] and value[:num_actual_tokens] because - # the reshape_and_cache_flash op uses the slot_mapping's shape - # to determine the number of actual tokens. - - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + if positions is not None and query.shape[0] <= 256 and rocm_aiter_ops.is_enabled(): + assert self.kv_sharing_target_layer_name is None + cos_sin_cache = self.rotary_emb.cos_sin_cache + is_neox = self.rotary_emb.is_neox_style + cos, sin = cos_sin_cache.chunk(2, dim=-1) + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) + query, key, key_cache, value_cache, output = ( + fused_qk_rope_reshape_and_cache( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + positions, + cos, + sin, + layer._k_scale, + layer._v_scale, + is_neox, + flash_layout=True, + apply_scale=is_fp8_kv_cache, + offs=None, + q_out=query, + k_out=key, + output_zeros=True, + zeros_out=output, + ) ) + else: + if positions is not None: + if current_platform.is_rocm(): + query, key = self.rotary_emb.forward_cuda(positions, query, key) + else: + query, key = self.rotary_emb(positions, query, key) + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping + # is not padded. However, we don't need to do + # key[:num_actual_tokens] and value[:num_actual_tokens] because + # the reshape_and_cache_flash op uses the slot_mapping's shape + # to determine the number of actual tokens. + + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(current_platform.fp8_dtype()) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 868143cc192e..6611a49e6caa 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -32,6 +32,12 @@ logger = init_logger(__name__) +if current_platform.is_rocm(): + + from vllm._aiter_ops import rocm_aiter_ops + if rocm_aiter_ops.is_enabled(): + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache + @dataclass class RocmAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. @@ -264,6 +270,7 @@ def forward( output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -306,19 +313,57 @@ def forward( kv_cache, self.num_kv_heads, self.head_size ) - if self.kv_sharing_target_layer_name is None: - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + if positions is not None and query.shape[0] <= 256 and rocm_aiter_ops.is_enabled(): + assert self.kv_sharing_target_layer_name is None + cos_sin_cache = self.rotary_emb.cos_sin_cache + is_neox = self.rotary_emb.is_neox_style + cos, sin = cos_sin_cache.chunk(2, dim=-1) + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + query, key, key_cache, value_cache, output = ( + fused_qk_rope_reshape_and_cache( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + positions, + cos, + sin, + layer._k_scale, + layer._v_scale, + is_neox, + flash_layout=False, + apply_scale=is_fp8_kv_cache, + offs=None, + q_out=query, + k_out=key, + output_zeros=True, + zeros_out=output, + ) ) + else: + if positions is not None: + if current_platform.is_rocm(): + query, key = self.rotary_emb.forward_cuda(positions, query, key) + else: + query, key = self.rotary_emb(positions, query, key) + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype)