From 73644881ea3686030b4136077b0cef35be46c48c Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Thu, 15 Jan 2026 15:31:07 +0800 Subject: [PATCH] [feat][plugin] Make ATOM work as plugin for upper framework Signed-off-by: zejunchen-zejun --- atom/__init__.py | 9 + atom/config.py | 37 +- atom/model_engine/model_runner.py | 8 +- atom/model_loader/loader.py | 60 +- atom/model_ops/__init__.py | 14 + atom/model_ops/attention_mha.py | 61 +- atom/model_ops/attention_mla.py | 4 +- atom/model_ops/attentions/aiter_attention.py | 40 +- atom/model_ops/attentions/gdn_attn.py | 2 +- atom/model_ops/base_attention.py | 252 +++++-- atom/model_ops/embed_head.py | 16 +- atom/model_ops/moe.py | 2 + atom/model_ops/paged_attention.py | 191 +++++ atom/model_ops/radix_attention.py | 116 +++ atom/models/qwen3.py | 125 +++- atom/models/qwen3_moe.py | 114 +-- atom/plugin/__init__.py | 13 + atom/plugin/attention.py | 653 ++++++++++++++++ atom/plugin/attention_mha.py | 739 +++++++++++++++++++ atom/plugin/config.py | 244 ++++++ atom/plugin/moe.py | 57 ++ atom/plugin/prepare.py | 85 +++ atom/plugin/register.py | 97 +++ atom/plugin/vllm/__init__.py | 5 + atom/plugin/vllm/model_wrapper.py | 141 ++++ atom/plugin/vllm/platform.py | 37 + atom/plugin/vllm/register.py | 110 +++ atom/utils/backends.py | 22 +- atom/utils/forward_context.py | 11 +- pyproject.toml | 15 +- recipes/SGLang-ATOM-Model-Impl-Backend.md | 72 ++ recipes/vLLM-ATOM-OOT-Plugin-Backend.md | 89 +++ 32 files changed, 3262 insertions(+), 179 deletions(-) create mode 100644 atom/model_ops/__init__.py create mode 100644 atom/model_ops/paged_attention.py create mode 100644 atom/model_ops/radix_attention.py create mode 100644 atom/plugin/__init__.py create mode 100644 atom/plugin/attention.py create mode 100644 atom/plugin/attention_mha.py create mode 100644 atom/plugin/config.py create mode 100644 atom/plugin/moe.py create mode 100644 atom/plugin/prepare.py create mode 100644 atom/plugin/register.py create mode 100644 atom/plugin/vllm/__init__.py create mode 100644 atom/plugin/vllm/model_wrapper.py create mode 100644 atom/plugin/vllm/platform.py create mode 100644 atom/plugin/vllm/register.py create mode 100644 recipes/SGLang-ATOM-Model-Impl-Backend.md create mode 100644 recipes/vLLM-ATOM-OOT-Plugin-Backend.md diff --git a/atom/__init__.py b/atom/__init__.py index c1f9ed8b2..7c1c75eb3 100644 --- a/atom/__init__.py +++ b/atom/__init__.py @@ -3,3 +3,12 @@ from atom.model_engine.llm_engine import LLMEngine from atom.sampling_params import SamplingParams + +# interface for upper framework to construct the model from ATOM +from atom.plugin import prepare_model + +__all__ = [ + "LLMEngine", + "SamplingParams", + "prepare_model", +] diff --git a/atom/config.py b/atom/config.py index 8ef9be071..e273fb73c 100644 --- a/atom/config.py +++ b/atom/config.py @@ -17,6 +17,10 @@ from torch.distributed import ProcessGroup, ReduceOp from transformers import AutoConfig, GenerationConfig, PretrainedConfig +# plugin-related utilities +from atom.plugin import is_plugin_mode +from atom.plugin.config import PluginConfig + logger = logging.getLogger("atom") @@ -598,6 +602,9 @@ class Config: torch_dtype: torch.dtype = field(init=False) speculative_config: Optional[SpeculativeConfig] = None + # only use for plugin mode + plugin_config: Optional[PluginConfig] = None + def _set_cudagraph_sizes(self): if self.compilation_config.cudagraph_capture_sizes: self.graph_bs = self.compilation_config.cudagraph_capture_sizes @@ -621,6 +628,7 @@ def __post_init__(self): if rope_params is None: rope_params = {} rope_params["rope_theta"] = getattr(self.hf_config, "rope_theta", None) + rope_params["rope_type"] = getattr(self.hf_config, "rope_type", "default") self.hf_config.rope_parameters = rope_params self.generation_config = get_generation_config(self.model) @@ -640,16 +648,25 @@ def __post_init__(self): self.max_model_len, hf_config_max_position_embeddings ) # assert self.max_num_batched_tokens >= self.max_model_len - if self.torch_profiler_dir is not None: - os.makedirs(self.torch_profiler_dir, exist_ok=True) - assert self.torch_profiler_dir is None or os.path.isdir( - self.torch_profiler_dir - ), f"torch_profiler_dir {self.torch_profiler_dir} is not a valid directory" - if self.compilation_config.level == CompilationLevel.PIECEWISE: - self.compilation_config.set_splitting_ops_for_v1() - self._set_cudagraph_sizes() - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - self.compilation_config.init_with_cudagraph_sizes() + if not is_plugin_mode(): + if self.torch_profiler_dir is not None: + os.makedirs(self.torch_profiler_dir, exist_ok=True) + assert self.torch_profiler_dir is None or os.path.isdir( + self.torch_profiler_dir + ), f"torch_profiler_dir {self.torch_profiler_dir} is not a valid directory" + + # only for server mode or plugin mode(vllm) + # for torch compile policy, plugin mode(vllm) uses the ATOM compile policy + # for cuda graph capture, plugin mode(vllm) uses the vLLM's cuda graph capture policy + if not is_plugin_mode() or ( + self.plugin_config is not None and self.plugin_config.is_vllm + ): + if self.compilation_config.level == CompilationLevel.PIECEWISE: + self.compilation_config.set_splitting_ops_for_v1() + self._set_cudagraph_sizes() + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + self.compilation_config.init_with_cudagraph_sizes() + self.torch_dtype = ( self.hf_config.dtype if getattr(self.hf_config, "dtype", None) is not None diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index f110e8bf1..adf68e320 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -575,7 +575,9 @@ def __init__(self, rank: int, config: Config): self.drafter.load_model(self.model) torch.set_default_device(self.device) self.allocate_forward_vars() - self.attn_metadata_builder = self.attn_backend.get_builder_cls()(self) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + model_runner=self + ) self.physical_block_size = self.attn_metadata_builder.block_size self.forward_done_event = torch.cuda.Event() self.warmup_model() @@ -1251,7 +1253,7 @@ def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None): self.forward_vars["cu_seqlens_q"].np[scheduled_bs + 1 : bs + 1] = ( self.forward_vars["cu_seqlens_q"].np[scheduled_bs] ) - attn_metadata, positions = self.attn_metadata_builder.build(batch, bs) + attn_metadata, positions = self.attn_metadata_builder.build(batch=batch, bs=bs) context_bs = batch.total_seqs_num_prefill if is_prefill else scheduled_bs # graph_bs should be batch size (number of sequences), not token count @@ -1503,7 +1505,7 @@ def capture_cudagraph(self): ) attn_metadata, context = ( - self.attn_metadata_builder.build_for_cudagraph_capture(bs) + self.attn_metadata_builder.build_for_cudagraph_capture(bs=bs) ) num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index 090a653cf..d027b7c16 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -31,6 +31,8 @@ from aiter.dist.parallel_state import get_tp_group from atom.models.qwen3_next_mtp import remap_mtp_weight_name +from atom.plugin.prepare import is_sglang + logger = logging.getLogger("atom") @@ -81,12 +83,54 @@ def safetensors_weights_iterator( yield name, f.get_tensor(name) +# when plugin mode, model loader method is bind to model implementation +# thus call this interface to load the model, which leverages the load_model +# method +def load_model_in_plugin_mode( + model, + config, + prefix: str = "", +) -> set[str]: + + # during loading model, the outplace operation may consume more + # GPU mem, which cached in torch caching allocator, here actively + # call empty cache to free the extra reserved but not used memory + def _empty_cache(): + import gc + + gc.collect() + torch.cuda.empty_cache() + + assert ( + config.plugin_config is not None and config.plugin_config.is_plugin_mode + ), "ATOM is not running in plugin mode" + if config.plugin_config.is_vllm: + model_name_or_path = config.plugin_config.model_config.model + elif config.plugin_config.is_sglang: + model_name_or_path = config.plugin_config.model_config.model_path + + _empty_cache() + loaded_weights_record = load_model( + model=model, + model_name_or_path=model_name_or_path, + hf_config=config.hf_config, + load_dummy=config.load_dummy, + spec_decode=False, + prefix=prefix, + is_plugin_mode=True, + ) + _empty_cache() + return loaded_weights_record + + def load_model( model: nn.Module, model_name_or_path: str, hf_config: AutoConfig, load_dummy: bool = False, spec_decode: bool = False, + prefix: str = "", + is_plugin_mode: bool = False, ): def have_shared_expert(name): maybe_matching_list = ["mlp.shared_experts.", "mlp.shared_expert."] @@ -95,6 +139,10 @@ def have_shared_expert(name): return maybe_matching_name return None + # need to record the loaded weight name for vllm load check + # it is only used in plugin mode for vllm + loaded_weights_record: set[str] = set() + packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) weights_mapping = getattr(model, "weights_mapping", {}) params_dict = dict(model.named_parameters()) @@ -161,6 +209,7 @@ def have_shared_expert(name): weight_loader, param, weight_tensor, shard_id ) ) + loaded_weights_record.add(prefix + param_name) break else: # Check if model has expert mapping before processing @@ -188,6 +237,7 @@ def have_shared_expert(name): expert_id, ) ) + loaded_weights_record.add(prefix + name) # weight_loader( # param, # weight_tensor, @@ -206,6 +256,7 @@ def have_shared_expert(name): futures.append( executor.submit(weight_loader, param, weight_tensor) ) + loaded_weights_record.add(prefix + name) # weight_loader(param, weight_tensor) else: # Model doesn't have expert mapping, use generic loading @@ -215,6 +266,7 @@ def have_shared_expert(name): ) # weight_loader(param, weight_tensor) futures.append(executor.submit(weight_loader, param, weight_tensor)) + loaded_weights_record.add(prefix + name) # Wait for all tasks to complete and raise any exceptions. for future in concurrent.futures.as_completed(futures): future.result() @@ -222,7 +274,13 @@ def have_shared_expert(name): if hasattr(module, "process_weights_after_loading"): module.process_weights_after_loading() quant_method = getattr(module, "quant_method", None) - if isinstance(quant_method, QuantizeMethodBase): + + # when running plugin mode for sglang, don't do the post process here + # since sglang will call this func automatically after finishing loading + if isinstance(quant_method, QuantizeMethodBase) and not is_sglang(): quant_method.process_weights_after_loading(module) if isinstance(quant_method, FusedMoEMethodBase): quant_method.init_prepare_finalize(module) + + if is_plugin_mode: + return loaded_weights_record diff --git a/atom/model_ops/__init__.py b/atom/model_ops/__init__.py new file mode 100644 index 000000000..4b6c0b545 --- /dev/null +++ b/atom/model_ops/__init__.py @@ -0,0 +1,14 @@ +from .paged_attention import PagedAttention +from .radix_attention import RadixAttention + +# This global class is used to construct the attention op in model, +# it can be assigned to different attention ops. +# By default, PagedAttention is used. +# For sglang, RadixAttention will be assigned to Attention +Attention = PagedAttention + +__all__ = [ + "Attention", + "PagedAttention", + "RadixAttention", +] diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 174354b7d..b38ca0ee3 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -15,8 +15,15 @@ from .attention_mla import MLAModules +from atom.plugin.prepare import is_plugin_mode, is_vllm +from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode -class Attention(nn.Module): + +@PagedAttentionImplDecoratorForPluginMode +class PagedAttentionImpl(nn.Module): + """ + Attention paged implementation + """ def __init__( self, @@ -24,11 +31,15 @@ def __init__( head_dim, scale, num_kv_heads, + alibi_slopes: list[float] | None, + sliding_window: Optional[int] = None, kv_cache_dtype="bf16", + logits_soft_cap: float | None = None, + attn_type=None, + kv_sharing_target_layer_name: int | None = None, layer_num=0, mla_modules: Optional[MLAModules] = None, sinks: Optional[nn.Parameter] = None, - sliding_window: Optional[int] = None, rotary_emb: Optional[torch.nn.Module] = None, q_norm: Optional[torch.nn.Module] = None, k_norm: Optional[torch.nn.Module] = None, @@ -37,12 +48,16 @@ def __init__( super().__init__() self.num_heads = num_heads self.head_dim = head_dim + # for upper framework, it uses head_size in built-in methods + self.head_size = head_dim self.scale = scale self.num_kv_heads = num_kv_heads + self.alibi_slopes = alibi_slopes self.k_cache = self.v_cache = torch.tensor([]) self.kv_cache_dtype = kv_cache_dtype self.max_model_len = 0 self.k_scale = self.v_scale = None + self.device = "cuda:" + str(torch.cuda.current_device()) self.layer_num = layer_num self.kv_scale_float = ( torch.finfo(torch.float8_e4m3fn).max / torch.finfo(aiter.dtypes.fp8).max @@ -56,7 +71,11 @@ def __init__( self.q_norm = q_norm self.k_norm = k_norm - def forward( + # for plugin mode(vllm), the query quant is disabled for now + if is_vllm(): + self.supports_quant_query_input = False + + def forward_impl_server_mode( self, q: torch.Tensor, k: torch.Tensor, @@ -416,3 +435,39 @@ def dispatch_backend(self, fwd_ctx: ForwardContext): if atom_config.kv_cache_block_size == 1024: return self.paged_attention_persistent_asm return self.paged_attention_asm + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor = None, + attn_metadata=None, + position: torch.Tensor = None, + q_scale: Optional[torch.Tensor] = None, + qkv: torch.Tensor = None, + output: torch.Tensor = None, + **kwargs, + ): + if is_plugin_mode(): + # forward impl method are added by the decorator + # PagedAttentionImplDecoratorForPluginMode + return self.forward_impl_plugin_mode( + layer=layer, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + position=position, + q_scale=q_scale, + qkv=qkv, + ) + else: + # only for server mode, keep the original method + o = self.forward_impl_server_mode( + q=query, k=key, v=value, position=position, q_scale=q_scale, qkv=qkv + ) + + return o diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index c6d026cea..dccb9d620 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -96,7 +96,7 @@ class MLAAttention(nn.Module): def __init__( self, num_heads: int, - head_size: int, + head_dim: int, scale: float, num_kv_heads: int, kv_cache_dtype: str, @@ -107,7 +107,7 @@ def __init__( ) -> None: super().__init__() self.num_heads = num_heads - self.head_size = head_size + self.head_dim = head_dim self.scale = float(scale) self.num_kv_heads = num_kv_heads self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype == "fp8" else "auto" diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index b285a0838..68a8567d5 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -9,41 +9,67 @@ import torch from aiter.dist.parallel_state import get_tp_group from atom.model_engine.scheduler import ScheduledBatch -from atom.model_ops.attention_mha import Attention from atom.utils import CpuGpuBuffer from atom.utils.block_convert import ( block_table_convert_triton, kv_indices_generate_triton, ) +import atom.model_ops as ops +from atom.model_ops.paged_attention import PagedAttention +from atom.model_ops.attention_mha import PagedAttentionImpl +from atom.model_ops.radix_attention import RadixAttention from atom.utils.forward_context import AttentionMetaData, Context from .backends import AttentionBackend, CommonAttentionBuilder +from atom.plugin.prepare import is_plugin_mode +from atom.plugin.attention import AiterAttentionMetadataBuilderDecoratorForPluginMode +from atom.plugin.attention import AiterBackendDecoratorForPluginMode def cdiv(a, b): return (a + b - 1) // b +@AiterBackendDecoratorForPluginMode class AiterBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "ROCM_AITER_ATTENTION" + return "ROCM_AITER_ATTENTION" if not is_plugin_mode() else "CUSTOM" @staticmethod def get_builder_cls() -> Type["AiterAttentionMetadataBuilder"]: return AiterAttentionMetadataBuilder @staticmethod - def get_impl_cls() -> Type["Attention"]: - return Attention + def get_impl_cls(): + attn_cls = ops.Attention + if attn_cls == PagedAttention: + return PagedAttentionImpl + elif attn_cls == RadixAttention: + raise NotImplementedError("RadixAttention is not supported for now") + raise NotImplementedError( + f"Unsupported attention class {attn_cls!r} configured in ops.Attention" + ) -class AiterAttentionMetadataBuilder(CommonAttentionBuilder): +@AiterAttentionMetadataBuilderDecoratorForPluginMode( + default_base_class=CommonAttentionBuilder +) +class AiterAttentionMetadataBuilder: BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] - def __init__(self, model_runner): + def __init__( + self, + kv_cache_spec=None, + layer_names=None, + config=None, + device=None, + model_runner=None, + ): self.block_size = 1024 if model_runner.block_size == 1024 else 16 - super().__init__(model_runner) + # Note: Cannot use super() here because the class is dynamically created by decorator + # Use explicit parent class call instead + CommonAttentionBuilder.__init__(self, model_runner) config = model_runner.config hf_config = config.hf_config self.num_attention_heads = ( diff --git a/atom/model_ops/attentions/gdn_attn.py b/atom/model_ops/attentions/gdn_attn.py index fedc06a33..3913f84dc 100644 --- a/atom/model_ops/attentions/gdn_attn.py +++ b/atom/model_ops/attentions/gdn_attn.py @@ -69,7 +69,7 @@ def __init__( self, model_runner, ): - super().__init__(model_runner) + super().__init__(model_runner=model_runner) self.num_spec = 0 if hasattr(model_runner, "drafter"): self.num_spec = model_runner.drafter.mtp_k diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index c8144facf..3660b3ebd 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -3,9 +3,12 @@ # from flash_attn import flash_attn_with_kvcache from typing import Optional +from abc import ABC, abstractmethod import torch from torch import nn +import triton +import triton.language as tl from atom.utils import mark_spliting_op @@ -14,6 +17,169 @@ from atom.utils.selector import get_attn_backend +# frontend interface class for constructing attention +# op in model file +class Attention: + def __new__(cls, *args, **kwargs): + from atom.model_ops import Attention + + return Attention(*args, **kwargs) + + +# this triton kernel is used to fetch the stored kv in +# kv cache for computing the extend path(chunked prefill) +# and it can be used for both server mode and plugin mode +@triton.jit +def cp_mha_gather_cache_kernel( + key_cache_ptr, # [num_blocks, page_size, num_head, head_size] + value_cache_ptr, # [num_blocks, page_size, num_head, head_size] + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size] + block_table_ptr, # [num_batches, max_block_num] + cu_seqlens_kv_ptr, # [num_batches + 1] + token_to_batch_ptr, # [max_cum_tokens] + seq_start_ptr, # [num_batches] + k_scale_ptr, # [1] / [num_blocks, num_kv_heads, page_size] + v_scale_ptr, + num_heads, + head_size, + x, + max_block_num, + DEQUANT: tl.constexpr, + PAGE_SIZE: tl.constexpr, + CACHE_FORMAT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + token_id = tl.program_id(0) + head_id = tl.program_id(1) + col_offsets = tl.arange(0, BLOCK_SIZE) + + key_ptr_offset = key_ptr + token_id * head_size * num_heads + head_id * head_size + value_ptr_offset = ( + value_ptr + token_id * head_size * num_heads + head_id * head_size + ) + batch_idx = tl.load(token_to_batch_ptr + token_id) + batch_start = tl.load(seq_start_ptr + batch_idx) + token_start = tl.load(cu_seqlens_kv_ptr + batch_idx) + batch_offset = token_id - token_start + batch_start + block_offset = batch_offset // PAGE_SIZE + block_id = tl.load(block_table_ptr + max_block_num * batch_idx + block_offset).to( + tl.int64 + ) + slot_id = batch_offset % PAGE_SIZE + + if CACHE_FORMAT == "NHD": + # for kv cache layout as + # K: [num_blocks, page_size, num_head, head_dim] + # V: [num_blocks, page_size, num_head, head_dim] + key_cache_ptr_offset = ( + key_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + slot_id * num_heads * head_size + + head_id * head_size + ) + value_cache_ptr_offset = ( + value_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + slot_id * num_heads * head_size + + head_id * head_size + ) + k_reg = tl.load(key_cache_ptr_offset + col_offsets) + v_reg = tl.load(value_cache_ptr_offset + col_offsets) + if DEQUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = k_reg.dtype + v_dtype = v_reg.dtype + k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype) + v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype) + tl.store(key_ptr_offset + col_offsets, k_reg) + tl.store(value_ptr_offset + col_offsets, v_reg) + + elif CACHE_FORMAT == "SHUFFLE": + # for kv cache layout as + # K: [num_blocks, num_head, head_dim // x, page_size, x] + # V: [num_blocks, num_head, page_size // x, head_dim, x] + key_cache_ptr_offset = ( + key_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + head_id * head_size * PAGE_SIZE + + slot_id * x + ) + value_cache_ptr_offset = ( + value_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + head_id * head_size * PAGE_SIZE + + (slot_id // x) * head_size * x + + slot_id % x + ) + k_reg_offset = col_offsets // x * PAGE_SIZE * x + col_offsets % x + v_reg_offset = col_offsets * x + k_reg = tl.load(key_cache_ptr_offset + k_reg_offset) + v_reg = tl.load(value_cache_ptr_offset + v_reg_offset) + if DEQUANT: + k_scale = 1.0 + v_scale = 1.0 + k_reg = k_reg.to(tl.float32) * k_scale + v_reg = v_reg.to(tl.float32) * v_scale + tl.store(key_ptr_offset + col_offsets, k_reg) + tl.store(value_ptr_offset + col_offsets, v_reg) + + +def cp_mha_gather_cache( + key_cache: torch.Tensor, + value_cache: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + block_tables: torch.Tensor, + k_scales: Optional[torch.Tensor], + v_scales: Optional[torch.Tensor], + cu_seqlens_kv: torch.Tensor, + token_to_batch: torch.Tensor, + seq_starts: torch.Tensor, + dequant: bool, + kv_cache_layout: str, + total_tokens: int, +): + assert kv_cache_layout in [ + "NHD", + "SHUFFLE", + ], "kv_cache_layout only support NHD, SHUFFLE" + if dequant: + assert k_scales is not None and v_scales is not None + head_dim = key.shape[2] + x = 16 // key_cache.element_size() + # For k cache layout: [num_blocks, num_heads, page_size, head_dim] + assert head_dim == key_cache.shape[3], ( + "We assume your kv cache layout is [num_blocks, " + "page_size, num_heads, head_dim], but got otherwise" + ) + page_size = key_cache.shape[1] + num_heads = key_cache.shape[2] + + grid = lambda meta: (total_tokens, num_heads) # noqa: E731 + cp_mha_gather_cache_kernel[grid]( + key_cache, + value_cache, + key, + value, + block_tables, + cu_seqlens_kv, + token_to_batch, + seq_starts, + k_scales, + v_scales, + num_heads, + head_dim, + x, + block_tables.size(1), + DEQUANT=dequant, + PAGE_SIZE=page_size, + CACHE_FORMAT=kv_cache_layout, + BLOCK_SIZE=head_dim, + ) + + def fake_( q: torch.Tensor, q_scale: Optional[torch.Tensor], @@ -51,7 +217,18 @@ def unified_attention_with_output_base( ) -> torch.Tensor: atom_config = get_current_atom_config() self = atom_config.compilation_config.static_forward_context[layer_name] - return self.impl.forward(q, k, v, positions, q_scale, qkv) + if use_mla: + return self.impl.forward(q, k, v, positions, q_scale, qkv) + else: + return self.impl.forward( + layer=self, + query=q, + key=k, + value=v, + position=positions, + q_scale=q_scale, + qkv=qkv, + ) def linear_attention_with_output_base_fake( @@ -79,7 +256,12 @@ def linear_attention_with_output_base( return self.impl.forward(mixed_qkv, b, a, core_attn_out) -class Attention(nn.Module): +class BaseAttention(nn.Module, ABC): + """ + Abstract base class for attention + + This class defines the interface that all attention implementations must follow + """ def __init__( self, @@ -95,71 +277,23 @@ def __init__( per_layer_sliding_window: Optional[int] = None, rotary_emb: Optional[torch.nn.Module] = None, prefix: Optional[str] = None, - q_norm: Optional[torch.nn.Module] = None, - k_norm: Optional[torch.nn.Module] = None, **kwargs, ): super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - self.scale = scale - self.num_kv_heads = num_kv_heads - self.k_cache = self.v_cache = torch.tensor([]) - self.kv_cache_dtype = kv_cache_dtype - self.max_model_len = 0 - self.k_scale = self.v_scale = None - self.layer_num = layer_num - self.mla_modules = mla_modules - self.use_mla = use_mla - self.base_attention = None - self.kv_cache = torch.tensor([]) - self.indexer = mla_modules.indexer if mla_modules is not None else None - self.sinks = sinks - - atom_config = get_current_atom_config() - dtype = atom_config.torch_dtype - block_size = atom_config.kv_cache_block_size - self.attn_backend = get_attn_backend( - block_size, - use_mla=self.use_mla, - ) - impl_cls = self.attn_backend.get_impl_cls() - self.impl = impl_cls( - num_heads, - head_dim, - scale, - num_kv_heads, - kv_cache_dtype, - layer_num, - mla_modules, - sinks=sinks, - sliding_window=per_layer_sliding_window, - rotary_emb=rotary_emb, - dtype=dtype, - q_norm=q_norm, - k_norm=k_norm, - ) - - compilation_config = atom_config.compilation_config - default_name = f"MLA_{layer_num}" if self.use_mla else f"MHA_{layer_num}" - self.layer_name = prefix if prefix is not None else default_name - if self.layer_name in compilation_config.static_forward_context: - raise ValueError("Duplicate layer: {}".format(self.layer_name)) - compilation_config.static_forward_context[self.layer_name] = self + @abstractmethod def forward( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - positions: torch.Tensor = None, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: Optional[torch.Tensor] = None, q_scale: Optional[torch.Tensor] = None, - qkv: torch.Tensor = None, - ): - output = torch.ops.aiter.unified_attention_with_output_base( - q, q_scale, k, v, positions, self.layer_name, self.use_mla, qkv + **kwargs, + ) -> torch.Tensor: + raise NotImplementedError( + f"{self.__class__.__name__} must implement the forward() method" ) - return output class LinearAttention(nn.Module): diff --git a/atom/model_ops/embed_head.py b/atom/model_ops/embed_head.py index e5b4b3d10..dd7c95ead 100644 --- a/atom/model_ops/embed_head.py +++ b/atom/model_ops/embed_head.py @@ -8,6 +8,7 @@ from aiter.tuned_gemm import tgemm from atom.utils.forward_context import ForwardContext, get_forward_context from torch import nn +from atom.plugin import is_plugin_mode class VocabParallelEmbedding(nn.Module): @@ -67,13 +68,14 @@ def __init__( self.register_parameter("bias", None) def forward(self, x: torch.Tensor): - forward_context: ForwardContext = get_forward_context() - context = forward_context.context - attn_metadata = forward_context.attn_metadata - # context = get_context() - if context.is_prefill and not context.is_draft: - last_indices = attn_metadata.cu_seqlens_q[1:] - 1 - x = x[last_indices].contiguous() + if not is_plugin_mode(): + forward_context: ForwardContext = get_forward_context() + context = forward_context.context + attn_metadata = forward_context.attn_metadata + # context = get_context() + if context.is_prefill and not context.is_draft: + last_indices = attn_metadata.cu_seqlens_q[1:] - 1 + x = x[last_indices].contiguous() logits = tgemm.mm(x, self.weight, self.bias) if self.tp_size > 1: logits = tensor_model_parallel_all_gather(logits) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index e6813b1de..4cbb8fa00 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -47,6 +47,7 @@ from atom.utils.forward_context import get_forward_context from torch import nn from transformers import PretrainedConfig +from atom.plugin.moe import FusedMoEDecoratorForPluginMode class FusedMoeWeightScaleSupported(Enum): @@ -1773,6 +1774,7 @@ def moe_forward_fake( ) +@FusedMoEDecoratorForPluginMode class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py new file mode 100644 index 000000000..ce2ac8635 --- /dev/null +++ b/atom/model_ops/paged_attention.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +# from flash_attn import flash_attn_with_kvcache +from typing import Optional + +import torch +from torch import nn + +from .attention_mla import MLAModules +from .base_attention import BaseAttention +from atom.config import get_current_atom_config +from atom.utils.selector import get_attn_backend +from atom.plugin.prepare import is_sglang, is_vllm +from atom.plugin.attention import unified_attention_with_output_base_for_plugin_mode + + +class PagedAttention(BaseAttention): + """ + Attention paged implementation + """ + + def __init__( + self, + num_heads, + head_dim, + scale, + num_kv_heads, + alibi_slopes: list[float] = None, + kv_cache_dtype="bf16", + layer_num=0, + use_mla: bool = False, + mla_modules: Optional[MLAModules] = None, + sinks: Optional[nn.Parameter] = None, + per_layer_sliding_window: Optional[int] = None, + rotary_emb: Optional[torch.nn.Module] = None, + prefix: Optional[str] = None, + q_norm: Optional[torch.nn.Module] = None, + k_norm: Optional[torch.nn.Module] = None, + **kwargs, + ): + # plugin mode(sglang) is not support paged attention + # for now, only support plugin mode(vllm) and atom server mode + assert ( + not is_sglang() + ), "PagedAttention is not supported for plugin mode(sglang) for now" + super().__init__( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + kv_cache_dtype=kv_cache_dtype, + layer_num=layer_num, + use_mla=use_mla, + mla_modules=mla_modules, + sinks=sinks, + per_layer_sliding_window=per_layer_sliding_window, + rotary_emb=rotary_emb, + prefix=prefix, + **kwargs, + ) + + # for plugin mode + if is_vllm(): + self.use_mla = use_mla + self.rotary_emb = rotary_emb + + try: + from vllm.attention.layer import Attention, AttentionType + except ImportError: + from vllm.model_executor.layers.attention import Attention + from vllm.v1.attention.backend import AttentionType + + atom_config = get_current_atom_config() + assert ( + atom_config is not None + ), "atom_config is required for plugin mode to vllm" + + # use vllm cache config and quant config to follow the convention of vllm + cache_config = atom_config.plugin_config.vllm_cache_config + quant_config = atom_config.plugin_config.vllm_quant_config + + # add extra impl args, which are needed to be passed to the impl class + # while it only works for custom attention backend for vllm + extra_impl_args = {} + if atom_config.plugin_config.vllm_use_atom_attention: + extra_impl_args["sinks"] = sinks + extra_impl_args["rotary_emb"] = rotary_emb + extra_impl_args["q_norm"] = q_norm + extra_impl_args["k_norm"] = k_norm + + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=None, + per_layer_sliding_window=per_layer_sliding_window, + prefix=f"{prefix}", + attn_type=AttentionType.DECODER, + kv_sharing_target_layer_name=None, + **extra_impl_args, + ) + + compilation_config = atom_config.compilation_config + self.layer_name = prefix + if self.layer_name in compilation_config.static_forward_context: + raise ValueError("Duplicate layer: {}".format(self.layer_name)) + compilation_config.static_forward_context[self.layer_name] = self + return + + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.k_cache = self.v_cache = torch.tensor([]) + self.kv_cache_dtype = kv_cache_dtype + self.max_model_len = 0 + self.k_scale = self.v_scale = None + self.layer_num = layer_num + self.mla_modules = mla_modules + self.use_mla = use_mla + self.base_attention = None + self.kv_cache = torch.tensor([]) + self.indexer = mla_modules.indexer if mla_modules is not None else None + self.sinks = sinks + + atom_config = get_current_atom_config() + dtype = atom_config.torch_dtype + block_size = atom_config.kv_cache_block_size + self.attn_backend = get_attn_backend( + block_size, + use_mla=self.use_mla, + ) + impl_cls = self.attn_backend.get_impl_cls() + self.impl = impl_cls( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + kv_cache_dtype=kv_cache_dtype, + layer_num=layer_num, + mla_modules=mla_modules, + sinks=sinks, + sliding_window=per_layer_sliding_window, + rotary_emb=rotary_emb, + dtype=dtype, + q_norm=q_norm, + k_norm=k_norm, + **kwargs, + ) + + compilation_config = atom_config.compilation_config + default_name = f"MLA_{layer_num}" if self.use_mla else f"MHA_{layer_num}" + self.layer_name = prefix if prefix is not None else default_name + if self.layer_name in compilation_config.static_forward_context: + raise ValueError("Duplicate layer: {}".format(self.layer_name)) + compilation_config.static_forward_context[self.layer_name] = self + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor = None, + q_scale: Optional[torch.Tensor] = None, + qkv: torch.Tensor = None, + **kwargs, + ): + if is_vllm(): + output = unified_attention_with_output_base_for_plugin_mode( + query, + q_scale, + key, + value, + positions, + layer_name=self.layer_name, + use_mla=self.use_mla, + qkv=qkv, + ) + return output + + # for atom server mode + output = torch.ops.aiter.unified_attention_with_output_base( + query, q_scale, key, value, positions, self.layer_name, self.use_mla, qkv + ) + return output diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py new file mode 100644 index 000000000..25388b384 --- /dev/null +++ b/atom/model_ops/radix_attention.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +from torch import nn +from typing import Optional + +from .attention_mla import MLAModules +from .base_attention import BaseAttention +from atom.plugin.prepare import is_plugin_mode, is_sglang +from atom.models.utils import maybe_prefix + + +class RadixAttention(BaseAttention): + """ + Attention radix implementation + """ + + def __init__( + self, + num_heads, + head_dim, + scale, + num_kv_heads, + kv_cache_dtype="bf16", + layer_num=0, + use_mla: bool = False, + mla_modules: Optional[MLAModules] = None, + sinks: Optional[nn.Parameter] = None, + per_layer_sliding_window: Optional[int] = None, + rotary_emb: Optional[torch.nn.Module] = None, + prefix: Optional[str] = None, + **kwargs, + ): + super().__init__( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + kv_cache_dtype=kv_cache_dtype, + layer_num=layer_num, + use_mla=use_mla, + mla_modules=mla_modules, + sinks=sinks, + per_layer_sliding_window=per_layer_sliding_window, + rotary_emb=rotary_emb, + prefix=prefix, + **kwargs, + ) + self.rotary_emb = rotary_emb + + if is_sglang(): + from sglang.srt.layers.radix_attention import RadixAttention + + self.attn = RadixAttention( + num_heads=num_heads, + head_dim=head_dim, + scaling=scale, + num_kv_heads=num_kv_heads, + layer_id=layer_num, + prefix=maybe_prefix(prefix, "attn"), + ) + else: + raise NotImplementedError( + "RadixAttention is only supported for plugin mode for sglang for now" + ) + + def forward_impl_plugin_mode( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata=None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor = None, + q_scale: torch.Tensor = None, + **kwargs, + ): + if is_sglang(): + # for sglang, forward_batch is required + forward_batch = kwargs.get("forward_batch", None) + assert forward_batch is not None, "forward_batch is required for sglang" + if self.rotary_emb is not None: + assert positions is not None, "positions is required for ROPE" + query, key = self.rotary_emb(positions, query, key) + return self.attn(q=query, k=key, v=value, forward_batch=forward_batch) + else: + raise NotImplementedError( + "RadixAttention is only supported for plugin mode for sglang for now" + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor = None, + q_scale: Optional[torch.Tensor] = None, + **kwargs, + ): + if is_plugin_mode(): + o = self.forward_impl_plugin_mode( + query=query, + key=key, + value=value, + positions=positions, + q_scale=q_scale, + **kwargs, + ) + else: + raise NotImplementedError( + "RadixAttention is not supported for server mode for now" + ) + return o diff --git a/atom/models/qwen3.py b/atom/models/qwen3.py index ccd4f7648..d4ea6444d 100644 --- a/atom/models/qwen3.py +++ b/atom/models/qwen3.py @@ -24,14 +24,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Iterable import torch # import torch.distributed as dist from aiter.dist.parallel_state import get_tp_group from aiter.rotary_embedding import get_rope -from atom.config import Config, QuantizationConfig +from atom.config import Config from atom.model_ops.activation import SiluAndMul # from atom.model_ops.attention import Attention @@ -47,6 +47,9 @@ from torch import nn from transformers import Qwen3Config +from atom.model_loader.loader import load_model_in_plugin_mode +from atom.models.utils import maybe_prefix + class Qwen3Attention(nn.Module): @@ -63,11 +66,11 @@ def __init__( rope_scaling: tuple | None = None, kv_cache_dtype: str = "fp16", layer_num: int = 0, - quant_config: Optional[QuantizationConfig] = None, + atom_config: Config = None, + prefix: str = "", ) -> None: super().__init__() tp_size = get_tp_group().world_size - self.layer_num = layer_num self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size @@ -85,13 +88,15 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=qkv_bias, - quant_config=quant_config, + quant_config=atom_config.quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - quant_config=quant_config, + quant_config=atom_config.quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, @@ -101,14 +106,16 @@ def __init__( rope_scaling=rope_scaling, ) self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - self.num_kv_heads, + num_heads=self.num_heads, + head_dim=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, kv_cache_dtype=kv_cache_dtype, layer_num=layer_num, use_mla=False, rotary_emb=self.rotary_emb, + config=atom_config, + prefix=f"{prefix}.attn", ) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -117,6 +124,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -124,7 +132,7 @@ def forward( q = self.q_norm(q) k = self.k_norm(k) - o = self.attn(q, k, v, positions) + o = self.attn(q, k, v, positions, **model_kwargs) output = self.o_proj(o) return output @@ -136,7 +144,8 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config=None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -144,12 +153,14 @@ def __init__( [intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) assert hidden_act == "silu" self.act_fn = SiluAndMul() @@ -166,11 +177,12 @@ class Qwen3DecoderLayer(nn.Module): def __init__( self, config: Qwen3Config, - cache_config: str = "bf16", - quant_config: Optional[QuantizationConfig] = None, + atom_config: Config, layer_num: int = 0, + prefix: str = "", ) -> None: super().__init__() + kv_cache_dtype = atom_config.kv_cache_dtype self.layer_num = layer_num rope_params = config.rope_parameters self.self_attn = Qwen3Attention( @@ -183,15 +195,17 @@ def __init__( head_dim=getattr(config, "head_dim", None), rope_theta=rope_params["rope_theta"], rope_scaling=rope_params, - kv_cache_dtype=cache_config, + kv_cache_dtype=kv_cache_dtype, layer_num=layer_num, - quant_config=quant_config, + atom_config=atom_config, + prefix=f"{prefix}.self_attn", ) self.mlp = Qwen3MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + quant_config=atom_config.quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -203,51 +217,65 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, + **model_kwargs: dict[str, Any] | None, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn(positions, hidden_states) + hidden_states = self.self_attn( + positions=positions, hidden_states=hidden_states, **model_kwargs + ) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual -@support_torch_compile +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + } +) class Qwen3Model(nn.Module): - def __init__(self, atom_config: Config) -> None: + def __init__(self, *, atom_config: Config, prefix: str = "") -> None: super().__init__() - config = atom_config.hf_config - cache_config = atom_config.kv_cache_dtype - quant_config = atom_config.quant_config + hf_config = atom_config.hf_config + self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, config.hidden_size + hf_config.vocab_size, hf_config.hidden_size ) self.layers = nn.ModuleList( [ Qwen3DecoderLayer( - config, - cache_config=cache_config, - quant_config=quant_config, + config=hf_config, + atom_config=atom_config, layer_num=layer_num, + prefix=f"{prefix}.layers.{layer_num}", ) - for layer_num in range(config.num_hidden_layers) + for layer_num in range(hf_config.num_hidden_layers) ] ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(hf_config.hidden_size, eps=hf_config.rms_norm_eps) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + **model_kwargs: dict[str, Any], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for layer in self.layers: - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + **model_kwargs, + ) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -261,20 +289,34 @@ class Qwen3ForCausalLM(nn.Module): "up_proj": ("gate_up_proj", 1), } - def __init__(self, atom_config: Config) -> None: + def __init__(self, config: Any, prefix: str = "") -> None: super().__init__() - config = atom_config.hf_config - self.model = Qwen3Model(atom_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - if config.tie_word_embeddings: + self.atom_config = config + self.hf_config = self.atom_config.hf_config + self.model = Qwen3Model( + atom_config=self.atom_config, prefix=maybe_prefix(prefix, "model") + ) + + self.lm_head = ParallelLMHead( + num_embeddings=self.hf_config.vocab_size, + embedding_dim=self.hf_config.hidden_size, + bias=False, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if self.hf_config.tie_word_embeddings: self.lm_head.weight.data = self.model.embed_tokens.weight.data def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + intermediate_tensors=None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions) + hidden_states = self.model( + input_ids=input_ids, positions=positions, **model_kwargs + ) return hidden_states def compute_logits( @@ -283,3 +325,14 @@ def compute_logits( ) -> torch.Tensor: logits = self.lm_head(hidden_states) return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + # here prefix is "model." because Qwen3ForCausalLM is constructed in model + # wrapper class, so the name of loaded weights are prefixed with "model.". + # The vLLM will check the name of the loaded weights to make sure all the + # weights are loaded correctly + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 8da45b41e..9a0e1eba1 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, Any, Iterable import torch from aiter.dist.communication_op import tensor_model_parallel_all_reduce @@ -33,9 +33,10 @@ ) from atom.utils import envs from torch import nn +from atom.model_loader.loader import load_model_in_plugin_mode # import torch.distributed as dist -from transformers import PretrainedConfig, Qwen3Config +from transformers import PretrainedConfig ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( @@ -152,7 +153,8 @@ def __init__( rope_scaling: tuple | None = None, kv_cache_dtype: str = "fp16", layer_num: int = 0, - quant_config: Optional[QuantizationConfig] = None, + atom_config: Config = None, + prefix: str = "", ) -> None: super().__init__() tp_size = get_tensor_model_parallel_world_size() @@ -180,14 +182,14 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=qkv_bias, - quant_config=quant_config, + quant_config=atom_config.quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - quant_config=quant_config, + quant_config=atom_config.quant_config, reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, ) @@ -219,6 +221,7 @@ def __init__( layer_num=layer_num, use_mla=False, rotary_emb=self.rotary_emb, + prefix=f"{prefix}.attn", q_norm=self.q_norm if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION else None, k_norm=self.k_norm if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION else None, ) @@ -230,36 +233,40 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: - attn_output = self.attn(q, k, v, positions, None, qkv) + q, k, v = torch.split( + qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 + ) + attn_output = self.attn( + query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv + ) else: # Add qk-norm q = self.q_norm(q) k = self.k_norm(k) - attn_output = self.attn(q, k, v, positions) + attn_output = self.attn( + query=q, key=k, value=v, positions=positions, **model_kwargs + ) output = self.o_proj(attn_output) return output class Qwen3MoeDecoderLayer(nn.Module): - def __init__( - self, - config: Qwen3Config, - prefix: str, - cache_config: str = "bf16", - quant_config: Optional[QuantizationConfig] = None, - layer_num: int = 0, - ) -> None: + def __init__(self, atom_config=None, layer_num: int = 0, prefix: str = "") -> None: super().__init__() + self.atom_config = atom_config + config = self.atom_config.hf_config self.hidden_size = config.hidden_size rope_params = config.rope_parameters rope_theta = rope_params["rope_theta"] rope_scaling = rope_params + kv_cache_dtype = atom_config.kv_cache_dtype max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. @@ -275,9 +282,10 @@ def __init__( head_dim=getattr(config, "head_dim", None), rope_theta=rope_theta, rope_scaling=rope_scaling, - kv_cache_dtype=cache_config, + kv_cache_dtype=kv_cache_dtype, layer_num=layer_num, - quant_config=quant_config, + atom_config=atom_config, + prefix=f"{prefix}.self_attn", ) # `mlp_only_layers` in the config. @@ -289,14 +297,16 @@ def __init__( and (self.layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3MoeSparseMoeBlock( - config, quant_config=quant_config, prefix=f"{prefix}.mlp" + config, + quant_config=self.atom_config.quant_config, + prefix=f"{prefix}.mlp", ) else: self.mlp = Qwen3MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + quant_config=self.atom_config.quant_config, reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, prefix=f"{prefix}.mlp", ) @@ -316,6 +326,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, + **model_kwargs: dict[str, Any] | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -326,6 +337,7 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, + **model_kwargs, ) # Fully Connected @@ -344,42 +356,37 @@ def __init__( ): super().__init__() - config = atom_config.hf_config - cache_config = atom_config.kv_cache_dtype - quant_config = atom_config.quant_config - self.config = config + self.config = atom_config.hf_config if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, + self.config.vocab_size, + self.config.hidden_size, ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, + self.config.num_hidden_layers, lambda prefix, layer_num=None: Qwen3MoeDecoderLayer( - config, - prefix, - cache_config=cache_config, - quant_config=quant_config, + atom_config=atom_config, layer_num=layer_num, + prefix=prefix, ), prefix=f"{prefix}.layers", layer_num_offset=0, ) if get_pp_group().is_last_rank: self.norm = RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, + self.config.hidden_size, + eps=self.config.rms_norm_eps, fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, ) else: self.norm = PPMissingLayer() self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size + ["hidden_states", "residual"], self.config.hidden_size ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -391,6 +398,7 @@ def forward( positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -404,7 +412,9 @@ def forward( residual = intermediate_tensors["residual"] for layer in self.layers[self.start_layer : self.end_layer]: - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer( + positions, hidden_states, residual, **model_kwargs + ) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -441,24 +451,22 @@ def __init__( layer_type: type[nn.Module] = Qwen3MoeDecoderLayer, ): super().__init__() - config = atom_config.hf_config - quant_config = atom_config.quant_config - self.config = config - self.quant_config = quant_config + self.atom_config = atom_config + self.config = self.atom_config.hf_config + # Only perform the following mapping when Qwen3MoeMLP exists - if getattr(config, "mlp_only_layers", []): + if getattr(self.config, "mlp_only_layers", []): self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"] self.model = Qwen3MoeModel( - atom_config=atom_config, + atom_config=self.atom_config, prefix=maybe_prefix(prefix, "model"), - layer_type=layer_type, ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, + num_embeddings=self.config.vocab_size, + embedding_dim=self.config.hidden_size, + bias=False, prefix=maybe_prefix(prefix, "lm_head"), ) else: @@ -479,9 +487,14 @@ def forward( positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any] | None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, ) return hidden_states @@ -508,3 +521,14 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + # here prefix is "model." because Qwen3MoeForCausalLM is constructed in model + # wrapper class, so the name of loaded weights are prefixed with "model.". + # The vLLM will check the name of the loaded weights to make sure all the + # weights are loaded correctly + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record diff --git a/atom/plugin/__init__.py b/atom/plugin/__init__.py new file mode 100644 index 000000000..27c855e51 --- /dev/null +++ b/atom/plugin/__init__.py @@ -0,0 +1,13 @@ +from .prepare import ( + prepare_model, + is_sglang, + is_vllm, + is_plugin_mode, +) + +__all__ = [ + "prepare_model", + "is_sglang", + "is_vllm", + "is_plugin_mode", +] diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py new file mode 100644 index 000000000..2b136ebc8 --- /dev/null +++ b/atom/plugin/attention.py @@ -0,0 +1,653 @@ +from typing import Generic, Optional +import logging + +from dataclasses import dataclass + +import torch + +from atom.plugin.prepare import is_vllm, is_sglang +from atom.utils import CpuGpuBuffer +from atom.utils.forward_context import Context, AttentionMetaData +from atom.model_ops.attention_mha import PagedAttentionImpl + +logger = logging.getLogger("atom") + +_PARTITION_SIZE_ROCM = 256 +_CP_TOKENS_PER_ITER_ROCM = 32 * 1024 + + +@dataclass +class AiterFlashAttentionDecodeMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + + +@dataclass +class AiterFlashAttentionPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + + +@dataclass +class AiterChunkSlidingWindowMetadata: + swa_seqlens: torch.Tensor + swa_cu_seqlens: torch.Tensor + swa_seq_starts: torch.Tensor + swa_token_to_batch: torch.Tensor + swa_max_seqlens: int + swa_total_tokens: int + swa_workspace: torch.Tensor + + +@dataclass +class AiterChunkContextMetadata: + workspace: torch.Tensor + cu_seq_lens_chunk: torch.Tensor + chunk_starts: torch.Tensor + token_to_batch: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + seq_lens: torch.Tensor + num_chunks: int + total_token_per_batch: list[int] + swa_metadata: Optional[AiterChunkSlidingWindowMetadata] = None + + +@dataclass +class AiterFlashAttentionChunkPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + chunk_context_metadata: AiterChunkContextMetadata + + +@dataclass +class MetadataForPluginMode: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + num_actual_kv_tokens: int + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + slot_mapping: torch.Tensor + block_table: torch.Tensor + + # prefill and decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + num_extends: int + num_extend_tokens: int + + decode_metadata: Optional[AiterFlashAttentionDecodeMetadata] = None + prefill_metadata: Optional[AiterFlashAttentionPrefillMetadata] = None + extend_metadata: Optional[AiterFlashAttentionChunkPrefillMetadata] = None + + use_cascade: bool = False + common_prefix_len: int = 0 + total_tokens: int = 0 + + context: Optional[Context] = None + + +class vllmAiterBackendMethods: + # here attention in ATOM doesn't accept the output buffer because + # ATOM works as a model impl backend, it needs the maximum freedom + # to decide the output buffer and shape, thus here use this flag to + # let vllm don't allocate the output buffer for ATOM. ATOM will + # handle the output buffer by itself + accept_output_buffer: bool = False + supported_dtypes: list = [torch.float16, torch.bfloat16] + forward_includes_kv_cache_update: bool = True + + def __init__(self): + raise TypeError( + f"{self.__class__.__name__} is a utility class and should not be instantiated. " + "Its methods are meant to be added to other classes via decorators." + ) + + @staticmethod + def get_supported_kernel_block_sizes(): + return [16, 32] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @classmethod + def is_mla(cls) -> bool: + return False + + @staticmethod + def get_required_kv_cache_layout(): + return None + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [64, 128, 256] + + @classmethod + def full_cls_name(cls) -> tuple[str, str]: + return (cls.__module__, cls.__qualname__) + + @classmethod + def supports_alibi_sqrt(cls) -> bool: + return False + + +def AiterBackendDecoratorForPluginMode(cls): + """ + Decorator for AiterBackend to add specific methods and attributes for plugin mode + """ + is_vllm_mode = is_vllm() + + if is_vllm_mode: + # for vllm, add the required methods + cls.full_cls_name = vllmAiterBackendMethods.full_cls_name + cls.accept_output_buffer = vllmAiterBackendMethods.accept_output_buffer + cls.supported_dtypes = vllmAiterBackendMethods.supported_dtypes + cls.forward_includes_kv_cache_update = ( + vllmAiterBackendMethods.forward_includes_kv_cache_update + ) + cls.get_supported_kernel_block_sizes = ( + vllmAiterBackendMethods.get_supported_kernel_block_sizes + ) + cls.get_kv_cache_shape = vllmAiterBackendMethods.get_kv_cache_shape + cls.is_mla = vllmAiterBackendMethods.is_mla + cls.get_required_kv_cache_layout = ( + vllmAiterBackendMethods.get_required_kv_cache_layout + ) + cls.get_supported_head_sizes = vllmAiterBackendMethods.get_supported_head_sizes + cls.supports_alibi_sqrt = vllmAiterBackendMethods.supports_alibi_sqrt + return cls + + +def create_attn_metadata_builder_init_method(base_class): + """ + Create the init method for metadata builder + """ + + def init_method_under_plugin_mode( + self, + kv_cache_spec=None, + layer_names=None, + config=None, + device=None, + model_runner=None, + ): + base_class.__init__(self, kv_cache_spec, layer_names, config, device) + logger.info("init AiterAttentionMetadataBuilder for plugin mode") + from vllm.config import VllmConfig, get_layers_from_vllm_config + + try: + from vllm.attention.layer import Attention + except ImportError: + from vllm.model_executor.layers.attention import Attention + + assert isinstance(config, VllmConfig) + + self.vllm_config = config + self.model_config = config.model_config + self.parallel_config = config.parallel_config + self.cache_config = config.cache_config + + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) + self.head_dim = self.model_config.get_head_size() + self.block_size = kv_cache_spec.block_size + + self.aot_sliding_window: Optional[tuple[int, int]] = None + self.total_tokens: int = 0 + + self.scheduler_config = config.scheduler_config + self.block_ratio = 1 + + sliding_window_sizes: set[tuple[int, int] | None] = set() + layers = get_layers_from_vllm_config(config, Attention) + for layer in layers.values(): + assert isinstance(layer.impl, PagedAttentionImpl) + sliding_window = layer.impl.sliding_window + if sliding_window is None or sliding_window == -1: + sliding_window_sizes.add(None) + elif isinstance(sliding_window, tuple): + sliding_window_sizes.add(sliding_window) + else: + sliding_window_sizes.add((sliding_window - 1, 0)) + + while len(sliding_window_sizes) > 0: + sliding_window_config = sliding_window_sizes.pop() + if sliding_window_config is not None and sliding_window_config[0] != -1: + assert ( + self.aot_sliding_window is None + ), "Aiter Backend only support one valid sliding window" + self.aot_sliding_window = sliding_window_config + + # for extend path to store the fetched key and value + # here buffer used for extend path is not calculated by vLLM and SGLang + # when profile_run, it is possible to exhaust the GPU memory when + # gpu_mem_utilization is much higher + self.extend_workspace = torch.empty( + [2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.head_dim], + dtype=self.model_config.dtype, + device=device, + ) + workspace_bytes = ( + 2 + * _CP_TOKENS_PER_ITER_ROCM + * self.num_heads_kv + * self.head_dim + * torch.tensor([], dtype=self.model_config.dtype).element_size() + ) + workspace_mib = workspace_bytes / (1024 * 1024) + logger.warning( + "ATOM allocates extend_workspace outside vLLM memory accounting: " + "shape=%s dtype=%s size=%.2f MiB. " + "This untracked GPU memory can increase OOM risk when " + "gpu_mem_utilization is high.", + tuple(self.extend_workspace.shape), + self.model_config.dtype, + workspace_mib, + ) + + # used for ROPE + max_num_batched_tokens = config.scheduler_config.max_num_batched_tokens + i64_kwargs = {"dtype": torch.int64, "device": device} + self.positions = CpuGpuBuffer(max_num_batched_tokens, **i64_kwargs) + + return init_method_under_plugin_mode + + +def setup_attn_metadata_builder_base_class_and_attributes(class_dict: dict): + """ + Setup the base class and attributes for attention metadata builder + """ + from vllm.v1.attention.backend import ( + AttentionCGSupport, + AttentionMetadataBuilder, + ) + + base_class = AttentionMetadataBuilder + generic_base = AttentionMetadataBuilder + needs_generic = True + + # align with vllm rocm aiter fa + class_dict["_cudagraph_support"] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + class_dict["reorder_batch_threshold"] = 1 + + return base_class, generic_base, needs_generic, class_dict + + +class vllmAttentionMetadataBuilderMethods: + def __init__(self): + raise TypeError( + f"{self.__class__.__name__} is a utility class and should not be instantiated. " + "Its methods are meant to be added to other classes via decorators." + ) + + def build( + self, + common_prefix_len: int = 0, + common_attn_metadata=None, + fast_build: bool = False, + ): + if common_prefix_len > 0: + raise ValueError("ATOM does not support cascade attention yet") + + from vllm.v1.attention.backends.utils import split_decodes_prefills_and_extends + + # here assume the decode num token is 1 per request + split_ret = split_decodes_prefills_and_extends( + common_attn_metadata=common_attn_metadata, decode_threshold=1 + ) + + ( + num_decodes, + num_extends, + num_prefills, + num_decode_tokens, + num_extend_tokens, + num_prefill_tokens, + ) = split_ret + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens.cpu() + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + decode_metadata = None + if num_decodes > 0: + decode_metadata = AiterFlashAttentionDecodeMetadata( + max_query_len=query_lens_cpu[:num_decodes].max().item(), + min_query_len=query_lens_cpu[:num_decodes].min().item(), + max_seq_len=seq_lens[:num_decodes].max().item(), + query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1], + ) + + extend_metadata = None + if num_extends > 0: + num_extends_slice = slice(num_decodes, num_decodes + num_extends) + query_lens_for_extend = query_lens_cpu[num_extends_slice] + seq_lens_for_extend = seq_lens[num_extends_slice] + computed_kv_lens = seq_lens_for_extend - query_lens_for_extend + swa_metadata = None + if self.aot_sliding_window is not None: + swa_seqlen_for_extend = torch.minimum( + seq_lens_for_extend, + query_lens_for_extend + self.aot_sliding_window[0] + 1, + ) + cu_seq_lens = torch.zeros( + num_extends + 1, + dtype=torch.int32, + device=seq_lens_for_extend.device, + ) + torch.cumsum( + swa_seqlen_for_extend, + dim=0, + dtype=cu_seq_lens.dtype, + out=cu_seq_lens[1:], + ) + token_to_seq = torch.arange( + 0, + num_extends, + dtype=torch.int32, + device=seq_lens_for_extend.device, + ) + token_to_seq = torch.repeat_interleave( + token_to_seq, swa_seqlen_for_extend + ) + fetched_shape = cu_seq_lens[-1].item() + swa_workspace = torch.empty( + (2, fetched_shape, self.num_heads_kv, self.head_dim), + dtype=self.vllm_config.model_config.dtype, + device=self.device, + ) + + seq_starts = seq_lens_for_extend - swa_seqlen_for_extend + max_seqlen_k = swa_seqlen_for_extend.max().item() + total_tokens = cu_seq_lens[-1].item() + + swa_metadata = AiterChunkSlidingWindowMetadata( + swa_seqlens=swa_seqlen_for_extend.to( + self.device, non_blocking=True + ), + swa_cu_seqlens=cu_seq_lens.to(self.device, non_blocking=True), + swa_seq_starts=seq_starts.to(self.device, non_blocking=True), + swa_token_to_batch=token_to_seq.to(self.device, non_blocking=True), + swa_max_seqlens=max_seqlen_k, + swa_total_tokens=total_tokens, + swa_workspace=swa_workspace, + ) + + # allocate the equal amount of workspace for + # each chunk prefill request + max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends + from vllm.utils.math_utils import cdiv + + num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk) + + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_extends) + * max_context_chunk + ) + chunk_ends = torch.min( + computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk + ) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp( + min=0 + ) # [num_chunks, num_extends] + cu_seq_lens_cpu = torch.zeros( + [num_chunks, num_extends + 1], dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 + ) + max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item() + + # Build token->batch mapping robustly, even with zero-length batches. + token_to_batch_tensor = torch.zeros( + (num_chunks, max_cum_tokens), dtype=torch.int32, pin_memory=True + ) + batch_ids = torch.arange(num_extends, dtype=torch.int32) + for chunk_idx in range(num_chunks): + total_tokens = cu_seq_lens_cpu[chunk_idx, -1].item() + if total_tokens == 0: + continue + token_to_batch = torch.repeat_interleave( + batch_ids, chunk_seq_lens[chunk_idx].to(torch.int64) + ) + token_to_batch_tensor[chunk_idx, :total_tokens] = token_to_batch + + chunk_context_metadata = AiterChunkContextMetadata( + workspace=self.extend_workspace, + cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True), + chunk_starts=chunk_starts.to(self.device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True), + num_chunks=num_chunks, + total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(), + swa_metadata=swa_metadata, + ) + + query_start_loc_device = common_attn_metadata.query_start_loc[ + num_decodes : num_decodes + num_extends + 1 + ] + seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice] + cu_seq_lens = torch.zeros( + num_extends + 1, dtype=torch.int32, device=seq_lens_device.device + ) + torch.cumsum( + seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:] + ) + extend_metadata = AiterFlashAttentionChunkPrefillMetadata( + max_query_len=query_lens_for_extend.max().item(), + min_query_len=query_lens_for_extend.min().item(), + max_seq_len=seq_lens[num_extends_slice].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0], + chunk_context_metadata=chunk_context_metadata, + ) + + prefill_metadata = None + if num_prefills > 0: + query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :] + query_start_loc_device = common_attn_metadata.query_start_loc[ + num_decodes + num_extends : + ] + prefill_metadata = AiterFlashAttentionPrefillMetadata( + max_query_len=query_lens_for_prefill.max().item(), + min_query_len=query_lens_for_prefill.min().item(), + max_seq_len=seq_lens[num_decodes + num_extends :].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0], + ) + + num_actual_kv_tokens = torch.sum(seq_lens).item() + + use_cascade = False + + context_batch_size = 0 + has_prefill = bool(num_prefills > 0 or num_extends > 0) + if has_prefill: + context_batch_size = num_prefills + num_extends + else: + context_batch_size = num_decodes + context_graph_bs = context_batch_size + + num_actual_tokens = common_attn_metadata.num_actual_tokens + context = Context( + positions=None, + is_prefill=has_prefill, + batch_size=context_batch_size, + graph_bs=context_graph_bs, + ) + + attn_metadata_for_plugin_mode = MetadataForPluginMode( + num_actual_tokens=num_actual_tokens, + num_actual_kv_tokens=num_actual_kv_tokens, + max_query_len=common_attn_metadata.max_query_len, + query_start_loc=common_attn_metadata.query_start_loc, + max_seq_len=common_attn_metadata.max_seq_len, + seq_lens=common_attn_metadata.seq_lens, + block_table=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_extends=num_extends, + num_extend_tokens=num_extend_tokens, + decode_metadata=decode_metadata, + prefill_metadata=prefill_metadata, + extend_metadata=extend_metadata, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + total_tokens=self.total_tokens, + context=context, + ) + + attn_metadata = AttentionMetaData( + max_seqlen_q=common_attn_metadata.max_query_len, + block_tables=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + plugin_metadata=attn_metadata_for_plugin_mode, + ) + + return attn_metadata + + # this method will be called by vllm, so it follows the vllm's interface convention + def build_for_cudagraph_capture( + self, + common_attn_metadata=None, + ): + self.total_tokens = ( + self.model_config.max_model_len + * self.vllm_config.scheduler_config.max_num_partial_prefills + ) + attn_metadata = self.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + self.total_tokens = 0 + return attn_metadata + + +def AiterAttentionMetadataBuilderDecoratorForPluginMode(default_base_class): + def decorator(cls): + is_vllm_mode = is_vllm() + is_sglang_mode = is_sglang() + + base_class = default_base_class + class_dict = {} + + # record original decorated cls methods + for key, value in cls.__dict__.items(): + if not key.startswith("__") or key in ( + "__annotations__", + "__init__", + "__module__", + "__qualname__", + "__doc__", + ): + class_dict[key] = value + + # handle the generic base class + needs_generic = False + generic_base = None + + if is_vllm_mode: + # get the base class and generic base class + base_class, generic_base, needs_generic, class_dict = ( + setup_attn_metadata_builder_base_class_and_attributes(class_dict) + ) + + # replace the __init__ method in the decorated class + class_dict["__init__"] = create_attn_metadata_builder_init_method( + base_class + ) + + # add the methods to the decorated class + for method_name in dir(vllmAttentionMetadataBuilderMethods): + if not method_name.startswith("_"): + method = getattr(vllmAttentionMetadataBuilderMethods, method_name) + if callable(method): + class_dict[method_name] = method + elif is_sglang_mode: + raise NotImplementedError( + "AttentionMetadataBuilder for sglang is not implemented yet" + ) + + # create the new class + new_class = type(cls.__name__, (base_class,), class_dict) + + # replace the inherit base class for plugin mode, meanwhile support generic base class + is_generic_builder_base = ( + isinstance(generic_base, type) + and issubclass(generic_base, Generic) + and len(getattr(generic_base, "__parameters__", ())) > 0 + ) + if needs_generic and is_generic_builder_base: + new_class.__orig_bases__ = (generic_base[AttentionMetaData],) + else: + new_class.__orig_bases__ = (generic_base,) + + return new_class + + return decorator + + +# here not register it as a custom op and mark split because vllm +# will register attention impl forward as a custom op, so here +# avoid duplicated registration, and the split op is registered +# into the atom support_torch_compile decorator +def unified_attention_with_output_base_for_plugin_mode( + q: torch.Tensor, + q_scale: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + positions: torch.Tensor, + layer_name: str, + use_mla: bool, + qkv: torch.Tensor, +) -> torch.Tensor: + from atom.config import get_current_atom_config + from atom.utils import envs + + atom_config = get_current_atom_config() + if use_mla: + raise NotImplementedError("MLA is not supported for plugin mode for now") + else: + self = atom_config.compilation_config.static_forward_context[layer_name] + # here is the standard vllm attention impl interface + # when using fusion, we need to pass the qkv and positions through the q,k,v + # [watch out] accept_output_buffer must be False for plugin mode + # because we don't want vllm to manipulate the q k v and output buffer + # ATOM needs to handle all of the buffer here + if envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: + output = self.attn(q, positions, qkv) + else: + # calculate the q and k with rotary embedding + if self.rotary_emb is not None: + q, k = self.rotary_emb(positions, q, k) + output = self.attn(q, k, v) + return output diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py new file mode 100644 index 000000000..fd6420bc3 --- /dev/null +++ b/atom/plugin/attention_mha.py @@ -0,0 +1,739 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +""" +Plugin mode extensions for PagedAttentionImpl. +This module provides additional methods for PagedAttentionImpl when running in plugin mode. +""" + +import torch +import aiter +from aiter import dtypes, fused_qk_norm_rope_cache_quant_shuffle +from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache +from aiter.ops.triton.gluon.pa_decode_gluon import get_recommended_splits +from typing import TYPE_CHECKING, Optional +from atom.utils import envs +from atom.model_ops.base_attention import cp_mha_gather_cache + +import logging + +logger = logging.getLogger("atom") + +if TYPE_CHECKING: + from atom.utils.forward_context import AttentionMetaData + +ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( + envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION +) + +_PARTITION_SIZE_ROCM = 256 +_CP_TOKENS_PER_ITER_ROCM = 32 * 1024 +_QWEN_GLUON_PA_DECODE_BS = 64 + + +class PagedAttentionImplPluginModeMethods: + """ + Container class for plugin mode methods. + This class cannot be instantiated - it only serves as a namespace for methods + that will be added to PagedAttentionImpl via decorator. + """ + + def __init__(self): + raise TypeError( + "PagedAttentionImplPluginModeMethods cannot be instantiated. " + "It is only used as a method container for the decorator." + ) + + # this method will just be called by vLLM and there is no logic in this method + # as ATOM handles the process after loading weights for all ops by itself + def process_weights_after_loading(self, act_dtype: torch.dtype = torch.bfloat16): + pass + + def rope_cache_plugin_mode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qkv: torch.Tensor, + position: torch.Tensor, + attention_metadata: "AttentionMetaData", + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + flash_layout: bool = False, + ): + num_blocks, block_size, num_kv_heads, head_size = k_cache.shape + + if not flash_layout: + x = 16 // k_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_cache.dtype, + device="meta", + ) + # ATOM: [num_blocks, num_kv_heads, head_size, block_size], + # vLLM: [num_blocks, num_kv_heads, block_size // x, head_size, x], + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_cache.dtype, + device="meta", + ) + new_key_cache = k_cache.view_as(k_cache_template) + new_value_cache = v_cache.view_as(v_cache_template) + else: + new_key_cache = k_cache + new_value_cache = v_cache + + # if flash kv_cache layout, the shape of kv_cache is: + # + # key_cache: [num_blocks, block_size, num_kv_heads, head_size] + # value_cache: [num_blocks, num_kv_heads, head_size, block_size] + # + # if not, the shape is: + # + # key_cache: [num_blocks, num_kv_heads, head_size // x, block_size, x] + # value_cache: [num_blocks, num_kv_heads, head_size, block_size] + # + # and the origin kv cache layout in fwd_args is not flash + + attn_metadata = attention_metadata + + use_triton_attn = self.sliding_window != -1 or self.head_dim != 128 + self.use_triton_attn = use_triton_attn + + if ( + self.rotary_emb is not None + and self.q_norm is not None + and self.k_norm is not None + ): + fused_qk_norm_rope_cache_quant_shuffle( + qkv, + num_heads_q=self.num_heads, + num_heads_k=self.num_kv_heads, + num_heads_v=self.num_kv_heads, + head_dim=self.head_dim, + eps=self.q_norm.eps, + qw=self.q_norm.weight, + kw=self.k_norm.weight, + cos_sin_cache=self.rotary_emb.cos_sin_cache, + is_neox_style=self.rotary_emb.is_neox_style, + pos_ids=position, + k_cache=new_key_cache, + v_cache=new_value_cache, + slot_mapping=attn_metadata.slot_mapping, + kv_cache_dtype=( + "auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype + ), + k_scale=k_scale, + v_scale=v_scale, + ) + + qkv = qkv.view(qkv.shape[0], -1, self.head_dim) + q, k, v = qkv.split( + [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1 + ) + elif use_triton_attn and self.rotary_emb is not None: + k_scale = v_scale = self.kv_scale + + q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache( + q, + k, + v, + new_key_cache, + new_value_cache, + attn_metadata.slot_mapping, + position, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + k_scale, + v_scale, + self.rotary_emb.is_neox_style, + flash_layout=flash_layout, + apply_scale=self.kv_cache_dtype.startswith("fp8"), + offs=None, + q_out=q, + k_out=k, + output_zeros=False, + ) + else: + if self.q_norm is not None: + q = self.q_norm(q) + if self.k_norm is not None: + k = self.k_norm(k) + if self.kv_cache_dtype == "fp8": + aiter.reshape_and_cache_with_pertoken_quant( + k, + v, + new_key_cache, + new_value_cache, + k_scale, + v_scale, + attn_metadata.slot_mapping, + asm_layout=True, + ) + else: + aiter.reshape_and_cache( + k, + v, + new_key_cache, + new_value_cache, + attn_metadata.slot_mapping, + kv_cache_dtype="auto", + k_scale=None, + v_scale=None, + asm_layout=True, + ) + + return q, k, v, k_cache, v_cache, k_scale, v_scale + + def paged_attention_triton_plugin_mode( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + out: torch.Tensor, + attn_metadata: "AttentionMetaData", + ): + o = out + num_seqs, num_q_heads_total, head_size = q.shape + num_blocks, num_kv_heads, _, block_size, _ = k_cache.shape + query_group_size = num_q_heads_total // num_kv_heads + assert num_q_heads_total % num_kv_heads == 0 + + max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads) + + context_partition_size = 256 + if self.sliding_window > 0: + max_context_partition_num = 1 + context_partition_size = 128 + + # Output buffers (same as Triton) + intermediate_shape = ( + num_seqs, + num_kv_heads, + max_context_partition_num, + query_group_size, + ) + exp_sums = torch.empty(intermediate_shape, dtype=torch.float32, device=q.device) + max_logits = torch.empty( + intermediate_shape, dtype=torch.float32, device=q.device + ) + temporary_output = torch.empty( + *intermediate_shape, + head_size, + dtype=q.dtype, + device=q.device, + ) + + per_tensor = False + if k_scale is not None: + per_tensor = k_scale.numel() == 1 + if not per_tensor: + k_scale = k_scale.unsqueeze(-1) + v_scale = v_scale.unsqueeze(-1) + compute_type = ( + torch.bfloat16 + if self.kv_cache_dtype == "bf16" or per_tensor + else aiter.dtypes.fp8 + ) + + torch.ops.aiter.pa_decode_gluon( + o, + q, + k_cache, + v_cache, + attn_metadata.plugin_metadata.seq_lens, + attn_metadata.block_tables, + self.scale, + 1, # query_lenth + max_context_partition_num, + context_partition_size, + compute_type, + None, + None if self.kv_cache_dtype == "bf16" else k_scale, + None if self.kv_cache_dtype == "bf16" else v_scale, + exp_sums=exp_sums, + max_logits=max_logits, + temporary_output=temporary_output, + alibi_slopes=None, + sinks=self.sinks, + sliding_window=self.sliding_window, + ps=True, + ) + + return o + + def paged_attention_asm_plugin_mode( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + num_decodes: int, + num_decode_tokens: int, + attn_metadata: "AttentionMetaData", + out: torch.Tensor, + ): + aiter.pa_fwd_asm( + Q=q, + K=k_cache, + V=v_cache, + block_tables=attn_metadata.plugin_metadata.block_table[:num_decodes], + context_lens=attn_metadata.plugin_metadata.seq_lens[:num_decodes], + block_tables_stride0=attn_metadata.plugin_metadata.block_table[ + :num_decodes + ].stride(0), + K_QScale=k_scale, + V_QScale=v_scale, + out_=out[:num_decode_tokens], + high_precision=0, + ) + + return + + def extend_for_sliding_window( + self, + attn_metadata: "AttentionMetaData", + query: torch.Tensor, + key_cache, + value_cache, + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + block_table: torch.Tensor, + k_scale: Optional[torch.Tensor], + v_scale: Optional[torch.Tensor], + ): + assert attn_metadata.plugin_metadata.extend_metadata is not None + assert ( + attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + is not None + ) + chunked_metadata = ( + attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + ) + swa_metadata = chunked_metadata.swa_metadata + assert swa_metadata is not None + swa_cu_seqlens = swa_metadata.swa_cu_seqlens + swa_seq_starts = swa_metadata.swa_seq_starts + swa_token_to_batch = swa_metadata.swa_token_to_batch + swa_max_seqlens = swa_metadata.swa_max_seqlens + swa_total_tokens = swa_metadata.swa_total_tokens + key_fetched, value_fetched = ( + swa_metadata.swa_workspace[0], + swa_metadata.swa_workspace[1], + ) + + cp_mha_gather_cache( + key_cache=key_cache, + value_cache=value_cache, + key=key_fetched, + value=value_fetched, + block_tables=block_table, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=swa_cu_seqlens, + token_to_batch=swa_token_to_batch, + seq_starts=swa_seq_starts, + dequant=self.kv_cache_dtype.startswith("fp8"), + kv_cache_layout="NHD", + total_tokens=swa_total_tokens, + ) + + sliding_window = ( + (self.sliding_window, 0, 0) + if self.sliding_window is not None + else (-1, -1, 0) + ) + aiter.flash_attn_varlen_func( + q=query, + k=key_fetched, + v=value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=swa_cu_seqlens, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=swa_max_seqlens, + min_seqlen_q=1, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=False, + out=output, + ) + + def extend_forward( + self, + attn_metadata: "AttentionMetaData", + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + min_seqlen_q: int, + block_table: torch.Tensor, + slot_mapping: torch.Tensor, + k_scale: Optional[torch.Tensor], + v_scale: Optional[torch.Tensor], + ): + from vllm.v1.attention.ops.merge_attn_states import merge_attn_states + + if self.sliding_window != -1: + self.extend_for_sliding_window( + attn_metadata, + query, + key_cache, + value_cache, + output, + cu_seqlens_q, + max_seqlen_q, + block_table, + k_scale, + v_scale, + ) + return + out, lse = aiter.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_q, + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + return_lse=True, + ) + assert attn_metadata.plugin_metadata.extend_metadata is not None + chunk_context_metadata = ( + attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + ) + num_chunks = chunk_context_metadata.num_chunks + workspace = chunk_context_metadata.workspace + cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk + max_seqlens = chunk_context_metadata.max_seq_lens + chunk_starts = chunk_context_metadata.chunk_starts + token_to_batch = chunk_context_metadata.token_to_batch + total_token_per_batch = chunk_context_metadata.total_token_per_batch + key_fetched, value_fetched = workspace[0], workspace[1] + chunked_output = None + chunked_lse = None + for chunk_idx in range(num_chunks): + cp_mha_gather_cache( + key_cache=key_cache, + value_cache=value_cache, + key=key_fetched, + value=value_fetched, + block_tables=block_table, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=cu_seqlens_kv[chunk_idx], + token_to_batch=token_to_batch[chunk_idx], + seq_starts=chunk_starts[chunk_idx], + dequant=self.kv_cache_dtype.startswith("fp8"), + kv_cache_layout="SHUFFLE", + total_tokens=total_token_per_batch[chunk_idx], + ) + + suf_out, suf_lse = aiter.flash_attn_varlen_func( + q=query, + k=key_fetched, + v=value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_kv[chunk_idx], + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlens[chunk_idx], + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=False, + window_size=(-1, -1, 0), + alibi_slopes=self.alibi_slopes, + return_lse=True, + ) + + if chunked_output is None: + chunked_output = suf_out + chunked_lse = suf_lse + else: + tmp_output = torch.empty_like(out) + tmp_lse = torch.empty_like(lse) + merge_attn_states( + output=tmp_output, + output_lse=tmp_lse, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=suf_out, + suffix_lse=suf_lse, + ) + chunked_output = tmp_output + chunked_lse = tmp_lse + + merge_attn_states( + output=output, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=out, + suffix_lse=lse, + ) + + def forward_impl_plugin_mode( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: "AttentionMetaData" = None, + position: torch.Tensor = None, + q_scale: torch.Tensor = None, + qkv: torch.Tensor = None, + output: torch.Tensor = None, + ): + # create the output here, it use query shape + num_tokens = query.shape[0] + output_dtype = query.dtype + output_shape = torch.Size((num_tokens, self.num_heads * self.head_size)) + output = torch.empty(output_shape, dtype=output_dtype, device=query.device) + + # dummy run will skip attention in cuda graph capture phase + if attn_metadata is None: + return output.fill_(0) + + # when using this optimization, the qkv tensor and + # position tensor are passed through q,k,v + # when not using this optimization, the position is not + # needed as the ROPE has been calculated outside of attention + if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: + assert ( + position is None + ), "position should be None because it is passed through k" + + position = key + qkv = value + + q_size = self.num_heads * self.head_dim + kv_size = self.num_kv_heads * self.head_dim + query, key, value = torch.split(qkv, [q_size, kv_size, kv_size], dim=-1) + + query = query.view(-1, self.num_heads, self.head_dim) + key = key.view(-1, self.num_kv_heads, self.head_dim) + value = value.view(-1, self.num_kv_heads, self.head_dim) + output = output.view(-1, self.num_heads, self.head_dim) + + num_actual_tokens = attn_metadata.plugin_metadata.num_actual_tokens + k_cache, v_cache = kv_cache.unbind(0) + num_blocks, block_size, num_kv_heads, _ = k_cache.shape + + if self.kv_cache_dtype == "fp8": + target_dtype = dtypes.d_dtypes[self.kv_cache_dtype] + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + # create kv scale according to the num_blocks + # usually it is created when cuda graph capture for decode phase + if self.kv_cache_dtype == "fp8": + if self.k_scale is None or self.v_scale is None: + self.kv_scale = torch.zeros( + 2, + num_blocks, + num_kv_heads, + block_size, + dtype=dtypes.fp32, + device=self.device, + ) + # update the layer kv scale tensor + self.k_scale = self.kv_scale[0] + self.v_scale = self.kv_scale[1] + layer.k_scale = self.k_scale + layer.v_scale = self.v_scale + + # rope and cache flush fusion. ATOM always use shuffle layout for kv cache + result = self.rope_cache_plugin_mode( + q=query, + k=key, + v=value, + qkv=qkv, + position=position, + attention_metadata=attn_metadata, + k_cache=k_cache, + v_cache=v_cache, + k_scale=self.k_scale, + v_scale=self.v_scale, + flash_layout=False, + ) + query, key, value, k_cache, v_cache, k_scale, v_scale = result + + # as vLLM cuda graph capture padding mechanism, here split the qkvo with + # the actual tokens + query = query[:num_actual_tokens] + if key is not None: + key = key[:num_actual_tokens] + if value is not None: + value = value[:num_actual_tokens] + output_actual_tokens = output[:num_actual_tokens] + + num_decodes = attn_metadata.plugin_metadata.num_decodes + num_prefills = attn_metadata.plugin_metadata.num_prefills + num_extends = attn_metadata.plugin_metadata.num_extends + + num_decode_tokens = attn_metadata.plugin_metadata.num_decode_tokens + num_extend_tokens = attn_metadata.plugin_metadata.num_extend_tokens + + # calculate for prefills + if num_prefills > 0: + assert attn_metadata.plugin_metadata.prefill_metadata is not None + + # prefill part is after decode and extend + prefill_query = query[num_decode_tokens + num_extend_tokens :] + prefill_key = key[num_decode_tokens + num_extend_tokens :] + prefill_value = value[num_decode_tokens + num_extend_tokens :] + + sliding_window = ( + (self.sliding_window, 0, 0) + if self.sliding_window is not None + else (-1, -1, 0) + ) + + aiter.flash_attn_varlen_func( + q=prefill_query, + k=prefill_key, + v=prefill_value, + cu_seqlens_q=attn_metadata.plugin_metadata.prefill_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.plugin_metadata.prefill_metadata.query_start_loc, + max_seqlen_q=attn_metadata.plugin_metadata.prefill_metadata.max_query_len, + max_seqlen_k=attn_metadata.plugin_metadata.prefill_metadata.max_seq_len, + min_seqlen_q=1, + dropout_p=attn_metadata.dropout_p, + softmax_scale=self.scale, + causal=True, + window_size=sliding_window, + alibi_slopes=None, + sink_ptr=self.sinks, + out=output_actual_tokens[num_decode_tokens + num_extend_tokens :], + ) + + # calculate for extends + if num_extends > 0: + assert attn_metadata.plugin_metadata.extend_metadata is not None + extend_tokens_slice = slice( + num_decode_tokens, num_decode_tokens + num_extend_tokens + ) + extend_querys = query[extend_tokens_slice] + extend_keys = key[extend_tokens_slice] + extend_values = value[extend_tokens_slice] + extend_outputs = output[extend_tokens_slice] + self.extend_forward( + attn_metadata=attn_metadata, + query=extend_querys, + key=extend_keys, + value=extend_values, + key_cache=k_cache, + value_cache=v_cache, + output=extend_outputs, + cu_seqlens_q=attn_metadata.plugin_metadata.extend_metadata.query_start_loc, + max_seqlen_q=attn_metadata.plugin_metadata.extend_metadata.max_query_len, + max_seqlen_k=attn_metadata.plugin_metadata.extend_metadata.max_seq_len, + min_seqlen_q=1, + block_table=attn_metadata.plugin_metadata.block_table[ + num_decodes : num_decodes + num_extends + ], + slot_mapping=attn_metadata.plugin_metadata.slot_mapping[ + num_decodes : num_decodes + num_extends + ], + k_scale=k_scale, + v_scale=v_scale, + ) + + # calculate for decodes + if num_decodes > 0: + assert attn_metadata.plugin_metadata.decode_metadata is not None + + num_blocks, block_size, num_kv_heads, head_size = k_cache.shape + x = 16 // k_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_cache.dtype, + device="meta", + ) + new_key_cache = k_cache.view_as(k_cache_template) + new_value_cache = v_cache.view_as(v_cache_template) + + if self.use_triton_attn: + self.paged_attention_triton_plugin_mode( + q=query[:num_decode_tokens], + k_cache=new_key_cache, + v_cache=new_value_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_actual_tokens[:num_decode_tokens], + attn_metadata=attn_metadata, + ) + else: + # Qwen only uses gluon pa decode when bs=64 + if num_decodes == _QWEN_GLUON_PA_DECODE_BS: + self.paged_attention_triton_plugin_mode( + q=query[:num_decode_tokens], + k_cache=new_key_cache, + v_cache=new_value_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_actual_tokens[:num_decode_tokens], + attn_metadata=attn_metadata, + ) + else: + self.paged_attention_asm_plugin_mode( + q=query[:num_decode_tokens], + k_cache=new_key_cache, + v_cache=new_value_cache, + k_scale=k_scale, + v_scale=v_scale, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + out=output_actual_tokens[:num_decode_tokens], + attn_metadata=attn_metadata, + ) + + output = output.view(-1, self.num_heads * self.head_dim) + + return output + + +def PagedAttentionImplDecoratorForPluginMode(cls): + method_names = [ + "process_weights_after_loading", + "rope_cache_plugin_mode", + "paged_attention_triton_plugin_mode", + "paged_attention_asm_plugin_mode", + "extend_for_sliding_window", + "extend_forward", + "forward_impl_plugin_mode", + ] + + logger.info( + "Use PagedAttentionImplDecoratorForPluginMode to decorate PagedAttentionImpl" + ) + + # Add all methods to the target class + for method_name in method_names: + method = getattr(PagedAttentionImplPluginModeMethods, method_name) + setattr(cls, method_name, method) + + return cls diff --git a/atom/plugin/config.py b/atom/plugin/config.py new file mode 100644 index 000000000..bf1af0f62 --- /dev/null +++ b/atom/plugin/config.py @@ -0,0 +1,244 @@ +import os +import sys + +from typing import Any, Optional +from dataclasses import dataclass + +import torch +import logging + +logger = logging.getLogger("atom") +_KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD = 18 * 1024 + + +@dataclass +class PluginConfig: + # common config for both framework + model_config: Any = None + rank: int = 0 + is_plugin_mode: bool = False + is_vllm: bool = False + is_sglang: bool = False + + # vllm specific + vllm_scheduler_config: Any = None + vllm_cache_config: Any = None + vllm_quant_config: Any = None + vllm_use_atom_attention: bool = False + + # sglang specific + sglang_model_opt_config: Any = None + sglang_load_config: Any = None + sglang_enable_torch_compile: bool = False + sglang_disable_cuda_graph: bool = False + sglang_enable_dp_attention: bool = False + sglang_dist_init_addr: Optional[str] = None + sglang_port_args: Any = None + + +def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: + from atom.config import Config, CompilationConfig + + vllm_model_config = config.model_config + vllm_scheduler_config = config.scheduler_config + vllm_cache_config = config.cache_config + vllm_parallel_config = config.parallel_config + vllm_use_atom_attention = bool( + os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "0" + ) + + # here use the ATOM compilation config, as the ATOM compile policy is used + # instead of vLLM one for torch compile, while for cuda graph capture, + # still use the vLLM because it has FULL_AND_PIECEWISE feature + # when you don't want to use atom torch compile, you can also use + # --enforce-eager to disable the atom torch compile when launch vllm server + compilation_config = config.compilation_config + vllm_compilation_config = CompilationConfig( + # use mode because vllm level argument is deprecated + level=compilation_config.mode, + use_cudagraph=False, + cudagraph_mode=None, + ) + + vllm_quant_config = config.quant_config + + plugin_config = PluginConfig( + # common config + model_config=vllm_model_config, + rank=vllm_parallel_config.rank, + is_plugin_mode=True, + is_vllm=True, + is_sglang=False, + # vllm specific + vllm_scheduler_config=vllm_scheduler_config, + vllm_cache_config=vllm_cache_config, + vllm_quant_config=vllm_quant_config, + vllm_use_atom_attention=vllm_use_atom_attention, + ) + + # specific + max_model_len = vllm_model_config.max_model_len + if hasattr(vllm_scheduler_config, "max_model_len"): + max_model_len = vllm_scheduler_config.max_model_len + + max_num_batched_tokens = vllm_scheduler_config.max_num_batched_tokens + # FIXME: known issue for illegal mem access in fused_moe kernel + if max_num_batched_tokens >= _KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD: + logger.warning( + "For plugin mode, when setting max_num_batched_tokens >= " + + f"{_KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD}, there is a known issue " + + "for illegal mem access in asm fused_moe kernel, if you met the issue, " + + "please set max_num_batched_tokens smaller or choose the ck fused_moe " + + "kernel instead of asm ones" + ) + + return Config( + model=vllm_model_config.model, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=vllm_scheduler_config.max_num_seqs, + max_model_len=max_model_len, + gpu_memory_utilization=vllm_cache_config.gpu_memory_utilization, + tensor_parallel_size=vllm_parallel_config.tensor_parallel_size, + enforce_eager=True, # disable using atom cuda graph + parallel_config=vllm_parallel_config, + kv_cache_block_size=vllm_cache_config.block_size, + num_kvcache_blocks=vllm_cache_config.num_gpu_blocks, + kv_cache_dtype=vllm_cache_config.cache_dtype, + enable_prefix_caching=vllm_cache_config.enable_prefix_caching, + port=None, + torch_profiler_dir=None, + compilation_config=vllm_compilation_config, + asyncio_mode=False, + load_dummy=False, + enable_expert_parallel=vllm_parallel_config.enable_expert_parallel, + master_addr=None, + enable_dp_attention=False, + plugin_config=plugin_config, + ) + + +def _generate_atom_config_from_sglang_config(config: Any): + from sglang.srt.server_args import ( + ServerArgs, + prepare_server_args, + PortArgs, + ) + from sglang.srt.configs.model_config import ModelConfig as SglangModelConfig + from sglang.srt.configs.modelopt_config import ModelOptConfig + from sglang.srt.configs.load_config import LoadConfig + from atom.config import Config, ParallelConfig, CompilationConfig + + # sglang has no global config variable like vllm, + # so here construct the server args from sys.argv passed by users + # this is the only way to get full arguments + server_args: ServerArgs = prepare_server_args(sys.argv[1:]) + + sgl_model_config = SglangModelConfig.from_server_args(server_args) + sgl_model_opt_config = ModelOptConfig( + quant=server_args.modelopt_quant, + checkpoint_restore_path=server_args.modelopt_checkpoint_restore_path, + checkpoint_save_path=server_args.modelopt_checkpoint_save_path, + export_path=server_args.modelopt_export_path, + ) + + sgl_load_config = LoadConfig( + load_format=server_args.load_format, + download_dir=server_args.download_dir, + model_loader_extra_config=server_args.model_loader_extra_config, + remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip, + remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port, + remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports, + remote_instance_weight_loader_backend=server_args.remote_instance_weight_loader_backend, + modelopt_config=sgl_model_opt_config, + rl_quant_profile=server_args.rl_quant_profile, + ) + + # sglang doesn't passed the rank number in config, so ATOM plugin + # get rank number through the torch.distributed.get_rank() + rank = torch.distributed.get_rank() + + # sglang uses the atom parallel config + sgl_parallel_config = ParallelConfig( + data_parallel_size=server_args.dp_size, + ) + + # use sglang torch compile policy and cuda graph policy + # because sglang doesn't use the compile decorator for model, + # we have no method to define self policy + sgl_compilation_config = CompilationConfig( + level=0, + use_cudagraph=False, + cudagraph_mode=None, + ) + + plugin_config = PluginConfig( + # common config + model_config=sgl_model_config, + rank=rank, + is_plugin_mode=True, + is_vllm=False, + is_sglang=True, + # sglang specific + sglang_model_opt_config=sgl_model_opt_config, + sglang_load_config=sgl_load_config, + sglang_enable_torch_compile=server_args.enable_torch_compile, + sglang_disable_cuda_graph=server_args.disable_cuda_graph, + sglang_enable_dp_attention=server_args.enable_dp_attention, + sglang_dist_init_addr=server_args.dist_init_addr, + sglang_port_args=PortArgs.init_new(server_args), + ) + + # force max num batched tokens to 16K because sgl doesn't have + # concept for max num batched tokens + return Config( + model=None, + max_num_batched_tokens=16384, + max_num_seqs=server_args.max_running_requests, + max_model_len=server_args.context_length, + gpu_memory_utilization=server_args.mem_fraction_static, + tensor_parallel_size=server_args.tp_size, + enforce_eager=True, # disable using atom cuda graph + parallel_config=sgl_parallel_config, + kv_cache_dtype=server_args.kv_cache_dtype, + enable_prefix_caching=False, + port=None, + torch_profiler_dir=None, + compilation_config=sgl_compilation_config, + asyncio_mode=False, + load_dummy=False, + enable_expert_parallel=bool(server_args.ep_size > 1), + master_addr=None, + enable_dp_attention=server_args.enable_dp_attention, + plugin_config=plugin_config, + ) + + +def generate_atom_config_for_plugin_mode(config: Any = None): + """ + Generate the atom config in plugin mode, be called when create the custom model + config: + - for vllm: config is VllmConfig and contains all config value from vllm + - for sglang: config is only model specific config passed from sglang, so the + server args is used + """ + + logger.info("Generate atom config for plugin mode from passed config") + + atom_config = None + from atom.plugin import is_vllm, is_sglang + from atom.config import set_current_atom_config + + if is_vllm(): + atom_config = _generate_atom_config_from_vllm_config(config) + elif is_sglang(): + atom_config = _generate_atom_config_from_sglang_config(config) + else: + raise ValueError( + "Make sure ATOM is running in plugin mode; " + "generate_atom_config_for_plugin_mode should be called in plugin mode." + ) + + # set the current atom config for the custom model + set_current_atom_config(atom_config) + + return atom_config diff --git a/atom/plugin/moe.py b/atom/plugin/moe.py new file mode 100644 index 000000000..f0ed21ff0 --- /dev/null +++ b/atom/plugin/moe.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Type, TypeVar +import logging +from atom.plugin.prepare import is_vllm + +logger = logging.getLogger("atom") + +T = TypeVar("T") + + +def _apply_moe_decoration(original_cls: Type[T]) -> Type[T]: + """ + Apply the actual decoration to the MoE class. + This is called lazily during instantiation. + """ + is_vllm_mode = is_vllm() + if is_vllm_mode: + # Rename this class because vLLM will call the modular + # kernel init method for all modules of the model, whose name is FusedMoE, + # to init the inside kernel, while for plugin mode, the atom maintains + # the kernel lifecycle by itself, so there is no need to call init on + # vllm side + original_cls.__name__ = "ATOMFusedMoE" + original_cls.__qualname__ = "ATOMFusedMoE" + + return original_cls + + +def FusedMoEDecoratorForPluginMode(cls: Type[T]) -> Type[T]: + """ + Lazy decorator that defers class modification until first instantiation + """ + original_cls = cls + decorated_cls_cache = {"value": None} + + def get_decorated_class(): + if decorated_cls_cache["value"] is not None: + return decorated_cls_cache["value"] + + decorated = _apply_moe_decoration(original_cls) + decorated_cls_cache["value"] = decorated + return decorated + + class LazyMoEWrapper(original_cls): + def __new__(cls, *args, **kwargs): + decorated_cls = get_decorated_class() + return decorated_cls(*args, **kwargs) + + # Preserve the original class name and module for the wrapper + LazyMoEWrapper.__name__ = original_cls.__name__ + LazyMoEWrapper.__qualname__ = original_cls.__qualname__ + LazyMoEWrapper.__module__ = original_cls.__module__ + + logger.info("Create lazy wrapper for FusedMoE to change the naming") + return LazyMoEWrapper diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py new file mode 100644 index 000000000..6ef788bdb --- /dev/null +++ b/atom/plugin/prepare.py @@ -0,0 +1,85 @@ +from typing import Any +import logging + +logger = logging.getLogger("atom") + +# all of the supported frameworks, including server mode and plugin mode +_SUPPORTED_FRAMEWORKS = ["vllm", "sglang", "sgl", "atom"] + +# supported frameworks for plugin mode +_SUPPORTED_FRAMEWORKS_FOR_PLUGIN_MODE = ["vllm", "sglang", "sgl"] + +# default is atom for server mode +_CURRENT_FRAMEWORK = "atom" + + +def is_sglang() -> bool: + global _CURRENT_FRAMEWORK + return bool(_CURRENT_FRAMEWORK.lower() in ["sglang", "sgl"]) + + +def is_vllm() -> bool: + global _CURRENT_FRAMEWORK + return bool(_CURRENT_FRAMEWORK.lower() in ["vllm"]) + + +def is_plugin_mode() -> bool: + global _CURRENT_FRAMEWORK + return bool(_CURRENT_FRAMEWORK.lower() in _SUPPORTED_FRAMEWORKS_FOR_PLUGIN_MODE) + + +def _set_framework_backbone(framework: str) -> None: + if framework.lower() not in _SUPPORTED_FRAMEWORKS: + raise ValueError(f"Unsupported framework {framework} for ATOM to plug in") + global _CURRENT_FRAMEWORK + _CURRENT_FRAMEWORK = framework + + +def prepare_model(config: Any, engine: str): + """ + Prepare the model to upper framework SGLang + """ + logger.info(f"Prepare model for plugin mode, the upper engine is {engine}") + + _set_framework_backbone(engine) + + if is_sglang(): + model_arch = config.architectures[0] + else: + raise ValueError( + f"prepare_model does not support engine {engine!r} " + f"with config type {type(config)}" + ) + + # import here to avoid partial initialization + from .register import ( + _ATOM_SUPPORTED_MODELS, + # register_ops_to_vllm, + register_ops_to_sglang, + init_aiter_dist, + set_attn_cls, + ) + + if model_arch not in _ATOM_SUPPORTED_MODELS: + supported_archs = list(_ATOM_SUPPORTED_MODELS.keys()) + raise ValueError( + f"ATOM does not support the required model architecture: {model_arch}. " + f"For now supported model architectures: {supported_archs}" + ) + + from atom.plugin.config import generate_atom_config_for_plugin_mode + + atom_config = generate_atom_config_for_plugin_mode(config) + + model_cls = _ATOM_SUPPORTED_MODELS[model_arch] + logger.info(f"ATOM model class for {model_arch} is {model_cls}") + + if is_sglang(): + register_ops_to_sglang(atom_config=atom_config) + + set_attn_cls() + + # init aiter dist for using aiter custom collective ops + init_aiter_dist(config=atom_config) + + return model_cls(atom_config=atom_config) diff --git a/atom/plugin/register.py b/atom/plugin/register.py new file mode 100644 index 000000000..d76e2e86c --- /dev/null +++ b/atom/plugin/register.py @@ -0,0 +1,97 @@ +import logging + +from atom.models.qwen3 import Qwen3ForCausalLM +from atom.models.qwen3_moe import Qwen3MoeForCausalLM +from atom.config import Config +from atom.plugin.prepare import is_vllm, is_sglang + +logger = logging.getLogger("atom") + +_ATOM_SUPPORTED_MODELS = { + "Qwen3ForCausalLM": Qwen3ForCausalLM, + "Qwen3MoeForCausalLM": Qwen3MoeForCausalLM, +} + + +def _register_custom_attention_to_sglang() -> None: + + from sglang.srt.layers.attention.attention_registry import ( + register_attention_backend, + ) + + # here register the custom attention backend with the name "aiter" + # as sglang defines the fixed attention backend choices, which must be + # in-tree + logger.info("Register custom attention backend AiterBackend to SGLang") + + @register_attention_backend("aiter") + def create_atom_backend(runner): + from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + + return AiterAttnBackend(runner) + + +def register_ops_to_sglang(atom_config: Config) -> None: + """ + Register custom ops to sglang, including attention + """ + _register_custom_attention_to_sglang() + + +def set_attn_cls() -> None: + """ + Set the attention class for constructing the model based on the framework + """ + import atom.model_ops as ops + + if is_vllm(): + ops.Attention = ops.PagedAttention + logger.info("Set Attention to PagedAttention for vLLM") + elif is_sglang(): + ops.Attention = ops.RadixAttention + logger.info("Set Attention to RadixAttention for SGLang") + + +def init_aiter_dist(config: Config) -> None: + """ + Initialize aiter dist for using aiter custom collective op + """ + logger.info( + "Initialize aiter dist for using aiter custom collective op for plugin mode" + ) + + from aiter import init_dist_env + from aiter.dist.utils import get_distributed_init_method + + rank = config.plugin_config.rank + tensor_parallel_size = config.tensor_parallel_size + + assert ( + config.plugin_config.is_plugin_mode + ), "Make sure ATOM is running in plugin mode" + + if config.plugin_config.is_vllm: + dp_master_ip = config.parallel_config.data_parallel_master_ip + dp_master_port = config.parallel_config.data_parallel_master_port + elif config.plugin_config.is_sglang: + if config.plugin_config.sglang_dist_init_addr is not None: + dp_master_ip, dp_master_port = ( + config.plugin_config.sglang_dist_init_addr.split(":") + ) + else: + dp_master_ip = "127.0.0.1" + dp_master_port = config.plugin_config.sglang_port_args.nccl_port + + distributed_init_method = get_distributed_init_method(dp_master_ip, dp_master_port) + + logger.info( + f"Initialize aiter dist for using aiter custom collective op for plugin mode, rank:{rank}" + ) + init_dist_env( + tensor_model_parallel_size=tensor_parallel_size, + rankID=rank, + backend="nccl", + distributed_init_method=distributed_init_method, + data_parallel_size=config.parallel_config.data_parallel_size, + data_parallel_rank=config.parallel_config.data_parallel_rank, + ) diff --git a/atom/plugin/vllm/__init__.py b/atom/plugin/vllm/__init__.py new file mode 100644 index 000000000..a76f02676 --- /dev/null +++ b/atom/plugin/vllm/__init__.py @@ -0,0 +1,5 @@ +"""vLLM plugin integration for ATOM.""" + +from .register import register_model, register_platform + +__all__ = ["register_platform", "register_model"] diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py new file mode 100644 index 000000000..bcca17ad1 --- /dev/null +++ b/atom/plugin/vllm/model_wrapper.py @@ -0,0 +1,141 @@ +from collections.abc import Iterable + +import importlib +import torch +import torch.nn as nn +from aiter.dist.parallel_state import ( + get_pp_group, + get_tp_group, +) +from vllm.config import VllmConfig +from vllm.model_executor.models.interfaces import ( + SupportsPP, + SupportsQuant, +) +from vllm.model_executor.models.interfaces_base import ( + VllmModel, + VllmModelForTextGeneration, +) +from vllm.sequence import IntermediateTensors + +import atom # noqa: F401 +from atom.plugin.config import generate_atom_config_for_plugin_mode + +import logging + +logger = logging.getLogger("atom") + + +_ATOM_MODEL_CLASSES: dict[str, str] = { + "Qwen3ForCausalLM": "atom.models.qwen3:Qwen3ForCausalLM", + "Qwen3MoeForCausalLM": "atom.models.qwen3_moe:Qwen3MoeForCausalLM", +} + + +def _get_atom_model_cls(model_arch: str) -> type: + if model_arch is not None and model_arch in _ATOM_MODEL_CLASSES: + model_ref = _ATOM_MODEL_CLASSES[model_arch] + else: + raise ValueError(f"The {model_arch} is not supported by ATOM OOT backend") + + module_path, class_name = model_ref.split(":", 1) + return getattr(importlib.import_module(module_path), class_name) + + +def _prepare_env(atom_config) -> None: + from atom.plugin.register import set_attn_cls, init_aiter_dist + + # set global attention class + logger.info("Set global attention class") + set_attn_cls() + + # init aiter dist for using aiter custom collective ops + logger.info("Init aiter dist for using aiter custom collective ops") + init_aiter_dist(config=atom_config) + + +class ATOMModelBase(nn.Module, VllmModel, SupportsQuant, SupportsPP): + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.config = vllm_config.model_config.hf_config + self.text_config = self.config.get_text_config() + self.cache_config = vllm_config.cache_config + self.device_config = vllm_config.device_config + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.quant_config = vllm_config.quant_config + + # Weights to skip in `self.load_weights` + self.skip_prefixes: list[str] = [] + self.skip_substrs: list[str] = [] + self.ignore_unexpected_prefixes: list[str] = [] + self.ignore_unexpected_suffixes: list[str] = [] + + self.atom_config = generate_atom_config_for_plugin_mode(vllm_config) + + _prepare_env(atom_config=self.atom_config) + + model_arch = vllm_config.model_config.architectures[0] + model_cls = _get_atom_model_cls(model_arch) + + logger.info(f"Construct ATOM model {model_arch} for vLLM plugin mode") + self.model = model_cls(self.atom_config) + + if self.model is None: + model_arch = vllm_config.model_config.architectures[0] + raise ValueError( + f"The model {model_arch} is not supported by model impl backend atom" + ) + + # here init aiter dist for using aiter custom collective ops + self.pp_group = get_pp_group() + self.tp_group = get_tp_group() + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs, + ) -> torch.Tensor | IntermediateTensors: + if not self.pp_group.is_first_rank: + assert intermediate_tensors is not None + input_ids = None + inputs_embeds = intermediate_tensors["hidden_states"] + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + if not self.pp_group.is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + return hidden_states + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + return self.model.load_weights(weights) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + logits = self.model.compute_logits(hidden_states) + return logits + + +class ATOMForCausalLM(ATOMModelBase, VllmModelForTextGeneration): ... + + +class ATOMMoEForCausalLM(ATOMModelBase, VllmModelForTextGeneration): ... diff --git a/atom/plugin/vllm/platform.py b/atom/plugin/vllm/platform.py new file mode 100644 index 000000000..d003c4ae0 --- /dev/null +++ b/atom/plugin/vllm/platform.py @@ -0,0 +1,37 @@ +"""ATOM vLLM platform integration. + +This module contains the vLLM `Platform` subclass used in ATOM's vLLM plugin +mode. Keep platform behavior here so `register.py` can focus on registration +and wiring only. +""" + +import logging +import os + +logger = logging.getLogger("atom") + +# This flag is used to enable the vLLM plugin mode. +disable_vllm_plugin = os.getenv("ATOM_DISABLE_VLLM_PLUGIN", "0").lower() == "1" +disable_vllm_plugin_attention = ( + os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "1" +) + +if not disable_vllm_plugin: + from vllm.platforms.rocm import RocmPlatform + + class ATOMPlatform(RocmPlatform): + # For multi-modality models, to make AiterBackend supported by ViT, + # get_supported_vit_attn_backends may need to be overridden here + @classmethod + def get_attn_backend_cls(cls, selected_backend, attn_selector_config) -> str: + if disable_vllm_plugin_attention: + logger.info("Fallback to original vLLM attention backend") + return super().get_attn_backend_cls( + selected_backend, attn_selector_config + ) + + logger.info("Use atom attention backend") + return "atom.model_ops.attentions.aiter_attention.AiterBackend" + +else: + ATOMPlatform = None diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py new file mode 100644 index 000000000..0330929a5 --- /dev/null +++ b/atom/plugin/vllm/register.py @@ -0,0 +1,110 @@ +import os +from typing import Optional +import logging + +import torch +from atom.plugin.prepare import _set_framework_backbone + +logger = logging.getLogger("atom") + +# this flag is used to enable the vllm plugin mode +disable_vllm_plugin = os.getenv("ATOM_DISABLE_VLLM_PLUGIN", "0").lower() == "1" +disable_vllm_plugin_attention = ( + os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "1" +) + +# those 2 models are covering most of dense and moe models +ATOM_CAUSAL_LM_MODEL_WRAPPER = "atom.plugin.vllm.model_wrapper:ATOMForCausalLM" +ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER = "atom.plugin.vllm.model_wrapper:ATOMMoEForCausalLM" + +# when register new model to vllm, add here +# Keys is from hf config arch name +_VLLM_MODEL_REGISTRY_OVERRIDES: dict[str, str] = { + "Qwen3ForCausalLM": ATOM_CAUSAL_LM_MODEL_WRAPPER, + "Qwen3MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, +} + + +def _set_plugin_mode() -> None: + _set_framework_backbone("vllm") + + +def register_platform() -> Optional[str]: + + if disable_vllm_plugin: + # return None instead of error because the flag can be used to + # run pure vllm mode without ATOM plugin + logger.info("Disable ATOM OOT plugin platforms") + return None + + _set_plugin_mode() + + # return the ATOM platform to vllm + return "atom.plugin.vllm.platform.ATOMPlatform" + + +def _patch_vllm_attention_process_weights_after_loading() -> None: + try: + from vllm.attention.layer import Attention + except ImportError: + from vllm.model_executor.layers.attention import Attention + + orig = Attention.process_weights_after_loading + + if getattr(orig, "_atom_default_act_dtype_patched", False): + return + + try: + import inspect + + sig = inspect.signature(orig) + act_dtype_param = sig.parameters.get("act_dtype") + if ( + act_dtype_param is not None + and act_dtype_param.default is not inspect._empty + ): + return + except Exception: + pass + + import functools + + @functools.wraps(orig) + def wrapped(self, act_dtype: "torch.dtype" = torch.bfloat16): + return orig(self, act_dtype) + + setattr(wrapped, "_atom_default_act_dtype_patched", True) + Attention.process_weights_after_loading = wrapped + + +def register_model() -> None: + if disable_vllm_plugin: + logger.info("Disable ATOM model register") + return + + import vllm.model_executor.models.registry as vllm_model_registry + + any_updated = False + for arch, qual in _VLLM_MODEL_REGISTRY_OVERRIDES.items(): + module_name, class_name = qual.split(":", 1) + existing = vllm_model_registry.ModelRegistry.models.get(arch) + if existing is not None: + # If already overridden to the same target, skip re-registering. + if ( + getattr(existing, "module_name", None) == module_name + and getattr(existing, "class_name", None) == class_name + ): + continue + + logger.info(f"Register model {arch} to vLLM with {qual}") + vllm_model_registry.ModelRegistry.register_model(arch, qual) + any_updated = True + + # clear lru cache + if any_updated: + vllm_model_registry._try_load_model_cls.cache_clear() + vllm_model_registry._try_inspect_model_cls.cache_clear() + + # patch attention process weights after loading + # to avoid the specific handle in ATOM loader + _patch_vllm_attention_process_weights_after_loading() diff --git a/atom/utils/backends.py b/atom/utils/backends.py index b583ba447..57eec887a 100644 --- a/atom/utils/backends.py +++ b/atom/utils/backends.py @@ -255,6 +255,24 @@ class SplitItem: graph: fx.GraphModule +# used to judge whether the node should be split or not +def _split_judge_func(node: fx.Node) -> bool: + # ATOM use mark_spliting_op to mark the attn as splitting op + if node.op == "call_function" and ( + hasattr(node.target, "spliting_op") and (node.target.spliting_op) + ): + return True + + # When plugin mode(vLLM), the attention impl op is registered + # as unified_attention + from atom.plugin import is_vllm + + if is_vllm() and "unified_attention" in node.name: + return True + + return False + + def split_graph( graph: fx.GraphModule, ops: list[str] ) -> tuple[fx.GraphModule, list[SplitItem]]: @@ -265,9 +283,7 @@ def split_graph( for node in graph.graph.nodes: if node.op in ("output", "placeholder"): continue - if node.op == "call_function" and ( - hasattr(node.target, "spliting_op") and (node.target.spliting_op) - ): + if _split_judge_func(node): subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id split_op_graphs.append(subgraph_id) diff --git a/atom/utils/forward_context.py b/atom/utils/forward_context.py index 771c34053..9b374620f 100644 --- a/atom/utils/forward_context.py +++ b/atom/utils/forward_context.py @@ -3,12 +3,15 @@ from contextlib import contextmanager from dataclasses import dataclass, field, fields -from typing import Any, Dict, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Union import numpy as np import torch from atom.config import Config, KVCacheTensor, ParallelConfig +if TYPE_CHECKING: + from atom.plugin.attention import MetadataForPluginMode + def _compute_chunked_local_num_tokens( num_tokens_across_dp_cpu: list[int], max_num_tokens: int, chunk_idx: int @@ -186,6 +189,9 @@ class AttentionMetaData: block_tables_converted: Optional[torch.Tensor] = None + # only used for plugin mode to store the metadata for attn + plugin_metadata: Optional["MetadataForPluginMode"] = None + def __init__( self, cu_seqlens_q: Optional[torch.Tensor] = None, @@ -212,6 +218,7 @@ def __init__( block_tables_converted: Optional[torch.Tensor] = None, sparse_cu_seqlens_q: Optional[torch.Tensor] = None, token_to_seq_idxs: Optional[torch.Tensor] = None, + plugin_metadata: Optional["MetadataForPluginMode"] = None, ): self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k @@ -238,6 +245,8 @@ def __init__( self.block_tables = block_tables_converted self.sparse_cu_seqlens_q = sparse_cu_seqlens_q self.token_to_seq_idxs = token_to_seq_idxs + if plugin_metadata is not None: + self.plugin_metadata = plugin_metadata def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" diff --git a/pyproject.toml b/pyproject.toml index 49d3d8c0b..7e6c623bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,6 @@ [build-system] -requires = ["setuptools_scm[toml]>=6.2"] +requires = ["setuptools>=61", "wheel", "setuptools_scm[toml]>=6.2"] +build-backend = "setuptools.build_meta" [project] name = "atom" @@ -24,3 +25,15 @@ fallback_version = "0.1.0" [tool.setuptools.packages.find] where = ["."] include = ["atom*"] + +[project.entry-points."vllm.platform_plugins"] +# by default the entry-points will be installed anyway +# but the plugin mode for platforms can be disabled by +# ATOM_DISABLE_VLLM_PLUGIN=1 +atom = "atom.plugin.vllm.register:register_platform" + +[project.entry-points."vllm.general_plugins"] +# by default the entry-points will be installed anyway +# but the plugin mode for models can be disabled by +# ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1 +atom_model_registry = "atom.plugin.vllm.register:register_model" diff --git a/recipes/SGLang-ATOM-Model-Impl-Backend.md b/recipes/SGLang-ATOM-Model-Impl-Backend.md new file mode 100644 index 000000000..fe409a269 --- /dev/null +++ b/recipes/SGLang-ATOM-Model-Impl-Backend.md @@ -0,0 +1,72 @@ +# Model Impl Backend of SGLang +ATOM can work as model implementation backend of popular framework, like SGLang. The users can launch the server like before and specify an extra argument to enable the ATOM model backend, where the optimized implementation of the required target model will be provided to SGLang to execute. When ATOM working under this mode, both framework-level features from SGLang and latest model-level fusion kernels from ATOM/AITER can be combined together to achieve the competitive performance. + +- Here is a detailed design slide for this feature: https://amdcloud-my.sharepoint.com/:p:/g/personal/zejchen_amd_com/IQCFdvmEeLTWT7ysApmZv_hVAfw2nTo8iesJZGblHS0evqQ?e=hjnIDM +- Here is the RFC to introduce the ATOM as model impl backend into SGLang: TODO + +## Preparing environment for SGLang with ATOM model backend +Here is the PR to introduce ATOM into SGLang: https://github.com/sgl-project/sglang/pull/16944, when this PR would be merged, the official SGLang can be used, but for now you need to use develop vllm branch +Pull the latest docker from SGLang official nightly docker for ROCm from https://hub.docker.com/r/rocm/sgl-dev/tags +```bash +docker pull rocm/sgl-dev:v0.5.8-rocm720-mi35x-20260130-preview +``` +Launch the container as usual, then all the next operations will be executed inside the container +Then the specific SGLang should be used because the PR to introduce the ATOM into SGLang has not been merged yet, so you need to: +```bash +git clone https://github.com/zejunchen-zejun/sglang.git +git checkout remotes/origin/zejun/model_impl +pip uninstall sglang -y +pip uninstall sgl-kernel -y +cd sgl-kernel +python3 setup_rocm.py install +export PYTHONPATH= +``` +Then the ATOM should be installed +```bash +git clone https://github.com/zejunchen-zejun/ATOM.git +cd ATOM +git checkout origin/zejun/plugin_for_atom_1223 +pip install -e . 2>&1 | tee build.log +``` +For AITER, there is no specific requirements, however, if you find any latest fusion kernels are missing, you may need to upgrade the AITER + +### Launching server of SGLang with ATOM model backend +You just need to deploy single code change, as add --model-impl atom to your SGLang server command. Here is an example: +```bash +export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 + +# quick allreduce +export AITER_QUICK_REDUCE_QUANTIZATION=INT4 +model_path=/data/models/Qwen3-235B-A22B-Instruct-2507-FP8 + +python3 -m sglang.launch_server \ + --model-path $model_path \ + --host localhost \ + --port 8000 \ + --trust-remote-code \ + --tensor-parallel-size 8 \ + --expert-parallel-size 8 \ + --kv-cache-dtype fp8_e4m3 \ + --mem-fraction-static 0.8 \ + --model-impl atom \ + 2>&1 | tee log.serve.log & +``` + +### Launching client for validating the accuracy +```bash +addr=localhost +port=8000 +url=http://${addr}:${port}/v1/completions +model= +task=gsm8k +lm_eval --model local-completions \ + --model_args model=${model},base_url=${url},num_concurrent=65,max_retries=1,tokenized_requests=False \ + --tasks ${task} \ + --num_fewshot 3 \ + 2>&1 | tee log.lmeval.log +``` + +### Known Limitations +There are some known limitations for now: +- Only Qwen-Dense and Qwen-MoE family models are supported +- Only TP and EP are supported diff --git a/recipes/vLLM-ATOM-OOT-Plugin-Backend.md b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md new file mode 100644 index 000000000..f3c2ba085 --- /dev/null +++ b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md @@ -0,0 +1,89 @@ +# vLLM out-of-tree ATOM Plugin Backend +ATOM can work as the OOT plugin backend of vLLM. The OOT register mechanism is quite mature and most of accelerators have leveraged this design to register their devices into vLLM without any code changes in upper framework. ATOM follows this design and provide the layer/op and model implementations to vLLM. The frontend users can launch vLLM server like before and there is no need to specify any arguments. Meanwhile the ATOM platform can leverage most of the vLLM features and focus more on model- and kernel-level optimizations. For the overall design, here is a RFC to enable ATOM work as the OOT plugin platform of vLLM: https://github.com/ROCm/ATOM/issues/201 + +## Preparing environment for vLLM with ATOM model backend +Pull the vLLM official docker for ROCm. If you are using the vLLM nightly docker, there could be incompatible error because vLLM is changing its code and may break the class/module import in ATOM +```bash +docker pull rocm/vllm-dev:nightly_main_20260118 +``` + +Then the ATOM should be installed. When the following PR merged, you can use ATOM main branch +```bash +git clone https://github.com/zejunchen-zejun/ATOM.git +cd ATOM +git checkout origin/zejun/plugin_for_atom_1223 +pip install -e . 2>&1 | tee build.log +``` +For AITER, you can build AITER with the latest main branch. The AITER version requirement for ATOM OOT is consistent with the AITER version requirement of the ATOM server mode. +Additionally, you may need to install some dependencies by: +```bash +pip install --upgrade triton +pip install transformers==5.0.0 +pip install git+https://github.com/foundation-model-stack/fastsafetensors.git +``` + +### Launching server of vLLM with ATOM OOT Plugin Platform +There is no code change to vLLM side, so you can launch the vLLM server like before without any specific argument +```bash +export SAFETENSORS_FAST_GPU=1 +export VLLM_ROCM_USE_AITER=1 +export VLLM_RPC_TIMEOUT=1800000 + +export VLLM_CACHE_ROOT=/root/.cache/vllm +export TORCHINDUCTOR_CACHE_DIR=/root/.cache/inductor + +rm -rf /root/.cache/ + +model_path= + +vllm serve $model_path \ + --host localhost \ + --port 8000 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --trust-remote-code \ + --disable-log-requests \ + --gpu_memory_utilization 0.9 \ + --async-scheduling \ + --load-format fastsafetensors \ + --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \ + --kv-cache-dtype fp8 \ + --max-num-batched-tokens 18432 \ + --max-model-len 16384 \ + --no-enable-prefix-caching \ + 2>&1 | tee log.serve.log & +``` + +If you want to disable the ATOM OOT plugin platform and model register, you can use below env flags. The default value is 0 +```bash +export ATOM_DISABLE_VLLM_PLUGIN=1 +``` +If you only want to disable the ATOM Attention Backend, you can use below env flags. The default value is 0 +```bash +export ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1 +``` + +### Launching client for validating the accuracy +```bash +addr=localhost +port=8000 +url=http://${addr}:${port}/v1/completions +model= +task=gsm8k +lm_eval --model local-completions \ + --model_args model=${model},base_url=${url},num_concurrent=65,max_retries=1,tokenized_requests=False \ + --tasks ${task} \ + --num_fewshot 3 \ + 2>&1 | tee log.lmeval.log +``` + +### Results for accuracy validation +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 3|exact_match|↑ |0.9037|± |0.0081| +| | |strict-match | 3|exact_match|↑ |0.8832|± |0.0088| + +### Known Limitations +There are some known limitations for now: +- Only Qwen-Dense and Qwen-MoE family models are supported +- Only TP and EP are supported