diff --git a/atom/__init__.py b/atom/__init__.py index c1f9ed8b2..dde5eeb84 100644 --- a/atom/__init__.py +++ b/atom/__init__.py @@ -3,3 +3,12 @@ from atom.model_engine.llm_engine import LLMEngine from atom.sampling_params import SamplingParams + +# interface for upper framework to constructe the model from ATOM +from atom.plugin import prepare_model + +__all__ = [ + "LLMEngine", + "SamplingParams", + "prepare_model", +] diff --git a/atom/config.py b/atom/config.py index 677ae41c9..9ecbdb365 100644 --- a/atom/config.py +++ b/atom/config.py @@ -17,6 +17,10 @@ from torch.distributed import ProcessGroup, ReduceOp from transformers import AutoConfig, GenerationConfig, PretrainedConfig +# plugin-related utilities +from atom.plugin import is_plugin_mode +from atom.plugin.config import PluginConfig + logger = logging.getLogger("atom") @@ -584,6 +588,9 @@ class Config: torch_dtype: torch.dtype = field(init=False) speculative_config: Optional[SpeculativeConfig] = None + # only use for plugin mode + plugin_config: Optional[PluginConfig] = None + def _set_cudagraph_sizes(self): if self.compilation_config.cudagraph_capture_sizes: self.graph_bs = self.compilation_config.cudagraph_capture_sizes @@ -626,16 +633,25 @@ def __post_init__(self): self.max_model_len, hf_config_max_position_embeddings ) # assert self.max_num_batched_tokens >= self.max_model_len - if self.torch_profiler_dir is not None: - os.makedirs(self.torch_profiler_dir, exist_ok=True) - assert self.torch_profiler_dir is None or os.path.isdir( - self.torch_profiler_dir - ), f"torch_profiler_dir {self.torch_profiler_dir} is not a valid directory" - if self.compilation_config.level == CompilationLevel.PIECEWISE: - self.compilation_config.set_splitting_ops_for_v1() - self._set_cudagraph_sizes() - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - self.compilation_config.init_with_cudagraph_sizes() + if not is_plugin_mode(): + if self.torch_profiler_dir is not None: + os.makedirs(self.torch_profiler_dir, exist_ok=True) + assert self.torch_profiler_dir is None or os.path.isdir( + self.torch_profiler_dir + ), f"torch_profiler_dir {self.torch_profiler_dir} is not a valid directory" + + # only for server mode or plugin mode(vllm) + # for torch compile policy, plugin mode(vllm) uses the ATOM compile policy + # for cuda graph capture, plugin mode(vllm) uses the vLLM's cuda graph capture policy + if not is_plugin_mode() or ( + self.plugin_config is not None and self.plugin_config.is_vllm + ): + if self.compilation_config.level == CompilationLevel.PIECEWISE: + self.compilation_config.set_splitting_ops_for_v1() + self._set_cudagraph_sizes() + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + self.compilation_config.init_with_cudagraph_sizes() + self.torch_dtype = ( self.hf_config.torch_dtype if getattr(self.hf_config, "torch_dtype", None) is not None diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 6861af392..0c208c905 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -608,7 +608,9 @@ def __init__(self, rank: int, config: Config): self.drafter.load_model(self.model) torch.set_default_device(self.device) self.allocate_forward_vars() - self.attn_metadata_builder = self.attn_backend.get_builder_cls()(self) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + model_runner=self + ) self.physical_block_size = self.attn_metadata_builder.block_size self.forward_done_event = torch.cuda.Event() self.warmup_model() @@ -1171,7 +1173,7 @@ def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None): self.forward_vars["cu_seqlens_q"].np[scheduled_bs + 1 : bs + 1] = ( self.forward_vars["cu_seqlens_q"].np[scheduled_bs] ) - attn_metadata, positions = self.attn_metadata_builder.build(batch, bs) + attn_metadata, positions = self.attn_metadata_builder.build(batch=batch, bs=bs) context_bs = batch.total_seqs_num_prefill if is_prefill else scheduled_bs # graph_bs should be batch size (number of sequences), not token count @@ -1472,7 +1474,7 @@ def capture_cudagraph(self): ) attn_metadata, context = ( - self.attn_metadata_builder.build_for_cudagraph_capture(bs) + self.attn_metadata_builder.build_for_cudagraph_capture(bs=bs) ) num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index 032b51385..a38c68883 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -29,6 +29,7 @@ get_spec_layer_idx_from_weight_name, rewrite_spec_layer_name, ) +from atom.plugin.prepare import is_vllm, is_sglang logger = logging.getLogger("atom") @@ -80,13 +81,61 @@ def safetensors_weights_iterator( yield name, f.get_tensor(name) +# when plugin mode, model loader method is bind to model implementation +# thus call this interface to load the model, which leverages the load_model +# method +def load_model_in_plugin_mode( + model, + config, + prefix: str = "", +) -> set[str]: + + # during loading model, the outplace operation may consume more + # GPU mem, which cached in torch caching allocator, here actively + # call empty cache to free the extra reserved but not used memory + def _empty_cache(): + import gc + + gc.collect() + torch.cuda.empty_cache() + + assert ( + config.plugin_config is not None and config.plugin_config.is_plugin_mode + ), "ATOM is not running in plugin mode" + if config.plugin_config.is_vllm: + model_name_or_path = config.plugin_config.model_config.model + elif config.plugin_config.is_sglang: + model_name_or_path = config.plugin_config.model_config.model_path + + _empty_cache() + loaded_weights_record = load_model( + model=model, + model_name_or_path=model_name_or_path, + hf_config=config.hf_config, + load_dummy=config.load_dummy, + spec_decode=False, + prefix=prefix, + is_plugin_mode=True, + act_dtype=config.plugin_config.model_config.dtype, + ) + _empty_cache() + return loaded_weights_record + + def load_model( model: nn.Module, model_name_or_path: str, hf_config: AutoConfig, load_dummy: bool = False, spec_decode: bool = False, + prefix: str = "", + is_plugin_mode: bool = False, + act_dtype: torch.dtype = None, ): + # need to record the loaded weight name for vllm load check + # it is only used in plugin mode for vllm + loaded_weights_record: set[str] = set() + packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) weights_mapping = getattr(model, "weights_mapping", {}) params_dict = dict(model.named_parameters()) @@ -145,6 +194,7 @@ def load_model( weight_loader, param, weight_tensor, shard_id ) ) + loaded_weights_record.add(prefix + param_name) break else: # Check if model has expert mapping before processing @@ -170,6 +220,7 @@ def load_model( expert_id, ) ) + loaded_weights_record.add(prefix + name) # weight_loader( # param, # weight_tensor, @@ -186,6 +237,7 @@ def load_model( futures.append( executor.submit(weight_loader, param, weight_tensor) ) + loaded_weights_record.add(prefix + name) # weight_loader(param, weight_tensor) else: # Model doesn't have expert mapping, use generic loading @@ -195,14 +247,29 @@ def load_model( ) # weight_loader(param, weight_tensor) futures.append(executor.submit(weight_loader, param, weight_tensor)) + loaded_weights_record.add(prefix + name) # Wait for all tasks to complete and raise any exceptions. for future in concurrent.futures.as_completed(futures): future.result() for _, module in model.named_modules(): if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() + if is_vllm(): + from vllm.attention.layer import Attention + + # call vLLM attn weights post processing with act_dtype if using vLLM attention module + if isinstance(module, Attention): + module.process_weights_after_loading(act_dtype=act_dtype) + else: + module.process_weights_after_loading() + else: + module.process_weights_after_loading() quant_method = getattr(module, "quant_method", None) - if isinstance(quant_method, QuantizeMethodBase): + # when running plugin mode for sglang, don't do the post process here + # since sglang will call this func automatically after finishing loading + if isinstance(quant_method, QuantizeMethodBase) and not is_sglang(): quant_method.process_weights_after_loading(module) if isinstance(quant_method, FusedMoEMethodBase): quant_method.init_prepare_finalize(module) + + if is_plugin_mode: + return loaded_weights_record diff --git a/atom/model_ops/__init__.py b/atom/model_ops/__init__.py new file mode 100644 index 000000000..8bce2960e --- /dev/null +++ b/atom/model_ops/__init__.py @@ -0,0 +1,14 @@ +from .paged_attention import PagedAttention +from .radix_attention import RadixAttention + +# This global class is used to construct the attention op in model, +# it can be assigned to different attention ops. +# By default, PagedAttention is used. +# For sglang, RadixAttention will be assigned to ATTN_CLS +ATTN_CLS = PagedAttention + +__all__ = [ + "ATTN_CLS", + "PagedAttention", + "RadixAttention", +] diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 31f38bb6a..6626b908b 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -15,8 +15,15 @@ from .attention_mla import MLAModules +from atom.plugin.prepare import is_plugin_mode, is_vllm +from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode -class Attention(nn.Module): + +@PagedAttentionImplDecoratorForPluginMode +class PagedAttentionImpl(nn.Module): + """ + Attention paged implementation + """ def __init__( self, @@ -24,11 +31,15 @@ def __init__( head_dim, scale, num_kv_heads, + alibi_slopes: list[float] | None, + sliding_window: Optional[int] = None, kv_cache_dtype="bf16", + logits_soft_cap: float | None = None, + attn_type=None, + kv_sharing_target_layer_name: int | None = None, layer_num=0, mla_modules: Optional[MLAModules] = None, sinks: Optional[nn.Parameter] = None, - sliding_window: Optional[int] = None, rotary_emb: Optional[torch.nn.Module] = None, q_norm: Optional[torch.nn.Module] = None, k_norm: Optional[torch.nn.Module] = None, @@ -37,12 +48,16 @@ def __init__( super().__init__() self.num_heads = num_heads self.head_dim = head_dim + # for upper framework, it uses head_size in built-in methods + self.head_size = head_dim self.scale = scale self.num_kv_heads = num_kv_heads + self.alibi_slopes = alibi_slopes self.k_cache = self.v_cache = torch.tensor([]) self.kv_cache_dtype = kv_cache_dtype self.max_model_len = 0 self.k_scale = self.v_scale = None + self.device = "cuda:" + str(torch.cuda.current_device()) self.layer_num = layer_num self.kv_scale_float = ( torch.finfo(torch.float8_e4m3fn).max / torch.finfo(aiter.dtypes.fp8).max @@ -56,7 +71,16 @@ def __init__( self.q_norm = q_norm self.k_norm = k_norm - def forward( + # for plugin mode(vllm), the query quant is disabled for now + if is_vllm(): + self.supports_quant_query_input = False + + # this method will just be called by vLLM and there is no logic in this method + # as ATOM handles the process after loading weights for all ops by itself + def process_weights_after_loading(self, act_dtype: torch.dtype = torch.bfloat16): + pass + + def forward_impl_server_mode( self, q: torch.Tensor, k: torch.Tensor, @@ -414,3 +438,39 @@ def dispatch_backend(self, fwd_ctx: ForwardContext): if atom_config.kv_cache_block_size == 1024: return self.paged_attention_persistent_asm return self.paged_attention_asm + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor = None, + attn_metadata=None, + position: torch.Tensor = None, + q_scale: Optional[torch.Tensor] = None, + qkv: torch.Tensor = None, + output: torch.Tensor = None, + **kwargs, + ): + if is_plugin_mode(): + # forward impl method are added by the decorator + # PagedAttentionImplDecoratorForPluginMode + return self.forward_impl_plugin_mode( + layer=layer, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + position=position, + q_scale=q_scale, + qkv=qkv, + ) + else: + # only for server mode, keep the original method + o = self.forward_impl_server_mode( + q=query, k=key, v=value, position=position, q_scale=q_scale, qkv=qkv + ) + + return o diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 6b2452cd1..30ebc2fd3 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -93,7 +93,7 @@ class MLAAttention(nn.Module): def __init__( self, num_heads: int, - head_size: int, + head_dim: int, scale: float, num_kv_heads: int, kv_cache_dtype: str, @@ -104,7 +104,7 @@ def __init__( ) -> None: super().__init__() self.num_heads = num_heads - self.head_size = head_size + self.head_dim = head_dim self.scale = float(scale) self.num_kv_heads = num_kv_heads self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype == "fp8" else "auto" diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index 0bcea5a6c..0489e3b21 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -9,34 +9,60 @@ import torch from aiter.dist.parallel_state import get_tp_group from atom.model_engine.scheduler import ScheduledBatch -from atom.model_ops.attention_mha import Attention from atom.utils import CpuGpuBuffer +import atom.model_ops as ops +from atom.model_ops.paged_attention import PagedAttention +from atom.model_ops.attention_mha import PagedAttentionImpl +from atom.model_ops.radix_attention import RadixAttention from atom.utils.block_convert import block_table_convert_triton from atom.utils.forward_context import AttentionMetaData, Context from .backends import AttentionBackend, CommonAttentionBuilder +from atom.plugin.prepare import is_plugin_mode +from atom.plugin.attention import AiterAttentionMetadataBuilderDecoratorForPluginMode +from atom.plugin.attention import AiterBackendDecoratorForPluginMode +@AiterBackendDecoratorForPluginMode class AiterBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "ROCM_AITER_ATTENTION" + return "ROCM_AITER_ATTENTION" if not is_plugin_mode() else "CUSTOM" @staticmethod def get_builder_cls() -> Type["AiterAttentionMetadataBuilder"]: return AiterAttentionMetadataBuilder @staticmethod - def get_impl_cls() -> Type["Attention"]: - return Attention + def get_impl_cls(): + attn_cls = ops.ATTN_CLS + if attn_cls == PagedAttention: + return PagedAttentionImpl + elif attn_cls == RadixAttention: + raise NotImplementedError("RadixAttention is not supported for now") + raise NotImplementedError( + f"Unsupported attention class {attn_cls!r} configured in ops.ATTN_CLS" + ) -class AiterAttentionMetadataBuilder(CommonAttentionBuilder): +@AiterAttentionMetadataBuilderDecoratorForPluginMode( + default_base_class=CommonAttentionBuilder +) +class AiterAttentionMetadataBuilder: BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] - def __init__(self, model_runner): + def __init__( + self, + kv_cache_spec=None, + layer_names=None, + config=None, + device=None, + model_runner=None, + ): self.block_size = 1024 if model_runner.block_size == 1024 else 16 - super().__init__(model_runner) + # Note: Cannot use super() here because the class is dynamically created by decorator + # Use explicit parent class call instead + CommonAttentionBuilder.__init__(self, model_runner) config = model_runner.config hf_config = config.hf_config self.num_attention_heads = ( diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index e12c1a386..4a9fd1720 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -3,6 +3,7 @@ # from flash_attn import flash_attn_with_kvcache from typing import Optional +from abc import ABC, abstractmethod import torch from torch import nn @@ -11,7 +12,6 @@ from atom.utils import mark_spliting_op from .attention_mla import MLAModules from atom.config import get_current_atom_config -from atom.utils.selector import get_attn_backend def fake_( @@ -51,10 +51,26 @@ def unified_attention_with_output_base( ) -> torch.Tensor: atom_config = get_current_atom_config() self = atom_config.compilation_config.static_forward_context[layer_name] - return self.impl.forward(q, k, v, positions, q_scale, qkv) + if use_mla: + return self.impl.forward(q, k, v, positions, q_scale, qkv) + else: + return self.impl.forward( + layer=self, + query=q, + key=k, + value=v, + position=positions, + q_scale=q_scale, + qkv=qkv, + ) + +class BaseAttention(nn.Module, ABC): + """ + Abstract base class for attention -class Attention(nn.Module): + This class defines the interface that all attention implementations must follow + """ def __init__( self, @@ -70,68 +86,20 @@ def __init__( per_layer_sliding_window: Optional[int] = None, rotary_emb: Optional[torch.nn.Module] = None, prefix: Optional[str] = None, - q_norm: Optional[torch.nn.Module] = None, - k_norm: Optional[torch.nn.Module] = None, **kwargs, ): super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - self.scale = scale - self.num_kv_heads = num_kv_heads - self.k_cache = self.v_cache = torch.tensor([]) - self.kv_cache_dtype = kv_cache_dtype - self.max_model_len = 0 - self.k_scale = self.v_scale = None - self.layer_num = layer_num - self.mla_modules = mla_modules - self.use_mla = use_mla - self.base_attention = None - self.kv_cache = torch.tensor([]) - self.indexer = mla_modules.indexer if mla_modules is not None else None - self.sinks = sinks - - atom_config = get_current_atom_config() - dtype = atom_config.torch_dtype - block_size = atom_config.kv_cache_block_size - self.attn_backend = get_attn_backend( - block_size, - use_mla=self.use_mla, - ) - impl_cls = self.attn_backend.get_impl_cls() - self.impl = impl_cls( - num_heads, - head_dim, - scale, - num_kv_heads, - kv_cache_dtype, - layer_num, - mla_modules, - sinks=sinks, - sliding_window=per_layer_sliding_window, - rotary_emb=rotary_emb, - dtype=dtype, - q_norm=q_norm, - k_norm=k_norm, - ) - - compilation_config = atom_config.compilation_config - default_name = f"MLA_{layer_num}" if self.use_mla else f"MHA_{layer_num}" - self.layer_name = prefix if prefix is not None else default_name - if self.layer_name in compilation_config.static_forward_context: - raise ValueError("Duplicate layer: {}".format(self.layer_name)) - compilation_config.static_forward_context[self.layer_name] = self + @abstractmethod def forward( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - positions: torch.Tensor = None, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: Optional[torch.Tensor] = None, q_scale: Optional[torch.Tensor] = None, - qkv: torch.Tensor = None, - ): - output = torch.ops.aiter.unified_attention_with_output_base( - q, q_scale, k, v, positions, self.layer_name, self.use_mla, qkv + **kwargs, + ) -> torch.Tensor: + raise NotImplementedError( + f"{self.__class__.__name__} must implement the forward() method" ) - return output diff --git a/atom/model_ops/embed_head.py b/atom/model_ops/embed_head.py index 77281d4f9..371c0324e 100644 --- a/atom/model_ops/embed_head.py +++ b/atom/model_ops/embed_head.py @@ -9,6 +9,7 @@ from torch import nn from atom.utils.forward_context import ForwardContext, get_forward_context +from atom.plugin import is_plugin_mode class VocabParallelEmbedding(nn.Module): @@ -68,13 +69,14 @@ def __init__( self.register_parameter("bias", None) def forward(self, x: torch.Tensor): - forward_context: ForwardContext = get_forward_context() - context = forward_context.context - attn_metadata = forward_context.attn_metadata - # context = get_context() - if context.is_prefill and not context.is_draft: - last_indices = attn_metadata.cu_seqlens_q[1:] - 1 - x = x[last_indices].contiguous() + if not is_plugin_mode(): + forward_context: ForwardContext = get_forward_context() + context = forward_context.context + attn_metadata = forward_context.attn_metadata + # context = get_context() + if context.is_prefill and not context.is_draft: + last_indices = attn_metadata.cu_seqlens_q[1:] - 1 + x = x[last_indices].contiguous() logits = tgemm.mm(x, self.weight, self.bias) if self.tp_size > 1: logits = tensor_model_parallel_all_gather(logits) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index b22fe317f..9e0710d00 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -47,6 +47,7 @@ ) from atom.utils.custom_register import direct_register_custom_op from atom.utils.forward_context import get_forward_context +from atom.plugin.moe import FusedMoEDecoratorForPluginMode class FusedMoeWeightScaleSupported(Enum): @@ -1757,6 +1758,7 @@ def moe_forward_fake( ) +@FusedMoEDecoratorForPluginMode class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py new file mode 100644 index 000000000..20a26576c --- /dev/null +++ b/atom/model_ops/paged_attention.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +# from flash_attn import flash_attn_with_kvcache +from typing import Optional + +import torch +from torch import nn + +from .attention_mla import MLAModules +from .base_attention import BaseAttention +from atom.config import get_current_atom_config +from atom.utils.selector import get_attn_backend +from atom.plugin.prepare import is_sglang, is_vllm +from atom.plugin.attention import unified_attention_with_output_base_for_plugin_mode + + +class PagedAttention(BaseAttention): + """ + Attention paged implementation + """ + + def __init__( + self, + num_heads, + head_dim, + scale, + num_kv_heads, + alibi_slopes: list[float] = None, + kv_cache_dtype="bf16", + layer_num=0, + use_mla: bool = False, + mla_modules: Optional[MLAModules] = None, + sinks: Optional[nn.Parameter] = None, + per_layer_sliding_window: Optional[int] = None, + rotary_emb: Optional[torch.nn.Module] = None, + prefix: Optional[str] = None, + q_norm: Optional[torch.nn.Module] = None, + k_norm: Optional[torch.nn.Module] = None, + **kwargs, + ): + # plugin mode(sglang) is not support paged attention + # for now, only support plugin mode(vllm) and atom server mode + assert ( + not is_sglang() + ), "PagedAttention is not supported for plugin mode(sglang) for now" + super().__init__( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + kv_cache_dtype=kv_cache_dtype, + layer_num=layer_num, + use_mla=use_mla, + mla_modules=mla_modules, + sinks=sinks, + per_layer_sliding_window=per_layer_sliding_window, + rotary_emb=rotary_emb, + prefix=prefix, + **kwargs, + ) + + # for plugin mode + if is_vllm(): + self.use_mla = use_mla + self.rotary_emb = rotary_emb + from vllm.attention.layer import Attention, AttentionType + + atom_config = get_current_atom_config() + assert ( + atom_config is not None + ), "atom_config is required for plugin mode to vllm" + + # use vllm cache config and quant config to follow the convention of vllm + cache_config = atom_config.plugin_config.vllm_cache_config + quant_config = atom_config.plugin_config.vllm_quant_config + + # add extra impl args, which are needed to be passed to the impl class + # while it only works for custom attention backend for vllm + extra_impl_args = {} + if atom_config.plugin_config.vllm_use_atom_attention: + extra_impl_args["sinks"] = sinks + extra_impl_args["rotary_emb"] = rotary_emb + extra_impl_args["q_norm"] = q_norm + extra_impl_args["k_norm"] = k_norm + + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=None, + per_layer_sliding_window=per_layer_sliding_window, + prefix=f"{prefix}", + attn_type=AttentionType.DECODER, + kv_sharing_target_layer_name=None, + **extra_impl_args, + ) + + compilation_config = atom_config.compilation_config + self.layer_name = prefix + if self.layer_name in compilation_config.static_forward_context: + raise ValueError("Duplicate layer: {}".format(self.layer_name)) + compilation_config.static_forward_context[self.layer_name] = self + return + + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.k_cache = self.v_cache = torch.tensor([]) + self.kv_cache_dtype = kv_cache_dtype + self.max_model_len = 0 + self.k_scale = self.v_scale = None + self.layer_num = layer_num + self.mla_modules = mla_modules + self.use_mla = use_mla + self.base_attention = None + self.kv_cache = torch.tensor([]) + self.indexer = mla_modules.indexer if mla_modules is not None else None + self.sinks = sinks + + atom_config = get_current_atom_config() + dtype = atom_config.torch_dtype + block_size = atom_config.kv_cache_block_size + self.attn_backend = get_attn_backend( + block_size, + use_mla=self.use_mla, + ) + impl_cls = self.attn_backend.get_impl_cls() + self.impl = impl_cls( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + kv_cache_dtype=kv_cache_dtype, + layer_num=layer_num, + mla_modules=mla_modules, + sinks=sinks, + sliding_window=per_layer_sliding_window, + rotary_emb=rotary_emb, + dtype=dtype, + q_norm=q_norm, + k_norm=k_norm, + **kwargs, + ) + + compilation_config = atom_config.compilation_config + default_name = f"MLA_{layer_num}" if self.use_mla else f"MHA_{layer_num}" + self.layer_name = prefix if prefix is not None else default_name + if self.layer_name in compilation_config.static_forward_context: + raise ValueError("Duplicate layer: {}".format(self.layer_name)) + compilation_config.static_forward_context[self.layer_name] = self + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor = None, + q_scale: Optional[torch.Tensor] = None, + qkv: torch.Tensor = None, + **kwargs, + ): + if is_vllm(): + output = unified_attention_with_output_base_for_plugin_mode( + query, + q_scale, + key, + value, + positions, + layer_name=self.layer_name, + use_mla=self.use_mla, + qkv=qkv, + ) + return output + + # for atom server mode + output = torch.ops.aiter.unified_attention_with_output_base( + query, q_scale, key, value, positions, self.layer_name, self.use_mla, qkv + ) + return output diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py new file mode 100644 index 000000000..6ebe7a562 --- /dev/null +++ b/atom/model_ops/radix_attention.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +from torch import nn +from typing import Optional + +from .attention_mla import MLAModules +from .base_attention import BaseAttention +from atom.plugin.prepare import is_plugin_mode, is_sglang +from atom.models.utils import maybe_prefix +from atom.utils import envs + +from aiter.rotary_embedding import AiterFusedSetKVBufferArg + +class RadixAttention(BaseAttention): + """ + Attention radix implementation + """ + + def __init__( + self, + num_heads, + head_dim, + scale, + num_kv_heads, + kv_cache_dtype="bf16", + layer_num=0, + use_mla: bool = False, + mla_modules: Optional[MLAModules] = None, + sinks: Optional[nn.Parameter] = None, + per_layer_sliding_window: Optional[int] = None, + rotary_emb: Optional[torch.nn.Module] = None, + prefix: Optional[str] = None, + **kwargs, + ): + super().__init__( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + kv_cache_dtype=kv_cache_dtype, + layer_num=layer_num, + use_mla=use_mla, + mla_modules=mla_modules, + sinks=sinks, + per_layer_sliding_window=per_layer_sliding_window, + rotary_emb=rotary_emb, + prefix=prefix, + **kwargs, + ) + + if is_sglang(): + self.rotary_emb = rotary_emb + self.layer_num = layer_num + + self.k_scale = torch.tensor([1.0], dtype=torch.float32) + self.v_scale = torch.tensor([1.0], dtype=torch.float32) + + # if True, save cache will be done in rope + self.use_rope_fused_qknorm = envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION + + from sglang.srt.layers.radix_attention import RadixAttention + self.attn = RadixAttention( + num_heads=num_heads, + head_dim=head_dim, + scaling=scale, + num_kv_heads=num_kv_heads, + layer_id=self.layer_num, + prefix=maybe_prefix(prefix, "attn"), + ) + else: + raise NotImplementedError( + "RadixAttention is only supported for plugin mode for sglang for now" + ) + + def forward_impl_plugin_mode( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata=None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor = None, + q_scale: torch.Tensor = None, + qkv: torch.Tensor = None, + **kwargs, + ): + # for sglang, forward_batch is required + forward_batch = kwargs.get("forward_batch", None) + assert forward_batch is not None, "forward_batch is required for sglang" + + if self.use_rope_fused_qknorm: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + self.layer_num + ) + block_size = 1024 # Default fallback + if hasattr(forward_batch, "attn_backend") and hasattr( + forward_batch.attn_backend, "page_size" + ): + block_size = forward_batch.attn_backend.page_size + elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr( + forward_batch.token_to_kv_pool.allocator, "page_size" + ): + block_size = forward_batch.token_to_kv_pool.allocator.page_size + elif hasattr(forward_batch.token_to_kv_pool, "page_size"): + block_size = forward_batch.token_to_kv_pool.page_size + x = 16 // k_buffer.element_size() + aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( + kv_cache=(k_buffer, v_buffer), + cache_loc=forward_batch.out_cache_loc, + k_scale=self.k_scale, + v_scale=self.v_scale, + return_kv=True, + use_shuffle_layout=True, + block_size=block_size, + x=x, + ) + q, k, v = self.rotary_emb( + qkv, + self.q_norm.weight, + self.k_norm.weight, + positions, + self.num_heads, + self.num_kv_heads, + self.q_norm.eps, + fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg, + ) + else: + # calculate the q and k with rotary embedding + assert self.rotary_emb is not None, "rotary_emb is required" + q, k = self.rotary_emb(positions, q, k) + v = value + + save_kv_cache = not self.use_rope_fused_qknorm + return self.attn( + q, + k, + v, + forward_batch=forward_batch, + save_kv_cache=save_kv_cache, + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor = None, + q_scale: Optional[torch.Tensor] = None, + qkv: torch.Tensor = None, + **kwargs, + ): + if is_plugin_mode(): + o = self.forward_impl_plugin_mode( + query=query, + key=key, + value=value, + positions=positions, + q_scale=q_scale, + qkv=qkv, + **kwargs, + ) + else: + raise NotImplementedError( + "RadixAttention is not supported for server mode for now" + ) + return o diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 4a88fb863..ac58cabfc 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -24,7 +24,7 @@ """Inference-only DeepseekV2/DeepseekV3 model.""" import logging -from typing import Any, Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from aiter import ( @@ -59,7 +59,7 @@ ) from atom.model_ops.activation import SiluAndMul from atom.model_ops.attention_mla import MLAModules, is_rocm_aiter_fp4bmm_enabled -from atom.model_ops.base_attention import Attention +import atom.model_ops as ops from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding from atom.model_ops.layernorm import LayerNorm, RMSNorm from atom.model_ops.linear import ( @@ -1397,11 +1397,12 @@ def __init__( indexer=self.indexer, ) - self.mla_attn = Attention( + self.mla_attn = ops.ATTN_CLS( num_heads=self.num_local_heads, head_dim=self.kv_lora_rank + self.qk_rope_head_dim, scale=self.scaling, num_kv_heads=1, + alibi_slopes=None, kv_cache_dtype=cache_config, layer_num=layer_num, use_mla=True, diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index eee08a710..3114802ea 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -28,7 +28,7 @@ # from vllm.model_executor.layers.logits_processor import LogitsProcessor from aiter.rotary_embedding import get_rope from atom.config import Config, QuantizationConfig -from atom.model_ops.base_attention import Attention +import atom.model_ops as ops from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding # from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig @@ -121,11 +121,12 @@ def __init__( # Only apply sliding window to every other layer sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None - self.attn = Attention( + self.attn = ops.ATTN_CLS( self.num_local_attention_heads, self.head_dim, self.scaling, num_kv_heads=self.num_local_key_value_heads, + alibi_slopes=None, kv_cache_dtype=cache_config, quant_config=quant_config, per_layer_sliding_window=sliding_window, diff --git a/atom/models/llama.py b/atom/models/llama.py index 4f485de69..7d2a95c4a 100644 --- a/atom/models/llama.py +++ b/atom/models/llama.py @@ -34,7 +34,7 @@ from atom.model_ops.activation import SiluAndMul # from atom.model_ops.attention import Attention -from atom.model_ops.base_attention import Attention +import atom.model_ops as ops from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding from atom.model_ops.layernorm import RMSNorm from atom.model_ops.linear import ( @@ -190,11 +190,12 @@ def __init__( if is_sliding: sliding_window = config.sliding_window - self.attn = Attention( + self.attn = ops.ATTN_CLS( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + alibi_slopes=None, kv_cache_dtype=cache_config, layer_num=layer_num, per_layer_sliding_window=sliding_window, diff --git a/atom/models/mixtral.py b/atom/models/mixtral.py index e9f4b66a7..ff7e2159e 100644 --- a/atom/models/mixtral.py +++ b/atom/models/mixtral.py @@ -32,7 +32,7 @@ from atom.config import Config, QuantizationConfig # from atom.model_ops.attention import Attention -from atom.model_ops.base_attention import Attention +import atom.model_ops as ops from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding from atom.model_ops.layernorm import RMSNorm from atom.model_ops.linear import QKVParallelLinear, ReplicatedLinear, RowParallelLinear @@ -170,11 +170,12 @@ def __init__( base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( + self.attn = ops.ATTN_CLS( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + alibi_slopes=None, kv_cache_dtype=cache_config, layer_num=layer_num, quant_config=quant_config, diff --git a/atom/models/qwen3.py b/atom/models/qwen3.py index ccd4f7648..1c932b990 100644 --- a/atom/models/qwen3.py +++ b/atom/models/qwen3.py @@ -24,18 +24,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Iterable import torch # import torch.distributed as dist from aiter.dist.parallel_state import get_tp_group from aiter.rotary_embedding import get_rope -from atom.config import Config, QuantizationConfig +from atom.config import Config from atom.model_ops.activation import SiluAndMul # from atom.model_ops.attention import Attention -from atom.model_ops.base_attention import Attention +import atom.model_ops as ops from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding from atom.model_ops.layernorm import RMSNorm from atom.model_ops.linear import ( @@ -47,6 +47,9 @@ from torch import nn from transformers import Qwen3Config +from atom.model_loader.loader import load_model_in_plugin_mode +from atom.models.utils import maybe_prefix + class Qwen3Attention(nn.Module): @@ -63,11 +66,11 @@ def __init__( rope_scaling: tuple | None = None, kv_cache_dtype: str = "fp16", layer_num: int = 0, - quant_config: Optional[QuantizationConfig] = None, + atom_config: Config = None, + prefix: str = "", ) -> None: super().__init__() tp_size = get_tp_group().world_size - self.layer_num = layer_num self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size @@ -85,13 +88,15 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=qkv_bias, - quant_config=quant_config, + quant_config=atom_config.quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - quant_config=quant_config, + quant_config=atom_config.quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, @@ -100,15 +105,18 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - self.num_kv_heads, + self.attn = ops.ATTN_CLS( + num_heads=self.num_heads, + head_dim=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + alibi_slopes=None, kv_cache_dtype=kv_cache_dtype, layer_num=layer_num, use_mla=False, rotary_emb=self.rotary_emb, + config=atom_config, + prefix=f"{prefix}.attn", ) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -117,6 +125,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -124,7 +133,7 @@ def forward( q = self.q_norm(q) k = self.k_norm(k) - o = self.attn(q, k, v, positions) + o = self.attn(q, k, v, positions, **model_kwargs) output = self.o_proj(o) return output @@ -136,7 +145,8 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config=None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -144,12 +154,14 @@ def __init__( [intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) assert hidden_act == "silu" self.act_fn = SiluAndMul() @@ -166,11 +178,12 @@ class Qwen3DecoderLayer(nn.Module): def __init__( self, config: Qwen3Config, - cache_config: str = "bf16", - quant_config: Optional[QuantizationConfig] = None, + atom_config: Config, layer_num: int = 0, + prefix: str = "", ) -> None: super().__init__() + kv_cache_dtype = atom_config.kv_cache_dtype self.layer_num = layer_num rope_params = config.rope_parameters self.self_attn = Qwen3Attention( @@ -183,15 +196,17 @@ def __init__( head_dim=getattr(config, "head_dim", None), rope_theta=rope_params["rope_theta"], rope_scaling=rope_params, - kv_cache_dtype=cache_config, + kv_cache_dtype=kv_cache_dtype, layer_num=layer_num, - quant_config=quant_config, + atom_config=atom_config, + prefix=f"{prefix}.self_attn", ) self.mlp = Qwen3MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + quant_config=atom_config.quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -203,51 +218,65 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, + **model_kwargs: dict[str, Any] | None, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn(positions, hidden_states) + hidden_states = self.self_attn( + positions=positions, hidden_states=hidden_states, **model_kwargs + ) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual -@support_torch_compile +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + } +) class Qwen3Model(nn.Module): - def __init__(self, atom_config: Config) -> None: + def __init__(self, *, atom_config: Config, prefix: str = "") -> None: super().__init__() - config = atom_config.hf_config - cache_config = atom_config.kv_cache_dtype - quant_config = atom_config.quant_config + hf_config = atom_config.hf_config + self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, config.hidden_size + hf_config.vocab_size, hf_config.hidden_size ) self.layers = nn.ModuleList( [ Qwen3DecoderLayer( - config, - cache_config=cache_config, - quant_config=quant_config, + config=hf_config, + atom_config=atom_config, layer_num=layer_num, + prefix=f"{prefix}.layers.{layer_num}", ) - for layer_num in range(config.num_hidden_layers) + for layer_num in range(hf_config.num_hidden_layers) ] ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(hf_config.hidden_size, eps=hf_config.rms_norm_eps) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + **model_kwargs: dict[str, Any], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for layer in self.layers: - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + **model_kwargs, + ) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -261,20 +290,34 @@ class Qwen3ForCausalLM(nn.Module): "up_proj": ("gate_up_proj", 1), } - def __init__(self, atom_config: Config) -> None: + def __init__(self, config: Any, prefix: str = "") -> None: super().__init__() - config = atom_config.hf_config - self.model = Qwen3Model(atom_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - if config.tie_word_embeddings: + self.atom_config = config + self.hf_config = self.atom_config.hf_config + self.model = Qwen3Model( + atom_config=self.atom_config, prefix=maybe_prefix(prefix, "model") + ) + + self.lm_head = ParallelLMHead( + num_embeddings=self.hf_config.vocab_size, + embedding_dim=self.hf_config.hidden_size, + bias=False, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if self.hf_config.tie_word_embeddings: self.lm_head.weight.data = self.model.embed_tokens.weight.data def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + intermediate_tensors=None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions) + hidden_states = self.model( + input_ids=input_ids, positions=positions, **model_kwargs + ) return hidden_states def compute_logits( @@ -283,3 +326,10 @@ def compute_logits( ) -> torch.Tensor: logits = self.lm_head(hidden_states) return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 57a89f1e5..a5fa065a9 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, Any, Iterable import torch from aiter.dist.communication_op import tensor_model_parallel_all_reduce @@ -9,8 +9,7 @@ from atom.config import Config, QuantizationConfig from atom.model_ops.activation import SiluAndMul -# from atom.model_ops.attention import Attention -from atom.model_ops.base_attention import Attention +import atom.model_ops as ops from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding from atom.model_ops.layernorm import RMSNorm from atom.model_ops.linear import ( @@ -30,9 +29,10 @@ from atom.utils import envs from atom.utils.decorators import support_torch_compile from torch import nn +from atom.model_loader.loader import load_model_in_plugin_mode # import torch.distributed as dist -from transformers import PretrainedConfig, Qwen3Config +from transformers import PretrainedConfig ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( @@ -149,7 +149,8 @@ def __init__( rope_scaling: tuple | None = None, kv_cache_dtype: str = "fp16", layer_num: int = 0, - quant_config: Optional[QuantizationConfig] = None, + atom_config: Config = None, + prefix: str = "", ) -> None: super().__init__() tp_size = get_tensor_model_parallel_world_size() @@ -177,14 +178,14 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=qkv_bias, - quant_config=quant_config, + quant_config=atom_config.quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - quant_config=quant_config, + quant_config=atom_config.quant_config, reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, ) @@ -207,7 +208,7 @@ def __init__( self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) - self.attn = Attention( + self.attn = ops.ATTN_CLS( self.num_heads, self.head_dim, self.scaling, @@ -216,6 +217,7 @@ def __init__( layer_num=layer_num, use_mla=False, rotary_emb=self.rotary_emb, + prefix=f"{prefix}.attn", q_norm=self.q_norm if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION else None, k_norm=self.k_norm if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION else None, ) @@ -227,36 +229,43 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: - attn_output = self.attn(q, k, v, positions, None, qkv) + attn_output = self.attn(query=q, + key=k, + value=v, + positions=positions, + q_scale=None, + qkv=qkv, + **model_kwargs) else: # Add qk-norm q = self.q_norm(q) k = self.k_norm(k) - attn_output = self.attn(q, k, v, positions) + attn_output = self.attn(query=q, + key=k, + value=v, + positions=positions, + **model_kwargs) output = self.o_proj(attn_output) return output class Qwen3MoeDecoderLayer(nn.Module): - def __init__( - self, - config: Qwen3Config, - prefix: str, - cache_config: str = "bf16", - quant_config: Optional[QuantizationConfig] = None, - layer_num: int = 0, - ) -> None: + def __init__(self, atom_config=None, layer_num: int = 0, prefix: str = "") -> None: super().__init__() + self.atom_config = atom_config + config = self.atom_config.hf_config self.hidden_size = config.hidden_size rope_params = config.rope_parameters rope_theta = rope_params["rope_theta"] rope_scaling = rope_params + kv_cache_dtype = atom_config.kv_cache_dtype max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. @@ -272,9 +281,10 @@ def __init__( head_dim=getattr(config, "head_dim", None), rope_theta=rope_theta, rope_scaling=rope_scaling, - kv_cache_dtype=cache_config, + kv_cache_dtype=kv_cache_dtype, layer_num=layer_num, - quant_config=quant_config, + atom_config=atom_config, + prefix=f"{prefix}.self_attn", ) # `mlp_only_layers` in the config. @@ -286,14 +296,16 @@ def __init__( and (self.layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3MoeSparseMoeBlock( - config, quant_config=quant_config, prefix=f"{prefix}.mlp" + config, + quant_config=self.atom_config.quant_config, + prefix=f"{prefix}.mlp", ) else: self.mlp = Qwen3MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + quant_config=self.atom_config.quant_config, reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, prefix=f"{prefix}.mlp", ) @@ -313,6 +325,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, + **model_kwargs: dict[str, Any] | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -323,6 +336,7 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, + **model_kwargs, ) # Fully Connected @@ -341,42 +355,37 @@ def __init__( ): super().__init__() - config = atom_config.hf_config - cache_config = atom_config.kv_cache_dtype - quant_config = atom_config.quant_config - self.config = config + self.config = atom_config.hf_config if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, + self.config.vocab_size, + self.config.hidden_size, ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, + self.config.num_hidden_layers, lambda prefix, layer_num=None: Qwen3MoeDecoderLayer( - config, - prefix, - cache_config=cache_config, - quant_config=quant_config, + atom_config=atom_config, layer_num=layer_num, + prefix=prefix, ), prefix=f"{prefix}.layers", layer_num_offset=0, ) if get_pp_group().is_last_rank: self.norm = RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, + self.config.hidden_size, + eps=self.config.rms_norm_eps, fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, ) else: self.norm = PPMissingLayer() self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size + ["hidden_states", "residual"], self.config.hidden_size ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -388,6 +397,7 @@ def forward( positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -401,7 +411,9 @@ def forward( residual = intermediate_tensors["residual"] for layer in self.layers[self.start_layer : self.end_layer]: - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer( + positions, hidden_states, residual, **model_kwargs + ) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -438,24 +450,22 @@ def __init__( layer_type: type[nn.Module] = Qwen3MoeDecoderLayer, ): super().__init__() - config = atom_config.hf_config - quant_config = atom_config.quant_config - self.config = config - self.quant_config = quant_config + self.atom_config = atom_config + self.config = self.atom_config.hf_config + # Only perform the following mapping when Qwen3MoeMLP exists - if getattr(config, "mlp_only_layers", []): + if getattr(self.config, "mlp_only_layers", []): self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"] self.model = Qwen3MoeModel( - atom_config=atom_config, + atom_config=self.atom_config, prefix=maybe_prefix(prefix, "model"), - layer_type=layer_type, ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, + num_embeddings=self.config.vocab_size, + embedding_dim=self.config.hidden_size, + bias=False, prefix=maybe_prefix(prefix, "lm_head"), ) else: @@ -476,9 +486,14 @@ def forward( positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any] | None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, ) return hidden_states @@ -505,3 +520,10 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record diff --git a/atom/plugin/__init__.py b/atom/plugin/__init__.py new file mode 100644 index 000000000..27c855e51 --- /dev/null +++ b/atom/plugin/__init__.py @@ -0,0 +1,13 @@ +from .prepare import ( + prepare_model, + is_sglang, + is_vllm, + is_plugin_mode, +) + +__all__ = [ + "prepare_model", + "is_sglang", + "is_vllm", + "is_plugin_mode", +] diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py new file mode 100644 index 000000000..9d1446a63 --- /dev/null +++ b/atom/plugin/attention.py @@ -0,0 +1,617 @@ +from typing import Optional +import logging + +from dataclasses import dataclass + +import torch + +from atom.plugin.prepare import is_vllm, is_sglang +from atom.utils import CpuGpuBuffer +from atom.utils.forward_context import Context, AttentionMetaData +from atom.model_ops.attention_mha import PagedAttentionImpl + +logger = logging.getLogger("atom") + +_PARTITION_SIZE_ROCM = 256 +_CP_TOKENS_PER_ITER_ROCM = 32 * 1024 + + +@dataclass +class AiterFlashAttentionDecodeMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + + +@dataclass +class AiterFlashAttentionPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + + +@dataclass +class AiterChunkSlidingWindowMetadata: + swa_seqlens: torch.Tensor + swa_cu_seqlens: torch.Tensor + swa_seq_starts: torch.Tensor + swa_token_to_batch: torch.Tensor + swa_max_seqlens: int + swa_total_tokens: int + swa_workspace: torch.Tensor + + +@dataclass +class AiterChunkContextMetadata: + workspace: torch.Tensor + cu_seq_lens_chunk: torch.Tensor + chunk_starts: torch.Tensor + token_to_batch: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + seq_lens: torch.Tensor + num_chunks: int + total_token_per_batch: list[int] + swa_metadata: Optional[AiterChunkSlidingWindowMetadata] = None + + +@dataclass +class AiterFlashAttentionChunkPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + chunk_context_metadata: AiterChunkContextMetadata + + +@dataclass +class MetadataForPluginMode: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + num_actual_kv_tokens: int + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + slot_mapping: torch.Tensor + block_table: torch.Tensor + + # prefill and decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + num_extends: int + num_extend_tokens: int + + decode_metadata: Optional[AiterFlashAttentionDecodeMetadata] = None + prefill_metadata: Optional[AiterFlashAttentionPrefillMetadata] = None + extend_metadata: Optional[AiterFlashAttentionChunkPrefillMetadata] = None + + use_cascade: bool = False + common_prefix_len: int = 0 + total_tokens: int = 0 + + context: Optional[Context] = None + + +class vllmAiterBackendMethods: + # here attention in ATOM doesn't accept the output buffer because + # ATOM works as a model impl backend, it needs the maximum freedom + # to decide the output buffer and shape, thus here use this flag to + # let vllm don't allocate the output buffer for ATOM. ATOM will + # handle the output buffer by itself + accept_output_buffer: bool = False + supported_dtypes: list = [torch.float16, torch.bfloat16] + + def __init__(self): + raise TypeError( + f"{self.__class__.__name__} is a utility class and should not be instantiated. " + "Its methods are meant to be added to other classes via decorators." + ) + + @staticmethod + def get_supported_kernel_block_sizes(): + from vllm.v1.attention.backend import MultipleOf + + return [MultipleOf(16)] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @classmethod + def is_mla(cls) -> bool: + return False + + @staticmethod + def get_required_kv_cache_layout(): + return None + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [64, 128, 256] + + @classmethod + def full_cls_name(cls) -> tuple[str, str]: + return (cls.__module__, cls.__qualname__) + + @classmethod + def supports_alibi_sqrt(cls) -> bool: + return False + + +def AiterBackendDecoratorForPluginMode(cls): + """ + Decorator for AiterBackend to add specific methods and attributes for plugin mode + """ + is_vllm_mode = is_vllm() + + if is_vllm_mode: + # for vllm, add the required methods + cls.full_cls_name = vllmAiterBackendMethods.full_cls_name + cls.accept_output_buffer = vllmAiterBackendMethods.accept_output_buffer + cls.supported_dtypes = vllmAiterBackendMethods.supported_dtypes + cls.get_supported_kernel_block_sizes = ( + vllmAiterBackendMethods.get_supported_kernel_block_sizes + ) + cls.get_kv_cache_shape = vllmAiterBackendMethods.get_kv_cache_shape + cls.is_mla = vllmAiterBackendMethods.is_mla + cls.get_required_kv_cache_layout = ( + vllmAiterBackendMethods.get_required_kv_cache_layout + ) + cls.get_supported_head_sizes = vllmAiterBackendMethods.get_supported_head_sizes + cls.supports_alibi_sqrt = vllmAiterBackendMethods.supports_alibi_sqrt + return cls + + +def create_attn_metadata_builder_init_method(base_class): + """ + Create the init method for metadata builder + """ + + def init_method_under_plugin_mode( + self, + kv_cache_spec=None, + layer_names=None, + config=None, + device=None, + model_runner=None, + ): + base_class.__init__(self, kv_cache_spec, layer_names, config, device) + logger.info("init AiterAttentionMetadataBuilder for plugin mode") + from vllm.config import VllmConfig, get_layers_from_vllm_config + from vllm.attention.layer import Attention + + assert isinstance(config, VllmConfig) + + self.vllm_config = config + self.model_config = config.model_config + self.parallel_config = config.parallel_config + self.cache_config = config.cache_config + + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) + self.head_dim = self.model_config.get_head_size() + self.block_size = kv_cache_spec.block_size + + self.aot_sliding_window: Optional[tuple[int, int]] = None + self.total_tokens: int = 0 + + self.scheduler_config = config.scheduler_config + self.block_ratio = 1 + + sliding_window_sizes: set[tuple[int, int] | None] = set() + layers = get_layers_from_vllm_config(config, Attention) + for layer in layers.values(): + assert isinstance(layer.impl, PagedAttentionImpl) + sliding_window_sizes.add((layer.impl.sliding_window - 1, 0)) + + while len(sliding_window_sizes) > 0: + sliding_window_config = sliding_window_sizes.pop() + if sliding_window_config is not None and sliding_window_config[0] != -1: + assert ( + self.aot_sliding_window is None + ), "Aiter Backend only support one valid sliding window" + self.aot_sliding_window = sliding_window_config + + # for extend path to store the fetched key and value + # here buffer used for extend path is not calculated by vLLM and SGLang + # when profile_run, it is possible to exhaust the GPU memory when + # gpu_mem_utilization is much higher + self.extend_workspace = torch.empty( + [2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.head_dim], + dtype=self.model_config.dtype, + device=device, + ) + + # used for ROPE + max_num_batched_tokens = config.scheduler_config.max_num_batched_tokens + i64_kwargs = {"dtype": torch.int64, "device": device} + self.positions = CpuGpuBuffer(max_num_batched_tokens, **i64_kwargs) + + return init_method_under_plugin_mode + + +def setup_attn_metadata_builder_base_class_and_attributes(class_dict: dict): + """ + Setup the base class and attributes for attention metadata builder + """ + from vllm.v1.attention.backend import ( + AttentionCGSupport, + AttentionMetadataBuilder, + ) + + base_class = AttentionMetadataBuilder + generic_base = AttentionMetadataBuilder + needs_generic = True + + # align with vllm rocm aiter fa + class_dict["_cudagraph_support"] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + class_dict["reorder_batch_threshold"] = 1 + + return base_class, generic_base, needs_generic, class_dict + + +class vllmAttentionMetadataBuilderMethods: + def __init__(self): + raise TypeError( + f"{self.__class__.__name__} is a utility class and should not be instantiated. " + "Its methods are meant to be added to other classes via decorators." + ) + + def build( + self, + common_prefix_len: int = 0, + common_attn_metadata=None, + fast_build: bool = False, + ): + if common_prefix_len > 0: + raise ValueError("ATOM does not support cascade attention yet") + + from vllm.v1.attention.backends.utils import split_decodes_prefills_and_extends + + # here assume the decode num token is 1 per request + split_ret = split_decodes_prefills_and_extends( + common_attn_metadata=common_attn_metadata, decode_threshold=1 + ) + + ( + num_decodes, + num_extends, + num_prefills, + num_decode_tokens, + num_extend_tokens, + num_prefill_tokens, + ) = split_ret + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens.cpu() + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + decode_metadata = None + if num_decodes > 0: + decode_metadata = AiterFlashAttentionDecodeMetadata( + max_query_len=query_lens_cpu[:num_decodes].max().item(), + min_query_len=query_lens_cpu[:num_decodes].min().item(), + max_seq_len=seq_lens[:num_decodes].max().item(), + query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1], + ) + + extend_metadata = None + if num_extends > 0: + num_extends_slice = slice(num_decodes, num_decodes + num_extends) + query_lens_for_extend = query_lens_cpu[num_extends_slice] + seq_lens_for_extend = seq_lens[num_extends_slice] + computed_kv_lens = seq_lens_for_extend - query_lens_for_extend + swa_metadata = None + if self.aot_sliding_window is not None: + swa_seqlen_for_extend = torch.minimum( + seq_lens_for_extend, + query_lens_for_extend + self.aot_sliding_window[0] + 1, + ) + cu_seq_lens = torch.zeros( + num_extends + 1, + dtype=torch.int32, + device=seq_lens_for_extend.device, + ) + torch.cumsum( + swa_seqlen_for_extend, + dim=0, + dtype=cu_seq_lens.dtype, + out=cu_seq_lens[1:], + ) + token_to_seq = torch.arange( + 0, + num_extends, + dtype=torch.int32, + device=seq_lens_for_extend.device, + ) + token_to_seq = torch.repeat_interleave( + token_to_seq, swa_seqlen_for_extend + ) + fetched_shape = cu_seq_lens[-1].item() + swa_workspace = torch.empty( + (2, fetched_shape, self.num_heads_kv, self.head_dim), + dtype=self.vllm_config.model_config.dtype, + device=self.device, + ) + + seq_starts = seq_lens_for_extend - swa_seqlen_for_extend + max_seqlen_k = swa_seqlen_for_extend.max().item() + total_tokens = cu_seq_lens[-1].item() + + swa_metadata = AiterChunkSlidingWindowMetadata( + swa_seqlens=swa_seqlen_for_extend.to( + self.device, non_blocking=True + ), + swa_cu_seqlens=cu_seq_lens.to(self.device, non_blocking=True), + swa_seq_starts=seq_starts.to(self.device, non_blocking=True), + swa_token_to_batch=token_to_seq.to(self.device, non_blocking=True), + swa_max_seqlens=max_seqlen_k, + swa_total_tokens=total_tokens, + swa_workspace=swa_workspace, + ) + + # allocate the equal amount of workspace for + # each chunk prefill request + max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends + from vllm.utils.math_utils import cdiv + + num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk) + + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_extends) + * max_context_chunk + ) + chunk_ends = torch.min( + computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk + ) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp( + min=0 + ) # [num_chunks, num_extends] + cu_seq_lens_cpu = torch.zeros( + [num_chunks, num_extends + 1], dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 + ) + max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item() + + # Build token->batch mapping robustly, even with zero-length batches. + token_to_batch_tensor = torch.zeros( + (num_chunks, max_cum_tokens), dtype=torch.int32, pin_memory=True + ) + batch_ids = torch.arange(num_extends, dtype=torch.int32) + for chunk_idx in range(num_chunks): + total_tokens = cu_seq_lens_cpu[chunk_idx, -1].item() + if total_tokens == 0: + continue + token_to_batch = torch.repeat_interleave( + batch_ids, chunk_seq_lens[chunk_idx].to(torch.int64) + ) + token_to_batch_tensor[chunk_idx, :total_tokens] = token_to_batch + + chunk_context_metadata = AiterChunkContextMetadata( + workspace=self.extend_workspace, + cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True), + chunk_starts=chunk_starts.to(self.device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True), + num_chunks=num_chunks, + total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(), + swa_metadata=swa_metadata, + ) + + query_start_loc_device = common_attn_metadata.query_start_loc[ + num_decodes : num_decodes + num_extends + 1 + ] + seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice] + cu_seq_lens = torch.zeros( + num_extends + 1, dtype=torch.int32, device=seq_lens_device.device + ) + torch.cumsum( + seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:] + ) + extend_metadata = AiterFlashAttentionChunkPrefillMetadata( + max_query_len=query_lens_for_extend.max().item(), + min_query_len=query_lens_for_extend.min().item(), + max_seq_len=seq_lens[num_extends_slice].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0], + chunk_context_metadata=chunk_context_metadata, + ) + + prefill_metadata = None + if num_prefills > 0: + query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :] + query_start_loc_device = common_attn_metadata.query_start_loc[ + num_decodes + num_extends : + ] + prefill_metadata = AiterFlashAttentionPrefillMetadata( + max_query_len=query_lens_for_prefill.max().item(), + min_query_len=query_lens_for_prefill.min().item(), + max_seq_len=seq_lens[num_decodes + num_extends :].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0], + ) + + num_actual_kv_tokens = torch.sum(seq_lens).item() + + use_cascade = False + + context_batch_size = 0 + has_prefill = bool(num_prefills > 0 or num_extends > 0) + if has_prefill: + context_batch_size = num_prefills + num_extends + else: + context_batch_size = num_decodes + context_graph_bs = context_batch_size + + num_actual_tokens = common_attn_metadata.num_actual_tokens + context = Context( + positions=None, + is_prefill=has_prefill, + batch_size=context_batch_size, + graph_bs=context_graph_bs, + ) + + attn_metadata_for_plugin_mode = MetadataForPluginMode( + num_actual_tokens=num_actual_tokens, + num_actual_kv_tokens=num_actual_kv_tokens, + max_query_len=common_attn_metadata.max_query_len, + query_start_loc=common_attn_metadata.query_start_loc, + max_seq_len=common_attn_metadata.max_seq_len, + seq_lens=common_attn_metadata.seq_lens, + block_table=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_extends=num_extends, + num_extend_tokens=num_extend_tokens, + decode_metadata=decode_metadata, + prefill_metadata=prefill_metadata, + extend_metadata=extend_metadata, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + total_tokens=self.total_tokens, + context=context, + ) + + attn_metadata = AttentionMetaData( + max_seqlen_q=common_attn_metadata.max_query_len, + block_tables=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + plugin_metadata=attn_metadata_for_plugin_mode, + ) + + return attn_metadata + + # this method will be called by vllm, so it follows the vllm's interface convention + def build_for_cudagraph_capture( + self, + common_attn_metadata=None, + ): + self.total_tokens = ( + self.model_config.max_model_len + * self.vllm_config.scheduler_config.max_num_partial_prefills + ) + attn_metadata = self.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + self.total_tokens = 0 + return attn_metadata + + +def AiterAttentionMetadataBuilderDecoratorForPluginMode(default_base_class): + def decorator(cls): + is_vllm_mode = is_vllm() + is_sglang_mode = is_sglang() + + base_class = default_base_class + class_dict = {} + + # record original decorated cls methods + for key, value in cls.__dict__.items(): + if not key.startswith("__") or key in ( + "__annotations__", + "__init__", + "__module__", + "__qualname__", + "__doc__", + ): + class_dict[key] = value + + # handle the generic base class + needs_generic = False + generic_base = None + + if is_vllm_mode: + # get the base class and generic base class + base_class, generic_base, needs_generic, class_dict = ( + setup_attn_metadata_builder_base_class_and_attributes(class_dict) + ) + + # replace the __init__ method in the decorated class + class_dict["__init__"] = create_attn_metadata_builder_init_method( + base_class + ) + + # add the methods to the decorated class + for method_name in dir(vllmAttentionMetadataBuilderMethods): + if not method_name.startswith("_"): + method = getattr(vllmAttentionMetadataBuilderMethods, method_name) + if callable(method): + class_dict[method_name] = method + elif is_sglang_mode: + raise NotImplementedError( + "AttentionMetadataBuilder for sglang is not implemented yet" + ) + + # create the new class + new_class = type(cls.__name__, (base_class,), class_dict) + + # replace the inherit base class for plugin mode, meanwhile support generic base class + if needs_generic and generic_base is not None: + new_class.__orig_bases__ = (generic_base[new_class],) + + return new_class + + return decorator + + +# here not register it as a custom op and mark split because vllm +# will register attention impl forward as a custom op, so here +# avoid duplicated registration, and the split op is registered +# into the atom support_torch_compile decorator +def unified_attention_with_output_base_for_plugin_mode( + q: torch.Tensor, + q_scale: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + positions: torch.Tensor, + layer_name: str, + use_mla: bool, + qkv: torch.Tensor, +) -> torch.Tensor: + from atom.config import get_current_atom_config + from atom.utils import envs + + atom_config = get_current_atom_config() + if use_mla: + raise NotImplementedError("MLA is not supported for plugin mode for now") + else: + self = atom_config.compilation_config.static_forward_context[layer_name] + # here is the standard vllm attention impl interface + # when using fusion, we need to pass the qkv and positions through the q,k,v + # [watch out] accept_output_buffer must be False for plugin mode + # because we don't want vllm to manipulate the q k v and output buffer + # ATOM needs to handle all of the buffer here + if envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: + output = self.attn(q, positions, qkv) + else: + # calculate the q and k with rotary embedding + if self.rotary_emb is not None: + q, k = self.rotary_emb(positions, q, k) + output = self.attn(q, k, v) + return output diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py new file mode 100644 index 000000000..f861df487 --- /dev/null +++ b/atom/plugin/attention_mha.py @@ -0,0 +1,890 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +""" +Plugin mode extensions for PagedAttentionImpl. +This module provides additional methods for PagedAttentionImpl when running in plugin mode. +""" + +import torch +import aiter +from aiter import dtypes, fused_qk_norm_rope_cache_quant_shuffle +from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache +from aiter.ops.triton.gluon.pa_decode_gluon import get_recommended_splits +import triton +import triton.language as tl +from typing import TYPE_CHECKING +from atom.utils import envs + +import logging + +logger = logging.getLogger("atom") + +if TYPE_CHECKING: + from atom.utils.forward_context import AttentionMetaData + +ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( + envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION +) + +_PARTITION_SIZE_ROCM = 256 +_CP_TOKENS_PER_ITER_ROCM = 32 * 1024 +_QWEN_GLUON_PA_DECODE_BS = 64 + + +@triton.jit +def cp_mha_gather_cache_kernel( + key_cache_ptr, # [num_blocks, page_size, num_head, head_size] + value_cache_ptr, # [num_blocks, page_size, num_head, head_size] + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size] + block_table_ptr, # [num_batches, max_block_num] + cu_seqlens_kv_ptr, # [num_batches + 1] + token_to_batch_ptr, # [max_cum_tokens] + seq_start_ptr, # [num_batches] + k_scale_ptr, # [1] / [num_blocks, num_kv_heads, page_size] + v_scale_ptr, + num_heads, + head_size, + x, + max_block_num, + DEQUANT: tl.constexpr, + PAGE_SIZE: tl.constexpr, + CACHE_FORMAT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + token_id = tl.program_id(0) + head_id = tl.program_id(1) + col_offsets = tl.arange(0, BLOCK_SIZE) + + key_ptr_offset = key_ptr + token_id * head_size * num_heads + head_id * head_size + value_ptr_offset = ( + value_ptr + token_id * head_size * num_heads + head_id * head_size + ) + batch_idx = tl.load(token_to_batch_ptr + token_id) + batch_start = tl.load(seq_start_ptr + batch_idx) + token_start = tl.load(cu_seqlens_kv_ptr + batch_idx) + batch_offset = token_id - token_start + batch_start + block_offset = batch_offset // PAGE_SIZE + block_id = tl.load(block_table_ptr + max_block_num * batch_idx + block_offset).to( + tl.int64 + ) + slot_id = batch_offset % PAGE_SIZE + + if CACHE_FORMAT == "NHD": + # for kv cache layout as + # K: [num_blocks, page_size, num_head, head_dim] + # V: [num_blocks, page_size, num_head, head_dim] + key_cache_ptr_offset = ( + key_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + slot_id * num_heads * head_size + + head_id * head_size + ) + value_cache_ptr_offset = ( + value_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + slot_id * num_heads * head_size + + head_id * head_size + ) + k_reg = tl.load(key_cache_ptr_offset + col_offsets) + v_reg = tl.load(value_cache_ptr_offset + col_offsets) + if DEQUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = k_reg.dtype + v_dtype = v_reg.dtype + k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype) + v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype) + tl.store(key_ptr_offset + col_offsets, k_reg) + tl.store(value_ptr_offset + col_offsets, v_reg) + + elif CACHE_FORMAT == "SHUFFLE": + # for kv cache layout as + # K: [num_blocks, num_head, head_dim // x, page_size, x] + # V: [num_blocks, num_head, page_size // x, head_dim, x] + key_cache_ptr_offset = ( + key_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + head_id * head_size * PAGE_SIZE + + slot_id * x + ) + value_cache_ptr_offset = ( + value_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + head_id * head_size * PAGE_SIZE + + (slot_id // x) * head_size * x + + slot_id % x + ) + k_reg_offset = col_offsets // x * PAGE_SIZE * x + col_offsets % x + v_reg_offset = col_offsets * x + k_reg = tl.load(key_cache_ptr_offset + k_reg_offset) + v_reg = tl.load(value_cache_ptr_offset + v_reg_offset) + if DEQUANT: + k_scale = 1.0 + v_scale = 1.0 + k_reg = k_reg.to(tl.float32) * k_scale + v_reg = v_reg.to(tl.float32) * v_scale + tl.store(key_ptr_offset + col_offsets, k_reg) + tl.store(value_ptr_offset + col_offsets, v_reg) + + +def cp_mha_gather_cache( + key_cache: torch.Tensor, + value_cache: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + block_tables: torch.Tensor, + k_scales: torch.Tensor, + v_scales: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + token_to_batch: torch.Tensor, + seq_starts: torch.Tensor, + dequant: bool, + kv_cache_layout: str, + total_tokens: int, +): + assert kv_cache_layout in [ + "NHD", + "SHUFFLE", + ], "kv_cache_layout only support NHD, SHUFFLE" + head_dim = key.shape[2] + x = 16 // key_cache.element_size() + # assert dequant is True, "Currently, we only support "\ + # "gather cache with dequant" + # For k cache layout: [num_blocks, num_heads, page_size, head_dim] + assert head_dim == key_cache.shape[3], ( + "We assume your kv cache layout is [num_blocks, " + "page_size, num_heads, head_dim], but got otherwise" + ) + page_size = key_cache.shape[1] + num_heads = key_cache.shape[2] + + grid = lambda meta: (total_tokens, num_heads) # noqa: E731 + cp_mha_gather_cache_kernel[grid]( + key_cache, + value_cache, + key, + value, + block_tables, + cu_seqlens_kv, + token_to_batch, + seq_starts, + k_scales, + v_scales, + num_heads, + head_dim, + x, + block_tables.size(1), + DEQUANT=dequant, + PAGE_SIZE=page_size, + CACHE_FORMAT=kv_cache_layout, + BLOCK_SIZE=head_dim, + ) + + +class PagedAttentionImplPluginModeMethods: + """ + Container class for plugin mode methods. + This class cannot be instantiated - it only serves as a namespace for methods + that will be added to PagedAttentionImpl via decorator. + """ + + def __init__(self): + raise TypeError( + "PagedAttentionImplPluginModeMethods cannot be instantiated. " + "It is only used as a method container for the decorator." + ) + + def rope_cache_plugin_mode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qkv: torch.Tensor, + position: torch.Tensor, + attention_metadata: "AttentionMetaData", + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + flash_layout: bool = False, + ): + num_blocks, block_size, num_kv_heads, head_size = k_cache.shape + + if not flash_layout: + x = 16 // k_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_cache.dtype, + device="meta", + ) + # ATOM: [num_blocks, num_kv_heads, head_size, block_size], + # vLLM: [num_blocks, num_kv_heads, block_size // x, head_size, x], + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_cache.dtype, + device="meta", + ) + new_key_cache = k_cache.view_as(k_cache_template) + new_value_cache = v_cache.view_as(v_cache_template) + else: + new_key_cache = k_cache + new_value_cache = v_cache + + # if flash kv_cache layout, the shape of kv_cache is: + # + # key_cache: [num_blocks, block_size, num_kv_heads, head_size] + # value_cache: [num_blocks, num_kv_heads, head_size, block_size] + # + # if not, the shape is: + # + # key_cache: [num_blocks, num_kv_heads, head_size // x, block_size, x] + # value_cache: [num_blocks, num_kv_heads, head_size, block_size] + # + # and the origin kv cache layout in fwd_args is not flash + + attn_metadata = attention_metadata + + use_triton_attn = self.sliding_window != -1 or self.head_dim != 128 + self.use_triton_attn = use_triton_attn + + if ( + self.rotary_emb is not None + and self.q_norm is not None + and self.k_norm is not None + ): + fused_qk_norm_rope_cache_quant_shuffle( + qkv, + num_heads_q=self.num_heads, + num_heads_k=self.num_kv_heads, + num_heads_v=self.num_kv_heads, + head_dim=self.head_dim, + eps=self.q_norm.eps, + qw=self.q_norm.weight, + kw=self.k_norm.weight, + cos_sin_cache=self.rotary_emb.cos_sin_cache, + is_neox_style=self.rotary_emb.is_neox_style, + pos_ids=position, + k_cache=new_key_cache, + v_cache=new_value_cache, + slot_mapping=attn_metadata.slot_mapping, + kv_cache_dtype=( + "auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype + ), + k_scale=k_scale, + v_scale=v_scale, + ) + + qkv = qkv.view(qkv.shape[0], -1, self.head_dim) + q, k, v = qkv.split( + [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1 + ) + # elif use_triton_attn and self.rotary_emb is not None: + elif 0: + # FIXME: this should be fixed by moving rope outside of attention + k_scale = v_scale = self.kv_scale + + q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache( + q, + k, + v, + new_key_cache, + new_value_cache, + attn_metadata.slot_mapping, + position, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + k_scale, + v_scale, + self.rotary_emb.is_neox_style, + flash_layout=flash_layout, + apply_scale=self.kv_cache_dtype.startswith("fp8"), + offs=None, + q_out=q, + k_out=k, + output_zeros=False, + ) + else: + if self.q_norm is not None: + q = self.q_norm(q) + if self.k_norm is not None: + k = self.k_norm(k) + if self.kv_cache_dtype == "fp8": + aiter.reshape_and_cache_with_pertoken_quant( + k, + v, + new_key_cache, + new_value_cache, + k_scale, + v_scale, + attn_metadata.slot_mapping, + asm_layout=True, + ) + else: + aiter.reshape_and_cache( + k, + v, + new_key_cache, + new_value_cache, + attn_metadata.slot_mapping, + kv_cache_dtype="auto", + k_scale=None, + v_scale=None, + asm_layout=True, + ) + + return q, k, v, k_cache, v_cache, k_scale, v_scale + + def paged_attention_triton_plugin_mode( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + out: torch.Tensor, + attn_metadata: "AttentionMetaData", + ): + o = out + num_seqs, num_q_heads_total, head_size = q.shape + num_blocks, num_kv_heads, _, block_size, _ = k_cache.shape + query_group_size = num_q_heads_total // num_kv_heads + assert num_q_heads_total % num_kv_heads == 0 + + max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads) + + context_partition_size = 256 + if self.sliding_window > 0: + max_context_partition_num = 1 + context_partition_size = 128 + + # Output buffers (same as Triton) + intermediate_shape = ( + num_seqs, + num_kv_heads, + max_context_partition_num, + query_group_size, + ) + exp_sums = torch.empty(intermediate_shape, dtype=torch.float32, device=q.device) + max_logits = torch.empty( + intermediate_shape, dtype=torch.float32, device=q.device + ) + temporary_output = torch.empty( + *intermediate_shape, + head_size, + dtype=q.dtype, + device=q.device, + ) + + per_tensor = False + if k_scale is not None: + per_tensor = k_scale.numel() == 1 + if not per_tensor: + k_scale = k_scale.unsqueeze(-1) + v_scale = v_scale.unsqueeze(-1) + compute_type = ( + torch.bfloat16 + if self.kv_cache_dtype == "bf16" or per_tensor + else aiter.dtypes.fp8 + ) + + torch.ops.aiter.pa_decode_gluon( + o, + q, + k_cache, + v_cache, + attn_metadata.plugin_metadata.seq_lens, + attn_metadata.block_tables, + self.scale, + 1, # query_lenth + max_context_partition_num, + context_partition_size, + compute_type, + None, + None if self.kv_cache_dtype == "bf16" else k_scale, + None if self.kv_cache_dtype == "bf16" else v_scale, + exp_sums=exp_sums, + max_logits=max_logits, + temporary_output=temporary_output, + alibi_slopes=None, + sinks=self.sinks, + sliding_window=self.sliding_window, + ps=True, + ) + + return o + + def paged_attention_asm_plugin_mode( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + num_decodes: int, + num_decode_tokens: int, + attn_metadata: "AttentionMetaData", + out: torch.Tensor, + ): + aiter.pa_fwd_asm( + Q=q, + K=k_cache, + V=v_cache, + block_tables=attn_metadata.plugin_metadata.block_table[:num_decodes], + context_lens=attn_metadata.plugin_metadata.seq_lens[:num_decodes], + block_tables_stride0=attn_metadata.plugin_metadata.block_table[ + :num_decodes + ].stride(0), + K_QScale=k_scale, + V_QScale=v_scale, + out_=out[:num_decode_tokens], + high_precision=0, + ) + + return + + def extend_for_sliding_window( + self, + attn_metadata: "AttentionMetaData", + query: torch.Tensor, + key_cache, + value_cache, + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + block_table: torch.Tensor, + k_scale: float, + v_scale: float, + ): + assert attn_metadata.plugin_metadata.extend_metadata is not None + assert ( + attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + is not None + ) + chunked_metadata = ( + attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + ) + swa_metadata = chunked_metadata.swa_metadata + assert swa_metadata is not None + swa_cu_seqlens = swa_metadata.swa_cu_seqlens + swa_seq_starts = swa_metadata.swa_seq_starts + swa_token_to_batch = swa_metadata.swa_token_to_batch + swa_max_seqlens = swa_metadata.swa_max_seqlens + swa_total_tokens = swa_metadata.swa_total_tokens + key_fetched, value_fetched = ( + swa_metadata.swa_workspace[0], + swa_metadata.swa_workspace[1], + ) + + cp_mha_gather_cache( + key_cache=key_cache, + value_cache=value_cache, + key=key_fetched, + value=value_fetched, + block_tables=block_table, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=swa_cu_seqlens, + token_to_batch=swa_token_to_batch, + seq_starts=swa_seq_starts, + dequant=self.kv_cache_dtype.startswith("fp8"), + kv_cache_layout="NHD", + total_tokens=swa_total_tokens, + ) + + sliding_window = ( + (self.sliding_window, 0, 0) + if self.sliding_window is not None + else (-1, -1, 0) + ) + aiter.flash_attn_varlen_func( + q=query, + k=key_fetched, + v=value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=swa_cu_seqlens, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=swa_max_seqlens, + min_seqlen_q=1, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=False, + out=output, + ) + + def extend_forward( + self, + attn_metadata: "AttentionMetaData", + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + min_seqlen_q: int, + block_table: torch.Tensor, + slot_mapping: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ): + from vllm.v1.attention.ops.merge_attn_states import merge_attn_states + + if self.sliding_window != -1: + self.extend_for_sliding_window( + attn_metadata, + query, + key_cache, + value_cache, + output, + cu_seqlens_q, + max_seqlen_q, + block_table, + k_scale, + v_scale, + ) + return + out, lse = aiter.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_q, + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + return_lse=True, + ) + assert attn_metadata.plugin_metadata.extend_metadata is not None + chunk_context_metadata = ( + attn_metadata.plugin_metadata.extend_metadata.chunk_context_metadata + ) + num_chunks = chunk_context_metadata.num_chunks + workspace = chunk_context_metadata.workspace + cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk + max_seqlens = chunk_context_metadata.max_seq_lens + chunk_starts = chunk_context_metadata.chunk_starts + token_to_batch = chunk_context_metadata.token_to_batch + total_token_per_batch = chunk_context_metadata.total_token_per_batch + key_fetched, value_fetched = workspace[0], workspace[1] + chunked_output = None + chunked_lse = None + for chunk_idx in range(num_chunks): + cp_mha_gather_cache( + key_cache=key_cache, + value_cache=value_cache, + key=key_fetched, + value=value_fetched, + block_tables=block_table, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=cu_seqlens_kv[chunk_idx], + token_to_batch=token_to_batch[chunk_idx], + seq_starts=chunk_starts[chunk_idx], + dequant=self.kv_cache_dtype.startswith("fp8"), + kv_cache_layout="SHUFFLE", + total_tokens=total_token_per_batch[chunk_idx], + ) + + suf_out, suf_lse = aiter.flash_attn_varlen_func( + q=query, + k=key_fetched, + v=value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_kv[chunk_idx], + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlens[chunk_idx], + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=False, + window_size=(-1, -1, 0), + alibi_slopes=self.alibi_slopes, + return_lse=True, + ) + + if chunked_output is None: + chunked_output = suf_out + chunked_lse = suf_lse + else: + tmp_output = torch.empty_like(out) + tmp_lse = torch.empty_like(lse) + merge_attn_states( + output=tmp_output, + output_lse=tmp_lse, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=suf_out, + suffix_lse=suf_lse, + ) + chunked_output = tmp_output + chunked_lse = tmp_lse + + merge_attn_states( + output=output, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=out, + suffix_lse=lse, + ) + + def forward_impl_plugin_mode( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: "AttentionMetaData" = None, + position: torch.Tensor = None, + q_scale: torch.Tensor = None, + qkv: torch.Tensor = None, + output: torch.Tensor = None, + ): + # create the output here, it use query shape + num_tokens = query.shape[0] + output_dtype = query.dtype + output_shape = torch.Size((num_tokens, self.num_heads * self.head_size)) + output = torch.empty(output_shape, dtype=output_dtype, device=query.device) + + # dummy run will skip attention in cuda graph capture phase + if attn_metadata is None: + return output.fill_(0) + + # when using this optimization, the qkv tensor and + # position tensor are passed through q,k,v + if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: + assert ( + position is None + ), "position should be None because it is passed through k" + + position = key + qkv = value + + q_size = self.num_heads * self.head_dim + kv_size = self.num_kv_heads * self.head_dim + query, key, value = torch.split(qkv, [q_size, kv_size, kv_size], dim=-1) + else: + # the position is computed by ATOM, and contained in attention metadata + # when dummy run, the attn metadata is None + if attn_metadata is not None: + position = attn_metadata.plugin_metadata.context.positions + + query = query.view(-1, self.num_heads, self.head_dim) + key = key.view(-1, self.num_kv_heads, self.head_dim) + value = value.view(-1, self.num_kv_heads, self.head_dim) + output = output.view(-1, self.num_heads, self.head_dim) + + num_actual_tokens = attn_metadata.plugin_metadata.num_actual_tokens + k_cache, v_cache = kv_cache.unbind(0) + num_blocks, block_size, num_kv_heads, _ = k_cache.shape + + if self.kv_cache_dtype == "fp8": + target_dtype = dtypes.d_dtypes[self.kv_cache_dtype] + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + # create kv scale according to the num_blocks + # usually it is created when cuda graph capture for decode phase + if self.kv_cache_dtype == "fp8": + if self.k_scale is None or self.v_scale is None: + self.kv_scale = torch.zeros( + 2, + num_blocks, + num_kv_heads, + block_size, + dtype=dtypes.fp32, + device=self.device, + ) + # update the layer kv scale tensor + self.k_scale = self.kv_scale[0] + self.v_scale = self.kv_scale[1] + layer.k_scale = self.k_scale + layer.v_scale = self.v_scale + + # rope and cache flush fusion. ATOM always use shuffle layout for kv cache + result = self.rope_cache_plugin_mode( + q=query, + k=key, + v=value, + qkv=qkv, + position=position, + attention_metadata=attn_metadata, + k_cache=k_cache, + v_cache=v_cache, + k_scale=self.k_scale, + v_scale=self.v_scale, + flash_layout=False, + ) + query, key, value, k_cache, v_cache, k_scale, v_scale = result + + # as vLLM cuda graph capture padding mechanism, here split the qkvo with + # the actual tokens + query = query[:num_actual_tokens] + if key is not None: + key = key[:num_actual_tokens] + if value is not None: + value = value[:num_actual_tokens] + output_actual_tokens = output[:num_actual_tokens] + + num_decodes = attn_metadata.plugin_metadata.num_decodes + num_prefills = attn_metadata.plugin_metadata.num_prefills + num_extends = attn_metadata.plugin_metadata.num_extends + + num_decode_tokens = attn_metadata.plugin_metadata.num_decode_tokens + num_extend_tokens = attn_metadata.plugin_metadata.num_extend_tokens + + # calculate for prefills + if num_prefills > 0: + assert attn_metadata.plugin_metadata.prefill_metadata is not None + + # prefill part is after decode and extend + prefill_query = query[num_decode_tokens + num_extend_tokens :] + prefill_key = key[num_decode_tokens + num_extend_tokens :] + prefill_value = value[num_decode_tokens + num_extend_tokens :] + + sliding_window = ( + (self.sliding_window, 0, 0) + if self.sliding_window is not None + else (-1, -1, 0) + ) + + aiter.flash_attn_varlen_func( + q=prefill_query, + k=prefill_key, + v=prefill_value, + cu_seqlens_q=attn_metadata.plugin_metadata.prefill_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.plugin_metadata.prefill_metadata.query_start_loc, + max_seqlen_q=attn_metadata.plugin_metadata.prefill_metadata.max_query_len, + max_seqlen_k=attn_metadata.plugin_metadata.prefill_metadata.max_seq_len, + min_seqlen_q=1, + dropout_p=attn_metadata.dropout_p, + softmax_scale=self.scale, + causal=True, + window_size=sliding_window, + alibi_slopes=None, + sink_ptr=self.sinks, + out=output_actual_tokens[num_decode_tokens + num_extend_tokens :], + ) + + # calculate for extends + if num_extends > 0: + assert attn_metadata.plugin_metadata.extend_metadata is not None + extend_tokens_slice = slice( + num_decode_tokens, num_decode_tokens + num_extend_tokens + ) + extend_querys = query[extend_tokens_slice] + extend_keys = key[extend_tokens_slice] + extend_values = value[extend_tokens_slice] + extend_outputs = output[extend_tokens_slice] + self.extend_forward( + attn_metadata=attn_metadata, + query=extend_querys, + key=extend_keys, + value=extend_values, + key_cache=k_cache, + value_cache=v_cache, + output=extend_outputs, + cu_seqlens_q=attn_metadata.plugin_metadata.extend_metadata.query_start_loc, + max_seqlen_q=attn_metadata.plugin_metadata.extend_metadata.max_query_len, + max_seqlen_k=attn_metadata.plugin_metadata.extend_metadata.max_seq_len, + min_seqlen_q=1, + block_table=attn_metadata.plugin_metadata.block_table[ + num_decodes : num_decodes + num_extends + ], + slot_mapping=attn_metadata.plugin_metadata.slot_mapping[ + num_decodes : num_decodes + num_extends + ], + k_scale=k_scale, + v_scale=v_scale, + ) + + # calculate for decodes + if num_decodes > 0: + assert attn_metadata.plugin_metadata.decode_metadata is not None + + num_blocks, block_size, num_kv_heads, head_size = k_cache.shape + x = 16 // k_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_cache.dtype, + device="meta", + ) + new_key_cache = k_cache.view_as(k_cache_template) + new_value_cache = v_cache.view_as(v_cache_template) + + if self.use_triton_attn: + self.paged_attention_triton_plugin_mode( + q=query[:num_decode_tokens], + k=new_key_cache, + v=new_value_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_actual_tokens[:num_decode_tokens], + attn_metadata=attn_metadata, + ) + else: + # Qwen only uses gluon pa decode when bs=64 + if num_decodes == _QWEN_GLUON_PA_DECODE_BS: + self.paged_attention_triton_plugin_mode( + q=query[:num_decode_tokens], + k_cache=new_key_cache, + v_cache=new_value_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_actual_tokens[:num_decode_tokens], + attn_metadata=attn_metadata, + ) + else: + self.paged_attention_asm_plugin_mode( + q=query[:num_decode_tokens], + k_cache=new_key_cache, + v_cache=new_value_cache, + k_scale=k_scale, + v_scale=v_scale, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + out=output_actual_tokens[:num_decode_tokens], + attn_metadata=attn_metadata, + ) + + output = output.view(-1, self.num_heads * self.head_dim) + + return output + + +def PagedAttentionImplDecoratorForPluginMode(cls): + method_names = [ + "rope_cache_plugin_mode", + "paged_attention_triton_plugin_mode", + "paged_attention_asm_plugin_mode", + "extend_for_sliding_window", + "extend_forward", + "forward_impl_plugin_mode", + ] + + logger.info( + "Use PagedAttentionImplDecoratorForPluginMode to decorate PagedAttentionImpl" + ) + + # Add all methods to the target class + for method_name in method_names: + method = getattr(PagedAttentionImplPluginModeMethods, method_name) + setattr(cls, method_name, method) + + return cls diff --git a/atom/plugin/config.py b/atom/plugin/config.py new file mode 100644 index 000000000..bf1af0f62 --- /dev/null +++ b/atom/plugin/config.py @@ -0,0 +1,244 @@ +import os +import sys + +from typing import Any, Optional +from dataclasses import dataclass + +import torch +import logging + +logger = logging.getLogger("atom") +_KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD = 18 * 1024 + + +@dataclass +class PluginConfig: + # common config for both framework + model_config: Any = None + rank: int = 0 + is_plugin_mode: bool = False + is_vllm: bool = False + is_sglang: bool = False + + # vllm specific + vllm_scheduler_config: Any = None + vllm_cache_config: Any = None + vllm_quant_config: Any = None + vllm_use_atom_attention: bool = False + + # sglang specific + sglang_model_opt_config: Any = None + sglang_load_config: Any = None + sglang_enable_torch_compile: bool = False + sglang_disable_cuda_graph: bool = False + sglang_enable_dp_attention: bool = False + sglang_dist_init_addr: Optional[str] = None + sglang_port_args: Any = None + + +def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: + from atom.config import Config, CompilationConfig + + vllm_model_config = config.model_config + vllm_scheduler_config = config.scheduler_config + vllm_cache_config = config.cache_config + vllm_parallel_config = config.parallel_config + vllm_use_atom_attention = bool( + os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "0" + ) + + # here use the ATOM compilation config, as the ATOM compile policy is used + # instead of vLLM one for torch compile, while for cuda graph capture, + # still use the vLLM because it has FULL_AND_PIECEWISE feature + # when you don't want to use atom torch compile, you can also use + # --enforce-eager to disable the atom torch compile when launch vllm server + compilation_config = config.compilation_config + vllm_compilation_config = CompilationConfig( + # use mode because vllm level argument is deprecated + level=compilation_config.mode, + use_cudagraph=False, + cudagraph_mode=None, + ) + + vllm_quant_config = config.quant_config + + plugin_config = PluginConfig( + # common config + model_config=vllm_model_config, + rank=vllm_parallel_config.rank, + is_plugin_mode=True, + is_vllm=True, + is_sglang=False, + # vllm specific + vllm_scheduler_config=vllm_scheduler_config, + vllm_cache_config=vllm_cache_config, + vllm_quant_config=vllm_quant_config, + vllm_use_atom_attention=vllm_use_atom_attention, + ) + + # specific + max_model_len = vllm_model_config.max_model_len + if hasattr(vllm_scheduler_config, "max_model_len"): + max_model_len = vllm_scheduler_config.max_model_len + + max_num_batched_tokens = vllm_scheduler_config.max_num_batched_tokens + # FIXME: known issue for illegal mem access in fused_moe kernel + if max_num_batched_tokens >= _KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD: + logger.warning( + "For plugin mode, when setting max_num_batched_tokens >= " + + f"{_KNOWN_ISSUE_MAX_NUM_BATCHED_TOKENS_THRESHOLD}, there is a known issue " + + "for illegal mem access in asm fused_moe kernel, if you met the issue, " + + "please set max_num_batched_tokens smaller or choose the ck fused_moe " + + "kernel instead of asm ones" + ) + + return Config( + model=vllm_model_config.model, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=vllm_scheduler_config.max_num_seqs, + max_model_len=max_model_len, + gpu_memory_utilization=vllm_cache_config.gpu_memory_utilization, + tensor_parallel_size=vllm_parallel_config.tensor_parallel_size, + enforce_eager=True, # disable using atom cuda graph + parallel_config=vllm_parallel_config, + kv_cache_block_size=vllm_cache_config.block_size, + num_kvcache_blocks=vllm_cache_config.num_gpu_blocks, + kv_cache_dtype=vllm_cache_config.cache_dtype, + enable_prefix_caching=vllm_cache_config.enable_prefix_caching, + port=None, + torch_profiler_dir=None, + compilation_config=vllm_compilation_config, + asyncio_mode=False, + load_dummy=False, + enable_expert_parallel=vllm_parallel_config.enable_expert_parallel, + master_addr=None, + enable_dp_attention=False, + plugin_config=plugin_config, + ) + + +def _generate_atom_config_from_sglang_config(config: Any): + from sglang.srt.server_args import ( + ServerArgs, + prepare_server_args, + PortArgs, + ) + from sglang.srt.configs.model_config import ModelConfig as SglangModelConfig + from sglang.srt.configs.modelopt_config import ModelOptConfig + from sglang.srt.configs.load_config import LoadConfig + from atom.config import Config, ParallelConfig, CompilationConfig + + # sglang has no global config variable like vllm, + # so here construct the server args from sys.argv passed by users + # this is the only way to get full arguments + server_args: ServerArgs = prepare_server_args(sys.argv[1:]) + + sgl_model_config = SglangModelConfig.from_server_args(server_args) + sgl_model_opt_config = ModelOptConfig( + quant=server_args.modelopt_quant, + checkpoint_restore_path=server_args.modelopt_checkpoint_restore_path, + checkpoint_save_path=server_args.modelopt_checkpoint_save_path, + export_path=server_args.modelopt_export_path, + ) + + sgl_load_config = LoadConfig( + load_format=server_args.load_format, + download_dir=server_args.download_dir, + model_loader_extra_config=server_args.model_loader_extra_config, + remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip, + remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port, + remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports, + remote_instance_weight_loader_backend=server_args.remote_instance_weight_loader_backend, + modelopt_config=sgl_model_opt_config, + rl_quant_profile=server_args.rl_quant_profile, + ) + + # sglang doesn't passed the rank number in config, so ATOM plugin + # get rank number through the torch.distributed.get_rank() + rank = torch.distributed.get_rank() + + # sglang uses the atom parallel config + sgl_parallel_config = ParallelConfig( + data_parallel_size=server_args.dp_size, + ) + + # use sglang torch compile policy and cuda graph policy + # because sglang doesn't use the compile decorator for model, + # we have no method to define self policy + sgl_compilation_config = CompilationConfig( + level=0, + use_cudagraph=False, + cudagraph_mode=None, + ) + + plugin_config = PluginConfig( + # common config + model_config=sgl_model_config, + rank=rank, + is_plugin_mode=True, + is_vllm=False, + is_sglang=True, + # sglang specific + sglang_model_opt_config=sgl_model_opt_config, + sglang_load_config=sgl_load_config, + sglang_enable_torch_compile=server_args.enable_torch_compile, + sglang_disable_cuda_graph=server_args.disable_cuda_graph, + sglang_enable_dp_attention=server_args.enable_dp_attention, + sglang_dist_init_addr=server_args.dist_init_addr, + sglang_port_args=PortArgs.init_new(server_args), + ) + + # force max num batched tokens to 16K because sgl doesn't have + # concept for max num batched tokens + return Config( + model=None, + max_num_batched_tokens=16384, + max_num_seqs=server_args.max_running_requests, + max_model_len=server_args.context_length, + gpu_memory_utilization=server_args.mem_fraction_static, + tensor_parallel_size=server_args.tp_size, + enforce_eager=True, # disable using atom cuda graph + parallel_config=sgl_parallel_config, + kv_cache_dtype=server_args.kv_cache_dtype, + enable_prefix_caching=False, + port=None, + torch_profiler_dir=None, + compilation_config=sgl_compilation_config, + asyncio_mode=False, + load_dummy=False, + enable_expert_parallel=bool(server_args.ep_size > 1), + master_addr=None, + enable_dp_attention=server_args.enable_dp_attention, + plugin_config=plugin_config, + ) + + +def generate_atom_config_for_plugin_mode(config: Any = None): + """ + Generate the atom config in plugin mode, be called when create the custom model + config: + - for vllm: config is VllmConfig and contains all config value from vllm + - for sglang: config is only model specific config passed from sglang, so the + server args is used + """ + + logger.info("Generate atom config for plugin mode from passed config") + + atom_config = None + from atom.plugin import is_vllm, is_sglang + from atom.config import set_current_atom_config + + if is_vllm(): + atom_config = _generate_atom_config_from_vllm_config(config) + elif is_sglang(): + atom_config = _generate_atom_config_from_sglang_config(config) + else: + raise ValueError( + "Make sure ATOM is running in plugin mode; " + "generate_atom_config_for_plugin_mode should be called in plugin mode." + ) + + # set the current atom config for the custom model + set_current_atom_config(atom_config) + + return atom_config diff --git a/atom/plugin/moe.py b/atom/plugin/moe.py new file mode 100644 index 000000000..f0ed21ff0 --- /dev/null +++ b/atom/plugin/moe.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Type, TypeVar +import logging +from atom.plugin.prepare import is_vllm + +logger = logging.getLogger("atom") + +T = TypeVar("T") + + +def _apply_moe_decoration(original_cls: Type[T]) -> Type[T]: + """ + Apply the actual decoration to the MoE class. + This is called lazily during instantiation. + """ + is_vllm_mode = is_vllm() + if is_vllm_mode: + # Rename this class because vLLM will call the modular + # kernel init method for all modules of the model, whose name is FusedMoE, + # to init the inside kernel, while for plugin mode, the atom maintains + # the kernel lifecycle by itself, so there is no need to call init on + # vllm side + original_cls.__name__ = "ATOMFusedMoE" + original_cls.__qualname__ = "ATOMFusedMoE" + + return original_cls + + +def FusedMoEDecoratorForPluginMode(cls: Type[T]) -> Type[T]: + """ + Lazy decorator that defers class modification until first instantiation + """ + original_cls = cls + decorated_cls_cache = {"value": None} + + def get_decorated_class(): + if decorated_cls_cache["value"] is not None: + return decorated_cls_cache["value"] + + decorated = _apply_moe_decoration(original_cls) + decorated_cls_cache["value"] = decorated + return decorated + + class LazyMoEWrapper(original_cls): + def __new__(cls, *args, **kwargs): + decorated_cls = get_decorated_class() + return decorated_cls(*args, **kwargs) + + # Preserve the original class name and module for the wrapper + LazyMoEWrapper.__name__ = original_cls.__name__ + LazyMoEWrapper.__qualname__ = original_cls.__qualname__ + LazyMoEWrapper.__module__ = original_cls.__module__ + + logger.info("Create lazy wrapper for FusedMoE to change the naming") + return LazyMoEWrapper diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py new file mode 100644 index 000000000..6b3f80b13 --- /dev/null +++ b/atom/plugin/prepare.py @@ -0,0 +1,89 @@ +from typing import Any +import logging + +logger = logging.getLogger("atom") + +# all of the supported frameworks, including server mode and plugin mode +_SUPPORTED_FRAMEWORKS = ["vllm", "sglang", "sgl", "atom"] + +# supported frameworks for plugin mode +_SUPPORTED_FRAMEWORKS_FOR_PLUGIN_MODE = ["vllm", "sglang", "sgl"] + +# default is atom for server mode +_CURRENT_FRAMEWORK = "atom" + + +def is_sglang() -> bool: + global _CURRENT_FRAMEWORK + return bool(_CURRENT_FRAMEWORK.lower() in ["sglang", "sgl"]) + + +def is_vllm() -> bool: + global _CURRENT_FRAMEWORK + return bool(_CURRENT_FRAMEWORK.lower() in ["vllm"]) + + +def is_plugin_mode() -> bool: + global _CURRENT_FRAMEWORK + return bool(_CURRENT_FRAMEWORK.lower() in _SUPPORTED_FRAMEWORKS_FOR_PLUGIN_MODE) + + +def _set_framework_backbone(framework: str) -> None: + if framework.lower() not in _SUPPORTED_FRAMEWORKS: + raise ValueError(f"Unsupported framework {framework} for ATOM to plug in") + global _CURRENT_FRAMEWORK + _CURRENT_FRAMEWORK = framework + + +def prepare_model(config: Any, engine: str): + """ + Prepare the model to upper framework, including + register custom ops and init aiter dist + """ + logger.info(f"Prepare model for plugin mode, the upper engine is {engine}") + + _set_framework_backbone(engine) + + # different engine passed different config + if is_vllm(): + # FIXME: remove the legacy code here + # model_arch = config.model_config.architectures[0] + raise NotImplementedError("VLLM will not be supported for now") + elif is_sglang(): + model_arch = config.architectures[0] + + # import here to avoid partial initialization + from .register import ( + _ATOM_SUPPORTED_MODELS, + # register_ops_to_vllm, + register_ops_to_sglang, + init_aiter_dist, + set_attn_cls, + ) + + if model_arch not in _ATOM_SUPPORTED_MODELS: + supported_archs = list(_ATOM_SUPPORTED_MODELS.keys()) + raise ValueError( + f"ATOM does not support the required model architecture: {model_arch}. " + f"For now supported model architectures: {supported_archs}" + ) + + from atom.plugin.config import generate_atom_config_for_plugin_mode + + atom_config = generate_atom_config_for_plugin_mode(config) + + model_cls = _ATOM_SUPPORTED_MODELS[model_arch] + logger.info(f"ATOM model class for {model_arch} is {model_cls}") + + if is_vllm(): + # register_ops_to_vllm(atom_config=atom_config) + raise NotImplementedError("VLLM will not be supported for now") + elif is_sglang(): + register_ops_to_sglang(atom_config=atom_config) + + set_attn_cls() + + # init aiter dist for using aiter custom collective ops + init_aiter_dist(config=atom_config) + + return model_cls(atom_config=atom_config) diff --git a/atom/plugin/register.py b/atom/plugin/register.py new file mode 100644 index 000000000..cd4f6f9f0 --- /dev/null +++ b/atom/plugin/register.py @@ -0,0 +1,97 @@ +import logging + +from atom.models.qwen3 import Qwen3ForCausalLM +from atom.models.qwen3_moe import Qwen3MoeForCausalLM +from atom.config import Config +from atom.plugin.prepare import is_vllm, is_sglang + +logger = logging.getLogger("atom") + +_ATOM_SUPPORTED_MODELS = { + "Qwen3ForCausalLM": Qwen3ForCausalLM, + "Qwen3MoeForCausalLM": Qwen3MoeForCausalLM, +} + + +def _register_custom_attention_to_sglang() -> None: + + from sglang.srt.layers.attention.attention_registry import ( + register_attention_backend, + ) + + logger.info("Register custom attention backend to SGLang") + + # here register the custom attention backend with the name "aiter" + # as sglang defines the fixed attention backend choices, which must be + # in-tree + @register_attention_backend("aiter") + def create_atom_backend(runner): + from atom.plugin.sglang.sgl_attn_backend import ATOMAttnBackendForSGLPluginMode + + return ATOMAttnBackendForSGLPluginMode(runner) + + +def register_ops_to_sglang(atom_config: Config) -> None: + """ + Register custom ops to sglang, including attention + """ + _register_custom_attention_to_sglang() + + +def set_attn_cls() -> None: + """ + Set the attention class for constructing the model based on the framework + """ + import atom.model_ops as ops + + if is_vllm(): + ops.ATTN_CLS = ops.PagedAttention + logger.info("Set ATTN_CLS to PagedAttention for vLLM") + elif is_sglang(): + ops.ATTN_CLS = ops.RadixAttention + logger.info("Set ATTN_CLS to RadixAttention for SGLang") + + +def init_aiter_dist(config: Config) -> None: + """ + Initialize aiter dist for using aiter custom collective op + """ + logger.info( + "Initialize aiter dist for using aiter custom collective op for plugin mode" + ) + + from aiter import init_dist_env + from aiter.dist.utils import get_distributed_init_method + + rank = config.plugin_config.rank + tensor_parallel_size = config.tensor_parallel_size + + assert ( + config.plugin_config.is_plugin_mode + ), "Make sure ATOM is running in plugin mode" + + if config.plugin_config.is_vllm: + dp_master_ip = config.parallel_config.data_parallel_master_ip + dp_master_port = config.parallel_config.data_parallel_master_port + elif config.plugin_config.is_sglang: + if config.plugin_config.sglang_dist_init_addr is not None: + dp_master_ip, dp_master_port = ( + config.plugin_config.sglang_dist_init_addr.split(":") + ) + else: + dp_master_ip = "127.0.0.1" + dp_master_port = config.plugin_config.sglang_port_args.nccl_port + + distributed_init_method = get_distributed_init_method(dp_master_ip, dp_master_port) + + logger.info( + f"Initialize aiter dist for using aiter custom collective op for plugin mode, rank:{rank}" + ) + init_dist_env( + tensor_model_parallel_size=tensor_parallel_size, + rankID=rank, + backend="nccl", + distributed_init_method=distributed_init_method, + data_parallel_size=config.parallel_config.data_parallel_size, + data_parallel_rank=config.parallel_config.data_parallel_rank, + ) diff --git a/atom/plugin/sglang/__init__.py b/atom/plugin/sglang/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/atom/plugin/sglang/sgl_attn_backend.py b/atom/plugin/sglang/sgl_attn_backend.py new file mode 100644 index 000000000..2b17b5acf --- /dev/null +++ b/atom/plugin/sglang/sgl_attn_backend.py @@ -0,0 +1,1116 @@ +from __future__ import annotations + +""" +end to end attention solution with aiter kernels +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInput + +try: + from aiter import ( + flash_attn_varlen_func, + dtypes, + get_pa_metadata_info_v1, + get_pa_metadata_v1, + pa_fwd_asm, + pa_persistent_fwd, + ) +except ImportError: + print( + "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." + ) + +import triton +import triton.language as tl + + +@triton.jit +def reshape_and_cache_shuffle_kernel( + key_ptr, # [num_tokens, num_kv_heads, head_size] + value_ptr, # [num_tokens, num_kv_heads, head_size] + key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x] + value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x] + slot_mapping_ptr, # [num_tokens] + k_scale_ptr, + v_scale_ptr, + x, + k_stride0, + v_stride0, + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE: tl.constexpr, + QUANT: tl.constexpr, +): + tid = tl.program_id(0) + head_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + src_offset_k = tid * k_stride0 + head_id * head_size + src_offset_v = tid * v_stride0 + head_id * head_size + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + dst_offset = ( + block_id * num_kv_heads * head_size * block_size + + head_id * head_size * block_size + ) + dst_k_shuffle_offset = ( + dst_offset + offset // x * block_size * x + block_offset * x + offset % x + ) + dst_v_shuffle_offset = ( + dst_offset + block_offset // x * head_size * x + offset * x + block_offset % x + ) + k_val = tl.load(key_ptr + src_offset_k + offset) + v_val = tl.load(value_ptr + src_offset_v + offset) + if QUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = key_cache_ptr.type.element_ty + v_dtype = value_cache_ptr.type.element_ty + k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype) + v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype) + tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) + tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + + +def reshape_and_cache_shuffle_triton( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scales: torch.Tensor, + v_scales: torch.Tensor, +): + num_tokens = slot_mapping.shape[0] + _, num_kv_heads, head_size = key.shape + num_blocks, block_size, _, _ = key_cache.shape + x = 16 // key_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=key_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=value_cache.dtype, + device="meta", + ) + new_key_cache = key_cache.view_as(k_cache_template) + new_value_cache = value_cache.view_as(v_cache_template) + QUANT = False + if kv_cache_dtype.startswith("fp8"): + QUANT = True + grid = ( + num_tokens, + num_kv_heads, + ) + reshape_and_cache_shuffle_kernel[grid]( + key, + value, + new_key_cache, + new_value_cache, + slot_mapping, + k_scales, + v_scales, + x, + key.stride(0), + value.stride(0), + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE=head_size, + QUANT=QUANT, + ) + + +@dataclass +class ForwardMetadata: + # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode + kv_indptr: Optional[torch.Tensor] + kv_indices: Optional[torch.Tensor] + qo_indptr: Optional[torch.Tensor] + kv_last_page_len: Optional[torch.Tensor] + max_q_len: Optional[int] + max_kv_len: Optional[int] + page_table: Optional[torch.Tensor] + kv_lens: Optional[torch.Tensor] + # PA metadata for pa_persistent_fwd (only used in decode mode, non-MLA) + pa_metadata_qo_indptr: Optional[torch.Tensor] = None + pa_metadata_pages_kv_indptr: Optional[torch.Tensor] = None + pa_metadata_kv_indices: Optional[torch.Tensor] = None + pa_metadata_context_lens: Optional[torch.Tensor] = None + pa_metadata_max_qlen: Optional[int] = None + pa_metadata_tp_q_head_num: Optional[int] = None + # Prefill metadata for mha_batch_prefill_func (only used in prefill mode, non-MLA) + # prefill_pages_kv_indptr: Optional[torch.Tensor] = None + # prefill_kv_indices: Optional[torch.Tensor] = None + # prefill_kv_last_page_lens: Optional[torch.Tensor] = None + + +class ATOMAttnBackendForSGLPluginMode(AiterAttnBackend): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): + super().__init__(model_runner, skip_prefill, kv_indptr_buf) + + mapping = getattr( + model_runner.token_to_kv_pool, "full_attention_layer_id_mapping", None + ) + + if isinstance(mapping, dict) and mapping: + first_full_attn_id = next(iter(mapping.keys())) + else: + first_full_attn_id = 0 + + self.q_dtype = model_runner.dtype # Save q dtype for pa_metadata building + + assert ( + not self.use_mla + ), "MLA mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." + + # Pre-initialized qo_indptr for pa_persistent_fwd decode mode: [0, 1, 2, ..., max_bs] + # In decode mode, each sequence has 1 token, so this is always [0, 1, 2, ..., batch_size] + max_bs = model_runner.req_to_token_pool.size + self.pa_decode_qo_indptr = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=model_runner.device + ) + self.seq_lens = torch.zeros( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=model_runner.device, + ) + # Pre-compute strided indices for page_table construction (used in both CUDA Graph and non-CUDA Graph modes) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=model_runner.device + ) + + if not self.use_mla: + # Pre-allocate buffers for pa_persistent_fwd (used in both CUDA graph and non-CUDA graph modes) + max_num_blocks_per_seq = ( + self.max_context_len + self.page_size - 1 + ) // self.page_size + max_total_blocks = max_bs * max_num_blocks_per_seq + self.pa_kv_indices = torch.zeros( + max_total_blocks, dtype=torch.int32, device=self.device + ) + # Pre-allocate pa_kv_indptr buffer (similar to self.kv_indptr, but dedicated for pa_persistent_fwd) + self.pa_kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=self.device + ) + # Pre-initialized batch indices [0, 1, 2, ..., max_bs-1] for Triton kernel + self.pa_batch_indices = torch.arange( + 0, max_bs, dtype=torch.int32, device=self.device + ) + + # Pre-allocated descale tensors for FP8 attention (q, k, v all use scale=1.0) + + self.logits_soft_cap = 0.0 + + self.forward_metadata: ForwardMetadata = None + + self.pa_metadata_buffers = None + + k_buffer, _ = model_runner.token_to_kv_pool.get_kv_buffer(first_full_attn_id) + num_slots, num_kv_heads, _ = k_buffer.shape + block_size = self.page_size + num_blocks = num_slots // block_size + max_total_tokens = num_blocks * block_size + self.k_qscale = torch.ones( + num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device + ) + self.v_qscale = torch.ones( + num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device + ) + self.decode_using_pa_ps = self.page_size == 1024 + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init auxiliary variables for triton attention backend.""" + + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + spec_info = forward_batch.spec_info + # qo_indptr = None + # kv_last_page_len = None + # max_q_len = None + page_table = None + + if forward_batch.forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 + + if self.use_mla: + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." + ) + else: + if self.decode_using_pa_ps: + # Non-MLA decode mode: use same logic as CUDA Graph mode for page_table construction + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + + # Common setup consistent with CUDA Graph mode (init_forward_metadata_replay_cuda_graph) + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_( + forward_batch.seq_lens, non_blocking=True + ) + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + page_table = self.req_to_token[ + forward_batch.req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table // self.page_size, non_blocking=True + ) + else: + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + None, # qo_indptr not used in non-MLA mode + None, # kv_last_page_len not used in non-MLA mode + 1, # max_q_len = 1 for decode mode + None, + ( + page_table_persistent[:bs, :max_seq_pages] + if self.decode_using_pa_ps + else page_table + ), + ( + seq_lens_persistent[:bs] + if self.decode_using_pa_ps + else forward_batch.seq_lens + ), + ) + + # Build pa_metadata for pa_persistent_fwd + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + # return # Early return for non-MLA decode mode + else: + prefix_lens = forward_batch.extend_prefix_lens + + if self.use_mla: + raise NotImplementedError( + "MLA prefill mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." + ) + else: + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + # Get page_table for mha_batch_prefill_func + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + self.qo_indptr[ + : bs + 1 + ], # qo_indptr is set by indices_updater_prefill + None, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + None, + forward_batch.seq_lens, + ) + + if ( + forward_batch.forward_mode.is_extend() + and not self.use_mla + and self.forward_metadata.page_table is not None + ): + if self.page_size > 1: + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + self.forward_metadata.page_table = ( + self.forward_metadata.page_table[ + :, self.strided_indices[:max_seq_pages] + ] + // self.page_size + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_prefill(forward_batch.batch_size) + if ( + not self.decode_using_pa_ps + and self.page_size > 1 + and self.forward_metadata.page_table is not None + ): + self.forward_metadata.page_table = ( + self.forward_metadata.page_table[:, self.strided_indices] + // self.page_size + ) + + def _allocate_pa_metadata_buffers( + self, + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ): + """Allocate or reuse pa_metadata buffers.""" + if self.pa_metadata_buffers is None: + self.pa_metadata_buffers = {} + + def _get_size_val(size): + return size[0] if isinstance(size, tuple) else size + + # Allocate work_metadata_ptrs + size_val = _get_size_val(work_metadata_ptrs_size) + if ( + "work_metadata_ptrs" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["work_metadata_ptrs"].shape[0] < size_val + ): + self.pa_metadata_buffers["work_metadata_ptrs"] = torch.empty( + work_metadata_ptrs_size, + dtype=work_metadata_ptrs_type, + device=self.device, + ) + + # Allocate work_indptr + size_val = _get_size_val(work_indptr_size) + if ( + "work_indptr" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["work_indptr"].shape[0] < size_val + ): + self.pa_metadata_buffers["work_indptr"] = torch.zeros( + work_indptr_size, dtype=work_indptr_type, device=self.device + ) + else: + self.pa_metadata_buffers["work_indptr"].zero_() + + # Allocate work_info + size_val = _get_size_val(work_info_size) + if ( + "work_info" not in self.pa_metadata_buffers + or len(self.pa_metadata_buffers["work_info"].shape) < len(work_info_size) + or self.pa_metadata_buffers["work_info"].shape[0] < size_val + ): + self.pa_metadata_buffers["work_info"] = torch.zeros( + work_info_size, dtype=work_info_type, device=self.device + ) + else: + self.pa_metadata_buffers["work_info"].zero_() + + # Allocate reduce_indptr + size_val = _get_size_val(reduce_indptr_size) + if ( + "reduce_indptr" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["reduce_indptr"].shape[0] < size_val + ): + self.pa_metadata_buffers["reduce_indptr"] = torch.zeros( + reduce_indptr_size, dtype=reduce_indptr_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_indptr"].zero_() + + # Allocate reduce_final_map + size_val = _get_size_val(reduce_final_map_size) + if ( + "reduce_final_map" not in self.pa_metadata_buffers + or len(self.pa_metadata_buffers["reduce_final_map"].shape) + < len(reduce_final_map_size) + or self.pa_metadata_buffers["reduce_final_map"].shape[0] < size_val + ): + self.pa_metadata_buffers["reduce_final_map"] = torch.zeros( + reduce_final_map_size, dtype=reduce_final_map_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_final_map"].zero_() + + # Allocate reduce_partial_map + reduce_partial_map_size_val = ( + reduce_partial_map_size + if isinstance(reduce_partial_map_size, int) + else reduce_partial_map_size[0] + ) + if ( + "reduce_partial_map" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["reduce_partial_map"].shape[0] + < reduce_partial_map_size_val + ): + self.pa_metadata_buffers["reduce_partial_map"] = torch.zeros( + reduce_partial_map_size, + dtype=reduce_partial_map_type, + device=self.device, + ) + else: + self.pa_metadata_buffers["reduce_partial_map"].zero_() + + def _build_pa_metadata_for_decode( + self, + batch_size: int, + tp_q_head_num: Optional[int] = None, + ): + """Build pa_metadata buffers for pa_persistent_fwd in decode mode. + + This method prepares all metadata buffers needed for pa_persistent_fwd kernel. + The metadata can be reused across multiple layers in the same forward pass. + + Args: + batch_size: Batch size for the current forward pass + tp_q_head_num: Number of Q heads per TP rank. If None, uses self.num_head. + """ + max_qlen = 1 + + # Use provided tp_q_head_num or default to self.num_head + if tp_q_head_num is None: + tp_q_head_num = self.num_head + + # kv_dtype_for_metadata = dtypes.fp8 + ( + (work_metadata_ptrs_size, work_metadata_ptrs_type), + (work_indptr_size, work_indptr_type), + (work_info_size, work_info_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_pa_metadata_info_v1( + batch_size, + self.num_kv_head, + ) + # Allocate metadata buffers with reuse optimization for multi-layer forward passes + self._allocate_pa_metadata_buffers( + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ) + qo_indptr = self.pa_decode_qo_indptr[: batch_size + 1] + + # Get context_lens (kv_lens is always set before calling _build_pa_metadata_for_decode) + # Note: kv_lens comes from self.seq_lens which is already int32 + context_lens = self.forward_metadata.kv_lens + + kernel_block_size = self.page_size + num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size + # Use dedicated pa_kv_indptr buffer (similar to self.kv_indptr, but for pa_persistent_fwd) + pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + # Convert page_table to kv_indices (block indices) using Triton kernel to avoid sync + # page_table shape: [batch_size, max_num_blocks_per_seq] + # Note: page_table comes from self.page_table which is already int32 and always set before this call + page_table = self.forward_metadata.page_table + + # Use Triton kernel to gather kv_indices from page_table (avoids high-level indexing sync) + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + self.pa_batch_indices[:batch_size], # [0, 1, 2, ..., batch_size-1] + num_blocks_per_seq, + pages_kv_indptr, + None, # kv_start_idx + self.pa_kv_indices, + page_table.stride(0), + ) + # Use the full buffer - pa_persistent_fwd reads only valid elements based on pages_kv_indptr + kv_indices = self.pa_kv_indices + + get_pa_metadata_v1( + seqlens_qo_indptr=qo_indptr, + pages_kv_indptr=pages_kv_indptr, + context_lens=context_lens.int(), + num_heads_per_head_k=tp_q_head_num // self.num_kv_head, + num_heads_k=self.num_kv_head, + is_causal=True, + work_metadata_ptrs=self.pa_metadata_buffers["work_metadata_ptrs"], + work_indptr=self.pa_metadata_buffers["work_indptr"], + work_info=self.pa_metadata_buffers["work_info"], + reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], + kv_granularity=max(kernel_block_size, 16), + block_size=kernel_block_size, + max_seqlen_qo=max_qlen, + uni_seqlen_qo=max_qlen, + fast_mode=True, + topk=-1, + max_split_per_batch=-1, + ) + # Store computed values in ForwardMetadata for reuse in forward_decode + self.forward_metadata.pa_metadata_qo_indptr = qo_indptr + self.forward_metadata.pa_metadata_pages_kv_indptr = pages_kv_indptr + self.forward_metadata.pa_metadata_kv_indices = kv_indices + self.forward_metadata.pa_metadata_context_lens = context_lens + self.forward_metadata.pa_metadata_max_qlen = max_qlen + self.forward_metadata.pa_metadata_tp_q_head_num = tp_q_head_num + + def _build_pa_metadata_for_prefill(self, batch_size: int): + """Build metadata for mha_batch_prefill_func in prefill mode. + + This method prepares page-level metadata needed for mha_batch_prefill_func. + The metadata is computed once per forward pass and reused across all layers. + """ + block_size = self.page_size + context_lens = self.forward_metadata.kv_lens + num_blocks_per_seq = (context_lens + block_size - 1) // block_size + + # Page-level kv_indptr (reuse pa_kv_indptr buffer) + pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + # Build kv_indices from page_table using triton kernel + page_table = self.forward_metadata.page_table + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + self.pa_batch_indices[:batch_size], + num_blocks_per_seq, + pages_kv_indptr, + None, # kv_start_idx + self.pa_kv_indices, + page_table.stride(0), + ) + # kv_indices = self.pa_kv_indices + + # Compute kv_last_page_lens for each sequence + # kv_last_page_lens = ((context_lens - 1) % block_size + 1).int() + + # Store in ForwardMetadata for reuse in forward_extend + # self.forward_metadata.prefill_pages_kv_indptr = pages_kv_indptr + # self.forward_metadata.prefill_kv_indices = kv_indices + # self.forward_metadata.prefill_kv_last_page_lens = kv_last_page_lens + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, + ): + self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int) + if kv_indices_buf is None: + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_kv_indices = kv_indices_buf + + # Always use preshuffle layout for pa_fwd_asm + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=self.device, + ) + self.seq_lens = torch.zeros((max_bs,), dtype=torch.int32, device=self.device) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ) + + # Pre-allocate buffers for pa_metadata in CUDA graph mode (non-MLA decode) + if self.decode_using_pa_ps and not self.use_mla: + # Pre-allocate pa_metadata buffers for CUDA graph compatibility + # These buffers will be reused in capture and replay phases + # Use max_bs and max_qlen=1 (decode mode) to calculate buffer sizes + # max_qlen = 1 # decode mode + # kv_dtype_for_metadata = dtypes.fp8 + ( + (work_metadata_ptrs_size, work_metadata_ptrs_type), + (work_indptr_size, work_indptr_type), + (work_info_size, work_info_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_pa_metadata_info_v1( + max_bs, + self.num_kv_head, + ) + + # Pre-allocate buffers with maximum size for CUDA graph compatibility + self._allocate_pa_metadata_buffers( + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + ): + if forward_mode.is_decode_or_idle(): + if self.use_mla: + # MLA mode: kv_indptr and kv_indices are used in forward_decode + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." + ) + else: + # Non-MLA decode mode: kv_indptr and kv_indices are NOT used in forward_decode + # (forward_decode uses pa_metadata_pages_kv_indptr and pa_metadata_kv_indices instead) + page_table = self.page_table[:bs, :] + self.seq_lens[:bs].copy_(seq_lens, non_blocking=True) + seq_lens_persistent = self.seq_lens[:bs] + self.forward_metadata = ForwardMetadata( + None, # kv_indptr not used in non-MLA decode mode + None, # kv_indices not used in non-MLA decode mode + None, # qo_indptr will be set by _build_pa_metadata_for_decode + None, # kv_last_page_len not used in non-MLA mode + 1, # max_q_len = 1 for decode mode + None, # max_kv_len + page_table, + seq_lens_persistent, + ) + + # Build pa_metadata using CUDA graph buffers + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + return # Early return for non-MLA decode mode + else: + raise ValueError(f"Invalid mode: {forward_mode=}") + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: Optional[torch.Tensor] = None, + ): + if forward_mode.is_decode_or_idle(): + # Common setup for both MLA and non-MLA modes + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + page_table = self.req_to_token[ + req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table // self.page_size, non_blocking=True + ) + + if self.use_mla: + # MLA mode: kv_indptr and kv_indices are used in forward_decode + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." + ) + else: + # Non-MLA decode mode: kv_indptr and kv_indices are NOT used in forward_decode + # (forward_decode uses pa_metadata_pages_kv_indptr and pa_metadata_kv_indices instead) + self.forward_metadata = ForwardMetadata( + None, # kv_indptr not used in non-MLA decode mode + None, # kv_indices not used in non-MLA decode mode + None, + None, # kv_last_page_len not used in non-MLA mode + 1, # max_q_len = 1 for decode mode, non-MTP + None, # max_kv_len + page_table_persistent[:bs, :max_seq_pages], + seq_lens_persistent[:bs], + ) + + # Rebuild pa_metadata using CUDA graph buffers (updates content, keeps same addresses) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + else: + raise ValueError("Invalid forward mode") + + def set_kv_buffer_with_layout_shuffle( + self, + cache_loc, + k, + v, + k_buffer, + v_buffer, + k_scale, + v_scale, + block_size, + ): + num_slots, num_kv_heads, head_dim = k_buffer.shape + num_blocks = num_slots // block_size + num_slots_with_block = num_blocks * block_size + k_buffer = k_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + v_buffer = v_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + reshape_and_cache_shuffle_triton( + k, + v, + k_buffer, + v_buffer, + cache_loc, + "auto", + k_scale, + v_scale, + ) + + def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True): + # print(f"Running forward_extend with q shape {q.shape}, k shape {k.shape}, v shape {v.shape}", flush=True) + # print(f"q dtype: {q.dtype}, k dtype: {k.dtype}, v dtype: {v.dtype}", flush=True) + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + + self.logits_soft_cap = layer.logit_cap + + if k is not None: + assert v is not None + if save_kv_cache: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle( + cache_loc, + k, + v, + k_buffer, + v_buffer, + layer.k_scale, + layer.v_scale, + self.page_size, + ) + # forward_batch.token_to_kv_pool.set_kv_buffer( + # layer, cache_loc, k, v, layer.k_scale, layer.v_scale + # ) + + seqlens_in_batch = forward_batch.seq_lens + cu_seqlens_q = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + # use fp8 mha directly + if q.dtype != k.dtype and k.dtype == dtypes.fp8: + q = q.to(dtypes.fp8) + o = flash_attn_varlen_func( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim), + v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=self.forward_metadata.max_q_len, + max_seqlen_k=self.forward_metadata.max_kv_len, + min_seqlen_q=0, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1, 0), + sink_ptr=None, + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def forward_decode_pa( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle( + forward_batch.out_cache_loc, + k, + v, + k_buffer, + v_buffer, + layer.k_scale, + layer.v_scale, + self.page_size, + ) + + if self.use_mla: + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + + block_size = self.page_size + num_slots, num_kv_heads, head_size = k_buffer.shape + num_blocks = num_slots // block_size + k_buffer = k_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_buffer = v_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + + x = 16 // k_buffer.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_buffer.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_buffer.dtype, + device="meta", + ) + new_key_cache = k_buffer.view_as(k_cache_template) + new_value_cache = v_buffer.view_as(v_cache_template) + q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + pa_fwd_asm( + Q=q, + K=new_key_cache, + V=new_value_cache, + block_tables=self.forward_metadata.page_table, + context_lens=self.forward_metadata.kv_lens, + block_tables_stride0=self.forward_metadata.page_table.stride(0), + K_QScale=self.k_scale, + V_QScale=self.v_scale, + out_=o, + ) + return o + + def forward_decode_pa_ps( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + # Create o as 3D tensor [batch_size, num_heads, head_dim] for both MLA and pa_fwd_asm + # In decode mode, q.shape[0] equals batch_size (each sequence has 1 token) + # Use q.shape[0] instead of forward_batch.batch_size to be safe + batch_size = q.shape[0] + head_dim_out = ( + layer.v_head_dim + if layer.qk_head_dim != layer.v_head_dim + else layer.head_dim + ) + o = q.new_empty((batch_size, layer.tp_q_head_num, head_dim_out)) + + if save_kv_cache: + if self.use_mla: + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle( + forward_batch.out_cache_loc, + k, + v, + k_buffer, + v_buffer, + layer.k_scale, + layer.v_scale, + self.page_size, + ) + # Shuffle operation is already fused in rotary_emb, so just save directly + # forward_batch.token_to_kv_pool.set_kv_buffer( + # layer, forward_batch.out_cache_loc, k, v, layer.k_scale, layer.v_scale + # ) + + if self.use_mla: + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + num_slots, num_kv_heads, head_size = k_buffer.shape + block_size = self.page_size + num_blocks = num_slots // block_size + k_buffer = k_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_buffer = v_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + + quant_dtype = dtypes.fp8 + x = 16 // quant_dtype.itemsize + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_buffer.dtype, + device="meta", + ) + # V: [num_blocks, block_size, num_kv_heads, head_size] -> [num_blocks, num_kv_heads, block_size // x, head_size, x] + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_buffer.dtype, + device="meta", + ) + new_key_cache = k_buffer.view_as(k_cache_template) + new_value_cache = v_buffer.view_as(v_cache_template) + + total_tokens = num_blocks * block_size + k_qscale = self.k_qscale[:, :total_tokens] + v_qscale = self.v_qscale[:, :total_tokens] + + q = q.view(batch_size, layer.tp_q_head_num, layer.head_dim) + + assert ( + self.forward_metadata.pa_metadata_qo_indptr is not None + ), "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_pages_kv_indptr is not None + ), "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_kv_indices is not None + ), "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_context_lens is not None + ), "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_max_qlen is not None + ), "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" + + qo_indptr = self.forward_metadata.pa_metadata_qo_indptr + kv_indptr = self.forward_metadata.pa_metadata_pages_kv_indptr + kv_indices = self.forward_metadata.pa_metadata_kv_indices + context_lens = self.forward_metadata.pa_metadata_context_lens + max_qlen = self.forward_metadata.pa_metadata_max_qlen + + _, _ = pa_persistent_fwd( + Q=q, + K=new_key_cache, + V=new_value_cache, + output=o, + max_qlen=max_qlen, + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + kv_indices=kv_indices, + context_lens=context_lens, + work_indptr=self.pa_metadata_buffers["work_indptr"], + work_info=self.pa_metadata_buffers["work_info"], + reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], + K_QScale=k_qscale, + V_QScale=v_qscale, + softmax_scale=layer.scaling, + mask=1, + ) + return o.view(-1, layer.tp_q_head_num * head_dim_out) + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + if self.use_mla: + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSGLPluginMode." + ) + else: + if self.decode_using_pa_ps: + return self.forward_decode_pa_ps( + q, k, v, layer, forward_batch, save_kv_cache + ) + else: + return self.forward_decode_pa( + q, k, v, layer, forward_batch, save_kv_cache + ) diff --git a/atom/plugin/vllm/__init__.py b/atom/plugin/vllm/__init__.py new file mode 100644 index 000000000..a76f02676 --- /dev/null +++ b/atom/plugin/vllm/__init__.py @@ -0,0 +1,5 @@ +"""vLLM plugin integration for ATOM.""" + +from .register import register_model, register_platform + +__all__ = ["register_platform", "register_model"] diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py new file mode 100644 index 000000000..bcca17ad1 --- /dev/null +++ b/atom/plugin/vllm/model_wrapper.py @@ -0,0 +1,141 @@ +from collections.abc import Iterable + +import importlib +import torch +import torch.nn as nn +from aiter.dist.parallel_state import ( + get_pp_group, + get_tp_group, +) +from vllm.config import VllmConfig +from vllm.model_executor.models.interfaces import ( + SupportsPP, + SupportsQuant, +) +from vllm.model_executor.models.interfaces_base import ( + VllmModel, + VllmModelForTextGeneration, +) +from vllm.sequence import IntermediateTensors + +import atom # noqa: F401 +from atom.plugin.config import generate_atom_config_for_plugin_mode + +import logging + +logger = logging.getLogger("atom") + + +_ATOM_MODEL_CLASSES: dict[str, str] = { + "Qwen3ForCausalLM": "atom.models.qwen3:Qwen3ForCausalLM", + "Qwen3MoeForCausalLM": "atom.models.qwen3_moe:Qwen3MoeForCausalLM", +} + + +def _get_atom_model_cls(model_arch: str) -> type: + if model_arch is not None and model_arch in _ATOM_MODEL_CLASSES: + model_ref = _ATOM_MODEL_CLASSES[model_arch] + else: + raise ValueError(f"The {model_arch} is not supported by ATOM OOT backend") + + module_path, class_name = model_ref.split(":", 1) + return getattr(importlib.import_module(module_path), class_name) + + +def _prepare_env(atom_config) -> None: + from atom.plugin.register import set_attn_cls, init_aiter_dist + + # set global attention class + logger.info("Set global attention class") + set_attn_cls() + + # init aiter dist for using aiter custom collective ops + logger.info("Init aiter dist for using aiter custom collective ops") + init_aiter_dist(config=atom_config) + + +class ATOMModelBase(nn.Module, VllmModel, SupportsQuant, SupportsPP): + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.config = vllm_config.model_config.hf_config + self.text_config = self.config.get_text_config() + self.cache_config = vllm_config.cache_config + self.device_config = vllm_config.device_config + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.quant_config = vllm_config.quant_config + + # Weights to skip in `self.load_weights` + self.skip_prefixes: list[str] = [] + self.skip_substrs: list[str] = [] + self.ignore_unexpected_prefixes: list[str] = [] + self.ignore_unexpected_suffixes: list[str] = [] + + self.atom_config = generate_atom_config_for_plugin_mode(vllm_config) + + _prepare_env(atom_config=self.atom_config) + + model_arch = vllm_config.model_config.architectures[0] + model_cls = _get_atom_model_cls(model_arch) + + logger.info(f"Construct ATOM model {model_arch} for vLLM plugin mode") + self.model = model_cls(self.atom_config) + + if self.model is None: + model_arch = vllm_config.model_config.architectures[0] + raise ValueError( + f"The model {model_arch} is not supported by model impl backend atom" + ) + + # here init aiter dist for using aiter custom collective ops + self.pp_group = get_pp_group() + self.tp_group = get_tp_group() + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs, + ) -> torch.Tensor | IntermediateTensors: + if not self.pp_group.is_first_rank: + assert intermediate_tensors is not None + input_ids = None + inputs_embeds = intermediate_tensors["hidden_states"] + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + if not self.pp_group.is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + return hidden_states + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + return self.model.load_weights(weights) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + logits = self.model.compute_logits(hidden_states) + return logits + + +class ATOMForCausalLM(ATOMModelBase, VllmModelForTextGeneration): ... + + +class ATOMMoEForCausalLM(ATOMModelBase, VllmModelForTextGeneration): ... diff --git a/atom/plugin/vllm/platform.py b/atom/plugin/vllm/platform.py new file mode 100644 index 000000000..aaaa9657f --- /dev/null +++ b/atom/plugin/vllm/platform.py @@ -0,0 +1,38 @@ +"""ATOM vLLM platform integration. + +This module contains the vLLM `Platform` subclass used in ATOM's vLLM plugin +mode. Keep platform behavior here so `register.py` can focus on registration +and wiring only. +""" + +import logging +import os + +logger = logging.getLogger("atom") + +# This flag is used to enable the vLLM plugin mode. +disable_vllm_plugin = os.getenv("ATOM_DISABLE_VLLM_PLUGIN", "0").lower() == "1" +disable_vllm_plugin_attention = ( + os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "1" +) + +if not disable_vllm_plugin: + from vllm.platforms.rocm import RocmPlatform +else: + # Keep the module importable even when vLLM isn't available / plugin disabled. + RocmPlatform = object # type: ignore[assignment] + + +class ATOMPlatform(RocmPlatform): + # For multi-modality models, to make AiterBackend supported by ViT, + # get_supported_vit_attn_backends may need to be overridden. + @classmethod + def get_attn_backend_cls(cls, selected_backend, attn_selector_config) -> str: + # Fall back to original behavior of vLLM mainline. + if disable_vllm_plugin_attention: + logger.info("Fallback to original vLLM attention backend") + return super().get_attn_backend_cls(selected_backend, attn_selector_config) + + # Return ATOM attention backend. + logger.info("Use atom attention backend") + return "atom.model_ops.attentions.aiter_attention.AiterBackend" diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py new file mode 100644 index 000000000..1dea4ff15 --- /dev/null +++ b/atom/plugin/vllm/register.py @@ -0,0 +1,71 @@ +import os +from typing import Optional +import logging + +from atom.plugin.prepare import _set_framework_backbone + +logger = logging.getLogger("atom") + +# this flag is used to enable the vllm plugin mode +disable_vllm_plugin = os.getenv("ATOM_DISABLE_VLLM_PLUGIN", "0").lower() == "1" +disable_vllm_plugin_attention = ( + os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "1" +) + +# those 2 models are covering most of dense and moe models +ATOM_CAUSAL_LM_MODEL_WRAPPER = "atom.plugin.vllm.model_wrapper:ATOMForCausalLM" +ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER = "atom.plugin.vllm.model_wrapper:ATOMMoEForCausalLM" + +# when register new model to vllm, add here +# Keys is from hf config arch name +_VLLM_MODEL_REGISTRY_OVERRIDES: dict[str, str] = { + "Qwen3ForCausalLM": ATOM_CAUSAL_LM_MODEL_WRAPPER, + "Qwen3MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, +} + + +def _set_plugin_mode() -> None: + _set_framework_backbone("vllm") + + +def register_platform() -> Optional[str]: + + if disable_vllm_plugin: + # return None instead of error because the flag can be used to + # run pure vllm mode without ATOM plugin + logger.info("Disable ATOM OOT plugin platforms") + return None + + _set_plugin_mode() + + # return the ATOM platform to vllm + return "atom.plugin.vllm.platform.ATOMPlatform" + + +def register_model() -> None: + if disable_vllm_plugin: + logger.info("Disable ATOM model register") + return + + import vllm.model_executor.models.registry as vllm_model_registry + + any_updated = False + for arch, qual in _VLLM_MODEL_REGISTRY_OVERRIDES.items(): + module_name, class_name = qual.split(":", 1) + existing = vllm_model_registry.ModelRegistry.models.get(arch) + if existing is not None: + # If already overridden to the same target, skip re-registering. + if ( + getattr(existing, "module_name", None) == module_name + and getattr(existing, "class_name", None) == class_name + ): + continue + + logger.info(f"Register model {arch} to vLLM with {qual}") + vllm_model_registry.ModelRegistry.register_model(arch, qual) + any_updated = True + + # clear lru cache + if any_updated: + vllm_model_registry._try_load_model_cls.cache_clear() + vllm_model_registry._try_inspect_model_cls.cache_clear() diff --git a/atom/utils/backends.py b/atom/utils/backends.py index b583ba447..57eec887a 100644 --- a/atom/utils/backends.py +++ b/atom/utils/backends.py @@ -255,6 +255,24 @@ class SplitItem: graph: fx.GraphModule +# used to judge whether the node should be split or not +def _split_judge_func(node: fx.Node) -> bool: + # ATOM use mark_spliting_op to mark the attn as splitting op + if node.op == "call_function" and ( + hasattr(node.target, "spliting_op") and (node.target.spliting_op) + ): + return True + + # When plugin mode(vLLM), the attention impl op is registered + # as unified_attention + from atom.plugin import is_vllm + + if is_vllm() and "unified_attention" in node.name: + return True + + return False + + def split_graph( graph: fx.GraphModule, ops: list[str] ) -> tuple[fx.GraphModule, list[SplitItem]]: @@ -265,9 +283,7 @@ def split_graph( for node in graph.graph.nodes: if node.op in ("output", "placeholder"): continue - if node.op == "call_function" and ( - hasattr(node.target, "spliting_op") and (node.target.spliting_op) - ): + if _split_judge_func(node): subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id split_op_graphs.append(subgraph_id) diff --git a/atom/utils/forward_context.py b/atom/utils/forward_context.py index 3682d98f5..df0cba55e 100644 --- a/atom/utils/forward_context.py +++ b/atom/utils/forward_context.py @@ -3,12 +3,15 @@ from contextlib import contextmanager from dataclasses import dataclass, field, fields -from typing import Any, Dict, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Union import numpy as np import torch from atom.config import Config, KVCacheTensor, ParallelConfig +if TYPE_CHECKING: + from atom.plugin.attention import MetadataForPluginMode + def _compute_chunked_local_num_tokens( num_tokens_across_dp_cpu: list[int], max_num_tokens: int, chunk_idx: int @@ -185,6 +188,9 @@ class AttentionMetaData: block_tables_converted: Optional[torch.Tensor] = None kv_indices_converted: Optional[torch.Tensor] = None + # only used for plugin mode to store the metadata for attn + plugin_metadata: Optional["MetadataForPluginMode"] = None + def __init__( self, cu_seqlens_q: Optional[torch.Tensor] = None, @@ -212,6 +218,7 @@ def __init__( kv_indices_converted: Optional[torch.Tensor] = None, sparse_cu_seqlens_q: Optional[torch.Tensor] = None, token_to_seq_idxs: Optional[torch.Tensor] = None, + plugin_metadata: Optional["MetadataForPluginMode"] = None, ): self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k @@ -240,6 +247,8 @@ def __init__( self.kv_indices = kv_indices_converted self.sparse_cu_seqlens_q = sparse_cu_seqlens_q self.token_to_seq_idxs = token_to_seq_idxs + if plugin_metadata is not None: + self.plugin_metadata = plugin_metadata def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" diff --git a/pyproject.toml b/pyproject.toml index 49d3d8c0b..7e6c623bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,6 @@ [build-system] -requires = ["setuptools_scm[toml]>=6.2"] +requires = ["setuptools>=61", "wheel", "setuptools_scm[toml]>=6.2"] +build-backend = "setuptools.build_meta" [project] name = "atom" @@ -24,3 +25,15 @@ fallback_version = "0.1.0" [tool.setuptools.packages.find] where = ["."] include = ["atom*"] + +[project.entry-points."vllm.platform_plugins"] +# by default the entry-points will be installed anyway +# but the plugin mode for platforms can be disabled by +# ATOM_DISABLE_VLLM_PLUGIN=1 +atom = "atom.plugin.vllm.register:register_platform" + +[project.entry-points."vllm.general_plugins"] +# by default the entry-points will be installed anyway +# but the plugin mode for models can be disabled by +# ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1 +atom_model_registry = "atom.plugin.vllm.register:register_model" diff --git a/recipes/SGLang-ATOM-Model-Impl-Backend.md b/recipes/SGLang-ATOM-Model-Impl-Backend.md new file mode 100644 index 000000000..fe409a269 --- /dev/null +++ b/recipes/SGLang-ATOM-Model-Impl-Backend.md @@ -0,0 +1,72 @@ +# Model Impl Backend of SGLang +ATOM can work as model implementation backend of popular framework, like SGLang. The users can launch the server like before and specify an extra argument to enable the ATOM model backend, where the optimized implementation of the required target model will be provided to SGLang to execute. When ATOM working under this mode, both framework-level features from SGLang and latest model-level fusion kernels from ATOM/AITER can be combined together to achieve the competitive performance. + +- Here is a detailed design slide for this feature: https://amdcloud-my.sharepoint.com/:p:/g/personal/zejchen_amd_com/IQCFdvmEeLTWT7ysApmZv_hVAfw2nTo8iesJZGblHS0evqQ?e=hjnIDM +- Here is the RFC to introduce the ATOM as model impl backend into SGLang: TODO + +## Preparing environment for SGLang with ATOM model backend +Here is the PR to introduce ATOM into SGLang: https://github.com/sgl-project/sglang/pull/16944, when this PR would be merged, the official SGLang can be used, but for now you need to use develop vllm branch +Pull the latest docker from SGLang official nightly docker for ROCm from https://hub.docker.com/r/rocm/sgl-dev/tags +```bash +docker pull rocm/sgl-dev:v0.5.8-rocm720-mi35x-20260130-preview +``` +Launch the container as usual, then all the next operations will be executed inside the container +Then the specific SGLang should be used because the PR to introduce the ATOM into SGLang has not been merged yet, so you need to: +```bash +git clone https://github.com/zejunchen-zejun/sglang.git +git checkout remotes/origin/zejun/model_impl +pip uninstall sglang -y +pip uninstall sgl-kernel -y +cd sgl-kernel +python3 setup_rocm.py install +export PYTHONPATH= +``` +Then the ATOM should be installed +```bash +git clone https://github.com/zejunchen-zejun/ATOM.git +cd ATOM +git checkout origin/zejun/plugin_for_atom_1223 +pip install -e . 2>&1 | tee build.log +``` +For AITER, there is no specific requirements, however, if you find any latest fusion kernels are missing, you may need to upgrade the AITER + +### Launching server of SGLang with ATOM model backend +You just need to deploy single code change, as add --model-impl atom to your SGLang server command. Here is an example: +```bash +export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 + +# quick allreduce +export AITER_QUICK_REDUCE_QUANTIZATION=INT4 +model_path=/data/models/Qwen3-235B-A22B-Instruct-2507-FP8 + +python3 -m sglang.launch_server \ + --model-path $model_path \ + --host localhost \ + --port 8000 \ + --trust-remote-code \ + --tensor-parallel-size 8 \ + --expert-parallel-size 8 \ + --kv-cache-dtype fp8_e4m3 \ + --mem-fraction-static 0.8 \ + --model-impl atom \ + 2>&1 | tee log.serve.log & +``` + +### Launching client for validating the accuracy +```bash +addr=localhost +port=8000 +url=http://${addr}:${port}/v1/completions +model= +task=gsm8k +lm_eval --model local-completions \ + --model_args model=${model},base_url=${url},num_concurrent=65,max_retries=1,tokenized_requests=False \ + --tasks ${task} \ + --num_fewshot 3 \ + 2>&1 | tee log.lmeval.log +``` + +### Known Limitations +There are some known limitations for now: +- Only Qwen-Dense and Qwen-MoE family models are supported +- Only TP and EP are supported diff --git a/recipes/vLLM-ATOM-OOT-Plugin-Backend.md b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md new file mode 100644 index 000000000..504753d3e --- /dev/null +++ b/recipes/vLLM-ATOM-OOT-Plugin-Backend.md @@ -0,0 +1,89 @@ +# vLLM out-of-tree ATOM Plugin Backend +ATOM can work as the OOT plugin backend of vLLM. The OOT register mechanism is quite mature and most of accelerators have leveraged this design to register their devices into vLLM without any code changes in upper framework. ATOM follows this design and provide the layer/op and model implementations to vLLM. The frontend users can launch vLLM server like before and there is no need to specify any arguments. Meanwhile the ATOM platform can leverage most of the vLLM features and focus more on model- and kernel-level optimizations. For the overall design, here is a RFC to enable ATOM work as the OOT plugin platform of vLLM: https://github.com/ROCm/ATOM/issues/201 + +## Preparing environment for vLLM with ATOM model backend +Pull the vLLM official docker for ROCm. If you are using the vLLM nightly docker, there could be incompatible error because vLLM is changing its code and may break the class/module import in ATOM +```bash +docker pull rocm/vllm-dev:nightly_main_20260118 +``` + +Then the ATOM should be installed. When the following PR merged, you can use ATOM main branch +```bash +git clone https://github.com/zejunchen-zejun/ATOM.git +cd ATOM +git checkout origin/zejun/plugin_for_atom_1223 +pip install -e . 2>&1 | tee build.log +``` +For AITER, there is no specific requirements, however, if you find any latest fusion kernels are missing, you may need to upgrade the AITER +Additionally, you may need to install some dependencies by: +```bash +pip install --upgrade triton +pip install transformers==5.0.0 +pip install git+https://github.com/foundation-model-stack/fastsafetensors.git +``` + +### Launching server of vLLM with ATOM OOT Plugin Platform +There is no code change to vLLM side, so you can launch the vLLM server like before without any specific argument +```bash +export SAFETENSORS_FAST_GPU=1 +export VLLM_ROCM_USE_AITER=1 +export VLLM_RPC_TIMEOUT=1800000 + +export VLLM_CACHE_ROOT=/root/.cache/vllm +export TORCHINDUCTOR_CACHE_DIR=/root/.cache/inductor + +rm -rf /root/.cache/ + +model_path= + +vllm serve $model_path \ + --host localhost \ + --port 8000 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --trust-remote-code \ + --disable-log-requests \ + --gpu_memory_utilization 0.9 \ + --async-scheduling \ + --load-format fastsafetensors \ + --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \ + --kv-cache-dtype fp8 \ + --max-num-batched-tokens 18432 \ + --max-model-len 16384 \ + --no-enable-prefix-caching \ + 2>&1 | tee log.serve.log & +``` + +If you want to disable the ATOM OOT plugin platform and model register, you can use below env flags. The default value is 0 +```bash +export ATOM_DISABLE_VLLM_PLUGIN=1 +``` +If you only want to disable the ATOM Attention Backend, you can use below env flags. The default value is 0 +```bash +export ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1 +``` + +### Launching client for validating the accuracy +```bash +addr=localhost +port=8000 +url=http://${addr}:${port}/v1/completions +model= +task=gsm8k +lm_eval --model local-completions \ + --model_args model=${model},base_url=${url},num_concurrent=65,max_retries=1,tokenized_requests=False \ + --tasks ${task} \ + --num_fewshot 3 \ + 2>&1 | tee log.lmeval.log +``` + +### Results for accuracy validation +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 3|exact_match|↑ |0.9037|± |0.0081| +| | |strict-match | 3|exact_match|↑ |0.8832|± |0.0088| + +### Known Limitations +There are some known limitations for now: +- Only Qwen-Dense and Qwen-MoE family models are supported +- Only TP and EP are supported