Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"GptOssForCausalLM": "atom.models.gpt_oss.GptOssForCausalLM",
"Glm4MoeForCausalLM": "atom.models.glm4_moe.Glm4MoeForCausalLM",
"Qwen3NextForCausalLM": "atom.models.qwen3_next.Qwen3NextForCausalLM",
"MiniMaxM2ForCausalLM": "atom.models.minimax_m2.MiniMaxM2ForCausalLM",
}
# seed = 34567
# np.random.seed(seed)
Expand Down
86 changes: 65 additions & 21 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,10 +1904,6 @@ def __init__(

self.use_chunked = get_dp_group().world_size > 1

if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError(
"Only softmax scoring function is supported for " "non-grouped topk."
)
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=self.top_k,
Expand Down Expand Up @@ -2078,21 +2074,34 @@ def _load_w13(
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size

# Calculate original shard size from loaded_weight
# Assuming loaded_weight is the full tensor (or one full partition if partially loaded)
# Here we assume loaded_weight is full tensor
original_shard_size = loaded_weight.shape[shard_dim] // self.tp_size
valid_shard_size = min(shard_size, original_shard_size)

# Load valid part from loaded_weight
loaded_shard = loaded_weight.narrow(
shard_dim, original_shard_size * tp_rank, valid_shard_size
)

# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
expert_data_slice = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data_slice = expert_data.narrow(shard_dim, shard_size, shard_size)

# Determine slice of expert_data to copy into (handle padding)
expert_data_valid = expert_data_slice.narrow(shard_dim, 0, valid_shard_size)

if expert_data.dtype != dtypes.fp4x2:
expert_data.copy_(loaded_weight)
expert_data_valid.copy_(loaded_shard)
else:
expert_data.view(torch.uint8).copy_(loaded_weight.view(torch.uint8))
expert_data_valid.view(torch.uint8).copy_(loaded_shard.view(torch.uint8))

def _load_w2(
self,
Expand All @@ -2107,15 +2116,28 @@ def _load_w2(
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]

if not load_full:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
original_shard_size = loaded_weight.shape[shard_dim] // self.tp_size
valid_shard_size = min(shard_size, original_shard_size)

loaded_shard = loaded_weight.narrow(
shard_dim, original_shard_size * tp_rank, valid_shard_size
)
# w2, down_proj: Load into only logical weight of w2.
if expert_data.dtype != dtypes.fp4x2:
expert_data.copy_(loaded_weight)
expert_data_valid = expert_data.narrow(shard_dim, 0, valid_shard_size)

if expert_data.dtype != dtypes.fp4x2:
expert_data_valid.copy_(loaded_shard)
else:
expert_data_valid.view(torch.uint8).copy_(
loaded_shard.view(torch.uint8)
)
else:
expert_data.view(torch.uint8).copy_(loaded_weight.view(torch.uint8))
# Full load
if expert_data.dtype != dtypes.fp4x2:
expert_data.copy_(loaded_weight)
else:
expert_data.view(torch.uint8).copy_(loaded_weight.view(torch.uint8))

def _load_single_value(
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
Expand Down Expand Up @@ -2388,11 +2410,33 @@ def select_experts(
num_fused_shared_experts=num_fused_shared_experts,
)
else:
topk_weights, topk_ids = fused_topk(
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
if scoring_func == "softmax":
topk_weights, topk_ids = fused_topk(
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
elif scoring_func == "sigmoid":
routing_weights = torch.sigmoid(router_logits.float())
scores_for_choice = routing_weights
if e_score_correction_bias is not None:
scores_for_choice = scores_for_choice + e_score_correction_bias

topk_ids = torch.topk(
scores_for_choice, top_k, dim=-1, sorted=False
).indices
topk_weights = routing_weights.gather(dim=-1, index=topk_ids)

if renormalize:
topk_weights = topk_weights / topk_weights.sum(
dim=-1, keepdim=True
).clamp_min(1e-20)

topk_ids = topk_ids.to(torch.int32)
else:
raise ValueError(
f"Unsupported scoring function for non-grouped topk: {scoring_func}"
)

return topk_weights, topk_ids

Expand Down
Loading
Loading