Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions atom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,12 @@

from atom.model_engine.llm_engine import LLMEngine
from atom.sampling_params import SamplingParams

# interface for upper framework to construct the model from ATOM
from atom.plugin import prepare_model

__all__ = [
"LLMEngine",
"SamplingParams",
"prepare_model",
]
37 changes: 27 additions & 10 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from torch.distributed import ProcessGroup, ReduceOp
from transformers import AutoConfig, GenerationConfig, PretrainedConfig

# plugin-related utilities
from atom.plugin import is_plugin_mode
from atom.plugin.config import PluginConfig

logger = logging.getLogger("atom")


Expand Down Expand Up @@ -598,6 +602,9 @@ class Config:
torch_dtype: torch.dtype = field(init=False)
speculative_config: Optional[SpeculativeConfig] = None

# only use for plugin mode
plugin_config: Optional[PluginConfig] = None

def _set_cudagraph_sizes(self):
if self.compilation_config.cudagraph_capture_sizes:
self.graph_bs = self.compilation_config.cudagraph_capture_sizes
Expand All @@ -621,6 +628,7 @@ def __post_init__(self):
if rope_params is None:
rope_params = {}
rope_params["rope_theta"] = getattr(self.hf_config, "rope_theta", None)
rope_params["rope_type"] = getattr(self.hf_config, "rope_type", "default")
self.hf_config.rope_parameters = rope_params

self.generation_config = get_generation_config(self.model)
Expand All @@ -640,16 +648,25 @@ def __post_init__(self):
self.max_model_len, hf_config_max_position_embeddings
)
# assert self.max_num_batched_tokens >= self.max_model_len
if self.torch_profiler_dir is not None:
os.makedirs(self.torch_profiler_dir, exist_ok=True)
assert self.torch_profiler_dir is None or os.path.isdir(
self.torch_profiler_dir
), f"torch_profiler_dir {self.torch_profiler_dir} is not a valid directory"
if self.compilation_config.level == CompilationLevel.PIECEWISE:
self.compilation_config.set_splitting_ops_for_v1()
self._set_cudagraph_sizes()
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
self.compilation_config.init_with_cudagraph_sizes()
if not is_plugin_mode():
if self.torch_profiler_dir is not None:
os.makedirs(self.torch_profiler_dir, exist_ok=True)
assert self.torch_profiler_dir is None or os.path.isdir(
self.torch_profiler_dir
), f"torch_profiler_dir {self.torch_profiler_dir} is not a valid directory"

# only for server mode or plugin mode(vllm)
# for torch compile policy, plugin mode(vllm) uses the ATOM compile policy
# for cuda graph capture, plugin mode(vllm) uses the vLLM's cuda graph capture policy
if not is_plugin_mode() or (
self.plugin_config is not None and self.plugin_config.is_vllm
):
if self.compilation_config.level == CompilationLevel.PIECEWISE:
self.compilation_config.set_splitting_ops_for_v1()
self._set_cudagraph_sizes()
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
self.compilation_config.init_with_cudagraph_sizes()

self.torch_dtype = (
self.hf_config.dtype
if getattr(self.hf_config, "dtype", None) is not None
Expand Down
8 changes: 5 additions & 3 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,9 @@ def __init__(self, rank: int, config: Config):
self.drafter.load_model(self.model)
torch.set_default_device(self.device)
self.allocate_forward_vars()
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(self)
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
model_runner=self
)
self.physical_block_size = self.attn_metadata_builder.block_size
self.forward_done_event = torch.cuda.Event()
self.warmup_model()
Expand Down Expand Up @@ -1251,7 +1253,7 @@ def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None):
self.forward_vars["cu_seqlens_q"].np[scheduled_bs + 1 : bs + 1] = (
self.forward_vars["cu_seqlens_q"].np[scheduled_bs]
)
attn_metadata, positions = self.attn_metadata_builder.build(batch, bs)
attn_metadata, positions = self.attn_metadata_builder.build(batch=batch, bs=bs)
context_bs = batch.total_seqs_num_prefill if is_prefill else scheduled_bs

