From fef62beed977e4df95b7f29ca8dd6a87196a250c Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Wed, 4 Mar 2026 11:41:23 +0000 Subject: [PATCH 01/13] [feat][plugin] make ATOM mla attention works for vllm Signed-off-by: XiaobingSuper --- atom/model_ops/attention_mla.py | 97 ++- atom/model_ops/attentions/aiter_attention.py | 1 - atom/model_ops/attentions/aiter_mla.py | 14 +- atom/model_ops/base_attention.py | 10 +- atom/model_ops/linear.py | 3 +- atom/model_ops/paged_attention.py | 105 ++- atom/model_ops/utils.py | 8 - atom/models/deepseek_v2.py | 12 +- atom/plugin/attention.py | 806 +++++++++++++++++- atom/plugin/attention_mla.py | 834 +++++++++++++++++++ atom/plugin/vllm/model_wrapper.py | 9 + atom/plugin/vllm/platform.py | 3 +- atom/plugin/vllm/register.py | 20 +- atom/utils/backends.py | 4 +- 14 files changed, 1828 insertions(+), 98 deletions(-) create mode 100644 atom/plugin/attention_mla.py diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index dccb9d620..03616ab54 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -35,6 +35,11 @@ batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm, ) + +from atom.plugin import is_plugin_mode, is_vllm + +from atom.plugin.attention_mla import MLAAttentionImplDecoratorForPluginMode + # torch.set_printoptions(threshold=10_000) logger = logging.getLogger("atom") @@ -92,11 +97,12 @@ def dynamic_per_batched_tensor_quant( return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() +@MLAAttentionImplDecoratorForPluginMode class MLAAttention(nn.Module): def __init__( self, num_heads: int, - head_dim: int, + head_size: int, scale: float, num_kv_heads: int, kv_cache_dtype: str, @@ -107,7 +113,7 @@ def __init__( ) -> None: super().__init__() self.num_heads = num_heads - self.head_dim = head_dim + self.head_dim = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype == "fp8" else "auto" @@ -134,7 +140,34 @@ def __init__( ) self.layer_num = layer_num - def process_weights_after_loading(self): + # for plugin mode(vllm) + if is_vllm(): + self.supports_quant_query_input = False + self.dcp_world_size: int = -1 + from vllm.config import get_current_vllm_config + from vllm.model_executor.layers.attention.mla_attention import ( + MLACommonMetadataBuilder, + ) + + self.chunked_prefill_workspace_size = ( + MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( + get_current_vllm_config() + ) + ) + self.cp_kv_cache_interleave_size: int = ( + get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size + ) + + self.is_aiter_triton_fp4_bmm_enabled = ( + is_rocm_aiter_fp4bmm_enabled() + and self.kv_b_proj.weight.dtype == torch.bfloat16 + ) + # q_pad_num_heads in kwargs + self.q_pad_num_heads = kwargs.get("q_pad_num_heads", None) + self._pad_v = True + self.flash_attn_varlen_func = flash_attn_varlen_func + + def process_weights_after_loading(self, act_dtype: Optional[torch.dtype] = None): if is_rocm_aiter_fp4bmm_enabled(): kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj) self.W_K, self.W_K_scale, W_V, self.W_V_scale = quark_post_load_weights( @@ -146,7 +179,7 @@ def process_weights_after_loading(self): self.W_K_scale = self.W_K_scale.transpose(-2, -1).contiguous() self.W_V = self.W_V.transpose(-2, -1).contiguous() self.W_V_scale = self.W_V_scale.transpose(-2, -1).contiguous() - else: # is_rocm_aiter_fp8bmm_enabled(): + else: # is_rocm_aiter_fp8bmm_enabled() kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, @@ -175,7 +208,7 @@ def process_weights_after_loading(self): W_V, dtype=dtypes.fp8 ) - def _v_up_proj_and_o_proj(self, x): + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V), Convert from (N, B, V) to (B, N, V) @@ -207,7 +240,7 @@ def _v_up_proj_and_o_proj(self, x): ) # Convert from (B, N, V) to (B, N * V) x = x.reshape(-1, self.num_heads * self.v_head_dim) - return self.o_proj(x) + return x def _q_proj_and_k_up_proj(self, x, x_scale=None): q_nope, q_pe = ( @@ -413,7 +446,7 @@ def _forward_prefill_mha( causal=True, ) - return self.o_proj(output.flatten(start_dim=-2)) + return output.flatten(start_dim=-2) def _forward_prefill_mla( self, @@ -480,7 +513,7 @@ def _forward_prefill_mla( None, ) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) def _forward_decode( self, @@ -555,16 +588,15 @@ def _forward_decode( kv_scale=self._k_scale, ) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) - def forward( + def forward_impl_server_mode( self, - q: torch.Tensor, # query in unified attn + q: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor, - positions: torch.Tensor, - q_scale: Optional[torch.Tensor], - qkv: Optional[torch.Tensor], + positions: torch.Tensor = None, + q_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: # kv_cache = self.kv_cache forward_context: ForwardContext = get_forward_context() @@ -577,7 +609,7 @@ def forward( if forward_context.context.is_dummy_run: # dummy run: skip real attention and return output_shape = list(q.shape) - output_shape[-1] = 7168 + output_shape[-1] = self.num_heads * self.v_head_dim atom_config = get_current_atom_config() output_dtype = atom_config.torch_dtype output = torch.empty(output_shape, dtype=output_dtype, device=q.device) @@ -647,6 +679,41 @@ def forward( return output + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, # query in unified attn + k_nope: torch.Tensor, + k_rope: torch.Tensor, + kv_cache: torch.Tensor = None, + attn_metadata=None, + positions: torch.Tensor = None, + q_scale: Optional[torch.Tensor] = None, + output: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + if is_plugin_mode(): + # forward impl method are added by the decorator + # MLAAttentionImplDecoratorForPluginMode + return self.forward_impl_plugin_mode( + layer=layer, + q=query, + k_c_normed=k_nope, + k_pe=k_rope, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=output, + ) + else: + # only for server mode, keep the original method + return self.forward_impl_server_mode( + q=query, + k_nope=k_nope, + k_rope=k_rope, + positions=positions, + q_scale=q_scale, + ) + @triton.jit def _convert_req_index_to_global_index_kernel( diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index 68a8567d5..f91c73e27 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -import itertools from typing import Type import aiter diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index fa4ec6c36..168e0eabe 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -19,6 +19,10 @@ from .backends import AttentionBackend, CommonAttentionBuilder +from atom.plugin.prepare import is_plugin_mode +from atom.plugin.attention import AiterMLAAttentionMetadataBuilderDecoratorForPluginMode +from atom.plugin.attention import AiterBackendDecoratorForPluginMode + logger = logging.getLogger("atom") @@ -26,10 +30,11 @@ def cdiv(a, b): return (a + b - 1) // b +@AiterBackendDecoratorForPluginMode class AiterMLABackend(AttentionBackend): @staticmethod def get_name() -> str: - return "ROCM_AITER_MLA" + return "ROCM_AITER_MLA" if not is_plugin_mode() else "CUSTOM" @staticmethod def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]: @@ -40,11 +45,14 @@ def get_impl_cls() -> Type["MLAAttention"]: return MLAAttention +@AiterMLAAttentionMetadataBuilderDecoratorForPluginMode( + default_base_class=CommonAttentionBuilder +) class AiterMLAMetadataBuilder(CommonAttentionBuilder): def __init__(self, model_runner): self.block_size = 1 - super().__init__(model_runner) + CommonAttentionBuilder.__init__(self, model_runner) config = model_runner.config hf_config = config.hf_config self.num_attention_heads = ( @@ -190,7 +198,7 @@ def prepare_mtp_decode(self, bs: int, max_seqlen_q: int, max_seqlen_k: int): return self.set_mla_persistent_worker_buffers(bs, max_seqlen_q) def prepare_prefill(self, batch: ScheduledBatch): - attn_metadata, positions = super().prepare_prefill(batch) + attn_metadata, positions = CommonAttentionBuilder.prepare_prefill(self, batch) bs = batch.total_seqs_num_prefill sum_scheduled_tokens = batch.total_tokens_num_prefill var = self.model_runner.forward_vars diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index 3660b3ebd..116484a9d 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -218,7 +218,15 @@ def unified_attention_with_output_base( atom_config = get_current_atom_config() self = atom_config.compilation_config.static_forward_context[layer_name] if use_mla: - return self.impl.forward(q, k, v, positions, q_scale, qkv) + output = self.impl.forward( + layer=self, + query=q, + k_nope=k, + k_rope=v, + positions=positions, + q_scale=q_scale, + ) + return self.impl.o_proj(output) else: return self.impl.forward( layer=self, diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index a3d7b4ef7..53a8897d1 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -385,8 +385,9 @@ def forward( if self.quant_type.value == QuantType.per_1x128.value: quant_func = functools_partial(quant_func, transpose_scale=True) if self.quant_type.value != QuantType.per_1x32.value: + # quant_func will call view, so we need to call contiguous to avoid view error x, x_scale = quant_func( - x, + x.contiguous(), quant_dtype=self.params_dtype, scale=getattr(self, "input_scale", None), ) diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index ce2ac8635..b9c742d8d 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -3,7 +3,7 @@ # from flash_attn import flash_attn_with_kvcache from typing import Optional - +import os import torch from torch import nn @@ -60,15 +60,15 @@ def __init__( **kwargs, ) + self.use_mla = use_mla # for plugin mode if is_vllm(): - self.use_mla = use_mla - self.rotary_emb = rotary_emb + self.rotary_emb = mla_modules.rotary_emb try: - from vllm.attention.layer import Attention, AttentionType + from vllm.attention.layer import Attention, MLAAttention, AttentionType except ImportError: - from vllm.model_executor.layers.attention import Attention + from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.v1.attention.backend import AttentionType atom_config = get_current_atom_config() @@ -88,28 +88,85 @@ def __init__( extra_impl_args["rotary_emb"] = rotary_emb extra_impl_args["q_norm"] = q_norm extra_impl_args["k_norm"] = k_norm - - self.attn = Attention( - num_heads=num_heads, - head_size=head_dim, - scale=scale, - num_kv_heads=num_kv_heads, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - quant_config=quant_config, - logits_soft_cap=None, - per_layer_sliding_window=per_layer_sliding_window, - prefix=f"{prefix}", - attn_type=AttentionType.DECODER, - kv_sharing_target_layer_name=None, - **extra_impl_args, - ) + if use_mla: + extra_impl_args["layer_num"] = layer_num + extra_impl_args["mla_modules"] = mla_modules + + if use_mla: + assert ( + mla_modules.indexer is None + ), "MLAAttention is not supported for sparse mode" + self.num_heads = num_heads + self.v_head_dim = mla_modules.v_head_dim + self.qk_head_dim = mla_modules.qk_head_dim + self.qk_nope_head_dim = mla_modules.qk_nope_head_dim + self.q_proj = mla_modules.q_proj + self.o_proj = mla_modules.o_proj + + self.attn = MLAAttention( + num_heads=num_heads, + scale=scale, + qk_nope_head_dim=mla_modules.qk_nope_head_dim, + qk_rope_head_dim=mla_modules.qk_rope_head_dim, + v_head_dim=mla_modules.v_head_dim, + q_lora_rank=mla_modules.q_lora_rank, + kv_lora_rank=mla_modules.kv_lora_rank, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + kv_b_proj=mla_modules.kv_b_proj, + use_sparse=False, + indexer=mla_modules.indexer, + **extra_impl_args, + ) + + def wrap_kv_b_proj(module_instance): + orig_impl = module_instance.forward + + def new_forward(*args, **kwargs): + out = orig_impl(*args, **kwargs) + if ( + os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() + == "0" + ): + return out + return out, None + + module_instance.forward = new_forward + return module_instance + + # vllm kv_b_proj return two values (output, bias), so we need to wrap it for fallback path. + self.attn.impl.kv_b_proj = wrap_kv_b_proj(self.attn.impl.kv_b_proj) + else: + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=None, + per_layer_sliding_window=per_layer_sliding_window, + prefix=f"{prefix}", + attn_type=AttentionType.DECODER, + kv_sharing_target_layer_name=None, + **extra_impl_args, + ) compilation_config = atom_config.compilation_config self.layer_name = prefix if self.layer_name in compilation_config.static_forward_context: raise ValueError("Duplicate layer: {}".format(self.layer_name)) compilation_config.static_forward_context[self.layer_name] = self + + if self.use_mla: + max_num_tokens = ( + atom_config.plugin_config.vllm_scheduler_config.max_num_batched_tokens + ) + compilation_config.static_forward_context["positions"] = torch.zeros( + max_num_tokens, dtype=torch.int64, device="cuda" + ) return self.num_heads = num_heads @@ -122,7 +179,6 @@ def __init__( self.k_scale = self.v_scale = None self.layer_num = layer_num self.mla_modules = mla_modules - self.use_mla = use_mla self.base_attention = None self.kv_cache = torch.tensor([]) self.indexer = mla_modules.indexer if mla_modules is not None else None @@ -136,7 +192,7 @@ def __init__( use_mla=self.use_mla, ) impl_cls = self.attn_backend.get_impl_cls() - self.impl = impl_cls( + impl_args = dict( num_heads=num_heads, head_dim=head_dim, scale=scale, @@ -153,7 +209,8 @@ def __init__( k_norm=k_norm, **kwargs, ) - + impl_args["head_size" if self.use_mla else "head_dim"] = head_dim + self.impl = impl_cls(**impl_args) compilation_config = atom_config.compilation_config default_name = f"MLA_{layer_num}" if self.use_mla else f"MHA_{layer_num}" self.layer_name = prefix if prefix is not None else default_name diff --git a/atom/model_ops/utils.py b/atom/model_ops/utils.py index c20154a97..1f9c19649 100644 --- a/atom/model_ops/utils.py +++ b/atom/model_ops/utils.py @@ -135,14 +135,6 @@ def all_close_1d(x: torch.Tensor) -> bool: return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) -def per_tensor_dequantize( - tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] -) -> torch.Tensor: - fake_qweight = tensor.to(torch.float16) - dq_weight = fake_qweight * inv_scale - return dq_weight - - def get_and_maybe_dequant_weights(layer: nn.Module) -> torch.Tensor: if layer.quant_type != QuantType.No: # NOTE: This should only be used offline, since it's O(N^3) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index f0342dce1..93414d383 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -24,7 +24,7 @@ """Inference-only DeepseekV2/DeepseekV3 model.""" import logging -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Iterable import torch from aiter import ( @@ -1823,6 +1823,7 @@ def __init__( layer_type: type[nn.Module] = DeepseekV2DecoderLayer, ): super().__init__() + self.atom_config = atom_config config = atom_config.hf_config quant_config = atom_config.quant_config self.config = config @@ -1899,6 +1900,15 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + from atom.model_loader.loader import load_model_in_plugin_mode + + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index 2b136ebc8..861c7af06 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -1,5 +1,6 @@ -from typing import Generic, Optional +from typing import Generic, Optional, TypeVar import logging +import os from dataclasses import dataclass @@ -9,6 +10,7 @@ from atom.utils import CpuGpuBuffer from atom.utils.forward_context import Context, AttentionMetaData from atom.model_ops.attention_mha import PagedAttentionImpl +from atom.model_ops.attention_mla import MLAAttention logger = logging.getLogger("atom") @@ -16,6 +18,9 @@ _CP_TOKENS_PER_ITER_ROCM = 32 * 1024 +_MLA_ATTENTION_FOR_PLUGIN_MODE = True + + @dataclass class AiterFlashAttentionDecodeMetadata: max_query_len: int @@ -67,7 +72,7 @@ class AiterFlashAttentionChunkPrefillMetadata: @dataclass -class MetadataForPluginMode: +class AiterFlashAttentionMetadataForPluginMode: # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -104,7 +109,7 @@ class MetadataForPluginMode: context: Optional[Context] = None -class vllmAiterBackendMethods: +class vllmAiterAttentionBackendMethods: # here attention in ATOM doesn't accept the output buffer because # ATOM works as a model impl backend, it needs the maximum freedom # to decide the output buffer and shape, thus here use this flag to @@ -157,33 +162,6 @@ def supports_alibi_sqrt(cls) -> bool: return False -def AiterBackendDecoratorForPluginMode(cls): - """ - Decorator for AiterBackend to add specific methods and attributes for plugin mode - """ - is_vllm_mode = is_vllm() - - if is_vllm_mode: - # for vllm, add the required methods - cls.full_cls_name = vllmAiterBackendMethods.full_cls_name - cls.accept_output_buffer = vllmAiterBackendMethods.accept_output_buffer - cls.supported_dtypes = vllmAiterBackendMethods.supported_dtypes - cls.forward_includes_kv_cache_update = ( - vllmAiterBackendMethods.forward_includes_kv_cache_update - ) - cls.get_supported_kernel_block_sizes = ( - vllmAiterBackendMethods.get_supported_kernel_block_sizes - ) - cls.get_kv_cache_shape = vllmAiterBackendMethods.get_kv_cache_shape - cls.is_mla = vllmAiterBackendMethods.is_mla - cls.get_required_kv_cache_layout = ( - vllmAiterBackendMethods.get_required_kv_cache_layout - ) - cls.get_supported_head_sizes = vllmAiterBackendMethods.get_supported_head_sizes - cls.supports_alibi_sqrt = vllmAiterBackendMethods.supports_alibi_sqrt - return cls - - def create_attn_metadata_builder_init_method(base_class): """ Create the init method for metadata builder @@ -503,7 +481,7 @@ def build( graph_bs=context_graph_bs, ) - attn_metadata_for_plugin_mode = MetadataForPluginMode( + attn_metadata_for_plugin_mode = AiterFlashAttentionMetadataForPluginMode( num_actual_tokens=num_actual_tokens, num_actual_kv_tokens=num_actual_kv_tokens, max_query_len=common_attn_metadata.max_query_len, @@ -616,6 +594,750 @@ def decorator(cls): return decorator +# for MLA attention metadata for plugin mode +if _MLA_ATTENTION_FOR_PLUGIN_MODE: + + @dataclass + class AiterMLACommonDecodeMetadataForPluginMode: + block_table: torch.Tensor + seq_lens: torch.Tensor + dcp_tot_seq_lens: torch.Tensor | None + + @dataclass + class AiterMLADecodeMetadataForPluginMode( + AiterMLACommonDecodeMetadataForPluginMode + ): + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: torch.Tensor | None = None + # The page indices of the paged kv cache + paged_kv_indices: torch.Tensor | None = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: torch.Tensor | None = None + # The query indptr, shape : [num_decode + 1] + qo_indptr: torch.Tensor | None = None + # The dtype of MLA out tensor + attn_out_dtype: torch.dtype = torch.bfloat16 + # The max query output length: int + max_qo_len: int | None = None + + @dataclass + class AiterMLACommonPrefillMetadataForPluginMode: + """Prefill Specific Metadata""" + + @dataclass + class AiterMLAChunkedContextMetadataForPluginMode: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + seq_lens: torch.Tensor + workspace: torch.Tensor + token_to_seq: torch.Tensor + chunk_total_token: list[int] + + # for mla DCP + padded_local_chunk_seq_lens: list[list[int]] | None = None + local_context_lens_allranks: list[list[int]] | None = None + padded_local_cu_seq_lens: torch.Tensor | None = None + cu_seq_lens_lst: list[list[int]] | None = None + chunk_size: int | None = None + + block_table: torch.Tensor + query_start_loc: torch.Tensor + max_query_len: int + chunked_context: AiterMLAChunkedContextMetadataForPluginMode | None = None + query_seq_lens: torch.Tensor | None = None + workspace_buffer: torch.Tensor | None = None + q_data_type: torch.dtype | None = None + + D = TypeVar("D", bound=AiterMLACommonDecodeMetadataForPluginMode) + + @dataclass + class AiterMLACommonMetadataForPluginMode(Generic[D]): + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # The dimension of the attention heads + head_dim: int | None = None + + decode: D | None = None + prefill: AiterMLACommonPrefillMetadataForPluginMode | None = None + + def __post_init__(self): + pass + # if self.head_dim is not None and not MLACommonBackend.supports_head_size( + # self.head_dim + # ): + # raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.") + + class vllmMLAAttentionMetadataBuilderMethods: + def __init__(self): + raise TypeError( + f"{self.__class__.__name__} is a utility class and should not be instantiated. " + "Its methods are meant to be added to other classes via decorators." + ) + + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_device: torch.Tensor, + max_seq_len: int, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + dcp_tot_seq_lens_device: torch.Tensor | None, + ): + # kernel block size is always 1, although the kv block size is not 1. + device = self.device + num_reqs = seq_lens_device.size(0) + + mask = torch.arange( + block_table_tensor.size(1), + dtype=block_table_tensor.dtype, + device=device, + ).unsqueeze(0) < seq_lens_device.unsqueeze(1) + paged_kv_indices = block_table_tensor[mask] + + # kernel block size is always 1, so each page has exactly 1 token. + # last_page_len is always 1 - just slice the pre-initialized buffer. + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] + + paged_kv_indptr = torch.cat( + [ + torch.zeros(1, dtype=seq_lens_device.dtype, device=device), + seq_lens_device.cumsum(dim=0, dtype=torch.int32), + ] + ) + qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_qo_len = qo_len.max().item() + + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + num_actual_pages = paged_kv_indices.size(0) + + self.paged_kv_indices[:num_actual_pages].copy_( + paged_kv_indices, non_blocking=True + ) + self.paged_kv_indices[num_actual_pages:].fill_(-1) + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + + self.paged_kv_indptr[: 1 + num_reqs].copy_( + paged_kv_indptr, non_blocking=True + ) + self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1]) + paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs] + + # paged_kv_last_page_len already uses the pre-initialized buffer slice + # (set above), so no copy needed - buffer is always 1s. + + self.qo_indptr[: 1 + num_reqs].copy_( + query_start_loc_device, non_blocking=True + ) + self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1] + qo_indptr = self.qo_indptr[: 1 + num_reqs] + + else: + qo_indptr = torch.arange( + 0, num_reqs + 1, step=1, dtype=torch.int32, device=device + ) + + attn_metadata = AiterMLADecodeMetadataForPluginMode( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + qo_indptr=qo_indptr, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, + max_qo_len=max_qo_len, + attn_out_dtype=self.decode_attn_out_dtype, + ) + + return attn_metadata + + def build_for_cudagraph_capture( + self, + common_attn_metadata=None, + ): + m = common_attn_metadata + # assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), ( + # "MLA only supports decode-only full CUDAGraph capture. " + # "Make sure all cudagraph capture sizes <= max_num_seq." + # ) + + # assert m.max_query_len <= self.reorder_batch_threshold # decode only + + return self.build(0, m) + + def build( + self, + common_prefix_len: int = 0, + common_attn_metadata=None, + fast_build: bool = False, + ): + + from vllm.v1.attention.backends.utils import split_decodes_and_prefills + from vllm.model_executor.layers.attention.mla_attention import ( + QueryLenSupport, + ) + + from vllm.utils.math_utils import cdiv, round_down + from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens + + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens + dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=(self.query_len_support != QueryLenSupport.VARLEN), + ) + ) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + prefill_metadata = None + if num_prefills > 0: + num_computed_tokens_cpu = ( + common_attn_metadata.compute_num_computed_tokens().cpu() + ) + + reqs_start = num_decodes # prefill_start + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + prefill_query_start_loc = ( + query_start_loc[reqs_start:] - query_start_loc[reqs_start] + ) + + chunked_context_metadata = None + if max_context_len_cpu > 0: + # NOTE: it is recommend you read the `Chunked Prefill` section + # in the comment at the top of the file before trying to + # understand the following code + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = ( + self.chunked_prefill_workspace_size + // num_prefills_with_context_cpu + ) + + if self.aot_schedule: + # align max_context_chunk to page_size by rounding down, + # currently the `gather_and_maybe_dequant_cache` kernel + # cannot handle `context_chunk_starts` that are not aligned + # to page_size + max_context_chunk = round_down( + max_context_chunk, self.page_size + ) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks + # like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + # Note(simon): this is done in CPU because of downstream's + # of `to_list`. + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) + * max_context_chunk + ) + chunk_ends = torch.min( + context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk + ) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + + cu_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32, + ) + chunk_total_token = cu_seq_lens_cpu[:, -1] + + max_token_num_over_chunk = chunk_total_token.max().item() + token_to_seq_tensor_cpu = torch.zeros( + [num_chunks, max_token_num_over_chunk], dtype=torch.int32 + ) + range_idx = torch.arange(num_prefills, dtype=torch.int32) + for i in range(num_chunks): + chunk_token_to_seq_tensor = torch.repeat_interleave( + range_idx, chunk_seq_lens[i] + ) + chunk_len = chunk_token_to_seq_tensor.shape[0] + token_to_seq_tensor_cpu[i, :chunk_len] = ( + chunk_token_to_seq_tensor + ) + + if self.dcp_world_size > 1: + local_context_lens_allranks = get_dcp_local_seq_lens( + context_lens_cpu, + self.dcp_world_size, + None, + self.dcp_local_block_size, + ) + # Note(qcs): The max local context lengths + # padded to `dcp_local_block_size`. + padded_local_context_lens_cpu: torch.Tensor = ( + cdiv( + context_lens_cpu, + self.dcp_virtual_block_size, + ) + * self.dcp_local_block_size + ) + # Note(hc): The above max_context_chunk already enforces + # block_size alignment, DCP just need the block_size can + # be divisible by dcp_world_size, because DCP use + # cp_gather_cache which not require `cp_chunk_starts` + # aligned to page_size. + assert max_context_chunk % self.dcp_world_size == 0 + padded_local_max_context_chunk_across_ranks = ( + cdiv( + max_context_chunk, + self.dcp_virtual_block_size, + ) + * self.dcp_local_block_size + ) + local_chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) + * padded_local_max_context_chunk_across_ranks + ) + local_chunk_ends = torch.min( + padded_local_context_lens_cpu.unsqueeze(0), + local_chunk_starts + + padded_local_max_context_chunk_across_ranks, + ) + padded_local_chunk_seq_lens = ( + local_chunk_ends - local_chunk_starts + ).clamp(min=0) + + padded_local_cu_chunk_seq_lens_cpu = torch.zeros( + num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True, + ) + torch.cumsum( + padded_local_chunk_seq_lens, + dim=1, + out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], + dtype=torch.int32, + ) + + chunked_context_metadata_cls = ( + AiterMLACommonPrefillMetadataForPluginMode.AiterMLAChunkedContextMetadataForPluginMode + ) + if self.dcp_world_size > 1: + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=local_chunk_starts.to(device, non_blocking=True), + seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_seq=token_to_seq_tensor_cpu.to( + device, non_blocking=True + ), + chunk_total_token=chunk_total_token.tolist(), + workspace=self.chunked_prefill_workspace, + padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), + local_context_lens_allranks=local_context_lens_allranks.tolist(), + padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to( + device, non_blocking=True + ), + cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), + chunk_size=padded_local_max_context_chunk_across_ranks, + ) + else: + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_seq=token_to_seq_tensor_cpu.to( + device, non_blocking=True + ), + chunk_total_token=chunk_total_token, + workspace=self.chunked_prefill_workspace, + ) + + if self._use_cudnn_prefill: + chunked_context_metadata.seq_lens = chunk_seq_lens + + assert ( + max(chunked_context_metadata.max_seq_lens) + <= self.chunked_prefill_workspace_size + ) + + prefill_metadata = AiterMLACommonPrefillMetadataForPluginMode( + block_table=block_table_tensor[reqs_start:, ...], + query_start_loc=prefill_query_start_loc, + max_query_len=max_query_len, + chunked_context=chunked_context_metadata, + ) + + decode_metadata = None + if num_decodes > 0: + dcp_tot_seq_lens_device = None + if self.dcp_world_size > 1: + dcp_tot_seq_lens_device = seq_lens[:num_decodes] + seq_lens = dcp_local_seq_lens + + # After DCP distribution, the maximum number of tokens for any rank is + # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size, + # and I is cp_kv_cache_interleave_size. + # This eliminates GPU->CPU sync while minimizing workspace + # over-allocation. + num_partitions = ( + self.dcp_world_size * self.cp_kv_cache_interleave_size + ) + max_seq_len = ( + (max_seq_len + num_partitions - 1) // num_partitions + ) * self.cp_kv_cache_interleave_size + + decode_metadata = self._build_decode( + block_table_tensor=block_table_tensor[:num_decodes, ...], + seq_lens_device=seq_lens[:num_decodes], + max_seq_len=max_seq_len, + query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], + query_start_loc_device=query_start_loc[: num_decodes + 1], + num_decode_tokens=num_decode_tokens, + dcp_tot_seq_lens_device=dcp_tot_seq_lens_device, + ) + + attn_metadata_for_plugin_mode = AiterMLACommonMetadataForPluginMode( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=max_seq_len, + num_actual_tokens=num_tokens, + query_start_loc=query_start_loc, + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + # MLACommonMetadata Chunk prefill specific + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + attn_metadata = AttentionMetaData( + max_seqlen_q=common_attn_metadata.max_query_len, + block_tables=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + plugin_metadata=attn_metadata_for_plugin_mode, + ) + return attn_metadata + + def create_mla_attn_metadata_builder_init_method(base_class): + """ + Create the init method for metadata builder + """ + + def init_method_under_plugin_mode( + self, + kv_cache_spec=None, + layer_names=None, + config=None, + device=None, + model_runner=None, + ): + base_class.__init__(self, kv_cache_spec, layer_names, config, device) + logger.info("init AiterAttentionMetadataBuilder for plugin mode") + from vllm.config import VllmConfig + + assert isinstance(config, VllmConfig) + + self.vllm_config = config + self.model_config = config.model_config + self.parallel_config = config.parallel_config + self.cache_config = config.cache_config + + self.compilation_config = self.vllm_config.compilation_config + self.decode_attn_out_dtype = self.vllm_config.model_config.dtype + # kernel block size is always 1. + max_num_pages_per_req = self.vllm_config.model_config.max_model_len + max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs + max_num_pages = max_num_reqs * max_num_pages_per_req + + # Preparing persistent buffers + # TODO: we can disambiguate between decode and mixed-prefill decode here + # so we can only use the persistent buffer if a cudagraph is actually + # being used. + + # paged_kv_last_page_len is always 1s (kernel block size is always 1), + # so we create it once and reuse slices in both eager and cudagraph modes. + self.paged_kv_last_page_len = torch.ones( + max_num_reqs, dtype=torch.int32, device=device + ) + + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device + ) + self.paged_kv_indices = torch.zeros( + max_num_pages, dtype=torch.int32, device=device + ) + + self.qo_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device + ) + + return init_method_under_plugin_mode + + def setup_mla_attn_metadata_builder_base_class_and_attributes(class_dict: dict): + """ + Setup the base class and attributes for attention metadata builder + """ + from vllm.model_executor.layers.attention.mla_attention import ( + MLACommonMetadataBuilder, + QueryLenSupport, + ) + from vllm.v1.attention.backend import AttentionCGSupport + + base_class = MLACommonMetadataBuilder + generic_base = MLACommonMetadataBuilder + needs_generic = True + + # align with vllm rocm aiter fa + class_dict["_cudagraph_support"] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) + class_dict["reorder_batch_threshold"] = 1 + class_dict["query_len_support"] = QueryLenSupport.UNIFORM + + return base_class, generic_base, needs_generic, class_dict + + def AiterMLAAttentionMetadataBuilderDecoratorForPluginMode(default_base_class): + def decorator(cls): + is_vllm_mode = is_vllm() + is_sglang_mode = is_sglang() + + base_class = default_base_class + class_dict = {} + + # record original decorated cls methods + for key, value in cls.__dict__.items(): + if not key.startswith("__") or key in ( + "__annotations__", + "__init__", + "__module__", + "__qualname__", + "__doc__", + ): + class_dict[key] = value + + # handle the generic base class + needs_generic = False + generic_base = None + + if is_vllm_mode: + # get the base class and generic base class + base_class, generic_base, needs_generic, class_dict = ( + setup_mla_attn_metadata_builder_base_class_and_attributes( + class_dict + ) + ) + + # replace the __init__ method in the decorated class + class_dict["__init__"] = create_mla_attn_metadata_builder_init_method( + base_class + ) + + # add the methods to the decorated class + for method_name in dir(vllmMLAAttentionMetadataBuilderMethods): + if not method_name.startswith("__"): + method = getattr( + vllmMLAAttentionMetadataBuilderMethods, method_name + ) + if callable(method): + class_dict[method_name] = method + elif is_sglang_mode: + raise NotImplementedError( + "AttentionMetadataBuilder for sglang is not implemented yet" + ) + + # create the new class + new_class = type(cls.__name__, (base_class,), class_dict) + + # replace the inherit base class for plugin mode, meanwhile support generic base class + if needs_generic and generic_base is not None: + new_class.__orig_bases__ = (generic_base[new_class],) + + return new_class + + return decorator + + class vllmAiterMLABackendMethods: + # here attention in ATOM doesn't accept the output buffer because + # ATOM works as a model impl backend, it needs the maximum freedom + # to decide the output buffer and shape, thus here use this flag to + # let vllm don't allocate the output buffer for ATOM. ATOM will + # handle the output buffer by itself + accept_output_buffer: bool = True + supported_dtypes: list = [torch.float16, torch.bfloat16] + + def __init__(self): + raise TypeError( + f"{self.__class__.__name__} is a utility class and should not be instantiated. " + "Its methods are meant to be added to other classes via decorators." + ) + + @staticmethod + def get_supported_kernel_block_sizes(): + return [1] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @classmethod + def is_mla(cls) -> bool: + return True + + @staticmethod + def get_required_kv_cache_layout(): + return None + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + @classmethod + def full_cls_name(cls) -> tuple[str, str]: + return (cls.__module__, cls.__qualname__) + + @classmethod + def supports_alibi_sqrt(cls) -> bool: + return False + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + # (num_blocks, num_layers, block_size, head_size) + return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2) + + +def AiterBackendDecoratorForPluginMode(cls): + """ + Decorator for AiterBackend to add specific methods and attributes for plugin mode + """ + is_vllm_mode = is_vllm() + + if is_vllm_mode: + # for vllm, add the required methods + if not issubclass(cls.get_impl_cls(), MLAAttention): + + cls.full_cls_name = vllmAiterAttentionBackendMethods.full_cls_name + cls.accept_output_buffer = ( + vllmAiterAttentionBackendMethods.accept_output_buffer + ) + cls.supported_dtypes = vllmAiterAttentionBackendMethods.supported_dtypes + cls.forward_includes_kv_cache_update = ( + vllmAiterAttentionBackendMethods.forward_includes_kv_cache_update + ) + cls.get_supported_kernel_block_sizes = ( + vllmAiterAttentionBackendMethods.get_supported_kernel_block_sizes + ) + cls.get_kv_cache_shape = vllmAiterAttentionBackendMethods.get_kv_cache_shape + cls.is_mla = vllmAiterAttentionBackendMethods.is_mla + cls.get_required_kv_cache_layout = ( + vllmAiterAttentionBackendMethods.get_required_kv_cache_layout + ) + cls.get_supported_head_sizes = ( + vllmAiterAttentionBackendMethods.get_supported_head_sizes + ) + cls.supports_alibi_sqrt = ( + vllmAiterAttentionBackendMethods.supports_alibi_sqrt + ) + else: + # for mla, add the required methods + cls.full_cls_name = vllmAiterMLABackendMethods.full_cls_name + cls.accept_output_buffer = vllmAiterMLABackendMethods.accept_output_buffer + cls.supported_dtypes = vllmAiterMLABackendMethods.supported_dtypes + cls.get_supported_kernel_block_sizes = ( + vllmAiterMLABackendMethods.get_supported_kernel_block_sizes + ) + cls.get_kv_cache_shape = vllmAiterMLABackendMethods.get_kv_cache_shape + cls.is_mla = vllmAiterMLABackendMethods.is_mla + cls.get_required_kv_cache_layout = ( + vllmAiterMLABackendMethods.get_required_kv_cache_layout + ) + cls.get_supported_head_sizes = ( + vllmAiterMLABackendMethods.get_supported_head_sizes + ) + cls.supports_alibi_sqrt = vllmAiterMLABackendMethods.supports_alibi_sqrt + cls.get_kv_cache_stride_order = ( + vllmAiterMLABackendMethods.get_kv_cache_stride_order + ) + return cls + + # here not register it as a custom op and mark split because vllm # will register attention impl forward as a custom op, so here # avoid duplicated registration, and the split op is registered @@ -635,7 +1357,27 @@ def unified_attention_with_output_base_for_plugin_mode( atom_config = get_current_atom_config() if use_mla: - raise NotImplementedError("MLA is not supported for plugin mode for now") + # raise NotImplementedError("MLA is not supported for plugin mode for now") + kv_c_normed = k + k_pe = v + self = atom_config.compilation_config.static_forward_context[layer_name] + q = self.q_proj(q, q_scale) + q = q.view(-1, self.num_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe + if os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "1": + k_pe = k_pe.unsqueeze(1) + if self.rotary_emb is not None: + q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim :], k_pe + ) + # positions written at model entry (model_wrapper.forward) + output = self.attn( + q, + kv_c_normed, + k_pe, + output_shape=(q.shape[0], self.num_heads * self.v_head_dim), + ) + return self.o_proj(output) else: self = atom_config.compilation_config.static_forward_context[layer_name] # here is the standard vllm attention impl interface diff --git a/atom/plugin/attention_mla.py b/atom/plugin/attention_mla.py new file mode 100644 index 000000000..093579fca --- /dev/null +++ b/atom/plugin/attention_mla.py @@ -0,0 +1,834 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +""" +Plugin mode extensions for MLAAttention. +This module provides additional methods for MLAAttention when running in plugin mode. +""" + +import torch +import aiter +from aiter import dtypes +from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 + +from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm, +) +from aiter.mla import mla_decode_fwd + +from functools import partial as functools_partial +from atom.config import get_current_atom_config +from atom.model_ops.linear import use_triton_gemm + +import logging + +logger = logging.getLogger("atom") + + +if use_triton_gemm(): + try: + from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import ( + fused_gemm_a8w8_blockscale_preshuffle_split_cat, + ) + from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import ( + fused_gemm_afp4wfp4_preshuffle_split_cat, + ) + except ImportError as e: + logger.warning(f"Triton fused GEMM split_cat not available: {e}") + fused_gemm_afp4wfp4_preshuffle_split_cat = None + fused_gemm_a8w8_blockscale_preshuffle_split_cat = None + + +def reorg_kvcache( + allgatered_kv_c_normed: torch.Tensor, + allgatered_k_pe: torch.Tensor, + padded_local_chunk_seq_lens_lst: list[int], + local_context_lens_allranks: list[list[int]], + sum_seq_len: int, + max_seq_len: int, + chunk_size: int, + chunk_idx: int, + toks: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + reorg and unpad kvcache after cp local gather to tp layout for attn kernel. + e.g. + allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ..., + T0_4, T0_5, pad, pad, T1_2, pad, ...] + -> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5, + T1_0, T1_1, T1_2, ...] + Args: + padded_local_chunk_seq_lens_lst: local chunk context lengths + under current CP rank. + local_context_lens_allranks: local context lengths on each CP rank. + sum_seq_len: the sum of cp_chunk_seq_lens_lst. + max_seq_len: the max value of cp_chunk_seq_lens_lst. + chunk_size: the local padded max context chunk from + chunked_context_metadata building. + chunk_idx: chunk idx of chunked_prefill. + toks: the number of tokens for local gather cache. + """ + kv_c_segments = [] + k_pe_segments = [] + src_token_idx = 0 + max_seq_len_check = 0 + for padded_local_chunk_seq_len, local_context_lens in zip( + padded_local_chunk_seq_lens_lst, local_context_lens_allranks + ): + cur_seq_len = 0 + for rank, local_context_len in enumerate(local_context_lens): + # Note(qcs): We split the context into multiple chunks, + # depending on the size of the workspace. + # local_context in dcp0: |-----------------| + # local_context in dcp1: |--------------| + # n*padded_local_chunk: |-----|-----|-----| + # local_chunk_len in dcp1: |-----|-----|--| + # so we need update the last chunk length in dcp1. + local_chunk_len = min( + max(0, local_context_len - chunk_idx * chunk_size), + padded_local_chunk_seq_len, + ) + if local_chunk_len != 0: + kv_c_segment = allgatered_kv_c_normed[ + rank * toks + + src_token_idx : rank * toks + + src_token_idx + + local_chunk_len + ] + k_pe_segment = allgatered_k_pe[ + rank * toks + + src_token_idx : rank * toks + + src_token_idx + + local_chunk_len + ] + kv_c_segments.append(kv_c_segment) + k_pe_segments.append(k_pe_segment) + cur_seq_len += local_chunk_len + max_seq_len_check = max(max_seq_len_check, cur_seq_len) + src_token_idx += padded_local_chunk_seq_len + reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) + reorganized_k_pe = torch.cat(k_pe_segments, dim=0) + assert reorganized_kv_c_normed.shape[0] == sum_seq_len + assert reorganized_k_pe.shape[0] == sum_seq_len + assert max_seq_len_check == max_seq_len + return reorganized_kv_c_normed, reorganized_k_pe + + +class MLAAttentionImplPluginModeMethods: + """ + Container class for plugin mode methods. + This class cannot be instantiated - it only serves as a namespace for methods + that will be added to PagedAttentionImpl via decorator. + """ + + def __init__(self): + raise TypeError( + "MLAAttentionImplPluginModeMethods cannot be instantiated. " + "It is only used as a method container for the decorator." + ) + + def _concat_k_nope_k_pe_plugin_mode( + self, k_nope: torch.Tensor, k_pe: torch.Tensor + ) -> torch.Tensor: + """ + Efficiently concatenate k_nope and k_pe tensors along the last dimension. + + This function avoids the performance penalty of torch.cat with expanded + non-contiguous tensors by pre-allocating the output and using direct copies. + + Args: + k_nope: Tensor of shape [..., nope_dim] + k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim] + or [..., pe_dim] + + Returns: + Tensor of shape [..., nope_dim + pe_dim] + """ + k = torch.empty( + (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]), + dtype=k_nope.dtype, + device=k_nope.device, + ) + # Direct copies with efficient broadcasting + k[..., : k_nope.shape[-1]] = k_nope + k[..., k_nope.shape[-1] :] = k_pe + return k + + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + return_lse=return_softmax_lse, + **kwargs, + ) + + return output + + def _run_prefill_new_tokens_plugin_mode(self, prefill, q, k, v, return_softmax_lse): + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.query_start_loc, + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.max_query_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=return_softmax_lse, + ) + + def _run_prefill_context_chunk_plugin_mode(self, prefill, chunk_idx, q, k, v): + assert prefill.chunked_context is not None + assert prefill.chunked_context is not None + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx], + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + + def _context_parallel_compute_prefill_context_plugin_mode( + self, + q, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale, + dcp_world_size, + ): + assert k_scale is None, "DCP not support scaled kvcache now." + assert attn_metadata.plugin_metadata.prefill is not None + prefill_metadata = attn_metadata.plugin_metadata.prefill + assert prefill_metadata.chunked_context is not None + assert prefill_metadata.chunked_context.padded_local_chunk_seq_lens is not None + assert prefill_metadata.chunked_context.local_context_lens_allranks is not None + assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None + assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None + assert prefill_metadata.chunked_context.chunk_size is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + + from vllm import _custom_ops as ops + from vllm.distributed.parallel_state import get_dcp_group + from vllm.v1.attention.ops.merge_attn_states import merge_attn_states + + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + ops.cp_gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[ + i + ], + batch_size=attn_metadata.plugin_metadata.num_prefills, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + # workspace + # |------- N tokens --------|--------- N*dcp_size tokens ----------| + # |<- use for loca_gather ->|<--------- use for allgather -------->| + allgather_offset = workspace.shape[0] // (dcp_world_size + 1) + assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0] + assert toks <= allgather_offset + local_gathered_kvcache = workspace[:toks] + cur_allgather_workspace = workspace[ + allgather_offset : allgather_offset * (1 + dcp_world_size) + ] + assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] + cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size] + cur_allgather_kvcache.copy_( + get_dcp_group().all_gather(local_gathered_kvcache, dim=0) + ) + assert ( + cur_allgather_kvcache.shape[-1] + == self.kv_lora_rank + self.qk_rope_head_dim + ) + allgatered_kv_c_normed, allgatered_k_pe = cur_allgather_kvcache.unsqueeze( + 1 + ).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kv_c_normed, k_pe = reorg_kvcache( + allgatered_kv_c_normed, + allgatered_k_pe, + padded_local_chunk_seq_lens_lst=prefill_metadata.chunked_context.padded_local_chunk_seq_lens[ + i + ], + local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks, + sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], + max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], + chunk_size=prefill_metadata.chunked_context.chunk_size, + chunk_idx=i, + toks=toks, + ) + + kv_nope = self.kv_b_proj(kv_c_normed).view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = self._concat_k_nope_k_pe(k_nope, k_pe) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _compute_prefill_context_plugin_mode( + self, + q, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale, + ): + assert attn_metadata.plugin_metadata.prefill is not None + prefill_metadata = attn_metadata.plugin_metadata.prefill + assert prefill_metadata.chunked_context is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + + from vllm import _custom_ops as ops + from vllm.v1.attention.ops.merge_attn_states import merge_attn_states + + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + ops.gather_and_maybe_dequant_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], + token_to_seq=prefill_metadata.chunked_context.token_to_seq[i], + num_tokens=prefill_metadata.chunked_context.chunk_total_token[i], + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] + k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) + + kv_nope = self.kv_b_proj(kv_c_normed).view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = self._concat_k_nope_k_pe_plugin_mode(k_nope, k_pe) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk_plugin_mode( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _forward_prefill_plugin_mode( + self, + q, + kv_c_normed, + k_pe, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale, + output, + ): + # TODO (zyongye): Prefill function hereplugin_metadata + assert attn_metadata.plugin_metadata.prefill is not None + assert self.dcp_world_size != -1 + + has_context = attn_metadata.plugin_metadata.prefill.chunked_context is not None + + if use_triton_gemm(): + weight = self.kv_b_proj.weight + weight_scale = self.kv_b_proj.weight_scale + if ( + fused_gemm_afp4wfp4_preshuffle_split_cat is not None + and weight.dtype == dtypes.fp4x2 + ): # FP4 GEMM + split + cat + m = kv_c_normed.shape[0] + # from aiter.ops.triton.quant import dynamic_mxfp4_quant + # input = kv_c_normed + # input_2d = input.view(-1, input.shape[-1]) + output_dtype = kv_c_normed.dtype + + # q_input, x_scale = dynamic_mxfp4_quant(input_2d) + quant_func = aiter.get_hip_quant(aiter.QuantType.per_1x32) + q_input, x_scale = quant_func( + kv_c_normed, + quant_dtype=dtypes.fp4x2, + shuffle=(m >= 32), + ) + + if m >= 32: + x_scale = x_scale.view(torch.uint8).view(x_scale.shape[0] // 32, -1) + else: + x_scale = x_scale[:m, ...].view(torch.uint8) + + k, v = fused_gemm_afp4wfp4_preshuffle_split_cat( + q_input.view(torch.uint8), + weight.view(torch.uint8).view(weight.shape[0] // 16, -1), + k_pe.expand((-1, self.num_heads, -1)), + x_scale, + weight_scale.view(torch.uint8).view( + weight_scale.shape[0] // 32, -1 + ), + self.qk_nope_head_dim, + self.v_head_dim, + output_dtype, + ) + elif ( + fused_gemm_a8w8_blockscale_preshuffle_split_cat is not None + and weight.dtype == dtypes.fp8 + ): # FP8 GEMM + split + cat + weight_shuffled = weight.reshape( + weight.shape[0] // 16, weight.shape[1] * 16 + ) + + output_dtype = kv_c_normed.dtype + + quant_func = functools_partial( + aiter.get_hip_quant(aiter.QuantType.per_1x128), transpose_scale=True + ) + q_input, x_scale = quant_func( + kv_c_normed, + quant_dtype=dtypes.fp8, + scale=getattr(self.kv_b_proj, "input_scale", None), + ) + + k, v = fused_gemm_a8w8_blockscale_preshuffle_split_cat( + q_input, + weight_shuffled, + k_pe.expand((-1, self.num_heads, -1)), + x_scale, + weight_scale, + self.qk_nope_head_dim, + self.v_head_dim, + output_dtype, + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed).view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + else: + kv_nope = self.kv_b_proj(kv_c_normed).view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + # k = self._concat_k_nope_k_pe_plugin_mode(k_nope, k_pe) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + output_prefill = self._run_prefill_new_tokens_plugin_mode( + prefill=attn_metadata.plugin_metadata.prefill, + q=q, + k=k, + v=v, + return_softmax_lse=has_context, + ) + + from vllm.v1.attention.ops.merge_attn_states import merge_attn_states + + if has_context: + suffix_output, suffix_lse = output_prefill + if self.dcp_world_size > 1: + context_output, context_lse = ( + self._context_parallel_compute_prefill_context_plugin_mode( + q, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale=None, + dcp_world_size=self.dcp_world_size, + ) + ) + else: + context_output, context_lse = self._compute_prefill_context_plugin_mode( + q, kv_c_and_k_pe_cache, attn_metadata, k_scale + ) + + # unpad if necessary + if self._pad_v: + context_output = context_output[..., : v.shape[-1]] + suffix_output = suffix_output[..., : v.shape[-1]] + + output = output.view(-1, self.num_heads, self.v_head_dim) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + else: + output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2) + output.copy_(output_prefill) + + def _forward_decode_plugin_mode( + self, + q, + kv_c_and_k_pe_cache, + attn_metadata, + layer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.plugin_metadata.decode is not None + assert attn_metadata.plugin_metadata.decode.max_qo_len is not None + + # if type(q) is tuple: + # q = torch.cat(q, dim=-1) + + assert isinstance(q, torch.Tensor) + B = q.shape[0] + o = torch.zeros( + B, + self.num_heads, + self.kv_lora_rank, + dtype=attn_metadata.plugin_metadata.decode.attn_out_dtype, + device=q.device, + ) + + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + use_persistent_mode = not ( + self.dcp_world_size > 1 and self.kv_cache_dtype == "fp8" + ) + if not use_persistent_mode: + # DP : disable persistent mode to avoid overflow + work_meta_data = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + else: + work_meta_data = attn_metadata.work_meta_data + work_indptr = attn_metadata.work_indptr + work_info_set = attn_metadata.work_info_set + reduce_indptr = attn_metadata.reduce_indptr + reduce_final_map = attn_metadata.reduce_final_map + reduce_partial_map = attn_metadata.reduce_partial_map + + paged_kv_indptr = attn_metadata.plugin_metadata.decode.paged_kv_indptr + paged_kv_indices = attn_metadata.plugin_metadata.decode.paged_kv_indices + + mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + attn_metadata.plugin_metadata.decode.qo_indptr, + paged_kv_indptr, + paged_kv_indices, + attn_metadata.plugin_metadata.decode.paged_kv_last_page_len, + attn_metadata.plugin_metadata.decode.max_qo_len, + sm_scale=self.scale, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=layer._q_scale, + kv_scale=layer._k_scale, + ) + return o, None + + def forward_impl_plugin_mode( + self, + layer, + q, + k_c_normed, + k_pe, + kv_cache, + attn_metadata=None, + output=None, + ): + assert output is not None, "Output tensor must be provided." + from vllm.distributed.parallel_state import get_dcp_group + from vllm import _custom_ops as ops + from vllm.platforms import current_platform + from vllm.v1.attention.ops.common import cp_lse_ag_out_rs + + # create the output here, it use query shape + if attn_metadata is None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + ( + self.chunked_prefill_workspace_size, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + if self.dcp_world_size == -1: + self.dcp_world_size = get_dcp_group().world_size + + fp8_attention = self.kv_cache_dtype.startswith("fp8") + + num_actual_toks = attn_metadata.plugin_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + assert ( + attn_metadata.plugin_metadata.num_decodes is not None + and attn_metadata.plugin_metadata.num_prefills is not None + and attn_metadata.plugin_metadata.num_decode_tokens is not None + ) + + has_decode = attn_metadata.plugin_metadata.num_decodes > 0 + has_prefill = attn_metadata.plugin_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.plugin_metadata.num_decode_tokens + + atom_config = get_current_atom_config() + positions = atom_config.compilation_config.static_forward_context["positions"][ + :num_actual_toks + ] + k_pe = k_pe.unsqueeze(1) + output_padded = output + output = output[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + decode_q = q[:num_decode_tokens] + prefill_q = q[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + + decode_only = has_decode and not has_prefill + if not decode_only: + if self.rotary_emb is not None: + self.rotary_emb(positions, q[..., self.qk_nope_head_dim :], k_pe) + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + aiter.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.plugin_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + + if has_prefill: + self._forward_prefill_plugin_mode( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + kv_cache, + attn_metadata, + layer._k_scale, + output=output[num_decode_tokens:], + ) + + if has_decode: + assert attn_metadata.plugin_metadata.decode is not None + + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + + if self.q_pad_num_heads is not None: + B, N, L = decode_q_pe.shape + decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) + decode_pe_padded.resize_((B, N, L)) + decode_pe_padded.copy_(decode_q_pe) + decode_q_pe = decode_pe_padded + + if self.is_aiter_triton_fp4_bmm_enabled: + decode_ql_nope = batched_gemm_a16wfp4( + decode_q_nope, + self.W_K, + self.W_K_scale, + transpose_bm=True, + prequant=True, + y_scale=layer._q_scale if fp8_attention else None, + ) + # elif self.is_aiter_triton_fp8_bmm_enabled: + else: + # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) + decode_ql_nope = _aiter_triton_fp8_bmm( + decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True, + ) + + if decode_only: + decode_q = torch.empty( + ( + decode_ql_nope.shape[0], + self.num_heads, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + dtype=( + dtypes.fp8 + if self.kv_cache_dtype.startswith("fp8") + else self.dtype + ), + device=decode_ql_nope.device, + ) + aiter.fused_qk_rope_concat_and_cache_mla( + decode_ql_nope, + decode_q_pe, + k_c_normed, + k_pe.squeeze(1), + kv_cache.view( + kv_cache.shape[0], -1, self.kv_lora_rank + self.qk_rope_head_dim + ), + decode_q, + attn_metadata.plugin_metadata.slot_mapping, + self._k_scale, + self._q_scale, + positions, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + is_neox=self.rotary_emb.is_neox_style, + is_nope_first=True, + ) + else: + if fp8_attention: + ql_nope_shape = decode_ql_nope.shape + q_pe_shape = decode_q_pe.shape + assert decode_ql_nope.shape[0] == decode_q_pe.shape[0] + assert decode_ql_nope.shape[1] == decode_q_pe.shape[1] + decode_q_shape = ( + ql_nope_shape[0], + ql_nope_shape[1], + ql_nope_shape[2] + q_pe_shape[2], + ) + # Using empty and copy since torch.cat introduces significant overhead. + decode_q0 = torch.empty( + decode_q_shape, + device=decode_ql_nope.device, + dtype=decode_ql_nope.dtype, + ) + decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope) + decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe) + + decode_q, _ = ops.scaled_fp8_quant( + decode_q0.view(decode_q_shape[0], -1), + layer._q_scale, + ) + decode_q = decode_q.view(decode_q_shape) + else: + decode_q = (decode_ql_nope, decode_q_pe) + decode_q = torch.cat(decode_q, dim=-1) + if self.dcp_world_size > 1: + assert not fp8_attention, "DCP not support fp8 kvcache now." + # decode_q do allgather in head dim. + decode_q = get_dcp_group().all_gather(decode_q, dim=1) + + # call decode attn + attn_out, lse = self._forward_decode_plugin_mode( + decode_q, kv_cache, attn_metadata, layer + ) + + # correct dcp attn_out with lse. + if self.dcp_world_size > 1: + attn_out = cp_lse_ag_out_rs( + attn_out, + lse, + get_dcp_group(), + is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), + ) + + # v_up projection + # self._v_up_proj(attn_out, out=output[:num_decode_tokens]) + # TODO: remove this copy. + out_up_proj = self._v_up_proj(attn_out) + output[:num_decode_tokens] = out_up_proj + + return output_padded + + +def MLAAttentionImplDecoratorForPluginMode(cls): + method_names = [ + "_concat_k_nope_k_pe_plugin_mode", + "_flash_attn_varlen_diff_headdims", + "_run_prefill_new_tokens_plugin_mode", + "_run_prefill_context_chunk_plugin_mode", + "_context_parallel_compute_prefill_context_plugin_mode", + "_compute_prefill_context_plugin_mode", + "_forward_prefill_plugin_mode", + "_forward_decode_plugin_mode", + "forward_impl_plugin_mode", + ] + + logger.info("Use MLAAttentionImplDecoratorForPluginMode to decorate MLAAttention") + + # Add all methods to the target class + for method_name in method_names: + method = getattr(MLAAttentionImplPluginModeMethods, method_name) + setattr(cls, method_name, method) + + return cls diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index bcca17ad1..7d6ae1391 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -29,6 +29,7 @@ _ATOM_MODEL_CLASSES: dict[str, str] = { "Qwen3ForCausalLM": "atom.models.qwen3:Qwen3ForCausalLM", "Qwen3MoeForCausalLM": "atom.models.qwen3_moe:Qwen3MoeForCausalLM", + "DeepseekV3ForCausalLM": "atom.models.deepseek_v2:DeepseekV3ForCausalLM", } @@ -108,6 +109,14 @@ def forward( input_ids = None inputs_embeds = intermediate_tensors["hidden_states"] + # capture. This ensures attention_mla reads correct positions in graph mode. + # This is only for mla attention in plugin mode. + if "positions" in self.atom_config.compilation_config.static_forward_context: + buf = self.atom_config.compilation_config.static_forward_context[ + "positions" + ] + buf[: positions.numel()].copy_(positions) + hidden_states = self.model( input_ids=input_ids, positions=positions, diff --git a/atom/plugin/vllm/platform.py b/atom/plugin/vllm/platform.py index d003c4ae0..7e3903c26 100644 --- a/atom/plugin/vllm/platform.py +++ b/atom/plugin/vllm/platform.py @@ -9,7 +9,6 @@ import os logger = logging.getLogger("atom") - # This flag is used to enable the vLLM plugin mode. disable_vllm_plugin = os.getenv("ATOM_DISABLE_VLLM_PLUGIN", "0").lower() == "1" disable_vllm_plugin_attention = ( @@ -31,6 +30,8 @@ def get_attn_backend_cls(cls, selected_backend, attn_selector_config) -> str: ) logger.info("Use atom attention backend") + if attn_selector_config.use_mla: + return "atom.model_ops.attentions.aiter_mla.AiterMLABackend" return "atom.model_ops.attentions.aiter_attention.AiterBackend" else: diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py index 0330929a5..38bd571b8 100644 --- a/atom/plugin/vllm/register.py +++ b/atom/plugin/vllm/register.py @@ -22,6 +22,7 @@ _VLLM_MODEL_REGISTRY_OVERRIDES: dict[str, str] = { "Qwen3ForCausalLM": ATOM_CAUSAL_LM_MODEL_WRAPPER, "Qwen3MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, + "DeepseekV3ForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, } @@ -43,13 +44,8 @@ def register_platform() -> Optional[str]: return "atom.plugin.vllm.platform.ATOMPlatform" -def _patch_vllm_attention_process_weights_after_loading() -> None: - try: - from vllm.attention.layer import Attention - except ImportError: - from vllm.model_executor.layers.attention import Attention - - orig = Attention.process_weights_after_loading +def _patch_vllm_attention_process_weights_after_loading(attention) -> None: + orig = attention.process_weights_after_loading if getattr(orig, "_atom_default_act_dtype_patched", False): return @@ -74,7 +70,7 @@ def wrapped(self, act_dtype: "torch.dtype" = torch.bfloat16): return orig(self, act_dtype) setattr(wrapped, "_atom_default_act_dtype_patched", True) - Attention.process_weights_after_loading = wrapped + attention.process_weights_after_loading = wrapped def register_model() -> None: @@ -107,4 +103,10 @@ def register_model() -> None: # patch attention process weights after loading # to avoid the specific handle in ATOM loader - _patch_vllm_attention_process_weights_after_loading() + try: + from vllm.attention.layer import Attention, MLAAttention + except ImportError: + from vllm.model_executor.layers.attention import Attention, MLAAttention + + _patch_vllm_attention_process_weights_after_loading(Attention) + _patch_vllm_attention_process_weights_after_loading(MLAAttention) diff --git a/atom/utils/backends.py b/atom/utils/backends.py index 57eec887a..8058595cd 100644 --- a/atom/utils/backends.py +++ b/atom/utils/backends.py @@ -555,9 +555,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: hash_content = [] for filepath in forward_code_files: hash_content.append(filepath) - if filepath == "": + if filepath == "" or filepath == "": # This means the function was dynamically generated, with - # e.g. exec(). We can't actually check these. + # e.g. exec() or frozen os module. We can't actually check these. continue with open(filepath) as f: hash_content.append(f.read()) From b9a0acf7533aa8bcfa08ae924d79ccedeaaa56b7 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Wed, 4 Mar 2026 12:12:32 +0000 Subject: [PATCH 02/13] recover unrelated code --- atom/model_ops/attentions/aiter_attention.py | 1 + atom/model_ops/utils.py | 8 + atom/plugin/attention.py | 1201 +++++++++--------- 3 files changed, 601 insertions(+), 609 deletions(-) diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index f91c73e27..68a8567d5 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import itertools from typing import Type import aiter diff --git a/atom/model_ops/utils.py b/atom/model_ops/utils.py index 1f9c19649..c20154a97 100644 --- a/atom/model_ops/utils.py +++ b/atom/model_ops/utils.py @@ -135,6 +135,14 @@ def all_close_1d(x: torch.Tensor) -> bool: return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) +def per_tensor_dequantize( + tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] +) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight + + def get_and_maybe_dequant_weights(layer: nn.Module) -> torch.Tensor: if layer.quant_type != QuantType.No: # NOTE: This should only be used offline, since it's O(N^3) diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index 861c7af06..d01d25078 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -18,9 +18,6 @@ _CP_TOKENS_PER_ITER_ROCM = 32 * 1024 -_MLA_ATTENTION_FOR_PLUGIN_MODE = True - - @dataclass class AiterFlashAttentionDecodeMetadata: max_query_len: int @@ -595,692 +592,678 @@ def decorator(cls): # for MLA attention metadata for plugin mode -if _MLA_ATTENTION_FOR_PLUGIN_MODE: +@dataclass +class AiterMLACommonDecodeMetadataForPluginMode: + block_table: torch.Tensor + seq_lens: torch.Tensor + dcp_tot_seq_lens: torch.Tensor | None + + +@dataclass +class AiterMLADecodeMetadataForPluginMode(AiterMLACommonDecodeMetadataForPluginMode): + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: torch.Tensor | None = None + # The page indices of the paged kv cache + paged_kv_indices: torch.Tensor | None = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: torch.Tensor | None = None + # The query indptr, shape : [num_decode + 1] + qo_indptr: torch.Tensor | None = None + # The dtype of MLA out tensor + attn_out_dtype: torch.dtype = torch.bfloat16 + # The max query output length: int + max_qo_len: int | None = None + + +@dataclass +class AiterMLACommonPrefillMetadataForPluginMode: + """Prefill Specific Metadata""" @dataclass - class AiterMLACommonDecodeMetadataForPluginMode: - block_table: torch.Tensor + class AiterMLAChunkedContextMetadataForPluginMode: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] seq_lens: torch.Tensor - dcp_tot_seq_lens: torch.Tensor | None + workspace: torch.Tensor + token_to_seq: torch.Tensor + chunk_total_token: list[int] - @dataclass - class AiterMLADecodeMetadataForPluginMode( - AiterMLACommonDecodeMetadataForPluginMode - ): - # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: torch.Tensor | None = None - # The page indices of the paged kv cache - paged_kv_indices: torch.Tensor | None = None - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] - paged_kv_last_page_len: torch.Tensor | None = None - # The query indptr, shape : [num_decode + 1] - qo_indptr: torch.Tensor | None = None - # The dtype of MLA out tensor - attn_out_dtype: torch.dtype = torch.bfloat16 - # The max query output length: int - max_qo_len: int | None = None + # for mla DCP + padded_local_chunk_seq_lens: list[list[int]] | None = None + local_context_lens_allranks: list[list[int]] | None = None + padded_local_cu_seq_lens: torch.Tensor | None = None + cu_seq_lens_lst: list[list[int]] | None = None + chunk_size: int | None = None - @dataclass - class AiterMLACommonPrefillMetadataForPluginMode: - """Prefill Specific Metadata""" - - @dataclass - class AiterMLAChunkedContextMetadataForPluginMode: - # New for MLA (compared to FlashAttention) - # For handling chunked prefill - cu_seq_lens: torch.Tensor - starts: torch.Tensor - seq_tot: list[int] - max_seq_lens: list[int] - seq_lens: torch.Tensor - workspace: torch.Tensor - token_to_seq: torch.Tensor - chunk_total_token: list[int] - - # for mla DCP - padded_local_chunk_seq_lens: list[list[int]] | None = None - local_context_lens_allranks: list[list[int]] | None = None - padded_local_cu_seq_lens: torch.Tensor | None = None - cu_seq_lens_lst: list[list[int]] | None = None - chunk_size: int | None = None - - block_table: torch.Tensor - query_start_loc: torch.Tensor - max_query_len: int - chunked_context: AiterMLAChunkedContextMetadataForPluginMode | None = None - query_seq_lens: torch.Tensor | None = None - workspace_buffer: torch.Tensor | None = None - q_data_type: torch.dtype | None = None - - D = TypeVar("D", bound=AiterMLACommonDecodeMetadataForPluginMode) + block_table: torch.Tensor + query_start_loc: torch.Tensor + max_query_len: int + chunked_context: AiterMLAChunkedContextMetadataForPluginMode | None = None + query_seq_lens: torch.Tensor | None = None + workspace_buffer: torch.Tensor | None = None + q_data_type: torch.dtype | None = None - @dataclass - class AiterMLACommonMetadataForPluginMode(Generic[D]): - """Metadata for MLACommon. - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ +D = TypeVar("D", bound=AiterMLACommonDecodeMetadataForPluginMode) - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - num_reqs: int - max_query_len: int - max_seq_len: int +@dataclass +class AiterMLACommonMetadataForPluginMode(Generic[D]): + """Metadata for MLACommon. - num_actual_tokens: int # Number of tokens excluding padding. - query_start_loc: torch.Tensor - slot_mapping: torch.Tensor + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ - # New for MLA (compared to FlashAttention) - # For handling prefill decode split - num_decodes: int - num_decode_tokens: int - num_prefills: int - - # The dimension of the attention heads - head_dim: int | None = None - - decode: D | None = None - prefill: AiterMLACommonPrefillMetadataForPluginMode | None = None - - def __post_init__(self): - pass - # if self.head_dim is not None and not MLACommonBackend.supports_head_size( - # self.head_dim - # ): - # raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.") - - class vllmMLAAttentionMetadataBuilderMethods: - def __init__(self): - raise TypeError( - f"{self.__class__.__name__} is a utility class and should not be instantiated. " - "Its methods are meant to be added to other classes via decorators." - ) + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| - def _build_decode( - self, - block_table_tensor: torch.Tensor, - seq_lens_device: torch.Tensor, - max_seq_len: int, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int, - dcp_tot_seq_lens_device: torch.Tensor | None, - ): - # kernel block size is always 1, although the kv block size is not 1. - device = self.device - num_reqs = seq_lens_device.size(0) - - mask = torch.arange( - block_table_tensor.size(1), - dtype=block_table_tensor.dtype, - device=device, - ).unsqueeze(0) < seq_lens_device.unsqueeze(1) - paged_kv_indices = block_table_tensor[mask] - - # kernel block size is always 1, so each page has exactly 1 token. - # last_page_len is always 1 - just slice the pre-initialized buffer. - paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] - - paged_kv_indptr = torch.cat( - [ - torch.zeros(1, dtype=seq_lens_device.dtype, device=device), - seq_lens_device.cumsum(dim=0, dtype=torch.int32), - ] - ) - qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - max_qo_len = qo_len.max().item() + num_reqs: int + max_query_len: int + max_seq_len: int - if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - num_actual_pages = paged_kv_indices.size(0) + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor - self.paged_kv_indices[:num_actual_pages].copy_( - paged_kv_indices, non_blocking=True - ) - self.paged_kv_indices[num_actual_pages:].fill_(-1) - paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int - self.paged_kv_indptr[: 1 + num_reqs].copy_( - paged_kv_indptr, non_blocking=True - ) - self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1]) - paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs] + # The dimension of the attention heads + head_dim: int | None = None - # paged_kv_last_page_len already uses the pre-initialized buffer slice - # (set above), so no copy needed - buffer is always 1s. + decode: D | None = None + prefill: AiterMLACommonPrefillMetadataForPluginMode | None = None - self.qo_indptr[: 1 + num_reqs].copy_( - query_start_loc_device, non_blocking=True - ) - self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1] - qo_indptr = self.qo_indptr[: 1 + num_reqs] + def __post_init__(self): + pass + # if self.head_dim is not None and not MLACommonBackend.supports_head_size( + # self.head_dim + # ): + # raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.") - else: - qo_indptr = torch.arange( - 0, num_reqs + 1, step=1, dtype=torch.int32, device=device - ) - attn_metadata = AiterMLADecodeMetadataForPluginMode( - block_table=block_table_tensor, - seq_lens=seq_lens_device, - paged_kv_indptr=paged_kv_indptr, - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_kv_last_page_len, - qo_indptr=qo_indptr, - dcp_tot_seq_lens=dcp_tot_seq_lens_device, - max_qo_len=max_qo_len, - attn_out_dtype=self.decode_attn_out_dtype, +class vllmMLAAttentionMetadataBuilderMethods: + def __init__(self): + raise TypeError( + f"{self.__class__.__name__} is a utility class and should not be instantiated. " + "Its methods are meant to be added to other classes via decorators." + ) + + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_device: torch.Tensor, + max_seq_len: int, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + dcp_tot_seq_lens_device: torch.Tensor | None, + ): + # kernel block size is always 1, although the kv block size is not 1. + device = self.device + num_reqs = seq_lens_device.size(0) + + mask = torch.arange( + block_table_tensor.size(1), + dtype=block_table_tensor.dtype, + device=device, + ).unsqueeze(0) < seq_lens_device.unsqueeze(1) + paged_kv_indices = block_table_tensor[mask] + + # kernel block size is always 1, so each page has exactly 1 token. + # last_page_len is always 1 - just slice the pre-initialized buffer. + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] + + paged_kv_indptr = torch.cat( + [ + torch.zeros(1, dtype=seq_lens_device.dtype, device=device), + seq_lens_device.cumsum(dim=0, dtype=torch.int32), + ] + ) + qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_qo_len = qo_len.max().item() + + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + num_actual_pages = paged_kv_indices.size(0) + + self.paged_kv_indices[:num_actual_pages].copy_( + paged_kv_indices, non_blocking=True + ) + self.paged_kv_indices[num_actual_pages:].fill_(-1) + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + + self.paged_kv_indptr[: 1 + num_reqs].copy_( + paged_kv_indptr, non_blocking=True ) + self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1]) + paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs] + + # paged_kv_last_page_len already uses the pre-initialized buffer slice + # (set above), so no copy needed - buffer is always 1s. + + self.qo_indptr[: 1 + num_reqs].copy_( + query_start_loc_device, non_blocking=True + ) + self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1] + qo_indptr = self.qo_indptr[: 1 + num_reqs] + + else: + qo_indptr = torch.arange( + 0, num_reqs + 1, step=1, dtype=torch.int32, device=device + ) + + attn_metadata = AiterMLADecodeMetadataForPluginMode( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + qo_indptr=qo_indptr, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, + max_qo_len=max_qo_len, + attn_out_dtype=self.decode_attn_out_dtype, + ) + + return attn_metadata - return attn_metadata + def build_for_cudagraph_capture( + self, + common_attn_metadata=None, + ): + return self.build(0, common_attn_metadata) + + def build( + self, + common_prefix_len: int = 0, + common_attn_metadata=None, + fast_build: bool = False, + ): - def build_for_cudagraph_capture( - self, - common_attn_metadata=None, - ): - m = common_attn_metadata - # assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), ( - # "MLA only supports decode-only full CUDAGraph capture. " - # "Make sure all cudagraph capture sizes <= max_num_seq." - # ) + from vllm.v1.attention.backends.utils import split_decodes_and_prefills + from vllm.model_executor.layers.attention.mla_attention import ( + QueryLenSupport, + ) - # assert m.max_query_len <= self.reorder_batch_threshold # decode only + from vllm.utils.math_utils import cdiv, round_down + from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens - return self.build(0, m) + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len - def build( - self, - common_prefix_len: int = 0, - common_attn_metadata=None, - fast_build: bool = False, - ): + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping - from vllm.v1.attention.backends.utils import split_decodes_and_prefills - from vllm.model_executor.layers.attention.mla_attention import ( - QueryLenSupport, + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens + dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=(self.query_len_support != QueryLenSupport.VARLEN), ) + ) - from vllm.utils.math_utils import cdiv, round_down - from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens - - num_reqs = common_attn_metadata.num_reqs - num_tokens = common_attn_metadata.num_actual_tokens - max_query_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.max_seq_len - - # Note(simon): be careful about the CPU <> GPU memory movement in this - # function. We should avoid GPU -> CPU sync as much as possible because - # it blocks on all previous kernels. - device = self.device - block_table_tensor = common_attn_metadata.block_table_tensor - slot_mapping = common_attn_metadata.slot_mapping - - query_start_loc = common_attn_metadata.query_start_loc - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - seq_lens = common_attn_metadata.seq_lens - dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens - - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold, - require_uniform=(self.query_len_support != QueryLenSupport.VARLEN), - ) + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + prefill_metadata = None + if num_prefills > 0: + num_computed_tokens_cpu = ( + common_attn_metadata.compute_num_computed_tokens().cpu() ) - assert num_decodes + num_prefills == num_reqs - assert num_decode_tokens + num_prefill_tokens == num_tokens + reqs_start = num_decodes # prefill_start + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + prefill_query_start_loc = ( + query_start_loc[reqs_start:] - query_start_loc[reqs_start] + ) - prefill_metadata = None - if num_prefills > 0: - num_computed_tokens_cpu = ( - common_attn_metadata.compute_num_computed_tokens().cpu() + chunked_context_metadata = None + if max_context_len_cpu > 0: + # NOTE: it is recommend you read the `Chunked Prefill` section + # in the comment at the top of the file before trying to + # understand the following code + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = ( + self.chunked_prefill_workspace_size // num_prefills_with_context_cpu ) - reqs_start = num_decodes # prefill_start + if self.aot_schedule: + # align max_context_chunk to page_size by rounding down, + # currently the `gather_and_maybe_dequant_cache` kernel + # cannot handle `context_chunk_starts` that are not aligned + # to page_size + max_context_chunk = round_down(max_context_chunk, self.page_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks + # like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + # Note(simon): this is done in CPU because of downstream's + # of `to_list`. + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) + * max_context_chunk + ) + chunk_ends = torch.min( + context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk + ) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] - max_context_len_cpu = context_lens_cpu.max().item() - num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() - prefill_query_start_loc = ( - query_start_loc[reqs_start:] - query_start_loc[reqs_start] + cu_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32, ) + chunk_total_token = cu_seq_lens_cpu[:, -1] - chunked_context_metadata = None - if max_context_len_cpu > 0: - # NOTE: it is recommend you read the `Chunked Prefill` section - # in the comment at the top of the file before trying to - # understand the following code - - # currently we allocate an equal amount of workspace for each - # prefill in the batch, we could probably use a more advanced - # algorithm here and allocate more workspace to prefills with - # longer context lengths - max_context_chunk = ( - self.chunked_prefill_workspace_size - // num_prefills_with_context_cpu + max_token_num_over_chunk = chunk_total_token.max().item() + token_to_seq_tensor_cpu = torch.zeros( + [num_chunks, max_token_num_over_chunk], dtype=torch.int32 + ) + range_idx = torch.arange(num_prefills, dtype=torch.int32) + for i in range(num_chunks): + chunk_token_to_seq_tensor = torch.repeat_interleave( + range_idx, chunk_seq_lens[i] ) + chunk_len = chunk_token_to_seq_tensor.shape[0] + token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor - if self.aot_schedule: - # align max_context_chunk to page_size by rounding down, - # currently the `gather_and_maybe_dequant_cache` kernel - # cannot handle `context_chunk_starts` that are not aligned - # to page_size - max_context_chunk = round_down( - max_context_chunk, self.page_size + if self.dcp_world_size > 1: + local_context_lens_allranks = get_dcp_local_seq_lens( + context_lens_cpu, + self.dcp_world_size, + None, + self.dcp_local_block_size, + ) + # Note(qcs): The max local context lengths + # padded to `dcp_local_block_size`. + padded_local_context_lens_cpu: torch.Tensor = ( + cdiv( + context_lens_cpu, + self.dcp_virtual_block_size, ) - - assert max_context_chunk > 0 - num_chunks = cdiv(max_context_len_cpu, max_context_chunk) - - # if `max_context_chunk = 256`, `num_chunks = 3`, and - # `num_prefills_with_context = 4`, create a tensor that looks - # like - # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] - # Note(simon): this is done in CPU because of downstream's - # of `to_list`. - chunk_starts = ( + * self.dcp_local_block_size + ) + # Note(hc): The above max_context_chunk already enforces + # block_size alignment, DCP just need the block_size can + # be divisible by dcp_world_size, because DCP use + # cp_gather_cache which not require `cp_chunk_starts` + # aligned to page_size. + assert max_context_chunk % self.dcp_world_size == 0 + padded_local_max_context_chunk_across_ranks = ( + cdiv( + max_context_chunk, + self.dcp_virtual_block_size, + ) + * self.dcp_local_block_size + ) + local_chunk_starts = ( torch.arange(num_chunks, dtype=torch.int32) .unsqueeze(1) .expand(-1, num_prefills) - * max_context_chunk + * padded_local_max_context_chunk_across_ranks ) - chunk_ends = torch.min( - context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk + local_chunk_ends = torch.min( + padded_local_context_lens_cpu.unsqueeze(0), + local_chunk_starts + + padded_local_max_context_chunk_across_ranks, ) - chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + padded_local_chunk_seq_lens = ( + local_chunk_ends - local_chunk_starts + ).clamp(min=0) - cu_seq_lens_cpu = torch.zeros( - num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + padded_local_cu_chunk_seq_lens_cpu = torch.zeros( + num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True, ) torch.cumsum( - chunk_seq_lens, + padded_local_chunk_seq_lens, dim=1, - out=cu_seq_lens_cpu[:, 1:], + out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], dtype=torch.int32, ) - chunk_total_token = cu_seq_lens_cpu[:, -1] - max_token_num_over_chunk = chunk_total_token.max().item() - token_to_seq_tensor_cpu = torch.zeros( - [num_chunks, max_token_num_over_chunk], dtype=torch.int32 + chunked_context_metadata_cls = ( + AiterMLACommonPrefillMetadataForPluginMode.AiterMLAChunkedContextMetadataForPluginMode + ) + if self.dcp_world_size > 1: + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=local_chunk_starts.to(device, non_blocking=True), + seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_seq=token_to_seq_tensor_cpu.to( + device, non_blocking=True + ), + chunk_total_token=chunk_total_token.tolist(), + workspace=self.chunked_prefill_workspace, + padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), + local_context_lens_allranks=local_context_lens_allranks.tolist(), + padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to( + device, non_blocking=True + ), + cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), + chunk_size=padded_local_max_context_chunk_across_ranks, ) - range_idx = torch.arange(num_prefills, dtype=torch.int32) - for i in range(num_chunks): - chunk_token_to_seq_tensor = torch.repeat_interleave( - range_idx, chunk_seq_lens[i] - ) - chunk_len = chunk_token_to_seq_tensor.shape[0] - token_to_seq_tensor_cpu[i, :chunk_len] = ( - chunk_token_to_seq_tensor - ) - - if self.dcp_world_size > 1: - local_context_lens_allranks = get_dcp_local_seq_lens( - context_lens_cpu, - self.dcp_world_size, - None, - self.dcp_local_block_size, - ) - # Note(qcs): The max local context lengths - # padded to `dcp_local_block_size`. - padded_local_context_lens_cpu: torch.Tensor = ( - cdiv( - context_lens_cpu, - self.dcp_virtual_block_size, - ) - * self.dcp_local_block_size - ) - # Note(hc): The above max_context_chunk already enforces - # block_size alignment, DCP just need the block_size can - # be divisible by dcp_world_size, because DCP use - # cp_gather_cache which not require `cp_chunk_starts` - # aligned to page_size. - assert max_context_chunk % self.dcp_world_size == 0 - padded_local_max_context_chunk_across_ranks = ( - cdiv( - max_context_chunk, - self.dcp_virtual_block_size, - ) - * self.dcp_local_block_size - ) - local_chunk_starts = ( - torch.arange(num_chunks, dtype=torch.int32) - .unsqueeze(1) - .expand(-1, num_prefills) - * padded_local_max_context_chunk_across_ranks - ) - local_chunk_ends = torch.min( - padded_local_context_lens_cpu.unsqueeze(0), - local_chunk_starts - + padded_local_max_context_chunk_across_ranks, - ) - padded_local_chunk_seq_lens = ( - local_chunk_ends - local_chunk_starts - ).clamp(min=0) - - padded_local_cu_chunk_seq_lens_cpu = torch.zeros( - num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True, - ) - torch.cumsum( - padded_local_chunk_seq_lens, - dim=1, - out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], - dtype=torch.int32, - ) - - chunked_context_metadata_cls = ( - AiterMLACommonPrefillMetadataForPluginMode.AiterMLAChunkedContextMetadataForPluginMode + else: + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_seq=token_to_seq_tensor_cpu.to( + device, non_blocking=True + ), + chunk_total_token=chunk_total_token, + workspace=self.chunked_prefill_workspace, ) - if self.dcp_world_size > 1: - chunked_context_metadata = chunked_context_metadata_cls( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=local_chunk_starts.to(device, non_blocking=True), - seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - seq_lens=chunk_seq_lens, - token_to_seq=token_to_seq_tensor_cpu.to( - device, non_blocking=True - ), - chunk_total_token=chunk_total_token.tolist(), - workspace=self.chunked_prefill_workspace, - padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), - local_context_lens_allranks=local_context_lens_allranks.tolist(), - padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to( - device, non_blocking=True - ), - cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), - chunk_size=padded_local_max_context_chunk_across_ranks, - ) - else: - chunked_context_metadata = chunked_context_metadata_cls( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=chunk_starts.to(device, non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - seq_lens=chunk_seq_lens, - token_to_seq=token_to_seq_tensor_cpu.to( - device, non_blocking=True - ), - chunk_total_token=chunk_total_token, - workspace=self.chunked_prefill_workspace, - ) - if self._use_cudnn_prefill: - chunked_context_metadata.seq_lens = chunk_seq_lens + if self._use_cudnn_prefill: + chunked_context_metadata.seq_lens = chunk_seq_lens - assert ( - max(chunked_context_metadata.max_seq_lens) - <= self.chunked_prefill_workspace_size - ) - - prefill_metadata = AiterMLACommonPrefillMetadataForPluginMode( - block_table=block_table_tensor[reqs_start:, ...], - query_start_loc=prefill_query_start_loc, - max_query_len=max_query_len, - chunked_context=chunked_context_metadata, + assert ( + max(chunked_context_metadata.max_seq_lens) + <= self.chunked_prefill_workspace_size ) - decode_metadata = None - if num_decodes > 0: - dcp_tot_seq_lens_device = None - if self.dcp_world_size > 1: - dcp_tot_seq_lens_device = seq_lens[:num_decodes] - seq_lens = dcp_local_seq_lens - - # After DCP distribution, the maximum number of tokens for any rank is - # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size, - # and I is cp_kv_cache_interleave_size. - # This eliminates GPU->CPU sync while minimizing workspace - # over-allocation. - num_partitions = ( - self.dcp_world_size * self.cp_kv_cache_interleave_size - ) - max_seq_len = ( - (max_seq_len + num_partitions - 1) // num_partitions - ) * self.cp_kv_cache_interleave_size - - decode_metadata = self._build_decode( - block_table_tensor=block_table_tensor[:num_decodes, ...], - seq_lens_device=seq_lens[:num_decodes], - max_seq_len=max_seq_len, - query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], - query_start_loc_device=query_start_loc[: num_decodes + 1], - num_decode_tokens=num_decode_tokens, - dcp_tot_seq_lens_device=dcp_tot_seq_lens_device, - ) + prefill_metadata = AiterMLACommonPrefillMetadataForPluginMode( + block_table=block_table_tensor[reqs_start:, ...], + query_start_loc=prefill_query_start_loc, + max_query_len=max_query_len, + chunked_context=chunked_context_metadata, + ) - attn_metadata_for_plugin_mode = AiterMLACommonMetadataForPluginMode( - num_reqs=common_attn_metadata.num_reqs, - max_query_len=common_attn_metadata.max_query_len, + decode_metadata = None + if num_decodes > 0: + dcp_tot_seq_lens_device = None + if self.dcp_world_size > 1: + dcp_tot_seq_lens_device = seq_lens[:num_decodes] + seq_lens = dcp_local_seq_lens + + # After DCP distribution, the maximum number of tokens for any rank is + # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size, + # and I is cp_kv_cache_interleave_size. + # This eliminates GPU->CPU sync while minimizing workspace + # over-allocation. + num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size + max_seq_len = ( + (max_seq_len + num_partitions - 1) // num_partitions + ) * self.cp_kv_cache_interleave_size + + decode_metadata = self._build_decode( + block_table_tensor=block_table_tensor[:num_decodes, ...], + seq_lens_device=seq_lens[:num_decodes], max_seq_len=max_seq_len, - num_actual_tokens=num_tokens, - query_start_loc=query_start_loc, - slot_mapping=slot_mapping, - head_dim=self.model_config.get_head_size(), - # MLACommonMetadata Chunk prefill specific - num_decodes=num_decodes, + query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], + query_start_loc_device=query_start_loc[: num_decodes + 1], num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, - prefill=prefill_metadata, - decode=decode_metadata, + dcp_tot_seq_lens_device=dcp_tot_seq_lens_device, ) - attn_metadata = AttentionMetaData( - max_seqlen_q=common_attn_metadata.max_query_len, - block_tables=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping, - plugin_metadata=attn_metadata_for_plugin_mode, + attn_metadata_for_plugin_mode = AiterMLACommonMetadataForPluginMode( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=max_seq_len, + num_actual_tokens=num_tokens, + query_start_loc=query_start_loc, + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + # MLACommonMetadata Chunk prefill specific + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + attn_metadata = AttentionMetaData( + max_seqlen_q=common_attn_metadata.max_query_len, + block_tables=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + plugin_metadata=attn_metadata_for_plugin_mode, + ) + return attn_metadata + + +def create_mla_attn_metadata_builder_init_method(base_class): + """ + Create the init method for metadata builder + """ + + def init_method_under_plugin_mode( + self, + kv_cache_spec=None, + layer_names=None, + config=None, + device=None, + model_runner=None, + ): + base_class.__init__(self, kv_cache_spec, layer_names, config, device) + logger.info("init AiterAttentionMetadataBuilder for plugin mode") + from vllm.config import VllmConfig + + assert isinstance(config, VllmConfig) + + self.vllm_config = config + self.model_config = config.model_config + self.parallel_config = config.parallel_config + self.cache_config = config.cache_config + + self.compilation_config = self.vllm_config.compilation_config + self.decode_attn_out_dtype = self.vllm_config.model_config.dtype + # kernel block size is always 1. + max_num_pages_per_req = self.vllm_config.model_config.max_model_len + max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs + max_num_pages = max_num_reqs * max_num_pages_per_req + + # Preparing persistent buffers + # TODO: we can disambiguate between decode and mixed-prefill decode here + # so we can only use the persistent buffer if a cudagraph is actually + # being used. + + # paged_kv_last_page_len is always 1s (kernel block size is always 1), + # so we create it once and reuse slices in both eager and cudagraph modes. + self.paged_kv_last_page_len = torch.ones( + max_num_reqs, dtype=torch.int32, device=device + ) + + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device ) - return attn_metadata - - def create_mla_attn_metadata_builder_init_method(base_class): - """ - Create the init method for metadata builder - """ - - def init_method_under_plugin_mode( - self, - kv_cache_spec=None, - layer_names=None, - config=None, - device=None, - model_runner=None, - ): - base_class.__init__(self, kv_cache_spec, layer_names, config, device) - logger.info("init AiterAttentionMetadataBuilder for plugin mode") - from vllm.config import VllmConfig - - assert isinstance(config, VllmConfig) - - self.vllm_config = config - self.model_config = config.model_config - self.parallel_config = config.parallel_config - self.cache_config = config.cache_config - - self.compilation_config = self.vllm_config.compilation_config - self.decode_attn_out_dtype = self.vllm_config.model_config.dtype - # kernel block size is always 1. - max_num_pages_per_req = self.vllm_config.model_config.max_model_len - max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs - max_num_pages = max_num_reqs * max_num_pages_per_req - - # Preparing persistent buffers - # TODO: we can disambiguate between decode and mixed-prefill decode here - # so we can only use the persistent buffer if a cudagraph is actually - # being used. - - # paged_kv_last_page_len is always 1s (kernel block size is always 1), - # so we create it once and reuse slices in both eager and cudagraph modes. - self.paged_kv_last_page_len = torch.ones( - max_num_reqs, dtype=torch.int32, device=device + self.paged_kv_indices = torch.zeros( + max_num_pages, dtype=torch.int32, device=device ) - if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.paged_kv_indptr = torch.zeros( - max_num_reqs + 1, dtype=torch.int32, device=device - ) - self.paged_kv_indices = torch.zeros( - max_num_pages, dtype=torch.int32, device=device - ) + self.qo_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device + ) - self.qo_indptr = torch.zeros( - max_num_reqs + 1, dtype=torch.int32, device=device - ) + return init_method_under_plugin_mode - return init_method_under_plugin_mode - def setup_mla_attn_metadata_builder_base_class_and_attributes(class_dict: dict): - """ - Setup the base class and attributes for attention metadata builder - """ - from vllm.model_executor.layers.attention.mla_attention import ( - MLACommonMetadataBuilder, - QueryLenSupport, - ) - from vllm.v1.attention.backend import AttentionCGSupport +def setup_mla_attn_metadata_builder_base_class_and_attributes(class_dict: dict): + """ + Setup the base class and attributes for attention metadata builder + """ + from vllm.model_executor.layers.attention.mla_attention import ( + MLACommonMetadataBuilder, + QueryLenSupport, + ) + from vllm.v1.attention.backend import AttentionCGSupport - base_class = MLACommonMetadataBuilder - generic_base = MLACommonMetadataBuilder - needs_generic = True + base_class = MLACommonMetadataBuilder + generic_base = MLACommonMetadataBuilder + needs_generic = True - # align with vllm rocm aiter fa - class_dict["_cudagraph_support"] = ( - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - ) - class_dict["reorder_batch_threshold"] = 1 - class_dict["query_len_support"] = QueryLenSupport.UNIFORM - - return base_class, generic_base, needs_generic, class_dict - - def AiterMLAAttentionMetadataBuilderDecoratorForPluginMode(default_base_class): - def decorator(cls): - is_vllm_mode = is_vllm() - is_sglang_mode = is_sglang() - - base_class = default_base_class - class_dict = {} - - # record original decorated cls methods - for key, value in cls.__dict__.items(): - if not key.startswith("__") or key in ( - "__annotations__", - "__init__", - "__module__", - "__qualname__", - "__doc__", - ): - class_dict[key] = value - - # handle the generic base class - needs_generic = False - generic_base = None - - if is_vllm_mode: - # get the base class and generic base class - base_class, generic_base, needs_generic, class_dict = ( - setup_mla_attn_metadata_builder_base_class_and_attributes( - class_dict - ) - ) + # align with vllm rocm aiter fa + class_dict["_cudagraph_support"] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + class_dict["reorder_batch_threshold"] = 1 + class_dict["query_len_support"] = QueryLenSupport.UNIFORM - # replace the __init__ method in the decorated class - class_dict["__init__"] = create_mla_attn_metadata_builder_init_method( - base_class - ) + return base_class, generic_base, needs_generic, class_dict - # add the methods to the decorated class - for method_name in dir(vllmMLAAttentionMetadataBuilderMethods): - if not method_name.startswith("__"): - method = getattr( - vllmMLAAttentionMetadataBuilderMethods, method_name - ) - if callable(method): - class_dict[method_name] = method - elif is_sglang_mode: - raise NotImplementedError( - "AttentionMetadataBuilder for sglang is not implemented yet" - ) - # create the new class - new_class = type(cls.__name__, (base_class,), class_dict) +def AiterMLAAttentionMetadataBuilderDecoratorForPluginMode(default_base_class): + def decorator(cls): + is_vllm_mode = is_vllm() + is_sglang_mode = is_sglang() - # replace the inherit base class for plugin mode, meanwhile support generic base class - if needs_generic and generic_base is not None: - new_class.__orig_bases__ = (generic_base[new_class],) + base_class = default_base_class + class_dict = {} - return new_class + # record original decorated cls methods + for key, value in cls.__dict__.items(): + if not key.startswith("__") or key in ( + "__annotations__", + "__init__", + "__module__", + "__qualname__", + "__doc__", + ): + class_dict[key] = value + + # handle the generic base class + needs_generic = False + generic_base = None - return decorator + if is_vllm_mode: + # get the base class and generic base class + base_class, generic_base, needs_generic, class_dict = ( + setup_mla_attn_metadata_builder_base_class_and_attributes(class_dict) + ) - class vllmAiterMLABackendMethods: - # here attention in ATOM doesn't accept the output buffer because - # ATOM works as a model impl backend, it needs the maximum freedom - # to decide the output buffer and shape, thus here use this flag to - # let vllm don't allocate the output buffer for ATOM. ATOM will - # handle the output buffer by itself - accept_output_buffer: bool = True - supported_dtypes: list = [torch.float16, torch.bfloat16] + # replace the __init__ method in the decorated class + class_dict["__init__"] = create_mla_attn_metadata_builder_init_method( + base_class + ) - def __init__(self): - raise TypeError( - f"{self.__class__.__name__} is a utility class and should not be instantiated. " - "Its methods are meant to be added to other classes via decorators." + # add the methods to the decorated class + for method_name in dir(vllmMLAAttentionMetadataBuilderMethods): + if not method_name.startswith("__"): + method = getattr( + vllmMLAAttentionMetadataBuilderMethods, method_name + ) + if callable(method): + class_dict[method_name] = method + elif is_sglang_mode: + raise NotImplementedError( + "AttentionMetadataBuilder for sglang is not implemented yet" ) - @staticmethod - def get_supported_kernel_block_sizes(): - return [1] - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - return (num_blocks, block_size, head_size) - - @classmethod - def is_mla(cls) -> bool: - return True - - @staticmethod - def get_required_kv_cache_layout(): - return None - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [576] - - @classmethod - def full_cls_name(cls) -> tuple[str, str]: - return (cls.__module__, cls.__qualname__) - - @classmethod - def supports_alibi_sqrt(cls) -> bool: - return False - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - # `stride_order` indicates the permutation that gets - # us from `get_kv_cache_shape` to the actual memory layout we want. - # (num_blocks, num_layers, block_size, head_size) - return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2) + # create the new class + new_class = type(cls.__name__, (base_class,), class_dict) + + # replace the inherit base class for plugin mode, meanwhile support generic base class + if needs_generic and generic_base is not None: + new_class.__orig_bases__ = (generic_base[new_class],) + + return new_class + + return decorator + + +class vllmAiterMLABackendMethods: + # here attention in ATOM doesn't accept the output buffer because + # ATOM works as a model impl backend, it needs the maximum freedom + # to decide the output buffer and shape, thus here use this flag to + # let vllm don't allocate the output buffer for ATOM. ATOM will + # handle the output buffer by itself + accept_output_buffer: bool = True + supported_dtypes: list = [torch.float16, torch.bfloat16] + + def __init__(self): + raise TypeError( + f"{self.__class__.__name__} is a utility class and should not be instantiated. " + "Its methods are meant to be added to other classes via decorators." + ) + + @staticmethod + def get_supported_kernel_block_sizes(): + return [1] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @classmethod + def is_mla(cls) -> bool: + return True + + @staticmethod + def get_required_kv_cache_layout(): + return None + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + @classmethod + def full_cls_name(cls) -> tuple[str, str]: + return (cls.__module__, cls.__qualname__) + + @classmethod + def supports_alibi_sqrt(cls) -> bool: + return False + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + # (num_blocks, num_layers, block_size, head_size) + return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2) def AiterBackendDecoratorForPluginMode(cls): From 07fd49782216c96505af0933e930b8eee6dc820c Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Wed, 4 Mar 2026 13:00:06 +0000 Subject: [PATCH 03/13] simplify attention.py code Signed-off-by: XiaobingSuper --- atom/plugin/attention.py | 60 ++++++---------------------------------- 1 file changed, 9 insertions(+), 51 deletions(-) diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index d01d25078..aadf739ed 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -19,19 +19,15 @@ @dataclass -class AiterFlashAttentionDecodeMetadata: +class AiterFlashAttentionPhaseMetadata: max_query_len: int min_query_len: int max_seq_len: int query_start_loc: torch.Tensor -@dataclass -class AiterFlashAttentionPrefillMetadata: - max_query_len: int - min_query_len: int - max_seq_len: int - query_start_loc: torch.Tensor +AiterFlashAttentionDecodeMetadata = AiterFlashAttentionPhaseMetadata +AiterFlashAttentionPrefillMetadata = AiterFlashAttentionPhaseMetadata @dataclass @@ -1271,53 +1267,15 @@ def AiterBackendDecoratorForPluginMode(cls): Decorator for AiterBackend to add specific methods and attributes for plugin mode """ is_vllm_mode = is_vllm() - if is_vllm_mode: - # for vllm, add the required methods if not issubclass(cls.get_impl_cls(), MLAAttention): - - cls.full_cls_name = vllmAiterAttentionBackendMethods.full_cls_name - cls.accept_output_buffer = ( - vllmAiterAttentionBackendMethods.accept_output_buffer - ) - cls.supported_dtypes = vllmAiterAttentionBackendMethods.supported_dtypes - cls.forward_includes_kv_cache_update = ( - vllmAiterAttentionBackendMethods.forward_includes_kv_cache_update - ) - cls.get_supported_kernel_block_sizes = ( - vllmAiterAttentionBackendMethods.get_supported_kernel_block_sizes - ) - cls.get_kv_cache_shape = vllmAiterAttentionBackendMethods.get_kv_cache_shape - cls.is_mla = vllmAiterAttentionBackendMethods.is_mla - cls.get_required_kv_cache_layout = ( - vllmAiterAttentionBackendMethods.get_required_kv_cache_layout - ) - cls.get_supported_head_sizes = ( - vllmAiterAttentionBackendMethods.get_supported_head_sizes - ) - cls.supports_alibi_sqrt = ( - vllmAiterAttentionBackendMethods.supports_alibi_sqrt - ) + methods_cls = vllmAiterAttentionBackendMethods else: - # for mla, add the required methods - cls.full_cls_name = vllmAiterMLABackendMethods.full_cls_name - cls.accept_output_buffer = vllmAiterMLABackendMethods.accept_output_buffer - cls.supported_dtypes = vllmAiterMLABackendMethods.supported_dtypes - cls.get_supported_kernel_block_sizes = ( - vllmAiterMLABackendMethods.get_supported_kernel_block_sizes - ) - cls.get_kv_cache_shape = vllmAiterMLABackendMethods.get_kv_cache_shape - cls.is_mla = vllmAiterMLABackendMethods.is_mla - cls.get_required_kv_cache_layout = ( - vllmAiterMLABackendMethods.get_required_kv_cache_layout - ) - cls.get_supported_head_sizes = ( - vllmAiterMLABackendMethods.get_supported_head_sizes - ) - cls.supports_alibi_sqrt = vllmAiterMLABackendMethods.supports_alibi_sqrt - cls.get_kv_cache_stride_order = ( - vllmAiterMLABackendMethods.get_kv_cache_stride_order - ) + methods_cls = vllmAiterMLABackendMethods + for name in dir(methods_cls): + if name.startswith("_"): + continue + setattr(cls, name, getattr(methods_cls, name)) return cls From d45cd8ad6276928b08d44b67150e675efb9109cb Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 5 Mar 2026 03:48:02 +0000 Subject: [PATCH 04/13] update postions init --- atom/model_ops/paged_attention.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index b9c742d8d..a10f8ab4d 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -161,12 +161,13 @@ def new_forward(*args, **kwargs): compilation_config.static_forward_context[self.layer_name] = self if self.use_mla: - max_num_tokens = ( - atom_config.plugin_config.vllm_scheduler_config.max_num_batched_tokens - ) - compilation_config.static_forward_context["positions"] = torch.zeros( - max_num_tokens, dtype=torch.int64, device="cuda" - ) + if "positions" not in compilation_config.static_forward_context: + max_num_tokens = ( + atom_config.plugin_config.vllm_scheduler_config.max_num_batched_tokens + ) + compilation_config.static_forward_context["positions"] = torch.zeros( + max_num_tokens, dtype=torch.int64, device="cuda" + ) return self.num_heads = num_heads From 5e633ed4ae93f888e1d7df76bb9577c93c033b1a Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 5 Mar 2026 05:49:25 +0000 Subject: [PATCH 05/13] cleare code v1 --- atom/model_ops/paged_attention.py | 11 +++-------- atom/plugin/attention_mla.py | 9 ++++----- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index a10f8ab4d..063661802 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -63,7 +63,7 @@ def __init__( self.use_mla = use_mla # for plugin mode if is_vllm(): - self.rotary_emb = mla_modules.rotary_emb + self.rotary_emb = mla_modules.rotary_emb if use_mla else rotary_emb try: from vllm.attention.layer import Attention, MLAAttention, AttentionType @@ -125,11 +125,6 @@ def wrap_kv_b_proj(module_instance): def new_forward(*args, **kwargs): out = orig_impl(*args, **kwargs) - if ( - os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() - == "0" - ): - return out return out, None module_instance.forward = new_forward @@ -165,8 +160,8 @@ def new_forward(*args, **kwargs): max_num_tokens = ( atom_config.plugin_config.vllm_scheduler_config.max_num_batched_tokens ) - compilation_config.static_forward_context["positions"] = torch.zeros( - max_num_tokens, dtype=torch.int64, device="cuda" + compilation_config.static_forward_context["positions"] = ( + torch.zeros(max_num_tokens, dtype=torch.int64, device="cuda") ) return diff --git a/atom/plugin/attention_mla.py b/atom/plugin/attention_mla.py index 093579fca..1b290c4e5 100644 --- a/atom/plugin/attention_mla.py +++ b/atom/plugin/attention_mla.py @@ -183,7 +183,6 @@ def _run_prefill_new_tokens_plugin_mode(self, prefill, q, k, v, return_softmax_l ) def _run_prefill_context_chunk_plugin_mode(self, prefill, chunk_idx, q, k, v): - assert prefill.chunked_context is not None assert prefill.chunked_context is not None return self._flash_attn_varlen_diff_headdims( q=q, @@ -273,7 +272,7 @@ def _context_parallel_compute_prefill_context_plugin_mode( toks=toks, ) - kv_nope = self.kv_b_proj(kv_c_normed).view( + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) @@ -340,7 +339,7 @@ def _compute_prefill_context_plugin_mode( kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) - kv_nope = self.kv_b_proj(kv_c_normed).view( + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) @@ -458,7 +457,7 @@ def _forward_prefill_plugin_mode( output_dtype, ) else: - kv_nope = self.kv_b_proj(kv_c_normed).view( + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split( @@ -467,7 +466,7 @@ def _forward_prefill_plugin_mode( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) else: - kv_nope = self.kv_b_proj(kv_c_normed).view( + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) From 007c829ef6e4e7cf08d1f0d066dccc0307767bda Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 5 Mar 2026 06:35:49 +0000 Subject: [PATCH 06/13] update scale use --- atom/plugin/attention_mla.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atom/plugin/attention_mla.py b/atom/plugin/attention_mla.py index 1b290c4e5..94a2c8d66 100644 --- a/atom/plugin/attention_mla.py +++ b/atom/plugin/attention_mla.py @@ -746,8 +746,8 @@ def forward_impl_plugin_mode( ), decode_q, attn_metadata.plugin_metadata.slot_mapping, - self._k_scale, - self._q_scale, + layer._k_scale, + layer._q_scale, positions, self.rotary_emb.cos_cache, self.rotary_emb.sin_cache, From 4d9971f7ae9832b31e160fbd93437e46d66162b4 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 5 Mar 2026 06:58:03 +0000 Subject: [PATCH 07/13] fix typo --- atom/plugin/attention.py | 5 ----- atom/plugin/attention_mla.py | 4 ++-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index aadf739ed..cb46e6b5a 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -1204,11 +1204,6 @@ def decorator(cls): class vllmAiterMLABackendMethods: - # here attention in ATOM doesn't accept the output buffer because - # ATOM works as a model impl backend, it needs the maximum freedom - # to decide the output buffer and shape, thus here use this flag to - # let vllm don't allocate the output buffer for ATOM. ATOM will - # handle the output buffer by itself accept_output_buffer: bool = True supported_dtypes: list = [torch.float16, torch.bfloat16] diff --git a/atom/plugin/attention_mla.py b/atom/plugin/attention_mla.py index 94a2c8d66..8a839f61e 100644 --- a/atom/plugin/attention_mla.py +++ b/atom/plugin/attention_mla.py @@ -276,9 +276,9 @@ def _context_parallel_compute_prefill_context_plugin_mode( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = self._concat_k_nope_k_pe(k_nope, k_pe) + k = self._concat_k_nope_k_pe_plugin_mode(k_nope, k_pe) - attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + attn_output, attn_softmax_lse = self._run_prefill_context_chunk_plugin_mode( prefill=prefill_metadata, chunk_idx=i, q=q, From 45f7c4ea82bbcb018228aa9d5f43d2cf92e432df Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 5 Mar 2026 07:52:58 +0000 Subject: [PATCH 08/13] fix ruff issue --- atom/model_ops/paged_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index 063661802..0ca91d034 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -3,7 +3,6 @@ # from flash_attn import flash_attn_with_kvcache from typing import Optional -import os import torch from torch import nn From 5c460b9c14a29e390752a9638f2df39d3d453d87 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Fri, 6 Mar 2026 08:05:00 +0000 Subject: [PATCH 09/13] update base_attention --- atom/model_ops/base_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index 116484a9d..65996d09d 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -191,10 +191,10 @@ def fake_( qkv: torch.Tensor, ) -> torch.Tensor: output_shape = list(q.shape) - if use_mla: - output_shape[-1] = 7168 # If we fusion rmsnorm and quant, the input dtype is fp8, but actually we use bf16 for output. atom_config = get_current_atom_config() + if use_mla: + output_shape[-1] = atom_config.hf_config.hidden_size output_dtype = atom_config.torch_dtype output = torch.zeros(output_shape, dtype=output_dtype, device=q.device) From f7341d23e4b195f8855c9c477de4731ccf1c4bf9 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Fri, 6 Mar 2026 10:23:03 +0000 Subject: [PATCH 10/13] clear mla init --- atom/model_ops/attention_mla.py | 29 +--------------------- atom/models/deepseek_v2.py | 11 +-------- atom/plugin/attention_mla.py | 41 ++++++++++++++++++++++++++++++- atom/plugin/vllm/model_wrapper.py | 6 ++++- 4 files changed, 47 insertions(+), 40 deletions(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 03616ab54..e451947fd 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -36,7 +36,7 @@ ) -from atom.plugin import is_plugin_mode, is_vllm +from atom.plugin import is_plugin_mode from atom.plugin.attention_mla import MLAAttentionImplDecoratorForPluginMode @@ -140,33 +140,6 @@ def __init__( ) self.layer_num = layer_num - # for plugin mode(vllm) - if is_vllm(): - self.supports_quant_query_input = False - self.dcp_world_size: int = -1 - from vllm.config import get_current_vllm_config - from vllm.model_executor.layers.attention.mla_attention import ( - MLACommonMetadataBuilder, - ) - - self.chunked_prefill_workspace_size = ( - MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( - get_current_vllm_config() - ) - ) - self.cp_kv_cache_interleave_size: int = ( - get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size - ) - - self.is_aiter_triton_fp4_bmm_enabled = ( - is_rocm_aiter_fp4bmm_enabled() - and self.kv_b_proj.weight.dtype == torch.bfloat16 - ) - # q_pad_num_heads in kwargs - self.q_pad_num_heads = kwargs.get("q_pad_num_heads", None) - self._pad_v = True - self.flash_attn_varlen_func = flash_attn_varlen_func - def process_weights_after_loading(self, act_dtype: Optional[torch.dtype] = None): if is_rocm_aiter_fp4bmm_enabled(): kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 93414d383..effe3a6b8 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -24,7 +24,7 @@ """Inference-only DeepseekV2/DeepseekV3 model.""" import logging -from typing import Optional, Tuple, Union, Iterable +from typing import Optional, Tuple, Union import torch from aiter import ( @@ -1900,15 +1900,6 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - # load weights in plugin mode and discard passed weights generator - from atom.model_loader.loader import load_model_in_plugin_mode - - loaded_weights_record = load_model_in_plugin_mode( - model=self, config=self.atom_config, prefix="model." - ) - return loaded_weights_record - class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/atom/plugin/attention_mla.py b/atom/plugin/attention_mla.py index 8a839f61e..2367f3d44 100644 --- a/atom/plugin/attention_mla.py +++ b/atom/plugin/attention_mla.py @@ -19,6 +19,9 @@ from functools import partial as functools_partial from atom.config import get_current_atom_config from atom.model_ops.linear import use_triton_gemm +from atom.plugin.prepare import is_vllm +from atom.utils import envs + import logging @@ -533,7 +536,7 @@ def _forward_decode_plugin_mode( assert isinstance(q, torch.Tensor) B = q.shape[0] - o = torch.zeros( + o = torch.empty( B, self.num_heads, self.kv_lora_rank, @@ -810,6 +813,33 @@ def forward_impl_plugin_mode( return output_padded +def _mla_plugin_mode_init(self, *args, **kwargs): + """Extra initialization for MLAAttentionImpl in plugin mode (vllm).""" + if is_vllm(): + from vllm.config import get_current_vllm_config + from vllm.model_executor.layers.attention.mla_attention import ( + MLACommonMetadataBuilder, + ) + + self.supports_quant_query_input = False + self.dcp_world_size: int = -1 + self.chunked_prefill_workspace_size = ( + MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( + get_current_vllm_config() + ) + ) + self.cp_kv_cache_interleave_size: int = ( + get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size + ) + self.is_aiter_triton_fp4_bmm_enabled = ( + envs.ATOM_USE_TRITON_MXFP4_BMM + and self.kv_b_proj.weight.dtype == torch.bfloat16 + ) + self.q_pad_num_heads = kwargs.get("q_pad_num_heads", None) + self._pad_v = True + self.flash_attn_varlen_func = aiter.flash_attn_varlen_func + + def MLAAttentionImplDecoratorForPluginMode(cls): method_names = [ "_concat_k_nope_k_pe_plugin_mode", @@ -830,4 +860,13 @@ def MLAAttentionImplDecoratorForPluginMode(cls): method = getattr(MLAAttentionImplPluginModeMethods, method_name) setattr(cls, method_name, method) + # Wrap __init__ to inject plugin-mode initialization + orig_init = cls.__init__ + + def new_init(self, *args, **kwargs): + orig_init(self, *args, **kwargs) + _mla_plugin_mode_init(self, *args, **kwargs) + + cls.__init__ = new_init + return cls diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 7d6ae1391..d2a2f8e4e 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -20,6 +20,7 @@ import atom # noqa: F401 from atom.plugin.config import generate_atom_config_for_plugin_mode +from atom.model_loader.loader import load_model_in_plugin_mode import logging @@ -134,7 +135,10 @@ def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]], ) -> set[str]: - return self.model.load_weights(weights) + loaded_weights_record = load_model_in_plugin_mode( + model=self.model, config=self.model.atom_config, prefix="model." + ) + return loaded_weights_record def compute_logits( self, From 5d696d719c1c8dfc1644de912b61f136424843aa Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Fri, 6 Mar 2026 10:38:52 +0000 Subject: [PATCH 11/13] clear code --- atom/models/deepseek_v2.py | 1 - atom/plugin/vllm/model_wrapper.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index effe3a6b8..f0342dce1 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1823,7 +1823,6 @@ def __init__( layer_type: type[nn.Module] = DeepseekV2DecoderLayer, ): super().__init__() - self.atom_config = atom_config config = atom_config.hf_config quant_config = atom_config.quant_config self.config = config diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index d2a2f8e4e..622cd0625 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -136,7 +136,7 @@ def load_weights( weights: Iterable[tuple[str, torch.Tensor]], ) -> set[str]: loaded_weights_record = load_model_in_plugin_mode( - model=self.model, config=self.model.atom_config, prefix="model." + model=self.model, config=self.atom_config, prefix="model." ) return loaded_weights_record From 9e8486af9f51ac8c72513e4080f8061bcb2b08c8 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Fri, 6 Mar 2026 15:42:20 +0000 Subject: [PATCH 12/13] avoid copy for quant_func --- atom/model_ops/linear.py | 3 +-- atom/plugin/attention_mla.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 53a8897d1..a3d7b4ef7 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -385,9 +385,8 @@ def forward( if self.quant_type.value == QuantType.per_1x128.value: quant_func = functools_partial(quant_func, transpose_scale=True) if self.quant_type.value != QuantType.per_1x32.value: - # quant_func will call view, so we need to call contiguous to avoid view error x, x_scale = quant_func( - x.contiguous(), + x, quant_dtype=self.params_dtype, scale=getattr(self, "input_scale", None), ) diff --git a/atom/plugin/attention_mla.py b/atom/plugin/attention_mla.py index 2367f3d44..c481a49bb 100644 --- a/atom/plugin/attention_mla.py +++ b/atom/plugin/attention_mla.py @@ -342,7 +342,7 @@ def _compute_prefill_context_plugin_mode( kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + kv_nope = self.kv_b_proj(kv_c_normed.contiguous())[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) From f57e9ff6f243f1d8fa220e2d6899f9a17a435938 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Fri, 6 Mar 2026 16:40:58 +0000 Subject: [PATCH 13/13] simlpe code --- atom/model_ops/paged_attention.py | 13 ------------- atom/plugin/attention_mla.py | 8 ++++---- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index 0ca91d034..17cbf4fbd 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -118,19 +118,6 @@ def __init__( indexer=mla_modules.indexer, **extra_impl_args, ) - - def wrap_kv_b_proj(module_instance): - orig_impl = module_instance.forward - - def new_forward(*args, **kwargs): - out = orig_impl(*args, **kwargs) - return out, None - - module_instance.forward = new_forward - return module_instance - - # vllm kv_b_proj return two values (output, bias), so we need to wrap it for fallback path. - self.attn.impl.kv_b_proj = wrap_kv_b_proj(self.attn.impl.kv_b_proj) else: self.attn = Attention( num_heads=num_heads, diff --git a/atom/plugin/attention_mla.py b/atom/plugin/attention_mla.py index c481a49bb..30f61ae8c 100644 --- a/atom/plugin/attention_mla.py +++ b/atom/plugin/attention_mla.py @@ -275,7 +275,7 @@ def _context_parallel_compute_prefill_context_plugin_mode( toks=toks, ) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + kv_nope = self.kv_b_proj(kv_c_normed).view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) @@ -342,7 +342,7 @@ def _compute_prefill_context_plugin_mode( kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) - kv_nope = self.kv_b_proj(kv_c_normed.contiguous())[0].view( + kv_nope = self.kv_b_proj(kv_c_normed.contiguous()).view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) @@ -460,7 +460,7 @@ def _forward_prefill_plugin_mode( output_dtype, ) else: - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + kv_nope = self.kv_b_proj(kv_c_normed).view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split( @@ -469,7 +469,7 @@ def _forward_prefill_plugin_mode( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) else: - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + kv_nope = self.kv_b_proj(kv_c_normed).view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)