From 31ea5c1633ec3a798017d936dd59399ab9df2606 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Fri, 14 Nov 2025 10:08:08 +0800 Subject: [PATCH 1/2] use aiter triton kernel as triton mha fallback path Signed-off-by: zhuyuhua-v --- vllm/v1/attention/backends/mla/triton_mla.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 781f77e96319..959af0ed1e96 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -129,14 +129,20 @@ def _flash_attn_varlen_diff_headdims( q, k, v, softmax_scale=softmax_scale, **kwargs ) else: - return super()._flash_attn_varlen_diff_headdims( - q, - k, - v, - return_softmax_lse=return_softmax_lse, + from aiter.ops.triton.mha import flash_attn_varlen_func + result = flash_attn_varlen_func( + q=q, + k=k, + v=v, + return_lse=return_softmax_lse, softmax_scale=softmax_scale, **kwargs, ) + if type(result) is tuple and return_softmax_lse: + output, lse = result + lse = lse.T.contiguous() + return (output, lse) + return result def _forward_decode( self, From 58a25ff716105ea54e8d7d655d13da7d2ef9d85e Mon Sep 17 00:00:00 2001 From: Zhu Yuhua Date: Thu, 20 Nov 2025 17:24:20 +0800 Subject: [PATCH 2/2] update ROCM condition in flash attention Refactor ROCM handling in triton mha method. --- vllm/v1/attention/backends/mla/triton_mla.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 959af0ed1e96..1b98ec7b6624 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -128,7 +128,7 @@ def _flash_attn_varlen_diff_headdims( return self._flash_attn_varlen_diff_headdims_rocm( q, k, v, softmax_scale=softmax_scale, **kwargs ) - else: + elif current_platform.is_rocm(): from aiter.ops.triton.mha import flash_attn_varlen_func result = flash_attn_varlen_func( q=q, @@ -143,6 +143,15 @@ def _flash_attn_varlen_diff_headdims( lse = lse.T.contiguous() return (output, lse) return result + else: + return super()._flash_attn_varlen_diff_headdims( + q, + k, + v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) def _forward_decode( self,