diff --git a/mistralrs-core/src/models/qwen3_moe.rs b/mistralrs-core/src/models/qwen3_moe.rs index 80eb93c2e4..491a2f3993 100644 --- a/mistralrs-core/src/models/qwen3_moe.rs +++ b/mistralrs-core/src/models/qwen3_moe.rs @@ -1,9 +1,10 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] use candle_core::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::Linear; use mistralrs_quant::{ - ColumnParallelLayer, FusedExperts, QuantMethod, QuantizedConfig, ReplicatedLayer, - RowParallelLayer, ShardedVarBuilder, + apply_immediate_isq, ColumnParallelLayer, FusedExperts, QuantMethod, QuantMethodConfig, + QuantizedConfig, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder, UnquantLinear, }; use serde::{Deserialize, Serialize}; use std::{collections::HashMap, sync::Arc}; @@ -443,6 +444,7 @@ impl FastMoeMlp { .gather_forward_autocast(&(up * gate.apply(&self.act)?)?, &indices)?; xs.squeeze(D::Minus2)? }; + dbg!(&ys, &scores); ys.to_dtype(DType::F32)? .broadcast_mul(&scores.unsqueeze(D::Minus1)?)? @@ -452,6 +454,259 @@ impl FastMoeMlp { } } +struct GroupedMoeMlp { + gate: Arc, + gate_proj_vec: Vec>, + up_proj_vec: Vec>, + down_proj_vec: Vec>, + act: Activation, + norm_topk_prob: bool, + num_experts_per_tok: usize, +} + +impl GroupedMoeMlp { + fn new( + cfg: &Config, + vb: ShardedVarBuilder, + layer_device: Device, + _comm: &Arc, + ) -> Result { + let num_experts = cfg.num_experts; + let gate = mistralrs_quant::linear_no_bias( + cfg.hidden_size, + num_experts, + &cfg.quantization_config, + vb.pp("gate").set_device(layer_device), + )?; + + let experts_vb = vb.pp("experts"); + let mut gate_proj_vec: Vec> = Vec::new(); + let mut up_proj_vec: Vec> = Vec::new(); + let mut down_proj_vec: Vec> = Vec::new(); + for i in 0..num_experts { + let vb = experts_vb.pp(i); + let gate_proj = vb.get( + (cfg.moe_intermediate_size, cfg.hidden_size), + "gate_proj.weight", + )?; + let up_proj = vb.get( + (cfg.moe_intermediate_size, cfg.hidden_size), + "up_proj.weight", + )?; + let down_proj = vb.get( + (cfg.hidden_size, cfg.moe_intermediate_size), + "down_proj.weight", + )?; + + gate_proj_vec.push(apply_immediate_isq( + Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( + Linear::new(gate_proj, None), + ))?), + vb.pp("gate_proj"), + )?); + up_proj_vec.push(apply_immediate_isq( + Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( + Linear::new(up_proj, None), + ))?), + vb.pp("up_proj"), + )?); + down_proj_vec.push(apply_immediate_isq( + Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( + Linear::new(down_proj, None), + ))?), + vb.pp("down_proj"), + )?); + } + + Ok(Self { + gate, + gate_proj_vec, + up_proj_vec, + down_proj_vec, + act: cfg.hidden_act, + norm_topk_prob: cfg.norm_topk_prob, + num_experts_per_tok: cfg.num_experts_per_tok, + }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let original_dtype = xs.dtype(); + + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + + let router_logits = self.gate.forward_autocast(xs)?; + let routing_weights = + candle_nn::ops::softmax_last_dim(&router_logits.to_dtype(DType::F32)?)?; + + let indices = routing_weights.arg_sort_last_dim(false)?.narrow( + D::Minus1, + 0, + self.num_experts_per_tok, + )?; + let mut scores = routing_weights.gather(&indices.contiguous()?, D::Minus1)?; + + if self.norm_topk_prob { + scores = scores.broadcast_div(&scores.sum_keepdim(D::Minus1)?)?; + } + + let ys = { + let gate = grouped_mat_mul(&xs, &self.gate_proj_vec, &indices)?; + let up = grouped_mat_mul(&xs, &self.up_proj_vec, &indices)?; + grouped_mat_mul( + &(up * gate.apply(&self.act)?)?, + &self.down_proj_vec, + &indices, + )? + }; + dbg!(&ys, &scores); + + ys.to_dtype(DType::F32)? + .broadcast_mul(&scores.unsqueeze(D::Minus1)?)? + .sum(D::Minus2)? + .reshape((b_size, seq_len, hidden_dim))? + .to_dtype(original_dtype) + } +} + +/// Grouped matrix multiplication for MoE layers. +/// +/// For each token, applies the selected experts and returns their outputs separately +/// (not accumulated), preserving the topk dimension for later weighting. +/// +/// - `xs`: Input tensor - can be: +/// - 2D: (tokens, hidden) +/// - 3D: (bs, seqlen, hidden) +/// - 4D: (bs, seqlen, topk, intermediate) - for down_proj after gate*up +/// - `ids`: (bs, seqlen, topk) - expert indices for each token +/// +/// Returns: (bs, seqlen, topk, output_dim) where each token has topk expert outputs +fn grouped_mat_mul( + xs: &Tensor, + experts: &Vec>, + ids: &Tensor, +) -> Result { + let (bs, seqlen, hidden, topk, is_4d) = match xs.dims().len() { + 2 => { + let (tokens, hidden) = xs.dims2()?; + let topk = ids.dim(D::Minus1)?; + (1, tokens, hidden, topk, false) + } + 3 => { + let (bs, seqlen, hidden) = xs.dims3()?; + let topk = ids.dim(D::Minus1)?; + (bs, seqlen, hidden, topk, false) + } + 4 => { + // Input is (bs, seqlen, topk, hidden) - intermediate activations + // In this case, ids should still be (bs, seqlen, topk) + let dims4 = xs.dims4()?; + (dims4.0, dims4.1, dims4.3, dims4.2, true) + } + _ => candle_core::bail!( + "grouped_mat_mul expects xs to be 2D, 3D, or 4D, got shape {:?}", + xs.dims() + ), + }; + + // Flatten to (bs * seqlen, hidden) or (bs * seqlen * topk, hidden) for 4D case + let (xs_2d, ids_vec) = if is_4d { + // For 4D input (bs, seqlen, topk, hidden), reshape to (bs * seqlen * topk, hidden) + let xs_flat = xs.reshape((bs * seqlen * topk, hidden))?; + // ids is (bs, seqlen, topk) + let ids_2d = ids.reshape((bs * seqlen, topk))?; + let ids_vec = ids_2d.to_vec2::()?; + (xs_flat, ids_vec) + } else { + // For 2D/3D input, flatten to (bs * seqlen, hidden) + let xs_2d = xs.reshape((bs * seqlen, hidden))?; + let ids_2d = ids.reshape((bs * seqlen, topk))?; + let ids_vec = ids_2d.to_vec2::()?; + (xs_2d, ids_vec) + }; + + // Build a map based on whether input is 4D or not + let expert_indices_map: Vec> = if is_4d { + // For 4D input: map expert_id -> flat indices into (bs*seqlen*topk) dimension + let mut map: Vec> = vec![vec![]; experts.len()]; + for (token_idx, expert_ids) in ids_vec.iter().enumerate() { + for (topk_idx, &expert_id) in expert_ids.iter().enumerate() { + let flat_idx = token_idx * topk + topk_idx; + map[expert_id as usize].push(flat_idx); + } + } + map + } else { + // For 2D/3D input: build the (token_idx, topk_idx) mapping as before + // We'll convert this to flat indices later + let mut map: Vec> = vec![vec![]; experts.len()]; + for (token_idx, expert_ids) in ids_vec.iter().enumerate() { + for (topk_idx, &expert_id) in expert_ids.iter().enumerate() { + // For non-4D, we'll use a different approach - store the flat output index + let flat_output_idx = token_idx * topk + topk_idx; + map[expert_id as usize].push(flat_output_idx); + } + } + map + }; + + + // Determine output dimension by processing first non-empty expert + let mut out_dim: Option = None; + let mut dtype: Option = None; + for expert_layer in experts.iter() { + let test_input = xs_2d.narrow(0, 0, 1)?; + let test_output = expert_layer.forward_autocast(&test_input)?; + out_dim = Some(test_output.dim(D::Minus1)?); + dtype = Some(test_output.dtype()); + break; + } + let out_dim = out_dim.ok_or_else(|| candle_core::Error::Msg("No experts available".to_string()))?; + let dtype = dtype.unwrap(); + + // Build output + let mut output_data: Vec = vec![0.0; bs * seqlen * topk * out_dim]; + + // Process each expert + for (expert_idx, expert_layer) in experts.iter().enumerate() { + let flat_indices = &expert_indices_map[expert_idx]; + if flat_indices.is_empty() { + continue; + } + + // Gather inputs for this expert + let indices_tensor = Tensor::new( + flat_indices.iter().map(|&x| x as u32).collect::>().as_slice(), + xs_2d.device(), + )?; + let expert_input = xs_2d.index_select(&indices_tensor, 0)?; + + // Run expert forward pass + let expert_output = expert_layer.forward_autocast(&expert_input)?; + + // Convert to f32 on CPU + let expert_output_f32 = expert_output.to_dtype(DType::F32)?.to_device(&candle_core::Device::Cpu)?.to_vec2::()?; + + // Place outputs in correct positions + for (i, &flat_idx) in flat_indices.iter().enumerate() { + let base_idx = flat_idx * out_dim; + for j in 0..out_dim { + output_data[base_idx + j] = expert_output_f32[i][j]; + } + } + } + + // Convert to tensor and move to correct device + let mut ys = Tensor::from_vec(output_data, (bs * seqlen, topk, out_dim), &candle_core::Device::Cpu)? + .to_dtype(dtype)? + .to_device(xs_2d.device())?; + + // Reshape to (bs, seqlen, topk, out_dim) + let out_dim = ys.dim(D::Minus1)?; + ys = ys.reshape((bs, seqlen, topk, out_dim))?; + + Ok(ys) +} + struct SlowMoeMlp { gate: candle_nn::Linear, experts: Vec, @@ -554,7 +809,7 @@ impl SlowMoeMlp { enum MoeOrMlp { FastMoe(FastMoeMlp), - SlowMoe(SlowMoeMlp), + SlowMoe(GroupedMoeMlp), Mlp(Mlp), } @@ -614,7 +869,7 @@ impl DecoderLayer { comm, )?) } else { - MoeOrMlp::SlowMoe(SlowMoeMlp::new( + MoeOrMlp::SlowMoe(GroupedMoeMlp::new( cfg, vb, mapper @@ -947,10 +1202,14 @@ impl IsqModel for Model { tensors.push((&mut layer.fused_down_proj, Some(i))); } MoeOrMlp::SlowMoe(layer) => { - for expert in &mut layer.experts { - tensors.push((&mut expert.gate_proj, Some(i))); - tensors.push((&mut expert.up_proj, Some(i))); - tensors.push((&mut expert.down_proj, Some(i))); + for expert in &mut layer.gate_proj_vec { + tensors.push((expert, Some(i))); + } + for expert in &mut layer.up_proj_vec { + tensors.push((expert, Some(i))); + } + for expert in &mut layer.down_proj_vec { + tensors.push((expert, Some(i))); } } }