From cae53bd32c2c72daa7edb4c5dcb01d039e06c05c Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Thu, 15 Jan 2026 15:31:07 +0800 Subject: [PATCH 01/37] [feat][plugin] Make ATOM work as plugin for upper framework Signed-off-by: zejunchen-zejun --- atom/__init__.py | 3 + atom/config.py | 39 +- atom/model_engine/model_runner.py | 6 +- atom/model_loader/loader.py | 63 +- atom/model_ops/__init__.py | 9 + atom/model_ops/attention_mha.py | 64 +- atom/model_ops/attention_mla.py | 4 +- atom/model_ops/attentions/aiter_attention.py | 34 +- atom/model_ops/base_attention.py | 85 +-- atom/model_ops/embed_head.py | 16 +- atom/model_ops/moe.py | 2 + atom/model_ops/paged_attention.py | 177 +++++ atom/model_ops/radix_attention.py | 106 +++ atom/models/deepseek_v2.py | 5 +- atom/models/gpt_oss.py | 5 +- atom/models/llama.py | 5 +- atom/models/mixtral.py | 5 +- atom/models/qwen3.py | 120 ++- atom/models/qwen3_moe.py | 108 +-- atom/plugin/__init__.py | 6 + atom/plugin/attention.py | 624 ++++++++++++++++ atom/plugin/attention_mha.py | 737 +++++++++++++++++++ atom/plugin/config.py | 234 ++++++ atom/plugin/moe.py | 60 ++ atom/plugin/prepare.py | 85 +++ atom/plugin/register.py | 106 +++ atom/utils/backends.py | 21 +- atom/utils/forward_context.py | 11 + recipes/Model-Impl-Backend.md | 168 +++++ 29 files changed, 2720 insertions(+), 188 deletions(-) create mode 100644 atom/model_ops/__init__.py create mode 100644 atom/model_ops/paged_attention.py create mode 100644 atom/model_ops/radix_attention.py create mode 100644 atom/plugin/__init__.py create mode 100644 atom/plugin/attention.py create mode 100644 atom/plugin/attention_mha.py create mode 100644 atom/plugin/config.py create mode 100644 atom/plugin/moe.py create mode 100644 atom/plugin/prepare.py create mode 100644 atom/plugin/register.py create mode 100644 recipes/Model-Impl-Backend.md diff --git a/atom/__init__.py b/atom/__init__.py index c1f9ed8b2..bcfab6841 100644 --- a/atom/__init__.py +++ b/atom/__init__.py @@ -3,3 +3,6 @@ from atom.model_engine.llm_engine import LLMEngine from atom.sampling_params import SamplingParams + +# interface for upper framework to constructe the model from ATOM +from atom.plugin import prepare_model diff --git a/atom/config.py b/atom/config.py index 677ae41c9..597f29146 100644 --- a/atom/config.py +++ b/atom/config.py @@ -17,6 +17,10 @@ from torch.distributed import ProcessGroup, ReduceOp from transformers import AutoConfig, GenerationConfig, PretrainedConfig +# only for plugin mode +from atom.plugin import is_plugin_mode +from atom.plugin.config import PluginConfig + logger = logging.getLogger("atom") @@ -584,6 +588,9 @@ class Config: torch_dtype: torch.dtype = field(init=False) speculative_config: Optional[SpeculativeConfig] = None + # only use for plugin mode + plugin_config: Optional[PluginConfig] = None + def _set_cudagraph_sizes(self): if self.compilation_config.cudagraph_capture_sizes: self.graph_bs = self.compilation_config.cudagraph_capture_sizes @@ -602,7 +609,10 @@ def __post_init__(self): self.kv_cache_block_size % 16 == 0 or self.kv_cache_block_size == 1 ), f"kv_cache_block_size ({self.kv_cache_block_size}) must be a multiple of 16 or 1" assert 1 <= self.tensor_parallel_size <= 8 - self.hf_config = get_hf_config(self.model) + if is_plugin_mode(): + self.hf_config = self.plugin_config.model_config.hf_config + else: + self.hf_config = get_hf_config(self.model) if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 rope_params = getattr(self.hf_config, "rope_scaling", {}) @@ -626,16 +636,23 @@ def __post_init__(self): self.max_model_len, hf_config_max_position_embeddings ) # assert self.max_num_batched_tokens >= self.max_model_len - if self.torch_profiler_dir is not None: - os.makedirs(self.torch_profiler_dir, exist_ok=True) - assert self.torch_profiler_dir is None or os.path.isdir( - self.torch_profiler_dir - ), f"torch_profiler_dir {self.torch_profiler_dir} is not a valid directory" - if self.compilation_config.level == CompilationLevel.PIECEWISE: - self.compilation_config.set_splitting_ops_for_v1() - self._set_cudagraph_sizes() - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - self.compilation_config.init_with_cudagraph_sizes() + if not is_plugin_mode(): + if self.torch_profiler_dir is not None: + os.makedirs(self.torch_profiler_dir, exist_ok=True) + assert self.torch_profiler_dir is None or os.path.isdir( + self.torch_profiler_dir + ), f"torch_profiler_dir {self.torch_profiler_dir} is not a valid directory" + + # only for server mode or plugin mode(vllm) + # for torch compile policy, plugin mode(vllm) uses the ATOM compile policy + # for cuda graph capture, plugin mode(vllm) uses the vLLM's cuda graph capture policy + if not is_plugin_mode() or self.plugin_config.is_vllm: + if self.compilation_config.level == CompilationLevel.PIECEWISE: + self.compilation_config.set_splitting_ops_for_v1() + self._set_cudagraph_sizes() + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + self.compilation_config.init_with_cudagraph_sizes() + self.torch_dtype = ( self.hf_config.torch_dtype if getattr(self.hf_config, "torch_dtype", None) is not None diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 6861af392..455cde057 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -608,7 +608,7 @@ def __init__(self, rank: int, config: Config): self.drafter.load_model(self.model) torch.set_default_device(self.device) self.allocate_forward_vars() - self.attn_metadata_builder = self.attn_backend.get_builder_cls()(self) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()(model_runner=self) self.physical_block_size = self.attn_metadata_builder.block_size self.forward_done_event = torch.cuda.Event() self.warmup_model() @@ -1171,7 +1171,7 @@ def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None): self.forward_vars["cu_seqlens_q"].np[scheduled_bs + 1 : bs + 1] = ( self.forward_vars["cu_seqlens_q"].np[scheduled_bs] ) - attn_metadata, positions = self.attn_metadata_builder.build(batch, bs) + attn_metadata, positions = self.attn_metadata_builder.build(batch=batch, bs=bs) context_bs = batch.total_seqs_num_prefill if is_prefill else scheduled_bs # graph_bs should be batch size (number of sequences), not token count @@ -1472,7 +1472,7 @@ def capture_cudagraph(self): ) attn_metadata, context = ( - self.attn_metadata_builder.build_for_cudagraph_capture(bs) + self.attn_metadata_builder.build_for_cudagraph_capture(bs=bs) ) num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index 032b51385..7c6e4f314 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -29,6 +29,7 @@ get_spec_layer_idx_from_weight_name, rewrite_spec_layer_name, ) +from atom.plugin.prepare import is_vllm logger = logging.getLogger("atom") @@ -80,13 +81,58 @@ def safetensors_weights_iterator( yield name, f.get_tensor(name) +# when plugin mode, model loader method is bind to model implementation +# thus call this interface to load the model, which leverags the load_model +# method +def load_model_in_plugin_mode( + model, + config, + prefix: str = "", +) -> set[str] | None: + + # during loading model, the outplace operation may consume more + # GPU mem, which cached in torch caching allocator, here actively + # call empty cache to free the extra reserved but not used memory + def _empty_cache(): + import gc + gc.collect() + torch.cuda.empty_cache() + + assert config.plugin_config is not None and \ + config.plugin_config.is_plugin_mode, \ + "ATOM is not running in plugin mode" + if config.plugin_config.is_vllm: + model_name_or_path = config.plugin_config.model_config.model + elif config.plugin_config.is_sglang: + model_name_or_path = config.plugin_config.model_config.model_path + + _empty_cache() + loaded_weights_record = load_model(model=model, + model_name_or_path=model_name_or_path, + hf_config=config.hf_config, + load_dummy=config.load_dummy, + spec_decode=False, + prefix=prefix, + is_plugin_mode=True, + act_dtype=config.plugin_config.model_config.dtype) + _empty_cache() + return loaded_weights_record + + def load_model( model: nn.Module, model_name_or_path: str, hf_config: AutoConfig, load_dummy: bool = False, spec_decode: bool = False, + prefix: str = "", + is_plugin_mode: bool = False, + act_dtype: torch.dtype = None, ): + # need to record the loaded weight name for vllm load check + # it is only used in plugin mode for vllm + loaded_weights_record = set[str]() + packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) weights_mapping = getattr(model, "weights_mapping", {}) params_dict = dict(model.named_parameters()) @@ -145,6 +191,7 @@ def load_model( weight_loader, param, weight_tensor, shard_id ) ) + loaded_weights_record.add(prefix + param_name) break else: # Check if model has expert mapping before processing @@ -170,6 +217,7 @@ def load_model( expert_id, ) ) + loaded_weights_record.add(prefix + name) # weight_loader( # param, # weight_tensor, @@ -186,6 +234,7 @@ def load_model( futures.append( executor.submit(weight_loader, param, weight_tensor) ) + loaded_weights_record.add(prefix + name) # weight_loader(param, weight_tensor) else: # Model doesn't have expert mapping, use generic loading @@ -195,14 +244,26 @@ def load_model( ) # weight_loader(param, weight_tensor) futures.append(executor.submit(weight_loader, param, weight_tensor)) + loaded_weights_record.add(prefix + name) # Wait for all tasks to complete and raise any exceptions. for future in concurrent.futures.as_completed(futures): future.result() for _, module in model.named_modules(): if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() + if is_vllm(): + from vllm.attention.layer import Attention + # call vLLM attn weights post processing with act_dtype if using vLLM attention module + if isinstance(module, Attention): + module.process_weights_after_loading(act_dtype=act_dtype) + else: + module.process_weights_after_loading() + else: + module.process_weights_after_loading() quant_method = getattr(module, "quant_method", None) if isinstance(quant_method, QuantizeMethodBase): quant_method.process_weights_after_loading(module) if isinstance(quant_method, FusedMoEMethodBase): quant_method.init_prepare_finalize(module) + + if is_plugin_mode: + return loaded_weights_record diff --git a/atom/model_ops/__init__.py b/atom/model_ops/__init__.py new file mode 100644 index 000000000..9c06892ec --- /dev/null +++ b/atom/model_ops/__init__.py @@ -0,0 +1,9 @@ +from .paged_attention import PagedAttention +from .radix_attention import RadixAttention + +# this global class is used to construct the attention op in model +# it can be assigned to different attention op +# default PagedAttention is used as ATOM for now supports PagedAttention +# for sglang, RadixAttention will be assigned to ATTN_CLS +# TODO: add env flag or argument to swicth the attention class +ATTN_CLS = PagedAttention diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 31f38bb6a..e5be98244 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -15,20 +15,30 @@ from .attention_mla import MLAModules +from atom.plugin.prepare import is_plugin_mode, is_vllm +from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode -class Attention(nn.Module): +@PagedAttentionImplDecoratorForPluginMode +class PagedAttentionImpl(nn.Module): + """ + Attention paged implementation + """ def __init__( self, num_heads, head_dim, scale, num_kv_heads, + alibi_slopes: list[float] | None, + sliding_window: Optional[int] = None, kv_cache_dtype="bf16", + logits_soft_cap: float | None = None, + attn_type = None, + kv_sharing_target_layer_name: int | None = None, layer_num=0, mla_modules: Optional[MLAModules] = None, sinks: Optional[nn.Parameter] = None, - sliding_window: Optional[int] = None, rotary_emb: Optional[torch.nn.Module] = None, q_norm: Optional[torch.nn.Module] = None, k_norm: Optional[torch.nn.Module] = None, @@ -37,12 +47,16 @@ def __init__( super().__init__() self.num_heads = num_heads self.head_dim = head_dim + # for upper framework, it uses head_size in built-in methods + self.head_size = head_dim self.scale = scale self.num_kv_heads = num_kv_heads + self.alibi_slopes = alibi_slopes self.k_cache = self.v_cache = torch.tensor([]) self.kv_cache_dtype = kv_cache_dtype self.max_model_len = 0 self.k_scale = self.v_scale = None + self.device = 'cuda:' + str(torch.cuda.current_device()) self.layer_num = layer_num self.kv_scale_float = ( torch.finfo(torch.float8_e4m3fn).max / torch.finfo(aiter.dtypes.fp8).max @@ -56,7 +70,14 @@ def __init__( self.q_norm = q_norm self.k_norm = k_norm - def forward( + # for plugin mode(vllm), the query quant is disabled for now + if is_vllm(): + self.supports_quant_query_input = False + + def process_weights_after_loading(self, act_dtype: torch.dtype = torch.bfloat16): + pass + + def forward_impl_server_mode( self, q: torch.Tensor, k: torch.Tensor, @@ -414,3 +435,40 @@ def dispatch_backend(self, fwd_ctx: ForwardContext): if atom_config.kv_cache_block_size == 1024: return self.paged_attention_persistent_asm return self.paged_attention_asm + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor = None, + attn_metadata = None, + position: torch.Tensor = None, + q_scale: Optional[torch.Tensor]=None, + qkv: torch.Tensor = None, + output: torch.Tensor = None, + **kwargs, + ): + if is_plugin_mode(): + # forward impl method are added by the decorator + # PagedAttentionImplDecoratorForPluginMode + return self.forward_impl_plugin_mode(layer=layer, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + position=position, + q_scale=q_scale, + qkv=qkv) + else: + # only for server mode, keep the original method + o = self.forward_impl_server_mode(q=query, + k=key, + v=value, + position=position, + q_scale=q_scale, + qkv=qkv) + + return o diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 6b2452cd1..30ebc2fd3 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -93,7 +93,7 @@ class MLAAttention(nn.Module): def __init__( self, num_heads: int, - head_size: int, + head_dim: int, scale: float, num_kv_heads: int, kv_cache_dtype: str, @@ -104,7 +104,7 @@ def __init__( ) -> None: super().__init__() self.num_heads = num_heads - self.head_size = head_size + self.head_dim = head_dim self.scale = float(scale) self.num_kv_heads = num_kv_heads self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype == "fp8" else "auto" diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index 0bcea5a6c..ef69c40dc 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -9,34 +9,54 @@ import torch from aiter.dist.parallel_state import get_tp_group from atom.model_engine.scheduler import ScheduledBatch -from atom.model_ops.attention_mha import Attention from atom.utils import CpuGpuBuffer +import atom.model_ops as ops +from atom.model_ops.paged_attention import PagedAttention +from atom.model_ops.attention_mha import PagedAttentionImpl +from atom.model_ops.radix_attention import RadixAttention from atom.utils.block_convert import block_table_convert_triton from atom.utils.forward_context import AttentionMetaData, Context from .backends import AttentionBackend, CommonAttentionBuilder +from atom.plugin.prepare import is_plugin_mode +from atom.plugin.attention import AiterAttentionMetadataBuilderDecoratorForPluginMode +from atom.plugin.attention import AiterBackendDecoratorForPluginMode +@AiterBackendDecoratorForPluginMode class AiterBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "ROCM_AITER_ATTENTION" + return "ROCM_AITER_ATTENTION" if not is_plugin_mode() else "CUSTOM" @staticmethod def get_builder_cls() -> Type["AiterAttentionMetadataBuilder"]: return AiterAttentionMetadataBuilder @staticmethod - def get_impl_cls() -> Type["Attention"]: - return Attention + def get_impl_cls(): + if ops.ATTN_CLS == PagedAttention: + return PagedAttentionImpl + elif ops.ATTN_CLS == RadixAttention: + raise NotImplementedError("RadixAttention is not supported for now") -class AiterAttentionMetadataBuilder(CommonAttentionBuilder): +@AiterAttentionMetadataBuilderDecoratorForPluginMode(default_base_class=CommonAttentionBuilder) +class AiterAttentionMetadataBuilder: BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] - def __init__(self, model_runner): + def __init__( + self, + kv_cache_spec = None, + layer_names = None, + config = None, + device = None, + model_runner = None, + ): self.block_size = 1024 if model_runner.block_size == 1024 else 16 - super().__init__(model_runner) + # Note: Cannot use super() here because the class is dynamically created by decorator + # Use explicit parent class call instead + CommonAttentionBuilder.__init__(self, model_runner) config = model_runner.config hf_config = config.hf_config self.num_attention_heads = ( diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index e12c1a386..544aa5fba 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -3,6 +3,7 @@ # from flash_attn import flash_attn_with_kvcache from typing import Optional +from abc import ABC, abstractmethod import torch from torch import nn @@ -51,10 +52,23 @@ def unified_attention_with_output_base( ) -> torch.Tensor: atom_config = get_current_atom_config() self = atom_config.compilation_config.static_forward_context[layer_name] - return self.impl.forward(q, k, v, positions, q_scale, qkv) + if use_mla: + return self.impl.forward(q, k, v, positions, q_scale, qkv) + else: + return self.impl.forward(layer=self, + query=q, + key=k, + value=v, + position=positions, + q_scale=q_scale, + qkv=qkv) +class BaseAttention(nn.Module, ABC): + """ + Abstract base class for attention -class Attention(nn.Module): + This class defines the interface that all attention implementations must follow + """ def __init__( self, @@ -70,68 +84,21 @@ def __init__( per_layer_sliding_window: Optional[int] = None, rotary_emb: Optional[torch.nn.Module] = None, prefix: Optional[str] = None, - q_norm: Optional[torch.nn.Module] = None, - k_norm: Optional[torch.nn.Module] = None, **kwargs, ): super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - self.scale = scale - self.num_kv_heads = num_kv_heads - self.k_cache = self.v_cache = torch.tensor([]) - self.kv_cache_dtype = kv_cache_dtype - self.max_model_len = 0 - self.k_scale = self.v_scale = None - self.layer_num = layer_num - self.mla_modules = mla_modules - self.use_mla = use_mla - self.base_attention = None - self.kv_cache = torch.tensor([]) - self.indexer = mla_modules.indexer if mla_modules is not None else None - self.sinks = sinks - - atom_config = get_current_atom_config() - dtype = atom_config.torch_dtype - block_size = atom_config.kv_cache_block_size - self.attn_backend = get_attn_backend( - block_size, - use_mla=self.use_mla, - ) - impl_cls = self.attn_backend.get_impl_cls() - self.impl = impl_cls( - num_heads, - head_dim, - scale, - num_kv_heads, - kv_cache_dtype, - layer_num, - mla_modules, - sinks=sinks, - sliding_window=per_layer_sliding_window, - rotary_emb=rotary_emb, - dtype=dtype, - q_norm=q_norm, - k_norm=k_norm, - ) - - compilation_config = atom_config.compilation_config - default_name = f"MLA_{layer_num}" if self.use_mla else f"MHA_{layer_num}" - self.layer_name = prefix if prefix is not None else default_name - if self.layer_name in compilation_config.static_forward_context: - raise ValueError("Duplicate layer: {}".format(self.layer_name)) - compilation_config.static_forward_context[self.layer_name] = self + @abstractmethod def forward( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - positions: torch.Tensor = None, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: Optional[torch.Tensor] = None, q_scale: Optional[torch.Tensor] = None, - qkv: torch.Tensor = None, - ): - output = torch.ops.aiter.unified_attention_with_output_base( - q, q_scale, k, v, positions, self.layer_name, self.use_mla, qkv + **kwargs, + ) -> torch.Tensor: + raise NotImplementedError( + f"{self.__class__.__name__} must implement the forward() method" ) - return output + diff --git a/atom/model_ops/embed_head.py b/atom/model_ops/embed_head.py index 77281d4f9..371c0324e 100644 --- a/atom/model_ops/embed_head.py +++ b/atom/model_ops/embed_head.py @@ -9,6 +9,7 @@ from torch import nn from atom.utils.forward_context import ForwardContext, get_forward_context +from atom.plugin import is_plugin_mode class VocabParallelEmbedding(nn.Module): @@ -68,13 +69,14 @@ def __init__( self.register_parameter("bias", None) def forward(self, x: torch.Tensor): - forward_context: ForwardContext = get_forward_context() - context = forward_context.context - attn_metadata = forward_context.attn_metadata - # context = get_context() - if context.is_prefill and not context.is_draft: - last_indices = attn_metadata.cu_seqlens_q[1:] - 1 - x = x[last_indices].contiguous() + if not is_plugin_mode(): + forward_context: ForwardContext = get_forward_context() + context = forward_context.context + attn_metadata = forward_context.attn_metadata + # context = get_context() + if context.is_prefill and not context.is_draft: + last_indices = attn_metadata.cu_seqlens_q[1:] - 1 + x = x[last_indices].contiguous() logits = tgemm.mm(x, self.weight, self.bias) if self.tp_size > 1: logits = tensor_model_parallel_all_gather(logits) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index b22fe317f..9e0710d00 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -47,6 +47,7 @@ ) from atom.utils.custom_register import direct_register_custom_op from atom.utils.forward_context import get_forward_context +from atom.plugin.moe import FusedMoEDecoratorForPluginMode class FusedMoeWeightScaleSupported(Enum): @@ -1757,6 +1758,7 @@ def moe_forward_fake( ) +@FusedMoEDecoratorForPluginMode class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py new file mode 100644 index 000000000..75ca9b314 --- /dev/null +++ b/atom/model_ops/paged_attention.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +# from flash_attn import flash_attn_with_kvcache +from typing import Optional + +import torch +from torch import nn + +from .attention_mla import MLAModules +from .base_attention import BaseAttention +from atom.config import get_current_atom_config +from atom.utils.selector import get_attn_backend +from atom.plugin.prepare import is_sglang, is_vllm +from atom.plugin.attention import unified_attention_with_output_base_for_plugin_mode + +class PagedAttention(BaseAttention): + """ + Attention paged implementation + """ + def __init__( + self, + num_heads, + head_dim, + scale, + num_kv_heads, + alibi_slopes: list[float] = None, + kv_cache_dtype="bf16", + layer_num=0, + use_mla: bool = False, + mla_modules: Optional[MLAModules] = None, + sinks: Optional[nn.Parameter] = None, + per_layer_sliding_window: Optional[int] = None, + rotary_emb: Optional[torch.nn.Module] = None, + prefix: Optional[str] = None, + q_norm: Optional[torch.nn.Module] = None, + k_norm: Optional[torch.nn.Module] = None, + **kwargs, + ): + # plugin mode(sglang) is not support paged attention + # for now, only support plugin mode(vllm) and atom server mode + assert not is_sglang(), "PagedAttention is not supported for plugin mode(sglang) for now" + super().__init__(num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + kv_cache_dtype=kv_cache_dtype, + layer_num=layer_num, + use_mla=use_mla, + mla_modules=mla_modules, + sinks=sinks, + per_layer_sliding_window=per_layer_sliding_window, + rotary_emb=rotary_emb, + prefix=prefix, + **kwargs) + + # for plugin mode + if is_vllm(): + self.use_mla = use_mla + from vllm.attention.layer import Attention, AttentionType + + atom_config = get_current_atom_config() + assert atom_config is not None, "atom_config is required for plugin mode to vllm" + + # use vllm cache config and quant config to follow the convention of vllm + cache_config = atom_config.plugin_config.vllm_cache_config + quant_config = atom_config.plugin_config.vllm_quant_config + + # add exter impl args, which are needed to be passed to the impl class + # while it only works for custom attention backend for vllm + extra_impl_args = {} + if atom_config.plugin_config.vllm_use_custom_attention: + extra_impl_args['sinks'] = sinks + extra_impl_args['rotary_emb'] = rotary_emb + extra_impl_args['q_norm'] = q_norm + extra_impl_args['k_norm'] = k_norm + + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=None, + per_layer_sliding_window=per_layer_sliding_window, + prefix=f"{prefix}", + attn_type=AttentionType.DECODER, + kv_sharing_target_layer_name=None, + **extra_impl_args, + ) + + compilation_config = atom_config.compilation_config + self.layer_name = prefix + if self.layer_name in compilation_config.static_forward_context: + raise ValueError("Duplicate layer: {}".format(self.layer_name)) + compilation_config.static_forward_context[self.layer_name] = self + return + + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.k_cache = self.v_cache = torch.tensor([]) + self.kv_cache_dtype = kv_cache_dtype + self.max_model_len = 0 + self.k_scale = self.v_scale = None + self.layer_num = layer_num + self.mla_modules = mla_modules + self.use_mla = use_mla + self.base_attention = None + self.kv_cache = torch.tensor([]) + self.indexer = mla_modules.indexer if mla_modules is not None else None + self.sinks = sinks + + atom_config = get_current_atom_config() + dtype = atom_config.torch_dtype + block_size = atom_config.kv_cache_block_size + self.attn_backend = get_attn_backend( + block_size, + use_mla=self.use_mla, + ) + impl_cls = self.attn_backend.get_impl_cls() + self.impl = impl_cls( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + kv_cache_dtype=kv_cache_dtype, + layer_num=layer_num, + mla_modules=mla_modules, + sinks=sinks, + sliding_window=per_layer_sliding_window, + rotary_emb=rotary_emb, + dtype=dtype, + q_norm=q_norm, + k_norm=k_norm, + **kwargs, + ) + + compilation_config = atom_config.compilation_config + default_name = f"MLA_{layer_num}" if self.use_mla else f"MHA_{layer_num}" + self.layer_name = prefix if prefix is not None else default_name + if self.layer_name in compilation_config.static_forward_context: + raise ValueError("Duplicate layer: {}".format(self.layer_name)) + compilation_config.static_forward_context[self.layer_name] = self + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor = None, + q_scale: Optional[torch.Tensor]=None, + qkv: torch.Tensor = None, + **kwargs, + ): + if is_vllm(): + output = unified_attention_with_output_base_for_plugin_mode( + query, + q_scale, + key, + value, + positions, + layer_name=self.layer_name, + use_mla=self.use_mla, + qkv=qkv, + ) + return output + + # for atom server mode + output = torch.ops.aiter.unified_attention_with_output_base( + query, q_scale, key, value, positions, self.layer_name, self.use_mla, qkv + ) + return output diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py new file mode 100644 index 000000000..9f892e797 --- /dev/null +++ b/atom/model_ops/radix_attention.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import aiter +import torch +from torch import nn +from typing import Optional + +from .attention_mla import MLAModules +from .base_attention import BaseAttention +from atom.plugin.prepare import is_plugin_mode, is_sglang +from atom.models.utils import maybe_prefix + + +class RadixAttention(BaseAttention): + """ + Attention radix implementation + """ + def __init__( + self, + num_heads, + head_dim, + scale, + num_kv_heads, + kv_cache_dtype="bf16", + layer_num=0, + use_mla: bool = False, + mla_modules: Optional[MLAModules] = None, + sinks: Optional[nn.Parameter] = None, + per_layer_sliding_window: Optional[int] = None, + rotary_emb: Optional[torch.nn.Module] = None, + prefix: Optional[str] = None, + **kwargs, + ): + super().__init__(num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + kv_cache_dtype=kv_cache_dtype, + layer_num=layer_num, + use_mla=use_mla, + mla_modules=mla_modules, + sinks=sinks, + per_layer_sliding_window=per_layer_sliding_window, + rotary_emb=rotary_emb, + prefix=prefix, + **kwargs) + + if is_sglang(): + from sglang.srt.layers.radix_attention import RadixAttention + self.attn = RadixAttention( + num_heads=num_heads, + head_dim=head_dim, + scaling=scale, + num_kv_heads=num_kv_heads, + layer_id=layer_num, + prefix=maybe_prefix(prefix, "attn"), + ) + else: + raise NotImplementedError("RadixAttention is only supported for plugin mode for sglang for now") + + def forward_impl_plugin_mode( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + position: torch.Tensor = None, + q_scale: torch.Tensor=None, + **kwargs, + ): + if is_sglang(): + # for sglang, forward_batch is required + forward_batch = kwargs.get("forward_batch", None) + assert forward_batch is not None, "forward_batch is required for sglang" + return self.attn(q=query, + k=key, + v=value, + forward_batch=forward_batch) + else: + raise NotImplementedError("RadixAttention is only supported \ + for plugin mode for sglang for now") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + position: torch.Tensor = None, + q_scale: Optional[torch.Tensor]=None, + **kwargs, + ): + if is_plugin_mode(): + o = self.forward_impl_plugin_mode(query=query, + key=key, + value=value, + position=position, + q_scale=q_scale, + **kwargs) + else: + raise NotImplementedError("RadixAttention is not supported for server \ + mode for now") + return o diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 4a88fb863..f39504195 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -59,7 +59,7 @@ ) from atom.model_ops.activation import SiluAndMul from atom.model_ops.attention_mla import MLAModules, is_rocm_aiter_fp4bmm_enabled -from atom.model_ops.base_attention import Attention +import atom.model_ops as ops from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding from atom.model_ops.layernorm import LayerNorm, RMSNorm from atom.model_ops.linear import ( @@ -1397,11 +1397,12 @@ def __init__( indexer=self.indexer, ) - self.mla_attn = Attention( + self.mla_attn = ops.ATTN_CLS( num_heads=self.num_local_heads, head_dim=self.kv_lora_rank + self.qk_rope_head_dim, scale=self.scaling, num_kv_heads=1, + alibi_slopes=None, kv_cache_dtype=cache_config, layer_num=layer_num, use_mla=True, diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index eee08a710..3114802ea 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -28,7 +28,7 @@ # from vllm.model_executor.layers.logits_processor import LogitsProcessor from aiter.rotary_embedding import get_rope from atom.config import Config, QuantizationConfig -from atom.model_ops.base_attention import Attention +import atom.model_ops as ops from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding # from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig @@ -121,11 +121,12 @@ def __init__( # Only apply sliding window to every other layer sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None - self.attn = Attention( + self.attn = ops.ATTN_CLS( self.num_local_attention_heads, self.head_dim, self.scaling, num_kv_heads=self.num_local_key_value_heads, + alibi_slopes=None, kv_cache_dtype=cache_config, quant_config=quant_config, per_layer_sliding_window=sliding_window, diff --git a/atom/models/llama.py b/atom/models/llama.py index 4f485de69..7d2a95c4a 100644 --- a/atom/models/llama.py +++ b/atom/models/llama.py @@ -34,7 +34,7 @@ from atom.model_ops.activation import SiluAndMul # from atom.model_ops.attention import Attention -from atom.model_ops.base_attention import Attention +import atom.model_ops as ops from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding from atom.model_ops.layernorm import RMSNorm from atom.model_ops.linear import ( @@ -190,11 +190,12 @@ def __init__( if is_sliding: sliding_window = config.sliding_window - self.attn = Attention( + self.attn = ops.ATTN_CLS( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + alibi_slopes=None, kv_cache_dtype=cache_config, layer_num=layer_num, per_layer_sliding_window=sliding_window, diff --git a/atom/models/mixtral.py b/atom/models/mixtral.py index e9f4b66a7..ff7e2159e 100644 --- a/atom/models/mixtral.py +++ b/atom/models/mixtral.py @@ -32,7 +32,7 @@ from atom.config import Config, QuantizationConfig # from atom.model_ops.attention import Attention -from atom.model_ops.base_attention import Attention +import atom.model_ops as ops from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding from atom.model_ops.layernorm import RMSNorm from atom.model_ops.linear import QKVParallelLinear, ReplicatedLinear, RowParallelLinear @@ -170,11 +170,12 @@ def __init__( base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( + self.attn = ops.ATTN_CLS( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + alibi_slopes=None, kv_cache_dtype=cache_config, layer_num=layer_num, quant_config=quant_config, diff --git a/atom/models/qwen3.py b/atom/models/qwen3.py index ccd4f7648..517bd9604 100644 --- a/atom/models/qwen3.py +++ b/atom/models/qwen3.py @@ -24,7 +24,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Any, Iterable import torch @@ -35,7 +35,7 @@ from atom.model_ops.activation import SiluAndMul # from atom.model_ops.attention import Attention -from atom.model_ops.base_attention import Attention +import atom.model_ops as ops from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding from atom.model_ops.layernorm import RMSNorm from atom.model_ops.linear import ( @@ -47,6 +47,9 @@ from torch import nn from transformers import Qwen3Config +from atom.model_loader.loader import load_model_in_plugin_mode +from atom.models.utils import maybe_prefix + class Qwen3Attention(nn.Module): @@ -63,11 +66,11 @@ def __init__( rope_scaling: tuple | None = None, kv_cache_dtype: str = "fp16", layer_num: int = 0, - quant_config: Optional[QuantizationConfig] = None, + atom_config: Config = None, + prefix: str = "", ) -> None: super().__init__() tp_size = get_tp_group().world_size - self.layer_num = layer_num self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size @@ -85,13 +88,15 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=qkv_bias, - quant_config=quant_config, + quant_config=atom_config.quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - quant_config=quant_config, + quant_config=atom_config.quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, @@ -100,15 +105,18 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - self.num_kv_heads, + self.attn = ops.ATTN_CLS( + num_heads=self.num_heads, + head_dim=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + alibi_slopes=None, kv_cache_dtype=kv_cache_dtype, layer_num=layer_num, use_mla=False, rotary_emb=self.rotary_emb, + config=atom_config, + prefix=f"{prefix}.attn", ) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -117,6 +125,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -124,7 +133,7 @@ def forward( q = self.q_norm(q) k = self.k_norm(k) - o = self.attn(q, k, v, positions) + o = self.attn(q, k, v, positions, **model_kwargs) output = self.o_proj(o) return output @@ -136,7 +145,8 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -144,12 +154,14 @@ def __init__( [intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) assert hidden_act == "silu" self.act_fn = SiluAndMul() @@ -166,11 +178,12 @@ class Qwen3DecoderLayer(nn.Module): def __init__( self, config: Qwen3Config, - cache_config: str = "bf16", - quant_config: Optional[QuantizationConfig] = None, + atom_config: Config, layer_num: int = 0, + prefix: str = "", ) -> None: super().__init__() + kv_cache_dtype = atom_config.kv_cache_dtype self.layer_num = layer_num rope_params = config.rope_parameters self.self_attn = Qwen3Attention( @@ -183,15 +196,17 @@ def __init__( head_dim=getattr(config, "head_dim", None), rope_theta=rope_params["rope_theta"], rope_scaling=rope_params, - kv_cache_dtype=cache_config, + kv_cache_dtype=kv_cache_dtype, layer_num=layer_num, - quant_config=quant_config, + atom_config=atom_config, + prefix=f"{prefix}.self_attn", ) self.mlp = Qwen3MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + quant_config=atom_config.quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -203,51 +218,63 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, + **model_kwargs: dict[str, Any] | None, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn(positions, hidden_states) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states, + **model_kwargs) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual -@support_torch_compile +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + } +) class Qwen3Model(nn.Module): - def __init__(self, atom_config: Config) -> None: + def __init__(self, *, atom_config: Config, prefix: str = "") -> None: super().__init__() - config = atom_config.hf_config - cache_config = atom_config.kv_cache_dtype - quant_config = atom_config.quant_config + hf_config = atom_config.hf_config + self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, config.hidden_size + hf_config.vocab_size, hf_config.hidden_size ) self.layers = nn.ModuleList( [ Qwen3DecoderLayer( - config, - cache_config=cache_config, - quant_config=quant_config, + config=hf_config, + atom_config=atom_config, layer_num=layer_num, + prefix=f"{prefix}.layers.{layer_num}", ) - for layer_num in range(config.num_hidden_layers) + for layer_num in range(hf_config.num_hidden_layers) ] ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(hf_config.hidden_size, eps=hf_config.rms_norm_eps) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + **model_kwargs: dict[str, Any], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for layer in self.layers: - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer(positions=positions, + hidden_states=hidden_states, + residual=residual, + **model_kwargs) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -261,20 +288,32 @@ class Qwen3ForCausalLM(nn.Module): "up_proj": ("gate_up_proj", 1), } - def __init__(self, atom_config: Config) -> None: + def __init__(self, config: Any, prefix: str = "") -> None: super().__init__() - config = atom_config.hf_config - self.model = Qwen3Model(atom_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - if config.tie_word_embeddings: + self.atom_config = config + self.hf_config = self.atom_config.hf_config + self.model = Qwen3Model( + atom_config=self.atom_config, prefix=maybe_prefix(prefix, "model") + ) + + self.lm_head = ParallelLMHead(num_embeddings=self.hf_config.vocab_size, + embedding_dim=self.hf_config.hidden_size, + bias=False, + prefix=maybe_prefix(prefix, "lm_head")) + if self.hf_config.tie_word_embeddings: self.lm_head.weight.data = self.model.embed_tokens.weight.data def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + intermediate_tensors = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions) + hidden_states = self.model(input_ids=input_ids, + positions=positions, + **model_kwargs) return hidden_states def compute_logits( @@ -283,3 +322,10 @@ def compute_logits( ) -> torch.Tensor: logits = self.lm_head(hidden_states) return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + loaded_weights_record = load_model_in_plugin_mode(model=self, + config=self.atom_config, + prefix="model.") + return loaded_weights_record diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 57a89f1e5..f509d2e51 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Tuple, Union, Any, Iterable import torch from aiter.dist.communication_op import tensor_model_parallel_all_reduce @@ -9,8 +9,7 @@ from atom.config import Config, QuantizationConfig from atom.model_ops.activation import SiluAndMul -# from atom.model_ops.attention import Attention -from atom.model_ops.base_attention import Attention +import atom.model_ops as ops from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding from atom.model_ops.layernorm import RMSNorm from atom.model_ops.linear import ( @@ -30,6 +29,7 @@ from atom.utils import envs from atom.utils.decorators import support_torch_compile from torch import nn +from atom.model_loader.loader import load_model_in_plugin_mode # import torch.distributed as dist from transformers import PretrainedConfig, Qwen3Config @@ -149,7 +149,8 @@ def __init__( rope_scaling: tuple | None = None, kv_cache_dtype: str = "fp16", layer_num: int = 0, - quant_config: Optional[QuantizationConfig] = None, + atom_config: Config = None, + prefix: str = "", ) -> None: super().__init__() tp_size = get_tensor_model_parallel_world_size() @@ -177,14 +178,14 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=qkv_bias, - quant_config=quant_config, + quant_config=atom_config.quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - quant_config=quant_config, + quant_config=atom_config.quant_config, reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, ) @@ -207,7 +208,7 @@ def __init__( self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) - self.attn = Attention( + self.attn = ops.ATTN_CLS( self.num_heads, self.head_dim, self.scaling, @@ -216,6 +217,7 @@ def __init__( layer_num=layer_num, use_mla=False, rotary_emb=self.rotary_emb, + prefix=f"{prefix}.attn", q_norm=self.q_norm if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION else None, k_norm=self.k_norm if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION else None, ) @@ -227,17 +229,21 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: - attn_output = self.attn(q, k, v, positions, None, qkv) + q, k, v = torch.split( + qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 + ) + attn_output = self.attn(query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv) else: # Add qk-norm q = self.q_norm(q) k = self.k_norm(k) - attn_output = self.attn(q, k, v, positions) + attn_output = self.attn(query=q, key=k, value=v, **model_kwargs) output = self.o_proj(attn_output) return output @@ -245,18 +251,19 @@ def forward( class Qwen3MoeDecoderLayer(nn.Module): def __init__( self, - config: Qwen3Config, - prefix: str, - cache_config: str = "bf16", - quant_config: Optional[QuantizationConfig] = None, + atom_config = None, layer_num: int = 0, + prefix: str = "" ) -> None: super().__init__() + self.atom_config = atom_config + config = self.atom_config.hf_config self.hidden_size = config.hidden_size rope_params = config.rope_parameters rope_theta = rope_params["rope_theta"] rope_scaling = rope_params + kv_cache_dtype = atom_config.kv_cache_dtype max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. @@ -272,9 +279,10 @@ def __init__( head_dim=getattr(config, "head_dim", None), rope_theta=rope_theta, rope_scaling=rope_scaling, - kv_cache_dtype=cache_config, + kv_cache_dtype=kv_cache_dtype, layer_num=layer_num, - quant_config=quant_config, + atom_config=atom_config, + prefix=f"{prefix}.self_attn", ) # `mlp_only_layers` in the config. @@ -286,14 +294,14 @@ def __init__( and (self.layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3MoeSparseMoeBlock( - config, quant_config=quant_config, prefix=f"{prefix}.mlp" + config, quant_config=self.atom_config.quant_config, prefix=f"{prefix}.mlp" ) else: self.mlp = Qwen3MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + quant_config=self.atom_config.quant_config, reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, prefix=f"{prefix}.mlp", ) @@ -313,6 +321,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, + **model_kwargs: dict[str, Any] | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -323,6 +332,7 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, + **model_kwargs, ) # Fully Connected @@ -341,42 +351,37 @@ def __init__( ): super().__init__() - config = atom_config.hf_config - cache_config = atom_config.kv_cache_dtype - quant_config = atom_config.quant_config - self.config = config + self.config = atom_config.hf_config if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, + self.config.vocab_size, + self.config.hidden_size, ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, + self.config.num_hidden_layers, lambda prefix, layer_num=None: Qwen3MoeDecoderLayer( - config, - prefix, - cache_config=cache_config, - quant_config=quant_config, + atom_config=atom_config, layer_num=layer_num, + prefix=prefix, ), prefix=f"{prefix}.layers", layer_num_offset=0, ) if get_pp_group().is_last_rank: self.norm = RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, - fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, + self.config.hidden_size, + eps=self.config.rms_norm_eps, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION ) else: self.norm = PPMissingLayer() self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size + ["hidden_states", "residual"], self.config.hidden_size ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -388,6 +393,7 @@ def forward( positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -401,7 +407,7 @@ def forward( residual = intermediate_tensors["residual"] for layer in self.layers[self.start_layer : self.end_layer]: - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer(positions, hidden_states, residual, **model_kwargs) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -438,26 +444,22 @@ def __init__( layer_type: type[nn.Module] = Qwen3MoeDecoderLayer, ): super().__init__() - config = atom_config.hf_config - quant_config = atom_config.quant_config - self.config = config - self.quant_config = quant_config + self.atom_config = atom_config + self.config = self.atom_config.hf_config + # Only perform the following mapping when Qwen3MoeMLP exists - if getattr(config, "mlp_only_layers", []): + if getattr(self.config, "mlp_only_layers", []): self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"] self.model = Qwen3MoeModel( - atom_config=atom_config, + atom_config=self.atom_config, prefix=maybe_prefix(prefix, "model"), - layer_type=layer_type, ) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - prefix=maybe_prefix(prefix, "lm_head"), - ) + self.lm_head = ParallelLMHead(num_embeddings=self.config.vocab_size, + embedding_dim=self.config.hidden_size, + bias=False, + prefix=maybe_prefix(prefix, "lm_head")) else: self.lm_head = PPMissingLayer() if self.config.tie_word_embeddings: @@ -476,9 +478,14 @@ def forward( positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any] | None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, ) return hidden_states @@ -505,3 +512,10 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + loaded_weights_record = load_model_in_plugin_mode(model=self, + config=self.atom_config, + prefix="model.") + return loaded_weights_record diff --git a/atom/plugin/__init__.py b/atom/plugin/__init__.py new file mode 100644 index 000000000..fc6671481 --- /dev/null +++ b/atom/plugin/__init__.py @@ -0,0 +1,6 @@ +from .prepare import ( + prepare_model, + is_sglang, + is_vllm, + is_plugin_mode, +) diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py new file mode 100644 index 000000000..617e9d062 --- /dev/null +++ b/atom/plugin/attention.py @@ -0,0 +1,624 @@ +from typing import Any, Type, Optional +import logging + +from dataclasses import dataclass + +import torch + +from atom.plugin.prepare import is_vllm, is_sglang +from atom.utils import CpuGpuBuffer +from atom.utils.forward_context import Context, AttentionMetaData +from atom.model_ops.attention_mha import PagedAttentionImpl + +logger = logging.getLogger("atom") + +_PARTITION_SIZE_ROCM = 256 +_CP_TOKENS_PER_ITER_ROCM = 32 * 1024 + +@dataclass +class AiterFlashAttentionDecodeMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + + +@dataclass +class AiterFlashAttentionPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + + +@dataclass +class AiterChunkSlidingWindowMetadata: + swa_seqlens: torch.Tensor + swa_cu_seqlens: torch.Tensor + swa_seq_starts: torch.Tensor + swa_token_to_batch: torch.Tensor + swa_max_seqlens: int + swa_total_tokens: int + swa_workspace: torch.Tensor + + +@dataclass +class AiterChunkContextMetadata: + workspace: torch.Tensor + cu_seq_lens_chunk: torch.Tensor + chunk_starts: torch.Tensor + token_to_batch: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + seq_lens: torch.Tensor + num_chunks: int + total_token_per_batch: list[int] + swa_metadata: AiterChunkSlidingWindowMetadata | None + + +@dataclass +class AiterFlashAttentionChunkPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + chunk_context_metadata: AiterChunkContextMetadata + + +@dataclass +class MetadataForPluginMode: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + num_actual_kv_tokens: int + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + slot_mapping: torch.Tensor + block_table: torch.Tensor + + # prefill and deocde split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + num_extends: int + num_extend_tokens: int + + decode_metadata: AiterFlashAttentionDecodeMetadata | None + prefill_metadata: AiterFlashAttentionPrefillMetadata | None + extend_metadata: AiterFlashAttentionChunkPrefillMetadata | None + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + total_tokens: int + + context: Optional[Context] = None + + # # Only for fp8 shuffle layout kv cache, we allocate kv_scale for each layer + # # since we might integrate per token quant for kv cache in the future. + # k_scale: dict[str, torch.Tensor] | None + # v_scale: dict[str, torch.Tensor] | None + +class vllmAiterBackendMethods: + # here attention in ATOM doesn't accept the output buffer because + # ATOM works as a model impl backend, it needs the maximum freedom + # to decide the output buffer and shape, thus here use this flag to + # let vllm don't allocate the output buffer for ATOM. ATOM will + # handle the output buffer by itself + accept_output_buffer: bool = False + supported_dtypes: list = [torch.float16, torch.bfloat16] + + def __init__(self): + raise TypeError( + f"{self.__class__.__name__} is a utility class and should not be instantiated. " + "Its methods are meant to be added to other classes via decorators." + ) + + @staticmethod + def get_supported_kernel_block_sizes(): + from vllm.v1.attention.backend import MultipleOf + return [MultipleOf(16)] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @classmethod + def is_mla(cls) -> bool: + return False + + @staticmethod + def get_required_kv_cache_layout(): + return None + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [64, 128, 256] + + @classmethod + def full_cls_name(cls) -> tuple[str, str]: + return (cls.__module__, cls.__qualname__) + + @classmethod + def supports_alibi_sqrt(cls) -> bool: + return False + + +def AiterBackendDecoratorForPluginMode(cls): + ''' + Decorator for AiterBackend to add specific methods and attributes for plugin mode + ''' + is_vllm_mode = is_vllm() + + if is_vllm_mode: + # for vllm, add the required methods + cls.full_cls_name = vllmAiterBackendMethods.full_cls_name + cls.accept_output_buffer = vllmAiterBackendMethods.accept_output_buffer + cls.supported_dtypes = vllmAiterBackendMethods.supported_dtypes + cls.get_supported_kernel_block_sizes = vllmAiterBackendMethods.get_supported_kernel_block_sizes + cls.get_kv_cache_shape = vllmAiterBackendMethods.get_kv_cache_shape + cls.is_mla = vllmAiterBackendMethods.is_mla + cls.get_required_kv_cache_layout = vllmAiterBackendMethods.get_required_kv_cache_layout + cls.get_supported_head_sizes = vllmAiterBackendMethods.get_supported_head_sizes + cls.supports_alibi_sqrt = vllmAiterBackendMethods.supports_alibi_sqrt + return cls + + +def create_attn_metadata_builder_init_method(base_class): + ''' + Create the init method for metadata builder + ''' + def init_method_under_plugin_mode(self, + kv_cache_spec=None, + layer_names=None, + config=None, + device=None, + model_runner=None): + base_class.__init__(self, + kv_cache_spec, + layer_names, + config, + device) + logger.info(f"init AiterAttentionMetadataBuilder for plugin mode") + from vllm.config import VllmConfig,get_layers_from_vllm_config + from vllm.attention.layer import Attention + + assert isinstance(config, VllmConfig) + + self.vllm_config = config + self.model_config = config.model_config + self.parallel_config = config.parallel_config + self.cache_config = config.cache_config + + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) + self.head_dim = self.model_config.get_head_size() + self.block_size = kv_cache_spec.block_size + + self.aot_sliding_window: tuple[int, int] | None = None + self.total_tokens: int = 0 + + self.scheduler_config = config.scheduler_config + self.block_ratio = 1 + + sliding_window_sizes: set[tuple[int, int] | None] = set() + layers = get_layers_from_vllm_config(config, Attention) + for layer in layers.values(): + assert isinstance(layer.impl, PagedAttentionImpl) + sliding_window_sizes.add((layer.impl.sliding_window - 1, 0)) + + while len(sliding_window_sizes) > 0: + sliding_window_config = sliding_window_sizes.pop() + if sliding_window_config is not None and sliding_window_config[0] != -1: + assert self.aot_sliding_window is None, ( + "Aiter Backend only support one valid sliding window" + ) + self.aot_sliding_window = sliding_window_config + + # for extend path to store the fetched key and value + # here buffer used for extend path is not calculated by vLLM and SGLang + # when profile_run, it is possible to exhaust the GPU memory when + # gpu_mem_utilization is much higher + self.extend_workspace = torch.empty( + [2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.head_dim], + dtype=self.model_config.dtype, + device=device, + ) + + # used for ROPE + max_num_batched_tokens = config.scheduler_config.max_num_batched_tokens + i64_kwargs = {"dtype": torch.int64, "device": device} + self.positions = CpuGpuBuffer(max_num_batched_tokens, **i64_kwargs) + + return init_method_under_plugin_mode + + +def setup_attn_metadata_builder_base_class_and_attributes(class_dict: dict): + ''' + Setup the base class and attributes for attention metadata builder + ''' + from vllm.v1.attention.backend import ( + AttentionCGSupport, + AttentionMetadataBuilder, + ) + + base_class = AttentionMetadataBuilder + generic_base = AttentionMetadataBuilder + needs_generic = True + + # align with vllm rocm aiter fa + class_dict['_cudagraph_support'] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + class_dict['reorder_batch_threshold'] = 1 + + return base_class, generic_base, needs_generic, class_dict + + +class vllmAttentionMetadataBuilderMethods: + def __init__(self): + raise TypeError( + f"{self.__class__.__name__} is a utility class and should not be instantiated. " + "Its methods are meant to be added to other classes via decorators." + ) + + def build( + self, + common_prefix_len: int = 0, + common_attn_metadata = None, + fast_build: bool = False, + ): + if common_prefix_len > 0: + raise ValueError("ATOM does not support cascade attention yet") + + from vllm.v1.attention.backends.utils import split_decodes_prefills_and_extends + + # here assume the decode num token is 1 per request + split_ret = split_decodes_prefills_and_extends( + common_attn_metadata=common_attn_metadata, + decode_threshold=1 + ) + + ( + num_decodes, + num_extends, + num_prefills, + num_decode_tokens, + num_extend_tokens, + num_prefill_tokens, + ) = split_ret + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens.cpu() + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + # used to store the positions of each tokens of each request + # for computing ROPE + positions = [] + + decode_metadata = None + if num_decodes > 0: + decode_metadata = AiterFlashAttentionDecodeMetadata( + max_query_len=query_lens_cpu[:num_decodes].max().item(), + min_query_len=query_lens_cpu[:num_decodes].min().item(), + max_seq_len=seq_lens[:num_decodes].max().item(), + query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1], + ) + for seq_len in seq_lens[:num_decodes]: + positions.append(seq_len - 1) + + extend_metadata = None + if num_extends > 0: + num_extends_slice = slice(num_decodes, num_decodes + num_extends) + query_lens_for_extend = query_lens_cpu[num_extends_slice] + seq_lens_for_extend = seq_lens[num_extends_slice] + computed_kv_lens = seq_lens_for_extend - query_lens_for_extend + swa_metadata = None + if self.aot_sliding_window is not None: + swa_seqlen_for_extend = torch.minimum( + seq_lens_for_extend, + query_lens_for_extend + self.aot_sliding_window[0] + 1, + ) + cu_seq_lens = torch.zeros( + num_extends + 1, + dtype=torch.int32, + device=seq_lens_for_extend.device, + ) + torch.cumsum( + swa_seqlen_for_extend, + dim=0, + dtype=cu_seq_lens.dtype, + out=cu_seq_lens[1:], + ) + token_to_seq = torch.arange( + 0, + num_extends, + dtype=torch.int32, + device=seq_lens_for_extend.device, + ) + token_to_seq = torch.repeat_interleave( + token_to_seq, swa_seqlen_for_extend + ) + fetched_shape = cu_seq_lens[-1].item() + swa_workspace = torch.empty( + (2, fetched_shape, self.num_heads_kv, self.head_dim), + dtype=self.vllm_config.model_config.dtype, + device=self.device, + ) + + seq_starts = seq_lens_for_extend - swa_seqlen_for_extend + max_seqlen_k = swa_seqlen_for_extend.max().item() + total_tokens = cu_seq_lens[-1].item() + + swa_metadata = AiterChunkSlidingWindowMetadata( + swa_seqlens=swa_seqlen_for_extend.to( + self.device, non_blocking=True + ), + swa_cu_seqlens=cu_seq_lens.to(self.device, non_blocking=True), + swa_seq_starts=seq_starts.to(self.device, non_blocking=True), + swa_token_to_batch=token_to_seq.to(self.device, non_blocking=True), + swa_max_seqlens=max_seqlen_k, + swa_total_tokens=total_tokens, + swa_workspace=swa_workspace, + ) + + # allocate the equal amount of workspace for + # each chunk prefill request + max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends + from vllm.utils.math_utils import cdiv + num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk) + + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_extends) + * max_context_chunk + ) + chunk_ends = torch.min( + computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk + ) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp( + min=0 + ) # [num_chunks, num_extends] + cu_seq_lens_cpu = torch.zeros( + [num_chunks, num_extends + 1], dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 + ) + max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item() + + # Build token->batch mapping robustly, even with zero-length batches. + token_to_batch_tensor = torch.zeros( + (num_chunks, max_cum_tokens), dtype=torch.int32, pin_memory=True + ) + batch_ids = torch.arange(num_extends, dtype=torch.int32) + for chunk_idx in range(num_chunks): + total_tokens = cu_seq_lens_cpu[chunk_idx, -1].item() + if total_tokens == 0: + continue + token_to_batch = torch.repeat_interleave( + batch_ids, chunk_seq_lens[chunk_idx].to(torch.int64) + ) + token_to_batch_tensor[chunk_idx, :total_tokens] = token_to_batch + + chunk_context_metadata = AiterChunkContextMetadata( + workspace=self.extend_workspace, + cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True), + chunk_starts=chunk_starts.to(self.device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True), + num_chunks=num_chunks, + total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(), + swa_metadata=swa_metadata, + ) + + query_start_loc_device = common_attn_metadata.query_start_loc[ + num_decodes : num_decodes + num_extends + 1 + ] + seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice] + cu_seq_lens = torch.zeros( + num_extends + 1, dtype=torch.int32, device=seq_lens_device.device + ) + torch.cumsum( + seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:] + ) + extend_metadata = AiterFlashAttentionChunkPrefillMetadata( + max_query_len=query_lens_for_extend.max().item(), + min_query_len=query_lens_for_extend.min().item(), + max_seq_len=seq_lens[num_extends_slice].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0], + chunk_context_metadata=chunk_context_metadata, + ) + + for idx in range(num_extends): + extend_start_seq_len = seq_lens_for_extend[idx] - query_lens_for_extend[idx] + extend_end_seq_len = seq_lens_for_extend[idx] + for pos in range(extend_start_seq_len, extend_end_seq_len): + positions.append(pos) + + prefill_metadata = None + if num_prefills > 0: + query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :] + query_start_loc_device = common_attn_metadata.query_start_loc[ + num_decodes + num_extends : + ] + prefill_metadata = AiterFlashAttentionPrefillMetadata( + max_query_len=query_lens_for_prefill.max().item(), + min_query_len=query_lens_for_prefill.min().item(), + max_seq_len=seq_lens[num_decodes + num_extends :].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0], + ) + for prefill_seq_len in seq_lens[num_decodes + num_extends :]: + for pos in range(prefill_seq_len): + positions.append(pos) + + num_actual_kv_tokens = torch.sum(seq_lens).item() + + use_cascade = common_prefix_len > 0 + + context_batch_size = 0 + has_prefill = bool(num_prefills > 0 or num_extends > 0) + if has_prefill: + context_batch_size = num_prefills + num_extends + else: + context_batch_size = num_decodes + context_graph_bs = context_batch_size + + num_actual_tokens = common_attn_metadata.num_actual_tokens + self.positions.np[:num_actual_tokens] = positions + context=Context( + positions=self.positions.copy_to_gpu(num_actual_tokens), + is_prefill=has_prefill, + batch_size=context_batch_size, + graph_bs=context_graph_bs, + ) + + attn_metadata_for_plugin_mode = MetadataForPluginMode( + num_actual_tokens=num_actual_tokens, + num_actual_kv_tokens=num_actual_kv_tokens, + max_query_len=common_attn_metadata.max_query_len, + query_start_loc=common_attn_metadata.query_start_loc, + max_seq_len=common_attn_metadata.max_seq_len, + seq_lens=common_attn_metadata.seq_lens, + block_table=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_extends=num_extends, + num_extend_tokens=num_extend_tokens, + decode_metadata=decode_metadata, + prefill_metadata=prefill_metadata, + extend_metadata=extend_metadata, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + total_tokens=self.total_tokens, + context=context, + ) + + attn_metadata = AttentionMetaData( + max_seqlen_q=common_attn_metadata.max_query_len, + block_tables=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + plugin_metadata=attn_metadata_for_plugin_mode, + ) + + # TODO: set the forward context + # set_forward_context( + # attn_metadata=attn_metadata, + # atom_config=self.config, + # context=context, + # num_tokens=num_tokens, + # num_tokens_across_dp=num_tokens_across_dp, + # ) + + return attn_metadata + + # this method will be called by vllm, so it follows the vllm's interface convention + def build_for_cudagraph_capture( + self, + common_attn_metadata = None, + ): + self.total_tokens = ( + self.model_config.max_model_len + * self.vllm_config.scheduler_config.max_num_partial_prefills + ) + attn_metadata = self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata) + self.total_tokens = 0 + return attn_metadata + + +def AiterAttentionMetadataBuilderDecoratorForPluginMode(default_base_class): + def decorator(cls): + is_vllm_mode = is_vllm() + is_sglang_mode = is_sglang() + + base_class = default_base_class + class_dict = {} + + for key, value in cls.__dict__.items(): + if not key.startswith('__') or key in ('__annotations__',): + class_dict[key] = value + + # handle the generic base class + needs_generic = False + generic_base = None + + if is_vllm_mode: + # get the base class and generic base class + base_class, generic_base, needs_generic, class_dict = \ + setup_attn_metadata_builder_base_class_and_attributes(class_dict) + + # replace the __init__ method to the decorated class + class_dict['__init__'] = create_attn_metadata_builder_init_method(base_class) + + # add the methods to the decorated class + for method_name in dir(vllmAttentionMetadataBuilderMethods): + if not method_name.startswith('_'): + method = getattr(vllmAttentionMetadataBuilderMethods, method_name) + if callable(method): + class_dict[method_name] = method + elif is_sglang_mode: + raise NotImplementedError("AttentionMetadataBuilder for sglang is not implemented yet") + + # create the new class + new_class = type(cls.__name__, (base_class,), class_dict) + + # replace the inherit base class for plugin mode, meanwhile support generic base class + if needs_generic and generic_base is not None: + new_class.__orig_bases__ = (generic_base[new_class],) + + return new_class + + return decorator + + +# here not register it as a custom op and mark split because vllm +# will register attention impl forward as a custom op, so here +# avoid duplicated registration, and the split op is registered +# into the atom support_torch_compile decorator +def unified_attention_with_output_base_for_plugin_mode( + q: torch.Tensor, + q_scale: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + positions: torch.Tensor, + layer_name: str, + use_mla: bool, + qkv: torch.Tensor, +) -> torch.Tensor: + from atom.config import get_current_atom_config + from atom.utils import envs + atom_config = get_current_atom_config() + if use_mla: + raise NotImplementedError("MLA is not supported for plugin mode for now") + else: + self = atom_config.compilation_config.static_forward_context[layer_name] + # here is the standard vllm attention impl interface + # when using fusion, we need to pass the qkv and positions through the q,k,v + # [watch out] accept_output_buffer must be False for plugin mode + # because we don't want vllm to manipulate the q k v and output buffer + # ATOM needs to handle all of the buffer here + if envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: + output = self.attn(q, positions, qkv) + else: + output = self.attn(q, k, v) + return output diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py new file mode 100644 index 000000000..5d61292a5 --- /dev/null +++ b/atom/plugin/attention_mha.py @@ -0,0 +1,737 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +""" +Plugin mode extensions for PagedAttentionImpl. +This module provides additional methods for PagedAttentionImpl when running in plugin mode. +""" + +import torch +import aiter +from aiter import dtypes, fused_qk_norm_rope_cache_quant_shuffle +from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache +from aiter.ops.triton.gluon.pa_decode_gluon import get_recommended_splits +from typing import TYPE_CHECKING + +import logging +logger = logging.getLogger("atom") + +if TYPE_CHECKING: + from atom.utils.forward_context import AttentionMetaData + +from atom.utils import envs + +ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION + +_PARTITION_SIZE_ROCM = 256 +_CP_TOKENS_PER_ITER_ROCM = 32 * 1024 + + +class PagedAttentionImplPluginModeMethods: + """ + Container class for plugin mode methods. + This class cannot be instantiated - it only serves as a namespace for methods + that will be added to PagedAttentionImpl via decorator. + """ + + def __init__(self): + raise TypeError( + "PagedAttentionImplPluginModeMethods cannot be instantiated. " + "It is only used as a method container for the decorator." + ) + + def rope_cache_plugin_mode(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qkv: torch.Tensor, + position: torch.Tensor, + attention_metadata: "AttentionMetaData", + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + flash_layout: bool = False): + + num_blocks, block_size, num_kv_heads, head_size = k_cache.shape + + if not flash_layout: + x = 16 // k_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_cache.dtype, + device="meta", + ) + # ATOM: [num_blocks, num_kv_heads, head_size, block_size], + # vLLM: [num_blocks, num_kv_heads, block_size // x, head_size, x], + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_cache.dtype, + device="meta", + ) + new_key_cache = k_cache.view_as(k_cache_template) + new_value_cache = v_cache.view_as(v_cache_template) + else: + new_key_cache = k_cache + new_value_cache = v_cache + + # if flash kv_cache layout, the shape of kv_cache is: + # + # key_cache: [num_blocks, block_size, num_kv_heads, head_size] + # value_cache: [num_blocks, num_kv_heads, head_size, block_size] + # + # if not, the shape is: + # + # key_cache: [num_blocks, num_kv_heads, head_size // x, block_size, x] + # value_cache: [num_blocks, num_kv_heads, head_size, block_size] + # + # and the origin kv cache layout in fwd_args is not flash + + attn_metadata = attention_metadata + + use_triton_attn = self.sliding_window != -1 or self.head_dim != 128 + self.use_triton_attn = use_triton_attn + + if ( + self.rotary_emb is not None + and self.q_norm is not None + and self.k_norm is not None + ): + fused_qk_norm_rope_cache_quant_shuffle( + qkv, + num_heads_q=self.num_heads, + num_heads_k=self.num_kv_heads, + num_heads_v=self.num_kv_heads, + head_dim=self.head_dim, + eps=self.q_norm.eps, + qw=self.q_norm.weight, + kw=self.k_norm.weight, + cos_sin_cache=self.rotary_emb.cos_sin_cache, + is_neox_style=self.rotary_emb.is_neox_style, + pos_ids=position, + k_cache=new_key_cache, + v_cache=new_value_cache, + slot_mapping=attn_metadata.slot_mapping, + kv_cache_dtype=( + "auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype + ), + k_scale=k_scale, + v_scale=v_scale, + ) + + qkv = qkv.view(qkv.shape[0], -1, self.head_dim) + q, k, v = qkv.split( + [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1) + elif use_triton_attn and self.rotary_emb is not None: + k_scale = v_scale = self.kv_scale + + q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache( + q, + k, + v, + new_key_cache, + new_value_cache, + attn_metadata.slot_mapping, + position, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + k_scale, + v_scale, + self.rotary_emb.is_neox_style, + flash_layout=flash_layout, + apply_scale=self.kv_cache_dtype.startswith("fp8"), + offs=None, + q_out=q, + k_out=k, + output_zeros=False, + ) + else: + # for asm paged attention + if self.rotary_emb is not None: + assert position is not None + q, k = self.rotary_emb(position, q, k) + if self.q_norm is not None: + q = self.q_norm(q) + if self.k_norm is not None: + k = self.k_norm(k) + if self.kv_cache_dtype == "fp8": + aiter.reshape_and_cache_with_pertoken_quant( + k, + v, + new_key_cache, + new_value_cache, + k_scale, + v_scale, + attn_metadata.slot_mapping, + asm_layout=True, + ) + else: + aiter.reshape_and_cache( + k, + v, + new_key_cache, + new_value_cache, + attn_metadata.slot_mapping, + kv_cache_dtype="auto", + k_scale=None, + v_scale=None, + asm_layout=True, + ) + + return q, k, v, k_cache, v_cache, k_scale, v_scale + + def _get_cp_mha_gather_cache_views( + self, key_cache: torch.Tensor, value_cache: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, int]: + # For SHUFFLE layout, the wrapper derives PAGE_SIZE/num_heads from + # tensor shapes; provide a reshape-only view to keep storage unchanged. + if key_cache.ndim == 5: + num_blocks = key_cache.shape[0] + num_heads = key_cache.shape[1] + page_size = key_cache.shape[3] + x = key_cache.shape[4] + head_size = key_cache.shape[2] * x + key_cache = key_cache.view(num_blocks, page_size, num_heads, head_size) + value_cache = value_cache.view(num_blocks, page_size, num_heads, head_size) + return key_cache, value_cache, page_size + return key_cache, value_cache, key_cache.shape[1] + + def paged_attention_triton_plugin_mode(self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + out: torch.Tensor, + attn_metadata: "AttentionMetaData"): + + o = out + num_seqs, num_q_heads_total, head_size = q.shape + num_blocks, num_kv_heads, _, block_size, _ = k_cache.shape + query_group_size = num_q_heads_total // num_kv_heads + assert num_q_heads_total % num_kv_heads == 0 + + max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads) + + context_partition_size = 256 + if self.sliding_window > 0: + max_context_partition_num = 1 + context_partition_size = 128 + + # Output buffers (same as Triton) + intermediate_shape = ( + num_seqs, + num_kv_heads, + max_context_partition_num, + query_group_size, + ) + exp_sums = torch.empty( + intermediate_shape, dtype=torch.float32, device=q.device + ) + max_logits = torch.empty( + intermediate_shape, dtype=torch.float32, device=q.device + ) + temporary_output = torch.empty( + *intermediate_shape, + head_size, + dtype=q.dtype, + device=q.device, + ) + + per_tensor = False + if k_scale is not None: + per_tensor = k_scale.numel() == 1 + if not per_tensor: + k_scale = k_scale.unsqueeze(-1) + v_scale = v_scale.unsqueeze(-1) + compute_type = torch.bfloat16 if self.kv_cache_dtype == "bf16" or per_tensor else aiter.dtypes.fp8 + + torch.ops.aiter.pa_decode_gluon( + o, + q, + k_cache, + v_cache, + attn_metadata.plugin_metadata.seq_lens, + attn_metadata.block_tables, + self.scale, + 1, # query_lenth + max_context_partition_num, + context_partition_size, + compute_type, + None, + None if self.kv_cache_dtype == "bf16" else k_scale, + None if self.kv_cache_dtype == "bf16" else v_scale, + exp_sums=exp_sums, + max_logits=max_logits, + temporary_output=temporary_output, + alibi_slopes=None, + sinks=self.sinks, + sliding_window=self.sliding_window, + ps=True, + ) + + return o + + def paged_attention_asm_plugin_mode(self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + num_decodes: int, + num_decode_tokens: int, + attn_metadata: "AttentionMetaData", + out: torch.Tensor): + aiter.pa_fwd_asm( + Q=q, + K=k_cache, + V=v_cache, + block_tables=attn_metadata.plugin_metadata.block_table[:num_decodes], + context_lens=attn_metadata.plugin_metadata.seq_lens[:num_decodes], + block_tables_stride0=attn_metadata.plugin_metadata.block_table[ + :num_decodes + ].stride(0), + K_QScale=k_scale, + V_QScale=v_scale, + out_=out[:num_decode_tokens], + high_precision=0, + ) + + return + + def extend_for_sliding_window( + self, + attn_metadata: "AttentionMetaData", + query: torch.Tensor, + key_cache, + value_cache, + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + block_table: torch.Tensor, + k_scale: float, + v_scale: float, + ): + assert attn_metadata.plugin_metadata.extend_metadata is not None + assert attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata is not None + chunked_metadata = attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + swa_metadata = chunked_metadata.swa_metadata + assert swa_metadata is not None + swa_cu_seqlens = swa_metadata.swa_cu_seqlens + swa_seq_starts = swa_metadata.swa_seq_starts + swa_token_to_batch = swa_metadata.swa_token_to_batch + swa_max_seqlens = swa_metadata.swa_max_seqlens + swa_total_tokens = swa_metadata.swa_total_tokens + key_fetched, value_fetched = ( + swa_metadata.swa_workspace[0], + swa_metadata.swa_workspace[1], + ) + + from vllm.v1.attention.backends.rocm_aiter_fa import cp_mha_gather_cache + # key_cache_for_gather, value_cache_for_gather, _ = ( + # self._get_cp_mha_gather_cache_views(key_cache, value_cache) + # ) + cp_mha_gather_cache( + key_cache=key_cache, + value_cache=value_cache, + key=key_fetched, + value=value_fetched, + block_tables=block_table, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=swa_cu_seqlens, + token_to_batch=swa_token_to_batch, + seq_starts=swa_seq_starts, + dequant=self.kv_cache_dtype.startswith("fp8"), + kv_cache_layout="NHD", + total_tokens=swa_total_tokens, + ) + + sliding_window = (self.sliding_window, 0, 0) if self.sliding_window is not None else (-1, -1, 0) + aiter.flash_attn_varlen_func( + q=query, + k=key_fetched, + v=value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=swa_cu_seqlens, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=swa_max_seqlens, + min_seqlen_q=1, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=False, + out=output, + ) + + def extend_forward( + self, + attn_metadata: "AttentionMetaData", + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + min_seqlen_q: int, + block_table: torch.Tensor, + slot_mapping: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ): + from vllm.v1.attention.ops.merge_attn_states import merge_attn_states + from vllm.v1.attention.backends.rocm_aiter_fa import cp_mha_gather_cache + + if self.sliding_window != -1: + self.extend_for_sliding_window( + attn_metadata, + query, + key_cache, + value_cache, + output, + cu_seqlens_q, + max_seqlen_q, + block_table, + k_scale, + v_scale, + ) + return + out, lse = aiter.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_q, + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + return_lse=True, + ) + assert attn_metadata.plugin_metadata.extend_metadata is not None + chunk_context_metadata = attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + num_chunks = chunk_context_metadata.num_chunks + workspace = chunk_context_metadata.workspace + cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk + max_seqlens = chunk_context_metadata.max_seq_lens + chunk_starts = chunk_context_metadata.chunk_starts + token_to_batch = chunk_context_metadata.token_to_batch + total_token_per_batch = chunk_context_metadata.total_token_per_batch + key_fetched, value_fetched = workspace[0], workspace[1] + chunked_output = None + chunked_lse = None + # key_cache_for_gather, value_cache_for_gather, _ = ( + # self._get_cp_mha_gather_cache_views(key_cache, value_cache) + # ) + for chunk_idx in range(num_chunks): + cp_mha_gather_cache( + key_cache=key_cache, + value_cache=value_cache, + key=key_fetched, + value=value_fetched, + block_tables=block_table, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=cu_seqlens_kv[chunk_idx], + token_to_batch=token_to_batch[chunk_idx], + seq_starts=chunk_starts[chunk_idx], + dequant=self.kv_cache_dtype.startswith("fp8"), + kv_cache_layout="SHUFFLE", + total_tokens=total_token_per_batch[chunk_idx], + ) + + suf_out, suf_lse = aiter.flash_attn_varlen_func( + q=query, + k=key_fetched, + v=value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_kv[chunk_idx], + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlens[chunk_idx], + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=False, + window_size=(-1, -1, 0), + alibi_slopes=self.alibi_slopes, + return_lse=True, + ) + + if chunked_output is None: + chunked_output = suf_out + chunked_lse = suf_lse + else: + tmp_output = torch.empty_like(out) + tmp_lse = torch.empty_like(lse) + merge_attn_states( + output=tmp_output, + output_lse=tmp_lse, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=suf_out, + suffix_lse=suf_lse, + ) + chunked_output = tmp_output + chunked_lse = tmp_lse + + merge_attn_states( + output=output, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=out, + suffix_lse=lse, + ) + + def forward_impl_plugin_mode( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: "AttentionMetaData" = None, + position: torch.Tensor = None, + q_scale: torch.Tensor = None, + qkv: torch.Tensor = None, + output: torch.Tensor = None, + ): + # create the output here, it use query shape + num_tokens = query.shape[0] + output_dtype = query.dtype + output_shape = torch.Size( + (num_tokens, self.num_heads * self.head_size) + ) + output = torch.empty(output_shape, dtype=output_dtype, device=query.device) + + # dummy run will skip attention in cuda graph capture phase + if attn_metadata is None: + return output.fill_(0) + + # when using this optimization, the qkv tensor and + # position tensor are passed through q,k,v + if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: + assert position is None, "position should be None because it is passed through k" + + position = key + qkv = value + + q_size = self.num_heads * self.head_dim + kv_size = self.num_kv_heads * self.head_dim + query, key, value = torch.split(qkv, [q_size, kv_size, kv_size], dim=-1) + else: + # the position is computed by ATOM, and contained in attention metadata + # when dummy run, the attn metadata is None + if attn_metadata is not None: + position = attn_metadata.plugin_metadata.context.positions + + query = query.view(-1, self.num_heads, self.head_dim) + key = key.view(-1, self.num_kv_heads, self.head_dim) + value = value.view(-1, self.num_kv_heads, self.head_dim) + output = output.view(-1, self.num_heads, self.head_dim) + + num_actual_tokens = attn_metadata.plugin_metadata.num_actual_tokens + k_cache, v_cache = kv_cache.unbind(0) + num_blocks, block_size, num_kv_heads, _ = k_cache.shape + + if self.kv_cache_dtype == "fp8": + target_dtype = dtypes.d_dtypes[self.kv_cache_dtype] + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + # create kv scale according to the num_blocks + # usually it is created when cuda graph capture for decode phase + if self.kv_cache_dtype == "fp8": + if self.k_scale is None or self.v_scale is None: + self.kv_scale = torch.zeros( + 2, + num_blocks, + num_kv_heads, + block_size, + dtype=dtypes.fp32, + device=self.device, + ) + # update the layer kv scale tensor + self.k_scale = self.kv_scale[0] + self.v_scale = self.kv_scale[1] + layer.k_scale = self.k_scale + layer.v_scale = self.v_scale + + # rope and cache flush fusion + result = self.rope_cache_plugin_mode(q=query, + k=key, + v=value, + qkv=qkv, + position=position, + attention_metadata=attn_metadata, + k_cache=k_cache, + v_cache=v_cache, + k_scale=self.k_scale, + v_scale=self.v_scale, + flash_layout=False) + (query, key, value, k_cache, v_cache, k_scale, v_scale) = result + + # The tokens are storaged as [decode:extend:prefill] order + # which is decided by the vllm + query = query[:num_actual_tokens] + if key is not None: + key = key[:num_actual_tokens] + if value is not None: + value = value[:num_actual_tokens] + + output_actual_tokens = output[:num_actual_tokens] + + num_decodes = attn_metadata.plugin_metadata.num_decodes + num_prefills = attn_metadata.plugin_metadata.num_prefills + num_extends = attn_metadata.plugin_metadata.num_extends + + num_decode_tokens = attn_metadata.plugin_metadata.num_decode_tokens + num_extend_tokens = attn_metadata.plugin_metadata.num_extend_tokens + + # calculate for prefills + if num_prefills > 0: + assert attn_metadata.plugin_metadata.prefill_metadata is not None + + # prefill part is after decode and extend + prefill_query = query[num_decode_tokens + num_extend_tokens :] + prefill_key = key[num_decode_tokens + num_extend_tokens :] + prefill_value = value[num_decode_tokens + num_extend_tokens :] + + sliding_window = (self.sliding_window, 0, 0) if self.sliding_window is not None else (-1, -1, 0) + + aiter.flash_attn_varlen_func( + q=prefill_query, + k=prefill_key, + v=prefill_value, + cu_seqlens_q=attn_metadata.plugin_metadata.prefill_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.plugin_metadata.prefill_metadata.query_start_loc, + max_seqlen_q=attn_metadata.plugin_metadata.prefill_metadata.max_query_len, + max_seqlen_k=attn_metadata.plugin_metadata.prefill_metadata.max_seq_len, + min_seqlen_q=1, + dropout_p=attn_metadata.dropout_p, + softmax_scale=self.scale, + causal=True, + window_size=sliding_window, + alibi_slopes=None, + sink_ptr=self.sinks, + out=output_actual_tokens[num_decode_tokens + num_extend_tokens :], + ) + + # calculate for extends + if num_extends > 0: + assert attn_metadata.plugin_metadata.extend_metadata is not None + extend_tokens_slice = slice( + num_decode_tokens, num_decode_tokens + num_extend_tokens + ) + extend_querys = query[extend_tokens_slice] + extend_keys = key[extend_tokens_slice] + extend_values = value[extend_tokens_slice] + extend_outputs = output[extend_tokens_slice] + self.extend_forward( + attn_metadata=attn_metadata, + query=extend_querys, + key=extend_keys, + value=extend_values, + key_cache=k_cache, + value_cache=v_cache, + output=extend_outputs, + cu_seqlens_q=attn_metadata.plugin_metadata.extend_metadata.query_start_loc, + max_seqlen_q=attn_metadata.plugin_metadata.extend_metadata.max_query_len, + max_seqlen_k=attn_metadata.plugin_metadata.extend_metadata.max_seq_len, + min_seqlen_q=1, + block_table=attn_metadata.plugin_metadata.block_table[ + num_decodes : num_decodes + num_extends + ], + slot_mapping=attn_metadata.plugin_metadata.slot_mapping[ + num_decodes : num_decodes + num_extends + ], + k_scale=k_scale, + v_scale=v_scale, + ) + + # calculate for decodes + if num_decodes > 0: + assert attn_metadata.plugin_metadata.decode_metadata is not None + + num_blocks, block_size, num_kv_heads, head_size = k_cache.shape + x = 16 // k_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_cache.dtype, + device="meta", + ) + new_key_cache = k_cache.view_as(k_cache_template) + new_value_cache = v_cache.view_as(v_cache_template) + + if self.use_triton_attn: + self.paged_attention_triton_plugin_mode( + q=query[:num_decode_tokens], + k=new_key_cache, + v=new_value_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_actual_tokens[:num_decode_tokens], + attn_metadata=attn_metadata, + ) + else: + # Qwen only uses gluon pa decode when bs=64 + if num_decodes == 64: + self.paged_attention_triton_plugin_mode( + q=query[:num_decode_tokens], + k_cache=new_key_cache, + v_cache=new_value_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_actual_tokens[:num_decode_tokens], + attn_metadata=attn_metadata, + ) + else: + self.paged_attention_asm_plugin_mode( + q=query[:num_decode_tokens], + k_cache=new_key_cache, + v_cache=new_value_cache, + k_scale=k_scale, + v_scale=v_scale, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + out=output_actual_tokens[:num_decode_tokens], + attn_metadata=attn_metadata, + ) + + output = output.view(-1, self.num_heads * self.head_dim) + + return output + + +def PagedAttentionImplDecoratorForPluginMode(cls): + + method_names = [ + 'rope_cache_plugin_mode', + '_get_cp_mha_gather_cache_views', + 'paged_attention_triton_plugin_mode', + 'paged_attention_asm_plugin_mode', + 'extend_for_sliding_window', + 'extend_forward', + 'forward_impl_plugin_mode', + ] + + logger.info('Use PagedAttentionImplDecoratorForPluginMode to decorate PagedAttentionImpl') + + # Add all methods to the target class + for method_name in method_names: + method = getattr(PagedAttentionImplPluginModeMethods, method_name) + setattr(cls, method_name, method) + + return cls diff --git a/atom/plugin/config.py b/atom/plugin/config.py new file mode 100644 index 000000000..758180aff --- /dev/null +++ b/atom/plugin/config.py @@ -0,0 +1,234 @@ +import os +from sre_parse import MAX_UNTIL +import sys + +from typing import Any, Optional +from dataclasses import dataclass + +import torch +import logging + +logger = logging.getLogger("atom") +_KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD = 18 * 1024 + +@dataclass +class PluginConfig: + # common config for both framework + model_config: Any = None + rank: int = 0 + is_plugin_mode: bool = False + is_vllm: bool = False + is_sglang: bool = False + + # vllm specific + vllm_scheduler_config: Any = None + vllm_cache_config: Any = None + vllm_quant_config: Any = None + vllm_use_custom_attention: bool = False + + # sglang specific + sglang_model_opt_config: Any = None + sglang_load_config: Any = None + sglang_enable_torch_compile: bool = False + sglang_disable_cuda_graph: bool = False + sglang_enable_dp_attention: bool = False + sglang_dist_init_addr: Optional[str] = None + sglang_port_args: Any = None + + +def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: + from atom.config import Config, CompilationConfig + + vllm_model_config = config.model_config + vllm_scheduler_config = config.scheduler_config + vllm_cache_config = config.cache_config + vllm_parallel_config = config.parallel_config + vllm_use_custom_attention = bool(os.getenv("VLLM_ATTENTION_BACKEND", "None").lower() == "custom") + + # here use the ATOM compilation config, as the ATOM compile policy is used + # instead of vLLM one for torch compile, while for cuda graph capture, + # still use the vLLM + # when you don't want to use atom torch compile, you can also use + # --enforce-eager to disable the atom torch compile when launch vllm server + compilation_config = config.compilation_config + vllm_compilation_config = CompilationConfig( + # use mode because vllm level argument is deprecated + level=compilation_config.mode, + use_cudagraph=False, + cudagraph_mode=None, + ) + + vllm_quant_config = config.quant_config + + plugin_config = PluginConfig( + # common config + model_config=vllm_model_config, + rank=vllm_parallel_config.rank, + is_plugin_mode=True, + is_vllm=True, + is_sglang=False, + # vllm specific + vllm_scheduler_config=vllm_scheduler_config, + vllm_cache_config=vllm_cache_config, + vllm_quant_config=vllm_quant_config, + vllm_use_custom_attention=vllm_use_custom_attention, + ) + + # specific + max_model_len = vllm_model_config.max_model_len + if hasattr(vllm_scheduler_config, 'max_model_len'): + max_model_len = vllm_scheduler_config.max_model_len + + max_num_batched_tokens = vllm_scheduler_config.max_num_batched_tokens + # FIXME: known issue for illegal mem access in fused_moe kernel + if max_num_batched_tokens >= _KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD: + logger.warning("For plugin mode, when setting max_num_batched_tokens >= " + + f"{_KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD}, there is a known issue " + + "for illegal mem access in asm fused_moe kernel, if you met the issue, " + + "please set max_num_batched_tokens smaller or choose the ck fused_moe " + + "kernel instead of asm ones") + + return Config( + model=None, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=vllm_scheduler_config.max_num_seqs, + max_model_len=max_model_len, + gpu_memory_utilization=vllm_cache_config.gpu_memory_utilization, + tensor_parallel_size=vllm_parallel_config.tensor_parallel_size, + enforce_eager=vllm_model_config.enforce_eager, + parallel_config=vllm_parallel_config, + kv_cache_block_size=vllm_cache_config.block_size, + num_kvcache_blocks=vllm_cache_config.num_gpu_blocks, + kv_cache_dtype=vllm_cache_config.cache_dtype, + enable_prefix_caching=vllm_cache_config.enable_prefix_caching, + port=None, + torch_profiler_dir=None, + compilation_config=vllm_compilation_config, + asyncio_mode=False, + load_dummy=False, + enable_expert_parallel=vllm_parallel_config.enable_expert_parallel, + master_addr=None, + enable_dp_attention=False, + plugin_config=plugin_config, + ) + + +def _generate_atom_config_from_sglang_config(config: Any): + from sglang.srt.server_args import ( + ServerArgs, + prepare_server_args, + PortArgs, + ) + from sglang.srt.configs.model_config import ModelConfig as SglangModelConfig + from sglang.srt.configs.modelopt_config import ModelOptConfig + from sglang.srt.configs.load_config import LoadConfig + from atom.config import Config, ParallelConfig, CompilationConfig + + # sglang has no global config variable like vllm, + # so here construct the server args from sys.argv passed by users + # this is the only way to get full arguments + server_args: ServerArgs = prepare_server_args(sys.argv[1:]) + + sgl_model_config = SglangModelConfig.from_server_args(server_args) + sgl_model_opt_config = ModelOptConfig( + quant=server_args.modelopt_quant, + checkpoint_restore_path=server_args.modelopt_checkpoint_restore_path, + checkpoint_save_path=server_args.modelopt_checkpoint_save_path, + export_path=server_args.modelopt_export_path, + ) + + sgl_load_config = LoadConfig( + load_format=server_args.load_format, + download_dir=server_args.download_dir, + model_loader_extra_config=server_args.model_loader_extra_config, + remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip, + remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port, + remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports, + remote_instance_weight_loader_backend=server_args.remote_instance_weight_loader_backend, + modelopt_config=sgl_model_opt_config, + rl_quant_profile=server_args.rl_quant_profile, + ) + + # sglang doesn't passed the rank number in config, so ATOM plugin + # get rank number through the torch.distributed.get_rank() + rank = torch.distributed.get_rank() + + # sglang uses the atom parallel config + sgl_parallel_config = ParallelConfig( + data_parallel_size=server_args.dp_size, + ) + + # create the compilation config static_forward_context for sglang + sgl_compilation_config = CompilationConfig() + # for now disable using atom torch compile as sglang uses its own torch compile + sgl_compilation_config.level = 0 + sgl_compilation_config.use_cudagraph = False + sgl_compilation_config.static_forward_context = {} + + plugin_config = PluginConfig( + # common config + model_config=sgl_model_config, + rank=rank, + is_plugin_mode=True, + is_vllm=False, + is_sglang=True, + # sglang specific + sglang_model_opt_config=sgl_model_opt_config, + sglang_load_config=sgl_load_config, + sglang_enable_torch_compile=server_args.enable_torch_compile, + sglang_disable_cuda_graph=server_args.disable_cuda_graph, + sglang_enable_dp_attention=server_args.enable_dp_attention, + sglang_dist_init_addr=server_args.dist_init_addr, + sglang_port_args=PortArgs.init_new(server_args), + ) + + # TODO: sgl doesn't have max num batched tokens, so force to 16k + return Config( + model=None, + max_num_batched_tokens=16384, + max_num_seqs=server_args.max_running_requests, + max_model_len=server_args.context_length, + gpu_memory_utilization=server_args.mem_fraction_static, + tensor_parallel_size=server_args.tp_size, + enforce_eager=not server_args.enable_torch_compile, + parallel_config=sgl_parallel_config, + kv_cache_dtype=server_args.kv_cache_dtype, + enable_prefix_caching=False, + port=None, + torch_profiler_dir=None, + compilation_config=sgl_compilation_config, + asyncio_mode=False, + load_dummy=False, + enable_expert_parallel=bool(server_args.ep_size > 1), + master_addr=None, + enable_dp_attention=server_args.enable_dp_attention, + plugin_config=plugin_config, + ) + + +def generate_atom_config_for_plugin_mode(config: Any = None): + ''' + Generate the atom config in plugin mode, be called when create the custom model + config: + - for vllm: config is VllmConfig and contains all config value from vllm + - for sglang: config is only model specific config passed from sglang, so the + server args is used + ''' + + logger.info('Generate atom config for plugin mode from passed config') + + atom_config = None + from atom.plugin import is_vllm, is_sglang + from atom.config import set_current_atom_config + if is_vllm(): + atom_config = _generate_atom_config_from_vllm_config(config) + elif is_sglang(): + atom_config = _generate_atom_config_from_sglang_config(config) + else: + raise ValueError("Make sure ATOM is running in plugin mode, \ + the function generate_atom_config_for_plugin_mode should be called in plugin mode") + + # set the current atom config for the custom model + set_current_atom_config(atom_config) + + return atom_config \ No newline at end of file diff --git a/atom/plugin/moe.py b/atom/plugin/moe.py new file mode 100644 index 000000000..524faf0a1 --- /dev/null +++ b/atom/plugin/moe.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Type, TypeVar +import logging +from atom.plugin.prepare import is_vllm + +logger = logging.getLogger("atom") + +T = TypeVar('T') + + +def _apply_moe_decoration(original_cls: Type[T]) -> Type[T]: + """ + Apply the actual decoration to the MoE class. + This is called lazily during instantiation. + """ + is_vllm_mode = is_vllm() + if is_vllm_mode: + # rename this class because the vllm will call the modular + # kernel init method for all modules of the model, whose name is FusedMoE, + # to init the inside kernel, while for plugin mode, the atom maintains + # the kernel lifecycle by itself, so there is no need to call init on + # vllm side + original_cls.__name__ = "ATOMFusedMoE" + original_cls.__qualname__ = "ATOMFusedMoE" + + return original_cls + + +def FusedMoEDecoratorForPluginMode(cls: Type[T]) -> Type[T]: + """ + Lazy decorator that defers class modification until first instantiation + """ + original_cls = cls + decorated_cls_cache = {'value': None} + + def get_decorated_class(): + if decorated_cls_cache['value'] is not None: + return decorated_cls_cache['value'] + + decorated = _apply_moe_decoration(original_cls) + decorated_cls_cache['value'] = decorated + return decorated + + class LazyMoEWrapper(original_cls): + """Wrapper that defers decoration until first instantiation.""" + + def __new__(cls_wrapper, *args, **kwargs): + decorated_cls = get_decorated_class() + return decorated_cls(*args, **kwargs) + + # Preserve the original class name and module for the wrapper + LazyMoEWrapper.__name__ = original_cls.__name__ + LazyMoEWrapper.__qualname__ = original_cls.__qualname__ + LazyMoEWrapper.__module__ = original_cls.__module__ + + logger.info(f'Create lazy wrapper for FusedMoE to change the naming') + + return LazyMoEWrapper diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py new file mode 100644 index 000000000..d6bb4caa3 --- /dev/null +++ b/atom/plugin/prepare.py @@ -0,0 +1,85 @@ +from typing import Any +import logging + +logger = logging.getLogger("atom") + +# all of the supported frameworks, including server mode and plugin mode +_SUPPORTED_FRAMEWORKS = ["vllm", "sglang", "sgl", "atom"] + +# supported frameworks for plugin mode +_SUPPORTED_FRAMEWORKS_FOR_PLUGIN_MODE = ["vllm", "sglang", "sgl"] + +# default is atom for server mode +_CURRENT_FRAMEWORK = "atom" + + +def is_sglang() -> bool: + global _CURRENT_FRAMEWORK + if _CURRENT_FRAMEWORK is None: + raise ValueError("_CURRENT_FRAMEWORK must be set before use") + return bool(_CURRENT_FRAMEWORK.lower() in ["sglang", "sgl"]) + + +def is_vllm() -> bool: + global _CURRENT_FRAMEWORK + if _CURRENT_FRAMEWORK is None: + raise ValueError("_CURRENT_FRAMEWORK must be set before use") + return bool(_CURRENT_FRAMEWORK.lower() in ["vllm"]) + + +def is_plugin_mode() -> bool: + global _CURRENT_FRAMEWORK + return bool(_CURRENT_FRAMEWORK.lower() in _SUPPORTED_FRAMEWORKS_FOR_PLUGIN_MODE) + + +def _set_framework_backbone(framework: str) -> None: + if framework.lower() not in _SUPPORTED_FRAMEWORKS: + raise ValueError(f"Unsupported framework {framework} for ATOM to plug in") + global _CURRENT_FRAMEWORK + _CURRENT_FRAMEWORK = framework + + +def prepare_model(config: Any, engine: str): + ''' + Prepare the model to upper framework, including + register custom ops and init aiter dist + ''' + logging.info(f'Prepare model for plugin mode, the upper engine is {engine}') + + _set_framework_backbone(engine) + + # different engine passed different config + if is_vllm(): + model_arch = config.model_config.architectures[0] + elif is_sglang(): + model_arch = config.architectures[0] + + # import here to avoid partial initialization + from .register import ( + _ATOM_SUPPORTED_MODELS, + register_ops_to_vllm, + register_ops_to_sglang, + init_aiter_dist, + set_attn_cls, + ) + + if model_arch not in _ATOM_SUPPORTED_MODELS: + logger.warning(f"ATOM does not support the required model architecture: {model_arch}") + + from atom.plugin.config import generate_atom_config_for_plugin_mode + atom_config = generate_atom_config_for_plugin_mode(config) + + model_cls = _ATOM_SUPPORTED_MODELS[model_arch] + logger.info(f'ATOM model class for {model_arch} is {model_cls}') + + if is_vllm(): + register_ops_to_vllm(atom_config=atom_config) + elif is_sglang(): + register_ops_to_sglang(atom_config=atom_config) + + set_attn_cls() + + # init aiter dist for using aiter custom collective ops + init_aiter_dist(config=atom_config) + + return model_cls(atom_config=atom_config) diff --git a/atom/plugin/register.py b/atom/plugin/register.py new file mode 100644 index 000000000..187501fd6 --- /dev/null +++ b/atom/plugin/register.py @@ -0,0 +1,106 @@ +import torch +import logging + +from atom.models.qwen3 import Qwen3ForCausalLM +from atom.models.qwen3_moe import Qwen3MoeForCausalLM +from atom.config import Config +from atom.plugin.prepare import is_vllm, is_sglang + +logger = logging.getLogger("atom") + +_ATOM_SUPPORTED_MODELS = { + "Qwen3ForCausalLM" : Qwen3ForCausalLM, + "Qwen3MoeForCausalLM" : Qwen3MoeForCausalLM, +} + + +def _register_custom_attention_to_vllm() -> None: + from vllm.v1.attention.backends.registry import register_backend, AttentionBackendEnum + logger.info('Register custom attention backend AiterBackend to vLLM') + register_backend(backend=AttentionBackendEnum.CUSTOM, + is_mamba=False, + class_path="atom.model_ops.attentions.aiter_attention.AiterBackend") + + +def _register_custom_attention_to_sglang() -> None: + + from sglang.srt.layers.attention.attention_registry import register_attention_backend + + # here register the custom attention backend with the name "aiter" + # as sglang defines the fixed attention backend choices, which must be + # in-tree + logger.info('Register custom attention backend AiterBackend to SGLang') + + @register_attention_backend("aiter") + def create_atom_backend(runner): + from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + return AiterAttnBackend(runner) + + +def register_ops_to_vllm(atom_config: Config) -> None: + ''' + Register custom ops to vllm, including attention + ''' + if atom_config.plugin_config.vllm_use_custom_attention: + _register_custom_attention_to_vllm() + else: + logger.warning("Please export VLLM_ATTENTION_BACKEND=CUSTOM to use atom attention") + + +def register_ops_to_sglang(atom_config: Config) -> None: + ''' + Register custom ops to sglang, including attention + ''' + _register_custom_attention_to_sglang() + + +def set_attn_cls() -> None: + ''' + Set the attention class for constructing the model based on the framework + ''' + import atom.model_ops as ops + from atom.model_ops import PagedAttention, RadixAttention + + if is_vllm(): + ops.ATTN_CLS = PagedAttention + logger.info('Set ATTN_CLS to PagedAttention for vLLM') + elif is_sglang(): + ops.ATTN_CLS = RadixAttention + logger.info('Set ATTN_CLS to RadixAttention for SGLang') + + +def init_aiter_dist(config: Config) -> None: + ''' + Initialize aiter dist for using aiter custom collective op + ''' + logger.info('Initialize aiter dist for using aiter custom collective op for plugin mode') + + from aiter import init_dist_env + from aiter.dist.utils import get_distributed_init_method + + rank = config.plugin_config.rank + tensor_parallel_size = config.tensor_parallel_size + + assert config.plugin_config.is_plugin_mode, "Make sure ATOM is running in plugin mode" + + if config.plugin_config.is_vllm: + dp_master_ip = config.parallel_config.data_parallel_master_ip + dp_master_port = config.parallel_config.data_parallel_master_port + elif config.plugin_config.is_sglang: + if config.plugin_config.sglang_dist_init_addr is not None: + dp_master_ip, dp_master_port = config.plugin_config.sglang_dist_init_addr.split(":") + else: + dp_master_ip = f"127.0.0.1" + dp_master_port = config.plugin_config.sglang_port_args.nccl_port + + distributed_init_method = get_distributed_init_method(dp_master_ip, dp_master_port) + + logger.info(f'Initialize aiter dist for using aiter custom collective op for plugin mode, rank:{rank}') + init_dist_env( + tensor_model_parallel_size=tensor_parallel_size, + rankID=rank, + backend="nccl", + distributed_init_method=distributed_init_method, + data_parallel_size=config.parallel_config.data_parallel_size, + data_parallel_rank=config.parallel_config.data_parallel_rank, + ) diff --git a/atom/utils/backends.py b/atom/utils/backends.py index b583ba447..9f606ab9a 100644 --- a/atom/utils/backends.py +++ b/atom/utils/backends.py @@ -255,6 +255,23 @@ class SplitItem: graph: fx.GraphModule +# used to judge whether the node should be split or not +def _split_judge_func(node: fx.Node) -> bool: + # ATOM use mark_spliting_op to mark the attn as splitting op + if node.op == "call_function" and ( + hasattr(node.target, "spliting_op") and (node.target.spliting_op) + ): + return True + + # When plugin mode(vLLM), the attention impl op is registered + # as unified_attention + from atom.plugin import is_vllm + if is_vllm() and "unified_attention" in node.name: + return True + + return False + + def split_graph( graph: fx.GraphModule, ops: list[str] ) -> tuple[fx.GraphModule, list[SplitItem]]: @@ -265,9 +282,7 @@ def split_graph( for node in graph.graph.nodes: if node.op in ("output", "placeholder"): continue - if node.op == "call_function" and ( - hasattr(node.target, "spliting_op") and (node.target.spliting_op) - ): + if _split_judge_func(node): subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id split_op_graphs.append(subgraph_id) diff --git a/atom/utils/forward_context.py b/atom/utils/forward_context.py index 3682d98f5..78c751c9a 100644 --- a/atom/utils/forward_context.py +++ b/atom/utils/forward_context.py @@ -9,6 +9,9 @@ import torch from atom.config import Config, KVCacheTensor, ParallelConfig +if TYPE_CHECKING: + from atom.plugin.attention import MetadataForPluginMode + def _compute_chunked_local_num_tokens( num_tokens_across_dp_cpu: list[int], max_num_tokens: int, chunk_idx: int @@ -185,6 +188,9 @@ class AttentionMetaData: block_tables_converted: Optional[torch.Tensor] = None kv_indices_converted: Optional[torch.Tensor] = None + # only used for plugin mode to store the metadata for attn + plugin_metadata: Optional["MetadataForPluginMode"] = None + def __init__( self, cu_seqlens_q: Optional[torch.Tensor] = None, @@ -212,6 +218,7 @@ def __init__( kv_indices_converted: Optional[torch.Tensor] = None, sparse_cu_seqlens_q: Optional[torch.Tensor] = None, token_to_seq_idxs: Optional[torch.Tensor] = None, + plugin_metadata: Optional["MetadataForPluginMode"] = None, ): self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k @@ -238,6 +245,10 @@ def __init__( self.block_tables = block_tables_converted if kv_indices_converted is not None: self.kv_indices = kv_indices_converted + if plugin_metadata is not None: + from atom.plugin.prepare import is_plugin_mode + assert is_plugin_mode(), "plugin_metadata is only supported for plugin mode" + self.plugin_metadata = plugin_metadata self.sparse_cu_seqlens_q = sparse_cu_seqlens_q self.token_to_seq_idxs = token_to_seq_idxs diff --git a/recipes/Model-Impl-Backend.md b/recipes/Model-Impl-Backend.md new file mode 100644 index 000000000..ecf1782dc --- /dev/null +++ b/recipes/Model-Impl-Backend.md @@ -0,0 +1,168 @@ +# Model Impl Backend of vLLM and SGLang +ATOM can work as model implementation backend of popular framework, like vLLM and SGLang. The users can launch vLLM and SGLang server like before and specify an extra argument to enable the ATOM model backend, where the optimized implementation of the required target model will be provided to vLLM and SGLang to execute. When ATOM working under this mode, both framework-level features from vLLM/SGLang and latest model-level fusion kernels from ATOM/AITER can be combined together to achieve the competitive performance. + +- Here is a detailed design slide for this feature: https://amdcloud-my.sharepoint.com/:p:/g/personal/zejchen_amd_com/IQCFdvmEeLTWT7ysApmZv_hVAfw2nTo8iesJZGblHS0evqQ?e=hjnIDM +- Here is the RFC to introduce the ATOM as model impl backend into vLLM: https://github.com/vllm-project/vllm/issues/33478 +- Here is the RFC to introduce the ATOM as model impl backend into SGLang: TODO + +## Preparing environment for vLLM with ATOM model backend +Here is the PR to introduce ATOM into vLLM: https://github.com/vllm-project/vllm/pull/32160, when this PR would be merged, the official vLLM can be used, but for now you need to use develop vllm branch + +Pull the latest docker from vLLM official nightly docker for ROCm from https://hub.docker.com/r/rocm/vllm-dev/tags +```bash +docker pull rocm/vllm-dev:nightly +``` +Launch the container as usual, then all the next operations will be executed inside the container +Then the specific vLLM should be used because the PR to introduce the ATOM into vLLM has not been merged yet, so you need to: +```bash +pip uninstall -y vllm +git clone https://github.com/zejunchen-zejun/vllm.git +cd vllm +git checkout origin/zejun/model_impl +export PYTORCH_ROCM_ARCH="gfx950" +python3 setup.py develop 2>&1 | tee build.log +``` +Then the ATOM should be installed +```bash +git clone https://github.com/zejunchen-zejun/ATOM.git +cd ATOM +git checkout origin/zejun/plugin_for_atom_1223 +pip install -e . 2>&1 | tee build.log +``` +For AITER, there is no specific requirements, however, if you find any latest fusion kernels are missing, you may need to upgrade the AITER +Additionally, you may need to upgrade your triton version by: +```bash +pip install --upgrade triton +``` + +### Launching server of vLLM with ATOM model backend +You just need to deploy 2 code changes to your previous server launch command. The one is using CUSTOM vLLM attention backend, the other is a new argument of specifying the ATOM model impl backend. Here is the an example. From the example, the specific fusion kernels are used by enabling the env flags, which is not easy to integrate into vLLM as vLLM has some heuristic to stipulate the boundary of ops and layers, where ATOM can provide the those kernels +```bash +export VLLM_ATTENTION_BACKEND=CUSTOM + +export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 +export AITER_QUICK_REDUCE_QUANTIZATION=INT4 + +export SAFETENSORS_FAST_GPU=1 +export VLLM_ROCM_USE_AITER=1 +export VLLM_RPC_TIMEOUT=1800000 + +export VLLM_CACHE_ROOT=/root/.cache/vllm +export TORCHINDUCTOR_CACHE_DIR=/root/.cache/inductor + +rm -rf /root/.cache/ + +model_path= + +vllm serve $model_path \ + --host localhost \ + --port 8000 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --trust-remote-code \ + --disable-log-requests \ + --gpu_memory_utilization 0.9 \ + --async-scheduling \ + --load-format fastsafetensors \ + --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \ + --kv-cache-dtype fp8 \ + --max-num-batched-tokens 18432 \ + --max-model-len 16384 \ + --no-enable-prefix-caching \ + --model-impl atom \ + 2>&1 | tee log.serve.log & +``` + +### Launching client for validating the accuracy +```bash +addr=localhost +port=8000 +url=http://${addr}:${port}/v1/completions +model= +task=gsm8k +lm_eval --model local-completions \ + --model_args model=${model},base_url=${url},num_concurrent=65,max_retries=1,tokenized_requests=False \ + --tasks ${task} \ + --num_fewshot 3 \ + 2>&1 | tee log.lmeval.log +``` + +### Results for accuracy validation +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 3|exact_match|↑ |0.8901|± |0.0086| +| | |strict-match | 3|exact_match|↑ |0.8772|± |0.0090| + +### Known Limitations +There are some known limitations for now: +- Only Qwen-Dense and Qwen-MoE family models are supported +- Only TP and EP are supported + + +## Preparing environment for SGLang with ATOM model backend +Here is the PR to introduce ATOM into SGLang: https://github.com/sgl-project/sglang/pull/16944, when this PR would be merged, the official SGLang can be used, but for now you need to use develop vllm branch +Pull the latest docker from SGLang official nightly docker for ROCm from https://hub.docker.com/r/rocm/sgl-dev/tags +```bash +docker pull rocm/sgl-dev:v0.5.8-rocm720-mi35x-20260130-preview +``` +Launch the container as usual, then all the next operations will be executed inside the container +Then the specific SGLang should be used because the PR to introduce the ATOM into SGLang has not been merged yet, so you need to: +```bash +git clone https://github.com/zejunchen-zejun/sglang.git +git checkout remotes/origin/zejun/model_impl +pip uninstall sglang -y +pip uninstall sgl-kernel -y +cd sgl-kernel +python3 setup_rocm.py install +export PYTHONPATH= +``` +Then the ATOM should be installed +```bash +git clone https://github.com/zejunchen-zejun/ATOM.git +cd ATOM +git checkout origin/zejun/plugin_for_atom_1223 +pip install -e . 2>&1 | tee build.log +``` +For AITER, there is no specific requirements, however, if you find any latest fusion kernels are missing, you may need to upgrade the AITER + +### Launching server of SGLang with ATOM model backend +You just need to deploy single code change, as add --model-impl atom to your SGLang server command. Here is an example: +```bash +export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 + +# quick allreduce +export AITER_QUICK_REDUCE_QUANTIZATION=INT4 +model_path=/data/models/Qwen3-235B-A22B-Instruct-2507-FP8 + +python3 -m sglang.launch_server \ + --model-path $model_path \ + --host localhost \ + --port 8000 \ + --trust-remote-code \ + --tensor-parallel-size 8 \ + --expert-parallel-size 8 \ + --kv-cache-dtype fp8_e4m3 \ + --mem-fraction-static 0.8 \ + --model-impl atom \ + 2>&1 | tee log.serve.log & +``` + +### Launching client for validating the accuracy +```bash +addr=localhost +port=8000 +url=http://${addr}:${port}/v1/completions +model= +task=gsm8k +lm_eval --model local-completions \ + --model_args model=${model},base_url=${url},num_concurrent=65,max_retries=1,tokenized_requests=False \ + --tasks ${task} \ + --num_fewshot 3 \ + 2>&1 | tee log.lmeval.log +``` + +### Known Limitations +There are some known limitations for now: +- Only Qwen-Dense and Qwen-MoE family models are supported +- Only TP and EP are supported +- For SGLang, there is still accuracy issue for now, but we will fix it soon From a5b5f3f749e497b42dd57d2aba08f3d4936fb6c4 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 2 Feb 2026 16:34:52 +0800 Subject: [PATCH 02/37] add Signed-off-by: zejunchen-zejun --- atom/config.py | 7 +++++-- atom/model_loader/loader.py | 6 +++--- atom/model_ops/__init__.py | 9 ++++----- atom/model_ops/attention_mha.py | 2 ++ atom/model_ops/radix_attention.py | 17 +++++++++-------- atom/plugin/attention.py | 23 +++++++---------------- atom/plugin/attention_mha.py | 6 +----- atom/plugin/config.py | 7 ++++--- atom/plugin/moe.py | 7 ++----- atom/plugin/prepare.py | 10 +++++----- atom/plugin/register.py | 1 - atom/utils/forward_context.py | 6 ++---- 12 files changed, 44 insertions(+), 57 deletions(-) diff --git a/atom/config.py b/atom/config.py index 597f29146..d5ed63c30 100644 --- a/atom/config.py +++ b/atom/config.py @@ -17,7 +17,7 @@ from torch.distributed import ProcessGroup, ReduceOp from transformers import AutoConfig, GenerationConfig, PretrainedConfig -# only for plugin mode +# plugin-related utilities from atom.plugin import is_plugin_mode from atom.plugin.config import PluginConfig @@ -610,6 +610,7 @@ def __post_init__(self): ), f"kv_cache_block_size ({self.kv_cache_block_size}) must be a multiple of 16 or 1" assert 1 <= self.tensor_parallel_size <= 8 if is_plugin_mode(): + assert self.plugin_config is not None, "plugin_config is required in plugin mode" self.hf_config = self.plugin_config.model_config.hf_config else: self.hf_config = get_hf_config(self.model) @@ -646,7 +647,9 @@ def __post_init__(self): # only for server mode or plugin mode(vllm) # for torch compile policy, plugin mode(vllm) uses the ATOM compile policy # for cuda graph capture, plugin mode(vllm) uses the vLLM's cuda graph capture policy - if not is_plugin_mode() or self.plugin_config.is_vllm: + if not is_plugin_mode() or ( + self.plugin_config is not None and self.plugin_config.is_vllm + ): if self.compilation_config.level == CompilationLevel.PIECEWISE: self.compilation_config.set_splitting_ops_for_v1() self._set_cudagraph_sizes() diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index 7c6e4f314..c5d3a8977 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -82,13 +82,13 @@ def safetensors_weights_iterator( # when plugin mode, model loader method is bind to model implementation -# thus call this interface to load the model, which leverags the load_model +# thus call this interface to load the model, which leverages the load_model # method def load_model_in_plugin_mode( model, config, prefix: str = "", -) -> set[str] | None: +) -> set[str]: # during loading model, the outplace operation may consume more # GPU mem, which cached in torch caching allocator, here actively @@ -131,7 +131,7 @@ def load_model( ): # need to record the loaded weight name for vllm load check # it is only used in plugin mode for vllm - loaded_weights_record = set[str]() + loaded_weights_record: set[str] = set() packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) weights_mapping = getattr(model, "weights_mapping", {}) diff --git a/atom/model_ops/__init__.py b/atom/model_ops/__init__.py index 9c06892ec..fc42e0468 100644 --- a/atom/model_ops/__init__.py +++ b/atom/model_ops/__init__.py @@ -1,9 +1,8 @@ from .paged_attention import PagedAttention from .radix_attention import RadixAttention -# this global class is used to construct the attention op in model -# it can be assigned to different attention op -# default PagedAttention is used as ATOM for now supports PagedAttention -# for sglang, RadixAttention will be assigned to ATTN_CLS -# TODO: add env flag or argument to swicth the attention class +# This global class is used to construct the attention op in model, +# it can be assigned to different attention ops. +# By default, PagedAttention is used. +# For sglang, RadixAttention will be assigned to ATTN_CLS ATTN_CLS = PagedAttention diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index e5be98244..ab443a7a2 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -74,6 +74,8 @@ def __init__( if is_vllm(): self.supports_quant_query_input = False + # this method will just be called by vLLM and there is no logic in this method + # as ATOM handle the process after loading weights for all ops by itself def process_weights_after_loading(self, act_dtype: torch.dtype = torch.bfloat16): pass diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index 9f892e797..d2a10ff70 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -import aiter import torch from torch import nn from typing import Optional @@ -68,7 +67,7 @@ def forward_impl_plugin_mode( output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, - position: torch.Tensor = None, + positions: torch.Tensor = None, q_scale: torch.Tensor=None, **kwargs, ): @@ -81,15 +80,16 @@ def forward_impl_plugin_mode( v=value, forward_batch=forward_batch) else: - raise NotImplementedError("RadixAttention is only supported \ - for plugin mode for sglang for now") + raise NotImplementedError( + "RadixAttention is only supported for plugin mode for sglang for now" + ) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - position: torch.Tensor = None, + positions: torch.Tensor = None, q_scale: Optional[torch.Tensor]=None, **kwargs, ): @@ -97,10 +97,11 @@ def forward( o = self.forward_impl_plugin_mode(query=query, key=key, value=value, - position=position, + positions=positions, q_scale=q_scale, **kwargs) else: - raise NotImplementedError("RadixAttention is not supported for server \ - mode for now") + raise NotImplementedError( + "RadixAttention is not supported for server mode for now" + ) return o diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index 617e9d062..7d71454f6 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -1,4 +1,4 @@ -from typing import Any, Type, Optional +from typing import Optional import logging from dataclasses import dataclass @@ -53,7 +53,7 @@ class AiterChunkContextMetadata: seq_lens: torch.Tensor num_chunks: int total_token_per_batch: list[int] - swa_metadata: AiterChunkSlidingWindowMetadata | None + swa_metadata: Optional[AiterChunkSlidingWindowMetadata] = None @dataclass @@ -84,7 +84,7 @@ class MetadataForPluginMode: slot_mapping: torch.Tensor block_table: torch.Tensor - # prefill and deocde split + # prefill and decode split num_decodes: int num_decode_tokens: int num_prefills: int @@ -92,9 +92,9 @@ class MetadataForPluginMode: num_extends: int num_extend_tokens: int - decode_metadata: AiterFlashAttentionDecodeMetadata | None - prefill_metadata: AiterFlashAttentionPrefillMetadata | None - extend_metadata: AiterFlashAttentionChunkPrefillMetadata | None + decode_metadata: Optional[AiterFlashAttentionDecodeMetadata] = None + prefill_metadata: Optional[AiterFlashAttentionPrefillMetadata] = None + extend_metadata: Optional[AiterFlashAttentionChunkPrefillMetadata] = None # For cascade attention. use_cascade: bool @@ -211,7 +211,7 @@ def init_method_under_plugin_mode(self, self.head_dim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size - self.aot_sliding_window: tuple[int, int] | None = None + self.aot_sliding_window: Optional[tuple[int, int]] = None self.total_tokens: int = 0 self.scheduler_config = config.scheduler_config @@ -521,15 +521,6 @@ def build( plugin_metadata=attn_metadata_for_plugin_mode, ) - # TODO: set the forward context - # set_forward_context( - # attn_metadata=attn_metadata, - # atom_config=self.config, - # context=context, - # num_tokens=num_tokens, - # num_tokens_across_dp=num_tokens_across_dp, - # ) - return attn_metadata # this method will be called by vllm, so it follows the vllm's interface convention diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index 5d61292a5..35fbf8149 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -146,7 +146,6 @@ def rope_cache_plugin_mode(self, output_zeros=False, ) else: - # for asm paged attention if self.rotary_emb is not None: assert position is not None q, k = self.rotary_emb(position, q, k) @@ -428,9 +427,6 @@ def extend_forward( key_fetched, value_fetched = workspace[0], workspace[1] chunked_output = None chunked_lse = None - # key_cache_for_gather, value_cache_for_gather, _ = ( - # self._get_cp_mha_gather_cache_views(key_cache, value_cache) - # ) for chunk_idx in range(num_chunks): cp_mha_gather_cache( key_cache=key_cache, @@ -564,7 +560,7 @@ def forward_impl_plugin_mode( layer.k_scale = self.k_scale layer.v_scale = self.v_scale - # rope and cache flush fusion + # rope and cache flush fusion. ATOM always use shuffle layout for kv cache result = self.rope_cache_plugin_mode(q=query, k=key, v=value, diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 758180aff..cd1771d91 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -1,5 +1,4 @@ import os -from sre_parse import MAX_UNTIL import sys from typing import Any, Optional @@ -225,8 +224,10 @@ def generate_atom_config_for_plugin_mode(config: Any = None): elif is_sglang(): atom_config = _generate_atom_config_from_sglang_config(config) else: - raise ValueError("Make sure ATOM is running in plugin mode, \ - the function generate_atom_config_for_plugin_mode should be called in plugin mode") + raise ValueError( + "Make sure ATOM is running in plugin mode; " + "generate_atom_config_for_plugin_mode should be called in plugin mode." + ) # set the current atom config for the custom model set_current_atom_config(atom_config) diff --git a/atom/plugin/moe.py b/atom/plugin/moe.py index 524faf0a1..1f7658675 100644 --- a/atom/plugin/moe.py +++ b/atom/plugin/moe.py @@ -17,7 +17,7 @@ def _apply_moe_decoration(original_cls: Type[T]) -> Type[T]: """ is_vllm_mode = is_vllm() if is_vllm_mode: - # rename this class because the vllm will call the modular + # Rename this class because vLLM will call the modular # kernel init method for all modules of the model, whose name is FusedMoE, # to init the inside kernel, while for plugin mode, the atom maintains # the kernel lifecycle by itself, so there is no need to call init on @@ -44,9 +44,7 @@ def get_decorated_class(): return decorated class LazyMoEWrapper(original_cls): - """Wrapper that defers decoration until first instantiation.""" - - def __new__(cls_wrapper, *args, **kwargs): + def __new__(cls, *args, **kwargs): decorated_cls = get_decorated_class() return decorated_cls(*args, **kwargs) @@ -56,5 +54,4 @@ def __new__(cls_wrapper, *args, **kwargs): LazyMoEWrapper.__module__ = original_cls.__module__ logger.info(f'Create lazy wrapper for FusedMoE to change the naming') - return LazyMoEWrapper diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py index d6bb4caa3..ec11d07ca 100644 --- a/atom/plugin/prepare.py +++ b/atom/plugin/prepare.py @@ -15,15 +15,11 @@ def is_sglang() -> bool: global _CURRENT_FRAMEWORK - if _CURRENT_FRAMEWORK is None: - raise ValueError("_CURRENT_FRAMEWORK must be set before use") return bool(_CURRENT_FRAMEWORK.lower() in ["sglang", "sgl"]) def is_vllm() -> bool: global _CURRENT_FRAMEWORK - if _CURRENT_FRAMEWORK is None: - raise ValueError("_CURRENT_FRAMEWORK must be set before use") return bool(_CURRENT_FRAMEWORK.lower() in ["vllm"]) @@ -64,7 +60,11 @@ def prepare_model(config: Any, engine: str): ) if model_arch not in _ATOM_SUPPORTED_MODELS: - logger.warning(f"ATOM does not support the required model architecture: {model_arch}") + supported_archs = list(_ATOM_SUPPORTED_MODELS.keys()) + raise ValueError( + f"ATOM does not support the required model architecture: {model_arch}. " + f"For now supported model architectures: {supported_archs}" + ) from atom.plugin.config import generate_atom_config_for_plugin_mode atom_config = generate_atom_config_for_plugin_mode(config) diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 187501fd6..5e7584f02 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -1,4 +1,3 @@ -import torch import logging from atom.models.qwen3 import Qwen3ForCausalLM diff --git a/atom/utils/forward_context.py b/atom/utils/forward_context.py index 78c751c9a..9581f1e75 100644 --- a/atom/utils/forward_context.py +++ b/atom/utils/forward_context.py @@ -245,12 +245,10 @@ def __init__( self.block_tables = block_tables_converted if kv_indices_converted is not None: self.kv_indices = kv_indices_converted - if plugin_metadata is not None: - from atom.plugin.prepare import is_plugin_mode - assert is_plugin_mode(), "plugin_metadata is only supported for plugin mode" - self.plugin_metadata = plugin_metadata self.sparse_cu_seqlens_q = sparse_cu_seqlens_q self.token_to_seq_idxs = token_to_seq_idxs + if plugin_metadata is not None: + self.plugin_metadata = plugin_metadata def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" From 6563083255e37d50435df83a068a1adf1dcd4521 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 2 Feb 2026 16:38:40 +0800 Subject: [PATCH 03/37] add Signed-off-by: zejunchen-zejun --- atom/config.py | 5 +- atom/model_engine/model_runner.py | 4 +- atom/model_loader/loader.py | 26 +-- atom/model_ops/attention_mha.py | 38 ++-- atom/model_ops/attentions/aiter_attention.py | 14 +- atom/model_ops/base_attention.py | 18 +- atom/model_ops/paged_attention.py | 48 +++--- atom/model_ops/radix_attention.py | 59 ++++--- atom/models/qwen3.py | 42 +++-- atom/models/qwen3_moe.py | 37 ++-- atom/plugin/attention.py | 90 ++++++---- atom/plugin/attention_mha.py | 172 +++++++++++-------- atom/plugin/config.py | 38 ++-- atom/plugin/moe.py | 14 +- atom/plugin/prepare.py | 9 +- atom/plugin/register.py | 67 +++++--- atom/utils/backends.py | 1 + 17 files changed, 395 insertions(+), 287 deletions(-) diff --git a/atom/config.py b/atom/config.py index d5ed63c30..6a5b80725 100644 --- a/atom/config.py +++ b/atom/config.py @@ -610,7 +610,10 @@ def __post_init__(self): ), f"kv_cache_block_size ({self.kv_cache_block_size}) must be a multiple of 16 or 1" assert 1 <= self.tensor_parallel_size <= 8 if is_plugin_mode(): - assert self.plugin_config is not None, "plugin_config is required in plugin mode" + # plugin mode + assert ( + self.plugin_config is not None + ), "plugin_config is required in plugin mode" self.hf_config = self.plugin_config.model_config.hf_config else: self.hf_config = get_hf_config(self.model) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 455cde057..0c208c905 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -608,7 +608,9 @@ def __init__(self, rank: int, config: Config): self.drafter.load_model(self.model) torch.set_default_device(self.device) self.allocate_forward_vars() - self.attn_metadata_builder = self.attn_backend.get_builder_cls()(model_runner=self) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + model_runner=self + ) self.physical_block_size = self.attn_metadata_builder.block_size self.forward_done_event = torch.cuda.Event() self.warmup_model() diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index c5d3a8977..f40de844f 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -95,26 +95,29 @@ def load_model_in_plugin_mode( # call empty cache to free the extra reserved but not used memory def _empty_cache(): import gc + gc.collect() torch.cuda.empty_cache() - assert config.plugin_config is not None and \ - config.plugin_config.is_plugin_mode, \ - "ATOM is not running in plugin mode" + assert ( + config.plugin_config is not None and config.plugin_config.is_plugin_mode + ), "ATOM is not running in plugin mode" if config.plugin_config.is_vllm: model_name_or_path = config.plugin_config.model_config.model elif config.plugin_config.is_sglang: model_name_or_path = config.plugin_config.model_config.model_path _empty_cache() - loaded_weights_record = load_model(model=model, - model_name_or_path=model_name_or_path, - hf_config=config.hf_config, - load_dummy=config.load_dummy, - spec_decode=False, - prefix=prefix, - is_plugin_mode=True, - act_dtype=config.plugin_config.model_config.dtype) + loaded_weights_record = load_model( + model=model, + model_name_or_path=model_name_or_path, + hf_config=config.hf_config, + load_dummy=config.load_dummy, + spec_decode=False, + prefix=prefix, + is_plugin_mode=True, + act_dtype=config.plugin_config.model_config.dtype, + ) _empty_cache() return loaded_weights_record @@ -252,6 +255,7 @@ def load_model( if hasattr(module, "process_weights_after_loading"): if is_vllm(): from vllm.attention.layer import Attention + # call vLLM attn weights post processing with act_dtype if using vLLM attention module if isinstance(module, Attention): module.process_weights_after_loading(act_dtype=act_dtype) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index ab443a7a2..b21055bfc 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -24,6 +24,7 @@ class PagedAttentionImpl(nn.Module): """ Attention paged implementation """ + def __init__( self, num_heads, @@ -34,7 +35,7 @@ def __init__( sliding_window: Optional[int] = None, kv_cache_dtype="bf16", logits_soft_cap: float | None = None, - attn_type = None, + attn_type=None, kv_sharing_target_layer_name: int | None = None, layer_num=0, mla_modules: Optional[MLAModules] = None, @@ -56,7 +57,7 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.max_model_len = 0 self.k_scale = self.v_scale = None - self.device = 'cuda:' + str(torch.cuda.current_device()) + self.device = "cuda:" + str(torch.cuda.current_device()) self.layer_num = layer_num self.kv_scale_float = ( torch.finfo(torch.float8_e4m3fn).max / torch.finfo(aiter.dtypes.fp8).max @@ -445,9 +446,9 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor = None, - attn_metadata = None, + attn_metadata=None, position: torch.Tensor = None, - q_scale: Optional[torch.Tensor]=None, + q_scale: Optional[torch.Tensor] = None, qkv: torch.Tensor = None, output: torch.Tensor = None, **kwargs, @@ -455,22 +456,21 @@ def forward( if is_plugin_mode(): # forward impl method are added by the decorator # PagedAttentionImplDecoratorForPluginMode - return self.forward_impl_plugin_mode(layer=layer, - query=query, - key=key, - value=value, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - position=position, - q_scale=q_scale, - qkv=qkv) + return self.forward_impl_plugin_mode( + layer=layer, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + position=position, + q_scale=q_scale, + qkv=qkv, + ) else: # only for server mode, keep the original method - o = self.forward_impl_server_mode(q=query, - k=key, - v=value, - position=position, - q_scale=q_scale, - qkv=qkv) + o = self.forward_impl_server_mode( + q=query, k=key, v=value, position=position, q_scale=q_scale, qkv=qkv + ) return o diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index ef69c40dc..8097a8738 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -41,17 +41,19 @@ def get_impl_cls(): raise NotImplementedError("RadixAttention is not supported for now") -@AiterAttentionMetadataBuilderDecoratorForPluginMode(default_base_class=CommonAttentionBuilder) +@AiterAttentionMetadataBuilderDecoratorForPluginMode( + default_base_class=CommonAttentionBuilder +) class AiterAttentionMetadataBuilder: BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] def __init__( self, - kv_cache_spec = None, - layer_names = None, - config = None, - device = None, - model_runner = None, + kv_cache_spec=None, + layer_names=None, + config=None, + device=None, + model_runner=None, ): self.block_size = 1024 if model_runner.block_size == 1024 else 16 # Note: Cannot use super() here because the class is dynamically created by decorator diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index 544aa5fba..3391ba349 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -55,13 +55,16 @@ def unified_attention_with_output_base( if use_mla: return self.impl.forward(q, k, v, positions, q_scale, qkv) else: - return self.impl.forward(layer=self, - query=q, - key=k, - value=v, - position=positions, - q_scale=q_scale, - qkv=qkv) + return self.impl.forward( + layer=self, + query=q, + key=k, + value=v, + position=positions, + q_scale=q_scale, + qkv=qkv, + ) + class BaseAttention(nn.Module, ABC): """ @@ -101,4 +104,3 @@ def forward( raise NotImplementedError( f"{self.__class__.__name__} must implement the forward() method" ) - diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index 75ca9b314..cffb79d4f 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -14,10 +14,12 @@ from atom.plugin.prepare import is_sglang, is_vllm from atom.plugin.attention import unified_attention_with_output_base_for_plugin_mode + class PagedAttention(BaseAttention): """ Attention paged implementation """ + def __init__( self, num_heads, @@ -39,20 +41,24 @@ def __init__( ): # plugin mode(sglang) is not support paged attention # for now, only support plugin mode(vllm) and atom server mode - assert not is_sglang(), "PagedAttention is not supported for plugin mode(sglang) for now" - super().__init__(num_heads=num_heads, - head_dim=head_dim, - scale=scale, - num_kv_heads=num_kv_heads, - kv_cache_dtype=kv_cache_dtype, - layer_num=layer_num, - use_mla=use_mla, - mla_modules=mla_modules, - sinks=sinks, - per_layer_sliding_window=per_layer_sliding_window, - rotary_emb=rotary_emb, - prefix=prefix, - **kwargs) + assert ( + not is_sglang() + ), "PagedAttention is not supported for plugin mode(sglang) for now" + super().__init__( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + kv_cache_dtype=kv_cache_dtype, + layer_num=layer_num, + use_mla=use_mla, + mla_modules=mla_modules, + sinks=sinks, + per_layer_sliding_window=per_layer_sliding_window, + rotary_emb=rotary_emb, + prefix=prefix, + **kwargs, + ) # for plugin mode if is_vllm(): @@ -60,7 +66,9 @@ def __init__( from vllm.attention.layer import Attention, AttentionType atom_config = get_current_atom_config() - assert atom_config is not None, "atom_config is required for plugin mode to vllm" + assert ( + atom_config is not None + ), "atom_config is required for plugin mode to vllm" # use vllm cache config and quant config to follow the convention of vllm cache_config = atom_config.plugin_config.vllm_cache_config @@ -70,10 +78,10 @@ def __init__( # while it only works for custom attention backend for vllm extra_impl_args = {} if atom_config.plugin_config.vllm_use_custom_attention: - extra_impl_args['sinks'] = sinks - extra_impl_args['rotary_emb'] = rotary_emb - extra_impl_args['q_norm'] = q_norm - extra_impl_args['k_norm'] = k_norm + extra_impl_args["sinks"] = sinks + extra_impl_args["rotary_emb"] = rotary_emb + extra_impl_args["q_norm"] = q_norm + extra_impl_args["k_norm"] = k_norm self.attn = Attention( num_heads=num_heads, @@ -153,7 +161,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, positions: torch.Tensor = None, - q_scale: Optional[torch.Tensor]=None, + q_scale: Optional[torch.Tensor] = None, qkv: torch.Tensor = None, **kwargs, ): diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index d2a10ff70..ce45201f0 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -15,6 +15,7 @@ class RadixAttention(BaseAttention): """ Attention radix implementation """ + def __init__( self, num_heads, @@ -31,22 +32,25 @@ def __init__( prefix: Optional[str] = None, **kwargs, ): - super().__init__(num_heads=num_heads, - head_dim=head_dim, - scale=scale, - num_kv_heads=num_kv_heads, - kv_cache_dtype=kv_cache_dtype, - layer_num=layer_num, - use_mla=use_mla, - mla_modules=mla_modules, - sinks=sinks, - per_layer_sliding_window=per_layer_sliding_window, - rotary_emb=rotary_emb, - prefix=prefix, - **kwargs) + super().__init__( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + kv_cache_dtype=kv_cache_dtype, + layer_num=layer_num, + use_mla=use_mla, + mla_modules=mla_modules, + sinks=sinks, + per_layer_sliding_window=per_layer_sliding_window, + rotary_emb=rotary_emb, + prefix=prefix, + **kwargs, + ) if is_sglang(): from sglang.srt.layers.radix_attention import RadixAttention + self.attn = RadixAttention( num_heads=num_heads, head_dim=head_dim, @@ -56,29 +60,28 @@ def __init__( prefix=maybe_prefix(prefix, "attn"), ) else: - raise NotImplementedError("RadixAttention is only supported for plugin mode for sglang for now") + raise NotImplementedError( + "RadixAttention is only supported for plugin mode for sglang for now" + ) def forward_impl_plugin_mode( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_metadata = None, + attn_metadata=None, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, positions: torch.Tensor = None, - q_scale: torch.Tensor=None, + q_scale: torch.Tensor = None, **kwargs, ): if is_sglang(): # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" - return self.attn(q=query, - k=key, - v=value, - forward_batch=forward_batch) + return self.attn(q=query, k=key, v=value, forward_batch=forward_batch) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" @@ -90,16 +93,18 @@ def forward( key: torch.Tensor, value: torch.Tensor, positions: torch.Tensor = None, - q_scale: Optional[torch.Tensor]=None, + q_scale: Optional[torch.Tensor] = None, **kwargs, ): if is_plugin_mode(): - o = self.forward_impl_plugin_mode(query=query, - key=key, - value=value, - positions=positions, - q_scale=q_scale, - **kwargs) + o = self.forward_impl_plugin_mode( + query=query, + key=key, + value=value, + positions=positions, + q_scale=q_scale, + **kwargs, + ) else: raise NotImplementedError( "RadixAttention is not supported for server mode for now" diff --git a/atom/models/qwen3.py b/atom/models/qwen3.py index 517bd9604..75c760ff7 100644 --- a/atom/models/qwen3.py +++ b/atom/models/qwen3.py @@ -145,7 +145,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config = None, + quant_config=None, prefix: str = "", ) -> None: super().__init__() @@ -225,9 +225,9 @@ def forward( hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states, - **model_kwargs) + hidden_states = self.self_attn( + positions=positions, hidden_states=hidden_states, **model_kwargs + ) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -270,10 +270,12 @@ def forward( hidden_states = self.embed_tokens(input_ids) residual = None for layer in self.layers: - hidden_states, residual = layer(positions=positions, - hidden_states=hidden_states, - residual=residual, - **model_kwargs) + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + **model_kwargs, + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -296,10 +298,12 @@ def __init__(self, config: Any, prefix: str = "") -> None: atom_config=self.atom_config, prefix=maybe_prefix(prefix, "model") ) - self.lm_head = ParallelLMHead(num_embeddings=self.hf_config.vocab_size, - embedding_dim=self.hf_config.hidden_size, - bias=False, - prefix=maybe_prefix(prefix, "lm_head")) + self.lm_head = ParallelLMHead( + num_embeddings=self.hf_config.vocab_size, + embedding_dim=self.hf_config.hidden_size, + bias=False, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.hf_config.tie_word_embeddings: self.lm_head.weight.data = self.model.embed_tokens.weight.data @@ -307,13 +311,13 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors = None, + intermediate_tensors=None, inputs_embeds: torch.Tensor | None = None, **model_kwargs: dict[str, Any], ) -> torch.Tensor: - hidden_states = self.model(input_ids=input_ids, - positions=positions, - **model_kwargs) + hidden_states = self.model( + input_ids=input_ids, positions=positions, **model_kwargs + ) return hidden_states def compute_logits( @@ -325,7 +329,7 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # load weights in plugin mode and discard passed weights generator - loaded_weights_record = load_model_in_plugin_mode(model=self, - config=self.atom_config, - prefix="model.") + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) return loaded_weights_record diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index f509d2e51..c46c23778 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -237,7 +237,9 @@ def forward( q, k, v = torch.split( qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 ) - attn_output = self.attn(query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv) + attn_output = self.attn( + query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv + ) else: # Add qk-norm q = self.q_norm(q) @@ -249,12 +251,7 @@ def forward( class Qwen3MoeDecoderLayer(nn.Module): - def __init__( - self, - atom_config = None, - layer_num: int = 0, - prefix: str = "" - ) -> None: + def __init__(self, atom_config=None, layer_num: int = 0, prefix: str = "") -> None: super().__init__() self.atom_config = atom_config @@ -294,7 +291,9 @@ def __init__( and (self.layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3MoeSparseMoeBlock( - config, quant_config=self.atom_config.quant_config, prefix=f"{prefix}.mlp" + config, + quant_config=self.atom_config.quant_config, + prefix=f"{prefix}.mlp", ) else: self.mlp = Qwen3MoeMLP( @@ -375,7 +374,7 @@ def __init__( self.norm = RMSNorm( self.config.hidden_size, eps=self.config.rms_norm_eps, - fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, ) else: self.norm = PPMissingLayer() @@ -407,7 +406,9 @@ def forward( residual = intermediate_tensors["residual"] for layer in self.layers[self.start_layer : self.end_layer]: - hidden_states, residual = layer(positions, hidden_states, residual, **model_kwargs) + hidden_states, residual = layer( + positions, hidden_states, residual, **model_kwargs + ) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -456,10 +457,12 @@ def __init__( ) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(num_embeddings=self.config.vocab_size, - embedding_dim=self.config.hidden_size, - bias=False, - prefix=maybe_prefix(prefix, "lm_head")) + self.lm_head = ParallelLMHead( + num_embeddings=self.config.vocab_size, + embedding_dim=self.config.hidden_size, + bias=False, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() if self.config.tie_word_embeddings: @@ -515,7 +518,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # load weights in plugin mode and discard passed weights generator - loaded_weights_record = load_model_in_plugin_mode(model=self, - config=self.atom_config, - prefix="model.") + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) return loaded_weights_record diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index 7d71454f6..37b5d46ae 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -15,6 +15,7 @@ _PARTITION_SIZE_ROCM = 256 _CP_TOKENS_PER_ITER_ROCM = 32 * 1024 + @dataclass class AiterFlashAttentionDecodeMetadata: max_query_len: int @@ -108,6 +109,7 @@ class MetadataForPluginMode: # k_scale: dict[str, torch.Tensor] | None # v_scale: dict[str, torch.Tensor] | None + class vllmAiterBackendMethods: # here attention in ATOM doesn't accept the output buffer because # ATOM works as a model impl backend, it needs the maximum freedom @@ -126,6 +128,7 @@ def __init__(self): @staticmethod def get_supported_kernel_block_sizes(): from vllm.v1.attention.backend import MultipleOf + return [MultipleOf(16)] @staticmethod @@ -162,9 +165,9 @@ def supports_alibi_sqrt(cls) -> bool: def AiterBackendDecoratorForPluginMode(cls): - ''' + """ Decorator for AiterBackend to add specific methods and attributes for plugin mode - ''' + """ is_vllm_mode = is_vllm() if is_vllm_mode: @@ -172,32 +175,35 @@ def AiterBackendDecoratorForPluginMode(cls): cls.full_cls_name = vllmAiterBackendMethods.full_cls_name cls.accept_output_buffer = vllmAiterBackendMethods.accept_output_buffer cls.supported_dtypes = vllmAiterBackendMethods.supported_dtypes - cls.get_supported_kernel_block_sizes = vllmAiterBackendMethods.get_supported_kernel_block_sizes + cls.get_supported_kernel_block_sizes = ( + vllmAiterBackendMethods.get_supported_kernel_block_sizes + ) cls.get_kv_cache_shape = vllmAiterBackendMethods.get_kv_cache_shape cls.is_mla = vllmAiterBackendMethods.is_mla - cls.get_required_kv_cache_layout = vllmAiterBackendMethods.get_required_kv_cache_layout + cls.get_required_kv_cache_layout = ( + vllmAiterBackendMethods.get_required_kv_cache_layout + ) cls.get_supported_head_sizes = vllmAiterBackendMethods.get_supported_head_sizes cls.supports_alibi_sqrt = vllmAiterBackendMethods.supports_alibi_sqrt return cls def create_attn_metadata_builder_init_method(base_class): - ''' + """ Create the init method for metadata builder - ''' - def init_method_under_plugin_mode(self, - kv_cache_spec=None, - layer_names=None, - config=None, - device=None, - model_runner=None): - base_class.__init__(self, - kv_cache_spec, - layer_names, - config, - device) + """ + + def init_method_under_plugin_mode( + self, + kv_cache_spec=None, + layer_names=None, + config=None, + device=None, + model_runner=None, + ): + base_class.__init__(self, kv_cache_spec, layer_names, config, device) logger.info(f"init AiterAttentionMetadataBuilder for plugin mode") - from vllm.config import VllmConfig,get_layers_from_vllm_config + from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.attention.layer import Attention assert isinstance(config, VllmConfig) @@ -226,9 +232,9 @@ def init_method_under_plugin_mode(self, while len(sliding_window_sizes) > 0: sliding_window_config = sliding_window_sizes.pop() if sliding_window_config is not None and sliding_window_config[0] != -1: - assert self.aot_sliding_window is None, ( - "Aiter Backend only support one valid sliding window" - ) + assert ( + self.aot_sliding_window is None + ), "Aiter Backend only support one valid sliding window" self.aot_sliding_window = sliding_window_config # for extend path to store the fetched key and value @@ -250,9 +256,9 @@ def init_method_under_plugin_mode(self, def setup_attn_metadata_builder_base_class_and_attributes(class_dict: dict): - ''' + """ Setup the base class and attributes for attention metadata builder - ''' + """ from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -263,8 +269,8 @@ def setup_attn_metadata_builder_base_class_and_attributes(class_dict: dict): needs_generic = True # align with vllm rocm aiter fa - class_dict['_cudagraph_support'] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - class_dict['reorder_batch_threshold'] = 1 + class_dict["_cudagraph_support"] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + class_dict["reorder_batch_threshold"] = 1 return base_class, generic_base, needs_generic, class_dict @@ -279,7 +285,7 @@ def __init__(self): def build( self, common_prefix_len: int = 0, - common_attn_metadata = None, + common_attn_metadata=None, fast_build: bool = False, ): if common_prefix_len > 0: @@ -289,8 +295,7 @@ def build( # here assume the decode num token is 1 per request split_ret = split_decodes_prefills_and_extends( - common_attn_metadata=common_attn_metadata, - decode_threshold=1 + common_attn_metadata=common_attn_metadata, decode_threshold=1 ) ( @@ -380,6 +385,7 @@ def build( # each chunk prefill request max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends from vllm.utils.math_utils import cdiv + num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk) chunk_starts = ( @@ -448,7 +454,9 @@ def build( ) for idx in range(num_extends): - extend_start_seq_len = seq_lens_for_extend[idx] - query_lens_for_extend[idx] + extend_start_seq_len = ( + seq_lens_for_extend[idx] - query_lens_for_extend[idx] + ) extend_end_seq_len = seq_lens_for_extend[idx] for pos in range(extend_start_seq_len, extend_end_seq_len): positions.append(pos) @@ -483,7 +491,7 @@ def build( num_actual_tokens = common_attn_metadata.num_actual_tokens self.positions.np[:num_actual_tokens] = positions - context=Context( + context = Context( positions=self.positions.copy_to_gpu(num_actual_tokens), is_prefill=has_prefill, batch_size=context_batch_size, @@ -526,13 +534,15 @@ def build( # this method will be called by vllm, so it follows the vllm's interface convention def build_for_cudagraph_capture( self, - common_attn_metadata = None, + common_attn_metadata=None, ): self.total_tokens = ( self.model_config.max_model_len * self.vllm_config.scheduler_config.max_num_partial_prefills ) - attn_metadata = self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata) + attn_metadata = self.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) self.total_tokens = 0 return attn_metadata @@ -546,7 +556,7 @@ def decorator(cls): class_dict = {} for key, value in cls.__dict__.items(): - if not key.startswith('__') or key in ('__annotations__',): + if not key.startswith("__") or key in ("__annotations__",): class_dict[key] = value # handle the generic base class @@ -555,20 +565,25 @@ def decorator(cls): if is_vllm_mode: # get the base class and generic base class - base_class, generic_base, needs_generic, class_dict = \ + base_class, generic_base, needs_generic, class_dict = ( setup_attn_metadata_builder_base_class_and_attributes(class_dict) + ) # replace the __init__ method to the decorated class - class_dict['__init__'] = create_attn_metadata_builder_init_method(base_class) + class_dict["__init__"] = create_attn_metadata_builder_init_method( + base_class + ) # add the methods to the decorated class for method_name in dir(vllmAttentionMetadataBuilderMethods): - if not method_name.startswith('_'): + if not method_name.startswith("_"): method = getattr(vllmAttentionMetadataBuilderMethods, method_name) if callable(method): class_dict[method_name] = method elif is_sglang_mode: - raise NotImplementedError("AttentionMetadataBuilder for sglang is not implemented yet") + raise NotImplementedError( + "AttentionMetadataBuilder for sglang is not implemented yet" + ) # create the new class new_class = type(cls.__name__, (base_class,), class_dict) @@ -598,6 +613,7 @@ def unified_attention_with_output_base_for_plugin_mode( ) -> torch.Tensor: from atom.config import get_current_atom_config from atom.utils import envs + atom_config = get_current_atom_config() if use_mla: raise NotImplementedError("MLA is not supported for plugin mode for now") diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index 35fbf8149..f4800a92a 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING import logging + logger = logging.getLogger("atom") if TYPE_CHECKING: @@ -21,7 +22,9 @@ from atom.utils import envs -ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION +ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( + envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION +) _PARTITION_SIZE_ROCM = 256 _CP_TOKENS_PER_ITER_ROCM = 32 * 1024 @@ -40,18 +43,20 @@ def __init__(self): "It is only used as a method container for the decorator." ) - def rope_cache_plugin_mode(self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qkv: torch.Tensor, - position: torch.Tensor, - attention_metadata: "AttentionMetaData", - k_cache: torch.Tensor, - v_cache: torch.Tensor, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - flash_layout: bool = False): + def rope_cache_plugin_mode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qkv: torch.Tensor, + position: torch.Tensor, + attention_metadata: "AttentionMetaData", + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + flash_layout: bool = False, + ): num_blocks, block_size, num_kv_heads, head_size = k_cache.shape @@ -84,7 +89,7 @@ def rope_cache_plugin_mode(self, # # key_cache: [num_blocks, num_kv_heads, head_size // x, block_size, x] # value_cache: [num_blocks, num_kv_heads, head_size, block_size] - # + # # and the origin kv cache layout in fwd_args is not flash attn_metadata = attention_metadata @@ -121,7 +126,8 @@ def rope_cache_plugin_mode(self, qkv = qkv.view(qkv.shape[0], -1, self.head_dim) q, k, v = qkv.split( - [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1) + [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1 + ) elif use_triton_attn and self.rotary_emb is not None: k_scale = v_scale = self.kv_scale @@ -195,14 +201,16 @@ def _get_cp_mha_gather_cache_views( return key_cache, value_cache, page_size return key_cache, value_cache, key_cache.shape[1] - def paged_attention_triton_plugin_mode(self, - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - out: torch.Tensor, - attn_metadata: "AttentionMetaData"): + def paged_attention_triton_plugin_mode( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + out: torch.Tensor, + attn_metadata: "AttentionMetaData", + ): o = out num_seqs, num_q_heads_total, head_size = q.shape @@ -224,9 +232,7 @@ def paged_attention_triton_plugin_mode(self, max_context_partition_num, query_group_size, ) - exp_sums = torch.empty( - intermediate_shape, dtype=torch.float32, device=q.device - ) + exp_sums = torch.empty(intermediate_shape, dtype=torch.float32, device=q.device) max_logits = torch.empty( intermediate_shape, dtype=torch.float32, device=q.device ) @@ -243,7 +249,11 @@ def paged_attention_triton_plugin_mode(self, if not per_tensor: k_scale = k_scale.unsqueeze(-1) v_scale = v_scale.unsqueeze(-1) - compute_type = torch.bfloat16 if self.kv_cache_dtype == "bf16" or per_tensor else aiter.dtypes.fp8 + compute_type = ( + torch.bfloat16 + if self.kv_cache_dtype == "bf16" or per_tensor + else aiter.dtypes.fp8 + ) torch.ops.aiter.pa_decode_gluon( o, @@ -253,8 +263,8 @@ def paged_attention_triton_plugin_mode(self, attn_metadata.plugin_metadata.seq_lens, attn_metadata.block_tables, self.scale, - 1, # query_lenth - max_context_partition_num, + 1, # query_lenth + max_context_partition_num, context_partition_size, compute_type, None, @@ -271,16 +281,18 @@ def paged_attention_triton_plugin_mode(self, return o - def paged_attention_asm_plugin_mode(self, - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - num_decodes: int, - num_decode_tokens: int, - attn_metadata: "AttentionMetaData", - out: torch.Tensor): + def paged_attention_asm_plugin_mode( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + num_decodes: int, + num_decode_tokens: int, + attn_metadata: "AttentionMetaData", + out: torch.Tensor, + ): aiter.pa_fwd_asm( Q=q, K=k_cache, @@ -312,8 +324,13 @@ def extend_for_sliding_window( v_scale: float, ): assert attn_metadata.plugin_metadata.extend_metadata is not None - assert attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata is not None - chunked_metadata = attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + assert ( + attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + is not None + ) + chunked_metadata = ( + attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + ) swa_metadata = chunked_metadata.swa_metadata assert swa_metadata is not None swa_cu_seqlens = swa_metadata.swa_cu_seqlens @@ -327,6 +344,7 @@ def extend_for_sliding_window( ) from vllm.v1.attention.backends.rocm_aiter_fa import cp_mha_gather_cache + # key_cache_for_gather, value_cache_for_gather, _ = ( # self._get_cp_mha_gather_cache_views(key_cache, value_cache) # ) @@ -346,7 +364,11 @@ def extend_for_sliding_window( total_tokens=swa_total_tokens, ) - sliding_window = (self.sliding_window, 0, 0) if self.sliding_window is not None else (-1, -1, 0) + sliding_window = ( + (self.sliding_window, 0, 0) + if self.sliding_window is not None + else (-1, -1, 0) + ) aiter.flash_attn_varlen_func( q=query, k=key_fetched, @@ -416,7 +438,9 @@ def extend_forward( return_lse=True, ) assert attn_metadata.plugin_metadata.extend_metadata is not None - chunk_context_metadata = attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + chunk_context_metadata = ( + attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + ) num_chunks = chunk_context_metadata.num_chunks workspace = chunk_context_metadata.workspace cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk @@ -502,19 +526,19 @@ def forward_impl_plugin_mode( # create the output here, it use query shape num_tokens = query.shape[0] output_dtype = query.dtype - output_shape = torch.Size( - (num_tokens, self.num_heads * self.head_size) - ) + output_shape = torch.Size((num_tokens, self.num_heads * self.head_size)) output = torch.empty(output_shape, dtype=output_dtype, device=query.device) # dummy run will skip attention in cuda graph capture phase if attn_metadata is None: return output.fill_(0) - # when using this optimization, the qkv tensor and + # when using this optimization, the qkv tensor and # position tensor are passed through q,k,v if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: - assert position is None, "position should be None because it is passed through k" + assert ( + position is None + ), "position should be None because it is passed through k" position = key qkv = value @@ -561,18 +585,20 @@ def forward_impl_plugin_mode( layer.v_scale = self.v_scale # rope and cache flush fusion. ATOM always use shuffle layout for kv cache - result = self.rope_cache_plugin_mode(q=query, - k=key, - v=value, - qkv=qkv, - position=position, - attention_metadata=attn_metadata, - k_cache=k_cache, - v_cache=v_cache, - k_scale=self.k_scale, - v_scale=self.v_scale, - flash_layout=False) - (query, key, value, k_cache, v_cache, k_scale, v_scale) = result + result = self.rope_cache_plugin_mode( + q=query, + k=key, + v=value, + qkv=qkv, + position=position, + attention_metadata=attn_metadata, + k_cache=k_cache, + v_cache=v_cache, + k_scale=self.k_scale, + v_scale=self.v_scale, + flash_layout=False, + ) + query, key, value, k_cache, v_cache, k_scale, v_scale = result # The tokens are storaged as [decode:extend:prefill] order # which is decided by the vllm @@ -600,7 +626,11 @@ def forward_impl_plugin_mode( prefill_key = key[num_decode_tokens + num_extend_tokens :] prefill_value = value[num_decode_tokens + num_extend_tokens :] - sliding_window = (self.sliding_window, 0, 0) if self.sliding_window is not None else (-1, -1, 0) + sliding_window = ( + (self.sliding_window, 0, 0) + if self.sliding_window is not None + else (-1, -1, 0) + ) aiter.flash_attn_varlen_func( q=prefill_query, @@ -680,7 +710,7 @@ def forward_impl_plugin_mode( v_scale=v_scale, out=output_actual_tokens[:num_decode_tokens], attn_metadata=attn_metadata, - ) + ) else: # Qwen only uses gluon pa decode when bs=64 if num_decodes == 64: @@ -714,16 +744,18 @@ def forward_impl_plugin_mode( def PagedAttentionImplDecoratorForPluginMode(cls): method_names = [ - 'rope_cache_plugin_mode', - '_get_cp_mha_gather_cache_views', - 'paged_attention_triton_plugin_mode', - 'paged_attention_asm_plugin_mode', - 'extend_for_sliding_window', - 'extend_forward', - 'forward_impl_plugin_mode', + "rope_cache_plugin_mode", + "_get_cp_mha_gather_cache_views", + "paged_attention_triton_plugin_mode", + "paged_attention_asm_plugin_mode", + "extend_for_sliding_window", + "extend_forward", + "forward_impl_plugin_mode", ] - logger.info('Use PagedAttentionImplDecoratorForPluginMode to decorate PagedAttentionImpl') + logger.info( + "Use PagedAttentionImplDecoratorForPluginMode to decorate PagedAttentionImpl" + ) # Add all methods to the target class for method_name in method_names: diff --git a/atom/plugin/config.py b/atom/plugin/config.py index cd1771d91..9cbfd9272 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -10,6 +10,7 @@ logger = logging.getLogger("atom") _KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD = 18 * 1024 + @dataclass class PluginConfig: # common config for both framework @@ -42,12 +43,14 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: vllm_scheduler_config = config.scheduler_config vllm_cache_config = config.cache_config vllm_parallel_config = config.parallel_config - vllm_use_custom_attention = bool(os.getenv("VLLM_ATTENTION_BACKEND", "None").lower() == "custom") + vllm_use_custom_attention = bool( + os.getenv("VLLM_ATTENTION_BACKEND", "None").lower() == "custom" + ) - # here use the ATOM compilation config, as the ATOM compile policy is used + # here use the ATOM compilation config, as the ATOM compile policy is used # instead of vLLM one for torch compile, while for cuda graph capture, # still use the vLLM - # when you don't want to use atom torch compile, you can also use + # when you don't want to use atom torch compile, you can also use # --enforce-eager to disable the atom torch compile when launch vllm server compilation_config = config.compilation_config vllm_compilation_config = CompilationConfig( @@ -75,17 +78,19 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: # specific max_model_len = vllm_model_config.max_model_len - if hasattr(vllm_scheduler_config, 'max_model_len'): - max_model_len = vllm_scheduler_config.max_model_len + if hasattr(vllm_scheduler_config, "max_model_len"): + max_model_len = vllm_scheduler_config.max_model_len max_num_batched_tokens = vllm_scheduler_config.max_num_batched_tokens # FIXME: known issue for illegal mem access in fused_moe kernel if max_num_batched_tokens >= _KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD: - logger.warning("For plugin mode, when setting max_num_batched_tokens >= " + - f"{_KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD}, there is a known issue " + - "for illegal mem access in asm fused_moe kernel, if you met the issue, " + - "please set max_num_batched_tokens smaller or choose the ck fused_moe " + - "kernel instead of asm ones") + logger.warning( + "For plugin mode, when setting max_num_batched_tokens >= " + + f"{_KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD}, there is a known issue " + + "for illegal mem access in asm fused_moe kernel, if you met the issue, " + + "please set max_num_batched_tokens smaller or choose the ck fused_moe " + + "kernel instead of asm ones" + ) return Config( model=None, @@ -148,7 +153,7 @@ def _generate_atom_config_from_sglang_config(config: Any): rl_quant_profile=server_args.rl_quant_profile, ) - # sglang doesn't passed the rank number in config, so ATOM plugin + # sglang doesn't passed the rank number in config, so ATOM plugin # get rank number through the torch.distributed.get_rank() rank = torch.distributed.get_rank() @@ -206,19 +211,20 @@ def _generate_atom_config_from_sglang_config(config: Any): def generate_atom_config_for_plugin_mode(config: Any = None): - ''' + """ Generate the atom config in plugin mode, be called when create the custom model - config: + config: - for vllm: config is VllmConfig and contains all config value from vllm - for sglang: config is only model specific config passed from sglang, so the server args is used - ''' + """ - logger.info('Generate atom config for plugin mode from passed config') + logger.info("Generate atom config for plugin mode from passed config") atom_config = None from atom.plugin import is_vllm, is_sglang from atom.config import set_current_atom_config + if is_vllm(): atom_config = _generate_atom_config_from_vllm_config(config) elif is_sglang(): @@ -232,4 +238,4 @@ def generate_atom_config_for_plugin_mode(config: Any = None): # set the current atom config for the custom model set_current_atom_config(atom_config) - return atom_config \ No newline at end of file + return atom_config diff --git a/atom/plugin/moe.py b/atom/plugin/moe.py index 1f7658675..c2db374df 100644 --- a/atom/plugin/moe.py +++ b/atom/plugin/moe.py @@ -7,7 +7,7 @@ logger = logging.getLogger("atom") -T = TypeVar('T') +T = TypeVar("T") def _apply_moe_decoration(original_cls: Type[T]) -> Type[T]: @@ -17,7 +17,7 @@ def _apply_moe_decoration(original_cls: Type[T]) -> Type[T]: """ is_vllm_mode = is_vllm() if is_vllm_mode: - # Rename this class because vLLM will call the modular + # Rename this class because vLLM will call the modular # kernel init method for all modules of the model, whose name is FusedMoE, # to init the inside kernel, while for plugin mode, the atom maintains # the kernel lifecycle by itself, so there is no need to call init on @@ -33,14 +33,14 @@ def FusedMoEDecoratorForPluginMode(cls: Type[T]) -> Type[T]: Lazy decorator that defers class modification until first instantiation """ original_cls = cls - decorated_cls_cache = {'value': None} + decorated_cls_cache = {"value": None} def get_decorated_class(): - if decorated_cls_cache['value'] is not None: - return decorated_cls_cache['value'] + if decorated_cls_cache["value"] is not None: + return decorated_cls_cache["value"] decorated = _apply_moe_decoration(original_cls) - decorated_cls_cache['value'] = decorated + decorated_cls_cache["value"] = decorated return decorated class LazyMoEWrapper(original_cls): @@ -53,5 +53,5 @@ def __new__(cls, *args, **kwargs): LazyMoEWrapper.__qualname__ = original_cls.__qualname__ LazyMoEWrapper.__module__ = original_cls.__module__ - logger.info(f'Create lazy wrapper for FusedMoE to change the naming') + logger.info(f"Create lazy wrapper for FusedMoE to change the naming") return LazyMoEWrapper diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py index ec11d07ca..29686e622 100644 --- a/atom/plugin/prepare.py +++ b/atom/plugin/prepare.py @@ -36,11 +36,11 @@ def _set_framework_backbone(framework: str) -> None: def prepare_model(config: Any, engine: str): - ''' + """ Prepare the model to upper framework, including register custom ops and init aiter dist - ''' - logging.info(f'Prepare model for plugin mode, the upper engine is {engine}') + """ + logging.info(f"Prepare model for plugin mode, the upper engine is {engine}") _set_framework_backbone(engine) @@ -67,10 +67,11 @@ def prepare_model(config: Any, engine: str): ) from atom.plugin.config import generate_atom_config_for_plugin_mode + atom_config = generate_atom_config_for_plugin_mode(config) model_cls = _ATOM_SUPPORTED_MODELS[model_arch] - logger.info(f'ATOM model class for {model_arch} is {model_cls}') + logger.info(f"ATOM model class for {model_arch} is {model_cls}") if is_vllm(): register_ops_to_vllm(atom_config=atom_config) diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 5e7584f02..d8cf0d6fb 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -8,71 +8,84 @@ logger = logging.getLogger("atom") _ATOM_SUPPORTED_MODELS = { - "Qwen3ForCausalLM" : Qwen3ForCausalLM, - "Qwen3MoeForCausalLM" : Qwen3MoeForCausalLM, + "Qwen3ForCausalLM": Qwen3ForCausalLM, + "Qwen3MoeForCausalLM": Qwen3MoeForCausalLM, } def _register_custom_attention_to_vllm() -> None: - from vllm.v1.attention.backends.registry import register_backend, AttentionBackendEnum - logger.info('Register custom attention backend AiterBackend to vLLM') - register_backend(backend=AttentionBackendEnum.CUSTOM, - is_mamba=False, - class_path="atom.model_ops.attentions.aiter_attention.AiterBackend") + from vllm.v1.attention.backends.registry import ( + register_backend, + AttentionBackendEnum, + ) + + logger.info("Register custom attention backend AiterBackend to vLLM") + register_backend( + backend=AttentionBackendEnum.CUSTOM, + is_mamba=False, + class_path="atom.model_ops.attentions.aiter_attention.AiterBackend", + ) def _register_custom_attention_to_sglang() -> None: - from sglang.srt.layers.attention.attention_registry import register_attention_backend + from sglang.srt.layers.attention.attention_registry import ( + register_attention_backend, + ) # here register the custom attention backend with the name "aiter" # as sglang defines the fixed attention backend choices, which must be # in-tree - logger.info('Register custom attention backend AiterBackend to SGLang') + logger.info("Register custom attention backend AiterBackend to SGLang") @register_attention_backend("aiter") def create_atom_backend(runner): from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + return AiterAttnBackend(runner) def register_ops_to_vllm(atom_config: Config) -> None: - ''' + """ Register custom ops to vllm, including attention - ''' + """ if atom_config.plugin_config.vllm_use_custom_attention: _register_custom_attention_to_vllm() else: - logger.warning("Please export VLLM_ATTENTION_BACKEND=CUSTOM to use atom attention") + logger.warning( + "Please export VLLM_ATTENTION_BACKEND=CUSTOM to use atom attention" + ) def register_ops_to_sglang(atom_config: Config) -> None: - ''' + """ Register custom ops to sglang, including attention - ''' + """ _register_custom_attention_to_sglang() def set_attn_cls() -> None: - ''' + """ Set the attention class for constructing the model based on the framework - ''' + """ import atom.model_ops as ops from atom.model_ops import PagedAttention, RadixAttention if is_vllm(): ops.ATTN_CLS = PagedAttention - logger.info('Set ATTN_CLS to PagedAttention for vLLM') + logger.info("Set ATTN_CLS to PagedAttention for vLLM") elif is_sglang(): ops.ATTN_CLS = RadixAttention - logger.info('Set ATTN_CLS to RadixAttention for SGLang') + logger.info("Set ATTN_CLS to RadixAttention for SGLang") def init_aiter_dist(config: Config) -> None: - ''' + """ Initialize aiter dist for using aiter custom collective op - ''' - logger.info('Initialize aiter dist for using aiter custom collective op for plugin mode') + """ + logger.info( + "Initialize aiter dist for using aiter custom collective op for plugin mode" + ) from aiter import init_dist_env from aiter.dist.utils import get_distributed_init_method @@ -80,21 +93,27 @@ def init_aiter_dist(config: Config) -> None: rank = config.plugin_config.rank tensor_parallel_size = config.tensor_parallel_size - assert config.plugin_config.is_plugin_mode, "Make sure ATOM is running in plugin mode" + assert ( + config.plugin_config.is_plugin_mode + ), "Make sure ATOM is running in plugin mode" if config.plugin_config.is_vllm: dp_master_ip = config.parallel_config.data_parallel_master_ip dp_master_port = config.parallel_config.data_parallel_master_port elif config.plugin_config.is_sglang: if config.plugin_config.sglang_dist_init_addr is not None: - dp_master_ip, dp_master_port = config.plugin_config.sglang_dist_init_addr.split(":") + dp_master_ip, dp_master_port = ( + config.plugin_config.sglang_dist_init_addr.split(":") + ) else: dp_master_ip = f"127.0.0.1" dp_master_port = config.plugin_config.sglang_port_args.nccl_port distributed_init_method = get_distributed_init_method(dp_master_ip, dp_master_port) - logger.info(f'Initialize aiter dist for using aiter custom collective op for plugin mode, rank:{rank}') + logger.info( + f"Initialize aiter dist for using aiter custom collective op for plugin mode, rank:{rank}" + ) init_dist_env( tensor_model_parallel_size=tensor_parallel_size, rankID=rank, diff --git a/atom/utils/backends.py b/atom/utils/backends.py index 9f606ab9a..57eec887a 100644 --- a/atom/utils/backends.py +++ b/atom/utils/backends.py @@ -266,6 +266,7 @@ def _split_judge_func(node: fx.Node) -> bool: # When plugin mode(vLLM), the attention impl op is registered # as unified_attention from atom.plugin import is_vllm + if is_vllm() and "unified_attention" in node.name: return True From 0f05699b972cf3ef7a57a74958055db87432bfc4 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 2 Feb 2026 17:08:35 +0800 Subject: [PATCH 04/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/attention.py | 2 +- atom/plugin/attention_mha.py | 3 ++- atom/plugin/register.py | 5 ++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index 37b5d46ae..b473a4050 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -479,7 +479,7 @@ def build( num_actual_kv_tokens = torch.sum(seq_lens).item() - use_cascade = common_prefix_len > 0 + use_cascade = False context_batch_size = 0 has_prefill = bool(num_prefills > 0 or num_extends > 0) diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index f4800a92a..ca0142e29 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -713,7 +713,8 @@ def forward_impl_plugin_mode( ) else: # Qwen only uses gluon pa decode when bs=64 - if num_decodes == 64: + QWEN_GLUON_PA_DECODE_BS = 64 + if num_decodes == QWEN_GLUON_PA_DECODE_BS: self.paged_attention_triton_plugin_mode( q=query[:num_decode_tokens], k_cache=new_key_cache, diff --git a/atom/plugin/register.py b/atom/plugin/register.py index d8cf0d6fb..22a2cf205 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -69,13 +69,12 @@ def set_attn_cls() -> None: Set the attention class for constructing the model based on the framework """ import atom.model_ops as ops - from atom.model_ops import PagedAttention, RadixAttention if is_vllm(): - ops.ATTN_CLS = PagedAttention + ops.ATTN_CLS = ops.PagedAttention logger.info("Set ATTN_CLS to PagedAttention for vLLM") elif is_sglang(): - ops.ATTN_CLS = RadixAttention + ops.ATTN_CLS = ops.RadixAttention logger.info("Set ATTN_CLS to RadixAttention for SGLang") From 7712e1f78916f7af78dea1b118f8aca1521b44bb Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 2 Feb 2026 17:44:33 +0800 Subject: [PATCH 05/37] format ruff Signed-off-by: zejunchen-zejun --- atom/__init__.py | 6 ++++++ atom/model_ops/__init__.py | 6 ++++++ atom/model_ops/base_attention.py | 1 - atom/plugin/__init__.py | 7 +++++++ atom/plugin/attention.py | 2 +- atom/plugin/attention_mha.py | 3 +-- atom/plugin/moe.py | 2 +- atom/plugin/register.py | 2 +- 8 files changed, 23 insertions(+), 6 deletions(-) diff --git a/atom/__init__.py b/atom/__init__.py index bcfab6841..dde5eeb84 100644 --- a/atom/__init__.py +++ b/atom/__init__.py @@ -6,3 +6,9 @@ # interface for upper framework to constructe the model from ATOM from atom.plugin import prepare_model + +__all__ = [ + "LLMEngine", + "SamplingParams", + "prepare_model", +] diff --git a/atom/model_ops/__init__.py b/atom/model_ops/__init__.py index fc42e0468..8bce2960e 100644 --- a/atom/model_ops/__init__.py +++ b/atom/model_ops/__init__.py @@ -6,3 +6,9 @@ # By default, PagedAttention is used. # For sglang, RadixAttention will be assigned to ATTN_CLS ATTN_CLS = PagedAttention + +__all__ = [ + "ATTN_CLS", + "PagedAttention", + "RadixAttention", +] diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index 3391ba349..4a9fd1720 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -12,7 +12,6 @@ from atom.utils import mark_spliting_op from .attention_mla import MLAModules from atom.config import get_current_atom_config -from atom.utils.selector import get_attn_backend def fake_( diff --git a/atom/plugin/__init__.py b/atom/plugin/__init__.py index fc6671481..27c855e51 100644 --- a/atom/plugin/__init__.py +++ b/atom/plugin/__init__.py @@ -4,3 +4,10 @@ is_vllm, is_plugin_mode, ) + +__all__ = [ + "prepare_model", + "is_sglang", + "is_vllm", + "is_plugin_mode", +] diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index b473a4050..151569744 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -202,7 +202,7 @@ def init_method_under_plugin_mode( model_runner=None, ): base_class.__init__(self, kv_cache_spec, layer_names, config, device) - logger.info(f"init AiterAttentionMetadataBuilder for plugin mode") + logger.info("init AiterAttentionMetadataBuilder for plugin mode") from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.attention.layer import Attention diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index ca0142e29..8d81d8da0 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -12,6 +12,7 @@ from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache from aiter.ops.triton.gluon.pa_decode_gluon import get_recommended_splits from typing import TYPE_CHECKING +from atom.utils import envs import logging @@ -20,8 +21,6 @@ if TYPE_CHECKING: from atom.utils.forward_context import AttentionMetaData -from atom.utils import envs - ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION ) diff --git a/atom/plugin/moe.py b/atom/plugin/moe.py index c2db374df..f0ed21ff0 100644 --- a/atom/plugin/moe.py +++ b/atom/plugin/moe.py @@ -53,5 +53,5 @@ def __new__(cls, *args, **kwargs): LazyMoEWrapper.__qualname__ = original_cls.__qualname__ LazyMoEWrapper.__module__ = original_cls.__module__ - logger.info(f"Create lazy wrapper for FusedMoE to change the naming") + logger.info("Create lazy wrapper for FusedMoE to change the naming") return LazyMoEWrapper diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 22a2cf205..88b4ac87c 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -105,7 +105,7 @@ def init_aiter_dist(config: Config) -> None: config.plugin_config.sglang_dist_init_addr.split(":") ) else: - dp_master_ip = f"127.0.0.1" + dp_master_ip = "127.0.0.1" dp_master_port = config.plugin_config.sglang_port_args.nccl_port distributed_init_method = get_distributed_init_method(dp_master_ip, dp_master_port) From c203883e3620635cf126e8dbeebaf80e8fe18e73 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 2 Feb 2026 17:48:27 +0800 Subject: [PATCH 06/37] ruff format Signed-off-by: zejunchen-zejun --- atom/utils/forward_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atom/utils/forward_context.py b/atom/utils/forward_context.py index 9581f1e75..df0cba55e 100644 --- a/atom/utils/forward_context.py +++ b/atom/utils/forward_context.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field, fields -from typing import Any, Dict, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Union import numpy as np import torch From 6bbaadea12942f9988abf4a97b04e6b83ab0a485 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 2 Feb 2026 20:05:56 +0800 Subject: [PATCH 07/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/attention.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index 151569744..dee2cc208 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -97,18 +97,12 @@ class MetadataForPluginMode: prefill_metadata: Optional[AiterFlashAttentionPrefillMetadata] = None extend_metadata: Optional[AiterFlashAttentionChunkPrefillMetadata] = None - # For cascade attention. - use_cascade: bool - common_prefix_len: int - total_tokens: int + use_cascade: bool = False + common_prefix_len: int = 0 + total_tokens: int = 0 context: Optional[Context] = None - # # Only for fp8 shuffle layout kv cache, we allocate kv_scale for each layer - # # since we might integrate per token quant for kv cache in the future. - # k_scale: dict[str, torch.Tensor] | None - # v_scale: dict[str, torch.Tensor] | None - class vllmAiterBackendMethods: # here attention in ATOM doesn't accept the output buffer because From a66944157d18025b7e1ff827f2249f42e478f7b8 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 2 Feb 2026 21:47:49 +0800 Subject: [PATCH 08/37] add Signed-off-by: zejunchen-zejun --- atom/model_ops/attention_mha.py | 2 +- atom/model_ops/attentions/aiter_attention.py | 8 ++++++-- atom/model_ops/paged_attention.py | 2 +- atom/plugin/attention.py | 8 +++++++- atom/plugin/attention_mha.py | 8 ++------ atom/plugin/prepare.py | 2 +- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index b21055bfc..6626b908b 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -76,7 +76,7 @@ def __init__( self.supports_quant_query_input = False # this method will just be called by vLLM and there is no logic in this method - # as ATOM handle the process after loading weights for all ops by itself + # as ATOM handles the process after loading weights for all ops by itself def process_weights_after_loading(self, act_dtype: torch.dtype = torch.bfloat16): pass diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index 8097a8738..0489e3b21 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -35,10 +35,14 @@ def get_builder_cls() -> Type["AiterAttentionMetadataBuilder"]: @staticmethod def get_impl_cls(): - if ops.ATTN_CLS == PagedAttention: + attn_cls = ops.ATTN_CLS + if attn_cls == PagedAttention: return PagedAttentionImpl - elif ops.ATTN_CLS == RadixAttention: + elif attn_cls == RadixAttention: raise NotImplementedError("RadixAttention is not supported for now") + raise NotImplementedError( + f"Unsupported attention class {attn_cls!r} configured in ops.ATTN_CLS" + ) @AiterAttentionMetadataBuilderDecoratorForPluginMode( diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index cffb79d4f..4f90f8648 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -74,7 +74,7 @@ def __init__( cache_config = atom_config.plugin_config.vllm_cache_config quant_config = atom_config.plugin_config.vllm_quant_config - # add exter impl args, which are needed to be passed to the impl class + # add extra impl args, which are needed to be passed to the impl class # while it only works for custom attention backend for vllm extra_impl_args = {} if atom_config.plugin_config.vllm_use_custom_attention: diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index dee2cc208..3c1ed46da 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -546,6 +546,12 @@ def decorator(cls): is_vllm_mode = is_vllm() is_sglang_mode = is_sglang() + # Outside of plugin integrations (vLLM / SGLang), we should keep the + # original class intact. In ATOM server mode, the metadata builder + # defines its own __init__ method. + if not is_vllm_mode and not is_sglang_mode: + return cls + base_class = default_base_class class_dict = {} @@ -563,7 +569,7 @@ def decorator(cls): setup_attn_metadata_builder_base_class_and_attributes(class_dict) ) - # replace the __init__ method to the decorated class + # replace the __init__ method in the decorated class class_dict["__init__"] = create_attn_metadata_builder_init_method( base_class ) diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index 8d81d8da0..b0dddc454 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -27,6 +27,7 @@ _PARTITION_SIZE_ROCM = 256 _CP_TOKENS_PER_ITER_ROCM = 32 * 1024 +_QWEN_GLUON_PA_DECODE_BS = 64 class PagedAttentionImplPluginModeMethods: @@ -343,10 +344,6 @@ def extend_for_sliding_window( ) from vllm.v1.attention.backends.rocm_aiter_fa import cp_mha_gather_cache - - # key_cache_for_gather, value_cache_for_gather, _ = ( - # self._get_cp_mha_gather_cache_views(key_cache, value_cache) - # ) cp_mha_gather_cache( key_cache=key_cache, value_cache=value_cache, @@ -712,8 +709,7 @@ def forward_impl_plugin_mode( ) else: # Qwen only uses gluon pa decode when bs=64 - QWEN_GLUON_PA_DECODE_BS = 64 - if num_decodes == QWEN_GLUON_PA_DECODE_BS: + if num_decodes == _QWEN_GLUON_PA_DECODE_BS: self.paged_attention_triton_plugin_mode( q=query[:num_decode_tokens], k_cache=new_key_cache, diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py index 29686e622..c7e13a3ce 100644 --- a/atom/plugin/prepare.py +++ b/atom/plugin/prepare.py @@ -40,7 +40,7 @@ def prepare_model(config: Any, engine: str): Prepare the model to upper framework, including register custom ops and init aiter dist """ - logging.info(f"Prepare model for plugin mode, the upper engine is {engine}") + logger.info(f"Prepare model for plugin mode, the upper engine is {engine}") _set_framework_backbone(engine) From 2f8e6ee518508c89e0b6c6606edc669769c57a1a Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 2 Feb 2026 22:11:06 +0800 Subject: [PATCH 09/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/attention.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index 3c1ed46da..2214fb9c5 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -546,17 +546,15 @@ def decorator(cls): is_vllm_mode = is_vllm() is_sglang_mode = is_sglang() - # Outside of plugin integrations (vLLM / SGLang), we should keep the - # original class intact. In ATOM server mode, the metadata builder - # defines its own __init__ method. - if not is_vllm_mode and not is_sglang_mode: - return cls - base_class = default_base_class class_dict = {} + # record original decorated cls methods for key, value in cls.__dict__.items(): - if not key.startswith("__") or key in ("__annotations__",): + if ( + not key.startswith("__") + or key in ("__annotations__", "__init__", "__module__", "__qualname__", "__doc__") + ): class_dict[key] = value # handle the generic base class From 113e587d6f71a780e6f44888db99a37c66665798 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 2 Feb 2026 22:52:15 +0800 Subject: [PATCH 10/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/attention.py | 9 ++++++--- atom/plugin/attention_mha.py | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index 2214fb9c5..c574039fa 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -551,9 +551,12 @@ def decorator(cls): # record original decorated cls methods for key, value in cls.__dict__.items(): - if ( - not key.startswith("__") - or key in ("__annotations__", "__init__", "__module__", "__qualname__", "__doc__") + if not key.startswith("__") or key in ( + "__annotations__", + "__init__", + "__module__", + "__qualname__", + "__doc__", ): class_dict[key] = value diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index b0dddc454..09a83cd41 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -344,6 +344,7 @@ def extend_for_sliding_window( ) from vllm.v1.attention.backends.rocm_aiter_fa import cp_mha_gather_cache + cp_mha_gather_cache( key_cache=key_cache, value_cache=value_cache, From de036e863fbae745a4dcb50f2263351e23ed9596 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 3 Feb 2026 10:58:16 +0800 Subject: [PATCH 11/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/config.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 9cbfd9272..40bdbc680 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -49,7 +49,7 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: # here use the ATOM compilation config, as the ATOM compile policy is used # instead of vLLM one for torch compile, while for cuda graph capture, - # still use the vLLM + # still use the vLLM because it has FULL_AND_PIECEWISE feature # when you don't want to use atom torch compile, you can also use # --enforce-eager to disable the atom torch compile when launch vllm server compilation_config = config.compilation_config @@ -99,7 +99,7 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: max_model_len=max_model_len, gpu_memory_utilization=vllm_cache_config.gpu_memory_utilization, tensor_parallel_size=vllm_parallel_config.tensor_parallel_size, - enforce_eager=vllm_model_config.enforce_eager, + enforce_eager=True, # disable using atom cuda graph parallel_config=vllm_parallel_config, kv_cache_block_size=vllm_cache_config.block_size, num_kvcache_blocks=vllm_cache_config.num_gpu_blocks, @@ -162,12 +162,14 @@ def _generate_atom_config_from_sglang_config(config: Any): data_parallel_size=server_args.dp_size, ) - # create the compilation config static_forward_context for sglang - sgl_compilation_config = CompilationConfig() - # for now disable using atom torch compile as sglang uses its own torch compile - sgl_compilation_config.level = 0 - sgl_compilation_config.use_cudagraph = False - sgl_compilation_config.static_forward_context = {} + # use sglang torch compile policy and cuda graph policy + # because sglang doesn't use the compile decorator for model, + # we have no method to define self policy + sgl_compilation_config = CompilationConfig( + level=0, + use_cudagraph=False, + cudagraph_mode=None, + ) plugin_config = PluginConfig( # common config @@ -186,7 +188,8 @@ def _generate_atom_config_from_sglang_config(config: Any): sglang_port_args=PortArgs.init_new(server_args), ) - # TODO: sgl doesn't have max num batched tokens, so force to 16k + # force max num batched tokens to 16K because sgl doesn't have + # concept for max num batched tokens return Config( model=None, max_num_batched_tokens=16384, @@ -194,7 +197,7 @@ def _generate_atom_config_from_sglang_config(config: Any): max_model_len=server_args.context_length, gpu_memory_utilization=server_args.mem_fraction_static, tensor_parallel_size=server_args.tp_size, - enforce_eager=not server_args.enable_torch_compile, + enforce_eager=True, # disable using atom cuda graph parallel_config=sgl_parallel_config, kv_cache_dtype=server_args.kv_cache_dtype, enable_prefix_caching=False, From f0f0c94a8753adaad74456e1b2740b8cdf3cee01 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 3 Feb 2026 11:57:09 +0800 Subject: [PATCH 12/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/attention_mha.py | 173 +++++++++++++++++++++++++++++++---- atom/plugin/config.py | 6 +- 2 files changed, 156 insertions(+), 23 deletions(-) diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index 09a83cd41..9129657c2 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -11,6 +11,8 @@ from aiter import dtypes, fused_qk_norm_rope_cache_quant_shuffle from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache from aiter.ops.triton.gluon.pa_decode_gluon import get_recommended_splits +import triton +import triton.language as tl from typing import TYPE_CHECKING from atom.utils import envs @@ -30,6 +32,157 @@ _QWEN_GLUON_PA_DECODE_BS = 64 +@triton.jit +def cp_mha_gather_cache_kernel( + key_cache_ptr, # [num_blocks, page_size, num_head, head_size] + value_cache_ptr, # [num_blocks, page_size, num_head, head_size] + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size] + block_table_ptr, # [num_batches, max_block_num] + cu_seqlens_kv_ptr, # [num_batches + 1] + token_to_batch_ptr, # [max_cum_tokens] + seq_start_ptr, # [num_batches] + k_scale_ptr, # [1] / [num_blocks, num_kv_heads, page_size] + v_scale_ptr, + num_heads, + head_size, + x, + max_block_num, + DEQUANT: tl.constexpr, + PAGE_SIZE: tl.constexpr, + CACHE_FORMAT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + token_id = tl.program_id(0) + head_id = tl.program_id(1) + col_offsets = tl.arange(0, BLOCK_SIZE) + + key_ptr_offset = key_ptr + token_id * head_size * num_heads + head_id * head_size + value_ptr_offset = ( + value_ptr + token_id * head_size * num_heads + head_id * head_size + ) + batch_idx = tl.load(token_to_batch_ptr + token_id) + batch_start = tl.load(seq_start_ptr + batch_idx) + token_start = tl.load(cu_seqlens_kv_ptr + batch_idx) + batch_offset = token_id - token_start + batch_start + block_offset = batch_offset // PAGE_SIZE + block_id = tl.load(block_table_ptr + max_block_num * batch_idx + block_offset).to( + tl.int64 + ) + slot_id = batch_offset % PAGE_SIZE + + if CACHE_FORMAT == "NHD": + # for kv cache layout as + # K: [num_blocks, page_size, num_head, head_dim] + # V: [num_blocks, page_size, num_head, head_dim] + key_cache_ptr_offset = ( + key_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + slot_id * num_heads * head_size + + head_id * head_size + ) + value_cache_ptr_offset = ( + value_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + slot_id * num_heads * head_size + + head_id * head_size + ) + k_reg = tl.load(key_cache_ptr_offset + col_offsets) + v_reg = tl.load(value_cache_ptr_offset + col_offsets) + if DEQUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = k_reg.dtype + v_dtype = v_reg.dtype + k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype) + v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype) + tl.store(key_ptr_offset + col_offsets, k_reg) + tl.store(value_ptr_offset + col_offsets, v_reg) + + elif CACHE_FORMAT == "SHUFFLE": + # for kv cache layout as + # K: [num_blocks, num_head, head_dim // x, page_size, x] + # V: [num_blocks, num_head, page_size // x, head_dim, x] + key_cache_ptr_offset = ( + key_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + head_id * head_size * PAGE_SIZE + + slot_id * x + ) + value_cache_ptr_offset = ( + value_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + head_id * head_size * PAGE_SIZE + + (slot_id // x) * head_size * x + + slot_id % x + ) + k_reg_offset = col_offsets // x * PAGE_SIZE * x + col_offsets % x + v_reg_offset = col_offsets * x + k_reg = tl.load(key_cache_ptr_offset + k_reg_offset) + v_reg = tl.load(value_cache_ptr_offset + v_reg_offset) + if DEQUANT: + k_scale = 1.0 + v_scale = 1.0 + k_reg = k_reg.to(tl.float32) * k_scale + v_reg = v_reg.to(tl.float32) * v_scale + tl.store(key_ptr_offset + col_offsets, k_reg) + tl.store(value_ptr_offset + col_offsets, v_reg) + + +def cp_mha_gather_cache( + key_cache: torch.Tensor, + value_cache: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + block_tables: torch.Tensor, + k_scales: torch.Tensor, + v_scales: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + token_to_batch: torch.Tensor, + seq_starts: torch.Tensor, + dequant: bool, + kv_cache_layout: str, + total_tokens: int, +): + assert kv_cache_layout in [ + "NHD", + "SHUFFLE", + ], "kv_cache_layout only support NHD, SHUFFLE" + head_dim = key.shape[2] + x = 16 // key_cache.element_size() + # assert dequant is True, "Currently, we only support "\ + # "gather cache with dequant" + # For k cache layout: [num_blocks, num_heads, page_size, head_dim] + assert head_dim == key_cache.shape[3], ( + "We assume your kv cache layout is [num_blocks, " + "page_size, num_heads, head_dim], but got otherwise" + ) + page_size = key_cache.shape[1] + num_heads = key_cache.shape[2] + + grid = lambda meta: (total_tokens, num_heads) + cp_mha_gather_cache_kernel[grid]( + key_cache, + value_cache, + key, + value, + block_tables, + cu_seqlens_kv, + token_to_batch, + seq_starts, + k_scales, + v_scales, + num_heads, + head_dim, + x, + block_tables.size(1), + DEQUANT=dequant, + PAGE_SIZE=page_size, + CACHE_FORMAT=kv_cache_layout, + BLOCK_SIZE=head_dim, + ) + + class PagedAttentionImplPluginModeMethods: """ Container class for plugin mode methods. @@ -185,22 +338,6 @@ def rope_cache_plugin_mode( return q, k, v, k_cache, v_cache, k_scale, v_scale - def _get_cp_mha_gather_cache_views( - self, key_cache: torch.Tensor, value_cache: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, int]: - # For SHUFFLE layout, the wrapper derives PAGE_SIZE/num_heads from - # tensor shapes; provide a reshape-only view to keep storage unchanged. - if key_cache.ndim == 5: - num_blocks = key_cache.shape[0] - num_heads = key_cache.shape[1] - page_size = key_cache.shape[3] - x = key_cache.shape[4] - head_size = key_cache.shape[2] * x - key_cache = key_cache.view(num_blocks, page_size, num_heads, head_size) - value_cache = value_cache.view(num_blocks, page_size, num_heads, head_size) - return key_cache, value_cache, page_size - return key_cache, value_cache, key_cache.shape[1] - def paged_attention_triton_plugin_mode( self, q: torch.Tensor, @@ -343,8 +480,6 @@ def extend_for_sliding_window( swa_metadata.swa_workspace[1], ) - from vllm.v1.attention.backends.rocm_aiter_fa import cp_mha_gather_cache - cp_mha_gather_cache( key_cache=key_cache, value_cache=value_cache, @@ -403,7 +538,6 @@ def extend_forward( v_scale: torch.Tensor, ): from vllm.v1.attention.ops.merge_attn_states import merge_attn_states - from vllm.v1.attention.backends.rocm_aiter_fa import cp_mha_gather_cache if self.sliding_window != -1: self.extend_for_sliding_window( @@ -742,7 +876,6 @@ def PagedAttentionImplDecoratorForPluginMode(cls): method_names = [ "rope_cache_plugin_mode", - "_get_cp_mha_gather_cache_views", "paged_attention_triton_plugin_mode", "paged_attention_asm_plugin_mode", "extend_for_sliding_window", diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 40bdbc680..09ff01b1f 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -99,7 +99,7 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: max_model_len=max_model_len, gpu_memory_utilization=vllm_cache_config.gpu_memory_utilization, tensor_parallel_size=vllm_parallel_config.tensor_parallel_size, - enforce_eager=True, # disable using atom cuda graph + enforce_eager=True, # disable using atom cuda graph parallel_config=vllm_parallel_config, kv_cache_block_size=vllm_cache_config.block_size, num_kvcache_blocks=vllm_cache_config.num_gpu_blocks, @@ -188,7 +188,7 @@ def _generate_atom_config_from_sglang_config(config: Any): sglang_port_args=PortArgs.init_new(server_args), ) - # force max num batched tokens to 16K because sgl doesn't have + # force max num batched tokens to 16K because sgl doesn't have # concept for max num batched tokens return Config( model=None, @@ -197,7 +197,7 @@ def _generate_atom_config_from_sglang_config(config: Any): max_model_len=server_args.context_length, gpu_memory_utilization=server_args.mem_fraction_static, tensor_parallel_size=server_args.tp_size, - enforce_eager=True, # disable using atom cuda graph + enforce_eager=True, # disable using atom cuda graph parallel_config=sgl_parallel_config, kv_cache_dtype=server_args.kv_cache_dtype, enable_prefix_caching=False, From dd6e9b301453ca980152bff9904c2f823fe71733 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 3 Feb 2026 12:06:12 +0800 Subject: [PATCH 13/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/attention_mha.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index 9129657c2..2d45f25f9 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -160,7 +160,7 @@ def cp_mha_gather_cache( page_size = key_cache.shape[1] num_heads = key_cache.shape[2] - grid = lambda meta: (total_tokens, num_heads) + grid = lambda meta: (total_tokens, num_heads) # noqa: E731 cp_mha_gather_cache_kernel[grid]( key_cache, value_cache, @@ -210,7 +210,6 @@ def rope_cache_plugin_mode( v_scale: torch.Tensor, flash_layout: bool = False, ): - num_blocks, block_size, num_kv_heads, head_size = k_cache.shape if not flash_layout: @@ -348,7 +347,6 @@ def paged_attention_triton_plugin_mode( out: torch.Tensor, attn_metadata: "AttentionMetaData", ): - o = out num_seqs, num_q_heads_total, head_size = q.shape num_blocks, num_kv_heads, _, block_size, _ = k_cache.shape @@ -667,9 +665,9 @@ def forward_impl_plugin_mode( # when using this optimization, the qkv tensor and # position tensor are passed through q,k,v if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: - assert ( - position is None - ), "position should be None because it is passed through k" + assert position is None, ( + "position should be None because it is passed through k" + ) position = key qkv = value @@ -873,7 +871,6 @@ def forward_impl_plugin_mode( def PagedAttentionImplDecoratorForPluginMode(cls): - method_names = [ "rope_cache_plugin_mode", "paged_attention_triton_plugin_mode", From b985b8287bf71b4786a81a4c8f0ca9678bb5a9ec Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 3 Feb 2026 12:15:57 +0800 Subject: [PATCH 14/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/attention_mha.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index 2d45f25f9..42fea642c 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -665,9 +665,9 @@ def forward_impl_plugin_mode( # when using this optimization, the qkv tensor and # position tensor are passed through q,k,v if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: - assert position is None, ( - "position should be None because it is passed through k" - ) + assert ( + position is None + ), "position should be None because it is passed through k" position = key qkv = value From b9806e055efd07b1c57bd488943702883da9b790 Mon Sep 17 00:00:00 2001 From: Guanbao Yu Date: Tue, 3 Feb 2026 15:40:44 +0000 Subject: [PATCH 15/37] fix sglang plugin mode acc issue --- atom/model_loader/loader.py | 6 ++++-- atom/model_ops/radix_attention.py | 4 ++++ atom/models/qwen3_moe.py | 4 +++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index f40de844f..a38c68883 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -29,7 +29,7 @@ get_spec_layer_idx_from_weight_name, rewrite_spec_layer_name, ) -from atom.plugin.prepare import is_vllm +from atom.plugin.prepare import is_vllm, is_sglang logger = logging.getLogger("atom") @@ -264,7 +264,9 @@ def load_model( else: module.process_weights_after_loading() quant_method = getattr(module, "quant_method", None) - if isinstance(quant_method, QuantizeMethodBase): + # when running plugin mode for sglang, don't do the post process here + # since sglang will call this func automatically after finishing loading + if isinstance(quant_method, QuantizeMethodBase) and not is_sglang(): quant_method.process_weights_after_loading(module) if isinstance(quant_method, FusedMoEMethodBase): quant_method.init_prepare_finalize(module) diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index ce45201f0..25388b384 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -47,6 +47,7 @@ def __init__( prefix=prefix, **kwargs, ) + self.rotary_emb = rotary_emb if is_sglang(): from sglang.srt.layers.radix_attention import RadixAttention @@ -81,6 +82,9 @@ def forward_impl_plugin_mode( # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" + if self.rotary_emb is not None: + assert positions is not None, "positions is required for ROPE" + query, key = self.rotary_emb(positions, query, key) return self.attn(q=query, k=key, v=value, forward_batch=forward_batch) else: raise NotImplementedError( diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index c46c23778..f8adfc9cf 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -245,7 +245,9 @@ def forward( q = self.q_norm(q) k = self.k_norm(k) - attn_output = self.attn(query=q, key=k, value=v, **model_kwargs) + attn_output = self.attn( + query=q, key=k, value=v, positions=positions, **model_kwargs + ) output = self.o_proj(attn_output) return output From 418d442ea6da65efd99a5901730687dd7728d024 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Sat, 7 Feb 2026 12:40:58 +0800 Subject: [PATCH 16/37] init vllm-atom, first commit Signed-off-by: zejunchen-zejun --- atom/model_ops/paged_attention.py | 2 +- atom/plugin/config.py | 8 +- atom/plugin/prepare.py | 9 +- atom/plugin/register.py | 26 ------ atom/plugin/vllm/__init__.py | 5 ++ atom/plugin/vllm/model_wrapper.py | 141 ++++++++++++++++++++++++++++++ atom/plugin/vllm/register.py | 91 +++++++++++++++++++ pyproject.toml | 15 +++- 8 files changed, 262 insertions(+), 35 deletions(-) create mode 100644 atom/plugin/vllm/__init__.py create mode 100644 atom/plugin/vllm/model_wrapper.py create mode 100644 atom/plugin/vllm/register.py diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index 4f90f8648..c498f34ce 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -77,7 +77,7 @@ def __init__( # add extra impl args, which are needed to be passed to the impl class # while it only works for custom attention backend for vllm extra_impl_args = {} - if atom_config.plugin_config.vllm_use_custom_attention: + if atom_config.plugin_config.vllm_use_atom_attention: extra_impl_args["sinks"] = sinks extra_impl_args["rotary_emb"] = rotary_emb extra_impl_args["q_norm"] = q_norm diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 09ff01b1f..234dfb2d3 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -24,7 +24,7 @@ class PluginConfig: vllm_scheduler_config: Any = None vllm_cache_config: Any = None vllm_quant_config: Any = None - vllm_use_custom_attention: bool = False + vllm_use_atom_attention: bool = False # sglang specific sglang_model_opt_config: Any = None @@ -43,8 +43,8 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: vllm_scheduler_config = config.scheduler_config vllm_cache_config = config.cache_config vllm_parallel_config = config.parallel_config - vllm_use_custom_attention = bool( - os.getenv("VLLM_ATTENTION_BACKEND", "None").lower() == "custom" + vllm_use_atom_attention = bool( + os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "0" ) # here use the ATOM compilation config, as the ATOM compile policy is used @@ -73,7 +73,7 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: vllm_scheduler_config=vllm_scheduler_config, vllm_cache_config=vllm_cache_config, vllm_quant_config=vllm_quant_config, - vllm_use_custom_attention=vllm_use_custom_attention, + vllm_use_atom_attention=vllm_use_atom_attention, ) # specific diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py index c7e13a3ce..6b3f80b13 100644 --- a/atom/plugin/prepare.py +++ b/atom/plugin/prepare.py @@ -46,14 +46,16 @@ def prepare_model(config: Any, engine: str): # different engine passed different config if is_vllm(): - model_arch = config.model_config.architectures[0] + # FIXME: remove the legacy code here + # model_arch = config.model_config.architectures[0] + raise NotImplementedError("VLLM will not be supported for now") elif is_sglang(): model_arch = config.architectures[0] # import here to avoid partial initialization from .register import ( _ATOM_SUPPORTED_MODELS, - register_ops_to_vllm, + # register_ops_to_vllm, register_ops_to_sglang, init_aiter_dist, set_attn_cls, @@ -74,7 +76,8 @@ def prepare_model(config: Any, engine: str): logger.info(f"ATOM model class for {model_arch} is {model_cls}") if is_vllm(): - register_ops_to_vllm(atom_config=atom_config) + # register_ops_to_vllm(atom_config=atom_config) + raise NotImplementedError("VLLM will not be supported for now") elif is_sglang(): register_ops_to_sglang(atom_config=atom_config) diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 88b4ac87c..08d94b13d 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -13,20 +13,6 @@ } -def _register_custom_attention_to_vllm() -> None: - from vllm.v1.attention.backends.registry import ( - register_backend, - AttentionBackendEnum, - ) - - logger.info("Register custom attention backend AiterBackend to vLLM") - register_backend( - backend=AttentionBackendEnum.CUSTOM, - is_mamba=False, - class_path="atom.model_ops.attentions.aiter_attention.AiterBackend", - ) - - def _register_custom_attention_to_sglang() -> None: from sglang.srt.layers.attention.attention_registry import ( @@ -45,18 +31,6 @@ def create_atom_backend(runner): return AiterAttnBackend(runner) -def register_ops_to_vllm(atom_config: Config) -> None: - """ - Register custom ops to vllm, including attention - """ - if atom_config.plugin_config.vllm_use_custom_attention: - _register_custom_attention_to_vllm() - else: - logger.warning( - "Please export VLLM_ATTENTION_BACKEND=CUSTOM to use atom attention" - ) - - def register_ops_to_sglang(atom_config: Config) -> None: """ Register custom ops to sglang, including attention diff --git a/atom/plugin/vllm/__init__.py b/atom/plugin/vllm/__init__.py new file mode 100644 index 000000000..a3a23d027 --- /dev/null +++ b/atom/plugin/vllm/__init__.py @@ -0,0 +1,5 @@ +"""vLLM plugin integration for ATOM.""" + +from .register import patch_model_registry, register_platform + +__all__ = ["register_platform", "patch_model_registry"] diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py new file mode 100644 index 000000000..49852ddab --- /dev/null +++ b/atom/plugin/vllm/model_wrapper.py @@ -0,0 +1,141 @@ +from collections.abc import Iterable + +import importlib +import torch +import torch.nn as nn +from aiter.dist.parallel_state import ( + get_pp_group, + get_tp_group, +) +from vllm.config import VllmConfig +from vllm.model_executor.models.interfaces import ( + SupportsPP, + SupportsQuant, +) +from vllm.model_executor.models.interfaces_base import ( + VllmModel, + VllmModelForTextGeneration, +) +from vllm.sequence import IntermediateTensors + +import atom +from atom.plugin.config import generate_atom_config_for_plugin_mode + +import logging + + +logger = logging.getLogger("atom") + + +_ATOM_MODEL_CLASSES: dict[str, str] = { + "Qwen3ForCausalLM": "atom.models.qwen3:Qwen3ForCausalLM", + "Qwen3MoeForCausalLM": "atom.models.qwen3_moe:Qwen3MoeForCausalLM", +} + + +def _get_atom_model_cls(model_arch: str) -> type: + try: + model_ref = _ATOM_MODEL_CLASSES[model_arch] + except KeyError as e: + raise ValueError(f"Unsupported ATOM model architecture: {model_arch}") from e + + module_path, class_name = model_ref.split(":", 1) + return getattr(importlib.import_module(module_path), class_name) + + +def _prepare_env(atom_config) -> None: + from atom.plugin.register import set_attn_cls, init_aiter_dist + # set global attention class + logger.info("Set global attention class") + set_attn_cls() + + # init aiter dist for using aiter custom collective ops + logger.info("Init aiter dist for using aiter custom collective ops") + init_aiter_dist(config=atom_config) + + +class ATOMModelBase(nn.Module, VllmModel, SupportsQuant, SupportsPP): + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.config = vllm_config.model_config.hf_config + self.text_config = self.config.get_text_config() + self.cache_config = vllm_config.cache_config + self.device_config = vllm_config.device_config + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.quant_config = vllm_config.quant_config + + # Weights to skip in `self.load_weights` + self.skip_prefixes: list[str] = [] + self.skip_substrs: list[str] = [] + self.ignore_unexpected_prefixes: list[str] = [] + self.ignore_unexpected_suffixes: list[str] = [] + + self.atom_config = generate_atom_config_for_plugin_mode(vllm_config) + + _prepare_env(atom_config = self.atom_config) + + model_arch = vllm_config.model_config.architectures[0] + model_cls = _get_atom_model_cls(model_arch) + + logger.info(f"Construct ATOM model {model_arch} for vLLM plugin mode") + self.model = model_cls(self.atom_config) + + if self.model is None: + model_arch = vllm_config.model_config.architectures[0] + raise ValueError( + f"The model {model_arch} is not supported by model impl backend atom" + ) + + # here init aiter dist for using aiter custom collective ops + self.pp_group = get_pp_group() + self.tp_group = get_tp_group() + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs, + ) -> torch.Tensor | IntermediateTensors: + if not self.pp_group.is_first_rank: + assert intermediate_tensors is not None + input_ids = None + inputs_embeds = intermediate_tensors["hidden_states"] + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + if not self.pp_group.is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + return hidden_states + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + return self.model.load_weights(weights) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + logits = self.model.compute_logits(hidden_states) + return logits + + +class ATOMForCausalLM(ATOMModelBase, VllmModelForTextGeneration): ... + + +class ATOMMoEForCausalLM(ATOMModelBase, VllmModelForTextGeneration): ... diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py new file mode 100644 index 000000000..250238ce7 --- /dev/null +++ b/atom/plugin/vllm/register.py @@ -0,0 +1,91 @@ +import os +from typing import Optional +import logging + +import atom +from atom.plugin.prepare import _set_framework_backbone + +logger = logging.getLogger("atom") + +# this flag is used to enable the vllm plugin mode +disable_vllm_plugin = os.getenv("ATOM_DISABLE_VLLM_PLUGIN", "0").lower() == "1" +disable_vllm_plugin_attention = os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "1" + +# those 2 models are covering most of dense and moe models +ATOM_CAUSAL_LM_MODEL_WRAPPER = "atom.plugin.vllm.model_wrapper:ATOMForCausalLM" +ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER = "atom.plugin.vllm.model_wrapper:ATOMMoEForCausalLM" + +# when register new model to vllm, add here +# Keys is from hf config arch name +_VLLM_MODEL_REGISTRY_OVERRIDES: dict[str, str] = { + "Qwen3ForCausalLM": ATOM_CAUSAL_LM_MODEL_WRAPPER, + "Qwen3MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, +} + + +if not disable_vllm_plugin: + from vllm.platforms.rocm import RocmPlatform + logger.info("Enable vLLM plugin mode") +else: + logger.info("Disable vLLM plugin mode") + RocmPlatform = object + + +def _set_plugin_mode() -> None: + _set_framework_backbone("vllm") + + +class ATOMPlatform(RocmPlatform): + # for multi-modality model, for makeing AiterBackend supported by vit + # get_supported_vit_attn_backends needs to be overridden + @classmethod + def get_attn_backend_cls(cls, selected_backend, attn_selector_config) -> str: + # fallback to original behavior of vllm mainline + if disable_vllm_plugin_attention: + logger.info("Fallback to original behavior of vLLM mainline") + return super().get_attn_backend_cls(selected_backend, attn_selector_config) + + # return atom attention backend + logger.info("Use atom attention backend") + return "atom.model_ops.attentions.aiter_attention.AiterBackend" + + +def register_platform() -> Optional[str]: + + if disable_vllm_plugin: + # return None instead of error because the flag can be used to + # run pure vllm mode without ATOM plugin + return None + + _set_plugin_mode() + + # return the ATOM platform to vllm + return f"{__name__}.ATOMPlatform" + + +def patch_model_registry() -> None: + if disable_vllm_plugin: + return + + import vllm.model_executor.models.registry as vllm_model_registry + + any_updated = False + for arch, qual in _VLLM_MODEL_REGISTRY_OVERRIDES.items(): + module_name, class_name = qual.split(":", 1) + existing = vllm_model_registry.ModelRegistry.models.get(arch) + if existing is not None: + # If already overridden to the same target, skip re-registering. + if ( + getattr(existing, "module_name", None) == module_name + and getattr(existing, "class_name", None) == class_name + ): + continue + + logger.info(f"Register model {arch} to vLLM with {qual}") + vllm_model_registry.ModelRegistry.register_model(arch, qual) + any_updated = True + + # clear lru cache + if any_updated: + vllm_model_registry._try_load_model_cls.cache_clear() + vllm_model_registry._try_inspect_model_cls.cache_clear() diff --git a/pyproject.toml b/pyproject.toml index 49d3d8c0b..6828fe861 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,6 @@ [build-system] -requires = ["setuptools_scm[toml]>=6.2"] +requires = ["setuptools>=61", "wheel", "setuptools_scm[toml]>=6.2"] +build-backend = "setuptools.build_meta" [project] name = "atom" @@ -24,3 +25,15 @@ fallback_version = "0.1.0" [tool.setuptools.packages.find] where = ["."] include = ["atom*"] + +[project.entry-points."vllm.platform_plugins"] +# by default the entry-points will be installed anyway +# but the plugin mode for platforms can be disabled by +# ATOM_DISABLE_VLLM_PLUGIN=1 +atom = "atom.plugin.vllm.register:register_platform" + +[project.entry-points."vllm.general_plugins"] +# by default the entry-points will be installed anyway +# but the plugin mode for models can be disabled by +# ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1 +atom_model_registry = "atom.plugin.vllm.register:patch_model_registry" From 7c54abe7db156f1f3603e00d5e9949f8493754cd Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 9 Feb 2026 11:48:16 +0800 Subject: [PATCH 17/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/vllm/__init__.py | 4 +- atom/plugin/vllm/platform.py | 39 +++++ atom/plugin/vllm/register.py | 34 +---- pyproject.toml | 2 +- recipes/Model-Impl-Backend.md | 168 ---------------------- recipes/SGLang-ATOM-Model-Impl-Backend.md | 73 ++++++++++ recipes/vLLM-ATOM-OOT-Plugin-Backend.md | 87 +++++++++++ 7 files changed, 209 insertions(+), 198 deletions(-) create mode 100644 atom/plugin/vllm/platform.py delete mode 100644 recipes/Model-Impl-Backend.md create mode 100644 recipes/SGLang-ATOM-Model-Impl-Backend.md create mode 100644 recipes/vLLM-ATOM-OOT-Plugin-Backend.md diff --git a/atom/plugin/vllm/__init__.py b/atom/plugin/vllm/__init__.py index a3a23d027..a76f02676 100644 --- a/atom/plugin/vllm/__init__.py +++ b/atom/plugin/vllm/__init__.py @@ -1,5 +1,5 @@ """vLLM plugin integration for ATOM.""" -from .register import patch_model_registry, register_platform +from .register import register_model, register_platform -__all__ = ["register_platform", "patch_model_registry"] +__all__ = ["register_platform", "register_model"] diff --git a/atom/plugin/vllm/platform.py b/atom/plugin/vllm/platform.py new file mode 100644 index 000000000..08298f9f0 --- /dev/null +++ b/atom/plugin/vllm/platform.py @@ -0,0 +1,39 @@ +"""ATOM vLLM platform integration. + +This module contains the vLLM `Platform` subclass used in ATOM's vLLM plugin +mode. Keep platform behavior here so `register.py` can focus on registration +and wiring only. +""" + +import logging +import os + +logger = logging.getLogger("atom") + +# This flag is used to enable the vLLM plugin mode. +disable_vllm_plugin = os.getenv("ATOM_DISABLE_VLLM_PLUGIN", "0").lower() == "1" +disable_vllm_plugin_attention = ( + os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "1" +) + +if not disable_vllm_plugin: + from vllm.platforms.rocm import RocmPlatform +else: + # Keep the module importable even when vLLM isn't available / plugin disabled. + RocmPlatform = object # type: ignore[assignment] + + +class ATOMPlatform(RocmPlatform): + # For multi-modality models, to make AiterBackend supported by ViT, + # get_supported_vit_attn_backends may need to be overridden. + @classmethod + def get_attn_backend_cls(cls, selected_backend, attn_selector_config) -> str: + # Fall back to original behavior of vLLM mainline. + if disable_vllm_plugin_attention: + logger.info("Fallback to original vLLM attention backend") + return super().get_attn_backend_cls(selected_backend, attn_selector_config) + + # Return ATOM attention backend. + logger.info("Use atom attention backend") + return "atom.model_ops.attentions.aiter_attention.AiterBackend" + diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py index 250238ce7..2b1e0a24b 100644 --- a/atom/plugin/vllm/register.py +++ b/atom/plugin/vllm/register.py @@ -9,7 +9,9 @@ # this flag is used to enable the vllm plugin mode disable_vllm_plugin = os.getenv("ATOM_DISABLE_VLLM_PLUGIN", "0").lower() == "1" -disable_vllm_plugin_attention = os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "1" +disable_vllm_plugin_attention = ( + os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "1" +) # those 2 models are covering most of dense and moe models ATOM_CAUSAL_LM_MODEL_WRAPPER = "atom.plugin.vllm.model_wrapper:ATOMForCausalLM" @@ -22,49 +24,27 @@ "Qwen3MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, } - -if not disable_vllm_plugin: - from vllm.platforms.rocm import RocmPlatform - logger.info("Enable vLLM plugin mode") -else: - logger.info("Disable vLLM plugin mode") - RocmPlatform = object - - def _set_plugin_mode() -> None: _set_framework_backbone("vllm") -class ATOMPlatform(RocmPlatform): - # for multi-modality model, for makeing AiterBackend supported by vit - # get_supported_vit_attn_backends needs to be overridden - @classmethod - def get_attn_backend_cls(cls, selected_backend, attn_selector_config) -> str: - # fallback to original behavior of vllm mainline - if disable_vllm_plugin_attention: - logger.info("Fallback to original behavior of vLLM mainline") - return super().get_attn_backend_cls(selected_backend, attn_selector_config) - - # return atom attention backend - logger.info("Use atom attention backend") - return "atom.model_ops.attentions.aiter_attention.AiterBackend" - - def register_platform() -> Optional[str]: if disable_vllm_plugin: # return None instead of error because the flag can be used to # run pure vllm mode without ATOM plugin + logger.info("Disable ATOM OOT plugin platforms") return None _set_plugin_mode() # return the ATOM platform to vllm - return f"{__name__}.ATOMPlatform" + return "atom.plugin.vllm.platform.ATOMPlatform" -def patch_model_registry() -> None: +def register_model() -> None: if disable_vllm_plugin: + logger.info("Disable ATOM model register") return import vllm.model_executor.models.registry as vllm_model_registry diff --git a/pyproject.toml b/pyproject.toml index 6828fe861..7e6c623bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,4 +36,4 @@ atom = "atom.plugin.vllm.register:register_platform" # by default the entry-points will be installed anyway # but the plugin mode for models can be disabled by # ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1 -atom_model_registry = "atom.plugin.vllm.register:patch_model_registry" +atom_model_registry = "atom.plugin.vllm.register:register_model" diff --git a/recipes/Model-Impl-Backend.md b/recipes/Model-Impl-Backend.md deleted file mode 100644 index ecf1782dc..000000000 --- a/recipes/Model-Impl-Backend.md +++ /dev/null @@ -1,168 +0,0 @@ -# Model Impl Backend of vLLM and SGLang -ATOM can work as model implementation backend of popular framework, like vLLM and SGLang. The users can launch vLLM and SGLang server like before and specify an extra argument to enable the ATOM model backend, where the optimized implementation of the required target model will be provided to vLLM and SGLang to execute. When ATOM working under this mode, both framework-level features from vLLM/SGLang and latest model-level fusion kernels from ATOM/AITER can be combined together to achieve the competitive performance. - -- Here is a detailed design slide for this feature: https://amdcloud-my.sharepoint.com/:p:/g/personal/zejchen_amd_com/IQCFdvmEeLTWT7ysApmZv_hVAfw2nTo8iesJZGblHS0evqQ?e=hjnIDM -- Here is the RFC to introduce the ATOM as model impl backend into vLLM: https://github.com/vllm-project/vllm/issues/33478 -- Here is the RFC to introduce the ATOM as model impl backend into SGLang: TODO - -## Preparing environment for vLLM with ATOM model backend -Here is the PR to introduce ATOM into vLLM: https://github.com/vllm-project/vllm/pull/32160, when this PR would be merged, the official vLLM can be used, but for now you need to use develop vllm branch - -Pull the latest docker from vLLM official nightly docker for ROCm from https://hub.docker.com/r/rocm/vllm-dev/tags -```bash -docker pull rocm/vllm-dev:nightly -``` -Launch the container as usual, then all the next operations will be executed inside the container -Then the specific vLLM should be used because the PR to introduce the ATOM into vLLM has not been merged yet, so you need to: -```bash -pip uninstall -y vllm -git clone https://github.com/zejunchen-zejun/vllm.git -cd vllm -git checkout origin/zejun/model_impl -export PYTORCH_ROCM_ARCH="gfx950" -python3 setup.py develop 2>&1 | tee build.log -``` -Then the ATOM should be installed -```bash -git clone https://github.com/zejunchen-zejun/ATOM.git -cd ATOM -git checkout origin/zejun/plugin_for_atom_1223 -pip install -e . 2>&1 | tee build.log -``` -For AITER, there is no specific requirements, however, if you find any latest fusion kernels are missing, you may need to upgrade the AITER -Additionally, you may need to upgrade your triton version by: -```bash -pip install --upgrade triton -``` - -### Launching server of vLLM with ATOM model backend -You just need to deploy 2 code changes to your previous server launch command. The one is using CUSTOM vLLM attention backend, the other is a new argument of specifying the ATOM model impl backend. Here is the an example. From the example, the specific fusion kernels are used by enabling the env flags, which is not easy to integrate into vLLM as vLLM has some heuristic to stipulate the boundary of ops and layers, where ATOM can provide the those kernels -```bash -export VLLM_ATTENTION_BACKEND=CUSTOM - -export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 -export AITER_QUICK_REDUCE_QUANTIZATION=INT4 - -export SAFETENSORS_FAST_GPU=1 -export VLLM_ROCM_USE_AITER=1 -export VLLM_RPC_TIMEOUT=1800000 - -export VLLM_CACHE_ROOT=/root/.cache/vllm -export TORCHINDUCTOR_CACHE_DIR=/root/.cache/inductor - -rm -rf /root/.cache/ - -model_path= - -vllm serve $model_path \ - --host localhost \ - --port 8000 \ - --tensor-parallel-size 8 \ - --enable-expert-parallel \ - --trust-remote-code \ - --disable-log-requests \ - --gpu_memory_utilization 0.9 \ - --async-scheduling \ - --load-format fastsafetensors \ - --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \ - --kv-cache-dtype fp8 \ - --max-num-batched-tokens 18432 \ - --max-model-len 16384 \ - --no-enable-prefix-caching \ - --model-impl atom \ - 2>&1 | tee log.serve.log & -``` - -### Launching client for validating the accuracy -```bash -addr=localhost -port=8000 -url=http://${addr}:${port}/v1/completions -model= -task=gsm8k -lm_eval --model local-completions \ - --model_args model=${model},base_url=${url},num_concurrent=65,max_retries=1,tokenized_requests=False \ - --tasks ${task} \ - --num_fewshot 3 \ - 2>&1 | tee log.lmeval.log -``` - -### Results for accuracy validation -|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| -|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| -|gsm8k| 3|flexible-extract| 3|exact_match|↑ |0.8901|± |0.0086| -| | |strict-match | 3|exact_match|↑ |0.8772|± |0.0090| - -### Known Limitations -There are some known limitations for now: -- Only Qwen-Dense and Qwen-MoE family models are supported -- Only TP and EP are supported - - -## Preparing environment for SGLang with ATOM model backend -Here is the PR to introduce ATOM into SGLang: https://github.com/sgl-project/sglang/pull/16944, when this PR would be merged, the official SGLang can be used, but for now you need to use develop vllm branch -Pull the latest docker from SGLang official nightly docker for ROCm from https://hub.docker.com/r/rocm/sgl-dev/tags -```bash -docker pull rocm/sgl-dev:v0.5.8-rocm720-mi35x-20260130-preview -``` -Launch the container as usual, then all the next operations will be executed inside the container -Then the specific SGLang should be used because the PR to introduce the ATOM into SGLang has not been merged yet, so you need to: -```bash -git clone https://github.com/zejunchen-zejun/sglang.git -git checkout remotes/origin/zejun/model_impl -pip uninstall sglang -y -pip uninstall sgl-kernel -y -cd sgl-kernel -python3 setup_rocm.py install -export PYTHONPATH= -``` -Then the ATOM should be installed -```bash -git clone https://github.com/zejunchen-zejun/ATOM.git -cd ATOM -git checkout origin/zejun/plugin_for_atom_1223 -pip install -e . 2>&1 | tee build.log -``` -For AITER, there is no specific requirements, however, if you find any latest fusion kernels are missing, you may need to upgrade the AITER - -### Launching server of SGLang with ATOM model backend -You just need to deploy single code change, as add --model-impl atom to your SGLang server command. Here is an example: -```bash -export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 - -# quick allreduce -export AITER_QUICK_REDUCE_QUANTIZATION=INT4 -model_path=/data/models/Qwen3-235B-A22B-Instruct-2507-FP8 - -python3 -m sglang.launch_server \ - --model-path $model_path \ - --host localhost \ - --port 8000 \ - --trust-remote-code \ - --tensor-parallel-size 8 \ - --expert-parallel-size 8 \ - --kv-cache-dtype fp8_e4m3 \ - --mem-fraction-static 0.8 \ - --model-impl atom \ - 2>&1 | tee log.serve.log & -``` - -### Launching client for validating the accuracy -```bash -addr=localhost -port=8000 -url=http://${addr}:${port}/v1/completions -model= -task=gsm8k -lm_eval --model local-completions \ - --model_args model=${model},base_url=${url},num_concurrent=65,max_retries=1,tokenized_requests=False \ - --tasks ${task} \ - --num_fewshot 3 \ - 2>&1 | tee log.lmeval.log -``` - -### Known Limitations -There are some known limitations for now: -- Only Qwen-Dense and Qwen-MoE family models are supported -- Only TP and EP are supported -- For SGLang, there is still accuracy issue for now, but we will fix it soon diff --git a/recipes/SGLang-ATOM-Model-Impl-Backend.md b/recipes/SGLang-ATOM-Model-Impl-Backend.md new file mode 100644 index 000000000..015279af8 --- /dev/null +++ b/recipes/SGLang-ATOM-Model-Impl-Backend.md @@ -0,0 +1,73 @@ +# Model Impl Backend of SGLang +ATOM can work as model implementation backend of popular framework, like SGLang. The users can launch the server like before and specify an extra argument to enable the ATOM model backend, where the optimized implementation of the required target model will be provided to SGLang to execute. When ATOM working under this mode, both framework-level features from SGLang and latest model-level fusion kernels from ATOM/AITER can be combined together to achieve the competitive performance. + +- Here is a detailed design slide for this feature: https://amdcloud-my.sharepoint.com/:p:/g/personal/zejchen_amd_com/IQCFdvmEeLTWT7ysApmZv_hVAfw2nTo8iesJZGblHS0evqQ?e=hjnIDM +- Here is the RFC to introduce the ATOM as model impl backend into SGLang: TODO + +## Preparing environment for SGLang with ATOM model backend +Here is the PR to introduce ATOM into SGLang: https://github.com/sgl-project/sglang/pull/16944, when this PR would be merged, the official SGLang can be used, but for now you need to use develop vllm branch +Pull the latest docker from SGLang official nightly docker for ROCm from https://hub.docker.com/r/rocm/sgl-dev/tags +```bash +docker pull rocm/sgl-dev:v0.5.8-rocm720-mi35x-20260130-preview +``` +Launch the container as usual, then all the next operations will be executed inside the container +Then the specific SGLang should be used because the PR to introduce the ATOM into SGLang has not been merged yet, so you need to: +```bash +git clone https://github.com/zejunchen-zejun/sglang.git +git checkout remotes/origin/zejun/model_impl +pip uninstall sglang -y +pip uninstall sgl-kernel -y +cd sgl-kernel +python3 setup_rocm.py install +export PYTHONPATH= +``` +Then the ATOM should be installed +```bash +git clone https://github.com/zejunchen-zejun/ATOM.git +cd ATOM +git checkout origin/zejun/plugin_for_atom_1223 +pip install -e . 2>&1 | tee build.log +``` +For AITER, there is no specific requirements, however, if you find any latest fusion kernels are missing, you may need to upgrade the AITER + +### Launching server of SGLang with ATOM model backend +You just need to deploy single code change, as add --model-impl atom to your SGLang server command. Here is an example: +```bash +export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 + +# quick allreduce +export AITER_QUICK_REDUCE_QUANTIZATION=INT4 +model_path=/data/models/Qwen3-235B-A22B-Instruct-2507-FP8 + +python3 -m sglang.launch_server \ + --model-path $model_path \ + --host localhost \ + --port 8000 \ + --trust-remote-code \ + --tensor-parallel-size 8 \ + --expert-parallel-size 8 \ + --kv-cache-dtype fp8_e4m3 \ + --mem-fraction-static 0.8 \ + --model-impl atom \ + 2>&1 | tee log.serve.log & +``` + +### Launching client for validating the accuracy +```bash +addr=localhost +port=8000 +url=http://${addr}:${port}/v1/completions +model= +task=gsm8k +lm_eval --model local-completions \ + --model_args model=${model},base_url=${url},num_concurrent=65,max_retries=1,tokenized_requests=False \ + --tasks ${task} \ + --num_fewshot 3 \ + 2>&1 | tee log.lmeval.log +``` + +### Known Limitations +There are some known limitations for now: +- Only Qwen-Dense and Qwen-MoE family models are supported +- Only TP and EP are supported +- For SGLang, there is still accuracy issue for now, but we will fix it soon diff --git a/recipes/vLLM-ATOM-OOT-Plugin-Backend.md b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md new file mode 100644 index 000000000..a0629e850 --- /dev/null +++ b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md @@ -0,0 +1,87 @@ +# vLLM out-of-tree ATOM Plugin Backend +ATOM can work as the OOT plugin backend of vLLM. The OOT register mechanism is quite mature and most of accelerators have leveraged this design to register their devices into vLLM without any code changes in upper framework. ATOM follows this design and provide the layer/op and model implementations to vLLM. The frontend users can launch vLLM server like before and there is no need to specify any arguments. Meanwhile the ATOM platform can leverage most of the vLLM features and focus more on model- and kernel-level optimizations. For the overall design, here is a RFC to enable ATOM work as the OOT plugin platform of vLLM: https://github.com/ROCm/ATOM/issues/201 + +## Preparing environment for vLLM with ATOM model backend +Pull the latest docker from vLLM official nightly docker for ROCm +```bash +docker pull rocm/vllm-dev:nightly +``` + +Then the ATOM should be installed. When the following PR merged, you can use ATOM main branch +```bash +git clone https://github.com/zejunchen-zejun/ATOM.git +cd ATOM +git checkout origin/zejun/plugin_for_atom_1223 +pip install -e . 2>&1 | tee build.log +``` +For AITER, there is no specific requirements, however, if you find any latest fusion kernels are missing, you may need to upgrade the AITER +Additionally, you may need to upgrade your triton version by: +```bash +pip install --upgrade triton +``` + +### Launching server of vLLM with ATOM OOT Plugin Platform +There is no code change to vLLM side, so you can launch the vLLM server like before without any specific argument +```bash +export SAFETENSORS_FAST_GPU=1 +export VLLM_ROCM_USE_AITER=1 +export VLLM_RPC_TIMEOUT=1800000 + +export VLLM_CACHE_ROOT=/root/.cache/vllm +export TORCHINDUCTOR_CACHE_DIR=/root/.cache/inductor + +rm -rf /root/.cache/ + +model_path= + +vllm serve $model_path \ + --host localhost \ + --port 8000 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --trust-remote-code \ + --disable-log-requests \ + --gpu_memory_utilization 0.9 \ + --async-scheduling \ + --load-format fastsafetensors \ + --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \ + --kv-cache-dtype fp8 \ + --max-num-batched-tokens 18432 \ + --max-model-len 16384 \ + --no-enable-prefix-caching \ + 2>&1 | tee log.serve.log & +``` + +If you want to disable the ATOM OOT plugin platform, you can use below env flags. The default value is 0 +```bash +export ATOM_DISABLE_VLLM_PLUGIN=1 +``` +If you want to disable the ATOM Attention Backend, you can use below env flags. The default value is 0 +```bash +export ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1 +``` + +### Launching client for validating the accuracy +```bash +addr=localhost +port=8000 +url=http://${addr}:${port}/v1/completions +model= +task=gsm8k +lm_eval --model local-completions \ + --model_args model=${model},base_url=${url},num_concurrent=65,max_retries=1,tokenized_requests=False \ + --tasks ${task} \ + --num_fewshot 3 \ + 2>&1 | tee log.lmeval.log +``` + +### Results for accuracy validation +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 3|exact_match|↑ |0.9037|± |0.0081| +| | |strict-match | 3|exact_match|↑ |0.8832|± |0.0088| + +### Known Limitations +There are some known limitations for now: +- Only Qwen-Dense and Qwen-MoE family models are supported +- Only TP and EP are supported From a44bed1c343ddf84e325b3ab264cc696877d6881 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 9 Feb 2026 13:28:41 +0800 Subject: [PATCH 18/37] add Signed-off-by: zejunchen-zejun --- recipes/SGLang-ATOM-Model-Impl-Backend.md | 1 - 1 file changed, 1 deletion(-) diff --git a/recipes/SGLang-ATOM-Model-Impl-Backend.md b/recipes/SGLang-ATOM-Model-Impl-Backend.md index 015279af8..fe409a269 100644 --- a/recipes/SGLang-ATOM-Model-Impl-Backend.md +++ b/recipes/SGLang-ATOM-Model-Impl-Backend.md @@ -70,4 +70,3 @@ lm_eval --model local-completions \ There are some known limitations for now: - Only Qwen-Dense and Qwen-MoE family models are supported - Only TP and EP are supported -- For SGLang, there is still accuracy issue for now, but we will fix it soon From 43604c9948b67e9dd893bf0954021978ddce357c Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 9 Feb 2026 15:22:54 +0800 Subject: [PATCH 19/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/attention_mha.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index 42fea642c..e962634a1 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -713,6 +713,13 @@ def forward_impl_plugin_mode( layer.k_scale = self.k_scale layer.v_scale = self.v_scale + query = query[:num_actual_tokens] + if key is not None: + key = key[:num_actual_tokens] + if value is not None: + value = value[:num_actual_tokens] + output_actual_tokens = output[:num_actual_tokens] + # rope and cache flush fusion. ATOM always use shuffle layout for kv cache result = self.rope_cache_plugin_mode( q=query, @@ -729,16 +736,6 @@ def forward_impl_plugin_mode( ) query, key, value, k_cache, v_cache, k_scale, v_scale = result - # The tokens are storaged as [decode:extend:prefill] order - # which is decided by the vllm - query = query[:num_actual_tokens] - if key is not None: - key = key[:num_actual_tokens] - if value is not None: - value = value[:num_actual_tokens] - - output_actual_tokens = output[:num_actual_tokens] - num_decodes = attn_metadata.plugin_metadata.num_decodes num_prefills = attn_metadata.plugin_metadata.num_prefills num_extends = attn_metadata.plugin_metadata.num_extends From 285929a3a74897632ec5ca2cc7e6d226f05cd2d8 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 9 Feb 2026 19:01:27 +0800 Subject: [PATCH 20/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/attention_mha.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index e962634a1..a02ff6781 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -713,13 +713,6 @@ def forward_impl_plugin_mode( layer.k_scale = self.k_scale layer.v_scale = self.v_scale - query = query[:num_actual_tokens] - if key is not None: - key = key[:num_actual_tokens] - if value is not None: - value = value[:num_actual_tokens] - output_actual_tokens = output[:num_actual_tokens] - # rope and cache flush fusion. ATOM always use shuffle layout for kv cache result = self.rope_cache_plugin_mode( q=query, @@ -736,6 +729,15 @@ def forward_impl_plugin_mode( ) query, key, value, k_cache, v_cache, k_scale, v_scale = result + # as vLLM cuda graph capture padding mechanism, here split the qkvo with + # the actual tokens + query = query[:num_actual_tokens] + if key is not None: + key = key[:num_actual_tokens] + if value is not None: + value = value[:num_actual_tokens] + output_actual_tokens = output[:num_actual_tokens] + num_decodes = attn_metadata.plugin_metadata.num_decodes num_prefills = attn_metadata.plugin_metadata.num_prefills num_extends = attn_metadata.plugin_metadata.num_extends From 77795ebeb11fa983eb1f7100a92b39a3fa483c09 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Mon, 9 Feb 2026 21:24:09 +0800 Subject: [PATCH 21/37] add Signed-off-by: zejunchen-zejun --- atom/model_ops/paged_attention.py | 1 + atom/plugin/attention.py | 23 ++++------------------- atom/plugin/attention_mha.py | 7 +++---- 3 files changed, 8 insertions(+), 23 deletions(-) diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index c498f34ce..20a26576c 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -63,6 +63,7 @@ def __init__( # for plugin mode if is_vllm(): self.use_mla = use_mla + self.rotary_emb = rotary_emb from vllm.attention.layer import Attention, AttentionType atom_config = get_current_atom_config() diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index c574039fa..9d1446a63 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -305,10 +305,6 @@ def build( seq_lens = common_attn_metadata.seq_lens.cpu() query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - # used to store the positions of each tokens of each request - # for computing ROPE - positions = [] - decode_metadata = None if num_decodes > 0: decode_metadata = AiterFlashAttentionDecodeMetadata( @@ -317,8 +313,6 @@ def build( max_seq_len=seq_lens[:num_decodes].max().item(), query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1], ) - for seq_len in seq_lens[:num_decodes]: - positions.append(seq_len - 1) extend_metadata = None if num_extends > 0: @@ -447,14 +441,6 @@ def build( chunk_context_metadata=chunk_context_metadata, ) - for idx in range(num_extends): - extend_start_seq_len = ( - seq_lens_for_extend[idx] - query_lens_for_extend[idx] - ) - extend_end_seq_len = seq_lens_for_extend[idx] - for pos in range(extend_start_seq_len, extend_end_seq_len): - positions.append(pos) - prefill_metadata = None if num_prefills > 0: query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :] @@ -467,9 +453,6 @@ def build( max_seq_len=seq_lens[num_decodes + num_extends :].max().item(), query_start_loc=query_start_loc_device - query_start_loc_device[0], ) - for prefill_seq_len in seq_lens[num_decodes + num_extends :]: - for pos in range(prefill_seq_len): - positions.append(pos) num_actual_kv_tokens = torch.sum(seq_lens).item() @@ -484,9 +467,8 @@ def build( context_graph_bs = context_batch_size num_actual_tokens = common_attn_metadata.num_actual_tokens - self.positions.np[:num_actual_tokens] = positions context = Context( - positions=self.positions.copy_to_gpu(num_actual_tokens), + positions=None, is_prefill=has_prefill, batch_size=context_batch_size, graph_bs=context_graph_bs, @@ -628,5 +610,8 @@ def unified_attention_with_output_base_for_plugin_mode( if envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: output = self.attn(q, positions, qkv) else: + # calculate the q and k with rotary embedding + if self.rotary_emb is not None: + q, k = self.rotary_emb(positions, q, k) output = self.attn(q, k, v) return output diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index a02ff6781..f861df487 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -280,7 +280,9 @@ def rope_cache_plugin_mode( q, k, v = qkv.split( [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1 ) - elif use_triton_attn and self.rotary_emb is not None: + # elif use_triton_attn and self.rotary_emb is not None: + elif 0: + # FIXME: this should be fixed by moving rope outside of attention k_scale = v_scale = self.kv_scale q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache( @@ -304,9 +306,6 @@ def rope_cache_plugin_mode( output_zeros=False, ) else: - if self.rotary_emb is not None: - assert position is not None - q, k = self.rotary_emb(position, q, k) if self.q_norm is not None: q = self.q_norm(q) if self.k_norm is not None: From b13a67007bdce90c0e365f8d51484c2d814a1921 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 10 Feb 2026 10:06:18 +0800 Subject: [PATCH 22/37] make lint happy Signed-off-by: zejunchen-zejun --- atom/models/deepseek_v2.py | 2 +- atom/models/qwen3.py | 4 ++-- atom/models/qwen3_moe.py | 4 ++-- atom/plugin/vllm/model_wrapper.py | 5 ++--- atom/plugin/vllm/platform.py | 1 - atom/plugin/vllm/register.py | 2 +- 6 files changed, 8 insertions(+), 10 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index f39504195..ac58cabfc 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -24,7 +24,7 @@ """Inference-only DeepseekV2/DeepseekV3 model.""" import logging -from typing import Any, Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from aiter import ( diff --git a/atom/models/qwen3.py b/atom/models/qwen3.py index 75c760ff7..1c932b990 100644 --- a/atom/models/qwen3.py +++ b/atom/models/qwen3.py @@ -24,14 +24,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Any, Iterable +from typing import Any, Iterable import torch # import torch.distributed as dist from aiter.dist.parallel_state import get_tp_group from aiter.rotary_embedding import get_rope -from atom.config import Config, QuantizationConfig +from atom.config import Config from atom.model_ops.activation import SiluAndMul # from atom.model_ops.attention import Attention diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index f8adfc9cf..97cec9829 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union, Any, Iterable +from typing import Optional, Union, Any, Iterable import torch from aiter.dist.communication_op import tensor_model_parallel_all_reduce @@ -32,7 +32,7 @@ from atom.model_loader.loader import load_model_in_plugin_mode # import torch.distributed as dist -from transformers import PretrainedConfig, Qwen3Config +from transformers import PretrainedConfig ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 49852ddab..7304b1c59 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -18,12 +18,10 @@ ) from vllm.sequence import IntermediateTensors -import atom from atom.plugin.config import generate_atom_config_for_plugin_mode import logging - logger = logging.getLogger("atom") @@ -45,6 +43,7 @@ def _get_atom_model_cls(model_arch: str) -> type: def _prepare_env(atom_config) -> None: from atom.plugin.register import set_attn_cls, init_aiter_dist + # set global attention class logger.info("Set global attention class") set_attn_cls() @@ -77,7 +76,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.atom_config = generate_atom_config_for_plugin_mode(vllm_config) - _prepare_env(atom_config = self.atom_config) + _prepare_env(atom_config=self.atom_config) model_arch = vllm_config.model_config.architectures[0] model_cls = _get_atom_model_cls(model_arch) diff --git a/atom/plugin/vllm/platform.py b/atom/plugin/vllm/platform.py index 08298f9f0..aaaa9657f 100644 --- a/atom/plugin/vllm/platform.py +++ b/atom/plugin/vllm/platform.py @@ -36,4 +36,3 @@ def get_attn_backend_cls(cls, selected_backend, attn_selector_config) -> str: # Return ATOM attention backend. logger.info("Use atom attention backend") return "atom.model_ops.attentions.aiter_attention.AiterBackend" - diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py index 2b1e0a24b..1dea4ff15 100644 --- a/atom/plugin/vllm/register.py +++ b/atom/plugin/vllm/register.py @@ -2,7 +2,6 @@ from typing import Optional import logging -import atom from atom.plugin.prepare import _set_framework_backbone logger = logging.getLogger("atom") @@ -24,6 +23,7 @@ "Qwen3MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, } + def _set_plugin_mode() -> None: _set_framework_backbone("vllm") From 31ccb16ce5e4a1ca99564460d38d8682dfd322b3 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 10 Feb 2026 10:09:59 +0800 Subject: [PATCH 23/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/vllm/model_wrapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 7304b1c59..a1d32fcf8 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -18,6 +18,7 @@ ) from vllm.sequence import IntermediateTensors +import atom # noqa: F401 from atom.plugin.config import generate_atom_config_for_plugin_mode import logging From f226b9550e022d7755334440ca30d53fd56fc6da Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 10 Feb 2026 10:36:19 +0800 Subject: [PATCH 24/37] add Signed-off-by: zejunchen-zejun --- atom/config.py | 16 ++++++++-------- atom/plugin/config.py | 1 + 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/atom/config.py b/atom/config.py index 6a5b80725..343253489 100644 --- a/atom/config.py +++ b/atom/config.py @@ -609,14 +609,14 @@ def __post_init__(self): self.kv_cache_block_size % 16 == 0 or self.kv_cache_block_size == 1 ), f"kv_cache_block_size ({self.kv_cache_block_size}) must be a multiple of 16 or 1" assert 1 <= self.tensor_parallel_size <= 8 - if is_plugin_mode(): - # plugin mode - assert ( - self.plugin_config is not None - ), "plugin_config is required in plugin mode" - self.hf_config = self.plugin_config.model_config.hf_config - else: - self.hf_config = get_hf_config(self.model) + # if is_plugin_mode(): + # # plugin mode + # assert ( + # self.plugin_config is not None + # ), "plugin_config is required in plugin mode" + # self.hf_config = self.plugin_config.model_config.hf_config + # else: + self.hf_config = get_hf_config(self.model) if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 rope_params = getattr(self.hf_config, "rope_scaling", {}) diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 234dfb2d3..23592118e 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -38,6 +38,7 @@ class PluginConfig: def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: from atom.config import Config, CompilationConfig + print('[zejun] ATOM vllm plugin config = ', config, flush=True) vllm_model_config = config.model_config vllm_scheduler_config = config.scheduler_config From b553cd2f305667abec4953aba748c10c53a46d93 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 10 Feb 2026 10:38:11 +0800 Subject: [PATCH 25/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 23592118e..1bbc321f6 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -38,7 +38,6 @@ class PluginConfig: def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: from atom.config import Config, CompilationConfig - print('[zejun] ATOM vllm plugin config = ', config, flush=True) vllm_model_config = config.model_config vllm_scheduler_config = config.scheduler_config @@ -94,7 +93,7 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: ) return Config( - model=None, + model=config.model, max_num_batched_tokens=max_num_batched_tokens, max_num_seqs=vllm_scheduler_config.max_num_seqs, max_model_len=max_model_len, From e1e83d4aa65b32ead056de6b16b46417f1e6e5c9 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 10 Feb 2026 10:51:04 +0800 Subject: [PATCH 26/37] add Signed-off-by: zejunchen-zejun --- atom/config.py | 7 ------- atom/plugin/config.py | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/atom/config.py b/atom/config.py index 343253489..9ecbdb365 100644 --- a/atom/config.py +++ b/atom/config.py @@ -609,13 +609,6 @@ def __post_init__(self): self.kv_cache_block_size % 16 == 0 or self.kv_cache_block_size == 1 ), f"kv_cache_block_size ({self.kv_cache_block_size}) must be a multiple of 16 or 1" assert 1 <= self.tensor_parallel_size <= 8 - # if is_plugin_mode(): - # # plugin mode - # assert ( - # self.plugin_config is not None - # ), "plugin_config is required in plugin mode" - # self.hf_config = self.plugin_config.model_config.hf_config - # else: self.hf_config = get_hf_config(self.model) if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 1bbc321f6..bf1af0f62 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -93,7 +93,7 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: ) return Config( - model=config.model, + model=vllm_model_config.model, max_num_batched_tokens=max_num_batched_tokens, max_num_seqs=vllm_scheduler_config.max_num_seqs, max_model_len=max_model_len, From 484e17d1397d6a8c9592a3cf897f8842cc328648 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 10 Feb 2026 10:59:23 +0800 Subject: [PATCH 27/37] add Signed-off-by: zejunchen-zejun --- recipes/vLLM-ATOM-OOT-Plugin-Backend.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recipes/vLLM-ATOM-OOT-Plugin-Backend.md b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md index a0629e850..aa37c4244 100644 --- a/recipes/vLLM-ATOM-OOT-Plugin-Backend.md +++ b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md @@ -15,9 +15,10 @@ git checkout origin/zejun/plugin_for_atom_1223 pip install -e . 2>&1 | tee build.log ``` For AITER, there is no specific requirements, however, if you find any latest fusion kernels are missing, you may need to upgrade the AITER -Additionally, you may need to upgrade your triton version by: +Additionally, you may need to upgrade your triton and transformers by: ```bash pip install --upgrade triton +pip install transformers==5.0.0 ``` ### Launching server of vLLM with ATOM OOT Plugin Platform From 36b6fd393a9e46681638922d8d84f0b69fe75d24 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 10 Feb 2026 16:44:18 +0800 Subject: [PATCH 28/37] add Signed-off-by: zejunchen-zejun --- recipes/vLLM-ATOM-OOT-Plugin-Backend.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recipes/vLLM-ATOM-OOT-Plugin-Backend.md b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md index aa37c4244..6ba2d4e23 100644 --- a/recipes/vLLM-ATOM-OOT-Plugin-Backend.md +++ b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md @@ -15,10 +15,11 @@ git checkout origin/zejun/plugin_for_atom_1223 pip install -e . 2>&1 | tee build.log ``` For AITER, there is no specific requirements, however, if you find any latest fusion kernels are missing, you may need to upgrade the AITER -Additionally, you may need to upgrade your triton and transformers by: +Additionally, you may need to install some dependencies by: ```bash pip install --upgrade triton pip install transformers==5.0.0 +pip install git+https://github.com/foundation-model-stack/fastsafetensors.git ``` ### Launching server of vLLM with ATOM OOT Plugin Platform From b1fb7b662e80ced17daf512fd0fdc056cf51ed83 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 10 Feb 2026 16:54:30 +0800 Subject: [PATCH 29/37] add Signed-off-by: zejunchen-zejun --- recipes/vLLM-ATOM-OOT-Plugin-Backend.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/vLLM-ATOM-OOT-Plugin-Backend.md b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md index 6ba2d4e23..24e11b538 100644 --- a/recipes/vLLM-ATOM-OOT-Plugin-Backend.md +++ b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md @@ -54,11 +54,11 @@ vllm serve $model_path \ 2>&1 | tee log.serve.log & ``` -If you want to disable the ATOM OOT plugin platform, you can use below env flags. The default value is 0 +If you want to disable the ATOM OOT plugin platform and model register, you can use below env flags. The default value is 0 ```bash export ATOM_DISABLE_VLLM_PLUGIN=1 ``` -If you want to disable the ATOM Attention Backend, you can use below env flags. The default value is 0 +If you only want to disable the ATOM Attention Backend, you can use below env flags. The default value is 0 ```bash export ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1 ``` From 0f0bedc456bee285fd774f716c6cced685b0278e Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 10 Feb 2026 17:25:16 +0800 Subject: [PATCH 30/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/vllm/model_wrapper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index a1d32fcf8..bcca17ad1 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -33,10 +33,10 @@ def _get_atom_model_cls(model_arch: str) -> type: - try: + if model_arch is not None and model_arch in _ATOM_MODEL_CLASSES: model_ref = _ATOM_MODEL_CLASSES[model_arch] - except KeyError as e: - raise ValueError(f"Unsupported ATOM model architecture: {model_arch}") from e + else: + raise ValueError(f"The {model_arch} is not supported by ATOM OOT backend") module_path, class_name = model_ref.split(":", 1) return getattr(importlib.import_module(module_path), class_name) From a05111847b6dec612cd040155d5136c7fa0964e8 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 10 Feb 2026 18:20:01 +0800 Subject: [PATCH 31/37] add Signed-off-by: zejunchen-zejun --- recipes/vLLM-ATOM-OOT-Plugin-Backend.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/vLLM-ATOM-OOT-Plugin-Backend.md b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md index 24e11b538..504753d3e 100644 --- a/recipes/vLLM-ATOM-OOT-Plugin-Backend.md +++ b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md @@ -2,9 +2,9 @@ ATOM can work as the OOT plugin backend of vLLM. The OOT register mechanism is quite mature and most of accelerators have leveraged this design to register their devices into vLLM without any code changes in upper framework. ATOM follows this design and provide the layer/op and model implementations to vLLM. The frontend users can launch vLLM server like before and there is no need to specify any arguments. Meanwhile the ATOM platform can leverage most of the vLLM features and focus more on model- and kernel-level optimizations. For the overall design, here is a RFC to enable ATOM work as the OOT plugin platform of vLLM: https://github.com/ROCm/ATOM/issues/201 ## Preparing environment for vLLM with ATOM model backend -Pull the latest docker from vLLM official nightly docker for ROCm +Pull the vLLM official docker for ROCm. If you are using the vLLM nightly docker, there could be incompatible error because vLLM is changing its code and may break the class/module import in ATOM ```bash -docker pull rocm/vllm-dev:nightly +docker pull rocm/vllm-dev:nightly_main_20260118 ``` Then the ATOM should be installed. When the following PR merged, you can use ATOM main branch From a00f59e0ad33a50b7c210b08ed4bde5630b61b7e Mon Sep 17 00:00:00 2001 From: Guanbao Yu Date: Wed, 11 Feb 2026 19:00:58 +0800 Subject: [PATCH 32/37] register attn backend to sgl from ATOM --- atom/config.py | 26 +- atom/model_ops/radix_attention.py | 10 +- atom/models/qwen3_moe.py | 77 +- atom/plugin/attention_backend/__init__.py | 0 .../attention_backend/sgl_attn_backend.py | 981 ++++++++++++++++++ atom/plugin/register.py | 4 +- atom/utils/envs.py | 1 + 7 files changed, 1074 insertions(+), 25 deletions(-) create mode 100644 atom/plugin/attention_backend/__init__.py create mode 100644 atom/plugin/attention_backend/sgl_attn_backend.py diff --git a/atom/config.py b/atom/config.py index 9ecbdb365..b10168428 100644 --- a/atom/config.py +++ b/atom/config.py @@ -609,19 +609,27 @@ def __post_init__(self): self.kv_cache_block_size % 16 == 0 or self.kv_cache_block_size == 1 ), f"kv_cache_block_size ({self.kv_cache_block_size}) must be a multiple of 16 or 1" assert 1 <= self.tensor_parallel_size <= 8 - self.hf_config = get_hf_config(self.model) + if is_plugin_mode(): + # plugin mode + assert ( + self.plugin_config is not None + ), "plugin_config is required in plugin mode" + self.hf_config = self.plugin_config.model_config.hf_config + else: + self.hf_config = get_hf_config(self.model) + + self.generation_config = get_generation_config(self.model) + if self.generation_config is not None: + if ( + eos_ids := getattr(self.generation_config, "eos_token_id", None) + ) is not None: + self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 - rope_params = getattr(self.hf_config, "rope_scaling", {}) + rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} rope_params["rope_theta"] = self.hf_config.rope_theta + rope_params["rope_type"] = getattr(rope_params, "rope_type", "default") self.hf_config.rope_parameters = rope_params - - self.generation_config = get_generation_config(self.model) - if self.generation_config is not None: - if ( - eos_ids := getattr(self.generation_config, "eos_token_id", None) - ) is not None: - self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids self.quant_config = get_quant_config(self.hf_config) hf_config_max_position_embeddings = getattr( self.hf_config, "max_position_embeddings", 8192 diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index 25388b384..af5e94056 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -9,6 +9,7 @@ from .base_attention import BaseAttention from atom.plugin.prepare import is_plugin_mode, is_sglang from atom.models.utils import maybe_prefix +from atom.utils import envs class RadixAttention(BaseAttention): @@ -47,7 +48,6 @@ def __init__( prefix=prefix, **kwargs, ) - self.rotary_emb = rotary_emb if is_sglang(): from sglang.srt.layers.radix_attention import RadixAttention @@ -64,6 +64,8 @@ def __init__( raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" ) + # if True, save cache will be done in rope + self.use_aiter_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM def forward_impl_plugin_mode( self, @@ -82,10 +84,8 @@ def forward_impl_plugin_mode( # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" - if self.rotary_emb is not None: - assert positions is not None, "positions is required for ROPE" - query, key = self.rotary_emb(positions, query, key) - return self.attn(q=query, k=key, v=value, forward_batch=forward_batch) + # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM + return self.attn(query, key, value, forward_batch=forward_batch, save_kv_cache=not self.use_aiter_rope_fused_qknorm) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 97cec9829..91471c762 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -5,7 +5,7 @@ from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size # from atom.model_ops.rotary_embedding import get_rope -from aiter.rotary_embedding import get_rope +from aiter.rotary_embedding import get_rope, AiterFusedSetKVBufferArg from atom.config import Config, QuantizationConfig from atom.model_ops.activation import SiluAndMul @@ -30,6 +30,7 @@ from atom.utils.decorators import support_torch_compile from torch import nn from atom.model_loader.loader import load_model_in_plugin_mode +from atom.plugin.prepare import is_sglang # import torch.distributed as dist from transformers import PretrainedConfig @@ -38,7 +39,7 @@ ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION ) - +ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM class Qwen3MoeMLP(nn.Module): def __init__( @@ -224,6 +225,61 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.layer_num = layer_num + self.k_scale = torch.tensor([1.0], dtype=torch.float32) + self.v_scale = torch.tensor([1.0], dtype=torch.float32) + + def forward_sgl_plugin_mode( + self, + positions: torch.Tensor, + qkv: torch.Tensor, + **model_kwargs: dict[str, Any] | None, + ): + if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE: + forward_batch = model_kwargs.get("forward_batch", None) + assert forward_batch is not None, "forward_batch is required for sglang" + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(self.layer_num) + block_size = 1024 # Default fallback + if hasattr(forward_batch, 'attn_backend') and hasattr(forward_batch.attn_backend, 'page_size'): + block_size = forward_batch.attn_backend.page_size + elif hasattr(forward_batch.token_to_kv_pool, 'allocator') and hasattr(forward_batch.token_to_kv_pool.allocator, 'page_size'): + block_size = forward_batch.token_to_kv_pool.allocator.page_size + elif hasattr(forward_batch.token_to_kv_pool, 'page_size'): + block_size = forward_batch.token_to_kv_pool.page_size + x = 16 // k_buffer.element_size() + aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( + kv_cache = (k_buffer, v_buffer), + cache_loc = forward_batch.out_cache_loc, + k_scale = self.k_scale, + v_scale = self.v_scale, + return_kv = True, + use_shuffle_layout = True, + block_size = block_size, + x = x, + ) + q, k, v = self.rotary_emb( + qkv, + self.q_norm.weight, + self.k_norm.weight, + positions, + self.num_heads, + self.num_kv_heads, + self.q_norm.eps, + fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg, + ) + else: + q, k, v = torch.split( + qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 + ) + # Add qk-norm + q = self.q_norm(q) + k = self.k_norm(k) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn( + q, k, v, positions=positions, **model_kwargs + ) + return attn_output def forward( self, @@ -241,13 +297,16 @@ def forward( query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv ) else: - # Add qk-norm - q = self.q_norm(q) - k = self.k_norm(k) + if is_sglang(): + attn_output = self.forward_sgl_plugin_mode(positions, qkv, **model_kwargs) + else: + # Add qk-norm + q = self.q_norm(q) + k = self.k_norm(k) - attn_output = self.attn( - query=q, key=k, value=v, positions=positions, **model_kwargs - ) + attn_output = self.attn( + query=q, key=k, value=v, positions=positions, **model_kwargs + ) output = self.o_proj(attn_output) return output @@ -261,7 +320,7 @@ def __init__(self, atom_config=None, layer_num: int = 0, prefix: str = "") -> No self.hidden_size = config.hidden_size rope_params = config.rope_parameters rope_theta = rope_params["rope_theta"] - rope_scaling = rope_params + rope_scaling = None if rope_params["rope_type"] == "default" else rope_params kv_cache_dtype = atom_config.kv_cache_dtype max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # DecoderLayers are created with `make_layers` which passes the prefix diff --git a/atom/plugin/attention_backend/__init__.py b/atom/plugin/attention_backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py new file mode 100644 index 000000000..662c072bf --- /dev/null +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -0,0 +1,981 @@ +from __future__ import annotations + +""" +end to end attention solution with aiter kernels +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInput + +try: + from aiter import ( + flash_attn_varlen_func, + dtypes, + get_pa_metadata_info_v1, + get_pa_metadata_v1, + pa_fwd_asm, + pa_persistent_fwd, + ) +except ImportError: + print( + "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." + ) + +import triton +import triton.language as tl + +@triton.jit +def reshape_and_cache_shuffle_kernel( + key_ptr, # [num_tokens, num_kv_heads, head_size] + value_ptr, # [num_tokens, num_kv_heads, head_size] + key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x] + value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x] + slot_mapping_ptr, # [num_tokens] + k_scale_ptr, + v_scale_ptr, + x, + k_stride0, + v_stride0, + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE: tl.constexpr, + QUANT: tl.constexpr, +): + tid = tl.program_id(0) + head_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + src_offset_k = tid * k_stride0 + head_id * head_size + src_offset_v = tid * v_stride0 + head_id * head_size + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + dst_offset = ( + block_id * num_kv_heads * head_size * block_size + + head_id * head_size * block_size + ) + dst_k_shuffle_offset = ( + dst_offset + offset // x * block_size * x + block_offset * x + offset % x + ) + dst_v_shuffle_offset = ( + dst_offset + + block_offset // x * head_size * x + + offset * x + + block_offset % x + ) + k_val = tl.load(key_ptr + src_offset_k + offset) + v_val = tl.load(value_ptr + src_offset_v + offset) + if QUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = key_cache_ptr.type.element_ty + v_dtype = value_cache_ptr.type.element_ty + k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype) + v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype) + tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) + tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + +def reshape_and_cache_shuffle_triton( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scales: torch.Tensor, + v_scales: torch.Tensor, +): + num_tokens = slot_mapping.shape[0] + _, num_kv_heads, head_size = key.shape + num_blocks, block_size, _, _ = key_cache.shape + x = 16 // key_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=key_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=value_cache.dtype, + device="meta", + ) + new_key_cache = key_cache.view_as(k_cache_template) + new_value_cache = value_cache.view_as(v_cache_template) + QUANT = False + if kv_cache_dtype.startswith("fp8"): + QUANT = True + grid = ( + num_tokens, + num_kv_heads, + ) + reshape_and_cache_shuffle_kernel[grid]( + key, + value, + new_key_cache, + new_value_cache, + slot_mapping, + k_scales, + v_scales, + x, + key.stride(0), + value.stride(0), + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE=head_size, + QUANT=QUANT, + ) + +@dataclass +class ForwardMetadata: + # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode + kv_indptr: Optional[torch.Tensor] + kv_indices: Optional[torch.Tensor] + qo_indptr: Optional[torch.Tensor] + kv_last_page_len: Optional[torch.Tensor] + max_q_len: Optional[int] + max_kv_len: Optional[int] + page_table: Optional[torch.Tensor] + kv_lens: Optional[torch.Tensor] + # PA metadata for pa_persistent_fwd (only used in decode mode, non-MLA) + pa_metadata_qo_indptr: Optional[torch.Tensor] = None + pa_metadata_pages_kv_indptr: Optional[torch.Tensor] = None + pa_metadata_kv_indices: Optional[torch.Tensor] = None + pa_metadata_context_lens: Optional[torch.Tensor] = None + pa_metadata_max_qlen: Optional[int] = None + pa_metadata_tp_q_head_num: Optional[int] = None + # Prefill metadata for mha_batch_prefill_func (only used in prefill mode, non-MLA) + # prefill_pages_kv_indptr: Optional[torch.Tensor] = None + # prefill_kv_indices: Optional[torch.Tensor] = None + # prefill_kv_last_page_lens: Optional[torch.Tensor] = None + + +class ATOMAttnBackendForSgl(AiterAttnBackend): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): + super().__init__(model_runner, skip_prefill, kv_indptr_buf) + + mapping = getattr( + model_runner.token_to_kv_pool, "full_attention_layer_id_mapping", None + ) + + if isinstance(mapping, dict) and mapping: + first_full_attn_id = next(iter(mapping.keys())) + else: + first_full_attn_id = 0 + + self.q_dtype = model_runner.dtype # Save q dtype for pa_metadata building + + assert not self.use_mla, "MLA mode is not implemented yet in ATOMAttnBackendForSgl." + + # Pre-initialized qo_indptr for pa_persistent_fwd decode mode: [0, 1, 2, ..., max_bs] + # In decode mode, each sequence has 1 token, so this is always [0, 1, 2, ..., batch_size] + max_bs = model_runner.req_to_token_pool.size + self.pa_decode_qo_indptr = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=model_runner.device + ) + self.seq_lens = torch.zeros( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=model_runner.device + ) + # Pre-compute strided indices for page_table construction (used in both CUDA Graph and non-CUDA Graph modes) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=model_runner.device + ) + + if not self.use_mla: + # Pre-allocate buffers for pa_persistent_fwd (used in both CUDA graph and non-CUDA graph modes) + max_num_blocks_per_seq = (self.max_context_len + self.page_size - 1) // self.page_size + max_total_blocks = max_bs * max_num_blocks_per_seq + self.pa_kv_indices = torch.zeros( + max_total_blocks, dtype=torch.int32, device=self.device + ) + # Pre-allocate pa_kv_indptr buffer (similar to self.kv_indptr, but dedicated for pa_persistent_fwd) + self.pa_kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=self.device + ) + # Pre-initialized batch indices [0, 1, 2, ..., max_bs-1] for Triton kernel + self.pa_batch_indices = torch.arange( + 0, max_bs, dtype=torch.int32, device=self.device + ) + + # Pre-allocated descale tensors for FP8 attention (q, k, v all use scale=1.0) + + + self.logits_soft_cap = 0.0 + + self.forward_metadata: ForwardMetadata = None + + self.pa_metadata_buffers = None + + k_buffer, _ = model_runner.token_to_kv_pool.get_kv_buffer(first_full_attn_id) + num_slots, num_kv_heads, _ = k_buffer.shape + block_size = self.page_size + num_blocks = num_slots // block_size + max_total_tokens = num_blocks * block_size + self.k_qscale = torch.ones( + num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device + ) + self.v_qscale = torch.ones( + num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device + ) + self.decode_using_pa_ps = self.page_size == 1024 + + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init auxiliary variables for triton attention backend.""" + + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + spec_info = forward_batch.spec_info + # qo_indptr = None + # kv_last_page_len = None + # max_q_len = None + page_table = None + + if forward_batch.forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 + + if self.use_mla: + raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + else: + if self.decode_using_pa_ps: + # Non-MLA decode mode: use same logic as CUDA Graph mode for page_table construction + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + + # Common setup consistent with CUDA Graph mode (init_forward_metadata_replay_cuda_graph) + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(forward_batch.seq_lens, non_blocking=True) + max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + page_table = self.req_to_token[forward_batch.req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] + page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) + else: + page_table = forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :] + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + None, # qo_indptr not used in non-MLA mode + None, # kv_last_page_len not used in non-MLA mode + 1, # max_q_len = 1 for decode mode + None, + page_table_persistent[:bs, :max_seq_pages] if self.decode_using_pa_ps else page_table, + seq_lens_persistent[:bs] if self.decode_using_pa_ps else forward_batch.seq_lens, + ) + + # Build pa_metadata for pa_persistent_fwd + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + # return # Early return for non-MLA decode mode + else: + prefix_lens = forward_batch.extend_prefix_lens + + if self.use_mla: + raise NotImplementedError("MLA prefill mode is not implemented yet in ATOMAttnBackendForSgl.") + else: + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + # Get page_table for mha_batch_prefill_func + page_table = forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :] + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + self.qo_indptr[: bs + 1], # qo_indptr is set by indices_updater_prefill + None, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + None, + forward_batch.seq_lens, + ) + + if (forward_batch.forward_mode.is_extend() and + not self.use_mla and + self.forward_metadata.page_table is not None): + if self.page_size > 1: + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + self.forward_metadata.page_table = ( + self.forward_metadata.page_table[:, self.strided_indices[:max_seq_pages]] // self.page_size + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_prefill(forward_batch.batch_size) + if not self.decode_using_pa_ps and self.page_size > 1 and self.forward_metadata.page_table is not None: + self.forward_metadata.page_table = ( + self.forward_metadata.page_table[:, self.strided_indices] // self.page_size + ) + + def _allocate_pa_metadata_buffers( + self, + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ): + """Allocate or reuse pa_metadata buffers.""" + if self.pa_metadata_buffers is None: + self.pa_metadata_buffers = {} + + def _get_size_val(size): + return size[0] if isinstance(size, tuple) else size + + # Allocate work_metadata_ptrs + size_val = _get_size_val(work_metadata_ptrs_size) + if ("work_metadata_ptrs" not in self.pa_metadata_buffers or + self.pa_metadata_buffers["work_metadata_ptrs"].shape[0] < size_val): + self.pa_metadata_buffers["work_metadata_ptrs"] = torch.empty( + work_metadata_ptrs_size, dtype=work_metadata_ptrs_type, device=self.device + ) + + # Allocate work_indptr + size_val = _get_size_val(work_indptr_size) + if ("work_indptr" not in self.pa_metadata_buffers or + self.pa_metadata_buffers["work_indptr"].shape[0] < size_val): + self.pa_metadata_buffers["work_indptr"] = torch.zeros( + work_indptr_size, dtype=work_indptr_type, device=self.device + ) + else: + self.pa_metadata_buffers["work_indptr"].zero_() + + # Allocate work_info + size_val = _get_size_val(work_info_size) + if ("work_info" not in self.pa_metadata_buffers or + len(self.pa_metadata_buffers["work_info"].shape) < len(work_info_size) or + self.pa_metadata_buffers["work_info"].shape[0] < size_val): + self.pa_metadata_buffers["work_info"] = torch.zeros( + work_info_size, dtype=work_info_type, device=self.device + ) + else: + self.pa_metadata_buffers["work_info"].zero_() + + # Allocate reduce_indptr + size_val = _get_size_val(reduce_indptr_size) + if ("reduce_indptr" not in self.pa_metadata_buffers or + self.pa_metadata_buffers["reduce_indptr"].shape[0] < size_val): + self.pa_metadata_buffers["reduce_indptr"] = torch.zeros( + reduce_indptr_size, dtype=reduce_indptr_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_indptr"].zero_() + + # Allocate reduce_final_map + size_val = _get_size_val(reduce_final_map_size) + if ("reduce_final_map" not in self.pa_metadata_buffers or + len(self.pa_metadata_buffers["reduce_final_map"].shape) < len(reduce_final_map_size) or + self.pa_metadata_buffers["reduce_final_map"].shape[0] < size_val): + self.pa_metadata_buffers["reduce_final_map"] = torch.zeros( + reduce_final_map_size, dtype=reduce_final_map_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_final_map"].zero_() + + # Allocate reduce_partial_map + reduce_partial_map_size_val = reduce_partial_map_size if isinstance(reduce_partial_map_size, int) else reduce_partial_map_size[0] + if ("reduce_partial_map" not in self.pa_metadata_buffers or + self.pa_metadata_buffers["reduce_partial_map"].shape[0] < reduce_partial_map_size_val): + self.pa_metadata_buffers["reduce_partial_map"] = torch.zeros( + reduce_partial_map_size, dtype=reduce_partial_map_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_partial_map"].zero_() + + def _build_pa_metadata_for_decode( + self, + batch_size: int, + tp_q_head_num: Optional[int] = None, + ): + """Build pa_metadata buffers for pa_persistent_fwd in decode mode. + + This method prepares all metadata buffers needed for pa_persistent_fwd kernel. + The metadata can be reused across multiple layers in the same forward pass. + + Args: + batch_size: Batch size for the current forward pass + tp_q_head_num: Number of Q heads per TP rank. If None, uses self.num_head. + """ + max_qlen = 1 + + # Use provided tp_q_head_num or default to self.num_head + if tp_q_head_num is None: + tp_q_head_num = self.num_head + + # kv_dtype_for_metadata = dtypes.fp8 + ( + (work_metadata_ptrs_size, work_metadata_ptrs_type), + (work_indptr_size, work_indptr_type), + (work_info_size, work_info_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_pa_metadata_info_v1( + batch_size, + self.num_kv_head, + ) + # Allocate metadata buffers with reuse optimization for multi-layer forward passes + self._allocate_pa_metadata_buffers( + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ) + qo_indptr = self.pa_decode_qo_indptr[: batch_size + 1] + + # Get context_lens (kv_lens is always set before calling _build_pa_metadata_for_decode) + # Note: kv_lens comes from self.seq_lens which is already int32 + context_lens = self.forward_metadata.kv_lens + + kernel_block_size = self.page_size + num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size + # Use dedicated pa_kv_indptr buffer (similar to self.kv_indptr, but for pa_persistent_fwd) + pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + # Convert page_table to kv_indices (block indices) using Triton kernel to avoid sync + # page_table shape: [batch_size, max_num_blocks_per_seq] + # Note: page_table comes from self.page_table which is already int32 and always set before this call + page_table = self.forward_metadata.page_table + + # Use Triton kernel to gather kv_indices from page_table (avoids high-level indexing sync) + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + self.pa_batch_indices[:batch_size], # [0, 1, 2, ..., batch_size-1] + num_blocks_per_seq, + pages_kv_indptr, + None, # kv_start_idx + self.pa_kv_indices, + page_table.stride(0), + ) + # Use the full buffer - pa_persistent_fwd reads only valid elements based on pages_kv_indptr + kv_indices = self.pa_kv_indices + + get_pa_metadata_v1( + seqlens_qo_indptr=qo_indptr, + pages_kv_indptr=pages_kv_indptr, + context_lens=context_lens.int(), + num_heads_per_head_k=tp_q_head_num // self.num_kv_head, + num_heads_k=self.num_kv_head, + is_causal=True, + work_metadata_ptrs=self.pa_metadata_buffers["work_metadata_ptrs"], + work_indptr=self.pa_metadata_buffers["work_indptr"], + work_info=self.pa_metadata_buffers["work_info"], + reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], + kv_granularity=max(kernel_block_size, 16), + block_size=kernel_block_size, + max_seqlen_qo=max_qlen, + uni_seqlen_qo=max_qlen, + fast_mode=True, + topk=-1, + max_split_per_batch=-1, + ) + # Store computed values in ForwardMetadata for reuse in forward_decode + self.forward_metadata.pa_metadata_qo_indptr = qo_indptr + self.forward_metadata.pa_metadata_pages_kv_indptr = pages_kv_indptr + self.forward_metadata.pa_metadata_kv_indices = kv_indices + self.forward_metadata.pa_metadata_context_lens = context_lens + self.forward_metadata.pa_metadata_max_qlen = max_qlen + self.forward_metadata.pa_metadata_tp_q_head_num = tp_q_head_num + + def _build_pa_metadata_for_prefill(self, batch_size: int): + """Build metadata for mha_batch_prefill_func in prefill mode. + + This method prepares page-level metadata needed for mha_batch_prefill_func. + The metadata is computed once per forward pass and reused across all layers. + """ + block_size = self.page_size + context_lens = self.forward_metadata.kv_lens + num_blocks_per_seq = (context_lens + block_size - 1) // block_size + + # Page-level kv_indptr (reuse pa_kv_indptr buffer) + pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + # Build kv_indices from page_table using triton kernel + page_table = self.forward_metadata.page_table + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + self.pa_batch_indices[:batch_size], + num_blocks_per_seq, + pages_kv_indptr, + None, # kv_start_idx + self.pa_kv_indices, + page_table.stride(0), + ) + # kv_indices = self.pa_kv_indices + + # Compute kv_last_page_lens for each sequence + # kv_last_page_lens = ((context_lens - 1) % block_size + 1).int() + + # Store in ForwardMetadata for reuse in forward_extend + # self.forward_metadata.prefill_pages_kv_indptr = pages_kv_indptr + # self.forward_metadata.prefill_kv_indices = kv_indices + # self.forward_metadata.prefill_kv_last_page_lens = kv_last_page_lens + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, + ): + self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int) + if kv_indices_buf is None: + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_kv_indices = kv_indices_buf + + # Always use preshuffle layout for pa_fwd_asm + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=self.device + ) + self.seq_lens = torch.zeros( + (max_bs,), dtype=torch.int32, device=self.device + ) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ) + + # Pre-allocate buffers for pa_metadata in CUDA graph mode (non-MLA decode) + if self.decode_using_pa_ps and not self.use_mla: + # Pre-allocate pa_metadata buffers for CUDA graph compatibility + # These buffers will be reused in capture and replay phases + # Use max_bs and max_qlen=1 (decode mode) to calculate buffer sizes + # max_qlen = 1 # decode mode + # kv_dtype_for_metadata = dtypes.fp8 + ( + (work_metadata_ptrs_size, work_metadata_ptrs_type), + (work_indptr_size, work_indptr_type), + (work_info_size, work_info_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_pa_metadata_info_v1( + max_bs, + self.num_kv_head, + ) + + # Pre-allocate buffers with maximum size for CUDA graph compatibility + self._allocate_pa_metadata_buffers( + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + ): + if forward_mode.is_decode_or_idle(): + if self.use_mla: + # MLA mode: kv_indptr and kv_indices are used in forward_decode + raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + else: + # Non-MLA decode mode: kv_indptr and kv_indices are NOT used in forward_decode + # (forward_decode uses pa_metadata_pages_kv_indptr and pa_metadata_kv_indices instead) + page_table = self.page_table[:bs, :] + self.seq_lens[:bs].copy_(seq_lens, non_blocking=True) + seq_lens_persistent = self.seq_lens[:bs] + self.forward_metadata = ForwardMetadata( + None, # kv_indptr not used in non-MLA decode mode + None, # kv_indices not used in non-MLA decode mode + None, # qo_indptr will be set by _build_pa_metadata_for_decode + None, # kv_last_page_len not used in non-MLA mode + 1, # max_q_len = 1 for decode mode + None, # max_kv_len + page_table, + seq_lens_persistent, + ) + + # Build pa_metadata using CUDA graph buffers + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + return # Early return for non-MLA decode mode + else: + raise ValueError(f"Invalid mode: {forward_mode=}") + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: Optional[torch.Tensor] = None, + ): + if forward_mode.is_decode_or_idle(): + # Common setup for both MLA and non-MLA modes + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) + max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + page_table = self.req_to_token[req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] + page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) + + if self.use_mla: + # MLA mode: kv_indptr and kv_indices are used in forward_decode + raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + else: + # Non-MLA decode mode: kv_indptr and kv_indices are NOT used in forward_decode + # (forward_decode uses pa_metadata_pages_kv_indptr and pa_metadata_kv_indices instead) + self.forward_metadata = ForwardMetadata( + None, # kv_indptr not used in non-MLA decode mode + None, # kv_indices not used in non-MLA decode mode + None, + None, # kv_last_page_len not used in non-MLA mode + 1, # max_q_len = 1 for decode mode, non-MTP + None, # max_kv_len + page_table_persistent[:bs, :max_seq_pages], + seq_lens_persistent[:bs], + ) + + # Rebuild pa_metadata using CUDA graph buffers (updates content, keeps same addresses) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + else: + raise ValueError("Invalid forward mode") + + def set_kv_buffer_with_layout_shuffle( + self, + cache_loc, + k, + v, + k_buffer, + v_buffer, + k_scale, + v_scale, + block_size, + ): + num_slots, num_kv_heads, head_dim = k_buffer.shape + num_blocks = num_slots // block_size + num_slots_with_block = num_blocks * block_size + k_buffer = k_buffer[:num_slots_with_block].view(num_blocks, block_size, num_kv_heads, head_dim) + v_buffer = v_buffer[:num_slots_with_block].view(num_blocks, block_size, num_kv_heads, head_dim) + reshape_and_cache_shuffle_triton( + k, + v, + k_buffer, + v_buffer, + cache_loc, + "auto", + k_scale, + v_scale, + ) + + def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True): + # print(f"Running forward_extend with q shape {q.shape}, k shape {k.shape}, v shape {v.shape}", flush=True) + # print(f"q dtype: {q.dtype}, k dtype: {k.dtype}, v dtype: {v.dtype}", flush=True) + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + + self.logits_soft_cap = layer.logit_cap + + if k is not None: + assert v is not None + if save_kv_cache: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle(cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) + # forward_batch.token_to_kv_pool.set_kv_buffer( + # layer, cache_loc, k, v, layer.k_scale, layer.v_scale + # ) + + seqlens_in_batch = forward_batch.seq_lens + cu_seqlens_q = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + # use fp8 mha directly + if q.dtype != k.dtype and k.dtype == dtypes.fp8: + q = q.to(dtypes.fp8) + o = flash_attn_varlen_func( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim), + v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=self.forward_metadata.max_q_len, + max_seqlen_k=self.forward_metadata.max_kv_len, + min_seqlen_q=0, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1, 0), + sink_ptr=None, + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + + def forward_decode_pa( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle(forward_batch.out_cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) + + if self.use_mla: + raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + + block_size = self.page_size + num_slots, num_kv_heads, head_size = k_buffer.shape + num_blocks = num_slots // block_size + k_buffer = k_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + v_buffer = v_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + + x = 16 // k_buffer.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_buffer.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_buffer.dtype, + device="meta", + ) + new_key_cache = k_buffer.view_as(k_cache_template) + new_value_cache = v_buffer.view_as(v_cache_template) + q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + pa_fwd_asm( + Q=q, + K=new_key_cache, + V=new_value_cache, + block_tables=self.forward_metadata.page_table, + context_lens=self.forward_metadata.kv_lens, + block_tables_stride0=self.forward_metadata.page_table.stride(0), + K_QScale=self.k_scale, + V_QScale=self.v_scale, + out_=o, + ) + return o + + def forward_decode_pa_ps( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + # Create o as 3D tensor [batch_size, num_heads, head_dim] for both MLA and pa_fwd_asm + # In decode mode, q.shape[0] equals batch_size (each sequence has 1 token) + # Use q.shape[0] instead of forward_batch.batch_size to be safe + batch_size = q.shape[0] + head_dim_out = layer.v_head_dim if layer.qk_head_dim != layer.v_head_dim else layer.head_dim + o = q.new_empty((batch_size, layer.tp_q_head_num, head_dim_out)) + + if save_kv_cache: + if self.use_mla: + raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle(forward_batch.out_cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) + # Shuffle operation is already fused in rotary_emb, so just save directly + # forward_batch.token_to_kv_pool.set_kv_buffer( + # layer, forward_batch.out_cache_loc, k, v, layer.k_scale, layer.v_scale + # ) + + if self.use_mla: + raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + num_slots, num_kv_heads, head_size = k_buffer.shape + block_size = self.page_size + num_blocks = num_slots // block_size + k_buffer = k_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + v_buffer = v_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + + + quant_dtype = dtypes.fp8 + x = 16 // quant_dtype.itemsize + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_buffer.dtype, + device="meta", + ) + # V: [num_blocks, block_size, num_kv_heads, head_size] -> [num_blocks, num_kv_heads, block_size // x, head_size, x] + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_buffer.dtype, + device="meta", + ) + new_key_cache = k_buffer.view_as(k_cache_template) + new_value_cache = v_buffer.view_as(v_cache_template) + + total_tokens = num_blocks * block_size + k_qscale = self.k_qscale[:, :total_tokens] + v_qscale = self.v_qscale[:, :total_tokens] + + q = q.view(batch_size, layer.tp_q_head_num, layer.head_dim) + + + assert self.forward_metadata.pa_metadata_qo_indptr is not None, "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" + assert self.forward_metadata.pa_metadata_pages_kv_indptr is not None, "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" + assert self.forward_metadata.pa_metadata_kv_indices is not None, "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" + assert self.forward_metadata.pa_metadata_context_lens is not None, "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" + assert self.forward_metadata.pa_metadata_max_qlen is not None, "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" + + qo_indptr = self.forward_metadata.pa_metadata_qo_indptr + kv_indptr = self.forward_metadata.pa_metadata_pages_kv_indptr + kv_indices = self.forward_metadata.pa_metadata_kv_indices + context_lens = self.forward_metadata.pa_metadata_context_lens + max_qlen = self.forward_metadata.pa_metadata_max_qlen + + + _, _ = pa_persistent_fwd( + Q=q, + K=new_key_cache, + V=new_value_cache, + output=o, + max_qlen=max_qlen, + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + kv_indices=kv_indices, + context_lens=context_lens, + work_indptr=self.pa_metadata_buffers["work_indptr"], + work_info=self.pa_metadata_buffers["work_info"], + reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], + K_QScale=k_qscale, + V_QScale=v_qscale, + softmax_scale=layer.scaling, + mask=1, + ) + return o.view(-1, layer.tp_q_head_num * head_dim_out) + + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + if self.use_mla: + raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + else: + if self.decode_using_pa_ps: + return self.forward_decode_pa_ps(q, k, v, layer, forward_batch, save_kv_cache) + else: + return self.forward_decode_pa(q, k, v, layer, forward_batch, save_kv_cache) + + + + + \ No newline at end of file diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 08d94b13d..b7e27c352 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -26,9 +26,9 @@ def _register_custom_attention_to_sglang() -> None: @register_attention_backend("aiter") def create_atom_backend(runner): - from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + from atom.plugin.attention_backend.sgl_attn_backend import ATOMAttnBackendForSgl - return AiterAttnBackend(runner) + return ATOMAttnBackendForSgl(runner) def register_ops_to_sglang(atom_config: Config) -> None: diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 62ce11bb5..3c9100c54 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -42,6 +42,7 @@ "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT", "1" ) == "1", + "ATOM_ROPE_FUSED_QKNORM": lambda: os.getenv("AITER_ROPE_FUSED_QKNORM", "0") == "1", } From 8491ef7cd979fa46d1107ec51e3d140c659e28ba Mon Sep 17 00:00:00 2001 From: Guanbao Yu Date: Wed, 11 Feb 2026 19:03:24 +0800 Subject: [PATCH 33/37] make format happy --- atom/config.py | 4 +- atom/model_ops/attentions/aiter_mla.py | 1 - atom/model_ops/radix_attention.py | 8 +- atom/models/qwen3_moe.py | 39 +- .../attention_backend/sgl_attn_backend.py | 397 ++++++++++++------ 5 files changed, 299 insertions(+), 150 deletions(-) diff --git a/atom/config.py b/atom/config.py index b10168428..6b08eb7b8 100644 --- a/atom/config.py +++ b/atom/config.py @@ -623,7 +623,9 @@ def __post_init__(self): if ( eos_ids := getattr(self.generation_config, "eos_token_id", None) ) is not None: - self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids + self.stop_token_ids = ( + [eos_ids] if isinstance(eos_ids, int) else eos_ids + ) if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 6520818e9..af5df28be 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -import itertools import logging from typing import Type diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index af5e94056..34fdf0f90 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -85,7 +85,13 @@ def forward_impl_plugin_mode( forward_batch = kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM - return self.attn(query, key, value, forward_batch=forward_batch, save_kv_cache=not self.use_aiter_rope_fused_qknorm) + return self.attn( + query, + key, + value, + forward_batch=forward_batch, + save_kv_cache=not self.use_aiter_rope_fused_qknorm, + ) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 91471c762..c629e91ef 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -41,6 +41,7 @@ ) ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM + class Qwen3MoeMLP(nn.Module): def __init__( self, @@ -237,24 +238,30 @@ def forward_sgl_plugin_mode( if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE: forward_batch = model_kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(self.layer_num) + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + self.layer_num + ) block_size = 1024 # Default fallback - if hasattr(forward_batch, 'attn_backend') and hasattr(forward_batch.attn_backend, 'page_size'): + if hasattr(forward_batch, "attn_backend") and hasattr( + forward_batch.attn_backend, "page_size" + ): block_size = forward_batch.attn_backend.page_size - elif hasattr(forward_batch.token_to_kv_pool, 'allocator') and hasattr(forward_batch.token_to_kv_pool.allocator, 'page_size'): + elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr( + forward_batch.token_to_kv_pool.allocator, "page_size" + ): block_size = forward_batch.token_to_kv_pool.allocator.page_size - elif hasattr(forward_batch.token_to_kv_pool, 'page_size'): + elif hasattr(forward_batch.token_to_kv_pool, "page_size"): block_size = forward_batch.token_to_kv_pool.page_size x = 16 // k_buffer.element_size() aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( - kv_cache = (k_buffer, v_buffer), - cache_loc = forward_batch.out_cache_loc, - k_scale = self.k_scale, - v_scale = self.v_scale, - return_kv = True, - use_shuffle_layout = True, - block_size = block_size, - x = x, + kv_cache=(k_buffer, v_buffer), + cache_loc=forward_batch.out_cache_loc, + k_scale=self.k_scale, + v_scale=self.v_scale, + return_kv=True, + use_shuffle_layout=True, + block_size=block_size, + x=x, ) q, k, v = self.rotary_emb( qkv, @@ -276,9 +283,7 @@ def forward_sgl_plugin_mode( q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn( - q, k, v, positions=positions, **model_kwargs - ) + attn_output = self.attn(q, k, v, positions=positions, **model_kwargs) return attn_output def forward( @@ -298,7 +303,9 @@ def forward( ) else: if is_sglang(): - attn_output = self.forward_sgl_plugin_mode(positions, qkv, **model_kwargs) + attn_output = self.forward_sgl_plugin_mode( + positions, qkv, **model_kwargs + ) else: # Add qk-norm q = self.q_norm(q) diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py index 662c072bf..87ed05af9 100644 --- a/atom/plugin/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -35,6 +35,7 @@ import triton import triton.language as tl + @triton.jit def reshape_and_cache_shuffle_kernel( key_ptr, # [num_tokens, num_kv_heads, head_size] @@ -71,10 +72,7 @@ def reshape_and_cache_shuffle_kernel( dst_offset + offset // x * block_size * x + block_offset * x + offset % x ) dst_v_shuffle_offset = ( - dst_offset - + block_offset // x * head_size * x - + offset * x - + block_offset % x + dst_offset + block_offset // x * head_size * x + offset * x + block_offset % x ) k_val = tl.load(key_ptr + src_offset_k + offset) v_val = tl.load(value_ptr + src_offset_v + offset) @@ -88,6 +86,7 @@ def reshape_and_cache_shuffle_kernel( tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + def reshape_and_cache_shuffle_triton( key: torch.Tensor, value: torch.Tensor, @@ -139,6 +138,7 @@ def reshape_and_cache_shuffle_triton( QUANT=QUANT, ) + @dataclass class ForwardMetadata: # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode @@ -183,7 +183,9 @@ def __init__( self.q_dtype = model_runner.dtype # Save q dtype for pa_metadata building - assert not self.use_mla, "MLA mode is not implemented yet in ATOMAttnBackendForSgl." + assert ( + not self.use_mla + ), "MLA mode is not implemented yet in ATOMAttnBackendForSgl." # Pre-initialized qo_indptr for pa_persistent_fwd decode mode: [0, 1, 2, ..., max_bs] # In decode mode, each sequence has 1 token, so this is always [0, 1, 2, ..., batch_size] @@ -195,7 +197,9 @@ def __init__( (max_bs,), dtype=torch.int32, device=model_runner.device ) self.page_table = torch.zeros( - (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=model_runner.device + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=model_runner.device, ) # Pre-compute strided indices for page_table construction (used in both CUDA Graph and non-CUDA Graph modes) self.strided_indices = torch.arange( @@ -204,7 +208,9 @@ def __init__( if not self.use_mla: # Pre-allocate buffers for pa_persistent_fwd (used in both CUDA graph and non-CUDA graph modes) - max_num_blocks_per_seq = (self.max_context_len + self.page_size - 1) // self.page_size + max_num_blocks_per_seq = ( + self.max_context_len + self.page_size - 1 + ) // self.page_size max_total_blocks = max_bs * max_num_blocks_per_seq self.pa_kv_indices = torch.zeros( max_total_blocks, dtype=torch.int32, device=self.device @@ -220,13 +226,12 @@ def __init__( # Pre-allocated descale tensors for FP8 attention (q, k, v all use scale=1.0) - self.logits_soft_cap = 0.0 self.forward_metadata: ForwardMetadata = None - + self.pa_metadata_buffers = None - + k_buffer, _ = model_runner.token_to_kv_pool.get_kv_buffer(first_full_attn_id) num_slots, num_kv_heads, _ = k_buffer.shape block_size = self.page_size @@ -239,7 +244,6 @@ def __init__( num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device ) self.decode_using_pa_ps = self.page_size == 1024 - def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" @@ -273,46 +277,69 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): bs = kv_indptr.shape[0] - 1 if self.use_mla: - raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) else: if self.decode_using_pa_ps: # Non-MLA decode mode: use same logic as CUDA Graph mode for page_table construction seq_lens_cpu = forward_batch.seq_lens_cpu if seq_lens_cpu is None: seq_lens_cpu = forward_batch.seq_lens.cpu() - + # Common setup consistent with CUDA Graph mode (init_forward_metadata_replay_cuda_graph) page_table_persistent = self.page_table seq_lens_persistent = self.seq_lens seq_lens_persistent.fill_(0) page_table_persistent.fill_(0) - seq_lens_persistent[:bs].copy_(forward_batch.seq_lens, non_blocking=True) - max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 - page_table = self.req_to_token[forward_batch.req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] - page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) + seq_lens_persistent[:bs].copy_( + forward_batch.seq_lens, non_blocking=True + ) + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + page_table = self.req_to_token[ + forward_batch.req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table // self.page_size, non_blocking=True + ) else: - page_table = forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :] - + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, None, # qo_indptr not used in non-MLA mode None, # kv_last_page_len not used in non-MLA mode - 1, # max_q_len = 1 for decode mode + 1, # max_q_len = 1 for decode mode None, - page_table_persistent[:bs, :max_seq_pages] if self.decode_using_pa_ps else page_table, - seq_lens_persistent[:bs] if self.decode_using_pa_ps else forward_batch.seq_lens, + ( + page_table_persistent[:bs, :max_seq_pages] + if self.decode_using_pa_ps + else page_table + ), + ( + seq_lens_persistent[:bs] + if self.decode_using_pa_ps + else forward_batch.seq_lens + ), ) - + # Build pa_metadata for pa_persistent_fwd if self.decode_using_pa_ps: self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) - # return # Early return for non-MLA decode mode + # return # Early return for non-MLA decode mode else: prefix_lens = forward_batch.extend_prefix_lens if self.use_mla: - raise NotImplementedError("MLA prefill mode is not implemented yet in ATOMAttnBackendForSgl.") + raise NotImplementedError( + "MLA prefill mode is not implemented yet in ATOMAttnBackendForSgl." + ) else: self.indices_updater_prefill.update( forward_batch.req_pool_indices, @@ -323,11 +350,15 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): spec_info=None, ) # Get page_table for mha_batch_prefill_func - page_table = forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :] + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] self.forward_metadata = ForwardMetadata( self.indices_updater_prefill.kv_indptr, self.indices_updater_prefill.kv_indices, - self.qo_indptr[: bs + 1], # qo_indptr is set by indices_updater_prefill + self.qo_indptr[ + : bs + 1 + ], # qo_indptr is set by indices_updater_prefill None, self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_kv_len, @@ -335,22 +366,34 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.seq_lens, ) - if (forward_batch.forward_mode.is_extend() and - not self.use_mla and - self.forward_metadata.page_table is not None): + if ( + forward_batch.forward_mode.is_extend() + and not self.use_mla + and self.forward_metadata.page_table is not None + ): if self.page_size > 1: seq_lens_cpu = forward_batch.seq_lens_cpu if seq_lens_cpu is None: seq_lens_cpu = forward_batch.seq_lens.cpu() - max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 self.forward_metadata.page_table = ( - self.forward_metadata.page_table[:, self.strided_indices[:max_seq_pages]] // self.page_size + self.forward_metadata.page_table[ + :, self.strided_indices[:max_seq_pages] + ] + // self.page_size ) if self.decode_using_pa_ps: self._build_pa_metadata_for_prefill(forward_batch.batch_size) - if not self.decode_using_pa_ps and self.page_size > 1 and self.forward_metadata.page_table is not None: + if ( + not self.decode_using_pa_ps + and self.page_size > 1 + and self.forward_metadata.page_table is not None + ): self.forward_metadata.page_table = ( - self.forward_metadata.page_table[:, self.strided_indices] // self.page_size + self.forward_metadata.page_table[:, self.strided_indices] + // self.page_size ) def _allocate_pa_metadata_buffers( @@ -371,90 +414,112 @@ def _allocate_pa_metadata_buffers( """Allocate or reuse pa_metadata buffers.""" if self.pa_metadata_buffers is None: self.pa_metadata_buffers = {} - + def _get_size_val(size): return size[0] if isinstance(size, tuple) else size - + # Allocate work_metadata_ptrs size_val = _get_size_val(work_metadata_ptrs_size) - if ("work_metadata_ptrs" not in self.pa_metadata_buffers or - self.pa_metadata_buffers["work_metadata_ptrs"].shape[0] < size_val): + if ( + "work_metadata_ptrs" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["work_metadata_ptrs"].shape[0] < size_val + ): self.pa_metadata_buffers["work_metadata_ptrs"] = torch.empty( - work_metadata_ptrs_size, dtype=work_metadata_ptrs_type, device=self.device + work_metadata_ptrs_size, + dtype=work_metadata_ptrs_type, + device=self.device, ) - + # Allocate work_indptr size_val = _get_size_val(work_indptr_size) - if ("work_indptr" not in self.pa_metadata_buffers or - self.pa_metadata_buffers["work_indptr"].shape[0] < size_val): + if ( + "work_indptr" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["work_indptr"].shape[0] < size_val + ): self.pa_metadata_buffers["work_indptr"] = torch.zeros( work_indptr_size, dtype=work_indptr_type, device=self.device ) else: self.pa_metadata_buffers["work_indptr"].zero_() - + # Allocate work_info size_val = _get_size_val(work_info_size) - if ("work_info" not in self.pa_metadata_buffers or - len(self.pa_metadata_buffers["work_info"].shape) < len(work_info_size) or - self.pa_metadata_buffers["work_info"].shape[0] < size_val): + if ( + "work_info" not in self.pa_metadata_buffers + or len(self.pa_metadata_buffers["work_info"].shape) < len(work_info_size) + or self.pa_metadata_buffers["work_info"].shape[0] < size_val + ): self.pa_metadata_buffers["work_info"] = torch.zeros( work_info_size, dtype=work_info_type, device=self.device ) else: self.pa_metadata_buffers["work_info"].zero_() - + # Allocate reduce_indptr size_val = _get_size_val(reduce_indptr_size) - if ("reduce_indptr" not in self.pa_metadata_buffers or - self.pa_metadata_buffers["reduce_indptr"].shape[0] < size_val): + if ( + "reduce_indptr" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["reduce_indptr"].shape[0] < size_val + ): self.pa_metadata_buffers["reduce_indptr"] = torch.zeros( reduce_indptr_size, dtype=reduce_indptr_type, device=self.device ) else: self.pa_metadata_buffers["reduce_indptr"].zero_() - + # Allocate reduce_final_map size_val = _get_size_val(reduce_final_map_size) - if ("reduce_final_map" not in self.pa_metadata_buffers or - len(self.pa_metadata_buffers["reduce_final_map"].shape) < len(reduce_final_map_size) or - self.pa_metadata_buffers["reduce_final_map"].shape[0] < size_val): + if ( + "reduce_final_map" not in self.pa_metadata_buffers + or len(self.pa_metadata_buffers["reduce_final_map"].shape) + < len(reduce_final_map_size) + or self.pa_metadata_buffers["reduce_final_map"].shape[0] < size_val + ): self.pa_metadata_buffers["reduce_final_map"] = torch.zeros( reduce_final_map_size, dtype=reduce_final_map_type, device=self.device ) else: self.pa_metadata_buffers["reduce_final_map"].zero_() - + # Allocate reduce_partial_map - reduce_partial_map_size_val = reduce_partial_map_size if isinstance(reduce_partial_map_size, int) else reduce_partial_map_size[0] - if ("reduce_partial_map" not in self.pa_metadata_buffers or - self.pa_metadata_buffers["reduce_partial_map"].shape[0] < reduce_partial_map_size_val): + reduce_partial_map_size_val = ( + reduce_partial_map_size + if isinstance(reduce_partial_map_size, int) + else reduce_partial_map_size[0] + ) + if ( + "reduce_partial_map" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["reduce_partial_map"].shape[0] + < reduce_partial_map_size_val + ): self.pa_metadata_buffers["reduce_partial_map"] = torch.zeros( - reduce_partial_map_size, dtype=reduce_partial_map_type, device=self.device + reduce_partial_map_size, + dtype=reduce_partial_map_type, + device=self.device, ) else: self.pa_metadata_buffers["reduce_partial_map"].zero_() def _build_pa_metadata_for_decode( - self, - batch_size: int, + self, + batch_size: int, tp_q_head_num: Optional[int] = None, ): """Build pa_metadata buffers for pa_persistent_fwd in decode mode. - + This method prepares all metadata buffers needed for pa_persistent_fwd kernel. The metadata can be reused across multiple layers in the same forward pass. - + Args: batch_size: Batch size for the current forward pass tp_q_head_num: Number of Q heads per TP rank. If None, uses self.num_head. """ max_qlen = 1 - + # Use provided tp_q_head_num or default to self.num_head if tp_q_head_num is None: tp_q_head_num = self.num_head - + # kv_dtype_for_metadata = dtypes.fp8 ( (work_metadata_ptrs_size, work_metadata_ptrs_type), @@ -483,22 +548,22 @@ def _build_pa_metadata_for_decode( reduce_partial_map_type, ) qo_indptr = self.pa_decode_qo_indptr[: batch_size + 1] - + # Get context_lens (kv_lens is always set before calling _build_pa_metadata_for_decode) # Note: kv_lens comes from self.seq_lens which is already int32 context_lens = self.forward_metadata.kv_lens - + kernel_block_size = self.page_size num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size # Use dedicated pa_kv_indptr buffer (similar to self.kv_indptr, but for pa_persistent_fwd) pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) - + # Convert page_table to kv_indices (block indices) using Triton kernel to avoid sync # page_table shape: [batch_size, max_num_blocks_per_seq] # Note: page_table comes from self.page_table which is already int32 and always set before this call page_table = self.forward_metadata.page_table - + # Use Triton kernel to gather kv_indices from page_table (avoids high-level indexing sync) create_flashinfer_kv_indices_triton[(batch_size,)]( page_table, @@ -554,7 +619,7 @@ def _build_pa_metadata_for_prefill(self, batch_size: int): # Page-level kv_indptr (reuse pa_kv_indptr buffer) pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) - + # Build kv_indices from page_table using triton kernel page_table = self.forward_metadata.page_table create_flashinfer_kv_indices_triton[(batch_size,)]( @@ -594,15 +659,15 @@ def init_cuda_graph_state( # Always use preshuffle layout for pa_fwd_asm self.page_table = torch.zeros( - (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=self.device - ) - self.seq_lens = torch.zeros( - (max_bs,), dtype=torch.int32, device=self.device + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=self.device, ) + self.seq_lens = torch.zeros((max_bs,), dtype=torch.int32, device=self.device) self.strided_indices = torch.arange( 0, self.max_context_len, self.page_size, device=self.device ) - + # Pre-allocate buffers for pa_metadata in CUDA graph mode (non-MLA decode) if self.decode_using_pa_ps and not self.use_mla: # Pre-allocate pa_metadata buffers for CUDA graph compatibility @@ -621,7 +686,7 @@ def init_cuda_graph_state( max_bs, self.num_kv_head, ) - + # Pre-allocate buffers with maximum size for CUDA graph compatibility self._allocate_pa_metadata_buffers( work_metadata_ptrs_size, @@ -651,7 +716,9 @@ def init_forward_metadata_capture_cuda_graph( if forward_mode.is_decode_or_idle(): if self.use_mla: # MLA mode: kv_indptr and kv_indices are used in forward_decode - raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) else: # Non-MLA decode mode: kv_indptr and kv_indices are NOT used in forward_decode # (forward_decode uses pa_metadata_pages_kv_indptr and pa_metadata_kv_indices instead) @@ -668,7 +735,7 @@ def init_forward_metadata_capture_cuda_graph( page_table, seq_lens_persistent, ) - + # Build pa_metadata using CUDA graph buffers if self.decode_using_pa_ps: self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) @@ -695,27 +762,36 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_persistent.fill_(0) page_table_persistent.fill_(0) seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) - max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 - page_table = self.req_to_token[req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] - page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) - + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + page_table = self.req_to_token[ + req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table // self.page_size, non_blocking=True + ) + if self.use_mla: # MLA mode: kv_indptr and kv_indices are used in forward_decode - raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) else: # Non-MLA decode mode: kv_indptr and kv_indices are NOT used in forward_decode # (forward_decode uses pa_metadata_pages_kv_indptr and pa_metadata_kv_indices instead) self.forward_metadata = ForwardMetadata( None, # kv_indptr not used in non-MLA decode mode None, # kv_indices not used in non-MLA decode mode - None, + None, None, # kv_last_page_len not used in non-MLA mode 1, # max_q_len = 1 for decode mode, non-MTP None, # max_kv_len page_table_persistent[:bs, :max_seq_pages], seq_lens_persistent[:bs], ) - + # Rebuild pa_metadata using CUDA graph buffers (updates content, keeps same addresses) if self.decode_using_pa_ps: self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) @@ -736,8 +812,12 @@ def set_kv_buffer_with_layout_shuffle( num_slots, num_kv_heads, head_dim = k_buffer.shape num_blocks = num_slots // block_size num_slots_with_block = num_blocks * block_size - k_buffer = k_buffer[:num_slots_with_block].view(num_blocks, block_size, num_kv_heads, head_dim) - v_buffer = v_buffer[:num_slots_with_block].view(num_blocks, block_size, num_kv_heads, head_dim) + k_buffer = k_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + v_buffer = v_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) reshape_and_cache_shuffle_triton( k, v, @@ -766,7 +846,16 @@ def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True): k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id ) - self.set_kv_buffer_with_layout_shuffle(cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) + self.set_kv_buffer_with_layout_shuffle( + cache_loc, + k, + v, + k_buffer, + v_buffer, + layer.k_scale, + layer.v_scale, + self.page_size, + ) # forward_batch.token_to_kv_pool.set_kv_buffer( # layer, cache_loc, k, v, layer.k_scale, layer.v_scale # ) @@ -795,7 +884,6 @@ def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True): ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) - def forward_decode_pa( self, @@ -817,18 +905,35 @@ def forward_decode_pa( k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id ) - self.set_kv_buffer_with_layout_shuffle(forward_batch.out_cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) + self.set_kv_buffer_with_layout_shuffle( + forward_batch.out_cache_loc, + k, + v, + k_buffer, + v_buffer, + layer.k_scale, + layer.v_scale, + self.page_size, + ) if self.use_mla: - raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) else: - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) block_size = self.page_size num_slots, num_kv_heads, head_size = k_buffer.shape num_blocks = num_slots // block_size - k_buffer = k_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) - v_buffer = v_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + k_buffer = k_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_buffer = v_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) x = 16 // k_buffer.element_size() k_cache_template = torch.empty( @@ -872,32 +977,54 @@ def forward_decode_pa_ps( # In decode mode, q.shape[0] equals batch_size (each sequence has 1 token) # Use q.shape[0] instead of forward_batch.batch_size to be safe batch_size = q.shape[0] - head_dim_out = layer.v_head_dim if layer.qk_head_dim != layer.v_head_dim else layer.head_dim + head_dim_out = ( + layer.v_head_dim + if layer.qk_head_dim != layer.v_head_dim + else layer.head_dim + ) o = q.new_empty((batch_size, layer.tp_q_head_num, head_dim_out)) if save_kv_cache: if self.use_mla: - raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) else: k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id ) - self.set_kv_buffer_with_layout_shuffle(forward_batch.out_cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) + self.set_kv_buffer_with_layout_shuffle( + forward_batch.out_cache_loc, + k, + v, + k_buffer, + v_buffer, + layer.k_scale, + layer.v_scale, + self.page_size, + ) # Shuffle operation is already fused in rotary_emb, so just save directly # forward_batch.token_to_kv_pool.set_kv_buffer( # layer, forward_batch.out_cache_loc, k, v, layer.k_scale, layer.v_scale # ) if self.use_mla: - raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) else: - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) num_slots, num_kv_heads, head_size = k_buffer.shape block_size = self.page_size num_blocks = num_slots // block_size - k_buffer = k_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) - v_buffer = v_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) - + k_buffer = k_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_buffer = v_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) quant_dtype = dtypes.fp8 x = 16 // quant_dtype.itemsize @@ -914,27 +1041,35 @@ def forward_decode_pa_ps( ) new_key_cache = k_buffer.view_as(k_cache_template) new_value_cache = v_buffer.view_as(v_cache_template) - + total_tokens = num_blocks * block_size k_qscale = self.k_qscale[:, :total_tokens] v_qscale = self.v_qscale[:, :total_tokens] - + q = q.view(batch_size, layer.tp_q_head_num, layer.head_dim) - - - assert self.forward_metadata.pa_metadata_qo_indptr is not None, "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" - assert self.forward_metadata.pa_metadata_pages_kv_indptr is not None, "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" - assert self.forward_metadata.pa_metadata_kv_indices is not None, "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" - assert self.forward_metadata.pa_metadata_context_lens is not None, "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" - assert self.forward_metadata.pa_metadata_max_qlen is not None, "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" - + + assert ( + self.forward_metadata.pa_metadata_qo_indptr is not None + ), "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_pages_kv_indptr is not None + ), "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_kv_indices is not None + ), "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_context_lens is not None + ), "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_max_qlen is not None + ), "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" + qo_indptr = self.forward_metadata.pa_metadata_qo_indptr kv_indptr = self.forward_metadata.pa_metadata_pages_kv_indptr kv_indices = self.forward_metadata.pa_metadata_kv_indices context_lens = self.forward_metadata.pa_metadata_context_lens max_qlen = self.forward_metadata.pa_metadata_max_qlen - - + _, _ = pa_persistent_fwd( Q=q, K=new_key_cache, @@ -953,29 +1088,29 @@ def forward_decode_pa_ps( K_QScale=k_qscale, V_QScale=v_qscale, softmax_scale=layer.scaling, - mask=1, + mask=1, ) return o.view(-1, layer.tp_q_head_num * head_dim_out) - def forward_decode( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer: RadixAttention, - forward_batch: ForwardBatch, - save_kv_cache=True, - ): + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): if self.use_mla: - raise NotImplementedError("MLA decode mode is not implemented yet in ATOMAttnBackendForSgl.") + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) else: if self.decode_using_pa_ps: - return self.forward_decode_pa_ps(q, k, v, layer, forward_batch, save_kv_cache) + return self.forward_decode_pa_ps( + q, k, v, layer, forward_batch, save_kv_cache + ) else: - return self.forward_decode_pa(q, k, v, layer, forward_batch, save_kv_cache) - - - - - \ No newline at end of file + return self.forward_decode_pa( + q, k, v, layer, forward_batch, save_kv_cache + ) From 9eb1c19d35fe22d6e5f9c675f31fa129d8596ca5 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Wed, 25 Feb 2026 10:31:00 +0800 Subject: [PATCH 34/37] add Signed-off-by: zejunchen-zejun --- atom/config.py | 27 ++++++++------------------ atom/model_ops/attentions/aiter_mla.py | 1 + atom/model_ops/radix_attention.py | 6 +++--- atom/models/qwen3_moe.py | 4 ++-- atom/utils/envs.py | 2 +- 5 files changed, 15 insertions(+), 25 deletions(-) diff --git a/atom/config.py b/atom/config.py index 6b08eb7b8..cc1bb9696 100644 --- a/atom/config.py +++ b/atom/config.py @@ -609,29 +609,18 @@ def __post_init__(self): self.kv_cache_block_size % 16 == 0 or self.kv_cache_block_size == 1 ), f"kv_cache_block_size ({self.kv_cache_block_size}) must be a multiple of 16 or 1" assert 1 <= self.tensor_parallel_size <= 8 - if is_plugin_mode(): - # plugin mode - assert ( - self.plugin_config is not None - ), "plugin_config is required in plugin mode" - self.hf_config = self.plugin_config.model_config.hf_config - else: - self.hf_config = get_hf_config(self.model) - - self.generation_config = get_generation_config(self.model) - if self.generation_config is not None: - if ( - eos_ids := getattr(self.generation_config, "eos_token_id", None) - ) is not None: - self.stop_token_ids = ( - [eos_ids] if isinstance(eos_ids, int) else eos_ids - ) + self.hf_config = get_hf_config(self.model) if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 - rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} + rope_params = getattr(self.hf_config, "rope_scaling", {}) rope_params["rope_theta"] = self.hf_config.rope_theta - rope_params["rope_type"] = getattr(rope_params, "rope_type", "default") self.hf_config.rope_parameters = rope_params + self.generation_config = get_generation_config(self.model) + if self.generation_config is not None: + if ( + eos_ids := getattr(self.generation_config, "eos_token_id", None) + ) is not None: + self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids self.quant_config = get_quant_config(self.hf_config) hf_config_max_position_embeddings = getattr( self.hf_config, "max_position_embeddings", 8192 diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index af5df28be..6520818e9 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import itertools import logging from typing import Type diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index 34fdf0f90..b25e1aaba 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -65,7 +65,7 @@ def __init__( "RadixAttention is only supported for plugin mode for sglang for now" ) # if True, save cache will be done in rope - self.use_aiter_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM + self.use_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE def forward_impl_plugin_mode( self, @@ -84,13 +84,13 @@ def forward_impl_plugin_mode( # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" - # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM + save_kv_cache = not self.use_rope_fused_qknorm return self.attn( query, key, value, forward_batch=forward_batch, - save_kv_cache=not self.use_aiter_rope_fused_qknorm, + save_kv_cache=save_kv_cache, ) else: raise NotImplementedError( diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index c629e91ef..51927fe2c 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -39,7 +39,7 @@ ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION ) -ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM +ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE class Qwen3MoeMLP(nn.Module): @@ -327,7 +327,7 @@ def __init__(self, atom_config=None, layer_num: int = 0, prefix: str = "") -> No self.hidden_size = config.hidden_size rope_params = config.rope_parameters rope_theta = rope_params["rope_theta"] - rope_scaling = None if rope_params["rope_type"] == "default" else rope_params + rope_scaling = rope_params kv_cache_dtype = atom_config.kv_cache_dtype max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # DecoderLayers are created with `make_layers` which passes the prefix diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 3c9100c54..0f6b77760 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -42,7 +42,7 @@ "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT", "1" ) == "1", - "ATOM_ROPE_FUSED_QKNORM": lambda: os.getenv("AITER_ROPE_FUSED_QKNORM", "0") == "1", + "ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE": lambda: os.getenv("ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE", "0") == "1", } From 6d14b84682f58caf4a1426c8f4edf56a6b36fc34 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Wed, 25 Feb 2026 10:43:08 +0800 Subject: [PATCH 35/37] add Signed-off-by: zejunchen-zejun --- atom/config.py | 1 + .../attention_backend/sgl_attn_backend.py | 20 +++++++++---------- atom/plugin/register.py | 8 ++++---- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/atom/config.py b/atom/config.py index cc1bb9696..9ecbdb365 100644 --- a/atom/config.py +++ b/atom/config.py @@ -615,6 +615,7 @@ def __post_init__(self): rope_params = getattr(self.hf_config, "rope_scaling", {}) rope_params["rope_theta"] = self.hf_config.rope_theta self.hf_config.rope_parameters = rope_params + self.generation_config = get_generation_config(self.model) if self.generation_config is not None: if ( diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py index 87ed05af9..2b17b5acf 100644 --- a/atom/plugin/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -163,7 +163,7 @@ class ForwardMetadata: # prefill_kv_last_page_lens: Optional[torch.Tensor] = None -class ATOMAttnBackendForSgl(AiterAttnBackend): +class ATOMAttnBackendForSGLPluginMode(AiterAttnBackend): def __init__( self, model_runner: ModelRunner, @@ -185,7 +185,7 @@ def __init__( assert ( not self.use_mla - ), "MLA mode is not implemented yet in ATOMAttnBackendForSgl." + ), "MLA mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." # Pre-initialized qo_indptr for pa_persistent_fwd decode mode: [0, 1, 2, ..., max_bs] # In decode mode, each sequence has 1 token, so this is always [0, 1, 2, ..., batch_size] @@ -278,7 +278,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): if self.use_mla: raise NotImplementedError( - "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." ) else: if self.decode_using_pa_ps: @@ -338,7 +338,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): if self.use_mla: raise NotImplementedError( - "MLA prefill mode is not implemented yet in ATOMAttnBackendForSgl." + "MLA prefill mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." ) else: self.indices_updater_prefill.update( @@ -717,7 +717,7 @@ def init_forward_metadata_capture_cuda_graph( if self.use_mla: # MLA mode: kv_indptr and kv_indices are used in forward_decode raise NotImplementedError( - "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." ) else: # Non-MLA decode mode: kv_indptr and kv_indices are NOT used in forward_decode @@ -776,7 +776,7 @@ def init_forward_metadata_replay_cuda_graph( if self.use_mla: # MLA mode: kv_indptr and kv_indices are used in forward_decode raise NotImplementedError( - "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." ) else: # Non-MLA decode mode: kv_indptr and kv_indices are NOT used in forward_decode @@ -918,7 +918,7 @@ def forward_decode_pa( if self.use_mla: raise NotImplementedError( - "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." ) else: k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( @@ -987,7 +987,7 @@ def forward_decode_pa_ps( if save_kv_cache: if self.use_mla: raise NotImplementedError( - "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." ) else: k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( @@ -1010,7 +1010,7 @@ def forward_decode_pa_ps( if self.use_mla: raise NotImplementedError( - "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." ) else: k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( @@ -1103,7 +1103,7 @@ def forward_decode( ): if self.use_mla: raise NotImplementedError( - "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." ) else: if self.decode_using_pa_ps: diff --git a/atom/plugin/register.py b/atom/plugin/register.py index b7e27c352..70c26affb 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -19,16 +19,16 @@ def _register_custom_attention_to_sglang() -> None: register_attention_backend, ) + logger.info("Register custom attention backend to SGLang") + # here register the custom attention backend with the name "aiter" # as sglang defines the fixed attention backend choices, which must be # in-tree - logger.info("Register custom attention backend AiterBackend to SGLang") - @register_attention_backend("aiter") def create_atom_backend(runner): - from atom.plugin.attention_backend.sgl_attn_backend import ATOMAttnBackendForSgl + from atom.plugin.attention_backend.sgl_attn_backend import ATOMAttnBackendForSGLPluginMode - return ATOMAttnBackendForSgl(runner) + return ATOMAttnBackendForSGLPluginMode(runner) def register_ops_to_sglang(atom_config: Config) -> None: From 2c0a44a406d3df2220039f85062765405902be26 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Wed, 25 Feb 2026 16:07:20 +0800 Subject: [PATCH 36/37] add Signed-off-by: zejunchen-zejun --- atom/model_ops/radix_attention.py | 84 +++++++++++++++++++++------ atom/models/qwen3_moe.py | 95 ++++++------------------------- atom/utils/envs.py | 1 - 3 files changed, 82 insertions(+), 98 deletions(-) diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index b25e1aaba..6ebe7a562 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -11,6 +11,7 @@ from atom.models.utils import maybe_prefix from atom.utils import envs +from aiter.rotary_embedding import AiterFusedSetKVBufferArg class RadixAttention(BaseAttention): """ @@ -50,22 +51,28 @@ def __init__( ) if is_sglang(): - from sglang.srt.layers.radix_attention import RadixAttention + self.rotary_emb = rotary_emb + self.layer_num = layer_num + + self.k_scale = torch.tensor([1.0], dtype=torch.float32) + self.v_scale = torch.tensor([1.0], dtype=torch.float32) + # if True, save cache will be done in rope + self.use_rope_fused_qknorm = envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION + + from sglang.srt.layers.radix_attention import RadixAttention self.attn = RadixAttention( num_heads=num_heads, head_dim=head_dim, scaling=scale, num_kv_heads=num_kv_heads, - layer_id=layer_num, + layer_id=self.layer_num, prefix=maybe_prefix(prefix, "attn"), ) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" ) - # if True, save cache will be done in rope - self.use_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE def forward_impl_plugin_mode( self, @@ -78,24 +85,63 @@ def forward_impl_plugin_mode( output_block_scale: torch.Tensor | None = None, positions: torch.Tensor = None, q_scale: torch.Tensor = None, + qkv: torch.Tensor = None, **kwargs, ): - if is_sglang(): - # for sglang, forward_batch is required - forward_batch = kwargs.get("forward_batch", None) - assert forward_batch is not None, "forward_batch is required for sglang" - save_kv_cache = not self.use_rope_fused_qknorm - return self.attn( - query, - key, - value, - forward_batch=forward_batch, - save_kv_cache=save_kv_cache, + # for sglang, forward_batch is required + forward_batch = kwargs.get("forward_batch", None) + assert forward_batch is not None, "forward_batch is required for sglang" + + if self.use_rope_fused_qknorm: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + self.layer_num ) - else: - raise NotImplementedError( - "RadixAttention is only supported for plugin mode for sglang for now" + block_size = 1024 # Default fallback + if hasattr(forward_batch, "attn_backend") and hasattr( + forward_batch.attn_backend, "page_size" + ): + block_size = forward_batch.attn_backend.page_size + elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr( + forward_batch.token_to_kv_pool.allocator, "page_size" + ): + block_size = forward_batch.token_to_kv_pool.allocator.page_size + elif hasattr(forward_batch.token_to_kv_pool, "page_size"): + block_size = forward_batch.token_to_kv_pool.page_size + x = 16 // k_buffer.element_size() + aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( + kv_cache=(k_buffer, v_buffer), + cache_loc=forward_batch.out_cache_loc, + k_scale=self.k_scale, + v_scale=self.v_scale, + return_kv=True, + use_shuffle_layout=True, + block_size=block_size, + x=x, + ) + q, k, v = self.rotary_emb( + qkv, + self.q_norm.weight, + self.k_norm.weight, + positions, + self.num_heads, + self.num_kv_heads, + self.q_norm.eps, + fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg, ) + else: + # calculate the q and k with rotary embedding + assert self.rotary_emb is not None, "rotary_emb is required" + q, k = self.rotary_emb(positions, q, k) + v = value + + save_kv_cache = not self.use_rope_fused_qknorm + return self.attn( + q, + k, + v, + forward_batch=forward_batch, + save_kv_cache=save_kv_cache, + ) def forward( self, @@ -104,6 +150,7 @@ def forward( value: torch.Tensor, positions: torch.Tensor = None, q_scale: Optional[torch.Tensor] = None, + qkv: torch.Tensor = None, **kwargs, ): if is_plugin_mode(): @@ -113,6 +160,7 @@ def forward( value=value, positions=positions, q_scale=q_scale, + qkv=qkv, **kwargs, ) else: diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 51927fe2c..a5fa065a9 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -5,7 +5,7 @@ from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size # from atom.model_ops.rotary_embedding import get_rope -from aiter.rotary_embedding import get_rope, AiterFusedSetKVBufferArg +from aiter.rotary_embedding import get_rope from atom.config import Config, QuantizationConfig from atom.model_ops.activation import SiluAndMul @@ -30,7 +30,6 @@ from atom.utils.decorators import support_torch_compile from torch import nn from atom.model_loader.loader import load_model_in_plugin_mode -from atom.plugin.prepare import is_sglang # import torch.distributed as dist from transformers import PretrainedConfig @@ -39,7 +38,6 @@ ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION ) -ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE class Qwen3MoeMLP(nn.Module): @@ -226,65 +224,6 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.layer_num = layer_num - self.k_scale = torch.tensor([1.0], dtype=torch.float32) - self.v_scale = torch.tensor([1.0], dtype=torch.float32) - - def forward_sgl_plugin_mode( - self, - positions: torch.Tensor, - qkv: torch.Tensor, - **model_kwargs: dict[str, Any] | None, - ): - if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE: - forward_batch = model_kwargs.get("forward_batch", None) - assert forward_batch is not None, "forward_batch is required for sglang" - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( - self.layer_num - ) - block_size = 1024 # Default fallback - if hasattr(forward_batch, "attn_backend") and hasattr( - forward_batch.attn_backend, "page_size" - ): - block_size = forward_batch.attn_backend.page_size - elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr( - forward_batch.token_to_kv_pool.allocator, "page_size" - ): - block_size = forward_batch.token_to_kv_pool.allocator.page_size - elif hasattr(forward_batch.token_to_kv_pool, "page_size"): - block_size = forward_batch.token_to_kv_pool.page_size - x = 16 // k_buffer.element_size() - aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( - kv_cache=(k_buffer, v_buffer), - cache_loc=forward_batch.out_cache_loc, - k_scale=self.k_scale, - v_scale=self.v_scale, - return_kv=True, - use_shuffle_layout=True, - block_size=block_size, - x=x, - ) - q, k, v = self.rotary_emb( - qkv, - self.q_norm.weight, - self.k_norm.weight, - positions, - self.num_heads, - self.num_kv_heads, - self.q_norm.eps, - fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg, - ) - else: - q, k, v = torch.split( - qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 - ) - # Add qk-norm - q = self.q_norm(q) - k = self.k_norm(k) - - q, k = self.rotary_emb(positions, q, k) - - attn_output = self.attn(q, k, v, positions=positions, **model_kwargs) - return attn_output def forward( self, @@ -295,25 +234,23 @@ def forward( qkv = self.qkv_proj(hidden_states) q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: - q, k, v = torch.split( - qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 - ) - attn_output = self.attn( - query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv - ) + attn_output = self.attn(query=q, + key=k, + value=v, + positions=positions, + q_scale=None, + qkv=qkv, + **model_kwargs) else: - if is_sglang(): - attn_output = self.forward_sgl_plugin_mode( - positions, qkv, **model_kwargs - ) - else: - # Add qk-norm - q = self.q_norm(q) - k = self.k_norm(k) + # Add qk-norm + q = self.q_norm(q) + k = self.k_norm(k) - attn_output = self.attn( - query=q, key=k, value=v, positions=positions, **model_kwargs - ) + attn_output = self.attn(query=q, + key=k, + value=v, + positions=positions, + **model_kwargs) output = self.o_proj(attn_output) return output diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 0f6b77760..62ce11bb5 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -42,7 +42,6 @@ "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT", "1" ) == "1", - "ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE": lambda: os.getenv("ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE", "0") == "1", } From 949868491462c4f357d4ba17583d2288e95f3046 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Wed, 25 Feb 2026 21:28:19 +0800 Subject: [PATCH 37/37] add Signed-off-by: zejunchen-zejun --- atom/plugin/register.py | 2 +- atom/plugin/{attention_backend => sglang}/__init__.py | 0 atom/plugin/{attention_backend => sglang}/sgl_attn_backend.py | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename atom/plugin/{attention_backend => sglang}/__init__.py (100%) rename atom/plugin/{attention_backend => sglang}/sgl_attn_backend.py (100%) diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 70c26affb..cd4f6f9f0 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -26,7 +26,7 @@ def _register_custom_attention_to_sglang() -> None: # in-tree @register_attention_backend("aiter") def create_atom_backend(runner): - from atom.plugin.attention_backend.sgl_attn_backend import ATOMAttnBackendForSGLPluginMode + from atom.plugin.sglang.sgl_attn_backend import ATOMAttnBackendForSGLPluginMode return ATOMAttnBackendForSGLPluginMode(runner) diff --git a/atom/plugin/attention_backend/__init__.py b/atom/plugin/sglang/__init__.py similarity index 100% rename from atom/plugin/attention_backend/__init__.py rename to atom/plugin/sglang/__init__.py diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/sglang/sgl_attn_backend.py similarity index 100% rename from atom/plugin/attention_backend/sgl_attn_backend.py rename to atom/plugin/sglang/sgl_attn_backend.py