From 19e45a52361d6f3857221b1906410f865b282d4d Mon Sep 17 00:00:00 2001 From: ganyi Date: Tue, 4 Nov 2025 09:09:49 +0000 Subject: [PATCH 1/9] enable mla_asm in sparse_mla backend Signed-off-by: ganyi --- vllm/attention/ops/rocm_aiter_mla_sparse.py | 20 +- vllm/model_executor/models/deepseek_v2.py | 268 +++++++++++++++++- vllm/platforms/rocm.py | 3 - vllm/v1/attention/backends/mla/indexer.py | 12 +- .../backends/mla/rocm_aiter_mla_sparse.py | 147 +++++++++- 5 files changed, 412 insertions(+), 38 deletions(-) diff --git a/vllm/attention/ops/rocm_aiter_mla_sparse.py b/vllm/attention/ops/rocm_aiter_mla_sparse.py index 080e92ecc940..bc19559ae51c 100644 --- a/vllm/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/attention/ops/rocm_aiter_mla_sparse.py @@ -185,25 +185,31 @@ def rocm_fp8_paged_mqa_logits( """ if rocm_aiter_ops.is_enabled(): - from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1 + batch_size, next_n, heads, head_dim = q_fp8.shape + num_blocks, block_size, _, _ = kv_cache_fp8.shape - batch_size, next_n, heads, _ = q_fp8.shape - out_qk = torch.full( - (heads, batch_size * next_n, max_model_len), + from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits + + out_logits = torch.full( + [batch_size * next_n, max_model_len], float("-inf"), device="cuda", dtype=torch.float32, ) - deepgemm_fp8_paged_mqa_logits_stage1( + deepgemm_fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, weights, - out_qk, + out_logits, context_lens, block_tables, max_model_len, + ChunkK=256, + Preshuffle=block_size == 64, + KVBlockSize=block_size, + WavePerEU=2, ) - return out_qk.sum(dim=0) + return out_logits else: return fp8_paged_mqa_logits_torch( q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 0b6513789aea..339783e187da 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -76,6 +76,7 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import ( @@ -93,10 +94,8 @@ maybe_prefix, ) -if current_platform.is_cuda_alike(): - from vllm import _custom_ops as ops -elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops as ops +if current_platform.is_cuda_alike() or current_platform.is_xpu(): + pass logger = init_logger(__name__) @@ -600,6 +599,187 @@ def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend +@triton.jit +def _indexer_k_quant_and_cache_kernel( + k_ptr, # [num_tokens, head_dim] + kv_cache_ptr, # [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B] + kv_cache_scale_ptr, # [n_blks, blk_size] + slot_mapping_ptr, # [num_tokens] + kv_cache_scale_stride, + kv_cache_value_stride, + block_size, + num_tokens, + head_dim: tl.constexpr, + BLOCK_TILE_SIZE: tl.constexpr, + HEAD_TILE_SIZE: tl.constexpr, + IS_FNUZ: tl.constexpr, + USE_UE8M0: tl.constexpr, +): + tid = tl.program_id(0) + offset = tl.arange(0, head_dim) + tile_offset = ( + offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE + + offset % HEAD_TILE_SIZE + ) + tile_store_offset = tile_offset + # for idx in tl.range(tid, num_tokens, n_program): + src_ptr = k_ptr + tid * head_dim + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + tile_block_id = block_offset // BLOCK_TILE_SIZE + tile_block_offset = block_offset % BLOCK_TILE_SIZE + val = tl.load(src_ptr + offset) + amax = tl.max(val.abs(), axis=-1).to(tl.float32) + if IS_FNUZ: + scale = tl.maximum(1e-4, amax) / 224.0 + else: + scale = tl.maximum(1e-4, amax) / 448.0 + + if USE_UE8M0: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + + fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty) + dst_ptr = ( + kv_cache_ptr + + block_id * kv_cache_value_stride + + tile_block_id * BLOCK_TILE_SIZE * head_dim + + tile_block_offset * HEAD_TILE_SIZE + ) + tl.store(dst_ptr + tile_store_offset, fp8_val) + dst_scale_ptr = kv_cache_scale_ptr + block_id * kv_cache_scale_stride + block_offset + tl.store(dst_scale_ptr, scale) + + +def indexer_k_quant_and_cache_triton( + k: torch.Tensor, + kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] + slot_mapping: torch.Tensor, + scale_fmt, + block_tile_size=16, + head_tile_size=16, +): + num_blocks = kv_cache.shape[0] + head_dim = k.shape[-1] + num_tokens = slot_mapping.shape[0] + block_size = kv_cache.shape[1] + # In real layout, we store the first portion as kv cache value + # and second portion as kv cache scale + kv_cache = kv_cache.view(num_blocks, -1) + kv_cache_value = kv_cache[:, : block_size * head_dim] + kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32) + head_tile_size = head_tile_size // kv_cache.element_size() + grid = (num_tokens,) + _indexer_k_quant_and_cache_kernel[grid]( + k, + kv_cache_value, + kv_cache_scale, + slot_mapping, + kv_cache_scale.stride(0), + kv_cache_value.stride(0), + block_size, + num_tokens, + head_dim, + block_tile_size, + head_tile_size, + IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, + USE_UE8M0=scale_fmt == "ue8m0", + ) + + +@triton.jit +def _cp_gather_indexer_quant_cache_kernel( + kv_cache_ptr, # [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B] + kv_cache_scale_ptr, # [n_blks, blk_size] + k_fp8_ptr, # [num_tokens, head_dim] + k_scale_ptr, # [num_tokens] + block_table_ptr, # [batch_size, block_table_stride] + cu_seqlen_ptr, # [batch_size + 1] + token_to_seq_ptr, # [num_tokens] + block_size, + block_table_stride, + kv_cache_stride, + kv_cache_scale_stride, + HEAD_DIM: tl.constexpr, + BLOCK_TILE_SIZE: tl.constexpr, + HEAD_TILE_SIZE: tl.constexpr, +): + tid = tl.program_id(0) + offset = tl.arange(0, HEAD_DIM) + batch_id = tl.load(token_to_seq_ptr + tid) + batch_start = tl.load(cu_seqlen_ptr + batch_id) + batch_end = tl.load(cu_seqlen_ptr + batch_id + 1) + batch_offset = tid - batch_start + if tid >= batch_end: + return + block_table_id = batch_offset // block_size + block_offset = batch_offset % block_size + block_table_offset = batch_id * block_table_stride + block_table_id + block_id = tl.load(block_table_ptr + block_table_offset) + tiled_block_id = block_offset // BLOCK_TILE_SIZE + tiled_block_offset = block_offset % BLOCK_TILE_SIZE + src_cache_offset = ( + block_id * kv_cache_stride + + tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE + + tiled_block_offset * HEAD_TILE_SIZE + ) + src_scale_offset = block_id * kv_cache_scale_stride + block_offset + dst_offset = tid * HEAD_DIM + src_scale_ptr = kv_cache_scale_ptr + src_scale_offset + src_cache_ptr = kv_cache_ptr + src_cache_offset + dst_k_ptr = k_fp8_ptr + dst_offset + scale_val = tl.load(src_scale_ptr) + tl.store(k_scale_ptr + tid, scale_val) + tiled_src_offset = ( + offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE + + offset % HEAD_TILE_SIZE + ) + val = tl.load(src_cache_ptr + tiled_src_offset) + tl.store(dst_k_ptr + offset, val) + + +def cp_gather_indexer_k_quant_cache_triton( + k_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] + k_fp8: torch.Tensor, + k_fp8_scale: torch.Tensor, + block_table: torch.Tensor, + cu_seqlen: torch.Tensor, + token_to_seq: torch.Tensor, + block_tile_size: int = 16, + head_tile_size: int = 16, +): + num_tokens = k_fp8.size(0) + block_size = k_cache.size(1) + block_table_stride = block_table.stride(0) + head_dim = k_fp8.shape[-1] + num_blocks = k_cache.shape[0] + # we assume the kv cache already been split to 2 portion + k_cache = k_cache.view(num_blocks, -1) + fp8_dtype = current_platform.fp8_dtype() + k_cache_value = k_cache[:, : block_size * head_dim].view(fp8_dtype) + k_cache_scale = k_cache[:, block_size * head_dim :].view(torch.float32) + grid = (num_tokens,) + k_fp8_scale = k_fp8_scale.view(torch.float32) + _cp_gather_indexer_quant_cache_kernel[grid]( + k_cache_value, + k_cache_scale, + k_fp8, + k_fp8_scale, + block_table, + cu_seqlen, + token_to_seq, + block_size, + block_table_stride, + k_cache_value.stride(0), + k_cache_scale.stride(0), + head_dim, + block_tile_size, + head_tile_size, + ) + + def sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: str, @@ -642,13 +822,26 @@ def sparse_attn_indexer( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - ops.indexer_k_quant_and_cache( - k, - kv_cache, - slot_mapping, - quant_block_size, - scale_fmt, - ) + # ops.indexer_k_quant_and_cache( + # k, + # kv_cache, + # slot_mapping, + # quant_block_size, + # scale_fmt, + # ) + # print("indexer and quant", flush=True) + # import torch.distributed as dist + # rid = dist.get_rank() + # print("k shape: ", k.shape, flush=True) + # torch.save(k, f"k_{rid}.pt") + # torch.save(kv_cache, f"kv_cache_{rid}.pt") + # torch.save(slot_mapping, f"slot_mapping_{rid}.pt") + # torch.cuda.synchronize() + + indexer_k_quant_and_cache_triton(k, kv_cache, slot_mapping, scale_fmt) + + # torch.cuda.synchronize() + # print("end of indexer and quant", flush=True) topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: @@ -664,13 +857,37 @@ def sparse_attn_indexer( device=k.device, dtype=torch.uint8, ) - ops.cp_gather_indexer_k_quant_cache( + # ops.cp_gather_indexer_k_quant_cache( + # kv_cache, + # k_fp8, + # k_scale, + # chunk.block_table, + # chunk.cu_seq_lens, + # ) + # print("cp gather indexer quant", flush=True) + # import torch.distributed as dist + # rid = dist.get_rank() + # print("kv_cache shape: ", kv_cache.shape, kv_cache.dtype, flush=True) + # print("k_fp8_shape: ", k_fp8.shape, flush=True) + # print("block_table shape: ", chunk.block_table, flush=True) + # torch.save(kv_cache, f"kv_cache_{rid}.pt") + # torch.save(k_fp8, f"k_fp8_{rid}.pt") + # torch.save(k_scale, f"k_scale_{rid}.pt") + # torch.save(chunk.block_table, f"block_table_{rid}.pt") + # torch.save(chunk.cu_seq_lens, f"cu_seq_lens_{rid}.pt") + # torch.save(chunk.token_to_seq, f"token_to_seq_{rid}.pt") + # torch.cuda.synchronize() + # print("done saving", flush=True) + cp_gather_indexer_k_quant_cache_triton( kv_cache, k_fp8, k_scale, chunk.block_table, chunk.cu_seq_lens, + chunk.token_to_seq, ) + # torch.cuda.synchronize() + # print("end of cp gather indexer quant", flush=True) fp8_mqa_logits_func = fp8_mqa_logits if current_platform.is_rocm(): from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits @@ -683,6 +900,8 @@ def sparse_attn_indexer( chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) + # torch.cuda.synchronize() + # print("end of mqa logits: ", flush=True) num_rows = logits.shape[0] topk_indices = topk_indices_buffer[ chunk.token_start : chunk.token_end, :topk_tokens @@ -697,6 +916,8 @@ def sparse_attn_indexer( logits.stride(1), topk_tokens, ) + # torch.cuda.synchronize() + # print("end of topk: ", flush=True) if has_decode: decode_metadata = attn_metadata.decode @@ -728,6 +949,8 @@ def sparse_attn_indexer( ) fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits + # print("start fp8_mqa pa", flush=True) + logits = fp8_paged_mqa_logits_func( padded_q_fp8_decode_tokens, kv_cache, @@ -737,9 +960,24 @@ def sparse_attn_indexer( decode_metadata.schedule_metadata, max_model_len=max_model_len, ) + + # torch.cuda.synchronize() + # print("end of fp8_mqa pa", flush=True) + # _, _, heads, _ = padded_q_fp8_decode_tokens.shape + # logits = torch.empty( + # (batch_size * next_n, max_model_len), + # device="cuda", + # dtype=torch.float32, + # ) num_rows = logits.shape[0] topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] - + # print("start topk per row decode", flush=True) + # print("next n: ", next_n, flush=True) + # print("seq_lens: ", decode_metadata.seq_lens, flush=True) + # print("num rows: ", num_rows, flush=True) + # torch.save(logits, "logits.pt") + # torch.save(topk_indices, "topk_indices.pt") + # torch.cuda.synchronize() torch.ops._C.top_k_per_row_decode( logits, next_n, @@ -750,6 +988,8 @@ def sparse_attn_indexer( logits.stride(1), topk_tokens, ) + # torch.cuda.synchronize() + # print("end of fp8_mqa pa topk", flush=True) if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens @@ -760,6 +1000,8 @@ def sparse_attn_indexer( topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( topk_indices ) + # torch.cuda.synchronize() + # print("end of unpack", flush=True) return topk_indices_buffer diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 876114c2d33a..04758091da8c 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -226,9 +226,6 @@ def get_attn_backend_cls( raise ValueError( "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." ) - assert block_size == 1, ( - "Sparse MLA backend on ROCm only supports block size 1 for now." - ) logger.info_once("Using Sparse MLA backend on V1 engine.") return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 77f1ba00d5b0..2060e18c40e7 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -11,7 +11,6 @@ ) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported from vllm.v1.attention.backends.utils import ( AttentionCGSupport, @@ -24,9 +23,8 @@ class DeepseekV32IndexerBackend(AttentionBackend): - @staticmethod - def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [1 if current_platform.is_rocm() else 64] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -62,6 +60,7 @@ class DeepseekV32IndexerPrefillChunkMetadata: cu_seqlen_ks: torch.Tensor cu_seqlen_ke: torch.Tensor cu_seq_lens: torch.Tensor + token_to_seq: torch.Tensor total_seq_lens: int token_start: int token_end: int @@ -258,6 +257,10 @@ def build_one_prefill_chunk( token_start = query_start_loc_cpu[reqs_start].item() token_end = query_start_loc_cpu[reqs_end].item() total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() + seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32) + token_to_seq = torch.repeat_interleave( + seq_idx, seq_lens_cpu[reqs_start:reqs_end] + ).to(self.device) assert total_seq_lens <= self.max_prefill_buffer_size cu_seq_lens = ( torch.cat( @@ -273,6 +276,7 @@ def build_one_prefill_chunk( cu_seqlen_ks=cu_seqlen_ks, cu_seqlen_ke=cu_seqlen_ke, cu_seq_lens=cu_seq_lens, + token_to_seq=token_to_seq, total_seq_lens=total_seq_lens, block_table=block_table[reqs_start:reqs_end], token_start=token_start, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index c0e7f0e380b9..da6f66d827c6 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -17,6 +17,7 @@ from vllm.attention.backends.utils import get_mla_dims from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.mla.common import ( MLACommonBaseImpl, ) @@ -35,6 +36,48 @@ logger = init_logger(__name__) +@triton.jit +def fetch_id_to_ragged_kernel( + in_tensor_ptr, # [num_seq, topk] + cumsum_ptr, # [num_seq + 1] + out_tensor_ptr, # [max_num_seq * topk] + in_tensor_ptr_stride, + TOPK: tl.constexpr, + TOKEN_NUM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + seq_id = tl.program_id(0) + block_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + token_start = tl.load(cumsum_ptr + seq_id) + token_end = tl.load(cumsum_ptr + seq_id + 1) + token_num = token_end - token_start + row_offset = block_id * BLOCK_SIZE + if row_offset >= token_num: + return + in_tensor_offset = seq_id * in_tensor_ptr_stride + row_offset + offset + in_tensor_mask = (row_offset + offset) < TOPK + in_tensor_val = tl.load(in_tensor_ptr + in_tensor_offset, mask=in_tensor_mask) + out_tensor_offset = token_start + row_offset + offset + out_tensor_mask = (out_tensor_offset < token_end) & in_tensor_mask + tl.store(out_tensor_ptr + out_tensor_offset, in_tensor_val, mask=out_tensor_mask) + + +def fetch_id_to_ragged_triton( + in_tensor: torch.Tensor, cumsum: torch.Tensor, out_tensor: torch.Tensor, topk +): + num_tokens = in_tensor.size(0) + block_size = 64 + num_block_per_row = triton.cdiv(topk, block_size) + grid = ( + num_tokens, + num_block_per_row, + ) + fetch_id_to_ragged_kernel[grid]( + in_tensor, cumsum, out_tensor, in_tensor.stride(0), topk, num_tokens, block_size + ) + + class ROCMAiterMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True @@ -85,6 +128,16 @@ class ROCMAiterMLASparseMetadata: block_table: torch.Tensor req_id_per_token: torch.Tensor + + qo_indptr: torch.Tensor + paged_kv_last_page_len: torch.Tensor + paged_kv_indices: torch.Tensor + paged_kv_indptr: torch.Tensor + paged_kv_indptr_rest: torch.Tensor + + # buffer_for_paged_kv_indices: torch.Tensor + # buffer_for_paged_kv_indptr: torch.Tensor + block_size: int = 1 topk_tokens: int = 2048 @@ -93,7 +146,11 @@ class ROCMAiterMLASparseMetadata: class ROCMAiterMLASparseMetadataBuilder( AttentionMetadataBuilder[ROCMAiterMLASparseMetadata] ): - cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) + buffer_for_paged_kv_indices: torch.Tensor = None + buffer_for_paged_kv_indptr: torch.Tensor = None def __init__( self, @@ -106,6 +163,8 @@ def __init__( self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.device = device + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + max_num_seqs = vllm_config.scheduler_config.max_num_seqs self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) @@ -126,6 +185,32 @@ def __init__( dtype=torch.int32, device=device, ) + self.qo_indptr = torch.arange( + 0, max_num_batched_tokens + 1, dtype=torch.int32, device=device + ) + self.paged_kv_last_page_len = torch.ones( + max_num_seqs, dtype=torch.int32, device=device + ) + + # These two needs to be calculated in runtime, + # but we still needs to prepare the buffer + self.paged_kv_indices = torch.zeros( + [max_num_batched_tokens * self.topk_tokens], + dtype=torch.int32, + device=device, + ) + self.paged_kv_indptr = torch.zeros( + [max_num_seqs + 1], dtype=torch.int32, device=device + ) + if ROCMAiterMLASparseMetadataBuilder.buffer_for_paged_kv_indices is None: + ROCMAiterMLASparseMetadataBuilder.buffer_for_paged_kv_indices = torch.zeros( + [max_num_batched_tokens * self.topk_tokens], + dtype=torch.int32, + device=device, + ) + ROCMAiterMLASparseMetadataBuilder.buffer_for_paged_kv_indptr = torch.zeros( + [max_num_seqs + 1], dtype=torch.int32, device=device + ) def build( self, @@ -144,8 +229,16 @@ def build( self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( torch.from_numpy(req_id_per_token), non_blocking=True ) - req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + self.paged_kv_indices.fill_(0) + self.paged_kv_indptr.fill_(0) + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + qo_indptr = self.qo_indptr[: num_tokens + 1] + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens] + paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens] + paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1] + paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :] + # print("paged kv indptr shape: ", paged_kv_indptr.shape) metadata = ROCMAiterMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, @@ -157,6 +250,11 @@ def build( req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, topk_tokens=self.topk_tokens, + qo_indptr=qo_indptr, + paged_kv_last_page_len=paged_kv_last_page_len, + paged_kv_indices=paged_kv_indices, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indptr_rest=paged_kv_indptr_rest, ) return metadata @@ -228,20 +326,47 @@ def __init__( def _forward_bf16_kv( self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, + q: torch.Tensor, # [sq, heads, d_qk] + kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk] + topk_indices: torch.Tensor, # [sq, topk] attn_metadata: ROCMAiterMLASparseMetadata, ) -> torch.Tensor: + # print("into sparse mla", flush=True) num_tokens = q.shape[0] - kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( - -1, 1, kv_c_and_k_pe_cache.shape[-1] + output = torch.empty( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=q.dtype, + device=q.device, + ) + # We assume the block size is 1 + # seq_len = (topk_indices == -1).int().argmax(dim=-1) + # torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:]) + seq_len = (topk_indices != -1).sum(dim=-1) + # print("seqlens: ", seq_len, flush=True) + # print("topk shape: ", topk_indices.shape) + torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:]) + attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1]) + # print("paged kv indptr shape: ", attn_metadata.paged_kv_indptr.shape) + # print("topk indices: ", topk_indices, flush=True) + fetch_id_to_ragged_triton( + topk_indices, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.topk_tokens, + ) + + rocm_aiter_ops.mla_decode_fwd( + q, + kv_c_and_k_pe_cache, + output, + self.scale, + attn_metadata.qo_indptr, + 1, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len, ) - topk_indices = topk_indices.view(num_tokens, 1, -1) - output = reference_mla_sparse_prefill( - q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale, 512 - )[0] return output[:, : self.num_heads, :] def forward( From 80184fbd2c26e32a5db6775e44d35195cb17fffc Mon Sep 17 00:00:00 2001 From: ganyi Date: Mon, 24 Nov 2025 02:21:03 +0000 Subject: [PATCH 2/9] remove unnecessary code and comments Signed-off-by: ganyi --- vllm/model_executor/models/deepseek_v2.py | 52 +------------------ .../backends/mla/rocm_aiter_mla_sparse.py | 24 +-------- 2 files changed, 2 insertions(+), 74 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 339783e187da..a24619bc56ec 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -829,20 +829,9 @@ def sparse_attn_indexer( # quant_block_size, # scale_fmt, # ) - # print("indexer and quant", flush=True) - # import torch.distributed as dist - # rid = dist.get_rank() - # print("k shape: ", k.shape, flush=True) - # torch.save(k, f"k_{rid}.pt") - # torch.save(kv_cache, f"kv_cache_{rid}.pt") - # torch.save(slot_mapping, f"slot_mapping_{rid}.pt") - # torch.cuda.synchronize() indexer_k_quant_and_cache_triton(k, kv_cache, slot_mapping, scale_fmt) - # torch.cuda.synchronize() - # print("end of indexer and quant", flush=True) - topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill @@ -864,20 +853,6 @@ def sparse_attn_indexer( # chunk.block_table, # chunk.cu_seq_lens, # ) - # print("cp gather indexer quant", flush=True) - # import torch.distributed as dist - # rid = dist.get_rank() - # print("kv_cache shape: ", kv_cache.shape, kv_cache.dtype, flush=True) - # print("k_fp8_shape: ", k_fp8.shape, flush=True) - # print("block_table shape: ", chunk.block_table, flush=True) - # torch.save(kv_cache, f"kv_cache_{rid}.pt") - # torch.save(k_fp8, f"k_fp8_{rid}.pt") - # torch.save(k_scale, f"k_scale_{rid}.pt") - # torch.save(chunk.block_table, f"block_table_{rid}.pt") - # torch.save(chunk.cu_seq_lens, f"cu_seq_lens_{rid}.pt") - # torch.save(chunk.token_to_seq, f"token_to_seq_{rid}.pt") - # torch.cuda.synchronize() - # print("done saving", flush=True) cp_gather_indexer_k_quant_cache_triton( kv_cache, k_fp8, @@ -886,8 +861,6 @@ def sparse_attn_indexer( chunk.cu_seq_lens, chunk.token_to_seq, ) - # torch.cuda.synchronize() - # print("end of cp gather indexer quant", flush=True) fp8_mqa_logits_func = fp8_mqa_logits if current_platform.is_rocm(): from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits @@ -900,8 +873,6 @@ def sparse_attn_indexer( chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) - # torch.cuda.synchronize() - # print("end of mqa logits: ", flush=True) num_rows = logits.shape[0] topk_indices = topk_indices_buffer[ chunk.token_start : chunk.token_end, :topk_tokens @@ -916,8 +887,6 @@ def sparse_attn_indexer( logits.stride(1), topk_tokens, ) - # torch.cuda.synchronize() - # print("end of topk: ", flush=True) if has_decode: decode_metadata = attn_metadata.decode @@ -949,7 +918,6 @@ def sparse_attn_indexer( ) fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits - # print("start fp8_mqa pa", flush=True) logits = fp8_paged_mqa_logits_func( padded_q_fp8_decode_tokens, @@ -961,23 +929,8 @@ def sparse_attn_indexer( max_model_len=max_model_len, ) - # torch.cuda.synchronize() - # print("end of fp8_mqa pa", flush=True) - # _, _, heads, _ = padded_q_fp8_decode_tokens.shape - # logits = torch.empty( - # (batch_size * next_n, max_model_len), - # device="cuda", - # dtype=torch.float32, - # ) num_rows = logits.shape[0] topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] - # print("start topk per row decode", flush=True) - # print("next n: ", next_n, flush=True) - # print("seq_lens: ", decode_metadata.seq_lens, flush=True) - # print("num rows: ", num_rows, flush=True) - # torch.save(logits, "logits.pt") - # torch.save(topk_indices, "topk_indices.pt") - # torch.cuda.synchronize() torch.ops._C.top_k_per_row_decode( logits, next_n, @@ -988,8 +941,7 @@ def sparse_attn_indexer( logits.stride(1), topk_tokens, ) - # torch.cuda.synchronize() - # print("end of fp8_mqa pa topk", flush=True) + if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens @@ -1000,8 +952,6 @@ def sparse_attn_indexer( topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( topk_indices ) - # torch.cuda.synchronize() - # print("end of unpack", flush=True) return topk_indices_buffer diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index da6f66d827c6..2051d092adca 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -135,9 +135,6 @@ class ROCMAiterMLASparseMetadata: paged_kv_indptr: torch.Tensor paged_kv_indptr_rest: torch.Tensor - # buffer_for_paged_kv_indices: torch.Tensor - # buffer_for_paged_kv_indptr: torch.Tensor - block_size: int = 1 topk_tokens: int = 2048 @@ -149,8 +146,6 @@ class ROCMAiterMLASparseMetadataBuilder( cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE ) - buffer_for_paged_kv_indices: torch.Tensor = None - buffer_for_paged_kv_indptr: torch.Tensor = None def __init__( self, @@ -202,15 +197,6 @@ def __init__( self.paged_kv_indptr = torch.zeros( [max_num_seqs + 1], dtype=torch.int32, device=device ) - if ROCMAiterMLASparseMetadataBuilder.buffer_for_paged_kv_indices is None: - ROCMAiterMLASparseMetadataBuilder.buffer_for_paged_kv_indices = torch.zeros( - [max_num_batched_tokens * self.topk_tokens], - dtype=torch.int32, - device=device, - ) - ROCMAiterMLASparseMetadataBuilder.buffer_for_paged_kv_indptr = torch.zeros( - [max_num_seqs + 1], dtype=torch.int32, device=device - ) def build( self, @@ -238,7 +224,7 @@ def build( paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens] paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1] paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :] - # print("paged kv indptr shape: ", paged_kv_indptr.shape) + metadata = ROCMAiterMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, @@ -331,23 +317,15 @@ def _forward_bf16_kv( topk_indices: torch.Tensor, # [sq, topk] attn_metadata: ROCMAiterMLASparseMetadata, ) -> torch.Tensor: - # print("into sparse mla", flush=True) num_tokens = q.shape[0] output = torch.empty( [num_tokens, self.num_heads, self.kv_lora_rank], dtype=q.dtype, device=q.device, ) - # We assume the block size is 1 - # seq_len = (topk_indices == -1).int().argmax(dim=-1) - # torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:]) seq_len = (topk_indices != -1).sum(dim=-1) - # print("seqlens: ", seq_len, flush=True) - # print("topk shape: ", topk_indices.shape) torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:]) attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1]) - # print("paged kv indptr shape: ", attn_metadata.paged_kv_indptr.shape) - # print("topk indices: ", topk_indices, flush=True) fetch_id_to_ragged_triton( topk_indices, attn_metadata.paged_kv_indptr, From 1f70773b7fa14e3c59c3de92f98e869c8efffb00 Mon Sep 17 00:00:00 2001 From: ganyi Date: Mon, 24 Nov 2025 02:37:59 +0000 Subject: [PATCH 3/9] move the triton kernel to rocm_aiter_mla_sparse.py Signed-off-by: ganyi --- vllm/attention/ops/rocm_aiter_mla_sparse.py | 182 +++++++++++++++ vllm/model_executor/models/deepseek_v2.py | 233 +++----------------- 2 files changed, 214 insertions(+), 201 deletions(-) diff --git a/vllm/attention/ops/rocm_aiter_mla_sparse.py b/vllm/attention/ops/rocm_aiter_mla_sparse.py index bc19559ae51c..b4c773246809 100644 --- a/vllm/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/attention/ops/rocm_aiter_mla_sparse.py @@ -8,10 +8,192 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton logger = init_logger(__name__) +@triton.jit +def _indexer_k_quant_and_cache_kernel( + k_ptr, # [num_tokens, head_dim] + kv_cache_ptr, # [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B] + kv_cache_scale_ptr, # [n_blks, blk_size] + slot_mapping_ptr, # [num_tokens] + kv_cache_scale_stride, + kv_cache_value_stride, + block_size, + num_tokens, + head_dim: tl.constexpr, + BLOCK_TILE_SIZE: tl.constexpr, + HEAD_TILE_SIZE: tl.constexpr, + IS_FNUZ: tl.constexpr, + USE_UE8M0: tl.constexpr, +): + tid = tl.program_id(0) + offset = tl.arange(0, head_dim) + tile_offset = ( + offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE + + offset % HEAD_TILE_SIZE + ) + tile_store_offset = tile_offset + # for idx in tl.range(tid, num_tokens, n_program): + src_ptr = k_ptr + tid * head_dim + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + tile_block_id = block_offset // BLOCK_TILE_SIZE + tile_block_offset = block_offset % BLOCK_TILE_SIZE + val = tl.load(src_ptr + offset) + amax = tl.max(val.abs(), axis=-1).to(tl.float32) + if IS_FNUZ: + scale = tl.maximum(1e-4, amax) / 224.0 + else: + scale = tl.maximum(1e-4, amax) / 448.0 + + if USE_UE8M0: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + + fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty) + dst_ptr = ( + kv_cache_ptr + + block_id * kv_cache_value_stride + + tile_block_id * BLOCK_TILE_SIZE * head_dim + + tile_block_offset * HEAD_TILE_SIZE + ) + tl.store(dst_ptr + tile_store_offset, fp8_val) + dst_scale_ptr = kv_cache_scale_ptr + block_id * kv_cache_scale_stride + block_offset + tl.store(dst_scale_ptr, scale) + + +def indexer_k_quant_and_cache_triton( + k: torch.Tensor, + kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] + slot_mapping: torch.Tensor, + scale_fmt, + block_tile_size=16, + head_tile_size=16, +): + num_blocks = kv_cache.shape[0] + head_dim = k.shape[-1] + num_tokens = slot_mapping.shape[0] + block_size = kv_cache.shape[1] + # In real layout, we store the first portion as kv cache value + # and second portion as kv cache scale + kv_cache = kv_cache.view(num_blocks, -1) + kv_cache_value = kv_cache[:, : block_size * head_dim] + kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32) + head_tile_size = head_tile_size // kv_cache.element_size() + grid = (num_tokens,) + _indexer_k_quant_and_cache_kernel[grid]( + k, + kv_cache_value, + kv_cache_scale, + slot_mapping, + kv_cache_scale.stride(0), + kv_cache_value.stride(0), + block_size, + num_tokens, + head_dim, + block_tile_size, + head_tile_size, + IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, + USE_UE8M0=scale_fmt == "ue8m0", + ) + + +@triton.jit +def _cp_gather_indexer_quant_cache_kernel( + kv_cache_ptr, # [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B] + kv_cache_scale_ptr, # [n_blks, blk_size] + k_fp8_ptr, # [num_tokens, head_dim] + k_scale_ptr, # [num_tokens] + block_table_ptr, # [batch_size, block_table_stride] + cu_seqlen_ptr, # [batch_size + 1] + token_to_seq_ptr, # [num_tokens] + block_size, + block_table_stride, + kv_cache_stride, + kv_cache_scale_stride, + HEAD_DIM: tl.constexpr, + BLOCK_TILE_SIZE: tl.constexpr, + HEAD_TILE_SIZE: tl.constexpr, +): + tid = tl.program_id(0) + offset = tl.arange(0, HEAD_DIM) + batch_id = tl.load(token_to_seq_ptr + tid) + batch_start = tl.load(cu_seqlen_ptr + batch_id) + batch_end = tl.load(cu_seqlen_ptr + batch_id + 1) + batch_offset = tid - batch_start + if tid >= batch_end: + return + block_table_id = batch_offset // block_size + block_offset = batch_offset % block_size + block_table_offset = batch_id * block_table_stride + block_table_id + block_id = tl.load(block_table_ptr + block_table_offset) + tiled_block_id = block_offset // BLOCK_TILE_SIZE + tiled_block_offset = block_offset % BLOCK_TILE_SIZE + src_cache_offset = ( + block_id * kv_cache_stride + + tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE + + tiled_block_offset * HEAD_TILE_SIZE + ) + src_scale_offset = block_id * kv_cache_scale_stride + block_offset + dst_offset = tid * HEAD_DIM + src_scale_ptr = kv_cache_scale_ptr + src_scale_offset + src_cache_ptr = kv_cache_ptr + src_cache_offset + dst_k_ptr = k_fp8_ptr + dst_offset + scale_val = tl.load(src_scale_ptr) + tl.store(k_scale_ptr + tid, scale_val) + tiled_src_offset = ( + offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE + + offset % HEAD_TILE_SIZE + ) + val = tl.load(src_cache_ptr + tiled_src_offset) + tl.store(dst_k_ptr + offset, val) + + +def cp_gather_indexer_k_quant_cache_triton( + k_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] + k_fp8: torch.Tensor, + k_fp8_scale: torch.Tensor, + block_table: torch.Tensor, + cu_seqlen: torch.Tensor, + token_to_seq: torch.Tensor, + block_tile_size: int = 16, + head_tile_size: int = 16, +): + num_tokens = k_fp8.size(0) + block_size = k_cache.size(1) + block_table_stride = block_table.stride(0) + head_dim = k_fp8.shape[-1] + num_blocks = k_cache.shape[0] + # we assume the kv cache already been split to 2 portion + k_cache = k_cache.view(num_blocks, -1) + fp8_dtype = current_platform.fp8_dtype() + k_cache_value = k_cache[:, : block_size * head_dim].view(fp8_dtype) + k_cache_scale = k_cache[:, block_size * head_dim :].view(torch.float32) + grid = (num_tokens,) + k_fp8_scale = k_fp8_scale.view(torch.float32) + _cp_gather_indexer_quant_cache_kernel[grid]( + k_cache_value, + k_cache_scale, + k_fp8, + k_fp8_scale, + block_table, + cu_seqlen, + token_to_seq, + block_size, + block_table_stride, + k_cache_value.stride(0), + k_cache_scale.stride(0), + head_dim, + block_tile_size, + head_tile_size, + ) + + # Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84 def fp8_mqa_logits_torch( q: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a24619bc56ec..a057764bbeb4 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -76,7 +76,6 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import ( @@ -94,8 +93,10 @@ maybe_prefix, ) -if current_platform.is_cuda_alike() or current_platform.is_xpu(): - pass +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops logger = init_logger(__name__) @@ -599,187 +600,6 @@ def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend -@triton.jit -def _indexer_k_quant_and_cache_kernel( - k_ptr, # [num_tokens, head_dim] - kv_cache_ptr, # [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B] - kv_cache_scale_ptr, # [n_blks, blk_size] - slot_mapping_ptr, # [num_tokens] - kv_cache_scale_stride, - kv_cache_value_stride, - block_size, - num_tokens, - head_dim: tl.constexpr, - BLOCK_TILE_SIZE: tl.constexpr, - HEAD_TILE_SIZE: tl.constexpr, - IS_FNUZ: tl.constexpr, - USE_UE8M0: tl.constexpr, -): - tid = tl.program_id(0) - offset = tl.arange(0, head_dim) - tile_offset = ( - offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE - + offset % HEAD_TILE_SIZE - ) - tile_store_offset = tile_offset - # for idx in tl.range(tid, num_tokens, n_program): - src_ptr = k_ptr + tid * head_dim - slot_id = tl.load(slot_mapping_ptr + tid) - if slot_id < 0: - return - block_id = slot_id // block_size - block_offset = slot_id % block_size - tile_block_id = block_offset // BLOCK_TILE_SIZE - tile_block_offset = block_offset % BLOCK_TILE_SIZE - val = tl.load(src_ptr + offset) - amax = tl.max(val.abs(), axis=-1).to(tl.float32) - if IS_FNUZ: - scale = tl.maximum(1e-4, amax) / 224.0 - else: - scale = tl.maximum(1e-4, amax) / 448.0 - - if USE_UE8M0: - scale = tl.exp2(tl.ceil(tl.log2(scale))) - - fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty) - dst_ptr = ( - kv_cache_ptr - + block_id * kv_cache_value_stride - + tile_block_id * BLOCK_TILE_SIZE * head_dim - + tile_block_offset * HEAD_TILE_SIZE - ) - tl.store(dst_ptr + tile_store_offset, fp8_val) - dst_scale_ptr = kv_cache_scale_ptr + block_id * kv_cache_scale_stride + block_offset - tl.store(dst_scale_ptr, scale) - - -def indexer_k_quant_and_cache_triton( - k: torch.Tensor, - kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] - slot_mapping: torch.Tensor, - scale_fmt, - block_tile_size=16, - head_tile_size=16, -): - num_blocks = kv_cache.shape[0] - head_dim = k.shape[-1] - num_tokens = slot_mapping.shape[0] - block_size = kv_cache.shape[1] - # In real layout, we store the first portion as kv cache value - # and second portion as kv cache scale - kv_cache = kv_cache.view(num_blocks, -1) - kv_cache_value = kv_cache[:, : block_size * head_dim] - kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32) - head_tile_size = head_tile_size // kv_cache.element_size() - grid = (num_tokens,) - _indexer_k_quant_and_cache_kernel[grid]( - k, - kv_cache_value, - kv_cache_scale, - slot_mapping, - kv_cache_scale.stride(0), - kv_cache_value.stride(0), - block_size, - num_tokens, - head_dim, - block_tile_size, - head_tile_size, - IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, - USE_UE8M0=scale_fmt == "ue8m0", - ) - - -@triton.jit -def _cp_gather_indexer_quant_cache_kernel( - kv_cache_ptr, # [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B] - kv_cache_scale_ptr, # [n_blks, blk_size] - k_fp8_ptr, # [num_tokens, head_dim] - k_scale_ptr, # [num_tokens] - block_table_ptr, # [batch_size, block_table_stride] - cu_seqlen_ptr, # [batch_size + 1] - token_to_seq_ptr, # [num_tokens] - block_size, - block_table_stride, - kv_cache_stride, - kv_cache_scale_stride, - HEAD_DIM: tl.constexpr, - BLOCK_TILE_SIZE: tl.constexpr, - HEAD_TILE_SIZE: tl.constexpr, -): - tid = tl.program_id(0) - offset = tl.arange(0, HEAD_DIM) - batch_id = tl.load(token_to_seq_ptr + tid) - batch_start = tl.load(cu_seqlen_ptr + batch_id) - batch_end = tl.load(cu_seqlen_ptr + batch_id + 1) - batch_offset = tid - batch_start - if tid >= batch_end: - return - block_table_id = batch_offset // block_size - block_offset = batch_offset % block_size - block_table_offset = batch_id * block_table_stride + block_table_id - block_id = tl.load(block_table_ptr + block_table_offset) - tiled_block_id = block_offset // BLOCK_TILE_SIZE - tiled_block_offset = block_offset % BLOCK_TILE_SIZE - src_cache_offset = ( - block_id * kv_cache_stride - + tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE - + tiled_block_offset * HEAD_TILE_SIZE - ) - src_scale_offset = block_id * kv_cache_scale_stride + block_offset - dst_offset = tid * HEAD_DIM - src_scale_ptr = kv_cache_scale_ptr + src_scale_offset - src_cache_ptr = kv_cache_ptr + src_cache_offset - dst_k_ptr = k_fp8_ptr + dst_offset - scale_val = tl.load(src_scale_ptr) - tl.store(k_scale_ptr + tid, scale_val) - tiled_src_offset = ( - offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE - + offset % HEAD_TILE_SIZE - ) - val = tl.load(src_cache_ptr + tiled_src_offset) - tl.store(dst_k_ptr + offset, val) - - -def cp_gather_indexer_k_quant_cache_triton( - k_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] - k_fp8: torch.Tensor, - k_fp8_scale: torch.Tensor, - block_table: torch.Tensor, - cu_seqlen: torch.Tensor, - token_to_seq: torch.Tensor, - block_tile_size: int = 16, - head_tile_size: int = 16, -): - num_tokens = k_fp8.size(0) - block_size = k_cache.size(1) - block_table_stride = block_table.stride(0) - head_dim = k_fp8.shape[-1] - num_blocks = k_cache.shape[0] - # we assume the kv cache already been split to 2 portion - k_cache = k_cache.view(num_blocks, -1) - fp8_dtype = current_platform.fp8_dtype() - k_cache_value = k_cache[:, : block_size * head_dim].view(fp8_dtype) - k_cache_scale = k_cache[:, block_size * head_dim :].view(torch.float32) - grid = (num_tokens,) - k_fp8_scale = k_fp8_scale.view(torch.float32) - _cp_gather_indexer_quant_cache_kernel[grid]( - k_cache_value, - k_cache_scale, - k_fp8, - k_fp8_scale, - block_table, - cu_seqlen, - token_to_seq, - block_size, - block_table_stride, - k_cache_value.stride(0), - k_cache_scale.stride(0), - head_dim, - block_tile_size, - head_tile_size, - ) - - def sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: str, @@ -822,15 +642,21 @@ def sparse_attn_indexer( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - # ops.indexer_k_quant_and_cache( - # k, - # kv_cache, - # slot_mapping, - # quant_block_size, - # scale_fmt, - # ) + indexer_k_quant_cache_and_cache_func = ops.indexer_k_quant_and_cache + if current_platform.is_rocm(): + from vllm.attention.ops.rocm_aiter_mla_sparse import ( + indexer_k_quant_and_cache_triton, + ) - indexer_k_quant_and_cache_triton(k, kv_cache, slot_mapping, scale_fmt) + indexer_k_quant_cache_and_cache_func = indexer_k_quant_and_cache_triton + + indexer_k_quant_cache_and_cache_func( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: @@ -846,21 +672,26 @@ def sparse_attn_indexer( device=k.device, dtype=torch.uint8, ) - # ops.cp_gather_indexer_k_quant_cache( - # kv_cache, - # k_fp8, - # k_scale, - # chunk.block_table, - # chunk.cu_seq_lens, - # ) - cp_gather_indexer_k_quant_cache_triton( + cp_gather_indexer_k_quant_cache_func = ops.cp_gather_indexer_k_quant_cache + if current_platform.is_rocm(): + from functools import partial + + from vllm.attention.ops.rocm_aiter_mla_sparse import ( + cp_gather_indexer_k_quant_cache_triton, + ) + + cp_gather_indexer_k_quant_cache_func = partial( + cp_gather_indexer_k_quant_cache_triton, + token_to_seq=chunk.token_to_seq, + ) + cp_gather_indexer_k_quant_cache_func( kv_cache, k_fp8, k_scale, chunk.block_table, chunk.cu_seq_lens, - chunk.token_to_seq, ) + fp8_mqa_logits_func = fp8_mqa_logits if current_platform.is_rocm(): from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits From 819202057b1090d3c168e30aa3d291aa712bdf4e Mon Sep 17 00:00:00 2001 From: ganyi Date: Mon, 24 Nov 2025 02:55:29 +0000 Subject: [PATCH 4/9] format the code Signed-off-by: ganyi --- vllm/v1/attention/backends/mla/indexer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 2060e18c40e7..23ea33351068 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -25,7 +25,6 @@ class DeepseekV32IndexerBackend(AttentionBackend): supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] - @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 128] From 0ed3e5b2918d2039dd90895b1f695b842b20d44f Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 27 Nov 2025 01:25:15 +0000 Subject: [PATCH 5/9] bug fix Signed-off-by: ganyi --- vllm/attention/ops/rocm_aiter_mla_sparse.py | 1 + vllm/v1/attention/backends/mla/common.py | 4 ++-- vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py | 3 +++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/attention/ops/rocm_aiter_mla_sparse.py b/vllm/attention/ops/rocm_aiter_mla_sparse.py index b4c773246809..c433030e60fe 100644 --- a/vllm/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/attention/ops/rocm_aiter_mla_sparse.py @@ -71,6 +71,7 @@ def indexer_k_quant_and_cache_triton( k: torch.Tensor, kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] slot_mapping: torch.Tensor, + quant_block_size, scale_fmt, block_tile_size=16, head_tile_size=16, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0a5257a1d87d..8f6e7ffaee4e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -256,8 +256,8 @@ class QueryLenSupport(Enum): is_vllm_fa = True except ImportError: # For rocm use upstream flash attention - if current_platform.is_rocm(): - from flash_attn import flash_attn_varlen_func + # if current_platform.is_rocm(): + # from flash_attn import flash_attn_varlen_func is_vllm_fa = False try: diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 2051d092adca..9de524093362 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -326,6 +326,7 @@ def _forward_bf16_kv( seq_len = (topk_indices != -1).sum(dim=-1) torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:]) attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1]) + print("paged_kv_indptr:", attn_metadata.paged_kv_indptr) fetch_id_to_ragged_triton( topk_indices, attn_metadata.paged_kv_indptr, @@ -333,6 +334,8 @@ def _forward_bf16_kv( attn_metadata.topk_tokens, ) + print("paged_kv_indices:", attn_metadata.paged_kv_indices) + rocm_aiter_ops.mla_decode_fwd( q, kv_c_and_k_pe_cache, From 1cef4726a84ad2e16829d9ec8cb3138b150afc84 Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 27 Nov 2025 03:42:01 +0000 Subject: [PATCH 6/9] remove print Signed-off-by: ganyi --- vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 9de524093362..2051d092adca 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -326,7 +326,6 @@ def _forward_bf16_kv( seq_len = (topk_indices != -1).sum(dim=-1) torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:]) attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1]) - print("paged_kv_indptr:", attn_metadata.paged_kv_indptr) fetch_id_to_ragged_triton( topk_indices, attn_metadata.paged_kv_indptr, @@ -334,8 +333,6 @@ def _forward_bf16_kv( attn_metadata.topk_tokens, ) - print("paged_kv_indices:", attn_metadata.paged_kv_indices) - rocm_aiter_ops.mla_decode_fwd( q, kv_c_and_k_pe_cache, From 83df375f50975c76683c81fb6a8ec4079a19c0d4 Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 11 Dec 2025 06:08:07 +0000 Subject: [PATCH 7/9] refactor the SparseAttnIndexer as CustomOp Signed-off-by: ganyi --- vllm/_aiter_ops.py | 584 ++++++++++++++++++ vllm/attention/ops/rocm_aiter_mla_sparse.py | 399 ------------ vllm/config/compilation.py | 1 + .../layers/sparse_attn_indexer.py | 314 ++++++++++ vllm/model_executor/models/deepseek_v2.py | 261 +------- vllm/platforms/rocm.py | 7 + 6 files changed, 920 insertions(+), 646 deletions(-) delete mode 100644 vllm/attention/ops/rocm_aiter_mla_sparse.py create mode 100644 vllm/model_executor/layers/sparse_attn_indexer.py diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 010817e79a93..8d96a92b4c2b 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1,13 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools +import importlib from collections.abc import Callable import torch import vllm.envs as envs +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.forward_context import get_forward_context from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata _FP8_DTYPE = current_platform.fp8_dtype() @@ -54,6 +59,188 @@ def wrapper(*args, **kwargs): return wrapper +@triton.jit +def _indexer_k_quant_and_cache_kernel( + k_ptr, # [num_tokens, head_dim] + kv_cache_ptr, # [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B] + kv_cache_scale_ptr, # [n_blks, blk_size] + slot_mapping_ptr, # [num_tokens] + kv_cache_scale_stride, + kv_cache_value_stride, + block_size, + num_tokens, + head_dim: tl.constexpr, + BLOCK_TILE_SIZE: tl.constexpr, + HEAD_TILE_SIZE: tl.constexpr, + IS_FNUZ: tl.constexpr, + USE_UE8M0: tl.constexpr, +): + tid = tl.program_id(0) + offset = tl.arange(0, head_dim) + tile_offset = ( + offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE + + offset % HEAD_TILE_SIZE + ) + tile_store_offset = tile_offset + # for idx in tl.range(tid, num_tokens, n_program): + src_ptr = k_ptr + tid * head_dim + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + tile_block_id = block_offset // BLOCK_TILE_SIZE + tile_block_offset = block_offset % BLOCK_TILE_SIZE + val = tl.load(src_ptr + offset) + amax = tl.max(val.abs(), axis=-1).to(tl.float32) + if IS_FNUZ: + scale = tl.maximum(1e-4, amax) / 224.0 + else: + scale = tl.maximum(1e-4, amax) / 448.0 + + if USE_UE8M0: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + + fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty) + dst_ptr = ( + kv_cache_ptr + + block_id * kv_cache_value_stride + + tile_block_id * BLOCK_TILE_SIZE * head_dim + + tile_block_offset * HEAD_TILE_SIZE + ) + tl.store(dst_ptr + tile_store_offset, fp8_val) + dst_scale_ptr = kv_cache_scale_ptr + block_id * kv_cache_scale_stride + block_offset + tl.store(dst_scale_ptr, scale) + + +def indexer_k_quant_and_cache_triton( + k: torch.Tensor, + kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] + slot_mapping: torch.Tensor, + quant_block_size, + scale_fmt, + block_tile_size=16, + head_tile_size=16, +): + num_blocks = kv_cache.shape[0] + head_dim = k.shape[-1] + num_tokens = slot_mapping.shape[0] + block_size = kv_cache.shape[1] + # In real layout, we store the first portion as kv cache value + # and second portion as kv cache scale + kv_cache = kv_cache.view(num_blocks, -1) + kv_cache_value = kv_cache[:, : block_size * head_dim] + kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32) + head_tile_size = head_tile_size // kv_cache.element_size() + grid = (num_tokens,) + _indexer_k_quant_and_cache_kernel[grid]( + k, + kv_cache_value, + kv_cache_scale, + slot_mapping, + kv_cache_scale.stride(0), + kv_cache_value.stride(0), + block_size, + num_tokens, + head_dim, + block_tile_size, + head_tile_size, + IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, + USE_UE8M0=scale_fmt == "ue8m0", + ) + + +@triton.jit +def _cp_gather_indexer_quant_cache_kernel( + kv_cache_ptr, # [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B] + kv_cache_scale_ptr, # [n_blks, blk_size] + k_fp8_ptr, # [num_tokens, head_dim] + k_scale_ptr, # [num_tokens] + block_table_ptr, # [batch_size, block_table_stride] + cu_seqlen_ptr, # [batch_size + 1] + token_to_seq_ptr, # [num_tokens] + block_size, + block_table_stride, + kv_cache_stride, + kv_cache_scale_stride, + HEAD_DIM: tl.constexpr, + BLOCK_TILE_SIZE: tl.constexpr, + HEAD_TILE_SIZE: tl.constexpr, +): + tid = tl.program_id(0) + offset = tl.arange(0, HEAD_DIM) + batch_id = tl.load(token_to_seq_ptr + tid) + batch_start = tl.load(cu_seqlen_ptr + batch_id) + batch_end = tl.load(cu_seqlen_ptr + batch_id + 1) + batch_offset = tid - batch_start + if tid >= batch_end: + return + block_table_id = batch_offset // block_size + block_offset = batch_offset % block_size + block_table_offset = batch_id * block_table_stride + block_table_id + block_id = tl.load(block_table_ptr + block_table_offset) + tiled_block_id = block_offset // BLOCK_TILE_SIZE + tiled_block_offset = block_offset % BLOCK_TILE_SIZE + src_cache_offset = ( + block_id * kv_cache_stride + + tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE + + tiled_block_offset * HEAD_TILE_SIZE + ) + src_scale_offset = block_id * kv_cache_scale_stride + block_offset + dst_offset = tid * HEAD_DIM + src_scale_ptr = kv_cache_scale_ptr + src_scale_offset + src_cache_ptr = kv_cache_ptr + src_cache_offset + dst_k_ptr = k_fp8_ptr + dst_offset + scale_val = tl.load(src_scale_ptr) + tl.store(k_scale_ptr + tid, scale_val) + tiled_src_offset = ( + offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE + + offset % HEAD_TILE_SIZE + ) + val = tl.load(src_cache_ptr + tiled_src_offset) + tl.store(dst_k_ptr + offset, val) + + +def cp_gather_indexer_k_quant_cache_triton( + k_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] + k_fp8: torch.Tensor, + k_fp8_scale: torch.Tensor, + block_table: torch.Tensor, + cu_seqlen: torch.Tensor, + token_to_seq: torch.Tensor, + block_tile_size: int = 16, + head_tile_size: int = 16, +): + num_tokens = k_fp8.size(0) + block_size = k_cache.size(1) + block_table_stride = block_table.stride(0) + head_dim = k_fp8.shape[-1] + num_blocks = k_cache.shape[0] + # we assume the kv cache already been split to 2 portion + k_cache = k_cache.view(num_blocks, -1) + fp8_dtype = current_platform.fp8_dtype() + k_cache_value = k_cache[:, : block_size * head_dim].view(fp8_dtype) + k_cache_scale = k_cache[:, block_size * head_dim :].view(torch.float32) + grid = (num_tokens,) + k_fp8_scale = k_fp8_scale.view(torch.float32) + _cp_gather_indexer_quant_cache_kernel[grid]( + k_cache_value, + k_cache_scale, + k_fp8, + k_fp8_scale, + block_table, + cu_seqlen, + token_to_seq, + block_size, + block_table_stride, + k_cache_value.stride(0), + k_cache_scale.stride(0), + head_dim, + block_tile_size, + head_tile_size, + ) + + def _rocm_aiter_fused_moe_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -637,6 +824,395 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake( return x_fp8, out_bs +# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156 +def fp8_paged_mqa_logits_torch( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +): + from vllm.utils.math_utils import cdiv + + fp8_dtype = current_platform.fp8_dtype() + batch_size, next_n, _, dim = q.size() + kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] + scale = scale.contiguous().view(torch.float) + q = q.float() + kv_cache = kv_cache.view(fp8_dtype).float() * scale + num_block, block_size, _, dim = kv_cache.size() + logits = torch.full( + [batch_size * next_n, max_model_len], + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + context_lens = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens[i] + q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") + weight_slice = ( + weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() + ) + for block_rk in range(cdiv(context_len, block_size)): + block_idx = block_tables[i][block_rk] + qx, kx = q[i], kv_cache[block_idx] + k_offsets = torch.arange( + block_rk * block_size, (block_rk + 1) * block_size, device="cuda" + ) + mask = (k_offsets[None, :] < context_len) & ( + k_offsets[None, :] <= q_offsets[:, None] + ) + s = torch.where( + mask[None, :, :], + (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( + logits.dtype + ), + float("-inf"), + ) + s = torch.relu(s) * weight_slice[..., None] + s = s.sum(dim=0) + logits[ + i * next_n : (i + 1) * next_n, + block_rk * block_size : (block_rk + 1) * block_size, + ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) + return logits + + +def rocm_fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + """Compute FP8 MQA logits using paged KV-cache. + + Args: + q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape + [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last + 4 bytes per (block,pos) store the `float` dequant scale. + weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. + context_lens: Tensor of shape [B], dtype int32; effective context length + for each batch element. + block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical + block indices to physical blocks in the paged cache. + schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; + used to distribute work across SMs. + max_model_len: Maximum sequence length used to size the logits output. + + Returns: + Logits tensor of shape [B * next_n, max_model_len], dtype + `torch.float32`. + """ + + if rocm_aiter_ops.is_enabled(): + batch_size, next_n, heads, head_dim = q_fp8.shape + num_blocks, block_size, _, _ = kv_cache_fp8.shape + + from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits + + out_logits = torch.full( + [batch_size * next_n, max_model_len], + float("-inf"), + device="cuda", + dtype=torch.float32, + ) + deepgemm_fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8, + weights, + out_logits, + context_lens, + block_tables, + max_model_len, + ChunkK=256, + Preshuffle=block_size == 64, + KVBlockSize=block_size, + WavePerEU=2, + ) + return out_logits + else: + return fp8_paged_mqa_logits_torch( + q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len + ) + + +# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84 +def fp8_mqa_logits_torch( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + kv, scale = kv + seq_len_kv = kv.shape[0] + k = kv.to(torch.bfloat16) + q = q.to(torch.bfloat16) + + mask_lo = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + ) + mask_hi = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] + ) + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k).float() * scale + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + return logits + + +def rocm_fp8_mqa_logits( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + + # TODO(ganyi): Temporarily workaround, will remove the module check and reference + # path after aiter merge this kernel into main + @functools.lru_cache + def has_mqa_logits_module(): + return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None + + if rocm_aiter_ops.is_enabled() and has_mqa_logits_module(): + from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits + + kv, scale = kv + return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke) + else: + return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + + +def rocm_aiter_sparse_attn_indexer_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + # profile run + # NOTE(Chen): create the max possible flattened_kv. So that + # profile_run can get correct memory usage. + _flattened_kv = torch.empty( + [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 + ) + fp8_dtype = current_platform.fp8_dtype() + _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() + _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() + return topk_indices_buffer + + +def rocm_aiter_sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + fp8_dtype = current_platform.fp8_dtype() + # assert isinstance(attn_metadata, dict) + if not isinstance(attn_metadata, dict): + return rocm_aiter_sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + ) + attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + indexer_k_quant_and_cache_triton( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) + + topk_indices_buffer[: hidden_states.shape[0]] = -1 + if has_prefill: + prefill_metadata = attn_metadata.prefill + for chunk in prefill_metadata.chunks: + k_fp8 = torch.empty( + [chunk.total_seq_lens, head_dim], + device=k.device, + dtype=fp8_dtype, + ) + k_scale = torch.empty( + [chunk.total_seq_lens, 4], + device=k.device, + dtype=torch.uint8, + ) + cp_gather_indexer_k_quant_cache_triton( + kv_cache, + k_fp8, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + chunk.token_to_seq, + ) + + logits = rocm_fp8_mqa_logits( + q_fp8[chunk.token_start : chunk.token_end], + (k_fp8, k_scale.view(torch.float32)), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + torch.ops._C.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if has_decode: + decode_metadata = attn_metadata.decode + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache = kv_cache.unsqueeze(-2) + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + # pad in edge case where we have short chunked prefill length < + # decode_threshold since we unstrictly split + # prefill and decode by decode_threshold + # (currently set to 1 + speculative tokens) + padded_q_fp8_decode_tokens = pack_seq_triton( + q_fp8[:num_decode_tokens], decode_lens + ) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_fp8.shape[1:] + ) + # TODO: move and optimize below logic with triton kernels + batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + + logits = rocm_fp8_paged_mqa_logits( + padded_q_fp8_decode_tokens, + kv_cache, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=max_model_len, + ) + + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if decode_metadata.requires_padding: + # if padded, we need to unpack + # the topk indices removing padded tokens + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), + decode_lens, + ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices + ) + + return topk_indices_buffer + + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -862,6 +1438,14 @@ def register_ops_once() -> None: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_sparse_attn_indexer", + op_func=rocm_aiter_sparse_attn_indexer, + mutates_args=["topk_indices_buffer"], + fake_impl=rocm_aiter_sparse_attn_indexer_fake, + dispatch_key=current_platform.dispatch_key, + ) + _OPS_REGISTERED = True @staticmethod diff --git a/vllm/attention/ops/rocm_aiter_mla_sparse.py b/vllm/attention/ops/rocm_aiter_mla_sparse.py deleted file mode 100644 index c433030e60fe..000000000000 --- a/vllm/attention/ops/rocm_aiter_mla_sparse.py +++ /dev/null @@ -1,399 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import importlib -from functools import lru_cache - -import torch - -from vllm._aiter_ops import rocm_aiter_ops -from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.triton_utils import tl, triton - -logger = init_logger(__name__) - - -@triton.jit -def _indexer_k_quant_and_cache_kernel( - k_ptr, # [num_tokens, head_dim] - kv_cache_ptr, # [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B] - kv_cache_scale_ptr, # [n_blks, blk_size] - slot_mapping_ptr, # [num_tokens] - kv_cache_scale_stride, - kv_cache_value_stride, - block_size, - num_tokens, - head_dim: tl.constexpr, - BLOCK_TILE_SIZE: tl.constexpr, - HEAD_TILE_SIZE: tl.constexpr, - IS_FNUZ: tl.constexpr, - USE_UE8M0: tl.constexpr, -): - tid = tl.program_id(0) - offset = tl.arange(0, head_dim) - tile_offset = ( - offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE - + offset % HEAD_TILE_SIZE - ) - tile_store_offset = tile_offset - # for idx in tl.range(tid, num_tokens, n_program): - src_ptr = k_ptr + tid * head_dim - slot_id = tl.load(slot_mapping_ptr + tid) - if slot_id < 0: - return - block_id = slot_id // block_size - block_offset = slot_id % block_size - tile_block_id = block_offset // BLOCK_TILE_SIZE - tile_block_offset = block_offset % BLOCK_TILE_SIZE - val = tl.load(src_ptr + offset) - amax = tl.max(val.abs(), axis=-1).to(tl.float32) - if IS_FNUZ: - scale = tl.maximum(1e-4, amax) / 224.0 - else: - scale = tl.maximum(1e-4, amax) / 448.0 - - if USE_UE8M0: - scale = tl.exp2(tl.ceil(tl.log2(scale))) - - fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty) - dst_ptr = ( - kv_cache_ptr - + block_id * kv_cache_value_stride - + tile_block_id * BLOCK_TILE_SIZE * head_dim - + tile_block_offset * HEAD_TILE_SIZE - ) - tl.store(dst_ptr + tile_store_offset, fp8_val) - dst_scale_ptr = kv_cache_scale_ptr + block_id * kv_cache_scale_stride + block_offset - tl.store(dst_scale_ptr, scale) - - -def indexer_k_quant_and_cache_triton( - k: torch.Tensor, - kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] - slot_mapping: torch.Tensor, - quant_block_size, - scale_fmt, - block_tile_size=16, - head_tile_size=16, -): - num_blocks = kv_cache.shape[0] - head_dim = k.shape[-1] - num_tokens = slot_mapping.shape[0] - block_size = kv_cache.shape[1] - # In real layout, we store the first portion as kv cache value - # and second portion as kv cache scale - kv_cache = kv_cache.view(num_blocks, -1) - kv_cache_value = kv_cache[:, : block_size * head_dim] - kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32) - head_tile_size = head_tile_size // kv_cache.element_size() - grid = (num_tokens,) - _indexer_k_quant_and_cache_kernel[grid]( - k, - kv_cache_value, - kv_cache_scale, - slot_mapping, - kv_cache_scale.stride(0), - kv_cache_value.stride(0), - block_size, - num_tokens, - head_dim, - block_tile_size, - head_tile_size, - IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, - USE_UE8M0=scale_fmt == "ue8m0", - ) - - -@triton.jit -def _cp_gather_indexer_quant_cache_kernel( - kv_cache_ptr, # [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B] - kv_cache_scale_ptr, # [n_blks, blk_size] - k_fp8_ptr, # [num_tokens, head_dim] - k_scale_ptr, # [num_tokens] - block_table_ptr, # [batch_size, block_table_stride] - cu_seqlen_ptr, # [batch_size + 1] - token_to_seq_ptr, # [num_tokens] - block_size, - block_table_stride, - kv_cache_stride, - kv_cache_scale_stride, - HEAD_DIM: tl.constexpr, - BLOCK_TILE_SIZE: tl.constexpr, - HEAD_TILE_SIZE: tl.constexpr, -): - tid = tl.program_id(0) - offset = tl.arange(0, HEAD_DIM) - batch_id = tl.load(token_to_seq_ptr + tid) - batch_start = tl.load(cu_seqlen_ptr + batch_id) - batch_end = tl.load(cu_seqlen_ptr + batch_id + 1) - batch_offset = tid - batch_start - if tid >= batch_end: - return - block_table_id = batch_offset // block_size - block_offset = batch_offset % block_size - block_table_offset = batch_id * block_table_stride + block_table_id - block_id = tl.load(block_table_ptr + block_table_offset) - tiled_block_id = block_offset // BLOCK_TILE_SIZE - tiled_block_offset = block_offset % BLOCK_TILE_SIZE - src_cache_offset = ( - block_id * kv_cache_stride - + tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE - + tiled_block_offset * HEAD_TILE_SIZE - ) - src_scale_offset = block_id * kv_cache_scale_stride + block_offset - dst_offset = tid * HEAD_DIM - src_scale_ptr = kv_cache_scale_ptr + src_scale_offset - src_cache_ptr = kv_cache_ptr + src_cache_offset - dst_k_ptr = k_fp8_ptr + dst_offset - scale_val = tl.load(src_scale_ptr) - tl.store(k_scale_ptr + tid, scale_val) - tiled_src_offset = ( - offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE - + offset % HEAD_TILE_SIZE - ) - val = tl.load(src_cache_ptr + tiled_src_offset) - tl.store(dst_k_ptr + offset, val) - - -def cp_gather_indexer_k_quant_cache_triton( - k_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] - k_fp8: torch.Tensor, - k_fp8_scale: torch.Tensor, - block_table: torch.Tensor, - cu_seqlen: torch.Tensor, - token_to_seq: torch.Tensor, - block_tile_size: int = 16, - head_tile_size: int = 16, -): - num_tokens = k_fp8.size(0) - block_size = k_cache.size(1) - block_table_stride = block_table.stride(0) - head_dim = k_fp8.shape[-1] - num_blocks = k_cache.shape[0] - # we assume the kv cache already been split to 2 portion - k_cache = k_cache.view(num_blocks, -1) - fp8_dtype = current_platform.fp8_dtype() - k_cache_value = k_cache[:, : block_size * head_dim].view(fp8_dtype) - k_cache_scale = k_cache[:, block_size * head_dim :].view(torch.float32) - grid = (num_tokens,) - k_fp8_scale = k_fp8_scale.view(torch.float32) - _cp_gather_indexer_quant_cache_kernel[grid]( - k_cache_value, - k_cache_scale, - k_fp8, - k_fp8_scale, - block_table, - cu_seqlen, - token_to_seq, - block_size, - block_table_stride, - k_cache_value.stride(0), - k_cache_scale.stride(0), - head_dim, - block_tile_size, - head_tile_size, - ) - - -# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84 -def fp8_mqa_logits_torch( - q: torch.Tensor, - kv: tuple[torch.Tensor, torch.Tensor], - weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, - cu_seqlen_ke: torch.Tensor, -) -> torch.Tensor: - """Compute FP8 MQA logits for a single sequence without KV paging. - - Args: - q: Query tensor of shape [M, H, D]. Casted to - `torch.float8_e4m3fn` by caller. - kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with - dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or - [N, 1]) with dtype `torch.float32`. - weights: weights of shape [M, H], dtype `torch.float32`. - cu_seqlen_ks: Start indices (inclusive) for valid K per query position, - shape [M], dtype int32. - cu_seqlen_ke: End indices (exclusive) for valid K per query position, - shape [M], dtype int32. - - Returns: - Logits tensor of shape [M, N], dtype `torch.float32`. - """ - kv, scale = kv - seq_len_kv = kv.shape[0] - k = kv.to(torch.bfloat16) - q = q.to(torch.bfloat16) - - mask_lo = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] - ) - mask_hi = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] - ) - mask = mask_lo & mask_hi - - score = torch.einsum("mhd,nd->hmn", q, k).float() * scale - logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float("-inf")) - - return logits - - -def rocm_fp8_mqa_logits( - q: torch.Tensor, - kv: tuple[torch.Tensor, torch.Tensor], - weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, - cu_seqlen_ke: torch.Tensor, -) -> torch.Tensor: - """Compute FP8 MQA logits for a single sequence without KV paging. - - Args: - q: Query tensor of shape [M, H, D]. Casted to - `torch.float8_e4m3fn` by caller. - kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with - dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or - [N, 1]) with dtype `torch.float32`. - weights: weights of shape [M, H], dtype `torch.float32`. - cu_seqlen_ks: Start indices (inclusive) for valid K per query position, - shape [M], dtype int32. - cu_seqlen_ke: End indices (exclusive) for valid K per query position, - shape [M], dtype int32. - - Returns: - Logits tensor of shape [M, N], dtype `torch.float32`. - """ - - # TODO(ganyi): Temporarily workaround, will remove the module check and reference - # path after aiter merge this kernel into main - @lru_cache - def has_mqa_logits_module(): - return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None - - if rocm_aiter_ops.is_enabled() and has_mqa_logits_module(): - from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits - - kv, scale = kv - return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke) - else: - return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) - - -# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156 -def fp8_paged_mqa_logits_torch( - q: torch.Tensor, - kv_cache: torch.Tensor, - weights: torch.Tensor, - context_lens: torch.Tensor, - block_tables: torch.Tensor, - max_model_len: int, -): - from vllm.utils.math_utils import cdiv - - fp8_dtype = current_platform.fp8_dtype() - batch_size, next_n, _, dim = q.size() - kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] - scale = scale.contiguous().view(torch.float) - q = q.float() - kv_cache = kv_cache.view(fp8_dtype).float() * scale - num_block, block_size, _, dim = kv_cache.size() - logits = torch.full( - [batch_size * next_n, max_model_len], - float("-inf"), - device=q.device, - dtype=torch.float32, - ) - context_lens = context_lens.tolist() - for i in range(batch_size): - context_len = context_lens[i] - q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") - weight_slice = ( - weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() - ) - for block_rk in range(cdiv(context_len, block_size)): - block_idx = block_tables[i][block_rk] - qx, kx = q[i], kv_cache[block_idx] - k_offsets = torch.arange( - block_rk * block_size, (block_rk + 1) * block_size, device="cuda" - ) - mask = (k_offsets[None, :] < context_len) & ( - k_offsets[None, :] <= q_offsets[:, None] - ) - s = torch.where( - mask[None, :, :], - (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( - logits.dtype - ), - float("-inf"), - ) - s = torch.relu(s) * weight_slice[..., None] - s = s.sum(dim=0) - logits[ - i * next_n : (i + 1) * next_n, - block_rk * block_size : (block_rk + 1) * block_size, - ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) - return logits - - -def rocm_fp8_paged_mqa_logits( - q_fp8: torch.Tensor, - kv_cache_fp8: torch.Tensor, - weights: torch.Tensor, - context_lens: torch.Tensor, - block_tables: torch.Tensor, - schedule_metadata: torch.Tensor, - max_model_len: int, -) -> torch.Tensor: - """Compute FP8 MQA logits using paged KV-cache. - - Args: - q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to - `torch.float8_e4m3fn` by caller. - kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape - [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last - 4 bytes per (block,pos) store the `float` dequant scale. - weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. - context_lens: Tensor of shape [B], dtype int32; effective context length - for each batch element. - block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical - block indices to physical blocks in the paged cache. - schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; - used to distribute work across SMs. - max_model_len: Maximum sequence length used to size the logits output. - - Returns: - Logits tensor of shape [B * next_n, max_model_len], dtype - `torch.float32`. - """ - - if rocm_aiter_ops.is_enabled(): - batch_size, next_n, heads, head_dim = q_fp8.shape - num_blocks, block_size, _, _ = kv_cache_fp8.shape - - from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits - - out_logits = torch.full( - [batch_size * next_n, max_model_len], - float("-inf"), - device="cuda", - dtype=torch.float32, - ) - deepgemm_fp8_paged_mqa_logits( - q_fp8, - kv_cache_fp8, - weights, - out_logits, - context_lens, - block_tables, - max_model_len, - ChunkK=256, - Preshuffle=block_size == 64, - KVBlockSize=block_size, - WavePerEU=2, - ) - return out_logits - else: - return fp8_paged_mqa_logits_torch( - q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len - ) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 3b6cb8a34360..2fa470978be7 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -611,6 +611,7 @@ class CompilationConfig: "vllm::gdn_attention_core", "vllm::kda_attention", "vllm::sparse_attn_indexer", + "vllm::rocm_aiter_sparse_attn_indexer", ] def compute_hash(self) -> str: diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py new file mode 100644 index 000000000000..b05da89460af --- /dev/null +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -0,0 +1,314 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Custom Sparse Attention Indexer layers.""" + +import torch + +from vllm._aiter_ops import rocm_aiter_ops +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits +from vllm.utils.torch_utils import direct_register_custom_op +from vllm.v1.attention.backends.mla.indexer import ( + DeepseekV32IndexerMetadata, +) + +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + +logger = init_logger(__name__) + + +def sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor, +) -> torch.Tensor: + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + fp8_dtype = current_platform.fp8_dtype() + print("in cuda path, which is wrong!", flush=True) + # assert isinstance(attn_metadata, dict) + if not isinstance(attn_metadata, dict): + return sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + ) + attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + ops.indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) + + topk_indices_buffer[: hidden_states.shape[0]] = -1 + if has_prefill: + prefill_metadata = attn_metadata.prefill + for chunk in prefill_metadata.chunks: + k_fp8 = torch.empty( + [chunk.total_seq_lens, head_dim], + device=k.device, + dtype=fp8_dtype, + ) + k_scale = torch.empty( + [chunk.total_seq_lens, 4], + device=k.device, + dtype=torch.uint8, + ) + ops.cp_gather_indexer_k_quant_cache( + kv_cache, + k_fp8, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + ) + + logits = fp8_mqa_logits( + q_fp8[chunk.token_start : chunk.token_end], + (k_fp8, k_scale.view(torch.float32)), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + torch.ops._C.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if has_decode: + decode_metadata = attn_metadata.decode + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache = kv_cache.unsqueeze(-2) + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + # pad in edge case where we have short chunked prefill length < + # decode_threshold since we unstrictly split + # prefill and decode by decode_threshold + # (currently set to 1 + speculative tokens) + padded_q_fp8_decode_tokens = pack_seq_triton( + q_fp8[:num_decode_tokens], decode_lens + ) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_fp8.shape[1:] + ) + # TODO: move and optimize below logic with triton kernels + batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + + logits = fp8_paged_mqa_logits( + padded_q_fp8_decode_tokens, + kv_cache, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=max_model_len, + ) + + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if decode_metadata.requires_padding: + # if padded, we need to unpack + # the topk indices removing padded tokens + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), + decode_lens, + ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices + ) + + return topk_indices_buffer + + +def sparse_attn_indexer_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + # profile run + # NOTE(Chen): create the max possible flattened_kv. So that + # profile_run can get correct memory usage. + _flattened_kv = torch.empty( + [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 + ) + fp8_dtype = current_platform.fp8_dtype() + _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() + _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() + return topk_indices_buffer + + +direct_register_custom_op( + op_name="sparse_attn_indexer", + op_func=sparse_attn_indexer, + mutates_args=["topk_indices_buffer"], + fake_impl=sparse_attn_indexer_fake, + dispatch_key=current_platform.dispatch_key, +) + + +@CustomOp.register("sparse_attn_indexer") +class SparseAttnIndexer(CustomOp): + """Sparse Attention Indexer Custom Op Layer. This layer is extracted as a + separate custom op since it involves heavy custom kernels like `mqa_logits`, + `paged_mqa_logits` and `top_k_per_row`, etc. Those kernels maybe requires + specific memory layout or implementation for different hardware backends to + achieve optimal performance. + + For now, the default native path will use CUDA backend path. Other platform + may requires add the corresponding Custom Op name `sparse_attn_indexer` to + `custom_ops` in `CompilationConfig` to enable the platform specific path. + """ + + def __init__( + self, + k_cache, + quant_block_size: int, + scale_fmt: str, + topk_tokens: int, + head_dim: int, + max_model_len: int, + max_total_seq_len: int, + topk_indices_buffer: torch.Tensor, + ): + super().__init__() + self.k_cache = k_cache + self.quant_block_size = quant_block_size + self.scale_fmt = scale_fmt + self.topk_tokens = topk_tokens + self.head_dim = head_dim + self.max_model_len = max_model_len + self.max_total_seq_len = max_total_seq_len + self.topk_indices_buffer = topk_indices_buffer + + def forwrad_native( + self, + hidden_states: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + ): + return self.forward_cuda(hidden_states, q_fp8, k, weights) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + ): + return torch.ops.vllm.sparse_attn_indexer( + hidden_states, + self.k_cache.layer_prefix, + self.k_cache.kv_cache[0], + q_fp8, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) + + def forward_hip( + self, + hidden_states: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + ): + if rocm_aiter_ops.is_enabled(): + return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( + hidden_states, + self.k_cache.prefix, + self.k_cache.kv_cache[0], + q_fp8, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) + else: + raise RuntimeError( + "Sparse attention indexer ROCm custom op requires ROCm " + "Aiter ops to be enabled." + ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a057764bbeb4..0f0c8dcd906c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -35,7 +35,6 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention -from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config from vllm.distributed import ( @@ -45,7 +44,6 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) -from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -65,6 +63,7 @@ per_token_group_quant_fp8, ) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -76,11 +75,8 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits -from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, - DeepseekV32IndexerMetadata, ) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec @@ -93,10 +89,8 @@ maybe_prefix, ) -if current_platform.is_cuda_alike(): - from vllm import _custom_ops as ops -elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops as ops +if current_platform.is_cuda_alike() or current_platform.is_xpu(): + pass logger = init_logger(__name__) @@ -600,229 +594,6 @@ def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend -def sparse_attn_indexer( - hidden_states: torch.Tensor, - k_cache_prefix: str, - kv_cache: torch.Tensor, - q_fp8: torch.Tensor, - k: torch.Tensor, - weights: torch.Tensor, - quant_block_size: int, - scale_fmt: str | None, - topk_tokens: int, - head_dim: int, - max_model_len: int, - total_seq_lens: int, - topk_indices_buffer: torch.Tensor | None, -) -> torch.Tensor: - # careful! this will be None in dummy run - attn_metadata = get_forward_context().attn_metadata - fp8_dtype = current_platform.fp8_dtype() - # assert isinstance(attn_metadata, dict) - if not isinstance(attn_metadata, dict): - return sparse_attn_indexer_fake( - hidden_states, - k_cache_prefix, - kv_cache, - q_fp8, - k, - weights, - quant_block_size, - scale_fmt, - topk_tokens, - head_dim, - max_model_len, - total_seq_lens, - topk_indices_buffer, - ) - attn_metadata = attn_metadata[k_cache_prefix] - assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) - slot_mapping = attn_metadata.slot_mapping - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens - - indexer_k_quant_cache_and_cache_func = ops.indexer_k_quant_and_cache - if current_platform.is_rocm(): - from vllm.attention.ops.rocm_aiter_mla_sparse import ( - indexer_k_quant_and_cache_triton, - ) - - indexer_k_quant_cache_and_cache_func = indexer_k_quant_and_cache_triton - - indexer_k_quant_cache_and_cache_func( - k, - kv_cache, - slot_mapping, - quant_block_size, - scale_fmt, - ) - - topk_indices_buffer[: hidden_states.shape[0]] = -1 - if has_prefill: - prefill_metadata = attn_metadata.prefill - for chunk in prefill_metadata.chunks: - k_fp8 = torch.empty( - [chunk.total_seq_lens, head_dim], - device=k.device, - dtype=fp8_dtype, - ) - k_scale = torch.empty( - [chunk.total_seq_lens, 4], - device=k.device, - dtype=torch.uint8, - ) - cp_gather_indexer_k_quant_cache_func = ops.cp_gather_indexer_k_quant_cache - if current_platform.is_rocm(): - from functools import partial - - from vllm.attention.ops.rocm_aiter_mla_sparse import ( - cp_gather_indexer_k_quant_cache_triton, - ) - - cp_gather_indexer_k_quant_cache_func = partial( - cp_gather_indexer_k_quant_cache_triton, - token_to_seq=chunk.token_to_seq, - ) - cp_gather_indexer_k_quant_cache_func( - kv_cache, - k_fp8, - k_scale, - chunk.block_table, - chunk.cu_seq_lens, - ) - - fp8_mqa_logits_func = fp8_mqa_logits - if current_platform.is_rocm(): - from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits - - fp8_mqa_logits_func = rocm_fp8_mqa_logits - logits = fp8_mqa_logits_func( - q_fp8[chunk.token_start : chunk.token_end], - (k_fp8, k_scale.view(torch.float32)), - weights[chunk.token_start : chunk.token_end], - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - ) - num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[ - chunk.token_start : chunk.token_end, :topk_tokens - ] - torch.ops._C.top_k_per_row_prefill( - logits, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) - - if has_decode: - decode_metadata = attn_metadata.decode - # kv_cache size requirement [num_block, block_size, n_head, head_dim], - # we only have [num_block, block_size, head_dim], - kv_cache = kv_cache.unsqueeze(-2) - decode_lens = decode_metadata.decode_lens - if decode_metadata.requires_padding: - # pad in edge case where we have short chunked prefill length < - # decode_threshold since we unstrictly split - # prefill and decode by decode_threshold - # (currently set to 1 + speculative tokens) - padded_q_fp8_decode_tokens = pack_seq_triton( - q_fp8[:num_decode_tokens], decode_lens - ) - else: - padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( - decode_lens.shape[0], -1, *q_fp8.shape[1:] - ) - # TODO: move and optimize below logic with triton kernels - batch_size = padded_q_fp8_decode_tokens.shape[0] - next_n = padded_q_fp8_decode_tokens.shape[1] - assert batch_size == decode_metadata.seq_lens.shape[0] - num_padded_tokens = batch_size * next_n - fp8_paged_mqa_logits_func = fp8_paged_mqa_logits - if current_platform.is_rocm(): - from vllm.attention.ops.rocm_aiter_mla_sparse import ( - rocm_fp8_paged_mqa_logits, - ) - - fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits - - logits = fp8_paged_mqa_logits_func( - padded_q_fp8_decode_tokens, - kv_cache, - weights[:num_padded_tokens], - decode_metadata.seq_lens, - decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len=max_model_len, - ) - - num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] - torch.ops._C.top_k_per_row_decode( - logits, - next_n, - decode_metadata.seq_lens, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) - - if decode_metadata.requires_padding: - # if padded, we need to unpack - # the topk indices removing padded tokens - topk_indices = unpack_seq_triton( - topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), - decode_lens, - ) - topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( - topk_indices - ) - - return topk_indices_buffer - - -def sparse_attn_indexer_fake( - hidden_states: torch.Tensor, - k_cache_prefix: str, - kv_cache: torch.Tensor, - q_fp8: torch.Tensor, - k: torch.Tensor, - weights: torch.Tensor, - quant_block_size: int, - scale_fmt: str | None, - topk_tokens: int, - head_dim: int, - max_model_len: int, - total_seq_lens: int, - topk_indices_buffer: torch.Tensor | None, -) -> torch.Tensor: - # profile run - # NOTE(Chen): create the max possible flattened_kv. So that - # profile_run can get correct memory usage. - _flattened_kv = torch.empty( - [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 - ) - fp8_dtype = current_platform.fp8_dtype() - _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() - _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() - return topk_indices_buffer - - -direct_register_custom_op( - op_name="sparse_attn_indexer", - op_func=sparse_attn_indexer, - mutates_args=["topk_indices_buffer"], - fake_impl=sparse_attn_indexer_fake, - dispatch_key=current_platform.dispatch_key, -) - - class Indexer(nn.Module): def __init__( self, @@ -883,6 +654,16 @@ def __init__( from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) + self.indexer_op = SparseAttnIndexer( + self.k_cache, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) def forward( self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb @@ -920,21 +701,7 @@ def forward( ) weights = weights.squeeze(-1) - return torch.ops.vllm.sparse_attn_indexer( - hidden_states, - self.k_cache.prefix, - self.k_cache.kv_cache[0], - q_fp8, - k, - weights, - self.quant_block_size, - self.scale_fmt, - self.topk_tokens, - self.head_dim, - self.max_model_len, - self.max_total_seq_len, - self.topk_indices_buffer, - ) + return self.indexer_op(hidden_states, q_fp8, k, weights) class DeepseekV2MLAAttention(nn.Module): diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 04758091da8c..1455f625f118 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -377,6 +377,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config compilation_config = vllm_config.compilation_config parallel_config = vllm_config.parallel_config + model_config = vllm_config.model_config + hf_config = model_config.hf_config is_eager_execution = compilation_config == CUDAGraphMode.NONE use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled() @@ -429,6 +431,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops: compilation_config.custom_ops.append("+quant_fp8") + # Default dispatch to rocm's sparse_attn_indexer implementation + if hf_config is not None and hasattr(hf_config, "index_topk"): + print("add sparse attn indexer to rocm custom ops", flush=True) + compilation_config.custom_ops.append("+sparse_attn_indexer") + @classmethod def verify_model_arch(cls, model_arch: str) -> None: if model_arch in _ROCM_UNSUPPORTED_MODELS: From c3261f2e378b5d2f87c2ef98f9830500bce6a6bf Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 11 Dec 2025 06:12:59 +0000 Subject: [PATCH 8/9] remove unnecessary change Signed-off-by: ganyi --- vllm/v1/attention/backends/mla/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 8f6e7ffaee4e..0a5257a1d87d 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -256,8 +256,8 @@ class QueryLenSupport(Enum): is_vllm_fa = True except ImportError: # For rocm use upstream flash attention - # if current_platform.is_rocm(): - # from flash_attn import flash_attn_varlen_func + if current_platform.is_rocm(): + from flash_attn import flash_attn_varlen_func is_vllm_fa = False try: From df386f47afaf714aa4b92755e64c2d0da61674cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Thu, 18 Dec 2025 12:08:02 +0200 Subject: [PATCH 9/9] fix: add missing underscore to _cudagraph_support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ROCMAiterMLASparseMetadataBuilder sets cudagraph_support, but the actual checked name is _cudagraph_support. This causes the default value AttentionCGSupport.NEVER to be used instead of the intended AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE Signed-off-by: Stig-Arne Grönroos --- vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 2051d092adca..066f5d5089a9 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -143,7 +143,7 @@ class ROCMAiterMLASparseMetadata: class ROCMAiterMLASparseMetadataBuilder( AttentionMetadataBuilder[ROCMAiterMLASparseMetadata] ): - cudagraph_support: ClassVar[AttentionCGSupport] = ( + _cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE )