diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 24b56879a..aceb9803e 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1367,6 +1367,7 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor positions = context.positions if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]: hidden_states = self.model(input_ids, positions) + logits = self.model.compute_logits(hidden_states) else: graph_bs = context.graph_bs max_q_len = forward_context.attn_metadata.max_seqlen_q @@ -1374,7 +1375,10 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor self.graphs[graph_key].replay() num_tokens = context.batch_size * max_q_len hidden_states = self.forward_vars["outputs"][:num_tokens] - logits = self.model.compute_logits(hidden_states) + if self.logits_in_graph: + logits = self.graph_logits[graph_key][:num_tokens] + else: + logits = self.model.compute_logits(hidden_states) return logits, hidden_states @@ -1588,7 +1592,9 @@ def capture_cudagraph(self): outputs = self.forward_vars["outputs"] self.graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] = dict() + self.graph_logits: dict[tuple[int, int], torch.Tensor] = dict() self.graph_pool = None + self.logits_in_graph = self.world_size == 1 with graph_capture() as gc: capture_range = ( @@ -1625,17 +1631,27 @@ def capture_cudagraph(self): num_tokens_across_dp=num_tokens_across_dp, ) + # Warmup outputs[:num_tokens] = self.model( input_ids[:num_tokens], positions[:num_tokens] - ) # warmup + ) + if self.logits_in_graph: + self.model.compute_logits(outputs[:num_tokens]) + # Capture: include compute_logits only when TP=1 since + # ParallelLMHead uses NCCL all_gather which is not + # graph-capturable on HIP when TP > 1. with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream): outputs[:num_tokens] = self.model( input_ids[:num_tokens], positions[:num_tokens] - ) # capture + ) + if self.logits_in_graph: + graph_logits = self.model.compute_logits(outputs[:num_tokens]) if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[(bs, max_q_len)] = graph + if self.logits_in_graph: + self.graph_logits[(bs, max_q_len)] = graph_logits torch.cuda.synchronize() self.graph_bs.sort(reverse=False) return time.time() - start_time, self.graph_bs diff --git a/atom/models/gpt_oss.py b/atom/models/gpt_oss.py index eee08a710..f4b006198 100644 --- a/atom/models/gpt_oss.py +++ b/atom/models/gpt_oss.py @@ -21,8 +21,12 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from aiter import ActivationType -from aiter.dist.communication_op import tensor_model_parallel_all_gather +from aiter.dist.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size # from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -43,10 +47,13 @@ make_layers, maybe_prefix, ) +from atom.utils import envs from atom.utils.decorators import support_torch_compile from torch import nn from transformers import GptOssConfig +ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION + def cdiv(x, y): return (x + y - 1) // y @@ -114,6 +121,7 @@ def __init__( quant_config=None, prefix=f"{prefix}.o_proj", bias=True, + reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, ) self.num_local_attention_heads = config.num_attention_heads // tp_size @@ -163,6 +171,7 @@ def __init__( self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.tp_size = get_tensor_model_parallel_world_size() self.router = ReplicatedLinear( config.hidden_size, config.num_local_experts, @@ -176,7 +185,7 @@ def __init__( top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - reduce_results=True, + reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, renormalize=True, quant_config=quant_config, prefix=f"{prefix}.experts", @@ -185,13 +194,33 @@ def __init__( activation=ActivationType.Swiglu, config=config, ) + # Detect MXFP4 MoE GEMM padding requirement from the quant method. + # When hidden_size is not aligned to 256, MXFP4 weights are padded + # and the kernel expects padded input. We handle padding here instead + # of in the layernorm, so the layernorm can use fused AllReduce. + if hasattr(self.experts.quant_method, "hidden_pad"): + self.moe_hidden_pad = self.experts.quant_method.hidden_pad + else: + self.moe_hidden_pad = 0 def forward(self, x: torch.Tensor) -> torch.Tensor: num_tokens = x.shape[0] - g = self.router(x[..., : self.hidden_size]) + g = self.router(x) + + # Pad input for MXFP4 MoE GEMM alignment if needed + if self.moe_hidden_pad > 0: + x = F.pad(x, (0, self.moe_hidden_pad)) + x = self.experts(hidden_states=x, router_logits=g) + # Remove padding from output + if self.moe_hidden_pad > 0: + x = x[:, : self.hidden_size] + + if self.tp_size > 1 and not ENABLE_ALLREDUCE_RMSNORM_FUSION: + x = tensor_model_parallel_all_reduce(x) + if self.is_sequence_parallel: x = tensor_model_parallel_all_gather(x.contiguous(), 0) x = x[:num_tokens] @@ -221,9 +250,20 @@ def __init__( layer_num=layer_num, ) self.mlp = MLPBlock(atom_config, self.layer_idx, prefix=f"{prefix}.mlp") - self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + # Fuse MoE AllReduce into input_layernorm for layers > 0. + # Layer 0 receives already-reduced embedding output, so no fusion needed. + self.input_layernorm = RMSNorm( + config.hidden_size, + eps=1e-5, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION and layer_num > 0, + ) + # Fuse o_proj AllReduce into post_attention_layernorm. + # Padding for MXFP4 MoE GEMM alignment is now handled inside MLPBlock, + # so this layernorm no longer needs x_pad_to_multiple. self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=1e-5, x_pad_to_multiple=256 + config.hidden_size, + eps=1e-5, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, ) def forward( @@ -242,7 +282,7 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - output = self.mlp(hidden_states)[:, : self.hidden_size] + output = self.mlp(hidden_states) return output, residual @@ -273,7 +313,11 @@ def __init__( ), prefix=f"{prefix}.layers", ) - self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) + self.norm = RMSNorm( + self.config.hidden_size, + eps=1e-5, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, + ) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], self.config.hidden_size )