diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index eee08a71..c0db38ae 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -36,6 +36,8 @@ from atom.model_ops.linear import QKVParallelLinear, ReplicatedLinear, RowParallelLinear from atom.model_ops.moe import FusedMoE +from atom.utils import envs + # from vllm.model_executor.model_loader.weight_utils import default_weight_loader from atom.models.utils import ( IntermediateTensors, @@ -141,7 +143,12 @@ def forward( qkv = self.qkv_proj(hidden_states) q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) # q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, positions) + if envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: + attn_output = self.attn( + query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv + ) + else: + attn_output = self.attn(q, k, v, positions) output = self.o_proj(attn_output) return output diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 9a0e1eba..2c3df6eb 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Any, Iterable +from typing import Optional, Union, Any import torch from aiter.dist.communication_op import tensor_model_parallel_all_reduce @@ -33,7 +33,6 @@ ) from atom.utils import envs from torch import nn -from atom.model_loader.loader import load_model_in_plugin_mode # import torch.distributed as dist from transformers import PretrainedConfig @@ -521,14 +520,3 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - # load weights in plugin mode and discard passed weights generator - # here prefix is "model." because Qwen3MoeForCausalLM is constructed in model - # wrapper class, so the name of loaded weights are prefixed with "model.". - # The vLLM will check the name of the loaded weights to make sure all the - # weights are loaded correctly - loaded_weights_record = load_model_in_plugin_mode( - model=self, config=self.atom_config, prefix="model." - ) - return loaded_weights_record diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index fd6420bc..4dfbe7ab 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -75,7 +75,7 @@ def rope_cache_plugin_mode( # ATOM: [num_blocks, num_kv_heads, head_size, block_size], # vLLM: [num_blocks, num_kv_heads, block_size // x, head_size, x], v_cache_template = torch.empty( - [num_blocks, num_kv_heads, block_size // x, head_size, x], + [num_blocks, num_kv_heads, head_size, block_size], dtype=v_cache.dtype, device="meta", ) @@ -134,9 +134,13 @@ def rope_cache_plugin_mode( [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1 ) elif use_triton_attn and self.rotary_emb is not None: - k_scale = v_scale = self.kv_scale - q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache( + k_scale = v_scale = self.one_scale + qkv = qkv.view(qkv.shape[0], -1, self.head_dim) + q, k, v = qkv.split( + [self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1 + ) + q, k, _k_cache, _v_cache = fused_qk_rope_reshape_and_cache( q, k, v, @@ -229,11 +233,9 @@ def paged_attention_triton_plugin_mode( ) per_tensor = False - if k_scale is not None: - per_tensor = k_scale.numel() == 1 - if not per_tensor: - k_scale = k_scale.unsqueeze(-1) - v_scale = v_scale.unsqueeze(-1) + if k_scale is not None and k_scale.numel() > 1: + k_scale = k_scale.unsqueeze(-1) + v_scale = v_scale.unsqueeze(-1) compute_type = ( torch.bfloat16 if self.kv_cache_dtype == "bf16" or per_tensor @@ -363,6 +365,7 @@ def extend_for_sliding_window( causal=True, window_size=sliding_window, alibi_slopes=self.alibi_slopes, + sink_ptr=self.sinks, return_lse=False, out=output, ) @@ -557,6 +560,7 @@ def forward_impl_plugin_mode( # update the layer kv scale tensor self.k_scale = self.kv_scale[0] self.v_scale = self.kv_scale[1] + self.one_scale = torch.ones((1,), dtype=torch.float32, device=self.device) layer.k_scale = self.k_scale layer.v_scale = self.v_scale @@ -669,7 +673,7 @@ def forward_impl_plugin_mode( device="meta", ) v_cache_template = torch.empty( - [num_blocks, num_kv_heads, block_size // x, head_size, x], + [num_blocks, num_kv_heads, head_size, block_size], dtype=v_cache.dtype, device="meta", ) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index bcca17ad..e3a09771 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -20,6 +20,7 @@ import atom # noqa: F401 from atom.plugin.config import generate_atom_config_for_plugin_mode +from atom.model_loader.loader import load_model_in_plugin_mode import logging @@ -29,6 +30,7 @@ _ATOM_MODEL_CLASSES: dict[str, str] = { "Qwen3ForCausalLM": "atom.models.qwen3:Qwen3ForCausalLM", "Qwen3MoeForCausalLM": "atom.models.qwen3_moe:Qwen3MoeForCausalLM", + "GptOssForCausalLM": "atom.models.gpt_oss:GptOssForCausalLM", } @@ -125,7 +127,10 @@ def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]], ) -> set[str]: - return self.model.load_weights(weights) + loaded_weights_record = load_model_in_plugin_mode( + model=self.model, config=self.model.atom_config, prefix="model." + ) + return loaded_weights_record def compute_logits( self, diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py index 0330929a..24b00f04 100644 --- a/atom/plugin/vllm/register.py +++ b/atom/plugin/vllm/register.py @@ -22,6 +22,7 @@ _VLLM_MODEL_REGISTRY_OVERRIDES: dict[str, str] = { "Qwen3ForCausalLM": ATOM_CAUSAL_LM_MODEL_WRAPPER, "Qwen3MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, + "GptOssForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, }