From c1d785cf291d3cd03b8bc718ce54da9a4bd665eb Mon Sep 17 00:00:00 2001 From: ganyi Date: Tue, 4 Nov 2025 09:09:49 +0000 Subject: [PATCH 1/9] enable mla_asm in sparse_mla backend Signed-off-by: ganyi --- vllm/model_executor/models/deepseek_v2.py | 28 ++- vllm/platforms/rocm.py | 5 +- vllm/v1/attention/backends/mla/indexer.py | 11 +- .../backends/mla/rocm_aiter_mla_sparse.py | 123 ++++++++++- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 203 +++++++++++++++++- 5 files changed, 343 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 13bb3cbd0846..928fcb83a54c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -648,7 +648,15 @@ def sparse_attn_indexer( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - ops.indexer_k_quant_and_cache( + indexer_k_quant_cache_and_cache_func = ops.indexer_k_quant_and_cache + if current_platform.is_rocm(): + from vllm.attention.ops.rocm_aiter_mla_sparse import ( + indexer_k_quant_and_cache_triton, + ) + + indexer_k_quant_cache_and_cache_func = indexer_k_quant_and_cache_triton + + indexer_k_quant_cache_and_cache_func( k, kv_cache, slot_mapping, @@ -670,13 +678,26 @@ def sparse_attn_indexer( for chunk in prefill_metadata.chunks: k_fp8 = k_fp8_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens] - ops.cp_gather_indexer_k_quant_cache( + cp_gather_indexer_k_quant_cache_func = ops.cp_gather_indexer_k_quant_cache + if current_platform.is_rocm(): + from functools import partial + + from vllm.attention.ops.rocm_aiter_mla_sparse import ( + cp_gather_indexer_k_quant_cache_triton, + ) + + cp_gather_indexer_k_quant_cache_func = partial( + cp_gather_indexer_k_quant_cache_triton, + token_to_seq=chunk.token_to_seq, + ) + cp_gather_indexer_k_quant_cache_func( kv_cache, k_fp8, k_scale, chunk.block_table, chunk.cu_seq_lens, ) + fp8_mqa_logits_func = fp8_mqa_logits if current_platform.is_rocm(): from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( @@ -743,6 +764,7 @@ def sparse_attn_indexer( ) fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits + logits = fp8_paged_mqa_logits_func( padded_q_fp8_decode_tokens, kv_cache, @@ -752,6 +774,7 @@ def sparse_attn_indexer( decode_metadata.schedule_metadata, max_model_len=max_model_len, ) + num_rows = logits.shape[0] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] @@ -765,6 +788,7 @@ def sparse_attn_indexer( logits.stride(1), topk_tokens, ) + if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 45e3d50e7159..fbb03fc1d3c6 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -214,10 +214,7 @@ def get_attn_backend_cls( raise ValueError( "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." ) - assert block_size == 1, ( - "Sparse MLA backend on ROCm only supports block size 1 for now." - ) - logger.info_once("Using Sparse MLA backend.") + logger.info_once("Using Sparse MLA backend on V1 engine.") return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() if attn_selector_config.use_mla: diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 3af785620a7a..791710f19c20 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -7,7 +7,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported from vllm.v1.attention.backend import ( AttentionBackend, @@ -25,9 +24,7 @@ class DeepseekV32IndexerBackend(AttentionBackend): - @staticmethod - def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [1 if current_platform.is_rocm() else 64] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -63,6 +60,7 @@ class DeepseekV32IndexerPrefillChunkMetadata: cu_seqlen_ks: torch.Tensor cu_seqlen_ke: torch.Tensor cu_seq_lens: torch.Tensor + token_to_seq: torch.Tensor total_seq_lens: int token_start: int token_end: int @@ -234,6 +232,10 @@ def build_one_prefill_chunk( token_start = query_start_loc_cpu[reqs_start].item() token_end = query_start_loc_cpu[reqs_end].item() total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() + seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32) + token_to_seq = torch.repeat_interleave( + seq_idx, seq_lens_cpu[reqs_start:reqs_end] + ).to(self.device) assert total_seq_lens <= self.max_prefill_buffer_size cu_seq_lens = ( torch.cat( @@ -249,6 +251,7 @@ def build_one_prefill_chunk( cu_seqlen_ks=cu_seqlen_ks, cu_seqlen_ke=cu_seqlen_ke, cu_seq_lens=cu_seq_lens, + token_to_seq=token_to_seq, total_seq_lens=total_seq_lens, block_table=block_table[reqs_start:reqs_end], token_start=token_start, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 704fc43840e1..3c56dbe378b9 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -11,6 +11,7 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.triton_utils import tl, triton from vllm.model_executor.layers.attention.mla_attention import ( MLACommonBaseImpl, get_mla_dims, @@ -33,6 +34,48 @@ logger = init_logger(__name__) +@triton.jit +def fetch_id_to_ragged_kernel( + in_tensor_ptr, # [num_seq, topk] + cumsum_ptr, # [num_seq + 1] + out_tensor_ptr, # [max_num_seq * topk] + in_tensor_ptr_stride, + TOPK: tl.constexpr, + TOKEN_NUM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + seq_id = tl.program_id(0) + block_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + token_start = tl.load(cumsum_ptr + seq_id) + token_end = tl.load(cumsum_ptr + seq_id + 1) + token_num = token_end - token_start + row_offset = block_id * BLOCK_SIZE + if row_offset >= token_num: + return + in_tensor_offset = seq_id * in_tensor_ptr_stride + row_offset + offset + in_tensor_mask = (row_offset + offset) < TOPK + in_tensor_val = tl.load(in_tensor_ptr + in_tensor_offset, mask=in_tensor_mask) + out_tensor_offset = token_start + row_offset + offset + out_tensor_mask = (out_tensor_offset < token_end) & in_tensor_mask + tl.store(out_tensor_ptr + out_tensor_offset, in_tensor_val, mask=out_tensor_mask) + + +def fetch_id_to_ragged_triton( + in_tensor: torch.Tensor, cumsum: torch.Tensor, out_tensor: torch.Tensor, topk +): + num_tokens = in_tensor.size(0) + block_size = 64 + num_block_per_row = triton.cdiv(topk, block_size) + grid = ( + num_tokens, + num_block_per_row, + ) + fetch_id_to_ragged_kernel[grid]( + in_tensor, cumsum, out_tensor, in_tensor.stride(0), topk, num_tokens, block_size + ) + + class ROCMAiterMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True @@ -83,6 +126,13 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata): block_table: torch.Tensor req_id_per_token: torch.Tensor + + qo_indptr: torch.Tensor + paged_kv_last_page_len: torch.Tensor + paged_kv_indices: torch.Tensor + paged_kv_indptr: torch.Tensor + paged_kv_indptr_rest: torch.Tensor + block_size: int = 1 topk_tokens: int = 2048 @@ -91,7 +141,9 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata): class ROCMAiterMLASparseMetadataBuilder( AttentionMetadataBuilder[ROCMAiterMLASparseMetadata] ): - cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) def __init__( self, @@ -104,6 +156,8 @@ def __init__( self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.device = device + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + max_num_seqs = vllm_config.scheduler_config.max_num_seqs self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) @@ -124,6 +178,23 @@ def __init__( dtype=torch.int32, device=device, ) + self.qo_indptr = torch.arange( + 0, max_num_batched_tokens + 1, dtype=torch.int32, device=device + ) + self.paged_kv_last_page_len = torch.ones( + max_num_seqs, dtype=torch.int32, device=device + ) + + # These two needs to be calculated in runtime, + # but we still needs to prepare the buffer + self.paged_kv_indices = torch.zeros( + [max_num_batched_tokens * self.topk_tokens], + dtype=torch.int32, + device=device, + ) + self.paged_kv_indptr = torch.zeros( + [max_num_seqs + 1], dtype=torch.int32, device=device + ) def build( self, @@ -142,7 +213,15 @@ def build( self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( torch.from_numpy(req_id_per_token), non_blocking=True ) + self.paged_kv_indices.fill_(0) + self.paged_kv_indptr.fill_(0) + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + qo_indptr = self.qo_indptr[: num_tokens + 1] + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens] + paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens] + paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1] + paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :] metadata = ROCMAiterMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, @@ -155,6 +234,11 @@ def build( req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, topk_tokens=self.topk_tokens, + qo_indptr=qo_indptr, + paged_kv_last_page_len=paged_kv_last_page_len, + paged_kv_indices=paged_kv_indices, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indptr_rest=paged_kv_indptr_rest, ) return metadata @@ -226,20 +310,39 @@ def __init__( def _forward_bf16_kv( self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, + q: torch.Tensor, # [sq, heads, d_qk] + kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk] + topk_indices: torch.Tensor, # [sq, topk] attn_metadata: ROCMAiterMLASparseMetadata, ) -> torch.Tensor: num_tokens = q.shape[0] - kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( - -1, 1, kv_c_and_k_pe_cache.shape[-1] + output = torch.empty( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=q.dtype, + device=q.device, + ) + seq_len = (topk_indices != -1).sum(dim=-1) + torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:]) + attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1]) + fetch_id_to_ragged_triton( + topk_indices, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.topk_tokens, + ) + + rocm_aiter_ops.mla_decode_fwd( + q, + kv_c_and_k_pe_cache, + output, + self.scale, + attn_metadata.qo_indptr, + 1, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len, ) - topk_indices = topk_indices.view(num_tokens, 1, -1) - output = reference_mla_sparse_prefill( - q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale, 512 - )[0] return output[:, : self.num_heads, :] def forward( diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 1e89d48dbbb6..c6bad7e89644 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -8,10 +8,193 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton logger = init_logger(__name__) +@triton.jit +def _indexer_k_quant_and_cache_kernel( + k_ptr, # [num_tokens, head_dim] + kv_cache_ptr, # [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B] + kv_cache_scale_ptr, # [n_blks, blk_size] + slot_mapping_ptr, # [num_tokens] + kv_cache_scale_stride, + kv_cache_value_stride, + block_size, + num_tokens, + head_dim: tl.constexpr, + BLOCK_TILE_SIZE: tl.constexpr, + HEAD_TILE_SIZE: tl.constexpr, + IS_FNUZ: tl.constexpr, + USE_UE8M0: tl.constexpr, +): + tid = tl.program_id(0) + offset = tl.arange(0, head_dim) + tile_offset = ( + offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE + + offset % HEAD_TILE_SIZE + ) + tile_store_offset = tile_offset + # for idx in tl.range(tid, num_tokens, n_program): + src_ptr = k_ptr + tid * head_dim + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + tile_block_id = block_offset // BLOCK_TILE_SIZE + tile_block_offset = block_offset % BLOCK_TILE_SIZE + val = tl.load(src_ptr + offset) + amax = tl.max(val.abs(), axis=-1).to(tl.float32) + if IS_FNUZ: + scale = tl.maximum(1e-4, amax) / 224.0 + else: + scale = tl.maximum(1e-4, amax) / 448.0 + + if USE_UE8M0: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + + fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty) + dst_ptr = ( + kv_cache_ptr + + block_id * kv_cache_value_stride + + tile_block_id * BLOCK_TILE_SIZE * head_dim + + tile_block_offset * HEAD_TILE_SIZE + ) + tl.store(dst_ptr + tile_store_offset, fp8_val) + dst_scale_ptr = kv_cache_scale_ptr + block_id * kv_cache_scale_stride + block_offset + tl.store(dst_scale_ptr, scale) + + +def indexer_k_quant_and_cache_triton( + k: torch.Tensor, + kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] + slot_mapping: torch.Tensor, + quant_block_size, + scale_fmt, + block_tile_size=16, + head_tile_size=16, +): + num_blocks = kv_cache.shape[0] + head_dim = k.shape[-1] + num_tokens = slot_mapping.shape[0] + block_size = kv_cache.shape[1] + # In real layout, we store the first portion as kv cache value + # and second portion as kv cache scale + kv_cache = kv_cache.view(num_blocks, -1) + kv_cache_value = kv_cache[:, : block_size * head_dim] + kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32) + head_tile_size = head_tile_size // kv_cache.element_size() + grid = (num_tokens,) + _indexer_k_quant_and_cache_kernel[grid]( + k, + kv_cache_value, + kv_cache_scale, + slot_mapping, + kv_cache_scale.stride(0), + kv_cache_value.stride(0), + block_size, + num_tokens, + head_dim, + block_tile_size, + head_tile_size, + IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, + USE_UE8M0=scale_fmt == "ue8m0", + ) + + +@triton.jit +def _cp_gather_indexer_quant_cache_kernel( + kv_cache_ptr, # [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B] + kv_cache_scale_ptr, # [n_blks, blk_size] + k_fp8_ptr, # [num_tokens, head_dim] + k_scale_ptr, # [num_tokens] + block_table_ptr, # [batch_size, block_table_stride] + cu_seqlen_ptr, # [batch_size + 1] + token_to_seq_ptr, # [num_tokens] + block_size, + block_table_stride, + kv_cache_stride, + kv_cache_scale_stride, + HEAD_DIM: tl.constexpr, + BLOCK_TILE_SIZE: tl.constexpr, + HEAD_TILE_SIZE: tl.constexpr, +): + tid = tl.program_id(0) + offset = tl.arange(0, HEAD_DIM) + batch_id = tl.load(token_to_seq_ptr + tid) + batch_start = tl.load(cu_seqlen_ptr + batch_id) + batch_end = tl.load(cu_seqlen_ptr + batch_id + 1) + batch_offset = tid - batch_start + if tid >= batch_end: + return + block_table_id = batch_offset // block_size + block_offset = batch_offset % block_size + block_table_offset = batch_id * block_table_stride + block_table_id + block_id = tl.load(block_table_ptr + block_table_offset) + tiled_block_id = block_offset // BLOCK_TILE_SIZE + tiled_block_offset = block_offset % BLOCK_TILE_SIZE + src_cache_offset = ( + block_id * kv_cache_stride + + tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE + + tiled_block_offset * HEAD_TILE_SIZE + ) + src_scale_offset = block_id * kv_cache_scale_stride + block_offset + dst_offset = tid * HEAD_DIM + src_scale_ptr = kv_cache_scale_ptr + src_scale_offset + src_cache_ptr = kv_cache_ptr + src_cache_offset + dst_k_ptr = k_fp8_ptr + dst_offset + scale_val = tl.load(src_scale_ptr) + tl.store(k_scale_ptr + tid, scale_val) + tiled_src_offset = ( + offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE + + offset % HEAD_TILE_SIZE + ) + val = tl.load(src_cache_ptr + tiled_src_offset) + tl.store(dst_k_ptr + offset, val) + + +def cp_gather_indexer_k_quant_cache_triton( + k_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] + k_fp8: torch.Tensor, + k_fp8_scale: torch.Tensor, + block_table: torch.Tensor, + cu_seqlen: torch.Tensor, + token_to_seq: torch.Tensor, + block_tile_size: int = 16, + head_tile_size: int = 16, +): + num_tokens = k_fp8.size(0) + block_size = k_cache.size(1) + block_table_stride = block_table.stride(0) + head_dim = k_fp8.shape[-1] + num_blocks = k_cache.shape[0] + # we assume the kv cache already been split to 2 portion + k_cache = k_cache.view(num_blocks, -1) + fp8_dtype = current_platform.fp8_dtype() + k_cache_value = k_cache[:, : block_size * head_dim].view(fp8_dtype) + k_cache_scale = k_cache[:, block_size * head_dim :].view(torch.float32) + grid = (num_tokens,) + k_fp8_scale = k_fp8_scale.view(torch.float32) + _cp_gather_indexer_quant_cache_kernel[grid]( + k_cache_value, + k_cache_scale, + k_fp8, + k_fp8_scale, + block_table, + cu_seqlen, + token_to_seq, + block_size, + block_table_stride, + k_cache_value.stride(0), + k_cache_scale.stride(0), + head_dim, + block_tile_size, + head_tile_size, + ) + + # Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84 def fp8_mqa_logits_torch( q: torch.Tensor, @@ -185,25 +368,31 @@ def rocm_fp8_paged_mqa_logits( """ if rocm_aiter_ops.is_enabled(): - from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1 + batch_size, next_n, heads, head_dim = q_fp8.shape + num_blocks, block_size, _, _ = kv_cache_fp8.shape + + from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits - batch_size, next_n, heads, _ = q_fp8.shape - out_qk = torch.full( - (heads, batch_size * next_n, max_model_len), + out_logits = torch.full( + [batch_size * next_n, max_model_len], float("-inf"), device="cuda", dtype=torch.float32, ) - deepgemm_fp8_paged_mqa_logits_stage1( + deepgemm_fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, weights, - out_qk, + out_logits, context_lens, block_tables, max_model_len, + ChunkK=256, + Preshuffle=block_size == 64, + KVBlockSize=block_size, + WavePerEU=2, ) - return out_qk.sum(dim=0) + return out_logits else: return fp8_paged_mqa_logits_torch( q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len From 57b0f87658a4d76dd706bc02b86a17465a44bc1c Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 11 Dec 2025 06:08:07 +0000 Subject: [PATCH 2/9] refactor the SparseAttnIndexer as CustomOp Signed-off-by: ganyi --- vllm/_aiter_ops.py | 12 + vllm/config/compilation.py | 1 + .../layers/sparse_attn_indexer.py | 310 +++++++++++++++ vllm/model_executor/models/deepseek_v2.py | 271 +------------ vllm/platforms/rocm.py | 3 + .../backends/mla/rocm_aiter_mla_sparse.py | 2 +- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 368 +++++++++++++----- 7 files changed, 620 insertions(+), 347 deletions(-) create mode 100644 vllm/model_executor/layers/sparse_attn_indexer.py diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 2e0c4a69c82f..9ff10a42fdd9 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -7,6 +7,10 @@ from torch._ops import OpOverload import vllm.envs as envs +from vllm.attention.ops.rocm_aiter_mla_sparse import ( + rocm_aiter_sparse_attn_indexer, + rocm_aiter_sparse_attn_indexer_fake, +) from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer @@ -1091,6 +1095,14 @@ def register_ops_once() -> None: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_sparse_attn_indexer", + op_func=rocm_aiter_sparse_attn_indexer, + mutates_args=["topk_indices_buffer"], + fake_impl=rocm_aiter_sparse_attn_indexer_fake, + dispatch_key=current_platform.dispatch_key, + ) + _OPS_REGISTERED = True @staticmethod diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 035aa24e33c7..fbc372ecb303 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -620,6 +620,7 @@ class CompilationConfig: "vllm::gdn_attention_core", "vllm::kda_attention", "vllm::sparse_attn_indexer", + "vllm::rocm_aiter_sparse_attn_indexer", ] def compute_hash(self) -> str: diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py new file mode 100644 index 000000000000..0274a7ba3785 --- /dev/null +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -0,0 +1,310 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Custom Sparse Attention Indexer layers.""" + +import torch + +from vllm._aiter_ops import rocm_aiter_ops +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits +from vllm.utils.torch_utils import direct_register_custom_op +from vllm.v1.attention.backends.mla.indexer import ( + DeepseekV32IndexerMetadata, +) +from vllm.v1.worker.workspace import current_workspace_manager + +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + +logger = init_logger(__name__) + + +def sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor, +) -> torch.Tensor: + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + fp8_dtype = current_platform.fp8_dtype() + + # assert isinstance(attn_metadata, dict) + if not isinstance(attn_metadata, dict): + # Reserve workspace for indexer during profiling run + current_workspace_manager().get_simultaneous( + ((total_seq_lens, head_dim), torch.float8_e4m3fn), + ((total_seq_lens, 4), torch.uint8), + ) + return sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + ) + attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + ops.indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) + + topk_indices_buffer[: hidden_states.shape[0]] = -1 + if has_prefill: + prefill_metadata = attn_metadata.prefill + + # Get the full shared workspace buffers once (will allocate on first use) + workspace_manager = current_workspace_manager() + k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( + ((total_seq_lens, head_dim), fp8_dtype), + ((total_seq_lens, 4), torch.uint8), + ) + for chunk in prefill_metadata.chunks: + k_fp8 = k_fp8_full[: chunk.total_seq_lens] + k_scale = k_scale_full[: chunk.total_seq_lens] + ops.cp_gather_indexer_k_quant_cache( + kv_cache, + k_fp8, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + ) + + logits = fp8_mqa_logits( + q_fp8[chunk.token_start : chunk.token_end], + (k_fp8, k_scale.view(torch.float32)), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + num_rows = logits.shape[0] + + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + torch.ops._C.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if has_decode: + decode_metadata = attn_metadata.decode + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache = kv_cache.unsqueeze(-2) + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + # pad in edge case where we have short chunked prefill length < + # decode_threshold since we unstrictly split + # prefill and decode by decode_threshold + # (currently set to 1 + speculative tokens) + padded_q_fp8_decode_tokens = pack_seq_triton( + q_fp8[:num_decode_tokens], decode_lens + ) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_fp8.shape[1:] + ) + # TODO: move and optimize below logic with triton kernels + batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + + logits = fp8_paged_mqa_logits( + padded_q_fp8_decode_tokens, + kv_cache, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=max_model_len, + ) + + num_rows = logits.shape[0] + + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if decode_metadata.requires_padding: + # if padded, we need to unpack + # the topk indices removing padded tokens + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), + decode_lens, + ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices + ) + + return topk_indices_buffer + + +def sparse_attn_indexer_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + return topk_indices_buffer + + +direct_register_custom_op( + op_name="sparse_attn_indexer", + op_func=sparse_attn_indexer, + mutates_args=["topk_indices_buffer"], + fake_impl=sparse_attn_indexer_fake, + dispatch_key=current_platform.dispatch_key, +) + + +@CustomOp.register("sparse_attn_indexer") +class SparseAttnIndexer(CustomOp): + """Sparse Attention Indexer Custom Op Layer. This layer is extracted as a + separate custom op since it involves heavy custom kernels like `mqa_logits`, + `paged_mqa_logits` and `top_k_per_row`, etc. Those kernels maybe requires + specific memory layout or implementation for different hardware backends to + achieve optimal performance. + + For now, the default native path will use CUDA backend path. Other platform + may requires add the corresponding Custom Op name `sparse_attn_indexer` to + `custom_ops` in `CompilationConfig` to enable the platform specific path. + """ + + def __init__( + self, + k_cache, + quant_block_size: int, + scale_fmt: str, + topk_tokens: int, + head_dim: int, + max_model_len: int, + max_total_seq_len: int, + topk_indices_buffer: torch.Tensor, + ): + super().__init__() + self.k_cache = k_cache + self.quant_block_size = quant_block_size + self.scale_fmt = scale_fmt + self.topk_tokens = topk_tokens + self.head_dim = head_dim + self.max_model_len = max_model_len + self.max_total_seq_len = max_total_seq_len + self.topk_indices_buffer = topk_indices_buffer + + def forwrad_native( + self, + hidden_states: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + ): + return self.forward_cuda(hidden_states, q_fp8, k, weights) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + ): + return torch.ops.vllm.sparse_attn_indexer( + hidden_states, + self.k_cache.layer_prefix, + self.k_cache.kv_cache[0], + q_fp8, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) + + def forward_hip( + self, + hidden_states: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + ): + if rocm_aiter_ops.is_enabled(): + return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( + hidden_states, + self.k_cache.prefix, + self.k_cache.kv_cache[0], + q_fp8, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) + else: + raise RuntimeError( + "Sparse attention indexer ROCm custom op requires ROCm " + "Aiter ops to be enabled." + ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 928fcb83a54c..978095d95316 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -43,7 +43,6 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) -from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -63,6 +62,7 @@ per_token_group_quant_fp8, ) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -74,16 +74,12 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits -from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, - DeepseekV32IndexerMetadata, ) from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -from vllm.v1.worker.workspace import current_workspace_manager from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP from .utils import ( @@ -94,10 +90,8 @@ maybe_prefix, ) -if current_platform.is_cuda_alike(): - from vllm import _custom_ops as ops -elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops as ops +if current_platform.is_cuda_alike() or current_platform.is_xpu(): + pass logger = init_logger(__name__) @@ -599,237 +593,6 @@ def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend -def sparse_attn_indexer( - hidden_states: torch.Tensor, - k_cache_prefix: str, - kv_cache: torch.Tensor, - q_fp8: torch.Tensor, - k: torch.Tensor, - weights: torch.Tensor, - quant_block_size: int, - scale_fmt: str | None, - topk_tokens: int, - head_dim: int, - max_model_len: int, - total_seq_lens: int, - topk_indices_buffer: torch.Tensor | None, -) -> torch.Tensor: - # careful! this will be None in dummy run - attn_metadata = get_forward_context().attn_metadata - fp8_dtype = current_platform.fp8_dtype() - - # assert isinstance(attn_metadata, dict) - if not isinstance(attn_metadata, dict): - # Reserve workspace for indexer during profiling run - current_workspace_manager().get_simultaneous( - ((total_seq_lens, head_dim), torch.float8_e4m3fn), - ((total_seq_lens, 4), torch.uint8), - ) - - return sparse_attn_indexer_fake( - hidden_states, - k_cache_prefix, - kv_cache, - q_fp8, - k, - weights, - quant_block_size, - scale_fmt, - topk_tokens, - head_dim, - max_model_len, - total_seq_lens, - topk_indices_buffer, - ) - attn_metadata = attn_metadata[k_cache_prefix] - assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) - slot_mapping = attn_metadata.slot_mapping - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens - - indexer_k_quant_cache_and_cache_func = ops.indexer_k_quant_and_cache - if current_platform.is_rocm(): - from vllm.attention.ops.rocm_aiter_mla_sparse import ( - indexer_k_quant_and_cache_triton, - ) - - indexer_k_quant_cache_and_cache_func = indexer_k_quant_and_cache_triton - - indexer_k_quant_cache_and_cache_func( - k, - kv_cache, - slot_mapping, - quant_block_size, - scale_fmt, - ) - - topk_indices_buffer[: hidden_states.shape[0]] = -1 - if has_prefill: - prefill_metadata = attn_metadata.prefill - - # Get the full shared workspace buffers once (will allocate on first use) - workspace_manager = current_workspace_manager() - k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( - ((total_seq_lens, head_dim), fp8_dtype), - ((total_seq_lens, 4), torch.uint8), - ) - - for chunk in prefill_metadata.chunks: - k_fp8 = k_fp8_full[: chunk.total_seq_lens] - k_scale = k_scale_full[: chunk.total_seq_lens] - cp_gather_indexer_k_quant_cache_func = ops.cp_gather_indexer_k_quant_cache - if current_platform.is_rocm(): - from functools import partial - - from vllm.attention.ops.rocm_aiter_mla_sparse import ( - cp_gather_indexer_k_quant_cache_triton, - ) - - cp_gather_indexer_k_quant_cache_func = partial( - cp_gather_indexer_k_quant_cache_triton, - token_to_seq=chunk.token_to_seq, - ) - cp_gather_indexer_k_quant_cache_func( - kv_cache, - k_fp8, - k_scale, - chunk.block_table, - chunk.cu_seq_lens, - ) - - fp8_mqa_logits_func = fp8_mqa_logits - if current_platform.is_rocm(): - from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( - rocm_fp8_mqa_logits, - ) - - fp8_mqa_logits_func = rocm_fp8_mqa_logits - logits = fp8_mqa_logits_func( - q_fp8[chunk.token_start : chunk.token_end], - (k_fp8, k_scale.view(torch.float32)), - weights[chunk.token_start : chunk.token_end], - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - ) - num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[ - chunk.token_start : chunk.token_end, :topk_tokens - ] - torch.ops._C.top_k_per_row_prefill( - logits, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) - - if has_decode: - decode_metadata = attn_metadata.decode - # kv_cache size requirement [num_block, block_size, n_head, head_dim], - # we only have [num_block, block_size, head_dim], - kv_cache = kv_cache.unsqueeze(-2) - decode_lens = decode_metadata.decode_lens - if decode_metadata.requires_padding: - # pad in edge case where we have short chunked prefill length < - # decode_threshold since we unstrictly split - # prefill and decode by decode_threshold - # (currently set to 1 + speculative tokens) - - # [num_decode_tokens, n_head, head_dim] -> [bs, 1+next_n, n_head, head_dim] - padded_q_fp8_decode_tokens = pack_seq_triton( - q_fp8[:num_decode_tokens], decode_lens - ) - # [num_decode_tokens, n_head] -> [bs, 1+next_n, n_head] - padded_weights = pack_seq_triton(weights[:num_decode_tokens], decode_lens) - # [bs, 1+next_n, n_head] -> [bs * next_n, n_head] - padded_weights = padded_weights.flatten(0, 1) - else: - padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( - decode_lens.shape[0], -1, *q_fp8.shape[1:] - ) - padded_weights = weights - # TODO: move and optimize below logic with triton kernels - batch_size = padded_q_fp8_decode_tokens.shape[0] - next_n = padded_q_fp8_decode_tokens.shape[1] - assert batch_size == decode_metadata.seq_lens.shape[0] - num_padded_tokens = batch_size * next_n - fp8_paged_mqa_logits_func = fp8_paged_mqa_logits - if current_platform.is_rocm(): - from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( - rocm_fp8_paged_mqa_logits, - ) - - fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits - - logits = fp8_paged_mqa_logits_func( - padded_q_fp8_decode_tokens, - kv_cache, - padded_weights[:num_padded_tokens], - decode_metadata.seq_lens, - decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len=max_model_len, - ) - - num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - - torch.ops._C.top_k_per_row_decode( - logits, - next_n, - decode_metadata.seq_lens, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) - - if decode_metadata.requires_padding: - # if padded, we need to unpack - # the topk indices removing padded tokens - topk_indices = unpack_seq_triton( - topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), - decode_lens, - ) - topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( - topk_indices - ) - - return topk_indices_buffer - - -def sparse_attn_indexer_fake( - hidden_states: torch.Tensor, - k_cache_prefix: str, - kv_cache: torch.Tensor, - q_fp8: torch.Tensor, - k: torch.Tensor, - weights: torch.Tensor, - quant_block_size: int, - scale_fmt: str | None, - topk_tokens: int, - head_dim: int, - max_model_len: int, - total_seq_lens: int, - topk_indices_buffer: torch.Tensor | None, -) -> torch.Tensor: - return topk_indices_buffer - - -direct_register_custom_op( - op_name="sparse_attn_indexer", - op_func=sparse_attn_indexer, - mutates_args=["topk_indices_buffer"], - fake_impl=sparse_attn_indexer_fake, - dispatch_key=current_platform.dispatch_key, -) - - class Indexer(nn.Module): def __init__( self, @@ -894,6 +657,16 @@ def __init__( from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) + self.indexer_op = SparseAttnIndexer( + self.k_cache, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) def forward( self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb @@ -916,6 +689,8 @@ def forward( q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim) k_pe = k_pe.reshape(-1, 1, self.rope_dim) + # `rotary_emb` is shape-preserving; `q_pe` is already + # [num_tokens, n_head, rope_dim]. q = torch.cat([q_pe, q_nope], dim=-1) # `k_pe` is [num_tokens, 1, rope_dim] (MQA). k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) @@ -937,21 +712,7 @@ def forward( ) weights = weights.squeeze(-1) - return torch.ops.vllm.sparse_attn_indexer( - hidden_states, - self.k_cache.prefix, - self.k_cache.kv_cache[0], - q_fp8, - k, - weights, - self.quant_block_size, - self.scale_fmt, - self.topk_tokens, - self.head_dim, - self.max_model_len, - self.max_total_seq_len, - self.topk_indices_buffer, - ) + return self.indexer_op(hidden_states, q_fp8, k, weights) class DeepseekV2MLAAttention(nn.Module): diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index fbb03fc1d3c6..4453a149b3fe 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -477,6 +477,9 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: ): compilation_config.custom_ops.append("+grouped_topk") + # Default dispatch to rocm's sparse_attn_indexer implementation + compilation_config.custom_ops.append("+sparse_attn_indexer") + @classmethod def verify_model_arch(cls, model_arch: str) -> None: if model_arch in _ROCM_UNSUPPORTED_MODELS: diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 3c56dbe378b9..77e7b0c04da5 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -141,7 +141,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata): class ROCMAiterMLASparseMetadataBuilder( AttentionMetadataBuilder[ROCMAiterMLASparseMetadata] ): - cudagraph_support: ClassVar[AttentionCGSupport] = ( + _cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE ) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index c6bad7e89644..19b5ad9c6bc9 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -1,16 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools import importlib -from functools import lru_cache import torch -from vllm._aiter_ops import rocm_aiter_ops -from vllm.logger import init_logger +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.forward_context import get_forward_context from vllm.platforms import current_platform from vllm.triton_utils import tl, triton - -logger = init_logger(__name__) +from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata @triton.jit @@ -194,92 +193,6 @@ def cp_gather_indexer_k_quant_cache_triton( head_tile_size, ) - -# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84 -def fp8_mqa_logits_torch( - q: torch.Tensor, - kv: tuple[torch.Tensor, torch.Tensor], - weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, - cu_seqlen_ke: torch.Tensor, -) -> torch.Tensor: - """Compute FP8 MQA logits for a single sequence without KV paging. - - Args: - q: Query tensor of shape [M, H, D]. Casted to - `torch.float8_e4m3fn` by caller. - kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with - dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or - [N, 1]) with dtype `torch.float32`. - weights: weights of shape [M, H], dtype `torch.float32`. - cu_seqlen_ks: Start indices (inclusive) for valid K per query position, - shape [M], dtype int32. - cu_seqlen_ke: End indices (exclusive) for valid K per query position, - shape [M], dtype int32. - - Returns: - Logits tensor of shape [M, N], dtype `torch.float32`. - """ - k_fp8, scale = kv - seq_len_kv = k_fp8.shape[0] - k = k_fp8.to(torch.bfloat16) - q = q.to(torch.bfloat16) - - mask_lo = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] - ) - mask_hi = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] - ) - mask = mask_lo & mask_hi - - score = torch.einsum("mhd,nd->hmn", q, k).float() * scale - logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float("-inf")) - - return logits - - -def rocm_fp8_mqa_logits( - q: torch.Tensor, - kv: tuple[torch.Tensor, torch.Tensor], - weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, - cu_seqlen_ke: torch.Tensor, -) -> torch.Tensor: - """Compute FP8 MQA logits for a single sequence without KV paging. - - Args: - q: Query tensor of shape [M, H, D]. Casted to - `torch.float8_e4m3fn` by caller. - kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with - dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or - [N, 1]) with dtype `torch.float32`. - weights: weights of shape [M, H], dtype `torch.float32`. - cu_seqlen_ks: Start indices (inclusive) for valid K per query position, - shape [M], dtype int32. - cu_seqlen_ke: End indices (exclusive) for valid K per query position, - shape [M], dtype int32. - - Returns: - Logits tensor of shape [M, N], dtype `torch.float32`. - """ - - # TODO(ganyi): Temporarily workaround, will remove the module check and reference - # path after aiter merge this kernel into main - @lru_cache - def has_mqa_logits_module(): - return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None - - if rocm_aiter_ops.is_enabled() and has_mqa_logits_module(): - from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits - - kv, scale = kv - return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke) - else: - return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) - - # Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156 def fp8_paged_mqa_logits_torch( q: torch.Tensor, @@ -366,6 +279,7 @@ def rocm_fp8_paged_mqa_logits( Logits tensor of shape [B * next_n, max_model_len], dtype `torch.float32`. """ + from vllm._aiter_ops import rocm_aiter_ops if rocm_aiter_ops.is_enabled(): batch_size, next_n, heads, head_dim = q_fp8.shape @@ -397,3 +311,275 @@ def rocm_fp8_paged_mqa_logits( return fp8_paged_mqa_logits_torch( q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len ) + + +# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84 +def fp8_mqa_logits_torch( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + kv, scale = kv + seq_len_kv = kv.shape[0] + k = kv.to(torch.bfloat16) + q = q.to(torch.bfloat16) + + mask_lo = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + ) + mask_hi = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] + ) + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k).float() * scale + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + return logits + + +def rocm_fp8_mqa_logits( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + + # TODO(ganyi): Temporarily workaround, will remove the module check and reference + # path after aiter merge this kernel into main + from vllm._aiter_ops import rocm_aiter_ops + + @functools.lru_cache + def has_mqa_logits_module(): + return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None + + if rocm_aiter_ops.is_enabled() and has_mqa_logits_module(): + from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits + + kv, scale = kv + return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke) + else: + return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + + +def rocm_aiter_sparse_attn_indexer_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + # profile run + # NOTE(Chen): create the max possible flattened_kv. So that + # profile_run can get correct memory usage. + _flattened_kv = torch.empty( + [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 + ) + fp8_dtype = current_platform.fp8_dtype() + _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() + _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() + return topk_indices_buffer + + +def rocm_aiter_sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + fp8_dtype = current_platform.fp8_dtype() + # assert isinstance(attn_metadata, dict) + if not isinstance(attn_metadata, dict): + return rocm_aiter_sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + ) + attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + indexer_k_quant_and_cache_triton( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) + + topk_indices_buffer[: hidden_states.shape[0]] = -1 + if has_prefill: + prefill_metadata = attn_metadata.prefill + for chunk in prefill_metadata.chunks: + k_fp8 = torch.empty( + [chunk.total_seq_lens, head_dim], + device=k.device, + dtype=fp8_dtype, + ) + k_scale = torch.empty( + [chunk.total_seq_lens, 4], + device=k.device, + dtype=torch.uint8, + ) + cp_gather_indexer_k_quant_cache_triton( + kv_cache, + k_fp8, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + chunk.token_to_seq, + ) + + logits = rocm_fp8_mqa_logits( + q_fp8[chunk.token_start : chunk.token_end], + (k_fp8, k_scale.view(torch.float32)), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + torch.ops._C.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if has_decode: + decode_metadata = attn_metadata.decode + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache = kv_cache.unsqueeze(-2) + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + # pad in edge case where we have short chunked prefill length < + # decode_threshold since we unstrictly split + # prefill and decode by decode_threshold + # (currently set to 1 + speculative tokens) + padded_q_fp8_decode_tokens = pack_seq_triton( + q_fp8[:num_decode_tokens], decode_lens + ) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_fp8.shape[1:] + ) + # TODO: move and optimize below logic with triton kernels + batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + + logits = rocm_fp8_paged_mqa_logits( + padded_q_fp8_decode_tokens, + kv_cache, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=max_model_len, + ) + + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if decode_metadata.requires_padding: + # if padded, we need to unpack + # the topk indices removing padded tokens + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), + decode_lens, + ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices + ) + + return topk_indices_buffer From 608e41fdfd3de56c86c6b7fb769266f01fc6720d Mon Sep 17 00:00:00 2001 From: ganyi Date: Wed, 7 Jan 2026 06:20:31 +0000 Subject: [PATCH 3/9] raise NotImplementedError for other platform Signed-off-by: ganyi --- vllm/model_executor/layers/sparse_attn_indexer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 0274a7ba3785..3f196dddcf73 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -255,7 +255,15 @@ def forwrad_native( k: torch.Tensor, weights: torch.Tensor, ): - return self.forward_cuda(hidden_states, q_fp8, k, weights) + if current_platform.is_cuda(): + return self.forward_cuda(hidden_states, q_fp8, k, weights) + elif current_platform.is_rocm(): + return self.forward_hip(hidden_states, q_fp8, k, weights) + else: + raise NotImplementedError( + "SparseAttnIndexer native forward is only implemented for " + "CUDA and ROCm platform." + ) def forward_cuda( self, From 439f16c7e53675c9de78a1c89041c4d5c37666ed Mon Sep 17 00:00:00 2001 From: ganyi Date: Wed, 14 Jan 2026 06:09:19 +0000 Subject: [PATCH 4/9] remove import Signed-off-by: ganyi --- vllm/_aiter_ops.py | 6 +- .../layers/sparse_attn_indexer.py | 6 +- vllm/model_executor/models/deepseek_v2.py | 4 - vllm/platforms/rocm.py | 5 +- vllm/v1/attention/backends/mla/indexer.py | 5 +- .../backends/mla/rocm_aiter_mla_sparse.py | 7 +- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 108 ++++++++++++------ 7 files changed, 88 insertions(+), 53 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 9ff10a42fdd9..3e232d61924a 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -7,12 +7,12 @@ from torch._ops import OpOverload import vllm.envs as envs -from vllm.attention.ops.rocm_aiter_mla_sparse import ( +from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( rocm_aiter_sparse_attn_indexer, rocm_aiter_sparse_attn_indexer_fake, ) -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer _FP8_DTYPE = current_platform.fp8_dtype() diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 3f196dddcf73..aeffbf227746 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -5,7 +5,6 @@ import torch from vllm._aiter_ops import rocm_aiter_ops -from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp @@ -15,6 +14,7 @@ from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerMetadata, ) +from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.worker.workspace import current_workspace_manager if current_platform.is_cuda_alike(): @@ -248,7 +248,7 @@ def __init__( self.max_total_seq_len = max_total_seq_len self.topk_indices_buffer = topk_indices_buffer - def forwrad_native( + def forward_native( self, hidden_states: torch.Tensor, q_fp8: torch.Tensor, @@ -274,7 +274,7 @@ def forward_cuda( ): return torch.ops.vllm.sparse_attn_indexer( hidden_states, - self.k_cache.layer_prefix, + self.k_cache.prefix, self.k_cache.kv_cache[0], q_fp8, k, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 978095d95316..c8b6533dcdf2 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -78,7 +78,6 @@ from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, ) -from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP @@ -90,9 +89,6 @@ maybe_prefix, ) -if current_platform.is_cuda_alike() or current_platform.is_xpu(): - pass - logger = init_logger(__name__) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 4453a149b3fe..782235af8297 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -214,7 +214,10 @@ def get_attn_backend_cls( raise ValueError( "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." ) - logger.info_once("Using Sparse MLA backend on V1 engine.") + assert block_size == 1, ( + "Sparse MLA backend on ROCm only supports block size 1 for now." + ) + logger.info_once("Using Sparse MLA backend.") return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() if attn_selector_config.use_mla: diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 791710f19c20..363979b4a236 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -7,6 +7,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported from vllm.v1.attention.backend import ( AttentionBackend, @@ -24,7 +25,9 @@ class DeepseekV32IndexerBackend(AttentionBackend): - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [1 if current_platform.is_rocm() else 64] @classmethod def get_supported_head_sizes(cls) -> list[int]: diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 77e7b0c04da5..47543ef1efb7 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -11,11 +11,11 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.triton_utils import tl, triton from vllm.model_executor.layers.attention.mla_attention import ( MLACommonBaseImpl, get_mla_dims, ) +from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -157,7 +157,6 @@ def __init__( parallel_config = vllm_config.parallel_config self.device = device max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens - max_num_seqs = vllm_config.scheduler_config.max_num_seqs self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) @@ -182,7 +181,7 @@ def __init__( 0, max_num_batched_tokens + 1, dtype=torch.int32, device=device ) self.paged_kv_last_page_len = torch.ones( - max_num_seqs, dtype=torch.int32, device=device + max_num_batched_tokens, dtype=torch.int32, device=device ) # These two needs to be calculated in runtime, @@ -193,7 +192,7 @@ def __init__( device=device, ) self.paged_kv_indptr = torch.zeros( - [max_num_seqs + 1], dtype=torch.int32, device=device + [max_num_batched_tokens + 1], dtype=torch.int32, device=device ) def build( diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 19b5ad9c6bc9..a496ed349e63 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -5,17 +5,21 @@ import torch -from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.forward_context import get_forward_context from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata +from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton + +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops @triton.jit def _indexer_k_quant_and_cache_kernel( k_ptr, # [num_tokens, head_dim] kv_cache_ptr, # [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B] + # [n_blocks, blk_size, head_dim] kv_cache_scale_ptr, # [n_blks, blk_size] slot_mapping_ptr, # [num_tokens] kv_cache_scale_stride, @@ -23,6 +27,7 @@ def _indexer_k_quant_and_cache_kernel( block_size, num_tokens, head_dim: tl.constexpr, + LAYOUT: tl.constexpr, BLOCK_TILE_SIZE: tl.constexpr, HEAD_TILE_SIZE: tl.constexpr, IS_FNUZ: tl.constexpr, @@ -30,10 +35,13 @@ def _indexer_k_quant_and_cache_kernel( ): tid = tl.program_id(0) offset = tl.arange(0, head_dim) - tile_offset = ( - offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE - + offset % HEAD_TILE_SIZE - ) + if LAYOUT == "SHUFFLE": + tile_offset = ( + offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE + + offset % HEAD_TILE_SIZE + ) + else: + tile_offset = offset tile_store_offset = tile_offset # for idx in tl.range(tid, num_tokens, n_program): src_ptr = k_ptr + tid * head_dim @@ -55,12 +63,17 @@ def _indexer_k_quant_and_cache_kernel( scale = tl.exp2(tl.ceil(tl.log2(scale))) fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty) - dst_ptr = ( - kv_cache_ptr - + block_id * kv_cache_value_stride - + tile_block_id * BLOCK_TILE_SIZE * head_dim - + tile_block_offset * HEAD_TILE_SIZE - ) + if LAYOUT == "SHUFFLE": + dst_ptr = ( + kv_cache_ptr + + block_id * kv_cache_value_stride + + tile_block_id * BLOCK_TILE_SIZE * head_dim + + tile_block_offset * HEAD_TILE_SIZE + ) + else: + dst_ptr = ( + kv_cache_ptr + block_id * kv_cache_value_stride + block_offset * head_dim + ) tl.store(dst_ptr + tile_store_offset, fp8_val) dst_scale_ptr = kv_cache_scale_ptr + block_id * kv_cache_scale_stride + block_offset tl.store(dst_scale_ptr, scale) @@ -96,6 +109,7 @@ def indexer_k_quant_and_cache_triton( block_size, num_tokens, head_dim, + "NHD", block_tile_size, head_tile_size, IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, @@ -106,6 +120,7 @@ def indexer_k_quant_and_cache_triton( @triton.jit def _cp_gather_indexer_quant_cache_kernel( kv_cache_ptr, # [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B] + # [n_blks, blk_size, head_dim] kv_cache_scale_ptr, # [n_blks, blk_size] k_fp8_ptr, # [num_tokens, head_dim] k_scale_ptr, # [num_tokens] @@ -116,6 +131,7 @@ def _cp_gather_indexer_quant_cache_kernel( block_table_stride, kv_cache_stride, kv_cache_scale_stride, + LAYOUT: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_TILE_SIZE: tl.constexpr, HEAD_TILE_SIZE: tl.constexpr, @@ -134,11 +150,14 @@ def _cp_gather_indexer_quant_cache_kernel( block_id = tl.load(block_table_ptr + block_table_offset) tiled_block_id = block_offset // BLOCK_TILE_SIZE tiled_block_offset = block_offset % BLOCK_TILE_SIZE - src_cache_offset = ( - block_id * kv_cache_stride - + tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE - + tiled_block_offset * HEAD_TILE_SIZE - ) + if LAYOUT == "SHUFFLE": + src_cache_offset = ( + block_id * kv_cache_stride + + tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE + + tiled_block_offset * HEAD_TILE_SIZE + ) + else: + src_cache_offset = block_id * kv_cache_stride + block_offset * HEAD_DIM src_scale_offset = block_id * kv_cache_scale_stride + block_offset dst_offset = tid * HEAD_DIM src_scale_ptr = kv_cache_scale_ptr + src_scale_offset @@ -146,10 +165,13 @@ def _cp_gather_indexer_quant_cache_kernel( dst_k_ptr = k_fp8_ptr + dst_offset scale_val = tl.load(src_scale_ptr) tl.store(k_scale_ptr + tid, scale_val) - tiled_src_offset = ( - offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE - + offset % HEAD_TILE_SIZE - ) + if LAYOUT == "SHUFFLE": + tiled_src_offset = ( + offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE + + offset % HEAD_TILE_SIZE + ) + else: + tiled_src_offset = offset val = tl.load(src_cache_ptr + tiled_src_offset) tl.store(dst_k_ptr + offset, val) @@ -188,11 +210,13 @@ def cp_gather_indexer_k_quant_cache_triton( block_table_stride, k_cache_value.stride(0), k_cache_scale.stride(0), + "NHD", head_dim, block_tile_size, head_tile_size, ) + # Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156 def fp8_paged_mqa_logits_torch( q: torch.Tensor, @@ -282,31 +306,27 @@ def rocm_fp8_paged_mqa_logits( from vllm._aiter_ops import rocm_aiter_ops if rocm_aiter_ops.is_enabled(): - batch_size, next_n, heads, head_dim = q_fp8.shape - num_blocks, block_size, _, _ = kv_cache_fp8.shape - - from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits + from aiter.ops.triton.attention.pa_mqa_logits import ( + deepgemm_fp8_paged_mqa_logits_stage1, + ) - out_logits = torch.full( - [batch_size * next_n, max_model_len], + batch_size, next_n, heads, _ = q_fp8.shape + out_qk = torch.full( + (heads, batch_size * next_n, max_model_len), float("-inf"), device="cuda", dtype=torch.float32, ) - deepgemm_fp8_paged_mqa_logits( + deepgemm_fp8_paged_mqa_logits_stage1( q_fp8, kv_cache_fp8, weights, - out_logits, + out_qk, context_lens, block_tables, max_model_len, - ChunkK=256, - Preshuffle=block_size == 64, - KVBlockSize=block_size, - WavePerEU=2, ) - return out_logits + return out_qk.sum(dim=0) else: return fp8_paged_mqa_logits_torch( q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len @@ -392,7 +412,7 @@ def has_mqa_logits_module(): return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None if rocm_aiter_ops.is_enabled() and has_mqa_logits_module(): - from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits + from aiter.ops.triton.attention.fp8_mqa_logits import fp8_mqa_logits kv, scale = kv return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke) @@ -469,7 +489,14 @@ def rocm_aiter_sparse_attn_indexer( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - indexer_k_quant_and_cache_triton( + # indexer_k_quant_and_cache_triton( + # k, + # kv_cache, + # slot_mapping, + # quant_block_size, + # scale_fmt, + # ) + ops.indexer_k_quant_and_cache( k, kv_cache, slot_mapping, @@ -491,13 +518,20 @@ def rocm_aiter_sparse_attn_indexer( device=k.device, dtype=torch.uint8, ) - cp_gather_indexer_k_quant_cache_triton( + # cp_gather_indexer_k_quant_cache_triton( + # kv_cache, + # k_fp8, + # k_scale, + # chunk.block_table, + # chunk.cu_seq_lens, + # chunk.token_to_seq, + # ) + ops.cp_gather_indexer_k_quant_cache( kv_cache, k_fp8, k_scale, chunk.block_table, chunk.cu_seq_lens, - chunk.token_to_seq, ) logits = rocm_fp8_mqa_logits( From d243f6f932504de87d8f727697aabf24a9869017 Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 27 Nov 2025 06:04:24 +0000 Subject: [PATCH 5/9] further optimize dsv3.2 Signed-off-by: ganyi --- vllm/model_executor/models/deepseek_v2.py | 20 +- .../backends/mla/rocm_aiter_mla_sparse.py | 243 +++++++++++++++--- 2 files changed, 217 insertions(+), 46 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index c8b6533dcdf2..c6ed406fcb96 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -669,27 +669,13 @@ def forward( ) -> torch.Tensor: q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) - q_pe, q_nope = torch.split( - q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 - ) k, _ = self.wk(hidden_states) k = self.k_norm(k) - k_pe, k_nope = torch.split( - k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 - ) - q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) - # Note: RoPE (NeoX) can introduce extra leading dimensions during compilation - # so we need to reshape back to token-flattened shapes - q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim) - k_pe = k_pe.reshape(-1, 1, self.rope_dim) - - # `rotary_emb` is shape-preserving; `q_pe` is already - # [num_tokens, n_head, rope_dim]. - q = torch.cat([q_pe, q_nope], dim=-1) - # `k_pe` is [num_tokens, 1, rope_dim] (MQA). - k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) + rotary_emb( + positions, q[..., : self.rope_dim], k[..., : self.rope_dim].unsqueeze(1) + ) # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 47543ef1efb7..afafaed47c4a 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -24,9 +24,6 @@ AttentionMetadataBuilder, CommonAttentionMetadata, ) -from vllm.v1.attention.backends.mla.flashmla_sparse import ( - triton_convert_req_index_to_global_index, -) from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: @@ -34,6 +31,184 @@ logger = init_logger(__name__) +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + cu_seqlens_ptr, # int32 [num_tokens + 1] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load cumulative sequence lengths to get starting index of this request + seq_start = tl.load(cu_seqlens_ptr + token_id) + seq_end = tl.load(cu_seqlens_ptr + token_id + 1) + + if tile_id * BLOCK_N + seq_start >= seq_end: + return + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + base = tl.load(bt_ptr, mask=valid_block, other=0) + + # # If token == -1 OR block_id OOB, output 0; else base * BLOCK_SIZE + offset + out_val = tl.where( + is_invalid_tok | (~valid_block), 0, base * BLOCK_SIZE + inblock_off + ) + out_ptr_ij = out_ptr + seq_start + indice_id + out_ptr_ij_mask = (seq_start + indice_id) < seq_end + + # store the results with mask + tl.store(out_ptr_ij, out_val, mask=out_ptr_ij_mask) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + cu_seqlens: torch.Tensor, # int32 [num_tokens + 1] + paged_kv_indices: torch.Tensor, # int32 [num_tokens * topk] out_buffer + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns +): + """ + out[token_id, indice_id] = + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})" + ) + # print("req_id: ", req_id, flush=True) + num_tokens = req_id.shape[0] + _, max_num_blocks_per_req = block_table.shape + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + cu_seqlens, + paged_kv_indices, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + ) + return + + +@triton.jit +def generate_sparse_seqlen_kernel( + seq_len_ptr, # [num_seq] + cu_query_lens_ptr, # [num_seq] + out_ptr, # [num_query_tokens] + topk_token: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + seq_id = tl.program_id(0) + query_offset = tl.program_id(1) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + query_start = tl.load(cu_query_lens_ptr + seq_id) + query_end = tl.load(cu_query_lens_ptr + seq_id + 1) + if query_start + tl.program_id(1) * BLOCK_SIZE > query_end: + return + query_len = query_end - query_start + query_mask = query_offset + query_start < query_end + seq_len = tl.load(seq_len_ptr + seq_id) + context_start_point = seq_len - query_len + sparse_seqlen = context_start_point + query_offset + sparse_seqlen_masked = tl.where( + sparse_seqlen + 1 < topk_token, sparse_seqlen + 1, topk_token + ) + tl.store( + out_ptr + query_start + query_offset, sparse_seqlen_masked, mask=query_mask + ) + + +def generate_sparse_seqlen_triton( + query_lens: torch.Tensor, + seq_lens: torch.Tensor, + cu_query_lens: torch.Tensor, + topk_token: int, + num_tokens: int, + max_query_len: int, +): + num_seqs = query_lens.size(0) + out = torch.empty([num_tokens], dtype=torch.int32, device=query_lens.device) + block_size = 64 + num_block_per_row = triton.cdiv(max_query_len, block_size) + grid = ( + num_seqs, + num_block_per_row, + ) + generate_sparse_seqlen_kernel[grid]( + seq_lens, + cu_query_lens, + out, + topk_token, + block_size, + ) + return out + + @triton.jit def fetch_id_to_ragged_kernel( in_tensor_ptr, # [num_seq, topk] @@ -131,7 +306,6 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata): paged_kv_last_page_len: torch.Tensor paged_kv_indices: torch.Tensor paged_kv_indptr: torch.Tensor - paged_kv_indptr_rest: torch.Tensor block_size: int = 1 topk_tokens: int = 2048 @@ -161,9 +335,6 @@ def __init__( self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.topk_tokens = vllm_config.model_config.hf_config.index_topk - self.topk_tokens_tensor = torch.tensor( - [self.topk_tokens], device=device, dtype=torch.int32 - ) self.max_model_len_tensor = torch.tensor( [self.model_config.max_model_len], device=device, dtype=torch.int32 ) @@ -209,18 +380,33 @@ def build( ) # Zero-fill for cudagraphs self.req_id_per_token_buffer.fill_(0) + self.paged_kv_indices.fill_(0) + self.paged_kv_indptr.fill_(0) self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( torch.from_numpy(req_id_per_token), non_blocking=True ) - self.paged_kv_indices.fill_(0) - self.paged_kv_indptr.fill_(0) + query_lens = ( + common_attn_metadata.query_start_loc[1:] + - common_attn_metadata.query_start_loc[:-1] + ) + seq_lens = common_attn_metadata.seq_lens + sparse_seqlen = generate_sparse_seqlen_triton( + query_lens, + seq_lens, + common_attn_metadata.query_start_loc, + self.topk_tokens, + num_tokens, + common_attn_metadata.max_query_len, + ) + + torch.cumsum(sparse_seqlen, dim=0, out=self.paged_kv_indptr[1 : num_tokens + 1]) + self.paged_kv_indptr[num_tokens + 1 :].fill_(self.paged_kv_indptr[num_tokens]) req_id_per_token = self.req_id_per_token_buffer[:num_tokens] qo_indptr = self.qo_indptr[: num_tokens + 1] paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens] - paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens] paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1] - paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :] + paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens] metadata = ROCMAiterMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, @@ -237,7 +423,6 @@ def build( paged_kv_last_page_len=paged_kv_last_page_len, paged_kv_indices=paged_kv_indices, paged_kv_indptr=paged_kv_indptr, - paged_kv_indptr_rest=paged_kv_indptr_rest, ) return metadata @@ -311,7 +496,6 @@ def _forward_bf16_kv( self, q: torch.Tensor, # [sq, heads, d_qk] kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk] - topk_indices: torch.Tensor, # [sq, topk] attn_metadata: ROCMAiterMLASparseMetadata, ) -> torch.Tensor: num_tokens = q.shape[0] @@ -320,15 +504,6 @@ def _forward_bf16_kv( dtype=q.dtype, device=q.device, ) - seq_len = (topk_indices != -1).sum(dim=-1) - torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:]) - attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1]) - fetch_id_to_ragged_triton( - topk_indices, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.topk_tokens, - ) rocm_aiter_ops.mla_decode_fwd( q, @@ -384,9 +559,21 @@ def forward( # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) if self.is_fp8bmm_enabled: + num_tokens = q.shape[0] + q = torch.empty( + [num_tokens, self.num_heads, self.kv_lora_rank + self.qk_rope_head_dim], + dtype=q.dtype, + device=q.device, + ) + q[:, :, self.kv_lora_rank :] = q_pe # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) ql_nope = rocm_aiter_ops.triton_fp8_bmm( - q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True, + YQ=q[:, :, : self.kv_lora_rank], ) else: # Multiply (N, B, P) x (N, P, L) -> (N, B, L) @@ -397,16 +584,16 @@ def forward( assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[:num_actual_toks] - topk_indices_global = triton_convert_req_index_to_global_index( + triton_convert_req_index_to_global_index( attn_metadata.req_id_per_token, attn_metadata.block_table, topk_indices, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, BLOCK_SIZE=attn_metadata.block_size, NUM_TOPK_TOKENS=attn_metadata.topk_tokens, ) - q = torch.cat([ql_nope, q_pe], dim=-1) - # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -418,9 +605,7 @@ def forward( scale=layer._k_scale, ) - attn_out = self._forward_bf16_kv( - q, kv_cache, topk_indices_global, attn_metadata - ) + attn_out = self._forward_bf16_kv(q, kv_cache, attn_metadata) self._v_up_proj(attn_out, out=output[:num_actual_toks]) return output From 803117e3c8f58471f0b0204d741d0e5f8e4ad2f3 Mon Sep 17 00:00:00 2001 From: ganyi Date: Tue, 20 Jan 2026 08:39:06 +0000 Subject: [PATCH 6/9] make gluon impl as default Signed-off-by: ganyi --- vllm/platforms/rocm.py | 3 - vllm/v1/attention/backends/mla/indexer.py | 3 +- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 59 +++++++++++-------- 3 files changed, 34 insertions(+), 31 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 782235af8297..d7eb44a528f7 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -214,9 +214,6 @@ def get_attn_backend_cls( raise ValueError( "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." ) - assert block_size == 1, ( - "Sparse MLA backend on ROCm only supports block size 1 for now." - ) logger.info_once("Using Sparse MLA backend.") return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 363979b4a236..52524b8890e2 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -7,7 +7,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported from vllm.v1.attention.backend import ( AttentionBackend, @@ -27,7 +26,7 @@ class DeepseekV32IndexerBackend(AttentionBackend): @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [1 if current_platform.is_rocm() else 64] + return [64] @classmethod def get_supported_head_sizes(cls) -> list[int]: diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index a496ed349e63..3bdf84e84e45 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -12,7 +12,7 @@ from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton if current_platform.is_cuda_alike(): - from vllm import _custom_ops as ops + pass @triton.jit @@ -109,7 +109,7 @@ def indexer_k_quant_and_cache_triton( block_size, num_tokens, head_dim, - "NHD", + "SHUFFLE", block_tile_size, head_tile_size, IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, @@ -210,7 +210,7 @@ def cp_gather_indexer_k_quant_cache_triton( block_table_stride, k_cache_value.stride(0), k_cache_scale.stride(0), - "NHD", + "SHUFFLE", head_dim, block_tile_size, head_tile_size, @@ -306,27 +306,34 @@ def rocm_fp8_paged_mqa_logits( from vllm._aiter_ops import rocm_aiter_ops if rocm_aiter_ops.is_enabled(): + batch_size, next_n, heads, head_dim = q_fp8.shape + num_blocks, block_size, _, _ = kv_cache_fp8.shape + from aiter.ops.triton.attention.pa_mqa_logits import ( - deepgemm_fp8_paged_mqa_logits_stage1, + deepgemm_fp8_paged_mqa_logits, ) batch_size, next_n, heads, _ = q_fp8.shape - out_qk = torch.full( - (heads, batch_size * next_n, max_model_len), + out_logits = torch.full( + [batch_size * next_n, max_model_len], float("-inf"), device="cuda", dtype=torch.float32, ) - deepgemm_fp8_paged_mqa_logits_stage1( + deepgemm_fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, weights, - out_qk, + out_logits, context_lens, block_tables, max_model_len, + ChunkK=256, + Preshuffle=block_size == 64, + KVBlockSize=block_size, + WavePerEU=2, ) - return out_qk.sum(dim=0) + return out_logits else: return fp8_paged_mqa_logits_torch( q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len @@ -489,20 +496,20 @@ def rocm_aiter_sparse_attn_indexer( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - # indexer_k_quant_and_cache_triton( - # k, - # kv_cache, - # slot_mapping, - # quant_block_size, - # scale_fmt, - # ) - ops.indexer_k_quant_and_cache( + indexer_k_quant_and_cache_triton( k, kv_cache, slot_mapping, quant_block_size, scale_fmt, ) + # ops.indexer_k_quant_and_cache( + # k, + # kv_cache, + # slot_mapping, + # quant_block_size, + # scale_fmt, + # ) topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: @@ -518,21 +525,21 @@ def rocm_aiter_sparse_attn_indexer( device=k.device, dtype=torch.uint8, ) - # cp_gather_indexer_k_quant_cache_triton( - # kv_cache, - # k_fp8, - # k_scale, - # chunk.block_table, - # chunk.cu_seq_lens, - # chunk.token_to_seq, - # ) - ops.cp_gather_indexer_k_quant_cache( + cp_gather_indexer_k_quant_cache_triton( kv_cache, k_fp8, k_scale, chunk.block_table, chunk.cu_seq_lens, + chunk.token_to_seq, ) + # ops.cp_gather_indexer_k_quant_cache( + # kv_cache, + # k_fp8, + # k_scale, + # chunk.block_table, + # chunk.cu_seq_lens, + # ) logits = rocm_fp8_mqa_logits( q_fp8[chunk.token_start : chunk.token_end], From e569fa2cfc42f23eb259eb3c4f37f0ab25986e79 Mon Sep 17 00:00:00 2001 From: ganyi Date: Tue, 20 Jan 2026 15:30:10 +0000 Subject: [PATCH 7/9] fix sparse len calculation issue Signed-off-by: ganyi --- vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index afafaed47c4a..14a0803e9af5 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -173,6 +173,9 @@ def generate_sparse_seqlen_kernel( query_len = query_end - query_start query_mask = query_offset + query_start < query_end seq_len = tl.load(seq_len_ptr + seq_id) + # Just return since the out_ptr is zero initialized. + if seq_len == 0: + return context_start_point = seq_len - query_len sparse_seqlen = context_start_point + query_offset sparse_seqlen_masked = tl.where( @@ -192,7 +195,8 @@ def generate_sparse_seqlen_triton( max_query_len: int, ): num_seqs = query_lens.size(0) - out = torch.empty([num_tokens], dtype=torch.int32, device=query_lens.device) + # zero initialize the tensor to make sure invalid positions will be zero + out = torch.zeros([num_tokens], dtype=torch.int32, device=query_lens.device) block_size = 64 num_block_per_row = triton.cdiv(max_query_len, block_size) grid = ( From 9b8f5ef15bc5b75724ed3ab556ccef15fba7a880 Mon Sep 17 00:00:00 2001 From: ganyi Date: Mon, 26 Jan 2026 02:28:38 +0000 Subject: [PATCH 8/9] fix ptpc scale load issue for fused shared expert path in deepseek mtp Signed-off-by: ganyi --- vllm/model_executor/models/deepseek_mtp.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 8fb2bfb16d73..653afe285fab 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -316,7 +316,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Determine split axis based on op type # gate/up: ColumnParallel → split along dim 0 # down: RowParallel → split along dim 1 - split_dim = 1 if "down_proj.weight" in name else 0 + split_dim = ( + 1 + if ("down_proj.weight" in name and loaded_weight.ndim > 1) + else 0 + ) total = loaded_weight.shape[split_dim] assert total % num_chunks == 0, ( f"Shared expert weight dim {total} " @@ -329,14 +333,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weight_to_load = loaded_weight if is_fusion_moe_shared_experts_layer: - if split_dim == 0: - weight_to_load = loaded_weight[ - j * chunk_size : (j + 1) * chunk_size, : - ] + chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size) + if loaded_weight.ndim == 1: + weight_to_load = loaded_weight[chunk_slice] + elif split_dim == 0: + weight_to_load = loaded_weight[chunk_slice, :] else: - weight_to_load = loaded_weight[ - :, j * chunk_size : (j + 1) * chunk_size - ] + weight_to_load = loaded_weight[:, chunk_slice] # Synthesize an expert-style name so expert mapping # can route it chunk_name = name.replace( From 51208a1c208b07b09a6c44c64a54f35afcae9236 Mon Sep 17 00:00:00 2001 From: ganyi Date: Tue, 27 Jan 2026 02:16:34 +0000 Subject: [PATCH 9/9] fp8 kvcache support Signed-off-by: ganyi --- vllm/model_executor/models/config.py | 4 +++- vllm/platforms/rocm.py | 5 ---- .../backends/mla/rocm_aiter_mla_sparse.py | 24 ++++++++++++++++--- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index e51a110ce0b3..bd0dddd38a86 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -519,7 +519,9 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled. cache_config = vllm_config.cache_config - if cache_config.cache_dtype.startswith("fp8"): + if not current_platform.is_rocm() and cache_config.cache_dtype.startswith( + "fp8" + ): cache_config.cache_dtype = "fp8_ds_mla" logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2") if cache_config.cache_dtype == "bfloat16": diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d7eb44a528f7..732f9a3494e7 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -207,13 +207,8 @@ def get_attn_backend_cls( from vllm._aiter_ops import rocm_aiter_ops block_size = attn_selector_config.block_size - kv_cache_dtype = attn_selector_config.kv_cache_dtype if attn_selector_config.use_sparse: - if kv_cache_dtype and kv_cache_dtype.startswith("fp8"): - raise ValueError( - "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." - ) logger.info_once("Using Sparse MLA backend.") return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 14a0803e9af5..9812634db5a9 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -15,6 +15,7 @@ MLACommonBaseImpl, get_mla_dims, ) +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( AttentionBackend, @@ -310,6 +311,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata): paged_kv_last_page_len: torch.Tensor paged_kv_indices: torch.Tensor paged_kv_indptr: torch.Tensor + attn_out_dtype: torch.dtype block_size: int = 1 topk_tokens: int = 2048 @@ -332,6 +334,7 @@ def __init__( ): self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config + self.model_dtype = vllm_config.model_config.dtype parallel_config = vllm_config.parallel_config self.device = device max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens @@ -422,6 +425,7 @@ def build( block_table=common_attn_metadata.block_table_tensor, req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, + attn_out_dtype=self.model_dtype, topk_tokens=self.topk_tokens, qo_indptr=qo_indptr, paged_kv_last_page_len=paged_kv_last_page_len, @@ -496,8 +500,9 @@ def __init__( self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() - def _forward_bf16_kv( + def _forward_mla( self, + layer: AttentionLayer, q: torch.Tensor, # [sq, heads, d_qk] kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk] attn_metadata: ROCMAiterMLASparseMetadata, @@ -505,10 +510,14 @@ def _forward_bf16_kv( num_tokens = q.shape[0] output = torch.empty( [num_tokens, self.num_heads, self.kv_lora_rank], - dtype=q.dtype, + dtype=attn_metadata.attn_out_dtype, device=q.device, ) + # print("kv cache shape: ", kv_c_and_k_pe_cache.shape, flush=True) + # print("kv cache dtype: ", kv_c_and_k_pe_cache.dtype, flush=True) + # print("q scale: ", layer._q_scale, flush=True) + # print("k scale: ", layer._k_scale, flush=True) rocm_aiter_ops.mla_decode_fwd( q, kv_c_and_k_pe_cache, @@ -519,6 +528,8 @@ def _forward_bf16_kv( attn_metadata.paged_kv_indptr, attn_metadata.paged_kv_indices, attn_metadata.paged_kv_last_page_len, + q_scale=layer._q_scale, + kv_scale=layer._k_scale, ) return output[:, : self.num_heads, :] @@ -609,7 +620,14 @@ def forward( scale=layer._k_scale, ) - attn_out = self._forward_bf16_kv(q, kv_cache, attn_metadata) + fp8_attention = self.kv_cache_dtype.startswith("fp8") + if fp8_attention: + original_q_shape = q.shape + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + q, _ = ops.scaled_fp8_quant(q.view(q.shape[0], -1), layer._q_scale) + q = q.view(original_q_shape) + + attn_out = self._forward_mla(layer, q, kv_cache, attn_metadata) self._v_up_proj(attn_out, out=output[:num_actual_toks]) return output