Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
cae53bd
[feat][plugin] Make ATOM work as plugin for upper
zejunchen-zejun Jan 15, 2026
a5b5f3f
add
zejunchen-zejun Feb 2, 2026
6563083
add
zejunchen-zejun Feb 2, 2026
0f05699
add
zejunchen-zejun Feb 2, 2026
7712e1f
format ruff
zejunchen-zejun Feb 2, 2026
c203883
ruff format
zejunchen-zejun Feb 2, 2026
6bbaade
add
zejunchen-zejun Feb 2, 2026
a669441
add
zejunchen-zejun Feb 2, 2026
2f8e6ee
add
zejunchen-zejun Feb 2, 2026
113e587
add
zejunchen-zejun Feb 2, 2026
de036e8
add
zejunchen-zejun Feb 3, 2026
f0f0c94
add
zejunchen-zejun Feb 3, 2026
dd6e9b3
add
zejunchen-zejun Feb 3, 2026
b985b82
add
zejunchen-zejun Feb 3, 2026
b9806e0
fix sglang plugin mode acc issue
Feb 3, 2026
418d442
init vllm-atom, first commit
zejunchen-zejun Feb 7, 2026
7c54abe
add
zejunchen-zejun Feb 9, 2026
a44bed1
add
zejunchen-zejun Feb 9, 2026
43604c9
add
zejunchen-zejun Feb 9, 2026
285929a
add
zejunchen-zejun Feb 9, 2026
77795eb
add
zejunchen-zejun Feb 9, 2026
b13a670
make lint happy
zejunchen-zejun Feb 10, 2026
31ccb16
add
zejunchen-zejun Feb 10, 2026
f226b95
add
zejunchen-zejun Feb 10, 2026
b553cd2
add
zejunchen-zejun Feb 10, 2026
e1e83d4
add
zejunchen-zejun Feb 10, 2026
484e17d
add
zejunchen-zejun Feb 10, 2026
36b6fd3
add
zejunchen-zejun Feb 10, 2026
b1fb7b6
add
zejunchen-zejun Feb 10, 2026
0f0bedc
add
zejunchen-zejun Feb 10, 2026
a051118
add
zejunchen-zejun Feb 10, 2026
a00f59e
register attn backend to sgl from ATOM
Feb 11, 2026
8491ef7
make format happy
Feb 11, 2026
9eb1c19
add
zejunchen-zejun Feb 25, 2026
6d14b84
add
zejunchen-zejun Feb 25, 2026
2c0a44a
add
zejunchen-zejun Feb 25, 2026
9498684
add
zejunchen-zejun Feb 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions atom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,12 @@

from atom.model_engine.llm_engine import LLMEngine
from atom.sampling_params import SamplingParams

# interface for upper framework to constructe the model from ATOM
from atom.plugin import prepare_model

__all__ = [
"LLMEngine",
"SamplingParams",
"prepare_model",
]
36 changes: 26 additions & 10 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from torch.distributed import ProcessGroup, ReduceOp
from transformers import AutoConfig, GenerationConfig, PretrainedConfig

# plugin-related utilities
from atom.plugin import is_plugin_mode
from atom.plugin.config import PluginConfig

logger = logging.getLogger("atom")


Expand Down Expand Up @@ -584,6 +588,9 @@ class Config:
torch_dtype: torch.dtype = field(init=False)
speculative_config: Optional[SpeculativeConfig] = None

# only use for plugin mode
plugin_config: Optional[PluginConfig] = None

def _set_cudagraph_sizes(self):
if self.compilation_config.cudagraph_capture_sizes:
self.graph_bs = self.compilation_config.cudagraph_capture_sizes
Expand Down Expand Up @@ -626,16 +633,25 @@ def __post_init__(self):
self.max_model_len, hf_config_max_position_embeddings
)
# assert self.max_num_batched_tokens >= self.max_model_len
if self.torch_profiler_dir is not None:
os.makedirs(self.torch_profiler_dir, exist_ok=True)
assert self.torch_profiler_dir is None or os.path.isdir(
self.torch_profiler_dir
), f"torch_profiler_dir {self.torch_profiler_dir} is not a valid directory"
if self.compilation_config.level == CompilationLevel.PIECEWISE:
self.compilation_config.set_splitting_ops_for_v1()
self._set_cudagraph_sizes()
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
self.compilation_config.init_with_cudagraph_sizes()
if not is_plugin_mode():
if self.torch_profiler_dir is not None:
os.makedirs(self.torch_profiler_dir, exist_ok=True)
assert self.torch_profiler_dir is None or os.path.isdir(
self.torch_profiler_dir
), f"torch_profiler_dir {self.torch_profiler_dir} is not a valid directory"

