Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,14 +1367,18 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor
positions = context.positions
if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]:
hidden_states = self.model(input_ids, positions)
logits = self.model.compute_logits(hidden_states)
else:
graph_bs = context.graph_bs
max_q_len = forward_context.attn_metadata.max_seqlen_q
graph_key = (graph_bs, max_q_len)
self.graphs[graph_key].replay()
num_tokens = context.batch_size * max_q_len
hidden_states = self.forward_vars["outputs"][:num_tokens]
logits = self.model.compute_logits(hidden_states)
if self.logits_in_graph:
logits = self.graph_logits[graph_key][:num_tokens]
else:
logits = self.model.compute_logits(hidden_states)

return logits, hidden_states

Expand Down Expand Up @@ -1588,7 +1592,9 @@ def capture_cudagraph(self):
outputs = self.forward_vars["outputs"]

self.graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] = dict()
self.graph_logits: dict[tuple[int, int], torch.Tensor] = dict()
self.graph_pool = None
self.logits_in_graph = self.world_size == 1

with graph_capture() as gc:
capture_range = (
Expand Down Expand Up @@ -1625,17 +1631,27 @@ def capture_cudagraph(self):
num_tokens_across_dp=num_tokens_across_dp,
)

# Warmup
outputs[:num_tokens] = self.model(
input_ids[:num_tokens], positions[:num_tokens]
) # warmup
)
if self.logits_in_graph:
self.model.compute_logits(outputs[:num_tokens])

# Capture: include compute_logits only when TP=1 since
# ParallelLMHead uses NCCL all_gather which is not
# graph-capturable on HIP when TP > 1.
with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream):
outputs[:num_tokens] = self.model(
input_ids[:num_tokens], positions[:num_tokens]
) # capture
)
if self.logits_in_graph:
graph_logits = self.model.compute_logits(outputs[:num_tokens])
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[(bs, max_q_len)] = graph
if self.logits_in_graph:
self.graph_logits[(bs, max_q_len)] = graph_logits
torch.cuda.synchronize()
self.graph_bs.sort(reverse=False)
return time.time() - start_time, self.graph_bs
58 changes: 51 additions & 7 deletions atom/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@

import torch
import torch.distributed as dist
import torch.nn.functional as F
from aiter import ActivationType
from aiter.dist.communication_op import tensor_model_parallel_all_gather
from aiter.dist.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size

# from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand All @@ -43,10 +47,13 @@
make_layers,
maybe_prefix,
)
from atom.utils import envs
from atom.utils.decorators import support_torch_compile
from torch import nn
from transformers import GptOssConfig

ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION


def cdiv(x, y):
return (x + y - 1) // y
Expand Down Expand Up @@ -114,6 +121,7 @@ def __init__(
quant_config=None,
prefix=f"{prefix}.o_proj",
bias=True,
reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION,
)

self.num_local_attention_heads = config.num_attention_heads // tp_size
Expand Down Expand Up @@ -163,6 +171,7 @@ def __init__(
self.hidden_size = config.hidden_size
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.tp_size = get_tensor_model_parallel_world_size()
self.router = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
Expand All @@ -176,7 +185,7 @@ def __init__(
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
reduce_results=True,
reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
Expand All @@ -185,13 +194,33 @@ def __init__(
activation=ActivationType.Swiglu,
config=config,
)
# Detect MXFP4 MoE GEMM padding requirement from the quant method.
# When hidden_size is not aligned to 256, MXFP4 weights are padded
# and the kernel expects padded input. We handle padding here instead
# of in the layernorm, so the layernorm can use fused AllReduce.
if hasattr(self.experts.quant_method, "hidden_pad"):
self.moe_hidden_pad = self.experts.quant_method.hidden_pad
else:
self.moe_hidden_pad = 0

def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]

g = self.router(x[..., : self.hidden_size])
g = self.router(x)

# Pad input for MXFP4 MoE GEMM alignment if needed
if self.moe_hidden_pad > 0:
x = F.pad(x, (0, self.moe_hidden_pad))

x = self.experts(hidden_states=x, router_logits=g)

# Remove padding from output
if self.moe_hidden_pad > 0:
x = x[:, : self.hidden_size]

if self.tp_size > 1 and not ENABLE_ALLREDUCE_RMSNORM_FUSION:
x = tensor_model_parallel_all_reduce(x)

if self.is_sequence_parallel:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
x = x[:num_tokens]
Expand Down Expand Up @@ -221,9 +250,20 @@ def __init__(
layer_num=layer_num,
)
self.mlp = MLPBlock(atom_config, self.layer_idx, prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
# Fuse MoE AllReduce into input_layernorm for layers > 0.
# Layer 0 receives already-reduced embedding output, so no fusion needed.
self.input_layernorm = RMSNorm(
config.hidden_size,
eps=1e-5,
fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION and layer_num > 0,
)
# Fuse o_proj AllReduce into post_attention_layernorm.
# Padding for MXFP4 MoE GEMM alignment is now handled inside MLPBlock,
# so this layernorm no longer needs x_pad_to_multiple.
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=1e-5, x_pad_to_multiple=256
config.hidden_size,
eps=1e-5,
fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION,
)

def forward(
Expand All @@ -242,7 +282,7 @@ def forward(

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
output = self.mlp(hidden_states)[:, : self.hidden_size]
output = self.mlp(hidden_states)
return output, residual


Expand Down Expand Up @@ -273,7 +313,11 @@ def __init__(
),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
self.norm = RMSNorm(
self.config.hidden_size,
eps=1e-5,
fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION,
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size
)
Expand Down