Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 78 additions & 31 deletions candle-core/src/quantized/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,59 @@ fn pad(p: usize, q: usize) -> usize {
}

fn quantize_q8_1(
src: &CudaView<f32>,
src: &CudaSlice<f32>,
dst: &mut CudaSlice<u8>,
elem_count: usize,
k: usize,
ky: usize,
dev: &CudaDevice,
) -> Result<()> {
let kx = elem_count;
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
let kx_padded = pad(k, MATRIX_ROW_PADDING);
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);

let total_rows = ky;
// Get Q8_1 metadata.
let q8_1_block_size = GgmlDType::Q8_1.block_size();
let q8_1_type_size = GgmlDType::Q8_1.type_size();

// Calculate the size of the output buffer in bytes.
let num_blocks_per_row = kx_padded / q8_1_block_size;
let dst_row_size_bytes = num_blocks_per_row * q8_1_type_size;

const CHUNK_SIZE: usize = 65535; // gridDim.y limit
let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, ky as u32, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = func.builder();
builder.arg(src);
builder.arg(dst);
barg!(builder, kx as i32, kx_padded as i32);
unsafe { builder.launch(cfg) }.w()?;

let mut rows_processed = 0;
while rows_processed < total_rows {
// --- calculate the number of rows for this chunk ---
let remaining_rows = total_rows - rows_processed;
// This is our gridDim.y, now <= 65535
let rows_in_chunk = std::cmp::min(CHUNK_SIZE, remaining_rows);

// --- slice the source (f32) tensor by elements ---
let src_start_elem = rows_processed * k;
let src_num_elems = rows_in_chunk * k;
let src_chunk = src.slice(src_start_elem..(src_start_elem + src_num_elems));

// --- slice the destination (u8) tensor by bytes ---
let dst_start_byte = rows_processed * dst_row_size_bytes;
let dst_num_bytes = rows_in_chunk * dst_row_size_bytes;
let dst_chunk = dst.slice(dst_start_byte..(dst_start_byte + dst_num_bytes));

let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, rows_in_chunk as u32, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};

let mut builder = func.builder();
builder.arg(&src_chunk);
builder.arg(&dst_chunk);
barg!(builder, k as i32, kx_padded as i32);
unsafe { builder.launch(cfg) }.w()?;

rows_processed += rows_in_chunk;
}

Ok(())
}

Expand Down Expand Up @@ -189,7 +222,7 @@ fn dequantize_f16(

fn dequantize_mul_mat_vec(
data: &PaddedCudaSlice,
y: &CudaView<f32>,
y: &CudaSlice<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
Expand Down Expand Up @@ -235,7 +268,7 @@ fn dequantize_mul_mat_vec(

fn mul_mat_vec_via_q8_1(
data: &PaddedCudaSlice,
y: &CudaView<f32>,
y: &CudaSlice<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
Expand Down Expand Up @@ -306,7 +339,7 @@ fn mul_mat_vec_via_q8_1(
#[allow(clippy::too_many_arguments)]
fn mul_mat_via_q8_1(
data: &PaddedCudaSlice,
y: &CudaView<f32>,
y: &CudaSlice<f32>,
dtype: GgmlDType,
x_rows: usize,
x_cols: usize,
Expand Down Expand Up @@ -377,7 +410,7 @@ fn indexed_moe_forward_fused_q8_1_input(
weight: &CudaView<u8>,
w_shape: &crate::Shape, //[num_experts, n, k]
w_dtype: GgmlDType,
input: &CudaView<f32>,
input: &CudaSlice<f32>,
in_shape: &crate::Shape, //[batch, topk or 1, k]
ids: &CudaView<u32>,
idx_shape: &crate::Shape, //[batch, topk]
Expand All @@ -391,11 +424,19 @@ fn indexed_moe_forward_fused_q8_1_input(
assert!(batch == idx_shape.dims()[0], "batch dim not match!");

//quant input into q8_1
let total_rows = batch * input_dim1;
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
batch * input_dim1 * k_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
// Get Q8_1 metadata.
let q8_1_block_size = GgmlDType::Q8_1.block_size();
let q8_1_type_size = GgmlDType::Q8_1.type_size();

// Calculate the size of the output buffer in bytes.
let num_blocks_per_row = k_padded / q8_1_block_size;
let dst_row_size_bytes = num_blocks_per_row * q8_1_type_size;
let y_size_in_bytes = total_rows * dst_row_size_bytes;
let mut input_quant = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
quantize_q8_1(input, &mut input_quant, k, batch * input_dim1, dev)?;

quantize_q8_1(&input, &mut input_quant, k, total_rows, dev)?;

// output buffer
let outsize = batch * topk * n;
Expand Down Expand Up @@ -469,7 +510,7 @@ impl QCudaStorage {
&self.data.inner.slice(0..),
self_shape, //[num_experts, n, k]
self.dtype(),
&input_storage.slice(0..),
&input_storage,
input_l.shape(), //[batch, topk or 1, k]
&ids_storage.slice(0..),
ids_l.shape(), //[batch, topk]
Expand Down Expand Up @@ -711,8 +752,11 @@ impl QCudaStorage {
) -> Result<(CudaStorage, crate::Shape)> {
let (nrows, ncols) = self_shape.dims2()?;
let rhs = rhs.as_cuda_slice::<f32>()?;
let rhs = match rhs_l.contiguous_offsets() {
Some((o1, o2)) => rhs.slice(o1..o2),
match rhs_l.contiguous_offsets() {
Some((o1, _)) => assert!(
o1 == 0,
"sliced input is not supported in quantized matmul!"
),
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
};
let (b_size, k) = match rhs_l.shape().dims() {
Expand Down Expand Up @@ -766,8 +810,11 @@ impl QCudaStorage {
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
} else {
let storage = storage.as_cuda_slice::<f32>()?;
let storage = match layout.contiguous_offsets() {
Some((o1, o2)) => storage.slice(o1..o2),
match layout.contiguous_offsets() {
Some((o1, _)) => assert!(
o1 == 0,
"sliced input is not supported in quantized matmul!"
),
None => Err(crate::Error::RequiresContiguous {
op: "quantized-matmul",
}
Expand Down Expand Up @@ -826,7 +873,7 @@ mod test {
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.memcpy_stod(&vs)?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
quantize_q8_1(&y, &mut y_q8_1, el, 1, &dev)?;
Ok(())
}

Expand All @@ -840,7 +887,7 @@ mod test {
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_vec_via_q8_1(
&xs.data,
&y.slice(..),
&y,
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
Expand All @@ -856,7 +903,7 @@ mod test {

let cuda_storage = dequantize_mul_mat_vec(
&xs.data,
&y.slice(..),
&y,
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
Expand All @@ -879,7 +926,7 @@ mod test {
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
&xs.data,
&y.slice(..),
&y,
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ 4,
/* x_cols */ ncols,
Expand Down Expand Up @@ -920,7 +967,7 @@ mod test {
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
&xs.data,
&y.slice(..),
&y,
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ x_rows,
/* x_cols */ ncols,
Expand Down
Loading