diff --git a/docs/assets/contributing/dockerfile-stages-dependency.png b/docs/assets/contributing/dockerfile-stages-dependency.png index b327eb2151f5..0caf7429d39d 100644 Binary files a/docs/assets/contributing/dockerfile-stages-dependency.png and b/docs/assets/contributing/dockerfile-stages-dependency.png differ diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index eb0dee8d4e39..538ec3e2988b 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -17,6 +17,12 @@ from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.compilation.rocm_aiter_fusion import ( + AITER_PER_TOKEN_QUANT_OP, + FUSED_SILU_MUL_PER_TOKEN_QUANT_OP, + VLLM_PER_TOKEN_QUANT_OP, + RocmAiterSiluMulFp8PerTokenQuantFusionPass, +) from vllm.config import ( CompilationConfig, CompilationMode, @@ -161,8 +167,59 @@ def ops_in_model_after(self): return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant] +class TestSiluMulPerTokenQuantModel(torch.nn.Module): + def __init__(self, hidden_size: int, **kwargs): + super().__init__() + self.silu_and_mul = SiluAndMul() + self.hidden_size = hidden_size + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() + + self.fp8_linear = Fp8LinearOp( + act_quant_static=False, + act_quant_group_shape=GroupShape.PER_TOKEN, + pad_output=True, + ) + + self.use_aiter_quant = ( + self.fp8_linear.quant_fp8.use_aiter + if hasattr(self.fp8_linear.quant_fp8, "use_aiter") + else False + ) + + weight_bf16 = torch.randn(hidden_size, hidden_size, dtype=torch.bfloat16) + weight_absmax = torch.max(torch.abs(weight_bf16), dim=0, keepdim=True)[ + 0 + ] # [1, hidden_size] + fp8_max = torch.finfo(FP8_DTYPE).max + self.wscale = ( + (weight_absmax / fp8_max).clamp(min=1e-12).to(torch.float32).t() + ) # [hidden_size, 1] + self.w = (weight_bf16 / weight_absmax).to(FP8_DTYPE).t() + + def forward(self, x): + y = self.silu_and_mul(x) + x2 = self.fp8_linear.apply(y, self.w, self.wscale, out_dtype=torch.bfloat16) + return x2, None # mimic 2-element output + + def ops_in_model_before(self): + silu_mul_op = ( + SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul + ) + + quant_op = ( + AITER_PER_TOKEN_QUANT_OP + if self.use_aiter_quant + else VLLM_PER_TOKEN_QUANT_OP + ) + + return [silu_mul_op, quant_op] + + def ops_in_model_after(self): + return [FUSED_SILU_MUL_PER_TOKEN_QUANT_OP] + + @pytest.mark.parametrize("num_tokens", [32, 64]) -@pytest.mark.parametrize("hidden_size", [128, 256]) +@pytest.mark.parametrize("hidden_size", [128, 256, 4096]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False]) @pytest.mark.parametrize( @@ -171,6 +228,7 @@ def ops_in_model_after(self): + [ (TestSiluMulNvfp4QuantModel, False, False), (TestSiluMulGroupFp8QuantModel, False, False), + (TestSiluMulPerTokenQuantModel, False, False), ], ) # cuda_force_torch used to test torch code path on platforms that @@ -186,6 +244,7 @@ def test_fusion_silu_and_mul_quant( TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel | TestSiluMulGroupFp8QuantModel + | TestSiluMulPerTokenQuantModel ], enable_silu_mul_custom_op: bool, enable_quant_fp8_custom_op: bool, @@ -195,6 +254,8 @@ def test_fusion_silu_and_mul_quant( pytest.skip("NVFP4 is not supported on this GPU.") if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND: pytest.skip("AITER is not supported on this GPU.") + if model_class is TestSiluMulPerTokenQuantModel and not IS_AITER_FOUND: + pytest.skip("AITER is not supported on this GPU.") torch.set_default_device("cuda") torch.set_default_dtype(dtype) @@ -224,6 +285,8 @@ def test_fusion_silu_and_mul_quant( ) fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)] + if model_class is TestSiluMulPerTokenQuantModel: + fusion_passes += [RocmAiterSiluMulFp8PerTokenQuantFusionPass(config)] passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)] backend = TestBackend(*passes) @@ -246,6 +309,8 @@ def test_fusion_silu_and_mul_quant( atol, rtol = 1e-1, 1e-1 elif model_class == TestSiluMulGroupFp8QuantModel: atol, rtol = 5e-2, 5e-2 + elif model_class == TestSiluMulPerTokenQuantModel: + atol, rtol = 1e-2, 1e-2 torch.testing.assert_close( result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index adba69c177fc..3444d7b9a5ec 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -663,6 +663,20 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake( return x_fp8, out_bs +def _rocm_aiter_fused_silu_mul_per_token_quant_impl( + out: torch.Tensor, scales: torch.Tensor, input: torch.Tensor +) -> None: + from aiter.ops.activation import fused_silu_mul_per_token_quant + + fused_silu_mul_per_token_quant(out, scales, input) + + +def _rocm_aiter_fused_silu_mul_per_token_quant_fake( + out: torch.Tensor, scales: torch.Tensor, input: torch.Tensor +) -> None: + pass + + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -901,6 +915,14 @@ def register_ops_once() -> None: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_fused_silu_mul_per_token_quant", + op_func=_rocm_aiter_fused_silu_mul_per_token_quant_impl, + mutates_args=["out", "scales"], + fake_impl=_rocm_aiter_fused_silu_mul_per_token_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + _OPS_REGISTERED = True @staticmethod @@ -1125,6 +1147,24 @@ def per_token_quant( torch.ops.vllm.rocm_aiter_per_token_quant(out, x, scale) return out, scale + @staticmethod + def fused_silu_mul_per_token_quant( + input: torch.Tensor, + quant_dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert quant_dtype in [torch.int8, _FP8_DTYPE] + assert input.ndim == 2, "Input must be 2D tensor (num_tokens, 2*d)" + assert input.shape[-1] % 2 == 0, "Input last dimension must be even" + + num_tokens, input_dim = input.shape + d = input_dim // 2 + + out = torch.empty((num_tokens, d), dtype=quant_dtype, device=input.device) + scales = torch.empty((num_tokens, 1), dtype=torch.float32, device=input.device) + + torch.ops.vllm.rocm_aiter_fused_silu_mul_per_token_quant(out, scales, input) + return out, scales + @staticmethod def triton_fp4_gemm_dynamic_qaunt( x: torch.Tensor, diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e5051c934999..e3cafcfd7ee2 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -18,6 +18,7 @@ from vllm.compilation.rocm_aiter_fusion import ( RocmAiterRMSNormFp8GroupQuantFusionPass, RocmAiterSiluMulFp8GroupQuantFusionPass, + RocmAiterSiluMulFp8PerTokenQuantFusionPass, ) if current_platform.is_cuda_alike(): @@ -132,6 +133,7 @@ def configure(self, config: VllmConfig): self.passes += [ActivationQuantFusionPass(config)] if rocm_aiter_ops.is_enabled(): self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)] + self.passes += [RocmAiterSiluMulFp8PerTokenQuantFusionPass(config)] # ROCm AITER all-reduce + RMSNorm fusion if ( diff --git a/vllm/compilation/rocm_aiter_fusion.py b/vllm/compilation/rocm_aiter_fusion.py index 8b5db9de3818..4271ecc87e5a 100644 --- a/vllm/compilation/rocm_aiter_fusion.py +++ b/vllm/compilation/rocm_aiter_fusion.py @@ -33,7 +33,15 @@ AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default -FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default +FUSED_SILU_MUL_GROUP_QUANT_OP = ( + torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default +) +FUSED_SILU_MUL_PER_TOKEN_QUANT_OP = ( + torch.ops.vllm.rocm_aiter_fused_silu_mul_per_token_quant.default +) + +AITER_PER_TOKEN_QUANT_OP = torch.ops.vllm.rocm_aiter_per_token_quant.default +VLLM_PER_TOKEN_QUANT_OP = torch.ops._C.dynamic_per_token_scaled_fp8_quant.default class AiterRMSFp8GroupQuantPattern: @@ -196,7 +204,7 @@ def pattern( def replacement( input: torch.Tensor, ): - at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128) + at = FUSED_SILU_MUL_GROUP_QUANT_OP(x=input, group_size=128) return at[0], at[1] inputs = [ @@ -206,6 +214,74 @@ def replacement( pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) +class AiterSiluMulFp8PerTokenQuantPattern(ActivationQuantPattern): + """ + This pattern fuses aiter silu_and_mul and per-token fp8 quant custom + ops into an aiter fused_silu_mul_per_token_quant op. + """ + + def __init__(self, quant_op: OpOverload): + self.silu_and_mul_matcher = MatcherSiluAndMul() + self.quant_op = quant_op + + def register(self, pm_pass: PatternMatcherPass): + from torch._higher_order_ops.auto_functionalize import auto_functionalized + + def pattern( + input: torch.Tensor, + ): + at1 = self.silu_and_mul_matcher(input) + + d = input.shape[-1] // 2 + out_shape = input.shape[:-1] + (d,) + out = torch.empty(out_shape, dtype=FP8_DTYPE, device=input.device) + + scale_shape = out_shape[:-1] + (1,) + scale = torch.empty(scale_shape, dtype=torch.float32, device=input.device) + + if self.quant_op == AITER_PER_TOKEN_QUANT_OP: + at2 = auto_functionalized( + self.quant_op, + out=out, + x=at1, + scale=scale, + ) + return at2[1], at2[2] + else: + at2 = auto_functionalized( + self.quant_op, + result=out, + input=at1, + scale=scale, + scale_ub=None, + ) + return at2[1], at2[2] + + def replacement( + input: torch.Tensor, + ): + d = input.shape[-1] // 2 + out_shape = input.shape[:-1] + (d,) + out = torch.empty(out_shape, dtype=FP8_DTYPE, device=input.device) + + scale_shape = out_shape[:-1] + (1,) + scales = torch.empty(scale_shape, dtype=torch.float32, device=input.device) + + at = auto_functionalized( + FUSED_SILU_MUL_PER_TOKEN_QUANT_OP, + out=out, + scales=scales, + input=input, + ) + return at[1], at[2] + + inputs = [ + self.silu_and_mul_matcher.inputs()[0], + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): """ This pass fuses a pre-defined set of custom ops into fused ops. @@ -240,3 +316,35 @@ def uuid(self): AiterSiluMulFp8GroupQuantPattern, ] return VllmInductorPass.hash_source(self, *fusion_patterns) + + +class RocmAiterSiluMulFp8PerTokenQuantFusionPass(VllmPatternMatcherPass): + """ + This pass fuses SiLUAndMul with per-token FP8 quantization. + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_aiter_silu_mul_fp8_per_token_quant_fusion_pass" + ) + + # Register patterns for both aiter and vllm per-token quant ops + for quant_op in [AITER_PER_TOKEN_QUANT_OP, VLLM_PER_TOKEN_QUANT_OP]: + AiterSiluMulFp8PerTokenQuantPattern(quant_op).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: torch.fx.Graph): + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def uuid(self): + fusion_patterns = [ + ActivationQuantPattern, + AiterSiluMulFp8PerTokenQuantPattern, + ] + return VllmInductorPass.hash_source(self, *fusion_patterns)