Skip to content
101 changes: 95 additions & 6 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import itertools

import pytest
import torch

import vllm.plugins
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
Expand All @@ -18,6 +21,7 @@
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
Expand Down Expand Up @@ -119,13 +123,79 @@ def ops_in_model_before_partial(self):
)


class TestAiterRmsnormGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, **kwargs):
super().__init__()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(3)
]

scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = [
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
for _ in range(3)
]

self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
self.eps = eps

def forward(self, x):
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)

x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0])
# make sure resid is used for replacement to work
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
x2, resid, self.norm_weight[1], self.eps
)

x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1])

y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
x3, resid, self.norm_weight[2], self.eps
)

x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2])

y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
x4, resid, self.norm_weight[3], self.eps
)
return y4

def ops_in_model_before(self):
return [
torch.ops.vllm.rocm_aiter_rms_norm,
torch.ops.vllm.rocm_aiter_group_fp8_quant,
]

def ops_in_model_before_partial(self):
return []

def ops_in_model_after(self):
return [
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
]


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
@pytest.mark.parametrize(
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
list(itertools.product([TestModel], [True, False], [True, False]))
+ [(TestAiterRmsnormGroupFp8QuantModel, False, False)],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize(
Expand All @@ -140,10 +210,14 @@ def test_fusion_rmsnorm_quant(
num_tokens,
eps,
static,
model_class,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
cuda_force_torch,
):
if model_class is TestAiterRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")

torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
Expand All @@ -167,12 +241,24 @@ def test_fusion_rmsnorm_quant(
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
if model_class is TestAiterRmsnormGroupFp8QuantModel:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
)

fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
else:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch)
model = model_class(
hidden_size=hidden_size,
eps=eps,
static=static,
cuda_force_torch=cuda_force_torch,
)

# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
Expand Down Expand Up @@ -202,7 +288,10 @@ def test_fusion_rmsnorm_quant(
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op:
if (
not enable_rms_norm_custom_op
and model_class is not TestAiterRmsnormGroupFp8QuantModel
):
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7
Expand Down
56 changes: 54 additions & 2 deletions tests/compile/test_silu_mul_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import vllm.envs as envs
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
from vllm._aiter_ops import IS_AITER_FOUND
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.activation_quant_fusion import (
FUSED_OPS,
Expand All @@ -24,6 +25,7 @@
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym,
Expand Down Expand Up @@ -126,14 +128,50 @@ def ops_in_model_after(self):
return [FUSED_OPS[kNvfp4Quant]]


class TestAiterSiluMulGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()

scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = torch.rand(
(scale_hidden_size, scale_hidden_size), dtype=torch.float32
)

self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()

def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
return x2

def ops_in_model_before(self):
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
]

def ops_in_model_after(self):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]


@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
@pytest.mark.parametrize(
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
+ [(TestSiluMulNvfp4QuantModel, False, False)],
+ [
(TestSiluMulNvfp4QuantModel, False, False),
(TestAiterSiluMulGroupFp8QuantModel, False, False),
],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
Expand All @@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel],
model_class: type[
TestSiluMulFp8QuantModel
| TestSiluMulNvfp4QuantModel
| TestAiterSiluMulGroupFp8QuantModel
],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
cuda_force_torch: bool,
):
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.")
if model_class is TestAiterSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")

torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
Expand All @@ -174,6 +218,12 @@ def test_fusion_silu_and_mul_quant(

with set_current_vllm_config(config):
fusion_pass = ActivationQuantFusionPass(config)
if model_class == TestAiterSiluMulGroupFp8QuantModel:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterSiluMulFp8GroupQuantFusionPass,
)

fusion_pass = RocmAiterSiluMulFp8GroupQuantFusionPass(config)

passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
backend = TestBackend(*passes)
Expand All @@ -194,6 +244,8 @@ def test_fusion_silu_and_mul_quant(
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1
elif model_class == TestAiterSiluMulGroupFp8QuantModel:
atol, rtol = 5e-2, 5e-2

torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
Expand Down
Loading