# only for server mode or plugin mode(vllm)
# for torch compile policy, plugin mode(vllm) uses the ATOM compile policy
# for cuda graph capture, plugin mode(vllm) uses the vLLM's cuda graph capture policy
if not is_plugin_mode() or (
self.plugin_config is not None and self.plugin_config.is_vllm
):
if self.compilation_config.level == CompilationLevel.PIECEWISE:
self.compilation_config.set_splitting_ops_for_v1()
self._set_cudagraph_sizes()
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
self.compilation_config.init_with_cudagraph_sizes()

self.torch_dtype = (
self.hf_config.torch_dtype
if getattr(self.hf_config, "torch_dtype", None) is not None
Expand Down
8 changes: 5 additions & 3 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,9 @@ def __init__(self, rank: int, config: Config):
self.drafter.load_model(self.model)
torch.set_default_device(self.device)
self.allocate_forward_vars()
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(self)
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
model_runner=self
)
self.physical_block_size = self.attn_metadata_builder.block_size
self.forward_done_event = torch.cuda.Event()
self.warmup_model()
Expand Down Expand Up @@ -1171,7 +1173,7 @@ def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None):
self.forward_vars["cu_seqlens_q"].np[scheduled_bs + 1 : bs + 1] = (
self.forward_vars["cu_seqlens_q"].np[scheduled_bs]
)
attn_metadata, positions = self.attn_metadata_builder.build(batch, bs)
attn_metadata, positions = self.attn_metadata_builder.build(batch=batch, bs=bs)
context_bs = batch.total_seqs_num_prefill if is_prefill else scheduled_bs

# graph_bs should be batch size (number of sequences), not token count
Expand Down Expand Up @@ -1472,7 +1474,7 @@ def capture_cudagraph(self):
)

attn_metadata, context = (
self.attn_metadata_builder.build_for_cudagraph_capture(bs)
self.attn_metadata_builder.build_for_cudagraph_capture(bs=bs)
)
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
Expand Down
71 changes: 69 additions & 2 deletions atom/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_spec_layer_idx_from_weight_name,
rewrite_spec_layer_name,
)
from atom.plugin.prepare import is_vllm, is_sglang

logger = logging.getLogger("atom")

