Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions atom/model_ops/base_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,11 @@ def cp_mha_gather_cache_kernel(
k_reg = tl.load(key_cache_ptr_offset + k_reg_offset)
v_reg = tl.load(value_cache_ptr_offset + v_reg_offset)
if DEQUANT:
k_scale = 1.0
v_scale = 1.0
scale_offset = (
block_id * num_heads * PAGE_SIZE + head_id * PAGE_SIZE + slot_id
)
k_scale = tl.load(k_scale_ptr + scale_offset)
v_scale = tl.load(v_scale_ptr + scale_offset)
k_reg = k_reg.to(tl.float32) * k_scale
v_reg = v_reg.to(tl.float32) * v_scale
tl.store(key_ptr_offset + col_offsets, k_reg)
Expand Down
2 changes: 2 additions & 0 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,8 @@ def rocm_asm_moe_impl(
quant_type_ in [QuantType.per_Token, QuantType.per_1x128]
and hidden_states.dtype in [torch.float16, torch.bfloat16]
and w1.dtype in [torch.int8, torch.uint8, torch.float8_e4m3fnuz]
and fc1_smooth_scale_fixed is not None
and fc2_smooth_scale_fixed is not None
)

return asm_moe(
Expand Down
62 changes: 41 additions & 21 deletions atom/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from atom.utils.decorators import support_torch_compile
from torch import nn
from transformers.models.glm4_moe import Glm4MoeConfig
from typing import Any, Iterable

from .utils import (
IntermediateTensors,
Expand All @@ -33,6 +34,8 @@
maybe_prefix,
)

from atom.model_loader.loader import load_model_in_plugin_mode


class Glm4MoeMLP(nn.Module):
def __init__(
Expand Down Expand Up @@ -197,7 +200,7 @@ def __init__(
qkv_bias: bool = False,
use_qk_norm: bool = False,
cache_config: str = "bf16",
quant_config: QuantizationConfig | None = None,
atom_config: Config | None = None,
prefix: str = "",
rope_theta: float = 10000,
layer_num: int = 0,
Expand Down Expand Up @@ -231,15 +234,15 @@ def __init__(
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
quant_config=atom_config.quant_config,
prefix=f"{prefix}.qkv_proj",
)

self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
quant_config=atom_config.quant_config,
prefix=f"{prefix}.o_proj",
)

Expand All @@ -254,12 +257,12 @@ def __init__(
partial_rotary_factor=partial_rotary_factor,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_heads=self.num_heads,
head_dim=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
kv_cache_dtype=cache_config,
quant_config=quant_config,
config=atom_config,
prefix=f"{prefix}.attn",
layer_num=layer_num,
use_mla=False,
Expand All @@ -274,6 +277,7 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
**model_kwargs: dict[str, Any] | None,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1)
Expand All @@ -286,7 +290,7 @@ def forward(
)

# q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, positions)
attn_output = self.attn(q, k, v, positions, **model_kwargs)
output = self.o_proj(attn_output)
return output

Expand All @@ -295,8 +299,7 @@ class Glm4MoeDecoderLayer(nn.Module):
def __init__(
self,
config: Glm4MoeConfig,
cache_config: str = "bf16",
quant_config: QuantizationConfig | None = None,
atom_config: Config | None = None,
prefix: str = "",
layer_num: int = 0,
enable_eplb: bool = False,
Expand All @@ -320,8 +323,8 @@ def __init__(
head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=config.attention_bias,
cache_config=cache_config,
quant_config=quant_config,
cache_config=atom_config.kv_cache_dtype,
atom_config=atom_config,
prefix=f"{prefix}.self_attn",
use_qk_norm=config.use_qk_norm,
rope_theta=rope_theta,
Expand All @@ -334,7 +337,7 @@ def __init__(
):
self.mlp = Glm4MoE(
config=config,
quant_config=quant_config,
quant_config=atom_config.quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
Expand All @@ -343,7 +346,7 @@ def __init__(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
quant_config=atom_config.quant_config,
prefix=f"{prefix}.mlp",
)

Expand All @@ -358,13 +361,16 @@ def forward(
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
**model_kwargs: dict[str, Any] | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
hidden_states = self.self_attn(
positions=positions, hidden_states=hidden_states, **model_kwargs
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
Expand All @@ -373,6 +379,7 @@ def forward(
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
Expand All @@ -382,8 +389,6 @@ def __init__(self, *, atom_config: Config, prefix: str = ""):
super().__init__()

config = atom_config.hf_config
cache_config = atom_config.kv_cache_dtype
quant_config = atom_config.quant_config
self.config = config

self.vocab_size = config.vocab_size
Expand All @@ -401,8 +406,7 @@ def __init__(self, *, atom_config: Config, prefix: str = ""):
config.num_hidden_layers,
lambda prefix, layer_num=None: Glm4MoeDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
atom_config=atom_config,
prefix=prefix,
layer_num=layer_num,
),
Expand All @@ -427,6 +431,7 @@ def forward(
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**model_kwargs: dict[str, Any],
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
Expand All @@ -440,7 +445,9 @@ def forward(
residual = intermediate_tensors["residual"]

for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(
positions, hidden_states, residual, **model_kwargs
)

if not get_pp_group().is_last_rank:
return IntermediateTensors(
Expand Down Expand Up @@ -525,6 +532,7 @@ class Glm4MoeForCausalLM(nn.Module, Glm4MixtureOfExperts):

def __init__(self, atom_config: Config, prefix: str = ""):
super().__init__()
self.atom_config = atom_config
config = atom_config.hf_config
quant_config = atom_config.quant_config
self.config = config
Expand Down Expand Up @@ -577,9 +585,10 @@ def forward(
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**model_kwargs: dict[str, Any],
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs
)
return hidden_states

Expand Down Expand Up @@ -612,6 +621,17 @@ def get_spec_layer_idx_from_weight_name(
return layer_idx + i
return None

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# load weights in plugin mode and discard passed weights generator
# here prefix is "model." because Glm4MoeForCausalLM is constructed in model
# wrapper class, so the name of loaded weights are prefixed with "model.".
# The vLLM will check the name of the loaded weights to make sure all the
# weights are loaded correctly
loaded_weights_record = load_model_in_plugin_mode(
model=self, config=self.atom_config, prefix="model."
)
return loaded_weights_record


def get_spec_layer_idx_from_weight_name(
config: Glm4MoeConfig, weight_name: str
Expand Down
2 changes: 2 additions & 0 deletions atom/plugin/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from atom.models.qwen3 import Qwen3ForCausalLM
from atom.models.qwen3_moe import Qwen3MoeForCausalLM
from atom.models.glm4_moe import Glm4MoeForCausalLM
from atom.config import Config
from atom.plugin.prepare import is_vllm, is_sglang

Expand All @@ -10,6 +11,7 @@
_ATOM_SUPPORTED_MODELS = {
"Qwen3ForCausalLM": Qwen3ForCausalLM,
"Qwen3MoeForCausalLM": Qwen3MoeForCausalLM,
"Glm4MoeForCausalLM": Glm4MoeForCausalLM,
}


Expand Down
1 change: 1 addition & 0 deletions atom/plugin/vllm/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_ATOM_MODEL_CLASSES: dict[str, str] = {
"Qwen3ForCausalLM": "atom.models.qwen3:Qwen3ForCausalLM",
"Qwen3MoeForCausalLM": "atom.models.qwen3_moe:Qwen3MoeForCausalLM",
"Glm4MoeForCausalLM": "atom.models.glm4_moe:Glm4MoeForCausalLM",
}


Expand Down
1 change: 1 addition & 0 deletions atom/plugin/vllm/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_VLLM_MODEL_REGISTRY_OVERRIDES: dict[str, str] = {
"Qwen3ForCausalLM": ATOM_CAUSAL_LM_MODEL_WRAPPER,
"Qwen3MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER,
"Glm4MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER,
}


Expand Down
Loading