From 55bb8c143e52e5c2188cf6e17ec4354ef4f55b68 Mon Sep 17 00:00:00 2001 From: Li Date: Sun, 15 Feb 2026 01:59:56 -0800 Subject: [PATCH 1/5] Enable AllReduce+RMSNorm fusion for GPT-OSS model Apply the same AllReduce+RMSNorm fusion pattern already used by Qwen3-MoE and DeepSeek-V2 to the GPT-OSS model. When ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION=1 (default) and TP > 1, the MoE output AllReduce is deferred and fused into the next layers input_layernorm, reducing one AllReduce kernel launch per layer. Changes: Set FusedMoE reduce_results=False when fusion is enabled with a fallback manual AllReduce when disabled. Enable fused_allreduce on input_layernorm (layer > 0) and the final norm. post_attention_layernorm is unchanged because x_pad_to_multiple=256 is incompatible with the fused kernel. Mathematically equivalent to the original. Can be disabled via ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION=0. Co-authored-by: Cursor --- atom/models/gpt_oss.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index eee08a710..313f916aa 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -22,7 +22,10 @@ import torch import torch.distributed as dist from aiter import ActivationType -from aiter.dist.communication_op import tensor_model_parallel_all_gather +from aiter.dist.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size # from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -43,10 +46,13 @@ make_layers, maybe_prefix, ) +from atom.utils import envs from atom.utils.decorators import support_torch_compile from torch import nn from transformers import GptOssConfig +ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION + def cdiv(x, y): return (x + y - 1) // y @@ -163,6 +169,7 @@ def __init__( self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.tp_size = get_tensor_model_parallel_world_size() self.router = ReplicatedLinear( config.hidden_size, config.num_local_experts, @@ -176,7 +183,7 @@ def __init__( top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - reduce_results=True, + reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, renormalize=True, quant_config=quant_config, prefix=f"{prefix}.experts", @@ -192,6 +199,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: g = self.router(x[..., : self.hidden_size]) x = self.experts(hidden_states=x, router_logits=g) + if self.tp_size > 1 and not ENABLE_ALLREDUCE_RMSNORM_FUSION: + x = tensor_model_parallel_all_reduce(x) + if self.is_sequence_parallel: x = tensor_model_parallel_all_gather(x.contiguous(), 0) x = x[:num_tokens] @@ -221,7 +231,15 @@ def __init__( layer_num=layer_num, ) self.mlp = MLPBlock(atom_config, self.layer_idx, prefix=f"{prefix}.mlp") - self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + # Fuse MoE AllReduce into input_layernorm for layers > 0. + # Layer 0 receives already-reduced embedding output, so no fusion needed. + self.input_layernorm = RMSNorm( + config.hidden_size, + eps=1e-5, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION and layer_num > 0, + ) + # post_attention_layernorm cannot use fused_allreduce because + # x_pad_to_multiple is incompatible with the fused kernel. self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=1e-5, x_pad_to_multiple=256 ) @@ -273,7 +291,11 @@ def __init__( ), prefix=f"{prefix}.layers", ) - self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) + self.norm = RMSNorm( + self.config.hidden_size, + eps=1e-5, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, + ) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], self.config.hidden_size ) From fd6d67c8c371064dd4cc817a3768253fd8531181 Mon Sep 17 00:00:00 2001 From: Li Date: Sun, 15 Feb 2026 02:16:48 -0800 Subject: [PATCH 2/5] Move MXFP4 padding from layernorm to MLPBlock, enable full AllReduce+RMSNorm fusion Previously, post_attention_layernorm used x_pad_to_multiple=256 for MXFP4 GEMM alignment, which prevented fused AllReduce on the attention output path. This moves the padding logic into MLPBlock.forward(), allowing both AllReduces per layer to be fused with their downstream RMSNorm: - o_proj AllReduce fused into post_attention_layernorm (all layers) - MoE AllReduce fused into next layer input_layernorm (layers > 0) This eliminates all standalone AllReduce kernel launches (70 per forward pass for 36 layers), reducing per-layer latency by ~60us. Co-authored-by: Cursor --- atom/models/gpt_oss.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index 313f916aa..f4b006198 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -21,6 +21,7 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from aiter import ActivationType from aiter.dist.communication_op import ( tensor_model_parallel_all_gather, @@ -120,6 +121,7 @@ def __init__( quant_config=None, prefix=f"{prefix}.o_proj", bias=True, + reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, ) self.num_local_attention_heads = config.num_attention_heads // tp_size @@ -192,13 +194,30 @@ def __init__( activation=ActivationType.Swiglu, config=config, ) + # Detect MXFP4 MoE GEMM padding requirement from the quant method. + # When hidden_size is not aligned to 256, MXFP4 weights are padded + # and the kernel expects padded input. We handle padding here instead + # of in the layernorm, so the layernorm can use fused AllReduce. + if hasattr(self.experts.quant_method, "hidden_pad"): + self.moe_hidden_pad = self.experts.quant_method.hidden_pad + else: + self.moe_hidden_pad = 0 def forward(self, x: torch.Tensor) -> torch.Tensor: num_tokens = x.shape[0] - g = self.router(x[..., : self.hidden_size]) + g = self.router(x) + + # Pad input for MXFP4 MoE GEMM alignment if needed + if self.moe_hidden_pad > 0: + x = F.pad(x, (0, self.moe_hidden_pad)) + x = self.experts(hidden_states=x, router_logits=g) + # Remove padding from output + if self.moe_hidden_pad > 0: + x = x[:, : self.hidden_size] + if self.tp_size > 1 and not ENABLE_ALLREDUCE_RMSNORM_FUSION: x = tensor_model_parallel_all_reduce(x) @@ -238,10 +257,13 @@ def __init__( eps=1e-5, fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION and layer_num > 0, ) - # post_attention_layernorm cannot use fused_allreduce because - # x_pad_to_multiple is incompatible with the fused kernel. + # Fuse o_proj AllReduce into post_attention_layernorm. + # Padding for MXFP4 MoE GEMM alignment is now handled inside MLPBlock, + # so this layernorm no longer needs x_pad_to_multiple. self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=1e-5, x_pad_to_multiple=256 + config.hidden_size, + eps=1e-5, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, ) def forward( @@ -260,7 +282,7 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - output = self.mlp(hidden_states)[:, : self.hidden_size] + output = self.mlp(hidden_states) return output, residual From 0595e0b3e4d2f658db10420b61d32e11da390a65 Mon Sep 17 00:00:00 2001 From: Li Date: Sun, 15 Feb 2026 02:43:34 -0800 Subject: [PATCH 3/5] Include compute_logits in HIP Graph for decode Previously, compute_logits (lm_head GEMM + TP AllGather) ran eagerly after every graph replay, adding Python overhead and kernel launch gaps. Now it is captured in the HIP Graph alongside the model forward, eliminating this overhead during decode. For prefill and eager fallback, compute_logits continues to run eagerly. The graph_logits dict stores references to graph-pool tensors that are updated in-place during replay, avoiding any extra allocation. Co-authored-by: Cursor --- atom/model_engine/model_runner.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 24b56879a..5c3c376d2 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1367,6 +1367,7 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor positions = context.positions if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]: hidden_states = self.model(input_ids, positions) + logits = self.model.compute_logits(hidden_states) else: graph_bs = context.graph_bs max_q_len = forward_context.attn_metadata.max_seqlen_q @@ -1374,7 +1375,7 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor self.graphs[graph_key].replay() num_tokens = context.batch_size * max_q_len hidden_states = self.forward_vars["outputs"][:num_tokens] - logits = self.model.compute_logits(hidden_states) + logits = self.graph_logits[graph_key][:num_tokens] return logits, hidden_states @@ -1588,6 +1589,7 @@ def capture_cudagraph(self): outputs = self.forward_vars["outputs"] self.graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] = dict() + self.graph_logits: dict[tuple[int, int], torch.Tensor] = dict() self.graph_pool = None with graph_capture() as gc: @@ -1625,17 +1627,23 @@ def capture_cudagraph(self): num_tokens_across_dp=num_tokens_across_dp, ) + # Warmup: run model forward + compute_logits outputs[:num_tokens] = self.model( input_ids[:num_tokens], positions[:num_tokens] - ) # warmup + ) + self.model.compute_logits(outputs[:num_tokens]) + # Capture: include compute_logits in the graph to avoid + # eager execution overhead during decode replay. with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream): outputs[:num_tokens] = self.model( input_ids[:num_tokens], positions[:num_tokens] - ) # capture + ) + graph_logits = self.model.compute_logits(outputs[:num_tokens]) if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[(bs, max_q_len)] = graph + self.graph_logits[(bs, max_q_len)] = graph_logits torch.cuda.synchronize() self.graph_bs.sort(reverse=False) return time.time() - start_time, self.graph_bs From 311d2acd763688ef5a92cea32bd2757a864d9fff Mon Sep 17 00:00:00 2001 From: Li Date: Sun, 15 Feb 2026 16:18:41 -0800 Subject: [PATCH 4/5] Guard compute_logits in HIP Graph for TP=1 only Co-authored-by: Cursor --- atom/model_engine/model_runner.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 5c3c376d2..7f3e366cb 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1375,7 +1375,10 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor self.graphs[graph_key].replay() num_tokens = context.batch_size * max_q_len hidden_states = self.forward_vars["outputs"][:num_tokens] - logits = self.graph_logits[graph_key][:num_tokens] + if self.logits_in_graph: + logits = self.graph_logits[graph_key][:num_tokens] + else: + logits = self.model.compute_logits(hidden_states) return logits, hidden_states @@ -1591,6 +1594,7 @@ def capture_cudagraph(self): self.graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] = dict() self.graph_logits: dict[tuple[int, int], torch.Tensor] = dict() self.graph_pool = None + self.logits_in_graph = self.world_size == 1 with graph_capture() as gc: capture_range = ( @@ -1627,23 +1631,29 @@ def capture_cudagraph(self): num_tokens_across_dp=num_tokens_across_dp, ) - # Warmup: run model forward + compute_logits + # Warmup outputs[:num_tokens] = self.model( input_ids[:num_tokens], positions[:num_tokens] ) - self.model.compute_logits(outputs[:num_tokens]) + if self.logits_in_graph: + self.model.compute_logits(outputs[:num_tokens]) - # Capture: include compute_logits in the graph to avoid - # eager execution overhead during decode replay. + # Capture: include compute_logits only when TP=1 since + # ParallelLMHead uses NCCL all_gather which is not + # graph-capturable on HIP when TP > 1. with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream): outputs[:num_tokens] = self.model( input_ids[:num_tokens], positions[:num_tokens] ) - graph_logits = self.model.compute_logits(outputs[:num_tokens]) + if self.logits_in_graph: + graph_logits = self.model.compute_logits( + outputs[:num_tokens] + ) if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[(bs, max_q_len)] = graph - self.graph_logits[(bs, max_q_len)] = graph_logits + if self.logits_in_graph: + self.graph_logits[(bs, max_q_len)] = graph_logits torch.cuda.synchronize() self.graph_bs.sort(reverse=False) return time.time() - start_time, self.graph_bs From 4b432098dedfcab4b1275e0e4d97b8c224f5b0be Mon Sep 17 00:00:00 2001 From: Li Date: Sun, 15 Feb 2026 16:38:48 -0800 Subject: [PATCH 5/5] Fix Black formatting for compute_logits guard Co-authored-by: Cursor --- atom/model_engine/model_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 7f3e366cb..aceb9803e 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1646,9 +1646,7 @@ def capture_cudagraph(self): input_ids[:num_tokens], positions[:num_tokens] ) if self.logits_in_graph: - graph_logits = self.model.compute_logits( - outputs[:num_tokens] - ) + graph_logits = self.model.compute_logits(outputs[:num_tokens]) if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[(bs, max_q_len)] = graph