Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def forward(
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# q, k = self.rotary_emb(positions, q, k)
v = v.contiguous()
# v = v.contiguous()
attn_output = self.attn(q, k, v, positions=positions)
output, _ = self.o_proj(attn_output)
return output
Expand Down
77 changes: 61 additions & 16 deletions vllm/v1/attention/backends/rocm_aiter_unified_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import AttentionType
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
Expand All @@ -20,6 +21,10 @@

logger = init_logger(__name__)

if current_platform.is_rocm():
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_enabled():
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache

class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
accept_output_buffer: bool = True
Expand Down Expand Up @@ -102,6 +107,7 @@
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.

Expand Down Expand Up @@ -145,23 +151,62 @@
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,

if positions is not None and query.shape[0] <= 256 and rocm_aiter_ops.is_enabled():

Check failure on line 155 in vllm/v1/attention/backends/rocm_aiter_unified_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/rocm_aiter_unified_attn.py:155:89: E501 Line too long (91 > 88)
assert self.kv_sharing_target_layer_name is None
cos_sin_cache = self.rotary_emb.cos_sin_cache
is_neox = self.rotary_emb.is_neox_style
cos, sin = cos_sin_cache.chunk(2, dim=-1)
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
query, key, key_cache, value_cache, output = (
fused_qk_rope_reshape_and_cache(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
positions,
cos,
sin,
layer._k_scale,
layer._v_scale,
is_neox,
flash_layout=True,
apply_scale=is_fp8_kv_cache,
offs=None,
q_out=query,
k_out=key,
output_zeros=True,
zeros_out=output,
)
)
else:
if positions is not None:
if current_platform.is_rocm():
query, key = self.rotary_emb.forward_cuda(positions, query, key)
else:
query, key = self.rotary_emb(positions, query, key)
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
Expand Down
Loading