# graph_bs should be batch size (number of sequences), not token count
Expand Down Expand Up @@ -1503,7 +1505,7 @@ def capture_cudagraph(self):
)

attn_metadata, context = (
self.attn_metadata_builder.build_for_cudagraph_capture(bs)
self.attn_metadata_builder.build_for_cudagraph_capture(bs=bs)
)
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
Expand Down
60 changes: 59 additions & 1 deletion atom/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from aiter.dist.parallel_state import get_tp_group
from atom.models.qwen3_next_mtp import remap_mtp_weight_name

from atom.plugin.prepare import is_sglang

logger = logging.getLogger("atom")


Expand Down Expand Up @@ -81,12 +83,54 @@ def safetensors_weights_iterator(
yield name, f.get_tensor(name)


# when plugin mode, model loader method is bind to model implementation
# thus call this interface to load the model, which leverages the load_model
# method
def load_model_in_plugin_mode(
model,
config,
prefix: str = "",
) -> set[str]:

# during loading model, the outplace operation may consume more
# GPU mem, which cached in torch caching allocator, here actively
# call empty cache to free the extra reserved but not used memory
def _empty_cache():
import gc

gc.collect()
torch.cuda.empty_cache()

assert (
config.plugin_config is not None and config.plugin_config.is_plugin_mode
), "ATOM is not running in plugin mode"
if config.plugin_config.is_vllm:
model_name_or_path = config.plugin_config.model_config.model
elif config.plugin_config.is_sglang:
model_name_or_path = config.plugin_config.model_config.model_path

_empty_cache()
loaded_weights_record = load_model(
model=model,
model_name_or_path=model_name_or_path,
hf_config=config.hf_config,
load_dummy=config.load_dummy,
spec_decode=False,
prefix=prefix,
is_plugin_mode=True,
)
_empty_cache()
return loaded_weights_record


def load_model(
model: nn.Module,
model_name_or_path: str,
hf_config: AutoConfig,
load_dummy: bool = False,
spec_decode: bool = False,
prefix: str = "",
is_plugin_mode: bool = False,
):
def have_shared_expert(name):
maybe_matching_list = ["mlp.shared_experts.", "mlp.shared_expert."]
Expand All @@ -95,6 +139,10 @@ def have_shared_expert(name):
return maybe_matching_name
return None

# need to record the loaded weight name for vllm load check
# it is only used in plugin mode for vllm
loaded_weights_record: set[str] = set()

packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
weights_mapping = getattr(model, "weights_mapping", {})
params_dict = dict(model.named_parameters())
Expand Down Expand Up @@ -161,6 +209,7 @@ def have_shared_expert(name):
weight_loader, param, weight_tensor, shard_id
)
)
loaded_weights_record.add(prefix + param_name)
break
else:
# Check if model has expert mapping before processing
Expand Down Expand Up @@ -188,6 +237,7 @@ def have_shared_expert(name):
expert_id,
)
)
loaded_weights_record.add(prefix + name)
# weight_loader(
# param,
# weight_tensor,
Expand All @@ -206,6 +256,7 @@ def have_shared_expert(name):
futures.append(
executor.submit(weight_loader, param, weight_tensor)
)
loaded_weights_record.add(prefix + name)
# weight_loader(param, weight_tensor)
else:
# Model doesn't have expert mapping, use generic loading
Expand All @@ -215,14 +266,21 @@ def have_shared_expert(name):
)
# weight_loader(param, weight_tensor)
futures.append(executor.submit(weight_loader, param, weight_tensor))
loaded_weights_record.add(prefix + name)
# Wait for all tasks to complete and raise any exceptions.
for future in concurrent.futures.as_completed(futures):
future.result()
for _, module in model.named_modules():
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):

