From 0e982151afe4336a0dc62970f907145e13be1289 Mon Sep 17 00:00:00 2001 From: kliuae Date: Thu, 5 Mar 2026 09:01:20 +0000 Subject: [PATCH 1/4] add vllm OOT plugin support for glm4moe Signed-off-by: kliuae --- atom/models/glm4_moe.py | 62 ++++++++++++++++++++----------- atom/plugin/register.py | 2 + atom/plugin/vllm/model_wrapper.py | 1 + atom/plugin/vllm/register.py | 1 + 4 files changed, 45 insertions(+), 21 deletions(-) diff --git a/atom/models/glm4_moe.py b/atom/models/glm4_moe.py index 31fe1d243..16984dc57 100644 --- a/atom/models/glm4_moe.py +++ b/atom/models/glm4_moe.py @@ -24,6 +24,7 @@ from atom.utils.decorators import support_torch_compile from torch import nn from transformers.models.glm4_moe import Glm4MoeConfig +from typing import Any, Iterable from .utils import ( IntermediateTensors, @@ -33,6 +34,8 @@ maybe_prefix, ) +from atom.model_loader.loader import load_model_in_plugin_mode + class Glm4MoeMLP(nn.Module): def __init__( @@ -197,7 +200,7 @@ def __init__( qkv_bias: bool = False, use_qk_norm: bool = False, cache_config: str = "bf16", - quant_config: QuantizationConfig | None = None, + atom_config: Config | None = None, prefix: str = "", rope_theta: float = 10000, layer_num: int = 0, @@ -231,7 +234,7 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=qkv_bias, - quant_config=quant_config, + quant_config=atom_config.quant_config, prefix=f"{prefix}.qkv_proj", ) @@ -239,7 +242,7 @@ def __init__( self.total_num_heads * self.head_dim, hidden_size, bias=False, - quant_config=quant_config, + quant_config=atom_config.quant_config, prefix=f"{prefix}.o_proj", ) @@ -254,12 +257,12 @@ def __init__( partial_rotary_factor=partial_rotary_factor, ) self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, + num_heads=self.num_heads, + head_dim=self.head_dim, + scale=self.scaling, num_kv_heads=self.num_kv_heads, kv_cache_dtype=cache_config, - quant_config=quant_config, + config=atom_config, prefix=f"{prefix}.attn", layer_num=layer_num, use_mla=False, @@ -274,6 +277,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -286,7 +290,7 @@ def forward( ) # q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, positions) + attn_output = self.attn(q, k, v, positions, **model_kwargs) output = self.o_proj(attn_output) return output @@ -295,8 +299,7 @@ class Glm4MoeDecoderLayer(nn.Module): def __init__( self, config: Glm4MoeConfig, - cache_config: str = "bf16", - quant_config: QuantizationConfig | None = None, + atom_config: Config | None = None, prefix: str = "", layer_num: int = 0, enable_eplb: bool = False, @@ -320,8 +323,8 @@ def __init__( head_dim=config.head_dim, rms_norm_eps=config.rms_norm_eps, qkv_bias=config.attention_bias, - cache_config=cache_config, - quant_config=quant_config, + cache_config=atom_config.kv_cache_dtype, + atom_config=atom_config, prefix=f"{prefix}.self_attn", use_qk_norm=config.use_qk_norm, rope_theta=rope_theta, @@ -334,7 +337,7 @@ def __init__( ): self.mlp = Glm4MoE( config=config, - quant_config=quant_config, + quant_config=atom_config.quant_config, prefix=f"{prefix}.mlp", enable_eplb=enable_eplb, ) @@ -343,7 +346,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + quant_config=atom_config.quant_config, prefix=f"{prefix}.mlp", ) @@ -358,13 +361,16 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, + **model_kwargs: dict[str, Any] | None, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + hidden_states = self.self_attn( + positions=positions, hidden_states=hidden_states, **model_kwargs + ) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -373,6 +379,7 @@ def forward( @support_torch_compile( dynamic_arg_dims={ "input_ids": 0, + "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, } @@ -382,8 +389,6 @@ def __init__(self, *, atom_config: Config, prefix: str = ""): super().__init__() config = atom_config.hf_config - cache_config = atom_config.kv_cache_dtype - quant_config = atom_config.quant_config self.config = config self.vocab_size = config.vocab_size @@ -401,8 +406,7 @@ def __init__(self, *, atom_config: Config, prefix: str = ""): config.num_hidden_layers, lambda prefix, layer_num=None: Glm4MoeDecoderLayer( config=config, - cache_config=cache_config, - quant_config=quant_config, + atom_config=atom_config, prefix=prefix, layer_num=layer_num, ), @@ -427,6 +431,7 @@ def forward( positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -440,7 +445,9 @@ def forward( residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer( + positions, hidden_states, residual, **model_kwargs + ) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -525,6 +532,7 @@ class Glm4MoeForCausalLM(nn.Module, Glm4MixtureOfExperts): def __init__(self, atom_config: Config, prefix: str = ""): super().__init__() + self.atom_config = atom_config config = atom_config.hf_config quant_config = atom_config.quant_config self.config = config @@ -577,9 +585,10 @@ def forward( positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds + input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs ) return hidden_states @@ -612,6 +621,17 @@ def get_spec_layer_idx_from_weight_name( return layer_idx + i return None + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + # here prefix is "model." because Glm4MoeForCausalLM is constructed in model + # wrapper class, so the name of loaded weights are prefixed with "model.". + # The vLLM will check the name of the loaded weights to make sure all the + # weights are loaded correctly + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record + def get_spec_layer_idx_from_weight_name( config: Glm4MoeConfig, weight_name: str diff --git a/atom/plugin/register.py b/atom/plugin/register.py index d76e2e86c..6a52cf54d 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -2,6 +2,7 @@ from atom.models.qwen3 import Qwen3ForCausalLM from atom.models.qwen3_moe import Qwen3MoeForCausalLM +from atom.models.glm4_moe import Glm4MoeForCausalLM from atom.config import Config from atom.plugin.prepare import is_vllm, is_sglang @@ -10,6 +11,7 @@ _ATOM_SUPPORTED_MODELS = { "Qwen3ForCausalLM": Qwen3ForCausalLM, "Qwen3MoeForCausalLM": Qwen3MoeForCausalLM, + "Glm4MoeForCausalLM": Glm4MoeForCausalLM, } diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index bcca17ad1..4aaf0036b 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -29,6 +29,7 @@ _ATOM_MODEL_CLASSES: dict[str, str] = { "Qwen3ForCausalLM": "atom.models.qwen3:Qwen3ForCausalLM", "Qwen3MoeForCausalLM": "atom.models.qwen3_moe:Qwen3MoeForCausalLM", + "Glm4MoeForCausalLM": "atom.models.glm4_moe:Glm4MoeForCausalLM", } diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py index 0330929a5..2ed96e179 100644 --- a/atom/plugin/vllm/register.py +++ b/atom/plugin/vllm/register.py @@ -22,6 +22,7 @@ _VLLM_MODEL_REGISTRY_OVERRIDES: dict[str, str] = { "Qwen3ForCausalLM": ATOM_CAUSAL_LM_MODEL_WRAPPER, "Qwen3MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, + "Glm4MoeForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, } From 7de55c05aa1cd0614209df3389f231a23ffe0477 Mon Sep 17 00:00:00 2001 From: kliuae Date: Thu, 5 Mar 2026 16:10:32 +0000 Subject: [PATCH 2/4] fp8 model Signed-off-by: kliuae --- atom/model_ops/moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 4cbb8fa00..efc7c0a34 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -545,6 +545,8 @@ def rocm_asm_moe_impl( quant_type_ in [QuantType.per_Token, QuantType.per_1x128] and hidden_states.dtype in [torch.float16, torch.bfloat16] and w1.dtype in [torch.int8, torch.uint8, torch.float8_e4m3fnuz] + and fc1_smooth_scale_fixed is not None + and fc2_smooth_scale_fixed is not None ) return asm_moe( From b4d0ebbc8ae6c262609e070c2d40ff026f2fc486 Mon Sep 17 00:00:00 2001 From: kliuae Date: Fri, 6 Mar 2026 07:51:19 +0000 Subject: [PATCH 3/4] fix large batch size accuracy Signed-off-by: kliuae --- atom/model_ops/base_attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index 3660b3ebd..3ee77b823 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -118,8 +118,9 @@ def cp_mha_gather_cache_kernel( k_reg = tl.load(key_cache_ptr_offset + k_reg_offset) v_reg = tl.load(value_cache_ptr_offset + v_reg_offset) if DEQUANT: - k_scale = 1.0 - v_scale = 1.0 + scale_offset = block_id * num_heads * PAGE_SIZE + head_id * PAGE_SIZE + slot_id + k_scale = tl.load(k_scale_ptr + scale_offset) + v_scale = tl.load(v_scale_ptr + scale_offset) k_reg = k_reg.to(tl.float32) * k_scale v_reg = v_reg.to(tl.float32) * v_scale tl.store(key_ptr_offset + col_offsets, k_reg) From cc7528e8aa817656cec4003c094a31358233608d Mon Sep 17 00:00:00 2001 From: kliuae Date: Fri, 6 Mar 2026 09:51:13 +0000 Subject: [PATCH 4/4] format Signed-off-by: kliuae --- atom/model_ops/base_attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index 3ee77b823..6b6eb0743 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -118,7 +118,9 @@ def cp_mha_gather_cache_kernel( k_reg = tl.load(key_cache_ptr_offset + k_reg_offset) v_reg = tl.load(value_cache_ptr_offset + v_reg_offset) if DEQUANT: - scale_offset = block_id * num_heads * PAGE_SIZE + head_id * PAGE_SIZE + slot_id + scale_offset = ( + block_id * num_heads * PAGE_SIZE + head_id * PAGE_SIZE + slot_id + ) k_scale = tl.load(k_scale_ptr + scale_offset) v_scale = tl.load(v_scale_ptr + scale_offset) k_reg = k_reg.to(tl.float32) * k_scale