From f7e67ffe3fd2f314065a9b050bf77da8b362a4b9 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 24 Oct 2025 23:28:07 -0400 Subject: [PATCH 1/6] Qwen3 grouped moe testing --- mistralrs-core/src/models/qwen3_moe.rs | 207 ++++++++++++++++++++++++- 1 file changed, 199 insertions(+), 8 deletions(-) diff --git a/mistralrs-core/src/models/qwen3_moe.rs b/mistralrs-core/src/models/qwen3_moe.rs index 80eb93c2e4..1854d474f0 100644 --- a/mistralrs-core/src/models/qwen3_moe.rs +++ b/mistralrs-core/src/models/qwen3_moe.rs @@ -1,9 +1,10 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] use candle_core::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::Linear; use mistralrs_quant::{ - ColumnParallelLayer, FusedExperts, QuantMethod, QuantizedConfig, ReplicatedLayer, - RowParallelLayer, ShardedVarBuilder, + apply_immediate_isq, ColumnParallelLayer, FusedExperts, QuantMethod, QuantMethodConfig, + QuantizedConfig, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder, UnquantLinear, }; use serde::{Deserialize, Serialize}; use std::{collections::HashMap, sync::Arc}; @@ -452,6 +453,192 @@ impl FastMoeMlp { } } +struct GroupedMoeMlp { + gate: Arc, + gate_proj_vec: Vec>, + up_proj_vec: Vec>, + down_proj_vec: Vec>, + act: Activation, + norm_topk_prob: bool, + num_experts_per_tok: usize, +} + +impl GroupedMoeMlp { + fn new( + cfg: &Config, + vb: ShardedVarBuilder, + layer_device: Device, + _comm: &Arc, + ) -> Result { + let num_experts = cfg.num_experts; + let gate = mistralrs_quant::linear_no_bias( + cfg.hidden_size, + num_experts, + &cfg.quantization_config, + vb.pp("gate").set_device(layer_device), + )?; + + let experts_vb = vb.pp("experts"); + let mut gate_proj_vec: Vec> = Vec::new(); + let mut up_proj_vec: Vec> = Vec::new(); + let mut down_proj_vec: Vec> = Vec::new(); + for i in 0..num_experts { + let vb = experts_vb.pp(i); + let gate_proj = vb.get( + (cfg.moe_intermediate_size, cfg.hidden_size), + "gate_proj.weight", + )?; + let up_proj = vb.get( + (cfg.moe_intermediate_size, cfg.hidden_size), + "up_proj.weight", + )?; + let down_proj = vb.get( + (cfg.hidden_size, cfg.moe_intermediate_size), + "down_proj.weight", + )?; + + gate_proj_vec.push(apply_immediate_isq( + Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( + Linear::new(gate_proj, None), + ))?), + vb.pp("gate_proj"), + )?); + up_proj_vec.push(apply_immediate_isq( + Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( + Linear::new(up_proj, None), + ))?), + vb.pp("up_proj"), + )?); + down_proj_vec.push(apply_immediate_isq( + Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( + Linear::new(down_proj, None), + ))?), + vb.pp("down_proj"), + )?); + } + + Ok(Self { + gate, + gate_proj_vec, + up_proj_vec, + down_proj_vec, + act: cfg.hidden_act, + norm_topk_prob: cfg.norm_topk_prob, + num_experts_per_tok: cfg.num_experts_per_tok, + }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let original_dtype = xs.dtype(); + + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + + let router_logits = self.gate.forward_autocast(xs)?; + let routing_weights = + candle_nn::ops::softmax_last_dim(&router_logits.to_dtype(DType::F32)?)?; + + let indices = routing_weights.arg_sort_last_dim(false)?.narrow( + D::Minus1, + 0, + self.num_experts_per_tok, + )?; + let mut scores = routing_weights.gather(&indices.contiguous()?, D::Minus1)?; + + if self.norm_topk_prob { + scores = scores.broadcast_div(&scores.sum_keepdim(D::Minus1)?)?; + } + + let ys = { + let gate = grouped_mat_mul(&xs, &self.gate_proj_vec, &indices)?; + let up = grouped_mat_mul(&xs, &self.up_proj_vec, &indices)?; + let xs = grouped_mat_mul( + &(up * gate.apply(&self.act)?)?, + &self.down_proj_vec, + &indices, + )?; + xs.squeeze(D::Minus2)? + }; + + ys.to_dtype(DType::F32)? + .broadcast_mul(&scores.unsqueeze(D::Minus1)?)? + .sum(D::Minus2)? + .reshape((b_size, seq_len, hidden_dim))? + .to_dtype(original_dtype) + } +} + +/// - `xs`: (bs, seqlen, hidden) or (tokens, hidden) +/// - `ids`: (bs, seqlen, topk) or (tokens, topk) +fn grouped_mat_mul( + xs: &Tensor, + experts: &Vec>, + ids: &Tensor, +) -> Result { + // Handle both 2D and 3D inputs + let (original_shape, xs_2d, ids_2d) = match xs.dims().len() { + 2 => { + // Already 2D: (tokens, hidden) + (None, xs.clone(), ids.clone()) + } + 3 => { + // 3D: (bs, seqlen, hidden) -> reshape to (bs * seqlen, hidden) + let (bs, seqlen, hidden) = xs.dims3()?; + let xs_2d = xs.reshape((bs * seqlen, hidden))?; + let ids_2d = ids.reshape((bs * seqlen, ()))?; // (bs * seqlen, topk) + (Some((bs, seqlen)), xs_2d, ids_2d) + } + _ => candle_core::bail!( + "grouped_mat_mul expects xs to be 2D or 3D, got shape {:?}", + xs.dims() + ), + }; + + let (_n_tok, hidden) = xs_2d.dims2()?; + let ids_vec = ids_2d.to_vec2::()?; + + // Build mapping: expert_id -> list of token indices + let mut ids_to_sorted = vec![vec![]; experts.len()]; + for expert_id in 0..experts.len() { + for (i, token) in ids_vec.iter().enumerate() { + for selected in token { + if *selected as usize == expert_id { + ids_to_sorted[expert_id].push(i as u32); + } + } + } + } + + // Initialize output tensor with zeros + let mut ys = xs_2d.zeros_like()?; + + // Process each expert + for (expert_idx, expert_layer) in experts.iter().enumerate() { + let expert_tokens = &ids_to_sorted[expert_idx]; + if expert_tokens.is_empty() { + continue; + } + + // Gather tokens for this expert + let token_indices = Tensor::new(expert_tokens.as_slice(), xs_2d.device())?; + let expert_input = xs_2d + .index_select(&token_indices, 0)? + .reshape(((), hidden))?; + + // Run expert forward pass + let expert_output = MatMul.qmethod_matmul(&expert_input, &**expert_layer)?; + + // Add expert outputs back to their original positions + ys = ys.index_add(&token_indices, &expert_output, 0)?; + } + + // Reshape back to original shape if needed + if let Some((bs, seqlen)) = original_shape { + ys = ys.reshape((bs, seqlen, hidden))?; + } + + Ok(ys) +} + struct SlowMoeMlp { gate: candle_nn::Linear, experts: Vec, @@ -554,7 +741,7 @@ impl SlowMoeMlp { enum MoeOrMlp { FastMoe(FastMoeMlp), - SlowMoe(SlowMoeMlp), + SlowMoe(GroupedMoeMlp), Mlp(Mlp), } @@ -614,7 +801,7 @@ impl DecoderLayer { comm, )?) } else { - MoeOrMlp::SlowMoe(SlowMoeMlp::new( + MoeOrMlp::SlowMoe(GroupedMoeMlp::new( cfg, vb, mapper @@ -947,10 +1134,14 @@ impl IsqModel for Model { tensors.push((&mut layer.fused_down_proj, Some(i))); } MoeOrMlp::SlowMoe(layer) => { - for expert in &mut layer.experts { - tensors.push((&mut expert.gate_proj, Some(i))); - tensors.push((&mut expert.up_proj, Some(i))); - tensors.push((&mut expert.down_proj, Some(i))); + for expert in &mut layer.gate_proj_vec { + tensors.push((expert, Some(i))); + } + for expert in &mut layer.up_proj_vec { + tensors.push((expert, Some(i))); + } + for expert in &mut layer.down_proj_vec { + tensors.push((expert, Some(i))); } } } From 4e2561682d4beaf0574b392a04aa120405fbdff6 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 24 Oct 2025 23:48:19 -0400 Subject: [PATCH 2/6] Fix forward autocast --- mistralrs-core/src/models/qwen3_moe.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/qwen3_moe.rs b/mistralrs-core/src/models/qwen3_moe.rs index 1854d474f0..1012edc71d 100644 --- a/mistralrs-core/src/models/qwen3_moe.rs +++ b/mistralrs-core/src/models/qwen3_moe.rs @@ -625,7 +625,7 @@ fn grouped_mat_mul( .reshape(((), hidden))?; // Run expert forward pass - let expert_output = MatMul.qmethod_matmul(&expert_input, &**expert_layer)?; + let expert_output = expert_layer.forward_autocast(&expert_input)?; // Add expert outputs back to their original positions ys = ys.index_add(&token_indices, &expert_output, 0)?; From 2c18182ccb4ef9809bc7f85849a16acd66c132ff Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 25 Oct 2025 07:22:29 -0400 Subject: [PATCH 3/6] Fix shape? --- mistralrs-core/src/models/qwen3_moe.rs | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/mistralrs-core/src/models/qwen3_moe.rs b/mistralrs-core/src/models/qwen3_moe.rs index 1012edc71d..1137b50e80 100644 --- a/mistralrs-core/src/models/qwen3_moe.rs +++ b/mistralrs-core/src/models/qwen3_moe.rs @@ -593,7 +593,7 @@ fn grouped_mat_mul( ), }; - let (_n_tok, hidden) = xs_2d.dims2()?; + let (n_tok, _in_dim) = xs_2d.dims2()?; let ids_vec = ids_2d.to_vec2::()?; // Build mapping: expert_id -> list of token indices @@ -608,10 +608,8 @@ fn grouped_mat_mul( } } - // Initialize output tensor with zeros - let mut ys = xs_2d.zeros_like()?; + let mut ys: Option = None; - // Process each expert for (expert_idx, expert_layer) in experts.iter().enumerate() { let expert_tokens = &ids_to_sorted[expert_idx]; if expert_tokens.is_empty() { @@ -620,20 +618,26 @@ fn grouped_mat_mul( // Gather tokens for this expert let token_indices = Tensor::new(expert_tokens.as_slice(), xs_2d.device())?; - let expert_input = xs_2d - .index_select(&token_indices, 0)? - .reshape(((), hidden))?; + let expert_input = xs_2d.index_select(&token_indices, 0)?; // Run expert forward pass let expert_output = expert_layer.forward_autocast(&expert_input)?; + if ys.is_none() { + let out_dim = expert_output.dim(D::Minus1)?; + ys = Some(Tensor::zeros((n_tok, out_dim), expert_output.dtype(), xs_2d.device())?); + } + // Add expert outputs back to their original positions - ys = ys.index_add(&token_indices, &expert_output, 0)?; + let ys_tensor = ys.as_ref().unwrap(); + ys = Some(ys_tensor.index_add(&token_indices, &expert_output, 0)?); } - // Reshape back to original shape if needed + let mut ys = ys.ok_or_else(|| candle_core::Error::Msg("No experts were selected".to_string()))?; + if let Some((bs, seqlen)) = original_shape { - ys = ys.reshape((bs, seqlen, hidden))?; + let out_dim = ys.dim(D::Minus1)?; + ys = ys.reshape((bs, seqlen, out_dim))?; } Ok(ys) From acf15cfa331bdae6596e2405855055caeebc6577 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 25 Oct 2025 07:37:26 -0400 Subject: [PATCH 4/6] Fix shape? --- mistralrs-core/src/models/qwen3_moe.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mistralrs-core/src/models/qwen3_moe.rs b/mistralrs-core/src/models/qwen3_moe.rs index 1137b50e80..9291d66083 100644 --- a/mistralrs-core/src/models/qwen3_moe.rs +++ b/mistralrs-core/src/models/qwen3_moe.rs @@ -551,13 +551,13 @@ impl GroupedMoeMlp { let ys = { let gate = grouped_mat_mul(&xs, &self.gate_proj_vec, &indices)?; let up = grouped_mat_mul(&xs, &self.up_proj_vec, &indices)?; - let xs = grouped_mat_mul( + grouped_mat_mul( &(up * gate.apply(&self.act)?)?, &self.down_proj_vec, &indices, - )?; - xs.squeeze(D::Minus2)? + )? }; + dbg!(&ys, &scores); ys.to_dtype(DType::F32)? .broadcast_mul(&scores.unsqueeze(D::Minus1)?)? From 0f1eddf560a5db28c999c00e2a6e0f766651d57c Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 25 Oct 2025 07:52:03 -0400 Subject: [PATCH 5/6] Fix shape? --- mistralrs-core/src/models/qwen3_moe.rs | 103 ++++++++++++++++--------- 1 file changed, 65 insertions(+), 38 deletions(-) diff --git a/mistralrs-core/src/models/qwen3_moe.rs b/mistralrs-core/src/models/qwen3_moe.rs index 9291d66083..e55d8e2d58 100644 --- a/mistralrs-core/src/models/qwen3_moe.rs +++ b/mistralrs-core/src/models/qwen3_moe.rs @@ -444,6 +444,7 @@ impl FastMoeMlp { .gather_forward_autocast(&(up * gate.apply(&self.act)?)?, &indices)?; xs.squeeze(D::Minus2)? }; + dbg!(&ys, &scores); ys.to_dtype(DType::F32)? .broadcast_mul(&scores.unsqueeze(D::Minus1)?)? @@ -567,25 +568,30 @@ impl GroupedMoeMlp { } } +/// Grouped matrix multiplication for MoE layers. +/// +/// For each token, applies the selected experts and returns their outputs separately +/// (not accumulated), preserving the topk dimension for later weighting. +/// /// - `xs`: (bs, seqlen, hidden) or (tokens, hidden) -/// - `ids`: (bs, seqlen, topk) or (tokens, topk) +/// - `ids`: (bs, seqlen, topk) - expert indices for each token +/// +/// Returns: (bs, seqlen, topk, output_dim) where each token has topk expert outputs fn grouped_mat_mul( xs: &Tensor, experts: &Vec>, ids: &Tensor, ) -> Result { - // Handle both 2D and 3D inputs - let (original_shape, xs_2d, ids_2d) = match xs.dims().len() { + let (bs, seqlen, hidden, topk) = match xs.dims().len() { 2 => { - // Already 2D: (tokens, hidden) - (None, xs.clone(), ids.clone()) + let (tokens, hidden) = xs.dims2()?; + let topk = ids.dim(D::Minus1)?; + (1, tokens, hidden, topk) } 3 => { - // 3D: (bs, seqlen, hidden) -> reshape to (bs * seqlen, hidden) let (bs, seqlen, hidden) = xs.dims3()?; - let xs_2d = xs.reshape((bs * seqlen, hidden))?; - let ids_2d = ids.reshape((bs * seqlen, ()))?; // (bs * seqlen, topk) - (Some((bs, seqlen)), xs_2d, ids_2d) + let topk = ids.dim(D::Minus1)?; + (bs, seqlen, hidden, topk) } _ => candle_core::bail!( "grouped_mat_mul expects xs to be 2D or 3D, got shape {:?}", @@ -593,52 +599,73 @@ fn grouped_mat_mul( ), }; - let (n_tok, _in_dim) = xs_2d.dims2()?; + // Flatten to (bs * seqlen, hidden) and (bs * seqlen, topk) + let xs_2d = xs.reshape((bs * seqlen, hidden))?; + let ids_2d = ids.reshape((bs * seqlen, topk))?; let ids_vec = ids_2d.to_vec2::()?; - // Build mapping: expert_id -> list of token indices - let mut ids_to_sorted = vec![vec![]; experts.len()]; - for expert_id in 0..experts.len() { - for (i, token) in ids_vec.iter().enumerate() { - for selected in token { - if *selected as usize == expert_id { - ids_to_sorted[expert_id].push(i as u32); - } - } + // Build a map: (token_idx, topk_idx) -> expert_id + // We need to track which expert corresponds to which position in the topk dimension + let mut expert_token_map: Vec> = vec![vec![]; experts.len()]; + for (token_idx, expert_ids) in ids_vec.iter().enumerate() { + for (topk_idx, &expert_id) in expert_ids.iter().enumerate() { + expert_token_map[expert_id as usize].push((token_idx, topk_idx)); } } - let mut ys: Option = None; + // Determine output dimension by processing first non-empty expert + let mut out_dim: Option = None; + let mut dtype: Option = None; + for expert_layer in experts.iter() { + // Take a single token as a test + let test_input = xs_2d.narrow(0, 0, 1)?; + let test_output = expert_layer.forward_autocast(&test_input)?; + out_dim = Some(test_output.dim(D::Minus1)?); + dtype = Some(test_output.dtype()); + break; + } + let out_dim = out_dim.ok_or_else(|| candle_core::Error::Msg("No experts available".to_string()))?; + let dtype = dtype.unwrap(); + + // Build output by collecting all (token, topk) outputs + // Use a Vec to build the output, then convert to tensor + let mut output_data: Vec = vec![0.0; bs * seqlen * topk * out_dim]; + // Process each expert and place outputs in the correct positions for (expert_idx, expert_layer) in experts.iter().enumerate() { - let expert_tokens = &ids_to_sorted[expert_idx]; - if expert_tokens.is_empty() { + let token_topk_pairs = &expert_token_map[expert_idx]; + if token_topk_pairs.is_empty() { continue; } - // Gather tokens for this expert - let token_indices = Tensor::new(expert_tokens.as_slice(), xs_2d.device())?; - let expert_input = xs_2d.index_select(&token_indices, 0)?; + // Gather inputs for this expert + let token_indices: Vec = token_topk_pairs.iter().map(|(t, _)| *t as u32).collect(); + let token_indices_tensor = Tensor::new(token_indices.as_slice(), xs_2d.device())?; + let expert_input = xs_2d.index_select(&token_indices_tensor, 0)?; // Run expert forward pass - let expert_output = expert_layer.forward_autocast(&expert_input)?; + let expert_output = expert_layer.forward_autocast(&expert_input)?; // (n_tokens_for_expert, out_dim) - if ys.is_none() { - let out_dim = expert_output.dim(D::Minus1)?; - ys = Some(Tensor::zeros((n_tok, out_dim), expert_output.dtype(), xs_2d.device())?); - } + // Convert to f32 on CPU for easy manipulation + let expert_output_f32 = expert_output.to_dtype(DType::F32)?.to_device(&candle_core::Device::Cpu)?.to_vec2::()?; - // Add expert outputs back to their original positions - let ys_tensor = ys.as_ref().unwrap(); - ys = Some(ys_tensor.index_add(&token_indices, &expert_output, 0)?); + // Place outputs in the correct (token_idx, topk_idx) positions + for (i, &(token_idx, topk_idx)) in token_topk_pairs.iter().enumerate() { + let base_idx = (token_idx * topk * out_dim) + (topk_idx * out_dim); + for j in 0..out_dim { + output_data[base_idx + j] = expert_output_f32[i][j]; + } + } } - let mut ys = ys.ok_or_else(|| candle_core::Error::Msg("No experts were selected".to_string()))?; + // Convert to tensor and move to correct device + let mut ys = Tensor::from_vec(output_data, (bs * seqlen, topk, out_dim), &candle_core::Device::Cpu)? + .to_dtype(dtype)? + .to_device(xs_2d.device())?; - if let Some((bs, seqlen)) = original_shape { - let out_dim = ys.dim(D::Minus1)?; - ys = ys.reshape((bs, seqlen, out_dim))?; - } + // Reshape to (bs, seqlen, topk, out_dim) + let out_dim = ys.dim(D::Minus1)?; + ys = ys.reshape((bs, seqlen, topk, out_dim))?; Ok(ys) } From c112e4454cc894186ae1ba2d90a2792576491dc0 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 25 Oct 2025 07:58:27 -0400 Subject: [PATCH 6/6] Fix shape? --- mistralrs-core/src/models/qwen3_moe.rs | 99 ++++++++++++++++++-------- 1 file changed, 68 insertions(+), 31 deletions(-) diff --git a/mistralrs-core/src/models/qwen3_moe.rs b/mistralrs-core/src/models/qwen3_moe.rs index e55d8e2d58..491a2f3993 100644 --- a/mistralrs-core/src/models/qwen3_moe.rs +++ b/mistralrs-core/src/models/qwen3_moe.rs @@ -573,7 +573,10 @@ impl GroupedMoeMlp { /// For each token, applies the selected experts and returns their outputs separately /// (not accumulated), preserving the topk dimension for later weighting. /// -/// - `xs`: (bs, seqlen, hidden) or (tokens, hidden) +/// - `xs`: Input tensor - can be: +/// - 2D: (tokens, hidden) +/// - 3D: (bs, seqlen, hidden) +/// - 4D: (bs, seqlen, topk, intermediate) - for down_proj after gate*up /// - `ids`: (bs, seqlen, topk) - expert indices for each token /// /// Returns: (bs, seqlen, topk, output_dim) where each token has topk expert outputs @@ -582,42 +585,75 @@ fn grouped_mat_mul( experts: &Vec>, ids: &Tensor, ) -> Result { - let (bs, seqlen, hidden, topk) = match xs.dims().len() { + let (bs, seqlen, hidden, topk, is_4d) = match xs.dims().len() { 2 => { let (tokens, hidden) = xs.dims2()?; let topk = ids.dim(D::Minus1)?; - (1, tokens, hidden, topk) + (1, tokens, hidden, topk, false) } 3 => { let (bs, seqlen, hidden) = xs.dims3()?; let topk = ids.dim(D::Minus1)?; - (bs, seqlen, hidden, topk) + (bs, seqlen, hidden, topk, false) + } + 4 => { + // Input is (bs, seqlen, topk, hidden) - intermediate activations + // In this case, ids should still be (bs, seqlen, topk) + let dims4 = xs.dims4()?; + (dims4.0, dims4.1, dims4.3, dims4.2, true) } _ => candle_core::bail!( - "grouped_mat_mul expects xs to be 2D or 3D, got shape {:?}", + "grouped_mat_mul expects xs to be 2D, 3D, or 4D, got shape {:?}", xs.dims() ), }; - // Flatten to (bs * seqlen, hidden) and (bs * seqlen, topk) - let xs_2d = xs.reshape((bs * seqlen, hidden))?; - let ids_2d = ids.reshape((bs * seqlen, topk))?; - let ids_vec = ids_2d.to_vec2::()?; - - // Build a map: (token_idx, topk_idx) -> expert_id - // We need to track which expert corresponds to which position in the topk dimension - let mut expert_token_map: Vec> = vec![vec![]; experts.len()]; - for (token_idx, expert_ids) in ids_vec.iter().enumerate() { - for (topk_idx, &expert_id) in expert_ids.iter().enumerate() { - expert_token_map[expert_id as usize].push((token_idx, topk_idx)); + // Flatten to (bs * seqlen, hidden) or (bs * seqlen * topk, hidden) for 4D case + let (xs_2d, ids_vec) = if is_4d { + // For 4D input (bs, seqlen, topk, hidden), reshape to (bs * seqlen * topk, hidden) + let xs_flat = xs.reshape((bs * seqlen * topk, hidden))?; + // ids is (bs, seqlen, topk) + let ids_2d = ids.reshape((bs * seqlen, topk))?; + let ids_vec = ids_2d.to_vec2::()?; + (xs_flat, ids_vec) + } else { + // For 2D/3D input, flatten to (bs * seqlen, hidden) + let xs_2d = xs.reshape((bs * seqlen, hidden))?; + let ids_2d = ids.reshape((bs * seqlen, topk))?; + let ids_vec = ids_2d.to_vec2::()?; + (xs_2d, ids_vec) + }; + + // Build a map based on whether input is 4D or not + let expert_indices_map: Vec> = if is_4d { + // For 4D input: map expert_id -> flat indices into (bs*seqlen*topk) dimension + let mut map: Vec> = vec![vec![]; experts.len()]; + for (token_idx, expert_ids) in ids_vec.iter().enumerate() { + for (topk_idx, &expert_id) in expert_ids.iter().enumerate() { + let flat_idx = token_idx * topk + topk_idx; + map[expert_id as usize].push(flat_idx); + } } - } + map + } else { + // For 2D/3D input: build the (token_idx, topk_idx) mapping as before + // We'll convert this to flat indices later + let mut map: Vec> = vec![vec![]; experts.len()]; + for (token_idx, expert_ids) in ids_vec.iter().enumerate() { + for (topk_idx, &expert_id) in expert_ids.iter().enumerate() { + // For non-4D, we'll use a different approach - store the flat output index + let flat_output_idx = token_idx * topk + topk_idx; + map[expert_id as usize].push(flat_output_idx); + } + } + map + }; + // Determine output dimension by processing first non-empty expert let mut out_dim: Option = None; let mut dtype: Option = None; for expert_layer in experts.iter() { - // Take a single token as a test let test_input = xs_2d.narrow(0, 0, 1)?; let test_output = expert_layer.forward_autocast(&test_input)?; out_dim = Some(test_output.dim(D::Minus1)?); @@ -627,31 +663,32 @@ fn grouped_mat_mul( let out_dim = out_dim.ok_or_else(|| candle_core::Error::Msg("No experts available".to_string()))?; let dtype = dtype.unwrap(); - // Build output by collecting all (token, topk) outputs - // Use a Vec to build the output, then convert to tensor + // Build output let mut output_data: Vec = vec![0.0; bs * seqlen * topk * out_dim]; - // Process each expert and place outputs in the correct positions + // Process each expert for (expert_idx, expert_layer) in experts.iter().enumerate() { - let token_topk_pairs = &expert_token_map[expert_idx]; - if token_topk_pairs.is_empty() { + let flat_indices = &expert_indices_map[expert_idx]; + if flat_indices.is_empty() { continue; } // Gather inputs for this expert - let token_indices: Vec = token_topk_pairs.iter().map(|(t, _)| *t as u32).collect(); - let token_indices_tensor = Tensor::new(token_indices.as_slice(), xs_2d.device())?; - let expert_input = xs_2d.index_select(&token_indices_tensor, 0)?; + let indices_tensor = Tensor::new( + flat_indices.iter().map(|&x| x as u32).collect::>().as_slice(), + xs_2d.device(), + )?; + let expert_input = xs_2d.index_select(&indices_tensor, 0)?; // Run expert forward pass - let expert_output = expert_layer.forward_autocast(&expert_input)?; // (n_tokens_for_expert, out_dim) + let expert_output = expert_layer.forward_autocast(&expert_input)?; - // Convert to f32 on CPU for easy manipulation + // Convert to f32 on CPU let expert_output_f32 = expert_output.to_dtype(DType::F32)?.to_device(&candle_core::Device::Cpu)?.to_vec2::()?; - // Place outputs in the correct (token_idx, topk_idx) positions - for (i, &(token_idx, topk_idx)) in token_topk_pairs.iter().enumerate() { - let base_idx = (token_idx * topk * out_dim) + (topk_idx * out_dim); + // Place outputs in correct positions + for (i, &flat_idx) in flat_indices.iter().enumerate() { + let base_idx = flat_idx * out_dim; for j in 0..out_dim { output_data[base_idx + j] = expert_output_f32[i][j]; }