diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 752e478d9b..cd44ec3944 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -20,6 +20,7 @@ gemm = { workspace = true } half = { workspace = true } float8 = { workspace = true } intel-mkl-src = { workspace = true, optional = true } +itertools = "0.12.1" libc = { workspace = true, optional = true } memmap2 = { workspace = true } num-traits = { workspace = true } @@ -46,6 +47,7 @@ nccl = ["cuda", "cudarc/nccl"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] metal = ["dep:metal", "dep:candle-metal-kernels"] +arm-nightly-feat = [] [[bench]] name = "bench_main" diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index f7571c45de..a0efa00998 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -47,6 +47,11 @@ //! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models. //! +#![cfg_attr(feature = "arm-nightly-feat", feature(stdarch_neon_dotprod))] +#![cfg_attr(feature = "arm-nightly-feat", feature(array_chunks))] +#![cfg_attr(feature = "arm-nightly-feat", feature(stdarch_neon_i8mm))] +#![cfg_attr(feature = "arm-nightly-feat", feature(portable_simd))] + #[cfg(feature = "accelerate")] mod accelerate; pub mod backend; diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs index 664f7653ee..427b34b133 100644 --- a/candle-core/src/quantized/avx.rs +++ b/candle-core/src/quantized/avx.rs @@ -1,5 +1,6 @@ use super::k_quants::{ - BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, + BlockF8Q8, BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, + QK8_0, QK_K, }; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; @@ -87,6 +88,25 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> } } +#[inline(always)] +pub(crate) fn vec_dot_f8q8_q8_0(n: usize, xs: &[BlockF8Q8], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_f8q8_q8_0: {n} is not divisible by {qk}") + } + unsafe { + let mut acc = _mm256_setzero_ps(); + for (x, y) in xs.iter().zip(ys.iter()) { + let d = _mm256_set1_ps(x.dq_d() * f16::to_f32(y.d)); + let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i); + let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i); + let q = mul_sum_i8_pairs_float(bx, by); + acc = _mm256_fmadd_ps(d, q, acc); + } + Ok(hsum_float_8(acc)) + } +} + #[inline(always)] unsafe fn get_scale_shuffle(i: usize) -> __m128i { const K_SHUFFLE: [u8; 128] = [ diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 19ca38495f..158fb019a0 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -1,5 +1,5 @@ use super::{GgmlDType, QStorage}; -use crate::quantized::k_quants::GgmlType; +use crate::quantized::quants::GgmlType; use crate::{backend::BackendDevice, cuda_backend::WrapErr}; use crate::{CudaDevice, CudaStorage, Result}; use half::f16; diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index ea5ec02578..bae6cafdae 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -1,6 +1,6 @@ //! Support for the GGML file format. -use super::{k_quants, GgmlDType, QStorage}; +use super::{iq_quants, k_quants, GgmlDType, QStorage}; use crate::{Device, Result}; use byteorder::{LittleEndian, ReadBytesExt}; use std::collections::HashMap; @@ -184,6 +184,9 @@ pub fn qtensor_from_ggml( GgmlDType::Q6K => { from_raw_data::(raw_data, size_in_bytes, dims, device) } + GgmlDType::Iq4Xs => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } diff --git a/candle-core/src/quantized/iq_quants/mod.rs b/candle-core/src/quantized/iq_quants/mod.rs new file mode 100644 index 0000000000..298617a1e5 --- /dev/null +++ b/candle-core/src/quantized/iq_quants/mod.rs @@ -0,0 +1,651 @@ +use std::ptr; + +use half::f16; +use utils::{quantize_row_iq3_xxs_impl, quantize_row_iq4_nl_impl}; + +use crate::{bail, Result}; + +mod utils; + +use super::{k_quants::BlockQ8_0, BlockQ8K, GgmlDType, GgmlType, QK_K}; + +pub const QK4_NL: usize = 32; + +pub(super) const KVALUES_IQ4NL: [i8; 16] = [ + -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113, +]; + +const K_MASK_IQ2XS: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128]; + +const K_SIGNS_IQ2XS: [u8; 128] = [ + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, 144, 17, 18, 147, 20, 149, + 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, + 43, 172, 45, 46, 175, 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, 80, 209, 210, 83, 212, + 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, 96, 225, 226, 99, 228, 101, 102, 231, 232, + 105, 106, 235, 108, 237, 238, 111, 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, + 252, 125, 126, 255, +]; + +const IQ3XXS_GRID: [u32; 256] = [ + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +]; + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockIQ4xs { + pub(crate) d: f16, + pub(crate) scales_h: u16, + pub(crate) scales_l: [u8; QK_K / 64], + pub(crate) qs: [u8; QK_K / 2], +} + +const _: () = assert!( + std::mem::size_of::() + == std::mem::size_of::() + std::mem::size_of::() + QK_K / 64 + QK_K / 2, + "wrong iq4_xs block size/padding" +); + +impl GgmlType for BlockIQ4xs { + const DTYPE: GgmlDType = GgmlDType::Iq4Xs; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = false; + + fn to_float(xs: &[Self], mut ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize block iq4xs {k} is not divisible by {QK_K}"); + } + + let nb = k / QK_K; + for block in xs.iter().take(nb) { + let d = block.d.to_f32(); + let mut qs = &block.qs[..]; + + for ib in 0..(QK_K / 32) { + let ib_div_2 = ib / 2; + let ib_mod_2 = ib % 2; + + let ls_low = (block.scales_l[ib_div_2] as i32 >> (4 * ib_mod_2 as i32)) & 0xF; + let ls_high = ((block.scales_h as i32 >> (2 * ib as i32)) & 3) << 4; + let ls = ls_low | ls_high; + + let dl = d * (ls as f32 - 32.); + + for j in 0..16 { + ys[j] = dl * KVALUES_IQ4NL[(qs[j] & 0xF) as usize] as f32; + ys[j + 16] = dl * KVALUES_IQ4NL[(qs[j] >> 4) as usize] as f32; + } + + qs = &qs[16..]; + ys = &mut ys[32..]; + } + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + let k = xs.len(); + if k % QK_K != 0 { + bail!("Input length must be multiple of QK_K = {}", QK_K); + } + + quantize_iq4_xs(xs, ys, 1, k, None)?; + + Ok(()) + } + + fn from_float_imatrix( + xs: &[f32], + ys: &mut [Self], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + let k = xs.len(); + if k % QK_K != 0 { + bail!("Input length must be multiple of QK_K = {}", QK_K); + } + let nrow = xs.len() / n_per_row; + + quantize_iq4_xs(xs, ys, nrow, n_per_row, Some(imatrix_weights))?; + + Ok(()) + } + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + // #[cfg(target_feature = "avx")] + // todo!(); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_iq4_xs_q8k(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % QK_K != 0 { + bail!("n must be a multiple of QK_K"); + } + let nb = n / QK_K; + + let mut sumf = 0.0f32; + + // Loop over each block + for ibl in 0..nb { + // x[ibl], y[ibl] + let x = &xs[ibl]; + let y = &ys[ibl]; + + // Convert x.d from fp16 to fp32, then multiply by y.d + let d4d8 = x.d.to_f32() * y.d; + + // We'll track the "h" scales + let mut h = x.scales_h; + + let mut qsi = 0; // index for x.qs + let mut q8i = 0; // index for y.qs + + // so we step by 2 in Rust as well + for ib2 in (0..(QK_K / 32)).step_by(2) { + // Reproduce the logic of ls1, ls2 + let ls1 = (x.scales_l[ib2 / 2] & 0x0f) | (((h << 4) & 0x30) as u8); + let ls2 = (x.scales_l[ib2 / 2] >> 4) | (((h << 2) & 0x30) as u8); + // Then we shift h by 4 in the original code + h >>= 4; + + // Convert ls1, ls2 to "scaled" floats + let d1 = d4d8 * ((ls1 as i32) - 32) as f32; + let d2 = d4d8 * ((ls2 as i32) - 32) as f32; + + // Two sets of 16 items each + // sum of the first 16-lane block + let mut sumi1 = 0; + let mut sumi2 = 0; + + // The first pass + for j in 0..16 { + // q8i + j vs q8i + j + 16 + // qs[qsi + j] & 0xf vs (qs[qsi + j] >> 4) + sumi1 += (y.qs[q8i + j] as i32) + * KVALUES_IQ4NL[(x.qs[qsi + j] & 0x0f) as usize] as i32; + sumi2 += (y.qs[q8i + j + 16] as i32) + * KVALUES_IQ4NL[((x.qs[qsi + j] >> 4) & 0x0f) as usize] as i32; + } + sumf += d1 * ((sumi1 + sumi2) as f32); + + qsi += 16; + q8i += 32; + + // The second pass + sumi1 = 0; + sumi2 = 0; + for j in 0..16 { + sumi1 += (y.qs[q8i + j] as i32) + * KVALUES_IQ4NL[(x.qs[qsi + j] & 0x0f) as usize] as i32; + sumi2 += (y.qs[q8i + j + 16] as i32) + * KVALUES_IQ4NL[((x.qs[qsi + j] >> 4) & 0x0f) as usize] as i32; + } + sumf += d2 * ((sumi1 + sumi2) as f32); + + qsi += 16; + q8i += 32; + } + } + + Ok(sumf) + } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } +} + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockIQ4nl { + pub(crate) d: f16, + pub(crate) qs: [u8; QK4_NL / 2], +} + +const _: () = assert!( + std::mem::size_of::() == std::mem::size_of::() + QK4_NL / 2, + "wrong iq4_nl block size/padding" +); + +impl GgmlType for BlockIQ4nl { + const DTYPE: GgmlDType = GgmlDType::Iq4Nl; + const BLCK_SIZE: usize = QK4_NL; + type VecDotType = BlockQ8_0; + const SUPPORTS_I8MM: bool = false; + + fn to_float(xs: &[Self], mut ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK4_NL != 0 { + crate::bail!("dequantize block iq4nl {k} is not divisible by {QK4_NL}"); + } + + let nb = k / QK4_NL; + for block in xs.iter().take(nb) { + let d = block.d.to_f32(); + let qs = &block.qs[..]; + + for j in 0..(QK4_NL / 2) { + ys[j] = d * KVALUES_IQ4NL[(qs[j] & 0xf) as usize] as f32; + ys[j + QK4_NL / 2] = d * KVALUES_IQ4NL[(qs[j] >> 4) as usize] as f32; + } + ys = &mut ys[QK4_NL..]; + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + let k = xs.len(); + if k % QK4_NL != 0 { + bail!("Input length must be multiple of QK4_NL = {}", QK4_NL); + } + + quantize_iq4_nl(xs, ys, 1, k, None)?; + + Ok(()) + } + + fn from_float_imatrix( + xs: &[f32], + ys: &mut [Self], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + let k = xs.len(); + if k % QK4_NL != 0 { + bail!("Input length must be multiple of QK4_NL = {}", QK4_NL); + } + let nrow = xs.len() / n_per_row; + + quantize_iq4_nl(xs, ys, nrow, n_per_row, Some(imatrix_weights))?; + + Ok(()) + } + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + // #[cfg(target_feature = "avx")] + // todo!(); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_iq4_nl_q8k(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % QK4_NL != 0 { + bail!("n must be a multiple of QK4_NL"); + } + let nb = n / QK4_NL; + + let mut sumf = 0.0f32; + + // Loop over each block + for ibl in 0..nb { + let x = &xs[ibl]; + let y = &ys[ibl]; + + let d = x.d.to_f32() * y.d.to_f32(); + + let mut sumi1 = 0; + let mut sumi2 = 0; + + for j in 0..QK4_NL / 2 { + sumi1 += y.qs[j] as i32 * KVALUES_IQ4NL[(x.qs[j] & 0xf) as usize] as i32; + sumi2 += + y.qs[j + QK4_NL / 2] as i32 * KVALUES_IQ4NL[(x.qs[j] >> 4) as usize] as i32; + } + + sumf += d * (sumi1 + sumi2) as f32; + } + + Ok(sumf) + } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } +} + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockIQ3xxs { + pub(crate) d: f16, + pub(crate) qs: [u8; 3 * QK_K / 8], +} + +const _: () = assert!( + std::mem::size_of::() == std::mem::size_of::() + 3 * QK_K / 8, + "wrong iq3_xxs block size/padding" +); + +impl GgmlType for BlockIQ3xxs { + const DTYPE: GgmlDType = GgmlDType::Iq3Xxs; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = false; + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize block iq3xxs {k} is not divisible by {QK_K}"); + } + let nb = k / QK_K; + + let mut aux32: u32; + + let mut y = ys.as_mut_ptr(); + + unsafe { + // Process each block. + for i in 0..nb { + // Access the i'th block. + let block = &xs[i]; + // Convert FP16 to f32. + let d = block.d.to_f32(); + // Get a pointer to the beginning of qs. + let mut qs = block.qs.as_ptr(); + // scales_and_signs is located QK_K/4 bytes after qs. + let scales_and_signs = qs.add((QK_K / 4) as usize); + + // Loop over each 32-byte subblock. + for ib32 in 0..(QK_K / 32) { + // Copy 4 bytes from scales_and_signs + 4*ib32 into aux32. + aux32 = ptr::read_unaligned( + scales_and_signs.add((4 * ib32) as usize) as *const u32 + ); + // Compute db = d * (0.5 + (aux32 >> 28)) * 0.5 + let db = d * (0.5 + ((aux32 >> 28) as f32)) * 0.5; + + // Process 4 groups per 32-byte subblock. + for l in 0..4 { + let shift = 7 * l; + let idx = ((aux32 >> shift) & 127) as usize; + // Get the corresponding 'signs' value. + let signs = K_SIGNS_IQ2XS[idx]; + + // Get pointers to grid1 and grid2. + // qs[2*l+0] and qs[2*l+1] are used as offsets into IQ3XXS_GRID. + let idx1 = *qs.add((2 * l + 0) as usize) as usize; + let grid1 = IQ3XXS_GRID.as_ptr().add(idx1) as *const u8; + let idx2 = *qs.add((2 * l + 1) as usize) as usize; + let grid2 = IQ3XXS_GRID.as_ptr().add(idx2) as *const u8; + + // For each of 4 values in grid1 and grid2. + for j in 0..4 { + let mask1 = K_MASK_IQ2XS[(j + 0) as usize]; + let mask2 = K_MASK_IQ2XS[(j + 4) as usize]; + let sign1 = if (signs & mask1) != 0 { -1.0 } else { 1.0 }; + let sign2 = if (signs & mask2) != 0 { -1.0 } else { 1.0 }; + + let grid1_val = *grid1.add(j as usize) as f32; + let grid2_val = *grid2.add(j as usize) as f32; + + // Write the dequantized values. + ptr::write(y.add(j + 0), db * grid1_val * sign1); + ptr::write(y.add(j + 4), db * grid2_val * sign2); + } + // Advance y by 8 floats. + y = y.add(8); + } + // Advance qs by 8 bytes. + qs = qs.add(8); + } + } + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + let k = xs.len(); + if k % QK_K != 0 { + bail!("Input length must be multiple of QK_K = {}", QK_K); + } + + unsafe { quantize_iq3_xxs(xs, ys, 1, k, None)? }; + + Ok(()) + } + + fn from_float_imatrix( + xs: &[f32], + ys: &mut [Self], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + let k = xs.len(); + if k % QK_K != 0 { + bail!("Input length must be multiple of QK_K = {}", QK_K); + } + let nrow = xs.len() / n_per_row; + + unsafe { quantize_iq3_xxs(xs, ys, nrow, n_per_row, Some(imatrix_weights))? }; + + Ok(()) + } + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + todo!() + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + todo!() + } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } +} + +fn quantize_iq4_xs( + src: &[f32], + ys: &mut [BlockIQ4xs], + nrow: usize, + n_per_row: usize, + quant_weights: Option<&[f32]>, +) -> Result<()> { + if n_per_row % QK_K != 0 { + bail!("n_per_row must be multiple of QK_K = {}", QK_K); + } + + let nblock = n_per_row / QK_K; + // We expect exactly nrow * nblock blocks in `ys`. + if ys.len() != nrow * nblock { + bail!( + "Output buffer size mismatch: want {} blocks, got {}", + nrow * nblock, + ys.len() + ); + } + + let mut lbuf = vec![0u8; QK_K]; // L[QK_K] + let mut weight = vec![0f32; 32]; // weight[32] (the block_size is 32) + let mut scales = vec![0f32; QK_K / 32]; // scales[QK_K/32], e.g. 256/32=8 + + let mut src_offset = 0; + let mut dst_offset = 0; + + for _row in 0..nrow { + // Each row has `nblock` blocks: + for ibl in 0..nblock { + let block = &mut ys[dst_offset + ibl]; + + let qw = quant_weights.map(|qw_all| { + let start = QK_K * ibl; + &qw_all[start..start + QK_K] + }); + + quantize_row_iq4_nl_impl( + /* super_block_size = */ QK_K, + /* block_size = */ 32, + /* x = */ + &src[src_offset + QK_K * ibl..src_offset + QK_K * (ibl + 1)], + /* dh = */ &mut block.d, + /* q4 = */ &mut block.qs, + /* scales_h = */ Some(&mut block.scales_h), + /* scales_l = */ Some(&mut block.scales_l), + /* scales = */ &mut scales, + /* weight = */ &mut weight, + /* L = */ &mut lbuf, + /* values = */ &KVALUES_IQ4NL, + /* quant_weights = */ qw, + /* ntry = */ 7, + ); + } + src_offset += n_per_row; + dst_offset += nblock; + } + + Ok(()) +} + +fn quantize_iq4_nl( + src: &[f32], + ys: &mut [BlockIQ4nl], + nrow: usize, + n_per_row: usize, + quant_weights: Option<&[f32]>, +) -> Result<()> { + if n_per_row % QK4_NL != 0 { + bail!("n_per_row must be multiple of QK4_NL = {}", QK4_NL); + } + + let nblock = n_per_row / QK4_NL; + // We expect exactly nrow * nblock blocks in `ys`. + if ys.len() != nrow * nblock { + bail!( + "Output buffer size mismatch: want {} blocks, got {}", + nrow * nblock, + ys.len() + ); + } + + let mut lbuf = vec![0u8; QK4_NL]; // L[QK4_NL] + let mut weight = vec![0f32; QK4_NL]; // weight[QK4_NL] + let mut scales = vec![0f32]; // scales[1] + + let mut src_offset = 0; + let mut dst_offset = 0; + + for _row in 0..nrow { + // Each row has `nblock` blocks: + for ibl in 0..nblock { + let block = &mut ys[dst_offset + ibl]; + + let qw = quant_weights.map(|qw_all| { + let start = QK4_NL * ibl; + &qw_all[start..start + QK4_NL] + }); + + quantize_row_iq4_nl_impl( + /* super_block_size = */ QK4_NL, + /* block_size = */ 32, + /* x = */ + &src[src_offset + QK4_NL * ibl..src_offset + QK4_NL * (ibl + 1)], + /* dh = */ &mut block.d, + /* q4 = */ &mut block.qs, + /* scales_h = */ None, + /* scales_l = */ None, + /* scales = */ &mut scales, + /* weight = */ &mut weight, + /* L = */ &mut lbuf, + /* values = */ &KVALUES_IQ4NL, + /* quant_weights = */ qw, + /* ntry = */ 7, + ); + } + src_offset += n_per_row; + dst_offset += nblock; + } + + Ok(()) +} + +unsafe fn quantize_iq3_xxs( + src: &[f32], + ys: &mut [BlockIQ3xxs], + nrow: usize, + n_per_row: usize, + quant_weights: Option<&[f32]>, +) -> Result<()> { + // Assert that n_per_row is a multiple of QK_K. + assert!( + n_per_row % QK_K == 0, + "n_per_row must be a multiple of QK_K" + ); + let nblock = n_per_row / QK_K; + + let mut src_ptr = src.as_ptr(); + let mut dst_ptr = ys.as_mut_ptr(); + + for _row in 0..nrow { + quantize_row_iq3_xxs_impl(256, src_ptr, dst_ptr, n_per_row as i64, quant_weights); + + src_ptr = src_ptr.add(n_per_row); + dst_ptr = dst_ptr.add(nblock); + } + Ok(()) +} diff --git a/candle-core/src/quantized/iq_quants/utils.rs b/candle-core/src/quantized/iq_quants/utils.rs new file mode 100644 index 0000000000..4614de6169 --- /dev/null +++ b/candle-core/src/quantized/iq_quants/utils.rs @@ -0,0 +1,843 @@ +use std::{ffi::c_void, mem, ptr}; + +use half::f16; + +use crate::quantized::{k_quants::QK_K, GgmlType}; + +use super::BlockIQ3xxs; + +const GROUP_MAX_EPS: f32 = 1e-15; +const GROUP_MAX_EPS_IQ3_XXS: f32 = 1e-8; + +fn nearest_int(x: f32) -> i32 { + x.round() as i32 +} + +#[allow(clippy::too_many_arguments)] +pub(super) fn quantize_row_iq4_nl_impl( + super_block_size: usize, + block_size: usize, + x: &[f32], + dh: &mut f16, + q4: &mut [u8], + scales_h: Option<&mut u16>, + scales_l: Option<&mut [u8]>, + scales: &mut [f32], + weight: &mut [f32], + lbuf: &mut [u8], + values: &[i8], + quant_weights: Option<&[f32]>, + ntry: i32, +) { + // For safety, confirm the slices have correct lengths: + let sb_div_2 = super_block_size / 2; + let sb_div_32 = super_block_size / 32; + let sb_div_64 = super_block_size / 64; + assert_eq!(q4.len(), sb_div_2); + assert_eq!(scales.len(), sb_div_32); + assert_eq!(lbuf.len(), super_block_size); + assert_eq!(weight.len(), block_size); + + // 1. compute sigma2 + let mut sigma2 = 0f32; + for x in x.iter().take(super_block_size) { + sigma2 += x * x; + } + sigma2 *= 2.0 / (super_block_size as f32); + + // 2. zero out q4, set dh to 0 + for qi in q4.iter_mut() { + *qi = 0; + } + *dh = f16::from_f32(0.0); + + // Track the max absolute scale across sub-blocks + let mut max_scale = 0.0_f32; + let mut amax_scale = 0.0_f32; + + // For each 32-float block within the 256-float super-block: + let nblocks = super_block_size / block_size; + + for ib in 0..nblocks { + let xb = &x[ib * block_size..ib * block_size + block_size]; + let lb = &mut lbuf[ib * block_size..ib * block_size + block_size]; + + // If we have external `quant_weights`, fill `weight[j] = quant_weights[j]*sqrt(...)`, + // else `weight[j] = xb[j]*xb[j]` + if let Some(qw) = quant_weights { + let qw_block = &qw[ib * block_size..ib * block_size + block_size]; + for j in 0..block_size { + let val = xb[j]; + weight[j] = qw_block[j] * (sigma2 + val * val).sqrt(); + } + } else { + for j in 0..block_size { + let val = xb[j]; + weight[j] = val * val; + } + } + + // 3. find amax (largest absolute value in block) + let mut amax = 0.0_f32; + let mut max_v = 0.0_f32; + for &xx in xb { + let ax = xx.abs(); + if ax > amax { + amax = ax; + max_v = xx; + } + } + + // If amax is extremely small, scale = 0 + if amax < GROUP_MAX_EPS { + scales[ib] = 0.0; + continue; + } + + // 4. initial guess for d + let sign_factor = if ntry > 0 { -1.0 } else { 1.0 }; + let mut d = sign_factor * max_v / (values[0] as f32); + let id = 1.0 / d; + + // 5. compute an initial sumqx, sumq2 + let mut sumqx = 0.0_f32; + let mut sumq2 = 0.0_f32; + for j in 0..block_size { + let val = xb[j]; + let al = id * val; + let l = best_index_int8(values, al); + lb[j] = l as u8; + + let q = values[l] as f32; + let w = weight[j]; + sumqx += w * q * val; + sumq2 += w * q * q; + } + d = sumqx / sumq2; + let mut best = d * sumqx; + + // 6. do extra tries around that initial guess + for itry in -ntry..=ntry { + let test_id = (itry as f32 + values[0] as f32) / max_v; + let mut tmp_sumqx = 0.0_f32; + let mut tmp_sumq2 = 0.0_f32; + for j in 0..block_size { + let val = xb[j]; + let al = test_id * val; + let l = best_index_int8(values, al); + let q = values[l] as f32; + let w = weight[j]; + tmp_sumqx += w * q * val; + tmp_sumq2 += w * q * q; + } + if tmp_sumq2 > 0.0 { + let maybe_d = tmp_sumqx / tmp_sumq2; + let maybe_best = maybe_d * tmp_sumqx; + if maybe_best > best { + best = maybe_best; + d = maybe_d; + } + } + } + + // 7. record the chosen scale + scales[ib] = d; + let abs_d = d.abs(); + if abs_d > amax_scale { + amax_scale = abs_d; + max_scale = d; + } + } + + // 8. If we have more than one 32-float block in the super-block: + if nblocks > 1 { + let scales_h = scales_h.expect("Expected scales_h, nblocks > 1"); + let scales_l = scales_l.expect("Expected scales_l, nblocks > 1"); + assert_eq!(scales_l.len(), sb_div_64); + + // zero scales_h, because we store 2 bits per block in it + // for nblocks=8, we store them in a single 16-bit value + *scales_h = 0; + for sl in scales_l.iter_mut() { + *sl = 0; + } + + let d = -max_scale / 32.0; + *dh = f16::from_f32(d); + let id = if d != 0.0 { 1.0 / d } else { 0.0 }; + + for ib in 0..nblocks { + // l = nearest_int(id * scales[ib]), clamp to [-32..31] + let mut l = (id * scales[ib]).round() as i32; + l = l.clamp(-32, 31); + + // refine block + let dl = d * (l as f32); + let idl = if dl != 0.0 { 1.0 / dl } else { 0.0 }; + + let xb = &x[ib * block_size..ib * block_size + block_size]; + let lb = &mut lbuf[ib * block_size..ib * block_size + block_size]; + for j in 0..block_size { + let val = xb[j]; + lb[j] = best_index_int8(values, idl * val) as u8; + } + + // store l in 4 bits + 4 bits + let l_offset = (l + 32) as u8; // now in [0..64) + let l_low = l_offset & 0x0f; + let l_high = l_offset >> 4; + + // scales_l[ib/2] uses the nibble for this block + if ib % 2 == 0 { + scales_l[ib / 2] = l_low; + } else { + scales_l[ib / 2] |= l_low << 4; + } + // scales_h for each block (2 bits per block) => stored in a 16-bit + // scaled_h[ib/8] with (l_high << (2*(ib%8))) + let shift = 2 * (ib % 8); + *scales_h |= (l_high as u16) << shift; + } + } else { + // single 32-float block => just store d + *dh = f16::from_f32(scales[0]); + if ntry > 0 { + let id = if scales[0] != 0.0 { + 1.0 / scales[0] + } else { + 0.0 + }; + for j in 0..super_block_size { + lbuf[j] = best_index_int8(values, id * x[j]) as u8; + } + } + } + + // 9. Finally, pack all 4-bit values from L into q4 + // q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4) + for i in 0..(super_block_size / 32) { + for j in 0..16 { + let lo = lbuf[32 * i + j] & 0x0f; + let hi = (lbuf[32 * i + 16 + j] & 0x0f) << 4; + q4[16 * i + j] = lo | hi; + } + } +} + +/// Finds the best index i in [0..values.len()) such that +/// `values[i]` is closest to `x`. The array `values` is strictly +/// ascending/ +fn best_index_int8(values: &[i8], x: f32) -> usize { + // Quick boundary checks + if x <= values[0] as f32 { + return 0; + } + let n = values.len(); + let last = (n - 1).max(0); + if x >= values[last] as f32 { + return last; + } + + // Binary search + let mut ml = 0; + let mut mu = last; + while mu - ml > 1 { + let mav = (ml + mu) / 2; + if x < values[mav] as f32 { + mu = mav; + } else { + ml = mav; + } + } + + // Return whichever is closer among values[mu-1], values[mu] + // But watch out if mu == 0 or mu == n-1 ... + // (the boundary checks above should keep mu>0) + let dist_left = (x - values[ml] as f32).abs(); + let dist_right = (values[mu] as f32 - x).abs(); + if dist_left <= dist_right { + ml + } else { + mu + } +} + +/// Global state analogous to the C code’s iq3_data. +/// (Using unsafe mutable statics; in production code consider using a safe wrapper.) +#[derive(Debug)] +struct Iq3Entry { + grid: Option>, + map: Option>, + neighbours: Option>, +} + +static mut IQ3_DATA: [Iq3Entry; 2] = [ + Iq3Entry { + grid: None, + map: None, + neighbours: None, + }, + Iq3Entry { + grid: None, + map: None, + neighbours: None, + }, +]; + +static KGRID_256: [u16; 256] = [ + 0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74, 81, 85, 88, 90, 97, 108, 120, 128, + 130, 132, 137, 144, 146, 153, 155, 159, 169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, + 292, 303, 315, 317, 321, 327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, + 529, 531, 536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664, 698, + 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978, 992, 1024, 1026, + 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105, 1112, 1139, 1143, + 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228, 1272, 1276, 1309, + 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553, 1555, 1576, 1589, + 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722, 1737, 1755, 1816, + 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063, 2077, 2079, 2091, + 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389, 2403, 2424, 2501, + 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746, 2754, 2795, 2824, + 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153, 3185, 3215, 3252, + 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610, 3626, 3670, 3680, + 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992, +]; + +static KGRID_512: [u16; 512] = [ + 0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34, 37, 39, 41, 43, 48, 50, 57, 60, 63, + 64, 65, 66, 68, 72, 73, 77, 80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, + 139, 142, 145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210, 217, + 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288, 291, 297, 312, 322, + 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393, 395, 409, 426, 441, 448, 450, 452, + 464, 466, 470, 475, 488, 492, 512, 513, 514, 516, 520, 521, 523, 525, 527, 528, 530, 537, 540, + 542, 556, 558, 561, 570, 576, 577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, + 640, 650, 653, 655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727, + 728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833, 840, 843, 849, + 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977, 989, 993, 1010, 1016, 1024, + 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047, 1048, 1050, 1057, 1059, 1061, + 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103, 1106, 1109, 1113, 1116, 1122, + 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199, 1209, 1212, 1216, 1218, 1221, + 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296, 1299, 1306, 1309, 1313, 1338, + 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415, 1425, 1453, 1457, 1477, 1481, + 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561, 1563, 1565, 1570, 1572, 1575, + 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648, 1658, 1662, 1664, 1674, 1680, + 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761, 1763, 1767, 1773, 1787, 1795, + 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877, 1882, 1892, 1902, 1915, 1934, + 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068, 2071, 2074, 2081, 2088, 2104, + 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177, 2179, 2184, 2189, 2193, 2203, + 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269, 2304, 2306, 2324, 2335, 2336, + 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520, 2531, 2537, 2562, 2568, 2572, + 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634, 2641, 2650, 2682, 2688, 2697, + 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805, 2811, 2817, 2820, 2832, 2842, + 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083, 3085, 3097, 3099, 3120, 3136, + 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276, 3281, 3296, 3349, 3363, 3378, + 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591, 3600, 3602, 3614, 3616, 3628, + 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729, 3736, 3753, 3778, 3802, 3805, + 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032, +]; + +/// Returns the index into IQ3_DATA for a given grid size. +/// Panics if grid_size is not 256 or 512. +fn iq3_data_index(grid_size: i32) -> usize { + assert!( + grid_size == 256 || grid_size == 512, + "grid_size must be 256 or 512" + ); + if grid_size == 256 { + 0 + } else { + 1 + } +} + +/// Helper: given a grid value (stored as u32) reinterpreted as 4 bytes, +/// compute the “index” value using: for each byte b, compute ((b-1)/2) +/// and pack it as bits (3 bits per coordinate). +fn compute_index_from_grid_val(val: u32) -> usize { + let bytes = val.to_ne_bytes(); + let mut index = 0usize; + for k in 0..4 { + // (b - 1) / 2 + let q = (bytes[k].saturating_sub(1)) / 2; + index |= (q as usize) << (3 * k); + } + index +} + +/// Computes the squared Euclidean distance between two 4–byte positions. +fn dist2(a: &[u8; 4], b: &[u8; 4]) -> i32 { + let mut d2 = 0; + for k in 0..4 { + let diff = a[k] as i32 - b[k] as i32; + d2 += diff * diff; + } + d2 +} + +/// Main initialization function. This reproduces the C function iq3xs_init_impl. +fn iq3xs_init_impl(grid_size: i32) { + // Determine which slot to use. + let gindex = iq3_data_index(grid_size); + // Use unsafe to access the global mutable state. + unsafe { + if IQ3_DATA[gindex].grid.is_some() { + return; + } + } + + // Choose constants based on grid_size. + let (kgrid, nwant) = if grid_size == 256 { + (&KGRID_256[..], 2) + } else { + (&KGRID_512[..], 3) + }; + let kmap_size = 4096; + + // --- Allocate and initialize the grid --- + // For each element, we compute 4 bytes as: for each i in 0..4: + // byte = 2 * ((kgrid[k] >> (3*i)) & 0x7) + 1 + let mut grid: Vec = Vec::with_capacity(grid_size as usize); + for &kg in kgrid.iter().take(grid_size as usize) { + let mut bytes = [0u8; 4]; + for i in 0..4 { + let l = (kg >> (3 * i)) & 0x7; + bytes[i] = (2 * l + 1) as u8; + } + grid.push(u32::from_ne_bytes(bytes)); + } + + // --- Allocate and initialize the map --- + // kmap: size = 4096, all initialized to -1. + let mut kmap: Vec = vec![-1; kmap_size]; + + // For each grid element, compute its index and store the grid index. + for (j, &val) in grid.iter().enumerate() { + let index = compute_index_from_grid_val(val); + kmap[index] = j as i32; + } + + // --- First pass: determine total space needed for neighbours --- + let mut total_neighbors = 0; + let mut num_not_in_map = 0; + for i in 0..kmap_size { + if kmap[i] >= 0 { + continue; + } + num_not_in_map += 1; + + // Reconstruct the “position” from the map index. + let mut pos = [0u8; 4]; + for k in 0..4 { + let l = (i >> (3 * k)) & 0x7; + pos[k] = (2 * l + 1) as u8; + } + // Build a vector of (distance, grid index) pairs. + let mut dist_vec: Vec<(i32, usize)> = Vec::with_capacity(grid_size as usize); + for (j, &grid_val) in grid.iter().enumerate() { + let grid_bytes = grid_val.to_ne_bytes(); + let d = dist2(&grid_bytes, &pos); + dist_vec.push((d, j)); + } + // Sort the vector: first by distance, then by grid index. + dist_vec.sort_by(|a, b| a.cmp(&b)); + // Count how many neighbors to include. + let mut n = 0; + let mut nhave = 1; + let mut d_current = dist_vec[0].0; + for &(d, _) in dist_vec.iter() { + if d > d_current { + if nhave == nwant { + break; + } + d_current = d; + nhave += 1; + } + n += 1; + } + total_neighbors += n; + } + + // Allocate neighbours vector with the total number of u16 values. + // Note: the C code allocates (num_neighbors + num_not_in_map) elements. + let total_nbrs_size = total_neighbors + num_not_in_map; + let mut neighbours: Vec = Vec::with_capacity(total_nbrs_size); + + // --- Second pass: fill in the neighbours data and update kmap --- + let mut nbr_counter = 0; // global counter in the neighbours vector + for i in 0..kmap_size { + if kmap[i] >= 0 { + continue; + } + // Reconstruct the “position” from the map index. + let mut pos = [0u8; 4]; + for k in 0..4 { + let l = (i >> (3 * k)) & 0x7; + pos[k] = (2 * l + 1) as u8; + } + // Build and sort the distances for all grid elements. + let mut dist_vec: Vec<(i32, usize)> = Vec::with_capacity(grid_size as usize); + for (j, &grid_val) in grid.iter().enumerate() { + let grid_bytes = grid_val.to_ne_bytes(); + let d = dist2(&grid_bytes, &pos); + dist_vec.push((d, j)); + } + dist_vec.sort_by(|a, b| a.cmp(&b)); + + // Store negative index in kmap to indicate start offset in the neighbours vector. + kmap[i] = -((nbr_counter as i32) + 1); + + // Reserve a slot for the count of neighbours. + neighbours.push(0); // placeholder; will update later + nbr_counter += 1; + + // Now, add the neighbour indices. + let mut n = 0; + let mut nhave = 1; + let mut d_current = dist_vec[0].0; + for &(d, j) in dist_vec.iter() { + if d > d_current { + if nhave == nwant { + break; + } + d_current = d; + nhave += 1; + } + // Store the grid index as u16. + neighbours.push(j as u16); + nbr_counter += 1; + n += 1; + } + // Update the placeholder with the count of neighbours for this cell. + neighbours[nbr_counter - n - 1] = n as u16; + } + + // Finally, update the global IQ3_DATA entry. + unsafe { + IQ3_DATA[gindex].grid = Some(grid); + IQ3_DATA[gindex].map = Some(kmap); + IQ3_DATA[gindex].neighbours = Some(neighbours); + } +} + +pub unsafe fn iq3_find_best_neighbour( + neighbours: *const u16, + grid: *const u32, + xval: *const f32, + weight: *const f32, + scale: f32, + L: *mut i8, +) -> i32 { + // neighbours[0] holds the number of neighbours. + let num_neighbors = *neighbours as i32; + assert!(num_neighbors > 0); + let mut best_d2 = f32::MAX; + let mut grid_index: i32 = -1; + // j from 1 to num_neighbors (inclusive) + for j in 1..=num_neighbors { + // neighbours[j] + let neigh = *neighbours.add(j as usize); + // Compute pointer pg = (const int8_t*)(grid + neighbours[j]) + let pg = (grid.add(neigh as usize)) as *const i8; + let mut d2 = 0f32; + for i in 0..4 { + // Note: pg[i] is read as i8 then converted to f32. + let q = *pg.add(i) as f32; + let diff = scale * q - *xval.add(i); + d2 += *weight.add(i) * diff * diff; + } + if d2 < best_d2 { + best_d2 = d2; + grid_index = neigh as i32; + } + } + assert!(grid_index >= 0); + let pg = (grid.add(grid_index as usize)) as *const i8; + for i in 0..4 { + // Here we assume that (pg[i]-1)/2 uses integer arithmetic. + *L.add(i) = ((*pg.add(i)) - 1) / 2; + } + grid_index +} + +pub unsafe fn quantize_row_iq3_xxs_impl( + grid_size: i32, + x: *const f32, + y: *mut BlockIQ3xxs, + n: i64, + quant_weights: Option<&[f32]>, +) { + iq3xs_init_impl(grid_size); + + // Assume iq3_data_index is defined elsewhere. + let gindex = iq3_data_index(grid_size); + + // Assume iq3_data is a global array with fields: grid, map, neighbours. + let kgrid_q3xs: *const u32 = IQ3_DATA[gindex].grid.as_ref().unwrap().as_ptr(); + let kmap_q3xs: *const i32 = IQ3_DATA[gindex].map.as_ref().unwrap().as_ptr(); + let kneighbors_q3xs: *const u16 = IQ3_DATA[gindex].neighbours.as_ref().unwrap().as_ptr(); + + assert!(n % (QK_K as i64) == 0); + + let k_max_q: i32 = 8; + let nbl = n / (QK_K as i64); + + // Variables to hold output pointers. + let mut dh = &raw mut (*y).d; + let mut qs = (*y).qs.as_mut_ptr(); + let block_size = mem::size_of::(); + let quant_size = block_size - mem::size_of::(); + + // Allocate temporary arrays on the stack. + let mut scales = [0f32; QK_K as usize / 32]; + let mut weight_arr = [0f32; 32]; + let mut xval_arr = [0f32; 32]; + let mut L_arr = [0i8; 32]; + let mut Laux_arr = [0i8; 32]; + let mut waux_arr = [0f32; 32]; + let mut is_on_grid = [false; 8]; + let mut is_on_grid_aux = [false; 8]; + let mut block_signs = [0u8; 8]; + let mut q3 = [0u8; 3 * (QK_K as usize / 8) + (QK_K as usize / 32)]; + + // Calculate pointers into q3 + let scales_and_signs = q3[QK_K as usize / 4..].as_mut_ptr() as *mut u32; + let qh = q3[3 * (QK_K as usize / 8)..].as_mut_ptr(); + + // For each block of QK_K values: + for ibl in 0..nbl as usize { + // Set the first fp16 value to zero. + *dh = f16::from_f32(0.0); + ptr::write_bytes( + q3.as_mut_ptr(), + 0, + 3 * (QK_K as usize / 8) + (QK_K as usize / 32), + ); + + let mut max_scale = 0f32; + let xbl = x.add(QK_K as usize * ibl); + let mut sumx2 = 0f32; + for i in 0..(QK_K as usize) { + let xi = *xbl.add(i); + sumx2 += xi * xi; + } + let sigma2 = 2.0 * sumx2 / QK_K as f32; + + for ib in 0..(QK_K as usize / 32) { + let xb = xbl.add(32 * ib); + if let Some(quant_weights) = quant_weights { + let qw = &quant_weights[QK_K as usize * ibl + 32 * ib..]; + for i in 0..32 { + weight_arr[i] = qw[i] * ((sigma2 + (*xb.add(i)) * (*xb.add(i))).sqrt()); + } + } else { + for i in 0..32 { + weight_arr[i] = *xb.add(i) * (*xb.add(i)); + } + } + for i in 0..32 { + waux_arr[i] = weight_arr[i].sqrt(); + } + for k in 0..4 { + let mut nflip = 0; + let mut s: u8 = 0; + for i in 0..8 { + let val = *xb.add(8 * k + i); + if val >= 0.0 { + xval_arr[8 * k + i] = val; + } else { + xval_arr[8 * k + i] = -val; + nflip += 1; + s |= 1 << i; + } + } + if nflip % 2 != 0 { + let mut imin = 0; + let mut min_val = weight_arr[8 * k + imin] + * (*xb.add(8 * k + imin)) + * (*xb.add(8 * k + imin)); + for i in 1..8 { + let ax = + weight_arr[8 * k + i] * (*xb.add(8 * k + i)) * (*xb.add(8 * k + i)); + if ax < min_val { + min_val = ax; + imin = i; + } + } + xval_arr[8 * k + imin] = -xval_arr[8 * k + imin]; + s ^= 1 << imin; + } + block_signs[k] = s & 127; + } + let mut max_val = xval_arr[0]; + for i in 1..32 { + max_val = if max_val > xval_arr[i] { + max_val + } else { + xval_arr[i] + }; + } + if max_val < GROUP_MAX_EPS_IQ3_XXS { + scales[ib] = 0.0; + for i in 0..32 { + L_arr[i] = 0; + } + continue; + } + let mut best = 0f32; + let mut scale = max_val / ((2 * k_max_q - 1) as f32); + for is in -15..=15 { + let id = ((2 * k_max_q - 1) as f32 + (is as f32) * 0.2) / max_val; + let this_scale = 1.0 / id; + for k in 0..8 { + for i in 0..4 { + // nearest_int and clamp must be defined elsewhere. + let l = nearest_int(0.5 * (id * xval_arr[4 * k + i] - 1.0)); + Laux_arr[4 * k + i] = l.clamp(0, k_max_q - 1) as i8; + } + let mut u: u16 = 0; + for i in 0..4 { + u |= (Laux_arr[4 * k + i] as u16) << (3 * i); + } + let mut grid_index = *kmap_q3xs.add(u as usize); + is_on_grid_aux[k] = true; + if grid_index < 0 { + is_on_grid_aux[k] = false; + let neighbours = + kneighbors_q3xs.offset(-(*kmap_q3xs.add(u as usize) + 1) as isize); + grid_index = iq3_find_best_neighbour( + neighbours, + kgrid_q3xs, + xval_arr.as_ptr().add(4 * k), + waux_arr.as_ptr().add(4 * k), + this_scale, + Laux_arr.as_mut_ptr().add(4 * k), + ); + } + } + let mut sumqx = 0f32; + let mut sumq2 = 0f32; + for i in 0..32 { + let w = weight_arr[i]; + let q = 2.0 * (Laux_arr[i] as f32) + 1.0; + sumqx += w * xval_arr[i] * q; + sumq2 += w * q * q; + } + if sumq2 > 0.0 && sumqx * sumqx > best * sumq2 { + scale = sumqx / sumq2; + best = scale * sumqx; + for i in 0..32 { + L_arr[i] = Laux_arr[i]; + } + for k in 0..8 { + is_on_grid[k] = is_on_grid_aux[k]; + } + } + } + let mut n_not_ongrid = 0; + for k in 0..8 { + if !is_on_grid[k] { + n_not_ongrid += 1; + } + } + if n_not_ongrid > 0 && scale > 0.0 { + let id = 1.0 / scale; + for k in 0..8 { + if is_on_grid[k] { + continue; + } + let mut u: u16 = 0; + for i in 0..4 { + let mut l = nearest_int(0.5 * (id * xval_arr[4 * k + i] - 1.0)); + l = l.clamp(0, k_max_q - 1); + u |= (l as u16) << (3 * i); + } + let mut grid_index = *kmap_q3xs.add(u as usize); + if grid_index < 0 { + let neighbours = + kneighbors_q3xs.offset(-(*kmap_q3xs.add(u as usize) + 1) as isize); + grid_index = iq3_find_best_neighbour( + neighbours, + kgrid_q3xs, + xval_arr.as_ptr().add(4 * k), + waux_arr.as_ptr().add(4 * k), + scale, + L_arr.as_mut_ptr().add(4 * k), + ); + } + let pg = (kgrid_q3xs.add(grid_index as usize)) as *const i8; + for i in 0..4 { + L_arr[4 * k + i] = ((*pg.add(i)) - 1) / 2; + } + } + let mut sumqx = 0f32; + let mut sumq2 = 0f32; + for i in 0..32 { + let w = weight_arr[i]; + let q = 2.0 * (L_arr[i] as f32) + 1.0; + sumqx += w * xval_arr[i] * q; + sumq2 += w * q * q; + } + if sumq2 > 0.0 { + scale = sumqx / sumq2; + } + } + if scale < 0.0 { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for k in 0..4 { + block_signs[k] = (!block_signs[k]) & 127; + } + } + for k in 0..8 { + let mut u: u16 = 0; + for i in 0..4 { + u |= (L_arr[4 * k + i] as u16) << (3 * i); + } + let grid_index = *kmap_q3xs.add(u as usize); + if grid_index < 0 { + println!("Oops: found point {} not on grid:", u); + for i in 0..4 { + print!(" {}", L_arr[4 * k + i]); + } + println!(); + panic!("fatal error"); + } + if grid_size == 256 { + q3[8 * ibl + k] = grid_index as u8; + } else { + q3[8 * ibl + k] = (grid_index & 255) as u8; + *qh |= ((grid_index >> 8) as u8) << k; + } + } + // Pack block_signs into scales_and_signs + *scales_and_signs.add(ibl) = block_signs[0] as u32 + | ((block_signs[1] as u32) << 7) + | ((block_signs[2] as u32) << 14) + | ((block_signs[3] as u32) << 21); + assert!(scale >= 0.0); + scales[ibl] = scale; + if scale > max_scale { + max_scale = scale; + } + } + + if max_scale == 0.0 { + ptr::write_bytes(qs, 0, quant_size as usize); + dh = dh.add(block_size as usize / mem::size_of::()); + qs = qs.add(block_size as usize); + continue; + } + let d = max_scale / 31.0; + *dh = f16::from_f32(d * 1.0125); + let id = 1.0 / d; + for ib in 0..(QK_K as usize / 32) { + let l = nearest_int(0.5 * (id * scales[ib] - 1.0)); + let l = l.clamp(0, 15); + let prev = *scales_and_signs.add(ib); + *scales_and_signs.add(ib) = prev | ((l as u32) << 28); + } + ptr::copy_nonoverlapping(q3.as_ptr(), qs, quant_size as usize); + dh = dh.add(block_size as usize / mem::size_of::()); + qs = qs.add(block_size as usize); + } +} diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants/mod.rs similarity index 90% rename from candle-core/src/quantized/k_quants.rs rename to candle-core/src/quantized/k_quants/mod.rs index 27cc984a20..1c07847a73 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants/mod.rs @@ -1,13 +1,13 @@ -use super::utils::{ +mod utils; +use super::k_quants::utils::{ get_scale_min_k4, group_for_dequantization, group_for_quantization, make_q3_quants, - make_qkx1_quants, make_qx_quants, nearest_int, + make_qkx1_quants, make_qkx3_quants, make_qp_quants, make_qx_quants, nearest_int, }; -use super::GgmlDType; -use crate::quantized::utils::{make_qkx3_quants, make_qp_quants}; +use super::{GgmlDType, GgmlType}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; -use half::{bf16, f16}; -use rayon::prelude::*; +use float8::F8E4M3; +use half::f16; // Default to QK_K 256 rather than 64. pub const QK_K: usize = 256; @@ -20,37 +20,6 @@ pub const QK5_1: usize = 32; pub const QK8_0: usize = 32; pub const QK8_1: usize = 32; -pub trait GgmlType: Sized + Clone + Send + Sync { - const DTYPE: GgmlDType; - const BLCK_SIZE: usize; - type VecDotType: GgmlType; - - // This is only safe for types that include immediate values such as float/int/... - fn zeros() -> Self { - unsafe { std::mem::MaybeUninit::zeroed().assume_init() } - } - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>; - fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>; - fn from_float_imatrix( - _xs: &[f32], - _ys: &mut [Self], - _imatrix_weights: &[f32], - _n_per_row: usize, - ) -> Result<()> { - crate::bail!( - "`from_float_imatrix` is unimplemented for {:?}", - Self::DTYPE - ); - } - - /// Dot product used as a building block for quantized mat-mul. - /// n is the number of elements to be considered. - fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; - - /// Generic implementation of the dot product without simd optimizations. - fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; -} - #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ4_0 { @@ -95,6 +64,20 @@ pub struct BlockQ8_0 { } const _: () = assert!(std::mem::size_of::() == 34); +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockF8Q8 { + d: F8E4M3, + pub(crate) qs: [i8; QK8_0], +} +const _: () = assert!(std::mem::size_of::() == 33); + +impl BlockF8Q8 { + pub fn dq_d(&self) -> f32 { + self.d.to_f32() / F8E4M3::MAX.to_f32() + } +} + #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ8_1 { @@ -170,6 +153,7 @@ impl GgmlType for BlockQ4_0 { const DTYPE: GgmlDType = GgmlDType::Q4_0; const BLCK_SIZE: usize = QK4_0; type VecDotType = BlockQ8_0; + const SUPPORTS_I8MM: bool = true; // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { @@ -264,12 +248,28 @@ impl GgmlType for BlockQ4_0 { } Ok(sumf) } + + #[allow(unreachable_code)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q4_0_q8_0(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ4_1 { const DTYPE: GgmlDType = GgmlDType::Q4_1; const BLCK_SIZE: usize = QK4_1; type VecDotType = BlockQ8_1; + const SUPPORTS_I8MM: bool = false; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { Self::vec_dot_unopt(n, xs, ys) @@ -359,12 +359,26 @@ impl GgmlType for BlockQ4_1 { } Ok(()) } + + #[allow(unreachable_code)] + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ5_0 { const DTYPE: GgmlDType = GgmlDType::Q5_0; const BLCK_SIZE: usize = QK5_0; type VecDotType = BlockQ8_0; + const SUPPORTS_I8MM: bool = false; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { let qk = Self::BLCK_SIZE; @@ -461,12 +475,25 @@ impl GgmlType for BlockQ5_0 { } Ok(()) } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ5_1 { const DTYPE: GgmlDType = GgmlDType::Q5_1; const BLCK_SIZE: usize = QK5_1; type VecDotType = BlockQ8_1; + const SUPPORTS_I8MM: bool = false; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { Self::vec_dot_unopt(n, xs, ys) @@ -569,12 +596,25 @@ impl GgmlType for BlockQ5_1 { } Ok(()) } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ8_0 { const DTYPE: GgmlDType = GgmlDType::Q8_0; const BLCK_SIZE: usize = QK8_0; type VecDotType = BlockQ8_0; + const SUPPORTS_I8MM: bool = true; // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { @@ -659,12 +699,136 @@ impl GgmlType for BlockQ8_0 { } Ok(sumf) } + + #[allow(unreachable_code)] + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q8_0_q8_0(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } +} + +impl GgmlType for BlockF8Q8 { + const DTYPE: GgmlDType = GgmlDType::F8Q8; + const BLCK_SIZE: usize = QK8_0; + type VecDotType = BlockQ8_0; + const SUPPORTS_I8MM: bool = true; + + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK8_0 != 0 { + crate::bail!("dequantize_row_f8q8: {k} is not divisible by {QK8_0}"); + } + + let nb = k / QK8_0; + + for i in 0..nb { + let d = xs[i].dq_d(); + + for j in 0..QK8_0 { + ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d; + } + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q8_0 + let k = xs.len(); + if k % Self::BLCK_SIZE != 0 { + crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE); + }; + let nb = k / Self::BLCK_SIZE; + if ys.len() != nb { + crate::bail!( + "size mismatch {} {} {}", + xs.len(), + ys.len(), + Self::BLCK_SIZE + ) + } + for (i, ys) in ys.iter_mut().enumerate() { + let mut amax = 0f32; + let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; + for &x in xs.iter() { + amax = amax.max(x.abs()) + } + let d = amax / ((1 << 7) - 1) as f32; + let id = if d != 0f32 { 1. / d } else { 0. }; + ys.d = F8E4M3::from_f32(d * F8E4M3::MAX.to_f32()); + for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) { + *y = f32::round(x * id) as i8 + } + } + Ok(()) + } + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_f8q8_q8_0(n, xs, ys); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_f8q8_q8_0(n, xs, ys); + + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_f8q8_q8_0(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_f8q8_q8_0: {n} is not divisible by {qk}") + } + + // Generic implementation. + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + let sum_i = xs + .qs + .iter() + .zip(ys.qs.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum::(); + sumf += sum_i as f32 * xs.dq_d() * f16::to_f32(ys.d) + } + Ok(sumf) + } + + #[allow(unreachable_code)] + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_f8q8_q8_0(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ8_1 { const DTYPE: GgmlDType = GgmlDType::Q8_1; const BLCK_SIZE: usize = QK8_1; type VecDotType = BlockQ8_1; + const SUPPORTS_I8MM: bool = false; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { Self::vec_dot_unopt(n, xs, ys) @@ -705,12 +869,23 @@ impl GgmlType for BlockQ8_1 { fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> { unimplemented!("no support for vec-dot on Q8_1") } + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + _n: usize, + _xs_0: &[Self], + _xs_1: &[Self], + _ys_0: &[Self::VecDotType], + _ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + unimplemented!("no support for i8mm matmul on Q8_1") + } } impl GgmlType for BlockQ2K { const DTYPE: GgmlDType = GgmlDType::Q2K; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = true; #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -947,12 +1122,28 @@ impl GgmlType for BlockQ2K { } Ok(()) } + + #[allow(unreachable_code)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q2k_q8k(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ3K { const DTYPE: GgmlDType = GgmlDType::Q3K; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = false; #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -1329,12 +1520,29 @@ impl GgmlType for BlockQ3K { Ok(()) } + + #[allow(unreachable_code)] + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q3k_q8k(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ4K { const DTYPE: GgmlDType = GgmlDType::Q4K; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = true; #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -1594,6 +1802,22 @@ impl GgmlType for BlockQ4K { } Ok(()) } + + #[allow(unreachable_code)] + #[allow(dead_code)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q4k_q8k(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928 @@ -1601,6 +1825,7 @@ impl GgmlType for BlockQ5K { const DTYPE: GgmlDType = GgmlDType::Q5K; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = true; #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -1905,12 +2130,29 @@ impl GgmlType for BlockQ5K { } Ok(()) } + + #[allow(unreachable_code)] + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q5k_q8k(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ6K { const DTYPE: GgmlDType = GgmlDType::Q6K; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = true; #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -2175,12 +2417,29 @@ impl GgmlType for BlockQ6K { } Ok(()) } + + #[allow(unreachable_code)] + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + #[cfg(target_feature = "neon")] + return super::neon::i8mm_q6k_q8k(n, xs_0, xs_1, ys_0, ys_1); + + crate::bail!("Unsupported block type for i8mm"); + } } impl GgmlType for BlockQ8K { const DTYPE: GgmlDType = GgmlDType::Q8K; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + const SUPPORTS_I8MM: bool = false; #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -2267,174 +2526,14 @@ impl GgmlType for BlockQ8K { } Ok(()) } -} - -// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 -pub fn matmul( - mkn: (usize, usize, usize), - lhs: &[f32], - rhs_t: &[T], - dst: &mut [f32], -) -> Result<()> { - let (m, k, n) = mkn; - if m * k != lhs.len() { - crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); - } - - let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE); - let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE); - // TODO: Do not make this copy if the DotType is f32. - // TODO: Pre-allocate this. - let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; - for row_idx in 0..m { - let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; - let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; - T::VecDotType::from_float(lhs, lhs_b)? - } - let lhs_b = lhs_b.as_slice(); - - for row_idx in 0..m { - let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; - let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; - - let result: Result> = dst_row - .into_par_iter() - .enumerate() - .with_min_len(128) - .with_max_len(512) - .map(|(col_idx, dst)| { - let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; - T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value) - }) - .collect(); - - result?; - } - Ok(()) -} - -impl GgmlType for f32 { - const DTYPE: GgmlDType = GgmlDType::F32; - const BLCK_SIZE: usize = 1; - type VecDotType = f32; - - fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - Self::vec_dot_unopt(n, xs, ys) - } - - fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - if xs.len() < n { - crate::bail!("size mismatch {} < {n}", xs.len()) - } - if ys.len() < n { - crate::bail!("size mismatch {} < {n}", ys.len()) - } - let mut res = 0f32; - unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; - Ok(res) - } - - fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { - if xs.len() != ys.len() { - crate::bail!("size mismatch {} {}", xs.len(), ys.len()); - } - ys.copy_from_slice(xs); - Ok(()) - } - - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { - if xs.len() != ys.len() { - crate::bail!("size mismatch {} {}", xs.len(), ys.len()); - } - ys.copy_from_slice(xs); - Ok(()) - } -} - -impl GgmlType for f16 { - const DTYPE: GgmlDType = GgmlDType::F16; - const BLCK_SIZE: usize = 1; - type VecDotType = f16; - - fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - Self::vec_dot_unopt(n, xs, ys) - } - - fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - if xs.len() < n { - crate::bail!("size mismatch {} < {n}", xs.len()) - } - if ys.len() < n { - crate::bail!("size mismatch {} < {n}", ys.len()) - } - let mut res = 0f32; - unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; - Ok(res) - } - - fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { - if xs.len() != ys.len() { - crate::bail!("size mismatch {} {}", xs.len(), ys.len()); - } - // TODO: vectorize - for (x, y) in xs.iter().zip(ys.iter_mut()) { - *y = f16::from_f32(*x) - } - Ok(()) - } - - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { - if xs.len() != ys.len() { - crate::bail!("size mismatch {} {}", xs.len(), ys.len()); - } - // TODO: vectorize - for (x, y) in xs.iter().zip(ys.iter_mut()) { - *y = x.to_f32() - } - Ok(()) - } -} - -impl GgmlType for bf16 { - const DTYPE: GgmlDType = GgmlDType::BF16; - const BLCK_SIZE: usize = 1; - type VecDotType = bf16; - - fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - Self::vec_dot_unopt(n, xs, ys) - } - - fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - if xs.len() < n { - crate::bail!("size mismatch {} < {n}", xs.len()) - } - if ys.len() < n { - crate::bail!("size mismatch {} < {n}", ys.len()) - } - let mut res = 0f32; - unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; - Ok(res) - } - - fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { - if xs.len() != ys.len() { - crate::bail!("size mismatch {} {}", xs.len(), ys.len()); - } - // TODO: vectorize - for (x, y) in xs.iter().zip(ys.iter_mut()) { - *y = bf16::from_f32(*x) - } - Ok(()) - } - - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { - if xs.len() != ys.len() { - crate::bail!("size mismatch {} {}", xs.len(), ys.len()); - } - // TODO: vectorize - for (x, y) in xs.iter().zip(ys.iter_mut()) { - *y = x.to_f32() - } - Ok(()) + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + _n: usize, + _xs_0: &[Self], + _xs_1: &[Self], + _ys_0: &[Self::VecDotType], + _ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + unreachable!(); } } diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/k_quants/utils.rs similarity index 99% rename from candle-core/src/quantized/utils.rs rename to candle-core/src/quantized/k_quants/utils.rs index 0a087cddbb..7db44df455 100644 --- a/candle-core/src/quantized/utils.rs +++ b/candle-core/src/quantized/k_quants/utils.rs @@ -7,7 +7,7 @@ pub(super) fn nearest_int(v: f32) -> i32 { /// Validates that the input and output are the right size and returns an iterator which maps each /// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed /// to be `T::BLCK_SIZE` long. -pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>( +pub(super) fn group_for_quantization<'a, 'b, T: super::GgmlType>( xs: &'b [f32], ys: &'a mut [T], ) -> Result> { @@ -28,7 +28,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>( /// Validates that the input and output are the right size and returns an iterator which maps each /// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed /// to be `T::BLCK_SIZE` long. -pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>( +pub(super) fn group_for_dequantization<'a, 'b, T: super::GgmlType>( xs: &'a [T], ys: &'b mut [f32], ) -> Result> { diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index d06b0674a7..9ff1268f14 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -34,7 +34,7 @@ impl QMetalStorage { } pub fn dequantize(&self, elem_count: usize) -> Result { - use crate::quantized::k_quants::GgmlType; + use crate::quantized::quants::GgmlType; let buffer = self.device.new_buffer_managed(self.buffer.length())?; let command_buffer = self.device.command_buffer()?; @@ -107,6 +107,22 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; } + GgmlDType::Iq4Xs => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockIQ4xs::to_float(&vec, &mut out)?; + } + GgmlDType::Iq4Nl => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockIQ4nl::to_float(&vec, &mut out)?; + } + GgmlDType::Iq3Xxs => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockIQ3xxs::to_float(&vec, &mut out)?; + } + GgmlDType::F8Q8 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockF8Q8::to_float(&vec, &mut out)?; + } } let buffer = self.device.new_buffer_with_data(&out)?; @@ -387,9 +403,13 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K, GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K, GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, + GgmlDType::Iq4Xs => candle_metal_kernels::GgmlDType::Iq4Xs, + GgmlDType::Iq4Nl => candle_metal_kernels::GgmlDType::Iq4Nl, + GgmlDType::Iq3Xxs => candle_metal_kernels::GgmlDType::Iq3Xxs, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16, + GgmlDType::F8Q8 => todo!("F8Q8 is unsupported on Metal"), } } } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 8591aa259d..727ebad4d3 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,4 +1,5 @@ use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor, D}; +use iq_quants::*; use k_quants::*; use std::borrow::Cow; @@ -9,9 +10,11 @@ mod dummy_metal; pub mod ggml_file; pub mod gguf_file; pub mod imatrix_file; +pub mod iq_quants; pub mod k_quants; #[cfg(feature = "metal")] pub mod metal; +pub mod quants; #[cfg(not(feature = "metal"))] mod metal { pub use super::dummy_metal::*; @@ -27,10 +30,9 @@ mod cuda { pub mod neon; #[cfg(target_feature = "simd128")] pub mod simd128; -pub mod utils; use half::{bf16, f16}; -pub use k_quants::GgmlType; +pub use quants::GgmlType; pub struct QTensor { storage: QStorage, @@ -200,10 +202,14 @@ pub enum GgmlDType { Q5K, Q6K, Q8K, + Iq4Xs, + Iq4Nl, + Iq3Xxs, + F8Q8, } impl GgmlDType { - pub(crate) fn from_u32(u: u32) -> Result { + pub fn from_u32(u: u32) -> Result { let dtype = match u { 0 => Self::F32, 1 => Self::F16, @@ -219,14 +225,18 @@ impl GgmlDType { 13 => Self::Q5K, 14 => Self::Q6K, 15 => Self::Q8K, + 18 => Self::Iq3Xxs, + 20 => Self::Iq4Nl, + 23 => Self::Iq4Xs, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 30 => Self::BF16, + 100 => Self::F8Q8, _ => crate::bail!("unknown dtype for tensor {u}"), }; Ok(dtype) } - pub(crate) fn to_u32(self) -> u32 { + pub fn to_u32(self) -> u32 { match self { Self::F32 => 0, Self::F16 => 1, @@ -242,8 +252,12 @@ impl GgmlDType { Self::Q5K => 13, Self::Q6K => 14, Self::Q8K => 15, + Self::Iq3Xxs => 18, + Self::Iq4Nl => 20, + Self::Iq4Xs => 23, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 Self::BF16 => 30, + Self::F8Q8 => 100, } } @@ -264,6 +278,19 @@ impl GgmlDType { Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), + Self::Iq4Nl => Box::new(vec![ + BlockIQ4nl::zeros(); + elem_count / BlockIQ4nl::BLCK_SIZE + ]), + Self::Iq4Xs => Box::new(vec![ + BlockIQ4xs::zeros(); + elem_count / BlockIQ4xs::BLCK_SIZE + ]), + Self::Iq3Xxs => Box::new(vec![ + BlockIQ3xxs::zeros(); + elem_count / BlockIQ3xxs::BLCK_SIZE + ]), + Self::F8Q8 => Box::new(vec![BlockF8Q8::zeros(); elem_count / BlockF8Q8::BLCK_SIZE]), Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]), } } @@ -286,6 +313,10 @@ impl GgmlDType { Self::Q5K => std::mem::size_of::(), Self::Q6K => std::mem::size_of::(), Self::Q8K => std::mem::size_of::(), + Self::Iq4Nl => std::mem::size_of::(), + Self::Iq4Xs => std::mem::size_of::(), + Self::Iq3Xxs => std::mem::size_of::(), + Self::F8Q8 => std::mem::size_of::(), } } @@ -298,9 +329,17 @@ impl GgmlDType { Self::Q4_1 => k_quants::QK4_1, Self::Q5_0 => k_quants::QK5_0, Self::Q5_1 => k_quants::QK5_1, - Self::Q8_0 => k_quants::QK8_0, + Self::Q8_0 | Self::F8Q8 => k_quants::QK8_0, Self::Q8_1 => k_quants::QK8_1, - Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K, + Self::Iq4Nl => iq_quants::QK4_NL, + Self::Q2K + | Self::Q3K + | Self::Q4K + | Self::Q5K + | Self::Q6K + | Self::Q8K + | Self::Iq4Xs + | Self::Iq3Xxs => k_quants::QK_K, } } } @@ -325,9 +364,9 @@ pub trait QuantizedType: Send + Sync { fn size(&self) -> usize; } -impl QuantizedType for Vec { +impl QuantizedType for Vec { fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> { - k_quants::matmul(mkn, lhs, self.as_slice(), dst) + quants::matmul(mkn, lhs, self.as_slice(), dst) } fn size(&self) -> usize { diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index c4d5d6f41a..200911ddf3 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -1,8 +1,14 @@ -use super::k_quants::{ - BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, +use super::{ + iq_quants::{BlockIQ4nl, BlockIQ4xs, QK4_NL}, + k_quants::{ + BlockF8Q8, BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, + BlockQ8_0, QK8_0, QK_K, + }, }; -use crate::Result; +use crate::{quantized::KVALUES_IQ4NL, Result}; use byteorder::{ByteOrder, LittleEndian}; +#[cfg(feature = "arm-nightly-feat")] +use itertools::izip; #[allow(unused_imports)] #[cfg(target_arch = "arm")] @@ -11,14 +17,8 @@ use core::arch::arm::*; #[allow(unused_imports)] #[cfg(target_arch = "aarch64")] use core::arch::aarch64::*; - -#[inline(always)] -unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t { - // TODO: dotprod - let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); - let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); - vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)) -} +#[cfg(feature = "arm-nightly-feat")] +use std::arch::is_aarch64_feature_detected; #[inline(always)] pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { @@ -51,8 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); - let pl0 = vdotq_s32(v0_0ls, v1_0l); - let ph0 = vdotq_s32(v0_0hs, v1_0h); + let pl0 = vdotq_s32_local(vdupq_n_s32(0), v0_0ls, v1_0l); + let ph0 = vdotq_s32_local(vdupq_n_s32(0), v0_0hs, v1_0h); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), @@ -62,7 +62,85 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Ok(vaddvq_f32(sumv0)) } } +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q4_0_q8_0( + n: usize, + xs_0: &[BlockQ4_0], + xs_1: &[BlockQ4_0], + ys_0: &[BlockQ8_0], + ys_1: &[BlockQ8_0], +) -> Result<[f32; 4]> { + let qk = QK8_0; + let nb = n / qk; + if n % QK8_0 != 0 { + crate::bail!("i8mm_q4_0_q8_0: {n} is not divisible by {qk}") + } + //let (xs_0, xs_1) = xs.split_at_mut(xs.len() / 2); + //let (ys_0, ys_1) = ys.split_at_mut(ys.len() / 2); + assert_eq!(xs_0.len(), xs_1.len()); + assert_eq!(ys_0.len(), ys_1.len()); + assert_eq!(xs_0.len(), ys_0.len()); + + unsafe { + let mut sum_f32 = vdupq_n_f32(0.0); + let m4b = vdupq_n_u8(0x0F); + let s8b = vdupq_n_s8(0x8); + + for i in 0..nb { + let x0 = &xs_0[i]; + let x1 = &xs_1[i]; + let y0 = &ys_0[i]; + let y1 = &ys_1[i]; + + let factor_00: f32 = x0.d.to_f32() * y0.d.to_f32(); + let factor_01: f32 = x1.d.to_f32() * y0.d.to_f32(); + let factor_10: f32 = x0.d.to_f32() * y1.d.to_f32(); + let factor_11: f32 = x1.d.to_f32() * y1.d.to_f32(); + + let xv0 = vld1q_u8(x0.qs.as_ptr()); //16xu8 + let xv1 = vld1q_u8(x1.qs.as_ptr()); //16xu8 + + // convert u8s to i4s so we have equal amount of row elements + // and columns elements to multiply + let xv0_0 = vreinterpretq_s8_u8(vandq_u8(xv0, m4b)); + let xv0_1 = vreinterpretq_s8_u8(vshrq_n_u8(xv0, 4)); + let xv1_0 = vreinterpretq_s8_u8(vandq_u8(xv1, m4b)); + let xv1_1 = vreinterpretq_s8_u8(vshrq_n_u8(xv1, 4)); + + // sub 8 + let xv0_0s = vsubq_s8(xv0_0, s8b); + let xv0_1s = vsubq_s8(xv0_1, s8b); + let xv1_0s = vsubq_s8(xv1_0, s8b); + let xv1_1s = vsubq_s8(xv1_1, s8b); + //end of conversion + + let yv0_0 = vld1q_s8(y0.qs.as_ptr()); //16xi8 + let yv0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); // 16xi8 + let yv1_0 = vld1q_s8(y1.qs.as_ptr()); //16xi8 + let yv1_1 = vld1q_s8(y1.qs.as_ptr().add(16)); // 16xi8 + + let i8mm = i8mm_params::new(xv0_0s, xv0_1s, xv1_0s, xv1_1s, yv0_0, yv0_1, yv1_0, yv1_1); + let loop_sum_s32 = i8mm.calculate(vdupq_n_s32(0)); + + // scaling + let factor_elems: [f32; 4] = [factor_00, factor_01, factor_10, factor_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + let loop_sum_f32 = vcvtq_f32_s32(loop_sum_s32); + + sum_f32 = vmlaq_f32(sum_f32, loop_sum_f32, factor); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sum_f32, 0); + let f1 = vgetq_lane_f32(sum_f32, 1); + let f2 = vgetq_lane_f32(sum_f32, 2); + let f3 = vgetq_lane_f32(sum_f32, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} #[inline(always)] pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result { let qk = QK8_0; @@ -83,8 +161,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); - let p0 = vdotq_s32(x0_0, y0_0); - let p1 = vdotq_s32(x0_1, y0_1); + let p0 = vdotq_s32_local(vdupq_n_s32(0), x0_0, y0_0); + let p1 = vdotq_s32_local(vdupq_n_s32(0), x0_1, y0_1); sumv0 = vmlaq_n_f32( sumv0, @@ -95,6 +173,161 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Ok(vaddvq_f32(sumv0)) } } +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q8_0_q8_0( + n: usize, + xs_0: &[BlockQ8_0], + xs_1: &[BlockQ8_0], + ys_0: &[BlockQ8_0], + ys_1: &[BlockQ8_0], +) -> Result<[f32; 4]> { + assert_eq!(xs_0.len(), xs_1.len()); + assert_eq!(ys_0.len(), ys_1.len()); + assert_eq!(xs_0.len(), ys_0.len()); + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("i8mm_q8_0_q8_0: {n} is not divisible by {qk}") + } + let nb = n / QK8_0; + unsafe { + let mut sum_f32 = vdupq_n_f32(0.0); + + for i in 0..nb { + let x0 = &xs_0[i]; + let x1 = &xs_1[i]; + let y0 = &ys_0[i]; + let y1 = &ys_1[i]; + + let factor_00: f32 = x0.d.to_f32() * y0.d.to_f32(); + let factor_01: f32 = x1.d.to_f32() * y0.d.to_f32(); + let factor_10: f32 = x0.d.to_f32() * y1.d.to_f32(); + let factor_11: f32 = x1.d.to_f32() * y1.d.to_f32(); + + let xv0_0 = vld1q_s8(x0.qs.as_ptr()); + let xv0_1 = vld1q_s8(x0.qs.as_ptr().add(16)); + let xv1_0 = vld1q_s8(x1.qs.as_ptr()); + let xv1_1 = vld1q_s8(x1.qs.as_ptr().add(16)); + + let yv0_0 = vld1q_s8(y0.qs.as_ptr()); + let yv0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); + let yv1_0 = vld1q_s8(y1.qs.as_ptr()); + let yv1_1 = vld1q_s8(y1.qs.as_ptr().add(16)); + + let i8mm = i8mm_params::new(xv0_0, xv0_1, xv1_0, xv1_1, yv0_0, yv0_1, yv1_0, yv1_1); + let loop_sum_s32 = i8mm.calculate(vdupq_n_s32(0)); + + // scaling + let factor_elems: [f32; 4] = [factor_00, factor_01, factor_10, factor_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + let loop_sum_f32 = vcvtq_f32_s32(loop_sum_s32); + + sum_f32 = vmlaq_f32(sum_f32, loop_sum_f32, factor); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sum_f32, 0); + let f1 = vgetq_lane_f32(sum_f32, 1); + let f2 = vgetq_lane_f32(sum_f32, 2); + let f3 = vgetq_lane_f32(sum_f32, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_f8q8_q8_0(n: usize, xs: &[BlockF8Q8], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_f8q8_q8_0: {n} is not divisible by {qk}") + } + let nb = n / QK8_0; + unsafe { + let mut sumv0 = vdupq_n_f32(0.0f32); + for i in 0..nb { + let x0 = &xs[i]; + let y0 = &ys[i]; + + let x0_0 = vld1q_s8(x0.qs.as_ptr()); + let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16)); + + // load y + let y0_0 = vld1q_s8(y0.qs.as_ptr()); + let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); + + let p0 = vdotq_s32_local(vdupq_n_s32(0), x0_0, y0_0); + let p1 = vdotq_s32_local(vdupq_n_s32(0), x0_1, y0_1); + + sumv0 = vmlaq_n_f32( + sumv0, + vcvtq_f32_s32(vaddq_s32(p0, p1)), + x0.dq_d() * y0.d.to_f32(), + ); + } + Ok(vaddvq_f32(sumv0)) + } +} +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_f8q8_q8_0( + n: usize, + xs_0: &[BlockF8Q8], + xs_1: &[BlockF8Q8], + ys_0: &[BlockQ8_0], + ys_1: &[BlockQ8_0], +) -> Result<[f32; 4]> { + assert_eq!(xs_0.len(), xs_1.len()); + assert_eq!(ys_0.len(), ys_1.len()); + assert_eq!(xs_0.len(), ys_0.len()); + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("i8mm_q8_0_q8_0: {n} is not divisible by {qk}") + } + let nb = n / QK8_0; + unsafe { + let mut sum_f32 = vdupq_n_f32(0.0); + + for i in 0..nb { + let x0 = &xs_0[i]; + let x1 = &xs_1[i]; + let y0 = &ys_0[i]; + let y1 = &ys_1[i]; + + let factor_00: f32 = x0.dq_d() * y0.d.to_f32(); + let factor_01: f32 = x1.dq_d() * y0.d.to_f32(); + let factor_10: f32 = x0.dq_d() * y1.d.to_f32(); + let factor_11: f32 = x1.dq_d() * y1.d.to_f32(); + + let xv0_0 = vld1q_s8(x0.qs.as_ptr()); + let xv0_1 = vld1q_s8(x0.qs.as_ptr().add(16)); + let xv1_0 = vld1q_s8(x1.qs.as_ptr()); + let xv1_1 = vld1q_s8(x1.qs.as_ptr().add(16)); + + let yv0_0 = vld1q_s8(y0.qs.as_ptr()); + let yv0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); + let yv1_0 = vld1q_s8(y1.qs.as_ptr()); + let yv1_1 = vld1q_s8(y1.qs.as_ptr().add(16)); + + let i8mm = i8mm_params::new(xv0_0, xv0_1, xv1_0, xv1_1, yv0_0, yv0_1, yv1_0, yv1_1); + let loop_sum_s32 = i8mm.calculate(vdupq_n_s32(0)); + + // scaling + let factor_elems: [f32; 4] = [factor_00, factor_01, factor_10, factor_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + let loop_sum_f32 = vcvtq_f32_s32(loop_sum_s32); + + sum_f32 = vmlaq_f32(sum_f32, loop_sum_f32, factor); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sum_f32, 0); + let f1 = vgetq_lane_f32(sum_f32, 1); + let f2 = vgetq_lane_f32(sum_f32, 2); + let f3 = vgetq_lane_f32(sum_f32, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} #[inline(always)] pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result { @@ -113,8 +346,8 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res for i in (0..QK_K).step_by(16) { let xs = vld1q_s8(xs.add(i)); let ys = vld1q_s8(ys.add(i)); - let xy = vdotq_s32(xs, ys); - sum_i = vaddq_s32(sum_i, xy) + + sum_i = vdotq_s32_local(sum_i, xs, ys); } sumf += vaddvq_s32(sum_i) as f32 * scale } @@ -130,8 +363,8 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let mut sum = 0f32; unsafe { let m4b = vdupq_n_u8(0xF); - let mone = vdupq_n_u8(3); + let mzero = vdupq_n_s32(0); for (x, y) in xs.iter().zip(ys.iter()) { let d_all = x.d.to_f32(); @@ -183,14 +416,14 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3)); - let p0 = vdotq_s32(q6bytes_0, q8bytes.0); - let p1 = vdotq_s32(q6bytes_1, q8bytes.1); + let p0 = vdotq_s32_local(mzero, q6bytes_0, q8bytes.0); + let p1 = vdotq_s32_local(mzero, q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vdotq_s32(q6bytes_2, q8bytes.2); - let p3 = vdotq_s32(q6bytes_3, q8bytes.3); + let p2 = vdotq_s32_local(mzero, q6bytes_2, q8bytes.2); + let p3 = vdotq_s32_local(mzero, q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); @@ -212,14 +445,14 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3)); - let p0 = vdotq_s32(q6bytes_0, q8bytes.0); - let p1 = vdotq_s32(q6bytes_1, q8bytes.1); + let p0 = vdotq_s32_local(mzero, q6bytes_0, q8bytes.0); + let p1 = vdotq_s32_local(mzero, q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vdotq_s32(q6bytes_2, q8bytes.2); - let p3 = vdotq_s32(q6bytes_3, q8bytes.3); + let p2 = vdotq_s32_local(mzero, q6bytes_2, q8bytes.2); + let p3 = vdotq_s32_local(mzero, q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); @@ -229,6 +462,274 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res } Ok(sum) } +// QK_K = 256 +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q6k_q8k( + _n: usize, + xs_0: &[BlockQ6K], + xs_1: &[BlockQ6K], + ys_0: &[BlockQ8K], + ys_1: &[BlockQ8K], +) -> Result<[f32; 4]> { + unsafe { + let mut fsum = vdupq_n_f32(0.0); + let m4b = vdupq_n_u8(0xF); + let mone = vdupq_n_u8(3); + for (x0, x1, y0, y1) in izip!(xs_0, xs_1, ys_0, ys_1) { + let d_00: f32 = x0.d.to_f32() * y0.d; + let d_01: f32 = x1.d.to_f32() * y0.d; + let d_10: f32 = x0.d.to_f32() * y1.d; + let d_11: f32 = x1.d.to_f32() * y1.d; + + let mut q6_0 = x0.ql.as_ptr(); + let mut q6_1 = x1.ql.as_ptr(); + let mut qh_0 = x0.qh.as_ptr(); + let mut qh_1 = x1.qh.as_ptr(); + let mut q8_0 = y0.qs.as_ptr(); + let mut q8_1 = y1.qs.as_ptr(); + + let mut scale_0 = x0.scales.as_ptr(); + let mut scale_1 = x1.scales.as_ptr(); + + let q8sums_0 = vld1q_s16_x2(y0.bsums.as_ptr()); + let q8sums_1 = vld1q_s16_x2(y1.bsums.as_ptr()); + let scales_0 = vld1q_s8(scale_0); + let scales_1 = vld1q_s8(scale_1); + + let q6scales_0 = int16x8x2_t( + vmovl_s8(vget_low_s8(scales_0)), + vmovl_s8(vget_high_s8(scales_0)), + ); + let q6scales_1 = int16x8x2_t( + vmovl_s8(vget_low_s8(scales_1)), + vmovl_s8(vget_high_s8(scales_1)), + ); + + // y0 x0 + let prod_00 = vaddq_s32( + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0.0), vget_low_s16(q6scales_0.0)), + vmull_s16(vget_high_s16(q8sums_0.0), vget_high_s16(q6scales_0.0)), + ), + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0.1), vget_low_s16(q6scales_0.1)), + vmull_s16(vget_high_s16(q8sums_0.1), vget_high_s16(q6scales_0.1)), + ), + ); + // y0 x1 + let prod_01 = vaddq_s32( + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0.0), vget_low_s16(q6scales_1.0)), + vmull_s16(vget_high_s16(q8sums_0.0), vget_high_s16(q6scales_1.0)), + ), + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0.1), vget_low_s16(q6scales_1.1)), + vmull_s16(vget_high_s16(q8sums_0.1), vget_high_s16(q6scales_1.1)), + ), + ); + // y1 x0 + let prod_10 = vaddq_s32( + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1.0), vget_low_s16(q6scales_0.0)), + vmull_s16(vget_high_s16(q8sums_1.0), vget_high_s16(q6scales_0.0)), + ), + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1.1), vget_low_s16(q6scales_0.1)), + vmull_s16(vget_high_s16(q8sums_1.1), vget_high_s16(q6scales_0.1)), + ), + ); + // y1 x1 + let prod_11 = vaddq_s32( + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1.0), vget_low_s16(q6scales_1.0)), + vmull_s16(vget_high_s16(q8sums_1.0), vget_high_s16(q6scales_1.0)), + ), + vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1.1), vget_low_s16(q6scales_1.1)), + vmull_s16(vget_high_s16(q8sums_1.1), vget_high_s16(q6scales_1.1)), + ), + ); + let sumi_mins_00 = vaddvq_s32(prod_00); + let sumi_mins_01 = vaddvq_s32(prod_01); + let sumi_mins_10 = vaddvq_s32(prod_10); + let sumi_mins_11 = vaddvq_s32(prod_11); + + let mut isum = vdupq_n_s32(0); + for _j in 0..QK_K / 128 { + let qhbits_0 = vld1q_u8_x2(qh_0); + let qhbits_1 = vld1q_u8_x2(qh_1); + qh_0 = qh_0.add(32); + qh_1 = qh_1.add(32); + + let q6bits_0 = vld1q_u8_x4(q6_0); + let q6bits_1 = vld1q_u8_x4(q6_1); + q6_0 = q6_0.add(64); + q6_1 = q6_1.add(64); + + let q8bytes0_0 = vld1q_s8_x4(q8_0); + let q8bytes1_0 = vld1q_s8_x4(q8_1); + q8_0 = q8_0.add(64); + q8_1 = q8_1.add(64); + + let q8bytes0_1 = vld1q_s8_x4(q8_0); + let q8bytes1_1 = vld1q_s8_x4(q8_1); + q8_0 = q8_0.add(64); + q8_1 = q8_1.add(64); + + let q6h0_0 = vshlq_n_u8(vandq_u8(mone, qhbits_0.0), 4); + let q6h0_1 = vshlq_n_u8(vandq_u8(mone, qhbits_0.1), 4); + let shifted = vshrq_n_u8(qhbits_0.0, 2); + let q6h0_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_0.1, 2); + let q6h0_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + let q6h1_0 = vshlq_n_u8(vandq_u8(mone, qhbits_1.0), 4); + let q6h1_1 = vshlq_n_u8(vandq_u8(mone, qhbits_1.1), 4); + let shifted = vshrq_n_u8(qhbits_1.0, 2); + let q6h1_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_1.1, 2); + let q6h1_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + let q6bytes0_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_0.0, m4b), q6h0_0)); + let q6bytes0_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_0.1, m4b), q6h0_1)); + let q6bytes0_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_0.2, m4b), q6h0_2)); + let q6bytes0_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_0.3, m4b), q6h0_3)); + + let q6bytes1_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_1.0, m4b), q6h1_0)); + let q6bytes1_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_1.1, m4b), q6h1_1)); + let q6bytes1_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_1.2, m4b), q6h1_2)); + let q6bytes1_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits_1.3, m4b), q6h1_3)); + + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(0) as i32, + x01: *scale_0.add(1) as i32, + x10: *scale_1.add(0) as i32, + x11: *scale_1.add(1) as i32, + }); + let i8mm = i8mm_params::new( + q6bytes0_0, + q6bytes0_1, + q6bytes1_0, + q6bytes1_1, + q8bytes0_0.0, + q8bytes0_0.1, + q8bytes1_0.0, + q8bytes1_0.1, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(2) as i32, + x01: *scale_0.add(3) as i32, + x10: *scale_1.add(2) as i32, + x11: *scale_1.add(3) as i32, + }); + let i8mm = i8mm_params::new( + q6bytes0_2, + q6bytes0_3, + q6bytes1_2, + q6bytes1_3, + q8bytes0_0.2, + q8bytes0_0.3, + q8bytes1_0.2, + q8bytes1_0.3, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + scale_0 = scale_0.add(4); + scale_1 = scale_1.add(4); + + let shifted = vshrq_n_u8(qhbits_0.0, 4); + let q6h0_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_0.1, 4); + let q6h0_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_0.0, 6); + let q6h0_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_0.1, 6); + let q6h0_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + let shifted = vshrq_n_u8(qhbits_1.0, 4); + let q6h1_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_1.1, 4); + let q6h1_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_1.0, 6); + let q6h1_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits_1.1, 6); + let q6h1_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + let q6bytes0_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_0.0, 4), q6h0_0)); + let q6bytes0_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_0.1, 4), q6h0_1)); + let q6bytes0_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_0.2, 4), q6h0_2)); + let q6bytes0_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_0.3, 4), q6h0_3)); + + let q6bytes1_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_1.0, 4), q6h1_0)); + let q6bytes1_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_1.1, 4), q6h1_1)); + let q6bytes1_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_1.2, 4), q6h1_2)); + let q6bytes1_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits_1.3, 4), q6h1_3)); + + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(0) as i32, + x01: *scale_0.add(1) as i32, + x10: *scale_1.add(0) as i32, + x11: *scale_1.add(1) as i32, + }); + let i8mm = i8mm_params::new( + q6bytes0_0, + q6bytes0_1, + q6bytes1_0, + q6bytes1_1, + q8bytes0_1.0, + q8bytes0_1.1, + q8bytes1_1.0, + q8bytes1_1.1, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(2) as i32, + x01: *scale_0.add(3) as i32, + x10: *scale_1.add(2) as i32, + x11: *scale_1.add(3) as i32, + }); + let i8mm = i8mm_params::new( + q6bytes0_2, + q6bytes0_3, + q6bytes1_2, + q6bytes1_3, + q8bytes0_1.2, + q8bytes0_1.3, + q8bytes1_1.2, + q8bytes1_1.3, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + scale_0 = scale_0.add(4); + scale_1 = scale_1.add(4); + } + let factor_elems: [f32; 4] = [d_00, d_01, d_10, d_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + //sum += d_all * y.d * ((isum - 32 * isum_mins) as f32); + let sumi_mins_arr: [i32; 4] = [ + -sumi_mins_00 * 32, + -sumi_mins_01 * 32, + -sumi_mins_10 * 32, + -sumi_mins_11 * 32, + ]; + let rawptr = &sumi_mins_arr as *const i32; + let sumi_minsv: int32x4_t = vld1q_s32(rawptr); + fsum = vmlaq_f32(fsum, factor, vcvtq_f32_s32(vaddq_s32(sumi_minsv, isum))); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(fsum, 0); + let f1 = vgetq_lane_f32(fsum, 1); + let f2 = vgetq_lane_f32(fsum, 2); + let f3 = vgetq_lane_f32(fsum, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} #[inline(always)] pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result { @@ -243,6 +744,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res unsafe { let m4b = vdupq_n_u8(0xF); + let mzero = vdupq_n_s32(0); let mone = vdupq_n_u8(1); let mtwo = vdupq_n_u8(2); @@ -298,13 +800,13 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); - let p0 = vdotq_s32(q5bytes_0, q8bytes.0); - let p1 = vdotq_s32(q5bytes_1, q8bytes.1); + let p0 = vdotq_s32_local(mzero, q5bytes_0, q8bytes.0); + let p1 = vdotq_s32_local(mzero, q5bytes_1, q8bytes.1); sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32; scales = scales.add(1); - let p2 = vdotq_s32(q5bytes_2, q8bytes.2); - let p3 = vdotq_s32(q5bytes_3, q8bytes.3); + let p2 = vdotq_s32_local(mzero, q5bytes_2, q8bytes.2); + let p3 = vdotq_s32_local(mzero, q5bytes_3, q8bytes.3); sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32; scales = scales.add(1); } @@ -313,6 +815,212 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res } Ok(sumf) } +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q5k_q8k( + _n: usize, + xs_0: &[BlockQ5K], + xs_1: &[BlockQ5K], + ys_0: &[BlockQ8K], + ys_1: &[BlockQ8K], +) -> Result<[f32; 4]> { + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + unsafe { + let mut sumfv = vdupq_n_f32(0.0); + let mut utmp_0 = [0u32; 4]; + let mut utmp_1 = [0u32; 4]; + let m4b = vdupq_n_u8(0xF); + let mone = vdupq_n_u8(1); + let mtwo = vdupq_n_u8(2); + let mzero = vdupq_n_s32(0); + for (x0, x1, y0, y1) in izip!(xs_0, xs_1, ys_0, ys_1) { + let d_00: f32 = x0.d.to_f32() * y0.d; + let d_01: f32 = x1.d.to_f32() * y0.d; + let d_10: f32 = x0.d.to_f32() * y1.d; + let d_11: f32 = x1.d.to_f32() * y1.d; + + let dmin_00 = -y0.d * x0.dmin.to_f32(); + let dmin_01 = -y0.d * x1.dmin.to_f32(); + let dmin_10 = -y1.d * x0.dmin.to_f32(); + let dmin_11 = -y1.d * x1.dmin.to_f32(); + + let q8sums_0 = vpaddq_s16( + vld1q_s16(y0.bsums.as_ptr()), + vld1q_s16(y0.bsums.as_ptr().add(8)), + ); + let q8sums_1 = vpaddq_s16( + vld1q_s16(y1.bsums.as_ptr()), + vld1q_s16(y1.bsums.as_ptr().add(8)), + ); + + LittleEndian::read_u32_into(&x0.scales, &mut utmp_0[0..3]); + LittleEndian::read_u32_into(&x1.scales, &mut utmp_1[0..3]); + + utmp_0[3] = ((utmp_0[2] >> 4) & KMASK2) | (((utmp_0[1] >> 6) & KMASK3) << 4); + let uaux = utmp_0[1] & KMASK1; + utmp_0[1] = (utmp_0[2] & KMASK2) | (((utmp_0[0] >> 6) & KMASK3) << 4); + utmp_0[2] = uaux; + utmp_0[0] &= KMASK1; + + utmp_1[3] = ((utmp_1[2] >> 4) & KMASK2) | (((utmp_1[1] >> 6) & KMASK3) << 4); + let uaux = utmp_1[1] & KMASK1; + utmp_1[1] = (utmp_1[2] & KMASK2) | (((utmp_1[0] >> 6) & KMASK3) << 4); + utmp_1[2] = uaux; + utmp_1[0] &= KMASK1; + + let mins8_0 = vld1_u8((utmp_0.as_ptr() as *const u8).add(8)); + let mins8_1 = vld1_u8((utmp_1.as_ptr() as *const u8).add(8)); + let mins_0 = vreinterpretq_s16_u16(vmovl_u8(mins8_0)); + let mins_1 = vreinterpretq_s16_u16(vmovl_u8(mins8_1)); + + // y0 x0 + let prod_00 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0), vget_low_s16(mins_0)), + vmull_s16(vget_high_s16(q8sums_0), vget_high_s16(mins_0)), + ); + // y0 x1 + let prod_01 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0), vget_low_s16(mins_1)), + vmull_s16(vget_high_s16(q8sums_0), vget_high_s16(mins_1)), + ); + // y1 x0 + let prod_10 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1), vget_low_s16(mins_0)), + vmull_s16(vget_high_s16(q8sums_1), vget_high_s16(mins_0)), + ); + // y1 x1 + let prod_11 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1), vget_low_s16(mins_1)), + vmull_s16(vget_high_s16(q8sums_1), vget_high_s16(mins_1)), + ); + let sumi_mins_00 = vaddvq_s32(prod_00); + let sumi_mins_01 = vaddvq_s32(prod_01); + let sumi_mins_10 = vaddvq_s32(prod_10); + let sumi_mins_11 = vaddvq_s32(prod_11); + + let mut scales_0 = utmp_0.as_ptr() as *const u8; + let mut scales_1 = utmp_1.as_ptr() as *const u8; + + let mut q5_0 = x0.qs.as_ptr(); + let mut q5_1 = x1.qs.as_ptr(); + let mut q8_0 = y0.qs.as_ptr(); + let mut q8_1 = y1.qs.as_ptr(); + + let mut qhbits_0 = vld1q_u8_x2(x0.qh.as_ptr()); + let mut qhbits_1 = vld1q_u8_x2(x1.qh.as_ptr()); + + let mut isum = vdupq_n_s32(0); + for _j in 0..QK_K / 64 { + let q5bits_0 = vld1q_u8_x2(q5_0); + let q5bits_1 = vld1q_u8_x2(q5_1); + q5_0 = q5_0.add(32); + q5_1 = q5_1.add(32); + let q8bytes_0 = vld1q_s8_x4(q8_0); + let q8bytes_1 = vld1q_s8_x4(q8_1); + q8_0 = q8_0.add(64); + q8_1 = q8_1.add(64); + + let q5h0_0 = vshlq_n_u8(vandq_u8(mone, qhbits_0.0), 4); + let q5h0_1 = vshlq_n_u8(vandq_u8(mone, qhbits_0.1), 4); + let q5h0_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits_0.0), 3); + let q5h0_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits_0.1), 3); + + let q5h1_0 = vshlq_n_u8(vandq_u8(mone, qhbits_1.0), 4); + let q5h1_1 = vshlq_n_u8(vandq_u8(mone, qhbits_1.1), 4); + let q5h1_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits_1.0), 3); + let q5h1_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits_1.1), 3); + + qhbits_0.0 = vshrq_n_u8(qhbits_0.0, 2); + qhbits_0.1 = vshrq_n_u8(qhbits_0.1, 2); + qhbits_1.0 = vshrq_n_u8(qhbits_1.0, 2); + qhbits_1.1 = vshrq_n_u8(qhbits_1.1, 2); + + let q5bytes0_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits_0.0, m4b), q5h0_0)); + let q5bytes0_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits_0.1, m4b), q5h0_1)); + let q5bytes0_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits_0.0, 4), q5h0_2)); + let q5bytes0_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits_0.1, 4), q5h0_3)); + + let q5bytes1_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits_1.0, m4b), q5h1_0)); + let q5bytes1_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits_1.1, m4b), q5h1_1)); + let q5bytes1_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits_1.0, 4), q5h1_2)); + let q5bytes1_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits_1.1, 4), q5h1_3)); + + let i8mm = i8mm_params::new( + q5bytes0_0, + q5bytes0_1, + q5bytes1_0, + q5bytes1_1, + q8bytes_0.0, + q8bytes_0.1, + q8bytes_1.0, + q8bytes_1.1, + ); + let i8mmres = i8mm.calculate(mzero); + + let sc_arr = [ + *scales_0 as i32, + *scales_1 as i32, + *scales_0 as i32, + *scales_1 as i32, + ]; + let rawptr = &sc_arr as *const i32; + let sc: int32x4_t = vld1q_s32(rawptr); + isum = vmlaq_s32(isum, i8mmres, sc); + + scales_0 = scales_0.add(1); + scales_1 = scales_1.add(1); + + let i8mm = i8mm_params::new( + q5bytes0_2, + q5bytes0_3, + q5bytes1_2, + q5bytes1_3, + q8bytes_0.2, + q8bytes_0.3, + q8bytes_1.2, + q8bytes_1.3, + ); + let i8mmres = i8mm.calculate(mzero); + let sc_arr = [ + *scales_0 as i32, + *scales_1 as i32, + *scales_0 as i32, + *scales_1 as i32, + ]; + let rawptr = &sc_arr as *const i32; + let sc: int32x4_t = vld1q_s32(rawptr); + isum = vmlaq_s32(isum, i8mmres, sc); + + scales_0 = scales_0.add(1); + scales_1 = scales_1.add(1); + } + let factor_elems: [f32; 4] = [d_00, d_01, d_10, d_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + + let dmin_arr: [f32; 4] = [dmin_00, dmin_01, dmin_10, dmin_11]; + let rawptr = &dmin_arr as *const f32; + let dminv: float32x4_t = vld1q_f32(rawptr); + + let sumi_mins_arr: [i32; 4] = [sumi_mins_00, sumi_mins_01, sumi_mins_10, sumi_mins_11]; + let rawptr = &sumi_mins_arr as *const i32; + let sumi_minsv: float32x4_t = vcvtq_f32_s32(vld1q_s32(rawptr)); + + let fsum = vcvtq_f32_s32(isum); + sumfv = vmlaq_f32(sumfv, fsum, factor); + sumfv = vmlaq_f32(sumfv, dminv, sumi_minsv); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sumfv, 0); + let f1 = vgetq_lane_f32(sumfv, 1); + let f2 = vgetq_lane_f32(sumfv, 2); + let f3 = vgetq_lane_f32(sumfv, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} #[inline(always)] pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result { @@ -328,6 +1036,7 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res unsafe { let m4b = vdupq_n_u8(0xF); + let mzero = vdupq_n_s32(0); for (x, y) in xs.iter().zip(ys.iter()) { let d = y.d * x.d.to_f32(); @@ -374,8 +1083,8 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), ); - let p0 = vdotq_s32(q4bytes.0, q8bytes.0); - let p1 = vdotq_s32(q4bytes.1, q8bytes.1); + let p0 = vdotq_s32_local(mzero, q4bytes.0, q8bytes.0); + let p1 = vdotq_s32_local(mzero, q4bytes.1, q8bytes.1); sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32; let q8bytes = vld1q_s8_x2(q8); @@ -384,8 +1093,8 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), ); - let p2 = vdotq_s32(q4bytes.0, q8bytes.0); - let p3 = vdotq_s32(q4bytes.1, q8bytes.1); + let p2 = vdotq_s32_local(mzero, q4bytes.0, q8bytes.0); + let p3 = vdotq_s32_local(mzero, q4bytes.1, q8bytes.1); sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32; } sumf += d * (sumi1 + sumi2) as f32; @@ -394,6 +1103,202 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res Ok(sumf) } +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q4k_q8k( + n: usize, + xs_0: &[BlockQ4K], + xs_1: &[BlockQ4K], + ys_0: &[BlockQ8K], + ys_1: &[BlockQ8K], +) -> Result<[f32; 4]> { + if n % QK_K != 0 { + crate::bail!("i8mm_q4k_q8k: {n} is not divisible by {QK_K}") + } + let mut utmp_0 = [0u32; 4]; + let mut utmp_1 = [0u32; 4]; + let mut scales_0 = [0u8; 16]; + let mut scales_1 = [0u8; 16]; + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + unsafe { + let mut sumfv = vdupq_n_f32(0.0); + let m4b = vdupq_n_u8(0xF); + + for (x0, x1, y0, y1) in izip!(xs_0, xs_1, ys_0, ys_1) { + let d_00: f32 = x0.d.to_f32() * y0.d; + let d_01: f32 = x1.d.to_f32() * y0.d; + let d_10: f32 = x0.d.to_f32() * y1.d; + let d_11: f32 = x1.d.to_f32() * y1.d; + + let dmin_00 = x0.dmin.to_f32() * y0.d; + let dmin_01 = x1.dmin.to_f32() * y0.d; + let dmin_10 = x0.dmin.to_f32() * y1.d; + let dmin_11 = x1.dmin.to_f32() * y1.d; + + let q8sums_0 = vpaddq_s16( + vld1q_s16(y0.bsums.as_ptr()), + vld1q_s16(y0.bsums.as_ptr().add(8)), + ); + let q8sums_1 = vpaddq_s16( + vld1q_s16(y1.bsums.as_ptr()), + vld1q_s16(y1.bsums.as_ptr().add(8)), + ); + LittleEndian::read_u32_into(&x0.scales, &mut utmp_0[0..3]); + LittleEndian::read_u32_into(&x1.scales, &mut utmp_1[0..3]); + + let mins8_0 = vld1_u32( + [ + utmp_0[1] & KMASK1, + ((utmp_0[2] >> 4) & KMASK2) | (((utmp_0[1] >> 6) & KMASK3) << 4), + ] + .as_ptr(), + ); + let mins8_1 = vld1_u32( + [ + utmp_1[1] & KMASK1, + ((utmp_1[2] >> 4) & KMASK2) | (((utmp_1[1] >> 6) & KMASK3) << 4), + ] + .as_ptr(), + ); + utmp_0[1] = (utmp_0[2] & KMASK2) | (((utmp_0[0] >> 6) & KMASK3) << 4); + utmp_0[0] &= KMASK1; + + utmp_1[1] = (utmp_1[2] & KMASK2) | (((utmp_1[0] >> 6) & KMASK3) << 4); + utmp_1[0] &= KMASK1; + + let mins_0 = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8_0))); // from x0 + let mins_1 = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8_1))); // from x1 + + // y0 x0 + let prod_00 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0), vget_low_s16(mins_0)), + vmull_s16(vget_high_s16(q8sums_0), vget_high_s16(mins_0)), + ); + // y0 x1 + let prod_01 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_0), vget_low_s16(mins_1)), + vmull_s16(vget_high_s16(q8sums_0), vget_high_s16(mins_1)), + ); + // y1 x0 + let prod_10 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1), vget_low_s16(mins_0)), + vmull_s16(vget_high_s16(q8sums_1), vget_high_s16(mins_0)), + ); + // y1 x1 + let prod_11 = vaddq_s32( + vmull_s16(vget_low_s16(q8sums_1), vget_low_s16(mins_1)), + vmull_s16(vget_high_s16(q8sums_1), vget_high_s16(mins_1)), + ); + + let s = [ + -dmin_00 * vaddvq_s32(prod_00) as f32, + -dmin_01 * vaddvq_s32(prod_01) as f32, + -dmin_10 * vaddvq_s32(prod_10) as f32, + -dmin_11 * vaddvq_s32(prod_11) as f32, + ]; + let rawptr = &s as *const f32; + let sumdiff: float32x4_t = vld1q_f32(rawptr); + sumfv = vaddq_f32(sumfv, sumdiff); + + LittleEndian::write_u32_into(&utmp_0, &mut scales_0); + LittleEndian::write_u32_into(&utmp_1, &mut scales_1); + + let mut q4_0 = x0.qs.as_ptr(); + let mut q4_1 = x1.qs.as_ptr(); + let mut q8_0 = y0.qs.as_ptr(); + let mut q8_1 = y1.qs.as_ptr(); + + let mut sumi1 = vdupq_n_s32(0); + let mut sumi2 = vdupq_n_s32(0); + // 0..4 + for j in 0..QK_K / 64 { + let xv0 = vld1q_u8_x2(q4_0); + let xv0_0_original = xv0.0; + let xv0_1_original = xv0.1; + q4_0 = q4_0.add(32); + + let xv1 = vld1q_u8_x2(q4_1); + let xv1_0_original = xv1.0; + let xv1_1_original = xv1.1; + q4_1 = q4_1.add(32); + + let yv0 = vld1q_s8_x2(q8_0); + let yv0_0 = yv0.0; + let yv0_1 = yv0.1; + q8_0 = q8_0.add(32); + + let yv1 = vld1q_s8_x2(q8_1); + let yv1_0 = yv1.0; + let yv1_1 = yv1.1; + q8_1 = q8_1.add(32); + + let xv0_0 = vreinterpretq_s8_u8(vandq_u8(xv0_0_original, m4b)); + let xv0_1 = vreinterpretq_s8_u8(vandq_u8(xv0_1_original, m4b)); + let xv1_0 = vreinterpretq_s8_u8(vandq_u8(xv1_0_original, m4b)); + let xv1_1 = vreinterpretq_s8_u8(vandq_u8(xv1_1_original, m4b)); + + let i8mm = i8mm_params::new(xv0_0, xv0_1, xv1_0, xv1_1, yv0_0, yv0_1, yv1_0, yv1_1); + let p1 = i8mm.calculate(vdupq_n_s32(0)); + + // x0 | x1 + // y0 | sc_0 sc_1 + // y1 | sc_0 sc_1 + let scarr = [ + scales_0[2 * j] as i32, + scales_1[2 * j] as i32, + scales_0[2 * j] as i32, + scales_1[2 * j] as i32, + ]; + let rawptr = &scarr as *const i32; + let sc: int32x4_t = vld1q_s32(rawptr); + sumi1 = vmlaq_s32(sumi1, p1, sc); + + let yv0 = vld1q_s8_x2(q8_0); + let yv0_0 = yv0.0; + let yv0_1 = yv0.1; + q8_0 = q8_0.add(32); + let yv1 = vld1q_s8_x2(q8_1); + let yv1_0 = yv1.0; + let yv1_1 = yv1.1; + q8_1 = q8_1.add(32); + + let xv0_0 = vreinterpretq_s8_u8(vshrq_n_u8(xv0_0_original, 4)); + let xv0_1 = vreinterpretq_s8_u8(vshrq_n_u8(xv0_1_original, 4)); + let xv1_0 = vreinterpretq_s8_u8(vshrq_n_u8(xv1_0_original, 4)); + let xv1_1 = vreinterpretq_s8_u8(vshrq_n_u8(xv1_1_original, 4)); + + let i8mm = i8mm_params::new(xv0_0, xv0_1, xv1_0, xv1_1, yv0_0, yv0_1, yv1_0, yv1_1); + let p2 = i8mm.calculate(vdupq_n_s32(0)); + let sc_arr = [ + scales_0[2 * j + 1] as i32, + scales_1[2 * j + 1] as i32, + scales_0[2 * j + 1] as i32, + scales_1[2 * j + 1] as i32, + ]; + let rawptr = &sc_arr as *const i32; + let sc: int32x4_t = vld1q_s32(rawptr); + sumi2 = vmlaq_s32(sumi2, p2, sc); + } + let factor_elems: [f32; 4] = [d_00, d_01, d_10, d_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + + let loop_sum_f32 = vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)); + sumfv = vmlaq_f32(sumfv, loop_sum_f32, factor); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sumfv, 0); + let f1 = vgetq_lane_f32(sumfv, 1); + let f2 = vgetq_lane_f32(sumfv, 2); + let f3 = vgetq_lane_f32(sumfv, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} + #[inline(always)] pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { @@ -407,6 +1312,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res unsafe { let m3b = vdupq_n_u8(0x3); + let mzero = vdupq_n_s32(0); let m0 = vdupq_n_u8(1); let m1 = vshlq_n_u8(m0, 1); let m2 = vshlq_n_u8(m0, 2); @@ -464,10 +1370,10 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0); - let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1); - let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2); - let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3); + let p0 = vdotq_s32_local(mzero, q3bytes_0, q8bytes_1.0); + let p1 = vdotq_s32_local(mzero, q3bytes_1, q8bytes_1.1); + let p2 = vdotq_s32_local(mzero, q3bytes_2, q8bytes_1.2); + let p3 = vdotq_s32_local(mzero, q3bytes_3, q8bytes_1.3); isum += vaddvq_s32(p0) * *scale as i32 + vaddvq_s32(p1) * *scale.add(1) as i32 + vaddvq_s32(p2) * *scale.add(2) as i32 @@ -496,10 +1402,10 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0); - let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1); - let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2); - let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3); + let p0 = vdotq_s32_local(mzero, q3bytes_0, q8bytes_2.0); + let p1 = vdotq_s32_local(mzero, q3bytes_1, q8bytes_2.1); + let p2 = vdotq_s32_local(mzero, q3bytes_2, q8bytes_2.2); + let p3 = vdotq_s32_local(mzero, q3bytes_3, q8bytes_2.3); isum += vaddvq_s32(p0) * *scale as i32 + vaddvq_s32(p1) * *scale.add(1) as i32 + vaddvq_s32(p2) * *scale.add(2) as i32 @@ -517,6 +1423,293 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res Ok(sumf) } +// NOTE: disabled for now, see BlockQ3K::SUPPORTS_I8MM +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q3k_q8k( + n: usize, + xs_0: &[BlockQ3K], + xs_1: &[BlockQ3K], + ys_0: &[BlockQ8K], + ys_1: &[BlockQ8K], +) -> Result<[f32; 4]> { + if n % QK_K != 0 { + crate::bail!("i8mm_q3k_q8k: {n} is not divisible by {QK_K}") + } + unsafe { + let mut sumfv = vdupq_n_f32(0.0); + let mut utmp_0 = [0u32; 4]; + let mut utmp_1 = [0u32; 4]; + let mut aux_0 = [0u32; 3]; + let mut aux_1 = [0u32; 3]; + const KMASK1: u32 = 0x03030303; + const KMASK2: u32 = 0x0f0f0f0f; + + let m3b = vdupq_n_u8(0x3); + let m0 = vdupq_n_u8(1); + let m1 = vshlq_n_u8(m0, 1); + let m2 = vshlq_n_u8(m0, 2); + let m3 = vshlq_n_u8(m0, 3); + + for (x0, x1, y0, y1) in izip!(xs_0, xs_1, ys_0, ys_1) { + let d_00: f32 = x0.d.to_f32() * y0.d; + let d_01: f32 = x1.d.to_f32() * y0.d; + let d_10: f32 = x0.d.to_f32() * y1.d; + let d_11: f32 = x1.d.to_f32() * y1.d; + + let mut q3_0 = x0.qs.as_ptr(); + let mut q3_1 = x1.qs.as_ptr(); + + let qh_0 = x0.hmask.as_ptr(); + let qh_1 = x1.hmask.as_ptr(); + + let mut q8_0 = y0.qs.as_ptr(); + let mut q8_1 = y1.qs.as_ptr(); + + let mut qhbits_0 = vld1q_u8_x2(qh_0); + let mut qhbits_1 = vld1q_u8_x2(qh_1); + + let mut isum = vdupq_n_s32(0); + + // Set up scales + LittleEndian::read_u32_into(&x0.scales, &mut aux_0); + LittleEndian::read_u32_into(&x1.scales, &mut aux_1); + + utmp_0[3] = ((aux_0[1] >> 4) & KMASK2) | (((aux_0[2] >> 6) & KMASK1) << 4); + utmp_0[2] = ((aux_0[0] >> 4) & KMASK2) | (((aux_0[2] >> 4) & KMASK1) << 4); + utmp_0[1] = (aux_0[1] & KMASK2) | (((aux_0[2] >> 2) & KMASK1) << 4); + utmp_0[0] = (aux_0[0] & KMASK2) | ((aux_0[2] & KMASK1) << 4); + + utmp_1[3] = ((aux_1[1] >> 4) & KMASK2) | (((aux_1[2] >> 6) & KMASK1) << 4); + utmp_1[2] = ((aux_1[0] >> 4) & KMASK2) | (((aux_1[2] >> 4) & KMASK1) << 4); + utmp_1[1] = (aux_1[1] & KMASK2) | (((aux_1[2] >> 2) & KMASK1) << 4); + utmp_1[0] = (aux_1[0] & KMASK2) | ((aux_1[2] & KMASK1) << 4); + + let mut scale_0 = utmp_0.as_mut_ptr() as *mut i8; + for j in 0..16 { + *scale_0.add(j) -= 32i8 + } + let mut scale_1 = utmp_1.as_mut_ptr() as *mut i8; + for j in 0..16 { + *scale_1.add(j) -= 32i8 + } + for j in 0..QK_K / 128 { + let q3bits_0 = vld1q_u8_x2(q3_0); + let q3bits_1 = vld1q_u8_x2(q3_1); + q3_0 = q3_0.add(32); + q3_1 = q3_1.add(32); + + // "y0" + let q8bytes0_1 = vld1q_s8_x4(q8_0); + q8_0 = q8_0.add(64); + let q8bytes0_2 = vld1q_s8_x4(q8_0); + q8_0 = q8_0.add(64); + + // "y1" + let q8bytes1_1 = vld1q_s8_x4(q8_1); + q8_1 = q8_1.add(64); + let q8bytes1_2 = vld1q_s8_x4(q8_1); + q8_1 = q8_1.add(64); + + // "x0" + let q3h_0_0 = vshlq_n_u8(vbicq_u8(m0, qhbits_0.0), 2); + let q3h_0_1 = vshlq_n_u8(vbicq_u8(m0, qhbits_0.1), 2); + let q3h_0_2 = vshlq_n_u8(vbicq_u8(m1, qhbits_0.0), 1); + let q3h_0_3 = vshlq_n_u8(vbicq_u8(m1, qhbits_0.1), 1); + + // "x1" + let q3h_1_0 = vshlq_n_u8(vbicq_u8(m0, qhbits_1.0), 2); + let q3h_1_1 = vshlq_n_u8(vbicq_u8(m0, qhbits_1.1), 2); + let q3h_1_2 = vshlq_n_u8(vbicq_u8(m1, qhbits_1.0), 1); + let q3h_1_3 = vshlq_n_u8(vbicq_u8(m1, qhbits_1.1), 1); + + // "x0" + let q3bytes_0_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits_0.0, m3b)), + vreinterpretq_s8_u8(q3h_0_0), + ); + let q3bytes_0_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits_0.1, m3b)), + vreinterpretq_s8_u8(q3h_0_1), + ); + let q3bytes_0_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_0.0, 2), m3b)), + vreinterpretq_s8_u8(q3h_0_2), + ); + let q3bytes_0_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_0.1, 2), m3b)), + vreinterpretq_s8_u8(q3h_0_3), + ); + // "x1" + let q3bytes_1_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits_1.0, m3b)), + vreinterpretq_s8_u8(q3h_1_0), + ); + let q3bytes_1_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits_1.1, m3b)), + vreinterpretq_s8_u8(q3h_1_1), + ); + let q3bytes_1_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_1.0, 2), m3b)), + vreinterpretq_s8_u8(q3h_1_2), + ); + let q3bytes_1_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_1.1, 2), m3b)), + vreinterpretq_s8_u8(q3h_1_3), + ); + + /* 4 x0s, 4 x1s + * 4 y0s, 4 y1s + * 1 step of i8mm needs 2 of each + * -> 2 sets of i8mm calcs are needed + */ + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(0) as i32, + x01: *scale_0.add(1) as i32, + x10: *scale_1.add(0) as i32, + x11: *scale_1.add(1) as i32, + }); + let i8mm = i8mm_params::new( + q3bytes_0_0, + q3bytes_0_1, + q3bytes_1_0, + q3bytes_1_1, + q8bytes0_1.0, + q8bytes0_1.1, + q8bytes1_1.0, + q8bytes1_1.1, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(2) as i32, + x01: *scale_0.add(3) as i32, + x10: *scale_1.add(2) as i32, + x11: *scale_1.add(3) as i32, + }); + let i8mm = i8mm_params::new( + q3bytes_0_2, + q3bytes_0_3, + q3bytes_1_2, + q3bytes_1_3, + q8bytes0_1.2, + q8bytes0_1.3, + q8bytes1_1.2, + q8bytes1_1.3, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + scale_0 = scale_0.add(4); + scale_1 = scale_1.add(4); + + let q3h_0_0 = vbicq_u8(m2, qhbits_0.0); + let q3h_0_1 = vbicq_u8(m2, qhbits_0.1); + + let q3h_0_3 = vshrq_n_u8(vbicq_u8(m3, qhbits_0.1), 1); + + let q3h_1_0 = vbicq_u8(m2, qhbits_1.0); + let q3h_1_1 = vbicq_u8(m2, qhbits_1.1); + let q3h_1_2 = vshrq_n_u8(vbicq_u8(m3, qhbits_1.0), 1); + let q3h_1_3 = vshrq_n_u8(vbicq_u8(m3, qhbits_1.1), 1); + + let q3bytes_0_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_0.0, 4), m3b)), + vreinterpretq_s8_u8(q3h_0_0), + ); + let q3bytes_0_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_0.1, 4), m3b)), + vreinterpretq_s8_u8(q3h_0_1), + ); + let q3bytes_0_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_0.0, 6), m3b)), + vreinterpretq_s8_u8(q3h_0_2), + ); + let q3bytes_0_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_0.1, 6), m3b)), + vreinterpretq_s8_u8(q3h_0_3), + ); + + let q3bytes_1_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_1.0, 4), m3b)), + vreinterpretq_s8_u8(q3h_1_0), + ); + let q3bytes_1_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_1.1, 4), m3b)), + vreinterpretq_s8_u8(q3h_1_1), + ); + let q3bytes_1_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_1.0, 6), m3b)), + vreinterpretq_s8_u8(q3h_1_2), + ); + let q3bytes_1_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits_1.1, 6), m3b)), + vreinterpretq_s8_u8(q3h_1_3), + ); + + // Same as above + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(0) as i32, + x01: *scale_0.add(1) as i32, + x10: *scale_1.add(0) as i32, + x11: *scale_1.add(1) as i32, + }); + let i8mm = i8mm_params::new( + q3bytes_0_0, + q3bytes_0_1, + q3bytes_1_0, + q3bytes_1_1, + q8bytes0_2.0, + q8bytes0_2.1, + q8bytes1_2.0, + q8bytes1_2.1, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + let sc = i8mm_x_scales::new(&x_scales { + x00: *scale_0.add(2) as i32, + x01: *scale_0.add(3) as i32, + x10: *scale_1.add(2) as i32, + x11: *scale_1.add(3) as i32, + }); + let i8mm = i8mm_params::new( + q3bytes_0_2, + q3bytes_0_3, + q3bytes_1_2, + q3bytes_1_3, + q8bytes0_2.2, + q8bytes0_2.3, + q8bytes1_2.2, + q8bytes1_2.3, + ); + isum = i8mm.calculate_with_scales(isum, sc); + + scale_0 = scale_0.add(4); + scale_1 = scale_1.add(4); + + if j == 0 { + qhbits_0.0 = vshrq_n_u8(qhbits_0.0, 4); + qhbits_0.1 = vshrq_n_u8(qhbits_0.1, 4); + qhbits_1.0 = vshrq_n_u8(qhbits_1.0, 4); + qhbits_1.1 = vshrq_n_u8(qhbits_1.1, 4); + } + } + let factor_elems: [f32; 4] = [d_00, d_01, d_10, d_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + + let fsum = vcvtq_f32_s32(isum); + sumfv = vmlaq_f32(sumfv, fsum, factor); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sumfv, 0); + let f1 = vgetq_lane_f32(sumfv, 1); + let f2 = vgetq_lane_f32(sumfv, 2); + let f3 = vgetq_lane_f32(sumfv, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} + #[inline(always)] pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { @@ -560,7 +1753,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res let mut isum = 0i32; let mut is = 0usize; - // TODO: dotprod for _j in 0..QK_K / 128 { let q2bits = vld1q_u8_x2(q2); q2 = q2.add(32); @@ -599,6 +1791,206 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res Ok(sumf) } +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +pub(crate) fn i8mm_q2k_q8k( + _n: usize, + xs_0: &[BlockQ2K], + xs_1: &[BlockQ2K], + ys_0: &[BlockQ8K], + ys_1: &[BlockQ8K], +) -> Result<[f32; 4]> { + let mut aux_0 = [0u8; 16]; + let mut aux_1 = [0u8; 16]; + + unsafe { + let mut sumfv = vdupq_n_f32(0.0); + let m3 = vdupq_n_u8(0x3); + let m4 = vdupq_n_u8(0xF); + for (x0, x1, y0, y1) in izip!(xs_0, xs_1, ys_0, ys_1) { + let d_00: f32 = x0.d.to_f32() * y0.d; + let d_01: f32 = x1.d.to_f32() * y0.d; + let d_10: f32 = x0.d.to_f32() * y1.d; + let d_11: f32 = x1.d.to_f32() * y1.d; + + let dmin_00 = -y0.d * x0.dmin.to_f32(); + let dmin_01 = -y0.d * x1.dmin.to_f32(); + let dmin_10 = -y1.d * x0.dmin.to_f32(); + let dmin_11 = -y1.d * x1.dmin.to_f32(); + + let mut q2_0 = x0.qs.as_ptr(); + let mut q2_1 = x1.qs.as_ptr(); + let mut q8_0 = y0.qs.as_ptr(); + let mut q8_1 = y1.qs.as_ptr(); + + let sc_0 = x0.scales.as_ptr(); + let sc_1 = x1.scales.as_ptr(); + + let mins_and_scales_0 = vld1q_u8(sc_0); + let mins_and_scales_1 = vld1q_u8(sc_1); + + let scales_0 = vandq_u8(mins_and_scales_0, m4); + let scales_1 = vandq_u8(mins_and_scales_1, m4); + + vst1q_u8(aux_0.as_mut_ptr(), scales_0); + vst1q_u8(aux_1.as_mut_ptr(), scales_1); + + let mins_0 = vshrq_n_u8(mins_and_scales_0, 4); + let mins_1 = vshrq_n_u8(mins_and_scales_1, 4); + + let q8sums_0 = vld1q_s16_x2(y0.bsums.as_ptr()); + let q8sums_1 = vld1q_s16_x2(y1.bsums.as_ptr()); + + let mins16_0 = int16x8x2_t( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins_0))), + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins_0))), + ); + let mins16_1 = int16x8x2_t( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins_1))), + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins_1))), + ); + // x --> mins16 + // y --> q8sums + let s00l = vaddq_s32( + vmull_s16(vget_low_s16(mins16_0.0), vget_low_s16(q8sums_0.0)), + vmull_s16(vget_high_s16(mins16_0.0), vget_high_s16(q8sums_0.0)), + ); + let s00h = vaddq_s32( + vmull_s16(vget_low_s16(mins16_0.1), vget_low_s16(q8sums_0.1)), + vmull_s16(vget_high_s16(mins16_0.1), vget_high_s16(q8sums_0.1)), + ); + + // 01 -> y0 * x1 + let s01l = vaddq_s32( + vmull_s16(vget_low_s16(mins16_1.0), vget_low_s16(q8sums_0.0)), + vmull_s16(vget_high_s16(mins16_1.0), vget_high_s16(q8sums_0.0)), + ); + let s01h = vaddq_s32( + vmull_s16(vget_low_s16(mins16_1.1), vget_low_s16(q8sums_0.1)), + vmull_s16(vget_high_s16(mins16_1.1), vget_high_s16(q8sums_0.1)), + ); + + // 10 -> y1 * x0 + let s10l = vaddq_s32( + vmull_s16(vget_low_s16(mins16_0.0), vget_low_s16(q8sums_1.0)), + vmull_s16(vget_high_s16(mins16_0.0), vget_high_s16(q8sums_1.0)), + ); + let s10h = vaddq_s32( + vmull_s16(vget_low_s16(mins16_0.1), vget_low_s16(q8sums_1.1)), + vmull_s16(vget_high_s16(mins16_0.1), vget_high_s16(q8sums_1.1)), + ); + + // 11 -> y1 * x1 + let s11l = vaddq_s32( + vmull_s16(vget_low_s16(mins16_1.0), vget_low_s16(q8sums_1.0)), + vmull_s16(vget_high_s16(mins16_1.0), vget_high_s16(q8sums_1.0)), + ); + let s11h = vaddq_s32( + vmull_s16(vget_low_s16(mins16_1.1), vget_low_s16(q8sums_1.1)), + vmull_s16(vget_high_s16(mins16_1.1), vget_high_s16(q8sums_1.1)), + ); + + let sumf_elems: [f32; 4] = [ + dmin_00 * vaddvq_s32(vaddq_s32(s00l, s00h)) as f32, + dmin_01 * vaddvq_s32(vaddq_s32(s01l, s01h)) as f32, + dmin_10 * vaddvq_s32(vaddq_s32(s10l, s10h)) as f32, + dmin_11 * vaddvq_s32(vaddq_s32(s11l, s11h)) as f32, + ]; + let rawptr = &sumf_elems as *const f32; + sumfv = vaddq_f32(sumfv, vld1q_f32(rawptr)); + + let mut isum = vdupq_n_s32(0i32); + let mut is = 0usize; + + for _j in 0..QK_K / 128 { + let q2bits_0 = vld1q_u8_x2(q2_0); + q2_0 = q2_0.add(32); + let mut q2bytes_0 = int8x16x2_t( + vreinterpretq_s8_u8(vandq_u8(q2bits_0.0, m3)), + vreinterpretq_s8_u8(vandq_u8(q2bits_0.1, m3)), + ); + let q2bits_1 = vld1q_u8_x2(q2_1); + q2_1 = q2_1.add(32); + let mut q2bytes_1 = int8x16x2_t( + vreinterpretq_s8_u8(vandq_u8(q2bits_1.0, m3)), + vreinterpretq_s8_u8(vandq_u8(q2bits_1.1, m3)), + ); + + let q8bytes_0 = vld1q_s8_x2(q8_0); + q8_0 = q8_0.add(32); + let q8bytes_1 = vld1q_s8_x2(q8_1); + q8_1 = q8_1.add(32); + isum = vaddq_s32( + isum, + i8mm_accum_with_scale( + &aux_0, &aux_1, is, 0, q2bytes_0, q2bytes_1, q8bytes_0, q8bytes_1, + ), + ); + + q2bytes_0.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_0.0, 2), m3)); + q2bytes_0.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_0.1, 2), m3)); + q2bytes_1.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_1.0, 2), m3)); + q2bytes_1.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_1.1, 2), m3)); + let q8bytes_0 = vld1q_s8_x2(q8_0); + q8_0 = q8_0.add(32); + let q8bytes_1 = vld1q_s8_x2(q8_1); + q8_1 = q8_1.add(32); + isum = vaddq_s32( + isum, + i8mm_accum_with_scale( + &aux_0, &aux_1, is, 2, q2bytes_0, q2bytes_1, q8bytes_0, q8bytes_1, + ), + ); + + q2bytes_0.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_0.0, 4), m3)); + q2bytes_0.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_0.1, 4), m3)); + q2bytes_1.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_1.0, 4), m3)); + q2bytes_1.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_1.1, 4), m3)); + let q8bytes_0 = vld1q_s8_x2(q8_0); + q8_0 = q8_0.add(32); + let q8bytes_1 = vld1q_s8_x2(q8_1); + q8_1 = q8_1.add(32); + isum = vaddq_s32( + isum, + i8mm_accum_with_scale( + &aux_0, &aux_1, is, 4, q2bytes_0, q2bytes_1, q8bytes_0, q8bytes_1, + ), + ); + + q2bytes_0.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_0.0, 6), m3)); + q2bytes_0.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_0.1, 6), m3)); + q2bytes_1.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_1.0, 6), m3)); + q2bytes_1.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits_1.1, 6), m3)); + let q8bytes_0 = vld1q_s8_x2(q8_0); + q8_0 = q8_0.add(32); + let q8bytes_1 = vld1q_s8_x2(q8_1); + q8_1 = q8_1.add(32); + isum = vaddq_s32( + isum, + i8mm_accum_with_scale( + &aux_0, &aux_1, is, 6, q2bytes_0, q2bytes_1, q8bytes_0, q8bytes_1, + ), + ); + is += 8; + } + + let factor_elems: [f32; 4] = [d_00, d_01, d_10, d_11]; + let rawptr = &factor_elems as *const f32; + let factor: float32x4_t = vld1q_f32(rawptr); + + let fsum = vcvtq_f32_s32(isum); + sumfv = vmlaq_f32(sumfv, fsum, factor); + } + // extract elements of the vector register + let f0 = vgetq_lane_f32(sumfv, 0); + let f1 = vgetq_lane_f32(sumfv, 1); + let f2 = vgetq_lane_f32(sumfv, 2); + let f3 = vgetq_lane_f32(sumfv, 3); + let res: [f32; 4] = [f0, f1, f2, f3]; + Ok(res) + } +} + #[inline(always)] unsafe fn multiply_accum_with_scale( aux: &[u8; 16], @@ -607,7 +1999,314 @@ unsafe fn multiply_accum_with_scale( q2bytes: int8x16x2_t, q8bytes: int8x16x2_t, ) -> i32 { - let p1 = vdotq_s32(q2bytes.0, q8bytes.0); - let p2 = vdotq_s32(q2bytes.1, q8bytes.1); + let mzero = vdupq_n_s32(0); + let p1 = vdotq_s32_local(mzero, q2bytes.0, q8bytes.0); + let p2 = vdotq_s32_local(mzero, q2bytes.1, q8bytes.1); vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32 } + +#[inline(always)] +pub(crate) fn vec_dot_iq4_xs_q8k(n: usize, xs: &[BlockIQ4xs], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_iq4_xs_q8k: {n} is not divisible by {QK_K}") + } + + unsafe { + let values = vld1q_s8(KVALUES_IQ4NL.as_ptr()); + let m4b = vdupq_n_u8(0x0f); + + let mut q4b = int8x16x4_t(vdupq_n_s8(0), vdupq_n_s8(0), vdupq_n_s8(0), vdupq_n_s8(0)); + + let mut sumf = 0f32; + + let nb = n / QK_K; + for ibl in 0..nb { + let y_block = &ys[ibl]; + let x_block = &xs[ibl]; + let mut q8 = y_block.qs.as_ptr(); + let mut q4 = x_block.qs.as_ptr(); + let mut h = x_block.scales_h; + + let mut sumi1 = 0; + let mut sumi2 = 0; + for ib in 0..(QK_K / 64) { + let q4bits = vld1q_u8_x2(q4); + q4 = q4.add(32); + let q8b = vld1q_s8_x4(q8); + q8 = q8.add(64); + + q4b.0 = vqtbl1q_s8(values, vandq_u8(q4bits.0, m4b)); + q4b.1 = vqtbl1q_s8(values, vshrq_n_u8(q4bits.0, 4)); + q4b.2 = vqtbl1q_s8(values, vandq_u8(q4bits.1, m4b)); + q4b.3 = vqtbl1q_s8(values, vshrq_n_u8(q4bits.1, 4)); + + let prod1 = + vdotq_s32_local(vdotq_s32_local(vdupq_n_s32(0), q4b.0, q8b.0), q4b.1, q8b.1); + let prod2 = + vdotq_s32_local(vdotq_s32_local(vdupq_n_s32(0), q4b.2, q8b.2), q4b.3, q8b.3); + + let ls1 = (x_block.scales_l[ib] & 0xf) as i32 | ((h << 4) & 0x30) as i32 - 32; + let ls2 = (x_block.scales_l[ib] >> 4) as i32 | ((h << 2) & 0x30) as i32 - 32; + h >>= 4; + + sumi1 += vaddvq_s32(prod1) * ls1; + sumi2 += vaddvq_s32(prod2) * ls2; + } + + sumf += xs[ibl].d.to_f32() * ys[ibl].d * (sumi1 + sumi2) as f32; + } + + Ok(sumf) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_iq4_nl_q8k(n: usize, xs: &[BlockIQ4nl], ys: &[BlockQ8_0]) -> Result { + if n % QK4_NL != 0 { + crate::bail!("vec_dot_iq4_nl_q8k: {n} is not divisible by {QK4_NL}") + } + + unsafe { + let values = vld1q_s8(KVALUES_IQ4NL.as_ptr()); + let m4b = vdupq_n_u8(0x0f); + + let mut q4b = int8x16x4_t(vdupq_n_s8(0), vdupq_n_s8(0), vdupq_n_s8(0), vdupq_n_s8(0)); + let mut q8b = int8x16x4_t(vdupq_n_s8(0), vdupq_n_s8(0), vdupq_n_s8(0), vdupq_n_s8(0)); + let mut q4bits = uint8x16x2_t(vdupq_n_u8(0), vdupq_n_u8(0)); + + let mut sumf = 0f32; + + let nb = n / QK4_NL; + for ib in (0..nb - 1).step_by(2) { + q4bits.0 = vld1q_u8(xs[ib].qs.as_ptr()); + q4bits.1 = vld1q_u8(xs[ib + 1].qs.as_ptr()); + q8b.0 = vld1q_s8(ys[ib].qs.as_ptr()); + q8b.1 = vld1q_s8(ys[ib].qs.as_ptr().add(16)); + q8b.2 = vld1q_s8(ys[ib + 1].qs.as_ptr()); + q8b.3 = vld1q_s8(ys[ib + 1].qs.as_ptr().add(16)); + + q4b.0 = vqtbl1q_s8(values, vandq_u8(q4bits.0, m4b)); + q4b.1 = vqtbl1q_s8(values, vshrq_n_u8(q4bits.0, 4)); + q4b.2 = vqtbl1q_s8(values, vandq_u8(q4bits.1, m4b)); + q4b.3 = vqtbl1q_s8(values, vshrq_n_u8(q4bits.1, 4)); + + let prod1 = + vdotq_s32_local(vdotq_s32_local(vdupq_n_s32(0), q4b.0, q8b.0), q4b.1, q8b.1); + let prod2 = + vdotq_s32_local(vdotq_s32_local(vdupq_n_s32(0), q4b.2, q8b.2), q4b.3, q8b.3); + + sumf += xs[ib].d.to_f32() * ys[ib].d.to_f32() * vaddvq_s32(prod1) as f32 + + xs[ib + 1].d.to_f32() * ys[ib + 1].d.to_f32() * vaddvq_s32(prod2) as f32; + } + + Ok(sumf) + } +} + +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +unsafe fn i8mm_accum_with_scale( + aux_0: &[u8; 16], + aux_1: &[u8; 16], + is: usize, + index: usize, + q2bytes_0: int8x16x2_t, + q2bytes_1: int8x16x2_t, + q8bytes_0: int8x16x2_t, + q8bytes_1: int8x16x2_t, +) -> int32x4_t { + let mzero = vdupq_n_s32(0); + + let c00 = aux_0[is + index] as i32; + let c01 = aux_0[is + index + 1] as i32; + let c10 = aux_1[is + index] as i32; + let c11 = aux_1[is + index + 1] as i32; + + let x00 = q2bytes_0.0; + let x01 = q2bytes_0.1; + let x10 = q2bytes_1.0; + let x11 = q2bytes_1.1; + + let y00 = q8bytes_0.0; + let y01 = q8bytes_0.1; + let y10 = q8bytes_1.0; + let y11 = q8bytes_1.1; + + let x_sc = x_scales { + x00: c00, + x01: c01, + x10: c10, + x11: c11, + }; + let i8mm_sc = i8mm_x_scales::new(&x_sc); + let mm = i8mm_params::new(x00, x01, x10, x11, y00, y01, y10, y11); + mm.calculate_with_scales(mzero, i8mm_sc) +} +#[allow(non_camel_case_types)] +#[cfg(feature = "arm-nightly-feat")] +struct i8mm_params { + x0: int8x16_t, + x1: int8x16_t, + x2: int8x16_t, + x3: int8x16_t, + y0: int8x16_t, + y1: int8x16_t, + y2: int8x16_t, + y3: int8x16_t, +} + +#[allow(non_camel_case_types)] +#[cfg(feature = "arm-nightly-feat")] +/// scales from scalar version +struct x_scales { + x00: i32, + x01: i32, + x10: i32, + x11: i32, +} +#[allow(non_camel_case_types)] +#[cfg(feature = "arm-nightly-feat")] +/// scales reorganized to fit i8mm calculations +struct i8mm_x_scales { + sc0: int32x4_t, + sc1: int32x4_t, +} + +#[cfg(feature = "arm-nightly-feat")] +impl i8mm_x_scales { + #[inline(always)] + unsafe fn new(sc: &x_scales) -> Self { + let v00 = vdupq_n_s32(sc.x00); + let v01 = vdupq_n_s32(sc.x01); + let v10 = vdupq_n_s32(sc.x10); + let v11 = vdupq_n_s32(sc.x11); + + let sc0 = vzip1q_s32(v00, v10); + let sc1 = vzip1q_s32(v01, v11); + + i8mm_x_scales { sc0, sc1 } + } +} + +#[cfg(feature = "arm-nightly-feat")] +impl i8mm_params { + #[inline(always)] + unsafe fn new( + xv0_0: int8x16_t, + xv0_1: int8x16_t, + xv1_0: int8x16_t, + xv1_1: int8x16_t, + yv0_0: int8x16_t, + yv0_1: int8x16_t, + yv1_0: int8x16_t, + yv1_1: int8x16_t, + ) -> Self { + // 1. 16xi8 -> 2xi64 + let xv0_0 = vreinterpretq_s64_s8(xv0_0); + let xv0_1 = vreinterpretq_s64_s8(xv0_1); + let xv1_0 = vreinterpretq_s64_s8(xv1_0); + let xv1_1 = vreinterpretq_s64_s8(xv1_1); + + let yv0_0 = vreinterpretq_s64_s8(yv0_0); + let yv0_1 = vreinterpretq_s64_s8(yv0_1); + let yv1_0 = vreinterpretq_s64_s8(yv1_0); + let yv1_1 = vreinterpretq_s64_s8(yv1_1); + + // 2. ZIP + let x0_0 = vzip1q_s64(xv0_0, xv1_0); + let x0_1 = vzip2q_s64(xv0_0, xv1_0); + let x1_0 = vzip1q_s64(xv0_1, xv1_1); + let x1_1 = vzip2q_s64(xv0_1, xv1_1); + + let y0_0 = vzip1q_s64(yv0_0, yv1_0); + let y0_1 = vzip2q_s64(yv0_0, yv1_0); + let y1_0 = vzip1q_s64(yv0_1, yv1_1); + let y1_1 = vzip2q_s64(yv0_1, yv1_1); + + // 3. interpret back + let x0_0 = vreinterpretq_s8_s64(x0_0); + let x0_1 = vreinterpretq_s8_s64(x0_1); + let x1_0 = vreinterpretq_s8_s64(x1_0); + let x1_1 = vreinterpretq_s8_s64(x1_1); + + let y0_0 = vreinterpretq_s8_s64(y0_0); + let y0_1 = vreinterpretq_s8_s64(y0_1); + let y1_0 = vreinterpretq_s8_s64(y1_0); + let y1_1 = vreinterpretq_s8_s64(y1_1); + + i8mm_params { + x0: x0_0, + x1: x0_1, + x2: x1_0, + x3: x1_1, + y0: y0_0, + y1: y0_1, + y2: y1_0, + y3: y1_1, + } + } + + #[inline(always)] + unsafe fn calculate(&self, acc: int32x4_t) -> int32x4_t { + if is_aarch64_feature_detected!("i8mm") { + self.impl_calc(acc) + } else { + // never takes this branch, but the check is needed + // for inlining the vmmlaq intrinsics + // see: + // https://community.arm.com/arm-community-blogs/b/architectures-and-processors-blog/posts/rust-neon-intrinsics + unreachable!(); + } + } + unsafe fn impl_calc(&self, acc: int32x4_t) -> int32x4_t { + let mut a = acc; + a = vmmlaq_s32(a, self.y0, self.x0); + a = vmmlaq_s32(a, self.y1, self.x1); + a = vmmlaq_s32(a, self.y2, self.x2); + vmmlaq_s32(a, self.y3, self.x3) + } + + unsafe fn calculate_with_scales(&self, acc: int32x4_t, scales: i8mm_x_scales) -> int32x4_t { + if is_aarch64_feature_detected!("i8mm") { + self.impl_calc_scales(acc, scales) + } else { + // never takes this branch, but the check is needed + // for inlining the vmmlaq intrinsics + // see: + // https://community.arm.com/arm-community-blogs/b/architectures-and-processors-blog/posts/rust-neon-intrinsics + unreachable!(); + } + } + #[inline(always)] + unsafe fn impl_calc_scales(&self, acc: int32x4_t, scales: i8mm_x_scales) -> int32x4_t { + let mzero = vdupq_n_s32(0); + let a = vmulq_s32(vmmlaq_s32(mzero, self.y0, self.x0), scales.sc0); + let b = vmulq_s32(vmmlaq_s32(mzero, self.y1, self.x1), scales.sc0); + let c = vmulq_s32(vmmlaq_s32(mzero, self.y2, self.x2), scales.sc1); + let d = vmulq_s32(vmmlaq_s32(mzero, self.y3, self.x3), scales.sc1); + + let mut sum; + sum = vaddq_s32(acc, a); + sum = vaddq_s32(sum, b); + sum = vaddq_s32(sum, c); + sum = vaddq_s32(sum, d); + sum + } +} + +#[inline(always)] +#[cfg(feature = "arm-nightly-feat")] +unsafe fn vdotq_s32_local(vz: int32x4_t, a: int8x16_t, b: int8x16_t) -> int32x4_t { + if is_aarch64_feature_detected!("dotprod") { + vdotq_s32(vz, a, b) + } else { + unreachable!(); + } +} +#[inline(always)] +#[cfg(not(feature = "arm-nightly-feat"))] +unsafe fn vdotq_s32_local(vz: int32x4_t, a: int8x16_t, b: int8x16_t) -> int32x4_t { + let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); + let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + vaddq_s32(vz, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))) +} diff --git a/candle-core/src/quantized/quants.rs b/candle-core/src/quantized/quants.rs new file mode 100644 index 0000000000..77e95632fe --- /dev/null +++ b/candle-core/src/quantized/quants.rs @@ -0,0 +1,357 @@ +use super::GgmlDType; +use crate::Result; +use half::{bf16, f16}; +use rayon::prelude::*; + +pub trait GgmlType: Sized + Clone + Send + Sync { + const DTYPE: GgmlDType; + const BLCK_SIZE: usize; + type VecDotType: GgmlType; + const SUPPORTS_I8MM: bool; + + // This is only safe for types that include immediate values such as float/int/... + fn zeros() -> Self { + unsafe { std::mem::MaybeUninit::zeroed().assume_init() } + } + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>; + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>; + fn from_float_imatrix( + _xs: &[f32], + _ys: &mut [Self], + _imatrix_weights: &[f32], + _n_per_row: usize, + ) -> Result<()> { + crate::bail!( + "`from_float_imatrix` is unimplemented for {:?}", + Self::DTYPE + ); + } + + /// Dot product used as a building block for quantized mat-mul. + /// n is the number of elements to be considered. + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; + + /// Generic implementation of the dot product without simd optimizations. + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; + + /// Multiply 2 rows by 2 columns and return a 2x2 matrix + /// based on aarch64 NEON i8mm instructions + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]>; +} + +impl GgmlType for f32 { + const DTYPE: GgmlDType = GgmlDType::F32; + const BLCK_SIZE: usize = 1; + type VecDotType = f32; + const SUPPORTS_I8MM: bool = false; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if xs.len() < n { + crate::bail!("size mismatch {} < {n}", xs.len()) + } + if ys.len() < n { + crate::bail!("size mismatch {} < {n}", ys.len()) + } + let mut res = 0f32; + unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; + Ok(res) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + ys.copy_from_slice(xs); + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + ys.copy_from_slice(xs); + Ok(()) + } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } +} + +impl GgmlType for f16 { + const DTYPE: GgmlDType = GgmlDType::F16; + const BLCK_SIZE: usize = 1; + type VecDotType = f16; + const SUPPORTS_I8MM: bool = false; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if xs.len() < n { + crate::bail!("size mismatch {} < {n}", xs.len()) + } + if ys.len() < n { + crate::bail!("size mismatch {} < {n}", ys.len()) + } + let mut res = 0f32; + unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; + Ok(res) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = f16::from_f32(*x) + } + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = x.to_f32() + } + Ok(()) + } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } +} + +impl GgmlType for bf16 { + const DTYPE: GgmlDType = GgmlDType::BF16; + const BLCK_SIZE: usize = 1; + type VecDotType = bf16; + const SUPPORTS_I8MM: bool = false; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if xs.len() < n { + crate::bail!("size mismatch {} < {n}", xs.len()) + } + if ys.len() < n { + crate::bail!("size mismatch {} < {n}", ys.len()) + } + let mut res = 0f32; + unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; + Ok(res) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = bf16::from_f32(*x) + } + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = x.to_f32() + } + Ok(()) + } + + #[allow(unused)] + #[cfg(feature = "arm-nightly-feat")] + fn matmul_i8mm( + n: usize, + xs_0: &[Self], + xs_1: &[Self], + ys_0: &[Self::VecDotType], + ys_1: &[Self::VecDotType], + ) -> Result<[f32; 4]> { + crate::bail!("Unsupported block type for i8mm"); + } +} + +fn matmul_ref( + mkn: (usize, usize, usize), + lhs: &[f32], + rhs_t: &[T], + dst: &mut [f32], +) -> Result<()> { + let (m, k, n) = mkn; + if m * k != lhs.len() { + crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); + } + + let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE); + let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE); + // TODO: Do not make this copy if the DotType is f32. + // TODO: Pre-allocate this. + let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; + for row_idx in 0..m { + let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; + T::VecDotType::from_float(lhs, lhs_b)? + } + let lhs_b = lhs_b.as_slice(); + + for row_idx in 0..m { + let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; + + let result: Result> = dst_row + .into_par_iter() + .enumerate() + .with_min_len(128) + .with_max_len(512) + .map(|(col_idx, dst)| { + let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value) + }) + .collect(); + + result?; + } + Ok(()) +} + +// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 +#[cfg(not(all(feature = "arm-nightly-feat", target_feature = "i8mm")))] +pub fn matmul( + mkn: (usize, usize, usize), + lhs: &[f32], + rhs_t: &[T], + dst: &mut [f32], +) -> Result<()> { + matmul_ref(mkn, lhs, rhs_t, dst) +} + +#[cfg(all(feature = "arm-nightly-feat", target_feature = "i8mm"))] +pub fn matmul( + mkn: (usize, usize, usize), + lhs: &[f32], + rhs_t: &[T], + dst: &mut [f32], +) -> Result<()> { + if !T::SUPPORTS_I8MM { + return matmul_ref(mkn, lhs, rhs_t, dst); + } + + let (m, k, n) = mkn; + if m * k != lhs.len() { + crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); + } + + let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE; + let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE; + // TODO: Do not make this copy if the DotType is f32. + // TODO: Pre-allocate this. + let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; + for row_idx in 0..m { + let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; + T::VecDotType::from_float(lhs, lhs_b)? + } + let lhs_b = lhs_b.as_slice(); + + let m_even = m % 2 == 0; + let m_limit = if m_even { m } else { m - 1 }; + let n_even = n % 2 == 0; + let n_limit = if n_even { n } else { n - 1 }; + + for row_idx in (0..m_limit).step_by(2) { + let lhs_row_0 = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs_row_1 = &lhs_b[(row_idx + 1) * k_in_lhs_blocks..(row_idx + 2) * k_in_lhs_blocks]; + + let dst_2_rows = &mut dst[row_idx * n..(row_idx + 2) * n]; + let (dst_row_0, dst_row_1) = dst_2_rows.split_at_mut(dst_2_rows.len() / 2); + + let dst_row_0_n = &mut dst_row_0[0..n_limit]; + let dst_row_1_n = &mut dst_row_1[0..n_limit]; + + let _result: Vec<_> = dst_row_0_n + .par_chunks_mut(2) + .zip(dst_row_1_n.par_chunks_mut(2)) + .enumerate() + .with_min_len(128) + .with_max_len(512) + .map(|(half_of_col_idx, (dst_0, dst_1))| { + let col_idx = half_of_col_idx * 2; // each step has 2 columns + let rhs_col_0 = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + let rhs_col_1 = + &rhs_t[(col_idx + 1) * k_in_rhs_blocks..(col_idx + 2) * k_in_rhs_blocks]; + + T::matmul_i8mm(k, rhs_col_0, rhs_col_1, lhs_row_0, lhs_row_1).map(|mm| { + dst_0[0] = mm[0]; + dst_0[1] = mm[1]; + dst_1[0] = mm[2]; + dst_1[1] = mm[3]; + }) + }) + .collect(); + if !n_even { + let col_idx = n - 1; + let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + dst_row_0[col_idx] = T::vec_dot(k, rhs_col, lhs_row_0).unwrap(); + dst_row_1[col_idx] = T::vec_dot(k, rhs_col, lhs_row_1).unwrap(); + } + } + if !m_even { + let row_idx = m - 1; + let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + + let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; + let result: Result> = dst_row + .into_par_iter() + .enumerate() + .with_min_len(128) + .with_max_len(512) + .map(|(col_idx, dst)| { + let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value) + }) + .collect(); + + result?; + } + Ok(()) +} diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index 1c8c0f2068..66ef2bf5b4 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -1,4 +1,6 @@ -use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; +use super::k_quants::{ + BlockF8Q8, BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, +}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; use half::f16; @@ -91,6 +93,46 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> } } +#[inline(always)] +pub(crate) fn vec_dot_f8q8_q8_0(n: usize, xs: &[BlockF8Q8], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_f8q8_q8_0: {n} is not divisible by {qk}") + } + unsafe { + let mut acc = f32x4_splat(0.0f32); + for (x, y) in xs.iter().zip(ys.iter()) { + let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr()); + let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr()); + let sum_xy = i32x4_dot_i16x8(x1, y1); + + let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8)); + let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2)); + + let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16)); + let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3)); + + let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24)); + let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4)); + + let sum_xy = f32x4_convert_i32x4(sum_xy); + + // f32x4_relaxed_madd is nightly only. + let d = f32x4_splat(x.dq_d() * f16::to_f32(y.d)); + let scaled = f32x4_mul(sum_xy, d); + acc = f32x4_add(acc, scaled) + } + let res = f32x4_extract_lane::<0>(acc) + + f32x4_extract_lane::<1>(acc) + + f32x4_extract_lane::<2>(acc) + + f32x4_extract_lane::<3>(acc); + Ok(res) + } +} + #[inline(always)] pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index ab3f15bcf8..66bbd3f8c2 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1,11 +1,11 @@ use candle_core::{ bail, - quantized::{self, GgmlDType}, + quantized::{self, quants, GgmlDType}, test_device, test_utils::to_vec2_round, DType, Device, IndexOp, Module, Result, Tensor, Var, }; -use quantized::{k_quants, GgmlType}; +use quantized::{iq_quants, k_quants, GgmlType}; use rand::prelude::*; const GGML_TEST_SIZE: usize = 32 * 128; @@ -90,7 +90,7 @@ fn quantized_matmul(device: &Device) -> Result<()> { let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let rhs = (0..(k * n)).map(|v| v as f32).collect::>(); k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; - k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; + quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; assert_eq!( dst.iter().map(|x| x.round()).collect::>(), &[ @@ -155,7 +155,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> { .collect::>(); let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; - k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; + quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; assert_eq!( dst.iter().map(|x| x.round()).collect::>(), &[ @@ -386,6 +386,26 @@ fn quantize_q5_1(device: &Device) -> Result<()> { Ok(()) } +fn quantize_f8q8(device: &Device) -> Result<()> { + let dtype = GgmlDType::F8Q8; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.01); + + Ok(()) +} + fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result { assert!( size % crate::quantized::k_quants::QK_K == 0, @@ -949,6 +969,209 @@ fn quantize_q8k(device: &Device) -> Result<()> { Ok(()) } +fn quantize_iq4_xs(device: &Device) -> Result<()> { + let dtype = GgmlDType::Iq4Xs; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.025); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 5.9); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + + Ok(()) +} + +fn quantize_iq4_nl(device: &Device) -> Result<()> { + let dtype = GgmlDType::Iq4Nl; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.025); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 5.9); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + + Ok(()) +} + +fn quantize_iq3_xxs(device: &Device) -> Result<()> { + let dtype = GgmlDType::Iq3Xxs; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + dbg!(&src.mean_all()?); + dbg!(&dst.mean_all()?); + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.025); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 5.9); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + + Ok(()) +} + +#[test] +fn imatrix_quantize_iq4_xs() -> Result<()> { + // let data = + // quantized::imatrix_file::load_imatrix("../Llama-3.2-3B-Instruct.imatrix").unwrap(); + // for (name, weights) in &data { + // println!("{name}, {} elems", weights.len()); + // } + // dbg!(&data["blk.0.attn_q.weight"].len()); + + let cpu = &Device::Cpu; + + let mut row_counts = 0f64; + let mut ncall = 0f64; + let mut values = Tensor::zeros((768,), DType::F32, cpu)?; + + for _ in 0..10 { + let lhs = Var::from_tensor(&Tensor::randn(0f32, 1f32, (1024, 512), cpu)?)?; + let rhs = Var::from_tensor(&Tensor::randn(0f32, 1f32, (512, 768), cpu)?)?; + let res = lhs.matmul(&rhs)?; + + // https://github.com/ggerganov/llama.cpp/blob/678d7994f4da0af3d29046be99950ac999ee9762/examples/imatrix/imatrix.cpp#L180-L186 + values = (values + res.sqr()?.sum(0)?)?; + row_counts += res.dim(0)? as f64; + ncall += 1.; + } + + // https://github.com/ggerganov/llama.cpp/blob/678d7994f4da0af3d29046be99950ac999ee9762/examples/imatrix/imatrix.cpp#L275 + let out = ((values / row_counts)? * ncall)?; + let imatrix = out.to_vec1::()?; + + let xs = Tensor::randn(0f32, 1f32, (1024, 768), cpu)?; + + let quant1 = quantized::QTensor::quantize(&xs, GgmlDType::Iq4Xs)?; + let quant2 = quantized::QTensor::quantize_imatrix(&xs, &imatrix, GgmlDType::Iq4Xs)?; + + let dequant1 = quant1.dequantize(cpu)?; + let dequant2 = quant2.dequantize(cpu)?; + + let err1 = (dequant1 - &xs)?.abs()?.mean_all()?.to_scalar::()?; + let err2 = (dequant2 - &xs)?.abs()?.mean_all()?.to_scalar::()?; + assert!(err2 < err1, "err2 {err2} > err1 {err1}"); + + Ok(()) +} + +#[test] +fn imatrix_quantize_iq4_nl() -> Result<()> { + // let data = + // quantized::imatrix_file::load_imatrix("../Llama-3.2-3B-Instruct.imatrix").unwrap(); + // for (name, weights) in &data { + // println!("{name}, {} elems", weights.len()); + // } + // dbg!(&data["blk.0.attn_q.weight"].len()); + + let cpu = &Device::Cpu; + + let mut row_counts = 0f64; + let mut ncall = 0f64; + let mut values = Tensor::zeros((768,), DType::F32, cpu)?; + + for _ in 0..10 { + let lhs = Var::from_tensor(&Tensor::randn(0f32, 1f32, (1024, 512), cpu)?)?; + let rhs = Var::from_tensor(&Tensor::randn(0f32, 1f32, (512, 768), cpu)?)?; + let res = lhs.matmul(&rhs)?; + + // https://github.com/ggerganov/llama.cpp/blob/678d7994f4da0af3d29046be99950ac999ee9762/examples/imatrix/imatrix.cpp#L180-L186 + values = (values + res.sqr()?.sum(0)?)?; + row_counts += res.dim(0)? as f64; + ncall += 1.; + } + + // https://github.com/ggerganov/llama.cpp/blob/678d7994f4da0af3d29046be99950ac999ee9762/examples/imatrix/imatrix.cpp#L275 + let out = ((values / row_counts)? * ncall)?; + let imatrix = out.to_vec1::()?; + + let xs = Tensor::randn(0f32, 1f32, (1024, 768), cpu)?; + + let quant1 = quantized::QTensor::quantize(&xs, GgmlDType::Iq4Nl)?; + let quant2 = quantized::QTensor::quantize_imatrix(&xs, &imatrix, GgmlDType::Iq4Nl)?; + + let dequant1 = quant1.dequantize(cpu)?; + let dequant2 = quant2.dequantize(cpu)?; + + let err1 = (dequant1 - &xs)?.abs()?.mean_all()?.to_scalar::()?; + let err2 = (dequant2 - &xs)?.abs()?.mean_all()?.to_scalar::()?; + assert!(err2 < err1, "err2 {err2} > err1 {err1}"); + + Ok(()) +} + test_device!( quantize_q4_0, quantize_q4_0_cpu, @@ -973,6 +1196,12 @@ test_device!( quantize_q5_1_cuda, quantize_q5_1_metal ); +test_device!( + quantize_f8q8, + quantize_f8q8_cpu, + quantize_f8q8_cuda, + quantize_f8q8_metal +); test_device!( quantize_q2k, quantize_q2k_cpu, @@ -1009,6 +1238,24 @@ test_device!( quantize_q8k_cuda, quantize_q8k_metal ); +test_device!( + quantize_iq4_xs, + quantize_iq4_xs_cpu, + quantize_iq4_xs_cuda, + quantize_iq4_xs_metal +); +test_device!( + quantize_iq4_nl, + quantize_iq4_nl_cpu, + quantize_iq4_nl_cuda, + quantize_iq4_nl_metal +); +test_device!( + quantize_iq3_xxs, + quantize_iq3_xxs_cpu, + quantize_iq3_xxs_cuda, + quantize_iq3_xxs_metal +); /// Very simple dot product implementation fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { @@ -1029,6 +1276,8 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { GgmlDType::Q5_0 => 0.001353, GgmlDType::Q5_1 => 0.00149, GgmlDType::Q8_0 => 0.000092, + GgmlDType::Iq4Xs => 0.001903, + GgmlDType::Iq4Nl => 0.002716, // Not from the ggml repo. GgmlDType::Q8K => 0.00065, @@ -1198,6 +1447,20 @@ quantized_matmul!( quantized_matmul_q3k_metal, GgmlDType::Q3K ); +quantized_matmul!( + quantized_matmul_iq4xs_bis, + quantized_matmul_iq4xs_cpu, + quantized_matmul_iq4xs_cuda, + quantized_matmul_iq4xs_metal, + GgmlDType::Iq4Xs +); +quantized_matmul!( + quantized_matmul_iq4nl_bis, + quantized_matmul_iq4nl_cpu, + quantized_matmul_iq4nl_cuda, + quantized_matmul_iq4nl_metal, + GgmlDType::Iq4Nl +); quantized_matmul!( quantized_matmul_q4k_bis, quantized_matmul_q4k_cpu, @@ -1306,6 +1569,58 @@ fn quantized_matmul_q4k() -> Result<()> { Ok(()) } +#[test] +fn quantized_matmul_iq4xs() -> Result<()> { + use iq_quants::BlockIQ4xs; + + let cpu = &Device::Cpu; + let (m, k, n) = (11, 512, 21); + let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); + + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Iq4Xs)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.442, 1.509, -0.293, 1.631]); + + ggml_matmul_error_test::()?; + + Ok(()) +} + +#[test] +fn quantized_matmul_iq4nl() -> Result<()> { + use iq_quants::BlockIQ4nl; + + let cpu = &Device::Cpu; + let (m, k, n) = (11, 512, 21); + let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); + + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Iq4Nl)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.432, 1.469, -0.312, 1.602]); + + ggml_matmul_error_test::()?; + + Ok(()) +} + #[test] fn quantized_matmul_q5k() -> Result<()> { use k_quants::BlockQ5K; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 7231640081..6d6ece376f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -17,11 +17,7 @@ const CONV: &str = include_str!("conv.metal"); const FILL: &str = include_str!("fill.metal"); const INDEXING: &str = include_str!("indexing.metal"); // Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle -#[cfg(not(target_os = "ios"))] const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); -// Current source: https://github.com/philipturner/metal-flash-attention/releases/tag/v1.0.1 -#[cfg(target_os = "ios")] -const MFA: &[u8] = include_bytes!("libMetalFlashAttention.ios.metallib"); const MLX_GEMM: &str = include_str!("mlx_gemm.metal"); const QUANTIZED: &str = include_str!("quantized.metal"); const RANDOM: &str = include_str!("random.metal"); @@ -2017,12 +2013,7 @@ pub fn call_sdpa_vector( alpha }; - let constants = Some(ConstantValues::new(vec![( - 20, - Value::Bool(/* sdpa_vector_has_mask */ false), - )])); - - let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, &name, constants)?; + let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -2134,13 +2125,7 @@ pub fn call_sdpa_vector_2pass( alpha }; - let constants = Some(ConstantValues::new(vec![( - 20, - Value::Bool(/* sdpa_vector_has_mask */ false), - )])); - - let pipeline = - kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?; + let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass1)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -2462,6 +2447,9 @@ pub enum GgmlDType { F16, F32, BF16, + Iq4Xs, + Iq4Nl, + Iq3Xxs, } #[allow(clippy::too_many_arguments)] @@ -2501,7 +2489,7 @@ pub fn call_quantized_matmul_mv_t( let r2: u32 = (ne12 / ne02) as u32; let r3: u32 = (ne13 / ne03) as u32; - let (nth0, nth1, align) = match dtype { + let (nth0, nth1, align, mem_size_bytes) = match dtype { GgmlDType::Q4_0 | GgmlDType::Q4_1 | GgmlDType::Q5_0 @@ -2511,7 +2499,7 @@ pub fn call_quantized_matmul_mv_t( let nth0 = 8; let nth1 = 8; let align = 8; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::Q2K => { // Fixing a bug in Metal for GGML @@ -2519,38 +2507,50 @@ pub fn call_quantized_matmul_mv_t( let nth0 = 2; let nth1 = 32; let align = 4; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::Q4K => { let nth0 = 4; let nth1 = 8; let align = 4; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::Q3K | GgmlDType::Q5K => { let nth0 = 2; let nth1 = 32; let align = 4; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::Q6K => { let nth0 = 2; let nth1 = 32; let align = 2; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::F16 | GgmlDType::BF16 | GgmlDType::Q8K => { // Original implem uses rows let nth0 = 32; let nth1 = 1; let align = 8; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::F32 => { let nth0 = 32; let nth1 = 1; let align = 8; - (nth0, nth1, align) + (nth0, nth1, align, None) + } + GgmlDType::Iq4Xs | GgmlDType::Iq4Nl => { + let nth0 = 4; + let nth1 = 16; + let align = 4; + (nth0, nth1, align, Some(32*std::mem::size_of::())) + } + GgmlDType::Iq3Xxs => { + let nth0 = 4; + let nth1 = 16; + let align = 8; + (nth0, nth1, align, Some( 256*4+128)) } }; let thread_groups_count = MTLSize { @@ -2579,6 +2579,9 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::F16 => "kernel_mul_mv_f16_f32", GgmlDType::BF16 => "kernel_mul_mv_bf16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", + GgmlDType::Iq4Xs => "kernel_mul_mv_iq4_xs_f32", + GgmlDType::Iq4Nl => "kernel_mul_mv_iq4_nl_f32", + GgmlDType::Iq3Xxs => "kernel_mul_mv_iq3_xxs_f32", }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; @@ -2586,6 +2589,10 @@ pub fn call_quantized_matmul_mv_t( let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); + if let Some(mem_size_bytes) = mem_size_bytes { + encoder.set_threadgroup_memory_length(0, mem_size_bytes as u64); + } + set_params!( encoder, ( @@ -2687,6 +2694,9 @@ pub fn call_quantized_matmul_mm_t( GgmlDType::F16 => "kernel_mul_mm_f16_f32", GgmlDType::BF16 => "kernel_mul_mm_bf16_f32", GgmlDType::F32 => "kernel_mul_mm_f32_f32", + GgmlDType::Iq4Xs => "kernel_mul_mm_iq4_xs_f32", + GgmlDType::Iq4Nl => "kernel_mul_mm_iq4_nl_f32", + GgmlDType::Iq3Xxs => "kernel_mul_mm_iq3_xxs_f32", }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal index a8873ad681..6afad8126b 100644 --- a/candle-metal-kernels/src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -299,8 +299,6 @@ struct MLXScaledDotProductAttentionParams { // ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.sdpa_vector" -constant bool sdpa_vector_has_mask [[function_constant(20)]]; - template [[kernel]] void sdpa_vector( const device T* queries [[buffer(0)]], @@ -313,16 +311,14 @@ template const constant size_t& v_stride, const constant float& scale, const constant float& softcapping, - const device bool* mask [[function_constant(sdpa_vector_has_mask)]], - const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], - const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; constexpr int BD = 32; constexpr int elem_per_thread = D / BD; - constexpr int stride = BN * D; + + const int stride = BN * D; typedef float U; @@ -340,9 +336,6 @@ template queries += head_idx * D + simd_lid * elem_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; - if (sdpa_vector_has_mask) { - mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride; - } out += head_idx * D + simd_gid * elem_per_thread; // Read the query and 0 the output accumulator @@ -358,43 +351,38 @@ template // For each key for (int i = simd_gid; i < N; i += BN) { - if (!sdpa_vector_has_mask || mask[0]) { - // Read the key - for (int j = 0; j < elem_per_thread; j++) { - k[j] = keys[j]; - } + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } - // Compute the i-th score - U score = 0; - for (int j = 0; j < elem_per_thread; j++) { - score += q[j] * k[j]; - } - score = simd_sum(score); - if (softcapping != 1.) { - score = precise::tanh(score); - score = score * softcapping; - } + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + if (softcapping != 1.) { + score = precise::tanh(score); + score = score * softcapping; + } - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; - // Update the output accumulator - for (int j = 0; j < elem_per_thread; j++) { - o[j] = o[j] * factor + exp_score * values[j]; - } + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; } // Move the pointers to the next kv keys += stride; values += stride; - if (sdpa_vector_has_mask) { - mask += BN * mask_seq_stride; - } } // Each thread has a partial part of the output so we need to combine them. @@ -440,9 +428,6 @@ template const constant size_t& v_stride, const constant float& scale, const constant float& softcapping, - const device bool* mask [[function_constant(sdpa_vector_has_mask)]], - const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], - const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -472,10 +457,6 @@ template values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D + simd_lid * elem_per_thread; out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread; - if (sdpa_vector_has_mask) { - mask += head_idx * mask_head_stride + - (block_idx * BN + simd_gid) * mask_seq_stride; - } sums += head_idx * blocks + block_idx; maxs += head_idx * blocks + block_idx; @@ -492,43 +473,75 @@ template // For each key for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { - if (!sdpa_vector_has_mask || mask[0]) { - // Read the key - for (int i = 0; i < elem_per_thread; i++) { - k[i] = keys[i]; - } + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } - // Compute the i-th score - U score = 0; - for (int i = 0; i < elem_per_thread; i++) { - score += q[i] * k[i]; - } - score = simd_sum(score); - if (softcapping != 1.) { - score = precise::tanh(score); - score = score * softcapping; - } + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + if (softcapping != 1.) { + score = precise::tanh(score); + score = score * softcapping; + } - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; - // Update the output accumulator - for (int i = 0; i < elem_per_thread; i++) { - o[i] = o[i] * factor + exp_score * values[i]; - } + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; } // Move the pointers to the next kv keys += blocks * stride; values += blocks * stride; - if (sdpa_vector_has_mask) { - mask += BN * blocks * mask_seq_stride; + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; + sum_exp_score = simd_sum(sum_exp_score * factor); + + // Write the sum and new max + if (simd_gid == 0) { + sums[0] = sum_exp_score; + maxs[0] = new_max; + } + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BN + simd_gid] = + o[i] * fast::exp(max_scores[simd_gid] - new_max); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // And write the output + if (simd_gid == 0) { + U output = outputs[simd_lid * BN]; + for (int j = 1; j < BN; j++) { + output += outputs[simd_lid * BN + j]; + } + out[i] = static_cast(output); } + threadgroup_barrier(mem_flags::mem_threadgroup); } } @@ -1656,9 +1669,6 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); const constant size_t& v_stride, \ const constant float& scale, \ const constant float& softcapping, \ - const device bool* mask [[function_constant(sdpa_vector_has_mask)]], \ - const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \ - const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); \ @@ -1676,9 +1686,6 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); const constant size_t& v_stride, \ const constant float& scale, \ const constant float& softcapping, \ - const device bool* mask [[function_constant(sdpa_vector_has_mask)]], \ - const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \ - const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); \ diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index e77697340e..074e3e194b 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,6 +1,9 @@ -use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; -use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{embedding, Embedding, Module, VarBuilder}; +use super::with_tracing::RmsNorm; +use candle::{ + quantized::{GgmlDType, QMatMul, QTensor}, + DType, Device, IndexOp, Result, Tensor, D, +}; +use candle_nn::{embedding, linear_no_bias as linear, Embedding, Linear, Module, VarBuilder}; use std::{collections::HashMap, f32::consts::PI}; pub const DEFAULT_MAX_SEQ_LEN: usize = 4096; @@ -225,10 +228,10 @@ impl Cache { #[derive(Debug, Clone)] struct CausalSelfAttention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, - o_proj: Linear, + q_proj: QMatMul, + k_proj: QMatMul, + v_proj: QMatMul, + o_proj: QMatMul, num_attention_heads: usize, num_key_value_heads: usize, head_dim: usize, @@ -365,6 +368,14 @@ impl CausalSelfAttention { let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; + + println!("A"); + let q_proj = QMatMul::from_qtensor(QTensor::quantize(q_proj.weight(), GgmlDType::Iq4Xs)?)?; + let k_proj = QMatMul::from_qtensor(QTensor::quantize(k_proj.weight(), GgmlDType::F32)?)?; + let v_proj = QMatMul::from_qtensor(QTensor::quantize(v_proj.weight(), GgmlDType::F32)?)?; + let o_proj = QMatMul::from_qtensor(QTensor::quantize(o_proj.weight(), GgmlDType::F32)?)?; + println!("B"); + Ok(Self { q_proj, k_proj, @@ -511,7 +522,7 @@ impl Llama { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let lm_head = if cfg.tie_word_embeddings { - Linear::from_weights(wte.embeddings().clone(), None) + Linear::new(wte.embeddings().clone(), None) } else { linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? };