Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 267 additions & 8 deletions mistralrs-core/src/models/qwen3_moe.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use candle_core::{DType, Device, Module, Result, Tensor, D};
use candle_nn::Linear;
use mistralrs_quant::{
ColumnParallelLayer, FusedExperts, QuantMethod, QuantizedConfig, ReplicatedLayer,
RowParallelLayer, ShardedVarBuilder,
apply_immediate_isq, ColumnParallelLayer, FusedExperts, QuantMethod, QuantMethodConfig,
QuantizedConfig, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder, UnquantLinear,
};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc};
Expand Down Expand Up @@ -443,6 +444,7 @@
.gather_forward_autocast(&(up * gate.apply(&self.act)?)?, &indices)?;
xs.squeeze(D::Minus2)?
};
dbg!(&ys, &scores);

ys.to_dtype(DType::F32)?
.broadcast_mul(&scores.unsqueeze(D::Minus1)?)?
Expand All @@ -452,7 +454,260 @@
}
}

struct GroupedMoeMlp {
gate: Arc<dyn QuantMethod>,
gate_proj_vec: Vec<Arc<dyn QuantMethod>>,
up_proj_vec: Vec<Arc<dyn QuantMethod>>,
down_proj_vec: Vec<Arc<dyn QuantMethod>>,
act: Activation,
norm_topk_prob: bool,
num_experts_per_tok: usize,
}

impl GroupedMoeMlp {
fn new(
cfg: &Config,
vb: ShardedVarBuilder,
layer_device: Device,
_comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let num_experts = cfg.num_experts;
let gate = mistralrs_quant::linear_no_bias(
cfg.hidden_size,
num_experts,
&cfg.quantization_config,
vb.pp("gate").set_device(layer_device),
)?;

let experts_vb = vb.pp("experts");
let mut gate_proj_vec: Vec<Arc<dyn QuantMethod>> = Vec::new();
let mut up_proj_vec: Vec<Arc<dyn QuantMethod>> = Vec::new();
let mut down_proj_vec: Vec<Arc<dyn QuantMethod>> = Vec::new();
for i in 0..num_experts {
let vb = experts_vb.pp(i);
let gate_proj = vb.get(
(cfg.moe_intermediate_size, cfg.hidden_size),
"gate_proj.weight",
)?;
let up_proj = vb.get(
(cfg.moe_intermediate_size, cfg.hidden_size),
"up_proj.weight",
)?;
let down_proj = vb.get(
(cfg.hidden_size, cfg.moe_intermediate_size),
"down_proj.weight",
)?;

gate_proj_vec.push(apply_immediate_isq(
Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
Linear::new(gate_proj, None),
))?),
vb.pp("gate_proj"),
)?);
up_proj_vec.push(apply_immediate_isq(
Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
Linear::new(up_proj, None),
))?),
vb.pp("up_proj"),
)?);
down_proj_vec.push(apply_immediate_isq(
Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
Linear::new(down_proj, None),
))?),
vb.pp("down_proj"),
)?);
}

Ok(Self {
gate,
gate_proj_vec,
up_proj_vec,
down_proj_vec,
act: cfg.hidden_act,
norm_topk_prob: cfg.norm_topk_prob,
num_experts_per_tok: cfg.num_experts_per_tok,
})
}

fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let original_dtype = xs.dtype();

let (b_size, seq_len, hidden_dim) = xs.dims3()?;

let router_logits = self.gate.forward_autocast(xs)?;
let routing_weights =
candle_nn::ops::softmax_last_dim(&router_logits.to_dtype(DType::F32)?)?;

let indices = routing_weights.arg_sort_last_dim(false)?.narrow(
D::Minus1,
0,
self.num_experts_per_tok,
)?;
let mut scores = routing_weights.gather(&indices.contiguous()?, D::Minus1)?;

if self.norm_topk_prob {
scores = scores.broadcast_div(&scores.sum_keepdim(D::Minus1)?)?;
}

let ys = {
let gate = grouped_mat_mul(&xs, &self.gate_proj_vec, &indices)?;

Check failure on line 553 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Clippy

this expression creates a reference which is immediately dereferenced by the compiler
let up = grouped_mat_mul(&xs, &self.up_proj_vec, &indices)?;

Check failure on line 554 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Clippy

this expression creates a reference which is immediately dereferenced by the compiler
grouped_mat_mul(
&(up * gate.apply(&self.act)?)?,
&self.down_proj_vec,
&indices,
)?
};
dbg!(&ys, &scores);

