From e8830759d41ac2d4ec367059140034b68919de86 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Fri, 2 Jan 2026 20:42:00 +0000 Subject: [PATCH] fix rope kv_cache for RocmAiterUnifiedAttentionImpl --- vllm/model_executor/models/gpt_oss.py | 2 +- .../backends/rocm_aiter_unified_attn.py | 77 +++++++++++++++---- 2 files changed, 62 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index a423d54a3188..c9bf7d11b05c 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -137,7 +137,7 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # q, k = self.rotary_emb(positions, q, k) - v = v.contiguous() + # v = v.contiguous() attn_output = self.attn(q, k, v, positions=positions) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index 16fb52ab501c..377d10250918 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -7,6 +7,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import AttentionType from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, @@ -20,6 +21,10 @@ logger = init_logger(__name__) +if current_platform.is_rocm(): + from vllm._aiter_ops import rocm_aiter_ops + if rocm_aiter_ops.is_enabled(): + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): accept_output_buffer: bool = True @@ -102,6 +107,7 @@ def forward( output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -145,23 +151,62 @@ def forward( # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached # in KV cache. - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + + if positions is not None and query.shape[0] <= 256 and rocm_aiter_ops.is_enabled(): + assert self.kv_sharing_target_layer_name is None + cos_sin_cache = self.rotary_emb.cos_sin_cache + is_neox = self.rotary_emb.is_neox_style + cos, sin = cos_sin_cache.chunk(2, dim=-1) + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) + query, key, key_cache, value_cache, output = ( + fused_qk_rope_reshape_and_cache( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + positions, + cos, + sin, + layer._k_scale, + layer._v_scale, + is_neox, + flash_layout=True, + apply_scale=is_fp8_kv_cache, + offs=None, + q_out=query, + k_out=key, + output_zeros=True, + zeros_out=output, + ) ) + else: + if positions is not None: + if current_platform.is_rocm(): + query, key = self.rotary_emb.forward_cuda(positions, query, key) + else: + query, key = self.rotary_emb(positions, query, key) + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype)