diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index dccb9d62..e451947f 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -35,6 +35,11 @@ batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm, ) + +from atom.plugin import is_plugin_mode + +from atom.plugin.attention_mla import MLAAttentionImplDecoratorForPluginMode + # torch.set_printoptions(threshold=10_000) logger = logging.getLogger("atom") @@ -92,11 +97,12 @@ def dynamic_per_batched_tensor_quant( return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() +@MLAAttentionImplDecoratorForPluginMode class MLAAttention(nn.Module): def __init__( self, num_heads: int, - head_dim: int, + head_size: int, scale: float, num_kv_heads: int, kv_cache_dtype: str, @@ -107,7 +113,7 @@ def __init__( ) -> None: super().__init__() self.num_heads = num_heads - self.head_dim = head_dim + self.head_dim = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype == "fp8" else "auto" @@ -134,7 +140,7 @@ def __init__( ) self.layer_num = layer_num - def process_weights_after_loading(self): + def process_weights_after_loading(self, act_dtype: Optional[torch.dtype] = None): if is_rocm_aiter_fp4bmm_enabled(): kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj) self.W_K, self.W_K_scale, W_V, self.W_V_scale = quark_post_load_weights( @@ -146,7 +152,7 @@ def process_weights_after_loading(self): self.W_K_scale = self.W_K_scale.transpose(-2, -1).contiguous() self.W_V = self.W_V.transpose(-2, -1).contiguous() self.W_V_scale = self.W_V_scale.transpose(-2, -1).contiguous() - else: # is_rocm_aiter_fp8bmm_enabled(): + else: # is_rocm_aiter_fp8bmm_enabled() kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, @@ -175,7 +181,7 @@ def process_weights_after_loading(self): W_V, dtype=dtypes.fp8 ) - def _v_up_proj_and_o_proj(self, x): + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V), Convert from (N, B, V) to (B, N, V) @@ -207,7 +213,7 @@ def _v_up_proj_and_o_proj(self, x): ) # Convert from (B, N, V) to (B, N * V) x = x.reshape(-1, self.num_heads * self.v_head_dim) - return self.o_proj(x) + return x def _q_proj_and_k_up_proj(self, x, x_scale=None): q_nope, q_pe = ( @@ -413,7 +419,7 @@ def _forward_prefill_mha( causal=True, ) - return self.o_proj(output.flatten(start_dim=-2)) + return output.flatten(start_dim=-2) def _forward_prefill_mla( self, @@ -480,7 +486,7 @@ def _forward_prefill_mla( None, ) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) def _forward_decode( self, @@ -555,16 +561,15 @@ def _forward_decode( kv_scale=self._k_scale, ) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) - def forward( + def forward_impl_server_mode( self, - q: torch.Tensor, # query in unified attn + q: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor, - positions: torch.Tensor, - q_scale: Optional[torch.Tensor], - qkv: Optional[torch.Tensor], + positions: torch.Tensor = None, + q_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: # kv_cache = self.kv_cache forward_context: ForwardContext = get_forward_context() @@ -577,7 +582,7 @@ def forward( if forward_context.context.is_dummy_run: # dummy run: skip real attention and return output_shape = list(q.shape) - output_shape[-1] = 7168 + output_shape[-1] = self.num_heads * self.v_head_dim atom_config = get_current_atom_config() output_dtype = atom_config.torch_dtype output = torch.empty(output_shape, dtype=output_dtype, device=q.device) @@ -647,6 +652,41 @@ def forward( return output + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, # query in unified attn + k_nope: torch.Tensor, + k_rope: torch.Tensor, + kv_cache: torch.Tensor = None, + attn_metadata=None, + positions: torch.Tensor = None, + q_scale: Optional[torch.Tensor] = None, + output: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + if is_plugin_mode(): + # forward impl method are added by the decorator + # MLAAttentionImplDecoratorForPluginMode + return self.forward_impl_plugin_mode( + layer=layer, + q=query, + k_c_normed=k_nope, + k_pe=k_rope, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=output, + ) + else: + # only for server mode, keep the original method + return self.forward_impl_server_mode( + q=query, + k_nope=k_nope, + k_rope=k_rope, + positions=positions, + q_scale=q_scale, + ) + @triton.jit def _convert_req_index_to_global_index_kernel( diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index fa4ec6c3..168e0eab 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -19,6 +19,10 @@ from .backends import AttentionBackend, CommonAttentionBuilder +from atom.plugin.prepare import is_plugin_mode +from atom.plugin.attention import AiterMLAAttentionMetadataBuilderDecoratorForPluginMode +from atom.plugin.attention import AiterBackendDecoratorForPluginMode + logger = logging.getLogger("atom") @@ -26,10 +30,11 @@ def cdiv(a, b): return (a + b - 1) // b +@AiterBackendDecoratorForPluginMode class AiterMLABackend(AttentionBackend): @staticmethod def get_name() -> str: - return "ROCM_AITER_MLA" + return "ROCM_AITER_MLA" if not is_plugin_mode() else "CUSTOM" @staticmethod def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]: @@ -40,11 +45,14 @@ def get_impl_cls() -> Type["MLAAttention"]: return MLAAttention +@AiterMLAAttentionMetadataBuilderDecoratorForPluginMode( + default_base_class=CommonAttentionBuilder +) class AiterMLAMetadataBuilder(CommonAttentionBuilder): def __init__(self, model_runner): self.block_size = 1 - super().__init__(model_runner) + CommonAttentionBuilder.__init__(self, model_runner) config = model_runner.config hf_config = config.hf_config self.num_attention_heads = ( @@ -190,7 +198,7 @@ def prepare_mtp_decode(self, bs: int, max_seqlen_q: int, max_seqlen_k: int): return self.set_mla_persistent_worker_buffers(bs, max_seqlen_q) def prepare_prefill(self, batch: ScheduledBatch): - attn_metadata, positions = super().prepare_prefill(batch) + attn_metadata, positions = CommonAttentionBuilder.prepare_prefill(self, batch) bs = batch.total_seqs_num_prefill sum_scheduled_tokens = batch.total_tokens_num_prefill var = self.model_runner.forward_vars diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index 3660b3eb..65996d09 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -191,10 +191,10 @@ def fake_( qkv: torch.Tensor, ) -> torch.Tensor: output_shape = list(q.shape) - if use_mla: - output_shape[-1] = 7168 # If we fusion rmsnorm and quant, the input dtype is fp8, but actually we use bf16 for output. atom_config = get_current_atom_config() + if use_mla: + output_shape[-1] = atom_config.hf_config.hidden_size output_dtype = atom_config.torch_dtype output = torch.zeros(output_shape, dtype=output_dtype, device=q.device) @@ -218,7 +218,15 @@ def unified_attention_with_output_base( atom_config = get_current_atom_config() self = atom_config.compilation_config.static_forward_context[layer_name] if use_mla: - return self.impl.forward(q, k, v, positions, q_scale, qkv) + output = self.impl.forward( + layer=self, + query=q, + k_nope=k, + k_rope=v, + positions=positions, + q_scale=q_scale, + ) + return self.impl.o_proj(output) else: return self.impl.forward( layer=self, diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index ce2ac863..17cbf4fb 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -3,7 +3,6 @@ # from flash_attn import flash_attn_with_kvcache from typing import Optional - import torch from torch import nn @@ -60,15 +59,15 @@ def __init__( **kwargs, ) + self.use_mla = use_mla # for plugin mode if is_vllm(): - self.use_mla = use_mla - self.rotary_emb = rotary_emb + self.rotary_emb = mla_modules.rotary_emb if use_mla else rotary_emb try: - from vllm.attention.layer import Attention, AttentionType + from vllm.attention.layer import Attention, MLAAttention, AttentionType except ImportError: - from vllm.model_executor.layers.attention import Attention + from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.v1.attention.backend import AttentionType atom_config = get_current_atom_config() @@ -88,28 +87,68 @@ def __init__( extra_impl_args["rotary_emb"] = rotary_emb extra_impl_args["q_norm"] = q_norm extra_impl_args["k_norm"] = k_norm - - self.attn = Attention( - num_heads=num_heads, - head_size=head_dim, - scale=scale, - num_kv_heads=num_kv_heads, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - quant_config=quant_config, - logits_soft_cap=None, - per_layer_sliding_window=per_layer_sliding_window, - prefix=f"{prefix}", - attn_type=AttentionType.DECODER, - kv_sharing_target_layer_name=None, - **extra_impl_args, - ) + if use_mla: + extra_impl_args["layer_num"] = layer_num + extra_impl_args["mla_modules"] = mla_modules + + if use_mla: + assert ( + mla_modules.indexer is None + ), "MLAAttention is not supported for sparse mode" + self.num_heads = num_heads + self.v_head_dim = mla_modules.v_head_dim + self.qk_head_dim = mla_modules.qk_head_dim + self.qk_nope_head_dim = mla_modules.qk_nope_head_dim + self.q_proj = mla_modules.q_proj + self.o_proj = mla_modules.o_proj + + self.attn = MLAAttention( + num_heads=num_heads, + scale=scale, + qk_nope_head_dim=mla_modules.qk_nope_head_dim, + qk_rope_head_dim=mla_modules.qk_rope_head_dim, + v_head_dim=mla_modules.v_head_dim, + q_lora_rank=mla_modules.q_lora_rank, + kv_lora_rank=mla_modules.kv_lora_rank, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + kv_b_proj=mla_modules.kv_b_proj, + use_sparse=False, + indexer=mla_modules.indexer, + **extra_impl_args, + ) + else: + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=None, + per_layer_sliding_window=per_layer_sliding_window, + prefix=f"{prefix}", + attn_type=AttentionType.DECODER, + kv_sharing_target_layer_name=None, + **extra_impl_args, + ) compilation_config = atom_config.compilation_config self.layer_name = prefix if self.layer_name in compilation_config.static_forward_context: raise ValueError("Duplicate layer: {}".format(self.layer_name)) compilation_config.static_forward_context[self.layer_name] = self + + if self.use_mla: + if "positions" not in compilation_config.static_forward_context: + max_num_tokens = ( + atom_config.plugin_config.vllm_scheduler_config.max_num_batched_tokens + ) + compilation_config.static_forward_context["positions"] = ( + torch.zeros(max_num_tokens, dtype=torch.int64, device="cuda") + ) return self.num_heads = num_heads @@ -122,7 +161,6 @@ def __init__( self.k_scale = self.v_scale = None self.layer_num = layer_num self.mla_modules = mla_modules - self.use_mla = use_mla self.base_attention = None self.kv_cache = torch.tensor([]) self.indexer = mla_modules.indexer if mla_modules is not None else None @@ -136,7 +174,7 @@ def __init__( use_mla=self.use_mla, ) impl_cls = self.attn_backend.get_impl_cls() - self.impl = impl_cls( + impl_args = dict( num_heads=num_heads, head_dim=head_dim, scale=scale, @@ -153,7 +191,8 @@ def __init__( k_norm=k_norm, **kwargs, ) - + impl_args["head_size" if self.use_mla else "head_dim"] = head_dim + self.impl = impl_cls(**impl_args) compilation_config = atom_config.compilation_config default_name = f"MLA_{layer_num}" if self.use_mla else f"MHA_{layer_num}" self.layer_name = prefix if prefix is not None else default_name diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index 2b136ebc..cb46e6b5 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -1,5 +1,6 @@ -from typing import Generic, Optional +from typing import Generic, Optional, TypeVar import logging +import os from dataclasses import dataclass @@ -9,6 +10,7 @@ from atom.utils import CpuGpuBuffer from atom.utils.forward_context import Context, AttentionMetaData from atom.model_ops.attention_mha import PagedAttentionImpl +from atom.model_ops.attention_mla import MLAAttention logger = logging.getLogger("atom") @@ -17,19 +19,15 @@ @dataclass -class AiterFlashAttentionDecodeMetadata: +class AiterFlashAttentionPhaseMetadata: max_query_len: int min_query_len: int max_seq_len: int query_start_loc: torch.Tensor -@dataclass -class AiterFlashAttentionPrefillMetadata: - max_query_len: int - min_query_len: int - max_seq_len: int - query_start_loc: torch.Tensor +AiterFlashAttentionDecodeMetadata = AiterFlashAttentionPhaseMetadata +AiterFlashAttentionPrefillMetadata = AiterFlashAttentionPhaseMetadata @dataclass @@ -67,7 +65,7 @@ class AiterFlashAttentionChunkPrefillMetadata: @dataclass -class MetadataForPluginMode: +class AiterFlashAttentionMetadataForPluginMode: # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -104,7 +102,7 @@ class MetadataForPluginMode: context: Optional[Context] = None -class vllmAiterBackendMethods: +class vllmAiterAttentionBackendMethods: # here attention in ATOM doesn't accept the output buffer because # ATOM works as a model impl backend, it needs the maximum freedom # to decide the output buffer and shape, thus here use this flag to @@ -157,33 +155,6 @@ def supports_alibi_sqrt(cls) -> bool: return False -def AiterBackendDecoratorForPluginMode(cls): - """ - Decorator for AiterBackend to add specific methods and attributes for plugin mode - """ - is_vllm_mode = is_vllm() - - if is_vllm_mode: - # for vllm, add the required methods - cls.full_cls_name = vllmAiterBackendMethods.full_cls_name - cls.accept_output_buffer = vllmAiterBackendMethods.accept_output_buffer - cls.supported_dtypes = vllmAiterBackendMethods.supported_dtypes - cls.forward_includes_kv_cache_update = ( - vllmAiterBackendMethods.forward_includes_kv_cache_update - ) - cls.get_supported_kernel_block_sizes = ( - vllmAiterBackendMethods.get_supported_kernel_block_sizes - ) - cls.get_kv_cache_shape = vllmAiterBackendMethods.get_kv_cache_shape - cls.is_mla = vllmAiterBackendMethods.is_mla - cls.get_required_kv_cache_layout = ( - vllmAiterBackendMethods.get_required_kv_cache_layout - ) - cls.get_supported_head_sizes = vllmAiterBackendMethods.get_supported_head_sizes - cls.supports_alibi_sqrt = vllmAiterBackendMethods.supports_alibi_sqrt - return cls - - def create_attn_metadata_builder_init_method(base_class): """ Create the init method for metadata builder @@ -503,7 +474,7 @@ def build( graph_bs=context_graph_bs, ) - attn_metadata_for_plugin_mode = MetadataForPluginMode( + attn_metadata_for_plugin_mode = AiterFlashAttentionMetadataForPluginMode( num_actual_tokens=num_actual_tokens, num_actual_kv_tokens=num_actual_kv_tokens, max_query_len=common_attn_metadata.max_query_len, @@ -616,6 +587,693 @@ def decorator(cls): return decorator +# for MLA attention metadata for plugin mode +@dataclass +class AiterMLACommonDecodeMetadataForPluginMode: + block_table: torch.Tensor + seq_lens: torch.Tensor + dcp_tot_seq_lens: torch.Tensor | None + + +@dataclass +class AiterMLADecodeMetadataForPluginMode(AiterMLACommonDecodeMetadataForPluginMode): + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: torch.Tensor | None = None + # The page indices of the paged kv cache + paged_kv_indices: torch.Tensor | None = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: torch.Tensor | None = None + # The query indptr, shape : [num_decode + 1] + qo_indptr: torch.Tensor | None = None + # The dtype of MLA out tensor + attn_out_dtype: torch.dtype = torch.bfloat16 + # The max query output length: int + max_qo_len: int | None = None + + +@dataclass +class AiterMLACommonPrefillMetadataForPluginMode: + """Prefill Specific Metadata""" + + @dataclass + class AiterMLAChunkedContextMetadataForPluginMode: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + seq_lens: torch.Tensor + workspace: torch.Tensor + token_to_seq: torch.Tensor + chunk_total_token: list[int] + + # for mla DCP + padded_local_chunk_seq_lens: list[list[int]] | None = None + local_context_lens_allranks: list[list[int]] | None = None + padded_local_cu_seq_lens: torch.Tensor | None = None + cu_seq_lens_lst: list[list[int]] | None = None + chunk_size: int | None = None + + block_table: torch.Tensor + query_start_loc: torch.Tensor + max_query_len: int + chunked_context: AiterMLAChunkedContextMetadataForPluginMode | None = None + query_seq_lens: torch.Tensor | None = None + workspace_buffer: torch.Tensor | None = None + q_data_type: torch.dtype | None = None + + +D = TypeVar("D", bound=AiterMLACommonDecodeMetadataForPluginMode) + + +@dataclass +class AiterMLACommonMetadataForPluginMode(Generic[D]): + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # The dimension of the attention heads + head_dim: int | None = None + + decode: D | None = None + prefill: AiterMLACommonPrefillMetadataForPluginMode | None = None + + def __post_init__(self): + pass + # if self.head_dim is not None and not MLACommonBackend.supports_head_size( + # self.head_dim + # ): + # raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.") + + +class vllmMLAAttentionMetadataBuilderMethods: + def __init__(self): + raise TypeError( + f"{self.__class__.__name__} is a utility class and should not be instantiated. " + "Its methods are meant to be added to other classes via decorators." + ) + + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_device: torch.Tensor, + max_seq_len: int, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + dcp_tot_seq_lens_device: torch.Tensor | None, + ): + # kernel block size is always 1, although the kv block size is not 1. + device = self.device + num_reqs = seq_lens_device.size(0) + + mask = torch.arange( + block_table_tensor.size(1), + dtype=block_table_tensor.dtype, + device=device, + ).unsqueeze(0) < seq_lens_device.unsqueeze(1) + paged_kv_indices = block_table_tensor[mask] + + # kernel block size is always 1, so each page has exactly 1 token. + # last_page_len is always 1 - just slice the pre-initialized buffer. + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] + + paged_kv_indptr = torch.cat( + [ + torch.zeros(1, dtype=seq_lens_device.dtype, device=device), + seq_lens_device.cumsum(dim=0, dtype=torch.int32), + ] + ) + qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_qo_len = qo_len.max().item() + + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + num_actual_pages = paged_kv_indices.size(0) + + self.paged_kv_indices[:num_actual_pages].copy_( + paged_kv_indices, non_blocking=True + ) + self.paged_kv_indices[num_actual_pages:].fill_(-1) + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + + self.paged_kv_indptr[: 1 + num_reqs].copy_( + paged_kv_indptr, non_blocking=True + ) + self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1]) + paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs] + + # paged_kv_last_page_len already uses the pre-initialized buffer slice + # (set above), so no copy needed - buffer is always 1s. + + self.qo_indptr[: 1 + num_reqs].copy_( + query_start_loc_device, non_blocking=True + ) + self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1] + qo_indptr = self.qo_indptr[: 1 + num_reqs] + + else: + qo_indptr = torch.arange( + 0, num_reqs + 1, step=1, dtype=torch.int32, device=device + ) + + attn_metadata = AiterMLADecodeMetadataForPluginMode( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + qo_indptr=qo_indptr, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, + max_qo_len=max_qo_len, + attn_out_dtype=self.decode_attn_out_dtype, + ) + + return attn_metadata + + def build_for_cudagraph_capture( + self, + common_attn_metadata=None, + ): + return self.build(0, common_attn_metadata) + + def build( + self, + common_prefix_len: int = 0, + common_attn_metadata=None, + fast_build: bool = False, + ): + + from vllm.v1.attention.backends.utils import split_decodes_and_prefills + from vllm.model_executor.layers.attention.mla_attention import ( + QueryLenSupport, + ) + + from vllm.utils.math_utils import cdiv, round_down + from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens + + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens + dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=(self.query_len_support != QueryLenSupport.VARLEN), + ) + ) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + prefill_metadata = None + if num_prefills > 0: + num_computed_tokens_cpu = ( + common_attn_metadata.compute_num_computed_tokens().cpu() + ) + + reqs_start = num_decodes # prefill_start + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + prefill_query_start_loc = ( + query_start_loc[reqs_start:] - query_start_loc[reqs_start] + ) + + chunked_context_metadata = None + if max_context_len_cpu > 0: + # NOTE: it is recommend you read the `Chunked Prefill` section + # in the comment at the top of the file before trying to + # understand the following code + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = ( + self.chunked_prefill_workspace_size // num_prefills_with_context_cpu + ) + + if self.aot_schedule: + # align max_context_chunk to page_size by rounding down, + # currently the `gather_and_maybe_dequant_cache` kernel + # cannot handle `context_chunk_starts` that are not aligned + # to page_size + max_context_chunk = round_down(max_context_chunk, self.page_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks + # like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + # Note(simon): this is done in CPU because of downstream's + # of `to_list`. + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) + * max_context_chunk + ) + chunk_ends = torch.min( + context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk + ) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + + cu_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32, + ) + chunk_total_token = cu_seq_lens_cpu[:, -1] + + max_token_num_over_chunk = chunk_total_token.max().item() + token_to_seq_tensor_cpu = torch.zeros( + [num_chunks, max_token_num_over_chunk], dtype=torch.int32 + ) + range_idx = torch.arange(num_prefills, dtype=torch.int32) + for i in range(num_chunks): + chunk_token_to_seq_tensor = torch.repeat_interleave( + range_idx, chunk_seq_lens[i] + ) + chunk_len = chunk_token_to_seq_tensor.shape[0] + token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor + + if self.dcp_world_size > 1: + local_context_lens_allranks = get_dcp_local_seq_lens( + context_lens_cpu, + self.dcp_world_size, + None, + self.dcp_local_block_size, + ) + # Note(qcs): The max local context lengths + # padded to `dcp_local_block_size`. + padded_local_context_lens_cpu: torch.Tensor = ( + cdiv( + context_lens_cpu, + self.dcp_virtual_block_size, + ) + * self.dcp_local_block_size + ) + # Note(hc): The above max_context_chunk already enforces + # block_size alignment, DCP just need the block_size can + # be divisible by dcp_world_size, because DCP use + # cp_gather_cache which not require `cp_chunk_starts` + # aligned to page_size. + assert max_context_chunk % self.dcp_world_size == 0 + padded_local_max_context_chunk_across_ranks = ( + cdiv( + max_context_chunk, + self.dcp_virtual_block_size, + ) + * self.dcp_local_block_size + ) + local_chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) + * padded_local_max_context_chunk_across_ranks + ) + local_chunk_ends = torch.min( + padded_local_context_lens_cpu.unsqueeze(0), + local_chunk_starts + + padded_local_max_context_chunk_across_ranks, + ) + padded_local_chunk_seq_lens = ( + local_chunk_ends - local_chunk_starts + ).clamp(min=0) + + padded_local_cu_chunk_seq_lens_cpu = torch.zeros( + num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True, + ) + torch.cumsum( + padded_local_chunk_seq_lens, + dim=1, + out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], + dtype=torch.int32, + ) + + chunked_context_metadata_cls = ( + AiterMLACommonPrefillMetadataForPluginMode.AiterMLAChunkedContextMetadataForPluginMode + ) + if self.dcp_world_size > 1: + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=local_chunk_starts.to(device, non_blocking=True), + seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_seq=token_to_seq_tensor_cpu.to( + device, non_blocking=True + ), + chunk_total_token=chunk_total_token.tolist(), + workspace=self.chunked_prefill_workspace, + padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), + local_context_lens_allranks=local_context_lens_allranks.tolist(), + padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to( + device, non_blocking=True + ), + cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), + chunk_size=padded_local_max_context_chunk_across_ranks, + ) + else: + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_seq=token_to_seq_tensor_cpu.to( + device, non_blocking=True + ), + chunk_total_token=chunk_total_token, + workspace=self.chunked_prefill_workspace, + ) + + if self._use_cudnn_prefill: + chunked_context_metadata.seq_lens = chunk_seq_lens + + assert ( + max(chunked_context_metadata.max_seq_lens) + <= self.chunked_prefill_workspace_size + ) + + prefill_metadata = AiterMLACommonPrefillMetadataForPluginMode( + block_table=block_table_tensor[reqs_start:, ...], + query_start_loc=prefill_query_start_loc, + max_query_len=max_query_len, + chunked_context=chunked_context_metadata, + ) + + decode_metadata = None + if num_decodes > 0: + dcp_tot_seq_lens_device = None + if self.dcp_world_size > 1: + dcp_tot_seq_lens_device = seq_lens[:num_decodes] + seq_lens = dcp_local_seq_lens + + # After DCP distribution, the maximum number of tokens for any rank is + # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size, + # and I is cp_kv_cache_interleave_size. + # This eliminates GPU->CPU sync while minimizing workspace + # over-allocation. + num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size + max_seq_len = ( + (max_seq_len + num_partitions - 1) // num_partitions + ) * self.cp_kv_cache_interleave_size + + decode_metadata = self._build_decode( + block_table_tensor=block_table_tensor[:num_decodes, ...], + seq_lens_device=seq_lens[:num_decodes], + max_seq_len=max_seq_len, + query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], + query_start_loc_device=query_start_loc[: num_decodes + 1], + num_decode_tokens=num_decode_tokens, + dcp_tot_seq_lens_device=dcp_tot_seq_lens_device, + ) + + attn_metadata_for_plugin_mode = AiterMLACommonMetadataForPluginMode( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=max_seq_len, + num_actual_tokens=num_tokens, + query_start_loc=query_start_loc, + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + # MLACommonMetadata Chunk prefill specific + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + attn_metadata = AttentionMetaData( + max_seqlen_q=common_attn_metadata.max_query_len, + block_tables=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + plugin_metadata=attn_metadata_for_plugin_mode, + ) + return attn_metadata + + +def create_mla_attn_metadata_builder_init_method(base_class): + """ + Create the init method for metadata builder + """ + + def init_method_under_plugin_mode( + self, + kv_cache_spec=None, + layer_names=None, + config=None, + device=None, + model_runner=None, + ): + base_class.__init__(self, kv_cache_spec, layer_names, config, device) + logger.info("init AiterAttentionMetadataBuilder for plugin mode") + from vllm.config import VllmConfig + + assert isinstance(config, VllmConfig) + + self.vllm_config = config + self.model_config = config.model_config + self.parallel_config = config.parallel_config + self.cache_config = config.cache_config + + self.compilation_config = self.vllm_config.compilation_config + self.decode_attn_out_dtype = self.vllm_config.model_config.dtype + # kernel block size is always 1. + max_num_pages_per_req = self.vllm_config.model_config.max_model_len + max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs + max_num_pages = max_num_reqs * max_num_pages_per_req + + # Preparing persistent buffers + # TODO: we can disambiguate between decode and mixed-prefill decode here + # so we can only use the persistent buffer if a cudagraph is actually + # being used. + + # paged_kv_last_page_len is always 1s (kernel block size is always 1), + # so we create it once and reuse slices in both eager and cudagraph modes. + self.paged_kv_last_page_len = torch.ones( + max_num_reqs, dtype=torch.int32, device=device + ) + + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device + ) + self.paged_kv_indices = torch.zeros( + max_num_pages, dtype=torch.int32, device=device + ) + + self.qo_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device + ) + + return init_method_under_plugin_mode + + +def setup_mla_attn_metadata_builder_base_class_and_attributes(class_dict: dict): + """ + Setup the base class and attributes for attention metadata builder + """ + from vllm.model_executor.layers.attention.mla_attention import ( + MLACommonMetadataBuilder, + QueryLenSupport, + ) + from vllm.v1.attention.backend import AttentionCGSupport + + base_class = MLACommonMetadataBuilder + generic_base = MLACommonMetadataBuilder + needs_generic = True + + # align with vllm rocm aiter fa + class_dict["_cudagraph_support"] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + class_dict["reorder_batch_threshold"] = 1 + class_dict["query_len_support"] = QueryLenSupport.UNIFORM + + return base_class, generic_base, needs_generic, class_dict + + +def AiterMLAAttentionMetadataBuilderDecoratorForPluginMode(default_base_class): + def decorator(cls): + is_vllm_mode = is_vllm() + is_sglang_mode = is_sglang() + + base_class = default_base_class + class_dict = {} + + # record original decorated cls methods + for key, value in cls.__dict__.items(): + if not key.startswith("__") or key in ( + "__annotations__", + "__init__", + "__module__", + "__qualname__", + "__doc__", + ): + class_dict[key] = value + + # handle the generic base class + needs_generic = False + generic_base = None + + if is_vllm_mode: + # get the base class and generic base class + base_class, generic_base, needs_generic, class_dict = ( + setup_mla_attn_metadata_builder_base_class_and_attributes(class_dict) + ) + + # replace the __init__ method in the decorated class + class_dict["__init__"] = create_mla_attn_metadata_builder_init_method( + base_class + ) + + # add the methods to the decorated class + for method_name in dir(vllmMLAAttentionMetadataBuilderMethods): + if not method_name.startswith("__"): + method = getattr( + vllmMLAAttentionMetadataBuilderMethods, method_name + ) + if callable(method): + class_dict[method_name] = method + elif is_sglang_mode: + raise NotImplementedError( + "AttentionMetadataBuilder for sglang is not implemented yet" + ) + + # create the new class + new_class = type(cls.__name__, (base_class,), class_dict) + + # replace the inherit base class for plugin mode, meanwhile support generic base class + if needs_generic and generic_base is not None: + new_class.__orig_bases__ = (generic_base[new_class],) + + return new_class + + return decorator + + +class vllmAiterMLABackendMethods: + accept_output_buffer: bool = True + supported_dtypes: list = [torch.float16, torch.bfloat16] + + def __init__(self): + raise TypeError( + f"{self.__class__.__name__} is a utility class and should not be instantiated. " + "Its methods are meant to be added to other classes via decorators." + ) + + @staticmethod + def get_supported_kernel_block_sizes(): + return [1] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @classmethod + def is_mla(cls) -> bool: + return True + + @staticmethod + def get_required_kv_cache_layout(): + return None + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + @classmethod + def full_cls_name(cls) -> tuple[str, str]: + return (cls.__module__, cls.__qualname__) + + @classmethod + def supports_alibi_sqrt(cls) -> bool: + return False + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + # (num_blocks, num_layers, block_size, head_size) + return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2) + + +def AiterBackendDecoratorForPluginMode(cls): + """ + Decorator for AiterBackend to add specific methods and attributes for plugin mode + """ + is_vllm_mode = is_vllm() + if is_vllm_mode: + if not issubclass(cls.get_impl_cls(), MLAAttention): + methods_cls = vllmAiterAttentionBackendMethods + else: + methods_cls = vllmAiterMLABackendMethods + for name in dir(methods_cls): + if name.startswith("_"): + continue + setattr(cls, name, getattr(methods_cls, name)) + return cls + + # here not register it as a custom op and mark split because vllm # will register attention impl forward as a custom op, so here # avoid duplicated registration, and the split op is registered @@ -635,7 +1293,27 @@ def unified_attention_with_output_base_for_plugin_mode( atom_config = get_current_atom_config() if use_mla: - raise NotImplementedError("MLA is not supported for plugin mode for now") + # raise NotImplementedError("MLA is not supported for plugin mode for now") + kv_c_normed = k + k_pe = v + self = atom_config.compilation_config.static_forward_context[layer_name] + q = self.q_proj(q, q_scale) + q = q.view(-1, self.num_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe + if os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "1": + k_pe = k_pe.unsqueeze(1) + if self.rotary_emb is not None: + q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim :], k_pe + ) + # positions written at model entry (model_wrapper.forward) + output = self.attn( + q, + kv_c_normed, + k_pe, + output_shape=(q.shape[0], self.num_heads * self.v_head_dim), + ) + return self.o_proj(output) else: self = atom_config.compilation_config.static_forward_context[layer_name] # here is the standard vllm attention impl interface diff --git a/atom/plugin/attention_mla.py b/atom/plugin/attention_mla.py new file mode 100644 index 00000000..30f61ae8 --- /dev/null +++ b/atom/plugin/attention_mla.py @@ -0,0 +1,872 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +""" +Plugin mode extensions for MLAAttention. +This module provides additional methods for MLAAttention when running in plugin mode. +""" + +import torch +import aiter +from aiter import dtypes +from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 + +from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm, +) +from aiter.mla import mla_decode_fwd + +from functools import partial as functools_partial +from atom.config import get_current_atom_config +from atom.model_ops.linear import use_triton_gemm +from atom.plugin.prepare import is_vllm +from atom.utils import envs + + +import logging + +logger = logging.getLogger("atom") + + +if use_triton_gemm(): + try: + from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import ( + fused_gemm_a8w8_blockscale_preshuffle_split_cat, + ) + from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import ( + fused_gemm_afp4wfp4_preshuffle_split_cat, + ) + except ImportError as e: + logger.warning(f"Triton fused GEMM split_cat not available: {e}") + fused_gemm_afp4wfp4_preshuffle_split_cat = None + fused_gemm_a8w8_blockscale_preshuffle_split_cat = None + + +def reorg_kvcache( + allgatered_kv_c_normed: torch.Tensor, + allgatered_k_pe: torch.Tensor, + padded_local_chunk_seq_lens_lst: list[int], + local_context_lens_allranks: list[list[int]], + sum_seq_len: int, + max_seq_len: int, + chunk_size: int, + chunk_idx: int, + toks: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + reorg and unpad kvcache after cp local gather to tp layout for attn kernel. + e.g. + allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ..., + T0_4, T0_5, pad, pad, T1_2, pad, ...] + -> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5, + T1_0, T1_1, T1_2, ...] + Args: + padded_local_chunk_seq_lens_lst: local chunk context lengths + under current CP rank. + local_context_lens_allranks: local context lengths on each CP rank. + sum_seq_len: the sum of cp_chunk_seq_lens_lst. + max_seq_len: the max value of cp_chunk_seq_lens_lst. + chunk_size: the local padded max context chunk from + chunked_context_metadata building. + chunk_idx: chunk idx of chunked_prefill. + toks: the number of tokens for local gather cache. + """ + kv_c_segments = [] + k_pe_segments = [] + src_token_idx = 0 + max_seq_len_check = 0 + for padded_local_chunk_seq_len, local_context_lens in zip( + padded_local_chunk_seq_lens_lst, local_context_lens_allranks + ): + cur_seq_len = 0 + for rank, local_context_len in enumerate(local_context_lens): + # Note(qcs): We split the context into multiple chunks, + # depending on the size of the workspace. + # local_context in dcp0: |-----------------| + # local_context in dcp1: |--------------| + # n*padded_local_chunk: |-----|-----|-----| + # local_chunk_len in dcp1: |-----|-----|--| + # so we need update the last chunk length in dcp1. + local_chunk_len = min( + max(0, local_context_len - chunk_idx * chunk_size), + padded_local_chunk_seq_len, + ) + if local_chunk_len != 0: + kv_c_segment = allgatered_kv_c_normed[ + rank * toks + + src_token_idx : rank * toks + + src_token_idx + + local_chunk_len + ] + k_pe_segment = allgatered_k_pe[ + rank * toks + + src_token_idx : rank * toks + + src_token_idx + + local_chunk_len + ] + kv_c_segments.append(kv_c_segment) + k_pe_segments.append(k_pe_segment) + cur_seq_len += local_chunk_len + max_seq_len_check = max(max_seq_len_check, cur_seq_len) + src_token_idx += padded_local_chunk_seq_len + reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) + reorganized_k_pe = torch.cat(k_pe_segments, dim=0) + assert reorganized_kv_c_normed.shape[0] == sum_seq_len + assert reorganized_k_pe.shape[0] == sum_seq_len + assert max_seq_len_check == max_seq_len + return reorganized_kv_c_normed, reorganized_k_pe + + +class MLAAttentionImplPluginModeMethods: + """ + Container class for plugin mode methods. + This class cannot be instantiated - it only serves as a namespace for methods + that will be added to PagedAttentionImpl via decorator. + """ + + def __init__(self): + raise TypeError( + "MLAAttentionImplPluginModeMethods cannot be instantiated. " + "It is only used as a method container for the decorator." + ) + + def _concat_k_nope_k_pe_plugin_mode( + self, k_nope: torch.Tensor, k_pe: torch.Tensor + ) -> torch.Tensor: + """ + Efficiently concatenate k_nope and k_pe tensors along the last dimension. + + This function avoids the performance penalty of torch.cat with expanded + non-contiguous tensors by pre-allocating the output and using direct copies. + + Args: + k_nope: Tensor of shape [..., nope_dim] + k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim] + or [..., pe_dim] + + Returns: + Tensor of shape [..., nope_dim + pe_dim] + """ + k = torch.empty( + (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]), + dtype=k_nope.dtype, + device=k_nope.device, + ) + # Direct copies with efficient broadcasting + k[..., : k_nope.shape[-1]] = k_nope + k[..., k_nope.shape[-1] :] = k_pe + return k + + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + return_lse=return_softmax_lse, + **kwargs, + ) + + return output + + def _run_prefill_new_tokens_plugin_mode(self, prefill, q, k, v, return_softmax_lse): + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.query_start_loc, + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.max_query_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=return_softmax_lse, + ) + + def _run_prefill_context_chunk_plugin_mode(self, prefill, chunk_idx, q, k, v): + assert prefill.chunked_context is not None + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx], + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + + def _context_parallel_compute_prefill_context_plugin_mode( + self, + q, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale, + dcp_world_size, + ): + assert k_scale is None, "DCP not support scaled kvcache now." + assert attn_metadata.plugin_metadata.prefill is not None + prefill_metadata = attn_metadata.plugin_metadata.prefill + assert prefill_metadata.chunked_context is not None + assert prefill_metadata.chunked_context.padded_local_chunk_seq_lens is not None + assert prefill_metadata.chunked_context.local_context_lens_allranks is not None + assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None + assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None + assert prefill_metadata.chunked_context.chunk_size is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + + from vllm import _custom_ops as ops + from vllm.distributed.parallel_state import get_dcp_group + from vllm.v1.attention.ops.merge_attn_states import merge_attn_states + + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + ops.cp_gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[ + i + ], + batch_size=attn_metadata.plugin_metadata.num_prefills, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + # workspace + # |------- N tokens --------|--------- N*dcp_size tokens ----------| + # |<- use for loca_gather ->|<--------- use for allgather -------->| + allgather_offset = workspace.shape[0] // (dcp_world_size + 1) + assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0] + assert toks <= allgather_offset + local_gathered_kvcache = workspace[:toks] + cur_allgather_workspace = workspace[ + allgather_offset : allgather_offset * (1 + dcp_world_size) + ] + assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] + cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size] + cur_allgather_kvcache.copy_( + get_dcp_group().all_gather(local_gathered_kvcache, dim=0) + ) + assert ( + cur_allgather_kvcache.shape[-1] + == self.kv_lora_rank + self.qk_rope_head_dim + ) + allgatered_kv_c_normed, allgatered_k_pe = cur_allgather_kvcache.unsqueeze( + 1 + ).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kv_c_normed, k_pe = reorg_kvcache( + allgatered_kv_c_normed, + allgatered_k_pe, + padded_local_chunk_seq_lens_lst=prefill_metadata.chunked_context.padded_local_chunk_seq_lens[ + i + ], + local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks, + sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], + max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], + chunk_size=prefill_metadata.chunked_context.chunk_size, + chunk_idx=i, + toks=toks, + ) + + kv_nope = self.kv_b_proj(kv_c_normed).view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = self._concat_k_nope_k_pe_plugin_mode(k_nope, k_pe) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk_plugin_mode( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _compute_prefill_context_plugin_mode( + self, + q, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale, + ): + assert attn_metadata.plugin_metadata.prefill is not None + prefill_metadata = attn_metadata.plugin_metadata.prefill + assert prefill_metadata.chunked_context is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + + from vllm import _custom_ops as ops + from vllm.v1.attention.ops.merge_attn_states import merge_attn_states + + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + ops.gather_and_maybe_dequant_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], + token_to_seq=prefill_metadata.chunked_context.token_to_seq[i], + num_tokens=prefill_metadata.chunked_context.chunk_total_token[i], + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] + k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) + + kv_nope = self.kv_b_proj(kv_c_normed.contiguous()).view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = self._concat_k_nope_k_pe_plugin_mode(k_nope, k_pe) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk_plugin_mode( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _forward_prefill_plugin_mode( + self, + q, + kv_c_normed, + k_pe, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale, + output, + ): + # TODO (zyongye): Prefill function hereplugin_metadata + assert attn_metadata.plugin_metadata.prefill is not None + assert self.dcp_world_size != -1 + + has_context = attn_metadata.plugin_metadata.prefill.chunked_context is not None + + if use_triton_gemm(): + weight = self.kv_b_proj.weight + weight_scale = self.kv_b_proj.weight_scale + if ( + fused_gemm_afp4wfp4_preshuffle_split_cat is not None + and weight.dtype == dtypes.fp4x2 + ): # FP4 GEMM + split + cat + m = kv_c_normed.shape[0] + # from aiter.ops.triton.quant import dynamic_mxfp4_quant + # input = kv_c_normed + # input_2d = input.view(-1, input.shape[-1]) + output_dtype = kv_c_normed.dtype + + # q_input, x_scale = dynamic_mxfp4_quant(input_2d) + quant_func = aiter.get_hip_quant(aiter.QuantType.per_1x32) + q_input, x_scale = quant_func( + kv_c_normed, + quant_dtype=dtypes.fp4x2, + shuffle=(m >= 32), + ) + + if m >= 32: + x_scale = x_scale.view(torch.uint8).view(x_scale.shape[0] // 32, -1) + else: + x_scale = x_scale[:m, ...].view(torch.uint8) + + k, v = fused_gemm_afp4wfp4_preshuffle_split_cat( + q_input.view(torch.uint8), + weight.view(torch.uint8).view(weight.shape[0] // 16, -1), + k_pe.expand((-1, self.num_heads, -1)), + x_scale, + weight_scale.view(torch.uint8).view( + weight_scale.shape[0] // 32, -1 + ), + self.qk_nope_head_dim, + self.v_head_dim, + output_dtype, + ) + elif ( + fused_gemm_a8w8_blockscale_preshuffle_split_cat is not None + and weight.dtype == dtypes.fp8 + ): # FP8 GEMM + split + cat + weight_shuffled = weight.reshape( + weight.shape[0] // 16, weight.shape[1] * 16 + ) + + output_dtype = kv_c_normed.dtype + + quant_func = functools_partial( + aiter.get_hip_quant(aiter.QuantType.per_1x128), transpose_scale=True + ) + q_input, x_scale = quant_func( + kv_c_normed, + quant_dtype=dtypes.fp8, + scale=getattr(self.kv_b_proj, "input_scale", None), + ) + + k, v = fused_gemm_a8w8_blockscale_preshuffle_split_cat( + q_input, + weight_shuffled, + k_pe.expand((-1, self.num_heads, -1)), + x_scale, + weight_scale, + self.qk_nope_head_dim, + self.v_head_dim, + output_dtype, + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed).view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + else: + kv_nope = self.kv_b_proj(kv_c_normed).view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + # k = self._concat_k_nope_k_pe_plugin_mode(k_nope, k_pe) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + output_prefill = self._run_prefill_new_tokens_plugin_mode( + prefill=attn_metadata.plugin_metadata.prefill, + q=q, + k=k, + v=v, + return_softmax_lse=has_context, + ) + + from vllm.v1.attention.ops.merge_attn_states import merge_attn_states + + if has_context: + suffix_output, suffix_lse = output_prefill + if self.dcp_world_size > 1: + context_output, context_lse = ( + self._context_parallel_compute_prefill_context_plugin_mode( + q, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale=None, + dcp_world_size=self.dcp_world_size, + ) + ) + else: + context_output, context_lse = self._compute_prefill_context_plugin_mode( + q, kv_c_and_k_pe_cache, attn_metadata, k_scale + ) + + # unpad if necessary + if self._pad_v: + context_output = context_output[..., : v.shape[-1]] + suffix_output = suffix_output[..., : v.shape[-1]] + + output = output.view(-1, self.num_heads, self.v_head_dim) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + else: + output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2) + output.copy_(output_prefill) + + def _forward_decode_plugin_mode( + self, + q, + kv_c_and_k_pe_cache, + attn_metadata, + layer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.plugin_metadata.decode is not None + assert attn_metadata.plugin_metadata.decode.max_qo_len is not None + + # if type(q) is tuple: + # q = torch.cat(q, dim=-1) + + assert isinstance(q, torch.Tensor) + B = q.shape[0] + o = torch.empty( + B, + self.num_heads, + self.kv_lora_rank, + dtype=attn_metadata.plugin_metadata.decode.attn_out_dtype, + device=q.device, + ) + + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + use_persistent_mode = not ( + self.dcp_world_size > 1 and self.kv_cache_dtype == "fp8" + ) + if not use_persistent_mode: + # DP : disable persistent mode to avoid overflow + work_meta_data = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + else: + work_meta_data = attn_metadata.work_meta_data + work_indptr = attn_metadata.work_indptr + work_info_set = attn_metadata.work_info_set + reduce_indptr = attn_metadata.reduce_indptr + reduce_final_map = attn_metadata.reduce_final_map + reduce_partial_map = attn_metadata.reduce_partial_map + + paged_kv_indptr = attn_metadata.plugin_metadata.decode.paged_kv_indptr + paged_kv_indices = attn_metadata.plugin_metadata.decode.paged_kv_indices + + mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + attn_metadata.plugin_metadata.decode.qo_indptr, + paged_kv_indptr, + paged_kv_indices, + attn_metadata.plugin_metadata.decode.paged_kv_last_page_len, + attn_metadata.plugin_metadata.decode.max_qo_len, + sm_scale=self.scale, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=layer._q_scale, + kv_scale=layer._k_scale, + ) + return o, None + + def forward_impl_plugin_mode( + self, + layer, + q, + k_c_normed, + k_pe, + kv_cache, + attn_metadata=None, + output=None, + ): + assert output is not None, "Output tensor must be provided." + from vllm.distributed.parallel_state import get_dcp_group + from vllm import _custom_ops as ops + from vllm.platforms import current_platform + from vllm.v1.attention.ops.common import cp_lse_ag_out_rs + + # create the output here, it use query shape + if attn_metadata is None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + ( + self.chunked_prefill_workspace_size, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + if self.dcp_world_size == -1: + self.dcp_world_size = get_dcp_group().world_size + + fp8_attention = self.kv_cache_dtype.startswith("fp8") + + num_actual_toks = attn_metadata.plugin_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + assert ( + attn_metadata.plugin_metadata.num_decodes is not None + and attn_metadata.plugin_metadata.num_prefills is not None + and attn_metadata.plugin_metadata.num_decode_tokens is not None + ) + + has_decode = attn_metadata.plugin_metadata.num_decodes > 0 + has_prefill = attn_metadata.plugin_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.plugin_metadata.num_decode_tokens + + atom_config = get_current_atom_config() + positions = atom_config.compilation_config.static_forward_context["positions"][ + :num_actual_toks + ] + k_pe = k_pe.unsqueeze(1) + output_padded = output + output = output[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + decode_q = q[:num_decode_tokens] + prefill_q = q[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + + decode_only = has_decode and not has_prefill + if not decode_only: + if self.rotary_emb is not None: + self.rotary_emb(positions, q[..., self.qk_nope_head_dim :], k_pe) + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + aiter.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.plugin_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + + if has_prefill: + self._forward_prefill_plugin_mode( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + kv_cache, + attn_metadata, + layer._k_scale, + output=output[num_decode_tokens:], + ) + + if has_decode: + assert attn_metadata.plugin_metadata.decode is not None + + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + + if self.q_pad_num_heads is not None: + B, N, L = decode_q_pe.shape + decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) + decode_pe_padded.resize_((B, N, L)) + decode_pe_padded.copy_(decode_q_pe) + decode_q_pe = decode_pe_padded + + if self.is_aiter_triton_fp4_bmm_enabled: + decode_ql_nope = batched_gemm_a16wfp4( + decode_q_nope, + self.W_K, + self.W_K_scale, + transpose_bm=True, + prequant=True, + y_scale=layer._q_scale if fp8_attention else None, + ) + # elif self.is_aiter_triton_fp8_bmm_enabled: + else: + # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) + decode_ql_nope = _aiter_triton_fp8_bmm( + decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True, + ) + + if decode_only: + decode_q = torch.empty( + ( + decode_ql_nope.shape[0], + self.num_heads, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + dtype=( + dtypes.fp8 + if self.kv_cache_dtype.startswith("fp8") + else self.dtype + ), + device=decode_ql_nope.device, + ) + aiter.fused_qk_rope_concat_and_cache_mla( + decode_ql_nope, + decode_q_pe, + k_c_normed, + k_pe.squeeze(1), + kv_cache.view( + kv_cache.shape[0], -1, self.kv_lora_rank + self.qk_rope_head_dim + ), + decode_q, + attn_metadata.plugin_metadata.slot_mapping, + layer._k_scale, + layer._q_scale, + positions, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + is_neox=self.rotary_emb.is_neox_style, + is_nope_first=True, + ) + else: + if fp8_attention: + ql_nope_shape = decode_ql_nope.shape + q_pe_shape = decode_q_pe.shape + assert decode_ql_nope.shape[0] == decode_q_pe.shape[0] + assert decode_ql_nope.shape[1] == decode_q_pe.shape[1] + decode_q_shape = ( + ql_nope_shape[0], + ql_nope_shape[1], + ql_nope_shape[2] + q_pe_shape[2], + ) + # Using empty and copy since torch.cat introduces significant overhead. + decode_q0 = torch.empty( + decode_q_shape, + device=decode_ql_nope.device, + dtype=decode_ql_nope.dtype, + ) + decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope) + decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe) + + decode_q, _ = ops.scaled_fp8_quant( + decode_q0.view(decode_q_shape[0], -1), + layer._q_scale, + ) + decode_q = decode_q.view(decode_q_shape) + else: + decode_q = (decode_ql_nope, decode_q_pe) + decode_q = torch.cat(decode_q, dim=-1) + if self.dcp_world_size > 1: + assert not fp8_attention, "DCP not support fp8 kvcache now." + # decode_q do allgather in head dim. + decode_q = get_dcp_group().all_gather(decode_q, dim=1) + + # call decode attn + attn_out, lse = self._forward_decode_plugin_mode( + decode_q, kv_cache, attn_metadata, layer + ) + + # correct dcp attn_out with lse. + if self.dcp_world_size > 1: + attn_out = cp_lse_ag_out_rs( + attn_out, + lse, + get_dcp_group(), + is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), + ) + + # v_up projection + # self._v_up_proj(attn_out, out=output[:num_decode_tokens]) + # TODO: remove this copy. + out_up_proj = self._v_up_proj(attn_out) + output[:num_decode_tokens] = out_up_proj + + return output_padded + + +def _mla_plugin_mode_init(self, *args, **kwargs): + """Extra initialization for MLAAttentionImpl in plugin mode (vllm).""" + if is_vllm(): + from vllm.config import get_current_vllm_config + from vllm.model_executor.layers.attention.mla_attention import ( + MLACommonMetadataBuilder, + ) + + self.supports_quant_query_input = False + self.dcp_world_size: int = -1 + self.chunked_prefill_workspace_size = ( + MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( + get_current_vllm_config() + ) + ) + self.cp_kv_cache_interleave_size: int = ( + get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size + ) + self.is_aiter_triton_fp4_bmm_enabled = ( + envs.ATOM_USE_TRITON_MXFP4_BMM + and self.kv_b_proj.weight.dtype == torch.bfloat16 + ) + self.q_pad_num_heads = kwargs.get("q_pad_num_heads", None) + self._pad_v = True + self.flash_attn_varlen_func = aiter.flash_attn_varlen_func + + +def MLAAttentionImplDecoratorForPluginMode(cls): + method_names = [ + "_concat_k_nope_k_pe_plugin_mode", + "_flash_attn_varlen_diff_headdims", + "_run_prefill_new_tokens_plugin_mode", + "_run_prefill_context_chunk_plugin_mode", + "_context_parallel_compute_prefill_context_plugin_mode", + "_compute_prefill_context_plugin_mode", + "_forward_prefill_plugin_mode", + "_forward_decode_plugin_mode", + "forward_impl_plugin_mode", + ] + + logger.info("Use MLAAttentionImplDecoratorForPluginMode to decorate MLAAttention") + + # Add all methods to the target class + for method_name in method_names: + method = getattr(MLAAttentionImplPluginModeMethods, method_name) + setattr(cls, method_name, method) + + # Wrap __init__ to inject plugin-mode initialization + orig_init = cls.__init__ + + def new_init(self, *args, **kwargs): + orig_init(self, *args, **kwargs) + _mla_plugin_mode_init(self, *args, **kwargs) + + cls.__init__ = new_init + + return cls diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index bcca17ad..622cd062 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -20,6 +20,7 @@ import atom # noqa: F401 from atom.plugin.config import generate_atom_config_for_plugin_mode +from atom.model_loader.loader import load_model_in_plugin_mode import logging @@ -29,6 +30,7 @@ _ATOM_MODEL_CLASSES: dict[str, str] = { "Qwen3ForCausalLM": "atom.models.qwen3:Qwen3ForCausalLM", "Qwen3MoeForCausalLM": "atom.models.qwen3_moe:Qwen3MoeForCausalLM", + "DeepseekV3ForCausalLM": "atom.models.deepseek_v2:DeepseekV3ForCausalLM", } @@ -108,6 +110,14 @@ def forward( input_ids = None inputs_embeds = intermediate_tensors["hidden_states"] + # capture. This ensures attention_mla reads correct positions in graph mode. + # This is only for mla attention in plugin mode. + if "positions" in self.atom_config.compilation_config.static_forward_context: + buf = self.atom_config.compilation_config.static_forward_context[ + "positions" + ] + buf[: positions.numel()].copy_(positions) + hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -125,7 +135,10 @@ def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]], ) -> set[str]: - return self.model.load_weights(weights) + loaded_weights_record = load_model_in_plugin_mode( + model=self.model, config=self.atom_config, prefix="model." + ) + return loaded_weights_record def compute_logits( self, diff --git a/atom/plugin/vllm/platform.py b/atom/plugin/vllm/platform.py index d003c4ae..7e3903c2 100644 --- a/atom/plugin/vllm/platform.py +++ b/atom/plugin/vllm/platform.py @@ -9,7 +9,6 @@ import os logger = logging.getLogger("atom") - # This flag is used to enable the vLLM plugin mode. disable_vllm_plugin = os.getenv("ATOM_DISABLE_VLLM_PLUGIN", "0").lower() == "1" disable_vllm_plugin_attention = ( @@ -31,6 +30,8 @@ def get_attn_backend_cls(cls, selected_backend, attn_selector_config) -> str: ) logger.info("Use atom attention backend") + if attn_selector_config.use_mla: + return "atom.model_ops.attentions.aiter_mla.AiterMLABackend" return "atom.model_ops.attentions.aiter_attention.AiterBackend" else: diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py index 0330929a..38bd571b 100644 --- a/atom/plugin/vllm/register.py +++ b/atom/plugin/vllm/register.py @@ -22,6 +22,7 @@ _VLLM_MODEL_REGISTRY_OVERRIDES: dict[str, str] = { "Qwen3ForCausalLM": ATOM_CAUSAL_LM_MODEL_WRAPPER, "Qwen3MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, + "DeepseekV3ForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, } @@ -43,13 +44,8 @@ def register_platform() -> Optional[str]: return "atom.plugin.vllm.platform.ATOMPlatform" -def _patch_vllm_attention_process_weights_after_loading() -> None: - try: - from vllm.attention.layer import Attention - except ImportError: - from vllm.model_executor.layers.attention import Attention - - orig = Attention.process_weights_after_loading +def _patch_vllm_attention_process_weights_after_loading(attention) -> None: + orig = attention.process_weights_after_loading if getattr(orig, "_atom_default_act_dtype_patched", False): return @@ -74,7 +70,7 @@ def wrapped(self, act_dtype: "torch.dtype" = torch.bfloat16): return orig(self, act_dtype) setattr(wrapped, "_atom_default_act_dtype_patched", True) - Attention.process_weights_after_loading = wrapped + attention.process_weights_after_loading = wrapped def register_model() -> None: @@ -107,4 +103,10 @@ def register_model() -> None: # patch attention process weights after loading # to avoid the specific handle in ATOM loader - _patch_vllm_attention_process_weights_after_loading() + try: + from vllm.attention.layer import Attention, MLAAttention + except ImportError: + from vllm.model_executor.layers.attention import Attention, MLAAttention + + _patch_vllm_attention_process_weights_after_loading(Attention) + _patch_vllm_attention_process_weights_after_loading(MLAAttention) diff --git a/atom/utils/backends.py b/atom/utils/backends.py index 57eec887..8058595c 100644 --- a/atom/utils/backends.py +++ b/atom/utils/backends.py @@ -555,9 +555,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: hash_content = [] for filepath in forward_code_files: hash_content.append(filepath) - if filepath == "": + if filepath == "" or filepath == "": # This means the function was dynamically generated, with - # e.g. exec(). We can't actually check these. + # e.g. exec() or frozen os module. We can't actually check these. continue with open(filepath) as f: hash_content.append(f.read())