diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 781f77e96319..1b98ec7b6624 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -128,6 +128,21 @@ def _flash_attn_varlen_diff_headdims( return self._flash_attn_varlen_diff_headdims_rocm( q, k, v, softmax_scale=softmax_scale, **kwargs ) + elif current_platform.is_rocm(): + from aiter.ops.triton.mha import flash_attn_varlen_func + result = flash_attn_varlen_func( + q=q, + k=k, + v=v, + return_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + if type(result) is tuple and return_softmax_lse: + output, lse = result + lse = lse.T.contiguous() + return (output, lse) + return result else: return super()._flash_attn_varlen_diff_headdims( q,