From c5165f590400c3998f08f32b2abbc4b91923b414 Mon Sep 17 00:00:00 2001 From: perzhang Date: Wed, 4 Mar 2026 06:40:31 +0000 Subject: [PATCH 1/5] [feat](gpt-oss): support gpt-oss for vllm plugin --- atom/models/gpt_oss.py | 16 +++++++++++++++- atom/plugin/vllm/model_wrapper.py | 1 + atom/plugin/vllm/register.py | 1 + 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index eee08a710..138862271 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Iterable import torch import torch.distributed as dist @@ -36,6 +36,8 @@ from atom.model_ops.linear import QKVParallelLinear, ReplicatedLinear, RowParallelLinear from atom.model_ops.moe import FusedMoE +from atom.model_loader.loader import load_model_in_plugin_mode + # from vllm.model_executor.model_loader.weight_utils import default_weight_loader from atom.models.utils import ( IntermediateTensors, @@ -391,3 +393,15 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_up_proj_name="up_proj", num_experts=self.config.num_local_experts, ) + + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + # here prefix is "model." because GptOssForCausalLM is constructed in model + # wrapper class, so the name of loaded weights are prefixed with "model.". + # The vLLM will check the name of the loaded weights to make sure all the + # weights are loaded correctly + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index bcca17ad1..dc1cae106 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -29,6 +29,7 @@ _ATOM_MODEL_CLASSES: dict[str, str] = { "Qwen3ForCausalLM": "atom.models.qwen3:Qwen3ForCausalLM", "Qwen3MoeForCausalLM": "atom.models.qwen3_moe:Qwen3MoeForCausalLM", + "GptOssForCausalLM": "atom.models.gpt_oss:GptOssForCausalLM", } diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py index 0330929a5..24b00f04d 100644 --- a/atom/plugin/vllm/register.py +++ b/atom/plugin/vllm/register.py @@ -22,6 +22,7 @@ _VLLM_MODEL_REGISTRY_OVERRIDES: dict[str, str] = { "Qwen3ForCausalLM": ATOM_CAUSAL_LM_MODEL_WRAPPER, "Qwen3MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, + "GptOssForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, } From 62198706ab764d8e9c6febfb6fa1ca49d98817a7 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Mar 2026 06:26:11 +0000 Subject: [PATCH 2/5] [Fix](gpt-oss): fix attention for gpt-oss --- atom/models/gpt_oss.py | 8 +++++++- atom/plugin/attention_mha.py | 25 +++++++++++++++++-------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index 138862271..1e676b4d2 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -37,6 +37,7 @@ from atom.model_ops.moe import FusedMoE from atom.model_loader.loader import load_model_in_plugin_mode +from atom.utils import envs # from vllm.model_executor.model_loader.weight_utils import default_weight_loader from atom.models.utils import ( @@ -143,7 +144,12 @@ def forward( qkv = self.qkv_proj(hidden_states) q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) # q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, positions) + if envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: + attn_output = self.attn( + query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv + ) + else: + attn_output = self.attn(q, k, v, positions) output = self.o_proj(attn_output) return output diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index fd6420bc3..b521c1229 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -75,7 +75,7 @@ def rope_cache_plugin_mode( # ATOM: [num_blocks, num_kv_heads, head_size, block_size], # vLLM: [num_blocks, num_kv_heads, block_size // x, head_size, x], v_cache_template = torch.empty( - [num_blocks, num_kv_heads, block_size // x, head_size, x], + [num_blocks, num_kv_heads, head_size, block_size], dtype=v_cache.dtype, device="meta", ) @@ -134,9 +134,13 @@ def rope_cache_plugin_mode( [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1 ) elif use_triton_attn and self.rotary_emb is not None: - k_scale = v_scale = self.kv_scale - - q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache( + + k_scale = v_scale = self.one_scale + qkv = qkv.view(qkv.shape[0], -1, self.head_dim) + q, k, v = qkv.split( + [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1 + ) + q, k, _k_cache, _v_cache = fused_qk_rope_reshape_and_cache( q, k, v, @@ -229,9 +233,7 @@ def paged_attention_triton_plugin_mode( ) per_tensor = False - if k_scale is not None: - per_tensor = k_scale.numel() == 1 - if not per_tensor: + if k_scale is not None and k_scale.numel() > 1: k_scale = k_scale.unsqueeze(-1) v_scale = v_scale.unsqueeze(-1) compute_type = ( @@ -557,6 +559,13 @@ def forward_impl_plugin_mode( # update the layer kv scale tensor self.k_scale = self.kv_scale[0] self.v_scale = self.kv_scale[1] + one_kv_scale = ( + torch.finfo(torch.float8_e4m3fn).max / torch.finfo(aiter.dtypes.fp8).max + if self.kv_cache_dtype == "fp8" + else 1.0 + ) + kv_scale = torch.ones((1,), dtype=torch.float32, device=self.device) + self.one_scale = kv_scale layer.k_scale = self.k_scale layer.v_scale = self.v_scale @@ -669,7 +678,7 @@ def forward_impl_plugin_mode( device="meta", ) v_cache_template = torch.empty( - [num_blocks, num_kv_heads, block_size // x, head_size, x], + [num_blocks, num_kv_heads, head_size, block_size], dtype=v_cache.dtype, device="meta", ) From b57030060859ec48d500222211f49e5a280bbdb6 Mon Sep 17 00:00:00 2001 From: perzhang Date: Fri, 6 Mar 2026 03:26:13 +0000 Subject: [PATCH 3/5] [Fix](gpt-oss): add sink for extend forward --- atom/plugin/attention_mha.py | 1 + 1 file changed, 1 insertion(+) diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index b521c1229..ba84dfdbc 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -365,6 +365,7 @@ def extend_for_sliding_window( causal=True, window_size=sliding_window, alibi_slopes=self.alibi_slopes, + sink_ptr=self.sinks, return_lse=False, out=output, ) From 8e746ed23f3c42be8a0d24221b55bdd92ef46037 Mon Sep 17 00:00:00 2001 From: perzhang Date: Fri, 6 Mar 2026 06:49:25 +0000 Subject: [PATCH 4/5] [fix](gpt-oss): mv load_weights func to base class --- atom/models/gpt_oss.py | 12 ------------ atom/models/qwen3_moe.py | 11 ----------- atom/plugin/attention_mha.py | 8 +------- atom/plugin/vllm/model_wrapper.py | 6 +++++- 4 files changed, 6 insertions(+), 31 deletions(-) diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index 1e676b4d2..381ad684c 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -399,15 +399,3 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_up_proj_name="up_proj", num_experts=self.config.num_local_experts, ) - - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - # load weights in plugin mode and discard passed weights generator - # here prefix is "model." because GptOssForCausalLM is constructed in model - # wrapper class, so the name of loaded weights are prefixed with "model.". - # The vLLM will check the name of the loaded weights to make sure all the - # weights are loaded correctly - loaded_weights_record = load_model_in_plugin_mode( - model=self, config=self.atom_config, prefix="model." - ) - return loaded_weights_record diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 9a0e1eba1..8dc1b60c7 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -521,14 +521,3 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - # load weights in plugin mode and discard passed weights generator - # here prefix is "model." because Qwen3MoeForCausalLM is constructed in model - # wrapper class, so the name of loaded weights are prefixed with "model.". - # The vLLM will check the name of the loaded weights to make sure all the - # weights are loaded correctly - loaded_weights_record = load_model_in_plugin_mode( - model=self, config=self.atom_config, prefix="model." - ) - return loaded_weights_record diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index ba84dfdbc..f7b436226 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -560,13 +560,7 @@ def forward_impl_plugin_mode( # update the layer kv scale tensor self.k_scale = self.kv_scale[0] self.v_scale = self.kv_scale[1] - one_kv_scale = ( - torch.finfo(torch.float8_e4m3fn).max / torch.finfo(aiter.dtypes.fp8).max - if self.kv_cache_dtype == "fp8" - else 1.0 - ) - kv_scale = torch.ones((1,), dtype=torch.float32, device=self.device) - self.one_scale = kv_scale + self.one_scale = torch.ones((1,), dtype=torch.float32, device=self.device) layer.k_scale = self.k_scale layer.v_scale = self.v_scale diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index dc1cae106..e3a097713 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -20,6 +20,7 @@ import atom # noqa: F401 from atom.plugin.config import generate_atom_config_for_plugin_mode +from atom.model_loader.loader import load_model_in_plugin_mode import logging @@ -126,7 +127,10 @@ def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]], ) -> set[str]: - return self.model.load_weights(weights) + loaded_weights_record = load_model_in_plugin_mode( + model=self.model, config=self.model.atom_config, prefix="model." + ) + return loaded_weights_record def compute_logits( self, From 441702e3dd8036b31fb955a19304bd186bdda12a Mon Sep 17 00:00:00 2001 From: perzhang Date: Fri, 6 Mar 2026 07:02:22 +0000 Subject: [PATCH 5/5] [fix](gpt-oss): fix format error --- atom/models/gpt_oss.py | 5 ++--- atom/models/qwen3_moe.py | 3 +-- atom/plugin/attention_mha.py | 6 +++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index 381ad684c..c0db38aea 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Iterable +from typing import Optional import torch import torch.distributed as dist @@ -36,7 +36,6 @@ from atom.model_ops.linear import QKVParallelLinear, ReplicatedLinear, RowParallelLinear from atom.model_ops.moe import FusedMoE -from atom.model_loader.loader import load_model_in_plugin_mode from atom.utils import envs # from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -148,7 +147,7 @@ def forward( attn_output = self.attn( query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv ) - else: + else: attn_output = self.attn(q, k, v, positions) output = self.o_proj(attn_output) return output diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 8dc1b60c7..2c3df6eb6 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Any, Iterable +from typing import Optional, Union, Any import torch from aiter.dist.communication_op import tensor_model_parallel_all_reduce @@ -33,7 +33,6 @@ ) from atom.utils import envs from torch import nn -from atom.model_loader.loader import load_model_in_plugin_mode # import torch.distributed as dist from transformers import PretrainedConfig diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index f7b436226..4dfbe7ab9 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -134,7 +134,7 @@ def rope_cache_plugin_mode( [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1 ) elif use_triton_attn and self.rotary_emb is not None: - + k_scale = v_scale = self.one_scale qkv = qkv.view(qkv.shape[0], -1, self.head_dim) q, k, v = qkv.split( @@ -234,8 +234,8 @@ def paged_attention_triton_plugin_mode( per_tensor = False if k_scale is not None and k_scale.numel() > 1: - k_scale = k_scale.unsqueeze(-1) - v_scale = v_scale.unsqueeze(-1) + k_scale = k_scale.unsqueeze(-1) + v_scale = v_scale.unsqueeze(-1) compute_type = ( torch.bfloat16 if self.kv_cache_dtype == "bf16" or per_tensor