Expand Down Expand Up @@ -80,13 +81,61 @@ def safetensors_weights_iterator(
yield name, f.get_tensor(name)


# when plugin mode, model loader method is bind to model implementation
# thus call this interface to load the model, which leverages the load_model
# method
def load_model_in_plugin_mode(
model,
config,
prefix: str = "",
) -> set[str]:

# during loading model, the outplace operation may consume more
# GPU mem, which cached in torch caching allocator, here actively
# call empty cache to free the extra reserved but not used memory
def _empty_cache():
import gc

gc.collect()
torch.cuda.empty_cache()

assert (
config.plugin_config is not None and config.plugin_config.is_plugin_mode
), "ATOM is not running in plugin mode"
if config.plugin_config.is_vllm:
model_name_or_path = config.plugin_config.model_config.model
elif config.plugin_config.is_sglang:
model_name_or_path = config.plugin_config.model_config.model_path

_empty_cache()
loaded_weights_record = load_model(
model=model,
model_name_or_path=model_name_or_path,
hf_config=config.hf_config,
load_dummy=config.load_dummy,
spec_decode=False,
prefix=prefix,
is_plugin_mode=True,
act_dtype=config.plugin_config.model_config.dtype,
)
_empty_cache()
return loaded_weights_record


def load_model(
model: nn.Module,
model_name_or_path: str,
hf_config: AutoConfig,
load_dummy: bool = False,
spec_decode: bool = False,
prefix: str = "",
is_plugin_mode: bool = False,
act_dtype: torch.dtype = None,
):
# need to record the loaded weight name for vllm load check
# it is only used in plugin mode for vllm
loaded_weights_record: set[str] = set()

packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
weights_mapping = getattr(model, "weights_mapping", {})
params_dict = dict(model.named_parameters())
Expand Down Expand Up @@ -145,6 +194,7 @@ def load_model(
weight_loader, param, weight_tensor, shard_id
)
)
loaded_weights_record.add(prefix + param_name)
break
else:
# Check if model has expert mapping before processing
Expand All @@ -170,6 +220,7 @@ def load_model(
expert_id,
)
)
loaded_weights_record.add(prefix + name)
# weight_loader(
# param,
# weight_tensor,
Expand All @@ -186,6 +237,7 @@ def load_model(
futures.append(
executor.submit(weight_loader, param, weight_tensor)
)
loaded_weights_record.add(prefix + name)
# weight_loader(param, weight_tensor)
else:
# Model doesn't have expert mapping, use generic loading
Expand All @@ -195,14 +247,29 @@ def load_model(
)
# weight_loader(param, weight_tensor)
futures.append(executor.submit(weight_loader, param, weight_tensor))
loaded_weights_record.add(prefix + name)
# Wait for all tasks to complete and raise any exceptions.
for future in concurrent.futures.as_completed(futures):
future.result()
for _, module in model.named_modules():
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
if is_vllm():
from vllm.attention.layer import Attention

# call vLLM attn weights post processing with act_dtype if using vLLM attention module
if isinstance(module, Attention):
module.process_weights_after_loading(act_dtype=act_dtype)
else:
module.process_weights_after_loading()
else:
module.process_weights_after_loading()
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# when running plugin mode for sglang, don't do the post process here
# since sglang will call this func automatically after finishing loading
if isinstance(quant_method, QuantizeMethodBase) and not is_sglang():
quant_method.process_weights_after_loading(module)
if isinstance(quant_method, FusedMoEMethodBase):
quant_method.init_prepare_finalize(module)

if is_plugin_mode:
return loaded_weights_record
14 changes: 14 additions & 0 deletions atom/model_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .paged_attention import PagedAttention
from .radix_attention import RadixAttention

# This global class is used to construct the attention op in model,
# it can be assigned to different attention ops.
# By default, PagedAttention is used.
# For sglang, RadixAttention will be assigned to ATTN_CLS
ATTN_CLS = PagedAttention

__all__ = [
"ATTN_CLS",
"PagedAttention",
"RadixAttention",
]
66 changes: 63 additions & 3 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,31 @@

from .attention_mla import MLAModules

from atom.plugin.prepare import is_plugin_mode, is_vllm
from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode

class Attention(nn.Module):

@PagedAttentionImplDecoratorForPluginMode
class PagedAttentionImpl(nn.Module):
"""
Attention paged implementation
"""

def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
alibi_slopes: list[float] | None,
sliding_window: Optional[int] = None,
kv_cache_dtype="bf16",
logits_soft_cap: float | None = None,
attn_type=None,
kv_sharing_target_layer_name: int | None = None,
layer_num=0,
mla_modules: Optional[MLAModules] = None,
sinks: Optional[nn.Parameter] = None,
sliding_window: Optional[int] = None,
rotary_emb: Optional[torch.nn.Module] = None,
q_norm: Optional[torch.nn.Module] = None,
k_norm: Optional[torch.nn.Module] = None,
Expand All @@ -37,12 +48,16 @@ def __init__(
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
# for upper framework, it uses head_size in built-in methods
self.head_size = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.alibi_slopes = alibi_slopes
self.k_cache = self.v_cache = torch.tensor([])
self.kv_cache_dtype = kv_cache_dtype
self.max_model_len = 0
self.k_scale = self.v_scale = None
self.device = "cuda:" + str(torch.cuda.current_device())
self.layer_num = layer_num
self.kv_scale_float = (
torch.finfo(torch.float8_e4m3fn).max / torch.finfo(aiter.dtypes.fp8).max
Expand All @@ -56,7 +71,16 @@ def __init__(
self.q_norm = q_norm
self.k_norm = k_norm

def forward(
# for plugin mode(vllm), the query quant is disabled for now
if is_vllm():
self.supports_quant_query_input = False

# this method will just be called by vLLM and there is no logic in this method
# as ATOM handles the process after loading weights for all ops by itself
def process_weights_after_loading(self, act_dtype: torch.dtype = torch.bfloat16):
pass

def forward_impl_server_mode(
self,
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -414,3 +438,39 @@ def dispatch_backend(self, fwd_ctx: ForwardContext):
if atom_config.kv_cache_block_size == 1024:
return self.paged_attention_persistent_asm
return self.paged_attention_asm

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor = None,
attn_metadata=None,
position: torch.Tensor = None,
q_scale: Optional[torch.Tensor] = None,
qkv: torch.Tensor = None,
output: torch.Tensor = None,
**kwargs,
):
if is_plugin_mode():
# forward impl method are added by the decorator
# PagedAttentionImplDecoratorForPluginMode
return self.forward_impl_plugin_mode(
layer=layer,
query=query,
key=key,
value=value,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
position=position,
q_scale=q_scale,
qkv=qkv,
)
else:
# only for server mode, keep the original method
o = self.forward_impl_server_mode(
q=query, k=key, v=value, position=position, q_scale=q_scale, qkv=qkv
)

return o
4 changes: 2 additions & 2 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class MLAAttention(nn.Module):
def __init__(
self,
num_heads: int,
head_size: int,
head_dim: int,
scale: float,
num_kv_heads: int,
kv_cache_dtype: str,
Expand All @@ -104,7 +104,7 @@ def __init__(
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.head_dim = head_dim
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype == "fp8" else "auto"
Expand Down
Loading