Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion atom/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from atom.model_ops.linear import QKVParallelLinear, ReplicatedLinear, RowParallelLinear
from atom.model_ops.moe import FusedMoE

from atom.utils import envs

# from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from atom.models.utils import (
IntermediateTensors,
Expand Down Expand Up @@ -141,7 +143,12 @@ def forward(
qkv = self.qkv_proj(hidden_states)
q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1)
# q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, positions)
if envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION:
attn_output = self.attn(
query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv
)
else:
attn_output = self.attn(q, k, v, positions)
output = self.o_proj(attn_output)
return output

Expand Down
14 changes: 1 addition & 13 deletions atom/models/qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Any, Iterable
from typing import Optional, Union, Any

import torch
from aiter.dist.communication_op import tensor_model_parallel_all_reduce
Expand Down Expand Up @@ -33,7 +33,6 @@
)
from atom.utils import envs
from torch import nn
from atom.model_loader.loader import load_model_in_plugin_mode

# import torch.distributed as dist
from transformers import PretrainedConfig
Expand Down Expand Up @@ -521,14 +520,3 @@ def make_empty_intermediate_tensors(

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# load weights in plugin mode and discard passed weights generator
# here prefix is "model." because Qwen3MoeForCausalLM is constructed in model
# wrapper class, so the name of loaded weights are prefixed with "model.".
# The vLLM will check the name of the loaded weights to make sure all the
# weights are loaded correctly
loaded_weights_record = load_model_in_plugin_mode(
model=self, config=self.atom_config, prefix="model."
)
return loaded_weights_record
22 changes: 13 additions & 9 deletions atom/plugin/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def rope_cache_plugin_mode(
# ATOM: [num_blocks, num_kv_heads, head_size, block_size],
# vLLM: [num_blocks, num_kv_heads, block_size // x, head_size, x],
v_cache_template = torch.empty(
[num_blocks, num_kv_heads, block_size // x, head_size, x],
[num_blocks, num_kv_heads, head_size, block_size],
dtype=v_cache.dtype,
device="meta",
)
Expand Down Expand Up @@ -134,9 +134,13 @@ def rope_cache_plugin_mode(
[self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1
)
elif use_triton_attn and self.rotary_emb is not None:
k_scale = v_scale = self.kv_scale

q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache(
k_scale = v_scale = self.one_scale
qkv = qkv.view(qkv.shape[0], -1, self.head_dim)
q, k, v = qkv.split(
[self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1
)
q, k, _k_cache, _v_cache = fused_qk_rope_reshape_and_cache(
q,
k,
v,
Expand Down Expand Up @@ -229,11 +233,9 @@ def paged_attention_triton_plugin_mode(
)

per_tensor = False
if k_scale is not None:
per_tensor = k_scale.numel() == 1
if not per_tensor:
k_scale = k_scale.unsqueeze(-1)
v_scale = v_scale.unsqueeze(-1)
if k_scale is not None and k_scale.numel() > 1:
k_scale = k_scale.unsqueeze(-1)
v_scale = v_scale.unsqueeze(-1)
compute_type = (
torch.bfloat16
if self.kv_cache_dtype == "bf16" or per_tensor
Expand Down Expand Up @@ -363,6 +365,7 @@ def extend_for_sliding_window(
causal=True,
window_size=sliding_window,
alibi_slopes=self.alibi_slopes,
sink_ptr=self.sinks,
return_lse=False,
out=output,
)
Expand Down Expand Up @@ -557,6 +560,7 @@ def forward_impl_plugin_mode(
# update the layer kv scale tensor
self.k_scale = self.kv_scale[0]
self.v_scale = self.kv_scale[1]
self.one_scale = torch.ones((1,), dtype=torch.float32, device=self.device)
layer.k_scale = self.k_scale
layer.v_scale = self.v_scale

Expand Down Expand Up @@ -669,7 +673,7 @@ def forward_impl_plugin_mode(
device="meta",
)
v_cache_template = torch.empty(
[num_blocks, num_kv_heads, block_size // x, head_size, x],
[num_blocks, num_kv_heads, head_size, block_size],
dtype=v_cache.dtype,
device="meta",
)
Expand Down
7 changes: 6 additions & 1 deletion atom/plugin/vllm/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import atom # noqa: F401
from atom.plugin.config import generate_atom_config_for_plugin_mode
from atom.model_loader.loader import load_model_in_plugin_mode

import logging

Expand All @@ -29,6 +30,7 @@
_ATOM_MODEL_CLASSES: dict[str, str] = {
"Qwen3ForCausalLM": "atom.models.qwen3:Qwen3ForCausalLM",
"Qwen3MoeForCausalLM": "atom.models.qwen3_moe:Qwen3MoeForCausalLM",
"GptOssForCausalLM": "atom.models.gpt_oss:GptOssForCausalLM",
}


Expand Down Expand Up @@ -125,7 +127,10 @@ def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
return self.model.load_weights(weights)
loaded_weights_record = load_model_in_plugin_mode(
model=self.model, config=self.model.atom_config, prefix="model."
)
return loaded_weights_record

def compute_logits(
self,
Expand Down
1 change: 1 addition & 0 deletions atom/plugin/vllm/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_VLLM_MODEL_REGISTRY_OVERRIDES: dict[str, str] = {
"Qwen3ForCausalLM": ATOM_CAUSAL_LM_MODEL_WRAPPER,
"Qwen3MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER,
"GptOssForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER,
}


Expand Down
Loading