From c07d3a5235f1390a7a56bf0aa00d2ed92e32f36b Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Sat, 9 Aug 2025 03:57:49 +0000 Subject: [PATCH 1/2] Fix extra-long context kernel launch issue for indexed moe forward --- candle-core/src/quantized/cuda.rs | 79 +++++++++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 5 deletions(-) diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 0ed6e63629..a19ce9803a 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -67,6 +67,30 @@ fn quantize_q8_1( Ok(()) } +fn quantize_q8_1_view( + src: &CudaView, + dst: &CudaView, + elem_count: usize, + ky: usize, + dev: &CudaDevice, +) -> Result<()> { + let kx = elem_count; + let kx_padded = pad(kx, MATRIX_ROW_PADDING); + let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE); + let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (num_blocks as u32, ky as u32, 1), + block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1), + shared_mem_bytes: 0, + }; + let mut builder = func.builder(); + builder.arg(src); + builder.arg(dst); + barg!(builder, kx as i32, kx_padded as i32); + unsafe { builder.launch(cfg) }.w()?; + Ok(()) +} + fn dequantize_f32( data: &PaddedCudaSlice, dtype: GgmlDType, @@ -377,7 +401,7 @@ fn indexed_moe_forward_fused_q8_1_input( weight: &CudaView, w_shape: &crate::Shape, //[num_experts, n, k] w_dtype: GgmlDType, - input: &CudaView, + input: &CudaSlice, in_shape: &crate::Shape, //[batch, topk or 1, k] ids: &CudaView, idx_shape: &crate::Shape, //[batch, topk] @@ -391,12 +415,57 @@ fn indexed_moe_forward_fused_q8_1_input( assert!(batch == idx_shape.dims()[0], "batch dim not match!"); //quant input into q8_1 + let total_rows = batch * input_dim1; let k_padded = pad(k, MATRIX_ROW_PADDING); - let y_size_in_bytes = - batch * input_dim1 * k_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); + // Get Q8_1 metadata. + let q8_1_block_size = GgmlDType::Q8_1.block_size(); + let q8_1_type_size = GgmlDType::Q8_1.type_size(); + + // Calculate the size of the output buffer in bytes. + let num_blocks_per_row = k_padded / q8_1_block_size; + let dst_row_size_bytes = num_blocks_per_row * q8_1_type_size; + let y_size_in_bytes = total_rows * dst_row_size_bytes; let mut input_quant = unsafe { dev.alloc::(y_size_in_bytes)? }; - quantize_q8_1(input, &mut input_quant, k, batch * input_dim1, dev)?; + const CHUNK_SIZE: usize = 65535; // gridDim.y limit + + if total_rows > CHUNK_SIZE { + let mut rows_processed = 0; + while rows_processed < total_rows { + // --- calculate the number of rows for this chunk --- + let remaining_rows = total_rows - rows_processed; + let rows_in_chunk = std::cmp::min(CHUNK_SIZE, remaining_rows); + + // --- slice the source (f32) tensor by elements --- + let src_start_elem = rows_processed * k; + let src_num_elems = rows_in_chunk * k; + let src_chunk = input.slice(src_start_elem..(src_start_elem + src_num_elems)); + + // --- slice the destination (u8) tensor by bytes --- + let dst_start_byte = rows_processed * dst_row_size_bytes; + let dst_num_bytes = rows_in_chunk * dst_row_size_bytes; + let dst_chunk = input_quant.slice(dst_start_byte..(dst_start_byte + dst_num_bytes)); + + // Launch the kernel for the current chunk. + quantize_q8_1_view( + &src_chunk, + &dst_chunk, + k, + rows_in_chunk, // This is our gridDim.y, now <= 65535 + dev, + )?; + + rows_processed += rows_in_chunk; + } + } else { + quantize_q8_1( + &input.slice(0..), + &mut input_quant, + k, + batch * input_dim1, + dev, + )?; + } // output buffer let outsize = batch * topk * n; let out = unsafe { dev.alloc::(outsize)? }; @@ -469,7 +538,7 @@ impl QCudaStorage { &self.data.inner.slice(0..), self_shape, //[num_experts, n, k] self.dtype(), - &input_storage.slice(0..), + &input_storage, input_l.shape(), //[batch, topk or 1, k] &ids_storage.slice(0..), ids_l.shape(), //[batch, topk] From 2f8023af516ba190b0cb2040cd0461830a359375 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Sat, 9 Aug 2025 15:27:42 +0000 Subject: [PATCH 2/2] Fix all entries for input quant with quantize_q8_1 --- candle-core/src/quantized/cuda.rs | 152 +++++++++++++----------------- 1 file changed, 65 insertions(+), 87 deletions(-) diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index a19ce9803a..c9df617d1f 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -44,50 +44,59 @@ fn pad(p: usize, q: usize) -> usize { } fn quantize_q8_1( - src: &CudaView, + src: &CudaSlice, dst: &mut CudaSlice, - elem_count: usize, + k: usize, ky: usize, dev: &CudaDevice, ) -> Result<()> { - let kx = elem_count; - let kx_padded = pad(kx, MATRIX_ROW_PADDING); + let kx_padded = pad(k, MATRIX_ROW_PADDING); let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE); - let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?; - let cfg = cudarc::driver::LaunchConfig { - grid_dim: (num_blocks as u32, ky as u32, 1), - block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1), - shared_mem_bytes: 0, - }; - let mut builder = func.builder(); - builder.arg(src); - builder.arg(dst); - barg!(builder, kx as i32, kx_padded as i32); - unsafe { builder.launch(cfg) }.w()?; - Ok(()) -} -fn quantize_q8_1_view( - src: &CudaView, - dst: &CudaView, - elem_count: usize, - ky: usize, - dev: &CudaDevice, -) -> Result<()> { - let kx = elem_count; - let kx_padded = pad(kx, MATRIX_ROW_PADDING); - let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE); + let total_rows = ky; + // Get Q8_1 metadata. + let q8_1_block_size = GgmlDType::Q8_1.block_size(); + let q8_1_type_size = GgmlDType::Q8_1.type_size(); + + // Calculate the size of the output buffer in bytes. + let num_blocks_per_row = kx_padded / q8_1_block_size; + let dst_row_size_bytes = num_blocks_per_row * q8_1_type_size; + + const CHUNK_SIZE: usize = 65535; // gridDim.y limit let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?; - let cfg = cudarc::driver::LaunchConfig { - grid_dim: (num_blocks as u32, ky as u32, 1), - block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1), - shared_mem_bytes: 0, - }; - let mut builder = func.builder(); - builder.arg(src); - builder.arg(dst); - barg!(builder, kx as i32, kx_padded as i32); - unsafe { builder.launch(cfg) }.w()?; + + let mut rows_processed = 0; + while rows_processed < total_rows { + // --- calculate the number of rows for this chunk --- + let remaining_rows = total_rows - rows_processed; + // This is our gridDim.y, now <= 65535 + let rows_in_chunk = std::cmp::min(CHUNK_SIZE, remaining_rows); + + // --- slice the source (f32) tensor by elements --- + let src_start_elem = rows_processed * k; + let src_num_elems = rows_in_chunk * k; + let src_chunk = src.slice(src_start_elem..(src_start_elem + src_num_elems)); + + // --- slice the destination (u8) tensor by bytes --- + let dst_start_byte = rows_processed * dst_row_size_bytes; + let dst_num_bytes = rows_in_chunk * dst_row_size_bytes; + let dst_chunk = dst.slice(dst_start_byte..(dst_start_byte + dst_num_bytes)); + + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (num_blocks as u32, rows_in_chunk as u32, 1), + block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1), + shared_mem_bytes: 0, + }; + + let mut builder = func.builder(); + builder.arg(&src_chunk); + builder.arg(&dst_chunk); + barg!(builder, k as i32, kx_padded as i32); + unsafe { builder.launch(cfg) }.w()?; + + rows_processed += rows_in_chunk; + } + Ok(()) } @@ -213,7 +222,7 @@ fn dequantize_f16( fn dequantize_mul_mat_vec( data: &PaddedCudaSlice, - y: &CudaView, + y: &CudaSlice, dtype: GgmlDType, ncols: usize, nrows: usize, @@ -259,7 +268,7 @@ fn dequantize_mul_mat_vec( fn mul_mat_vec_via_q8_1( data: &PaddedCudaSlice, - y: &CudaView, + y: &CudaSlice, dtype: GgmlDType, ncols: usize, nrows: usize, @@ -330,7 +339,7 @@ fn mul_mat_vec_via_q8_1( #[allow(clippy::too_many_arguments)] fn mul_mat_via_q8_1( data: &PaddedCudaSlice, - y: &CudaView, + y: &CudaSlice, dtype: GgmlDType, x_rows: usize, x_cols: usize, @@ -427,45 +436,8 @@ fn indexed_moe_forward_fused_q8_1_input( let y_size_in_bytes = total_rows * dst_row_size_bytes; let mut input_quant = unsafe { dev.alloc::(y_size_in_bytes)? }; - const CHUNK_SIZE: usize = 65535; // gridDim.y limit + quantize_q8_1(&input, &mut input_quant, k, total_rows, dev)?; - if total_rows > CHUNK_SIZE { - let mut rows_processed = 0; - while rows_processed < total_rows { - // --- calculate the number of rows for this chunk --- - let remaining_rows = total_rows - rows_processed; - let rows_in_chunk = std::cmp::min(CHUNK_SIZE, remaining_rows); - - // --- slice the source (f32) tensor by elements --- - let src_start_elem = rows_processed * k; - let src_num_elems = rows_in_chunk * k; - let src_chunk = input.slice(src_start_elem..(src_start_elem + src_num_elems)); - - // --- slice the destination (u8) tensor by bytes --- - let dst_start_byte = rows_processed * dst_row_size_bytes; - let dst_num_bytes = rows_in_chunk * dst_row_size_bytes; - let dst_chunk = input_quant.slice(dst_start_byte..(dst_start_byte + dst_num_bytes)); - - // Launch the kernel for the current chunk. - quantize_q8_1_view( - &src_chunk, - &dst_chunk, - k, - rows_in_chunk, // This is our gridDim.y, now <= 65535 - dev, - )?; - - rows_processed += rows_in_chunk; - } - } else { - quantize_q8_1( - &input.slice(0..), - &mut input_quant, - k, - batch * input_dim1, - dev, - )?; - } // output buffer let outsize = batch * topk * n; let out = unsafe { dev.alloc::(outsize)? }; @@ -780,8 +752,11 @@ impl QCudaStorage { ) -> Result<(CudaStorage, crate::Shape)> { let (nrows, ncols) = self_shape.dims2()?; let rhs = rhs.as_cuda_slice::()?; - let rhs = match rhs_l.contiguous_offsets() { - Some((o1, o2)) => rhs.slice(o1..o2), + match rhs_l.contiguous_offsets() { + Some((o1, _)) => assert!( + o1 == 0, + "sliced input is not supported in quantized matmul!" + ), None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?, }; let (b_size, k) = match rhs_l.shape().dims() { @@ -835,8 +810,11 @@ impl QCudaStorage { storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)? } else { let storage = storage.as_cuda_slice::()?; - let storage = match layout.contiguous_offsets() { - Some((o1, o2)) => storage.slice(o1..o2), + match layout.contiguous_offsets() { + Some((o1, _)) => assert!( + o1 == 0, + "sliced input is not supported in quantized matmul!" + ), None => Err(crate::Error::RequiresContiguous { op: "quantized-matmul", } @@ -895,7 +873,7 @@ mod test { let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes)? }; let vs: Vec = (0..el).map(|v| v as f32).collect(); let y = dev.memcpy_stod(&vs)?; - quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?; + quantize_q8_1(&y, &mut y_q8_1, el, 1, &dev)?; Ok(()) } @@ -909,7 +887,7 @@ mod test { xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; let cuda_storage = mul_mat_vec_via_q8_1( &xs.data, - &y.slice(..), + &y, /* dtype */ GgmlDType::Q4_0, /* ncols */ ncols, /* nrows */ 1, @@ -925,7 +903,7 @@ mod test { let cuda_storage = dequantize_mul_mat_vec( &xs.data, - &y.slice(..), + &y, /* dtype */ GgmlDType::Q4_0, /* ncols */ ncols, /* nrows */ 1, @@ -948,7 +926,7 @@ mod test { xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; let cuda_storage = mul_mat_via_q8_1( &xs.data, - &y.slice(..), + &y, /* dtype */ GgmlDType::Q4_0, /* x_rows */ 4, /* x_cols */ ncols, @@ -989,7 +967,7 @@ mod test { xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; let cuda_storage = mul_mat_via_q8_1( &xs.data, - &y.slice(..), + &y, /* dtype */ GgmlDType::Q4_0, /* x_rows */ x_rows, /* x_cols */ ncols,