ys.to_dtype(DType::F32)?
.broadcast_mul(&scores.unsqueeze(D::Minus1)?)?
.sum(D::Minus2)?
.reshape((b_size, seq_len, hidden_dim))?
.to_dtype(original_dtype)
}
}

/// Grouped matrix multiplication for MoE layers.
///
/// For each token, applies the selected experts and returns their outputs separately
/// (not accumulated), preserving the topk dimension for later weighting.
///
/// - `xs`: Input tensor - can be:
/// - 2D: (tokens, hidden)
/// - 3D: (bs, seqlen, hidden)
/// - 4D: (bs, seqlen, topk, intermediate) - for down_proj after gate*up
/// - `ids`: (bs, seqlen, topk) - expert indices for each token
///
/// Returns: (bs, seqlen, topk, output_dim) where each token has topk expert outputs
fn grouped_mat_mul(
xs: &Tensor,
experts: &Vec<Arc<dyn QuantMethod>>,
ids: &Tensor,
) -> Result<Tensor> {
let (bs, seqlen, hidden, topk, is_4d) = match xs.dims().len() {
2 => {
let (tokens, hidden) = xs.dims2()?;
let topk = ids.dim(D::Minus1)?;
(1, tokens, hidden, topk, false)
}
3 => {
let (bs, seqlen, hidden) = xs.dims3()?;
let topk = ids.dim(D::Minus1)?;
(bs, seqlen, hidden, topk, false)
}
4 => {
// Input is (bs, seqlen, topk, hidden) - intermediate activations
// In this case, ids should still be (bs, seqlen, topk)
let dims4 = xs.dims4()?;
(dims4.0, dims4.1, dims4.3, dims4.2, true)
}
_ => candle_core::bail!(
"grouped_mat_mul expects xs to be 2D, 3D, or 4D, got shape {:?}",
xs.dims()
),
};

// Flatten to (bs * seqlen, hidden) or (bs * seqlen * topk, hidden) for 4D case
let (xs_2d, ids_vec) = if is_4d {
// For 4D input (bs, seqlen, topk, hidden), reshape to (bs * seqlen * topk, hidden)
let xs_flat = xs.reshape((bs * seqlen * topk, hidden))?;
// ids is (bs, seqlen, topk)
let ids_2d = ids.reshape((bs * seqlen, topk))?;
let ids_vec = ids_2d.to_vec2::<u32>()?;
(xs_flat, ids_vec)
} else {
// For 2D/3D input, flatten to (bs * seqlen, hidden)
let xs_2d = xs.reshape((bs * seqlen, hidden))?;
let ids_2d = ids.reshape((bs * seqlen, topk))?;
let ids_vec = ids_2d.to_vec2::<u32>()?;
(xs_2d, ids_vec)
};

// Build a map based on whether input is 4D or not
let expert_indices_map: Vec<Vec<usize>> = if is_4d {
// For 4D input: map expert_id -> flat indices into (bs*seqlen*topk) dimension
let mut map: Vec<Vec<usize>> = vec![vec![]; experts.len()];
for (token_idx, expert_ids) in ids_vec.iter().enumerate() {
for (topk_idx, &expert_id) in expert_ids.iter().enumerate() {
let flat_idx = token_idx * topk + topk_idx;
map[expert_id as usize].push(flat_idx);
}
}
map
} else {
// For 2D/3D input: build the (token_idx, topk_idx) mapping as before
// We'll convert this to flat indices later
let mut map: Vec<Vec<usize>> = vec![vec![]; experts.len()];
for (token_idx, expert_ids) in ids_vec.iter().enumerate() {
for (topk_idx, &expert_id) in expert_ids.iter().enumerate() {
// For non-4D, we'll use a different approach - store the flat output index
let flat_output_idx = token_idx * topk + topk_idx;
map[expert_id as usize].push(flat_output_idx);
}
}
map
};


// Determine output dimension by processing first non-empty expert
let mut out_dim: Option<usize> = None;
let mut dtype: Option<DType> = None;
for expert_layer in experts.iter() {

Check failure on line 656 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Clippy

this loop never actually loops
let test_input = xs_2d.narrow(0, 0, 1)?;
let test_output = expert_layer.forward_autocast(&test_input)?;
out_dim = Some(test_output.dim(D::Minus1)?);
dtype = Some(test_output.dtype());
break;
}
let out_dim = out_dim.ok_or_else(|| candle_core::Error::Msg("No experts available".to_string()))?;
let dtype = dtype.unwrap();

// Build output
let mut output_data: Vec<f32> = vec![0.0; bs * seqlen * topk * out_dim];

// Process each expert
for (expert_idx, expert_layer) in experts.iter().enumerate() {
let flat_indices = &expert_indices_map[expert_idx];
if flat_indices.is_empty() {
continue;
}

// Gather inputs for this expert
let indices_tensor = Tensor::new(
flat_indices.iter().map(|&x| x as u32).collect::<Vec<_>>().as_slice(),
xs_2d.device(),
)?;
let expert_input = xs_2d.index_select(&indices_tensor, 0)?;

// Run expert forward pass
let expert_output = expert_layer.forward_autocast(&expert_input)?;

// Convert to f32 on CPU
let expert_output_f32 = expert_output.to_dtype(DType::F32)?.to_device(&candle_core::Device::Cpu)?.to_vec2::<f32>()?;

// Place outputs in correct positions
for (i, &flat_idx) in flat_indices.iter().enumerate() {
let base_idx = flat_idx * out_dim;
for j in 0..out_dim {

Check failure on line 692 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Clippy

it looks like you're manually copying between slices
output_data[base_idx + j] = expert_output_f32[i][j];
}
}
}

// Convert to tensor and move to correct device
let mut ys = Tensor::from_vec(output_data, (bs * seqlen, topk, out_dim), &candle_core::Device::Cpu)?
.to_dtype(dtype)?
.to_device(xs_2d.device())?;

// Reshape to (bs, seqlen, topk, out_dim)
let out_dim = ys.dim(D::Minus1)?;
ys = ys.reshape((bs, seqlen, topk, out_dim))?;

Ok(ys)
}

