Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def __init__(
quant_name="",
quant_method=None,
exclude_layers: Optional[list[str]] = None,
packed_components: Optional[dict[str, list[str]]] = None,
):
super().__init__()
self["quant_type"] = quant_type if quant_type is not None else QuantType.No
Expand All @@ -263,6 +264,9 @@ def __init__(
self["is_dynamic"] = is_dynamic
self["quant_method"] = quant_method
self["exclude_layers"] = exclude_layers if exclude_layers is not None else []
self["packed_components"] = (
packed_components if packed_components is not None else {}
)

def get_name(self):
return self["quant_name"]
Expand Down
12 changes: 12 additions & 0 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from atom.model_engine.scheduler import ScheduledBatch, ScheduledBatchOutput
from atom.model_engine.sequence import Sequence, SequenceStatus, SequenceType
from atom.model_loader.loader import load_model
from atom.models.utils import build_packed_components_mapping
from atom.model_ops.rejection_sampler import RejectionSampler
from atom.model_ops.sampler import Sampler
from atom.spec_decode.eagle import EagleProposer
Expand Down Expand Up @@ -563,6 +564,7 @@ def __init__(self, rank: int, config: Config):
)

model_class = resolve_obj_by_qualname(support_model_arch_dict[hf_config.architectures[0]]) # type: ignore
self.build_inverse_mapping(model_class)
self.model = model_class(config)
torch.set_default_device(None)
load_model(self.model, config.model, config.hf_config, config.load_dummy)
Expand All @@ -585,6 +587,16 @@ def __init__(self, rank: int, config: Config):
if self.config.compilation_config.level == 1:
self.model = torch.compile(self.model, fullgraph=True, backend="eager")

def build_inverse_mapping(self, model_class: Any):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this part move to quant_config, instead of in model runner

# Build inverse mapping from the model class's packed_modules_mapping
# BEFORE instantiation, so that get_quant_config_for_layer can resolve
# packed names (e.g. "gate_up_proj") during layer construction.
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
if packed_modules_mapping and self.config.quant_config.get("exclude_layers"):
self.config.quant_config["packed_components"] = (
build_packed_components_mapping(packed_modules_mapping)
)

def is_deepseek_mla(self) -> bool:
if not hasattr(self.hf_text_config, "model_type"):
return False
Expand Down
87 changes: 77 additions & 10 deletions atom/model_ops/fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,56 @@ def _swizzle_mxfp4(quant_tensor, scale):
return quant_tensor, InFlexData(), scale


def routing_from_topk(topk_weights, topk_ids, n_expts_tot):
"""Convert FusedMoE.select_experts output to triton routing data structures.

This bridges the gap between ATOM's grouped topk / sigmoid routing
(which triton_kernels routing() does not support) and the triton
matmul_ogs compute kernels.

Args:
topk_weights: (n_tokens, n_expts_act) routing weights from select_experts
topk_ids: (n_tokens, n_expts_act) expert indices from select_experts
n_expts_tot: total number of experts (global, before EP)

Returns:
(RoutingData, GatherIndx, ScatterIndx) compatible with triton_kernel_fused_experts
"""
from triton_kernels.routing import (
RoutingData,
GatherIndx,
ScatterIndx,
compute_expt_data,
)

n_tokens, n_expts_act = topk_weights.shape
n_gates_pad = n_tokens * n_expts_act

# Sort each token's selected experts by expert_id (required by triton kernels)
expt_indx_sorted, sort_indices = torch.sort(topk_ids.int(), dim=1)
expt_scal_sorted = torch.gather(topk_weights, 1, sort_indices.long())

# Flatten to 1D
expt_scal = expt_scal_sorted.reshape(-1).to(topk_weights.dtype)
expt_indx = expt_indx_sorted.reshape(-1).to(torch.int32)

# Sort by expert_id globally so experts are contiguous for the matmul
topk_indx = torch.argsort(expt_indx, stable=True).int()
gate_indx = torch.argsort(topk_indx, stable=True).int()
gate_scal = expt_scal[topk_indx.long()]

# Histogram of tokens over experts
hist = torch.histc(expt_indx.float(), bins=n_expts_tot, max=n_expts_tot - 1).int()

# Build routing data structures using triton-accelerated compute_expt_data
gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx)
scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx)
expt_data = compute_expt_data(hist, n_expts_tot, n_gates_pad)

routing_data = RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data)
return routing_data, gather_indx, scatter_indx