# when running plugin mode for sglang, don't do the post process here
# since sglang will call this func automatically after finishing loading
if isinstance(quant_method, QuantizeMethodBase) and not is_sglang():
quant_method.process_weights_after_loading(module)
if isinstance(quant_method, FusedMoEMethodBase):
quant_method.init_prepare_finalize(module)

if is_plugin_mode:
return loaded_weights_record
14 changes: 14 additions & 0 deletions atom/model_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .paged_attention import PagedAttention
from .radix_attention import RadixAttention

# This global class is used to construct the attention op in model,
# it can be assigned to different attention ops.
# By default, PagedAttention is used.
# For sglang, RadixAttention will be assigned to Attention
Attention = PagedAttention

__all__ = [
"Attention",
"PagedAttention",
"RadixAttention",
]
61 changes: 58 additions & 3 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,31 @@

from .attention_mla import MLAModules

from atom.plugin.prepare import is_plugin_mode, is_vllm
from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode

class Attention(nn.Module):

@PagedAttentionImplDecoratorForPluginMode
class PagedAttentionImpl(nn.Module):
"""
Attention paged implementation
"""

def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
alibi_slopes: list[float] | None,
sliding_window: Optional[int] = None,
kv_cache_dtype="bf16",
logits_soft_cap: float | None = None,
attn_type=None,
kv_sharing_target_layer_name: int | None = None,
layer_num=0,
mla_modules: Optional[MLAModules] = None,
sinks: Optional[nn.Parameter] = None,
sliding_window: Optional[int] = None,
rotary_emb: Optional[torch.nn.Module] = None,
q_norm: Optional[torch.nn.Module] = None,
k_norm: Optional[torch.nn.Module] = None,
Expand All @@ -37,12 +48,16 @@ def __init__(
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
# for upper framework, it uses head_size in built-in methods
self.head_size = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.alibi_slopes = alibi_slopes
self.k_cache = self.v_cache = torch.tensor([])
self.kv_cache_dtype = kv_cache_dtype
self.max_model_len = 0
self.k_scale = self.v_scale = None
self.device = "cuda:" + str(torch.cuda.current_device())
self.layer_num = layer_num
self.kv_scale_float = (
torch.finfo(torch.float8_e4m3fn).max / torch.finfo(aiter.dtypes.fp8).max
Expand All @@ -56,7 +71,11 @@ def __init__(
self.q_norm = q_norm
self.k_norm = k_norm

def forward(
# for plugin mode(vllm), the query quant is disabled for now
if is_vllm():
self.supports_quant_query_input = False

def forward_impl_server_mode(
self,
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -416,3 +435,39 @@ def dispatch_backend(self, fwd_ctx: ForwardContext):
if atom_config.kv_cache_block_size == 1024:
return self.paged_attention_persistent_asm
return self.paged_attention_asm

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor = None,
attn_metadata=None,
position: torch.Tensor = None,
q_scale: Optional[torch.Tensor] = None,
qkv: torch.Tensor = None,
output: torch.Tensor = None,
**kwargs,
):
if is_plugin_mode():
# forward impl method are added by the decorator
# PagedAttentionImplDecoratorForPluginMode
return self.forward_impl_plugin_mode(
layer=layer,
query=query,
key=key,
value=value,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
position=position,
q_scale=q_scale,
qkv=qkv,
)
else:
# only for server mode, keep the original method
o = self.forward_impl_server_mode(
q=query, k=key, v=value, position=position, q_scale=q_scale, qkv=qkv
)

return o
4 changes: 2 additions & 2 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class MLAAttention(nn.Module):
def __init__(
self,
num_heads: int,
head_size: int,
head_dim: int,
scale: float,
num_kv_heads: int,
kv_cache_dtype: str,
Expand All @@ -107,7 +107,7 @@ def __init__(
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.head_dim = head_dim
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype == "fp8" else "auto"
Expand Down
Loading