From bcbbf14f17384c3ef98f76135e182fd2ab73ddb6 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 14 Nov 2024 20:51:51 -0500 Subject: [PATCH 01/26] Sketch iq quants? --- candle-core/src/quantized/iq_quants.rs | 73 ++++++++++++++++++++++++++ candle-core/src/quantized/mod.rs | 2 + 2 files changed, 75 insertions(+) create mode 100644 candle-core/src/quantized/iq_quants.rs diff --git a/candle-core/src/quantized/iq_quants.rs b/candle-core/src/quantized/iq_quants.rs new file mode 100644 index 0000000000..ad4e93066e --- /dev/null +++ b/candle-core/src/quantized/iq_quants.rs @@ -0,0 +1,73 @@ +use half::f16; + +use crate::Result; + +use super::{BlockQ8K, GgmlDType, GgmlType, QK_K}; + +pub const QK4_NL: usize = 32; + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +struct BlockIQ4xs { + pub(crate) d: f16, + pub(crate) scales_h: u16, + pub(crate) scales_l: [u8; QK_K / 64], + pub(crate) qs: [u8; QK_K / 2], +} + +const _: () = assert!( + std::mem::size_of::() + == std::mem::size_of::() + std::mem::size_of::() + QK_K / 64 + QK_K / 2, + "wrong iq4_xs block size/padding" +); + + +impl GgmlType for BlockIQ4xs { + const DTYPE: GgmlDType = GgmlDType::Q3K; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + todo!() + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q8_0 + let k = xs.len(); + if k % Self::BLCK_SIZE != 0 { + crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE); + }; + let nb = k / Self::BLCK_SIZE; + if ys.len() != nb { + crate::bail!( + "size mismatch {} {} {}", + xs.len(), + ys.len(), + Self::BLCK_SIZE + ) + } + for (i, ys) in ys.iter_mut().enumerate() { + let mut amax = 0f32; + let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; + for &x in xs.iter() { + amax = amax.max(x.abs()) + } + let d = amax / ((1 << 7) - 1) as f32; + let id = if d != 0f32 { 1. / d } else { 0. }; + ys.d = f16::from_f32(d); + for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) { + *y = f32::round(x * id) as i8 + } + } + Ok(()) + } + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + todo!() + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + todo!() + } +} diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 7f8dbfcf2a..5514dbf841 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -8,6 +8,7 @@ mod dummy_cuda; mod dummy_metal; pub mod ggml_file; pub mod gguf_file; +pub mod iq_quants; pub mod k_quants; #[cfg(feature = "metal")] pub mod metal; @@ -157,6 +158,7 @@ pub enum GgmlDType { Q5K, Q6K, Q8K, + Q4KXS, } impl GgmlDType { From 3168dc2195abab5f4987c622089773de105d2415 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 12 Jan 2025 18:48:37 -0500 Subject: [PATCH 02/26] Initial impl --- candle-core/src/quantized/iq_quants.rs | 98 +++++++++-- candle-core/src/quantized/metal.rs | 5 + candle-core/src/quantized/mod.rs | 20 ++- candle-core/src/quantized/utils.rs | 235 +++++++++++++++++++++++++ candle-core/tests/quantized_tests.rs | 54 ++++++ 5 files changed, 392 insertions(+), 20 deletions(-) diff --git a/candle-core/src/quantized/iq_quants.rs b/candle-core/src/quantized/iq_quants.rs index ad4e93066e..54011df2fa 100644 --- a/candle-core/src/quantized/iq_quants.rs +++ b/candle-core/src/quantized/iq_quants.rs @@ -1,6 +1,9 @@ use half::f16; -use crate::Result; +use crate::{ + quantized::utils::{group_for_quantization, quantize_iq4_nl}, + Result, +}; use super::{BlockQ8K, GgmlDType, GgmlType, QK_K}; @@ -8,7 +11,7 @@ pub const QK4_NL: usize = 32; #[derive(Debug, Clone, PartialEq)] #[repr(C)] -struct BlockIQ4xs { +pub struct BlockIQ4xs { pub(crate) d: f16, pub(crate) scales_h: u16, pub(crate) scales_l: [u8; QK_K / 64], @@ -21,18 +24,74 @@ const _: () = assert!( "wrong iq4_xs block size/padding" ); +const KVALUES_IQ4NL: [i8; 16] = [ + -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113, +]; impl GgmlType for BlockIQ4xs { - const DTYPE: GgmlDType = GgmlDType::Q3K; + const DTYPE: GgmlDType = GgmlDType::IQ4_XS; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { - todo!() + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantis block iq4xs {k} is not divisible by {QK_K}"); + } + + let nb = k / QK_K; + for i in 0..nb { + let block = &xs[i]; + + let d = block.d.to_f32(); + let qs = &block.qs; + + let mut qs_offset = 0; + + // A pointer (offset) into out_chunk: + let mut y_offset = 0; + + // 2. For each sub-block of size 32: + // QK_K/32 sub-blocks, each sub-block contributes 32 floats of output. + for ib in 0..(QK_K / 32) { + // 2a. Reconstruct `ls` from scales_l/scales_h: + // This matches the C code: + // ls = ((scales_l[ib/2] >> (4*(ib%2))) & 0xf) + // | (((scales_h >> (2*ib)) & 3) << 4); + let ib_div_2 = ib / 2; + let ib_mod_2 = ib % 2; + + let ls_low = (block.scales_l[ib_div_2] >> (4 * ib_mod_2)) & 0xF; + let ls_high = ((block.scales_h >> (2 * ib)) & 0x3) << 4; + let ls = (ls_low as u16 | ls_high) as i32; // range [0..63] + + // 2b. Compute the scale for this sub-block + // In the C code: float dl = d * (ls - 32). + let dl = d * ((ls - 32) as f32); + + // 2c. Now fill 32 floats of output by reading 16 bytes from qs. + // Each byte in qs has two 4-bit indices: low nibble, high nibble. + // So we do 16 times: + // y[j+0] = dl * kvalues_iq4nl[ qs[j] & 0xF ]; + // y[j+16] = dl * kvalues_iq4nl[ qs[j] >> 4 ]; + for j in 0..16 { + let byte_val = qs[qs_offset + j]; + let idx0 = (byte_val & 0xF) as usize; // low nibble + let idx1 = (byte_val >> 4) as usize; // high nibble + + ys[y_offset + j] = dl * KVALUES_IQ4NL[idx0] as f32; + ys[y_offset + j + 16] = dl * KVALUES_IQ4NL[idx1] as f32; + } + + // Advance by 16 bytes in qs, 32 floats in y + qs_offset += 16; + y_offset += 32; + } + } + Ok(()) } fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { - // quantize_row_q8_0 let k = xs.len(); if k % Self::BLCK_SIZE != 0 { crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE); @@ -46,18 +105,23 @@ impl GgmlType for BlockIQ4xs { Self::BLCK_SIZE ) } - for (i, ys) in ys.iter_mut().enumerate() { - let mut amax = 0f32; - let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; - for &x in xs.iter() { - amax = amax.max(x.abs()) - } - let d = amax / ((1 << 7) - 1) as f32; - let id = if d != 0f32 { 1. / d } else { 0. }; - ys.d = f16::from_f32(d); - for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) { - *y = f32::round(x * id) as i8 - } + const SUPER_BLOCK_SIZE: usize = QK_K; + const BLOCK_SIZE: usize = 32; + const NTRY: i32 = 7; + + for (ys_block, xs_block) in group_for_quantization(xs, ys)? { + quantize_iq4_nl( + xs_block, + SUPER_BLOCK_SIZE, + BLOCK_SIZE, + &mut ys_block.d, + &mut ys_block.qs, + &mut [ys_block.scales_h], + &mut ys_block.scales_l, + &KVALUES_IQ4NL, + None, + NTRY, + ); } Ok(()) } diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index d06b0674a7..2339256bb0 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -107,6 +107,10 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; } + GgmlDType::IQ4_XS => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockIQ4xs::to_float(&vec, &mut out)?; + } } let buffer = self.device.new_buffer_with_data(&out)?; @@ -387,6 +391,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K, GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K, GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, + GgmlDType::IQ4_XS => candle_metal_kernels::GgmlDType::Q8_0, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16, diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 96a144078d..ab6f1ef1c1 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,4 +1,5 @@ use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor, D}; +use iq_quants::*; use k_quants::*; use std::borrow::Cow; @@ -8,8 +9,8 @@ mod dummy_cuda; mod dummy_metal; pub mod ggml_file; pub mod gguf_file; -pub mod iq_quants; pub mod imatrix_file; +pub mod iq_quants; pub mod k_quants; #[cfg(feature = "metal")] pub mod metal; @@ -201,7 +202,7 @@ pub enum GgmlDType { Q5K, Q6K, Q8K, - Q4KXS, + IQ4_XS, } impl GgmlDType { @@ -221,6 +222,7 @@ impl GgmlDType { 13 => Self::Q5K, 14 => Self::Q6K, 15 => Self::Q8K, + 23 => Self::IQ4_XS, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 30 => Self::BF16, _ => crate::bail!("unknown dtype for tensor {u}"), @@ -244,6 +246,7 @@ impl GgmlDType { Self::Q5K => 13, Self::Q6K => 14, Self::Q8K => 15, + Self::IQ4_XS => 23, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 Self::BF16 => 30, } @@ -266,6 +269,10 @@ impl GgmlDType { Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), + Self::IQ4_XS => Box::new(vec![ + BlockIQ4xs::zeros(); + elem_count / BlockIQ4xs::BLCK_SIZE + ]), Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]), } } @@ -288,6 +295,7 @@ impl GgmlDType { Self::Q5K => std::mem::size_of::(), Self::Q6K => std::mem::size_of::(), Self::Q8K => std::mem::size_of::(), + Self::IQ4_XS => std::mem::size_of::(), } } @@ -302,7 +310,13 @@ impl GgmlDType { Self::Q5_1 => k_quants::QK5_1, Self::Q8_0 => k_quants::QK8_0, Self::Q8_1 => k_quants::QK8_1, - Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K, + Self::Q2K + | Self::Q3K + | Self::Q4K + | Self::Q5K + | Self::Q6K + | Self::Q8K + | Self::IQ4_XS => k_quants::QK_K, } } } diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/utils.rs index 0a087cddbb..fdff4447d7 100644 --- a/candle-core/src/quantized/utils.rs +++ b/candle-core/src/quantized/utils.rs @@ -1,5 +1,9 @@ +use half::f16; + use crate::Result; +const GROUP_MAX_EPS: f32 = 1e-15; + pub(super) fn nearest_int(v: f32) -> i32 { v.round() as i32 } @@ -560,3 +564,234 @@ pub(super) fn make_qp_quants( sumlx / suml2 } + +fn best_index_int8(n: usize, val: &[i8], x: f32) -> usize { + if x <= val[0] as f32 { + return 0; + } + if x >= val[n - 1] as f32 { + return n - 1; + } + let mut ml = 0_usize; + let mut mu = n - 1; + while mu - ml > 1 { + let mav = (ml + mu) / 2; + if x < val[mav] as f32 { + mu = mav; + } else { + ml = mav; + } + } + let dist_low = (x - val[mu - 1] as f32).abs(); + let dist_high = (val[mu] as f32 - x).abs(); + if dist_low < dist_high { + mu - 1 + } else { + mu + } +} + +#[allow(non_snake_case)] +pub(super) fn quantize_iq4_nl( + xs_block: &[f32], + super_block_size: usize, + block_size: usize, + d: &mut f16, + qs: &mut [u8], + scales_h: &mut [u16], + scales_l: &mut [u8], + values: &[i8], + quant_weights: Option<&[f32]>, + ntry: i32, +) { + // 1. Compute sigma2 = sum(x^2) * 2.f/super_block_size + let mut sigma2 = 0.0f32; + for &val in xs_block.iter() { + sigma2 += val * val; + } + sigma2 *= 2.0 / (super_block_size as f32); + + // 2. Zero out q4 region for safety (super_block_size/2 bytes) + // (Your block_out struct might store q4 in an array of length super_block_size/2) + qs.iter_mut().for_each(|x| *x = 0); + + // 3. dh[0] = 0 (in half float). We store it in block_out.dm for example: + *d = f16::from_f32(0.); // you may convert to half if needed + + // We'll store sub-block scales in a temporary float array + let mut scales = vec![0.0f32; super_block_size / block_size]; + let mut weight = vec![0.0f32; block_size]; + // We'll store indexes in a temporary array L + let mut L = vec![0u8; super_block_size]; + + let mut max_scale = 0.0f32; + let mut amax_scale = 0.0f32; + let nb = super_block_size / block_size; + + // 4. For each sub-block ib + for ib in 0..nb { + let start = ib * block_size; + let end = start + block_size; + let xb = &xs_block[start..end]; + let Lb = &mut L[start..end]; + + // optional quant_weights + if let Some(qw) = quant_weights { + let qw_block = &qw[start..end]; + for j in 0..block_size { + // weight[j] = qw[j] * sqrtf(sigma2 + xb[j]^2) + weight[j] = qw_block[j] * (sigma2 + xb[j] * xb[j]).sqrt(); + } + } else { + // weight[j] = xb[j]^2 + for j in 0..block_size { + weight[j] = xb[j] * xb[j]; + } + } + + // 5. find amax and max + let mut amax = 0.0f32; + let mut max = 0.0f32; + for &v in xb.iter() { + let ax = v.abs(); + if ax > amax { + amax = ax; + max = v; + } + } + // 6. if (amax < GROUP_MAX_EPS) => scales[ib] = 0; continue + if amax < GROUP_MAX_EPS { + scales[ib] = 0.0; + continue; + } + + // 7. do the initial d = ±(max/values[0]) (depending on ntry>0) + let sign = if ntry > 0 { -1.0 } else { 1.0 }; + let mut d = sign * (max / (values[0] as f32)); + let id = 1.0 / d; + let mut sumqx = 0.0f32; + let mut sumq2 = 0.0f32; + + // 7a. compute sumqx, sumq2 with that scale + for j in 0..block_size { + let al = id * xb[j]; + let l = best_index_int8(16, &values, al); + let q = values[l] as f32; + let w = weight[j]; + sumqx += w * q * xb[j]; + sumq2 += w * q * q; + Lb[j] = l as u8; + } + // 7b. refine d => sumqx / sumq2 + if sumq2 != 0.0 { + d = sumqx / sumq2; + } + let mut best = d * sumqx; + + // 8. search in range -ntry..ntry + for itry in -ntry..=ntry { + let itryf = itry as f32; + // id = (itry + values[0]) / max (from the code) + let attempt_id = (itryf + values[0] as f32) / max; + let mut attempt_sumqx = 0.0f32; + let mut attempt_sumq2 = 0.0f32; + for j in 0..block_size { + let al = attempt_id * xb[j]; + let l = best_index_int8(16, &values, al); + let q = values[l] as f32; + let w = weight[j]; + attempt_sumqx += w * q * xb[j]; + attempt_sumq2 += w * q * q; + } + if attempt_sumq2 > 0.0 { + let candidate_d = attempt_sumqx / attempt_sumq2; + let candidate_val = candidate_d * attempt_sumqx; + if candidate_val * candidate_val > best * attempt_sumq2 { + d = candidate_d; + best = candidate_val; + } + } + } + + // store final scale for this sub-block + scales[ib] = d; + let abs_d = d.abs(); + if abs_d > amax_scale { + amax_scale = abs_d; + max_scale = d; + } + } + + // 9. Now handle the second pass if we have more than one sub-block + if nb > 1 { + // block_out.dm = half of d => from the code: d = -max_scale/32 + let d_f32 = -max_scale / 32.0; + *d = f16::from_f32(d_f32); // or convert to half as needed + + let id = if d_f32 != 0.0 { 1.0 / d_f32 } else { 0.0 }; + + // scales_h/l might be stored in block_out as arrays + // zero out scales_h first + scales_h.iter_mut().for_each(|x| *x = 0); + + for ib in 0..nb { + // nearest int in the range [-32..31] + let mut l = nearest_int(id * scales[ib]); + l = l.clamp(-32, 31); + + // compute dl = d*l + let dl = d_f32 * (l as f32); + let idl = if dl != 0.0 { 1.0 / dl } else { 0.0 }; + + let start = ib * block_size; + let end = start + block_size; + let xb = &xs_block[start..end]; + let Lb = &mut L[start..end]; + + // re-assign Lb with the refined scale + for j in 0..block_size { + let al = idl * xb[j]; + let idx = best_index_int8(16, &values, al); + Lb[j] = idx as u8; + } + + // pack l into scales_h/l + let l_packed = (l + 32) as u8; // shift from [-32..31] to [0..63] + let l_l = l_packed & 0xF; + let l_h = l_packed >> 4; + + // store into block_out.scales_l[ib/2], block_out.scales_h[ib/8] + if ib % 2 == 0 { + scales_l[ib / 2] = l_l; + } else { + scales_l[ib / 2] |= l_l << 4; + } + scales_h[ib / 8] |= (l_h as u16) << (2 * (ib as u16 % 8)) ; + } + } else { + // super_block_size/block_size <= 1 + // the code sets block_out.dm = scales[0] + *d = f16::from_f32(scales[0]); + if ntry > 0 { + let id = if scales[0] != 0.0 { + 1.0 / scales[0] + } else { + 0.0 + }; + for j in 0..super_block_size { + let idx = best_index_int8(16, &values, id * xs_block[j]); + L[j] = idx as u8; + } + } + } + + // 10. Finally, build q4 from L + // "for i in 0..super_block_size/32 { for j in 0..16 { q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4) } }" + for i in 0..(super_block_size / 32) { + for j in 0..16 { + let l0 = L[32 * i + j]; + let l1 = L[32 * i + 16 + j] << 4; + qs[16 * i + j] = l0 | l1; + } + } +} diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index ab3f15bcf8..830c9b3dd9 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -949,6 +949,54 @@ fn quantize_q8k(device: &Device) -> Result<()> { Ok(()) } +fn quantize_iq4_xs(device: &Device) -> Result<()> { + let dtype = GgmlDType::IQ4_XS; + let src = get_test_vector2(0.5, 256, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + dbg!(&src[100..110], &dst[100..110]); + compare_with_error(dst.as_slice(), src.as_slice(), 0.017); + + // Test some specific values + assert_eq!( + [src[0], src[128], src[256], src[512], src[800], src[1023]], + [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + ); + let dst = round_vector(&dst); + assert_eq!( + [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + [-0.5, -0.373, -0.25, 0.0, 0.288, 0.498] + ); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + Ok(()) +} + test_device!( quantize_q4_0, quantize_q4_0_cpu, @@ -1009,6 +1057,12 @@ test_device!( quantize_q8k_cuda, quantize_q8k_metal ); +test_device!( + quantize_iq4_xs, + quantize_iq4_xs_cpu, + quantize_iq4_xs_cuda, + quantize_iq4_xs_metal +); /// Very simple dot product implementation fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { From 0c55e369e0b1c8017f273a3cf3e160f954dad6db Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 13 Jan 2025 09:46:59 -0500 Subject: [PATCH 03/26] Work --- candle-core/src/quantized/iq_quants.rs | 39 +++------- candle-core/src/quantized/utils.rs | 36 +++++----- candle-core/tests/quantized_tests.rs | 98 +++++++++++++++----------- 3 files changed, 81 insertions(+), 92 deletions(-) diff --git a/candle-core/src/quantized/iq_quants.rs b/candle-core/src/quantized/iq_quants.rs index 54011df2fa..6701d02b45 100644 --- a/candle-core/src/quantized/iq_quants.rs +++ b/candle-core/src/quantized/iq_quants.rs @@ -33,7 +33,7 @@ impl GgmlType for BlockIQ4xs { const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + fn to_float(xs: &[Self], mut ys: &mut [f32]) -> Result<()> { let k = ys.len(); if k % QK_K != 0 { crate::bail!("dequantis block iq4xs {k} is not divisible by {QK_K}"); @@ -44,48 +44,29 @@ impl GgmlType for BlockIQ4xs { let block = &xs[i]; let d = block.d.to_f32(); - let qs = &block.qs; + let mut qs = &block.qs[..]; - let mut qs_offset = 0; - - // A pointer (offset) into out_chunk: - let mut y_offset = 0; - - // 2. For each sub-block of size 32: - // QK_K/32 sub-blocks, each sub-block contributes 32 floats of output. for ib in 0..(QK_K / 32) { - // 2a. Reconstruct `ls` from scales_l/scales_h: - // This matches the C code: - // ls = ((scales_l[ib/2] >> (4*(ib%2))) & 0xf) - // | (((scales_h >> (2*ib)) & 3) << 4); let ib_div_2 = ib / 2; let ib_mod_2 = ib % 2; let ls_low = (block.scales_l[ib_div_2] >> (4 * ib_mod_2)) & 0xF; let ls_high = ((block.scales_h >> (2 * ib)) & 0x3) << 4; - let ls = (ls_low as u16 | ls_high) as i32; // range [0..63] + let ls = (ls_low as u16 | ls_high) as i32; - // 2b. Compute the scale for this sub-block - // In the C code: float dl = d * (ls - 32). let dl = d * ((ls - 32) as f32); - // 2c. Now fill 32 floats of output by reading 16 bytes from qs. - // Each byte in qs has two 4-bit indices: low nibble, high nibble. - // So we do 16 times: - // y[j+0] = dl * kvalues_iq4nl[ qs[j] & 0xF ]; - // y[j+16] = dl * kvalues_iq4nl[ qs[j] >> 4 ]; for j in 0..16 { - let byte_val = qs[qs_offset + j]; - let idx0 = (byte_val & 0xF) as usize; // low nibble - let idx1 = (byte_val >> 4) as usize; // high nibble + let byte_val = qs[j]; + let idx0 = (byte_val & 0xF) as usize; + let idx1 = (byte_val >> 4) as usize; - ys[y_offset + j] = dl * KVALUES_IQ4NL[idx0] as f32; - ys[y_offset + j + 16] = dl * KVALUES_IQ4NL[idx1] as f32; + ys[j] = dl * KVALUES_IQ4NL[idx0] as f32; + ys[j + 16] = dl * KVALUES_IQ4NL[idx1] as f32; } - // Advance by 16 bytes in qs, 32 floats in y - qs_offset += 16; - y_offset += 32; + qs = &qs[16..]; + ys = &mut ys[32..]; } } Ok(()) diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/utils.rs index fdff4447d7..12a7a2ad4e 100644 --- a/candle-core/src/quantized/utils.rs +++ b/candle-core/src/quantized/utils.rs @@ -582,8 +582,8 @@ fn best_index_int8(n: usize, val: &[i8], x: f32) -> usize { ml = mav; } } - let dist_low = (x - val[mu - 1] as f32).abs(); - let dist_high = (val[mu] as f32 - x).abs(); + let dist_low = x - val[mu - 1] as f32; + let dist_high = val[mu] as f32 - x; if dist_low < dist_high { mu - 1 } else { @@ -668,7 +668,7 @@ pub(super) fn quantize_iq4_nl( // 7. do the initial d = ±(max/values[0]) (depending on ntry>0) let sign = if ntry > 0 { -1.0 } else { 1.0 }; let mut d = sign * (max / (values[0] as f32)); - let id = 1.0 / d; + let mut id = 1.0 / d; let mut sumqx = 0.0f32; let mut sumq2 = 0.0f32; @@ -692,24 +692,20 @@ pub(super) fn quantize_iq4_nl( for itry in -ntry..=ntry { let itryf = itry as f32; // id = (itry + values[0]) / max (from the code) - let attempt_id = (itryf + values[0] as f32) / max; - let mut attempt_sumqx = 0.0f32; - let mut attempt_sumq2 = 0.0f32; + id = (itryf + values[0] as f32) / max; + sumqx = 0.0f32; + sumq2 = 0.0f32; for j in 0..block_size { - let al = attempt_id * xb[j]; + let al = id * xb[j]; let l = best_index_int8(16, &values, al); let q = values[l] as f32; let w = weight[j]; - attempt_sumqx += w * q * xb[j]; - attempt_sumq2 += w * q * q; + sumqx += w * q * xb[j]; + sumq2 += w * q * q; } - if attempt_sumq2 > 0.0 { - let candidate_d = attempt_sumqx / attempt_sumq2; - let candidate_val = candidate_d * attempt_sumqx; - if candidate_val * candidate_val > best * attempt_sumq2 { - d = candidate_d; - best = candidate_val; - } + if sumq2 > 0. && sumqx * sumqx > best * sumq2 { + d = sumqx / sumq2; + best = d * sumqx; } } @@ -756,9 +752,9 @@ pub(super) fn quantize_iq4_nl( } // pack l into scales_h/l - let l_packed = (l + 32) as u8; // shift from [-32..31] to [0..63] - let l_l = l_packed & 0xF; - let l_h = l_packed >> 4; + l += 32; // shift from [-32..31] to [0..63] + let l_l = (l & 0xF) as u8; + let l_h = (l >> 4) as u8; // store into block_out.scales_l[ib/2], block_out.scales_h[ib/8] if ib % 2 == 0 { @@ -766,7 +762,7 @@ pub(super) fn quantize_iq4_nl( } else { scales_l[ib / 2] |= l_l << 4; } - scales_h[ib / 8] |= (l_h as u16) << (2 * (ib as u16 % 8)) ; + scales_h[ib / 8] |= (l_h as u16) << (2 * (ib as u16 % 8)); } } else { // super_block_size/block_size <= 1 diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 830c9b3dd9..43c034d4d3 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -950,50 +950,62 @@ fn quantize_q8k(device: &Device) -> Result<()> { } fn quantize_iq4_xs(device: &Device) -> Result<()> { - let dtype = GgmlDType::IQ4_XS; - let src = get_test_vector2(0.5, 256, device)?; - let quant = quantized::QTensor::quantize(&src, dtype)?; - let dst = quant.dequantize(device)?; - let dst_f16 = quant.dequantize_f16(device)?; - let diff = (dst.to_dtype(DType::F16)? - dst_f16)? - .to_dtype(DType::F32)? - .abs()? - .sum_all()? - .to_vec0::()?; - assert_eq!(diff, 0.); - - let src = src.to_vec1::()?; - let dst = dst.to_vec1::()?; - dbg!(&src[100..110], &dst[100..110]); - compare_with_error(dst.as_slice(), src.as_slice(), 0.017); - - // Test some specific values - assert_eq!( - [src[0], src[128], src[256], src[512], src[800], src[1023]], - [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] - ); - let dst = round_vector(&dst); - assert_eq!( - [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], - [-0.5, -0.373, -0.25, 0.0, 0.288, 0.498] - ); + // let dtype = GgmlDType::IQ4_XS; + // let src = get_test_vector2(0.5, 256, device)?; + // let quant = quantized::QTensor::quantize(&src, dtype)?; + // let dst = quant.dequantize(device)?; + // let dst_f16 = quant.dequantize_f16(device)?; + // let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + // .to_dtype(DType::F32)? + // .abs()? + // .sum_all()? + // .to_vec0::()?; + // assert_eq!(diff, 0.); + + // let src = src.to_vec1::()?; + // let dst = dst.to_vec1::()?; + // dbg!(&src[10 * 10..(10 + 1) * 10], &dst[10 * 10..(10 + 1) * 10]); + // compare_with_error(dst.as_slice(), src.as_slice(), 0.017); + + // // Test some specific values + // assert_eq!( + // [src[0], src[128], src[256], src[512], src[800], src[1023]], + // [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + // ); + // let dst = round_vector(&dst); + // assert_eq!( + // [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + // [-0.5, -0.373, -0.25, 0.0, 0.288, 0.498] + // ); + + // let src_big = get_test_vector2(128.0, 1024, device)?; + // let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + // let dst_big = quant_big.dequantize(device)?; + // let dst_big_f16 = quant_big.dequantize_f16(device)?; + // let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + // .to_dtype(DType::F32)? + // .abs()? + // .sum_all()? + // .to_vec0::()?; + // assert_eq!(diff, 0.); + + // let src_big = src_big.to_vec1::()?; + // let dst_big = dst_big.to_vec1::()?; + // compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5); + + // ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; - let src_big = get_test_vector2(128.0, 1024, device)?; - let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; - let dst_big = quant_big.dequantize(device)?; - let dst_big_f16 = quant_big.dequantize_f16(device)?; - let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? - .to_dtype(DType::F32)? - .abs()? - .sum_all()? - .to_vec0::()?; - assert_eq!(diff, 0.); - - let src_big = src_big.to_vec1::()?; - let dst_big = dst_big.to_vec1::()?; - compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5); - - ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + let dtype = GgmlDType::IQ4_XS; + let tgt = Tensor::from_vec( + (1..=256).map(|x| 1. / x as f32).collect(), + (256,), + &Device::Cpu, + )?; + let q = quantized::QTensor::quantize(&tgt, dtype)?; + let res = q.dequantize(&Device::Cpu)?; + + println!("tgt {}", tgt.narrow(0, 0, 10)?); + println!("res {}", res.narrow(0, 0, 10)?); Ok(()) } From 533e2d4cf30f23814b9b3342d03302a8379b01c2 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 16 Jan 2025 21:56:52 -0500 Subject: [PATCH 04/26] Debugging --- candle-core/src/quantized/iq_quants.rs | 69 ++++++++++++++++---------- candle-core/src/quantized/utils.rs | 8 ++- candle-core/tests/quantized_tests.rs | 15 ++++-- 3 files changed, 59 insertions(+), 33 deletions(-) diff --git a/candle-core/src/quantized/iq_quants.rs b/candle-core/src/quantized/iq_quants.rs index 6701d02b45..8b69564f81 100644 --- a/candle-core/src/quantized/iq_quants.rs +++ b/candle-core/src/quantized/iq_quants.rs @@ -1,7 +1,7 @@ use half::f16; use crate::{ - quantized::utils::{group_for_quantization, quantize_iq4_nl}, + quantized::utils::{group_for_quantization, quantize_row_iq4_nl}, Result, }; @@ -36,7 +36,7 @@ impl GgmlType for BlockIQ4xs { fn to_float(xs: &[Self], mut ys: &mut [f32]) -> Result<()> { let k = ys.len(); if k % QK_K != 0 { - crate::bail!("dequantis block iq4xs {k} is not divisible by {QK_K}"); + crate::bail!("dequantize block iq4xs {k} is not divisible by {QK_K}"); } let nb = k / QK_K; @@ -50,19 +50,15 @@ impl GgmlType for BlockIQ4xs { let ib_div_2 = ib / 2; let ib_mod_2 = ib % 2; - let ls_low = (block.scales_l[ib_div_2] >> (4 * ib_mod_2)) & 0xF; - let ls_high = ((block.scales_h >> (2 * ib)) & 0x3) << 4; - let ls = (ls_low as u16 | ls_high) as i32; + let ls_low = (block.scales_l[ib_div_2] as i32 >> (4 * ib_mod_2 as i32)) & 0xF; + let ls_high = ((block.scales_h as i32 >> (2 * ib as i32)) & 3) << 4; + let ls = ls_low | ls_high; - let dl = d * ((ls - 32) as f32); + let dl = d * (ls as f32 - 32.); for j in 0..16 { - let byte_val = qs[j]; - let idx0 = (byte_val & 0xF) as usize; - let idx1 = (byte_val >> 4) as usize; - - ys[j] = dl * KVALUES_IQ4NL[idx0] as f32; - ys[j + 16] = dl * KVALUES_IQ4NL[idx1] as f32; + ys[j] = dl * KVALUES_IQ4NL[(qs[j] & 0xF) as usize] as f32; + ys[j + 16] = dl * KVALUES_IQ4NL[(qs[j] >> 4) as usize] as f32; } qs = &qs[16..]; @@ -90,20 +86,43 @@ impl GgmlType for BlockIQ4xs { const BLOCK_SIZE: usize = 32; const NTRY: i32 = 7; - for (ys_block, xs_block) in group_for_quantization(xs, ys)? { - quantize_iq4_nl( - xs_block, - SUPER_BLOCK_SIZE, - BLOCK_SIZE, - &mut ys_block.d, - &mut ys_block.qs, - &mut [ys_block.scales_h], - &mut ys_block.scales_l, - &KVALUES_IQ4NL, - None, - NTRY, - ); + let nrow = 1; + let n_per_row = 256; + let nblock = n_per_row / QK_K; + + for row in 0..nrow { + let ys_block = &mut ys[nblock * row]; + for ibl in 0..nblock { + let xs_block = &xs[n_per_row * row + QK_K * ibl..]; + quantize_row_iq4_nl( + xs_block, + SUPER_BLOCK_SIZE, + BLOCK_SIZE, + &mut ys_block.d, + &mut ys_block.qs, + &mut [ys_block.scales_h], + &mut ys_block.scales_l, + &KVALUES_IQ4NL, + None, + NTRY, + ); + } } + + // for (ys_block, xs_block) in group_for_quantization(xs, ys)? { + // quantize_row_iq4_nl( + // xs_block, + // SUPER_BLOCK_SIZE, + // BLOCK_SIZE, + // &mut ys_block.d, + // &mut ys_block.qs, + // &mut [ys_block.scales_h], + // &mut ys_block.scales_l, + // &KVALUES_IQ4NL, + // None, + // NTRY, + // ); + // } Ok(()) } diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/utils.rs index 12a7a2ad4e..0f5550a5e9 100644 --- a/candle-core/src/quantized/utils.rs +++ b/candle-core/src/quantized/utils.rs @@ -592,7 +592,7 @@ fn best_index_int8(n: usize, val: &[i8], x: f32) -> usize { } #[allow(non_snake_case)] -pub(super) fn quantize_iq4_nl( +pub(super) fn quantize_row_iq4_nl( xs_block: &[f32], super_block_size: usize, block_size: usize, @@ -676,16 +676,14 @@ pub(super) fn quantize_iq4_nl( for j in 0..block_size { let al = id * xb[j]; let l = best_index_int8(16, &values, al); + Lb[j] = l as u8; let q = values[l] as f32; let w = weight[j]; sumqx += w * q * xb[j]; sumq2 += w * q * q; - Lb[j] = l as u8; } // 7b. refine d => sumqx / sumq2 - if sumq2 != 0.0 { - d = sumqx / sumq2; - } + d = sumqx / sumq2; let mut best = d * sumqx; // 8. search in range -ntry..ntry diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 43c034d4d3..8a82c78c72 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -996,9 +996,15 @@ fn quantize_iq4_xs(device: &Device) -> Result<()> { // ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; let dtype = GgmlDType::IQ4_XS; - let tgt = Tensor::from_vec( - (1..=256).map(|x| 1. / x as f32).collect(), - (256,), + // let tgt = Tensor::from_vec( + // (1..=256).map(|x| 1. / x as f32).collect(), + // (256,), + // &Device::Cpu, + // )?; + let tgt = Tensor::randn( + 0f32, + 1f32, + 256, &Device::Cpu, )?; let q = quantized::QTensor::quantize(&tgt, dtype)?; @@ -1006,6 +1012,9 @@ fn quantize_iq4_xs(device: &Device) -> Result<()> { println!("tgt {}", tgt.narrow(0, 0, 10)?); println!("res {}", res.narrow(0, 0, 10)?); + + let diff = (tgt - res)?.abs()?.sum_all()?.to_scalar::()?; + dbg!(&diff); Ok(()) } From ac9f7c720a5a6edf9ef063122b634b3c518654a1 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 17 Jan 2025 21:41:20 -0500 Subject: [PATCH 05/26] Remove comments --- candle-core/src/quantized/utils.rs | 37 +++------------------------- candle-core/tests/quantized_tests.rs | 7 +----- 2 files changed, 4 insertions(+), 40 deletions(-) diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/utils.rs index 0f5550a5e9..20c3b1259e 100644 --- a/candle-core/src/quantized/utils.rs +++ b/candle-core/src/quantized/utils.rs @@ -604,52 +604,41 @@ pub(super) fn quantize_row_iq4_nl( quant_weights: Option<&[f32]>, ntry: i32, ) { - // 1. Compute sigma2 = sum(x^2) * 2.f/super_block_size let mut sigma2 = 0.0f32; for &val in xs_block.iter() { sigma2 += val * val; } sigma2 *= 2.0 / (super_block_size as f32); - // 2. Zero out q4 region for safety (super_block_size/2 bytes) - // (Your block_out struct might store q4 in an array of length super_block_size/2) qs.iter_mut().for_each(|x| *x = 0); - // 3. dh[0] = 0 (in half float). We store it in block_out.dm for example: - *d = f16::from_f32(0.); // you may convert to half if needed + *d = f16::from_f32(0.); - // We'll store sub-block scales in a temporary float array let mut scales = vec![0.0f32; super_block_size / block_size]; let mut weight = vec![0.0f32; block_size]; - // We'll store indexes in a temporary array L let mut L = vec![0u8; super_block_size]; let mut max_scale = 0.0f32; let mut amax_scale = 0.0f32; let nb = super_block_size / block_size; - // 4. For each sub-block ib for ib in 0..nb { let start = ib * block_size; let end = start + block_size; let xb = &xs_block[start..end]; let Lb = &mut L[start..end]; - // optional quant_weights if let Some(qw) = quant_weights { let qw_block = &qw[start..end]; for j in 0..block_size { - // weight[j] = qw[j] * sqrtf(sigma2 + xb[j]^2) weight[j] = qw_block[j] * (sigma2 + xb[j] * xb[j]).sqrt(); } } else { - // weight[j] = xb[j]^2 for j in 0..block_size { weight[j] = xb[j] * xb[j]; } } - // 5. find amax and max let mut amax = 0.0f32; let mut max = 0.0f32; for &v in xb.iter() { @@ -659,20 +648,17 @@ pub(super) fn quantize_row_iq4_nl( max = v; } } - // 6. if (amax < GROUP_MAX_EPS) => scales[ib] = 0; continue if amax < GROUP_MAX_EPS { scales[ib] = 0.0; continue; } - // 7. do the initial d = ±(max/values[0]) (depending on ntry>0) let sign = if ntry > 0 { -1.0 } else { 1.0 }; let mut d = sign * (max / (values[0] as f32)); let mut id = 1.0 / d; let mut sumqx = 0.0f32; let mut sumq2 = 0.0f32; - // 7a. compute sumqx, sumq2 with that scale for j in 0..block_size { let al = id * xb[j]; let l = best_index_int8(16, &values, al); @@ -682,14 +668,11 @@ pub(super) fn quantize_row_iq4_nl( sumqx += w * q * xb[j]; sumq2 += w * q * q; } - // 7b. refine d => sumqx / sumq2 d = sumqx / sumq2; let mut best = d * sumqx; - // 8. search in range -ntry..ntry for itry in -ntry..=ntry { let itryf = itry as f32; - // id = (itry + values[0]) / max (from the code) id = (itryf + values[0] as f32) / max; sumqx = 0.0f32; sumq2 = 0.0f32; @@ -707,7 +690,6 @@ pub(super) fn quantize_row_iq4_nl( } } - // store final scale for this sub-block scales[ib] = d; let abs_d = d.abs(); if abs_d > amax_scale { @@ -716,24 +698,18 @@ pub(super) fn quantize_row_iq4_nl( } } - // 9. Now handle the second pass if we have more than one sub-block if nb > 1 { - // block_out.dm = half of d => from the code: d = -max_scale/32 let d_f32 = -max_scale / 32.0; - *d = f16::from_f32(d_f32); // or convert to half as needed + *d = f16::from_f32(d_f32); let id = if d_f32 != 0.0 { 1.0 / d_f32 } else { 0.0 }; - // scales_h/l might be stored in block_out as arrays - // zero out scales_h first scales_h.iter_mut().for_each(|x| *x = 0); for ib in 0..nb { - // nearest int in the range [-32..31] let mut l = nearest_int(id * scales[ib]); l = l.clamp(-32, 31); - // compute dl = d*l let dl = d_f32 * (l as f32); let idl = if dl != 0.0 { 1.0 / dl } else { 0.0 }; @@ -742,19 +718,16 @@ pub(super) fn quantize_row_iq4_nl( let xb = &xs_block[start..end]; let Lb = &mut L[start..end]; - // re-assign Lb with the refined scale for j in 0..block_size { let al = idl * xb[j]; let idx = best_index_int8(16, &values, al); Lb[j] = idx as u8; } - // pack l into scales_h/l - l += 32; // shift from [-32..31] to [0..63] + l += 32; let l_l = (l & 0xF) as u8; let l_h = (l >> 4) as u8; - // store into block_out.scales_l[ib/2], block_out.scales_h[ib/8] if ib % 2 == 0 { scales_l[ib / 2] = l_l; } else { @@ -763,8 +736,6 @@ pub(super) fn quantize_row_iq4_nl( scales_h[ib / 8] |= (l_h as u16) << (2 * (ib as u16 % 8)); } } else { - // super_block_size/block_size <= 1 - // the code sets block_out.dm = scales[0] *d = f16::from_f32(scales[0]); if ntry > 0 { let id = if scales[0] != 0.0 { @@ -779,8 +750,6 @@ pub(super) fn quantize_row_iq4_nl( } } - // 10. Finally, build q4 from L - // "for i in 0..super_block_size/32 { for j in 0..16 { q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4) } }" for i in 0..(super_block_size / 32) { for j in 0..16 { let l0 = L[32 * i + j]; diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 8a82c78c72..b133fa6295 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1001,12 +1001,7 @@ fn quantize_iq4_xs(device: &Device) -> Result<()> { // (256,), // &Device::Cpu, // )?; - let tgt = Tensor::randn( - 0f32, - 1f32, - 256, - &Device::Cpu, - )?; + let tgt = Tensor::randn(0f32, 1f32, 256, &Device::Cpu)?; let q = quantized::QTensor::quantize(&tgt, dtype)?; let res = q.dequantize(&Device::Cpu)?; From 9c23e27dcc482d80550bf4de68f3ca4c98f3bbe9 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 2 Feb 2025 21:43:02 -0500 Subject: [PATCH 06/26] Add cpu impl for isq4x --- candle-core/src/quantized/cuda.rs | 2 +- candle-core/src/quantized/iq_quants.rs | 137 ---------- candle-core/src/quantized/iq_quants/mod.rs | 232 ++++++++++++++++ candle-core/src/quantized/iq_quants/utils.rs | 254 ++++++++++++++++++ .../{k_quants.rs => k_quants/mod.rs} | 212 +-------------- .../src/quantized/{ => k_quants}/utils.rs | 202 +------------- candle-core/src/quantized/metal.rs | 6 +- candle-core/src/quantized/mod.rs | 28 +- candle-core/src/quantized/quants.rs | 205 ++++++++++++++ candle-core/tests/quantized_tests.rs | 144 +++++----- candle-transformers/src/models/llama.rs | 15 +- 11 files changed, 809 insertions(+), 628 deletions(-) delete mode 100644 candle-core/src/quantized/iq_quants.rs create mode 100644 candle-core/src/quantized/iq_quants/mod.rs create mode 100644 candle-core/src/quantized/iq_quants/utils.rs rename candle-core/src/quantized/{k_quants.rs => k_quants/mod.rs} (92%) rename candle-core/src/quantized/{ => k_quants}/utils.rs (74%) create mode 100644 candle-core/src/quantized/quants.rs diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 19ca38495f..158fb019a0 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -1,5 +1,5 @@ use super::{GgmlDType, QStorage}; -use crate::quantized::k_quants::GgmlType; +use crate::quantized::quants::GgmlType; use crate::{backend::BackendDevice, cuda_backend::WrapErr}; use crate::{CudaDevice, CudaStorage, Result}; use half::f16; diff --git a/candle-core/src/quantized/iq_quants.rs b/candle-core/src/quantized/iq_quants.rs deleted file mode 100644 index 8b69564f81..0000000000 --- a/candle-core/src/quantized/iq_quants.rs +++ /dev/null @@ -1,137 +0,0 @@ -use half::f16; - -use crate::{ - quantized::utils::{group_for_quantization, quantize_row_iq4_nl}, - Result, -}; - -use super::{BlockQ8K, GgmlDType, GgmlType, QK_K}; - -pub const QK4_NL: usize = 32; - -#[derive(Debug, Clone, PartialEq)] -#[repr(C)] -pub struct BlockIQ4xs { - pub(crate) d: f16, - pub(crate) scales_h: u16, - pub(crate) scales_l: [u8; QK_K / 64], - pub(crate) qs: [u8; QK_K / 2], -} - -const _: () = assert!( - std::mem::size_of::() - == std::mem::size_of::() + std::mem::size_of::() + QK_K / 64 + QK_K / 2, - "wrong iq4_xs block size/padding" -); - -const KVALUES_IQ4NL: [i8; 16] = [ - -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113, -]; - -impl GgmlType for BlockIQ4xs { - const DTYPE: GgmlDType = GgmlDType::IQ4_XS; - const BLCK_SIZE: usize = QK_K; - type VecDotType = BlockQ8K; - - fn to_float(xs: &[Self], mut ys: &mut [f32]) -> Result<()> { - let k = ys.len(); - if k % QK_K != 0 { - crate::bail!("dequantize block iq4xs {k} is not divisible by {QK_K}"); - } - - let nb = k / QK_K; - for i in 0..nb { - let block = &xs[i]; - - let d = block.d.to_f32(); - let mut qs = &block.qs[..]; - - for ib in 0..(QK_K / 32) { - let ib_div_2 = ib / 2; - let ib_mod_2 = ib % 2; - - let ls_low = (block.scales_l[ib_div_2] as i32 >> (4 * ib_mod_2 as i32)) & 0xF; - let ls_high = ((block.scales_h as i32 >> (2 * ib as i32)) & 3) << 4; - let ls = ls_low | ls_high; - - let dl = d * (ls as f32 - 32.); - - for j in 0..16 { - ys[j] = dl * KVALUES_IQ4NL[(qs[j] & 0xF) as usize] as f32; - ys[j + 16] = dl * KVALUES_IQ4NL[(qs[j] >> 4) as usize] as f32; - } - - qs = &qs[16..]; - ys = &mut ys[32..]; - } - } - Ok(()) - } - - fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { - let k = xs.len(); - if k % Self::BLCK_SIZE != 0 { - crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE); - }; - let nb = k / Self::BLCK_SIZE; - if ys.len() != nb { - crate::bail!( - "size mismatch {} {} {}", - xs.len(), - ys.len(), - Self::BLCK_SIZE - ) - } - const SUPER_BLOCK_SIZE: usize = QK_K; - const BLOCK_SIZE: usize = 32; - const NTRY: i32 = 7; - - let nrow = 1; - let n_per_row = 256; - let nblock = n_per_row / QK_K; - - for row in 0..nrow { - let ys_block = &mut ys[nblock * row]; - for ibl in 0..nblock { - let xs_block = &xs[n_per_row * row + QK_K * ibl..]; - quantize_row_iq4_nl( - xs_block, - SUPER_BLOCK_SIZE, - BLOCK_SIZE, - &mut ys_block.d, - &mut ys_block.qs, - &mut [ys_block.scales_h], - &mut ys_block.scales_l, - &KVALUES_IQ4NL, - None, - NTRY, - ); - } - } - - // for (ys_block, xs_block) in group_for_quantization(xs, ys)? { - // quantize_row_iq4_nl( - // xs_block, - // SUPER_BLOCK_SIZE, - // BLOCK_SIZE, - // &mut ys_block.d, - // &mut ys_block.qs, - // &mut [ys_block.scales_h], - // &mut ys_block.scales_l, - // &KVALUES_IQ4NL, - // None, - // NTRY, - // ); - // } - Ok(()) - } - - #[allow(unreachable_code)] - fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - todo!() - } - - fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - todo!() - } -} diff --git a/candle-core/src/quantized/iq_quants/mod.rs b/candle-core/src/quantized/iq_quants/mod.rs new file mode 100644 index 0000000000..b7281ebb8c --- /dev/null +++ b/candle-core/src/quantized/iq_quants/mod.rs @@ -0,0 +1,232 @@ +use half::f16; +use utils::quantize_row_iq4_nl_impl; + +use crate::{bail, Result}; + +mod utils; + +use super::{BlockQ8K, GgmlDType, GgmlType, QK_K}; + +pub const QK4_NL: usize = 32; + +const KVALUES_IQ4NL: [i8; 16] = [ + -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113, +]; + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockIQ4xs { + pub(crate) d: f16, + pub(crate) scales_h: u16, + pub(crate) scales_l: [u8; QK_K / 64], + pub(crate) qs: [u8; QK_K / 2], +} + +const _: () = assert!( + std::mem::size_of::() + == std::mem::size_of::() + std::mem::size_of::() + QK_K / 64 + QK_K / 2, + "wrong iq4_xs block size/padding" +); + +impl GgmlType for BlockIQ4xs { + const DTYPE: GgmlDType = GgmlDType::Iq4Xs; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + + fn to_float(xs: &[Self], mut ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize block iq4xs {k} is not divisible by {QK_K}"); + } + + let nb = k / QK_K; + for i in 0..nb { + let block = &xs[i]; + + let d = block.d.to_f32(); + let mut qs = &block.qs[..]; + + for ib in 0..(QK_K / 32) { + let ib_div_2 = ib / 2; + let ib_mod_2 = ib % 2; + + let ls_low = (block.scales_l[ib_div_2] as i32 >> (4 * ib_mod_2 as i32)) & 0xF; + let ls_high = ((block.scales_h as i32 >> (2 * ib as i32)) & 3) << 4; + let ls = ls_low | ls_high; + + let dl = d * (ls as f32 - 32.); + + for j in 0..16 { + ys[j] = dl * KVALUES_IQ4NL[(qs[j] & 0xF) as usize] as f32; + ys[j + 16] = dl * KVALUES_IQ4NL[(qs[j] >> 4) as usize] as f32; + } + + qs = &qs[16..]; + ys = &mut ys[32..]; + } + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + let k = xs.len(); + if k % QK_K != 0 { + bail!("Input length must be multiple of QK_K = {}", QK_K); + } + + quantize_iq4_xs(xs, ys, 1, k, None)?; + + Ok(()) + } + + fn from_float_imatrix( + xs: &[f32], + ys: &mut [Self], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + let k = xs.len(); + if k % QK_K != 0 { + bail!("Input length must be multiple of QK_K = {}", QK_K); + } + let nrow = xs.len() / n_per_row; + + quantize_iq4_xs_imatrix(xs, ys, nrow, n_per_row, Some(imatrix_weights)); + + Ok(()) + } + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + todo!() + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + todo!() + } +} + +fn quantize_iq4_xs( + src: &[f32], + ys: &mut [BlockIQ4xs], + nrow: usize, + n_per_row: usize, + quant_weights: Option<&[f32]>, +) -> Result<()> { + // Basic sanity checks, similar to the C macro GGML_ASSERT + if n_per_row % QK_K != 0 { + bail!("n_per_row must be multiple of QK_K = {}", QK_K); + } + + let nblock = n_per_row / QK_K; + // We expect exactly nrow * nblock blocks in `ys`. + if ys.len() != nrow * nblock { + bail!( + "Output buffer size mismatch: want {} blocks, got {}", + nrow * nblock, + ys.len() + ); + } + + // We'll need some local buffers that match the usage in the C code + let mut lbuf = vec![0u8; QK_K]; // L[QK_K] + let mut weight = vec![0f32; 32]; // weight[32] (the block_size is 32) + let mut scales = vec![0f32; QK_K / 32]; // scales[QK_K/32], e.g. 256/32=8 + + let mut src_offset = 0; + let mut dst_offset = 0; + + for _row in 0..nrow { + // Each row has `nblock` blocks: + for ibl in 0..nblock { + // In C: block_iq4_xs * iq4 = (block_iq4_xs *)qrow; + let block = &mut ys[dst_offset + ibl]; + + // quant_weights? + let qw = quant_weights.map(|qw_all| { + let start = QK_K * ibl; + &qw_all[start..start + QK_K] + }); + + quantize_row_iq4_nl_impl( + /* super_block_size = */ QK_K, + /* block_size = */ 32, + /* x = */ + &src[src_offset + QK_K * ibl..src_offset + QK_K * (ibl + 1)], + /* dh = */ &mut block.d, + /* q4 = */ &mut block.qs, + /* scales_h = */ &mut block.scales_h, + /* scales_l = */ &mut block.scales_l, + /* scales = */ &mut scales, + /* weight = */ &mut weight, + /* L = */ &mut lbuf, + /* values = */ &KVALUES_IQ4NL, + /* quant_weights = */ qw, + /* ntry = */ 7, + ); + } + src_offset += n_per_row; + dst_offset += nblock; + } + + Ok(()) +} + +pub fn quantize_iq4_xs_imatrix( + src: &[f32], + dst: &mut [BlockIQ4xs], + nrow: usize, + n_per_row: usize, + quant_weights: Option<&[f32]>, +) { + // 1. Check that n_per_row is multiple of QK_K + assert_eq!(n_per_row % QK_K, 0, "n_per_row must be multiple of QK_K"); + let nblock = n_per_row / QK_K; + + // 2. We expect nrow * nblock blocks in `dst` + assert_eq!( + dst.len(), + nrow * nblock, + "Output slice must have exactly nrow*nblock elements" + ); + + // 3. Local buffers matching the C usage + let mut lbuf = vec![0u8; QK_K]; + let mut weight = vec![0f32; 32]; + let mut scales = vec![0f32; QK_K / 32]; + + // We'll track how far we've consumed `src`. + let mut src_offset = 0; + // Also track how far we move in `dst`. + let mut dst_offset = 0; + + // 4. Outer loop over rows + for _row in 0..nrow { + // In C: block_iq4_xs * iq4 = (block_iq4_xs *)qrow; + // Here: let block_slice = &mut dst[dst_offset..dst_offset + nblock]; + for ibl in 0..nblock { + let block = &mut dst[dst_offset + ibl]; + + // If quant_weights is Some, get the sub-slice for this block + let qw_block = quant_weights.map(|qw_all| &qw_all[ibl * QK_K..(ibl + 1) * QK_K]); + + quantize_row_iq4_nl_impl( + QK_K, // super_block_size + 32, // block_size + &src[src_offset + ibl * QK_K..src_offset + (ibl + 1) * QK_K], + &mut block.d, + &mut block.qs, + &mut block.scales_h, + &mut block.scales_l, + &mut scales, + &mut weight, + &mut lbuf, + &KVALUES_IQ4NL, + qw_block, + 7, // ntry + ); + } + src_offset += n_per_row; + dst_offset += nblock; + } +} diff --git a/candle-core/src/quantized/iq_quants/utils.rs b/candle-core/src/quantized/iq_quants/utils.rs new file mode 100644 index 0000000000..228f9e7d78 --- /dev/null +++ b/candle-core/src/quantized/iq_quants/utils.rs @@ -0,0 +1,254 @@ +use half::f16; + +const GROUP_MAX_EPS: f32 = 1e-15; + +#[allow(clippy::too_many_arguments)] +pub(super) fn quantize_row_iq4_nl_impl( + super_block_size: usize, + block_size: usize, + x: &[f32], + dh: &mut f16, + q4: &mut [u8], + scales_h: &mut u16, + scales_l: &mut [u8], + scales: &mut [f32], + weight: &mut [f32], + lbuf: &mut [u8], + values: &[i8], + quant_weights: Option<&[f32]>, + ntry: i32, +) { + // For safety, confirm the slices have correct lengths: + let sb_div_2 = super_block_size / 2; + let sb_div_32 = super_block_size / 32; + let sb_div_64 = super_block_size / 64; + assert_eq!(q4.len(), sb_div_2); + assert_eq!(scales.len(), sb_div_32); + assert_eq!(scales_l.len(), sb_div_64); + assert_eq!(lbuf.len(), super_block_size); + assert_eq!(weight.len(), block_size); + + // 1. compute sigma2 + let mut sigma2 = 0f32; + for j in 0..super_block_size { + sigma2 += x[j] * x[j]; + } + sigma2 *= 2.0 / (super_block_size as f32); + + // 2. zero out q4, set dh to 0 + for qi in q4.iter_mut() { + *qi = 0; + } + *dh = f16::from_f32(0.0); + + // Track the max absolute scale across sub-blocks + let mut max_scale = 0.0_f32; + let mut amax_scale = 0.0_f32; + + // For each 32-float block within the 256-float super-block: + let nblocks = super_block_size / block_size; + + for ib in 0..nblocks { + let xb = &x[ib * block_size..ib * block_size + block_size]; + let lb = &mut lbuf[ib * block_size..ib * block_size + block_size]; + + // If we have external `quant_weights`, fill `weight[j] = quant_weights[j]*sqrt(...)`, + // else `weight[j] = xb[j]*xb[j]` + if let Some(qw) = quant_weights { + let qw_block = &qw[ib * block_size..ib * block_size + block_size]; + for j in 0..block_size { + let val = xb[j]; + weight[j] = qw_block[j] * (sigma2 + val * val).sqrt(); + } + } else { + for j in 0..block_size { + let val = xb[j]; + weight[j] = val * val; + } + } + + // 3. find amax (largest absolute value in block) + let mut amax = 0.0_f32; + let mut max_v = 0.0_f32; + for &xx in xb { + let ax = xx.abs(); + if ax > amax { + amax = ax; + max_v = xx; + } + } + + // If amax is extremely small, scale = 0 + if amax < GROUP_MAX_EPS { + scales[ib] = 0.0; + continue; + } + + // 4. initial guess for d + let sign_factor = if ntry > 0 { -1.0 } else { 1.0 }; + let mut d = sign_factor * max_v / (values[0] as f32); + let id = 1.0 / d; + + // 5. compute an initial sumqx, sumq2 + let mut sumqx = 0.0_f32; + let mut sumq2 = 0.0_f32; + for j in 0..block_size { + let val = xb[j]; + let al = id * val; + let l = best_index_int8(values, al); + lb[j] = l as u8; + + let q = values[l] as f32; + let w = weight[j]; + sumqx += w * q * val; + sumq2 += w * q * q; + } + d = sumqx / sumq2; + let mut best = d * sumqx; + + // 6. do extra tries around that initial guess + for itry in -ntry..=ntry { + let test_id = (itry as f32 + values[0] as f32) / max_v; + let mut tmp_sumqx = 0.0_f32; + let mut tmp_sumq2 = 0.0_f32; + for j in 0..block_size { + let val = xb[j]; + let al = test_id * val; + let l = best_index_int8(values, al); + let q = values[l] as f32; + let w = weight[j]; + tmp_sumqx += w * q * val; + tmp_sumq2 += w * q * q; + } + if tmp_sumq2 > 0.0 { + let maybe_d = tmp_sumqx / tmp_sumq2; + let maybe_best = maybe_d * tmp_sumqx; + if maybe_best > best { + best = maybe_best; + d = maybe_d; + } + } + } + + // 7. record the chosen scale + scales[ib] = d; + let abs_d = d.abs(); + if abs_d > amax_scale { + amax_scale = abs_d; + max_scale = d; + } + } + + // 8. If we have more than one 32-float block in the super-block: + if nblocks > 1 { + // zero scales_h, because we store 2 bits per block in it + // for nblocks=8, we store them in a single 16-bit value + *scales_h = 0; + for sl in scales_l.iter_mut() { + *sl = 0; + } + + let d = -max_scale / 32.0; + *dh = f16::from_f32(d); + let id = if d != 0.0 { 1.0 / d } else { 0.0 }; + + for ib in 0..nblocks { + // l = nearest_int(id * scales[ib]), clamp to [-32..31] + let mut l = (id * scales[ib]).round() as i32; + if l < -32 { + l = -32; + } + if l > 31 { + l = 31; + } + + // refine block + let dl = d * (l as f32); + let idl = if dl != 0.0 { 1.0 / dl } else { 0.0 }; + + let xb = &x[ib * block_size..ib * block_size + block_size]; + let lb = &mut lbuf[ib * block_size..ib * block_size + block_size]; + for j in 0..block_size { + let val = xb[j]; + lb[j] = best_index_int8(values, idl * val) as u8; + } + + // store l in 4 bits + 4 bits + let l_offset = (l + 32) as u8; // now in [0..64) + let l_low = l_offset & 0x0f; + let l_high = l_offset >> 4; + + // scales_l[ib/2] uses the nibble for this block + if ib % 2 == 0 { + scales_l[ib / 2] = l_low; + } else { + scales_l[ib / 2] |= l_low << 4; + } + // scales_h for each block (2 bits per block) => stored in a 16-bit + // scaled_h[ib/8] with (l_high << (2*(ib%8))) + let shift = 2 * (ib % 8); + *scales_h |= (l_high as u16) << shift; + } + } else { + // single 32-float block => just store d + *dh = f16::from_f32(scales[0]); + if ntry > 0 { + let id = if scales[0] != 0.0 { + 1.0 / scales[0] + } else { + 0.0 + }; + for j in 0..super_block_size { + lbuf[j] = best_index_int8(values, id * x[j]) as u8; + } + } + } + + // 9. Finally, pack all 4-bit values from L into q4 + // q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4) + for i in 0..(super_block_size / 32) { + for j in 0..16 { + let lo = lbuf[32 * i + j] & 0x0f; + let hi = (lbuf[32 * i + 16 + j] & 0x0f) << 4; + q4[16 * i + j] = lo | hi; + } + } +} + +/// Finds the best index i in [0..values.len()) such that +/// `values[i]` is closest to `x`. The array `values` is strictly +/// ascending/ +fn best_index_int8(values: &[i8], x: f32) -> usize { + // Quick boundary checks + if x <= values[0] as f32 { + return 0; + } + let n = values.len(); + let last = (n - 1).max(0); + if x >= values[last] as f32 { + return last; + } + + // Binary search + let mut ml = 0; + let mut mu = last; + while mu - ml > 1 { + let mav = (ml + mu) / 2; + if x < values[mav] as f32 { + mu = mav; + } else { + ml = mav; + } + } + + // Return whichever is closer among values[mu-1], values[mu] + // But watch out if mu == 0 or mu == n-1 ... + // (the boundary checks above should keep mu>0) + let dist_left = (x - values[ml] as f32).abs(); + let dist_right = (values[mu] as f32 - x).abs(); + if dist_left <= dist_right { + ml + } else { + mu + } +} diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants/mod.rs similarity index 92% rename from candle-core/src/quantized/k_quants.rs rename to candle-core/src/quantized/k_quants/mod.rs index 27cc984a20..3b3b533ecb 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants/mod.rs @@ -1,13 +1,12 @@ -use super::utils::{ +mod utils; +use super::k_quants::utils::{ get_scale_min_k4, group_for_dequantization, group_for_quantization, make_q3_quants, - make_qkx1_quants, make_qx_quants, nearest_int, + make_qkx1_quants, make_qkx3_quants, make_qp_quants, make_qx_quants, nearest_int, }; -use super::GgmlDType; -use crate::quantized::utils::{make_qkx3_quants, make_qp_quants}; +use super::{GgmlDType, GgmlType}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; -use half::{bf16, f16}; -use rayon::prelude::*; +use half::f16; // Default to QK_K 256 rather than 64. pub const QK_K: usize = 256; @@ -20,37 +19,6 @@ pub const QK5_1: usize = 32; pub const QK8_0: usize = 32; pub const QK8_1: usize = 32; -pub trait GgmlType: Sized + Clone + Send + Sync { - const DTYPE: GgmlDType; - const BLCK_SIZE: usize; - type VecDotType: GgmlType; - - // This is only safe for types that include immediate values such as float/int/... - fn zeros() -> Self { - unsafe { std::mem::MaybeUninit::zeroed().assume_init() } - } - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>; - fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>; - fn from_float_imatrix( - _xs: &[f32], - _ys: &mut [Self], - _imatrix_weights: &[f32], - _n_per_row: usize, - ) -> Result<()> { - crate::bail!( - "`from_float_imatrix` is unimplemented for {:?}", - Self::DTYPE - ); - } - - /// Dot product used as a building block for quantized mat-mul. - /// n is the number of elements to be considered. - fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; - - /// Generic implementation of the dot product without simd optimizations. - fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; -} - #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ4_0 { @@ -2268,173 +2236,3 @@ impl GgmlType for BlockQ8K { Ok(()) } } - -// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 -pub fn matmul( - mkn: (usize, usize, usize), - lhs: &[f32], - rhs_t: &[T], - dst: &mut [f32], -) -> Result<()> { - let (m, k, n) = mkn; - if m * k != lhs.len() { - crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); - } - - let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE); - let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE); - // TODO: Do not make this copy if the DotType is f32. - // TODO: Pre-allocate this. - let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; - for row_idx in 0..m { - let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; - let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; - T::VecDotType::from_float(lhs, lhs_b)? - } - let lhs_b = lhs_b.as_slice(); - - for row_idx in 0..m { - let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; - let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; - - let result: Result> = dst_row - .into_par_iter() - .enumerate() - .with_min_len(128) - .with_max_len(512) - .map(|(col_idx, dst)| { - let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; - T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value) - }) - .collect(); - - result?; - } - Ok(()) -} - -impl GgmlType for f32 { - const DTYPE: GgmlDType = GgmlDType::F32; - const BLCK_SIZE: usize = 1; - type VecDotType = f32; - - fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - Self::vec_dot_unopt(n, xs, ys) - } - - fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - if xs.len() < n { - crate::bail!("size mismatch {} < {n}", xs.len()) - } - if ys.len() < n { - crate::bail!("size mismatch {} < {n}", ys.len()) - } - let mut res = 0f32; - unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; - Ok(res) - } - - fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { - if xs.len() != ys.len() { - crate::bail!("size mismatch {} {}", xs.len(), ys.len()); - } - ys.copy_from_slice(xs); - Ok(()) - } - - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { - if xs.len() != ys.len() { - crate::bail!("size mismatch {} {}", xs.len(), ys.len()); - } - ys.copy_from_slice(xs); - Ok(()) - } -} - -impl GgmlType for f16 { - const DTYPE: GgmlDType = GgmlDType::F16; - const BLCK_SIZE: usize = 1; - type VecDotType = f16; - - fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - Self::vec_dot_unopt(n, xs, ys) - } - - fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - if xs.len() < n { - crate::bail!("size mismatch {} < {n}", xs.len()) - } - if ys.len() < n { - crate::bail!("size mismatch {} < {n}", ys.len()) - } - let mut res = 0f32; - unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; - Ok(res) - } - - fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { - if xs.len() != ys.len() { - crate::bail!("size mismatch {} {}", xs.len(), ys.len()); - } - // TODO: vectorize - for (x, y) in xs.iter().zip(ys.iter_mut()) { - *y = f16::from_f32(*x) - } - Ok(()) - } - - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { - if xs.len() != ys.len() { - crate::bail!("size mismatch {} {}", xs.len(), ys.len()); - } - // TODO: vectorize - for (x, y) in xs.iter().zip(ys.iter_mut()) { - *y = x.to_f32() - } - Ok(()) - } -} - -impl GgmlType for bf16 { - const DTYPE: GgmlDType = GgmlDType::BF16; - const BLCK_SIZE: usize = 1; - type VecDotType = bf16; - - fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - Self::vec_dot_unopt(n, xs, ys) - } - - fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - if xs.len() < n { - crate::bail!("size mismatch {} < {n}", xs.len()) - } - if ys.len() < n { - crate::bail!("size mismatch {} < {n}", ys.len()) - } - let mut res = 0f32; - unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; - Ok(res) - } - - fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { - if xs.len() != ys.len() { - crate::bail!("size mismatch {} {}", xs.len(), ys.len()); - } - // TODO: vectorize - for (x, y) in xs.iter().zip(ys.iter_mut()) { - *y = bf16::from_f32(*x) - } - Ok(()) - } - - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { - if xs.len() != ys.len() { - crate::bail!("size mismatch {} {}", xs.len(), ys.len()); - } - // TODO: vectorize - for (x, y) in xs.iter().zip(ys.iter_mut()) { - *y = x.to_f32() - } - Ok(()) - } -} diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/k_quants/utils.rs similarity index 74% rename from candle-core/src/quantized/utils.rs rename to candle-core/src/quantized/k_quants/utils.rs index 20c3b1259e..7db44df455 100644 --- a/candle-core/src/quantized/utils.rs +++ b/candle-core/src/quantized/k_quants/utils.rs @@ -1,9 +1,5 @@ -use half::f16; - use crate::Result; -const GROUP_MAX_EPS: f32 = 1e-15; - pub(super) fn nearest_int(v: f32) -> i32 { v.round() as i32 } @@ -11,7 +7,7 @@ pub(super) fn nearest_int(v: f32) -> i32 { /// Validates that the input and output are the right size and returns an iterator which maps each /// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed /// to be `T::BLCK_SIZE` long. -pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>( +pub(super) fn group_for_quantization<'a, 'b, T: super::GgmlType>( xs: &'b [f32], ys: &'a mut [T], ) -> Result> { @@ -32,7 +28,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>( /// Validates that the input and output are the right size and returns an iterator which maps each /// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed /// to be `T::BLCK_SIZE` long. -pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>( +pub(super) fn group_for_dequantization<'a, 'b, T: super::GgmlType>( xs: &'a [T], ys: &'b mut [f32], ) -> Result> { @@ -564,197 +560,3 @@ pub(super) fn make_qp_quants( sumlx / suml2 } - -fn best_index_int8(n: usize, val: &[i8], x: f32) -> usize { - if x <= val[0] as f32 { - return 0; - } - if x >= val[n - 1] as f32 { - return n - 1; - } - let mut ml = 0_usize; - let mut mu = n - 1; - while mu - ml > 1 { - let mav = (ml + mu) / 2; - if x < val[mav] as f32 { - mu = mav; - } else { - ml = mav; - } - } - let dist_low = x - val[mu - 1] as f32; - let dist_high = val[mu] as f32 - x; - if dist_low < dist_high { - mu - 1 - } else { - mu - } -} - -#[allow(non_snake_case)] -pub(super) fn quantize_row_iq4_nl( - xs_block: &[f32], - super_block_size: usize, - block_size: usize, - d: &mut f16, - qs: &mut [u8], - scales_h: &mut [u16], - scales_l: &mut [u8], - values: &[i8], - quant_weights: Option<&[f32]>, - ntry: i32, -) { - let mut sigma2 = 0.0f32; - for &val in xs_block.iter() { - sigma2 += val * val; - } - sigma2 *= 2.0 / (super_block_size as f32); - - qs.iter_mut().for_each(|x| *x = 0); - - *d = f16::from_f32(0.); - - let mut scales = vec![0.0f32; super_block_size / block_size]; - let mut weight = vec![0.0f32; block_size]; - let mut L = vec![0u8; super_block_size]; - - let mut max_scale = 0.0f32; - let mut amax_scale = 0.0f32; - let nb = super_block_size / block_size; - - for ib in 0..nb { - let start = ib * block_size; - let end = start + block_size; - let xb = &xs_block[start..end]; - let Lb = &mut L[start..end]; - - if let Some(qw) = quant_weights { - let qw_block = &qw[start..end]; - for j in 0..block_size { - weight[j] = qw_block[j] * (sigma2 + xb[j] * xb[j]).sqrt(); - } - } else { - for j in 0..block_size { - weight[j] = xb[j] * xb[j]; - } - } - - let mut amax = 0.0f32; - let mut max = 0.0f32; - for &v in xb.iter() { - let ax = v.abs(); - if ax > amax { - amax = ax; - max = v; - } - } - if amax < GROUP_MAX_EPS { - scales[ib] = 0.0; - continue; - } - - let sign = if ntry > 0 { -1.0 } else { 1.0 }; - let mut d = sign * (max / (values[0] as f32)); - let mut id = 1.0 / d; - let mut sumqx = 0.0f32; - let mut sumq2 = 0.0f32; - - for j in 0..block_size { - let al = id * xb[j]; - let l = best_index_int8(16, &values, al); - Lb[j] = l as u8; - let q = values[l] as f32; - let w = weight[j]; - sumqx += w * q * xb[j]; - sumq2 += w * q * q; - } - d = sumqx / sumq2; - let mut best = d * sumqx; - - for itry in -ntry..=ntry { - let itryf = itry as f32; - id = (itryf + values[0] as f32) / max; - sumqx = 0.0f32; - sumq2 = 0.0f32; - for j in 0..block_size { - let al = id * xb[j]; - let l = best_index_int8(16, &values, al); - let q = values[l] as f32; - let w = weight[j]; - sumqx += w * q * xb[j]; - sumq2 += w * q * q; - } - if sumq2 > 0. && sumqx * sumqx > best * sumq2 { - d = sumqx / sumq2; - best = d * sumqx; - } - } - - scales[ib] = d; - let abs_d = d.abs(); - if abs_d > amax_scale { - amax_scale = abs_d; - max_scale = d; - } - } - - if nb > 1 { - let d_f32 = -max_scale / 32.0; - *d = f16::from_f32(d_f32); - - let id = if d_f32 != 0.0 { 1.0 / d_f32 } else { 0.0 }; - - scales_h.iter_mut().for_each(|x| *x = 0); - - for ib in 0..nb { - let mut l = nearest_int(id * scales[ib]); - l = l.clamp(-32, 31); - - let dl = d_f32 * (l as f32); - let idl = if dl != 0.0 { 1.0 / dl } else { 0.0 }; - - let start = ib * block_size; - let end = start + block_size; - let xb = &xs_block[start..end]; - let Lb = &mut L[start..end]; - - for j in 0..block_size { - let al = idl * xb[j]; - let idx = best_index_int8(16, &values, al); - Lb[j] = idx as u8; - } - - l += 32; - let l_l = (l & 0xF) as u8; - let l_h = (l >> 4) as u8; - - if ib % 2 == 0 { - scales_l[ib / 2] = l_l; - } else { - scales_l[ib / 2] |= l_l << 4; - } - scales_h[ib / 8] |= (l_h as u16) << (2 * (ib as u16 % 8)); - } - } else { - *d = f16::from_f32(scales[0]); - if ntry > 0 { - let id = if scales[0] != 0.0 { - 1.0 / scales[0] - } else { - 0.0 - }; - for j in 0..super_block_size { - let idx = best_index_int8(16, &values, id * xs_block[j]); - L[j] = idx as u8; - } - } - } - - for i in 0..(super_block_size / 32) { - for j in 0..16 { - let l0 = L[32 * i + j]; - let l1 = L[32 * i + 16 + j] << 4; - qs[16 * i + j] = l0 | l1; - } - } -} diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 2339256bb0..511a5b6ae2 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -34,7 +34,7 @@ impl QMetalStorage { } pub fn dequantize(&self, elem_count: usize) -> Result { - use crate::quantized::k_quants::GgmlType; + use crate::quantized::quants::GgmlType; let buffer = self.device.new_buffer_managed(self.buffer.length())?; let command_buffer = self.device.command_buffer()?; @@ -107,7 +107,7 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; } - GgmlDType::IQ4_XS => { + GgmlDType::Iq4Xs => { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockIQ4xs::to_float(&vec, &mut out)?; } @@ -391,7 +391,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K, GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K, GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, - GgmlDType::IQ4_XS => candle_metal_kernels::GgmlDType::Q8_0, + GgmlDType::Iq4Xs => candle_metal_kernels::GgmlDType::Q8_0, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16, diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index ab6f1ef1c1..761069656f 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -14,6 +14,7 @@ pub mod iq_quants; pub mod k_quants; #[cfg(feature = "metal")] pub mod metal; +pub mod quants; #[cfg(not(feature = "metal"))] mod metal { pub use super::dummy_metal::*; @@ -29,10 +30,9 @@ mod cuda { pub mod neon; #[cfg(target_feature = "simd128")] pub mod simd128; -pub mod utils; use half::{bf16, f16}; -pub use k_quants::GgmlType; +pub use quants::GgmlType; pub struct QTensor { storage: QStorage, @@ -202,7 +202,7 @@ pub enum GgmlDType { Q5K, Q6K, Q8K, - IQ4_XS, + Iq4Xs, } impl GgmlDType { @@ -222,7 +222,7 @@ impl GgmlDType { 13 => Self::Q5K, 14 => Self::Q6K, 15 => Self::Q8K, - 23 => Self::IQ4_XS, + 23 => Self::Iq4Xs, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 30 => Self::BF16, _ => crate::bail!("unknown dtype for tensor {u}"), @@ -246,7 +246,7 @@ impl GgmlDType { Self::Q5K => 13, Self::Q6K => 14, Self::Q8K => 15, - Self::IQ4_XS => 23, + Self::Iq4Xs => 23, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 Self::BF16 => 30, } @@ -269,7 +269,7 @@ impl GgmlDType { Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), - Self::IQ4_XS => Box::new(vec![ + Self::Iq4Xs => Box::new(vec![ BlockIQ4xs::zeros(); elem_count / BlockIQ4xs::BLCK_SIZE ]), @@ -295,7 +295,7 @@ impl GgmlDType { Self::Q5K => std::mem::size_of::(), Self::Q6K => std::mem::size_of::(), Self::Q8K => std::mem::size_of::(), - Self::IQ4_XS => std::mem::size_of::(), + Self::Iq4Xs => std::mem::size_of::(), } } @@ -310,13 +310,9 @@ impl GgmlDType { Self::Q5_1 => k_quants::QK5_1, Self::Q8_0 => k_quants::QK8_0, Self::Q8_1 => k_quants::QK8_1, - Self::Q2K - | Self::Q3K - | Self::Q4K - | Self::Q5K - | Self::Q6K - | Self::Q8K - | Self::IQ4_XS => k_quants::QK_K, + Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K | Self::Iq4Xs => { + k_quants::QK_K + } } } } @@ -341,9 +337,9 @@ pub trait QuantizedType: Send + Sync { fn size(&self) -> usize; } -impl QuantizedType for Vec { +impl QuantizedType for Vec { fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> { - k_quants::matmul(mkn, lhs, self.as_slice(), dst) + quants::matmul(mkn, lhs, self.as_slice(), dst) } fn size(&self) -> usize { diff --git a/candle-core/src/quantized/quants.rs b/candle-core/src/quantized/quants.rs new file mode 100644 index 0000000000..db3230271f --- /dev/null +++ b/candle-core/src/quantized/quants.rs @@ -0,0 +1,205 @@ +use super::GgmlDType; +use crate::Result; +use half::{bf16, f16}; +use rayon::prelude::*; + +pub trait GgmlType: Sized + Clone + Send + Sync { + const DTYPE: GgmlDType; + const BLCK_SIZE: usize; + type VecDotType: GgmlType; + + // This is only safe for types that include immediate values such as float/int/... + fn zeros() -> Self { + unsafe { std::mem::MaybeUninit::zeroed().assume_init() } + } + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>; + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>; + fn from_float_imatrix( + _xs: &[f32], + _ys: &mut [Self], + _imatrix_weights: &[f32], + _n_per_row: usize, + ) -> Result<()> { + crate::bail!( + "`from_float_imatrix` is unimplemented for {:?}", + Self::DTYPE + ); + } + + /// Dot product used as a building block for quantized mat-mul. + /// n is the number of elements to be considered. + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; + + /// Generic implementation of the dot product without simd optimizations. + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; +} + +// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 +pub fn matmul( + mkn: (usize, usize, usize), + lhs: &[f32], + rhs_t: &[T], + dst: &mut [f32], +) -> Result<()> { + let (m, k, n) = mkn; + if m * k != lhs.len() { + crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); + } + + let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE); + let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE); + // TODO: Do not make this copy if the DotType is f32. + // TODO: Pre-allocate this. + let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; + for row_idx in 0..m { + let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; + T::VecDotType::from_float(lhs, lhs_b)? + } + let lhs_b = lhs_b.as_slice(); + + for row_idx in 0..m { + let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; + + let result: Result> = dst_row + .into_par_iter() + .enumerate() + .with_min_len(128) + .with_max_len(512) + .map(|(col_idx, dst)| { + let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value) + }) + .collect(); + + result?; + } + Ok(()) +} + +impl GgmlType for f32 { + const DTYPE: GgmlDType = GgmlDType::F32; + const BLCK_SIZE: usize = 1; + type VecDotType = f32; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if xs.len() < n { + crate::bail!("size mismatch {} < {n}", xs.len()) + } + if ys.len() < n { + crate::bail!("size mismatch {} < {n}", ys.len()) + } + let mut res = 0f32; + unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; + Ok(res) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + ys.copy_from_slice(xs); + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + ys.copy_from_slice(xs); + Ok(()) + } +} + +impl GgmlType for f16 { + const DTYPE: GgmlDType = GgmlDType::F16; + const BLCK_SIZE: usize = 1; + type VecDotType = f16; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if xs.len() < n { + crate::bail!("size mismatch {} < {n}", xs.len()) + } + if ys.len() < n { + crate::bail!("size mismatch {} < {n}", ys.len()) + } + let mut res = 0f32; + unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; + Ok(res) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = f16::from_f32(*x) + } + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = x.to_f32() + } + Ok(()) + } +} + +impl GgmlType for bf16 { + const DTYPE: GgmlDType = GgmlDType::BF16; + const BLCK_SIZE: usize = 1; + type VecDotType = bf16; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if xs.len() < n { + crate::bail!("size mismatch {} < {n}", xs.len()) + } + if ys.len() < n { + crate::bail!("size mismatch {} < {n}", ys.len()) + } + let mut res = 0f32; + unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; + Ok(res) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = bf16::from_f32(*x) + } + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = x.to_f32() + } + Ok(()) + } +} diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index b133fa6295..593c6b9bd6 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1,6 +1,6 @@ use candle_core::{ bail, - quantized::{self, GgmlDType}, + quantized::{self, quants, GgmlDType}, test_device, test_utils::to_vec2_round, DType, Device, IndexOp, Module, Result, Tensor, Var, @@ -90,7 +90,7 @@ fn quantized_matmul(device: &Device) -> Result<()> { let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let rhs = (0..(k * n)).map(|v| v as f32).collect::>(); k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; - k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; + quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; assert_eq!( dst.iter().map(|x| x.round()).collect::>(), &[ @@ -155,7 +155,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> { .collect::>(); let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; - k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; + quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; assert_eq!( dst.iter().map(|x| x.round()).collect::>(), &[ @@ -950,66 +950,84 @@ fn quantize_q8k(device: &Device) -> Result<()> { } fn quantize_iq4_xs(device: &Device) -> Result<()> { - // let dtype = GgmlDType::IQ4_XS; - // let src = get_test_vector2(0.5, 256, device)?; - // let quant = quantized::QTensor::quantize(&src, dtype)?; - // let dst = quant.dequantize(device)?; - // let dst_f16 = quant.dequantize_f16(device)?; - // let diff = (dst.to_dtype(DType::F16)? - dst_f16)? - // .to_dtype(DType::F32)? - // .abs()? - // .sum_all()? - // .to_vec0::()?; - // assert_eq!(diff, 0.); - - // let src = src.to_vec1::()?; - // let dst = dst.to_vec1::()?; - // dbg!(&src[10 * 10..(10 + 1) * 10], &dst[10 * 10..(10 + 1) * 10]); - // compare_with_error(dst.as_slice(), src.as_slice(), 0.017); - - // // Test some specific values - // assert_eq!( - // [src[0], src[128], src[256], src[512], src[800], src[1023]], - // [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] - // ); - // let dst = round_vector(&dst); - // assert_eq!( - // [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], - // [-0.5, -0.373, -0.25, 0.0, 0.288, 0.498] - // ); - - // let src_big = get_test_vector2(128.0, 1024, device)?; - // let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; - // let dst_big = quant_big.dequantize(device)?; - // let dst_big_f16 = quant_big.dequantize_f16(device)?; - // let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? - // .to_dtype(DType::F32)? - // .abs()? - // .sum_all()? - // .to_vec0::()?; - // assert_eq!(diff, 0.); - - // let src_big = src_big.to_vec1::()?; - // let dst_big = dst_big.to_vec1::()?; - // compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5); - - // ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; - - let dtype = GgmlDType::IQ4_XS; - // let tgt = Tensor::from_vec( - // (1..=256).map(|x| 1. / x as f32).collect(), - // (256,), - // &Device::Cpu, - // )?; - let tgt = Tensor::randn(0f32, 1f32, 256, &Device::Cpu)?; - let q = quantized::QTensor::quantize(&tgt, dtype)?; - let res = q.dequantize(&Device::Cpu)?; - - println!("tgt {}", tgt.narrow(0, 0, 10)?); - println!("res {}", res.narrow(0, 0, 10)?); - - let diff = (tgt - res)?.abs()?.sum_all()?.to_scalar::()?; - dbg!(&diff); + let dtype = GgmlDType::Iq4Xs; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.025); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 5.9); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + + Ok(()) +} + +#[test] +fn imatrix_quantize_iq4_xs() -> Result<()> { + // let data = + // quantized::imatrix_file::load_imatrix("../Llama-3.2-3B-Instruct.imatrix").unwrap(); + // for (name, weights) in &data { + // println!("{name}, {} elems", weights.len()); + // } + // dbg!(&data["blk.0.attn_q.weight"].len()); + + let cpu = &Device::Cpu; + + let mut row_counts = 0f64; + let mut ncall = 0f64; + let mut values = Tensor::zeros((768,), DType::F32, cpu)?; + + for _ in 0..10 { + let lhs = Var::from_tensor(&Tensor::randn(0f32, 1f32, (1024, 512), cpu)?)?; + let rhs = Var::from_tensor(&Tensor::randn(0f32, 1f32, (512, 768), cpu)?)?; + let res = lhs.matmul(&rhs)?; + + // https://github.com/ggerganov/llama.cpp/blob/678d7994f4da0af3d29046be99950ac999ee9762/examples/imatrix/imatrix.cpp#L180-L186 + values = (values + res.sqr()?.sum(0)?)?; + row_counts += res.dim(0)? as f64; + ncall += 1.; + } + + // https://github.com/ggerganov/llama.cpp/blob/678d7994f4da0af3d29046be99950ac999ee9762/examples/imatrix/imatrix.cpp#L275 + let out = ((values / row_counts)? * ncall)?; + let imatrix = out.to_vec1::()?; + + let xs = Tensor::randn(0f32, 1f32, (1024, 768), cpu)?; + + let quant1 = quantized::QTensor::quantize(&xs, GgmlDType::Iq4Xs)?; + let quant2 = quantized::QTensor::quantize_imatrix(&xs, &imatrix, GgmlDType::Iq4Xs)?; + + let dequant1 = quant1.dequantize(cpu)?; + let dequant2 = quant2.dequantize(cpu)?; + + let err1 = (dequant1 - &xs)?.abs()?.mean_all()?.to_scalar::()?; + let err2 = (dequant2 - &xs)?.abs()?.mean_all()?.to_scalar::()?; + assert!(err2 < err1, "err2 {err2} > err1 {err1}"); + Ok(()) } diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index e77697340e..bb1c1d9870 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,5 +1,8 @@ use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; -use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle::{ + quantized::{GgmlDType, QTensor}, + DType, Device, IndexOp, Result, Tensor, D, +}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; use std::{collections::HashMap, f32::consts::PI}; @@ -276,6 +279,16 @@ impl CausalSelfAttention { let k = self.k_proj.forward(x)?; let v = self.v_proj.forward(x)?; + let q = QTensor::quantize(&q, GgmlDType::Iq4Xs)? + .dequantize(q.device())? + .to_dtype(q.dtype())?; + let k = QTensor::quantize(&k, GgmlDType::Iq4Xs)? + .dequantize(q.device())? + .to_dtype(q.dtype())?; + let v = QTensor::quantize(&v, GgmlDType::Iq4Xs)? + .dequantize(q.device())? + .to_dtype(q.dtype())?; + let q = q .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? .transpose(1, 2)? From cabf37fecbe5b70927b373ff7e67404b6e88d350 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 2 Feb 2025 23:14:12 -0500 Subject: [PATCH 07/26] Implement neon vec dot --- candle-core/src/quantized/iq_quants/mod.rs | 87 +++++++++++++++++++--- candle-core/src/quantized/neon.rs | 61 ++++++++++++++- candle-transformers/src/models/llama.rs | 34 ++++----- 3 files changed, 152 insertions(+), 30 deletions(-) diff --git a/candle-core/src/quantized/iq_quants/mod.rs b/candle-core/src/quantized/iq_quants/mod.rs index b7281ebb8c..8099ee9b38 100644 --- a/candle-core/src/quantized/iq_quants/mod.rs +++ b/candle-core/src/quantized/iq_quants/mod.rs @@ -9,7 +9,7 @@ use super::{BlockQ8K, GgmlDType, GgmlType, QK_K}; pub const QK4_NL: usize = 32; -const KVALUES_IQ4NL: [i8; 16] = [ +pub(super) const KVALUES_IQ4NL: [i8; 16] = [ -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113, ]; @@ -98,11 +98,86 @@ impl GgmlType for BlockIQ4xs { #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - todo!() + // #[cfg(target_feature = "avx")] + // todo!(); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_iq4_xs_q8k(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) } fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - todo!() + if n % QK_K != 0 { + bail!("n must be a multiple of QK_K"); + } + let nb = n / QK_K; + + let mut sumf = 0.0f32; + + // Loop over each block + for ibl in 0..nb { + // x[ibl], y[ibl] + let x = &xs[ibl]; + let y = &ys[ibl]; + + // Convert x.d from fp16 to fp32, then multiply by y.d + let d4d8 = x.d.to_f32() * y.d; + + // We'll track the "h" scales + let mut h = x.scales_h; + + let mut qsi = 0; // index for x.qs + let mut q8i = 0; // index for y.qs + + // so we step by 2 in Rust as well + for ib2 in (0..(QK_K / 32)).step_by(2) { + // Reproduce the logic of ls1, ls2 + let ls1 = (x.scales_l[ib2 / 2] & 0x0f) | (((h << 4) & 0x30) as u8); + let ls2 = (x.scales_l[ib2 / 2] >> 4) | (((h << 2) & 0x30) as u8); + // Then we shift h by 4 in the original code + h >>= 4; + + // Convert ls1, ls2 to "scaled" floats + let d1 = d4d8 * ((ls1 as i32) - 32) as f32; + let d2 = d4d8 * ((ls2 as i32) - 32) as f32; + + // Two sets of 16 items each + // sum of the first 16-lane block + let mut sumi1 = 0; + let mut sumi2 = 0; + + // The first pass + for j in 0..16 { + // q8i + j vs q8i + j + 16 + // qs[qsi + j] & 0xf vs (qs[qsi + j] >> 4) + sumi1 += (y.qs[q8i + j] as i32) + * KVALUES_IQ4NL[(x.qs[qsi + j] & 0x0f) as usize] as i32; + sumi2 += (y.qs[q8i + j + 16] as i32) + * KVALUES_IQ4NL[((x.qs[qsi + j] >> 4) & 0x0f) as usize] as i32; + } + sumf += d1 * ((sumi1 + sumi2) as f32); + + qsi += 16; + q8i += 32; + + // The second pass + sumi1 = 0; + sumi2 = 0; + for j in 0..16 { + sumi1 += (y.qs[q8i + j] as i32) + * KVALUES_IQ4NL[(x.qs[qsi + j] & 0x0f) as usize] as i32; + sumi2 += (y.qs[q8i + j + 16] as i32) + * KVALUES_IQ4NL[((x.qs[qsi + j] >> 4) & 0x0f) as usize] as i32; + } + sumf += d2 * ((sumi1 + sumi2) as f32); + + qsi += 16; + q8i += 32; + } + } + + Ok(sumf) } } @@ -113,7 +188,6 @@ fn quantize_iq4_xs( n_per_row: usize, quant_weights: Option<&[f32]>, ) -> Result<()> { - // Basic sanity checks, similar to the C macro GGML_ASSERT if n_per_row % QK_K != 0 { bail!("n_per_row must be multiple of QK_K = {}", QK_K); } @@ -128,7 +202,6 @@ fn quantize_iq4_xs( ); } - // We'll need some local buffers that match the usage in the C code let mut lbuf = vec![0u8; QK_K]; // L[QK_K] let mut weight = vec![0f32; 32]; // weight[32] (the block_size is 32) let mut scales = vec![0f32; QK_K / 32]; // scales[QK_K/32], e.g. 256/32=8 @@ -139,10 +212,8 @@ fn quantize_iq4_xs( for _row in 0..nrow { // Each row has `nblock` blocks: for ibl in 0..nblock { - // In C: block_iq4_xs * iq4 = (block_iq4_xs *)qrow; let block = &mut ys[dst_offset + ibl]; - // quant_weights? let qw = quant_weights.map(|qw_all| { let start = QK_K * ibl; &qw_all[start..start + QK_K] @@ -202,8 +273,6 @@ pub fn quantize_iq4_xs_imatrix( // 4. Outer loop over rows for _row in 0..nrow { - // In C: block_iq4_xs * iq4 = (block_iq4_xs *)qrow; - // Here: let block_slice = &mut dst[dst_offset..dst_offset + nblock]; for ibl in 0..nblock { let block = &mut dst[dst_offset + ibl]; diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index c4d5d6f41a..e469afb835 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -1,7 +1,11 @@ -use super::k_quants::{ - BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, +use super::{ + iq_quants::BlockIQ4xs, + k_quants::{ + BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, + QK_K, + }, }; -use crate::Result; +use crate::{quantized::KVALUES_IQ4NL, Result}; use byteorder::{ByteOrder, LittleEndian}; #[allow(unused_imports)] @@ -611,3 +615,54 @@ unsafe fn multiply_accum_with_scale( let p2 = vdotq_s32(q2bytes.1, q8bytes.1); vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32 } + +#[inline(always)] +pub(crate) fn vec_dot_iq4_xs_q8k(n: usize, xs: &[BlockIQ4xs], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_iq4_xs_q8k: {n} is not divisible by {QK_K}") + } + + unsafe { + let values = vld1q_s8(KVALUES_IQ4NL.as_ptr()); + let m4b = vdupq_n_u8(0x0f); + + let mut q4b = int8x16x4_t(vdupq_n_s8(0), vdupq_n_s8(0), vdupq_n_s8(0), vdupq_n_s8(0)); + + let mut sumf = 0f32; + + let nb = n / QK_K; + for ibl in 0..nb { + let mut q8 = ys[ibl].qs.as_ptr(); + let mut q4 = xs[ibl].qs.as_ptr(); + let mut h = xs[ibl].scales_h; + + let mut sumi1 = 0; + let mut sumi2 = 0; + for ib in 0..(QK_K / 64) { + let q4bits = vld1q_u8_x2(q4); + q4 = q4.add(32); + let q8b = vld1q_s8_x4(q8); + q8 = q8.add(64); + + q4b.0 = vqtbl1q_s8(values, vandq_u8(q4bits.0, m4b)); + q4b.1 = vqtbl1q_s8(values, vshrq_n_u8(q4bits.0, 4)); + q4b.2 = vqtbl1q_s8(values, vandq_u8(q4bits.1, m4b)); + q4b.3 = vqtbl1q_s8(values, vshrq_n_u8(q4bits.1, 4)); + + let prod1 = vaddq_s32(vdotq_s32(q4b.0, q8b.0), vdotq_s32(q4b.1, q8b.1)); + let prod2 = vaddq_s32(vdotq_s32(q4b.2, q8b.2), vdotq_s32(q4b.3, q8b.3)); + + let ls1 = (xs[ibl].scales_l[ib] & 0xf) as i32 | ((h << 4) & 0x30) as i32 - 32; + let ls2 = (xs[ibl].scales_l[ib] >> 4) as i32 | ((h << 2) & 0x30) as i32 - 32; + h = h >> 4; + + sumi1 += vaddvq_s32(prod1) * ls1; + sumi2 += vaddvq_s32(prod2) * ls2; + } + + sumf += xs[ibl].d.to_f32() * ys[ibl].d * (sumi1 + sumi2) as f32; + } + + Ok(sumf) + } +} diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index bb1c1d9870..074e3e194b 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,9 +1,9 @@ -use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; +use super::with_tracing::RmsNorm; use candle::{ - quantized::{GgmlDType, QTensor}, + quantized::{GgmlDType, QMatMul, QTensor}, DType, Device, IndexOp, Result, Tensor, D, }; -use candle_nn::{embedding, Embedding, Module, VarBuilder}; +use candle_nn::{embedding, linear_no_bias as linear, Embedding, Linear, Module, VarBuilder}; use std::{collections::HashMap, f32::consts::PI}; pub const DEFAULT_MAX_SEQ_LEN: usize = 4096; @@ -228,10 +228,10 @@ impl Cache { #[derive(Debug, Clone)] struct CausalSelfAttention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, - o_proj: Linear, + q_proj: QMatMul, + k_proj: QMatMul, + v_proj: QMatMul, + o_proj: QMatMul, num_attention_heads: usize, num_key_value_heads: usize, head_dim: usize, @@ -279,16 +279,6 @@ impl CausalSelfAttention { let k = self.k_proj.forward(x)?; let v = self.v_proj.forward(x)?; - let q = QTensor::quantize(&q, GgmlDType::Iq4Xs)? - .dequantize(q.device())? - .to_dtype(q.dtype())?; - let k = QTensor::quantize(&k, GgmlDType::Iq4Xs)? - .dequantize(q.device())? - .to_dtype(q.dtype())?; - let v = QTensor::quantize(&v, GgmlDType::Iq4Xs)? - .dequantize(q.device())? - .to_dtype(q.dtype())?; - let q = q .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? .transpose(1, 2)? @@ -378,6 +368,14 @@ impl CausalSelfAttention { let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; + + println!("A"); + let q_proj = QMatMul::from_qtensor(QTensor::quantize(q_proj.weight(), GgmlDType::Iq4Xs)?)?; + let k_proj = QMatMul::from_qtensor(QTensor::quantize(k_proj.weight(), GgmlDType::F32)?)?; + let v_proj = QMatMul::from_qtensor(QTensor::quantize(v_proj.weight(), GgmlDType::F32)?)?; + let o_proj = QMatMul::from_qtensor(QTensor::quantize(o_proj.weight(), GgmlDType::F32)?)?; + println!("B"); + Ok(Self { q_proj, k_proj, @@ -524,7 +522,7 @@ impl Llama { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let lm_head = if cfg.tie_word_embeddings { - Linear::from_weights(wte.embeddings().clone(), None) + Linear::new(wte.embeddings().clone(), None) } else { linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? }; From 0dce868daf569ce8656ad9b0ab573dc500e1da63 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 3 Feb 2025 06:00:34 -0500 Subject: [PATCH 08/26] Small optimization --- candle-core/src/quantized/neon.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index e469afb835..b598ef1e76 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -632,9 +632,11 @@ pub(crate) fn vec_dot_iq4_xs_q8k(n: usize, xs: &[BlockIQ4xs], ys: &[BlockQ8K]) - let nb = n / QK_K; for ibl in 0..nb { - let mut q8 = ys[ibl].qs.as_ptr(); - let mut q4 = xs[ibl].qs.as_ptr(); - let mut h = xs[ibl].scales_h; + let y_block = &ys[ibl]; + let x_block = &xs[ibl]; + let mut q8 = y_block.qs.as_ptr(); + let mut q4 = x_block.qs.as_ptr(); + let mut h = x_block.scales_h; let mut sumi1 = 0; let mut sumi2 = 0; @@ -652,8 +654,8 @@ pub(crate) fn vec_dot_iq4_xs_q8k(n: usize, xs: &[BlockIQ4xs], ys: &[BlockQ8K]) - let prod1 = vaddq_s32(vdotq_s32(q4b.0, q8b.0), vdotq_s32(q4b.1, q8b.1)); let prod2 = vaddq_s32(vdotq_s32(q4b.2, q8b.2), vdotq_s32(q4b.3, q8b.3)); - let ls1 = (xs[ibl].scales_l[ib] & 0xf) as i32 | ((h << 4) & 0x30) as i32 - 32; - let ls2 = (xs[ibl].scales_l[ib] >> 4) as i32 | ((h << 2) & 0x30) as i32 - 32; + let ls1 = (x_block.scales_l[ib] & 0xf) as i32 | ((h << 4) & 0x30) as i32 - 32; + let ls2 = (x_block.scales_l[ib] >> 4) as i32 | ((h << 2) & 0x30) as i32 - 32; h = h >> 4; sumi1 += vaddvq_s32(prod1) * ls1; From 8f111f7756eab6ca6d741f80820ce63c0289d918 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 3 Feb 2025 06:02:28 -0500 Subject: [PATCH 09/26] Clippy --- candle-core/src/quantized/iq_quants/mod.rs | 4 +--- candle-core/src/quantized/iq_quants/utils.rs | 11 +++-------- candle-core/src/quantized/neon.rs | 2 +- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/candle-core/src/quantized/iq_quants/mod.rs b/candle-core/src/quantized/iq_quants/mod.rs index 8099ee9b38..8e4d90106b 100644 --- a/candle-core/src/quantized/iq_quants/mod.rs +++ b/candle-core/src/quantized/iq_quants/mod.rs @@ -40,9 +40,7 @@ impl GgmlType for BlockIQ4xs { } let nb = k / QK_K; - for i in 0..nb { - let block = &xs[i]; - + for block in xs.iter().take(nb) { let d = block.d.to_f32(); let mut qs = &block.qs[..]; diff --git a/candle-core/src/quantized/iq_quants/utils.rs b/candle-core/src/quantized/iq_quants/utils.rs index 228f9e7d78..206f8cb110 100644 --- a/candle-core/src/quantized/iq_quants/utils.rs +++ b/candle-core/src/quantized/iq_quants/utils.rs @@ -30,8 +30,8 @@ pub(super) fn quantize_row_iq4_nl_impl( // 1. compute sigma2 let mut sigma2 = 0f32; - for j in 0..super_block_size { - sigma2 += x[j] * x[j]; + for x in x.iter().take(super_block_size) { + sigma2 += x * x; } sigma2 *= 2.0 / (super_block_size as f32); @@ -155,12 +155,7 @@ pub(super) fn quantize_row_iq4_nl_impl( for ib in 0..nblocks { // l = nearest_int(id * scales[ib]), clamp to [-32..31] let mut l = (id * scales[ib]).round() as i32; - if l < -32 { - l = -32; - } - if l > 31 { - l = 31; - } + l = l.clamp(-32, 31); // refine block let dl = d * (l as f32); diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index b598ef1e76..c5fffc9194 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -656,7 +656,7 @@ pub(crate) fn vec_dot_iq4_xs_q8k(n: usize, xs: &[BlockIQ4xs], ys: &[BlockQ8K]) - let ls1 = (x_block.scales_l[ib] & 0xf) as i32 | ((h << 4) & 0x30) as i32 - 32; let ls2 = (x_block.scales_l[ib] >> 4) as i32 | ((h << 2) & 0x30) as i32 - 32; - h = h >> 4; + h >>= 4; sumi1 += vaddvq_s32(prod1) * ls1; sumi2 += vaddvq_s32(prod2) * ls2; From dbe13c2873470c0a52659edfd9f449a4bd476f61 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 3 Feb 2025 08:48:42 -0500 Subject: [PATCH 10/26] Revert mask in sdpa vector kernels --- candle-metal-kernels/src/lib.rs | 19 +-- .../src/scaled_dot_product_attention.metal | 161 +++++++++--------- 2 files changed, 86 insertions(+), 94 deletions(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 7231640081..706e07c1e4 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -17,11 +17,7 @@ const CONV: &str = include_str!("conv.metal"); const FILL: &str = include_str!("fill.metal"); const INDEXING: &str = include_str!("indexing.metal"); // Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle -#[cfg(not(target_os = "ios"))] const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); -// Current source: https://github.com/philipturner/metal-flash-attention/releases/tag/v1.0.1 -#[cfg(target_os = "ios")] -const MFA: &[u8] = include_bytes!("libMetalFlashAttention.ios.metallib"); const MLX_GEMM: &str = include_str!("mlx_gemm.metal"); const QUANTIZED: &str = include_str!("quantized.metal"); const RANDOM: &str = include_str!("random.metal"); @@ -2017,12 +2013,7 @@ pub fn call_sdpa_vector( alpha }; - let constants = Some(ConstantValues::new(vec![( - 20, - Value::Bool(/* sdpa_vector_has_mask */ false), - )])); - - let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, &name, constants)?; + let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -2134,13 +2125,7 @@ pub fn call_sdpa_vector_2pass( alpha }; - let constants = Some(ConstantValues::new(vec![( - 20, - Value::Bool(/* sdpa_vector_has_mask */ false), - )])); - - let pipeline = - kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?; + let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass1)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal index a8873ad681..6afad8126b 100644 --- a/candle-metal-kernels/src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -299,8 +299,6 @@ struct MLXScaledDotProductAttentionParams { // ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.sdpa_vector" -constant bool sdpa_vector_has_mask [[function_constant(20)]]; - template [[kernel]] void sdpa_vector( const device T* queries [[buffer(0)]], @@ -313,16 +311,14 @@ template const constant size_t& v_stride, const constant float& scale, const constant float& softcapping, - const device bool* mask [[function_constant(sdpa_vector_has_mask)]], - const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], - const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; constexpr int BD = 32; constexpr int elem_per_thread = D / BD; - constexpr int stride = BN * D; + + const int stride = BN * D; typedef float U; @@ -340,9 +336,6 @@ template queries += head_idx * D + simd_lid * elem_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; - if (sdpa_vector_has_mask) { - mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride; - } out += head_idx * D + simd_gid * elem_per_thread; // Read the query and 0 the output accumulator @@ -358,43 +351,38 @@ template // For each key for (int i = simd_gid; i < N; i += BN) { - if (!sdpa_vector_has_mask || mask[0]) { - // Read the key - for (int j = 0; j < elem_per_thread; j++) { - k[j] = keys[j]; - } + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } - // Compute the i-th score - U score = 0; - for (int j = 0; j < elem_per_thread; j++) { - score += q[j] * k[j]; - } - score = simd_sum(score); - if (softcapping != 1.) { - score = precise::tanh(score); - score = score * softcapping; - } + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + if (softcapping != 1.) { + score = precise::tanh(score); + score = score * softcapping; + } - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; - // Update the output accumulator - for (int j = 0; j < elem_per_thread; j++) { - o[j] = o[j] * factor + exp_score * values[j]; - } + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; } // Move the pointers to the next kv keys += stride; values += stride; - if (sdpa_vector_has_mask) { - mask += BN * mask_seq_stride; - } } // Each thread has a partial part of the output so we need to combine them. @@ -440,9 +428,6 @@ template const constant size_t& v_stride, const constant float& scale, const constant float& softcapping, - const device bool* mask [[function_constant(sdpa_vector_has_mask)]], - const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], - const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -472,10 +457,6 @@ template values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D + simd_lid * elem_per_thread; out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread; - if (sdpa_vector_has_mask) { - mask += head_idx * mask_head_stride + - (block_idx * BN + simd_gid) * mask_seq_stride; - } sums += head_idx * blocks + block_idx; maxs += head_idx * blocks + block_idx; @@ -492,43 +473,75 @@ template // For each key for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { - if (!sdpa_vector_has_mask || mask[0]) { - // Read the key - for (int i = 0; i < elem_per_thread; i++) { - k[i] = keys[i]; - } + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } - // Compute the i-th score - U score = 0; - for (int i = 0; i < elem_per_thread; i++) { - score += q[i] * k[i]; - } - score = simd_sum(score); - if (softcapping != 1.) { - score = precise::tanh(score); - score = score * softcapping; - } + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + if (softcapping != 1.) { + score = precise::tanh(score); + score = score * softcapping; + } - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; - // Update the output accumulator - for (int i = 0; i < elem_per_thread; i++) { - o[i] = o[i] * factor + exp_score * values[i]; - } + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; } // Move the pointers to the next kv keys += blocks * stride; values += blocks * stride; - if (sdpa_vector_has_mask) { - mask += BN * blocks * mask_seq_stride; + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; + sum_exp_score = simd_sum(sum_exp_score * factor); + + // Write the sum and new max + if (simd_gid == 0) { + sums[0] = sum_exp_score; + maxs[0] = new_max; + } + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BN + simd_gid] = + o[i] * fast::exp(max_scores[simd_gid] - new_max); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // And write the output + if (simd_gid == 0) { + U output = outputs[simd_lid * BN]; + for (int j = 1; j < BN; j++) { + output += outputs[simd_lid * BN + j]; + } + out[i] = static_cast(output); } + threadgroup_barrier(mem_flags::mem_threadgroup); } } @@ -1656,9 +1669,6 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); const constant size_t& v_stride, \ const constant float& scale, \ const constant float& softcapping, \ - const device bool* mask [[function_constant(sdpa_vector_has_mask)]], \ - const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \ - const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); \ @@ -1676,9 +1686,6 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); const constant size_t& v_stride, \ const constant float& scale, \ const constant float& softcapping, \ - const device bool* mask [[function_constant(sdpa_vector_has_mask)]], \ - const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \ - const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); \ From 9e0f120c34550d5908b3c5a2f2301fd19875b8ca Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 3 Feb 2025 18:58:20 -0500 Subject: [PATCH 11/26] Update check_shape --- candle-core/src/quantized/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 761069656f..555714d727 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -393,7 +393,7 @@ fn check_shape(shape: &Shape, block_size: usize) -> Result<()> { if dims.is_empty() { crate::bail!("scalar tensor cannot be quantized {shape:?}") } - if dims[dims.len() - 1] % block_size != 0 { + if dims.iter().product::() % block_size != 0 { crate::bail!( "quantized tensor must have their last dim divisible by block size {shape:?} {}", block_size From 7ff6e0903fa3978bb68e72543f1b9b55d5b59a47 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 3 Feb 2025 19:02:26 -0500 Subject: [PATCH 12/26] Update check_shape --- candle-core/src/quantized/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 555714d727..761069656f 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -393,7 +393,7 @@ fn check_shape(shape: &Shape, block_size: usize) -> Result<()> { if dims.is_empty() { crate::bail!("scalar tensor cannot be quantized {shape:?}") } - if dims.iter().product::() % block_size != 0 { + if dims[dims.len() - 1] % block_size != 0 { crate::bail!( "quantized tensor must have their last dim divisible by block size {shape:?} {}", block_size From e4efab3243cb3620c0a31e31f910c6aeb263d081 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 3 Feb 2025 19:21:57 -0500 Subject: [PATCH 13/26] Support qtensor_from_ggml --- candle-core/src/quantized/ggml_file.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index ea5ec02578..bae6cafdae 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -1,6 +1,6 @@ //! Support for the GGML file format. -use super::{k_quants, GgmlDType, QStorage}; +use super::{iq_quants, k_quants, GgmlDType, QStorage}; use crate::{Device, Result}; use byteorder::{LittleEndian, ReadBytesExt}; use std::collections::HashMap; @@ -184,6 +184,9 @@ pub fn qtensor_from_ggml( GgmlDType::Q6K => { from_raw_data::(raw_data, size_in_bytes, dims, device) } + GgmlDType::Iq4Xs => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } From 1baec7c48d3eaccbdf73ce2a79d0dcf6e650c86a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gergely=20B=C3=A1lint?= Date: Fri, 27 Oct 2023 14:54:13 +0200 Subject: [PATCH 14/26] [AArch64] Quantized MatMul performance improvement on Arm CPUs --- candle-core/Cargo.toml | 2 + candle-core/src/lib.rs | 5 + candle-core/src/quantized/iq_quants/mod.rs | 13 + candle-core/src/quantized/k_quants/mod.rs | 179 +++ candle-core/src/quantized/neon.rs | 1593 +++++++++++++++++++- candle-core/src/quantized/quants.rs | 238 ++- 6 files changed, 1943 insertions(+), 87 deletions(-) diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 752e478d9b..cd44ec3944 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -20,6 +20,7 @@ gemm = { workspace = true } half = { workspace = true } float8 = { workspace = true } intel-mkl-src = { workspace = true, optional = true } +itertools = "0.12.1" libc = { workspace = true, optional = true } memmap2 = { workspace = true } num-traits = { workspace = true } @@ -46,6 +47,7 @@ nccl = ["cuda", "cudarc/nccl"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] metal = ["dep:metal", "dep:candle-metal-kernels"] +arm-nightly-feat = [] [[bench]] name = "bench_main" diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index f7571c45de..a0efa00998 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -47,6 +47,11 @@ //! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models. //! +#![cfg_attr(feature = "arm-nightly-feat", feature(stdarch_neon_dotprod))] +#![cfg_attr(feature = "arm-nightly-feat", feature(array_chunks))] +#![cfg_attr(feature = "arm-nightly-feat", feature(stdarch_neon_i8mm))] +#![cfg_attr(feature = "arm-nightly-feat", feature(portable_simd))] + #[cfg(feature = "accelerate")] mod accelerate; pub mod backend; diff --git a/candle-core/src/quantized/iq_quants/mod.rs b/candle-core/src/quantized/iq_quants/mod.rs index 8e4d90106b..faa61fc08f 100644 --- a/candle-core/src/quantized/iq_quants/mod.rs +++ b/candle-core/src/quantized/iq_quants/mod.rs @@ -32,6 +32,7 @@ impl GgmlType for BlockIQ4xs { const DTYPE: GgmlDType = GgmlDType::Iq4Xs; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = false; fn to_float(xs: &[Self], mut ys: &mut [f32]) -> Result<()> { let k = ys.len(); @@ -177,6 +178,18 @@ impl GgmlType for BlockIQ4xs { Ok(sumf) } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } } fn quantize_iq4_xs( diff --git a/candle-core/src/quantized/k_quants/mod.rs b/candle-core/src/quantized/k_quants/mod.rs index 3b3b533ecb..e00fa3a906 100644 --- a/candle-core/src/quantized/k_quants/mod.rs +++ b/candle-core/src/quantized/k_quants/mod.rs @@ -138,6 +138,7 @@ impl GgmlType for BlockQ4_0 { const DTYPE: GgmlDType = GgmlDType::Q4_0; const BLCK_SIZE: usize = QK4_0; type VecDotType = BlockQ8_0; + const SUPPORTS_I8MM: bool = true; // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { @@ -232,12 +233,28 @@ impl GgmlType for BlockQ4_0 { } Ok(sumf) } + + #[allow(unreachable_code)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q4_0_q8_0(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ4_1 { const DTYPE: GgmlDType = GgmlDType::Q4_1; const BLCK_SIZE: usize = QK4_1; type VecDotType = BlockQ8_1; + const SUPPORTS_I8MM: bool = false; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { Self::vec_dot_unopt(n, xs, ys) @@ -327,12 +344,26 @@ impl GgmlType for BlockQ4_1 { } Ok(()) } + + #[allow(unreachable_code)] + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ5_0 { const DTYPE: GgmlDType = GgmlDType::Q5_0; const BLCK_SIZE: usize = QK5_0; type VecDotType = BlockQ8_0; + const SUPPORTS_I8MM: bool = false; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { let qk = Self::BLCK_SIZE; @@ -429,12 +460,25 @@ impl GgmlType for BlockQ5_0 { } Ok(()) } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ5_1 { const DTYPE: GgmlDType = GgmlDType::Q5_1; const BLCK_SIZE: usize = QK5_1; type VecDotType = BlockQ8_1; + const SUPPORTS_I8MM: bool = false; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { Self::vec_dot_unopt(n, xs, ys) @@ -537,12 +581,25 @@ impl GgmlType for BlockQ5_1 { } Ok(()) } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ8_0 { const DTYPE: GgmlDType = GgmlDType::Q8_0; const BLCK_SIZE: usize = QK8_0; type VecDotType = BlockQ8_0; + const SUPPORTS_I8MM: bool = true; // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { @@ -627,12 +684,29 @@ impl GgmlType for BlockQ8_0 { } Ok(sumf) } + + #[allow(unreachable_code)] + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q8_0_q8_0(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ8_1 { const DTYPE: GgmlDType = GgmlDType::Q8_1; const BLCK_SIZE: usize = QK8_1; type VecDotType = BlockQ8_1; + const SUPPORTS_I8MM: bool = false; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { Self::vec_dot_unopt(n, xs, ys) @@ -673,12 +747,23 @@ impl GgmlType for BlockQ8_1 { fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> { unimplemented!("no support for vec-dot on Q8_1") } + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + _n: usize, + _xs_0: &[Self], + _xs_1: &[Self], + _ys_0: &[Self::VecDotType], + _ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + unimplemented!("no support for i8mm matmul on Q8_1") + } } impl GgmlType for BlockQ2K { const DTYPE: GgmlDType = GgmlDType::Q2K; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = true; #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -915,12 +1000,28 @@ impl GgmlType for BlockQ2K { } Ok(()) } + + #[allow(unreachable_code)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q2k_q8k(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ3K { const DTYPE: GgmlDType = GgmlDType::Q3K; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = false; #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -1297,12 +1398,29 @@ impl GgmlType for BlockQ3K { Ok(()) } + + #[allow(unreachable_code)] + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q3k_q8k(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ4K { const DTYPE: GgmlDType = GgmlDType::Q4K; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = true; #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -1562,6 +1680,22 @@ impl GgmlType for BlockQ4K { } Ok(()) } + + #[allow(unreachable_code)] + #[allow(dead_code)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q4k_q8k(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928 @@ -1569,6 +1703,7 @@ impl GgmlType for BlockQ5K { const DTYPE: GgmlDType = GgmlDType::Q5K; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = true; #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -1873,12 +2008,29 @@ impl GgmlType for BlockQ5K { } Ok(()) } + + #[allow(unreachable_code)] + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q5k_q8k(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ6K { const DTYPE: GgmlDType = GgmlDType::Q6K; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = true; #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -2143,12 +2295,29 @@ impl GgmlType for BlockQ6K { } Ok(()) } + + #[allow(unreachable_code)] + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q6k_q8k(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ8K { const DTYPE: GgmlDType = GgmlDType::Q8K; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = false; #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -2235,4 +2404,14 @@ impl GgmlType for BlockQ8K { } Ok(()) } + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + _n: usize, + _xs_0: &[Self], + _xs_1: &[Self], + _ys_0: &[Self::VecDotType], + _ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + unreachable!(); + } } diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index c5fffc9194..af388ca7ee 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -7,6 +7,8 @@ use super::{ }; use crate::{quantized::KVALUES_IQ4NL, Result}; use byteorder::{ByteOrder, LittleEndian}; +#[cfg(feature = "arm-nightly-feat")] +use itertools::izip; #[allow(unused_imports)] #[cfg(target_arch = "arm")] @@ -15,14 +17,8 @@ use core::arch::arm::*; #[allow(unused_imports)] #[cfg(target_arch = "aarch64")] use core::arch::aarch64::*; - -#[inline(always)] -unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t { - // TODO: dotprod - let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); - let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); - vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)) -} +#[cfg(feature = "arm-nightly-feat")] +use std::arch::is_aarch64_feature_detected; #[inline(always)] pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { @@ -55,8 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); - let pl0 = vdotq_s32(v0_0ls, v1_0l); - let ph0 = vdotq_s32(v0_0hs, v1_0h); + let pl0 = vdotq_s32_local(vdupq_n_s32(0), v0_0ls, v1_0l); + let ph0 = vdotq_s32_local(vdupq_n_s32(0), v0_0hs, v1_0h); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), @@ -66,7 +62,85 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Ok(vaddvq_f32(sumv0)) } } +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q4_0_q8_0( + n: usize, + xs_0: &[BlockQ4_0], + xs_1: &[BlockQ4_0], + ys_0: &[BlockQ8_0], + ys_1: &[BlockQ8_0], +) -> Result<[f32; 4]> { + let qk = QK8_0; + let nb = n / qk; + if n % QK8_0 != 0 { + crate::bail!("i8mm_q4_0_q8_0: {n} is not divisible by {qk}") + } + //let (xs_0, xs_1) = xs.split_at_mut(xs.len() / 2); + //let (ys_0, ys_1) = ys.split_at_mut(ys.len() / 2); + assert_eq!(xs_0.len(), xs_1.len()); + assert_eq!(ys_0.len(), ys_1.len()); + assert_eq!(xs_0.len(), ys_0.len()); + + unsafe { + let mut sum_f32 = vdupq_n_f32(0.0); + let m4b = vdupq_n_u8(0x0F); + let s8b = vdupq_n_s8(0x8); + + for i in 0..nb { + let x0 = &xs_0[i]; + let x1 = &xs_1[i]; + let y0 = &ys_0[i]; + let y1 = &ys_1[i]; + + let factor_00: f32 = x0.d.to_f32() * y0.d.to_f32(); + let factor_01: f32 = x1.d.to_f32() * y0.d.to_f32(); + let factor_10: f32 = x0.d.to_f32() * y1.d.to_f32(); + let factor_11: f32 = x1.d.to_f32() * y1.d.to_f32(); + + let xv0 = vld1q_u8(x0.qs.as_ptr()); //16xu8 + let xv1 = vld1q_u8(x1.qs.as_ptr()); //16xu8 + + // convert u8s to i4s so we have equal amount of row elements + // and columns elements to multiply + let xv0_0 = vreinterpretq_s8_u8(vandq_u8(xv0, m4b)); + let xv0_1 = vreinterpretq_s8_u8(vshrq_n_u8(xv0, 4)); + let xv1_0 = vreinterpretq_s8_u8(vandq_u8(xv1, m4b)); + let xv1_1 = vreinterpretq_s8_u8(vshrq_n_u8(xv1, 4)); + + // sub 8 + let xv0_0s = vsubq_s8(xv0_0, s8b); + let xv0_1s = vsubq_s8(xv0_1, s8b); + let xv1_0s = vsubq_s8(xv1_0, s8b); + let xv1_1s = vsubq_s8(xv1_1, s8b); + //end of conversion + + let yv0_0 = vld1q_s8(y0.qs.as_ptr()); //16xi8 + let yv0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); // 16xi8 + let yv1_0 = vld1q_s8(y1.qs.as_ptr()); //16xi8 + let yv1_1 = vld1q_s8(y1.qs.as_ptr().add(16)); // 16xi8 + + let i8mm = i8mm_params::new(xv0_0s, xv0_1s, xv1_0s, xv1_1s, yv0_0, yv0_1, yv1_0, yv1_1); + let loop_sum_s32 = i8mm.calculate(vdupq_n_s32(0)); + + // scaling + let factor_elems: [f32; 4] = [factor_00, factor_01, factor_10, factor_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + let loop_sum_f32 = vcvtq_f32_s32(loop_sum_s32); + + sum_f32 = vmlaq_f32(sum_f32, loop_sum_f32, factor); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sum_f32, 0); + let f1 = vgetq_lane_f32(sum_f32, 1); + let f2 = vgetq_lane_f32(sum_f32, 2); + let f3 = vgetq_lane_f32(sum_f32, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} #[inline(always)] pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result { let qk = QK8_0; @@ -87,8 +161,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); - let p0 = vdotq_s32(x0_0, y0_0); - let p1 = vdotq_s32(x0_1, y0_1); + let p0 = vdotq_s32_local(vdupq_n_s32(0), x0_0, y0_0); + let p1 = vdotq_s32_local(vdupq_n_s32(0), x0_1, y0_1); sumv0 = vmlaq_n_f32( sumv0, @@ -99,6 +173,67 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Ok(vaddvq_f32(sumv0)) } } +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q8_0_q8_0( + n: usize, + xs_0: &[BlockQ8_0], + xs_1: &[BlockQ8_0], + ys_0: &[BlockQ8_0], + ys_1: &[BlockQ8_0], +) -> Result<[f32; 4]> { + assert_eq!(xs_0.len(), xs_1.len()); + assert_eq!(ys_0.len(), ys_1.len()); + assert_eq!(xs_0.len(), ys_0.len()); + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("i8mm_q8_0_q8_0: {n} is not divisible by {qk}") + } + let nb = n / QK8_0; + unsafe { + let mut sum_f32 = vdupq_n_f32(0.0); + + for i in 0..nb { + let x0 = &xs_0[i]; + let x1 = &xs_1[i]; + let y0 = &ys_0[i]; + let y1 = &ys_1[i]; + + let factor_00: f32 = x0.d.to_f32() * y0.d.to_f32(); + let factor_01: f32 = x1.d.to_f32() * y0.d.to_f32(); + let factor_10: f32 = x0.d.to_f32() * y1.d.to_f32(); + let factor_11: f32 = x1.d.to_f32() * y1.d.to_f32(); + + let xv0_0 = vld1q_s8(x0.qs.as_ptr()); + let xv0_1 = vld1q_s8(x0.qs.as_ptr().add(16)); + let xv1_0 = vld1q_s8(x1.qs.as_ptr()); + let xv1_1 = vld1q_s8(x1.qs.as_ptr().add(16)); + + let yv0_0 = vld1q_s8(y0.qs.as_ptr()); + let yv0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); + let yv1_0 = vld1q_s8(y1.qs.as_ptr()); + let yv1_1 = vld1q_s8(y1.qs.as_ptr().add(16)); + + let i8mm = i8mm_params::new(xv0_0, xv0_1, xv1_0, xv1_1, yv0_0, yv0_1, yv1_0, yv1_1); + let loop_sum_s32 = i8mm.calculate(vdupq_n_s32(0)); + + // scaling + let factor_elems: [f32; 4] = [factor_00, factor_01, factor_10, factor_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + let loop_sum_f32 = vcvtq_f32_s32(loop_sum_s32); + + sum_f32 = vmlaq_f32(sum_f32, loop_sum_f32, factor); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sum_f32, 0); + let f1 = vgetq_lane_f32(sum_f32, 1); + let f2 = vgetq_lane_f32(sum_f32, 2); + let f3 = vgetq_lane_f32(sum_f32, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} #[inline(always)] pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result { @@ -117,8 +252,8 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res for i in (0..QK_K).step_by(16) { let xs = vld1q_s8(xs.add(i)); let ys = vld1q_s8(ys.add(i)); - let xy = vdotq_s32(xs, ys); - sum_i = vaddq_s32(sum_i, xy) + + sum_i = vdotq_s32_local(sum_i, xs, ys); } sumf += vaddvq_s32(sum_i) as f32 * scale } @@ -134,8 +269,8 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let mut sum = 0f32; unsafe { let m4b = vdupq_n_u8(0xF); - let mone = vdupq_n_u8(3); + let mzero = vdupq_n_s32(0); for (x, y) in xs.iter().zip(ys.iter()) { let d_all = x.d.to_f32(); @@ -187,14 +322,14 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3)); - let p0 = vdotq_s32(q6bytes_0, q8bytes.0); - let p1 = vdotq_s32(q6bytes_1, q8bytes.1); + let p0 = vdotq_s32_local(mzero, q6bytes_0, q8bytes.0); + let p1 = vdotq_s32_local(mzero, q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vdotq_s32(q6bytes_2, q8bytes.2); - let p3 = vdotq_s32(q6bytes_3, q8bytes.3); + let p2 = vdotq_s32_local(mzero, q6bytes_2, q8bytes.2); + let p3 = vdotq_s32_local(mzero, q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); @@ -216,14 +351,14 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3)); - let p0 = vdotq_s32(q6bytes_0, q8bytes.0); - let p1 = vdotq_s32(q6bytes_1, q8bytes.1); + let p0 = vdotq_s32_local(mzero, q6bytes_0, q8bytes.0); + let p1 = vdotq_s32_local(mzero, q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vdotq_s32(q6bytes_2, q8bytes.2); - let p3 = vdotq_s32(q6bytes_3, q8bytes.3); + let p2 = vdotq_s32_local(mzero, q6bytes_2, q8bytes.2); + let p3 = vdotq_s32_local(mzero, q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); @@ -233,6 +368,274 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res } Ok(sum) } +// QK_K = 256 +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q6k_q8k( + _n: usize, + xs_0: &[BlockQ6K], + xs_1: &[BlockQ6K], + ys_0: &[BlockQ8K], + ys_1: &[BlockQ8K], +) -> Result<[f32; 4]> { + unsafe { + let mut fsum = vdupq_n_f32(0.0); + let m4b = vdupq_n_u8(0xF); + let mone = vdupq_n_u8(3); + for (x0, x1, y0, y1) in izip!(xs_0, xs_1, ys_0, ys_1) { + let d_00: f32 = x0.d.to_f32() * y0.d; + let d_01: f32 = x1.d.to_f32() * y0.d; + let d_10: f32 = x0.d.to_f32() * y1.d; + let d_11: f32 = x1.d.to_f32() * y1.d; + + let mut q6_0 = x0.ql.as_ptr(); + let mut q6_1 = x1.ql.as_ptr(); + let mut qh_0 = x0.qh.as_ptr(); + let mut qh_1 = x1.qh.as_ptr(); + let mut q8_0 = y0.qs.as_ptr(); + let mut q8_1 = y1.qs.as_ptr(); + + let mut scale_0 = x0.scales.as_ptr(); + let mut scale_1 = x1.scales.as_ptr(); + + let q8sums_0 = vld1q_s16_x2(y0.bsums.as_ptr()); + let q8sums_1 = vld1q_s16_x2(y1.bsums.as_ptr()); + let scales_0 = vld1q_s8(scale_0); + let scales_1 = vld1q_s8(scale_1); + + let q6scales_0 = int16x8x2_t( + vmovl_s8(vget_low_s8(scales_0)), + vmovl_s8(vget_high_s8(scales_0)), + ); + let q6scales_1 = int16x8x2_t( + vmovl_s8(vget_low_s8(scales_1)), + vmovl_s8(vget_high_s8(scales_1)), + ); + + // y0 x0 + let prod_00 = vaddq_s32( + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0.0), vget_low_s16(q6scales_0.0)), + vmull_s16(vget_high_s16(q8sums_0.0), vget_high_s16(q6scales_0.0)), + ), + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0.1), vget_low_s16(q6scales_0.1)), + vmull_s16(vget_high_s16(q8sums_0.1), vget_high_s16(q6scales_0.1)), + ), + ); + // y0 x1 + let prod_01 = vaddq_s32( + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0.0), vget_low_s16(q6scales_1.0)), + vmull_s16(vget_high_s16(q8sums_0.0), vget_high_s16(q6scales_1.0)), + ), + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0.1), vget_low_s16(q6scales_1.1)), + vmull_s16(vget_high_s16(q8sums_0.1), vget_high_s16(q6scales_1.1)), + ), + ); + // y1 x0 + let prod_10 = vaddq_s32( + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1.0), vget_low_s16(q6scales_0.0)), + vmull_s16(vget_high_s16(q8sums_1.0), vget_high_s16(q6scales_0.0)), + ), + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1.1), vget_low_s16(q6scales_0.1)), + vmull_s16(vget_high_s16(q8sums_1.1), vget_high_s16(q6scales_0.1)), + ), + ); + // y1 x1 + let prod_11 = vaddq_s32( + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1.0), vget_low_s16(q6scales_1.0)), + vmull_s16(vget_high_s16(q8sums_1.0), vget_high_s16(q6scales_1.0)), + ), + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1.1), vget_low_s16(q6scales_1.1)), + vmull_s16(vget_high_s16(q8sums_1.1), vget_high_s16(q6scales_1.1)), + ), + ); + let sumi_mins_00 = vaddvq_s32(prod_00); + let sumi_mins_01 = vaddvq_s32(prod_01); + let sumi_mins_10 = vaddvq_s32(prod_10); + let sumi_mins_11 = vaddvq_s32(prod_11); + + let mut isum = vdupq_n_s32(0); + for _j in 0..QK_K / 128 { + let qhbits_0 = vld1q_u8_x2(qh_0); + let qhbits_1 = vld1q_u8_x2(qh_1); + qh_0 = qh_0.add(32); + qh_1 = qh_1.add(32); + + let q6bits_0 = vld1q_u8_x4(q6_0); + let q6bits_1 = vld1q_u8_x4(q6_1); + q6_0 = q6_0.add(64); + q6_1 = q6_1.add(64); + + let q8bytes0_0 = vld1q_s8_x4(q8_0); + let q8bytes1_0 = vld1q_s8_x4(q8_1); + q8_0 = q8_0.add(64); + q8_1 = q8_1.add(64); + + let q8bytes0_1 = vld1q_s8_x4(q8_0); + let q8bytes1_1 = vld1q_s8_x4(q8_1); + q8_0 = q8_0.add(64); + q8_1 = q8_1.add(64); + + let q6h0_0 = vshlq_n_u8(vandq_u8(mone, qhbits_0.0), 4); + let q6h0_1 = vshlq_n_u8(vandq_u8(mone, qhbits_0.1), 4); + let shifted = vshrq_n_u8(qhbits_0.0, 2); + let q6h0_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_0.1, 2); + let q6h0_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + let q6h1_0 = vshlq_n_u8(vandq_u8(mone, qhbits_1.0), 4); + let q6h1_1 = vshlq_n_u8(vandq_u8(mone, qhbits_1.1), 4); + let shifted = vshrq_n_u8(qhbits_1.0, 2); + let q6h1_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_1.1, 2); + let q6h1_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + let q6bytes0_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_0.0, m4b), q6h0_0)); + let q6bytes0_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_0.1, m4b), q6h0_1)); + let q6bytes0_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_0.2, m4b), q6h0_2)); + let q6bytes0_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_0.3, m4b), q6h0_3)); + + let q6bytes1_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_1.0, m4b), q6h1_0)); + let q6bytes1_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_1.1, m4b), q6h1_1)); + let q6bytes1_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_1.2, m4b), q6h1_2)); + let q6bytes1_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_1.3, m4b), q6h1_3)); + + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(0) as i32, + x01: *scale_0.add(1) as i32, + x10: *scale_1.add(0) as i32, + x11: *scale_1.add(1) as i32, + }); + let i8mm = i8mm_params::new( + q6bytes0_0, + q6bytes0_1, + q6bytes1_0, + q6bytes1_1, + q8bytes0_0.0, + q8bytes0_0.1, + q8bytes1_0.0, + q8bytes1_0.1, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(2) as i32, + x01: *scale_0.add(3) as i32, + x10: *scale_1.add(2) as i32, + x11: *scale_1.add(3) as i32, + }); + let i8mm = i8mm_params::new( + q6bytes0_2, + q6bytes0_3, + q6bytes1_2, + q6bytes1_3, + q8bytes0_0.2, + q8bytes0_0.3, + q8bytes1_0.2, + q8bytes1_0.3, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + scale_0 = scale_0.add(4); + scale_1 = scale_1.add(4); + + let shifted = vshrq_n_u8(qhbits_0.0, 4); + let q6h0_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_0.1, 4); + let q6h0_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_0.0, 6); + let q6h0_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_0.1, 6); + let q6h0_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + let shifted = vshrq_n_u8(qhbits_1.0, 4); + let q6h1_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_1.1, 4); + let q6h1_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_1.0, 6); + let q6h1_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_1.1, 6); + let q6h1_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + let q6bytes0_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_0.0, 4), q6h0_0)); + let q6bytes0_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_0.1, 4), q6h0_1)); + let q6bytes0_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_0.2, 4), q6h0_2)); + let q6bytes0_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_0.3, 4), q6h0_3)); + + let q6bytes1_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_1.0, 4), q6h1_0)); + let q6bytes1_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_1.1, 4), q6h1_1)); + let q6bytes1_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_1.2, 4), q6h1_2)); + let q6bytes1_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_1.3, 4), q6h1_3)); + + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(0) as i32, + x01: *scale_0.add(1) as i32, + x10: *scale_1.add(0) as i32, + x11: *scale_1.add(1) as i32, + }); + let i8mm = i8mm_params::new( + q6bytes0_0, + q6bytes0_1, + q6bytes1_0, + q6bytes1_1, + q8bytes0_1.0, + q8bytes0_1.1, + q8bytes1_1.0, + q8bytes1_1.1, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(2) as i32, + x01: *scale_0.add(3) as i32, + x10: *scale_1.add(2) as i32, + x11: *scale_1.add(3) as i32, + }); + let i8mm = i8mm_params::new( + q6bytes0_2, + q6bytes0_3, + q6bytes1_2, + q6bytes1_3, + q8bytes0_1.2, + q8bytes0_1.3, + q8bytes1_1.2, + q8bytes1_1.3, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + scale_0 = scale_0.add(4); + scale_1 = scale_1.add(4); + } + let factor_elems: [f32; 4] = [d_00, d_01, d_10, d_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + //sum += d_all * y.d * ((isum - 32 * isum_mins) as f32); + let sumi_mins_arr: [i32; 4] = [ + -sumi_mins_00 * 32, + -sumi_mins_01 * 32, + -sumi_mins_10 * 32, + -sumi_mins_11 * 32, + ]; + let rawptr = &sumi_mins_arr as *const i32; + let sumi_minsv: int32x4_t = vld1q_s32(rawptr); + fsum = vmlaq_f32(fsum, factor, vcvtq_f32_s32(vaddq_s32(sumi_minsv, isum))); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(fsum, 0); + let f1 = vgetq_lane_f32(fsum, 1); + let f2 = vgetq_lane_f32(fsum, 2); + let f3 = vgetq_lane_f32(fsum, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} #[inline(always)] pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result { @@ -247,6 +650,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res unsafe { let m4b = vdupq_n_u8(0xF); + let mzero = vdupq_n_s32(0); let mone = vdupq_n_u8(1); let mtwo = vdupq_n_u8(2); @@ -302,13 +706,13 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); - let p0 = vdotq_s32(q5bytes_0, q8bytes.0); - let p1 = vdotq_s32(q5bytes_1, q8bytes.1); + let p0 = vdotq_s32_local(mzero, q5bytes_0, q8bytes.0); + let p1 = vdotq_s32_local(mzero, q5bytes_1, q8bytes.1); sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32; scales = scales.add(1); - let p2 = vdotq_s32(q5bytes_2, q8bytes.2); - let p3 = vdotq_s32(q5bytes_3, q8bytes.3); + let p2 = vdotq_s32_local(mzero, q5bytes_2, q8bytes.2); + let p3 = vdotq_s32_local(mzero, q5bytes_3, q8bytes.3); sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32; scales = scales.add(1); } @@ -317,6 +721,212 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res } Ok(sumf) } +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q5k_q8k( + _n: usize, + xs_0: &[BlockQ5K], + xs_1: &[BlockQ5K], + ys_0: &[BlockQ8K], + ys_1: &[BlockQ8K], +) -> Result<[f32; 4]> { + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + unsafe { + let mut sumfv = vdupq_n_f32(0.0); + let mut utmp_0 = [0u32; 4]; + let mut utmp_1 = [0u32; 4]; + let m4b = vdupq_n_u8(0xF); + let mone = vdupq_n_u8(1); + let mtwo = vdupq_n_u8(2); + let mzero = vdupq_n_s32(0); + for (x0, x1, y0, y1) in izip!(xs_0, xs_1, ys_0, ys_1) { + let d_00: f32 = x0.d.to_f32() * y0.d; + let d_01: f32 = x1.d.to_f32() * y0.d; + let d_10: f32 = x0.d.to_f32() * y1.d; + let d_11: f32 = x1.d.to_f32() * y1.d; + + let dmin_00 = -y0.d * x0.dmin.to_f32(); + let dmin_01 = -y0.d * x1.dmin.to_f32(); + let dmin_10 = -y1.d * x0.dmin.to_f32(); + let dmin_11 = -y1.d * x1.dmin.to_f32(); + + let q8sums_0 = vpaddq_s16( + vld1q_s16(y0.bsums.as_ptr()), + vld1q_s16(y0.bsums.as_ptr().add(8)), + ); + let q8sums_1 = vpaddq_s16( + vld1q_s16(y1.bsums.as_ptr()), + vld1q_s16(y1.bsums.as_ptr().add(8)), + ); + + LittleEndian::read_u32_into(&x0.scales, &mut utmp_0[0..3]); + LittleEndian::read_u32_into(&x1.scales, &mut utmp_1[0..3]); + + utmp_0[3] = ((utmp_0[2] >> 4) & KMASK2) | (((utmp_0[1] >> 6) & KMASK3) << 4); + let uaux = utmp_0[1] & KMASK1; + utmp_0[1] = (utmp_0[2] & KMASK2) | (((utmp_0[0] >> 6) & KMASK3) << 4); + utmp_0[2] = uaux; + utmp_0[0] &= KMASK1; + + utmp_1[3] = ((utmp_1[2] >> 4) & KMASK2) | (((utmp_1[1] >> 6) & KMASK3) << 4); + let uaux = utmp_1[1] & KMASK1; + utmp_1[1] = (utmp_1[2] & KMASK2) | (((utmp_1[0] >> 6) & KMASK3) << 4); + utmp_1[2] = uaux; + utmp_1[0] &= KMASK1; + + let mins8_0 = vld1_u8((utmp_0.as_ptr() as *const u8).add(8)); + let mins8_1 = vld1_u8((utmp_1.as_ptr() as *const u8).add(8)); + let mins_0 = vreinterpretq_s16_u16(vmovl_u8(mins8_0)); + let mins_1 = vreinterpretq_s16_u16(vmovl_u8(mins8_1)); + + // y0 x0 + let prod_00 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0), vget_low_s16(mins_0)), + vmull_s16(vget_high_s16(q8sums_0), vget_high_s16(mins_0)), + ); + // y0 x1 + let prod_01 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0), vget_low_s16(mins_1)), + vmull_s16(vget_high_s16(q8sums_0), vget_high_s16(mins_1)), + ); + // y1 x0 + let prod_10 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1), vget_low_s16(mins_0)), + vmull_s16(vget_high_s16(q8sums_1), vget_high_s16(mins_0)), + ); + // y1 x1 + let prod_11 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1), vget_low_s16(mins_1)), + vmull_s16(vget_high_s16(q8sums_1), vget_high_s16(mins_1)), + ); + let sumi_mins_00 = vaddvq_s32(prod_00); + let sumi_mins_01 = vaddvq_s32(prod_01); + let sumi_mins_10 = vaddvq_s32(prod_10); + let sumi_mins_11 = vaddvq_s32(prod_11); + + let mut scales_0 = utmp_0.as_ptr() as *const u8; + let mut scales_1 = utmp_1.as_ptr() as *const u8; + + let mut q5_0 = x0.qs.as_ptr(); + let mut q5_1 = x1.qs.as_ptr(); + let mut q8_0 = y0.qs.as_ptr(); + let mut q8_1 = y1.qs.as_ptr(); + + let mut qhbits_0 = vld1q_u8_x2(x0.qh.as_ptr()); + let mut qhbits_1 = vld1q_u8_x2(x1.qh.as_ptr()); + + let mut isum = vdupq_n_s32(0); + for _j in 0..QK_K / 64 { + let q5bits_0 = vld1q_u8_x2(q5_0); + let q5bits_1 = vld1q_u8_x2(q5_1); + q5_0 = q5_0.add(32); + q5_1 = q5_1.add(32); + let q8bytes_0 = vld1q_s8_x4(q8_0); + let q8bytes_1 = vld1q_s8_x4(q8_1); + q8_0 = q8_0.add(64); + q8_1 = q8_1.add(64); + + let q5h0_0 = vshlq_n_u8(vandq_u8(mone, qhbits_0.0), 4); + let q5h0_1 = vshlq_n_u8(vandq_u8(mone, qhbits_0.1), 4); + let q5h0_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits_0.0), 3); + let q5h0_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits_0.1), 3); + + let q5h1_0 = vshlq_n_u8(vandq_u8(mone, qhbits_1.0), 4); + let q5h1_1 = vshlq_n_u8(vandq_u8(mone, qhbits_1.1), 4); + let q5h1_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits_1.0), 3); + let q5h1_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits_1.1), 3); + + qhbits_0.0 = vshrq_n_u8(qhbits_0.0, 2); + qhbits_0.1 = vshrq_n_u8(qhbits_0.1, 2); + qhbits_1.0 = vshrq_n_u8(qhbits_1.0, 2); + qhbits_1.1 = vshrq_n_u8(qhbits_1.1, 2); + + let q5bytes0_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits_0.0, m4b), q5h0_0)); + let q5bytes0_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits_0.1, m4b), q5h0_1)); + let q5bytes0_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits_0.0, 4), q5h0_2)); + let q5bytes0_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits_0.1, 4), q5h0_3)); + + let q5bytes1_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits_1.0, m4b), q5h1_0)); + let q5bytes1_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits_1.1, m4b), q5h1_1)); + let q5bytes1_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits_1.0, 4), q5h1_2)); + let q5bytes1_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits_1.1, 4), q5h1_3)); + + let i8mm = i8mm_params::new( + q5bytes0_0, + q5bytes0_1, + q5bytes1_0, + q5bytes1_1, + q8bytes_0.0, + q8bytes_0.1, + q8bytes_1.0, + q8bytes_1.1, + ); + let i8mmres = i8mm.calculate(mzero); + + let sc_arr = [ + *scales_0 as i32, + *scales_1 as i32, + *scales_0 as i32, + *scales_1 as i32, + ]; + let rawptr = &sc_arr as *const i32; + let sc: int32x4_t = vld1q_s32(rawptr); + isum = vmlaq_s32(isum, i8mmres, sc); + + scales_0 = scales_0.add(1); + scales_1 = scales_1.add(1); + + let i8mm = i8mm_params::new( + q5bytes0_2, + q5bytes0_3, + q5bytes1_2, + q5bytes1_3, + q8bytes_0.2, + q8bytes_0.3, + q8bytes_1.2, + q8bytes_1.3, + ); + let i8mmres = i8mm.calculate(mzero); + let sc_arr = [ + *scales_0 as i32, + *scales_1 as i32, + *scales_0 as i32, + *scales_1 as i32, + ]; + let rawptr = &sc_arr as *const i32; + let sc: int32x4_t = vld1q_s32(rawptr); + isum = vmlaq_s32(isum, i8mmres, sc); + + scales_0 = scales_0.add(1); + scales_1 = scales_1.add(1); + } + let factor_elems: [f32; 4] = [d_00, d_01, d_10, d_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + + let dmin_arr: [f32; 4] = [dmin_00, dmin_01, dmin_10, dmin_11]; + let rawptr = &dmin_arr as *const f32; + let dminv: float32x4_t = vld1q_f32(rawptr); + + let sumi_mins_arr: [i32; 4] = [sumi_mins_00, sumi_mins_01, sumi_mins_10, sumi_mins_11]; + let rawptr = &sumi_mins_arr as *const i32; + let sumi_minsv: float32x4_t = vcvtq_f32_s32(vld1q_s32(rawptr)); + + let fsum = vcvtq_f32_s32(isum); + sumfv = vmlaq_f32(sumfv, fsum, factor); + sumfv = vmlaq_f32(sumfv, dminv, sumi_minsv); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sumfv, 0); + let f1 = vgetq_lane_f32(sumfv, 1); + let f2 = vgetq_lane_f32(sumfv, 2); + let f3 = vgetq_lane_f32(sumfv, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} #[inline(always)] pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result { @@ -332,6 +942,7 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res unsafe { let m4b = vdupq_n_u8(0xF); + let mzero = vdupq_n_s32(0); for (x, y) in xs.iter().zip(ys.iter()) { let d = y.d * x.d.to_f32(); @@ -378,8 +989,8 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), ); - let p0 = vdotq_s32(q4bytes.0, q8bytes.0); - let p1 = vdotq_s32(q4bytes.1, q8bytes.1); + let p0 = vdotq_s32_local(mzero, q4bytes.0, q8bytes.0); + let p1 = vdotq_s32_local(mzero, q4bytes.1, q8bytes.1); sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32; let q8bytes = vld1q_s8_x2(q8); @@ -388,8 +999,8 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), ); - let p2 = vdotq_s32(q4bytes.0, q8bytes.0); - let p3 = vdotq_s32(q4bytes.1, q8bytes.1); + let p2 = vdotq_s32_local(mzero, q4bytes.0, q8bytes.0); + let p3 = vdotq_s32_local(mzero, q4bytes.1, q8bytes.1); sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32; } sumf += d * (sumi1 + sumi2) as f32; @@ -398,6 +1009,202 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res Ok(sumf) } +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q4k_q8k( + n: usize, + xs_0: &[BlockQ4K], + xs_1: &[BlockQ4K], + ys_0: &[BlockQ8K], + ys_1: &[BlockQ8K], +) -> Result<[f32; 4]> { + if n % QK_K != 0 { + crate::bail!("i8mm_q4k_q8k: {n} is not divisible by {QK_K}") + } + let mut utmp_0 = [0u32; 4]; + let mut utmp_1 = [0u32; 4]; + let mut scales_0 = [0u8; 16]; + let mut scales_1 = [0u8; 16]; + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + unsafe { + let mut sumfv = vdupq_n_f32(0.0); + let m4b = vdupq_n_u8(0xF); + + for (x0, x1, y0, y1) in izip!(xs_0, xs_1, ys_0, ys_1) { + let d_00: f32 = x0.d.to_f32() * y0.d; + let d_01: f32 = x1.d.to_f32() * y0.d; + let d_10: f32 = x0.d.to_f32() * y1.d; + let d_11: f32 = x1.d.to_f32() * y1.d; + + let dmin_00 = x0.dmin.to_f32() * y0.d; + let dmin_01 = x1.dmin.to_f32() * y0.d; + let dmin_10 = x0.dmin.to_f32() * y1.d; + let dmin_11 = x1.dmin.to_f32() * y1.d; + + let q8sums_0 = vpaddq_s16( + vld1q_s16(y0.bsums.as_ptr()), + vld1q_s16(y0.bsums.as_ptr().add(8)), + ); + let q8sums_1 = vpaddq_s16( + vld1q_s16(y1.bsums.as_ptr()), + vld1q_s16(y1.bsums.as_ptr().add(8)), + ); + LittleEndian::read_u32_into(&x0.scales, &mut utmp_0[0..3]); + LittleEndian::read_u32_into(&x1.scales, &mut utmp_1[0..3]); + + let mins8_0 = vld1_u32( + [ + utmp_0[1] & KMASK1, + ((utmp_0[2] >> 4) & KMASK2) | (((utmp_0[1] >> 6) & KMASK3) << 4), + ] + .as_ptr(), + ); + let mins8_1 = vld1_u32( + [ + utmp_1[1] & KMASK1, + ((utmp_1[2] >> 4) & KMASK2) | (((utmp_1[1] >> 6) & KMASK3) << 4), + ] + .as_ptr(), + ); + utmp_0[1] = (utmp_0[2] & KMASK2) | (((utmp_0[0] >> 6) & KMASK3) << 4); + utmp_0[0] &= KMASK1; + + utmp_1[1] = (utmp_1[2] & KMASK2) | (((utmp_1[0] >> 6) & KMASK3) << 4); + utmp_1[0] &= KMASK1; + + let mins_0 = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8_0))); // from x0 + let mins_1 = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8_1))); // from x1 + + // y0 x0 + let prod_00 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0), vget_low_s16(mins_0)), + vmull_s16(vget_high_s16(q8sums_0), vget_high_s16(mins_0)), + ); + // y0 x1 + let prod_01 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0), vget_low_s16(mins_1)), + vmull_s16(vget_high_s16(q8sums_0), vget_high_s16(mins_1)), + ); + // y1 x0 + let prod_10 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1), vget_low_s16(mins_0)), + vmull_s16(vget_high_s16(q8sums_1), vget_high_s16(mins_0)), + ); + // y1 x1 + let prod_11 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1), vget_low_s16(mins_1)), + vmull_s16(vget_high_s16(q8sums_1), vget_high_s16(mins_1)), + ); + + let s = [ + -dmin_00 * vaddvq_s32(prod_00) as f32, + -dmin_01 * vaddvq_s32(prod_01) as f32, + -dmin_10 * vaddvq_s32(prod_10) as f32, + -dmin_11 * vaddvq_s32(prod_11) as f32, + ]; + let rawptr = &s as *const f32; + let sumdiff: float32x4_t = vld1q_f32(rawptr); + sumfv = vaddq_f32(sumfv, sumdiff); + + LittleEndian::write_u32_into(&utmp_0, &mut scales_0); + LittleEndian::write_u32_into(&utmp_1, &mut scales_1); + + let mut q4_0 = x0.qs.as_ptr(); + let mut q4_1 = x1.qs.as_ptr(); + let mut q8_0 = y0.qs.as_ptr(); + let mut q8_1 = y1.qs.as_ptr(); + + let mut sumi1 = vdupq_n_s32(0); + let mut sumi2 = vdupq_n_s32(0); + // 0..4 + for j in 0..QK_K / 64 { + let xv0 = vld1q_u8_x2(q4_0); + let xv0_0_original = xv0.0; + let xv0_1_original = xv0.1; + q4_0 = q4_0.add(32); + + let xv1 = vld1q_u8_x2(q4_1); + let xv1_0_original = xv1.0; + let xv1_1_original = xv1.1; + q4_1 = q4_1.add(32); + + let yv0 = vld1q_s8_x2(q8_0); + let yv0_0 = yv0.0; + let yv0_1 = yv0.1; + q8_0 = q8_0.add(32); + + let yv1 = vld1q_s8_x2(q8_1); + let yv1_0 = yv1.0; + let yv1_1 = yv1.1; + q8_1 = q8_1.add(32); + + let xv0_0 = vreinterpretq_s8_u8(vandq_u8(xv0_0_original, m4b)); + let xv0_1 = vreinterpretq_s8_u8(vandq_u8(xv0_1_original, m4b)); + let xv1_0 = vreinterpretq_s8_u8(vandq_u8(xv1_0_original, m4b)); + let xv1_1 = vreinterpretq_s8_u8(vandq_u8(xv1_1_original, m4b)); + + let i8mm = i8mm_params::new(xv0_0, xv0_1, xv1_0, xv1_1, yv0_0, yv0_1, yv1_0, yv1_1); + let p1 = i8mm.calculate(vdupq_n_s32(0)); + + // x0 | x1 + // y0 | sc_0 sc_1 + // y1 | sc_0 sc_1 + let scarr = [ + scales_0[2 * j] as i32, + scales_1[2 * j] as i32, + scales_0[2 * j] as i32, + scales_1[2 * j] as i32, + ]; + let rawptr = &scarr as *const i32; + let sc: int32x4_t = vld1q_s32(rawptr); + sumi1 = vmlaq_s32(sumi1, p1, sc); + + let yv0 = vld1q_s8_x2(q8_0); + let yv0_0 = yv0.0; + let yv0_1 = yv0.1; + q8_0 = q8_0.add(32); + let yv1 = vld1q_s8_x2(q8_1); + let yv1_0 = yv1.0; + let yv1_1 = yv1.1; + q8_1 = q8_1.add(32); + + let xv0_0 = vreinterpretq_s8_u8(vshrq_n_u8(xv0_0_original, 4)); + let xv0_1 = vreinterpretq_s8_u8(vshrq_n_u8(xv0_1_original, 4)); + let xv1_0 = vreinterpretq_s8_u8(vshrq_n_u8(xv1_0_original, 4)); + let xv1_1 = vreinterpretq_s8_u8(vshrq_n_u8(xv1_1_original, 4)); + + let i8mm = i8mm_params::new(xv0_0, xv0_1, xv1_0, xv1_1, yv0_0, yv0_1, yv1_0, yv1_1); + let p2 = i8mm.calculate(vdupq_n_s32(0)); + let sc_arr = [ + scales_0[2 * j + 1] as i32, + scales_1[2 * j + 1] as i32, + scales_0[2 * j + 1] as i32, + scales_1[2 * j + 1] as i32, + ]; + let rawptr = &sc_arr as *const i32; + let sc: int32x4_t = vld1q_s32(rawptr); + sumi2 = vmlaq_s32(sumi2, p2, sc); + } + let factor_elems: [f32; 4] = [d_00, d_01, d_10, d_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + + let loop_sum_f32 = vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)); + sumfv = vmlaq_f32(sumfv, loop_sum_f32, factor); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sumfv, 0); + let f1 = vgetq_lane_f32(sumfv, 1); + let f2 = vgetq_lane_f32(sumfv, 2); + let f3 = vgetq_lane_f32(sumfv, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} + #[inline(always)] pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { @@ -411,6 +1218,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res unsafe { let m3b = vdupq_n_u8(0x3); + let mzero = vdupq_n_s32(0); let m0 = vdupq_n_u8(1); let m1 = vshlq_n_u8(m0, 1); let m2 = vshlq_n_u8(m0, 2); @@ -468,10 +1276,10 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0); - let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1); - let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2); - let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3); + let p0 = vdotq_s32_local(mzero, q3bytes_0, q8bytes_1.0); + let p1 = vdotq_s32_local(mzero, q3bytes_1, q8bytes_1.1); + let p2 = vdotq_s32_local(mzero, q3bytes_2, q8bytes_1.2); + let p3 = vdotq_s32_local(mzero, q3bytes_3, q8bytes_1.3); isum += vaddvq_s32(p0) * *scale as i32 + vaddvq_s32(p1) * *scale.add(1) as i32 + vaddvq_s32(p2) * *scale.add(2) as i32 @@ -500,10 +1308,10 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0); - let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1); - let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2); - let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3); + let p0 = vdotq_s32_local(mzero, q3bytes_0, q8bytes_2.0); + let p1 = vdotq_s32_local(mzero, q3bytes_1, q8bytes_2.1); + let p2 = vdotq_s32_local(mzero, q3bytes_2, q8bytes_2.2); + let p3 = vdotq_s32_local(mzero, q3bytes_3, q8bytes_2.3); isum += vaddvq_s32(p0) * *scale as i32 + vaddvq_s32(p1) * *scale.add(1) as i32 + vaddvq_s32(p2) * *scale.add(2) as i32 @@ -521,6 +1329,293 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res Ok(sumf) } +// NOTE: disabled for now, see BlockQ3K::SUPPORTS_I8MM +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q3k_q8k( + n: usize, + xs_0: &[BlockQ3K], + xs_1: &[BlockQ3K], + ys_0: &[BlockQ8K], + ys_1: &[BlockQ8K], +) -> Result<[f32; 4]> { + if n % QK_K != 0 { + crate::bail!("i8mm_q3k_q8k: {n} is not divisible by {QK_K}") + } + unsafe { + let mut sumfv = vdupq_n_f32(0.0); + let mut utmp_0 = [0u32; 4]; + let mut utmp_1 = [0u32; 4]; + let mut aux_0 = [0u32; 3]; + let mut aux_1 = [0u32; 3]; + const KMASK1: u32 = 0x03030303; + const KMASK2: u32 = 0x0f0f0f0f; + + let m3b = vdupq_n_u8(0x3); + let m0 = vdupq_n_u8(1); + let m1 = vshlq_n_u8(m0, 1); + let m2 = vshlq_n_u8(m0, 2); + let m3 = vshlq_n_u8(m0, 3); + + for (x0, x1, y0, y1) in izip!(xs_0, xs_1, ys_0, ys_1) { + let d_00: f32 = x0.d.to_f32() * y0.d; + let d_01: f32 = x1.d.to_f32() * y0.d; + let d_10: f32 = x0.d.to_f32() * y1.d; + let d_11: f32 = x1.d.to_f32() * y1.d; + + let mut q3_0 = x0.qs.as_ptr(); + let mut q3_1 = x1.qs.as_ptr(); + + let qh_0 = x0.hmask.as_ptr(); + let qh_1 = x1.hmask.as_ptr(); + + let mut q8_0 = y0.qs.as_ptr(); + let mut q8_1 = y1.qs.as_ptr(); + + let mut qhbits_0 = vld1q_u8_x2(qh_0); + let mut qhbits_1 = vld1q_u8_x2(qh_1); + + let mut isum = vdupq_n_s32(0); + + // Set up scales + LittleEndian::read_u32_into(&x0.scales, &mut aux_0); + LittleEndian::read_u32_into(&x1.scales, &mut aux_1); + + utmp_0[3] = ((aux_0[1] >> 4) & KMASK2) | (((aux_0[2] >> 6) & KMASK1) << 4); + utmp_0[2] = ((aux_0[0] >> 4) & KMASK2) | (((aux_0[2] >> 4) & KMASK1) << 4); + utmp_0[1] = (aux_0[1] & KMASK2) | (((aux_0[2] >> 2) & KMASK1) << 4); + utmp_0[0] = (aux_0[0] & KMASK2) | ((aux_0[2] & KMASK1) << 4); + + utmp_1[3] = ((aux_1[1] >> 4) & KMASK2) | (((aux_1[2] >> 6) & KMASK1) << 4); + utmp_1[2] = ((aux_1[0] >> 4) & KMASK2) | (((aux_1[2] >> 4) & KMASK1) << 4); + utmp_1[1] = (aux_1[1] & KMASK2) | (((aux_1[2] >> 2) & KMASK1) << 4); + utmp_1[0] = (aux_1[0] & KMASK2) | ((aux_1[2] & KMASK1) << 4); + + let mut scale_0 = utmp_0.as_mut_ptr() as *mut i8; + for j in 0..16 { + *scale_0.add(j) -= 32i8 + } + let mut scale_1 = utmp_1.as_mut_ptr() as *mut i8; + for j in 0..16 { + *scale_1.add(j) -= 32i8 + } + for j in 0..QK_K / 128 { + let q3bits_0 = vld1q_u8_x2(q3_0); + let q3bits_1 = vld1q_u8_x2(q3_1); + q3_0 = q3_0.add(32); + q3_1 = q3_1.add(32); + + // "y0" + let q8bytes0_1 = vld1q_s8_x4(q8_0); + q8_0 = q8_0.add(64); + let q8bytes0_2 = vld1q_s8_x4(q8_0); + q8_0 = q8_0.add(64); + + // "y1" + let q8bytes1_1 = vld1q_s8_x4(q8_1); + q8_1 = q8_1.add(64); + let q8bytes1_2 = vld1q_s8_x4(q8_1); + q8_1 = q8_1.add(64); + + // "x0" + let q3h_0_0 = vshlq_n_u8(vbicq_u8(m0, qhbits_0.0), 2); + let q3h_0_1 = vshlq_n_u8(vbicq_u8(m0, qhbits_0.1), 2); + let q3h_0_2 = vshlq_n_u8(vbicq_u8(m1, qhbits_0.0), 1); + let q3h_0_3 = vshlq_n_u8(vbicq_u8(m1, qhbits_0.1), 1); + + // "x1" + let q3h_1_0 = vshlq_n_u8(vbicq_u8(m0, qhbits_1.0), 2); + let q3h_1_1 = vshlq_n_u8(vbicq_u8(m0, qhbits_1.1), 2); + let q3h_1_2 = vshlq_n_u8(vbicq_u8(m1, qhbits_1.0), 1); + let q3h_1_3 = vshlq_n_u8(vbicq_u8(m1, qhbits_1.1), 1); + + // "x0" + let q3bytes_0_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits_0.0, m3b)), + vreinterpretq_s8_u8(q3h_0_0), + ); + let q3bytes_0_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits_0.1, m3b)), + vreinterpretq_s8_u8(q3h_0_1), + ); + let q3bytes_0_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_0.0, 2), m3b)), + vreinterpretq_s8_u8(q3h_0_2), + ); + let q3bytes_0_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_0.1, 2), m3b)), + vreinterpretq_s8_u8(q3h_0_3), + ); + // "x1" + let q3bytes_1_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits_1.0, m3b)), + vreinterpretq_s8_u8(q3h_1_0), + ); + let q3bytes_1_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits_1.1, m3b)), + vreinterpretq_s8_u8(q3h_1_1), + ); + let q3bytes_1_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_1.0, 2), m3b)), + vreinterpretq_s8_u8(q3h_1_2), + ); + let q3bytes_1_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_1.1, 2), m3b)), + vreinterpretq_s8_u8(q3h_1_3), + ); + + /* 4 x0s, 4 x1s + * 4 y0s, 4 y1s + * 1 step of i8mm needs 2 of each + * -> 2 sets of i8mm calcs are needed + */ + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(0) as i32, + x01: *scale_0.add(1) as i32, + x10: *scale_1.add(0) as i32, + x11: *scale_1.add(1) as i32, + }); + let i8mm = i8mm_params::new( + q3bytes_0_0, + q3bytes_0_1, + q3bytes_1_0, + q3bytes_1_1, + q8bytes0_1.0, + q8bytes0_1.1, + q8bytes1_1.0, + q8bytes1_1.1, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(2) as i32, + x01: *scale_0.add(3) as i32, + x10: *scale_1.add(2) as i32, + x11: *scale_1.add(3) as i32, + }); + let i8mm = i8mm_params::new( + q3bytes_0_2, + q3bytes_0_3, + q3bytes_1_2, + q3bytes_1_3, + q8bytes0_1.2, + q8bytes0_1.3, + q8bytes1_1.2, + q8bytes1_1.3, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + scale_0 = scale_0.add(4); + scale_1 = scale_1.add(4); + + let q3h_0_0 = vbicq_u8(m2, qhbits_0.0); + let q3h_0_1 = vbicq_u8(m2, qhbits_0.1); + + let q3h_0_3 = vshrq_n_u8(vbicq_u8(m3, qhbits_0.1), 1); + + let q3h_1_0 = vbicq_u8(m2, qhbits_1.0); + let q3h_1_1 = vbicq_u8(m2, qhbits_1.1); + let q3h_1_2 = vshrq_n_u8(vbicq_u8(m3, qhbits_1.0), 1); + let q3h_1_3 = vshrq_n_u8(vbicq_u8(m3, qhbits_1.1), 1); + + let q3bytes_0_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_0.0, 4), m3b)), + vreinterpretq_s8_u8(q3h_0_0), + ); + let q3bytes_0_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_0.1, 4), m3b)), + vreinterpretq_s8_u8(q3h_0_1), + ); + let q3bytes_0_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_0.0, 6), m3b)), + vreinterpretq_s8_u8(q3h_0_2), + ); + let q3bytes_0_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_0.1, 6), m3b)), + vreinterpretq_s8_u8(q3h_0_3), + ); + + let q3bytes_1_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_1.0, 4), m3b)), + vreinterpretq_s8_u8(q3h_1_0), + ); + let q3bytes_1_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_1.1, 4), m3b)), + vreinterpretq_s8_u8(q3h_1_1), + ); + let q3bytes_1_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_1.0, 6), m3b)), + vreinterpretq_s8_u8(q3h_1_2), + ); + let q3bytes_1_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_1.1, 6), m3b)), + vreinterpretq_s8_u8(q3h_1_3), + ); + + // Same as above + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(0) as i32, + x01: *scale_0.add(1) as i32, + x10: *scale_1.add(0) as i32, + x11: *scale_1.add(1) as i32, + }); + let i8mm = i8mm_params::new( + q3bytes_0_0, + q3bytes_0_1, + q3bytes_1_0, + q3bytes_1_1, + q8bytes0_2.0, + q8bytes0_2.1, + q8bytes1_2.0, + q8bytes1_2.1, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(2) as i32, + x01: *scale_0.add(3) as i32, + x10: *scale_1.add(2) as i32, + x11: *scale_1.add(3) as i32, + }); + let i8mm = i8mm_params::new( + q3bytes_0_2, + q3bytes_0_3, + q3bytes_1_2, + q3bytes_1_3, + q8bytes0_2.2, + q8bytes0_2.3, + q8bytes1_2.2, + q8bytes1_2.3, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + scale_0 = scale_0.add(4); + scale_1 = scale_1.add(4); + + if j == 0 { + qhbits_0.0 = vshrq_n_u8(qhbits_0.0, 4); + qhbits_0.1 = vshrq_n_u8(qhbits_0.1, 4); + qhbits_1.0 = vshrq_n_u8(qhbits_1.0, 4); + qhbits_1.1 = vshrq_n_u8(qhbits_1.1, 4); + } + } + let factor_elems: [f32; 4] = [d_00, d_01, d_10, d_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + + let fsum = vcvtq_f32_s32(isum); + sumfv = vmlaq_f32(sumfv, fsum, factor); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sumfv, 0); + let f1 = vgetq_lane_f32(sumfv, 1); + let f2 = vgetq_lane_f32(sumfv, 2); + let f3 = vgetq_lane_f32(sumfv, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} + #[inline(always)] pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { @@ -564,7 +1659,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res let mut isum = 0i32; let mut is = 0usize; - // TODO: dotprod for _j in 0..QK_K / 128 { let q2bits = vld1q_u8_x2(q2); q2 = q2.add(32); @@ -603,6 +1697,206 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res Ok(sumf) } +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q2k_q8k( + _n: usize, + xs_0: &[BlockQ2K], + xs_1: &[BlockQ2K], + ys_0: &[BlockQ8K], + ys_1: &[BlockQ8K], +) -> Result<[f32; 4]> { + let mut aux_0 = [0u8; 16]; + let mut aux_1 = [0u8; 16]; + + unsafe { + let mut sumfv = vdupq_n_f32(0.0); + let m3 = vdupq_n_u8(0x3); + let m4 = vdupq_n_u8(0xF); + for (x0, x1, y0, y1) in izip!(xs_0, xs_1, ys_0, ys_1) { + let d_00: f32 = x0.d.to_f32() * y0.d; + let d_01: f32 = x1.d.to_f32() * y0.d; + let d_10: f32 = x0.d.to_f32() * y1.d; + let d_11: f32 = x1.d.to_f32() * y1.d; + + let dmin_00 = -y0.d * x0.dmin.to_f32(); + let dmin_01 = -y0.d * x1.dmin.to_f32(); + let dmin_10 = -y1.d * x0.dmin.to_f32(); + let dmin_11 = -y1.d * x1.dmin.to_f32(); + + let mut q2_0 = x0.qs.as_ptr(); + let mut q2_1 = x1.qs.as_ptr(); + let mut q8_0 = y0.qs.as_ptr(); + let mut q8_1 = y1.qs.as_ptr(); + + let sc_0 = x0.scales.as_ptr(); + let sc_1 = x1.scales.as_ptr(); + + let mins_and_scales_0 = vld1q_u8(sc_0); + let mins_and_scales_1 = vld1q_u8(sc_1); + + let scales_0 = vandq_u8(mins_and_scales_0, m4); + let scales_1 = vandq_u8(mins_and_scales_1, m4); + + vst1q_u8(aux_0.as_mut_ptr(), scales_0); + vst1q_u8(aux_1.as_mut_ptr(), scales_1); + + let mins_0 = vshrq_n_u8(mins_and_scales_0, 4); + let mins_1 = vshrq_n_u8(mins_and_scales_1, 4); + + let q8sums_0 = vld1q_s16_x2(y0.bsums.as_ptr()); + let q8sums_1 = vld1q_s16_x2(y1.bsums.as_ptr()); + + let mins16_0 = int16x8x2_t( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins_0))), + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins_0))), + ); + let mins16_1 = int16x8x2_t( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins_1))), + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins_1))), + ); + // x --> mins16 + // y --> q8sums + let s00l = vaddq_s32( + vmull_s16(vget_low_s16(mins16_0.0), vget_low_s16(q8sums_0.0)), + vmull_s16(vget_high_s16(mins16_0.0), vget_high_s16(q8sums_0.0)), + ); + let s00h = vaddq_s32( + vmull_s16(vget_low_s16(mins16_0.1), vget_low_s16(q8sums_0.1)), + vmull_s16(vget_high_s16(mins16_0.1), vget_high_s16(q8sums_0.1)), + ); + + // 01 -> y0 * x1 + let s01l = vaddq_s32( + vmull_s16(vget_low_s16(mins16_1.0), vget_low_s16(q8sums_0.0)), + vmull_s16(vget_high_s16(mins16_1.0), vget_high_s16(q8sums_0.0)), + ); + let s01h = vaddq_s32( + vmull_s16(vget_low_s16(mins16_1.1), vget_low_s16(q8sums_0.1)), + vmull_s16(vget_high_s16(mins16_1.1), vget_high_s16(q8sums_0.1)), + ); + + // 10 -> y1 * x0 + let s10l = vaddq_s32( + vmull_s16(vget_low_s16(mins16_0.0), vget_low_s16(q8sums_1.0)), + vmull_s16(vget_high_s16(mins16_0.0), vget_high_s16(q8sums_1.0)), + ); + let s10h = vaddq_s32( + vmull_s16(vget_low_s16(mins16_0.1), vget_low_s16(q8sums_1.1)), + vmull_s16(vget_high_s16(mins16_0.1), vget_high_s16(q8sums_1.1)), + ); + + // 11 -> y1 * x1 + let s11l = vaddq_s32( + vmull_s16(vget_low_s16(mins16_1.0), vget_low_s16(q8sums_1.0)), + vmull_s16(vget_high_s16(mins16_1.0), vget_high_s16(q8sums_1.0)), + ); + let s11h = vaddq_s32( + vmull_s16(vget_low_s16(mins16_1.1), vget_low_s16(q8sums_1.1)), + vmull_s16(vget_high_s16(mins16_1.1), vget_high_s16(q8sums_1.1)), + ); + + let sumf_elems: [f32; 4] = [ + dmin_00 * vaddvq_s32(vaddq_s32(s00l, s00h)) as f32, + dmin_01 * vaddvq_s32(vaddq_s32(s01l, s01h)) as f32, + dmin_10 * vaddvq_s32(vaddq_s32(s10l, s10h)) as f32, + dmin_11 * vaddvq_s32(vaddq_s32(s11l, s11h)) as f32, + ]; + let rawptr = &sumf_elems as *const f32; + sumfv = vaddq_f32(sumfv, vld1q_f32(rawptr)); + + let mut isum = vdupq_n_s32(0i32); + let mut is = 0usize; + + for _j in 0..QK_K / 128 { + let q2bits_0 = vld1q_u8_x2(q2_0); + q2_0 = q2_0.add(32); + let mut q2bytes_0 = int8x16x2_t( + vreinterpretq_s8_u8(vandq_u8(q2bits_0.0, m3)), + vreinterpretq_s8_u8(vandq_u8(q2bits_0.1, m3)), + ); + let q2bits_1 = vld1q_u8_x2(q2_1); + q2_1 = q2_1.add(32); + let mut q2bytes_1 = int8x16x2_t( + vreinterpretq_s8_u8(vandq_u8(q2bits_1.0, m3)), + vreinterpretq_s8_u8(vandq_u8(q2bits_1.1, m3)), + ); + + let q8bytes_0 = vld1q_s8_x2(q8_0); + q8_0 = q8_0.add(32); + let q8bytes_1 = vld1q_s8_x2(q8_1); + q8_1 = q8_1.add(32); + isum = vaddq_s32( + isum, + i8mm_accum_with_scale( + &aux_0, &aux_1, is, 0, q2bytes_0, q2bytes_1, q8bytes_0, q8bytes_1, + ), + ); + + q2bytes_0.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_0.0, 2), m3)); + q2bytes_0.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_0.1, 2), m3)); + q2bytes_1.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_1.0, 2), m3)); + q2bytes_1.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_1.1, 2), m3)); + let q8bytes_0 = vld1q_s8_x2(q8_0); + q8_0 = q8_0.add(32); + let q8bytes_1 = vld1q_s8_x2(q8_1); + q8_1 = q8_1.add(32); + isum = vaddq_s32( + isum, + i8mm_accum_with_scale( + &aux_0, &aux_1, is, 2, q2bytes_0, q2bytes_1, q8bytes_0, q8bytes_1, + ), + ); + + q2bytes_0.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_0.0, 4), m3)); + q2bytes_0.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_0.1, 4), m3)); + q2bytes_1.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_1.0, 4), m3)); + q2bytes_1.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_1.1, 4), m3)); + let q8bytes_0 = vld1q_s8_x2(q8_0); + q8_0 = q8_0.add(32); + let q8bytes_1 = vld1q_s8_x2(q8_1); + q8_1 = q8_1.add(32); + isum = vaddq_s32( + isum, + i8mm_accum_with_scale( + &aux_0, &aux_1, is, 4, q2bytes_0, q2bytes_1, q8bytes_0, q8bytes_1, + ), + ); + + q2bytes_0.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_0.0, 6), m3)); + q2bytes_0.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_0.1, 6), m3)); + q2bytes_1.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_1.0, 6), m3)); + q2bytes_1.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_1.1, 6), m3)); + let q8bytes_0 = vld1q_s8_x2(q8_0); + q8_0 = q8_0.add(32); + let q8bytes_1 = vld1q_s8_x2(q8_1); + q8_1 = q8_1.add(32); + isum = vaddq_s32( + isum, + i8mm_accum_with_scale( + &aux_0, &aux_1, is, 6, q2bytes_0, q2bytes_1, q8bytes_0, q8bytes_1, + ), + ); + is += 8; + } + + let factor_elems: [f32; 4] = [d_00, d_01, d_10, d_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + + let fsum = vcvtq_f32_s32(isum); + sumfv = vmlaq_f32(sumfv, fsum, factor); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sumfv, 0); + let f1 = vgetq_lane_f32(sumfv, 1); + let f2 = vgetq_lane_f32(sumfv, 2); + let f3 = vgetq_lane_f32(sumfv, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} + #[inline(always)] unsafe fn multiply_accum_with_scale( aux: &[u8; 16], @@ -611,8 +1905,9 @@ unsafe fn multiply_accum_with_scale( q2bytes: int8x16x2_t, q8bytes: int8x16x2_t, ) -> i32 { - let p1 = vdotq_s32(q2bytes.0, q8bytes.0); - let p2 = vdotq_s32(q2bytes.1, q8bytes.1); + let mzero = vdupq_n_s32(0); + let p1 = vdotq_s32_local(mzero, q2bytes.0, q8bytes.0); + let p2 = vdotq_s32_local(mzero, q2bytes.1, q8bytes.1); vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32 } @@ -651,8 +1946,10 @@ pub(crate) fn vec_dot_iq4_xs_q8k(n: usize, xs: &[BlockIQ4xs], ys: &[BlockQ8K]) - q4b.2 = vqtbl1q_s8(values, vandq_u8(q4bits.1, m4b)); q4b.3 = vqtbl1q_s8(values, vshrq_n_u8(q4bits.1, 4)); - let prod1 = vaddq_s32(vdotq_s32(q4b.0, q8b.0), vdotq_s32(q4b.1, q8b.1)); - let prod2 = vaddq_s32(vdotq_s32(q4b.2, q8b.2), vdotq_s32(q4b.3, q8b.3)); + let prod1 = + vdotq_s32_local(vdotq_s32_local(vdupq_n_s32(0), q4b.0, q8b.0), q4b.1, q8b.1); + let prod2 = + vdotq_s32_local(vdotq_s32_local(vdupq_n_s32(0), q4b.2, q8b.2), q4b.3, q8b.3); let ls1 = (x_block.scales_l[ib] & 0xf) as i32 | ((h << 4) & 0x30) as i32 - 32; let ls2 = (x_block.scales_l[ib] >> 4) as i32 | ((h << 2) & 0x30) as i32 - 32; @@ -668,3 +1965,211 @@ pub(crate) fn vec_dot_iq4_xs_q8k(n: usize, xs: &[BlockIQ4xs], ys: &[BlockQ8K]) - Ok(sumf) } } + +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +unsafe fn i8mm_accum_with_scale( + aux_0: &[u8; 16], + aux_1: &[u8; 16], + is: usize, + index: usize, + q2bytes_0: int8x16x2_t, + q2bytes_1: int8x16x2_t, + q8bytes_0: int8x16x2_t, + q8bytes_1: int8x16x2_t, +) -> int32x4_t { + let mzero = vdupq_n_s32(0); + + let c00 = aux_0[is + index] as i32; + let c01 = aux_0[is + index + 1] as i32; + let c10 = aux_1[is + index] as i32; + let c11 = aux_1[is + index + 1] as i32; + + let x00 = q2bytes_0.0; + let x01 = q2bytes_0.1; + let x10 = q2bytes_1.0; + let x11 = q2bytes_1.1; + + let y00 = q8bytes_0.0; + let y01 = q8bytes_0.1; + let y10 = q8bytes_1.0; + let y11 = q8bytes_1.1; + + let x_sc = x_scales { + x00: c00, + x01: c01, + x10: c10, + x11: c11, + }; + let i8mm_sc = i8mm_x_scales::new(&x_sc); + let mm = i8mm_params::new(x00, x01, x10, x11, y00, y01, y10, y11); + mm.calculate_with_scales(mzero, i8mm_sc) +} +#[allow(non_camel_case_types)] +#[cfg(feature = "arm-nightly-feat")] +struct i8mm_params { + x0: int8x16_t, + x1: int8x16_t, + x2: int8x16_t, + x3: int8x16_t, + y0: int8x16_t, + y1: int8x16_t, + y2: int8x16_t, + y3: int8x16_t, +} + +#[allow(non_camel_case_types)] +#[cfg(feature = "arm-nightly-feat")] +/// scales from scalar version +struct x_scales { + x00: i32, + x01: i32, + x10: i32, + x11: i32, +} +#[allow(non_camel_case_types)] +#[cfg(feature = "arm-nightly-feat")] +/// scales reorganized to fit i8mm calculations +struct i8mm_x_scales { + sc0: int32x4_t, + sc1: int32x4_t, +} + +#[cfg(feature = "arm-nightly-feat")] +impl i8mm_x_scales { + #[inline(always)] + unsafe fn new(sc: &x_scales) -> Self { + let v00 = vdupq_n_s32(sc.x00); + let v01 = vdupq_n_s32(sc.x01); + let v10 = vdupq_n_s32(sc.x10); + let v11 = vdupq_n_s32(sc.x11); + + let sc0 = vzip1q_s32(v00, v10); + let sc1 = vzip1q_s32(v01, v11); + + i8mm_x_scales { sc0, sc1 } + } +} + +#[cfg(feature = "arm-nightly-feat")] +impl i8mm_params { + #[inline(always)] + unsafe fn new( + xv0_0: int8x16_t, + xv0_1: int8x16_t, + xv1_0: int8x16_t, + xv1_1: int8x16_t, + yv0_0: int8x16_t, + yv0_1: int8x16_t, + yv1_0: int8x16_t, + yv1_1: int8x16_t, + ) -> Self { + // 1. 16xi8 -> 2xi64 + let xv0_0 = vreinterpretq_s64_s8(xv0_0); + let xv0_1 = vreinterpretq_s64_s8(xv0_1); + let xv1_0 = vreinterpretq_s64_s8(xv1_0); + let xv1_1 = vreinterpretq_s64_s8(xv1_1); + + let yv0_0 = vreinterpretq_s64_s8(yv0_0); + let yv0_1 = vreinterpretq_s64_s8(yv0_1); + let yv1_0 = vreinterpretq_s64_s8(yv1_0); + let yv1_1 = vreinterpretq_s64_s8(yv1_1); + + // 2. ZIP + let x0_0 = vzip1q_s64(xv0_0, xv1_0); + let x0_1 = vzip2q_s64(xv0_0, xv1_0); + let x1_0 = vzip1q_s64(xv0_1, xv1_1); + let x1_1 = vzip2q_s64(xv0_1, xv1_1); + + let y0_0 = vzip1q_s64(yv0_0, yv1_0); + let y0_1 = vzip2q_s64(yv0_0, yv1_0); + let y1_0 = vzip1q_s64(yv0_1, yv1_1); + let y1_1 = vzip2q_s64(yv0_1, yv1_1); + + // 3. interpret back + let x0_0 = vreinterpretq_s8_s64(x0_0); + let x0_1 = vreinterpretq_s8_s64(x0_1); + let x1_0 = vreinterpretq_s8_s64(x1_0); + let x1_1 = vreinterpretq_s8_s64(x1_1); + + let y0_0 = vreinterpretq_s8_s64(y0_0); + let y0_1 = vreinterpretq_s8_s64(y0_1); + let y1_0 = vreinterpretq_s8_s64(y1_0); + let y1_1 = vreinterpretq_s8_s64(y1_1); + + i8mm_params { + x0: x0_0, + x1: x0_1, + x2: x1_0, + x3: x1_1, + y0: y0_0, + y1: y0_1, + y2: y1_0, + y3: y1_1, + } + } + + #[inline(always)] + unsafe fn calculate(&self, acc: int32x4_t) -> int32x4_t { + if is_aarch64_feature_detected!("i8mm") { + self.impl_calc(acc) + } else { + // never takes this branch, but the check is needed + // for inlining the vmmlaq intrinsics + // see: + // https://community.arm.com/arm-community-blogs/b/architectures-and-processors-blog/posts/rust-neon-intrinsics + unreachable!(); + } + } + unsafe fn impl_calc(&self, acc: int32x4_t) -> int32x4_t { + let mut a = acc; + a = vmmlaq_s32(a, self.y0, self.x0); + a = vmmlaq_s32(a, self.y1, self.x1); + a = vmmlaq_s32(a, self.y2, self.x2); + vmmlaq_s32(a, self.y3, self.x3) + } + + unsafe fn calculate_with_scales(&self, acc: int32x4_t, scales: i8mm_x_scales) -> int32x4_t { + if is_aarch64_feature_detected!("i8mm") { + self.impl_calc_scales(acc, scales) + } else { + // never takes this branch, but the check is needed + // for inlining the vmmlaq intrinsics + // see: + // https://community.arm.com/arm-community-blogs/b/architectures-and-processors-blog/posts/rust-neon-intrinsics + unreachable!(); + } + } + #[inline(always)] + unsafe fn impl_calc_scales(&self, acc: int32x4_t, scales: i8mm_x_scales) -> int32x4_t { + let mzero = vdupq_n_s32(0); + let a = vmulq_s32(vmmlaq_s32(mzero, self.y0, self.x0), scales.sc0); + let b = vmulq_s32(vmmlaq_s32(mzero, self.y1, self.x1), scales.sc0); + let c = vmulq_s32(vmmlaq_s32(mzero, self.y2, self.x2), scales.sc1); + let d = vmulq_s32(vmmlaq_s32(mzero, self.y3, self.x3), scales.sc1); + + let mut sum; + sum = vaddq_s32(acc, a); + sum = vaddq_s32(sum, b); + sum = vaddq_s32(sum, c); + sum = vaddq_s32(sum, d); + sum + } +} + +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +unsafe fn vdotq_s32_local(vz: int32x4_t, a: int8x16_t, b: int8x16_t) -> int32x4_t { + if is_aarch64_feature_detected!("dotprod") { + vdotq_s32(vz, a, b) + } else { + unreachable!(); + } +} +#[inline(always)] +#[cfg(not(feature = "arm-nightly-feat"))] +unsafe fn vdotq_s32_local(vz: int32x4_t, a: int8x16_t, b: int8x16_t) -> int32x4_t { + let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); + let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + vaddq_s32(vz, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))) +} diff --git a/candle-core/src/quantized/quants.rs b/candle-core/src/quantized/quants.rs index db3230271f..77e95632fe 100644 --- a/candle-core/src/quantized/quants.rs +++ b/candle-core/src/quantized/quants.rs @@ -7,6 +7,7 @@ pub trait GgmlType: Sized + Clone + Send + Sync { const DTYPE: GgmlDType; const BLCK_SIZE: usize; type VecDotType: GgmlType; + const SUPPORTS_I8MM: bool; // This is only safe for types that include immediate values such as float/int/... fn zeros() -> Self { @@ -32,56 +33,24 @@ pub trait GgmlType: Sized + Clone + Send + Sync { /// Generic implementation of the dot product without simd optimizations. fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; -} - -// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 -pub fn matmul( - mkn: (usize, usize, usize), - lhs: &[f32], - rhs_t: &[T], - dst: &mut [f32], -) -> Result<()> { - let (m, k, n) = mkn; - if m * k != lhs.len() { - crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); - } - let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE); - let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE); - // TODO: Do not make this copy if the DotType is f32. - // TODO: Pre-allocate this. - let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; - for row_idx in 0..m { - let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; - let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; - T::VecDotType::from_float(lhs, lhs_b)? - } - let lhs_b = lhs_b.as_slice(); - - for row_idx in 0..m { - let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; - let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; - - let result: Result> = dst_row - .into_par_iter() - .enumerate() - .with_min_len(128) - .with_max_len(512) - .map(|(col_idx, dst)| { - let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; - T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value) - }) - .collect(); - - result?; - } - Ok(()) + /// Multiply 2 rows by 2 columns and return a 2x2 matrix + /// based on aarch64 NEON i8mm instructions + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]>; } impl GgmlType for f32 { const DTYPE: GgmlDType = GgmlDType::F32; const BLCK_SIZE: usize = 1; type VecDotType = f32; + const SUPPORTS_I8MM: bool = false; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { Self::vec_dot_unopt(n, xs, ys) @@ -114,12 +83,25 @@ impl GgmlType for f32 { ys.copy_from_slice(xs); Ok(()) } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for f16 { const DTYPE: GgmlDType = GgmlDType::F16; const BLCK_SIZE: usize = 1; type VecDotType = f16; + const SUPPORTS_I8MM: bool = false; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { Self::vec_dot_unopt(n, xs, ys) @@ -158,12 +140,25 @@ impl GgmlType for f16 { } Ok(()) } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for bf16 { const DTYPE: GgmlDType = GgmlDType::BF16; const BLCK_SIZE: usize = 1; type VecDotType = bf16; + const SUPPORTS_I8MM: bool = false; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { Self::vec_dot_unopt(n, xs, ys) @@ -202,4 +197,161 @@ impl GgmlType for bf16 { } Ok(()) } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } +} + +fn matmul_ref( + mkn: (usize, usize, usize), + lhs: &[f32], + rhs_t: &[T], + dst: &mut [f32], +) -> Result<()> { + let (m, k, n) = mkn; + if m * k != lhs.len() { + crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); + } + + let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE); + let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE); + // TODO: Do not make this copy if the DotType is f32. + // TODO: Pre-allocate this. + let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; + for row_idx in 0..m { + let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; + T::VecDotType::from_float(lhs, lhs_b)? + } + let lhs_b = lhs_b.as_slice(); + + for row_idx in 0..m { + let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; + + let result: Result> = dst_row + .into_par_iter() + .enumerate() + .with_min_len(128) + .with_max_len(512) + .map(|(col_idx, dst)| { + let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value) + }) + .collect(); + + result?; + } + Ok(()) +} + +// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 +#[cfg(not(all(feature = "arm-nightly-feat", target_feature = "i8mm")))] +pub fn matmul( + mkn: (usize, usize, usize), + lhs: &[f32], + rhs_t: &[T], + dst: &mut [f32], +) -> Result<()> { + matmul_ref(mkn, lhs, rhs_t, dst) +} + +#[cfg(all(feature = "arm-nightly-feat", target_feature = "i8mm"))] +pub fn matmul( + mkn: (usize, usize, usize), + lhs: &[f32], + rhs_t: &[T], + dst: &mut [f32], +) -> Result<()> { + if !T::SUPPORTS_I8MM { + return matmul_ref(mkn, lhs, rhs_t, dst); + } + + let (m, k, n) = mkn; + if m * k != lhs.len() { + crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); + } + + let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE; + let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE; + // TODO: Do not make this copy if the DotType is f32. + // TODO: Pre-allocate this. + let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; + for row_idx in 0..m { + let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; + T::VecDotType::from_float(lhs, lhs_b)? + } + let lhs_b = lhs_b.as_slice(); + + let m_even = m % 2 == 0; + let m_limit = if m_even { m } else { m - 1 }; + let n_even = n % 2 == 0; + let n_limit = if n_even { n } else { n - 1 }; + + for row_idx in (0..m_limit).step_by(2) { + let lhs_row_0 = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs_row_1 = &lhs_b[(row_idx + 1) * k_in_lhs_blocks..(row_idx + 2) * k_in_lhs_blocks]; + + let dst_2_rows = &mut dst[row_idx * n..(row_idx + 2) * n]; + let (dst_row_0, dst_row_1) = dst_2_rows.split_at_mut(dst_2_rows.len() / 2); + + let dst_row_0_n = &mut dst_row_0[0..n_limit]; + let dst_row_1_n = &mut dst_row_1[0..n_limit]; + + let _result: Vec<_> = dst_row_0_n + .par_chunks_mut(2) + .zip(dst_row_1_n.par_chunks_mut(2)) + .enumerate() + .with_min_len(128) + .with_max_len(512) + .map(|(half_of_col_idx, (dst_0, dst_1))| { + let col_idx = half_of_col_idx * 2; // each step has 2 columns + let rhs_col_0 = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + let rhs_col_1 = + &rhs_t[(col_idx + 1) * k_in_rhs_blocks..(col_idx + 2) * k_in_rhs_blocks]; + + T::matmul_i8mm(k, rhs_col_0, rhs_col_1, lhs_row_0, lhs_row_1).map(|mm| { + dst_0[0] = mm[0]; + dst_0[1] = mm[1]; + dst_1[0] = mm[2]; + dst_1[1] = mm[3]; + }) + }) + .collect(); + if !n_even { + let col_idx = n - 1; + let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + dst_row_0[col_idx] = T::vec_dot(k, rhs_col, lhs_row_0).unwrap(); + dst_row_1[col_idx] = T::vec_dot(k, rhs_col, lhs_row_1).unwrap(); + } + } + if !m_even { + let row_idx = m - 1; + let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + + let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; + let result: Result> = dst_row + .into_par_iter() + .enumerate() + .with_min_len(128) + .with_max_len(512) + .map(|(col_idx, dst)| { + let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value) + }) + .collect(); + + result?; + } + Ok(()) } From 25cbfcaf7a3da869f8e7f61d0cb75b78d0cc92c5 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 5 Feb 2025 21:21:16 -0500 Subject: [PATCH 15/26] Add metal support --- candle-core/src/quantized/metal.rs | 2 +- candle-core/tests/quantized_tests.rs | 36 +++++++++++++++++++++++++++- candle-metal-kernels/src/lib.rs | 29 +++++++++++++++------- 3 files changed, 57 insertions(+), 10 deletions(-) diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 511a5b6ae2..d4bf2afeda 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -391,7 +391,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K, GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K, GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, - GgmlDType::Iq4Xs => candle_metal_kernels::GgmlDType::Q8_0, + GgmlDType::Iq4Xs => candle_metal_kernels::GgmlDType::Iq4Xs, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16, diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 593c6b9bd6..bce0840f84 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -5,7 +5,7 @@ use candle_core::{ test_utils::to_vec2_round, DType, Device, IndexOp, Module, Result, Tensor, Var, }; -use quantized::{k_quants, GgmlType}; +use quantized::{iq_quants, k_quants, GgmlType}; use rand::prelude::*; const GGML_TEST_SIZE: usize = 32 * 128; @@ -1117,6 +1117,7 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { GgmlDType::Q5_0 => 0.001353, GgmlDType::Q5_1 => 0.00149, GgmlDType::Q8_0 => 0.000092, + GgmlDType::Iq4Xs => 0.001903, // Not from the ggml repo. GgmlDType::Q8K => 0.00065, @@ -1286,6 +1287,13 @@ quantized_matmul!( quantized_matmul_q3k_metal, GgmlDType::Q3K ); +quantized_matmul!( + quantized_matmul_iq4xs_bis, + quantized_matmul_iq4xs_cpu, + quantized_matmul_iq4xs_cuda, + quantized_matmul_iq4xs_metal, + GgmlDType::Q4K +); quantized_matmul!( quantized_matmul_q4k_bis, quantized_matmul_q4k_cpu, @@ -1394,6 +1402,32 @@ fn quantized_matmul_q4k() -> Result<()> { Ok(()) } +#[test] +fn quantized_matmul_iq4xs() -> Result<()> { + use iq_quants::BlockIQ4xs; + + let cpu = &Device::Cpu; + let (m, k, n) = (11, 512, 21); + let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); + + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Iq4Xs)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.442, 1.509, -0.293, 1.631]); + + ggml_matmul_error_test::()?; + + Ok(()) +} + #[test] fn quantized_matmul_q5k() -> Result<()> { use k_quants::BlockQ5K; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 706e07c1e4..c6710be672 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2447,6 +2447,7 @@ pub enum GgmlDType { F16, F32, BF16, + Iq4Xs } #[allow(clippy::too_many_arguments)] @@ -2486,7 +2487,7 @@ pub fn call_quantized_matmul_mv_t( let r2: u32 = (ne12 / ne02) as u32; let r3: u32 = (ne13 / ne03) as u32; - let (nth0, nth1, align) = match dtype { + let (nth0, nth1, align, mem_size_bytes) = match dtype { GgmlDType::Q4_0 | GgmlDType::Q4_1 | GgmlDType::Q5_0 @@ -2496,7 +2497,7 @@ pub fn call_quantized_matmul_mv_t( let nth0 = 8; let nth1 = 8; let align = 8; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::Q2K => { // Fixing a bug in Metal for GGML @@ -2504,38 +2505,44 @@ pub fn call_quantized_matmul_mv_t( let nth0 = 2; let nth1 = 32; let align = 4; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::Q4K => { let nth0 = 4; let nth1 = 8; let align = 4; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::Q3K | GgmlDType::Q5K => { let nth0 = 2; let nth1 = 32; let align = 4; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::Q6K => { let nth0 = 2; let nth1 = 32; let align = 2; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::F16 | GgmlDType::BF16 | GgmlDType::Q8K => { // Original implem uses rows let nth0 = 32; let nth1 = 1; let align = 8; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::F32 => { let nth0 = 32; let nth1 = 1; let align = 8; - (nth0, nth1, align) + (nth0, nth1, align, None) + } + GgmlDType::Iq4Xs => { + let nth0 = 4; + let nth1 = 16; + let align = 4; + (nth0, nth1, align, Some(32*std::mem::size_of::())) } }; let thread_groups_count = MTLSize { @@ -2564,6 +2571,7 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::F16 => "kernel_mul_mv_f16_f32", GgmlDType::BF16 => "kernel_mul_mv_bf16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", + GgmlDType::Iq4Xs => "kernel_mul_mm_iq4_xs_f32", }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; @@ -2571,6 +2579,10 @@ pub fn call_quantized_matmul_mv_t( let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); + if let Some(mem_size_bytes) = mem_size_bytes { + encoder.set_threadgroup_memory_length(0, mem_size_bytes as u64); + } + set_params!( encoder, ( @@ -2672,6 +2684,7 @@ pub fn call_quantized_matmul_mm_t( GgmlDType::F16 => "kernel_mul_mm_f16_f32", GgmlDType::BF16 => "kernel_mul_mm_bf16_f32", GgmlDType::F32 => "kernel_mul_mm_f32_f32", + GgmlDType::Iq4Xs => "kernel_mul_mm_iq4_xs_f32", }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; From 8456349deebd2465c75129065bc33f2ead004417 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 5 Feb 2025 21:26:50 -0500 Subject: [PATCH 16/26] debug --- candle-core/src/quantized/iq_quants/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/candle-core/src/quantized/iq_quants/mod.rs b/candle-core/src/quantized/iq_quants/mod.rs index faa61fc08f..af099ab3e0 100644 --- a/candle-core/src/quantized/iq_quants/mod.rs +++ b/candle-core/src/quantized/iq_quants/mod.rs @@ -73,6 +73,7 @@ impl GgmlType for BlockIQ4xs { bail!("Input length must be multiple of QK_K = {}", QK_K); } + println!("starting"); quantize_iq4_xs(xs, ys, 1, k, None)?; Ok(()) @@ -223,6 +224,7 @@ fn quantize_iq4_xs( for _row in 0..nrow { // Each row has `nblock` blocks: for ibl in 0..nblock { + println!("ibl {ibl}"); let block = &mut ys[dst_offset + ibl]; let qw = quant_weights.map(|qw_all| { From 3a94e5e6cd7d8f129744a31b9410109a05c369f1 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 5 Feb 2025 21:31:10 -0500 Subject: [PATCH 17/26] debug --- candle-core/src/quantized/iq_quants/mod.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/candle-core/src/quantized/iq_quants/mod.rs b/candle-core/src/quantized/iq_quants/mod.rs index af099ab3e0..e04e1545a7 100644 --- a/candle-core/src/quantized/iq_quants/mod.rs +++ b/candle-core/src/quantized/iq_quants/mod.rs @@ -73,8 +73,7 @@ impl GgmlType for BlockIQ4xs { bail!("Input length must be multiple of QK_K = {}", QK_K); } - println!("starting"); - quantize_iq4_xs(xs, ys, 1, k, None)?; += quantize_iq4_xs(xs, ys, 1, k, None)?; Ok(()) } @@ -224,7 +223,6 @@ fn quantize_iq4_xs( for _row in 0..nrow { // Each row has `nblock` blocks: for ibl in 0..nblock { - println!("ibl {ibl}"); let block = &mut ys[dst_offset + ibl]; let qw = quant_weights.map(|qw_all| { From a79581c25368b616b9265ae183a11d3cedfd0ecd Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 5 Feb 2025 21:31:25 -0500 Subject: [PATCH 18/26] debug --- candle-core/src/quantized/iq_quants/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-core/src/quantized/iq_quants/mod.rs b/candle-core/src/quantized/iq_quants/mod.rs index e04e1545a7..faa61fc08f 100644 --- a/candle-core/src/quantized/iq_quants/mod.rs +++ b/candle-core/src/quantized/iq_quants/mod.rs @@ -73,7 +73,7 @@ impl GgmlType for BlockIQ4xs { bail!("Input length must be multiple of QK_K = {}", QK_K); } -= quantize_iq4_xs(xs, ys, 1, k, None)?; + quantize_iq4_xs(xs, ys, 1, k, None)?; Ok(()) } From e4d5edc92db8c8739812123fc6dd26191f7124b9 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 5 Feb 2025 21:55:16 -0500 Subject: [PATCH 19/26] FIx kernel name for mv --- candle-metal-kernels/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c6710be672..5fd4ec5be7 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2571,7 +2571,7 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::F16 => "kernel_mul_mv_f16_f32", GgmlDType::BF16 => "kernel_mul_mv_bf16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", - GgmlDType::Iq4Xs => "kernel_mul_mm_iq4_xs_f32", + GgmlDType::Iq4Xs => "kernel_mul_mv_iq4_xs_f32", }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; From e9c139d319039dc2248b79ec1e71515e8551f77b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 6 Feb 2025 08:30:06 -0500 Subject: [PATCH 20/26] Add metal, neon iq4nl --- candle-core/src/quantized/iq_quants/mod.rs | 203 +++++++++++++++---- candle-core/src/quantized/iq_quants/utils.rs | 9 +- candle-core/src/quantized/metal.rs | 5 + candle-core/src/quantized/mod.rs | 8 + candle-core/src/quantized/neon.rs | 45 +++- candle-core/tests/quantized_tests.rs | 124 ++++++++++- candle-metal-kernels/src/lib.rs | 7 +- 7 files changed, 355 insertions(+), 46 deletions(-) diff --git a/candle-core/src/quantized/iq_quants/mod.rs b/candle-core/src/quantized/iq_quants/mod.rs index faa61fc08f..7ad3ac401c 100644 --- a/candle-core/src/quantized/iq_quants/mod.rs +++ b/candle-core/src/quantized/iq_quants/mod.rs @@ -5,7 +5,7 @@ use crate::{bail, Result}; mod utils; -use super::{BlockQ8K, GgmlDType, GgmlType, QK_K}; +use super::{k_quants::BlockQ8_0, BlockQ8K, GgmlDType, GgmlType, QK_K}; pub const QK4_NL: usize = 32; @@ -90,7 +90,7 @@ impl GgmlType for BlockIQ4xs { } let nrow = xs.len() / n_per_row; - quantize_iq4_xs_imatrix(xs, ys, nrow, n_per_row, Some(imatrix_weights)); + quantize_iq4_xs(xs, ys, nrow, n_per_row, Some(imatrix_weights))?; Ok(()) } @@ -192,6 +192,126 @@ impl GgmlType for BlockIQ4xs { } } +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockIQ4nl { + pub(crate) d: f16, + pub(crate) qs: [u8; QK4_NL / 2], +} + +const _: () = assert!( + std::mem::size_of::() == std::mem::size_of::() + QK4_NL / 2, + "wrong iq4_nl block size/padding" +); + +impl GgmlType for BlockIQ4nl { + const DTYPE: GgmlDType = GgmlDType::Iq4Nl; + const BLCK_SIZE: usize = QK4_NL; + type VecDotType = BlockQ8_0; + const SUPPORTS_I8MM: bool = false; + + fn to_float(xs: &[Self], mut ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK4_NL != 0 { + crate::bail!("dequantize block iq4xs {k} is not divisible by {QK4_NL}"); + } + + let nb = k / QK4_NL; + for block in xs.iter().take(nb) { + let d = block.d.to_f32(); + let qs = &block.qs[..]; + + for j in 0..(QK4_NL / 2) { + ys[j] = d * KVALUES_IQ4NL[(qs[j] & 0xf) as usize] as f32; + ys[j + QK4_NL / 2] = d * KVALUES_IQ4NL[(qs[j] >> 4) as usize] as f32; + } + ys = &mut ys[QK4_NL..]; + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + let k = xs.len(); + if k % QK4_NL != 0 { + bail!("Input length must be multiple of QK4_NL = {}", QK4_NL); + } + + quantize_iq4_nl(xs, ys, 1, k, None)?; + + Ok(()) + } + + fn from_float_imatrix( + xs: &[f32], + ys: &mut [Self], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + let k = xs.len(); + if k % QK4_NL != 0 { + bail!("Input length must be multiple of QK4_NL = {}", QK4_NL); + } + let nrow = xs.len() / n_per_row; + + quantize_iq4_nl(xs, ys, nrow, n_per_row, Some(imatrix_weights))?; + + Ok(()) + } + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + // #[cfg(target_feature = "avx")] + // todo!(); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_iq4_nl_q8k(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % QK4_NL != 0 { + bail!("n must be a multiple of QK4_NL"); + } + let nb = n / QK4_NL; + + let mut sumf = 0.0f32; + + // Loop over each block + for ibl in 0..nb { + let x = &xs[ibl]; + let y = &ys[ibl]; + + let d = x.d.to_f32() * y.d.to_f32(); + + let mut sumi1 = 0; + let mut sumi2 = 0; + + for j in 0..QK4_NL / 2 { + sumi1 += y.qs[j] as i32 * KVALUES_IQ4NL[(x.qs[j] & 0xf) as usize] as i32; + sumi2 += + y.qs[j + QK4_NL / 2] as i32 * KVALUES_IQ4NL[(x.qs[j] >> 4) as usize] as i32; + } + + sumf += d * (sumi1 + sumi2) as f32; + } + + Ok(sumf) + } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } +} + fn quantize_iq4_xs( src: &[f32], ys: &mut [BlockIQ4xs], @@ -237,8 +357,8 @@ fn quantize_iq4_xs( &src[src_offset + QK_K * ibl..src_offset + QK_K * (ibl + 1)], /* dh = */ &mut block.d, /* q4 = */ &mut block.qs, - /* scales_h = */ &mut block.scales_h, - /* scales_l = */ &mut block.scales_l, + /* scales_h = */ Some(&mut block.scales_h), + /* scales_l = */ Some(&mut block.scales_l), /* scales = */ &mut scales, /* weight = */ &mut weight, /* L = */ &mut lbuf, @@ -254,59 +374,64 @@ fn quantize_iq4_xs( Ok(()) } -pub fn quantize_iq4_xs_imatrix( +fn quantize_iq4_nl( src: &[f32], - dst: &mut [BlockIQ4xs], + ys: &mut [BlockIQ4nl], nrow: usize, n_per_row: usize, quant_weights: Option<&[f32]>, -) { - // 1. Check that n_per_row is multiple of QK_K - assert_eq!(n_per_row % QK_K, 0, "n_per_row must be multiple of QK_K"); - let nblock = n_per_row / QK_K; +) -> Result<()> { + if n_per_row % QK4_NL != 0 { + bail!("n_per_row must be multiple of QK4_NL = {}", QK4_NL); + } - // 2. We expect nrow * nblock blocks in `dst` - assert_eq!( - dst.len(), - nrow * nblock, - "Output slice must have exactly nrow*nblock elements" - ); + let nblock = n_per_row / QK4_NL; + // We expect exactly nrow * nblock blocks in `ys`. + if ys.len() != nrow * nblock { + bail!( + "Output buffer size mismatch: want {} blocks, got {}", + nrow * nblock, + ys.len() + ); + } - // 3. Local buffers matching the C usage - let mut lbuf = vec![0u8; QK_K]; - let mut weight = vec![0f32; 32]; - let mut scales = vec![0f32; QK_K / 32]; + let mut lbuf = vec![0u8; QK4_NL]; // L[QK4_NL] + let mut weight = vec![0f32; QK4_NL]; // weight[QK4_NL] + let mut scales = vec![0f32]; // scales[1] - // We'll track how far we've consumed `src`. let mut src_offset = 0; - // Also track how far we move in `dst`. let mut dst_offset = 0; - // 4. Outer loop over rows for _row in 0..nrow { + // Each row has `nblock` blocks: for ibl in 0..nblock { - let block = &mut dst[dst_offset + ibl]; + let block = &mut ys[dst_offset + ibl]; - // If quant_weights is Some, get the sub-slice for this block - let qw_block = quant_weights.map(|qw_all| &qw_all[ibl * QK_K..(ibl + 1) * QK_K]); + let qw = quant_weights.map(|qw_all| { + let start = QK4_NL * ibl; + &qw_all[start..start + QK4_NL] + }); quantize_row_iq4_nl_impl( - QK_K, // super_block_size - 32, // block_size - &src[src_offset + ibl * QK_K..src_offset + (ibl + 1) * QK_K], - &mut block.d, - &mut block.qs, - &mut block.scales_h, - &mut block.scales_l, - &mut scales, - &mut weight, - &mut lbuf, - &KVALUES_IQ4NL, - qw_block, - 7, // ntry + /* super_block_size = */ QK4_NL, + /* block_size = */ 32, + /* x = */ + &src[src_offset + QK4_NL * ibl..src_offset + QK4_NL * (ibl + 1)], + /* dh = */ &mut block.d, + /* q4 = */ &mut block.qs, + /* scales_h = */ None, + /* scales_l = */ None, + /* scales = */ &mut scales, + /* weight = */ &mut weight, + /* L = */ &mut lbuf, + /* values = */ &KVALUES_IQ4NL, + /* quant_weights = */ qw, + /* ntry = */ 7, ); } src_offset += n_per_row; dst_offset += nblock; } + + Ok(()) } diff --git a/candle-core/src/quantized/iq_quants/utils.rs b/candle-core/src/quantized/iq_quants/utils.rs index 206f8cb110..58bb7dd87d 100644 --- a/candle-core/src/quantized/iq_quants/utils.rs +++ b/candle-core/src/quantized/iq_quants/utils.rs @@ -9,8 +9,8 @@ pub(super) fn quantize_row_iq4_nl_impl( x: &[f32], dh: &mut f16, q4: &mut [u8], - scales_h: &mut u16, - scales_l: &mut [u8], + scales_h: Option<&mut u16>, + scales_l: Option<&mut [u8]>, scales: &mut [f32], weight: &mut [f32], lbuf: &mut [u8], @@ -24,7 +24,6 @@ pub(super) fn quantize_row_iq4_nl_impl( let sb_div_64 = super_block_size / 64; assert_eq!(q4.len(), sb_div_2); assert_eq!(scales.len(), sb_div_32); - assert_eq!(scales_l.len(), sb_div_64); assert_eq!(lbuf.len(), super_block_size); assert_eq!(weight.len(), block_size); @@ -141,6 +140,10 @@ pub(super) fn quantize_row_iq4_nl_impl( // 8. If we have more than one 32-float block in the super-block: if nblocks > 1 { + let scales_h = scales_h.expect("Expected scales_h, nblocks > 1"); + let scales_l = scales_l.expect("Expected scales_l, nblocks > 1"); + assert_eq!(scales_l.len(), sb_div_64); + // zero scales_h, because we store 2 bits per block in it // for nblocks=8, we store them in a single 16-bit value *scales_h = 0; diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index d4bf2afeda..6ea9d23fad 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -111,6 +111,10 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockIQ4xs::to_float(&vec, &mut out)?; } + GgmlDType::Iq4Nl => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockIQ4nl::to_float(&vec, &mut out)?; + } } let buffer = self.device.new_buffer_with_data(&out)?; @@ -392,6 +396,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K, GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, GgmlDType::Iq4Xs => candle_metal_kernels::GgmlDType::Iq4Xs, + GgmlDType::Iq4Nl => candle_metal_kernels::GgmlDType::Iq4Nl, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16, diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 761069656f..d9ebcca125 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -203,6 +203,7 @@ pub enum GgmlDType { Q6K, Q8K, Iq4Xs, + Iq4Nl, } impl GgmlDType { @@ -246,6 +247,7 @@ impl GgmlDType { Self::Q5K => 13, Self::Q6K => 14, Self::Q8K => 15, + Self::Iq4Nl => 20, Self::Iq4Xs => 23, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 Self::BF16 => 30, @@ -269,6 +271,10 @@ impl GgmlDType { Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), + Self::Iq4Nl => Box::new(vec![ + BlockIQ4nl::zeros(); + elem_count / BlockIQ4nl::BLCK_SIZE + ]), Self::Iq4Xs => Box::new(vec![ BlockIQ4xs::zeros(); elem_count / BlockIQ4xs::BLCK_SIZE @@ -295,6 +301,7 @@ impl GgmlDType { Self::Q5K => std::mem::size_of::(), Self::Q6K => std::mem::size_of::(), Self::Q8K => std::mem::size_of::(), + Self::Iq4Nl => std::mem::size_of::(), Self::Iq4Xs => std::mem::size_of::(), } } @@ -310,6 +317,7 @@ impl GgmlDType { Self::Q5_1 => k_quants::QK5_1, Self::Q8_0 => k_quants::QK8_0, Self::Q8_1 => k_quants::QK8_1, + Self::Iq4Nl => iq_quants::QK4_NL, Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K | Self::Iq4Xs => { k_quants::QK_K } diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index af388ca7ee..df8389075e 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -1,5 +1,5 @@ use super::{ - iq_quants::BlockIQ4xs, + iq_quants::{BlockIQ4nl, BlockIQ4xs, QK4_NL}, k_quants::{ BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, @@ -1966,6 +1966,49 @@ pub(crate) fn vec_dot_iq4_xs_q8k(n: usize, xs: &[BlockIQ4xs], ys: &[BlockQ8K]) - } } +#[inline(always)] +pub(crate) fn vec_dot_iq4_nl_q8k(n: usize, xs: &[BlockIQ4nl], ys: &[BlockQ8_0]) -> Result { + if n % QK4_NL != 0 { + crate::bail!("vec_dot_iq4_nl_q8k: {n} is not divisible by {QK4_NL}") + } + + unsafe { + let values = vld1q_s8(KVALUES_IQ4NL.as_ptr()); + let m4b = vdupq_n_u8(0x0f); + + let mut q4b = int8x16x4_t(vdupq_n_s8(0), vdupq_n_s8(0), vdupq_n_s8(0), vdupq_n_s8(0)); + let mut q8b = int8x16x4_t(vdupq_n_s8(0), vdupq_n_s8(0), vdupq_n_s8(0), vdupq_n_s8(0)); + let mut q4bits = uint8x16x2_t(vdupq_n_u8(0), vdupq_n_u8(0)); + + let mut sumf = 0f32; + + let nb = n / QK4_NL; + for ib in (0..nb - 1).step_by(2) { + q4bits.0 = vld1q_u8(xs[ib].qs.as_ptr()); + q4bits.1 = vld1q_u8(xs[ib + 1].qs.as_ptr()); + q8b.0 = vld1q_s8(ys[ib].qs.as_ptr()); + q8b.1 = vld1q_s8(ys[ib].qs.as_ptr().add(16)); + q8b.2 = vld1q_s8(ys[ib + 1].qs.as_ptr()); + q8b.3 = vld1q_s8(ys[ib + 1].qs.as_ptr().add(16)); + + q4b.0 = vqtbl1q_s8(values, vandq_u8(q4bits.0, m4b)); + q4b.1 = vqtbl1q_s8(values, vshrq_n_u8(q4bits.0, 4)); + q4b.2 = vqtbl1q_s8(values, vandq_u8(q4bits.1, m4b)); + q4b.3 = vqtbl1q_s8(values, vshrq_n_u8(q4bits.1, 4)); + + let prod1 = + vdotq_s32_local(vdotq_s32_local(vdupq_n_s32(0), q4b.0, q8b.0), q4b.1, q8b.1); + let prod2 = + vdotq_s32_local(vdotq_s32_local(vdupq_n_s32(0), q4b.2, q8b.2), q4b.3, q8b.3); + + sumf += xs[ib].d.to_f32() * ys[ib].d.to_f32() * vaddvq_s32(prod1) as f32 + + xs[ib + 1].d.to_f32() * ys[ib + 1].d.to_f32() * vaddvq_s32(prod2) as f32; + } + + Ok(sumf) + } +} + #[inline(always)] #[cfg(feature = "arm-nightly-feat")] unsafe fn i8mm_accum_with_scale( diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index bce0840f84..34fc6afc9e 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -986,6 +986,43 @@ fn quantize_iq4_xs(device: &Device) -> Result<()> { Ok(()) } +fn quantize_iq4_nl(device: &Device) -> Result<()> { + let dtype = GgmlDType::Iq4Nl; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.025); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 5.9); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + + Ok(()) +} + #[test] fn imatrix_quantize_iq4_xs() -> Result<()> { // let data = @@ -1031,6 +1068,51 @@ fn imatrix_quantize_iq4_xs() -> Result<()> { Ok(()) } +#[test] +fn imatrix_quantize_iq4_nl() -> Result<()> { + // let data = + // quantized::imatrix_file::load_imatrix("../Llama-3.2-3B-Instruct.imatrix").unwrap(); + // for (name, weights) in &data { + // println!("{name}, {} elems", weights.len()); + // } + // dbg!(&data["blk.0.attn_q.weight"].len()); + + let cpu = &Device::Cpu; + + let mut row_counts = 0f64; + let mut ncall = 0f64; + let mut values = Tensor::zeros((768,), DType::F32, cpu)?; + + for _ in 0..10 { + let lhs = Var::from_tensor(&Tensor::randn(0f32, 1f32, (1024, 512), cpu)?)?; + let rhs = Var::from_tensor(&Tensor::randn(0f32, 1f32, (512, 768), cpu)?)?; + let res = lhs.matmul(&rhs)?; + + // https://github.com/ggerganov/llama.cpp/blob/678d7994f4da0af3d29046be99950ac999ee9762/examples/imatrix/imatrix.cpp#L180-L186 + values = (values + res.sqr()?.sum(0)?)?; + row_counts += res.dim(0)? as f64; + ncall += 1.; + } + + // https://github.com/ggerganov/llama.cpp/blob/678d7994f4da0af3d29046be99950ac999ee9762/examples/imatrix/imatrix.cpp#L275 + let out = ((values / row_counts)? * ncall)?; + let imatrix = out.to_vec1::()?; + + let xs = Tensor::randn(0f32, 1f32, (1024, 768), cpu)?; + + let quant1 = quantized::QTensor::quantize(&xs, GgmlDType::Iq4Nl)?; + let quant2 = quantized::QTensor::quantize_imatrix(&xs, &imatrix, GgmlDType::Iq4Nl)?; + + let dequant1 = quant1.dequantize(cpu)?; + let dequant2 = quant2.dequantize(cpu)?; + + let err1 = (dequant1 - &xs)?.abs()?.mean_all()?.to_scalar::()?; + let err2 = (dequant2 - &xs)?.abs()?.mean_all()?.to_scalar::()?; + assert!(err2 < err1, "err2 {err2} > err1 {err1}"); + + Ok(()) +} + test_device!( quantize_q4_0, quantize_q4_0_cpu, @@ -1097,6 +1179,12 @@ test_device!( quantize_iq4_xs_cuda, quantize_iq4_xs_metal ); +test_device!( + quantize_iq4_nl, + quantize_iq4_nl_cpu, + quantize_iq4_nl_cuda, + quantize_iq4_nl_metal +); /// Very simple dot product implementation fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { @@ -1118,6 +1206,7 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { GgmlDType::Q5_1 => 0.00149, GgmlDType::Q8_0 => 0.000092, GgmlDType::Iq4Xs => 0.001903, + GgmlDType::Iq4Nl => 0.002716, // Not from the ggml repo. GgmlDType::Q8K => 0.00065, @@ -1292,7 +1381,14 @@ quantized_matmul!( quantized_matmul_iq4xs_cpu, quantized_matmul_iq4xs_cuda, quantized_matmul_iq4xs_metal, - GgmlDType::Q4K + GgmlDType::Iq4Xs +); +quantized_matmul!( + quantized_matmul_iq4nl_bis, + quantized_matmul_iq4nl_cpu, + quantized_matmul_iq4nl_cuda, + quantized_matmul_iq4nl_metal, + GgmlDType::Iq4Nl ); quantized_matmul!( quantized_matmul_q4k_bis, @@ -1428,6 +1524,32 @@ fn quantized_matmul_iq4xs() -> Result<()> { Ok(()) } +#[test] +fn quantized_matmul_iq4nl() -> Result<()> { + use iq_quants::BlockIQ4nl; + + let cpu = &Device::Cpu; + let (m, k, n) = (11, 512, 21); + let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); + + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Iq4Nl)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.432, 1.469, -0.312, 1.602]); + + ggml_matmul_error_test::()?; + + Ok(()) +} + #[test] fn quantized_matmul_q5k() -> Result<()> { use k_quants::BlockQ5K; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5fd4ec5be7..9d31dd7f81 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2447,7 +2447,8 @@ pub enum GgmlDType { F16, F32, BF16, - Iq4Xs + Iq4Xs, + Iq4Nl, } #[allow(clippy::too_many_arguments)] @@ -2538,7 +2539,7 @@ pub fn call_quantized_matmul_mv_t( let align = 8; (nth0, nth1, align, None) } - GgmlDType::Iq4Xs => { + GgmlDType::Iq4Xs | GgmlDType::Iq4Nl => { let nth0 = 4; let nth1 = 16; let align = 4; @@ -2572,6 +2573,7 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::BF16 => "kernel_mul_mv_bf16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", GgmlDType::Iq4Xs => "kernel_mul_mv_iq4_xs_f32", + GgmlDType::Iq4Nl => "kernel_mul_mv_iq4_nl_f32", }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; @@ -2685,6 +2687,7 @@ pub fn call_quantized_matmul_mm_t( GgmlDType::BF16 => "kernel_mul_mm_bf16_f32", GgmlDType::F32 => "kernel_mul_mm_f32_f32", GgmlDType::Iq4Xs => "kernel_mul_mm_iq4_xs_f32", + GgmlDType::Iq4Nl => "kernel_mul_mm_iq4_nl_f32", }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; From f5cca346767302b2b9a6ee12e3d68590662fd3e5 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 6 Feb 2025 08:36:29 -0500 Subject: [PATCH 21/26] Fix missing handling in from_u32 --- candle-core/src/quantized/mod.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index d9ebcca125..65aaace433 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -207,7 +207,7 @@ pub enum GgmlDType { } impl GgmlDType { - pub(crate) fn from_u32(u: u32) -> Result { + pub fn from_u32(u: u32) -> Result { let dtype = match u { 0 => Self::F32, 1 => Self::F16, @@ -223,6 +223,7 @@ impl GgmlDType { 13 => Self::Q5K, 14 => Self::Q6K, 15 => Self::Q8K, + 20 => Self::Iq4Nl, 23 => Self::Iq4Xs, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 30 => Self::BF16, @@ -231,7 +232,7 @@ impl GgmlDType { Ok(dtype) } - pub(crate) fn to_u32(self) -> u32 { + pub fn to_u32(self) -> u32 { match self { Self::F32 => 0, Self::F16 => 1, From 9db53908c730d72e30f547000481cd4a1daef8a1 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 6 Feb 2025 11:44:30 -0500 Subject: [PATCH 22/26] Sketch iq3xxs --- candle-core/src/quantized/iq_quants/mod.rs | 218 ++++++- candle-core/src/quantized/iq_quants/utils.rs | 591 +++++++++++++++++++ candle-core/src/quantized/metal.rs | 5 + candle-core/src/quantized/mod.rs | 19 +- candle-core/tests/quantized_tests.rs | 44 ++ candle-metal-kernels/src/lib.rs | 9 + 6 files changed, 881 insertions(+), 5 deletions(-) diff --git a/candle-core/src/quantized/iq_quants/mod.rs b/candle-core/src/quantized/iq_quants/mod.rs index 7ad3ac401c..2c2c6f836b 100644 --- a/candle-core/src/quantized/iq_quants/mod.rs +++ b/candle-core/src/quantized/iq_quants/mod.rs @@ -1,5 +1,7 @@ +use std::ptr; + use half::f16; -use utils::quantize_row_iq4_nl_impl; +use utils::{quantize_row_iq3_xxs_impl, quantize_row_iq4_nl_impl}; use crate::{bail, Result}; @@ -13,6 +15,53 @@ pub(super) const KVALUES_IQ4NL: [i8; 16] = [ -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113, ]; +const K_MASK_IQ2XS: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128]; + +const K_SIGNS_IQ2XS: [u8; 128] = [ + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, 144, 17, 18, 147, 20, 149, + 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, + 43, 172, 45, 46, 175, 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, 80, 209, 210, 83, 212, + 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, 96, 225, 226, 99, 228, 101, 102, 231, 232, + 105, 106, 235, 108, 237, 238, 111, 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, + 252, 125, 126, 255, +]; + +const IQ3XXS_GRID: [u32; 256] = [ + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +]; + #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockIQ4xs { @@ -213,7 +262,7 @@ impl GgmlType for BlockIQ4nl { fn to_float(xs: &[Self], mut ys: &mut [f32]) -> Result<()> { let k = ys.len(); if k % QK4_NL != 0 { - crate::bail!("dequantize block iq4xs {k} is not divisible by {QK4_NL}"); + crate::bail!("dequantize block iq4nl {k} is not divisible by {QK4_NL}"); } let nb = k / QK4_NL; @@ -312,6 +361,145 @@ impl GgmlType for BlockIQ4nl { } } +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockIQ3xxs { + pub(crate) d: f16, + pub(crate) qs: [u8; 3 * QK_K / 8], +} + +const _: () = assert!( + std::mem::size_of::() == std::mem::size_of::() + 3 * QK_K / 8, + "wrong iq3_xxs block size/padding" +); + +impl GgmlType for BlockIQ3xxs { + const DTYPE: GgmlDType = GgmlDType::Iq3Xxs; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = false; + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize block iq3xxs {k} is not divisible by {QK_K}"); + } + let nb = k / QK_K; + + let mut aux32: u32; + + let mut y = ys.as_mut_ptr(); + + unsafe { + // Process each block. + for i in 0..nb { + // Access the i'th block. + let block = &xs[i]; + // Convert FP16 to f32. + let d = block.d.to_f32(); + // Get a pointer to the beginning of qs. + let mut qs = block.qs.as_ptr(); + // scales_and_signs is located QK_K/4 bytes after qs. + let scales_and_signs = qs.add((QK_K / 4) as usize); + + // Loop over each 32-byte subblock. + for ib32 in 0..(QK_K / 32) { + // Copy 4 bytes from scales_and_signs + 4*ib32 into aux32. + aux32 = ptr::read_unaligned( + scales_and_signs.add((4 * ib32) as usize) as *const u32 + ); + // Compute db = d * (0.5 + (aux32 >> 28)) * 0.5 + let db = d * (0.5 + ((aux32 >> 28) as f32)) * 0.5; + + // Process 4 groups per 32-byte subblock. + for l in 0..4 { + let shift = 7 * l; + let idx = ((aux32 >> shift) & 127) as usize; + // Get the corresponding 'signs' value. + let signs = K_SIGNS_IQ2XS[idx]; + + // Get pointers to grid1 and grid2. + // qs[2*l+0] and qs[2*l+1] are used as offsets into IQ3XXS_GRID. + let idx1 = *qs.add((2 * l + 0) as usize) as usize; + let grid1 = IQ3XXS_GRID.as_ptr().add(idx1); + let idx2 = *qs.add((2 * l + 1) as usize) as usize; + let grid2 = IQ3XXS_GRID.as_ptr().add(idx2); + + // For each of 4 values in grid1 and grid2. + for j in 0..4 { + let mask1 = K_MASK_IQ2XS[(j + 0) as usize]; + let mask2 = K_MASK_IQ2XS[(j + 4) as usize]; + let sign1 = if (signs & mask1) != 0 { -1.0 } else { 1.0 }; + let sign2 = if (signs & mask2) != 0 { -1.0 } else { 1.0 }; + + let grid1_val = *grid1.add(j as usize) as f32; + let grid2_val = *grid2.add(j as usize) as f32; + + // Write the dequantized values. + ptr::write(y.add(j + 0), db * grid1_val * sign1); + ptr::write(y.add(j + 4), db * grid2_val * sign2); + } + // Advance y by 8 floats. + y = y.add(8); + } + // Advance qs by 8 bytes. + qs = qs.add(8); + } + } + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + let k = xs.len(); + if k % QK4_NL != 0 { + bail!("Input length must be multiple of QK4_NL = {}", QK4_NL); + } + + unsafe { quantize_iq3_xxs(xs, ys, 1, k, None)? }; + + Ok(()) + } + + fn from_float_imatrix( + xs: &[f32], + ys: &mut [Self], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + let k = xs.len(); + if k % QK4_NL != 0 { + bail!("Input length must be multiple of QK4_NL = {}", QK4_NL); + } + let nrow = xs.len() / n_per_row; + + unsafe { quantize_iq3_xxs(xs, ys, nrow, n_per_row, Some(imatrix_weights))? }; + + Ok(()) + } + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + todo!() + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + todo!() + } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } +} + fn quantize_iq4_xs( src: &[f32], ys: &mut [BlockIQ4xs], @@ -435,3 +623,29 @@ fn quantize_iq4_nl( Ok(()) } + +unsafe fn quantize_iq3_xxs( + src: &[f32], + ys: &mut [BlockIQ3xxs], + nrow: usize, + n_per_row: usize, + quant_weights: Option<&[f32]>, +) -> Result<()> { + // Assert that n_per_row is a multiple of QK_K. + assert!( + n_per_row % QK_K == 0, + "n_per_row must be a multiple of QK_K" + ); + let nblock = n_per_row / QK_K; + + let mut src_ptr = src.as_ptr(); + let mut dst_ptr = ys.as_mut_ptr(); + + for _row in 0..nrow { + quantize_row_iq3_xxs_impl(256, src_ptr, dst_ptr, n_per_row as i64, quant_weights); + + src_ptr = src_ptr.add(n_per_row); + dst_ptr = dst_ptr.add(nblock); + } + Ok(()) +} diff --git a/candle-core/src/quantized/iq_quants/utils.rs b/candle-core/src/quantized/iq_quants/utils.rs index 58bb7dd87d..e4b379a33a 100644 --- a/candle-core/src/quantized/iq_quants/utils.rs +++ b/candle-core/src/quantized/iq_quants/utils.rs @@ -1,6 +1,17 @@ +use std::{ffi::c_void, mem, ptr}; + use half::f16; +use crate::quantized::{k_quants::QK_K, GgmlType}; + +use super::BlockIQ3xxs; + const GROUP_MAX_EPS: f32 = 1e-15; +const GROUP_MAX_EPS_IQ3_XXS: f32 = 1e-8; + +fn nearest_int(x: f32) -> i32 { + x.round() as i32 +} #[allow(clippy::too_many_arguments)] pub(super) fn quantize_row_iq4_nl_impl( @@ -250,3 +261,583 @@ fn best_index_int8(values: &[i8], x: f32) -> usize { mu } } + +/// Global state analogous to the C code’s iq3_data. +/// (Using unsafe mutable statics; in production code consider using a safe wrapper.) +#[derive(Debug)] +struct Iq3Entry { + grid: Option>, + map: Option>, + neighbours: Option>, +} + +static mut IQ3_DATA: [Iq3Entry; 2] = [ + Iq3Entry { + grid: None, + map: None, + neighbours: None, + }, + Iq3Entry { + grid: None, + map: None, + neighbours: None, + }, +]; + +static KGRID_256: [u16; 256] = [ + 0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74, 81, 85, 88, 90, 97, 108, 120, 128, + 130, 132, 137, 144, 146, 153, 155, 159, 169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, + 292, 303, 315, 317, 321, 327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, + 529, 531, 536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664, 698, + 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978, 992, 1024, 1026, + 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105, 1112, 1139, 1143, + 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228, 1272, 1276, 1309, + 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553, 1555, 1576, 1589, + 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722, 1737, 1755, 1816, + 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063, 2077, 2079, 2091, + 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389, 2403, 2424, 2501, + 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746, 2754, 2795, 2824, + 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153, 3185, 3215, 3252, + 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610, 3626, 3670, 3680, + 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992, +]; + +static KGRID_512: [u16; 512] = [ + 0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34, 37, 39, 41, 43, 48, 50, 57, 60, 63, + 64, 65, 66, 68, 72, 73, 77, 80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, + 139, 142, 145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210, 217, + 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288, 291, 297, 312, 322, + 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393, 395, 409, 426, 441, 448, 450, 452, + 464, 466, 470, 475, 488, 492, 512, 513, 514, 516, 520, 521, 523, 525, 527, 528, 530, 537, 540, + 542, 556, 558, 561, 570, 576, 577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, + 640, 650, 653, 655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727, + 728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833, 840, 843, 849, + 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977, 989, 993, 1010, 1016, 1024, + 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047, 1048, 1050, 1057, 1059, 1061, + 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103, 1106, 1109, 1113, 1116, 1122, + 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199, 1209, 1212, 1216, 1218, 1221, + 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296, 1299, 1306, 1309, 1313, 1338, + 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415, 1425, 1453, 1457, 1477, 1481, + 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561, 1563, 1565, 1570, 1572, 1575, + 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648, 1658, 1662, 1664, 1674, 1680, + 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761, 1763, 1767, 1773, 1787, 1795, + 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877, 1882, 1892, 1902, 1915, 1934, + 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068, 2071, 2074, 2081, 2088, 2104, + 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177, 2179, 2184, 2189, 2193, 2203, + 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269, 2304, 2306, 2324, 2335, 2336, + 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520, 2531, 2537, 2562, 2568, 2572, + 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634, 2641, 2650, 2682, 2688, 2697, + 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805, 2811, 2817, 2820, 2832, 2842, + 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083, 3085, 3097, 3099, 3120, 3136, + 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276, 3281, 3296, 3349, 3363, 3378, + 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591, 3600, 3602, 3614, 3616, 3628, + 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729, 3736, 3753, 3778, 3802, 3805, + 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032, +]; + +/// Returns the index into IQ3_DATA for a given grid size. +/// Panics if grid_size is not 256 or 512. +fn iq3_data_index(grid_size: i32) -> usize { + assert!( + grid_size == 256 || grid_size == 512, + "grid_size must be 256 or 512" + ); + if grid_size == 256 { + 0 + } else { + 1 + } +} + +/// Helper: given a grid value (stored as u32) reinterpreted as 4 bytes, +/// compute the “index” value using: for each byte b, compute ((b-1)/2) +/// and pack it as bits (3 bits per coordinate). +fn compute_index_from_grid_val(val: u32) -> usize { + let bytes = val.to_ne_bytes(); + let mut index = 0usize; + for k in 0..4 { + // (b - 1) / 2 + let q = (bytes[k].saturating_sub(1)) / 2; + index |= (q as usize) << (3 * k); + } + index +} + +/// Computes the squared Euclidean distance between two 4–byte positions. +fn dist2(a: &[u8; 4], b: &[u8; 4]) -> i32 { + let mut d2 = 0; + for k in 0..4 { + let diff = a[k] as i32 - b[k] as i32; + d2 += diff * diff; + } + d2 +} + +/// Main initialization function. This reproduces the C function iq3xs_init_impl. +fn iq3xs_init_impl(grid_size: i32) { + // Determine which slot to use. + let gindex = iq3_data_index(grid_size); + // Use unsafe to access the global mutable state. + unsafe { + if IQ3_DATA[gindex].grid.is_some() { + return; + } + } + + // Choose constants based on grid_size. + let (kgrid, nwant) = if grid_size == 256 { + (&KGRID_256[..], 2) + } else { + (&KGRID_512[..], 3) + }; + let kmap_size = 4096; + + // --- Allocate and initialize the grid --- + // For each element, we compute 4 bytes as: for each i in 0..4: + // byte = 2 * ((kgrid[k] >> (3*i)) & 0x7) + 1 + let mut grid: Vec = Vec::with_capacity(grid_size as usize); + for &kg in kgrid.iter().take(grid_size as usize) { + let mut bytes = [0u8; 4]; + for i in 0..4 { + let l = (kg >> (3 * i)) & 0x7; + bytes[i] = (2 * l + 1) as u8; + } + grid.push(u32::from_ne_bytes(bytes)); + } + + // --- Allocate and initialize the map --- + // kmap: size = 4096, all initialized to -1. + let mut kmap: Vec = vec![-1; kmap_size]; + + // For each grid element, compute its index and store the grid index. + for (j, &val) in grid.iter().enumerate() { + let index = compute_index_from_grid_val(val); + kmap[index] = j as i32; + } + + // --- First pass: determine total space needed for neighbours --- + let mut total_neighbors = 0; + let mut num_not_in_map = 0; + for i in 0..kmap_size { + if kmap[i] >= 0 { + continue; + } + num_not_in_map += 1; + + // Reconstruct the “position” from the map index. + let mut pos = [0u8; 4]; + for k in 0..4 { + let l = (i >> (3 * k)) & 0x7; + pos[k] = (2 * l + 1) as u8; + } + // Build a vector of (distance, grid index) pairs. + let mut dist_vec: Vec<(i32, usize)> = Vec::with_capacity(grid_size as usize); + for (j, &grid_val) in grid.iter().enumerate() { + let grid_bytes = grid_val.to_ne_bytes(); + let d = dist2(&grid_bytes, &pos); + dist_vec.push((d, j)); + } + // Sort the vector: first by distance, then by grid index. + dist_vec.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1))); + // Count how many neighbors to include. + let mut n = 0; + let mut nhave = 1; + let mut d_current = dist_vec[0].0; + for &(d, _) in dist_vec.iter().step_by(2) { + if d > d_current { + if nhave == nwant { + break; + } + d_current = d; + nhave += 1; + } + n += 1; + } + total_neighbors += n; + } + + // Allocate neighbours vector with the total number of u16 values. + // Note: the C code allocates (num_neighbors + num_not_in_map) elements. + let total_nbrs_size = total_neighbors + num_not_in_map; + let mut neighbours: Vec = Vec::with_capacity(total_nbrs_size); + + // --- Second pass: fill in the neighbours data and update kmap --- + let mut nbr_counter = 0; // global counter in the neighbours vector + for i in 0..kmap_size { + if kmap[i] >= 0 { + continue; + } + // Reconstruct the “position” from the map index. + let mut pos = [0u8; 4]; + for k in 0..4 { + let l = (i >> (3 * k)) & 0x7; + pos[k] = (2 * l + 1) as u8; + } + // Build and sort the distances for all grid elements. + let mut dist_vec: Vec<(i32, usize)> = Vec::with_capacity(grid_size as usize); + for (j, &grid_val) in grid.iter().enumerate() { + let grid_bytes = grid_val.to_ne_bytes(); + let d = dist2(&grid_bytes, &pos); + dist_vec.push((d, j)); + } + dist_vec.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1))); + + // Store negative index in kmap to indicate start offset in the neighbours vector. + kmap[i] = -((nbr_counter as i32) + 1); + + // Reserve a slot for the count of neighbours. + neighbours.push(0); // placeholder; will update later + nbr_counter += 1; + + // Now, add the neighbour indices. + let mut n = 0; + let mut nhave = 1; + let mut d_current = dist_vec[0].0; + for &(d, j) in dist_vec.iter().step_by(2) { + if d > d_current { + if nhave == nwant { + break; + } + d_current = d; + nhave += 1; + } + // Store the grid index as u16. + neighbours.push(j as u16 + 1); + nbr_counter += 1; + n += 1; + } + // Update the placeholder with the count of neighbours for this cell. + neighbours[nbr_counter - n - 1] = n as u16; + } + + // Finally, update the global IQ3_DATA entry. + unsafe { + IQ3_DATA[gindex].grid = Some(grid); + IQ3_DATA[gindex].map = Some(kmap); + IQ3_DATA[gindex].neighbours = Some(neighbours); + } +} + +pub unsafe fn iq3_find_best_neighbour( + neighbours: *const u16, + grid: *const u32, + xval: *const f32, + weight: *const f32, + scale: f32, + L: *mut i8, +) -> i32 { + // neighbours[0] holds the number of neighbours. + let num_neighbors = *neighbours as i32; + assert!(num_neighbors > 0); + let mut best_d2 = f32::MAX; + let mut grid_index: i32 = -1; + // j from 1 to num_neighbors (inclusive) + for j in 1..=num_neighbors { + // neighbours[j] + let neigh = *neighbours.add(j as usize); + // Compute pointer pg = (const int8_t*)(grid + neighbours[j]) + let pg = (grid.add(neigh as usize)) as *const i8; + let mut d2 = 0f32; + for i in 0..4 { + // Note: pg[i] is read as i8 then converted to f32. + let q = *pg.add(i) as f32; + let diff = scale * q - *xval.add(i); + d2 += *weight.add(i) * diff * diff; + } + if d2 < best_d2 { + best_d2 = d2; + grid_index = neigh as i32; + } + } + assert!(grid_index >= 0); + let pg = (grid.add(grid_index as usize)) as *const i8; + for i in 0..4 { + // Here we assume that (pg[i]-1)/2 uses integer arithmetic. + *L.add(i) = ((*pg.add(i)) - 1) / 2; + } + grid_index +} + +pub unsafe fn quantize_row_iq3_xxs_impl( + grid_size: i32, + x: *const f32, + y: *mut BlockIQ3xxs, + n: i64, + quant_weights: Option<&[f32]>, +) { + iq3xs_init_impl(grid_size); + + // Assume iq3_data_index is defined elsewhere. + let gindex = iq3_data_index(grid_size); + + // Assume iq3_data is a global array with fields: grid, map, neighbours. + let kgrid_q3xs: *const u32 = IQ3_DATA[gindex].grid.as_ref().unwrap().as_ptr(); + let kmap_q3xs: *const i32 = IQ3_DATA[gindex].map.as_ref().unwrap().as_ptr(); + let kneighbors_q3xs: *const u16 = IQ3_DATA[gindex].neighbours.as_ref().unwrap().as_ptr(); + + assert!(n % (QK_K as i64) == 0); + + let k_max_q: i32 = 8; + let nbl = n / (QK_K as i64); + + // Variables to hold output pointers. + let mut dh = &raw mut (*y).d; + let mut qs = (*y).qs.as_mut_ptr(); + let block_size = mem::size_of::(); + let quant_size = block_size - mem::size_of::(); + + // Allocate temporary arrays on the stack. + let mut scales = [0f32; QK_K as usize / 32]; + let mut weight_arr = [0f32; 32]; + let mut xval_arr = [0f32; 32]; + let mut L_arr = [0i8; 32]; + let mut Laux_arr = [0i8; 32]; + let mut waux_arr = [0f32; 32]; + let mut is_on_grid = [false; 8]; + let mut is_on_grid_aux = [false; 8]; + let mut block_signs = [0u8; 8]; + let mut q3 = [0u8; 3 * (QK_K as usize / 8) + (QK_K as usize / 32)]; + + // Calculate pointers into q3 + let scales_and_signs = q3[QK_K as usize / 4..].as_mut_ptr() as *mut u32; + let qh = q3[3 * (QK_K as usize / 8)..].as_mut_ptr(); + + // For each block of QK_K values: + for ibl in 0..nbl as usize { + // Set the first fp16 value to zero. + *dh = f16::from_f32(0.0); + ptr::write_bytes( + q3.as_mut_ptr(), + 0, + 3 * (QK_K as usize / 8) + (QK_K as usize / 32), + ); + + let mut max_scale = 0f32; + let xbl = x.add(QK_K as usize * ibl); + let mut sumx2 = 0f32; + for i in 0..(QK_K as usize) { + let xi = *xbl.add(i); + sumx2 += xi * xi; + } + let sigma2 = 2.0 * sumx2 / QK_K as f32; + + for ib in 0..(QK_K as usize / 32) { + let xb = xbl.add(32 * ib); + if let Some(quant_weights) = quant_weights { + let qw = &quant_weights[QK_K as usize * ibl + 32 * ib..]; + for i in 0..32 { + weight_arr[i] = qw[i] * ((sigma2 + (*xb.add(i)) * (*xb.add(i))).sqrt()); + } + } else { + for i in 0..32 { + weight_arr[i] = *xb.add(i) * (*xb.add(i)); + } + } + for i in 0..32 { + waux_arr[i] = weight_arr[i].sqrt(); + } + for k in 0..4 { + let mut nflip = 0; + let mut s: u8 = 0; + for i in 0..8 { + let val = *xb.add(8 * k + i); + if val >= 0.0 { + xval_arr[8 * k + i] = val; + } else { + xval_arr[8 * k + i] = -val; + nflip += 1; + s |= 1 << i; + } + } + if nflip % 2 != 0 { + let mut imin = 0; + let mut min_val = weight_arr[8 * k + imin] + * (*xb.add(8 * k + imin)) + * (*xb.add(8 * k + imin)); + for i in 1..8 { + let ax = + weight_arr[8 * k + i] * (*xb.add(8 * k + i)) * (*xb.add(8 * k + i)); + if ax < min_val { + min_val = ax; + imin = i; + } + } + xval_arr[8 * k + imin] = -xval_arr[8 * k + imin]; + s ^= 1 << imin; + } + block_signs[k] = s & 127; + } + let mut max_val = xval_arr[0]; + for i in 1..32 { + max_val = if max_val > xval_arr[i] { + max_val + } else { + xval_arr[i] + }; + } + if max_val < GROUP_MAX_EPS_IQ3_XXS { + scales[ib] = 0.0; + for i in 0..32 { + L_arr[i] = 0; + } + continue; + } + let mut best = 0f32; + let mut scale = max_val / ((2 * k_max_q - 1) as f32); + for is in -15..=15 { + let id = ((2 * k_max_q - 1) as f32 + (is as f32) * 0.2) / max_val; + let this_scale = 1.0 / id; + for k in 0..8 { + for i in 0..4 { + // nearest_int and clamp must be defined elsewhere. + let l = nearest_int(0.5 * (id * xval_arr[4 * k + i] - 1.0)); + Laux_arr[4 * k + i] = l.clamp(0, k_max_q - 1) as i8; + } + let mut u: u16 = 0; + for i in 0..4 { + u |= (Laux_arr[4 * k + i] as u16) << (3 * i); + } + let mut grid_index = *kmap_q3xs.add(u as usize); + is_on_grid_aux[k] = true; + if grid_index < 0 { + is_on_grid_aux[k] = false; + let neighbours = + kneighbors_q3xs.offset(-(*kmap_q3xs.add(u as usize)) as isize - 1); + grid_index = iq3_find_best_neighbour( + neighbours, + kgrid_q3xs, + xval_arr.as_ptr().add(4 * k), + waux_arr.as_ptr().add(4 * k), + this_scale, + Laux_arr.as_mut_ptr().add(4 * k), + ); + } + } + let mut sumqx = 0f32; + let mut sumq2 = 0f32; + for i in 0..32 { + let w = weight_arr[i]; + let q = 2.0 * (Laux_arr[i] as f32) + 1.0; + sumqx += w * xval_arr[i] * q; + sumq2 += w * q * q; + } + if sumq2 > 0.0 && sumqx * sumqx > best * sumq2 { + scale = sumqx / sumq2; + best = scale * sumqx; + for i in 0..32 { + L_arr[i] = Laux_arr[i]; + } + for k in 0..8 { + is_on_grid[k] = is_on_grid_aux[k]; + } + } + } + let mut n_not_ongrid = 0; + for k in 0..8 { + if !is_on_grid[k] { + n_not_ongrid += 1; + } + } + if n_not_ongrid > 0 && scale > 0.0 { + let id = 1.0 / scale; + for k in 0..8 { + if is_on_grid[k] { + continue; + } + let mut u: u16 = 0; + for i in 0..4 { + let mut l = nearest_int(0.5 * (id * xval_arr[4 * k + i] - 1.0)); + l = l.clamp(0, k_max_q - 1); + u |= (l as u16) << (3 * i); + } + let mut grid_index = *kmap_q3xs.add(u as usize); + if grid_index < 0 { + let neighbours = + kneighbors_q3xs.offset(-(*kmap_q3xs.add(u as usize)) as isize - 1); + grid_index = iq3_find_best_neighbour( + neighbours, + kgrid_q3xs, + xval_arr.as_ptr().add(4 * k), + waux_arr.as_ptr().add(4 * k), + scale, + L_arr.as_mut_ptr().add(4 * k), + ); + } + let pg = (kgrid_q3xs.add(grid_index as usize)) as *const i8; + for i in 0..4 { + L_arr[4 * k + i] = ((*pg.add(i)) - 1) / 2; + } + } + let mut sumqx = 0f32; + let mut sumq2 = 0f32; + for i in 0..32 { + let w = weight_arr[i]; + let q = 2.0 * (L_arr[i] as f32) + 1.0; + sumqx += w * xval_arr[i] * q; + sumq2 += w * q * q; + } + if sumq2 > 0.0 { + scale = sumqx / sumq2; + } + } + if scale < 0.0 { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for k in 0..4 { + block_signs[k] = (!block_signs[k]) & 127; + } + } + for k in 0..8 { + let mut u: u16 = 0; + for i in 0..4 { + u |= (L_arr[4 * k + i] as u16) << (3 * i); + } + let grid_index = *kmap_q3xs.add(u as usize); + if grid_index < 0 { + println!("Oops: found point {} not on grid:", u); + for i in 0..4 { + print!(" {}", L_arr[4 * k + i]); + } + println!(); + panic!("fatal error"); + } + if grid_size == 256 { + q3[8 * ibl + k] = grid_index as u8; + } else { + q3[8 * ibl + k] = (grid_index & 255) as u8; + *qh |= ((grid_index >> 8) as u8) << k; + } + } + // Pack block_signs into scales_and_signs + *scales_and_signs.add(ibl) = block_signs[0] as u32 + | ((block_signs[1] as u32) << 7) + | ((block_signs[2] as u32) << 14) + | ((block_signs[3] as u32) << 21); + assert!(scale >= 0.0); + scales[ibl] = scale; + if scale > max_scale { + max_scale = scale; + } + } + + if max_scale == 0.0 { + ptr::write_bytes(qs, 0, quant_size as usize); + dh = dh.add(block_size as usize / mem::size_of::()); + qs = qs.add(block_size as usize); + continue; + } + let d = max_scale / 31.0; + *dh = f16::from_f32(d * 1.0125); + let id = 1.0 / d; + for ib in 0..(QK_K as usize / 32) { + let l = nearest_int(0.5 * (id * scales[ib] - 1.0)); + let l = l.clamp(0, 15); + let prev = *scales_and_signs.add(ib); + *scales_and_signs.add(ib) = prev | ((l as u32) << 28); + } + ptr::copy_nonoverlapping(q3.as_ptr(), qs, quant_size as usize); + dh = dh.add(block_size as usize / mem::size_of::()); + qs = qs.add(block_size as usize); + } +} diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 6ea9d23fad..e4b5bf9a28 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -115,6 +115,10 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockIQ4nl::to_float(&vec, &mut out)?; } + GgmlDType::Iq3Xxs => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockIQ3xxs::to_float(&vec, &mut out)?; + } } let buffer = self.device.new_buffer_with_data(&out)?; @@ -397,6 +401,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, GgmlDType::Iq4Xs => candle_metal_kernels::GgmlDType::Iq4Xs, GgmlDType::Iq4Nl => candle_metal_kernels::GgmlDType::Iq4Nl, + GgmlDType::Iq3Xxs => candle_metal_kernels::GgmlDType::Iq3Xxs, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16, diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 65aaace433..01abafae25 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -204,6 +204,7 @@ pub enum GgmlDType { Q8K, Iq4Xs, Iq4Nl, + Iq3Xxs, } impl GgmlDType { @@ -223,6 +224,7 @@ impl GgmlDType { 13 => Self::Q5K, 14 => Self::Q6K, 15 => Self::Q8K, + 18 => Self::Iq3Xxs, 20 => Self::Iq4Nl, 23 => Self::Iq4Xs, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 @@ -248,6 +250,7 @@ impl GgmlDType { Self::Q5K => 13, Self::Q6K => 14, Self::Q8K => 15, + Self::Iq3Xxs => 18, Self::Iq4Nl => 20, Self::Iq4Xs => 23, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 @@ -280,6 +283,10 @@ impl GgmlDType { BlockIQ4xs::zeros(); elem_count / BlockIQ4xs::BLCK_SIZE ]), + Self::Iq3Xxs => Box::new(vec![ + BlockIQ3xxs::zeros(); + elem_count / BlockIQ3xxs::BLCK_SIZE + ]), Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]), } } @@ -304,6 +311,7 @@ impl GgmlDType { Self::Q8K => std::mem::size_of::(), Self::Iq4Nl => std::mem::size_of::(), Self::Iq4Xs => std::mem::size_of::(), + Self::Iq3Xxs => std::mem::size_of::(), } } @@ -319,9 +327,14 @@ impl GgmlDType { Self::Q8_0 => k_quants::QK8_0, Self::Q8_1 => k_quants::QK8_1, Self::Iq4Nl => iq_quants::QK4_NL, - Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K | Self::Iq4Xs => { - k_quants::QK_K - } + Self::Q2K + | Self::Q3K + | Self::Q4K + | Self::Q5K + | Self::Q6K + | Self::Q8K + | Self::Iq4Xs + | Self::Iq3Xxs => k_quants::QK_K, } } } diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 34fc6afc9e..b665b1d2ac 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1023,6 +1023,44 @@ fn quantize_iq4_nl(device: &Device) -> Result<()> { Ok(()) } +fn quantize_iq3_xxs(device: &Device) -> Result<()> { + let dtype = GgmlDType::Iq3Xxs; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + dbg!(&dst.mean_all()?); + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.025); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 5.9); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + + Ok(()) +} + #[test] fn imatrix_quantize_iq4_xs() -> Result<()> { // let data = @@ -1185,6 +1223,12 @@ test_device!( quantize_iq4_nl_cuda, quantize_iq4_nl_metal ); +test_device!( + quantize_iq3_xxs, + quantize_iq3_xxs_cpu, + quantize_iq3_xxs_cuda, + quantize_iq3_xxs_metal +); /// Very simple dot product implementation fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 9d31dd7f81..6d6ece376f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2449,6 +2449,7 @@ pub enum GgmlDType { BF16, Iq4Xs, Iq4Nl, + Iq3Xxs, } #[allow(clippy::too_many_arguments)] @@ -2545,6 +2546,12 @@ pub fn call_quantized_matmul_mv_t( let align = 4; (nth0, nth1, align, Some(32*std::mem::size_of::())) } + GgmlDType::Iq3Xxs => { + let nth0 = 4; + let nth1 = 16; + let align = 8; + (nth0, nth1, align, Some( 256*4+128)) + } }; let thread_groups_count = MTLSize { width: divide(ne01 as usize, align), @@ -2574,6 +2581,7 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::F32 => "kernel_mul_mv_f32_f32", GgmlDType::Iq4Xs => "kernel_mul_mv_iq4_xs_f32", GgmlDType::Iq4Nl => "kernel_mul_mv_iq4_nl_f32", + GgmlDType::Iq3Xxs => "kernel_mul_mv_iq3_xxs_f32", }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; @@ -2688,6 +2696,7 @@ pub fn call_quantized_matmul_mm_t( GgmlDType::F32 => "kernel_mul_mm_f32_f32", GgmlDType::Iq4Xs => "kernel_mul_mm_iq4_xs_f32", GgmlDType::Iq4Nl => "kernel_mul_mm_iq4_nl_f32", + GgmlDType::Iq3Xxs => "kernel_mul_mm_iq3_xxs_f32", }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; From d3909ae33af7dd7370245745b9ad9ff23bea1474 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 6 Feb 2025 19:55:46 -0500 Subject: [PATCH 23/26] Some fixes --- candle-core/src/quantized/iq_quants/mod.rs | 12 ++++++------ candle-core/src/quantized/iq_quants/utils.rs | 14 +++++++------- candle-core/tests/quantized_tests.rs | 1 + 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/candle-core/src/quantized/iq_quants/mod.rs b/candle-core/src/quantized/iq_quants/mod.rs index 2c2c6f836b..298617a1e5 100644 --- a/candle-core/src/quantized/iq_quants/mod.rs +++ b/candle-core/src/quantized/iq_quants/mod.rs @@ -421,9 +421,9 @@ impl GgmlType for BlockIQ3xxs { // Get pointers to grid1 and grid2. // qs[2*l+0] and qs[2*l+1] are used as offsets into IQ3XXS_GRID. let idx1 = *qs.add((2 * l + 0) as usize) as usize; - let grid1 = IQ3XXS_GRID.as_ptr().add(idx1); + let grid1 = IQ3XXS_GRID.as_ptr().add(idx1) as *const u8; let idx2 = *qs.add((2 * l + 1) as usize) as usize; - let grid2 = IQ3XXS_GRID.as_ptr().add(idx2); + let grid2 = IQ3XXS_GRID.as_ptr().add(idx2) as *const u8; // For each of 4 values in grid1 and grid2. for j in 0..4 { @@ -452,8 +452,8 @@ impl GgmlType for BlockIQ3xxs { fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { let k = xs.len(); - if k % QK4_NL != 0 { - bail!("Input length must be multiple of QK4_NL = {}", QK4_NL); + if k % QK_K != 0 { + bail!("Input length must be multiple of QK_K = {}", QK_K); } unsafe { quantize_iq3_xxs(xs, ys, 1, k, None)? }; @@ -468,8 +468,8 @@ impl GgmlType for BlockIQ3xxs { n_per_row: usize, ) -> Result<()> { let k = xs.len(); - if k % QK4_NL != 0 { - bail!("Input length must be multiple of QK4_NL = {}", QK4_NL); + if k % QK_K != 0 { + bail!("Input length must be multiple of QK_K = {}", QK_K); } let nrow = xs.len() / n_per_row; diff --git a/candle-core/src/quantized/iq_quants/utils.rs b/candle-core/src/quantized/iq_quants/utils.rs index e4b379a33a..4614de6169 100644 --- a/candle-core/src/quantized/iq_quants/utils.rs +++ b/candle-core/src/quantized/iq_quants/utils.rs @@ -438,12 +438,12 @@ fn iq3xs_init_impl(grid_size: i32) { dist_vec.push((d, j)); } // Sort the vector: first by distance, then by grid index. - dist_vec.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1))); + dist_vec.sort_by(|a, b| a.cmp(&b)); // Count how many neighbors to include. let mut n = 0; let mut nhave = 1; let mut d_current = dist_vec[0].0; - for &(d, _) in dist_vec.iter().step_by(2) { + for &(d, _) in dist_vec.iter() { if d > d_current { if nhave == nwant { break; @@ -480,7 +480,7 @@ fn iq3xs_init_impl(grid_size: i32) { let d = dist2(&grid_bytes, &pos); dist_vec.push((d, j)); } - dist_vec.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1))); + dist_vec.sort_by(|a, b| a.cmp(&b)); // Store negative index in kmap to indicate start offset in the neighbours vector. kmap[i] = -((nbr_counter as i32) + 1); @@ -493,7 +493,7 @@ fn iq3xs_init_impl(grid_size: i32) { let mut n = 0; let mut nhave = 1; let mut d_current = dist_vec[0].0; - for &(d, j) in dist_vec.iter().step_by(2) { + for &(d, j) in dist_vec.iter() { if d > d_current { if nhave == nwant { break; @@ -502,7 +502,7 @@ fn iq3xs_init_impl(grid_size: i32) { nhave += 1; } // Store the grid index as u16. - neighbours.push(j as u16 + 1); + neighbours.push(j as u16); nbr_counter += 1; n += 1; } @@ -702,7 +702,7 @@ pub unsafe fn quantize_row_iq3_xxs_impl( if grid_index < 0 { is_on_grid_aux[k] = false; let neighbours = - kneighbors_q3xs.offset(-(*kmap_q3xs.add(u as usize)) as isize - 1); + kneighbors_q3xs.offset(-(*kmap_q3xs.add(u as usize) + 1) as isize); grid_index = iq3_find_best_neighbour( neighbours, kgrid_q3xs, @@ -753,7 +753,7 @@ pub unsafe fn quantize_row_iq3_xxs_impl( let mut grid_index = *kmap_q3xs.add(u as usize); if grid_index < 0 { let neighbours = - kneighbors_q3xs.offset(-(*kmap_q3xs.add(u as usize)) as isize - 1); + kneighbors_q3xs.offset(-(*kmap_q3xs.add(u as usize) + 1) as isize); grid_index = iq3_find_best_neighbour( neighbours, kgrid_q3xs, diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index b665b1d2ac..5de4828d7b 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1029,6 +1029,7 @@ fn quantize_iq3_xxs(device: &Device) -> Result<()> { let quant = quantized::QTensor::quantize(&src, dtype)?; let dst = quant.dequantize(device)?; let dst_f16 = quant.dequantize_f16(device)?; + dbg!(&src.mean_all()?); dbg!(&dst.mean_all()?); let diff = (dst.to_dtype(DType::F16)? - dst_f16)? .to_dtype(DType::F32)? From 0f862b2d0b5445bdc8fac7679616926b8f127a92 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 7 Feb 2025 14:49:43 -0500 Subject: [PATCH 24/26] Add f8q8 --- candle-core/src/quantized/avx.rs | 22 +++- candle-core/src/quantized/k_quants/mod.rs | 122 ++++++++++++++++++++++ candle-core/src/quantized/metal.rs | 5 + candle-core/src/quantized/mod.rs | 7 +- candle-core/src/quantized/neon.rs | 98 ++++++++++++++++- candle-core/src/quantized/simd128.rs | 44 +++++++- candle-core/tests/quantized_tests.rs | 26 +++++ 7 files changed, 319 insertions(+), 5 deletions(-) diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs index 664f7653ee..427b34b133 100644 --- a/candle-core/src/quantized/avx.rs +++ b/candle-core/src/quantized/avx.rs @@ -1,5 +1,6 @@ use super::k_quants::{ - BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, + BlockF8Q8, BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, + QK8_0, QK_K, }; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; @@ -87,6 +88,25 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> } } +#[inline(always)] +pub(crate) fn vec_dot_f8q8_q8_0(n: usize, xs: &[BlockF8Q8], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_f8q8_q8_0: {n} is not divisible by {qk}") + } + unsafe { + let mut acc = _mm256_setzero_ps(); + for (x, y) in xs.iter().zip(ys.iter()) { + let d = _mm256_set1_ps(x.dq_d() * f16::to_f32(y.d)); + let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i); + let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i); + let q = mul_sum_i8_pairs_float(bx, by); + acc = _mm256_fmadd_ps(d, q, acc); + } + Ok(hsum_float_8(acc)) + } +} + #[inline(always)] unsafe fn get_scale_shuffle(i: usize) -> __m128i { const K_SHUFFLE: [u8; 128] = [ diff --git a/candle-core/src/quantized/k_quants/mod.rs b/candle-core/src/quantized/k_quants/mod.rs index e00fa3a906..fc4c90950c 100644 --- a/candle-core/src/quantized/k_quants/mod.rs +++ b/candle-core/src/quantized/k_quants/mod.rs @@ -6,6 +6,7 @@ use super::k_quants::utils::{ use super::{GgmlDType, GgmlType}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; +use float8::F8E4M3; use half::f16; // Default to QK_K 256 rather than 64. @@ -63,6 +64,20 @@ pub struct BlockQ8_0 { } const _: () = assert!(std::mem::size_of::() == 34); +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockF8Q8 { + d: F8E4M3, + pub(crate) qs: [i8; QK8_0], +} +const _: () = assert!(std::mem::size_of::() == 33); + +impl BlockF8Q8 { + pub fn dq_d(&self) -> f32 { + self.d.to_f32() / F8E4M3::MAX.to_f32() + } +} + #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ8_1 { @@ -702,6 +717,113 @@ impl GgmlType for BlockQ8_0 { } } +impl GgmlType for BlockF8Q8 { + const DTYPE: GgmlDType = GgmlDType::F8Q8; + const BLCK_SIZE: usize = QK8_0; + type VecDotType = BlockQ8_0; + const SUPPORTS_I8MM: bool = true; + + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK8_0 != 0 { + crate::bail!("dequantize_row_f8q8: {k} is not divisible by {QK8_0}"); + } + + let nb = k / QK8_0; + + for i in 0..nb { + let d = xs[i].d.to_f32(); + + for j in 0..QK8_0 { + ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d; + } + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q8_0 + let k = xs.len(); + if k % Self::BLCK_SIZE != 0 { + crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE); + }; + let nb = k / Self::BLCK_SIZE; + if ys.len() != nb { + crate::bail!( + "size mismatch {} {} {}", + xs.len(), + ys.len(), + Self::BLCK_SIZE + ) + } + for (i, ys) in ys.iter_mut().enumerate() { + let mut amax = 0f32; + let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; + for &x in xs.iter() { + amax = amax.max(x.abs()) + } + let d = amax / ((1 << 7) - 1) as f32; + let id = if d != 0f32 { 1. / d } else { 0. }; + ys.d = F8E4M3::from_f32(d * F8E4M3::MAX.to_f32()); + for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) { + *y = f32::round(x * id) as i8 + } + } + Ok(()) + } + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_f8q8_q8_0(n, xs, ys); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_f8q8_q8_0(n, xs, ys); + + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_f8q8_q8_0(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_f8q8_q8_0: {n} is not divisible by {qk}") + } + + // Generic implementation. + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + let sum_i = xs + .qs + .iter() + .zip(ys.qs.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum::(); + sumf += sum_i as f32 * xs.dq_d() * f16::to_f32(ys.d) + } + Ok(sumf) + } + + #[allow(unreachable_code)] + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q8_0_q8_0(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } +} + impl GgmlType for BlockQ8_1 { const DTYPE: GgmlDType = GgmlDType::Q8_1; const BLCK_SIZE: usize = QK8_1; diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index e4b5bf9a28..9ff1268f14 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -119,6 +119,10 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockIQ3xxs::to_float(&vec, &mut out)?; } + GgmlDType::F8Q8 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockF8Q8::to_float(&vec, &mut out)?; + } } let buffer = self.device.new_buffer_with_data(&out)?; @@ -405,6 +409,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16, + GgmlDType::F8Q8 => todo!("F8Q8 is unsupported on Metal"), } } } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 01abafae25..727ebad4d3 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -205,6 +205,7 @@ pub enum GgmlDType { Iq4Xs, Iq4Nl, Iq3Xxs, + F8Q8, } impl GgmlDType { @@ -229,6 +230,7 @@ impl GgmlDType { 23 => Self::Iq4Xs, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 30 => Self::BF16, + 100 => Self::F8Q8, _ => crate::bail!("unknown dtype for tensor {u}"), }; Ok(dtype) @@ -255,6 +257,7 @@ impl GgmlDType { Self::Iq4Xs => 23, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 Self::BF16 => 30, + Self::F8Q8 => 100, } } @@ -287,6 +290,7 @@ impl GgmlDType { BlockIQ3xxs::zeros(); elem_count / BlockIQ3xxs::BLCK_SIZE ]), + Self::F8Q8 => Box::new(vec![BlockF8Q8::zeros(); elem_count / BlockF8Q8::BLCK_SIZE]), Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]), } } @@ -312,6 +316,7 @@ impl GgmlDType { Self::Iq4Nl => std::mem::size_of::(), Self::Iq4Xs => std::mem::size_of::(), Self::Iq3Xxs => std::mem::size_of::(), + Self::F8Q8 => std::mem::size_of::(), } } @@ -324,7 +329,7 @@ impl GgmlDType { Self::Q4_1 => k_quants::QK4_1, Self::Q5_0 => k_quants::QK5_0, Self::Q5_1 => k_quants::QK5_1, - Self::Q8_0 => k_quants::QK8_0, + Self::Q8_0 | Self::F8Q8 => k_quants::QK8_0, Self::Q8_1 => k_quants::QK8_1, Self::Iq4Nl => iq_quants::QK4_NL, Self::Q2K diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index df8389075e..23c8484d57 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -1,8 +1,8 @@ use super::{ iq_quants::{BlockIQ4nl, BlockIQ4xs, QK4_NL}, k_quants::{ - BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, - QK_K, + BlockF8Q8, BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, + BlockQ8_0, QK8_0, QK_K, }, }; use crate::{quantized::KVALUES_IQ4NL, Result}; @@ -235,6 +235,100 @@ pub(crate) fn i8mm_q8_0_q8_0( } } +#[inline(always)] +pub(crate) fn vec_dot_f8q8_q8_0(n: usize, xs: &[BlockF8Q8], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_f8q8_q8_0: {n} is not divisible by {qk}") + } + let nb = n / QK8_0; + unsafe { + let mut sumv0 = vdupq_n_f32(0.0f32); + for i in 0..nb { + let x0 = &xs[i]; + let y0 = &ys[i]; + + let x0_0 = vld1q_s8(x0.qs.as_ptr()); + let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16)); + + // load y + let y0_0 = vld1q_s8(y0.qs.as_ptr()); + let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); + + let p0 = vdotq_s32_local(vdupq_n_s32(0), x0_0, y0_0); + let p1 = vdotq_s32_local(vdupq_n_s32(0), x0_1, y0_1); + + sumv0 = vmlaq_n_f32( + sumv0, + vcvtq_f32_s32(vaddq_s32(p0, p1)), + x0.dq_d() * y0.d.to_f32(), + ); + } + Ok(vaddvq_f32(sumv0)) + } +} +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q8_0_q8_0( + n: usize, + xs_0: &[BlockF8Q8], + xs_1: &[BlockF8Q8], + ys_0: &[BlockQ8_0], + ys_1: &[BlockQ8_0], +) -> Result<[f32; 4]> { + assert_eq!(xs_0.len(), xs_1.len()); + assert_eq!(ys_0.len(), ys_1.len()); + assert_eq!(xs_0.len(), ys_0.len()); + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("i8mm_q8_0_q8_0: {n} is not divisible by {qk}") + } + let nb = n / QK8_0; + unsafe { + let mut sum_f32 = vdupq_n_f32(0.0); + + for i in 0..nb { + let x0 = &xs_0[i]; + let x1 = &xs_1[i]; + let y0 = &ys_0[i]; + let y1 = &ys_1[i]; + + let factor_00: f32 = x0.dq_d() * y0.d.to_f32(); + let factor_01: f32 = x1.dq_d() * y0.d.to_f32(); + let factor_10: f32 = x0.dq_d() * y1.d.to_f32(); + let factor_11: f32 = x1.dq_d() * y1.d.to_f32(); + + let xv0_0 = vld1q_s8(x0.qs.as_ptr()); + let xv0_1 = vld1q_s8(x0.qs.as_ptr().add(16)); + let xv1_0 = vld1q_s8(x1.qs.as_ptr()); + let xv1_1 = vld1q_s8(x1.qs.as_ptr().add(16)); + + let yv0_0 = vld1q_s8(y0.qs.as_ptr()); + let yv0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); + let yv1_0 = vld1q_s8(y1.qs.as_ptr()); + let yv1_1 = vld1q_s8(y1.qs.as_ptr().add(16)); + + let i8mm = i8mm_params::new(xv0_0, xv0_1, xv1_0, xv1_1, yv0_0, yv0_1, yv1_0, yv1_1); + let loop_sum_s32 = i8mm.calculate(vdupq_n_s32(0)); + + // scaling + let factor_elems: [f32; 4] = [factor_00, factor_01, factor_10, factor_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + let loop_sum_f32 = vcvtq_f32_s32(loop_sum_s32); + + sum_f32 = vmlaq_f32(sum_f32, loop_sum_f32, factor); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sum_f32, 0); + let f1 = vgetq_lane_f32(sum_f32, 1); + let f2 = vgetq_lane_f32(sum_f32, 2); + let f3 = vgetq_lane_f32(sum_f32, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} + #[inline(always)] pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result { let qk = QK_K; diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index 1c8c0f2068..66ef2bf5b4 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -1,4 +1,6 @@ -use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; +use super::k_quants::{ + BlockF8Q8, BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, +}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; use half::f16; @@ -91,6 +93,46 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> } } +#[inline(always)] +pub(crate) fn vec_dot_f8q8_q8_0(n: usize, xs: &[BlockF8Q8], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_f8q8_q8_0: {n} is not divisible by {qk}") + } + unsafe { + let mut acc = f32x4_splat(0.0f32); + for (x, y) in xs.iter().zip(ys.iter()) { + let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr()); + let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr()); + let sum_xy = i32x4_dot_i16x8(x1, y1); + + let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8)); + let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2)); + + let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16)); + let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3)); + + let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24)); + let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4)); + + let sum_xy = f32x4_convert_i32x4(sum_xy); + + // f32x4_relaxed_madd is nightly only. + let d = f32x4_splat(x.dq_d() * f16::to_f32(y.d)); + let scaled = f32x4_mul(sum_xy, d); + acc = f32x4_add(acc, scaled) + } + let res = f32x4_extract_lane::<0>(acc) + + f32x4_extract_lane::<1>(acc) + + f32x4_extract_lane::<2>(acc) + + f32x4_extract_lane::<3>(acc); + Ok(res) + } +} + #[inline(always)] pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 5de4828d7b..66bbd3f8c2 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -386,6 +386,26 @@ fn quantize_q5_1(device: &Device) -> Result<()> { Ok(()) } +fn quantize_f8q8(device: &Device) -> Result<()> { + let dtype = GgmlDType::F8Q8; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.01); + + Ok(()) +} + fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result { assert!( size % crate::quantized::k_quants::QK_K == 0, @@ -1176,6 +1196,12 @@ test_device!( quantize_q5_1_cuda, quantize_q5_1_metal ); +test_device!( + quantize_f8q8, + quantize_f8q8_cpu, + quantize_f8q8_cuda, + quantize_f8q8_metal +); test_device!( quantize_q2k, quantize_q2k_cpu, From 9557e2355af95fc39c9ed11a48202beb2a4503fa Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 8 Feb 2025 11:04:46 -0500 Subject: [PATCH 25/26] Fix f8q8 dequant --- candle-core/src/quantized/k_quants/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-core/src/quantized/k_quants/mod.rs b/candle-core/src/quantized/k_quants/mod.rs index fc4c90950c..9b83767d58 100644 --- a/candle-core/src/quantized/k_quants/mod.rs +++ b/candle-core/src/quantized/k_quants/mod.rs @@ -733,7 +733,7 @@ impl GgmlType for BlockF8Q8 { let nb = k / QK8_0; for i in 0..nb { - let d = xs[i].d.to_f32(); + let d = xs[i].dq_d(); for j in 0..QK8_0 { ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d; From 48804a550d75dac331cf8920d0400a3d95156af3 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 22 Feb 2025 16:28:45 -0500 Subject: [PATCH 26/26] Fix i8mm --- candle-core/src/quantized/k_quants/mod.rs | 2 +- candle-core/src/quantized/neon.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-core/src/quantized/k_quants/mod.rs b/candle-core/src/quantized/k_quants/mod.rs index 9b83767d58..1c07847a73 100644 --- a/candle-core/src/quantized/k_quants/mod.rs +++ b/candle-core/src/quantized/k_quants/mod.rs @@ -818,7 +818,7 @@ impl GgmlType for BlockF8Q8 { ys_1: &[Self::VecDotType], ) -> Result<[f32; 4]> { #[cfg(target_feature = "neon")] - return super::neon::i8mm_q8_0_q8_0(n, xs_0, xs_1, ys_0, ys_1); + return super::neon::i8mm_f8q8_q8_0(n, xs_0, xs_1, ys_0, ys_1); crate::bail!("Unsupported block type for i8mm"); } diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 23c8484d57..200911ddf3 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -269,7 +269,7 @@ pub(crate) fn vec_dot_f8q8_q8_0(n: usize, xs: &[BlockF8Q8], ys: &[BlockQ8_0]) -> } #[inline(always)] #[cfg(feature = "arm-nightly-feat")] -pub(crate) fn i8mm_q8_0_q8_0( +pub(crate) fn i8mm_f8q8_q8_0( n: usize, xs_0: &[BlockF8Q8], xs_1: &[BlockF8Q8],