def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
"""
Shrink the given tensor and apply the given view to it. This is
Expand Down Expand Up @@ -161,26 +211,37 @@ def triton_kernel_fused_experts(
if global_num_experts == -1:
global_num_experts = E

half_N = N // 2

if intermediate_cache is None:
intermediate_cache = torch.empty(
(batch_dim, M * topk, N // 2),
(batch_dim, M * topk, half_N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)

# Add batch_dim to output buffer because matmul_ogs expects 3D output
intermediate_cache = _resize_cache(
intermediate_cache, (batch_dim, M * topk, N // 2)
intermediate_cache, (batch_dim, M * topk, half_N)
)
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))

act = FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit),
2,
)
gammas = routing_data.gate_scal if routing_data else None

# NOTE: We intentionally do NOT use the triton fused SwiGLU activation
# because it expects interleaved [gate0, up0, gate1, up1, ...] layout
# while our w13 weights produce concatenated [gate | up] output.
# It also uses a non-standard formula: s*sigmoid(alpha*s)*(linear+1)
# with alpha=1.702, which differs from the standard SiLU activation
# (x*sigmoid(x)*up) used by most MoE models.
# Instead, we compute the matmul without fused activation and apply
# standard silu(gate) * up manually.
raw_intermediate = torch.empty(
(batch_dim, M * topk, N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)

matmul_ogs(
hidden_states,
w1,
Expand All @@ -189,12 +250,17 @@ def triton_kernel_fused_experts(
gather_indx=gather_indx,
precision_config=w13_precision_config,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act,
y=intermediate_cache,
y=raw_intermediate,
)

# Standard SiLU/SwiGLU activation: silu(gate) * up
raw_2d = raw_intermediate.view(M * topk, N)
gate = raw_2d[:, :half_N]
up = raw_2d[:, half_N:]
intermediate_cache[0] = torch.nn.functional.silu(gate) * up

matmul_ogs(
intermediate_cache.view(M * topk, N // 2),
intermediate_cache.view(M * topk, half_N),
w2,
w2_bias,
routing_data,
Expand All @@ -203,5 +269,6 @@ def triton_kernel_fused_experts(
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)

output_tensor = output_tensor.view(M, K)
return output_tensor
5 changes: 5 additions & 0 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,12 @@ def __init__(
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
source_quant_dtype: torch.dtype = None,
prefix: str = "",
**kwargs,
):
self.tp_dim = 0
if quant_config is not None and prefix:
quant_config = get_quant_config_for_layer(quant_config, prefix)
super().__init__(
input_size,
output_size,
Expand Down Expand Up @@ -551,6 +554,7 @@ def __init__(
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
source_quant_dtype: torch.dtype = None,
prefix: str = "",
**kwargs,
):
self.head_size = head_size
Expand Down Expand Up @@ -582,6 +586,7 @@ def __init__(
bias=bias,
quant_config=quant_config,
source_quant_dtype=source_quant_dtype,
prefix=prefix,
)

def weight_loader(
Expand Down
87 changes: 81 additions & 6 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,68 @@ def apply(
activation: ActivationType = ActivationType.Silu,
) -> torch.Tensor:
if self.use_triton:
from atom.model_ops.fused_moe_triton import triton_kernel_moe_forward
from atom.model_ops.fused_moe_triton import (
triton_kernel_moe_forward,
triton_kernel_fused_experts,
routing_from_topk,
)

# Check if the model needs custom routing that triton routing()
# does not support (grouped topk, sigmoid scoring, bias correction).
needs_custom_routing = (
use_grouped_topk
or scoring_func != "softmax"
or e_score_correction_bias is not None
or custom_routing_function is not None
)

if needs_custom_routing:
# Use ATOM's full-featured select_experts for routing,
# then triton matmul_ogs for the actual MoE computation.
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=layer.num_fused_shared_experts,
routed_scaling_factor=layer.routed_scaling_factor,
)

# Convert to triton routing data structures
n_expts_tot = router_logits.shape[-1]
if global_num_experts > 0:
n_expts_tot = global_num_experts

routing_data, gather_idx, scatter_idx = routing_from_topk(
topk_weights, topk_ids, n_expts_tot
)

output = torch.empty_like(x)
_moe_result = triton_kernel_fused_experts(
output,
x,
layer.w13_weight,
layer.w2_weight,
routing_data,
gather_idx,
scatter_idx,
topk=top_k,
activation=activation,
w13_precision_config=self.w13_precision_config,
w2_precision_config=self.w2_precision_config,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
return _moe_result

return triton_kernel_moe_forward(
x,
Expand Down Expand Up @@ -2077,18 +2138,27 @@ def _load_w13(

# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
expert_shard_size = expert_data.shape[shard_dim] // 2
# Derive shard size from loaded_weight (unpadded checkpoint) to avoid
# out-of-bounds when expert_data is padded (e.g. MXFP4 alignment).
load_shard_size = loaded_weight.shape[shard_dim] // self.tp_size
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
shard_dim, load_shard_size * tp_rank, load_shard_size
)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
expert_data = expert_data.narrow(shard_dim, 0, expert_shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data = expert_data.narrow(
shard_dim, expert_shard_size, expert_shard_size
)
# When expert_data is padded beyond the actual weight size, narrow to
# the loaded weight size so the copy shape matches.
if load_shard_size != expert_shard_size:
expert_data = expert_data.narrow(shard_dim, 0, load_shard_size)
if expert_data.dtype != dtypes.fp4x2:
expert_data.copy_(loaded_weight)
else:
Expand All @@ -2108,9 +2178,14 @@ def _load_w2(
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if not load_full:
# Derive shard size from loaded_weight (unpadded checkpoint) to
# avoid out-of-bounds when expert_data is padded (e.g. MXFP4).
load_shard_size = loaded_weight.shape[shard_dim] // self.tp_size
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
shard_dim, load_shard_size * tp_rank, load_shard_size
)
if load_shard_size != shard_size:
expert_data = expert_data.narrow(shard_dim, 0, load_shard_size)
# w2, down_proj: Load into only logical weight of w2.
if expert_data.dtype != dtypes.fp4x2:
expert_data.copy_(loaded_weight)
Expand Down
1 change: 1 addition & 0 deletions atom/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __init__(
prefix=f"{prefix}.experts",
scoring_func="sigmoid",
e_score_correction_bias=self.gate.e_score_correction_bias,
has_bias=getattr(config, "moe_ffn_bias", False),
config=config,
)

Expand Down
Loading