Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions vllm/attention/backends/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
)
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
CUSTOM = ""
# set to None to avoid alias with other backend, whose value is an empty string
CUSTOM = None

def get_path(self, include_classname: bool = True) -> str:
"""Get the class path for this backend (respects overrides).
Expand Down Expand Up @@ -139,7 +140,8 @@ class MambaAttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend"
GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
CUSTOM = ""
# set to None to avoid alias with other backend, whose value is an empty string
CUSTOM = None

def get_path(self, include_classname: bool = True) -> str:
"""Get the class path for this backend (respects overrides).
Expand Down
13 changes: 7 additions & 6 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ class ModelConfig:
available.\n
- "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation.\n
- "terratorch" will use the TerraTorch model implementation.
- "terratorch" will use the TerraTorch model implementation.\n
- "atom" will use the atom model implementation for AMD users.
"""
override_attention_dtype: str | None = None
"""Override dtype for attention"""
Expand Down Expand Up @@ -741,10 +742,10 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
raise ValueError("max_model_len must be an integer after __post_init__.")
return self

def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers modeling backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
cls = "Transformers"
def _get_model_impl_backend_cls(self, model_impl: str = "Transformers") -> str:
"""Determine which modeling backend class of the model implementation will be used if
`model_impl` is set to `transformers`, `auto` or other backends."""
cls = model_impl
# If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal
cls += "MultiModal" if self.hf_config != self.hf_text_config else ""
cls += "MoE" if self.get_num_experts() > 1 else ""
Expand All @@ -771,7 +772,7 @@ def _get_transformers_backend_cls(self) -> str:
def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers modeling backend class."""
used_cls = self._model_info.architecture
transformers_backend_cls = self._get_transformers_backend_cls()
transformers_backend_cls = self._get_model_impl_backend_cls()
return used_cls == transformers_backend_cls

@property
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
model_config=model_config,
)

if arch == model_config._get_transformers_backend_cls():
if arch == model_config._get_model_impl_backend_cls():
assert model_config.model_impl != "vllm"
if model_config.model_impl == "auto":
logger.warning_once(
Expand Down
110 changes: 110 additions & 0 deletions vllm/model_executor/models/atom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright 2026 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `atom` models"""

import torch
import torch.nn as nn
from collections.abc import Iterable

from vllm.model_executor.models.interfaces import (
SupportsPP,
SupportsQuant,
)
from vllm.model_executor.models.interfaces_base import VllmModel
from vllm.model_executor.models.interfaces_base import VllmModelForTextGeneration
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors

logger = init_logger(__name__)

class ATOMModelBase(nn.Module, VllmModel, SupportsQuant, SupportsPP):

def __init_subclass__(cls, *args, **kwargs):
super().__init_subclass__(*args, **kwargs)

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
logger.info("Using ATOM modeling backend.")

self.config = vllm_config.model_config.hf_config
self.text_config = self.config.get_text_config()
self.cache_config = vllm_config.cache_config
self.device_config = vllm_config.device_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.quant_config = vllm_config.quant_config

self.pp_group = get_pp_group()
self.tp_group = get_tp_group()

# Weights to skip in `self.load_weights`
self.skip_prefixes: list[str] = []
self.skip_substrs: list[str] = []
self.ignore_unexpected_prefixes: list[str] = []
self.ignore_unexpected_suffixes: list[str] = []

import atom
self.model = atom.prepare_model(config=vllm_config, framework="vllm")
if self.model is None:
model_arch = vllm_config.model_config.architectures[0]
raise ValueError(f'The model {model_arch} is not supported by model impl backend atom')

def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**model_kwargs,
) -> torch.Tensor | IntermediateTensors:
if not self.pp_group.is_first_rank:
assert intermediate_tensors is not None
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]

hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)

if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})

return hidden_states

def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
return self.model.load_weights(weights)

def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
logits = self.model.compute_logits(hidden_states)
return logits


class ATOMForCausalLM(ATOMModelBase, VllmModelForTextGeneration): ...

class ATOMMoEForCausalLM(ATOMModelBase, VllmModelForTextGeneration): ...
30 changes: 28 additions & 2 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,12 @@
),
}

_ATOM_BACKEND_MODELS = {
# Text generation models
"ATOMForCausalLM": ("atom", "ATOMForCausalLM"),
"ATOMMoEForCausalLM": ("atom", "ATOMMoEForCausalLM"),
}

_VLLM_MODELS = {
**_TEXT_GENERATION_MODELS,
**_EMBEDDING_MODELS,
Expand All @@ -478,6 +484,7 @@
**_SPECULATIVE_DECODING_MODELS,
**_TRANSFORMERS_SUPPORTED_MODELS,
**_TRANSFORMERS_BACKEND_MODELS,
**_ATOM_BACKEND_MODELS,
}

# This variable is used as the args for subprocess.run(). We
Expand Down Expand Up @@ -806,7 +813,16 @@ def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
if model_arch not in self.models:
return None

return _try_inspect_model_cls(model_arch, self.models[model_arch])
model = self.models[model_arch]
return _try_inspect_model_cls(model_arch, model)

def _try_resolve_atom(
self,
architecture: str,
model_config: ModelConfig,
) -> str | None:
cls_name = model_config._get_model_impl_backend_cls(model_impl="ATOM")
return cls_name

def _try_resolve_transformers(
self,
Expand Down Expand Up @@ -872,7 +888,7 @@ def _try_resolve_transformers(
"is not compatible with vLLM."
)

return model_config._get_transformers_backend_cls()
return model_config._get_model_impl_backend_cls()

def _normalize_arch(
self,
Expand Down Expand Up @@ -920,6 +936,10 @@ def inspect_model_cls(
elif model_config.model_impl == "terratorch":
model_info = self._try_inspect_model_cls("Terratorch")
return (model_info, "Terratorch")
elif model_config.model_impl == "atom":
arch = self._try_resolve_atom(architectures[0], model_config)
model_info = self._try_inspect_model_cls(arch)
return (model_info, arch)

# Fallback to transformers impl (after resolving convert_type)
if (
Expand Down Expand Up @@ -974,6 +994,12 @@ def resolve_model_cls(
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
elif model_config.model_impl == "atom":
arch = self._try_resolve_atom(architectures[0], model_config)
if arch is not None:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)

# Fallback to transformers impl (after resolving convert_type)
if (
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ def get_attn_backend_cls(
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum

if selected_backend == AttentionBackendEnum.CUSTOM:
logger.info("Using CUSTOM backend.")
return AttentionBackendEnum.CUSTOM.get_path()

if use_sparse:
if kv_cache_dtype.startswith("fp8"):
raise ValueError(
Expand Down