From aa11defa7e3d2aa5b8dd3059a699a7160503e464 Mon Sep 17 00:00:00 2001 From: Li Date: Sun, 15 Feb 2026 01:59:56 -0800 Subject: [PATCH 1/6] Enable AllReduce+RMSNorm fusion for GPT-OSS model Apply the same AllReduce+RMSNorm fusion pattern already used by Qwen3-MoE and DeepSeek-V2 to the GPT-OSS model. When ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION=1 (default) and TP > 1, the MoE output AllReduce is deferred and fused into the next layers input_layernorm, reducing one AllReduce kernel launch per layer. Changes: Set FusedMoE reduce_results=False when fusion is enabled with a fallback manual AllReduce when disabled. Enable fused_allreduce on input_layernorm (layer > 0) and the final norm. post_attention_layernorm is unchanged because x_pad_to_multiple=256 is incompatible with the fused kernel. Mathematically equivalent to the original. Can be disabled via ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION=0. Co-authored-by: Cursor --- atom/models/gpt_oss.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index eee08a710..313f916aa 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -22,7 +22,10 @@ import torch import torch.distributed as dist from aiter import ActivationType -from aiter.dist.communication_op import tensor_model_parallel_all_gather +from aiter.dist.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size # from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -43,10 +46,13 @@ make_layers, maybe_prefix, ) +from atom.utils import envs from atom.utils.decorators import support_torch_compile from torch import nn from transformers import GptOssConfig +ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION + def cdiv(x, y): return (x + y - 1) // y @@ -163,6 +169,7 @@ def __init__( self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.tp_size = get_tensor_model_parallel_world_size() self.router = ReplicatedLinear( config.hidden_size, config.num_local_experts, @@ -176,7 +183,7 @@ def __init__( top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - reduce_results=True, + reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, renormalize=True, quant_config=quant_config, prefix=f"{prefix}.experts", @@ -192,6 +199,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: g = self.router(x[..., : self.hidden_size]) x = self.experts(hidden_states=x, router_logits=g) + if self.tp_size > 1 and not ENABLE_ALLREDUCE_RMSNORM_FUSION: + x = tensor_model_parallel_all_reduce(x) + if self.is_sequence_parallel: x = tensor_model_parallel_all_gather(x.contiguous(), 0) x = x[:num_tokens] @@ -221,7 +231,15 @@ def __init__( layer_num=layer_num, ) self.mlp = MLPBlock(atom_config, self.layer_idx, prefix=f"{prefix}.mlp") - self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + # Fuse MoE AllReduce into input_layernorm for layers > 0. + # Layer 0 receives already-reduced embedding output, so no fusion needed. + self.input_layernorm = RMSNorm( + config.hidden_size, + eps=1e-5, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION and layer_num > 0, + ) + # post_attention_layernorm cannot use fused_allreduce because + # x_pad_to_multiple is incompatible with the fused kernel. self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=1e-5, x_pad_to_multiple=256 ) @@ -273,7 +291,11 @@ def __init__( ), prefix=f"{prefix}.layers", ) - self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) + self.norm = RMSNorm( + self.config.hidden_size, + eps=1e-5, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, + ) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], self.config.hidden_size ) From 9ba5eadfeea3ece9e20bb64fe10e8c01effa61b6 Mon Sep 17 00:00:00 2001 From: Li Date: Sun, 15 Feb 2026 02:16:48 -0800 Subject: [PATCH 2/6] Move MXFP4 padding from layernorm to MLPBlock, enable full AllReduce+RMSNorm fusion Previously, post_attention_layernorm used x_pad_to_multiple=256 for MXFP4 GEMM alignment, which prevented fused AllReduce on the attention output path. This moves the padding logic into MLPBlock.forward(), allowing both AllReduces per layer to be fused with their downstream RMSNorm: - o_proj AllReduce fused into post_attention_layernorm (all layers) - MoE AllReduce fused into next layer input_layernorm (layers > 0) This eliminates all standalone AllReduce kernel launches (70 per forward pass for 36 layers), reducing per-layer latency by ~60us. Co-authored-by: Cursor --- atom/models/gpt_oss.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index 313f916aa..f4b006198 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -21,6 +21,7 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from aiter import ActivationType from aiter.dist.communication_op import ( tensor_model_parallel_all_gather, @@ -120,6 +121,7 @@ def __init__( quant_config=None, prefix=f"{prefix}.o_proj", bias=True, + reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, ) self.num_local_attention_heads = config.num_attention_heads // tp_size @@ -192,13 +194,30 @@ def __init__( activation=ActivationType.Swiglu, config=config, ) + # Detect MXFP4 MoE GEMM padding requirement from the quant method. + # When hidden_size is not aligned to 256, MXFP4 weights are padded + # and the kernel expects padded input. We handle padding here instead + # of in the layernorm, so the layernorm can use fused AllReduce. + if hasattr(self.experts.quant_method, "hidden_pad"): + self.moe_hidden_pad = self.experts.quant_method.hidden_pad + else: + self.moe_hidden_pad = 0 def forward(self, x: torch.Tensor) -> torch.Tensor: num_tokens = x.shape[0] - g = self.router(x[..., : self.hidden_size]) + g = self.router(x) + + # Pad input for MXFP4 MoE GEMM alignment if needed + if self.moe_hidden_pad > 0: + x = F.pad(x, (0, self.moe_hidden_pad)) + x = self.experts(hidden_states=x, router_logits=g) + # Remove padding from output + if self.moe_hidden_pad > 0: + x = x[:, : self.hidden_size] + if self.tp_size > 1 and not ENABLE_ALLREDUCE_RMSNORM_FUSION: x = tensor_model_parallel_all_reduce(x) @@ -238,10 +257,13 @@ def __init__( eps=1e-5, fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION and layer_num > 0, ) - # post_attention_layernorm cannot use fused_allreduce because - # x_pad_to_multiple is incompatible with the fused kernel. + # Fuse o_proj AllReduce into post_attention_layernorm. + # Padding for MXFP4 MoE GEMM alignment is now handled inside MLPBlock, + # so this layernorm no longer needs x_pad_to_multiple. self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=1e-5, x_pad_to_multiple=256 + config.hidden_size, + eps=1e-5, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, ) def forward( @@ -260,7 +282,7 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - output = self.mlp(hidden_states)[:, : self.hidden_size] + output = self.mlp(hidden_states) return output, residual From 2ffab672ed93353958509f94c6843108ff34230f Mon Sep 17 00:00:00 2001 From: Li Date: Sun, 15 Feb 2026 02:43:34 -0800 Subject: [PATCH 3/6] Include compute_logits in HIP Graph for decode Previously, compute_logits (lm_head GEMM + TP AllGather) ran eagerly after every graph replay, adding Python overhead and kernel launch gaps. Now it is captured in the HIP Graph alongside the model forward, eliminating this overhead during decode. For prefill and eager fallback, compute_logits continues to run eagerly. The graph_logits dict stores references to graph-pool tensors that are updated in-place during replay, avoiding any extra allocation. Co-authored-by: Cursor --- atom/model_engine/model_runner.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 4be2c9a4e..62941a13e 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1317,6 +1317,7 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor positions = context.positions if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]: hidden_states = self.model(input_ids, positions) + logits = self.model.compute_logits(hidden_states) else: graph_bs = context.graph_bs max_q_len = forward_context.attn_metadata.max_seqlen_q @@ -1324,7 +1325,7 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor self.graphs[graph_key].replay() num_tokens = context.batch_size * max_q_len hidden_states = self.forward_vars["outputs"][:num_tokens] - logits = self.model.compute_logits(hidden_states) + logits = self.graph_logits[graph_key][:num_tokens] return logits, hidden_states @@ -1484,6 +1485,7 @@ def capture_cudagraph(self): self.forward_vars["kv_indptr"].gpu.zero_() self.graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] = dict() + self.graph_logits: dict[tuple[int, int], torch.Tensor] = dict() self.graph_pool = None with graph_capture() as gc: @@ -1521,17 +1523,23 @@ def capture_cudagraph(self): num_tokens_across_dp=num_tokens_across_dp, ) + # Warmup: run model forward + compute_logits outputs[:num_tokens] = self.model( input_ids[:num_tokens], positions[:num_tokens] - ) # warmup + ) + self.model.compute_logits(outputs[:num_tokens]) + # Capture: include compute_logits in the graph to avoid + # eager execution overhead during decode replay. with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream): outputs[:num_tokens] = self.model( input_ids[:num_tokens], positions[:num_tokens] - ) # capture + ) + graph_logits = self.model.compute_logits(outputs[:num_tokens]) if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[(bs, max_q_len)] = graph + self.graph_logits[(bs, max_q_len)] = graph_logits torch.cuda.synchronize() self.graph_bs.sort(reverse=False) return time.time() - start_time, self.graph_bs From 0cad4aabeb466601b72867bcc9a4c5eed670d17a Mon Sep 17 00:00:00 2001 From: Li Date: Sun, 15 Feb 2026 16:18:41 -0800 Subject: [PATCH 4/6] Guard compute_logits in HIP Graph for TP=1 only Co-authored-by: Cursor --- atom/model_engine/model_runner.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 62941a13e..5fcc6bd46 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1325,7 +1325,10 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor self.graphs[graph_key].replay() num_tokens = context.batch_size * max_q_len hidden_states = self.forward_vars["outputs"][:num_tokens] - logits = self.graph_logits[graph_key][:num_tokens] + if self.logits_in_graph: + logits = self.graph_logits[graph_key][:num_tokens] + else: + logits = self.model.compute_logits(hidden_states) return logits, hidden_states @@ -1487,6 +1490,7 @@ def capture_cudagraph(self): self.graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] = dict() self.graph_logits: dict[tuple[int, int], torch.Tensor] = dict() self.graph_pool = None + self.logits_in_graph = self.world_size == 1 with graph_capture() as gc: capture_range = ( @@ -1523,23 +1527,29 @@ def capture_cudagraph(self): num_tokens_across_dp=num_tokens_across_dp, ) - # Warmup: run model forward + compute_logits + # Warmup outputs[:num_tokens] = self.model( input_ids[:num_tokens], positions[:num_tokens] ) - self.model.compute_logits(outputs[:num_tokens]) + if self.logits_in_graph: + self.model.compute_logits(outputs[:num_tokens]) - # Capture: include compute_logits in the graph to avoid - # eager execution overhead during decode replay. + # Capture: include compute_logits only when TP=1 since + # ParallelLMHead uses NCCL all_gather which is not + # graph-capturable on HIP when TP > 1. with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream): outputs[:num_tokens] = self.model( input_ids[:num_tokens], positions[:num_tokens] ) - graph_logits = self.model.compute_logits(outputs[:num_tokens]) + if self.logits_in_graph: + graph_logits = self.model.compute_logits( + outputs[:num_tokens] + ) if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[(bs, max_q_len)] = graph - self.graph_logits[(bs, max_q_len)] = graph_logits + if self.logits_in_graph: + self.graph_logits[(bs, max_q_len)] = graph_logits torch.cuda.synchronize() self.graph_bs.sort(reverse=False) return time.time() - start_time, self.graph_bs From adab0172ebe8354777af743fecedd8d988c0ccd3 Mon Sep 17 00:00:00 2001 From: Li Date: Sun, 15 Feb 2026 16:38:48 -0800 Subject: [PATCH 5/6] Fix Black formatting for compute_logits guard Co-authored-by: Cursor --- atom/model_engine/model_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 5fcc6bd46..508c6b4f6 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1542,9 +1542,7 @@ def capture_cudagraph(self): input_ids[:num_tokens], positions[:num_tokens] ) if self.logits_in_graph: - graph_logits = self.model.compute_logits( - outputs[:num_tokens] - ) + graph_logits = self.model.compute_logits(outputs[:num_tokens]) if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[(bs, max_q_len)] = graph From 120af642ed1cbda171859684c0ad20db7ca86129 Mon Sep 17 00:00:00 2001 From: Li Date: Mon, 16 Feb 2026 00:43:24 -0800 Subject: [PATCH 6/6] Enable Triton MXFP4 MoE on gfx950 when ATOM_USE_TRITON_GEMM=1 Extend the Triton MoE kernel path (matmul_ogs + routing from triton_kernels) to gfx950 (MI355X) when ATOM_USE_TRITON_GEMM is enabled. The triton_kernels package already supports gfx950 via GFX950MXScaleLayout. This allows GPT-OSS MXFP4 models on MI355X to use the optimized Triton MoE path with fused routing, Swiglu activation, and matmul_ogs GEMM. The change is opt-in: without ATOM_USE_TRITON_GEMM=1, gfx950 continues to use the CK/ASM path. Co-authored-by: Cursor --- atom/model_ops/moe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 4cbb8fa00..b6820cdda 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -43,6 +43,7 @@ per_tensor_dequantize, shuffle_weights, ) +from atom.utils import envs from atom.utils.custom_register import direct_register_custom_op from atom.utils.forward_context import get_forward_context from torch import nn @@ -639,7 +640,10 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): self.quant_type == QuantType.per_1x128 or self.quant_type == QuantType.per_1x32 ) - self.use_triton = get_gfx().startswith("gfx94") + gfx = get_gfx() + self.use_triton = gfx.startswith("gfx94") or ( + gfx.startswith("gfx95") and envs.ATOM_USE_TRITON_GEMM + ) if self.use_triton: from atom.model_ops.utils import has_triton_kernels