diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 125e4e382774..d18f095ba5b4 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -77,7 +77,8 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): ) CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend" # Placeholder for third-party/custom backends - must be registered before use - CUSTOM = "" + # set to None to avoid alias with other backend, whose value is an empty string + CUSTOM = None def get_path(self, include_classname: bool = True) -> str: """Get the class path for this backend (respects overrides). @@ -139,7 +140,8 @@ class MambaAttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend" GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend" # Placeholder for third-party/custom backends - must be registered before use - CUSTOM = "" + # set to None to avoid alias with other backend, whose value is an empty string + CUSTOM = None def get_path(self, include_classname: bool = True) -> str: """Get the class path for this backend (respects overrides). diff --git a/vllm/config/model.py b/vllm/config/model.py index caa9a3440c41..d69e17326459 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -280,7 +280,8 @@ class ModelConfig: available.\n - "vllm" will use the vLLM model implementation.\n - "transformers" will use the Transformers model implementation.\n - - "terratorch" will use the TerraTorch model implementation. + - "terratorch" will use the TerraTorch model implementation.\n + - "atom" will use the atom model implementation for AMD users. """ override_attention_dtype: str | None = None """Override dtype for attention""" @@ -741,10 +742,10 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": raise ValueError("max_model_len must be an integer after __post_init__.") return self - def _get_transformers_backend_cls(self) -> str: - """Determine which Transformers modeling backend class will be used if - `model_impl` is set to `transformers` or `auto`.""" - cls = "Transformers" + def _get_model_impl_backend_cls(self, model_impl: str = "Transformers") -> str: + """Determine which modeling backend class of the model implementation will be used if + `model_impl` is set to `transformers`, `auto` or other backends.""" + cls = model_impl # If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal cls += "MultiModal" if self.hf_config != self.hf_text_config else "" cls += "MoE" if self.get_num_experts() > 1 else "" @@ -771,7 +772,7 @@ def _get_transformers_backend_cls(self) -> str: def using_transformers_backend(self) -> bool: """Check if the model is using the Transformers modeling backend class.""" used_cls = self._model_info.architecture - transformers_backend_cls = self._get_transformers_backend_cls() + transformers_backend_cls = self._get_model_impl_backend_cls() return used_cls == transformers_backend_cls @property diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 2021b68b8a60..44f0a1517e29 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -180,7 +180,7 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], model_config=model_config, ) - if arch == model_config._get_transformers_backend_cls(): + if arch == model_config._get_model_impl_backend_cls(): assert model_config.model_impl != "vllm" if model_config.model_impl == "auto": logger.warning_once( diff --git a/vllm/model_executor/models/atom.py b/vllm/model_executor/models/atom.py new file mode 100644 index 000000000000..d41ba80073e1 --- /dev/null +++ b/vllm/model_executor/models/atom.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2026 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `atom` models""" + +import torch +import torch.nn as nn +from collections.abc import Iterable + +from vllm.model_executor.models.interfaces import ( + SupportsPP, + SupportsQuant, +) +from vllm.model_executor.models.interfaces_base import VllmModel +from vllm.model_executor.models.interfaces_base import VllmModelForTextGeneration +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group, get_tp_group +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + +class ATOMModelBase(nn.Module, VllmModel, SupportsQuant, SupportsPP): + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + logger.info("Using ATOM modeling backend.") + + self.config = vllm_config.model_config.hf_config + self.text_config = self.config.get_text_config() + self.cache_config = vllm_config.cache_config + self.device_config = vllm_config.device_config + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.quant_config = vllm_config.quant_config + + self.pp_group = get_pp_group() + self.tp_group = get_tp_group() + + # Weights to skip in `self.load_weights` + self.skip_prefixes: list[str] = [] + self.skip_substrs: list[str] = [] + self.ignore_unexpected_prefixes: list[str] = [] + self.ignore_unexpected_suffixes: list[str] = [] + + import atom + self.model = atom.prepare_model(config=vllm_config, framework="vllm") + if self.model is None: + model_arch = vllm_config.model_config.architectures[0] + raise ValueError(f'The model {model_arch} is not supported by model impl backend atom') + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs, + ) -> torch.Tensor | IntermediateTensors: + if not self.pp_group.is_first_rank: + assert intermediate_tensors is not None + input_ids = None + inputs_embeds = intermediate_tensors["hidden_states"] + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + if not self.pp_group.is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + return hidden_states + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + return self.model.load_weights(weights) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + logits = self.model.compute_logits(hidden_states) + return logits + + +class ATOMForCausalLM(ATOMModelBase, VllmModelForTextGeneration): ... + +class ATOMMoEForCausalLM(ATOMModelBase, VllmModelForTextGeneration): ... diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a0d8a78a2ae7..853d3a7934ec 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -470,6 +470,12 @@ ), } +_ATOM_BACKEND_MODELS = { + # Text generation models + "ATOMForCausalLM": ("atom", "ATOMForCausalLM"), + "ATOMMoEForCausalLM": ("atom", "ATOMMoEForCausalLM"), +} + _VLLM_MODELS = { **_TEXT_GENERATION_MODELS, **_EMBEDDING_MODELS, @@ -478,6 +484,7 @@ **_SPECULATIVE_DECODING_MODELS, **_TRANSFORMERS_SUPPORTED_MODELS, **_TRANSFORMERS_BACKEND_MODELS, + **_ATOM_BACKEND_MODELS, } # This variable is used as the args for subprocess.run(). We @@ -806,7 +813,16 @@ def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None: if model_arch not in self.models: return None - return _try_inspect_model_cls(model_arch, self.models[model_arch]) + model = self.models[model_arch] + return _try_inspect_model_cls(model_arch, model) + + def _try_resolve_atom( + self, + architecture: str, + model_config: ModelConfig, + ) -> str | None: + cls_name = model_config._get_model_impl_backend_cls(model_impl="ATOM") + return cls_name def _try_resolve_transformers( self, @@ -872,7 +888,7 @@ def _try_resolve_transformers( "is not compatible with vLLM." ) - return model_config._get_transformers_backend_cls() + return model_config._get_model_impl_backend_cls() def _normalize_arch( self, @@ -920,6 +936,10 @@ def inspect_model_cls( elif model_config.model_impl == "terratorch": model_info = self._try_inspect_model_cls("Terratorch") return (model_info, "Terratorch") + elif model_config.model_impl == "atom": + arch = self._try_resolve_atom(architectures[0], model_config) + model_info = self._try_inspect_model_cls(arch) + return (model_info, arch) # Fallback to transformers impl (after resolving convert_type) if ( @@ -974,6 +994,12 @@ def resolve_model_cls( model_cls = self._try_load_model_cls(arch) if model_cls is not None: return (model_cls, arch) + elif model_config.model_impl == "atom": + arch = self._try_resolve_atom(architectures[0], model_config) + if arch is not None: + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) # Fallback to transformers impl (after resolving convert_type) if ( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 90c3b9e341f4..b85f20fec68b 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -266,6 +266,10 @@ def get_attn_backend_cls( from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.registry import AttentionBackendEnum + if selected_backend == AttentionBackendEnum.CUSTOM: + logger.info("Using CUSTOM backend.") + return AttentionBackendEnum.CUSTOM.get_path() + if use_sparse: if kv_cache_dtype.startswith("fp8"): raise ValueError(