struct SlowMoeMlp {

Check warning on line 710 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

struct `SlowMoeMlp` is never constructed

Check warning on line 710 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / MSRV Check (1.90)

struct `SlowMoeMlp` is never constructed

Check failure on line 710 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Clippy

struct `SlowMoeMlp` is never constructed

Check warning on line 710 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

struct `SlowMoeMlp` is never constructed

Check warning on line 710 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Docs

struct `SlowMoeMlp` is never constructed

Check warning on line 710 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

struct `SlowMoeMlp` is never constructed

Check warning on line 710 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

struct `SlowMoeMlp` is never constructed

Check warning on line 710 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

struct `SlowMoeMlp` is never constructed

Check warning on line 710 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Test Suite (windows-latest, stable)

struct `SlowMoeMlp` is never constructed
gate: candle_nn::Linear,
experts: Vec<Mlp>,
norm_topk_prob: bool,
Expand All @@ -460,7 +715,7 @@
}

impl SlowMoeMlp {
fn new(

Check warning on line 718 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

associated items `new` and `forward` are never used

Check warning on line 718 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / MSRV Check (1.90)

associated items `new` and `forward` are never used

Check failure on line 718 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Clippy

associated items `new` and `forward` are never used

Check warning on line 718 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

associated items `new` and `forward` are never used

Check warning on line 718 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Docs

associated items `new` and `forward` are never used

Check warning on line 718 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

associated items `new` and `forward` are never used

Check warning on line 718 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

associated items `new` and `forward` are never used

Check warning on line 718 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

associated items `new` and `forward` are never used

Check warning on line 718 in mistralrs-core/src/models/qwen3_moe.rs

View workflow job for this annotation

GitHub Actions / Test Suite (windows-latest, stable)

associated items `new` and `forward` are never used
cfg: &Config,
vb: ShardedVarBuilder,
layer_device: Device,
Expand Down Expand Up @@ -554,7 +809,7 @@

enum MoeOrMlp {
FastMoe(FastMoeMlp),
SlowMoe(SlowMoeMlp),
SlowMoe(GroupedMoeMlp),
Mlp(Mlp),
}

Expand Down Expand Up @@ -614,7 +869,7 @@
comm,
)?)
} else {
MoeOrMlp::SlowMoe(SlowMoeMlp::new(
MoeOrMlp::SlowMoe(GroupedMoeMlp::new(
cfg,
vb,
mapper
Expand Down Expand Up @@ -947,10 +1202,14 @@
tensors.push((&mut layer.fused_down_proj, Some(i)));
}
MoeOrMlp::SlowMoe(layer) => {
for expert in &mut layer.experts {
tensors.push((&mut expert.gate_proj, Some(i)));
tensors.push((&mut expert.up_proj, Some(i)));
tensors.push((&mut expert.down_proj, Some(i)));
for expert in &mut layer.gate_proj_vec {
tensors.push((expert, Some(i)));
}
for expert in &mut layer.up_proj_vec {
tensors.push((expert, Some(i)));
}
for expert in &mut layer.down_proj_vec {
tensors.push((expert, Some(i)));
}
}
}
Expand Down
Loading