diff --git a/.gitmodules b/.gitmodules index 12631cbc27..cae9aba11a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "candle-examples/examples/flash-attn/cutlass"] path = candle-flash-attn/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "candle-flash-attn-v3/cutlass"] + path = candle-flash-attn-v3/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index b2dbd68012..0000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "[python]": { - "editor.defaultFormatter": "ms-python.black-formatter" - }, - "python.formatting.provider": "none", - "python.testing.pytestArgs": [ - "candle-pyo3" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true -} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index f1f10ffb9f..11566c33e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ exclude = [ "candle-book", "candle-flash-attn", + "candle-flash-attn-v3", "candle-kernels", "candle-metal-kernels", "candle-onnx", @@ -36,6 +37,7 @@ byteorder = "1.4.3" candle = { path = "./candle-core", package = "candle-core", version = "0.9.1" } candle-datasets = { path = "./candle-datasets", version = "0.9.1" } candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.1" } +candle-flash-attn-v3 = { path = "./candle-flash-attn-v3", version = "0.9.1" } candle-kernels = { path = "./candle-kernels", version = "0.9.1" } candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.1" } candle-nn = { path = "./candle-nn", version = "0.9.1" } @@ -43,11 +45,12 @@ candle-onnx = { path = "./candle-onnx", version = "0.9.1" } candle-transformers = { path = "./candle-transformers", version = "0.9.1" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } -cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } +cudarc = { version = "0.16.6", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.4.1" half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] } +float8 = { version = "0.3.0", features = ["num-traits"] } hound = "3.5.1" image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] } imageproc = { version = "0.24.0", default-features = false } @@ -61,7 +64,7 @@ parquet = { version = "51.0.0" } rand = "0.9.0" rand_distr = "0.5.1" rayon = "1.7.0" -safetensors = "0.4.1" +safetensors = "0.6.0" serde = { version = "1.0.171", features = ["derive"] } serde_plain = "1.0.2" serde_json = "1.0.99" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index ebd2c51934..54087c99a2 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -18,6 +18,7 @@ metal = { workspace = true, optional = true } cudarc = { workspace = true, optional = true } gemm = { workspace = true } half = { workspace = true } +float8 = { workspace = true } intel-mkl-src = { workspace = true, optional = true } libc = { workspace = true, optional = true } memmap2 = { workspace = true } @@ -43,8 +44,9 @@ criterion = { workspace = true } [features] default = [] -cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"] +cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda", "float8/cuda"] cudnn = ["cuda", "cudarc/cudnn"] +nccl = ["cuda", "cudarc/nccl"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"] diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 34f45d3d22..70b68c7282 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -8,6 +8,8 @@ pub(crate) mod reduce; pub(crate) mod unary; pub(crate) mod where_cond; +#[cfg(feature = "cuda")] +use candle_core::backend::BackendDevice; use candle_core::{Device, Result}; pub(crate) trait BenchDevice { @@ -26,13 +28,13 @@ impl BenchDevice for Device { .synchronize() .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?); #[cfg(not(feature = "cuda"))] - panic!("Cuda device without cuda feature enabled: {:?}", device) + panic!("Cuda device without cuda feature enabled: {device:?}") } Device::Metal(device) => { #[cfg(feature = "metal")] return Ok(device.wait_until_completed()?); #[cfg(not(feature = "metal"))] - panic!("Metal device without metal feature enabled: {:?}", device) + panic!("Metal device without metal feature enabled: {device:?}") } } } diff --git a/candle-core/benches/benchmarks/qmatmul.rs b/candle-core/benches/benchmarks/qmatmul.rs index 4d34588b36..bd1d815fc4 100644 --- a/candle-core/benches/benchmarks/qmatmul.rs +++ b/candle-core/benches/benchmarks/qmatmul.rs @@ -31,7 +31,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) { let flops = b * m * n * k; - let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype))); + let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{dtype:?}"))); group.sample_size(200); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |b| { diff --git a/candle-core/benches/benchmarks/reduce.rs b/candle-core/benches/benchmarks/reduce.rs index e0755a7080..4bbe5e829f 100644 --- a/candle-core/benches/benchmarks/reduce.rs +++ b/candle-core/benches/benchmarks/reduce.rs @@ -44,12 +44,12 @@ fn run_reduce( let k = 1024; let a = if strided { - Tensor::rand(lo, up, (b, m, k), &device) + Tensor::rand(lo, up, (b, m, k), device) .unwrap() .transpose(0, 2) .unwrap() } else { - Tensor::rand(lo, up, (b, m, k), &device).unwrap() + Tensor::rand(lo, up, (b, m, k), device).unwrap() }; let flops = b * m * k * T::DTYPE.size_in_bytes(); @@ -105,12 +105,12 @@ fn run_arg_reduce( let k = 1024; let a = if strided { - Tensor::rand(lo, up, (b, m, k), &device) + Tensor::rand(lo, up, (b, m, k), device) .unwrap() .transpose(0, 2) .unwrap() } else { - Tensor::rand(lo, up, (b, m, k), &device).unwrap() + Tensor::rand(lo, up, (b, m, k), device).unwrap() }; let flops = b * m * k * T::DTYPE.size_in_bytes(); diff --git a/candle-core/benches/benchmarks/unary.rs b/candle-core/benches/benchmarks/unary.rs index 9efd75093d..616ee1609d 100644 --- a/candle-core/benches/benchmarks/unary.rs +++ b/candle-core/benches/benchmarks/unary.rs @@ -40,7 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { let handler = BenchDeviceHandler::new().unwrap(); for device in handler.devices { for dtype in [DType::F32, DType::BF16, DType::F16] { - let name = format!("sqrt_{:?}", dtype); + let name = format!("sqrt_{dtype:?}"); run_unary_benchmark(c, &device, dtype, &name); } } diff --git a/candle-core/benches/benchmarks/where_cond.rs b/candle-core/benches/benchmarks/where_cond.rs index 0e91f656fc..2f97d9cac7 100644 --- a/candle-core/benches/benchmarks/where_cond.rs +++ b/candle-core/benches/benchmarks/where_cond.rs @@ -22,7 +22,7 @@ const M: usize = 1024; const K: usize = 1024; const SIZE: usize = B * M * K; -const DATA: [u8; SIZE] = create_cond_arr::(); +static DATA: [u8; SIZE] = create_cond_arr::(); fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap(); diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index a85f8d36d2..b61d46d2de 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -158,6 +158,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone { fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; fn set_seed(&self, _: u64) -> Result<()>; + fn get_current_seed(&self) -> Result; /// Synchronize should block until all the operations on the device are completed. fn synchronize(&self) -> Result<()>; diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index 5ea5612a7c..38e7a7c9a6 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -93,6 +93,8 @@ from_tensor!(f32); from_tensor!(f16); from_tensor!(bf16); from_tensor!(i64); +from_tensor!(i32); +from_tensor!(i16); from_tensor!(u32); from_tensor!(u8); @@ -130,6 +132,16 @@ impl Tensor { f.write_u32::(v)? } } + DType::I16 => { + for v in vs.to_vec1::()? { + f.write_i16::(v)? + } + } + DType::I32 => { + for v in vs.to_vec1::()? { + f.write_i32::(v)? + } + } DType::I64 => { for v in vs.to_vec1::()? { f.write_i64::(v)? @@ -139,6 +151,15 @@ impl Tensor { let vs = vs.to_vec1::()?; f.write_all(&vs)?; } + DType::F8E4M3 => { + let vs = vs.to_vec1::()?; + for v in vs { + f.write_u8(v.to_bits())? + } + } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err(crate::Error::UnsupportedDTypeForOp(self.dtype(), "write_bytes").bt()) + } } Ok(()) } diff --git a/candle-core/src/cpu/avx.rs b/candle-core/src/cpu/avx.rs index 9398a3460a..113fc14ced 100644 --- a/candle-core/src/cpu/avx.rs +++ b/candle-core/src/cpu/avx.rs @@ -1,10 +1,10 @@ -use super::{Cpu, CpuF16}; +use super::{Cpu, CpuBF16, CpuF16}; #[cfg(target_arch = "x86")] use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -use half::f16; +use half::{bf16, f16}; pub struct CurrentCpu {} @@ -146,3 +146,82 @@ impl CpuF16 for CurrentCpuF16 { *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); } } + +pub struct CurrentCpuBF16 {} +impl CpuBF16 for CurrentCpuBF16 { + type Unit = __m256; + type Array = [__m256; ARR]; + + const STEP: usize = STEP; + const EPR: usize = EPR; + + fn n() -> usize { + ARR + } + + unsafe fn zero() -> Self::Unit { + _mm256_setzero_ps() + } + + unsafe fn zero_array() -> Self::Array { + [Self::zero(); ARR] + } + + unsafe fn from_f32(v: f32) -> Self::Unit { + _mm256_set1_ps(v) + } + + #[cfg(target_feature = "f16c")] + unsafe fn load(mem_addr: *const bf16) -> Self::Unit { + _mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i)) + } + + #[cfg(not(target_feature = "f16c"))] + unsafe fn load(mem_addr: *const bf16) -> Self::Unit { + let mut tmp = [0.0f32; 8]; + for i in 0..8 { + tmp[i] = (*mem_addr.add(i)).to_f32(); + } + _mm256_loadu_ps(tmp.as_ptr()) + } + + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit { + _mm256_add_ps(a, b) + } + + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit { + _mm256_add_ps(_mm256_mul_ps(b, c), a) + } + + #[cfg(target_feature = "f16c")] + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) { + _mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0)) + } + + #[cfg(not(target_feature = "f16c"))] + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) { + let mut tmp = [0.0f32; 8]; + _mm256_storeu_ps(tmp.as_mut_ptr(), a); + for i in 0..8 { + *mem_addr.add(i) = bf16::from_f32(tmp[i]); + } + } + + unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) { + let mut offset = ARR >> 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + offset >>= 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + offset >>= 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1)); + let t1 = _mm_hadd_ps(t0, t0); + *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); + } +} diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs index 527646d62b..bca76adcc8 100644 --- a/candle-core/src/cpu/kernels.rs +++ b/candle-core/src/cpu/kernels.rs @@ -121,6 +121,13 @@ impl VecOps for half::bf16 { fn max(self, other: Self) -> Self { Self::max(self, other) } + + #[inline(always)] + unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { + let mut res_f32 = 0f32; + super::vec_dot_bf16(lhs, rhs, &mut res_f32, len); + *res = half::bf16::from_f32(res_f32); + } } impl VecOps for u8 { #[inline(always)] @@ -144,6 +151,28 @@ impl VecOps for u32 { ::max(self, other) } } +impl VecOps for i16 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} +impl VecOps for i32 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} impl VecOps for i64 { #[inline(always)] fn min(self, other: Self) -> Self { @@ -156,6 +185,18 @@ impl VecOps for i64 { } } +impl VecOps for float8::F8E4M3 { + #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } +} + #[inline(always)] pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) { if n_threads == 1 { diff --git a/candle-core/src/cpu/mod.rs b/candle-core/src/cpu/mod.rs index be5b99128e..0b77e6ecb7 100644 --- a/candle-core/src/cpu/mod.rs +++ b/candle-core/src/cpu/mod.rs @@ -1,5 +1,3 @@ -//! Traits and methods for CPU-backed Tensors - pub mod erf; pub mod kernels; @@ -38,14 +36,33 @@ trait CpuF16 { unsafe fn from_f32(v: f32) -> Self::Unit; unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit); } -use half::f16; + +#[allow(unused)] +trait CpuBF16 { + type Unit; + type Array; + const STEP: usize; + const EPR: usize; + + fn n() -> usize; + unsafe fn zero() -> Self::Unit; + unsafe fn zero_array() -> Self::Array; + unsafe fn load(mem_addr: *const bf16) -> Self::Unit; + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit; + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit; + unsafe fn vec_reduce(x: Self::Array, y: *mut f32); + unsafe fn from_f32(v: f32) -> Self::Unit; + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit); +} + +use half::{bf16, f16}; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(target_feature = "avx")] pub mod avx; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(target_feature = "avx")] -pub use avx::{CurrentCpu, CurrentCpuF16}; +pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16}; #[cfg(target_arch = "wasm32")] #[cfg(target_feature = "simd128")] @@ -172,6 +189,34 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f *c = sumf; } +#[cfg(target_feature = "avx")] +#[inline(always)] +pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) { + let mut sumf = 0.0f32; + let np = k & !(CurrentCpuBF16::STEP - 1); + + let mut sum = CurrentCpuBF16::zero_array(); + let mut ax = CurrentCpuBF16::zero_array(); + let mut ay = CurrentCpuBF16::zero_array(); + + for i in (0..np).step_by(CurrentCpuBF16::STEP) { + for j in 0..CurrentCpuBF16::n() { + ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR)); + ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR)); + + sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]); + } + } + + CurrentCpuBF16::vec_reduce(sum, &mut sumf); + + // leftovers + for i in np..k { + sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32(); + } + *c = sumf; +} + #[cfg(not(target_feature = "avx"))] #[inline(always)] pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) { @@ -182,3 +227,14 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f } *c = sum; } + +#[cfg(not(target_feature = "avx"))] +#[inline(always)] +pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) { + // leftovers + let mut sum = 0.0; + for i in 0..k { + sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32(); + } + *c = sum; +} diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index af7cb5bd4f..7d35c9e52a 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -2,6 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; +use float8::F8E4M3 as f8e4m3; use half::{bf16, f16}; use rayon::prelude::*; @@ -20,22 +21,38 @@ const USE_IM2COL_CONV2D: bool = true; pub enum CpuStorage { U8(Vec), U32(Vec), + I16(Vec), + I32(Vec), I64(Vec), BF16(Vec), F16(Vec), F32(Vec), F64(Vec), + F8E4M3(Vec), + // Dummy types that store raw bytes + F6E2M3(Vec), + F6E3M2(Vec), + F4(Vec), + F8E8M0(Vec), } #[derive(Debug, Clone)] pub enum CpuStorageRef<'a> { U8(&'a [u8]), U32(&'a [u32]), + I16(&'a [i16]), + I32(&'a [i32]), I64(&'a [i64]), BF16(&'a [bf16]), F16(&'a [f16]), F32(&'a [f32]), F64(&'a [f64]), + F8E4M3(&'a [f8e4m3]), + // Dummy types that store raw bytes + F6E2M3(&'a [u8]), + F6E3M2(&'a [u8]), + F4(&'a [u8]), + F8E8M0(&'a [u8]), } #[derive(Debug, Clone)] @@ -1636,6 +1653,28 @@ impl CpuStorage { .concat(); Self::U32(storages) } + Self::I16(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I16(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I16(storages) + } + Self::I32(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I32(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I32(storages) + } Self::I64(_) => { let storages = storages .iter() @@ -1691,6 +1730,61 @@ impl CpuStorage { .concat(); Self::F64(storages) } + Self::F8E4M3(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F8E4M3(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F8E4M3(storages) + } + Self::F6E2M3(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F6E2M3(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F6E2M3(storages) + } + Self::F6E3M2(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F6E3M2(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F6E3M2(storages) + } + Self::F4(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F4(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F4(storages) + } + Self::F8E8M0(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F8E8M0(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F8E8M0(storages) + } }; Ok(s) } @@ -1703,11 +1797,18 @@ impl BackendStorage for CpuStorage { match self { Self::U8(_) => DType::U8, Self::U32(_) => DType::U32, + Self::I16(_) => DType::I16, + Self::I32(_) => DType::I32, Self::I64(_) => DType::I64, Self::BF16(_) => DType::BF16, Self::F16(_) => DType::F16, Self::F32(_) => DType::F32, Self::F64(_) => DType::F64, + Self::F8E4M3(_) => DType::F8E4M3, + Self::F6E2M3(_) => DType::F6E2M3, + Self::F6E3M2(_) => DType::F6E3M2, + Self::F4(_) => DType::F4, + Self::F8E8M0(_) => DType::F8E8M0, } } @@ -1910,6 +2011,226 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::F64(data)) } + // Conversions to F8E4M3 + (Self::U8(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::U32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::I64(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::BF16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v.to_f32())); + Ok(Self::F8E4M3(data)) + } + (Self::F16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v.to_f32())); + Ok(Self::F8E4M3(data)) + } + (Self::F32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, f8e4m3::from_f32); + Ok(Self::F8E4M3(data)) + } + (Self::F64(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, f8e4m3::from_f64); + Ok(Self::F8E4M3(data)) + } + (Self::F8E4M3(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::F8E4M3(data)) + } + // Conversions from F8E4M3 + (Self::F8E4M3(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } + (Self::F8E4M3(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u32); + Ok(Self::U32(data)) + } + (Self::F8E4M3(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i64); + Ok(Self::I64(data)) + } + (Self::F8E4M3(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); + Ok(Self::BF16(data)) + } + (Self::F8E4M3(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); + Ok(Self::F16(data)) + } + (Self::F8E4M3(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v.to_f32()); + Ok(Self::F32(data)) + } + (Self::F8E4M3(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v.to_f64()); + Ok(Self::F64(data)) + } + // Conversions to I16 + (Self::U8(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::U32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::I16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I16(data)) + } + (Self::I32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::I64(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::BF16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::F16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::F32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::F64(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::F8E4M3(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + // Conversions to I32 + (Self::U8(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::U32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::I16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::I32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I32(data)) + } + (Self::I64(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::BF16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::F16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::F32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::F64(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::F8E4M3(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + // Conversions from I16 + (Self::I16(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::I16(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::I16(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I16(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::I16(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::I16(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::I16(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::I16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + // Conversions from I32 + (Self::I32(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::I32(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::I32(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I32(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::I32(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::I32(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::I32(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::I32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| f8e4m3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + // Dummy types - return error for all conversions to/from dummy types + (_, DType::F6E2M3) | (_, DType::F6E3M2) | (_, DType::F4) | (_, DType::F8E8M0) => { + Err(Error::UnsupportedDTypeForOp(dtype, "to_dtype").bt()) + } + (Self::F6E2M3(_), _) + | (Self::F6E3M2(_), _) + | (Self::F4(_), _) + | (Self::F8E8M0(_), _) => { + Err(Error::UnsupportedDTypeForOp(self.dtype(), "to_dtype").bt()) + } } } @@ -2023,9 +2344,19 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v.powf(e)); Ok(Self::F64(data)) } - Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), - Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), - Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, |v| v.powf(f8e4m3::from_f64(e))); + Ok(Self::F8E4M3(data)) + } + Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "powf").bt()), + Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "powf").bt()), + Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "powf").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "powf").bt()), + Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "powf").bt()), + Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "powf").bt()), + Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "powf").bt()), + Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "powf").bt()), + Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "powf").bt()), } } @@ -2048,9 +2379,19 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| elu(v, alpha)); Ok(Self::F64(data)) } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, |v| elu(v, f8e4m3::from_f64(alpha))); + Ok(Self::F8E4M3(data)) + } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), + Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "elu").bt()), + Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "elu").bt()), + Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "elu").bt()), + Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "elu").bt()), } } @@ -2100,10 +2441,26 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, B::u32); Ok(Self::U32(data)) } + Self::I16(storage) => { + let data = unary_map(storage, layout, B::i16); + Ok(Self::I16(data)) + } + Self::I32(storage) => { + let data = unary_map(storage, layout, B::i32); + Ok(Self::I32(data)) + } Self::I64(storage) => { let data = unary_map(storage, layout, B::i64); Ok(Self::I64(data)) } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, B::f8e4m3); + Ok(Self::F8E4M3(data)) + } + Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "unary").bt()), + Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "unary").bt()), + Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "unary").bt()), + Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "unary").bt()), } } @@ -2154,6 +2511,14 @@ impl BackendStorage for CpuStorage { }; Ok(Self::U32(data)) } + (Self::I16(lhs), Self::I16(rhs)) => { + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::i16); + Ok(Self::I16(data)) + } + (Self::I32(lhs), Self::I32(rhs)) => { + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::i32); + Ok(Self::I32(data)) + } (Self::I64(lhs), Self::I64(rhs)) => { let data = if B::I64_VEC { binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec) @@ -2170,6 +2535,10 @@ impl BackendStorage for CpuStorage { }; Ok(Self::U8(data)) } + (Self::F8E4M3(lhs), Self::F8E4M3(rhs)) => { + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f8e4m3); + Ok(Self::F8E4M3(data)) + } _ => { // This should be covered by the dtype check above. Err(Error::DTypeMismatchBinaryOp { @@ -2197,6 +2566,12 @@ impl BackendStorage for CpuStorage { (Self::U32(src), Self::U32(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } + (Self::I16(src), Self::I16(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::I32(src), Self::I32(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } (Self::I64(src), Self::I64(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } @@ -2212,6 +2587,19 @@ impl BackendStorage for CpuStorage { (Self::F64(src), Self::F64(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } + (Self::F8E4M3(src), Self::F8E4M3(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F6E2M3(src), Self::F6E2M3(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F6E3M2(src), Self::F6E3M2(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F4(src), Self::F4(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o), + (Self::F8E8M0(src), Self::F8E8M0(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } (_, dst) => { return Err(Error::DTypeMismatchBinaryOp { lhs: self.dtype(), @@ -2228,11 +2616,26 @@ impl BackendStorage for CpuStorage { match (self, dst) { (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I16(src), Self::I16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F8E4M3(src), Self::F8E4M3(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_l) + } + (Self::F6E2M3(src), Self::F6E2M3(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_l) + } + (Self::F6E3M2(src), Self::F6E3M2(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_l) + } + (Self::F4(src), Self::F4(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F8E8M0(src), Self::F8E8M0(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_l) + } (_, dst) => { // This should be covered by the dtype check above. return Err(Error::DTypeMismatchBinaryOp { @@ -2257,6 +2660,8 @@ impl BackendStorage for CpuStorage { match self { Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I16(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")), } @@ -2469,6 +2874,8 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), + Self::I16(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), + Self::I32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()), } @@ -2498,6 +2905,20 @@ impl BackendStorage for CpuStorage { }; IndexAdd { ids, dim }.map(self, l, src, src_l) } + Self::I16(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } + Self::I32(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } Self::I64(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], @@ -2563,7 +2984,23 @@ impl BackendStorage for CpuStorage { (Self::F64(storage), Scalar::F64(v)) => set(storage, l, v), (Self::U8(storage), Scalar::U8(v)) => set(storage, l, v), (Self::U32(storage), Scalar::U32(v)) => set(storage, l, v), + (Self::I16(storage), Scalar::I16(v)) => set(storage, l, v), + (Self::I32(storage), Scalar::I32(v)) => set(storage, l, v), (Self::I64(storage), Scalar::I64(v)) => set(storage, l, v), + (Self::F8E4M3(storage), Scalar::F8E4M3(v)) => set(storage, l, v), + // Dummy types don't support scalar operations + (Self::F6E2M3(_), _) => { + crate::bail!("const_set not supported for dummy type F6E2M3") + } + (Self::F6E3M2(_), _) => { + crate::bail!("const_set not supported for dummy type F6E3M2") + } + (Self::F4(_), _) => { + crate::bail!("const_set not supported for dummy type F4") + } + (Self::F8E8M0(_), _) => { + crate::bail!("const_set not supported for dummy type F8E8M0") + } (st, s) => crate::bail!( "const_set dtype mismatch, expected {:?} but got {:?}", st.dtype(), @@ -2605,15 +3042,26 @@ impl BackendDevice for CpuDevice { crate::bail!("cannot seed the CPU rng with set_seed") } + fn get_current_seed(&self) -> Result { + crate::bail!("cannot get the CPU rng seed with get_current_seed") + } + fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result { use rand::prelude::*; let elem_count = shape.elem_count(); let mut rng = rand::rng(); match dtype { - DType::U8 | DType::U32 | DType::I64 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()) - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F8E4M3 + | DType::F6E2M3 + | DType::F6E3M2 + | DType::F4 + | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()), DType::BF16 => { let mut data = Vec::with_capacity(elem_count); let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max)) @@ -2658,9 +3106,16 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::rng(); match dtype { - DType::U8 | DType::U32 | DType::I64 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F8E4M3 + | DType::F6E2M3 + | DType::F6E3M2 + | DType::F4 + | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()), DType::BF16 => { let mut data = Vec::with_capacity(elem_count); let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std)) @@ -2717,6 +3172,16 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::U32(v) } + DType::I16 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I16(v) + } + DType::I32 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I32(v) + } DType::I64 => { let mut v = Vec::with_capacity(elem_count); v.set_len(elem_count); @@ -2742,6 +3207,14 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::F64(v) } + DType::F8E4M3 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F8E4M3(v) + } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err(Error::UnsupportedDTypeForOp(dtype, "alloc_uninit").bt()) + } }; Ok(storage) } @@ -2751,11 +3224,17 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![0u8; elem_count]), DType::U32 => CpuStorage::U32(vec![0u32; elem_count]), + DType::I16 => CpuStorage::I16(vec![0i16; elem_count]), + DType::I32 => CpuStorage::I32(vec![0i32; elem_count]), DType::I64 => CpuStorage::I64(vec![0i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]), DType::F32 => CpuStorage::F32(vec![0f32; elem_count]), DType::F64 => CpuStorage::F64(vec![0f64; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![f8e4m3::ZERO; elem_count]), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err(Error::UnsupportedDTypeForOp(dtype, "zeros").bt()) + } }; Ok(storage) } diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index c404c3ad99..1f800a928b 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -10,11 +10,19 @@ pub trait Map1 { match vs { C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)), C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)), + C::I16(vs) => Ok(C::I16(self.f(vs, layout)?)), + C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)), C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)), C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)), C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)), C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)), C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)), + C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)), + // Dummy types don't support Map1 operations + C::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()), + C::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()), + C::F4(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()), + C::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()), } } } @@ -26,11 +34,19 @@ pub trait Map1Any { match vs { C::U8(vs) => Ok(self.f(vs, layout, C::U8)?), C::U32(vs) => Ok(self.f(vs, layout, C::U32)?), + C::I16(vs) => Ok(self.f(vs, layout, C::I16)?), + C::I32(vs) => Ok(self.f(vs, layout, C::I32)?), C::I64(vs) => Ok(self.f(vs, layout, C::I64)?), C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?), C::F16(vs) => Ok(self.f(vs, layout, C::F16)?), C::F32(vs) => Ok(self.f(vs, layout, C::F32)?), C::F64(vs) => Ok(self.f(vs, layout, C::F64)?), + C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?), + // Dummy types don't support Map1Any operations + C::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()), + C::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()), + C::F4(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()), + C::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()), } } } @@ -43,11 +59,14 @@ pub trait Map2 { match (v1, v2) { (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)), + (C::I16(v1), C::I16(v2)) => Ok(C::I16(self.f(v1, l1, v2, l2)?)), + (C::I32(v1), C::I32(v2)) => Ok(C::I32(self.f(v1, l1, v2, l2)?)), (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)), (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)), (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), @@ -66,11 +85,14 @@ pub trait Map2InPlace { match (v1, v2) { (C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?, (C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?, + (C::I16(v1), C::I16(v2)) => self.f(v1, l1, v2, l2)?, + (C::I32(v1), C::I32(v2)) => self.f(v1, l1, v2, l2)?, (C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?, (C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?, (C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?, (C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?, (C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?, + (C::F8E4M3(v1), C::F8E4M3(v2)) => self.f(v1, l1, v2, l2)?, (v1, v2) => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), @@ -90,11 +112,14 @@ pub trait Map2U8 { match (v1, v2) { (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::I16(v1), C::I16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::I32(v1), C::I32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index ba3267e03a..b3526ed7e5 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -1,11 +1,11 @@ -use crate::backend::BackendDevice; +use crate::backend::{BackendDevice, BackendStorage}; use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; pub use candle_kernels as kernels; pub use cudarc; use cudarc::driver::CudaFunction; use half::{bf16, f16}; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr}; @@ -38,6 +38,7 @@ pub struct CudaDevice { stream: Arc, pub(crate) blas: Arc, curand: Arc>, + seed_value: Arc>, } impl std::fmt::Debug for CudaDevice { @@ -93,6 +94,18 @@ impl CudaDevice { self.stream.memcpy_dtod(src, dst).w() } + pub fn memcpy_dtoh< + T: cudarc::driver::DeviceRepr, + Src: cudarc::driver::DevicePtr, + Dst: cudarc::driver::HostSlice, + >( + &self, + src: &Src, + dst: &mut Dst, + ) -> Result<()> { + self.stream.memcpy_dtoh(src, dst).w() + } + pub fn memcpy_stod< T: cudarc::driver::DeviceRepr, Src: cudarc::driver::HostSlice + ?Sized, @@ -232,6 +245,10 @@ impl CudaDevice { stream: self.stream.clone(), }) } + + pub fn cublas_handle(&self) -> Arc { + self.blas.clone() + } } impl CudaDevice { @@ -251,6 +268,7 @@ impl CudaDevice { curand: Arc::new(Mutex::new(CudaRng(curand))), modules: Arc::new(std::sync::RwLock::new(module_store)), custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())), + seed_value: Arc::new(RwLock::new(299792458)), }) } } @@ -274,6 +292,7 @@ impl BackendDevice for CudaDevice { curand: Arc::new(Mutex::new(CudaRng(curand))), modules: Arc::new(std::sync::RwLock::new(module_store)), custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())), + seed_value: Arc::new(RwLock::new(299792458)), }) } @@ -282,9 +301,14 @@ impl BackendDevice for CudaDevice { // state will be identical and the same random numbers will be generated. let mut curand = self.curand.lock().unwrap(); curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?; + *self.seed_value.write().unwrap() = seed; Ok(()) } + fn get_current_seed(&self) -> Result { + Ok(*self.seed_value.read().unwrap()) + } + fn location(&self) -> crate::DeviceLocation { crate::DeviceLocation::Cuda { gpu_id: self.context.ordinal(), @@ -306,6 +330,14 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::U32(data) } + DType::I16 => { + let data = self.alloc_zeros::(elem_count)?; + CudaStorageSlice::I16(data) + } + DType::I32 => { + let data = self.alloc_zeros::(elem_count)?; + CudaStorageSlice::I32(data) + } DType::I64 => { let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::I64(data) @@ -326,6 +358,14 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + return Err(CudaError::InternalError("F8E4M3 not supported in CUDA backend").into()) + } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err( + CudaError::InternalError("Dummy types not supported in CUDA backend").into(), + ) + } }; Ok(CudaStorage { slice, @@ -339,13 +379,17 @@ impl BackendDevice for CudaDevice { let slice = match dtype { // TODO: Add support for F16 and BF16 though this is likely to require some upstream // cudarc changes. - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_uniform", - }) - .w()? - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F16 + | DType::BF16 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + }) + .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count)? }; curand.0.fill_with_uniform(&mut data).w()?; @@ -356,6 +400,13 @@ impl BackendDevice for CudaDevice { curand.0.fill_with_uniform(&mut data).w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 | DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + }) + .w()? + } }; let slice = if lo == 0. && up == 1.0 { slice @@ -383,13 +434,17 @@ impl BackendDevice for CudaDevice { elem_count }; let slice = match dtype { - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_normal", - }) - .w()? - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F16 + | DType::BF16 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count_round)? }; curand @@ -403,6 +458,13 @@ impl BackendDevice for CudaDevice { curand.0.fill_with_normal(&mut data, mean, std).w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 | DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()? + } }; Ok(CudaStorage { slice, @@ -421,6 +483,14 @@ impl BackendDevice for CudaDevice { let data = self.alloc::(elem_count)?; CudaStorageSlice::U32(data) } + DType::I16 => { + let data = self.alloc::(elem_count)?; + CudaStorageSlice::I16(data) + } + DType::I32 => { + let data = self.alloc::(elem_count)?; + CudaStorageSlice::I32(data) + } DType::I64 => { let data = self.alloc::(elem_count)?; CudaStorageSlice::I64(data) @@ -441,6 +511,14 @@ impl BackendDevice for CudaDevice { let data = self.alloc::(elem_count)?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + return Err(CudaError::InternalError("F8E4M3 not supported in CUDA backend").into()) + } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err( + CudaError::InternalError("Dummy types not supported in CUDA backend").into(), + ) + } }; Ok(CudaStorage { slice, @@ -458,6 +536,14 @@ impl BackendDevice for CudaDevice { let data = self.memcpy_stod(storage)?; CudaStorageSlice::U32(data) } + CpuStorageRef::I16(storage) => { + let data = self.memcpy_stod(storage)?; + CudaStorageSlice::I16(data) + } + CpuStorageRef::I32(storage) => { + let data = self.memcpy_stod(storage)?; + CudaStorageSlice::I32(data) + } CpuStorageRef::I64(storage) => { let data = self.memcpy_stod(storage)?; CudaStorageSlice::I64(data) @@ -478,6 +564,20 @@ impl BackendDevice for CudaDevice { let data = self.memcpy_stod(storage)?; CudaStorageSlice::F64(data) } + CpuStorageRef::F8E4M3(storage) => { + let data = self.memcpy_stod(storage)?; + CudaStorageSlice::F8E4M3(data) + } + CpuStorageRef::F4(_) + | CpuStorageRef::F6E2M3(_) + | CpuStorageRef::F6E3M2(_) + | CpuStorageRef::F8E8M0(_) => { + return Err(CudaError::UnsupportedDtype { + dtype: T::DTYPE, + op: "storage_from_slice", + } + .into()); + } }; Ok(CudaStorage { slice, @@ -495,6 +595,14 @@ impl BackendDevice for CudaDevice { let data = self.memcpy_stod(storage)?; CudaStorageSlice::U32(data) } + CpuStorage::I16(storage) => { + let data = self.memcpy_stod(storage)?; + CudaStorageSlice::I16(data) + } + CpuStorage::I32(storage) => { + let data = self.memcpy_stod(storage)?; + CudaStorageSlice::I32(data) + } CpuStorage::I64(storage) => { let data = self.memcpy_stod(storage)?; CudaStorageSlice::I64(data) @@ -515,6 +623,20 @@ impl BackendDevice for CudaDevice { let data = self.memcpy_stod(storage)?; CudaStorageSlice::F64(data) } + CpuStorage::F8E4M3(storage) => { + let data = self.memcpy_stod(storage)?; + CudaStorageSlice::F8E4M3(data) + } + CpuStorage::F4(_) + | CpuStorage::F6E2M3(_) + | CpuStorage::F6E3M2(_) + | CpuStorage::F8E8M0(_) => { + return Err(CudaError::UnsupportedDtype { + dtype: storage.dtype(), + op: "storage_from_cpu_storage", + } + .into()); + } }; Ok(CudaStorage { slice, @@ -532,6 +654,14 @@ impl BackendDevice for CudaDevice { let data = self.memcpy_stod(&storage)?; CudaStorageSlice::U32(data) } + CpuStorage::I16(storage) => { + let data = self.memcpy_stod(&storage)?; + CudaStorageSlice::I16(data) + } + CpuStorage::I32(storage) => { + let data = self.memcpy_stod(&storage)?; + CudaStorageSlice::I32(data) + } CpuStorage::I64(storage) => { let data = self.memcpy_stod(&storage)?; CudaStorageSlice::I64(data) @@ -552,6 +682,20 @@ impl BackendDevice for CudaDevice { let data = self.memcpy_stod(&storage)?; CudaStorageSlice::F64(data) } + CpuStorage::F8E4M3(storage) => { + let data = self.memcpy_stod(&storage)?; + CudaStorageSlice::F8E4M3(data) + } + CpuStorage::F4(_) + | CpuStorage::F6E2M3(_) + | CpuStorage::F6E3M2(_) + | CpuStorage::F8E8M0(_) => { + return Err(CudaError::UnsupportedDtype { + dtype: storage.dtype(), + op: "storage_from_cpu_storage_owned", + } + .into()); + } }; Ok(CudaStorage { slice, diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 95987ba033..51edd5de44 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -40,11 +40,14 @@ impl crate::scalar::Scalar { match self { Scalar::U8(v) => builder.arg(v), Scalar::U32(v) => builder.arg(v), + Scalar::I16(v) => builder.arg(v), + Scalar::I32(v) => builder.arg(v), Scalar::I64(v) => builder.arg(v), Scalar::F32(v) => builder.arg(v), Scalar::F64(v) => builder.arg(v), Scalar::F16(v) => builder.arg(v), Scalar::BF16(v) => builder.arg(v), + Scalar::F8E4M3(v) => builder.arg(v), }; } } @@ -64,11 +67,19 @@ impl SlicePtrOrNull { pub enum CudaStorageSlice { U8(CudaSlice), U32(CudaSlice), + I16(CudaSlice), + I32(CudaSlice), I64(CudaSlice), BF16(CudaSlice), F16(CudaSlice), F32(CudaSlice), F64(CudaSlice), + F8E4M3(CudaSlice), + // Dummy types that store raw bytes + F6E2M3(CudaSlice), + F6E3M2(CudaSlice), + F4(CudaSlice), + F8E8M0(CudaSlice), } struct Clone; @@ -1173,11 +1184,14 @@ macro_rules! cuda_dtype { } cuda_dtype!(u8, U8); cuda_dtype!(u32, U32); +cuda_dtype!(i16, I16); +cuda_dtype!(i32, I32); cuda_dtype!(i64, I64); cuda_dtype!(f16, F16); cuda_dtype!(bf16, BF16); cuda_dtype!(f32, F32); cuda_dtype!(f64, F64); +cuda_dtype!(float8::F8E4M3, F8E4M3); impl CudaStorage { pub fn wrap_cuda_slice(slice: CudaSlice, device: CudaDevice) -> CudaStorage { @@ -1298,11 +1312,18 @@ impl BackendStorage for CudaStorage { match self.slice { CudaStorageSlice::U8(_) => DType::U8, CudaStorageSlice::U32(_) => DType::U32, + CudaStorageSlice::I16(_) => DType::I16, + CudaStorageSlice::I32(_) => DType::I32, CudaStorageSlice::I64(_) => DType::I64, CudaStorageSlice::BF16(_) => DType::BF16, CudaStorageSlice::F16(_) => DType::F16, CudaStorageSlice::F32(_) => DType::F32, CudaStorageSlice::F64(_) => DType::F64, + CudaStorageSlice::F8E4M3(_) => DType::F8E4M3, + CudaStorageSlice::F6E2M3(_) => DType::F6E2M3, + CudaStorageSlice::F6E3M2(_) => DType::F6E3M2, + CudaStorageSlice::F4(_) => DType::F4, + CudaStorageSlice::F8E8M0(_) => DType::F8E8M0, } } @@ -1321,11 +1342,21 @@ impl BackendStorage for CudaStorage { let ((src, _guard_src), kernel_name) = match &mut self.slice { S::U8(s) => (slice_ptr(s, src_o), "const_set_u8"), S::U32(s) => (slice_ptr(s, src_o), "const_set_u32"), + S::I16(s) => (slice_ptr(s, src_o), "const_set_i16"), + S::I32(s) => (slice_ptr(s, src_o), "const_set_i32"), S::I64(s) => (slice_ptr(s, src_o), "const_set_i64"), S::BF16(s) => (slice_ptr(s, src_o), "const_set_bf16"), S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"), S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"), S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"), + S::F8E4M3(s) => (slice_ptr(s, src_o), "const_set_f8e4m3"), + S::F4(_) | S::F6E2M3(_) | S::F6E3M2(_) | S::F8E8M0(_) => { + return Err(CudaError::UnsupportedDtype { + dtype: self.dtype(), + op: "const_set", + } + .into()); + } }; let func = dev.get_or_load_func(kernel_name, &kernels::FILL)?; @@ -1354,11 +1385,24 @@ impl BackendStorage for CudaStorage { let (inp, _guard) = match &self.slice { CudaStorageSlice::U8(inp) => slice_ptr(inp, start_o), CudaStorageSlice::U32(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::I16(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::I32(inp) => slice_ptr(inp, start_o), CudaStorageSlice::I64(inp) => slice_ptr(inp, start_o), CudaStorageSlice::BF16(inp) => slice_ptr(inp, start_o), CudaStorageSlice::F16(inp) => slice_ptr(inp, start_o), CudaStorageSlice::F32(inp) => slice_ptr(inp, start_o), CudaStorageSlice::F64(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::F8E4M3(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::F4(_) + | CudaStorageSlice::F6E2M3(_) + | CudaStorageSlice::F6E3M2(_) + | CudaStorageSlice::F8E8M0(_) => { + return Err(CudaError::UnsupportedDtype { + dtype: self.dtype(), + op: "to_dtype", + } + .into()); + } }; let inp = &inp; @@ -1442,6 +1486,25 @@ impl BackendStorage for CudaStorage { unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::F64(out) } + DType::F8E4M3 => { + let out = unsafe { dev.alloc::(el)? }; + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + barg!(builder, *inp); + builder.arg(&out); + unsafe { builder.launch(cfg) }.w()?; + CudaStorageSlice::F8E4M3(out) + } + DType::I16 | DType::I32 => { + return Err(CudaError::InternalError("i16,i32 dtypes are not supported").into()) + } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err( + CudaError::InternalError("Dummy types not supported in CUDA backend").into(), + ) + } }; Ok(Self { slice, @@ -1506,6 +1569,14 @@ impl BackendStorage for CudaStorage { let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::U32(cpu_storage)) } + CudaStorageSlice::I16(slice) => { + let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; + Ok(CpuStorage::I16(cpu_storage)) + } + CudaStorageSlice::I32(slice) => { + let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; + Ok(CpuStorage::I32(cpu_storage)) + } CudaStorageSlice::I64(slice) => { let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::I64(cpu_storage)) @@ -1526,6 +1597,18 @@ impl BackendStorage for CudaStorage { let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::F64(cpu_storage)) } + CudaStorageSlice::F8E4M3(slice) => { + let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; + Ok(CpuStorage::F8E4M3(cpu_storage)) + } + CudaStorageSlice::F4(_) + | CudaStorageSlice::F6E2M3(_) + | CudaStorageSlice::F6E3M2(_) + | CudaStorageSlice::F8E8M0(_) => Err(CudaError::UnsupportedDtype { + dtype: self.dtype(), + op: "to_cpu_storage", + } + .into()), } } @@ -1653,7 +1736,12 @@ impl BackendStorage for CudaStorage { S::F64(out) } (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv1d does not support u32"))?, + (S::I16(_), S::I16(_)) => Err(CudaError::InternalError("conv1d does not support i16"))?, + (S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv1d does not support i32"))?, (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv1d does not support i64"))?, + (S::F8E4M3(_), S::F8E4M3(_)) => { + Err(CudaError::InternalError("conv1d does not support f8e4m3"))? + } _ => Err(CudaError::InternalError("dtype mismatch in conv1d"))?, }; Ok(Self { slice, device }) @@ -1833,7 +1921,12 @@ impl BackendStorage for CudaStorage { S::F64(out) } (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?, + (S::I16(_), S::I16(_)) => Err(CudaError::InternalError("conv2d does not support i16"))?, + (S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv2d does not support i32"))?, (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?, + (S::F8E4M3(_), S::F8E4M3(_)) => { + Err(CudaError::InternalError("conv2d does not support f8e4m3"))? + } _ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?, }; Ok(Self { slice, device }) @@ -2017,11 +2110,16 @@ impl BackendStorage for CudaStorage { let ((src, _guard_src), (dst, _guard_dst), kname) = match (&self.slice, &mut dst.slice) { (S::U8(s), S::U8(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_u8"), (S::U32(s), S::U32(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_u32"), + (S::I16(s), S::I16(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_i16"), + (S::I32(s), S::I32(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_i32"), (S::I64(s), S::I64(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_i64"), (S::BF16(s), S::BF16(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_bf16"), (S::F16(s), S::F16(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f16"), (S::F32(s), S::F32(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f32"), (S::F64(s), S::F64(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f64"), + (S::F8E4M3(_s), S::F8E4M3(_d)) => { + Err(CudaError::InternalError("copy2d not supported for f8e4m3"))? + } _ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?, }; let func = dev.get_or_load_func(kname, &kernels::FILL)?; @@ -2129,6 +2227,38 @@ impl BackendStorage for CudaStorage { unsafe { builder.launch(cfg) }.w()?; } } + (CudaStorageSlice::I16(src), CudaStorageSlice::I16(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.memcpy_dtod(&src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_i16", &kernels::UNARY)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(&src); + builder.arg(&mut dst); + // SAFETY: ffi. + unsafe { builder.launch(cfg) }.w()?; + } + } + (CudaStorageSlice::I32(src), CudaStorageSlice::I32(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.memcpy_dtod(&src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_i32", &kernels::UNARY)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(&src); + builder.arg(&mut dst); + // SAFETY: ffi. + unsafe { builder.launch(cfg) }.w()?; + } + } (CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { @@ -2161,6 +2291,22 @@ impl BackendStorage for CudaStorage { unsafe { builder.launch(cfg) }.w()?; } } + (CudaStorageSlice::F8E4M3(src), CudaStorageSlice::F8E4M3(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.memcpy_dtod(&src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_f8e4m3", &kernels::UNARY)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(&src); + builder.arg(&mut dst); + // SAFETY: ffi. + unsafe { builder.launch(cfg) }.w()?; + } + } _ => Err(CudaError::InternalError( "dtype mismatch in copy_strided op", ))?, diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index 0a81f0ac7f..014b9e6c39 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -19,11 +19,16 @@ pub trait Map1 { let out = match s { S::U8(s) => S::U8(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?), + S::I16(s) => S::I16(self.f(s, d, l)?), + S::I32(s) => S::I32(self.f(s, d, l)?), S::I64(s) => S::I64(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?), S::F16(s) => S::F16(self.f(s, d, l)?), S::F32(s) => S::F32(self.f(s, d, l)?), S::F64(s) => S::F64(self.f(s, d, l)?), + S::F8E4M3(_) | S::F4(_) | S::F6E2M3(_) | S::F6E3M2(_) | S::F8E8M0(_) => { + crate::bail!("Map1 does not uspport this dtype."); + } }; Ok(out) } @@ -43,11 +48,16 @@ pub trait Map2 { let out = match (s1, s2) { (S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?), (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), + (S::I16(s1), S::I16(s2)) => S::I16(self.f(s1, l1, s2, l2, d)?), + (S::I32(s1), S::I32(s2)) => S::I32(self.f(s1, l1, s2, l2, d)?), (S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?), (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), + (S::F8E4M3(_), S::F8E4M3(_)) => { + Err(CudaError::InternalError("Map2 not supported for F8E4M3"))? + } _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, }; Ok(out) @@ -113,11 +123,16 @@ pub trait Map2InPlace { match (dst, src) { (S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d), (S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d), + (S::I16(dst), S::I16(src)) => self.f(dst, dst_l, src, src_l, d), + (S::I32(dst), S::I32(src)) => self.f(dst, dst_l, src, src_l, d), (S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d), (S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d), (S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d), (S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d), (S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d), + (S::F8E4M3(_), S::F8E4M3(_)) => Err(CudaError::InternalError( + "Map2InPlace not supported for F8E4M3", + ))?, _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, } } @@ -136,11 +151,16 @@ pub trait Map1Any { let out = match s { S::U8(s) => self.f(s, d, l, S::U8)?, S::U32(s) => self.f(s, d, l, S::U32)?, + S::I16(s) => self.f(s, d, l, S::I16)?, + S::I32(s) => self.f(s, d, l, S::I32)?, S::I64(s) => self.f(s, d, l, S::I64)?, S::BF16(s) => self.f(s, d, l, S::BF16)?, S::F16(s) => self.f(s, d, l, S::F16)?, S::F32(s) => self.f(s, d, l, S::F32)?, S::F64(s) => self.f(s, d, l, S::F64)?, + S::F8E4M3(_) | S::F4(_) | S::F6E2M3(_) | S::F6E3M2(_) | S::F8E8M0(_) => { + crate::bail!("Map1 does not uspport this dtype."); + } }; Ok(out) } diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 8d0b8b3595..64c57fc80c 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -267,6 +267,15 @@ impl Device { } } + /// Get the current seed for the device RNG. + pub fn get_current_seed(&self) -> Result { + match self { + Self::Cpu => CpuDevice.get_current_seed(), + Self::Cuda(c) => c.get_current_seed(), + Self::Metal(m) => m.get_current_seed(), + } + } + pub fn same_device(&self, rhs: &Self) -> bool { match (self, rhs) { (Self::Cpu, Self::Cpu) => true, diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 76d39010a9..a9b53947f3 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -13,10 +13,10 @@ impl Tensor { let device_str = match self.device().location() { crate::DeviceLocation::Cpu => "".to_owned(), crate::DeviceLocation::Cuda { gpu_id } => { - format!(", cuda:{}", gpu_id) + format!(", cuda:{gpu_id}") } crate::DeviceLocation::Metal { gpu_id } => { - format!(", metal:{}", gpu_id) + format!(", metal:{gpu_id}") } }; @@ -56,11 +56,22 @@ impl std::fmt::Debug for Tensor { match self.dtype() { DType::U8 => self.fmt_dt::(f), DType::U32 => self.fmt_dt::(f), + DType::I16 => self.fmt_dt::(f), + DType::I32 => self.fmt_dt::(f), DType::I64 => self.fmt_dt::(f), DType::BF16 => self.fmt_dt::(f), DType::F16 => self.fmt_dt::(f), DType::F32 => self.fmt_dt::(f), DType::F64 => self.fmt_dt::(f), + DType::F8E4M3 => self.fmt_dt::(f), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + write!( + f, + "Tensor[{:?}; dtype={}, unsupported dummy type]", + self.shape(), + self.dtype().as_str() + ) + } } } } @@ -464,6 +475,18 @@ impl std::fmt::Display for Tensor { tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; writeln!(f)?; } + DType::I16 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + DType::I32 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } DType::I64 => { let tf: IntFormatter = IntFormatter::new(); let max_w = tf.max_width(&to_display); @@ -498,15 +521,29 @@ impl std::fmt::Display for Tensor { writeln!(f)?; } } + DType::F8E4M3 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + writeln!( + f, + "Dummy type {} (not supported for display)", + self.dtype().as_str() + )?; + } }; let device_str = match self.device().location() { crate::DeviceLocation::Cpu => "".to_owned(), crate::DeviceLocation::Cuda { gpu_id } => { - format!(", cuda:{}", gpu_id) + format!(", cuda:{gpu_id}") } crate::DeviceLocation::Metal { gpu_id } => { - format!(", metal:{}", gpu_id) + format!(", metal:{gpu_id}") } }; diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index b0697c1935..035ca6d503 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -10,6 +10,10 @@ pub enum DType { U8, // Unsigned 32 bits integer. U32, + // Signed 16 bits integer. + I16, + // Signed 32 bits integer. + I32, // Signed 64 bits integer. I64, // Brain floating-point using half precision (16 bits). @@ -20,6 +24,16 @@ pub enum DType { F32, // Floating-point using double precision (64 bits). F64, + // 8-bit floating point with 4-bit exponent and 3-bit mantissa. + F8E4M3, + /// 6-bit float with 2 exponent bits and 3 mantissa bits (MX6 format) + F6E2M3, + /// 6-bit float with 3 exponent bits and 2 mantissa bits (MX6 format) + F6E3M2, + /// 4-bit float (MX4 format) + F4, + /// 8-bit float with 8 exponent bits and 0 mantissa bits + F8E8M0, } #[derive(Debug, PartialEq, Eq)] @@ -39,11 +53,18 @@ impl std::str::FromStr for DType { match s { "u8" => Ok(Self::U8), "u32" => Ok(Self::U32), + "i16" => Ok(Self::I16), + "i32" => Ok(Self::I32), "i64" => Ok(Self::I64), "bf16" => Ok(Self::BF16), "f16" => Ok(Self::F16), "f32" => Ok(Self::F32), "f64" => Ok(Self::F64), + "f8e4m3" => Ok(Self::F8E4M3), + "f6e2m3" => Ok(Self::F6E2M3), + "f6e3m2" => Ok(Self::F6E3M2), + "f4" => Ok(Self::F4), + "f8e8m0" => Ok(Self::F8E8M0), _ => Err(DTypeParseError(s.to_string())), } } @@ -55,11 +76,18 @@ impl DType { match self { Self::U8 => "u8", Self::U32 => "u32", + Self::I16 => "i16", + Self::I32 => "i32", Self::I64 => "i64", Self::BF16 => "bf16", Self::F16 => "f16", Self::F32 => "f32", Self::F64 => "f64", + Self::F8E4M3 => "f8e4m3", + Self::F6E2M3 => "f6e2m3", + Self::F6E3M2 => "f6e3m2", + Self::F4 => "f4", + Self::F8E8M0 => "f8e8m0", } } @@ -68,25 +96,48 @@ impl DType { match self { Self::U8 => 1, Self::U32 => 4, + Self::I16 => 2, + Self::I32 => 4, Self::I64 => 8, Self::BF16 => 2, Self::F16 => 2, Self::F32 => 4, Self::F64 => 8, + Self::F8E4M3 => 1, + Self::F6E2M3 => 0, // 6 bits + Self::F6E3M2 => 0, // 6 bits + Self::F4 => 0, // 4 bits + Self::F8E8M0 => 1, } } pub fn is_int(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I64 => true, - Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false, + Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => true, + Self::BF16 + | Self::F16 + | Self::F32 + | Self::F64 + | Self::F8E4M3 + | Self::F6E2M3 + | Self::F6E3M2 + | Self::F4 + | Self::F8E8M0 => false, } } pub fn is_float(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I64 => false, - Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true, + Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => false, + Self::BF16 + | Self::F16 + | Self::F32 + | Self::F64 + | Self::F8E4M3 + | Self::F6E2M3 + | Self::F6E3M2 + | Self::F4 + | Self::F8E8M0 => true, } } } @@ -170,15 +221,19 @@ macro_rules! with_dtype { } }; } +use float8::F8E4M3 as f8e4m3; use half::{bf16, f16}; with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); +with_dtype!(i16, I16, |v: f64| v as i16, |v: i16| v as f64); +with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64); with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64); with_dtype!(f16, F16, f16::from_f64, f16::to_f64); with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); with_dtype!(f64, F64, |v: f64| v, |v: f64| v); +with_dtype!(f8e4m3, F8E4M3, f8e4m3::from_f64, |v: f8e4m3| v.to_f64()); pub trait IntDType: WithDType + num_traits::Bounded { fn is_true(&self) -> bool; @@ -212,9 +267,28 @@ impl IntDType for u8 { } } +impl IntDType for i16 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + +impl IntDType for i32 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + pub trait FloatDType: WithDType {} impl FloatDType for f16 {} impl FloatDType for bf16 {} impl FloatDType for f32 {} impl FloatDType for f64 {} +impl FloatDType for f8e4m3 {} diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 329099354b..f55f39308d 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -218,6 +218,10 @@ impl crate::backend::BackendDevice for CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + fn get_current_seed(&self) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn location(&self) -> crate::DeviceLocation { fail!() } diff --git a/candle-core/src/dummy_dtype.rs b/candle-core/src/dummy_dtype.rs new file mode 100644 index 0000000000..85b9d2efa5 --- /dev/null +++ b/candle-core/src/dummy_dtype.rs @@ -0,0 +1,268 @@ +//! Dummy data types for experimental/future float formats +//! +//! These are placeholder types for experimental floating-point formats +//! that are defined in the safetensors spec but not yet fully implemented. + +use crate::{DType, Error, Result, WithDType}; + +/// 6-bit float with 2 exponent bits and 3 mantissa bits (MX6 format) +/// This is a dummy type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct F6E2M3; + +/// 6-bit float with 3 exponent bits and 2 mantissa bits (MX6 format) +/// This is a dummy type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct F6E3M2; + +/// 4-bit float (MX4 format) +/// This is a dummy type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct F4; + +/// 8-bit float with 8 exponent bits and 0 mantissa bits +/// This is a dummy type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct F8E8M0; + +// Implement WithDType for dummy types +macro_rules! dummy_with_dtype { + ($ty:ty, $dtype:ident) => { + impl WithDType for $ty { + const DTYPE: DType = DType::$dtype; + + fn from_f64(_v: f64) -> Self { + panic!( + "{} is a dummy type and cannot be constructed", + stringify!($ty) + ) + } + + fn to_f64(self) -> f64 { + panic!( + "{} is a dummy type and cannot be converted", + stringify!($ty) + ) + } + + fn to_scalar(self) -> crate::scalar::Scalar { + panic!( + "{} is a dummy type and cannot be converted to scalar", + stringify!($ty) + ) + } + + fn cpu_storage_ref(_data: &[Self]) -> crate::CpuStorageRef { + panic!( + "{} is a dummy type and does not support storage", + stringify!($ty) + ) + } + + fn to_cpu_storage_owned(_data: Vec) -> crate::CpuStorage { + panic!( + "{} is a dummy type and does not support storage", + stringify!($ty) + ) + } + + fn cpu_storage_data(_s: crate::CpuStorage) -> Result> { + Err(Error::UnsupportedDTypeForOp(DType::$dtype, "cpu_storage_data").bt()) + } + + fn cpu_storage_as_slice(_s: &crate::CpuStorage) -> Result<&[Self]> { + Err(Error::UnsupportedDTypeForOp(DType::$dtype, "cpu_storage_as_slice").bt()) + } + } + }; +} + +dummy_with_dtype!(F6E2M3, F6E2M3); +dummy_with_dtype!(F6E3M2, F6E3M2); +dummy_with_dtype!(F4, F4); +dummy_with_dtype!(F8E8M0, F8E8M0); + +// Implement NumAssign traits for dummy types +macro_rules! dummy_num_assign { + ($ty:ty) => { + impl std::ops::AddAssign for $ty { + fn add_assign(&mut self, _other: Self) { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::SubAssign for $ty { + fn sub_assign(&mut self, _other: Self) { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::MulAssign for $ty { + fn mul_assign(&mut self, _other: Self) { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::DivAssign for $ty { + fn div_assign(&mut self, _other: Self) { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::RemAssign for $ty { + fn rem_assign(&mut self, _other: Self) { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::Add for $ty { + type Output = Self; + fn add(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::Sub for $ty { + type Output = Self; + fn sub(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::Mul for $ty { + type Output = Self; + fn mul(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::Div for $ty { + type Output = Self; + fn div(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl std::ops::Rem for $ty { + type Output = Self; + fn rem(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl num_traits::Zero for $ty { + fn zero() -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + + fn is_zero(&self) -> bool { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl num_traits::One for $ty { + fn one() -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + + impl num_traits::Num for $ty { + type FromStrRadixErr = std::num::ParseFloatError; + + fn from_str_radix( + _str: &str, + _radix: u32, + ) -> std::result::Result { + panic!( + "{} is a dummy type and does not support parsing", + stringify!($ty) + ) + } + } + + impl crate::cpu::kernels::VecOps for $ty { + fn min(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + + fn max(self, _other: Self) -> Self { + panic!( + "{} is a dummy type and does not support operations", + stringify!($ty) + ) + } + } + }; +} + +dummy_num_assign!(F6E2M3); +dummy_num_assign!(F6E3M2); +dummy_num_assign!(F4); +dummy_num_assign!(F8E8M0); + +// Display implementations +impl std::fmt::Display for F6E2M3 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "F6E2M3") + } +} + +impl std::fmt::Display for F6E3M2 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "F6E3M2") + } +} + +impl std::fmt::Display for F4 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "F4") + } +} + +impl std::fmt::Display for F8E8M0 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "F8E8M0") + } +} diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index de43f243fb..f4955f2d17 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -222,6 +222,10 @@ impl crate::backend::BackendDevice for MetalDevice { Err(Error::NotCompiledWithMetalSupport) } + fn get_current_seed(&self) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn location(&self) -> crate::DeviceLocation { fail!() } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 5729013be3..0a9b35bbeb 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -1,4 +1,6 @@ //! Candle-specific Error and Result +use std::{convert::Infallible, fmt::Display}; + use crate::{DType, DeviceLocation, Layout, MetalError, Shape}; #[derive(Debug, Clone)] @@ -209,6 +211,13 @@ pub enum Error { #[error("{0}")] Wrapped(Box), + /// Arbitrary errors wrapping with context. + #[error("{wrapped:?}\n{context:?}")] + WrappedContext { + wrapped: Box, + context: String, + }, + #[error("{context}\n{inner}")] Context { inner: Box, @@ -299,40 +308,85 @@ pub fn zip(r1: Result, r2: Result) -> Result<(T, U)> { } } -// Taken from anyhow. -pub trait Context { +pub(crate) mod private { + pub trait Sealed {} + + impl Sealed for std::result::Result where E: std::error::Error {} + impl Sealed for Option {} +} + +/// Attach more context to an error. +/// +/// Inspired by [`anyhow::Context`]. +pub trait Context: private::Sealed { /// Wrap the error value with additional context. - fn context(self, context: C) -> Result + fn context(self, context: C) -> std::result::Result where - C: std::fmt::Display + Send + Sync + 'static; + C: Display + Send + Sync + 'static; /// Wrap the error value with additional context that is evaluated lazily /// only once an error does occur. - fn with_context(self, f: F) -> Result + fn with_context(self, f: F) -> std::result::Result where - C: std::fmt::Display + Send + Sync + 'static, + C: Display + Send + Sync + 'static, F: FnOnce() -> C; } -impl Context for Option { - fn context(self, context: C) -> Result +impl Context for std::result::Result +where + E: std::error::Error + Send + Sync + 'static, +{ + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using map_err to save 2 useless frames off the captured backtrace + // in ext_context. + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context.to_string(), + }), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context().to_string(), + }), + } + } +} + +impl Context for Option { + fn context(self, context: C) -> std::result::Result where - C: std::fmt::Display + Send + Sync + 'static, + C: Display + Send + Sync + 'static, { + // Not using ok_or_else to save 2 useless frames off the captured + // backtrace. match self { - Some(v) => Ok(v), - None => Err(Error::UnwrapNone.context(context).bt()), + Some(ok) => Ok(ok), + None => Err(Error::msg(context)), } } - fn with_context(self, f: F) -> Result + fn with_context(self, context: F) -> std::result::Result where - C: std::fmt::Display + Send + Sync + 'static, + C: Display + Send + Sync + 'static, F: FnOnce() -> C, { match self { - Some(v) => Ok(v), - None => Err(Error::UnwrapNone.context(f()).bt()), + Some(ok) => Ok(ok), + None => Err(Error::msg(context())), } } } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 16dc8e02aa..d75e23542f 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -62,6 +62,7 @@ mod device; pub mod display; mod dtype; pub mod dummy_cuda_backend; +pub mod dummy_dtype; mod dummy_metal_backend; pub mod error; mod indexer; @@ -94,6 +95,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef}; pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1}; pub use device::{Device, DeviceLocation, NdArray}; pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; +pub use dummy_dtype::{F4, F6E2M3, F6E3M2, F8E8M0}; pub use error::{Context, Error, Result}; pub use indexer::{IndexOp, TensorIndexer}; pub use layout::Layout; @@ -101,7 +103,7 @@ pub use shape::{Shape, D}; pub use storage::Storage; pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule}; pub use strided_index::{StridedBlocks, StridedIndex}; -pub use tensor::{Tensor, TensorId}; +pub use tensor::{from_storage_no_op, Tensor, TensorId}; pub use variable::Var; #[cfg(feature = "cuda")] diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 43869a0c3a..f249202ddd 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -58,14 +58,21 @@ impl Commands { }) } + pub fn flush_command_buffer(&mut self) -> Result<()> { + self.command_buffer.commit(); + let command_buffer = self.command_queue.new_command_buffer().to_owned(); + self.command_buffer = command_buffer.clone(); + self.command_buffer_index = 0; + + Ok(()) + } + pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> { let mut command_buffer = self.command_buffer.to_owned(); let mut flushed = false; if self.command_buffer_index > self.compute_per_buffer { - self.command_buffer.commit(); - command_buffer = self.command_queue.new_command_buffer().to_owned(); - self.command_buffer = command_buffer.clone(); - self.command_buffer_index = 0; + self.flush_command_buffer()?; + command_buffer = self.command_buffer.to_owned(); flushed = true; } self.command_buffer_index += 1; @@ -120,6 +127,8 @@ pub struct MetalDevice { pub(crate) kernels: Arc, /// Seed for random number generation. pub(crate) seed: Arc>, + /// Value of the current seed + pub(crate) seed_value: Arc>, } impl std::fmt::Debug for MetalDevice { @@ -181,6 +190,13 @@ impl MetalDevice { Ok(()) } + pub fn flush_command_buffer(&self) -> Result<()> { + let mut commands = self.commands.write().map_err(MetalError::from)?; + commands.flush_command_buffer()?; + + Ok(()) + } + pub fn command_buffer(&self) -> Result { let mut commands = self.commands.write().map_err(MetalError::from)?; let (flushed, command_buffer) = commands.command_buffer()?; @@ -190,6 +206,11 @@ impl MetalDevice { Ok(command_buffer) } + pub fn new_command_buffer(&self) -> Result { + let commands = self.commands.write().map_err(MetalError::from)?; + Ok(commands.command_queue.new_command_buffer().to_owned()) + } + pub fn wait_until_completed(&self) -> Result<()> { let mut commands = self.commands.write().map_err(MetalError::from)?; commands.wait_until_completed() @@ -218,6 +239,16 @@ impl MetalDevice { self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) } + pub fn new_buffer_private( + &self, + element_count: usize, + dtype: DType, + name: &str, + ) -> Result> { + let size = (element_count * dtype.size_in_bytes()) as NSUInteger; + self.allocate_buffer(size, metal::MTLResourceOptions::StorageModePrivate, name) + } + /// Creates a new buffer (not necessarily zeroed). /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) /// This means the buffer can be read on the CPU but will require manual diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 2bb07ea44d..91f1b36e52 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -3,7 +3,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; -use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; +use crate::{CpuStorage, CpuStorageRef, DType, Error, Layout, Result, Shape}; use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels}; use metal::{Buffer, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; @@ -98,11 +98,17 @@ impl BackendStorage for MetalStorage { match self.dtype { DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), + DType::I16 => Ok(CpuStorage::I16(self.to_cpu()?)), + DType::I32 => Ok(CpuStorage::I32(self.to_cpu()?)), DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)), DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)), + DType::F8E4M3 => Ok(CpuStorage::F8E4M3(self.to_cpu()?)), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + Err(crate::Error::UnsupportedDTypeForOp(self.dtype, "to_cpu_storage").bt()) + } } } @@ -457,6 +463,12 @@ impl BackendStorage for MetalStorage { DType::U32 => contiguous::const_set::U32, DType::U8 => contiguous::const_set::U8, DType::F64 => crate::bail!("unsupported const-set f64"), + DType::F8E4M3 => crate::bail!("unsupported const-set f8e4m3"), + DType::I16 => crate::bail!("unsupported const-set i16"), + DType::I32 => crate::bail!("unsupported const-set i32"), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + crate::bail!("unsupported const-set for dummy type {:?}", dtype) + } }; candle_metal_kernels::call_const_set_contiguous( &device.device, @@ -479,6 +491,12 @@ impl BackendStorage for MetalStorage { DType::U32 => strided::const_set::U32, DType::U8 => strided::const_set::U8, DType::F64 => crate::bail!("unsupported const-set f64"), + DType::F8E4M3 => crate::bail!("unsupported const-set f8e4m3"), + DType::I16 => crate::bail!("unsupported const-set i16"), + DType::I32 => crate::bail!("unsupported const-set i32"), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + crate::bail!("unsupported const-set for dummy type {:?}", dtype) + } }; candle_metal_kernels::call_const_set_strided( &device.device, @@ -503,6 +521,7 @@ impl BackendStorage for MetalStorage { (DType::BF16, Scalar::BF16(s)) => set(self, s, l), (DType::F32, Scalar::F32(s)) => set(self, s, l), (DType::F64, Scalar::F64(s)) => set(self, s, l), + (DType::F8E4M3, _) => crate::bail!("Metal const_set does not support f8e4m3"), _ => crate::bail!("dtype mismatch, expected {:?}, got {:?}", self.dtype, s), } } @@ -2055,6 +2074,7 @@ impl BackendDevice for MetalDevice { buffers: Arc::new(RwLock::new(HashMap::new())), kernels, seed, + seed_value: Arc::new(RwLock::new(299792458)), }) } @@ -2093,11 +2113,20 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match T::cpu_storage_ref(s) { CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F8E4M3(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F6E2M3(_) + | CpuStorageRef::F6E3M2(_) + | CpuStorageRef::F4(_) + | CpuStorageRef::F8E8M0(_) => { + return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "to_dtype").bt()) + } }; Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE)) } @@ -2106,11 +2135,20 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match storage { CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::F8E4M3(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::F6E2M3(_) + | CpuStorage::F6E3M2(_) + | CpuStorage::F4(_) + | CpuStorage::F8E8M0(_) => { + return Err(Error::UnsupportedDTypeForOp(storage.dtype(), "to_dtype").bt()) + } }; Ok(Self::Storage::new( buffer?, @@ -2197,6 +2235,8 @@ impl BackendDevice for MetalDevice { } fn set_seed(&self, seed: u64) -> Result<()> { + *self.seed_value.write().unwrap() = seed; + let seed: u32 = seed.try_into().map_err(|_| { MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string()) })?; @@ -2211,6 +2251,10 @@ impl BackendDevice for MetalDevice { Ok(()) } + fn get_current_seed(&self) -> Result { + Ok(*self.seed_value.read().unwrap()) + } + fn synchronize(&self) -> Result<()> { self.wait_until_completed() } diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 83e4f6527f..496465ec33 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -85,9 +85,16 @@ impl Header { DType::F16 => "f2", DType::F32 => "f4", DType::F64 => "f8", + DType::I16 => "i2", + DType::I32 => "i4", DType::I64 => "i8", DType::U32 => "u4", DType::U8 => "u1", + DType::F8E4M3 => Err(Error::Npy("f8e4m3 is not supported".into()))?, + DType::F6E2M3 => Err(Error::Npy("f6e2m3 is not supported".into()))?, + DType::F6E3M2 => Err(Error::Npy("f6e3m2 is not supported".into()))?, + DType::F4 => Err(Error::Npy("f4 is not supported".into()))?, + DType::F8E8M0 => Err(Error::Npy("f8e8m0 is not supported".into()))?, }; if !shape.is_empty() { shape.push(',') @@ -106,7 +113,7 @@ impl Header { let mut parts: Vec = vec![]; let mut start_index = 0usize; let mut cnt_parenthesis = 0i64; - for (index, c) in header.chars().enumerate() { + for (index, c) in header.char_indices() { match c { '(' => cnt_parenthesis += 1, ')' => cnt_parenthesis -= 1, @@ -160,9 +167,9 @@ impl Header { "e" | "f2" => DType::F16, "f" | "f4" => DType::F32, "d" | "f8" => DType::F64, - // "i" | "i4" => DType::S32, + "i" | "i4" => DType::I32, "q" | "i8" => DType::I64, - // "h" | "i2" => DType::S16, + "h" | "i2" => DType::I16, // "b" | "i1" => DType::S8, "B" | "u1" => DType::U8, "I" | "u4" => DType::U32, @@ -234,11 +241,31 @@ impl Tensor { reader.read_u32_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::I16 => { + let mut data_t = vec![0i16; elem_count]; + reader.read_i16_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + DType::I32 => { + let mut data_t = vec![0i32; elem_count]; + reader.read_i32_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } DType::I64 => { let mut data_t = vec![0i64; elem_count]; reader.read_i64_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::F8E4M3 => { + let mut data_t = vec![0u8; elem_count]; + reader.read_exact(&mut data_t)?; + let data_f8: Vec = + data_t.into_iter().map(float8::F8E4M3::from_bits).collect(); + Tensor::from_vec(data_f8, shape, &Device::Cpu) + } + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + Err(Error::UnsupportedDTypeForOp(dtype, "from_reader").bt()) + } } } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index e2627f762a..7962a71487 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -2,6 +2,7 @@ //! #![allow(clippy::redundant_closure_call)] use crate::Tensor; +use float8::F8E4M3 as f8e4m3; use half::{bf16, f16}; use num_traits::float::Float; @@ -192,7 +193,10 @@ pub trait UnaryOpT { fn f64(v1: f64) -> f64; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; + fn i16(v1: i16) -> i16; + fn i32(v1: i32) -> i32; fn i64(v1: i64) -> i64; + fn f8e4m3(v1: f8e4m3) -> f8e4m3; // There is no very good way to represent optional function in traits so we go for an explicit // boolean flag to mark the function as existing. @@ -216,7 +220,10 @@ pub trait BinaryOpT { fn f64(v1: f64, v2: f64) -> f64; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; + fn i16(v1: i16, v2: i16) -> i16; + fn i32(v1: i32, v2: i32) -> i32; fn i64(v1: i64, v2: i64) -> i64; + fn f8e4m3(v1: f8e4m3, v2: f8e4m3) -> f8e4m3; const BF16_VEC: bool = false; fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {} @@ -291,9 +298,21 @@ macro_rules! bin_op { $e(v1, v2) } #[inline(always)] + fn i16(v1: i16, v2: i16) -> i16 { + $e(v1, v2) + } + #[inline(always)] + fn i32(v1: i32, v2: i32) -> i32 { + $e(v1, v2) + } + #[inline(always)] fn i64(v1: i64, v2: i64) -> i64 { $e(v1, v2) } + #[inline(always)] + fn f8e4m3(v1: f8e4m3, v2: f8e4m3) -> f8e4m3 { + $e(v1, v2) + } #[cfg(feature = "mkl")] const F32_VEC: bool = true; @@ -379,9 +398,21 @@ macro_rules! unary_op { todo!("no unary function for u32") } #[inline(always)] + fn i16(_: i16) -> i16 { + todo!("no unary function for i16") + } + #[inline(always)] + fn i32(_: i32) -> i32 { + todo!("no unary function for i32") + } + #[inline(always)] fn i64(_: i64) -> i64 { todo!("no unary function for i64") } + #[inline(always)] + fn f8e4m3($a: f8e4m3) -> f8e4m3 { + $e + } } }; @@ -415,9 +446,21 @@ macro_rules! unary_op { todo!("no unary function for u32") } #[inline(always)] + fn i16(_: i16) -> i16 { + todo!("no unary function for i16") + } + #[inline(always)] + fn i32(_: i32) -> i32 { + todo!("no unary function for i32") + } + #[inline(always)] fn i64(_: i64) -> i64 { todo!("no unary function for i64") } + #[inline(always)] + fn f8e4m3($a: f8e4m3) -> f8e4m3 { + $e + } #[cfg(feature = "mkl")] const F32_VEC: bool = true; @@ -514,9 +557,28 @@ impl UnaryOpT for Gelu { 0 } #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + f8e4m3::from_f32(0.5) + * v + * (f8e4m3::ONE + + f8e4m3::tanh( + f8e4m3::from_f32(SQRT_TWO_OVER_PI_F32) + * v + * (f8e4m3::ONE + f8e4m3::from_f32(0.044715) * v * v), + )) + } const KERNEL: &'static str = "ugelu"; #[cfg(feature = "mkl")] @@ -587,9 +649,21 @@ impl UnaryOpT for Erf { 0 } #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + f8e4m3::from_f64(Self::f64(v.to_f64())) + } } /// Silu operation @@ -621,9 +695,21 @@ impl UnaryOpT for Silu { 0 } #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + v / (f8e4m3::ONE + (-v).exp()) + } const KERNEL: &'static str = "usilu"; #[cfg(feature = "mkl")] @@ -692,9 +778,21 @@ impl UnaryOpT for Abs { v } #[inline(always)] + fn i16(v: i16) -> i16 { + v.abs() + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v.abs() + } + #[inline(always)] fn i64(v: i64) -> i64 { v.abs() } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + v.abs() + } } impl UnaryOpT for Ceil { @@ -726,9 +824,21 @@ impl UnaryOpT for Ceil { v } #[inline(always)] + fn i16(v: i16) -> i16 { + v + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + v.ceil() + } } impl UnaryOpT for Floor { @@ -760,9 +870,21 @@ impl UnaryOpT for Floor { v } #[inline(always)] + fn i16(v: i16) -> i16 { + v + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + v.floor() + } } impl UnaryOpT for Round { @@ -794,9 +916,21 @@ impl UnaryOpT for Round { v } #[inline(always)] + fn i16(v: i16) -> i16 { + v + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + v.round() + } } impl UnaryOpT for GeluErf { @@ -828,9 +962,21 @@ impl UnaryOpT for GeluErf { 0 } #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + f8e4m3::from_f64(Self::f64(v.to_f64())) + } } impl UnaryOpT for Relu { @@ -862,8 +1008,20 @@ impl UnaryOpT for Relu { v } #[inline(always)] + fn i16(v: i16) -> i16 { + v.max(0) + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v.max(0) + } + #[inline(always)] fn i64(v: i64) -> i64 { - v + v.max(0) + } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + v.max(f8e4m3::ZERO) } } @@ -960,7 +1118,25 @@ impl UnaryOpT for Sign { u32::min(1, v) } #[inline(always)] + fn i16(v: i16) -> i16 { + (v > 0) as i16 - (v < 0) as i16 + } + #[inline(always)] + fn i32(v: i32) -> i32 { + (v > 0) as i32 - (v < 0) as i32 + } + #[inline(always)] fn i64(v: i64) -> i64 { (v > 0) as i64 - (v < 0) as i64 } + #[inline(always)] + fn f8e4m3(v: f8e4m3) -> f8e4m3 { + if v > f8e4m3::ZERO { + f8e4m3::ONE + } else if v < f8e4m3::ZERO { + -f8e4m3::ONE + } else { + f8e4m3::ZERO + } + } } diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index c8d483a37a..c2ac1d705d 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -443,6 +443,7 @@ impl QCudaStorage { GgmlDType::Q5K => deq::(&buffer, block_len, &mut out)?, GgmlDType::Q6K => deq::(&buffer, block_len, &mut out)?, GgmlDType::Q8K => deq::(&buffer, block_len, &mut out)?, + GgmlDType::BF16 => deq::(&buffer, block_len, &mut out)?, } self.device @@ -476,6 +477,87 @@ impl QCudaStorage { Ok(()) } + pub fn quantize_imatrix( + &mut self, + src: &CudaStorage, + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + // Run the quantization on cpu. + let src = match &src.slice { + crate::cuda_backend::CudaStorageSlice::F32(data) => self.device.memcpy_dtov(data)?, + _ => crate::bail!("only f32 can be quantized"), + }; + let src_len = src.len(); + let src = crate::Storage::Cpu(crate::CpuStorage::F32(src)); + let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; + qcpu_storage.quantize_imatrix(&src, imatrix_weights, n_per_row)?; + let data = qcpu_storage.data()?; + let padded_len = + data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size(); + let mut inner = unsafe { self.device.alloc::(padded_len)? }; + self.device + .memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?; + self.data = PaddedCudaSlice { + inner, + len: data.len(), + }; + Ok(()) + } + + pub fn quantize_imatrix_onto( + &mut self, + src: &crate::CpuStorage, + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + // Run the quantization on cpu. + let src_len = src.as_slice::()?.len(); + let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; + + if let QStorage::Cpu(storage) = &mut qcpu_storage { + storage.from_float_imatrix(src.as_slice::()?, imatrix_weights, n_per_row)?; + } else { + unreachable!() + } + + let data = qcpu_storage.data()?; + let padded_len = + data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size(); + let mut inner = unsafe { self.device.alloc::(padded_len)? }; + self.device + .memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?; + self.data = PaddedCudaSlice { + inner, + len: data.len(), + }; + Ok(()) + } + + pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> { + // Run the quantization on cpu. + let src_len = src.as_slice::()?.len(); + let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; + + if let QStorage::Cpu(storage) = &mut qcpu_storage { + storage.from_float(src.as_slice::()?)?; + } else { + unreachable!() + } + + let data = qcpu_storage.data()?; + let padded_len = + data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size(); + let mut inner = unsafe { self.device.alloc::(padded_len)? }; + self.device + .memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?; + self.data = PaddedCudaSlice { + inner, + len: data.len(), + }; + Ok(()) + } + pub fn storage_size_in_bytes(&self) -> usize { self.data.len } @@ -502,6 +584,13 @@ impl QCudaStorage { self.dequantize_matmul(self_shape, storage, layout) } } + + pub fn data(&self) -> Result> { + let mut out = vec![0u8; self.data.len]; + self.device + .memcpy_dtoh(&self.data.inner.slice(..self.data.len), &mut out)?; + Ok(out) + } } impl QCudaStorage { diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs index ca7b812084..1636f50bb7 100644 --- a/candle-core/src/quantized/dummy_cuda.rs +++ b/candle-core/src/quantized/dummy_cuda.rs @@ -32,6 +32,28 @@ impl QCudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub fn quantize_imatrix( + &mut self, + _src: &CudaStorage, + _imatrix_weights: &[f32], + _n_per_row: usize, + ) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn quantize_imatrix_onto( + &mut self, + _src: &crate::CpuStorage, + _imatrix_weights: &[f32], + _n_per_row: usize, + ) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + pub fn storage_size_in_bytes(&self) -> usize { 0 } @@ -44,6 +66,10 @@ impl QCudaStorage { ) -> Result<(CudaStorage, crate::Shape)> { Err(Error::NotCompiledWithCudaSupport) } + + pub fn data(&self) -> Result> { + Err(Error::NotCompiledWithCudaSupport) + } } pub fn load_quantized( diff --git a/candle-core/src/quantized/dummy_metal.rs b/candle-core/src/quantized/dummy_metal.rs index 520d0ed49a..d4d87861f9 100644 --- a/candle-core/src/quantized/dummy_metal.rs +++ b/candle-core/src/quantized/dummy_metal.rs @@ -28,6 +28,28 @@ impl QMetalStorage { Err(Error::NotCompiledWithMetalSupport) } + pub fn quantize_imatrix( + &mut self, + _src: &MetalStorage, + _imatrix_weights: &[f32], + _n_per_row: usize, + ) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + pub fn quantize_imatrix_onto( + &mut self, + _src: &crate::CpuStorage, + _imatrix_weights: &[f32], + _n_per_row: usize, + ) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + pub fn storage_size_in_bytes(&self) -> usize { 0 } @@ -40,6 +62,10 @@ impl QMetalStorage { ) -> Result<(MetalStorage, crate::Shape)> { Err(Error::NotCompiledWithMetalSupport) } + + pub fn data(&self) -> Result> { + Err(Error::NotCompiledWithMetalSupport) + } } pub fn load_quantized( diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 0f7e9c118c..ea5ec02578 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -134,7 +134,7 @@ fn from_raw_data( super::QTensor::new(data, dims) } -/// Creates a Tensor from a raw GGML tensor. +/// Creates a [Tensor] from a raw GGML tensor. pub fn qtensor_from_ggml( ggml_dtype: GgmlDType, raw_data: &[u8], @@ -153,6 +153,7 @@ pub fn qtensor_from_ggml( match ggml_dtype { GgmlDType::F32 => from_raw_data::(raw_data, size_in_bytes, dims, device), GgmlDType::F16 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::BF16 => from_raw_data::(raw_data, size_in_bytes, dims, device), GgmlDType::Q4_0 => { from_raw_data::(raw_data, size_in_bytes, dims, device) } diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index 2ea6c7a34c..cabe41d647 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -1,8 +1,9 @@ -//! Support for the [GGUF file format](https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md). +//! Support for the GGUF file format. //! +//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md use super::{GgmlDType, QTensor}; -use crate::{Context, Device, Result}; +use crate::{Device, Result}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::collections::HashMap; @@ -338,7 +339,7 @@ impl Value { if value_type.len() != 1 { crate::bail!("multiple value-types in the same array {value_type:?}") } - value_type.into_iter().next().context("empty value_type")? + value_type.into_iter().next().unwrap() }; w.write_u32::(value_type.to_u32())?; w.write_u64::(v.len() as u64)?; diff --git a/candle-core/src/quantized/imatrix_file.rs b/candle-core/src/quantized/imatrix_file.rs new file mode 100644 index 0000000000..db434f7f3e --- /dev/null +++ b/candle-core/src/quantized/imatrix_file.rs @@ -0,0 +1,85 @@ +use std::collections::HashMap; +use std::fs::File; +use std::io::{Cursor, Read}; +use std::path::Path; + +use byteorder::{LittleEndian, ReadBytesExt}; + +use crate::Result; + +pub fn load_imatrix>(fname: P) -> Result>> { + let mut all_data = HashMap::new(); + + let mut file = File::open(&fname).map_err(|e| { + crate::Error::msg(format!( + "Failed to open {}: {}", + fname.as_ref().display(), + e + )) + })?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer).map_err(|e| { + crate::Error::msg(format!( + "Failed to read file {}: {}", + fname.as_ref().display(), + e + )) + })?; + + let mut cursor = Cursor::new(buffer); + + let n_entries = cursor + .read_i32::() + .map_err(|e| crate::Error::msg(format!("Failed to read number of entries: {}", e)))? + as usize; + + if n_entries < 1 { + crate::bail!("No data in file {}", fname.as_ref().display()); + } + + for i in 0..n_entries { + // Read length of the name + let len = cursor.read_i32::().map_err(|e| { + crate::Error::msg(format!( + "Failed to read name length for entry {}: {}", + i + 1, + e + )) + })? as usize; + + // Read the name + let mut name_buf = vec![0u8; len]; + cursor.read_exact(&mut name_buf).map_err(|e| { + crate::Error::msg(format!("Failed to read name for entry {}: {}", i + 1, e)) + })?; + let name = String::from_utf8(name_buf).map_err(|e| { + crate::Error::msg(format!("Invalid UTF-8 name for entry {}: {}", i + 1, e)) + })?; + + // Read ncall and nval + let ncall = cursor.read_i32::().map_err(|e| { + crate::Error::msg(format!("Failed to read ncall for entry {}: {}", i + 1, e)) + })? as usize; + + let nval = cursor.read_i32::().map_err(|e| { + crate::Error::msg(format!("Failed to read nval for entry {}: {}", i + 1, e)) + })? as usize; + + if nval < 1 { + crate::bail!("Invalid nval for entry {}: {}", i + 1, nval); + } + + let mut data = Vec::with_capacity(nval); + for _ in 0..nval { + let v = cursor.read_f32::().unwrap(); + if ncall == 0 { + data.push(v); + } else { + data.push(v / ncall as f32); + } + } + all_data.insert(name, data); + } + + Ok(all_data) +} diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 1d3e053898..3789849fa9 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -3,9 +3,10 @@ use super::utils::{ make_qkx1_quants, make_qx_quants, nearest_int, }; use super::GgmlDType; +use crate::quantized::utils::{make_qkx3_quants, make_qp_quants}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; -use half::f16; +use half::{bf16, f16}; use rayon::prelude::*; // Default to QK_K 256 rather than 64. @@ -30,6 +31,17 @@ pub trait GgmlType: Sized + Clone + Send + Sync { } fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>; fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>; + fn from_float_imatrix( + _xs: &[f32], + _ys: &mut [Self], + _imatrix_weights: &[f32], + _n_per_row: usize, + ) -> Result<()> { + crate::bail!( + "`from_float_imatrix` is unimplemented for {:?}", + Self::DTYPE + ); + } /// Dot product used as a building block for quantized mat-mul. /// n is the number of elements to be considered. @@ -830,6 +842,72 @@ impl GgmlType for BlockQ2K { } Ok(()) } + + fn from_float_imatrix( + xs: &[f32], + ys: &mut [Self], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + for (sblk_idx, (block, x)) in group_for_quantization(xs, ys)?.into_iter().enumerate() { + //calculate scales and mins + let mut mins: [f32; QK_K / 16] = [0.0; QK_K / 16]; + let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16]; + let mut weights: [f32; 16] = [0.0; 16]; + let mut sw: [f32; QK_K / 16] = [0.0; QK_K / 16]; + let mut ls: [u8; QK_K / 16] = [0; QK_K / 16]; + let mut lm: [u8; QK_K / 16] = [0; QK_K / 16]; + + let sum_x2 = x.iter().map(|x| x * x).sum::(); + let sigma2 = sum_x2 / QK_K as f32; + for (j, x_scale_slice) in x.chunks_exact(16).enumerate() { + for (l, (w_elem, x_elem)) in weights.iter_mut().zip(x_scale_slice).enumerate() { + let imatrix_row = sblk_idx % (n_per_row / QK_K); + let imatrix_w = imatrix_weights[imatrix_row * QK_K + 16 * j + l]; + *w_elem = imatrix_w * (sigma2 + x_elem * x_elem).sqrt(); + } + let sumw = weights.iter().sum::(); + sw[j] = sumw; + (scales[j], mins[j]) = + make_qkx3_quants(3, x_scale_slice, Some(&weights), -0.9, 0.05, 36, false); + } + + let d_block = make_qp_quants(QK_K / 16, 15, &scales, &mut ls, &sw); + let m_block = make_qp_quants(QK_K / 16, 15, &mins, &mut lm, &sw); + + block.d = f16::from_f32(d_block); + block.dmin = f16::from_f32(m_block); + + for j in 0..QK_K / 16 { + block.scales[j] = ls[j] | (lm[j] << 4); + } + + let mut big_l: [u8; QK_K] = [0; QK_K]; + + for j in 0..QK_K / 16 { + let d = block.d.to_f32() * (block.scales[j] & 0xF) as f32; + if d == 0.0 { + continue; + } + let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32; + for ii in 0..16 { + let ll = nearest_int((x[16 * j + ii] + dm) / d).clamp(0, 3); + big_l[16 * j + ii] = ll as u8; + } + } + + for j in (0..QK_K).step_by(128) { + for ll in 0..32 { + block.qs[j / 4 + ll] = big_l[j + ll] + | (big_l[j + ll + 32] << 2) + | (big_l[j + ll + 64] << 4) + | (big_l[j + ll + 96] << 6); + } + } + } + Ok(()) + } + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { for (block, y) in group_for_dequantization(xs, ys)? { @@ -1091,6 +1169,110 @@ impl GgmlType for BlockQ3K { Ok(()) } + fn from_float_imatrix( + xs: &[f32], + ys: &mut [Self], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + for (sblk_idx, (block, x)) in group_for_quantization(xs, ys)?.into_iter().enumerate() { + let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16]; + let mut weights: [f32; 16] = [0.0; 16]; + let mut sw: [f32; QK_K / 16] = [0.0; QK_K / 16]; + let mut ls: [i8; QK_K / 16] = [0; QK_K / 16]; + let mut l: [i8; QK_K] = [0; QK_K]; + + let sum_x2 = x.iter().map(|x| x * x).sum::(); + let sigma2 = 2. * sum_x2 / QK_K as f32; + // let av_x = sigma2.sqrt(); + + for (j, x_scale_slice) in x.chunks_exact(16).enumerate() { + for (l, (w_elem, x_elem)) in weights.iter_mut().zip(x_scale_slice).enumerate() { + let imatrix_row = sblk_idx % (n_per_row / QK_K); + let imatrix_w = imatrix_weights[imatrix_row * QK_K + 16 * j + l]; + *w_elem = imatrix_w * (sigma2 + x_elem * x_elem).sqrt(); + } + let sumw = weights.iter().sum::(); + sw[j] = sumw; + scales[j] = unsafe { + make_qx_quants( + 16, + 4, + x_scale_slice.as_ptr(), + l.as_mut_ptr().add(16 * j), + 1, + weights.as_ptr(), + ) + }; + } + + block.scales.fill(0); + let d_block = unsafe { + make_qx_quants( + QK_K / 16, + 32, + scales.as_ptr(), + ls.as_mut_ptr(), + 1, + sw.as_ptr(), + ) + }; + block.d = f16::from_f32(d_block); + for (j, l) in ls.iter().enumerate().take(QK_K / 16) { + if j < 8 { + block.scales[j] = (l & 0xF) as u8; + } else { + block.scales[j - 8] |= ((l & 0xF) << 4) as u8; + } + let l = l >> 4; + block.scales[j % 4 + 8] |= (l << (2 * (j / 4))) as u8; + } + + for j in 0..QK_K / 16 { + let sc = if j < 8 { + block.scales[j] & 0xF + } else { + block.scales[j - 8] >> 4 + }; + let sc = (sc | (((block.scales[8 + j % 4] >> (2 * (j / 4))) & 3) << 4)) as i8 - 32; + let d = block.d.to_f32() * sc as f32; + if d != 0.0 { + for ii in 0..16 { + let l_val = nearest_int(x[16 * j + ii] / d); + l[16 * j + ii] = (l_val.clamp(-4, 3) + 4) as i8; + } + } + } + + block.hmask.fill(0); + let mut m = 0; + let mut hm = 1; + + for ll in l.iter_mut() { + if *ll > 3 { + block.hmask[m] |= hm; + *ll -= 4; + } + m += 1; + if m == QK_K / 8 { + m = 0; + hm <<= 1; + } + } + + for j in (0..QK_K).step_by(128) { + for l_val in 0..32 { + block.qs[j / 4 + l_val] = (l[j + l_val] + | (l[j + l_val + 32] << 2) + | (l[j + l_val + 64] << 4) + | (l[j + l_val + 96] << 6)) + as u8; + } + } + } + Ok(()) + } + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { const KMASK1: u32 = 0x03030303; @@ -1309,6 +1491,79 @@ impl GgmlType for BlockQ4K { } Ok(()) } + + fn from_float_imatrix( + xs: &[f32], + ys: &mut [Self], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + for (sblk_idx, (block, x)) in group_for_quantization(xs, ys)?.into_iter().enumerate() { + let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32]; + let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32]; + let mut weights: [f32; 32] = [0.0; 32]; + let mut sw: [f32; QK_K / 32] = [0.0; QK_K / 32]; + let mut ls: [u8; QK_K / 32] = [0; QK_K / 32]; + let mut lm: [u8; QK_K / 32] = [0; QK_K / 32]; + + let sum_x2 = x.iter().map(|x| x * x).sum::(); + let sigma2 = 2. * sum_x2 / QK_K as f32; + // let av_x = sigma2.sqrt(); + + for (j, x_scale_slice) in x.chunks_exact(32).enumerate() { + for (l, (w_elem, x_elem)) in weights.iter_mut().zip(x_scale_slice).enumerate() { + let imatrix_row = sblk_idx % (n_per_row / QK_K); + let imatrix_w = imatrix_weights[imatrix_row * QK_K + 32 * j + l]; + *w_elem = imatrix_w * (sigma2 + x_elem * x_elem).sqrt(); + } + let sumw = weights.iter().sum::(); + sw[j] = sumw; + (scales[j], mins[j]) = + make_qkx3_quants(15, x_scale_slice, Some(&weights), -0.9, 0.05, 36, false); + } + + let d_block = make_qp_quants(QK_K / 32, 63, &scales, &mut ls, &sw); + let m_block = make_qp_quants(QK_K / 32, 63, &mins, &mut lm, &sw); + for j in 0..QK_K / 32 { + let ls_val = ls[j]; + let lm_val = lm[j]; + if j < 4 { + block.scales[j] = ls_val; + block.scales[j + 4] = lm_val; + } else { + block.scales[j + 4] = (ls_val & 0xF) | ((lm_val & 0xF) << 4); + block.scales[j - 4] |= (ls_val >> 4) << 6; + block.scales[j] |= (lm_val >> 4) << 6; + } + } + + block.d = f16::from_f32(d_block); + block.dmin = f16::from_f32(m_block); + + let mut l: [u8; QK_K] = [0; QK_K]; + for j in 0..QK_K / 32 { + let (sc, m) = get_scale_min_k4(j, &block.scales); + let d = block.d.to_f32() * sc as f32; + if d != 0.0 { + let dm = block.dmin.to_f32() * m as f32; + for ii in 0..32 { + let l_val = nearest_int((x[32 * j + ii] + dm) / d); + l[32 * j + ii] = l_val.clamp(0, 15) as u8; + } + } + } + + let q = &mut block.qs; + for j in (0..QK_K).step_by(64) { + for l_val in 0..32 { + let offset_index = (j / 64) * 32 + l_val; + q[offset_index] = l[j + l_val] | (l[j + l_val + 32] << 4); + } + } + } + Ok(()) + } + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { for (block, y) in group_for_dequantization(xs, ys)? { @@ -1524,6 +1779,95 @@ impl GgmlType for BlockQ5K { Ok(()) } + fn from_float_imatrix( + xs: &[f32], + ys: &mut [Self], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + for (sblk_idx, (block, x)) in group_for_quantization(xs, ys)?.into_iter().enumerate() { + let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32]; + let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32]; + let mut weights: [f32; 32] = [0.0; 32]; + let mut sw: [f32; QK_K / 32] = [0.0; QK_K / 32]; + let mut ls: [u8; QK_K / 32] = [0; QK_K / 32]; + let mut lm: [u8; QK_K / 32] = [0; QK_K / 32]; + + let sum_x2 = x.iter().map(|x| x * x).sum::(); + let sigma2 = 2. * sum_x2 / QK_K as f32; + // let av_x = sigma2.sqrt(); + + for (j, x_scale_slice) in x.chunks_exact(32).enumerate() { + for (l, (w_elem, x_elem)) in weights.iter_mut().zip(x_scale_slice).enumerate() { + let imatrix_row = sblk_idx % (n_per_row / QK_K); + let imatrix_w = imatrix_weights[imatrix_row * QK_K + 32 * j + l]; + *w_elem = imatrix_w * (sigma2 + x_elem * x_elem).sqrt(); + } + let sumw = weights.iter().sum::(); + sw[j] = sumw; + (scales[j], mins[j]) = + make_qkx3_quants(31, x_scale_slice, Some(&weights), -0.9, 0.05, 36, false); + } + + let d_block = make_qp_quants(QK_K / 32, 63, &scales, &mut ls, &sw); + let m_block = make_qp_quants(QK_K / 32, 63, &mins, &mut lm, &sw); + for j in 0..QK_K / 32 { + let ls_val = ls[j].min(63); + let lm_val = lm[j].min(63); + if j < 4 { + block.scales[j] = ls_val; + block.scales[j + 4] = lm_val; + } else { + block.scales[j + 4] = (ls_val & 0xF) | ((lm_val & 0xF) << 4); + block.scales[j - 4] |= (ls_val >> 4) << 6; + block.scales[j] |= (lm_val >> 4) << 6; + } + } + + block.d = f16::from_f32(d_block); + block.dmin = f16::from_f32(m_block); + + let mut l: [u8; QK_K] = [0; QK_K]; + for j in 0..QK_K / 32 { + let (sc, m) = get_scale_min_k4(j, &block.scales); + let d = block.d.to_f32() * sc as f32; + if d != 0.0 { + let dm = block.dmin.to_f32() * m as f32; + for ii in 0..32 { + let l_val = nearest_int((x[32 * j + ii] + dm) / d); + l[32 * j + ii] = l_val.clamp(0, 31) as u8; + } + } + } + + let qh = &mut block.qh; + let ql = &mut block.qs; + qh.fill(0); + + let mut m1 = 1; + let mut m2 = 2; + for n in (0..QK_K).step_by(64) { + let offset = (n / 64) * 32; + for j in 0..32 { + let mut l1 = l[n + j]; + if l1 > 15 { + l1 -= 16; + qh[j] |= m1; + } + let mut l2 = l[n + j + 32]; + if l2 > 15 { + l2 -= 16; + qh[j] |= m2; + } + ql[offset + j] = l1 | (l2 << 4); + } + m1 <<= 2; + m2 <<= 2; + } + } + Ok(()) + } + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { for (block, y) in group_for_dequantization(xs, ys)? { @@ -1658,7 +2002,94 @@ impl GgmlType for BlockQ6K { let mut max_scale = 0f32; let mut max_abs_scale = 0f32; for (ib, scale_) in scales.iter_mut().enumerate() { - let scale = make_qx_quants(16, 32, x.add(16 * ib), l.add(16 * ib), 1); + let scale = + make_qx_quants(16, 32, x.add(16 * ib), l.add(16 * ib), 1, std::ptr::null()); + *scale_ = scale; + let abs_scale = scale.abs(); + if abs_scale > max_abs_scale { + max_abs_scale = abs_scale; + max_scale = scale + } + } + + let iscale = -128f32 / max_scale; + y.d = f16::from_f32(1.0 / iscale); + + for (y_scale, scale) in y.scales.iter_mut().zip(scales.iter()) { + *y_scale = nearest_int(iscale * scale).min(127) as i8 + } + + for (j, &y_scale) in y.scales.iter().enumerate() { + let d = y.d.to_f32() * y_scale as f32; + if d == 0. { + continue; + } + for ii in 0..16 { + let ll = nearest_int(*x.add(16 * j + ii) / d).clamp(-32, 31); + *l.add(16 * j + ii) = (ll + 32) as i8 + } + } + + let mut ql = y.ql.as_mut_ptr(); + let mut qh = y.qh.as_mut_ptr(); + + for j in (0..QK_K).step_by(128) { + for l_idx in 0..32 { + let q1 = *l.add(j + l_idx) & 0xF; + let q2 = *l.add(j + l_idx + 32) & 0xF; + let q3 = *l.add(j + l_idx + 64) & 0xF; + let q4 = *l.add(j + l_idx + 96) & 0xF; + *ql.add(l_idx) = (q1 | (q3 << 4)) as u8; + *ql.add(l_idx + 32) = (q2 | (q4 << 4)) as u8; + *qh.add(l_idx) = ((*l.add(j + l_idx) >> 4) + | ((*l.add(j + l_idx + 32) >> 4) << 2) + | ((*l.add(j + l_idx + 64) >> 4) << 4) + | ((*l.add(j + l_idx + 96) >> 4) << 6)) + as u8; + } + ql = ql.add(64); + qh = qh.add(32); + } + + x = x.add(QK_K) + } + } + Ok(()) + } + + fn from_float_imatrix( + xs: &[f32], + ys: &mut [Self], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + if xs.len() != ys.len() * Self::BLCK_SIZE { + crate::bail!( + "quantize_row_q6k: size mismatch {} {} {}", + xs.len(), + ys.len(), + Self::BLCK_SIZE + ) + } + let mut l = [0i8; QK_K]; + let mut scales = [0f32; QK_K / 16]; + let mut x = xs.as_ptr(); + let imatrix_weights = imatrix_weights.as_ptr(); + let l = l.as_mut_ptr(); + unsafe { + for (sblk_idx, y) in ys.iter_mut().enumerate() { + let mut max_scale = 0f32; + let mut max_abs_scale = 0f32; + for (ib, scale_) in scales.iter_mut().enumerate() { + let imatrix_row = sblk_idx % (n_per_row / QK_K); + let scale = make_qx_quants( + 16, + 32, + x.add(16 * ib), + l.add(16 * ib), + 1, + imatrix_weights.add(QK_K * imatrix_row + 16 * ib), + ); *scale_ = scale; let abs_scale = scale.abs(); if abs_scale > max_abs_scale { @@ -1882,6 +2313,52 @@ pub fn matmul( Ok(()) } +pub fn matmul_f16( + mkn: (usize, usize, usize), + lhs: &[f16], + rhs_t: &[T], + dst: &mut [f16], +) -> Result<()> { + let (m, k, n) = mkn; + if m * k != lhs.len() { + crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); + } + + let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE); + let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE); + // TODO: Do not make this copy if the DotType is f32. + // TODO: Pre-allocate this. + let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; + for row_idx in 0..m { + let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; + T::VecDotType::from_float( + &lhs.into_iter().map(|&x| x.to_f32()).collect::>(), + lhs_b, + )? + } + let lhs_b = lhs_b.as_slice(); + + for row_idx in 0..m { + let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; + + let result: Result> = dst_row + .into_par_iter() + .enumerate() + .with_min_len(128) + .with_max_len(512) + .map(|(col_idx, dst)| { + let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = f16::from_f32(value)) + }) + .collect(); + + result?; + } + Ok(()) +} + impl GgmlType for f32 { const DTYPE: GgmlDType = GgmlDType::F32; const BLCK_SIZE: usize = 1; @@ -1963,3 +2440,47 @@ impl GgmlType for f16 { Ok(()) } } + +impl GgmlType for bf16 { + const DTYPE: GgmlDType = GgmlDType::BF16; + const BLCK_SIZE: usize = 1; + type VecDotType = bf16; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if xs.len() < n { + crate::bail!("size mismatch {} < {n}", xs.len()) + } + if ys.len() < n { + crate::bail!("size mismatch {} < {n}", ys.len()) + } + let mut res = 0f32; + unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; + Ok(res) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = bf16::from_f32(*x) + } + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = x.to_f32() + } + Ok(()) + } +} diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index f7f5b68ac2..d06b0674a7 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -1,6 +1,6 @@ use super::{GgmlDType, QStorage}; use crate::backend::BackendStorage; -use crate::{DType, MetalDevice, MetalStorage, Result, Shape}; +use crate::{DType, MetalDevice, MetalStorage, Result, Shape, D}; use metal::Buffer; use std::sync::Arc; @@ -55,6 +55,10 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); half::f16::to_float(&vec, &mut out)?; } + GgmlDType::BF16 => { + let vec: Vec = read_to_vec(&buffer, block_len); + half::bf16::to_float(&vec, &mut out)?; + } GgmlDType::Q4_0 => { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; @@ -126,11 +130,65 @@ impl QMetalStorage { Ok(()) } + pub fn quantize_imatrix( + &mut self, + src: &MetalStorage, + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + // Quantization only happens on CPU for now. + let src = src.to_cpu::()?; + let elem_count = src.len(); + let src = crate::Storage::Cpu(crate::CpuStorage::F32(src)); + let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; + qcpu_storage.quantize_imatrix(&src, imatrix_weights, n_per_row)?; + let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?; + self.buffer = buffer; + Ok(()) + } + + pub fn quantize_imatrix_onto( + &mut self, + src: &crate::CpuStorage, + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + // Quantization only happens on CPU for now. + let elem_count = src.as_slice::()?.len(); + let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; + + if let QStorage::Cpu(storage) = &mut qcpu_storage { + storage.from_float_imatrix(src.as_slice::()?, imatrix_weights, n_per_row)?; + } else { + unreachable!() + } + + let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?; + self.buffer = buffer; + Ok(()) + } + + pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> { + // Quantization only happens on CPU for now. + let elem_count = src.as_slice::()?.len(); + let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; + + if let QStorage::Cpu(storage) = &mut qcpu_storage { + storage.from_float(src.as_slice::()?)?; + } else { + unreachable!() + } + + let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?; + self.buffer = buffer; + Ok(()) + } + pub fn storage_size_in_bytes(&self) -> usize { self.buffer.length() as usize } - pub fn fwd( + fn fwd_mv( &self, self_shape: &Shape, storage: &MetalStorage, @@ -186,6 +244,112 @@ impl QMetalStorage { let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32); Ok((dst_storage, dst_shape)) } + + pub fn fwd( + &self, + self_shape: &Shape, + storage: &MetalStorage, + layout: &crate::Layout, + ) -> Result<(MetalStorage, Shape)> { + use crate::MetalError; + + if !layout.is_contiguous() { + crate::bail!("input tensor is not contiguous {layout:?}") + } + let src_shape = layout.shape(); + // self is transposed so n is first then k. + if src_shape.rank() < 2 { + crate::bail!("input tensor has only one dimension {layout:?}") + } + let n = self_shape.dim(D::Minus2)?; + let k = self_shape.dim(D::Minus1)?; + let mut dst_shape = src_shape.dims().to_vec(); + + if src_shape.rank() < self_shape.rank() { + crate::bail!( + "input rank ({}) must be >= weight rank ({})", + src_shape.rank(), + self_shape.rank() + ) + } + + if src_shape.dim(D::Minus2)? == 1 { + return self.fwd_mv(self_shape, storage, layout); + } + + let last_k = dst_shape.pop().unwrap(); + if last_k != k { + crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape) + } + dst_shape.push(n); + let dst_shape = Shape::from(dst_shape); + let device = storage.device().clone(); + let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?; + let command_buffer = device.command_buffer()?; + + assert_eq!(storage.dtype(), DType::F32); + + if self_shape.rank() > 4 { + crate::bail!("weight rank ({}) must be <= 4", self_shape.rank()) + } + let src0_l = crate::Layout::contiguous( + [vec![1; 4 - self_shape.rank()], self_shape.dims().to_vec()].concat(), + ); + let src0_stride = src0_l + .stride() + .iter() + .map(|x| { + (*x as f32 * (self.dtype.type_size() as f32 / self.dtype.block_size() as f32)) + as usize + }) + .collect::>(); + + if src_shape.rank() > 4 { + crate::bail!("weight rank ({}) must be <= 4", src_shape.rank()) + } + let src1_l = crate::Layout::contiguous( + [vec![1; 4 - src_shape.rank()], src_shape.dims().to_vec()].concat(), + ); + + candle_metal_kernels::call_quantized_matmul_mm_t( + device.device(), + &command_buffer, + device.kernels(), + self.dtype.into(), + src0_l.dims(), + &src0_stride, + &self.buffer, + src1_l.dims(), + &src1_l + .stride() + .iter() + .map(|x| x * DType::F32.size_in_bytes()) + .collect::>(), + storage.buffer(), + src1_l.start_offset() * storage.dtype().size_in_bytes(), + dst_shape.dims(), + 0, + &dst, + ) + .map_err(MetalError::from)?; + + let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32); + Ok((dst_storage, dst_shape)) + } + + pub fn data(&self) -> Result> { + let buffer = self.device.new_buffer_managed(self.buffer.length())?; + { + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("to_cpu"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + } + self.device.wait_until_completed()?; + Ok(read_to_vec::(&buffer, self.buffer.length() as usize)) + } } pub fn load_quantized( @@ -225,6 +389,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, + GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16, } } } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 802c5691f0..a39f864ded 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,5 +1,6 @@ -//! Code for GGML and GGUF files -use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor}; +use crate::{ + backend::BackendStorage, CpuStorage, DType, Device, Result, Shape, Storage, Tensor, D, +}; use k_quants::*; use std::borrow::Cow; @@ -9,6 +10,7 @@ mod dummy_cuda; mod dummy_metal; pub mod ggml_file; pub mod gguf_file; +pub mod imatrix_file; pub mod k_quants; #[cfg(feature = "metal")] pub mod metal; @@ -28,10 +30,26 @@ pub mod neon; #[cfg(target_feature = "simd128")] pub mod simd128; pub mod utils; -use half::f16; +use half::{bf16, f16}; pub use k_quants::GgmlType; +fn as_t_slice(data: Cow<'_, [u8]>) -> &[T] { + let size = std::mem::size_of::(); + assert_eq!( + data.len() % size, + 0, + "Data length must be a multiple of T's size" + ); + let ptr = data.as_ptr(); + assert_eq!( + (ptr as usize) % std::mem::align_of::(), + 0, + "Data pointer must be aligned to T's alignment" + ); + unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) } +} + pub struct QTensor { storage: QStorage, shape: Shape, @@ -63,6 +81,46 @@ pub enum QStorage { } impl QStorage { + pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result { + match device { + Device::Cpu => Ok(Self::Cpu(dtype.from_data(data))), + Device::Metal(d) => match dtype { + GgmlDType::F32 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::F16 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::(data)), + }, + Device::Cuda(d) => match dtype { + GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::(data)), + }, + } + } + fn block_size(&self) -> usize { match self { QStorage::Cpu(storage) => storage.block_size(), @@ -102,7 +160,61 @@ impl QStorage { } (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?, - _ => crate::bail!("Invalid dequantize storage locations do not match"), + _ => crate::bail!("Invalid quantize storage locations do not match"), + } + Ok(()) + } + + fn quantize_imatrix( + &mut self, + src: &Storage, + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + match (self, src) { + (QStorage::Cpu(storage), Storage::Cpu(src)) => { + storage.from_float_imatrix(src.as_slice::()?, imatrix_weights, n_per_row)?; + } + (QStorage::Metal(storage), Storage::Metal(src)) => { + storage.quantize_imatrix(src, imatrix_weights, n_per_row)? + } + (QStorage::Cuda(storage), Storage::Cuda(src)) => { + storage.quantize_imatrix(src, imatrix_weights, n_per_row)? + } + _ => crate::bail!("Invalid quantize storage locations do not match"), + } + Ok(()) + } + + fn quantize_onto(&mut self, src: &Storage) -> Result<()> { + match (self, src) { + (QStorage::Cpu(storage), Storage::Cpu(src)) => { + storage.from_float(src.as_slice::()?)?; + } + (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?, + (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?, + _ => crate::bail!("Invalid quantize source storage locations: not on cpu"), + } + Ok(()) + } + + fn quantize_imatrix_onto( + &mut self, + src: &Storage, + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + match (self, src) { + (QStorage::Cpu(storage), Storage::Cpu(src)) => { + storage.from_float_imatrix(src.as_slice::()?, imatrix_weights, n_per_row)?; + } + (QStorage::Metal(storage), Storage::Cpu(src)) => { + storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)? + } + (QStorage::Cuda(storage), Storage::Cpu(src)) => { + storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)? + } + _ => crate::bail!("Invalid quantize storage locations do not match"), } Ok(()) } @@ -123,9 +235,8 @@ impl QStorage { let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; Ok(Cow::from(data)) } - QStorage::Metal(_) | QStorage::Cuda(_) => { - crate::bail!("not implemented"); - } + QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)), + QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)), } } } @@ -134,6 +245,7 @@ impl QStorage { pub enum GgmlDType { F32, F16, + BF16, Q4_0, Q4_1, Q5_0, @@ -165,6 +277,8 @@ impl GgmlDType { 13 => Self::Q5K, 14 => Self::Q6K, 15 => Self::Q8K, + // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 + 30 => Self::BF16, _ => crate::bail!("unknown dtype for tensor {u}"), }; Ok(dtype) @@ -186,6 +300,8 @@ impl GgmlDType { Self::Q5K => 13, Self::Q6K => 14, Self::Q8K => 15, + // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 + Self::BF16 => 30, } } @@ -206,14 +322,36 @@ impl GgmlDType { Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), + Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]), + } + } + + pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box { + match self { + Self::F32 => Box::new(as_t_slice::(data).to_vec()), + Self::F16 => Box::new(as_t_slice::(data).to_vec()), + Self::Q4_0 => Box::new(as_t_slice::(data).to_vec()), + Self::Q4_1 => Box::new(as_t_slice::(data).to_vec()), + Self::Q5_0 => Box::new(as_t_slice::(data).to_vec()), + Self::Q5_1 => Box::new(as_t_slice::(data).to_vec()), + Self::Q8_0 => Box::new(as_t_slice::(data).to_vec()), + Self::Q8_1 => Box::new(as_t_slice::(data).to_vec()), + Self::Q2K => Box::new(as_t_slice::(data).to_vec()), + Self::Q3K => Box::new(as_t_slice::(data).to_vec()), + Self::Q4K => Box::new(as_t_slice::(data).to_vec()), + Self::Q5K => Box::new(as_t_slice::(data).to_vec()), + Self::Q6K => Box::new(as_t_slice::(data).to_vec()), + Self::Q8K => Box::new(as_t_slice::(data).to_vec()), + Self::BF16 => Box::new(as_t_slice::(data).to_vec()), } } + /// The type size for blocks in bytes. pub fn type_size(&self) -> usize { use k_quants::*; match self { Self::F32 => 4, - Self::F16 => 2, + Self::F16 | Self::BF16 => 2, Self::Q4_0 => std::mem::size_of::(), Self::Q4_1 => std::mem::size_of::(), Self::Q5_0 => std::mem::size_of::(), @@ -234,7 +372,7 @@ impl GgmlDType { pub fn block_size(&self) -> usize { match self { Self::F32 => 1, - Self::F16 => 1, + Self::F16 | Self::BF16 => 1, Self::Q4_0 => k_quants::QK4_0, Self::Q4_1 => k_quants::QK4_1, Self::Q5_0 => k_quants::QK5_0, @@ -250,12 +388,20 @@ impl GgmlDType { pub trait QuantizedType: Send + Sync { fn dtype(&self) -> GgmlDType; fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>; + fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()>; fn dequantize(&self, elem_count: usize) -> Result; fn storage_size_in_bytes(&self) -> usize; fn as_ptr(&self) -> *const u8; fn block_size(&self) -> usize; #[allow(clippy::wrong_self_convention)] fn from_float(&mut self, xs: &[f32]) -> Result<()>; + #[allow(clippy::wrong_self_convention)] + fn from_float_imatrix( + &mut self, + xs: &[f32], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()>; fn size(&self) -> usize; } @@ -263,6 +409,9 @@ impl QuantizedType for Vec { fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> { k_quants::matmul(mkn, lhs, self.as_slice(), dst) } + fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()> { + k_quants::matmul_f16(mkn, lhs, self.as_slice(), dst) + } fn size(&self) -> usize { self.len() * core::mem::size_of::() @@ -272,6 +421,15 @@ impl QuantizedType for Vec { T::from_float(xs, self) } + fn from_float_imatrix( + &mut self, + xs: &[f32], + imatrix_weights: &[f32], + n_per_row: usize, + ) -> Result<()> { + T::from_float_imatrix(xs, self, imatrix_weights, n_per_row) + } + fn dtype(&self) -> GgmlDType { T::DTYPE } @@ -342,6 +500,112 @@ impl QTensor { }) } + pub fn quantize_imatrix( + src: &Tensor, + imatrix_weights: &[f32], + dtype: GgmlDType, + ) -> Result { + // (n_per_row/QK_K-1)*QK_K+(QK_K/32-1)*32+32=n_per_row + // Size of imatrix == last dim of tensor + let n_per_row = src.dim(D::Minus1)?; + if imatrix_weights.len() != n_per_row { + crate::bail!( + "imatrix weights must have the same length {} as the last dim of src {}", + imatrix_weights.len(), + src.dim(D::Minus1)? + ); + } + + let shape = src.shape(); + let block_size = dtype.block_size(); + check_shape(shape, block_size)?; + let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; + let elem_count = shape.elem_count(); + if elem_count % block_size != 0 { + crate::bail!( + "tensor size ({shape:?}) is not divisible by block size {}", + block_size + ); + } + let mut storage = src.device().qzeros(elem_count, dtype)?; + storage.quantize_imatrix(&src.storage(), imatrix_weights, n_per_row)?; + Ok(Self { + storage, + shape: shape.clone(), + }) + } + + /// Quantize `src` (currently on the CPU) to a QTensor on `dev` + pub fn quantize_imatrix_onto( + src: &Tensor, + imatrix_weights: &[f32], + dtype: GgmlDType, + dev: &Device, + ) -> Result { + if !src.device().is_cpu() { + crate::bail!( + "`quantize_onto` expects a `src` to be on the cpu, got {:?}.", + src.device() + ) + } + // (n_per_row/QK_K-1)*QK_K+(QK_K/32-1)*32+32=n_per_row + // Size of imatrix == last dim of tensor + let n_per_row = src.dim(D::Minus1)?; + if imatrix_weights.len() != n_per_row { + crate::bail!( + "imatrix weights must have the same length {} as the last dim of src {}", + imatrix_weights.len(), + src.dim(D::Minus1)? + ); + } + let shape = src.shape(); + let block_size = dtype.block_size(); + check_shape(shape, block_size)?; + let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; + let elem_count = shape.elem_count(); + if elem_count % block_size != 0 { + crate::bail!( + "tensor size ({shape:?}) is not divisible by block size {}", + block_size + ) + } + // storage is on the `dev`, src is on `cpu` + let mut storage = dev.qzeros(elem_count, dtype)?; + storage.quantize_imatrix_onto(&src.storage(), imatrix_weights, n_per_row)?; + Ok(Self { + storage, + shape: shape.clone(), + }) + } + + /// Quantize `src` (currently on the CPU) to a QTensor on `dev` + pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result { + if !src.device().is_cpu() { + crate::bail!( + "`quantize_onto` expects a `src` to be on the cpu, got {:?}.", + src.device() + ) + } + let shape = src.shape(); + let block_size = dtype.block_size(); + check_shape(shape, block_size)?; + let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; + let elem_count = shape.elem_count(); + if elem_count % block_size != 0 { + crate::bail!( + "tensor size ({shape:?}) is not divisible by block size {}", + block_size + ) + } + // storage is on the `dev`, src is on `cpu` + let mut storage = dev.qzeros(elem_count, dtype)?; + storage.quantize_onto(&src.storage())?; + Ok(Self { + storage, + shape: shape.clone(), + }) + } + pub fn dtype(&self) -> GgmlDType { self.storage.dtype() } @@ -422,7 +686,7 @@ thread_local! { impl QMatMul { pub fn from_arc(qtensor: std::sync::Arc) -> Result { let dequantize = match qtensor.dtype() { - GgmlDType::F32 | GgmlDType::F16 => true, + GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true, _ => DEQUANTIZE_ALL.with(|b| *b), }; let t = if dequantize { @@ -481,7 +745,7 @@ impl crate::CustomOp1 for QTensor { crate::bail!("input tensor has only one dimension {layout:?}") } let mut dst_shape = src_shape.dims().to_vec(); - let last_k = dst_shape.pop().context("empty dst_shape")?; + let last_k = dst_shape.pop().unwrap(); if last_k != k { crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape) } @@ -492,11 +756,33 @@ impl crate::CustomOp1 for QTensor { QStorage::Cpu(storage) => storage, QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"), }; - let slice = storage.as_slice::()?; - let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; - let mut dst_storage = vec![0f32; dst_shape.elem_count()]; - self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?; - Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) + match storage.dtype() { + DType::F32 => { + let slice = storage.as_slice::()?; + let slice = + &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; + let mut dst_storage = vec![0f32; dst_shape.elem_count()]; + self_storage.matmul_t( + (dst_shape.elem_count() / n, k, n), + slice, + &mut dst_storage, + )?; + Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) + } + DType::F16 => { + let slice = storage.as_slice::()?; + let slice = + &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; + let mut dst_storage = vec![f16::ZERO; dst_shape.elem_count()]; + self_storage.matmul_t_f16( + (dst_shape.elem_count() / n, k, n), + slice, + &mut dst_storage, + )?; + Ok((crate::CpuStorage::F16(dst_storage), dst_shape)) + } + _ => crate::bail!("Expected f32/f16"), + } } fn metal_fwd( diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/utils.rs index fa6eff51d3..0a087cddbb 100644 --- a/candle-core/src/quantized/utils.rs +++ b/candle-core/src/quantized/utils.rs @@ -64,6 +64,7 @@ pub(super) unsafe fn make_qx_quants( x: *const f32, ls: *mut i8, rmse_type: i32, + qw: *const f32, ) -> f32 { let mut max = 0f32; let mut amax = 0f32; @@ -99,7 +100,13 @@ pub(super) unsafe fn make_qx_quants( let l = nearest_int(iscale * x); let l = l.clamp(-nmax, nmax - 1); *ls.add(i) = (l + nmax) as i8; - let w = if weight_type == 1 { x * x } else { 1.0 }; + let w = if !qw.is_null() { + *qw.add(i) + } else if weight_type == 1 { + x * x + } else { + 1.0 + }; let l = l as f32; sumlx += w * x * l; suml2 += w * l * l; @@ -118,7 +125,13 @@ pub(super) unsafe fn make_qx_quants( if l + nmax != *ls.add(i) as i32 { changed = true; } - let w = if weight_type == 1 { x * x } else { 1f32 }; + let w = if !qw.is_null() { + *qw.add(i) + } else if weight_type == 1 { + x * x + } else { + 1.0 + }; let l = l as f32; slx += w * x * l; sl2 += w * l * l; @@ -140,7 +153,13 @@ pub(super) unsafe fn make_qx_quants( let mut n_changed = 0; for i in 0..n { let x = *x.add(i); - let w = if weight_type == 1 { x * x } else { 1. }; + let w = if !qw.is_null() { + *qw.add(i) + } else if weight_type == 1 { + x * x + } else { + 1.0 + }; let l = *ls.add(i) as i32 - nmax; let mut slx = sumlx - w * x * l as f32; if slx > 0. { @@ -179,7 +198,13 @@ pub(super) unsafe fn make_qx_quants( let x = *x.add(i); let l = nearest_int(iscale * x); let l = l.clamp(-nmax, nmax - 1); - let w = if weight_type == 1 { x * x } else { 1. }; + let w = if !qw.is_null() { + *qw.add(i) + } else if weight_type == 1 { + x * x + } else { + 1.0 + }; let l = l as f32; sumlx += w * x * l; suml2 += w * l * l; @@ -198,6 +223,7 @@ pub(super) unsafe fn make_qx_quants( } // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L224 +/// (scale, min) pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) { let n = x.len(); let mut l = vec![0; n]; @@ -324,3 +350,213 @@ pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 { } 1.0 / iscale } + +// https://github.com/ggerganov/llama.cpp/blob/678d7994f4da0af3d29046be99950ac999ee9762/ggml/src/ggml-quants.c#L744 +/// (scale, min) +pub(super) fn make_qkx3_quants( + nmax: i32, + x: &[f32], + weights: Option<&[f32]>, + rmin: f32, + rdelta: f32, + nstep: usize, + use_mad: bool, +) -> (f32, f32) { + let n = x.len(); + let mut l: [u8; 32] = [0; 32]; + let mut l_aux: [u8; 32] = [0; 32]; + + let mut min_val = x[0]; + let mut max_val = x[0]; + let mut sum_w = match weights { + Some(w) => w[0], + None => x[0] * x[0], + }; + let mut sum_x = sum_w * x[0]; + + for i in 1..n { + if x[i] < min_val { + min_val = x[i]; + } + if x[i] > max_val { + max_val = x[i]; + } + let w = match weights { + Some(w) => w[i], + None => x[i] * x[i], + }; + sum_w += w; + sum_x += w * x[i]; + } + + if min_val > 0.0 { + min_val = 0.0; + } + + if max_val <= min_val { + return (0.0, -min_val); + } + + let mut iscale = nmax as f32 / (max_val - min_val); + let mut scale = 1.0 / iscale; + let mut best_mad = 0.0; + + for i in 0..n { + let l_val = nearest_int(iscale * (x[i] - min_val)).clamp(0, nmax) as u8; + l[i] = l_val; + let diff = scale * (l_val as f32) + min_val - x[i]; + let diff = if use_mad { diff.abs() } else { diff * diff }; + let w = match weights { + Some(w) => w[i], + None => x[i] * x[i], + }; + best_mad += w * diff; + } + + if nstep < 1 { + return (scale, -min_val); + } + + for is in 0..=nstep { + iscale = (rmin + rdelta * is as f32 + nmax as f32) / (max_val - min_val); + let (mut sum_l, mut sum_l2, mut sum_xl) = (0.0, 0.0, 0.0); + + for i in 0..n { + let l_val = nearest_int(iscale * (x[i] - min_val)).clamp(0, nmax) as u8; + l_aux[i] = l_val; + let w = match weights { + Some(w) => w[i], + None => x[i] * x[i], + }; + sum_l += w * l_val as f32; + sum_l2 += w * (l_val as f32).powi(2); + sum_xl += w * l_val as f32 * x[i]; + } + + let d = sum_w * sum_l2 - sum_l * sum_l; + if d > 0.0 { + let mut this_scale = (sum_w * sum_xl - sum_x * sum_l) / d; + let mut this_min = (sum_l2 * sum_x - sum_l * sum_xl) / d; + + if this_min > 0.0 { + this_min = 0.0; + this_scale = sum_xl / sum_l2; + } + + let mut mad = 0.0; + for i in 0..n { + let diff = this_scale * (l_aux[i] as f32) + this_min - x[i]; + let diff = if use_mad { diff.abs() } else { diff * diff }; + let w = match weights { + Some(w) => w[i], + None => x[i] * x[i], + }; + mad += w * diff; + } + + if mad < best_mad { + l.copy_from_slice(&l_aux); + best_mad = mad; + scale = this_scale; + min_val = this_min; + } + } + } + + (scale, -min_val) +} + +// https://github.com/ggerganov/llama.cpp/blob/678d7994f4da0af3d29046be99950ac999ee9762/ggml/src/ggml-quants.c#L827 +pub(super) fn make_qp_quants( + n: usize, + nmax: u8, + x: &[f32], + l: &mut [u8], + quant_weights: &[f32], +) -> f32 { + assert_eq!(x.len(), n); + assert_eq!(l.len(), n); + assert_eq!(quant_weights.len(), n); + + let max = x.iter().copied().fold(0.0, f32::max); + if max == 0.0 { + l.iter_mut().for_each(|li| *li = 0); + return 0.0; + } + + let mut iscale = nmax as f32 / max; + for (xi, li) in x.iter().zip(l.iter_mut()) { + *li = nearest_int(iscale * xi) as u8; + } + + let scale = 1.0 / iscale; + let mut best_mse = x + .iter() + .zip(l.iter()) + .zip(quant_weights.iter()) + .map(|((&xi, &li), &w)| { + let diff = xi - scale * li as f32; + w * diff * diff + }) + .sum::(); + + for is in -4..=4 { + if is == 0 { + continue; + } + let iscale_is = (0.1 * is as f32 + nmax as f32) / max; + let scale_is = 1.0 / iscale_is; + + let mse = x + .iter() + .zip(quant_weights.iter()) + .map(|(&xi, &w)| { + let mut li = nearest_int(iscale_is * xi) as u8; + li = li.min(nmax); + let diff = xi - scale_is * li as f32; + w * diff * diff + }) + .sum::(); + + if mse < best_mse { + best_mse = mse; + iscale = iscale_is; + } + } + + let mut sumlx = 0.0; + let mut suml2 = 0.0; + for ((xi, li), &w) in x.iter().zip(l.iter_mut()).zip(quant_weights.iter()) { + let mut li_new = (iscale * xi).round() as u8; + li_new = li_new.min(nmax); + *li = li_new; + sumlx += w * xi * li_new as f32; + suml2 += w * (li_new as f32).powi(2); + } + + for _ in 0..5 { + let mut n_changed = 0; + for ((xi, li), &w) in x.iter().zip(l.iter_mut()).zip(quant_weights.iter()) { + let mut slx = sumlx - w * xi * *li as f32; + let mut sl2 = suml2 - w * (*li as f32).powi(2); + if slx > 0.0 && sl2 > 0.0 { + let new_li = (nearest_int(xi * sl2 / slx) as u8).min(nmax); + if new_li != *li { + slx += w * xi * new_li as f32; + sl2 += w * (new_li as f32).powi(2); + if slx.powi(2) * suml2 > sumlx.powi(2) * sl2 { + *li = new_li; + sumlx = slx; + suml2 = sl2; + n_changed += 1; + } + } + } + } + if n_changed == 0 { + break; + } + } + + sumlx / suml2 +} diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index d402d6b8e0..e1f77b927d 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -9,6 +9,9 @@ //! Tensors can also be serialized to safetensor format using the `save` function or //! `Tensor::save_safetensors` method. //! +use crate::op::BackpropOp; +use crate::storage::Storage; +use crate::tensor::from_storage; use crate::{DType, Device, Error, Result, Tensor, WithDType}; use safetensors::tensor as st; use safetensors::tensor::SafeTensors; @@ -21,11 +24,18 @@ impl From for st::Dtype { match value { DType::U8 => st::Dtype::U8, DType::U32 => st::Dtype::U32, + DType::I16 => st::Dtype::I16, + DType::I32 => st::Dtype::I32, DType::I64 => st::Dtype::I64, DType::BF16 => st::Dtype::BF16, DType::F16 => st::Dtype::F16, DType::F32 => st::Dtype::F32, DType::F64 => st::Dtype::F64, + DType::F8E4M3 => st::Dtype::F8_E4M3, + DType::F6E2M3 => st::Dtype::F6_E2M3, + DType::F6E3M2 => st::Dtype::F6_E3M2, + DType::F4 => st::Dtype::F4, + DType::F8E8M0 => st::Dtype::F8_E8M0, } } } @@ -36,11 +46,18 @@ impl TryFrom for DType { match value { st::Dtype::U8 => Ok(DType::U8), st::Dtype::U32 => Ok(DType::U32), + st::Dtype::I16 => Ok(DType::I16), + st::Dtype::I32 => Ok(DType::I32), st::Dtype::I64 => Ok(DType::I64), st::Dtype::BF16 => Ok(DType::BF16), st::Dtype::F16 => Ok(DType::F16), st::Dtype::F32 => Ok(DType::F32), st::Dtype::F64 => Ok(DType::F64), + st::Dtype::F8_E4M3 => Ok(DType::F8E4M3), + st::Dtype::F6_E2M3 => Ok(DType::F6E2M3), + st::Dtype::F6_E3M2 => Ok(DType::F6E3M2), + st::Dtype::F4 => Ok(DType::F4), + st::Dtype::F8_E8M0 => Ok(DType::F8E8M0), dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } @@ -91,7 +108,7 @@ impl st::View for &Tensor { impl Tensor { pub fn save_safetensors>(&self, name: &str, filename: P) -> Result<()> { let data = [(name, self.clone())]; - Ok(st::serialize_to_file(data, &None, filename.as_ref())?) + Ok(st::serialize_to_file(data, None, filename.as_ref())?) } } @@ -198,11 +215,70 @@ impl Tensor { match dtype { DType::U8 => convert_slice::(data, shape, device), DType::U32 => convert_slice::(data, shape, device), + DType::I16 => convert_slice::(data, shape, device), + DType::I32 => convert_slice::(data, shape, device), DType::I64 => convert_slice::(data, shape, device), DType::BF16 => convert_slice::(data, shape, device), DType::F16 => convert_slice::(data, shape, device), DType::F32 => convert_slice::(data, shape, device), DType::F64 => convert_slice::(data, shape, device), + DType::F8E4M3 => convert_slice::(data, shape, device), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + // For dummy types, create storage with raw bytes + let storage = match device { + Device::Cpu => { + let cpu_storage = match dtype { + DType::F6E2M3 => crate::cpu_backend::CpuStorage::F6E2M3(data.to_vec()), + DType::F6E3M2 => crate::cpu_backend::CpuStorage::F6E3M2(data.to_vec()), + DType::F4 => crate::cpu_backend::CpuStorage::F4(data.to_vec()), + DType::F8E8M0 => crate::cpu_backend::CpuStorage::F8E8M0(data.to_vec()), + _ => unreachable!(), + }; + Storage::Cpu(cpu_storage) + } + #[cfg(feature = "cuda")] + Device::Cuda(device) => { + let mut slice = unsafe { device.alloc::(data.len())? }; + device.memcpy_htod(data, &mut slice)?; + + let slice = match dtype { + DType::F6E2M3 => crate::cuda_backend::CudaStorageSlice::F6E2M3(slice), + DType::F6E3M2 => crate::cuda_backend::CudaStorageSlice::F6E3M2(slice), + DType::F4 => crate::cuda_backend::CudaStorageSlice::F4(slice), + DType::F8E8M0 => crate::cuda_backend::CudaStorageSlice::F8E8M0(slice), + _ => unreachable!(), + }; + let storage = crate::cuda_backend::CudaStorage { + slice, + device: device.clone(), + }; + Storage::Cuda(storage) + } + #[cfg(not(feature = "cuda"))] + Device::Cuda(_) => { + return Err(Error::Msg("CUDA support not compiled".to_string())); + } + #[cfg(feature = "metal")] + Device::Metal(device) => { + let buffer = device.new_buffer_with_data(data)?; + + let storage = crate::metal_backend::MetalStorage::new( + buffer, + device.clone(), + data.len(), + dtype, + ); + Storage::Metal(storage) + } + #[cfg(not(feature = "metal"))] + Device::Metal(_) => { + return Err(Error::Msg("Metal support not compiled".to_string())); + } + }; + + let op = BackpropOp::none(); + Ok(from_storage(storage, shape, op, false)) + } } } } @@ -215,30 +291,109 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { convert_with_cast_::(view, device, conv) } st::Dtype::U32 => convert_::(view, device), - st::Dtype::I32 => { - let conv = |x| Ok(i64::from(x)); - convert_with_cast_::(view, device, conv) - } + st::Dtype::I16 => convert_::(view, device), + st::Dtype::I32 => convert_::(view, device), st::Dtype::I64 => convert_::(view, device), st::Dtype::BF16 => convert_::(view, device), st::Dtype::F16 => convert_::(view, device), st::Dtype::F32 => convert_::(view, device), st::Dtype::F64 => convert_::(view, device), + st::Dtype::F8_E4M3 => convert_::(view, device), + st::Dtype::F6_E2M3 | st::Dtype::F6_E3M2 | st::Dtype::F4 | st::Dtype::F8_E8M0 => { + // For dummy types, we need to handle loading by creating a dummy tensor + // Since these types don't have actual data representation, we'll create + // a tensor that indicates it's a dummy type + convert_dummy(view, device) + } dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } +fn convert_dummy(view: &st::TensorView<'_>, device: &Device) -> Result { + // For dummy types, we'll create the appropriate storage variant that preserves + // both the raw data and the correct dtype + let (dtype, _dtype_name) = match view.dtype() { + st::Dtype::F6_E2M3 => (DType::F6E2M3, "F6_E2M3 (MX6)"), + st::Dtype::F6_E3M2 => (DType::F6E3M2, "F6_E3M2 (MX6)"), + st::Dtype::F4 => (DType::F4, "F4 (MX4)"), + st::Dtype::F8_E8M0 => (DType::F8E8M0, "F8_E8M0"), + _ => unreachable!("convert_dummy called with non-dummy dtype"), + }; + + // Load the raw bytes + let data = view.data(); + let shape = view.shape(); + + // Create storage with the appropriate dummy type variant + let storage = match device { + Device::Cpu => { + let cpu_storage = match dtype { + DType::F6E2M3 => crate::cpu_backend::CpuStorage::F6E2M3(data.to_vec()), + DType::F6E3M2 => crate::cpu_backend::CpuStorage::F6E3M2(data.to_vec()), + DType::F4 => crate::cpu_backend::CpuStorage::F4(data.to_vec()), + DType::F8E8M0 => crate::cpu_backend::CpuStorage::F8E8M0(data.to_vec()), + _ => unreachable!(), + }; + Storage::Cpu(cpu_storage) + } + #[cfg(feature = "cuda")] + Device::Cuda(device) => { + let mut slice = unsafe { device.alloc::(data.len())? }; + device.memcpy_htod(data, &mut slice)?; + + let slice = match dtype { + DType::F6E2M3 => crate::cuda_backend::CudaStorageSlice::F6E2M3(slice), + DType::F6E3M2 => crate::cuda_backend::CudaStorageSlice::F6E3M2(slice), + DType::F4 => crate::cuda_backend::CudaStorageSlice::F4(slice), + DType::F8E8M0 => crate::cuda_backend::CudaStorageSlice::F8E8M0(slice), + _ => unreachable!(), + }; + let storage = crate::cuda_backend::CudaStorage { + slice, + device: device.clone(), + }; + Storage::Cuda(storage) + } + #[cfg(not(feature = "cuda"))] + Device::Cuda(_) => { + return Err(Error::Msg("CUDA support not compiled".to_string())); + } + #[cfg(feature = "metal")] + Device::Metal(device) => { + let buffer = device.new_buffer_with_data(data)?; + + let storage = + crate::metal_backend::MetalStorage::new(buffer, device.clone(), data.len(), dtype); + Storage::Metal(storage) + } + #[cfg(not(feature = "metal"))] + Device::Metal(_) => { + return Err(Error::Msg("Metal support not compiled".to_string())); + } + }; + + // Create tensor with correct dtype + let op = BackpropOp::none(); + Ok(from_storage(storage, shape, op, false)) +} + fn convert_back(tensor: &Tensor) -> Result> { // TODO: This makes an unnecessary copy when the tensor is on the cpu. let tensor = tensor.flatten_all()?; match tensor.dtype() { DType::U8 => Ok(convert_back_::(tensor.to_vec1()?)), DType::U32 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I16 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::I64 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::BF16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F64 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::F8E4M3 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + Err(Error::Msg("Internal error: dtype mismatch in storage".to_string()).bt()) + } } } @@ -259,7 +414,7 @@ pub fn save + Ord + std::fmt::Display, P: AsRef>( tensors: &HashMap, filename: P, ) -> Result<()> { - Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?) + Ok(st::serialize_to_file(tensors, None, filename.as_ref())?) } #[derive(yoke::Yokeable)] diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs index b86d885fa0..5c512c03b9 100644 --- a/candle-core/src/scalar.rs +++ b/candle-core/src/scalar.rs @@ -1,17 +1,21 @@ //! TensorScalar Enum and Trait //! use crate::{DType, Result, Tensor, WithDType}; +use float8::F8E4M3 as f8e4m3; use half::{bf16, f16}; #[derive(Debug, Clone, Copy, PartialEq)] pub enum Scalar { U8(u8), U32(u32), + I16(i16), + I32(i32), I64(i64), BF16(bf16), F16(f16), F32(f32), F64(f64), + F8E4M3(f8e4m3), } impl From for Scalar { @@ -25,11 +29,17 @@ impl Scalar { match dtype { DType::U8 => Scalar::U8(0), DType::U32 => Scalar::U32(0), + DType::I16 => Scalar::I16(0), + DType::I32 => Scalar::I32(0), DType::I64 => Scalar::I64(0), DType::BF16 => Scalar::BF16(bf16::ZERO), DType::F16 => Scalar::F16(f16::ZERO), DType::F32 => Scalar::F32(0.0), DType::F64 => Scalar::F64(0.0), + DType::F8E4M3 => Scalar::F8E4M3(f8e4m3::ZERO), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + panic!("Cannot create zero scalar for dummy type {dtype:?}") + } } } @@ -37,11 +47,17 @@ impl Scalar { match dtype { DType::U8 => Scalar::U8(1), DType::U32 => Scalar::U32(1), + DType::I16 => Scalar::I16(1), + DType::I32 => Scalar::I32(1), DType::I64 => Scalar::I64(1), DType::BF16 => Scalar::BF16(bf16::ONE), DType::F16 => Scalar::F16(f16::ONE), DType::F32 => Scalar::F32(1.0), DType::F64 => Scalar::F64(1.0), + DType::F8E4M3 => Scalar::F8E4M3(f8e4m3::ONE), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + panic!("Cannot create one scalar for dummy type {dtype:?}") + } } } @@ -49,11 +65,14 @@ impl Scalar { match self { Scalar::U8(_) => DType::U8, Scalar::U32(_) => DType::U32, + Scalar::I16(_) => DType::I16, + Scalar::I32(_) => DType::I32, Scalar::I64(_) => DType::I64, Scalar::BF16(_) => DType::BF16, Scalar::F16(_) => DType::F16, Scalar::F32(_) => DType::F32, Scalar::F64(_) => DType::F64, + Scalar::F8E4M3(_) => DType::F8E4M3, } } @@ -61,11 +80,14 @@ impl Scalar { match self { Scalar::U8(v) => *v as f64, Scalar::U32(v) => *v as f64, + Scalar::I16(v) => *v as f64, + Scalar::I32(v) => *v as f64, Scalar::I64(v) => *v as f64, Scalar::BF16(v) => v.to_f64(), Scalar::F16(v) => v.to_f64(), Scalar::F32(v) => *v as f64, Scalar::F64(v) => *v, + Scalar::F8E4M3(v) => v.to_f64(), } } } diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index af53661773..14ace645da 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -114,11 +114,33 @@ impl crate::CustomOp1 for ArgSort { let sort_indexes = match storage { crate::CpuStorage::U8(vs) => self.asort(vs, layout), crate::CpuStorage::U32(vs) => self.asort(vs, layout), + crate::CpuStorage::I16(vs) => self.asort(vs, layout), + crate::CpuStorage::I32(vs) => self.asort(vs, layout), crate::CpuStorage::I64(vs) => self.asort(vs, layout), crate::CpuStorage::BF16(vs) => self.asort(vs, layout), crate::CpuStorage::F16(vs) => self.asort(vs, layout), crate::CpuStorage::F32(vs) => self.asort(vs, layout), crate::CpuStorage::F64(vs) => self.asort(vs, layout), + crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout), + // Dummy types don't support sorting + crate::CpuStorage::F6E2M3(_) => { + return Err( + crate::Error::UnsupportedDTypeForOp(crate::DType::F6E2M3, "argsort").bt(), + ) + } + crate::CpuStorage::F6E3M2(_) => { + return Err( + crate::Error::UnsupportedDTypeForOp(crate::DType::F6E3M2, "argsort").bt(), + ) + } + crate::CpuStorage::F4(_) => { + return Err(crate::Error::UnsupportedDTypeForOp(crate::DType::F4, "argsort").bt()) + } + crate::CpuStorage::F8E8M0(_) => { + return Err( + crate::Error::UnsupportedDTypeForOp(crate::DType::F8E8M0, "argsort").bt(), + ) + } }; let sort_indexes = crate::CpuStorage::U32(sort_indexes); Ok((sort_indexes, layout.shape().into())) @@ -159,7 +181,15 @@ impl crate::CustomOp1 for ArgSort { DType::F64 => "asort_asc_f64", DType::U8 => "asort_asc_u8", DType::U32 => "asort_asc_u32", + DType::I16 => "asort_asc_i16", + DType::I32 => "asort_asc_i32", DType::I64 => "asort_asc_i64", + DType::F8E4M3 => "asort_asc_f8e4m3", + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err( + crate::Error::UnsupportedDTypeForOp(storage.dtype(), "argsort").bt(), + ) + } } } else { match storage.dtype() { @@ -169,7 +199,15 @@ impl crate::CustomOp1 for ArgSort { DType::F64 => "asort_desc_f64", DType::U8 => "asort_desc_u8", DType::U32 => "asort_desc_u32", + DType::I16 => "asort_desc_i16", + DType::I32 => "asort_desc_i32", DType::I64 => "asort_desc_i64", + DType::F8E4M3 => "asort_desc_f8e4m3", + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + return Err( + crate::Error::UnsupportedDTypeForOp(storage.dtype(), "argsort").bt(), + ) + } } } }; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 952374c2e6..fe641d3788 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -176,6 +176,22 @@ pub(crate) fn from_storage>( Tensor(Arc::new(tensor_)) } +/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides. This has a BackpropOp:none(). +pub fn from_storage_no_op>(storage: Storage, shape: S, is_variable: bool) -> Tensor { + let dtype = storage.dtype(); + let device = storage.device(); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: Arc::new(RwLock::new(storage)), + layout: Layout::contiguous(shape), + op: BackpropOp::none(), + is_variable, + dtype, + device, + }; + Tensor(Arc::new(tensor_)) +} + impl Tensor { pub(crate) fn ones_impl>( shape: S, @@ -270,6 +286,51 @@ impl Tensor { Tensor::zeros(self.shape(), self.dtype(), self.device()) } + // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from + // the variable module. + pub(crate) unsafe fn empty_impl>( + shape: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result { + let none = BackpropOp::none(); + let shape = shape.into(); + let storage = device.alloc_uninit(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) + } + + /// Creates a new tensor filled with uninitialized memory. + /// + /// # Safety + /// This returns uninitialized memory. + /// + /// ```rust + /// use candle_core::{Tensor, DType, Device}; + /// let a = unsafe { Tensor::empty((2, 3), DType::F32, &Device::Cpu)? }; + /// // a == b + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub unsafe fn empty>(shape: S, dtype: DType, device: &Device) -> Result { + Self::empty_impl(shape, dtype, device, false) + } + + /// Creates a new tensor filled with uninitialized memory of the same shape, dtype, and device as the other + /// tensor. + /// + /// # Safety + /// This returns uninitialized memory. + /// + /// ```rust + /// use candle_core::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = unsafe { a.empty_like()? }; + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub unsafe fn empty_like(&self) -> Result { + Tensor::empty(self.shape(), self.dtype(), self.device()) + } + pub(crate) fn rand_impl, T: crate::FloatDType>( lo: T, up: T, @@ -2754,6 +2815,49 @@ impl Tensor { } Ok(result) } + + /// Returns a view of which contains all slices of size `size` from self tensor in the dimension + /// `dim` and stepped by `step`. + pub fn unfold(&self, dim: D, size: usize, step: usize) -> Result { + // https://github.com/pytorch/pytorch/blob/75b0720a97ac5d82e8a7a1a6ae7c5f7a87d7183d/aten/src/ATen/native/TensorShape.cpp#L3785-L3804 + let mut sizes = self.dims().to_vec(); + let mut strides = self.stride().to_vec(); + + let dim = dim.to_index(self.shape(), "unfold")?; + + let max_len = if self.dims().is_empty() { + 1 + } else { + sizes[dim] + }; + if size > max_len { + bail!( + "unsqueeze: maximum size for tensor at dimension {dim} is {max_len} but size is {size}" + ) + } + sizes.push(size); + strides.push(if self.dims().is_empty() { + 1 + } else { + strides[dim] + }); + + if !self.dims().is_empty() { + sizes[dim] = ((sizes[dim] as f32 - size as f32) / step as f32 + 1.) as usize; + strides[dim] *= step; + } + + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::new(sizes.into(), strides, self.layout.start_offset()), + op: BackpropOp::new1(self, Op::Reshape), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } } macro_rules! bin_trait { diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 9aa15e9d50..7700ea2af1 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -378,12 +378,7 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) { assert!( difference < tolerance, - "Error at index {}: value = {}, expected = {}. Difference = {} exceeds tolerance = {}.", - i, - value, - expected_value, - difference, - tolerance + "Error at index {i}: value = {value}, expected = {expected_value}. Difference = {difference} exceeds tolerance = {tolerance}." ); } } diff --git a/candle-examples/examples/clip/main.rs b/candle-examples/examples/clip/main.rs index 273edb6a0a..e38249ce41 100644 --- a/candle-examples/examples/clip/main.rs +++ b/candle-examples/examples/clip/main.rs @@ -95,7 +95,7 @@ pub fn main() -> anyhow::Result<()> { let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; let softmax_image = softmax(&logits_per_image, 1)?; let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; - println!("softmax_image_vec: {:?}", softmax_image_vec); + println!("softmax_image_vec: {softmax_image_vec:?}"); let probability_vec = softmax_image_vec .iter() .map(|v| v * 100.0) @@ -105,7 +105,7 @@ pub fn main() -> anyhow::Result<()> { let start = i * probability_per_image; let end = start + probability_per_image; let prob = &probability_vec[start..end]; - println!("\n\nResults for image: {}\n", img); + println!("\n\nResults for image: {img}\n"); for (i, p) in prob.iter().enumerate() { println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]); } diff --git a/candle-examples/examples/codegeex4-9b/main.rs b/candle-examples/examples/codegeex4-9b/main.rs index 3848082f5f..dd854b0c05 100644 --- a/candle-examples/examples/codegeex4-9b/main.rs +++ b/candle-examples/examples/codegeex4-9b/main.rs @@ -69,7 +69,7 @@ impl TextGeneration { let start_gen = std::time::Instant::now(); println!("\n start_gen"); - println!("samplelen {}", sample_len); + println!("samplelen {sample_len}"); let mut count = 0; let mut result = vec![]; for index in 0..sample_len { @@ -101,10 +101,7 @@ impl TextGeneration { .decode(&[next_token], true) .expect("Token error"); if self.verbose { - println!( - "[Count: {}] [Raw Token: {}] [Decode Token: {}]", - count, next_token, token - ); + println!("[Count: {count}] [Raw Token: {next_token}] [Decode Token: {token}]"); } result.push(token); std::io::stdout().flush()?; diff --git a/candle-examples/examples/csm/main.rs b/candle-examples/examples/csm/main.rs index feadd6872c..3ace0fbbb1 100644 --- a/candle-examples/examples/csm/main.rs +++ b/candle-examples/examples/csm/main.rs @@ -207,7 +207,7 @@ fn main() -> Result<()> { for (turn_idx, prompt) in args.prompt.split('|').enumerate() { println!("{prompt:?}"); let speaker_idx = turn_idx % 2; - let prompt = format!("[{speaker_idx}]{}<|end_of_text|>", prompt); + let prompt = format!("[{speaker_idx}]{prompt}<|end_of_text|>"); let prompt = tokenizer.encode(prompt, true).map_err(E::msg)?; let (mut tokens, mut mask) = model.text_tokens_and_mask(prompt.get_ids())?; diff --git a/candle-examples/examples/debertav2/main.rs b/candle-examples/examples/debertav2/main.rs index 2f5f3ff2ca..61535d8f4e 100644 --- a/candle-examples/examples/debertav2/main.rs +++ b/candle-examples/examples/debertav2/main.rs @@ -320,7 +320,7 @@ fn main() -> Result<()> { results.push(current_row_result); } - println!("\n{:?}", results); + println!("\n{results:?}"); } TaskType::TextClassification(classification_model) => { @@ -344,7 +344,7 @@ fn main() -> Result<()> { }); } - println!("\n{:?}", results); + println!("\n{results:?}"); } } Ok(()) diff --git a/candle-examples/examples/distilbert/main.rs b/candle-examples/examples/distilbert/main.rs index 7f9df7cff3..06d29eb511 100644 --- a/candle-examples/examples/distilbert/main.rs +++ b/candle-examples/examples/distilbert/main.rs @@ -243,7 +243,7 @@ fn process_masked_output( for (token_idx, &token_id) in input_ids_vec[0].iter().enumerate() { if token_id == mask_token_id { - println!("Predictions for [MASK] at position {}:", token_idx); + println!("Predictions for [MASK] at position {token_idx}:"); let pos_logits = output.get(0)?.get(token_idx)?; let probs = candle_nn::ops::softmax(&pos_logits, 0)?; diff --git a/candle-examples/examples/efficientvit/main.rs b/candle-examples/examples/efficientvit/main.rs index efbf813c52..8d65968a6e 100644 --- a/candle-examples/examples/efficientvit/main.rs +++ b/candle-examples/examples/efficientvit/main.rs @@ -30,7 +30,7 @@ impl Which { Self::M4 => "m4", Self::M5 => "m5", }; - format!("timm/efficientvit_{}.r224_in1k", name) + format!("timm/efficientvit_{name}.r224_in1k") } fn config(&self) -> efficientvit::Config { diff --git a/candle-examples/examples/fastvit/main.rs b/candle-examples/examples/fastvit/main.rs index 520fd0aed3..a5c9d1c39d 100644 --- a/candle-examples/examples/fastvit/main.rs +++ b/candle-examples/examples/fastvit/main.rs @@ -32,7 +32,7 @@ impl Which { Self::SA36 => "sa36", Self::MA36 => "ma36", }; - format!("timm/fastvit_{}.apple_in1k", name) + format!("timm/fastvit_{name}.apple_in1k") } fn config(&self) -> fastvit::Config { diff --git a/candle-examples/examples/glm4/main.rs b/candle-examples/examples/glm4/main.rs index c4a300cf3a..3c547b59f5 100644 --- a/candle-examples/examples/glm4/main.rs +++ b/candle-examples/examples/glm4/main.rs @@ -89,8 +89,7 @@ impl TextGeneration { .expect("token decode error"); if args.verbose { println!( - "[Count: {}] [Raw Token: {}] [Decode Token: {}]", - generated_tokens, next_token, token + "[Count: {generated_tokens}] [Raw Token: {next_token}] [Decode Token: {token}]" ); } else { print!("{token}"); diff --git a/candle-examples/examples/hiera/main.rs b/candle-examples/examples/hiera/main.rs index 55bb1d54e1..06a95c2ad2 100644 --- a/candle-examples/examples/hiera/main.rs +++ b/candle-examples/examples/hiera/main.rs @@ -30,7 +30,7 @@ impl Which { Self::Large => "large", Self::Huge => "huge", }; - format!("timm/hiera_{}_224.mae_in1k_ft_in1k", name) + format!("timm/hiera_{name}_224.mae_in1k_ft_in1k") } fn config(&self) -> hiera::Config { diff --git a/candle-examples/examples/llava/main.rs b/candle-examples/examples/llava/main.rs index cb8093002f..b18ca4cb84 100644 --- a/candle-examples/examples/llava/main.rs +++ b/candle-examples/examples/llava/main.rs @@ -206,10 +206,8 @@ fn main() -> Result<()> { let llava: LLaVA = LLaVA::load(vb, &llava_config, clip_vision_config)?; println!("generating conv template"); - let image_token_se = format!( - "{}{}{}", - DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_END_TOKEN - ); + let image_token_se = + format!("{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}"); let qs = if args.prompt.contains(IMAGE_PLACEHOLDER) { if llava_config.mm_use_im_start_end { args.prompt.replace(IMAGE_PLACEHOLDER, &image_token_se) diff --git a/candle-examples/examples/mamba-minimal/main.rs b/candle-examples/examples/mamba-minimal/main.rs index 5e8968c039..2c8c53b300 100644 --- a/candle-examples/examples/mamba-minimal/main.rs +++ b/candle-examples/examples/mamba-minimal/main.rs @@ -123,7 +123,7 @@ enum Which { impl std::fmt::Display for Which { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + write!(f, "{self:?}") } } diff --git a/candle-examples/examples/mamba/main.rs b/candle-examples/examples/mamba/main.rs index b8c8bb70f6..5caf2e9fad 100644 --- a/candle-examples/examples/mamba/main.rs +++ b/candle-examples/examples/mamba/main.rs @@ -135,7 +135,7 @@ enum Which { impl std::fmt::Display for Which { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + write!(f, "{self:?}") } } diff --git a/candle-examples/examples/mobileclip/main.rs b/candle-examples/examples/mobileclip/main.rs index d9615c43b8..68d6bb32ab 100644 --- a/candle-examples/examples/mobileclip/main.rs +++ b/candle-examples/examples/mobileclip/main.rs @@ -25,7 +25,7 @@ impl Which { Self::S1 => "S1", Self::S2 => "S2", }; - format!("apple/MobileCLIP-{}-OpenCLIP", name) + format!("apple/MobileCLIP-{name}-OpenCLIP") } fn config(&self) -> mobileclip::MobileClipConfig { @@ -107,7 +107,7 @@ pub fn main() -> anyhow::Result<()> { let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; let softmax_image = softmax(&logits_per_image, 1)?; let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; - println!("softmax_image_vec: {:?}", softmax_image_vec); + println!("softmax_image_vec: {softmax_image_vec:?}"); let probability_vec = softmax_image_vec .iter() .map(|v| v * 100.0) @@ -118,7 +118,7 @@ pub fn main() -> anyhow::Result<()> { let start = i * probability_per_image; let end = start + probability_per_image; let prob = &probability_vec[start..end]; - println!("\n\nResults for image: {}\n", img); + println!("\n\nResults for image: {img}\n"); for (i, p) in prob.iter().enumerate() { println!("Probability: {:.4}% Text: {}", p, vec_seq[i]); diff --git a/candle-examples/examples/mobilenetv4/main.rs b/candle-examples/examples/mobilenetv4/main.rs index c31b91e6e4..b71b9ef61c 100644 --- a/candle-examples/examples/mobilenetv4/main.rs +++ b/candle-examples/examples/mobilenetv4/main.rs @@ -28,7 +28,7 @@ impl Which { Self::Large => "conv_large.e600_r384", Self::HybridLarge => "hybrid_large.ix_e600_r384", }; - format!("timm/mobilenetv4_{}_in1k", name) + format!("timm/mobilenetv4_{name}_in1k") } fn resolution(&self) -> u32 { diff --git a/candle-examples/examples/mobileone/main.rs b/candle-examples/examples/mobileone/main.rs index 76533fe3d5..7e0b0d448b 100644 --- a/candle-examples/examples/mobileone/main.rs +++ b/candle-examples/examples/mobileone/main.rs @@ -28,7 +28,7 @@ impl Which { Self::S3 => "s3", Self::S4 => "s4", }; - format!("timm/mobileone_{}.apple_in1k", name) + format!("timm/mobileone_{name}.apple_in1k") } fn config(&self) -> mobileone::Config { diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs index 86ea83043e..e8e84a2e52 100644 --- a/candle-examples/examples/moondream/main.rs +++ b/candle-examples/examples/moondream/main.rs @@ -106,7 +106,7 @@ impl TextGeneration { } }; load_t = start_gen.elapsed(); - println!("load_t: {:?}", load_t); + println!("load_t: {load_t:?}"); logits }; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; diff --git a/candle-examples/examples/orpheus/main.rs b/candle-examples/examples/orpheus/main.rs index 706e08cab9..adf31c90d9 100644 --- a/candle-examples/examples/orpheus/main.rs +++ b/candle-examples/examples/orpheus/main.rs @@ -247,7 +247,7 @@ impl Model { } fn run(&mut self, prompt: &str) -> Result<()> { - println!("running the model on '{}'", prompt); + println!("running the model on '{prompt}'"); let device = &self.device; let prompt = format!("{voice}: {prompt}", voice = self.voice.as_str()); let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?; @@ -259,7 +259,7 @@ impl Model { ] .concat(); if self.verbose_prompt { - println!("{:?}", tokens); + println!("{tokens:?}"); } let mut cache = self.cache.clone(); diff --git a/candle-examples/examples/paligemma/main.rs b/candle-examples/examples/paligemma/main.rs index 9ce5011bc2..2412f17531 100644 --- a/candle-examples/examples/paligemma/main.rs +++ b/candle-examples/examples/paligemma/main.rs @@ -253,7 +253,7 @@ fn main() -> Result<()> { .to_device(&device)? .to_dtype(dtype)? .unsqueeze(0)?; - println!("loaded image with shape {:?}", image); + println!("loaded image with shape {image:?}"); let start = std::time::Instant::now(); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let model = Model::new(&config, vb)?; diff --git a/candle-examples/examples/pixtral/main.rs b/candle-examples/examples/pixtral/main.rs index 79f438686f..4697eefe26 100644 --- a/candle-examples/examples/pixtral/main.rs +++ b/candle-examples/examples/pixtral/main.rs @@ -295,7 +295,7 @@ fn main() -> Result<()> { )? }; let image = image.to_device(&device)?.unsqueeze(0)?; - println!("loaded image with shape {:?}", image); + println!("loaded image with shape {image:?}"); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; if args.vision_only { diff --git a/candle-examples/examples/quantized-gemma/main.rs b/candle-examples/examples/quantized-gemma/main.rs index 48f4b1dc67..98ce7bd41e 100644 --- a/candle-examples/examples/quantized-gemma/main.rs +++ b/candle-examples/examples/quantized-gemma/main.rs @@ -92,12 +92,12 @@ impl Args { None => { let api = hf_hub::api::sync::Api::new()?; let repo = "google/gemma-3-4b-it"; - println!("DEBUG: Downloading tokenizer from {}", repo); + println!("DEBUG: Downloading tokenizer from {repo}"); let api = api.model(repo.to_string()); api.get("tokenizer.json")? } }; - println!("DEBUG: Loading tokenizer from {:?}", tokenizer_path); + println!("DEBUG: Loading tokenizer from {tokenizer_path:?}"); let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?; Ok(tokenizer) @@ -128,7 +128,7 @@ impl Args { fn format_size(size_in_bytes: usize) -> String { if size_in_bytes < 1_000 { - format!("{}B", size_in_bytes) + format!("{size_in_bytes}B") } else if size_in_bytes < 1_000_000 { format!("{:.2}KB", size_in_bytes as f64 / 1e3) } else if size_in_bytes < 1_000_000_000 { diff --git a/candle-examples/examples/quantized-phi/main.rs b/candle-examples/examples/quantized-phi/main.rs index a776e989e5..7ec13e4f80 100644 --- a/candle-examples/examples/quantized-phi/main.rs +++ b/candle-examples/examples/quantized-phi/main.rs @@ -148,7 +148,7 @@ impl Args { fn format_size(size_in_bytes: usize) -> String { if size_in_bytes < 1_000 { - format!("{}B", size_in_bytes) + format!("{size_in_bytes}B") } else if size_in_bytes < 1_000_000 { format!("{:.2}KB", size_in_bytes as f64 / 1e3) } else if size_in_bytes < 1_000_000_000 { diff --git a/candle-examples/examples/quantized-qwen2-instruct/main.rs b/candle-examples/examples/quantized-qwen2-instruct/main.rs index ff6ebe900b..a4dd5b0848 100644 --- a/candle-examples/examples/quantized-qwen2-instruct/main.rs +++ b/candle-examples/examples/quantized-qwen2-instruct/main.rs @@ -159,7 +159,7 @@ impl Args { fn format_size(size_in_bytes: usize) -> String { if size_in_bytes < 1_000 { - format!("{}B", size_in_bytes) + format!("{size_in_bytes}B") } else if size_in_bytes < 1_000_000 { format!("{:.2}KB", size_in_bytes as f64 / 1e3) } else if size_in_bytes < 1_000_000_000 { diff --git a/candle-examples/examples/quantized-qwen3/main.rs b/candle-examples/examples/quantized-qwen3/main.rs index b57466be85..b4b63beda0 100644 --- a/candle-examples/examples/quantized-qwen3/main.rs +++ b/candle-examples/examples/quantized-qwen3/main.rs @@ -143,7 +143,7 @@ impl Args { fn format_size(size_in_bytes: usize) -> String { if size_in_bytes < 1_000 { - format!("{}B", size_in_bytes) + format!("{size_in_bytes}B") } else if size_in_bytes < 1_000_000 { format!("{:.2}KB", size_in_bytes as f64 / 1e3) } else if size_in_bytes < 1_000_000_000 { diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index abd4b38907..eb7e348a05 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -423,7 +423,7 @@ impl Args { fn format_size(size_in_bytes: usize) -> String { if size_in_bytes < 1_000 { - format!("{}B", size_in_bytes) + format!("{size_in_bytes}B") } else if size_in_bytes < 1_000_000 { format!("{:.2}KB", size_in_bytes as f64 / 1e3) } else if size_in_bytes < 1_000_000_000 { diff --git a/candle-examples/examples/repvgg/main.rs b/candle-examples/examples/repvgg/main.rs index 7cc90ba16b..5b3521243b 100644 --- a/candle-examples/examples/repvgg/main.rs +++ b/candle-examples/examples/repvgg/main.rs @@ -38,7 +38,7 @@ impl Which { Self::B2G4 => "b2g4", Self::B3G4 => "b3g4", }; - format!("timm/repvgg_{}.rvgg_in1k", name) + format!("timm/repvgg_{name}.rvgg_in1k") } fn config(&self) -> repvgg::Config { diff --git a/candle-examples/examples/rwkv/main.rs b/candle-examples/examples/rwkv/main.rs index 8fb2c0d41f..aa5a406cb0 100644 --- a/candle-examples/examples/rwkv/main.rs +++ b/candle-examples/examples/rwkv/main.rs @@ -134,7 +134,7 @@ enum Which { impl std::fmt::Display for Which { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + write!(f, "{self:?}") } } diff --git a/candle-examples/examples/segformer/main.rs b/candle-examples/examples/segformer/main.rs index 16db62fc01..152f5b8d45 100644 --- a/candle-examples/examples/segformer/main.rs +++ b/candle-examples/examples/segformer/main.rs @@ -57,16 +57,16 @@ enum Commands { } fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> { - println!("loading model {} via huggingface hub", model_name); + println!("loading model {model_name} via huggingface hub"); let api = hf_hub::api::sync::Api::new()?; let api = api.model(model_name.clone()); let model_file = api.get("model.safetensors")?; - println!("model {} downloaded and loaded", model_name); + println!("model {model_name} downloaded and loaded"); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? }; let config = std::fs::read_to_string(api.get("config.json")?)?; let config: Config = serde_json::from_str(&config)?; - println!("{:?}", config); + println!("{config:?}"); Ok((vb, config)) } @@ -138,7 +138,7 @@ fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Res classification.to_vec1::()? ); let label_id = classification.argmax(0)?.to_scalar::()?; - let label_id = format!("{}", label_id); + let label_id = format!("{label_id}"); println!("label: {}", config.id2label[&label_id]); Ok(()) } diff --git a/candle-examples/examples/siglip/main.rs b/candle-examples/examples/siglip/main.rs index a78ed7f5d3..d20746717a 100644 --- a/candle-examples/examples/siglip/main.rs +++ b/candle-examples/examples/siglip/main.rs @@ -146,7 +146,7 @@ pub fn main() -> anyhow::Result<()> { let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; let softmax_image = softmax(&logits_per_image, 1)?; let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; - println!("softmax_image_vec: {:?}", softmax_image_vec); + println!("softmax_image_vec: {softmax_image_vec:?}"); let probability_vec = softmax_image_vec .iter() .map(|v| v * 100.0) @@ -156,7 +156,7 @@ pub fn main() -> anyhow::Result<()> { let start = i * probability_per_image; let end = start + probability_per_image; let prob = &probability_vec[start..end]; - println!("\n\nResults for image: {}\n", img); + println!("\n\nResults for image: {img}\n"); for (i, p) in prob.iter().enumerate() { println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]); } diff --git a/candle-examples/examples/splade/main.rs b/candle-examples/examples/splade/main.rs index aa4c60ac41..738b624b7f 100644 --- a/candle-examples/examples/splade/main.rs +++ b/candle-examples/examples/splade/main.rs @@ -73,7 +73,7 @@ fn main() -> Result<()> { Err(_) => match repo.get("pytorch_model.bin") { Ok(pytorch_model) => pytorch_model, Err(e) => { - return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e))); + return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {e}"))); } }, }, diff --git a/candle-examples/examples/trocr/main.rs b/candle-examples/examples/trocr/main.rs index f857295c78..63ee3c1bef 100644 --- a/candle-examples/examples/trocr/main.rs +++ b/candle-examples/examples/trocr/main.rs @@ -93,7 +93,7 @@ pub fn main() -> anyhow::Result<()> { .get("model.safetensors")? } }; - println!("model: {:?}", model); + println!("model: {model:?}"); unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? } }; diff --git a/candle-examples/examples/xlm-roberta/main.rs b/candle-examples/examples/xlm-roberta/main.rs index c1f759164e..8bf5af6b88 100644 --- a/candle-examples/examples/xlm-roberta/main.rs +++ b/candle-examples/examples/xlm-roberta/main.rs @@ -117,7 +117,7 @@ fn main() -> Result<()> { Err(_) => match repo.get("pytorch_model.bin") { Ok(pytorch_model) => pytorch_model, Err(e) => { - return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e))); + return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {e}"))); } }, }, diff --git a/candle-flash-attn-v3/.gitignore b/candle-flash-attn-v3/.gitignore new file mode 100644 index 0000000000..fc378cabab --- /dev/null +++ b/candle-flash-attn-v3/.gitignore @@ -0,0 +1,7 @@ +.idea +target +Cargo.lock +.venv +hkernel/build/* +__pycache__ +*.egg-info \ No newline at end of file diff --git a/candle-flash-attn-v3/.gitmodules b/candle-flash-attn-v3/.gitmodules new file mode 100644 index 0000000000..2b822e9a55 --- /dev/null +++ b/candle-flash-attn-v3/.gitmodules @@ -0,0 +1,3 @@ +[submodule "cutlass"] + path = cutlass + url = https://github.com/NVIDIA/cutlass.git \ No newline at end of file diff --git a/candle-flash-attn-v3/Cargo.toml b/candle-flash-attn-v3/Cargo.toml new file mode 100644 index 0000000000..ce3de8e4c8 --- /dev/null +++ b/candle-flash-attn-v3/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "candle-flash-attn-v3" +version = "0.9.1" +edition = "2021" + +description = "Flash attention V3 layer for the candle ML framework." +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT OR Apache-2.0" +readme = "README.md" +authors = ["Michael Feil"] +repository = "https://github.com/michaelfeil/candle-flash-attn-v3" +exclude = ["cutlass/docs/**", "cutlass/test/**", "cutlass/examples/**", "cutlass/tools/**", "cutlass/media/**"] + +[dependencies] +candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.1" } +half = { version = "2.3.1", features = ["num-traits"] } + +[build-dependencies] +anyhow = { version = "1", features = ["backtrace"] } +num_cpus = "1.15.0" +rayon = "1.7.0" + +[dev-dependencies] +anyhow = { version = "1", features = ["backtrace"] } +candle-nn = { path = "../candle-nn", features = ["cuda"] } +rstest = "0.23" \ No newline at end of file diff --git a/candle-flash-attn-v3/LICENSE-APACHE b/candle-flash-attn-v3/LICENSE-APACHE new file mode 100644 index 0000000000..f49a4e16e6 --- /dev/null +++ b/candle-flash-attn-v3/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/candle-flash-attn-v3/LICENSE-MIT b/candle-flash-attn-v3/LICENSE-MIT new file mode 100644 index 0000000000..468cd79a8f --- /dev/null +++ b/candle-flash-attn-v3/LICENSE-MIT @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/candle-flash-attn-v3/README.md b/candle-flash-attn-v3/README.md new file mode 100644 index 0000000000..4ffa16bf6f --- /dev/null +++ b/candle-flash-attn-v3/README.md @@ -0,0 +1,40 @@ +# Candle Flash Attention v3 Layer + +Flash Attention v3 Layer for Hopper (compatible nvidia `sm90a` arch) and the candle framework. + +Work supported by Baseten (https://github.com/basetenlabs) +If you are working on the intersection of CUDA / LLMs and Inference already, feel free to reach out, [we are hiring.](https://www.baseten.co/careers/) + +### Usage + +```rust +use baseten_candle_flash_attn_v3; +use anyhow::Result; +use candle::{DType, Device, IndexOp, Tensor, D}; + +fn flash_attn_acausal() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 3 * 2 * 64, &device)? + .to_dtype(DType::F16)? + .reshape((1, 3, 2, 64))?; // batch, head, seqlen, hidden_dim + let k = (&q / 400.)?; + let v = (&q / 500.)?; + let q = (&q / 300.)?; + + let att = { + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + baseten_candle_flash_attn_v3::flash_attn(&q, &k, &v, 0.5, false, false)?.transpose(1, 2)? + }; +``` + +### Install instructions + +``` +[dependencies] +candle = { version = "*", package = "candle-core", default-features = false } +candle-nn = { version = "*" } +candle-transformers = { version = "*" } +baseten-candle-flash-attn-v3 = { git = "https://github.com/michaelfeil/candle-flash-attn-v3", rev = "main", optional = true } +```` \ No newline at end of file diff --git a/candle-flash-attn-v3/build.rs b/candle-flash-attn-v3/build.rs new file mode 100644 index 0000000000..d33f2937cf --- /dev/null +++ b/candle-flash-attn-v3/build.rs @@ -0,0 +1,344 @@ +// build.rs +use anyhow::{anyhow, Context, Result}; +use rayon::prelude::*; +use std::path::PathBuf; +use std::str::FromStr; + +const CUDA_NVCC_FLAGS: Option<&'static str> = option_env!("CUDA_NVCC_FLAGS"); + +const KERNEL_FILES: &[&str] = &[ + "flash_api.cu", + "flash_fwd_hdim64_fp16_sm90.cu", + "flash_fwd_hdim64_bf16_sm90.cu", + "flash_fwd_hdim128_fp16_sm90.cu", + "flash_fwd_hdim128_bf16_sm90.cu", + "flash_fwd_hdim256_fp16_sm90.cu", + "flash_fwd_hdim256_bf16_sm90.cu", + // "flash_bwd_hdim64_fp16_sm90.cu", + // "flash_bwd_hdim96_fp16_sm90.cu", + // "flash_bwd_hdim128_fp16_sm90.cu", + // commented out in main repo: // "flash_bwd_hdim256_fp16_sm90.cu", + // "flash_bwd_hdim64_bf16_sm90.cu", + // "flash_bwd_hdim96_bf16_sm90.cu", + // "flash_bwd_hdim128_bf16_sm90.cu", + // "flash_fwd_hdim64_e4m3_sm90.cu", + // "flash_fwd_hdim128_e4m3_sm90.cu", + // "flash_fwd_hdim256_e4m3_sm90.cu", + "flash_fwd_hdim64_fp16_gqa2_sm90.cu", + "flash_fwd_hdim64_fp16_gqa4_sm90.cu", + "flash_fwd_hdim64_fp16_gqa8_sm90.cu", + "flash_fwd_hdim64_fp16_gqa16_sm90.cu", + "flash_fwd_hdim64_fp16_gqa32_sm90.cu", + "flash_fwd_hdim128_fp16_gqa2_sm90.cu", + "flash_fwd_hdim128_fp16_gqa4_sm90.cu", + "flash_fwd_hdim128_fp16_gqa8_sm90.cu", + "flash_fwd_hdim128_fp16_gqa16_sm90.cu", + "flash_fwd_hdim128_fp16_gqa32_sm90.cu", + "flash_fwd_hdim256_fp16_gqa2_sm90.cu", + "flash_fwd_hdim256_fp16_gqa4_sm90.cu", + "flash_fwd_hdim256_fp16_gqa8_sm90.cu", + "flash_fwd_hdim256_fp16_gqa16_sm90.cu", + "flash_fwd_hdim256_fp16_gqa32_sm90.cu", + "flash_fwd_hdim64_bf16_gqa2_sm90.cu", + "flash_fwd_hdim64_bf16_gqa4_sm90.cu", + "flash_fwd_hdim64_bf16_gqa8_sm90.cu", + "flash_fwd_hdim64_bf16_gqa16_sm90.cu", + "flash_fwd_hdim64_bf16_gqa32_sm90.cu", + "flash_fwd_hdim128_bf16_gqa2_sm90.cu", + "flash_fwd_hdim128_bf16_gqa4_sm90.cu", + "flash_fwd_hdim128_bf16_gqa8_sm90.cu", + "flash_fwd_hdim128_bf16_gqa16_sm90.cu", + "flash_fwd_hdim128_bf16_gqa32_sm90.cu", + "flash_fwd_hdim256_bf16_gqa2_sm90.cu", + "flash_fwd_hdim256_bf16_gqa4_sm90.cu", + "flash_fwd_hdim256_bf16_gqa8_sm90.cu", + "flash_fwd_hdim256_bf16_gqa16_sm90.cu", + "flash_fwd_hdim256_bf16_gqa32_sm90.cu", + // "flash_fwd_hdim64_e4m3_gqa2_sm90.cu", + // "flash_fwd_hdim64_e4m3_gqa4_sm90.cu", + // "flash_fwd_hdim64_e4m3_gqa8_sm90.cu", + // "flash_fwd_hdim64_e4m3_gqa16_sm90.cu", + // "flash_fwd_hdim64_e4m3_gqa32_sm90.cu", + // "flash_fwd_hdim128_e4m3_gqa2_sm90.cu", + // "flash_fwd_hdim128_e4m3_gqa4_sm90.cu", + // "flash_fwd_hdim128_e4m3_gqa8_sm90.cu", + // "flash_fwd_hdim128_e4m3_gqa16_sm90.cu", + // "flash_fwd_hdim128_e4m3_gqa32_sm90.cu", + // "flash_fwd_hdim256_e4m3_gqa2_sm90.cu", + // "flash_fwd_hdim256_e4m3_gqa4_sm90.cu", + // "flash_fwd_hdim256_e4m3_gqa8_sm90.cu", + // "flash_fwd_hdim256_e4m3_gqa16_sm90.cu", + // "flash_fwd_hdim256_e4m3_gqa32_sm90.cu", +]; + +fn main() -> Result<()> { + // Use RAYON_NUM_THREADS or else default to the number of physical CPUs + let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else( + |_| num_cpus::get_physical(), + |s| usize::from_str(&s).unwrap_or_else(|_| num_cpus::get_physical()), + ); + // limit to 16 cpus to not use to much ram on large servers + let num_cpus = num_cpus.min(16); + + rayon::ThreadPoolBuilder::new() + .num_threads(num_cpus) + .build_global() + .unwrap(); + + // Telling Cargo that if any of these files changes, rebuild. + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); + + for file in KERNEL_FILES { + println!("cargo:rerun-if-changed=hkernel/{file}"); + } + println!("cargo:rerun-if-changed=kernels/**.h"); + println!("cargo:rerun-if-changed=kernels/**.hpp"); + println!("cargo:rerun-if-changed=kernels/**.cpp"); + + let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?); + // You can optionally allow an environment variable to cache the compiled artifacts. + // If not found, we compile into the standard OUT_DIR. + let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") { + Err(_) => out_dir.clone(), + Ok(build_dir) => { + let path = PathBuf::from(build_dir); + path.canonicalize().map_err(|_| { + anyhow!( + "Directory doesn't exist: {} (the current directory is {})", + path.display(), + std::env::current_dir().unwrap().display() + ) + })? + } + }; + + // Ensure we set CUDA_INCLUDE_DIR for our crates that might rely on it. + set_cuda_include_dir()?; + + // If set, pass along the custom compiler for NVCC + let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN").ok(); + + // Determine the GPU architecture we’re targeting, e.g. 90 for `sm_90`. + let compute_cap = compute_cap()?; + // assert compute cap is sm90 + assert!(compute_cap == 90, "Compute capability must be 90 (90a)"); + + // Our final library name + let out_file = build_dir.join("libflashattentionv3.a"); + + // Construct the list of (input_file -> output_object_file) + let kernel_dir = PathBuf::from("hkernel"); + let cu_files: Vec<(PathBuf, PathBuf)> = KERNEL_FILES + .iter() + .map(|f| { + let mut obj_file = out_dir.join(f); + obj_file.set_extension("o"); + (kernel_dir.join(f), obj_file) + }) + .collect(); + + // Decide whether to skip recompile if outputs are up to date. + // This is a simplistic approach, + // so feel free to refine if you need more robust up-to-date checks. + let out_modified = out_file + .metadata() + .and_then(|m| m.modified()) + .ok() + .unwrap_or_else(|| std::time::SystemTime::UNIX_EPOCH); + let should_compile = !out_file.exists() + || cu_files.iter().any(|(input, _)| { + let input_modified = input + .metadata() + .and_then(|m| m.modified()) + .unwrap_or(std::time::SystemTime::UNIX_EPOCH); + input_modified.duration_since(out_modified).is_ok() // True if input_modified >= out_modified + }); + + if should_compile { + // 1) Compile each .cu/.cpp -> .o + cu_files + .par_iter() + .try_for_each(|(input, obj)| -> Result<()> { + let mut command = std::process::Command::new("nvcc"); + + // Optimization and standard + command.arg("-O3"); + command.arg("-std=c++17"); + + // GPU architecture, hard code sm_90a instead of sm90 + command.arg(format!("--gpu-architecture={}", "sm_90a")); + + // Compile to object file + command.arg("-c"); + command.args(["-o", obj.to_str().unwrap()]); + + // Default stream per-thread + command.args(["--default-stream", "per-thread"]); + + // Include path + command.arg("-Icutlass/include"); + + // Undefine CUDA “no half/bfloat” macros + command.arg("-U__CUDA_NO_HALF_OPERATORS__"); + command.arg("-U__CUDA_NO_HALF_CONVERSIONS__"); + command.arg("-U__CUDA_NO_BFLOAT16_OPERATORS__"); + command.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__"); + command.arg("-U__CUDA_NO_BFLOAT162_OPERATORS__"); + command.arg("-U__CUDA_NO_BFLOAT162_CONVERSIONS__"); + + // Enable relaxed/extended lambda and fast math + command.arg("--expt-relaxed-constexpr"); + command.arg("--expt-extended-lambda"); + command.arg("--use_fast_math"); + + // PTXAS options: verbose output, register usage info, etc. + command.arg("--ptxas-options=-v"); + command.arg("--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"); + + // Additional debug/performance flags + command.arg("-lineinfo"); + command.arg("-DCUTLASS_DEBUG_TRACE_LEVEL=0"); + command.arg("-DNDEBUG"); + + // https://github.com/EricLBuehler/mistral.rs/issues/941 + command.arg("-D_USE_MATH_DEFINES"); + + if let Some(ccbin_path) = &ccbin_env { + command.arg("-allow-unsupported-compiler"); + command.args(["-ccbin", ccbin_path]); + } + + // Add the source file + command.arg(input); + + // https://github.com/EricLBuehler/mistral.rs/issues/286 + if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { + command.arg("--compiler-options"); + command.arg(cuda_nvcc_flags_env); + } + + let output = command + .spawn() + .with_context(|| format!("Failed to spawn nvcc for {input:?}"))? + .wait_with_output() + .with_context(|| format!("Failed during nvcc invocation for {input:?}"))?; + + if !output.status.success() { + return Err(anyhow!( + "nvcc error:\nCommand: {:?}\nstdout:\n{}\nstderr:\n{}", + command, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + )); + } + + Ok(()) + })?; + + // 2) Create static library from the .o files + let obj_files = cu_files + .iter() + .map(|(_, obj)| obj.clone()) + .collect::>(); + + let mut command = std::process::Command::new("nvcc"); + command.arg("--lib"); + command.args(["-o", out_file.to_str().unwrap()]); + command.args(obj_files); + + let output = command + .spawn() + .context("Failed spawning nvcc to archive .o files")? + .wait_with_output() + .context("Failed during nvcc archive step")?; + + if !output.status.success() { + return Err(anyhow!( + "nvcc error (archiving):\nCommand: {:?}\nstdout:\n{}\nstderr:\n{}", + command, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + )); + } + } + + // Finally, instruct cargo to link your library + println!("cargo:rustc-link-search={}", build_dir.display()); + println!("cargo:rustc-link-lib=static=flashattentionv3"); + + // Link required system libs + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=dylib=stdc++"); + + Ok(()) +} + +/// This function attempts to find a CUDA toolkit root that contains `include/cuda.h`, +/// and prints that path as `CUDA_INCLUDE_DIR`. +fn set_cuda_include_dir() -> Result<()> { + // Adapted from cudarc build.rs + let env_vars = [ + "CUDA_PATH", + "CUDA_ROOT", + "CUDA_TOOLKIT_ROOT_DIR", + "CUDNN_LIB", + ]; + let env_vars = env_vars + .into_iter() + .filter_map(|v| std::env::var(v).ok()) + .map(Into::::into); + + let common_roots = [ + "/usr", + "/usr/local/cuda", + "/opt/cuda", + "/usr/lib/cuda", + "C:/Program Files/NVIDIA GPU Computing Toolkit", + "C:/CUDA", + ]; + let candidates = env_vars.chain(common_roots.into_iter().map(Into::into)); + + let root = candidates + .filter(|path| path.join("include").join("cuda.h").is_file()) + .next() + .ok_or_else(|| anyhow!("Cannot find a valid CUDA root with include/cuda.h"))?; + + println!( + "cargo:rustc-env=CUDA_INCLUDE_DIR={}", + root.join("include").display() + ); + Ok(()) +} + +/// Determine the compute capability we should target. +/// If the user sets `CUDA_COMPUTE_CAP` we trust that. +/// Otherwise, we attempt to parse it from `nvidia-smi`. +fn compute_cap() -> Result { + if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { + let cc = compute_cap_str + .parse::() + .context("Failed to parse CUDA_COMPUTE_CAP")?; + Ok(cc) + } else { + // parse from nvidia-smi + let output = std::process::Command::new("nvidia-smi") + .args(["--query-gpu=compute_cap", "--format=csv"]) + .output() + .context("Failed to run nvidia-smi. Make sure it's in PATH.")?; + let stdout = String::from_utf8_lossy(&output.stdout); + let mut lines = stdout.lines(); + if lines.next().unwrap_or("") != "compute_cap" { + return Err(anyhow!("Unexpected output from nvidia-smi: {stdout}")); + } + if let Some(cap_line) = lines.next() { + // e.g. "9.0" -> "90" + let cc_str = cap_line.trim().replace('.', ""); + let cc = cc_str.parse::()?; + Ok(cc) + } else { + Err(anyhow!("nvidia-smi did not return a compute_cap line")) + } + } +} diff --git a/candle-flash-attn-v3/cutlass b/candle-flash-attn-v3/cutlass new file mode 160000 index 0000000000..4c42f73fda --- /dev/null +++ b/candle-flash-attn-v3/cutlass @@ -0,0 +1 @@ +Subproject commit 4c42f73fdab5787e3bb57717f35a8cb1b3c0dc6d diff --git a/candle-flash-attn-v3/hkernel/combine.h b/candle-flash-attn-v3/hkernel/combine.h new file mode 100644 index 0000000000..c26f7ea562 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/combine.h @@ -0,0 +1,248 @@ + +#pragma once + +#include + +#include +#include "cutlass/layout/layout.h" +#include +#include + +#include "kernel_traits.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SharedStorageLSE { + cute::array_aligned> smem_lse; + cute::array_aligned> smem_valid_splits; +}; + +// DONT use Kernel_traits here to avoid redundant compilation. +// template +template +__global__ void combine_attn_seqk_parallel(Params const params) { + // using Element = typename Kernel_traits::OutputType; + // using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = int64_t; // Kernel_traits::index_t + constexpr int kMaxSplits = 1 << Log_max_splits; + // constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = 128; //Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + //__shared__ __align__(16) ElementAccum sLSE[kMaxSplits][kBlockM+1]; + extern __shared__ char smem_[]; + using SharedStorage = SharedStorageLSE, Int>, Shape>>; + SharedStorage &shared_storage = + *reinterpret_cast(smem_); + Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape, Int>{}); + Tensor sValidSplits = make_tensor(make_smem_ptr(shared_storage.smem_valid_splits.data()), Shape>{}); + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t lse_size = params.b * params.h * params.seqlen_q; + //if (cute::thread0()) print ("final %d %d %d %d\n", params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape, Int>{}, + make_stride(lse_size, _1{})); + + // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. + // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. + Layout flat_layout = make_layout(lse_size); + Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); + auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); + Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); + Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); + + Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); + + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then transpose them. + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { sLSE(row,col) = lse; } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } + } + __syncthreads(); + + // Reduce along the kBlockM dimension to determine valid splits (store in SMEM) + // One thread per split. Know NumThreads = 128 >= NumMaxSplits + if (tidx < kMaxSplits) { + bool is_valid_split = false; + #pragma unroll + for (int col = 0; col < kBlockM; ++col) { + if(sLSE(tidx,col) != -INFINITY) { + is_valid_split = true; + } + } + sValidSplits(tidx) = is_valid_split; + } + __syncthreads(); + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // kBlockM rows, so each time we load we can load 128 / kBlockM rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + //if (bidx == 0 && tidx < 128) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row,col) : -INFINITY; + + } + //return; + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); + #pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); + #pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + if (params.unpadded_lse) { + const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; + if (lse_offset < lse_size) { + gLSE_unpadded(lse_offset) = lse_logsum; + } + } else { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } + } + //if (cute::thread0()) printf ("lse_logsum = %f\n", lse_logsum); + + // Store the scales exp(lse - lse_logsum) in shared memory. + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { sLSE(row,col) = expf(lse_accum(l) - lse_logsum); } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + //if (cute::thread0()) print_tensor (cOaccum); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } + } + // Load Oaccum in then scale and accumulate to O + for (int split = 0; split < params.num_splits; ++split) { + // DONT copy in Oaccum if lse(split) = -inf for all kBlockM. + if(sValidSplits(split)) { + flash::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE(split,row); + if (lse_scale != 0.f) { + #pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { + #pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + //tOrO(i, m, k) += tOrOaccum(i, m, k); + } + } + } + //if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE(split, 0), sLSE(split, 1)); print_tensor(tOrOaccum); } + } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + //if (cute::thread0()) { print_tensor(tOrO); } + + Tensor rO = flash::convert_type(tOrO); + // Write to gO + #pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + //if (cute::thread0()) print ("final %d %d %d %d %d\n", idx, params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); + if (idx < params.b * params.h * params.seqlen_q) { + //print ("final2\n"); + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + + head_idx * params.o_head_stride + row * params.o_row_stride; + #pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + //if (cute::thread0()) { print ("final\n"); print_tensor(gO); } + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); + } + } + } + } +} + +} // namespace flash diff --git a/candle-flash-attn-v3/hkernel/copy_paged_sm90_tma.hpp b/candle-flash-attn-v3/hkernel/copy_paged_sm90_tma.hpp new file mode 100644 index 0000000000..218a7c3850 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/copy_paged_sm90_tma.hpp @@ -0,0 +1,8 @@ +#pragma once +#include + +#if CUTLASS_VERSION >= 360 +#include "copy_paged_sm90_tma_cutlass36.hpp" +#else +#include "copy_paged_sm90_tma_cutlass35.hpp" +#endif diff --git a/candle-flash-attn-v3/hkernel/copy_paged_sm90_tma_cutlass35.hpp b/candle-flash-attn-v3/hkernel/copy_paged_sm90_tma_cutlass35.hpp new file mode 100644 index 0000000000..6c467a2eb4 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/copy_paged_sm90_tma_cutlass35.hpp @@ -0,0 +1,402 @@ + +#pragma once + +#include +#include +#include + +static_assert(CUTLASS_VERSION < 360, "CUTLASS 3.5.x is required for this file due to incompatible API changes in Cutlass. Cutlass 3.5 does not have the cache_hint argument to SM90_TMA_LOAD ops."); + + +struct PagedCopyArgs { + + CUTE_HOST_DEVICE + PagedCopyArgs() : block_table_batch_stride{0}, page_block_size(0), block_table(nullptr) { + }; + + CUTE_HOST_DEVICE + PagedCopyArgs(int64_t const block_table_batch_stride_, int const page_block_size_, const int32_t *const block_table_) : block_table_batch_stride{block_table_batch_stride_}, page_block_size(page_block_size_), block_table(block_table_) { + }; + + const int64_t block_table_batch_stride; // The stride between block tables for different batches + const int page_block_size; // The size of a page block in number of elements + const int32_t *const block_table; // The block table, must be properly sized or a nullptr +}; + +namespace cute { + + struct SM90_TMA_LOAD_PAGED + { + using COPY_OP = SM90_TMA_LOAD; // The underlying copy operation that we delegate work to + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0) + { + CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 1D"); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + PagedCopyArgs const* pca, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 2D"); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + PagedCopyArgs const* pca, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + // WARNING: Do not place anything else here, or a performance regression will occur + // look out for ptxas build warnings like "Potential Performance Loss: wgmma.mma_async instructions are serialized" + // asserts that pca==nullptr, but even an assert would kill performance + return SM90_TMA_LOAD_3D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0, crd1, crd2); + } + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + PagedCopyArgs const* pca, + void * smem_ptr, + // Index order reordered for TMA from PagedSeqLenTraits::get_kv_gmem_layout() + // via cute::make_tma_copy_atom ( see detail::construct_tma_gbasis ) + // and detail::make_tma_copy_desc to create a TMA descriptor. + // The same reordering is aplied prior to calling via cute::tma_partition. + + // Final order determined experimentally. + int32_t const& crdK, // embedding dim + int32_t const& crdM, // sequence dim + int32_t const& crdH, // head dim + int32_t const& crdB) // batch dim + { + //auto log = pca.debug_log->nextline(); + //log.append_threadinfo(); + //log.snprintf("SM_90_TMA_LOAD_PAGED::copy(%d, %d, %d, %d) ", (int)crdM, (int)crdK, (int)crdH, (int)crdB); + if (pca == nullptr) { + return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, smem_ptr, crdK, crdM, crdH, crdB); + } + auto const page_block_size = pca->page_block_size; + int32_t const page_idx_offset = crdM / page_block_size; // page index within the batch entry + int32_t const seq_pos_offset = crdM - page_idx_offset * page_block_size; // == crd1 % page_block_size_ -> sequence position within the page + int32_t const page_idx = pca->block_table[page_idx_offset + crdB*pca->block_table_batch_stride]; // The page index for the given batch and sequence position + //if (cute::thread0()) { + // printf("SM90_TMA_LOAD_PAGED::copy crdM=%d, crdB=%d, crdK=%d, crdH=%d, page_idx=%d, seq_pos_offset=%d, ptr=%p\n", (int)crdM, (int)crdB, (int) crdK, (int) crdH, (int)page_idx, (int)seq_pos_offset, (void*)desc_ptr); + //} + + return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, smem_ptr, crdK, seq_pos_offset, crdH, page_idx); + + } + + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 5D"); + } + + }; + +struct SM90_TMA_LOAD_MULTICAST_PAGED +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& crd0) + { + CUTE_INVALID_CONTROL_PATH("not implemented"); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + PagedCopyArgs const* pca, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + CUTE_INVALID_CONTROL_PATH("not implemented"); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + PagedCopyArgs const* pca, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + // WARNING: Do not place anything else here, or a performance regression will occur + // look out for ptxas build warnings like "Potential Performance Loss: wgmma.mma_async instructions are serialized" + // asserts that pca==nullptr, but even an assert would kill performance + return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1, crd2); + } + + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + PagedCopyArgs const* pca, + void * smem_ptr, + // Index order reordered for TMA from PagedSeqLenTraits::get_kv_gmem_layout() + // via cute::make_tma_copy_atom ( see detail::construct_tma_gbasis ) + // and detail::make_tma_copy_desc to create a TMA descriptor. + // The same reordering is aplied prior to calling via cute::tma_partition. + + // Final order determined experimentally. + int32_t const& crdK, // embedding dim + int32_t const& crdM, // sequence dim + int32_t const& crdH, // head dim + int32_t const& crdB) // batch dim + { + if (pca == nullptr) { + return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crdK, crdM, crdH, crdB); + } + auto const page_block_size = pca->page_block_size; + int32_t const page_idx_offset = crdM / page_block_size; // page index within the batch entry + int32_t const seq_pos_offset = crdM - page_idx_offset*page_block_size; // == crd1 % page_block_size_ -> sequence position within the page + int32_t const page_idx = pca->block_table[page_idx_offset + crdB*pca->block_table_batch_stride]; // The page index for the given batch and sequence position + //if (cute::thread0()) { + // printf("SM90_TMA_LOAD_MULTICAST_PAGED::copy crdM=%d, crdB=%d, crdK=%d, crdH=%d, page_idx=%d, seq_pos_offset=%d, ptr=%p\n", (int)crdM, (int)crdB, (int) crdK, (int) crdH, (int)page_idx, (int)seq_pos_offset, (void*)desc_ptr); + //} + return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crdK, seq_pos_offset, crdH, page_idx); + + } + +}; + + + +// We also need to specialize Copy_Traits for PAGED_COPY_OP, we can do this by inheriting from the traits of the underlying copy op + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD /////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_PAGED_OP : SM90_TMA_LOAD_PAGED {}; + +// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar +// Use .with(tma_mbar) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {&tma_desc_, &tma_mbar, nullptr }}; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {new_tma_desc, &tma_mbar, nullptr }}; + } + + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask, PagedCopyArgs const & paged_copy_args ) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {&tma_desc_, &tma_mbar, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }}; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask,PagedCopyArgs const &paged_copy_args ) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {new_tma_desc, &tma_mbar, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM90_TMA_LOAD before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM90_TMA_LOAD with tma_desc and tma_mbar +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + PagedCopyArgs const* + > const opargs_; +}; + + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_MULTICAST_PAGED_OP : SM90_TMA_LOAD_MULTICAST_PAGED {}; + +// The non-executable SM90_TMA_LOAD_MULTICAST with tma_desc and no tma_mbar +// Use .with(tma_mbar, multicast_mask) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { + return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, nullptr }}; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { + return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, nullptr }}; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask, PagedCopyArgs const& paged_copy_args) const { + return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }}; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask, PagedCopyArgs const& paged_copy_args) const { + return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t, // multicast mask + PagedCopyArgs const* + > const opargs_; +}; + + +template +CUTE_HOST_RTC +auto +make_virtualized_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + VShape const &virtual_shape, + SLayout const slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + /** + Variant of cute::make_tma_copy which allows to separate a virtual tensor coordinate space and + a physical TMA tensor coordinate space. Used for Paged Attention with TMA. + */ + auto cta_v_tile = make_identity_layout(virtual_shape).compose(cta_tiler); + auto cta_t_tile = make_layout(cluster_size); + //cute::print("\nVirtual Shape:"); cute::print(virtual_shape); + //cute::print("\nPhysical Shape:"); cute::print(gtensor.layout().shape()); cute::print("\n"); + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, + gtensor, slayout, + cta_t_tile, cta_v_tile); + +} + +} diff --git a/candle-flash-attn-v3/hkernel/copy_paged_sm90_tma_cutlass36.hpp b/candle-flash-attn-v3/hkernel/copy_paged_sm90_tma_cutlass36.hpp new file mode 100644 index 0000000000..6d6717f932 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/copy_paged_sm90_tma_cutlass36.hpp @@ -0,0 +1,401 @@ + +#pragma once + +#include +#include +#include + +static_assert(CUTLASS_VERSION >= 360, "CUTLASS 3.6.x is required for this file due to incompatible API changes in Cutlass. Cutlass < 3.6 does not have the cache_hint argument to SM90_TMA_LOAD ops."); + +struct PagedCopyArgs { + + CUTE_HOST_DEVICE + PagedCopyArgs() : block_table_batch_stride{0}, page_block_size(0), block_table(nullptr) { + }; + + CUTE_HOST_DEVICE + PagedCopyArgs(int64_t const block_table_batch_stride_, int const page_block_size_, const int32_t *const block_table_) : block_table_batch_stride{block_table_batch_stride_}, page_block_size(page_block_size_), block_table(block_table_) { + }; + + const int64_t block_table_batch_stride; // The stride between block tables for different batches + const int page_block_size; // The size of a page block in number of elements + const int32_t *const block_table; // The block table, must be properly sized or a nullptr +}; + +namespace cute { + + struct SM90_TMA_LOAD_PAGED + { + using COPY_OP = SM90_TMA_LOAD; // The underlying copy operation that we delegate work to + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0) + { + CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 1D"); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + PagedCopyArgs const* pca, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 2D"); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + PagedCopyArgs const* pca, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + // WARNING: Do not place anything else here, or a performance regression will occur + // look out for ptxas build warnings like "Potential Performance Loss: wgmma.mma_async instructions are serialized" + // asserts that pca==nullptr, but even an assert would kill performance + return SM90_TMA_LOAD_3D::copy(desc_ptr, mbar_ptr, static_cast(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crd0, crd1, crd2); + } + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + PagedCopyArgs const* pca, + void * smem_ptr, + // Index order reordered for TMA from PagedSeqLenTraits::get_kv_gmem_layout() + // via cute::make_tma_copy_atom ( see detail::construct_tma_gbasis ) + // and detail::make_tma_copy_desc to create a TMA descriptor. + // The same reordering is aplied prior to calling via cute::tma_partition. + + // Final order determined experimentally. + int32_t const& crdK, // embedding dim + int32_t const& crdM, // sequence dim + int32_t const& crdH, // head dim + int32_t const& crdB) // batch dim + { + //auto log = pca.debug_log->nextline(); + //log.append_threadinfo(); + //log.snprintf("SM_90_TMA_LOAD_PAGED::copy(%d, %d, %d, %d) ", (int)crdM, (int)crdK, (int)crdH, (int)crdB); + if (pca == nullptr) { + return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, static_cast(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crdK, crdM, crdH, crdB); + } + auto const page_block_size = pca->page_block_size; + int32_t const page_idx_offset = crdM / page_block_size; // page index within the batch entry + int32_t const seq_pos_offset = crdM - page_idx_offset * page_block_size; // == crd1 % page_block_size_ -> sequence position within the page + int32_t const page_idx = pca->block_table[page_idx_offset + crdB*pca->block_table_batch_stride]; // The page index for the given batch and sequence position + //if (cute::thread0()) { + // printf("SM90_TMA_LOAD_PAGED::copy crdM=%d, crdB=%d, crdK=%d, crdH=%d, page_idx=%d, seq_pos_offset=%d, ptr=%p\n", (int)crdM, (int)crdB, (int) crdK, (int) crdH, (int)page_idx, (int)seq_pos_offset, (void*)desc_ptr); + //} + + return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, static_cast(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crdK, seq_pos_offset, crdH, page_idx); + + } + + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 5D"); + } + + }; + +struct SM90_TMA_LOAD_MULTICAST_PAGED +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& crd0) + { + CUTE_INVALID_CONTROL_PATH("not implemented"); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + PagedCopyArgs const* pca, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + CUTE_INVALID_CONTROL_PATH("not implemented"); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + PagedCopyArgs const* pca, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + // WARNING: Do not place anything else here, or a performance regression will occur + // look out for ptxas build warnings like "Potential Performance Loss: wgmma.mma_async instructions are serialized" + // asserts that pca==nullptr, but even an assert would kill performance + return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, static_cast(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crd0, crd1, crd2); + } + + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + PagedCopyArgs const* pca, + void * smem_ptr, + // Index order reordered for TMA from PagedSeqLenTraits::get_kv_gmem_layout() + // via cute::make_tma_copy_atom ( see detail::construct_tma_gbasis ) + // and detail::make_tma_copy_desc to create a TMA descriptor. + // The same reordering is aplied prior to calling via cute::tma_partition. + + // Final order determined experimentally. + int32_t const& crdK, // embedding dim + int32_t const& crdM, // sequence dim + int32_t const& crdH, // head dim + int32_t const& crdB) // batch dim + { + if (pca == nullptr) { + return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, static_cast(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crdK, crdM, crdH, crdB); + } + auto const page_block_size = pca->page_block_size; + int32_t const page_idx_offset = crdM / page_block_size; // page index within the batch entry + int32_t const seq_pos_offset = crdM - page_idx_offset*page_block_size; // == crd1 % page_block_size_ -> sequence position within the page + int32_t const page_idx = pca->block_table[page_idx_offset + crdB*pca->block_table_batch_stride]; // The page index for the given batch and sequence position + //if (cute::thread0()) { + // printf("SM90_TMA_LOAD_MULTICAST_PAGED::copy crdM=%d, crdB=%d, crdK=%d, crdH=%d, page_idx=%d, seq_pos_offset=%d, ptr=%p\n", (int)crdM, (int)crdB, (int) crdK, (int) crdH, (int)page_idx, (int)seq_pos_offset, (void*)desc_ptr); + //} + return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, static_cast(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crdK, seq_pos_offset, crdH, page_idx); + + } + +}; + + + +// We also need to specialize Copy_Traits for PAGED_COPY_OP, we can do this by inheriting from the traits of the underlying copy op + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD /////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_PAGED_OP : SM90_TMA_LOAD_PAGED {}; + +// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar +// Use .with(tma_mbar) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {&tma_desc_, &tma_mbar, nullptr}}; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {new_tma_desc, &tma_mbar, nullptr }}; + } + + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask, PagedCopyArgs const & paged_copy_args, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {&tma_desc_, &tma_mbar, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }}; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask, PagedCopyArgs const & paged_copy_args, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {new_tma_desc, &tma_mbar, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM90_TMA_LOAD before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM90_TMA_LOAD with tma_desc and tma_mbar +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + PagedCopyArgs const* + > const opargs_; +}; + + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_MULTICAST_PAGED_OP : SM90_TMA_LOAD_MULTICAST_PAGED {}; + +// The non-executable SM90_TMA_LOAD_MULTICAST with tma_desc and no tma_mbar +// Use .with(tma_mbar, multicast_mask) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, nullptr }}; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, nullptr }}; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask, PagedCopyArgs const & paged_copy_args, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }}; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask, PagedCopyArgs const& paged_copy_args, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t, // multicast mask + PagedCopyArgs const* + > const opargs_; +}; + + +template +CUTE_HOST_RTC +auto +make_virtualized_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + VShape const &virtual_shape, + SLayout const slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + /** + Variant of cute::make_tma_copy which allows to separate a virtual tensor coordinate space and + a physical TMA tensor coordinate space. Used for Paged Attention with TMA. + */ + auto cta_v_tile = make_identity_layout(virtual_shape).compose(cta_tiler); + auto cta_t_tile = make_layout(cluster_size); + //cute::print("\nVirtual Shape:"); cute::print(virtual_shape); + //cute::print("\nPhysical Shape:"); cute::print(gtensor.layout().shape()); cute::print("\n"); + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, + gtensor, slayout, + cta_t_tile, cta_v_tile); + +} + +} diff --git a/candle-flash-attn-v3/hkernel/epilogue_fwd_sm90_tma.hpp b/candle-flash-attn-v3/hkernel/epilogue_fwd_sm90_tma.hpp new file mode 100644 index 0000000000..26664c1041 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/epilogue_fwd_sm90_tma.hpp @@ -0,0 +1,417 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "named_barrier.hpp" +#include "utils.h" + +namespace flash { + +using namespace cute; + +// template +template +struct CollectiveEpilogueFwd { + + using InputType = typename Ktraits::Element; + using Element = typename Ktraits::OutputType; + static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kBlockN = Ktraits::kBlockN; + static constexpr int kBlockH = Ktraits::kBlockH; + static constexpr int kHeadDim = Ktraits::kHeadDim; + using TileShape_MNK = Shape, Int, Int>; + + static constexpr int kNWarps = Ktraits::kNWarps; + static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; + static constexpr bool Is_WS = Ktraits::Is_WS; + + static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup; + static constexpr int NumMmaThreads = kNThreads - NumCopyThreads; + + static constexpr bool Is_split = Ktraits::Is_split; + static constexpr bool No_smem_O = Ktraits::No_smem_O; + +#ifndef NO_FP8_COLUMN_PERMUTE + static constexpr bool epi_column_permute = is_same_v; +#else + static constexpr bool epi_column_permute = false; +#endif + + using GmemShapeOT = std::conditional_t< + Is_split, + typename Seqlen_traits::ShapeOAccumT, + typename Seqlen_traits::ShapeT + >; + using GmemStrideOT = std::conditional_t< + Is_split, + typename Seqlen_traits::StrideOAccumT, + typename Seqlen_traits::StrideT + >; + using GmemLayoutOT = std::conditional_t< + Is_split, + typename Seqlen_traits::LayoutOAccumT, + typename Seqlen_traits::LayoutT + >; + + using GmemLayoutLseT = std::conditional_t< + Is_split, + typename Seqlen_traits::LayoutLseAccumT, + typename Seqlen_traits::LayoutLseT + >; + + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutOCopy = typename Ktraits::SmemLayoutOCopy; + using TileShapeOCopy = typename Ktraits::TileShapeOCopy; + + using SmemCopyAtomO = std::conditional_t, Element>, Copy_Atom>; + using SharedStorage = cute::array_aligned>; + + using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; + using TMA_O = decltype(make_tma_copy( + GmemTiledCopyOTMA{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + GmemShapeOT{}, + GmemStrideOT{} + ), + SmemLayoutOCopy{}, + TileShapeOCopy{}, + _1{})); // no mcast for O + + // These are for storing the output tensor without TMA (e.g., for setting output to zero and var-seq-len) + static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v); + static_assert(kHeadDim % kNumVecElem == 0); + static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem; + static_assert(NumMmaThreads % kNumThreadsPerRow == 0); + static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow; + using TiledCopyOAtom = cute::Copy_Atom, Element>; + using TiledCopyOThrLayout = decltype(cute::make_layout( + cute::make_shape(Int{}, Int{}), + LayoutRight{})); + using TiledCopyOValLayout = decltype(cute::make_layout( + cute::make_shape(_1{}, Int{}), + LayoutRight{})); + using TiledCopyO = decltype(make_tiled_copy( + TiledCopyOAtom{}, + TiledCopyOThrLayout{}, // Thr layout + TiledCopyOValLayout{} // Val layout + )); + + // used for rmem -> smem O copy in fp8 kernel to undo column permutation + using ThreadLayoutrO = Layout, _4, _1>, + Stride<_4, _32, _1, _0>>; + using ValueLayoutrO = Layout, Int>, + Stride<_0, _2, Stride<_4, _1>, _8>>; + using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom, Element>{}, + ThreadLayoutrO{}, ValueLayoutrO{})); + using TiledCopyShaperO = Shape<_8, Int, _16, Int>; + using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout{})); + + // Host side kernel arguments + struct Arguments { + Element* ptr_O; + GmemLayoutOT const layout_O; + float* ptr_LSE; + GmemLayoutLseT const layout_LSE; + }; + + // Device side kernel params + struct Params { + Element* ptr_O; + GmemLayoutOT const layout_O; + float* ptr_LSE; + GmemLayoutLseT const layout_LSE; + TMA_O tma_store_O; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.layout_O); + TMA_O tma_store_O = make_tma_copy( + GmemTiledCopyOTMA{}, + mO, + SmemLayoutOCopy{}, + TileShapeOCopy{}, + _1{}); // no mcast for O + return {args.ptr_O, args.layout_O, args.ptr_LSE, args.layout_LSE, tma_store_O}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& epilogue_params) { + if constexpr (!Seqlen_traits::UseVarSeqLen && !No_smem_O) { + cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor()); + } + } + + template + CUTLASS_DEVICE void + store(Params const& epilogue_params, + FrgTensorO const& tOrO, + FrgTensorLSE const& lse, + SharedStorage& shared_storage, + TiledMma tiled_mma, + int thread_idx, + cute::tuple const& block_coord, + const Seqlen_traits& seqlen_traits_q, + const cutlass::FastDivmod& qhead_per_khead_divmod + ) { + + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + const int bidh_kv = qhead_per_khead_divmod.divide(bidh); + const int h_block = bidh % int(qhead_per_khead_divmod); + + Tensor tOrO_out = flash::convert_type(tOrO); + if constexpr(!No_smem_O) { + if constexpr (!epi_column_permute) { + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + + Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Make sure all WGs have finished reading V + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::ValueEmpty) /*id*/); + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } else { + TiledCopyrO rmem_tiled_copy_O; + Tensor sOacc = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutrO{}); + auto rmem_thr_copy_O = rmem_tiled_copy_O.get_thread_slice(thread_idx); + + Tensor taccOsO = rmem_thr_copy_O.partition_D(sOacc); + Tensor taccOrO = make_tensor(tOrO_out.data(), shape(taccOsO)); + + // Make sure all WGs have finished reading V + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::ValueEmpty) /*id*/); + cute::copy(rmem_tiled_copy_O, taccOrO, taccOsO); + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + } + + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE); + Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0, 0>(taccOcO))::value == 2); + static_assert(decltype(size<0, 1>(taccOcO))::value == 2); + // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices. + Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{}); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // 2 * MMA_M + + if constexpr(!Seqlen_traits::UseGQAPacking) { + Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor( + mLSE, Shape>{}, bidh, bidb, n_split_idx)(_, m_block); + if (get<1>(taccOcO_row(_0{})) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { + gLSE(row) = lse(mi); + } + } + } + } else { + // shape<1>(epilogue_params.layout_O) == h/h_k + // In common case where ceil_div(h/h_k, kBlockH) == 1, + // int(qhead_per_khead_divmod) == 1, bidh_kv == bidh, h_block == 0 + const int h_offset = shape<1>(epilogue_params.layout_O) * bidh_kv + + h_block * kBlockH; + const int m_bound = seqlen_traits_q.actual_seq_len - m_block * (kBlockM/kBlockH); + const int h_bound = shape<1>(epilogue_params.layout_O) - h_block * kBlockH; + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + const int h_local = row % kBlockH; + const int m_local = row/kBlockH; + if(h_local < h_bound && m_local < m_bound) { + Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(mLSE, + Shape>{}, h_offset + h_local, bidb, n_split_idx) + (_, m_block); + gLSE(m_local) = lse(mi); + } + } + } + + if constexpr (No_smem_O) { + flash::write_rmem_to_gmem( + tOrO_out, epilogue_params.ptr_O, epilogue_params.layout_O, TileShapeOCopy{}, + m_block, h_block, bidh, bidh_kv, bidb, n_split_idx, + tiled_mma, seqlen_traits_q, thread_idx); + } else { + int write_warp_idx = kNWarps - 1; + if (cutlass::canonical_warp_idx_sync() == write_warp_idx) { + cutlass::arch::NamedBarrier::sync( + NumMmaThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ); + } + TiledCopyO gmem_tiled_copy_O; + Tensor sO_out = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutOCopy{}); + if constexpr(!Seqlen_traits::UseGQAPacking) { + flash::write_O( + epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O, + epilogue_params.layout_O, TileShapeOCopy{}, sO_out, + m_block, bidh, bidb, n_split_idx, seqlen_traits_q, write_warp_idx, tiled_mma, tOrO_out + ); + } else { + Tensor mO = epilogue_params.tma_store_O.get_tma_tensor(epilogue_params.layout_O.shape()); + Tensor gO = seqlen_traits_q.get_o_local_tile_tensor( + mO, TileShapeOCopy{}, bidh_kv, bidb, n_split_idx) + (_, _, _, m_block, h_block); // (bM/bH, bH, K) + auto block_tma_O = epilogue_params.tma_store_O.get_slice(_0{}); + Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) + Tensor tOsO = block_tma_O.partition_S(sO_out); // (TMA, TMA_M, TMA_K) + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == write_warp_idx && lane_predicate) { + cute::copy(epilogue_params.tma_store_O, tOsO, tOgO); + tma_store_arrive(); + } + } + } + } + + CUTLASS_DEVICE void + store_tail() { + if constexpr(!No_smem_O) { tma_store_wait<0>(); } + } + + // Write 0 to output and -inf to LSE + template + CUTLASS_DEVICE void + store_zero( + Params const& epilogue_params, + SharedStorage& shared_storage, + int thread_idx, + cute::tuple const& block_coord, + const Seqlen_traits& seqlen_traits_q + ) { + static_assert(!Seqlen_traits::UseGQAPacking, "Don't call store_zero for gqa packed layouts."); + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + + if constexpr(!Is_split) { + Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O); + Tensor gO = seqlen_traits_q.get_o_local_tile_tensor( + mO, select<0, 2>(TileShape_MNK{}), bidh, bidb, n_split_idx + )(_, _, m_block); // (M, K) + + TiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_fragment_like(tOgO); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_traits_q.actual_seq_len - m_block * kBlockM + ); + } + + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE); + Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor( + mLSE, Shape>{}, bidh, bidb, n_split_idx)(_, m_block); + static_assert(kBlockM <= NumMmaThreads); + if (thread_idx < min(kBlockM, seqlen_traits_q.actual_seq_len - m_block * kBlockM)) { + gLSE(thread_idx) = !Is_split ? INFINITY : -INFINITY; + } + } + + // Write 0 to output and -inf to LSE + template + CUTLASS_DEVICE void + store_zero_gqa( + Params const& epilogue_params, + SharedStorage& shared_storage, + int thread_idx, + cute::tuple const& block_coord, + const Seqlen_traits& seqlen_traits_q, + const cutlass::FastDivmod& qhead_per_khead_divmod + ) { + static_assert(Seqlen_traits::UseGQAPacking, "Special store_zero method for GQA packed layouts."); + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + const int bidh_kv = qhead_per_khead_divmod.divide(bidh); + const int h_block = bidh % int(qhead_per_khead_divmod); + const int h_bound = min(shape<1>(epilogue_params.layout_O) - h_block * kBlockH, kBlockH); + const int m_bound = min(seqlen_traits_q.actual_seq_len - m_block * (kBlockM/kBlockH), kBlockM/kBlockH); + + if constexpr(!Is_split) { + Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O); + Tensor gO = seqlen_traits_q.get_o_local_tile_tensor( + mO, TileShapeOCopy{}, bidh_kv, bidb, n_split_idx) + (_, _, _, m_block, h_block); // (bM/bH, bH, K) + TiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + if constexpr(kNumRows <= kBlockH) { + // slice into bM/bH and write out zero tiles (bH, K) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO(0,_,_)); + Tensor tOrO = make_fragment_like(tOgO); + clear(tOrO); + Tensor cO = cute::make_identity_tensor(select<1, 2>(TileShapeOCopy{})); + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + // dummy predicate, unused since Is_even_K=true + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + #pragma unroll + for(int m = 0; m < m_bound; ++m) { + tOgO = gmem_thr_copy_O.partition_D(gO(m,_,_)); + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, h_bound + ); + } + } else { + // slice into bH and write out zero tiles (bM/bH, K) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO(_,0,_)); + Tensor tOrO = make_fragment_like(tOgO); + clear(tOrO); + Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShapeOCopy{})); + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + // dummy predicate, unused since Is_even_K=true + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + #pragma unroll + for(int h = 0; h < h_bound; ++h) { + tOgO = gmem_thr_copy_O.partition_D(gO(_,h,_)); + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, m_bound + ); + } + } + } + + const int h_offset = shape<1>(epilogue_params.layout_O) * bidh_kv + h_block * kBlockH; + const int thread_idx_h = thread_idx % kBlockH; + const int thread_idx_m = thread_idx / kBlockH; + + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE); + Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor( + mLSE, Shape>{}, h_offset + thread_idx_h, bidb, n_split_idx)(_, m_block); + if(thread_idx_h < h_bound && thread_idx_m < m_bound) { + gLSE(thread_idx_m) = !Is_split ? INFINITY : -INFINITY; + } + } + +}; + +} // namespace flash diff --git a/candle-flash-attn-v3/hkernel/flash.h b/candle-flash-attn-v3/hkernel/flash.h new file mode 100644 index 0000000000..0b5adb267e --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash.h @@ -0,0 +1,198 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#include "cutlass/fast_math.h" // For cutlass::FastDivmod + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The stride between rows of Oaccum. + index_t oaccum_batch_stride; + index_t oaccum_row_stride; + index_t oaccum_head_stride; + index_t oaccum_split_stride; + + // The pointer to the P matrix. + void * __restrict__ p_ptr; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q, total_k; + int b_k; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + uint32_t scale_softmax_log2_half2; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + + // If provided, the actual length of each q / o sequence. + int * __restrict__ seqused_q; + // If provided, the actual length of each k / v sequence. + int * __restrict__ seqused_k; + + int *__restrict__ blockmask; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int * __restrict__ cache_batch_idx; + + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + int page_num_blocks; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Local window size + int window_size_left, window_size_right; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_e4m3; + bool is_causal; + bool is_local; + bool is_kv_cache; + bool use_gqa_packing; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; + + bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. + bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). + + int * __restrict__ tile_count_semaphore; + float * __restrict__ descale_q_ptr; + float * __restrict__ descale_k_ptr; + float * __restrict__ descale_v_ptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// struct Flash_bwd_params : public Flash_fwd_params { + +// // The dO and dQKV matrices. +// void *__restrict__ do_ptr; +// void *__restrict__ dq_ptr; +// void *__restrict__ dk_ptr; +// void *__restrict__ dv_ptr; + +// // To accumulate dQ +// void *__restrict__ dq_accum_ptr; +// void *__restrict__ dk_accum_ptr; +// void *__restrict__ dv_accum_ptr; + +// // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q +// // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ +// // dv_accum_ptr; + +// // The stride between rows of the dO, dQ, dK and dV matrices. +// // TD [2022-04-16]: We're using 32-bit indexing to save registers. +// // The code probably won't work for arrays larger than 2GB. +// index_t do_batch_stride; +// index_t do_row_stride; +// index_t do_head_stride; +// index_t dq_batch_stride; +// index_t dk_batch_stride; +// index_t dv_batch_stride; +// index_t dq_row_stride; +// index_t dk_row_stride; +// index_t dv_row_stride; +// index_t dq_head_stride; +// index_t dk_head_stride; +// index_t dv_head_stride; + +// // The pointer to the softmax d sum. +// void *__restrict__ dsoftmax_sum; +// void *__restrict__ softmax_lse_log2_ptr; + +// int *__restrict__ dq_semaphore; + +// bool deterministic; +// index_t dq_accum_split_stride; +// }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream); +// template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); diff --git a/candle-flash-attn-v3/hkernel/flash_api.cpp b/candle-flash-attn-v3/hkernel/flash_api.cpp new file mode 100644 index 0000000000..d79f5211e0 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_api.cpp @@ -0,0 +1,1745 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. +#include +#include +#include +#include + +#include + +#include "flash.h" +#include "static_switch.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#include +#include +#include // For __half and __half2float +#include // For cudaMemcpy, cudaMemcpyDeviceToHost + +// Helper to read/print small FP16 arrays from device +void read_and_print_fp16(const void* dev_ptr, size_t num_elements, const char* name) { + if (!dev_ptr) { + printf(" %s is null.\n", name); + return; + } + // Allocate host array + std::vector<__half> host_data(num_elements); + // Copy from GPU -> CPU + cudaMemcpy(host_data.data(), dev_ptr, sizeof(__half) * num_elements, cudaMemcpyDeviceToHost); + + printf(" %s first %zu FP16 elements:\n ", name, num_elements); + for (size_t i = 0; i < num_elements; i++) { + float val = __half2float(host_data[i]); + printf("%9.6f ", val); + } + printf("\n"); +} + +// Helper to read/print small int32 arrays from device +void read_and_print_int32(const int32_t* dev_ptr, size_t num_elements, const char* name) { + if (!dev_ptr) { + printf(" %s is null.\n", name); + return; + } + std::vector host_data(num_elements); + cudaMemcpy(host_data.data(), dev_ptr, sizeof(int32_t) * num_elements, cudaMemcpyDeviceToHost); + + printf(" %s first %zu int32 values:\n ", name, num_elements); + for (size_t i = 0; i < num_elements; i++) { + printf("%d ", host_data[i]); + } + printf("\n"); +} + +void print_params(const Flash_fwd_params &p) { + printf("\n===== Flash_fwd_params Dump =====\n"); + + // Basic geometry + printf(" b = %lu\n", p.b); + printf(" b_k = %lu\n", p.b_k); + printf(" h = %lu\n", p.h); + printf(" h_k = %lu\n", p.h_k); + printf(" d = %lu\n", p.d); + printf(" d_rounded = %lu\n", p.d_rounded); + printf(" h_h_k_ratio = %lu\n", p.h_h_k_ratio); + + // Sequence lengths + printf(" seqlen_q = %lu\n", p.seqlen_q); + printf(" seqlen_k = %lu\n", p.seqlen_k); + printf(" seqlen_q_rounded = %lu\n", p.seqlen_q_rounded); + printf(" seqlen_k_rounded = %lu\n", p.seqlen_k_rounded); + printf(" total_q = %u\n", p.total_q); + printf(" total_k = %u\n", p.total_k); + + // Strides + printf("\n Strides:\n"); + printf(" q_batch_stride = %lu\n", (unsigned long)p.q_batch_stride); + printf(" q_row_stride = %lu\n", (unsigned long)p.q_row_stride); + printf(" q_head_stride = %lu\n", (unsigned long)p.q_head_stride); + printf(" k_batch_stride = %lu\n", (unsigned long)p.k_batch_stride); + printf(" k_row_stride = %lu\n", (unsigned long)p.k_row_stride); + printf(" k_head_stride = %lu\n", (unsigned long)p.k_head_stride); + printf(" v_batch_stride = %lu\n", (unsigned long)p.v_batch_stride); + printf(" v_row_stride = %lu\n", (unsigned long)p.v_row_stride); + printf(" v_head_stride = %lu\n", (unsigned long)p.v_head_stride); + printf(" o_batch_stride = %lu\n", (unsigned long)p.o_batch_stride); + printf(" o_row_stride = %lu\n", (unsigned long)p.o_row_stride); + printf(" o_head_stride = %lu\n", (unsigned long)p.o_head_stride); + + // Pointer addresses + printf("\n Pointer addresses:\n"); + printf(" q_ptr = %p\n", p.q_ptr); + printf(" k_ptr = %p\n", p.k_ptr); + printf(" v_ptr = %p\n", p.v_ptr); + printf(" o_ptr = %p\n", p.o_ptr); + printf(" p_ptr = %p\n", p.p_ptr); + printf(" softmax_lse_ptr = %p\n", p.softmax_lse_ptr); + printf(" alibi_slopes_ptr= %p\n", p.alibi_slopes_ptr); + printf(" descale_q_ptr = %p\n", p.descale_q_ptr); + printf(" descale_k_ptr = %p\n", p.descale_k_ptr); + printf(" descale_v_ptr = %p\n", p.descale_v_ptr); + + // (varlen / kv-cache) pointer addresses + printf(" cu_seqlens_q = %p\n", p.cu_seqlens_q); + printf(" cu_seqlens_k = %p\n", p.cu_seqlens_k); + printf(" seqused_q = %p\n", p.seqused_q); + printf(" seqused_k = %p\n", p.seqused_k); + printf(" block_table = %p\n", p.block_table); + printf(" tile_count_semaphore = %p\n", p.tile_count_semaphore); + + // Additional KV cache / GQA + printf("\n GQA / KV cache details:\n"); + printf(" page_block_size = %d\n", p.page_block_size); + printf(" page_num_blocks = %d\n", p.page_num_blocks); + printf(" use_gqa_packing = %d\n", p.use_gqa_packing); + printf(" num_splits = %d\n", p.num_splits); + + // Softmax & dropout scales + printf("\n Softmax / dropout:\n"); + printf(" scale_softmax = %f\n", p.scale_softmax); + printf(" scale_softmax_log2 = %f\n", p.scale_softmax_log2); + printf(" scale_softmax_log2_half2 = 0x%08x (raw bits)\n", p.scale_softmax_log2_half2); + printf(" p_dropout = %f\n", p.p_dropout); + printf(" p_dropout_in_uint8_t = %u\n", p.p_dropout_in_uint8_t); + printf(" rp_dropout = %f\n", p.rp_dropout); + printf(" scale_softmax_rp_dropout = %f\n", p.scale_softmax_rp_dropout); + + // Booleans / flags + printf("\n Flags:\n"); + printf(" is_bf16 = %d\n", p.is_bf16); + printf(" is_e4m3 = %d\n", p.is_e4m3); + printf(" is_causal = %d\n", p.is_causal); + printf(" is_local = %d\n", p.is_local); + printf(" is_kv_cache = %d\n", p.is_kv_cache); + printf(" seqlenq_ngroups_swapped = %d\n", p.seqlenq_ngroups_swapped); + printf(" unpadded_lse = %d\n", p.unpadded_lse); + + // Window / block sizes + printf(" window_size_left = %d\n", p.window_size_left); + printf(" window_size_right = %d\n", p.window_size_right); + + printf("===== End of Flash_fwd_params Dump =====\n\n"); + + // Optional: read small data from pointers. + // Adjust "4" or "2" to however many elements you need to debug. + if (p.q_ptr) { + read_and_print_fp16(p.q_ptr, 4, "q_ptr"); + } + if (p.k_ptr) { + read_and_print_fp16(p.k_ptr, 4, "k_ptr"); + } + if (p.v_ptr) { + read_and_print_fp16(p.v_ptr, 4, "v_ptr"); + } + if (p.o_ptr) { + read_and_print_fp16(p.o_ptr, 4, "o_ptr"); + } + if (p.softmax_lse_ptr) { + read_and_print_fp16(p.softmax_lse_ptr, 4, "softmax_lse_ptr"); + } + + // For cu_seqlens_q and cu_seqlens_k, read 2 int32_t elements, for example + if (p.cu_seqlens_q) { + read_and_print_int32(static_cast(p.cu_seqlens_q), 2, "cu_seqlens_q"); + } + if (p.cu_seqlens_k) { + read_and_print_int32(static_cast(p.cu_seqlens_k), 2, "cu_seqlens_k"); + } +} + +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t b_k, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *p_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + bool seqlenq_ngroups_swapped=false, + bool unpadded_lse=false) { + + // Reset the parameters + params = {}; + + params.is_bf16 = q.dtype() == torch::kBFloat16; + params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn; + params.is_kv_cache = false; + params.page_num_blocks = 0; + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + params.o_batch_stride = out.stride(0); + if (seqlenq_ngroups_swapped) { + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; + } + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_q = static_cast(seqused_q); + params.seqused_k = static_cast(seqused_k); + + TORCH_CHECK( + bool(params.cu_seqlens_q) == bool(params.cu_seqlens_k), + "cu_seqlens_q and cu_seqlens_k must be both null or non-null" + ); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.b_k = b_k; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2); + __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half); + params.scale_softmax_log2_half2 = reinterpret_cast(scale_softmax_log2_half2); + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + window_size_left = std::min(int(seqlen_k), window_size_left); + window_size_right = std::min(int(seqlen_k), window_size_right); + if (window_size_left < 0) { window_size_left = seqlen_k; } + if (window_size_right < 0) { window_size_right = seqlen_k; } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_causal = window_size_left == int(seqlen_k) && window_size_right == 0; + if ((window_size_left < int(seqlen_k) || window_size_right < int(seqlen_k)) && !params.is_causal) { + params.is_local = true; + } + + #ifdef FLASHATTENTION_DISABLE_LOCAL + TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), + "This flash attention build does not support local attention."); + #endif + + #ifdef FLASHATTENTION_DISABLE_UNEVEN_K + TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); + #endif + + params.unpadded_lse = unpadded_lse; + params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; +} + +void set_params_dgrad(Flash_bwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor out, + const at::Tensor dout, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *dq_accum_d, + void *dk_accum_d, + void *dv_accum_d, + void *softmax_lse_d, + void *dsoftmax_sum_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + bool deterministic) { + + set_params_fprop(params, + b, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + q, k, v, out, + cu_seqlens_q_d, + cu_seqlens_k_d, + seqused_q, + seqused_k, + nullptr, + softmax_lse_d, + p_dropout, + softmax_scale, + window_size_left, + window_size_right); + + // Set the pointers and strides. + params.do_ptr = dout.data_ptr(); + params.do_row_stride = dout.stride(-3); + params.do_head_stride = dout.stride(-2); + params.dq_ptr = dq.data_ptr(); + params.dk_ptr = dk.data_ptr(); + params.dv_ptr = dv.data_ptr(); + params.page_num_blocks = 0; + params.dq_row_stride = dq.stride(-3); + params.dk_row_stride = dk.stride(-3); + params.dv_row_stride = dv.stride(-3); + params.dq_head_stride = dq.stride(-2); + params.dk_head_stride = dk.stride(-2); + params.dv_head_stride = dv.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.do_batch_stride = dout.stride(0); + params.dq_batch_stride = dq.stride(0); + params.dk_batch_stride = dk.stride(0); + params.dv_batch_stride = dv.stride(0); + } + + params.dq_accum_ptr = dq_accum_d; + params.dk_accum_ptr = dk_accum_d; + params.dv_accum_ptr = dv_accum_d; + + // Softmax sum + params.dsoftmax_sum = dsoftmax_sum_d; + + params.deterministic = deterministic; +} + + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 80% +// of the best efficiency. +inline int num_splits_heuristic(int batch_nheads_mblocks, int batch_nheads, int num_SMs, int num_n_blocks, + int max_splits, int head_size, bool use_one_mma_wg) { + // Goal of the starting threshold is to determine whether to split or not. + // Empirically, the efficiency threshold can be much lower than 80% depending on num_n_blocks. + int num_m_blocks = batch_nheads_mblocks/batch_nheads; + float start_threshold; + float num_n_blocksf = float(num_n_blocks); + if (head_size == 128) { + if (std::log2f(num_n_blocksf) <= 4) { // 2048 -- .25 + start_threshold = .20f + (std::log2f(num_n_blocksf) - 3) * .05f; + } else if (std::log2f(num_n_blocksf) <= 5) { // 4096 -- .25 + start_threshold = .25f; + } else if (std::log2f(num_n_blocksf) <= 6) { // 8192 -- .36 + start_threshold = .28f + (std::log2f(num_n_blocksf) - 5) * .08f; + } else if (std::log2f(num_n_blocksf) <= 7) { // 16K -- .42 + start_threshold = .36f + (std::log2f(num_n_blocksf) - 6) * .06f; + } else { + // Just split freely + start_threshold = .8f; + } + if (num_m_blocks > 1 && start_threshold < .5f) + start_threshold += .05f * (std::log2f(num_n_blocksf) - 2); + } else if (head_size == 256) { + // TODO for hdim 256 + if (num_n_blocks <= 40) { + start_threshold = .24f; + } else if (std::log2f(num_n_blocksf) <= 8) { + start_threshold = .33f + std::max(0.f, (std::log2f(num_n_blocksf) - std::log2f(50)) * 0.02971f); + } else { + // Just split freely + start_threshold = .8f; + } + } else if (head_size == 64) { + if (use_one_mma_wg) { + if (std::log2f(num_n_blocksf) <= 4) { // 2K -- .33 + start_threshold = .33f; + } else if (std::log2f(num_n_blocksf) <= 5) { // 4K -- .37 + start_threshold = .33f + (std::log2f(num_n_blocksf) - 4) * .04f; + } else if (std::log2f(num_n_blocksf) <= 6) { // 8K -- .40 + start_threshold = .37f + (std::log2f(num_n_blocksf) - 5) * .03f; + } else if (std::log2f(num_n_blocksf) <= 7) { // 16K -- .43 + start_threshold = .4f + (std::log2f(num_n_blocksf) - 6) * .03f; + } else if (std::log2f(num_n_blocksf) <= 8) { // 32K -- .46 + start_threshold = .43f + (std::log2f(num_n_blocksf) - 7) * .03f; + } else { + start_threshold = .8f; + } + } else { + if (std::log2f(num_n_blocksf) <= 6) { // 8K -- .5 + start_threshold = .5f; + } else { + start_threshold = .8f; + } + } + } else { + // placeholder for other hdims + start_threshold = .8f; + } + + float first_wave = float(batch_nheads_mblocks) / num_SMs; + // printf("Start threshold and wave = %f, %f.\n", start_threshold, first_wave); + // Only use start_threshold if initial work doesn't exceed one wave + if ((first_wave/ceil(first_wave) > start_threshold && first_wave <= 1.f) || + (first_wave/ceil(first_wave) > .8f)) { + return 1; + } + // if (first_wave_batch_nheads > start_threshold) { return 1; } + // if (first_wave_batch_nheads > start_threshold || first_wave > .8f) { return 1; } + // if (float(batch_nheads)/num_SMs > start_threshold) { return 1; } + + // If num_n_blocks is too small, use 1 split + // For example, we never split for hdim = 128 and seqlen_k = 512, + // or for hdim = 128, seqlen_k = 1024, and one MMA warpgroup. + if (num_n_blocks < 8 || (use_one_mma_wg && num_n_blocks < 10)) { return 1; } + + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + + // NOTE: disable split eligibility check for FA3 since we have dynamic tile scheduler + // for exiting splits with no work early, and check leads to efficiency quantization issues. + // Comment from FA2: + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + // auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + // return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + // }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + // if (!is_split_eligible(num_splits)) { + // efficiency.push_back(0.f); + // } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, n_waves = %f, ceil(n_waves) = %f, eff = %f\n", num_splits, n_waves, ceil(n_waves), eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + // } + } + // Correct for excessive splitting with e.g. 1 bsz*nheads*mblocks + // Empirically, efficiency threshold in these cases is about 40% for 64K seqlen_k + float threshold = num_m_blocks == 1 ? std::min(0.3f + batch_nheads * 0.1f, 0.8f) : 0.8f; + threshold = threshold * max_efficiency; + // printf("Max efficiency = %f. Threshold = %f.\n", max_efficiency, threshold); + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + // if (!is_split_eligible(num_splits)) { continue; } + if (efficiency[num_splits - 1] > threshold) { + // printf("num_splits chosen = %d, threshold = %f, efficiency = %f.\n", num_splits, threshold, efficiency[num_splits - 1]); + return num_splits; + } + } + return 1; +} + +std::tuple set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, + const int num_heads, const int num_heads_k, const int head_size, const int max_seqlen_k, const int max_seqlen_q, + const int head_size_rounded, const float p_dropout, + const int num_splits, cudaDeviceProp *dprops, bool use_gqa_packing, bool is_causal, struct c10::TensorOptions opts) { + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + + params.num_splits = num_splits; + at::Tensor softmax_lse_accum; + at::Tensor out_accum; + + if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout + if (num_splits < 1) { + const int gqa_ratio = num_heads / num_heads_k; + const int block_h = 1 << static_cast(std::ceil(std::log2(std::clamp(gqa_ratio, 1, 32)))); + const int block_m = head_size == 64 ? 192 : 128; + const bool use_one_mma_wg = max_seqlen_q <= 64/block_h; + + int block_n = 128; + if (head_size == 128 && !is_causal) { + block_n = 176; + } else if (head_size == 256) { + block_n = use_one_mma_wg ? 96 : 80; + } + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + const int batch_nheads = use_gqa_packing ? batch_size * num_heads_k : batch_size * num_heads; + const int batch_nheads_mblocks = use_gqa_packing + ? ceildiv(max_seqlen_q, block_m / block_h) * batch_nheads + : ceildiv(max_seqlen_q, block_m) * batch_nheads; + params.num_splits = num_splits_heuristic(batch_nheads_mblocks, batch_nheads, + dprops->multiProcessorCount, num_n_blocks, 128, head_size, use_one_mma_wg); + // printf("Num splits heuristic = %d.\n", params.num_splits); + } + if (params.num_splits > 1) { + softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + params.oaccum_row_stride = out_accum.stride(-2); + params.oaccum_head_stride = out_accum.stride(-3); + params.oaccum_batch_stride = out_accum.stride(-4); + params.oaccum_split_stride = out_accum.stride(0); + } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); + } + + return std::make_tuple(softmax_lse_accum, out_accum); +} + + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + + int dtype = 1; + if (params.is_bf16) { dtype = 2; } + else if (params.is_e4m3) { dtype = 3; } + PREC_SWITCH(dtype, Element, [&] { + HEADDIM_SWITCH(params.d, kHeadSize, [&] { + if(!params.use_gqa_packing) { + run_mha_fwd_(params, stream); + } else { + QUERYHEAD_SWITCH(params.h_h_k_ratio, kBlockH, [&] { + run_mha_fwd_gqa_(params, stream); + }); + } + }); + }); + +#if 0 + if (!params.is_e4m3) { + if (params.is_bf16) { + if (params.d == 64) { + run_mha_fwd_(params, stream); + } else if (params.d == 128) { + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_(params, stream); + } + } else { + if (params.d == 64) { + run_mha_fwd_(params, stream); + } else if (params.d == 128) { + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_(params, stream); + } + } + } else { + if (params.d == 64) { + run_mha_fwd_(params, stream); + } else if (params.d == 128) { + run_mha_fwd_(params, stream); + } else if (params.d == 256) { + run_mha_fwd_(params, stream); + } + } +#endif +} + +std::vector +mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const float softmax_scale, + c10::optional &descale_q_, // 1 + c10::optional &descale_k_, // 1 + c10::optional &descale_v_, // 1 + bool is_causal, + int window_size_left, + int window_size_right, + bool use_gqa_packing = false + ) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90, "FlashAttention-3 only supports Hopper GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == at::ScalarType::Float8_e4m3fn, + "FlashAttention-3 only support fp16, bf16, or fp8 e4m3 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + // Guard against mistaken setting of gqa flag + if (num_heads == num_heads_k) { use_gqa_packing = false; } + + TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now"); + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn + ? (out.dtype() == at::kBFloat16) + : (out.dtype() == q_dtype), + "Output must have the same dtype as input dtype if dtype is " + "not fp8, or fp16 for fp8 input."); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } else { + if (q_dtype == at::ScalarType::Float8_e4m3fn) + out = torch::empty_like(q_padded, at::kBFloat16); + else + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + if (is_causal) { window_size_right = 0; } + + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor p; + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q_padded, k_padded, v_padded, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_q=*/nullptr, + /*seqused_k=*/nullptr, + nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + /*window_size_left=*/window_size_left, + /*window_size_right=*/window_size_right); + + auto tile_count_semaphore = is_causal || params.is_local + ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); + params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + + at::Tensor descale_q, descale_k, descale_v; + if(q_dtype == at::ScalarType::Float8_e4m3fn) { + if (descale_q_.has_value()) { + descale_q = descale_q_.value(); + CHECK_DEVICE(descale_q); + CHECK_SHAPE(descale_q, 1); + } else { descale_q = torch::ones({1}, opts.dtype(at::kFloat)); } + if (descale_k_.has_value()) { + descale_k = descale_k_.value(); + CHECK_DEVICE(descale_k); + CHECK_SHAPE(descale_k, 1); + } else { descale_k = torch::ones({1}, opts.dtype(at::kFloat)); } + if (descale_v_.has_value()) { + descale_v = descale_v_.value(); + CHECK_DEVICE(descale_v); + CHECK_SHAPE(descale_v, 1); + } else { descale_v = torch::ones({1}, opts.dtype(at::kFloat)); } + params.descale_q_ptr = descale_q.data_ptr(); + params.descale_k_ptr = descale_k.data_ptr(); + params.descale_v_ptr = descale_v.data_ptr(); + } else { + params.descale_q_ptr = nullptr; + params.descale_k_ptr = nullptr; + params.descale_v_ptr = nullptr; + } + + params.use_gqa_packing = use_gqa_packing; + + if (seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; +} + +std::vector +mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used. + c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + int max_seqlen_q, + const int max_seqlen_k, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + at::Tensor block_table; + const bool paged_KV = block_table_.has_value(); + if (paged_KV) { + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + + void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); + + const int total_q = q.sizes()[0]; + + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : k.size(0); + const int page_block_size = !paged_KV ? -1 : k.size(1); + TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + const int total_k = k.size(0); + + if (!paged_KV) { + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + } else { + CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } + + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + if (seqused_q.has_value()){ + auto seqused_q_ = seqused_q.value(); + TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device"); + TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous"); + CHECK_SHAPE(seqused_q_, batch_size); + } + + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og); + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } else { + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + if (is_causal) { window_size_right = 0; } + + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto opts = q.options(); + auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q_padded, k_padded, v_padded, out, + cu_seqlens_q_d, + cu_seqlens_k.data_ptr(), + seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr, + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, + /*p_d=*/nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + /*seqlenq_ngroups_swapped=*/false, + /*unpadded_lse=*/true); + params.total_q = total_q; + params.total_k = total_k; + + if (paged_KV) { + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + params.page_num_blocks = k.size(0); + } + params.page_block_size = page_block_size; + params.page_num_blocks = num_blocks; + + //printf("mha_varlen_fwd: params.seqlen_k=%d, max_seqlen_k=%d, params.page_num_blocks=%d\n", (int)params.seqlen_k, (int)max_seqlen_k, (int)params.page_num_blocks); + if (max_seqlen_k > 0) { + // print_params(params); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse}; +} + +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + // FP16_SWITCH(!params.is_bf16, [&] { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_bwd_(params, stream); + // }); + // }); + if (!params.is_bf16) { + if (params.d <= 64) { + run_mha_bwd_(params, stream); + } else if (params.d <= 96) { + run_mha_bwd_(params, stream); + } else { + run_mha_bwd_(params, stream); + } + } else { + if (params.d <= 64) { + run_mha_bwd_(params, stream); + } else if (params.d <= 96) { + run_mha_bwd_(params, stream); + } else { + run_mha_bwd_(params, stream); + } + } +} + +std::vector +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic) { + + #ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm9x = dprops->major == 9 && dprops->minor >= 0; + TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer."); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = dout.size(3); + const int head_size = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size <= 128, "FlashAttention backward only supports head dimension at most 128"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32); + // This should match the kernel configs + const int kBlockM = head_size <= 64 ? 128 : (head_size < 256 ? 64 : 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, kBlockM); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dv = torch::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + dout_padded = dout; + } + + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto opts = q.options(); + // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 + auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + auto softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + at::Tensor dk_accum, dv_accum; + dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat)); + // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat)); + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); + dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + if (is_causal) { window_size_right = 0; } + + Flash_bwd_params params; + + set_params_dgrad(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + dout_padded, dq, dk_expanded, dv_expanded, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_q=*/nullptr, + /*seqused_k=*/nullptr, + dq_accum.data_ptr(), + // loop ? dk_accum.data_ptr() : nullptr, + // loop ? dv_accum.data_ptr() : nullptr, + nullptr, + nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + /*window_size_left=*/window_size_left, + /*window_size_right=*/window_size_right, + deterministic); + params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); + + // Will be zero'ed out in the backward preprocess kernel + at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32)); + params.dq_semaphore = dq_semaphore.data_ptr(); + // printf("dq_semaphore: %p, [%d, %d, %d]\n", params.dq_semaphore, (seqlen_q + 64 - 1) / 64, batch_size, num_heads); + + if (seqlen_q > 0) { + run_mha_bwd(params, stream); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); + at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); + } + + if (head_size_og % 8 != 0) { + dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d, dq_accum}; +} + +std::vector +mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used. + c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic) { + + #ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm9x = dprops->major == 9 && dprops->minor >= 0; + TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer."); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = dout.size(2); + const int head_size = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size <= 128, "FlashAttention backward only supports head dimension at most 128"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32); + // This should match the kernel configs + const int kBlockM = head_size <= 64 ? 128 : (head_size < 256 ? 64 : 32); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, kBlockM); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + int const total_q_padded_rounded = round_multiple(total_q + batch_size * 128, 128); + + TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(dout, total_q, num_heads, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + if (seqused_q.has_value()){ + auto seqused_q_ = seqused_q.value(); + TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device"); + TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous"); + CHECK_SHAPE(seqused_q_, batch_size); + } + + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size); + } else { + dv = torch::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + dout_padded = dout; + } + + if (is_causal) { window_size_right = 0; } + + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto opts = q.options(); + // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 + auto softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat)); + auto softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + at::Tensor dk_accum, dv_accum; + dq_accum = torch::empty({num_heads, total_q_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat)); + // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat)); + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); + dv_expanded = torch::empty({total_k, num_heads, head_size}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + Flash_bwd_params params; + + set_params_dgrad(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + dout_padded, dq, dk_expanded, dv_expanded, + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr, + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, + dq_accum.data_ptr(), + // loop ? dk_accum.data_ptr() : nullptr, + // loop ? dv_accum.data_ptr() : nullptr, + nullptr, + nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + /*window_size_left=*/window_size_left, + /*window_size_right=*/window_size_right, + deterministic); + params.total_q = total_q; + params.total_k = total_k; + params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); + + // Will be zero'ed out in the backward preprocess kernel + at::Tensor dq_semaphore = torch::empty({(max_seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32)); + params.dq_semaphore = dq_semaphore.data_ptr(); + + if (max_seqlen_q > 0) { + run_mha_bwd(params, stream); + } else { + // If max_seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); + at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); + } + + if (head_size_og % 8 != 0) { + dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d, dq_accum, softmax_lse_log2 }; +} + +std::vector +mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + c10::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size + c10::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size + c10::optional &seqlens_k_, // batch_size + c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + c10::optional &cache_batch_idx_, // indices to index into the KV cache + c10::optional &leftpad_k_, // batch_size + c10::optional &block_table_, // batch_size x max_num_blocks_per_seq + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const float softmax_scale, + c10::optional &descale_q_, // 1 + c10::optional &descale_k_, // 1 + c10::optional &descale_v_, // 1 + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + int num_splits, + int max_seqlen_k_hint, + bool use_gqa_packing + ) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + // bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90, "FlashAttention-3 only supports Hopper GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == at::ScalarType::Float8_e4m3fn, + "FlashAttention-3 only support fp16, bf16, or fp8 e4m3 data type"); + TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + at::Tensor block_table; + const bool paged_KV = block_table_.has_value(); + if (paged_KV) { + TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx"); + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : kcache.size(0); + const int page_block_size = !paged_KV ? 1 : kcache.size(1); + TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; + const int num_heads_k = kcache.size(2); + const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + // Guard against mistaken setting of gqa flag + if (num_heads == num_heads_k) { use_gqa_packing = false; } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + if (is_causal) { window_size_right = 0; } + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = + seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && + window_size_right < 0 && head_size_og % 8 == 0 && + !alibi_slopes_.has_value() && !use_gqa_packing; + if (seqlenq_ngroups_swapped) { + const int ngroups = num_heads / num_heads_k; + q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + if (!paged_KV) { + CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); + } else { + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } + + at::Tensor q_padded, kcache_padded, vcache_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + kcache_padded = kcache; + vcache_padded = vcache; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn + ? (out.dtype() == at::kBFloat16) + : (out.dtype() == q_dtype), + "Output must have the same dtype as input dtype if dtype is " + "not fp8, or fp16 for fp8 input."); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } else { + if (q_dtype == at::ScalarType::Float8_e4m3fn) { + out = torch::empty_like(q_padded, at::kBFloat16); + } + else + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, batch_size_c, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q_padded, kcache_padded, vcache_padded, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_q=*/nullptr, + /*seqused_k=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right + ); + + at::Tensor descale_q, descale_k, descale_v; + if(q_dtype == at::ScalarType::Float8_e4m3fn) { + if (descale_q_.has_value()) { + descale_q = descale_q_.value(); + CHECK_DEVICE(descale_q); + CHECK_SHAPE(descale_q, 1); + } else { descale_q = torch::ones({1}, opts.dtype(at::kFloat)); } + if (descale_k_.has_value()) { + descale_k = descale_k_.value(); + CHECK_DEVICE(descale_k); + CHECK_SHAPE(descale_k, 1); + } else { descale_k = torch::ones({1}, opts.dtype(at::kFloat)); } + if (descale_v_.has_value()) { + descale_v = descale_v_.value(); + CHECK_DEVICE(descale_v); + CHECK_SHAPE(descale_v, 1); + } else { descale_v = torch::ones({1}, opts.dtype(at::kFloat)); } + params.descale_q_ptr = descale_q.data_ptr(); + params.descale_k_ptr = descale_k.data_ptr(); + params.descale_v_ptr = descale_v.data_ptr(); + } else { + params.descale_q_ptr = nullptr; + params.descale_k_ptr = nullptr; + params.descale_v_ptr = nullptr; + } + + params.is_kv_cache = true; + + params.use_gqa_packing = use_gqa_packing; + + at::Tensor k, v, k_padded, v_padded; + if (k_.has_value()) { + TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in"); + TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in"); + TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache"); + k = k_.value(); + v = v_.value(); + TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query"); + TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query"); + CHECK_DEVICE(k); CHECK_DEVICE(v); + TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension"); + int seqlen_knew = k.size(1); + CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og); + if (head_size_og % 8 != 0) { + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + k_padded = k; + v_padded = v; + } + params.seqlen_knew = seqlen_knew; + params.knew_ptr = k_padded.data_ptr(); + params.vnew_ptr = v_padded.data_ptr(); + // All stride are in elements, not bytes. + params.knew_batch_stride = k_padded.stride(0); + params.vnew_batch_stride = v_padded.stride(0); + params.knew_row_stride = k_padded.stride(-3); + params.vnew_row_stride = v_padded.stride(-3); + params.knew_head_stride = k_padded.stride(-2); + params.vnew_head_stride = v_padded.stride(-2); + } + + if (seqlens_k_.has_value()) { + auto seqlens_k = seqlens_k_.value(); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + CHECK_DEVICE(seqlens_k); + CHECK_CONTIGUOUS(seqlens_k); + CHECK_SHAPE(seqlens_k, batch_size); + params.seqused_k = static_cast(seqlens_k.data_ptr()); + } + if (leftpad_k_.has_value()) { + TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + TORCH_CHECK(false, "Left Padding K is not supported"); + //params.leftpad_k = static_cast(leftpad_k.data_ptr()); + } + + if (rotary_cos_.has_value()) { + TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + auto rotary_cos = rotary_cos_.value(); + CHECK_DEVICE(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + CHECK_CONTIGUOUS(rotary_cos); + TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); + + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + CHECK_CONTIGUOUS(rotary_sin); + TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + } else { + params.rotary_dim = 0; + } + + if (cache_batch_idx_.has_value()) { + auto cache_batch_idx = cache_batch_idx_.value(); + CHECK_DEVICE(cache_batch_idx); + CHECK_CONTIGUOUS(cache_batch_idx); + TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32"); + params.cache_batch_idx = reinterpret_cast(cache_batch_idx.data_ptr()); + } + + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( + params, batch_size, num_heads, num_heads_k, head_size, max_seqlen_k_hint, seqlen_q, + head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, use_gqa_packing, is_causal, opts); + + auto tile_count_semaphore = is_causal || params.is_local || params.num_splits != 1 + ? torch::zeros({1}, opts.dtype(torch::kInt32)) + : torch::empty({1}, opts.dtype(torch::kInt32)); + params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + + if (paged_KV) { + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + } + params.page_block_size = page_block_size; + + TORCH_CHECK(!alibi_slopes_.has_value(), "Alibi Slopes are not supported yet"); + //set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx, + // or paged KV cache + //run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV); + run_mha_fwd(params, stream); + + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + if (k_.has_value()) { + // It's expensive to copy the KV cache here for the case where head size not divisible by 8, + // but we don't expect to get this case in practice. This is just so that the code works for that case. + kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); + vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); + } + } + + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + } + + return {out, softmax_lse}; +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashAttention"; + m.def("fwd", &mha_fwd, "Forward pass"); + m.def("bwd", &mha_bwd, "Backward pass"); + m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); + m.def("varlen_bwd", &mha_varlen_bwd, "Varlen backward pass"); + m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache"); +} diff --git a/candle-flash-attn-v3/hkernel/flash_api.cu b/candle-flash-attn-v3/hkernel/flash_api.cu new file mode 100644 index 0000000000..2452140daa --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_api.cu @@ -0,0 +1,315 @@ +#include "flash_fwd_launch_template.h" +#include "flash.h" +#include "static_switch.h" + + +// Helper to read/print small FP16 arrays from device +void read_and_print_fp16(const void* dev_ptr, size_t num_elements, const char* name) { + if (!dev_ptr) { + printf(" %s is null.\n", name); + return; + } + // We copy `num_elements` __half from GPU -> CPU + std::vector<__half> host_data(num_elements); + cudaMemcpy(host_data.data(), dev_ptr, + sizeof(__half) * num_elements, cudaMemcpyDeviceToHost); + + printf(" %s first %zu FP16 elements:\n ", name, num_elements); + for (size_t i = 0; i < num_elements; i++) { + // Convert each __half to float for printing + float val = __half2float(host_data[i]); + printf("%9.6f ", val); + } + printf("\n"); +} + +// Helper to read/print small int32 arrays from device +void read_and_print_int32(const int32_t* dev_ptr, size_t num_elements, const char* name) { + if (!dev_ptr) { + printf(" %s is null.\n", name); + return; + } + std::vector host_data(num_elements); + cudaMemcpy(host_data.data(), dev_ptr, + sizeof(int32_t) * num_elements, cudaMemcpyDeviceToHost); + + printf(" %s first %zu int32 values:\n ", name, num_elements); + for (size_t i = 0; i < num_elements; i++) { + printf("%d ", host_data[i]); + } + printf("\n"); +} + +// Prints all fields from Flash_fwd_params, plus optionally reads small data from pointers +void print_params(const Flash_fwd_params &p) { + printf("\n===== Flash_fwd_params Dump =====\n"); + + // Basic geometry + printf(" b = %lu\n", p.b); + printf(" b_k = %lu\n", p.b_k); + printf(" h = %lu\n", p.h); + printf(" h_k = %lu\n", p.h_k); + printf(" d = %lu\n", p.d); + printf(" d_rounded = %lu\n", p.d_rounded); + printf(" h_h_k_ratio = %lu\n", p.h_h_k_ratio); + + // Sequence lengths + printf(" seqlen_q = %lu\n", p.seqlen_q); + printf(" seqlen_k = %lu\n", p.seqlen_k); + printf(" seqlen_q_rounded = %lu\n", p.seqlen_q_rounded); + printf(" seqlen_k_rounded = %lu\n", p.seqlen_k_rounded); + printf(" total_q = %u\n", p.total_q); + printf(" total_k = %u\n", p.total_k); + + // Strides + printf(" q_batch_stride = %lu\n", (unsigned long)p.q_batch_stride); + printf(" q_row_stride = %lu\n", (unsigned long)p.q_row_stride); + printf(" q_head_stride = %lu\n", (unsigned long)p.q_head_stride); + printf(" k_batch_stride = %lu\n", (unsigned long)p.k_batch_stride); + printf(" k_row_stride = %lu\n", (unsigned long)p.k_row_stride); + printf(" k_head_stride = %lu\n", (unsigned long)p.k_head_stride); + printf(" v_batch_stride = %lu\n", (unsigned long)p.v_batch_stride); + printf(" v_row_stride = %lu\n", (unsigned long)p.v_row_stride); + printf(" v_head_stride = %lu\n", (unsigned long)p.v_head_stride); + printf(" o_batch_stride = %lu\n", (unsigned long)p.o_batch_stride); + printf(" o_row_stride = %lu\n", (unsigned long)p.o_row_stride); + printf(" o_head_stride = %lu\n", (unsigned long)p.o_head_stride); + + // Pointer addresses + printf("\n Pointer addresses:\n"); + printf(" q_ptr = %p\n", p.q_ptr); + printf(" k_ptr = %p\n", p.k_ptr); + printf(" v_ptr = %p\n", p.v_ptr); + printf(" o_ptr = %p\n", p.o_ptr); + printf(" p_ptr = %p\n", p.p_ptr); + printf(" softmax_lse_ptr = %p\n", p.softmax_lse_ptr); + printf(" alibi_slopes_ptr= %p\n", p.alibi_slopes_ptr); + printf(" descale_q_ptr = %p\n", p.descale_q_ptr); + printf(" descale_k_ptr = %p\n", p.descale_k_ptr); + printf(" descale_v_ptr = %p\n", p.descale_v_ptr); + + // (varlen / kv-cache) pointer addresses + printf(" cu_seqlens_q = %p\n", p.cu_seqlens_q); + printf(" cu_seqlens_k = %p\n", p.cu_seqlens_k); + printf(" seqused_q = %p\n", p.seqused_q); + printf(" seqused_k = %p\n", p.seqused_k); + printf(" block_table = %p\n", p.block_table); + printf(" tile_count_semaphore = %p\n", p.tile_count_semaphore); + + // Additional KV cache / GQA + printf(" page_block_size = %d\n", p.page_block_size); + printf(" page_num_blocks = %d\n", p.page_num_blocks); + printf(" use_gqa_packing = %d\n", p.use_gqa_packing); + printf(" num_splits = %d\n", p.num_splits); + + // Softmax & dropout scales + printf("\n Softmax / dropout:\n"); + printf(" scale_softmax = %f\n", p.scale_softmax); + printf(" scale_softmax_log2 = %f\n", p.scale_softmax_log2); + printf(" scale_softmax_log2_half2 = 0x%08x (raw bits)\n", p.scale_softmax_log2_half2); + printf(" p_dropout = %f\n", p.p_dropout); + printf(" p_dropout_in_uint8_t = %u\n", p.p_dropout_in_uint8_t); + printf(" rp_dropout = %f\n", p.rp_dropout); + printf(" scale_softmax_rp_dropout = %f\n", p.scale_softmax_rp_dropout); + + // Booleans / flags + printf("\n Flags:\n"); + printf(" is_bf16 = %d\n", p.is_bf16); + printf(" is_e4m3 = %d\n", p.is_e4m3); + printf(" is_causal = %d\n", p.is_causal); + printf(" is_local = %d\n", p.is_local); + printf(" is_kv_cache = %d\n", p.is_kv_cache); + printf(" seqlenq_ngroups_swapped = %d\n", p.seqlenq_ngroups_swapped); + printf(" unpadded_lse = %d\n", p.unpadded_lse); + + // Window / block sizes + printf(" window_size_left = %d\n", p.window_size_left); + printf(" window_size_right = %d\n", p.window_size_right); + + printf("===== End of Flash_fwd_params Dump =====\n\n"); + + // Optional: read small data from pointers. + // Adjust the "4" or "2" below for however many elements you want to debug. + + // For example, if q_ptr is not null, try reading 4 elements as FP16 + if (p.q_ptr) { + read_and_print_fp16(p.q_ptr, 4, "q_ptr"); + } + if (p.k_ptr) { + read_and_print_fp16(p.k_ptr, 4, "k_ptr"); + } + if (p.v_ptr) { + read_and_print_fp16(p.v_ptr, 4, "v_ptr"); + } + if (p.o_ptr) { + read_and_print_fp16(p.o_ptr, 4, "o_ptr"); + } + if (p.softmax_lse_ptr) { + read_and_print_fp16(p.softmax_lse_ptr, 4, "softmax_lse_ptr"); + } + + // For cu_seqlens_q and cu_seqlens_k, read 2 int32_t elements, for example + if (p.cu_seqlens_q) { + read_and_print_int32(p.cu_seqlens_q, 2, "cu_seqlens_q"); + } + if (p.cu_seqlens_k) { + read_and_print_int32(p.cu_seqlens_k, 2, "cu_seqlens_k"); + } +} + + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + // Select a numeric code for precision: + // 3 = cutlass::float_e4m3_t (fp8) + // 2 = cutlass::bfloat16_t (bf16) + // 1 = cutlass::half_t (fp16) + int prec_type = 1; // default = fp16 + if (params.is_e4m3) { + prec_type = 3; + } else if (params.is_bf16) { + prec_type = 2; + } + // TODO: no GQA switch + PREC_SWITCH(prec_type, elem_type, [&] { + HEADDIM_SWITCH(params.d, kHeadDim, [&] { + // run_mha_fwd_(params, stream); + if(!params.use_gqa_packing) { + run_mha_fwd_(params, stream); + } else { + QUERYHEAD_SWITCH(params.h_h_k_ratio, kBlockH, [&] { + run_mha_fwd_gqa_(params, stream); + }); + } + }); + + }); +} + +extern "C" void run_mha( + void *q_ptr, + void *k_ptr, + void *v_ptr, + void *o_ptr, + void *softmax_lse_ptr, + void *alibi_slopes_ptr, + + int32_t *cu_seqlens_q_ptr, + int32_t *cu_seqlens_k_ptr, + + uint32_t q_batch_stride, + uint32_t k_batch_stride, + uint32_t v_batch_stride, + uint32_t o_batch_stride, + uint32_t alibi_slopes_batch_stride, + + uint32_t q_row_stride, + uint32_t k_row_stride, + uint32_t v_row_stride, + uint32_t o_row_stride, + + uint32_t q_head_stride, + uint32_t k_head_stride, + uint32_t v_head_stride, + uint32_t o_head_stride, + + uint32_t b, + uint32_t h, + uint32_t h_k, + uint32_t d, + uint32_t d_rounded, + float softmax_scale, + + uint32_t seqlen_q, + uint32_t seqlen_k, + uint32_t seqlen_q_rounded, + uint32_t seqlen_k_rounded, + + int is_bf16, + int is_causal, + int unpadded_lse, + int use_gqa_packing, + + int window_size_left, + int window_size_right, + + uint32_t total_q, + uint32_t total_k +) { + Flash_fwd_params params; + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + // Set the pointers and strides. + params.q_ptr = q_ptr; + params.k_ptr = k_ptr; + params.v_ptr = v_ptr; + params.o_ptr = o_ptr; + + params.softmax_lse_ptr = softmax_lse_ptr; + params.alibi_slopes_ptr = alibi_slopes_ptr; + + // All stride are in elements, not bytes. + params.q_batch_stride = q_batch_stride; + params.k_batch_stride = k_batch_stride; + params.v_batch_stride = v_batch_stride; + params.o_batch_stride = o_batch_stride; + params.alibi_slopes_batch_stride = alibi_slopes_batch_stride; + + params.q_row_stride = q_row_stride; + params.k_row_stride = k_row_stride; + params.v_row_stride = v_row_stride; + params.o_row_stride = o_row_stride; + params.q_head_stride = q_head_stride; + params.k_head_stride = k_head_stride; + params.v_head_stride = v_head_stride; + params.o_head_stride = o_head_stride; + + // Set the dimensions. + params.b = b; + params.b_k = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2); + __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half); + params.scale_softmax_log2_half2 = reinterpret_cast(scale_softmax_log2_half2); + + params.p_dropout = 1.; // probability to keep + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + params.is_bf16 = is_bf16; + params.cu_seqlens_q = cu_seqlens_q_ptr; + params.cu_seqlens_k = cu_seqlens_k_ptr; + params.p_ptr = nullptr; // used for `return_softmax`. + params.seqused_q = nullptr; + params.seqused_k = nullptr; + + params.is_causal = is_causal; + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.num_splits = 0; + params.page_block_size = -1; + + params.total_q = total_q; + params.total_k = total_k; + + params.unpadded_lse = unpadded_lse; + params.use_gqa_packing = use_gqa_packing; + + // print_params(params); + + cudaStream_t stream = 0; // Use the default stream. + run_mha_fwd(params, stream); +} \ No newline at end of file diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa16_sm90.cu new file mode 100644 index 0000000000..d839721b19 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa2_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa2_sm90.cu new file mode 100644 index 0000000000..85d328151b --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa32_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa32_sm90.cu new file mode 100644 index 0000000000..4bf5525c7c --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa4_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa4_sm90.cu new file mode 100644 index 0000000000..486c762ff5 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa8_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa8_sm90.cu new file mode 100644 index 0000000000..157081389c --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_sm90.cu new file mode 100644 index 0000000000..11bb9ddecc --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa16_sm90.cu new file mode 100644 index 0000000000..45ce0357da --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa2_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa2_sm90.cu new file mode 100644 index 0000000000..1941fe4a20 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa32_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa32_sm90.cu new file mode 100644 index 0000000000..c3c2d5e2fc --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa4_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa4_sm90.cu new file mode 100644 index 0000000000..8341090702 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa8_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa8_sm90.cu new file mode 100644 index 0000000000..98cdac6767 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_sm90.cu new file mode 100644 index 0000000000..04b431f10b --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_fp8(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa16_sm90.cu new file mode 100644 index 0000000000..988041bf62 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa2_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa2_sm90.cu new file mode 100644 index 0000000000..92936c1d77 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa32_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa32_sm90.cu new file mode 100644 index 0000000000..1039313497 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa4_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa4_sm90.cu new file mode 100644 index 0000000000..2d369fcb34 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa8_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa8_sm90.cu new file mode 100644 index 0000000000..e556921af8 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_sm90.cu new file mode 100644 index 0000000000..176c38eddc --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim128_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa16_sm90.cu new file mode 100644 index 0000000000..2c9c356523 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa2_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa2_sm90.cu new file mode 100644 index 0000000000..5e72b41c4c --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa32_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa32_sm90.cu new file mode 100644 index 0000000000..90ae2162a7 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa4_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa4_sm90.cu new file mode 100644 index 0000000000..b7c6345b26 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa8_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa8_sm90.cu new file mode 100644 index 0000000000..566760319d --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_sm90.cu new file mode 100644 index 0000000000..06d0df617b --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa16_sm90.cu new file mode 100644 index 0000000000..9c0f7d626b --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa2_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa2_sm90.cu new file mode 100644 index 0000000000..c41ac3d4e9 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa32_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa32_sm90.cu new file mode 100644 index 0000000000..b486e1a393 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa4_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa4_sm90.cu new file mode 100644 index 0000000000..2b97017868 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa8_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa8_sm90.cu new file mode 100644 index 0000000000..ebe0f92cae --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_sm90.cu new file mode 100644 index 0000000000..78884313ec --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_fp8(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa16_sm90.cu new file mode 100644 index 0000000000..91fc6200e0 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa2_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa2_sm90.cu new file mode 100644 index 0000000000..21a81044ae --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa32_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa32_sm90.cu new file mode 100644 index 0000000000..502a66281f --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa4_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa4_sm90.cu new file mode 100644 index 0000000000..e6dc49dc67 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa8_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa8_sm90.cu new file mode 100644 index 0000000000..046c9e304c --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_sm90.cu new file mode 100644 index 0000000000..0cc26c7910 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim256_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa16_sm90.cu new file mode 100644 index 0000000000..0381c601ee --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa2_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa2_sm90.cu new file mode 100644 index 0000000000..6be1d9c588 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa32_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa32_sm90.cu new file mode 100644 index 0000000000..154efcac54 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa4_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa4_sm90.cu new file mode 100644 index 0000000000..b8fe56a321 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa8_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa8_sm90.cu new file mode 100644 index 0000000000..cda356c268 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_sm90.cu new file mode 100644 index 0000000000..d3839898f2 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa16_sm90.cu new file mode 100644 index 0000000000..74e61967a4 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa2_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa2_sm90.cu new file mode 100644 index 0000000000..ff8213c055 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa32_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa32_sm90.cu new file mode 100644 index 0000000000..22ce8ed06d --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa4_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa4_sm90.cu new file mode 100644 index 0000000000..b0f09e7808 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa8_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa8_sm90.cu new file mode 100644 index 0000000000..16775723d0 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_fp8_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_sm90.cu new file mode 100644 index 0000000000..471a5037a1 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_fp8(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa16_sm90.cu new file mode 100644 index 0000000000..cbe5159d17 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa2_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa2_sm90.cu new file mode 100644 index 0000000000..f18c68b231 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa2_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa32_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa32_sm90.cu new file mode 100644 index 0000000000..a4cf2813de --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa32_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa4_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa4_sm90.cu new file mode 100644 index 0000000000..8e9932dbd1 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa4_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa8_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa8_sm90.cu new file mode 100644 index 0000000000..79cbce7d01 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_gqa8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_gqa_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64_gqa(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_sm90.cu b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_sm90.cu new file mode 100644 index 0000000000..c6eac53520 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_hdim64_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_kernel.h b/candle-flash-attn-v3/hkernel/flash_fwd_kernel.h new file mode 100644 index 0000000000..4c5a109ad5 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_kernel.h @@ -0,0 +1,420 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "flash.h" +#include "utils.h" +#include "softmax.h" +#include "tile_scheduler.hpp" +#include "mainloop_fwd_sm90_tma_gmma_ws.hpp" +#include "epilogue_fwd_sm90_tma.hpp" + +namespace flash { + +using namespace cute; + +template +__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) + compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, + CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, + CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params, + Seqlen_traits_Q seqlen_traits_q, Seqlen_traits seqlen_traits_k + ) { + + using Element = typename Ktraits::Element; + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; + + static_assert(Ktraits::Is_WS); + static constexpr bool Is_WS = Ktraits::Is_WS; + static constexpr bool No_smem_O = Ktraits::No_smem_O; + + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); + static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup; + static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kBlockH = Ktraits::kBlockH; + // static constexpr int kBlockN = Ktraits::kBlockN; + // static constexpr int kHeadDim = Ktraits::kHeadDim; + + using CollectiveMainloop = CollectiveMainloopFwd; + using CollectiveEpilogue = CollectiveEpilogueFwd; + + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); + CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params); + } + + // Obtain warp index + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + pipeline_params.role = warp_group_idx == 0 + ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NumMmaThreads; + + if (warp_idx == 0 && lane_predicate) { + shared_storage.barrier_Q.init(1 /*numThreads*/); + if constexpr (!No_smem_O) { shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/); } + } + // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); + MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{}); + MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{}); + + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue; + + // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + + // static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16); + static_assert(Ktraits::kNWarps == 8 || Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16); + if (warp_group_idx == 0) { // Producer + cutlass::arch::warpgroup_reg_dealloc(); + + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + if (warp_idx_in_warpgroup == 0) { // Load Q, K, V + PipelineState smem_pipe_write_k = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_v = cutlass::make_producer_start_state(); + + int work_idx = 0; + + TileScheduler scheduler(&shared_storage.tile_count_semaphore); + for (auto work_tile_info = scheduler.get_initial_work(); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work(scheduler_params, work_tile_info)) { + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + + seqlen_traits_q.init(bidb); + seqlen_traits_k.init(bidb); + if constexpr(seqlen_traits_q.UseVarSeqLen) { + // NOTE: to support in future with gqa packed layouts, changed kBlockM to kBlockM/kBlockH + if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) { + continue; + } + } + int n_block_min = 0, n_block_max; + collective_mainloop.get_n_block_min_max( + mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k, + n_block_min, n_block_max); + if constexpr (Is_causal || Is_local || seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) { + if(n_block_max <= n_block_min) { + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + scheduler.broadcast_next_work(work_tile_info); + continue; + } + } + collective_mainloop.load( + mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, + shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx, + seqlen_traits_q, seqlen_traits_k, n_block_min, n_block_max); + ++work_idx; + } + collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v); + } + } else { // Consumer + cutlass::arch::warpgroup_reg_alloc(); + + TileScheduler scheduler(&shared_storage.tile_count_semaphore); + // Initialize matmul objects. + typename Ktraits::TiledMma1 tiled_mma1; + + PipelineState smem_pipe_read_k, smem_pipe_read_v; + // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v + // (like in Cutlass's gemm) because the read and release pipeline states are always the same. + + collective_mainloop.mma_init(); + scheduler.init_consumer(); + + int work_idx = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = scheduler.get_initial_work(); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work(scheduler_params, work_tile_info)) { + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); + flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax(mainloop_params.softmax_scale_log2); + + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + + seqlen_traits_q.init(bidb); + seqlen_traits_k.init(bidb); + if constexpr(seqlen_traits_q.UseVarSeqLen) { + // NOTE: to support in future with gqa packed layouts, changed kBlockM to kBlockM/kBlockH + if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) { + continue; + } + } + int n_block_max, n_block_min = 0; + collective_mainloop.get_n_block_min_max( + mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k, + n_block_min, n_block_max); + if constexpr (Is_causal || Is_local || seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) { + if(n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE. + if constexpr(!Seqlen_traits_Q::UseGQAPacking) { + collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, + block_coord, seqlen_traits_q); + } else { + collective_epilogue.store_zero_gqa(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, + block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod); + } + continue; + } + } + + collective_mainloop.mma( + mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, + tOrO, softmax, n_block_min, n_block_max, threadIdx.x - NumCopyThreads, work_idx, + m_block, shared_storage, seqlen_traits_q, seqlen_traits_k); + // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage); + collective_epilogue.store( + epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, + threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod); + + ++work_idx; + } + collective_epilogue.store_tail(); + } + +} + +template +__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) + compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, + CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, + CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params, + Seqlen_traits_Q seqlen_traits_q, Seqlen_traits seqlen_traits_k + ) { + + using Element = typename Ktraits::Element; + static_assert(cutlass::sizeof_bits_v == 8); + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; + + static_assert(Ktraits::Is_WS); + static constexpr bool Is_WS = Ktraits::Is_WS; + static constexpr bool No_smem_O = Ktraits::No_smem_O; + + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); + static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup; + static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kBlockH = Ktraits::kBlockH; + // static constexpr int kBlockN = Ktraits::kBlockN; + // static constexpr int kHeadDim = Ktraits::kHeadDim; + static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128 && Ktraits::kNWarps != 8; + static constexpr bool Use_max_offset = true; + + using CollectiveMainloop = CollectiveMainloopFwd; + using CollectiveEpilogue = CollectiveEpilogueFwd; + + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using MainloopPipelineVt = typename Ktraits::MainloopPipelineNoTMA; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineParamsVt = typename MainloopPipelineVt::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); + CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params); + } + + // Obtain warp index + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + // additional pipeline to synchronize out-of-place smem transpose of V + PipelineParamsVt pipeline_params_vt; + pipeline_params_vt.producer_arv_count = NumCopyThreads; + pipeline_params_vt.consumer_arv_count = NumMmaThreads; + MainloopPipelineVt pipeline_vt(shared_storage.pipeline_vt, pipeline_params_vt); + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + pipeline_params.role = warp_group_idx == 0 + ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NumMmaThreads; + + if (warp_idx == 0 && lane_predicate) { + shared_storage.barrier_Q.init(1 /*numThreads*/); + if constexpr (!No_smem_O) { shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/); } + } + // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); + MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{}); + // pipeline_v has producer warpgroup for its consumer in fp8 kernel + pipeline_params.num_consumers = NumCopyThreads; + pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; + MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{}); + + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue; + + float descale_q = *mainloop_params.descale_q_ptr; + float descale_k = *mainloop_params.descale_k_ptr; + float descale_v = *mainloop_params.descale_v_ptr; + shared_storage.softmax_scale_qk_log2 = mainloop_params.softmax_scale_log2 * descale_q * descale_k; + shared_storage.descale_v = descale_v; + shared_storage.seqlen_init_k = seqlen_traits_k.UseVarSeqLen || bool(seqlen_traits_k.seq_used); + + // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + + static_assert(Ktraits::kNWarps == 8 || Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16); + if (warp_group_idx == 0) { // Producer + cutlass::arch::warpgroup_reg_dealloc(); + + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + PipelineState smem_pipe_read, smem_pipe_release; + + int work_idx = 0; + + TileScheduler scheduler(&shared_storage.tile_count_semaphore); + for (auto work_tile_info = scheduler.get_initial_work(); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work(scheduler_params, work_tile_info)) { + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + + if constexpr (seqlen_traits_q.UseVarSeqLen) { seqlen_traits_q.init(bidb); } + if (shared_storage.seqlen_init_k) { seqlen_traits_k.init_no_guard(bidb); } + if constexpr(seqlen_traits_q.UseVarSeqLen) { + // NOTE: to support in future with gqa packed layout, changed kBlockM to kBlockM/kBlockH + if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) { + continue; + } + } + int n_block_min = 0, n_block_max; + collective_mainloop.get_n_block_min_max( + mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k, + n_block_min, n_block_max); + if constexpr (Is_causal || Is_local ||seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) { + if(n_block_max <= n_block_min) { + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + scheduler.broadcast_next_work(work_tile_info); + // need to sync producer warpgroup + cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast(FwdNamedBarriers::ProducerWG) /*id*/); + continue; + } + } + collective_mainloop.load_fp8( + mainloop_params, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, smem_pipe_read, + shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx, + seqlen_traits_q, seqlen_traits_k, n_block_min, n_block_max); + ++work_idx; + // don't need to sync producer warpgroup here + // if constexpr (Is_causal) { + // cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast(FwdNamedBarriers::ProducerWG) /*id*/); } + } + collective_mainloop.load_tail_one_write(pipeline_k, pipeline_v, smem_pipe_write); + } else { // Consumer + cutlass::arch::warpgroup_reg_alloc(); + + TileScheduler scheduler(&shared_storage.tile_count_semaphore); + // Initialize matmul objects. + typename Ktraits::TiledMma1 tiled_mma1; + PipelineState smem_pipe_read; + PipelineState smem_pipe_release; + + collective_mainloop.mma_init(); + scheduler.init_consumer(); + + int work_idx = 0; + + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = scheduler.get_initial_work(); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work(scheduler_params, work_tile_info)) { + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); + flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), Use_max_offset> softmax(shared_storage.softmax_scale_qk_log2); + + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + + if constexpr (seqlen_traits_q.UseVarSeqLen) { seqlen_traits_q.init(bidb); } + if (shared_storage.seqlen_init_k) { seqlen_traits_k.init_no_guard(bidb); } + if constexpr(seqlen_traits_q.UseVarSeqLen) { + // NOTE: to support in future with gqa packed layout, changed kBlockM to kBlockM/kBlockH + if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) { + continue; + } + } + int n_block_max, n_block_min = 0; + collective_mainloop.get_n_block_min_max( + mainloop_params, m_block, n_split_idx, seqlen_traits_q, seqlen_traits_k, + n_block_min, n_block_max); + if constexpr (Is_causal || Is_local || seqlen_traits_k.UseVarSeqLen || Ktraits::Is_split) { + if(n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE. + if constexpr(!Seqlen_traits_Q::UseGQAPacking) { + collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, + block_coord, seqlen_traits_q); + } else { + collective_epilogue.store_zero_gqa(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, + block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod); + } + continue; + } + } + + collective_mainloop.mma_fp8( + mainloop_params, pipeline_k, pipeline_vt, smem_pipe_read, smem_pipe_release, + tOrO, softmax, n_block_min, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, + shared_storage, seqlen_traits_q, seqlen_traits_k); + + collective_epilogue.store( + epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, + threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod); + + ++work_idx; + } + collective_epilogue.store_tail(); + } + +} + +} // namespace flash diff --git a/candle-flash-attn-v3/hkernel/flash_fwd_launch_template.h b/candle-flash-attn-v3/hkernel/flash_fwd_launch_template.h new file mode 100644 index 0000000000..b91c74a2df --- /dev/null +++ b/candle-flash-attn-v3/hkernel/flash_fwd_launch_template.h @@ -0,0 +1,561 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/cluster_launch.hpp" + +#include "static_switch.h" +#include "flash.h" +#include "tile_scheduler.hpp" +#include "flash_fwd_kernel.h" +#include "kernel_traits.h" +#include "seq_len.h" +#include "utils.h" +#include "combine.h" + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time."); + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using OutputType = typename Kernel_traits::OutputType; + using TileShape_MNK = typename Kernel_traits::TileShape_MNK; + using ClusterShape = typename Kernel_traits::ClusterShape_MNK; + + constexpr static bool Is_split = Kernel_traits::Is_split; + static_assert(Seqlen_traits_Q::UseGQAPacking == (Kernel_traits::kBlockH > 1), "If kBlockH > 1, use gqa packed layouts"); + static_assert(!(Is_split && Seqlen_traits::UseVarSeqLen), "Split KV not yet supported for variable seqlen."); + + using CollectiveMainloop = flash::CollectiveMainloopFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + using Scheduler = std::conditional_t< + Seqlen_traits::UseVarSeqLen, + flash::SingleTileScheduler, + std::conditional_t, + flash::DynamicPersistentTileScheduler< + Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup, + Kernel_traits::NumProducerThreads, + Is_split + > + >>; + // using Scheduler = flash::SingleTileScheduler; + Seqlen_traits_Q seqlen_traits_q( + params.total_q, params.seqlen_q, params.cu_seqlens_q, params.seqused_q); + Seqlen_traits seqlen_traits_k( + params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k); + + typename CollectiveMainloop::Params mainloop_params = + CollectiveMainloop::to_underlying_arguments({ + static_cast(params.q_ptr), + seqlen_traits_q.get_gmem_layout( + params.seqlen_q, params.d, params.h_k, params.b, params.h_h_k_ratio, + params.q_row_stride, params.q_head_stride, params.q_batch_stride + ), // layout_Q + static_cast(params.k_ptr), + seqlen_traits_k.get_gmem_layout( + params.seqlen_k, params.d, params.h_k, params.b_k, + params.k_row_stride, params.k_head_stride, params.k_batch_stride, + params.page_block_size, params.page_num_blocks + ), // layout_K + static_cast(params.v_ptr), + seqlen_traits_k.get_gmem_layout( + params.seqlen_k, params.d, params.h_k, params.b_k, + params.v_row_stride, params.v_head_stride, params.v_batch_stride, + params.page_block_size, params.page_num_blocks + ), // layout_V + seqlen_traits_k.get_virtual_shape(params.seqlen_k, params.d, params.h_k, params.b, params.h_h_k_ratio, false), + params.scale_softmax_log2, + params.descale_q_ptr, + params.descale_k_ptr, + params.descale_v_ptr, + params.window_size_left, + params.window_size_right, + ceil_div(params.h_h_k_ratio, Kernel_traits::kBlockH), + params.cache_batch_idx, + Is_split ? params.num_splits : 1, + params.block_table, + params.block_table_batch_stride, + params.page_block_size, + (params.page_block_size > 0) ? params.b*params.seqlen_k/params.page_block_size : 0 + }); + typename CollectiveEpilogue::Params epilogue_params = [&] { + if constexpr(!Is_split) { + return CollectiveEpilogue::to_underlying_arguments({ + static_cast(params.o_ptr), + seqlen_traits_q.get_gmem_layout( + params.seqlen_q, params.d, params.h_k, params.b, params.h_h_k_ratio, + params.o_row_stride, params.o_head_stride, params.o_batch_stride + ), // layout_O + static_cast(params.softmax_lse_ptr), + seqlen_traits_q.get_lse_gmem_layout( + params.seqlen_q, params.h, params.b + ) // layout_LSE + }); + } else { + return CollectiveEpilogue::to_underlying_arguments({ + static_cast(params.oaccum_ptr), + seqlen_traits_q.get_oaccum_gmem_layout( + params.seqlen_q, params.d, params.h_k, params.b, params.h_h_k_ratio, params.num_splits, + params.oaccum_row_stride, params.oaccum_head_stride, params.oaccum_batch_stride, + params.oaccum_split_stride + ), // layout_O + static_cast(params.softmax_lseaccum_ptr), + seqlen_traits_q.get_lseaccum_gmem_layout( + params.seqlen_q, params.h, params.b, params.num_splits + ), // layout_LSE + }); + } + }(); + + int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM/Kernel_traits::kBlockH); + num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{}); + int num_blocks_h = params.h_k * ceil_div(params.h_h_k_ratio, Kernel_traits::kBlockH); + typename Scheduler::Arguments scheduler_args = + {num_blocks_m, Is_split ? params.num_splits : 1, num_blocks_h, params.b, params.tile_count_semaphore}; + typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); + + // Get the ptr to kernel function. + void *kernel; + if constexpr(cutlass::sizeof_bits_v == 8) + kernel = (void *)flash::compute_attn_ws_fp8; + else + kernel = (void *)flash::compute_attn_ws; + if (params.block_table != nullptr) { + if ((params.page_block_size % Kernel_traits::kBlockN) != 0) { + fprintf(stderr, "Sequence length in N (%d) dimension must divide page block size (%d) if block table is used\n", (int) Kernel_traits::kBlockN, (int) params.page_block_size); + exit(1); + } + } + int smem_size = sizeof(typename Kernel_traits::SharedStorage); + // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q)); + // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k)); + // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v)); + // int smem_size_o = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_o)); + // printf("smem_size = %d, q = %d, k = %d, v = %d, o = %d.\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_o); + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + + int device; + cudaGetDevice(&device); + int multiprocessor_count; + CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); + static constexpr int ctaSize = Kernel_traits::kNWarps * 32; + dim3 block_dims(ctaSize); + if constexpr(size(ClusterShape{}) > 1) { + dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); + cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; + cutlass::launch_kernel_on_cluster( + launch_params, kernel, mainloop_params, epilogue_params, + scheduler_params, seqlen_traits_q, seqlen_traits_k); + } else { + if constexpr(cutlass::sizeof_bits_v == 8) { + flash::compute_attn_ws_fp8 + <<>> + (mainloop_params, epilogue_params, scheduler_params, seqlen_traits_q, seqlen_traits_k); + } else { + flash::compute_attn_ws + <<>> + (mainloop_params, epilogue_params, scheduler_params, seqlen_traits_q, seqlen_traits_k); + } + + } + CHECK_CUDA_KERNEL_LAUNCH(); + + if constexpr (Is_split) { + using FinalOutputType = typename Kernel_traits::FinalOutputType; + static_assert(is_same_v, "Assume OutputType of main kernel is float."); + static_assert(is_same_v, "ElementAccum must be float."); + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kHeadDim = Kernel_traits::kHeadDim; + constexpr static int kBlockM = kHeadDim % 128 == 0 ? 4 : (kHeadDim % 64 == 0 ? 8 : 16); + constexpr static bool Is_even_K = true; // always true for our current setting + void *kernel_combine; + int smem_size_combine; + NUM_SPLITS_SWITCH(params.num_splits, kLogMaxSplits, [&] { + constexpr static int kMaxSplits = 1 << kLogMaxSplits; + kernel_combine = (void *) flash::combine_attn_seqk_parallel< + FinalOutputType, ElementAccum, kHeadDim, kBlockM, kLogMaxSplits, Is_even_K, Flash_fwd_params>; + smem_size_combine = sizeof( + flash::SharedStorageLSE, Int>, Shape>>); + }); + if (smem_size_combine >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel_combine, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_combine)); + } + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + dim3 block_dims_combine(128); + dim3 cluster_dims_combine(1, 1, 1); + cutlass::ClusterLaunchParams launch_params_combine{ + grid_combine, block_dims_combine, cluster_dims_combine, smem_size_combine, stream}; + cutlass::launch_kernel_on_cluster(launch_params_combine, kernel_combine, params); + CHECK_CUDA_KERNEL_LAUNCH(); + } +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + constexpr static bool UseCluster = false; + + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + MMA_3WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] { + SEQLEN_SWITCH(params, Seqlen_traits, Seqlen_traits_Q, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192) % 2 == 0 && !Is_causal && !Is_local && !Is_split + // && kNumMmaWGs == 3 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits, + Seqlen_traits_Q + >(params, stream); + // }); + }); + }); + }); + }); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + BOOL_SWITCH(params.block_table!=nullptr, UseBlockTable, [&] { + MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + SEQLEN_SWITCH(params, Seqlen_traits, Seqlen_traits_Q, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // Only use Cluster if number of tiles along seqlen_q is even + // and not Is_causal, Is_split, or varseqlen + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split + && kNumMmaWGs == 2 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits, + Seqlen_traits_Q + >(params, stream); + }); + + }); + }); + }); + }); + }); + }); +} + + + +template +void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 256; + BOOL_SWITCH(params.block_table!=nullptr, UseBlockTable, [&] { + MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + SEQLEN_SWITCH(params, Seqlen_traits, Seqlen_traits_Q, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // Only use Cluster if number of tiles along seqlen_q is even + // and not Is_causal, Is_split, or varseqlen + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split + && kNumMmaWGs == 2 && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits, + Seqlen_traits_Q + >(params, stream); + }); + }); + }); + }); + }); + }); + }); +} + +// template +// void run_mha_fwd_hdim64_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) { +// constexpr static int Headdim = 64; +// constexpr static int kBlockN = 128; +// constexpr static int kStages = 4; +// // constexpr static bool UseCluster = false; +// // constexpr static int kBlockM = 192; +// // constexpr static int kNWarps = 4 + kBlockM/16; +// using Seqlen_traits = flash::FixedSeqLenTraits; + +// MMA_3WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] { +// BOOL_SWITCH(params.is_causal, Is_causal, [&] { +// BOOL_SWITCH(params.is_local, Is_local, [&] { +// BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { +// BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192) % 2 == 0 && !Is_causal && !Is_local && !Is_split +// && kNumMmaWGs == 3, UseCluster, [&] { +// run_flash_fwd< +// Flash_fwd_kernel_traits_fp8, +// Is_causal, +// Is_local && !Is_causal, +// Seqlen_traits +// >(params, stream); +// }); +// }); +// }); +// }); +// }); +// } + +// template +// void run_mha_fwd_hdim128_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) { +// constexpr static int Headdim = 128; +// constexpr static int kBlockN = 256; +// constexpr static int kStages = 2; +// // constexpr static int kBlockM = 128; +// // constexpr static int kNWarps = 4 + kBlockM/16; +// using Seqlen_traits = flash::FixedSeqLenTraits; + +// MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] { +// BOOL_SWITCH(params.is_causal, Is_causal, [&] { +// BOOL_SWITCH(params.is_local, Is_local, [&] { +// BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { +// BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split +// && kNumMmaWGs == 2, UseCluster, [&] { +// run_flash_fwd< +// Flash_fwd_kernel_traits_fp8, +// Is_causal, +// Is_local && !Is_causal, +// Seqlen_traits +// >(params, stream); +// }); +// }); +// }); +// }); +// }); +// } + +// template +// void run_mha_fwd_hdim256_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) { +// constexpr static int Headdim = 256; +// constexpr static int kBlockN = 128; +// constexpr static int kStages = 2; +// // constexpr static int kBlockM = 128; +// // constexpr static int kNWarps = 4 + kBlockM/16; +// using Seqlen_traits = flash::FixedSeqLenTraits; + +// MMA_2WG_SWITCH(params.seqlen_q, kNumMmaWGs, [&] { +// BOOL_SWITCH(params.is_causal, Is_causal, [&] { +// BOOL_SWITCH(params.is_local, Is_local, [&] { +// BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { +// BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Is_split +// && kNumMmaWGs == 2, UseCluster, [&] { +// run_flash_fwd< +// Flash_fwd_kernel_traits_fp8, +// Is_causal, +// Is_local && !Is_causal, +// Seqlen_traits +// >(params, stream); +// }); +// }); +// }); +// }); +// }); +// } + +/* +** GQA methods +*/ + +template +void run_mha_fwd_hdim64_gqa(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + constexpr static bool UseCluster = false; + using Seqlen_traits = flash::FixedSeqLenTraits; + using Seqlen_traits_Q = flash::FixedGQASeqLenTraits; + + MMA_3WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split + // && kNumMmaWGs == 3, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits, + Seqlen_traits_Q + >(params, stream); + // }); + }); + }); + }); + }); +} + +template +void run_mha_fwd_hdim128_gqa(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + constexpr static bool UseCluster = false; + using Seqlen_traits = flash::FixedSeqLenTraits; + using Seqlen_traits_Q = flash::FixedGQASeqLenTraits; + + MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split + // && kNumMmaWGs == 2, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits, + Seqlen_traits_Q + >(params, stream); + // }); + }); + }); + }); + }); +} + +template +void run_mha_fwd_hdim256_gqa(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 256; + constexpr static bool UseCluster = false; + using Seqlen_traits = flash::FixedSeqLenTraits; + using Seqlen_traits_Q = flash::FixedGQASeqLenTraits; + + MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { + // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split + // && kNumMmaWGs == 2, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, + Is_local && !Is_causal, + Seqlen_traits, + Seqlen_traits_Q + >(params, stream); + // }); + }); + }); + }); + }); +} + +// template +// void run_mha_fwd_hdim64_fp8_gqa(Flash_fwd_params ¶ms, cudaStream_t stream) { +// constexpr static int Headdim = 64; +// constexpr static int kBlockN = 128; +// constexpr static int kStages = 4; +// constexpr static bool UseCluster = false; +// using Seqlen_traits = flash::FixedSeqLenTraits; +// using Seqlen_traits_Q = flash::FixedGQASeqLenTraits; + +// MMA_3WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] { +// BOOL_SWITCH(params.is_causal, Is_causal, [&] { +// BOOL_SWITCH(params.is_local, Is_local, [&] { +// BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { +// // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 192/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split +// // && kNumMmaWGs == 3, UseCluster, [&] { +// run_flash_fwd< +// Flash_fwd_kernel_traits_fp8, +// Is_causal, +// Is_local && !Is_causal, +// Seqlen_traits, +// Seqlen_traits_Q +// >(params, stream); +// // }); +// }); +// }); +// }); +// }); +// } + +// template +// void run_mha_fwd_hdim128_fp8_gqa(Flash_fwd_params ¶ms, cudaStream_t stream) { +// constexpr static int Headdim = 128; +// constexpr static int kBlockN = 256; +// constexpr static int kStages = 2; +// constexpr static bool UseCluster = false; +// using Seqlen_traits = flash::FixedSeqLenTraits; +// using Seqlen_traits_Q = flash::FixedGQASeqLenTraits; + +// MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] { +// BOOL_SWITCH(params.is_causal, Is_causal, [&] { +// BOOL_SWITCH(params.is_local, Is_local, [&] { +// BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { +// // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split +// // && kNumMmaWGs == 2, UseCluster, [&] { +// run_flash_fwd< +// Flash_fwd_kernel_traits_fp8, +// Is_causal, +// Is_local && !Is_causal, +// Seqlen_traits, +// Seqlen_traits_Q +// >(params, stream); +// // }); +// }); +// }); +// }); +// }); +// } + +// template +// void run_mha_fwd_hdim256_fp8_gqa(Flash_fwd_params ¶ms, cudaStream_t stream) { +// constexpr static int Headdim = 256; +// constexpr static int kBlockN = 128; +// constexpr static int kStages = 2; +// constexpr static bool UseCluster = false; +// using Seqlen_traits = flash::FixedSeqLenTraits; +// using Seqlen_traits_Q = flash::FixedGQASeqLenTraits; + +// MMA_2WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] { +// BOOL_SWITCH(params.is_causal, Is_causal, [&] { +// BOOL_SWITCH(params.is_local, Is_local, [&] { +// BOOL_SWITCH(params.num_splits > 1, Is_split, [&] { +// // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128/kBlockH) % 2 == 0 && !Is_causal && !Is_local && !Is_split +// // && kNumMmaWGs == 2, UseCluster, [&] { +// run_flash_fwd< +// Flash_fwd_kernel_traits_fp8, +// Is_causal, +// Is_local && !Is_causal, +// Seqlen_traits, +// Seqlen_traits_Q +// >(params, stream); +// // }); +// }); +// }); +// }); +// }); +// } diff --git a/candle-flash-attn-v3/hkernel/kernel_traits.h b/candle-flash-attn-v3/hkernel/kernel_traits.h new file mode 100644 index 0000000000..b7ef43f5de --- /dev/null +++ b/candle-flash-attn-v3/hkernel/kernel_traits.h @@ -0,0 +1,1085 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" + +using namespace cute; + +template +struct SharedStorageQKVO { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + union { + cute::array_aligned> smem_v; + cute::array_aligned> smem_o; + }; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterBarrier barrier_O; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + int tile_count_semaphore; + }; +}; + +// Use if Oaccum is too large for SharedStorageQKVO +template +struct SharedStorageQKVOaccum { + cute::array_aligned> smem_q; + union { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + cute::array_aligned> smem_o; + }; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterBarrier barrier_O; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + int tile_count_semaphore; + }; +}; + +// SharedStorage struct with no smem for O +template +struct SharedStorageQKV { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + int tile_count_semaphore; + }; +}; + +template +struct SharedStorageQKVOVt { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + union { + cute::array_aligned> smem_v_out; + cute::array_aligned> smem_o; + }; + }; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterBarrier barrier_O; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + typename cutlass::PipelineAsync::SharedStorage pipeline_vt; + int tile_count_semaphore; + float softmax_scale_qk_log2; + float descale_v; + bool seqlen_init_k; + }; +}; + +// Use if Oaccum is too large for SharedStorageQKVOVt +template +struct SharedStorageQKVOVtaccum { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + union { + struct { + cute::array_aligned> smem_v; + cute::array_aligned> smem_v_out; + }; + cute::array_aligned> smem_o; + }; + }; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterBarrier barrier_O; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + typename cutlass::PipelineAsync::SharedStorage pipeline_vt; + int tile_count_semaphore; + float softmax_scale_qk_log2; + float descale_v; + bool seqlen_init_k; + }; +}; + +template +struct SharedStorageQKVVt { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + cute::array_aligned> smem_v_out; + }; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + typename cutlass::PipelineAsync::SharedStorage pipeline_vt; + int tile_count_semaphore; + float softmax_scale_qk_log2; + float descale_v; + bool seqlen_init_k; + }; +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template +struct Flash_fwd_kernel_traits { + using Element = elem_type; + using ElementAccum = float; + using FinalOutputType = elem_type; + using OutputType = std::conditional_t; + using index_t = int64_t; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; + static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp; + + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_; + static_assert(kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16); + static constexpr bool Is_WS = true; + static_assert(!(Is_WS && Is_Q_in_regs), "Warp-specialization does not support Q in registers"); + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kBlockH = kBlockH_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static_assert(kBlockM % kBlockH == 0); + using TileShape_MNK = Shape, Int, Int>; + + static constexpr int kClusterM = kClusterM_; + using ClusterShape_MNK = Shape, _1, _1>; + + static constexpr int kStages = kStages_; + + static constexpr bool Is_split = Is_split_; + static constexpr bool No_smem_O = Is_split; + + using AtomLayoutMNK = Layout, _1, _1>>; + using TiledMma0 = decltype(cute::make_tiled_mma( + std::conditional_t< + Is_Q_in_regs, + decltype(cute::GMMA::rs_op_selector()), + decltype(cute::GMMA::ss_op_selector()) + >{}, + AtomLayoutMNK{})); + using TiledMma1 = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(TileShape_MNK{})), + GMMA::Major::K, GMMA::Major::MN>(), + AtomLayoutMNK{})); + + using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + + // for gmem -> smem Q copy + using FactoringLayoutQ = Layout, Int, Int>, + Stride, _1, Int>>; + using TileShapeQCopy = std::conditional_t<(kBlockH > 1), + decltype(shape(FactoringLayoutQ{})), decltype(select<0, 2>(TileShape_MNK{}))>; + using SmemLayoutQCopy = std::conditional_t<(kBlockH > 1), + decltype(composition(SmemLayoutQ{}, FactoringLayoutQ{})), SmemLayoutQ>; + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutK = + decltype(tile_to_shape(SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtomV{}, + make_shape(get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), Int{}))); + + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutVt = + decltype(composition(SmemLayoutV{}, + make_ordered_layout( + make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), + Step<_2, _1, _3>{}))); + + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + // for smem -> gmem O copy + using TileShapeOCopy = TileShapeQCopy; + using SmemLayoutOCopy = std::conditional_t<(kBlockH > 1), + decltype(composition(SmemLayoutO{}, FactoringLayoutQ{})), SmemLayoutO>; + + using SmemCopyAtomQ = Copy_Atom; + + using SharedStorage = std::conditional_t, + SharedStorageQKV>; + + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using MainloopPipelineNoTMA = typename cutlass::PipelineAsync; + using PipelineState = typename cutlass::PipelineState; + // using BarrierType = typename MainloopPipeline::ProducerBarrierType; + +}; + +// Traits struct for fp8 kernel with in-kernel transpose +// template +// struct Flash_fwd_kernel_traits_fp8 { +// using Element = elem_type; +// static_assert(cutlass::sizeof_bits_v == 8); +// using ElementAccum = float; +// using FinalOutputType = cutlass::bfloat16_t; +// using OutputType = std::conditional_t; +// using index_t = int64_t; + +// static constexpr bool Is_split = Is_split_; +// static constexpr bool No_smem_O = false; +// // NOTE: not using smem for epilogue degrades perf substantially. +// // static constexpr bool No_smem_O = Is_split; + +// // The number of threads. +// static constexpr int kNWarps = kNWarps_; +// static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; +// static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup; + +// static constexpr bool Is_Q_in_regs = Is_Q_in_regs_; +// static_assert(kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16); +// static constexpr bool Is_WS = true; +// static_assert(!Is_Q_in_regs, "Warp-specialization does not support Q in registers"); + +// static constexpr int kBlockM = kBlockM_; +// static constexpr int kBlockN = kBlockN_; +// static constexpr int kBlockH = kBlockH_; +// static constexpr int kHeadDim = kHeadDim_; +// static_assert(kHeadDim % 32 == 0); +// static_assert(kBlockM % kBlockH == 0); +// using TileShape_MNK = Shape, Int, Int>; + +// static constexpr int kClusterM = kClusterM_; +// using ClusterShape_MNK = Shape, _1, _1>; + +// static constexpr int kStages = kStages_; +// static_assert(kStages > 1); + +// // Use this to save enough smem when writing out in float precision. +// static constexpr bool VO_union_all = Is_split && (kBlockM != 64) && (kHeadDim == 256); + +// using AtomLayoutMNK = Layout, _1, _1>>; +// using TiledMma0 = decltype(cute::make_tiled_mma( +// cute::GMMA::ss_op_selector(), +// AtomLayoutMNK{})); + +// using TiledMma1 = decltype(cute::make_tiled_mma( +// cute::GMMA::rs_op_selector(TileShape_MNK{}))>(), +// AtomLayoutMNK{})); + +// using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); +// using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + +// // for gmem -> smem Q copy +// using FactoringLayoutQ = Layout, Int, Int>, +// Stride, _1, Int>>; +// using TileShapeQCopy = std::conditional_t<(kBlockH > 1), +// decltype(shape(FactoringLayoutQ{})), decltype(select<0, 2>(TileShape_MNK{}))>; +// using SmemLayoutQCopy = std::conditional_t<(kBlockH > 1), +// decltype(composition(SmemLayoutQ{}, FactoringLayoutQ{})), SmemLayoutQ>; + +// using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); +// using SmemLayoutK = +// decltype(tile_to_shape(SmemLayoutAtomK{}, +// make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + +// using TransposeShapeAtomV = Shape<_64, _64>; +// using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); +// using SmemLayoutV = +// decltype(tile_to_shape(SmemLayoutAtomV{}, +// make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + +// // for fp8 in-kernel transpose -- src layout +// using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{})); +// using SmemShapeLDSM = Shape, Shape<_16, _4>>; +// using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, +// shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{}))); +// using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{}))); + +// // For fp8, this is the memory transpose. +// using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); +// using SmemLayoutVt = +// decltype(tile_to_shape(SmemLayoutAtomVt{}, +// make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}))); + +// // for fp8 in-kernel transpose -- dst layout +// using SmemLayoutVtTrans = +// decltype(composition(SmemLayoutVt{}, +// make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{}))); +// using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); +// #ifndef NO_FP8_COLUMN_PERMUTE +// using SmemShapeSTSM = Shape, Shape<_8, _8>>; +// #else +// using SmemShapeSTSM = Shape, Shape<_16, _4>>; +// #endif +// using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, +// shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{}))); +// using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); + +// using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); +// using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); +// // for smem -> gmem O copy +// using TileShapeOCopy = TileShapeQCopy; +// using SmemLayoutOCopy = std::conditional_t<(kBlockH > 1), +// decltype(composition(SmemLayoutO{}, FactoringLayoutQ{})), SmemLayoutO>; + +// // used for rmem -> smem O copy in fp8 kernel to undo column permutation +// using ThreadLayoutrO = Layout, _4, _1>, +// Stride<_4, _32, _1, _0>>; +// using ValueLayoutrO = Layout, Int>, +// Stride<_0, _2, Stride<_4, _1>, _8>>; +// using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom, OutputType>{}, +// ThreadLayoutrO{}, ValueLayoutrO{})); + +// using TiledCopyShaperO = Shape<_8, Int, _16, Int>; +// using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout{})); + +// using SmemCopyAtomQ = Copy_Atom; + +// using SharedStorage = std::conditional_t, +// SharedStorageQKVOVtaccum>, +// SharedStorageQKVVt>; + +// using MainloopPipeline = typename cutlass::PipelineTmaAsync; +// using MainloopPipelineNoTMA = typename cutlass::PipelineAsync; +// using PipelineState = typename cutlass::PipelineState; +// // using BarrierType = typename MainloopPipeline::ProducerBarrierType; +// }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SharedStorageQKVdOdKV; + +template +struct SharedStorageQKVdOdKV { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + union { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + struct { + cute::array_aligned> smem_dk; + cute::array_aligned> smem_dv; + }; + }; + cute::array_aligned> smem_p; + cute::array_aligned> smem_ds; + }; + struct { + cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. + cutlass::arch::ClusterTransactionBarrier barrier_K; + cutlass::arch::ClusterTransactionBarrier barrier_V; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_do; + }; +}; + +template +struct SharedStorageQKVdOdKV { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + union { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + struct { + cute::array_aligned> smem_dk; + cute::array_aligned> smem_dv; + }; + }; + union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used. + cute::array_aligned> smem_p; + cute::array_aligned> smem_ds; + }; + }; + struct { + cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. + cutlass::arch::ClusterTransactionBarrier barrier_K; + cutlass::arch::ClusterTransactionBarrier barrier_V; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_do; + }; +}; + +template +struct SharedStorageQKVdOdKVWS; + +template +struct SharedStorageQKVdOdKVWS { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + union { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + struct { + cute::array_aligned> smem_dk; + cute::array_aligned> smem_dv; + }; + }; + cute::array_aligned> smem_p; + cute::array_aligned> smem_ds; + cute::array_aligned> smem_dqacc; + cute::array_aligned smem_lse; + cute::array_aligned smem_dpsum; + }; + struct { + cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. + cutlass::arch::ClusterTransactionBarrier barrier_K; + cutlass::arch::ClusterTransactionBarrier barrier_V; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_do; + }; +}; + +template +struct SharedStorageQKVdOdKVWS { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + union { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + struct { + cute::array_aligned> smem_dk; + cute::array_aligned> smem_dv; + }; + }; + union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used. + cute::array_aligned> smem_p; + cute::array_aligned> smem_ds; + }; + cute::array_aligned> smem_dqacc; + cute::array_aligned smem_lse; + cute::array_aligned smem_dpsum; + }; + struct { + cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. + cutlass::arch::ClusterTransactionBarrier barrier_K; + cutlass::arch::ClusterTransactionBarrier barrier_V; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_do; + }; +}; + +template +struct SharedStorageQKVdOdKVSeqqPar; + +template +struct SharedStorageQKVdOdKVSeqqPar { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + union { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + }; + struct { + cute::array_aligned> smem_dq; + }; + }; + cute::array_aligned> smem_p; + cute::array_aligned> smem_ds; + }; + struct { + cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterTransactionBarrier barrier_dO; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + }; +}; + +template +struct SharedStorageQKVdOdKVSeqqPar { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + union { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + }; + struct { + cute::array_aligned> smem_dq; + }; + }; + union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used. + cute::array_aligned> smem_p; + cute::array_aligned> smem_ds; + }; + }; + struct { + cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterTransactionBarrier barrier_dO; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// template +// struct Flash_bwd_kernel_traits { +// using Element = elem_type; +// using ElementAccum = float; +// using index_t = int64_t; + +// // The number of threads. +// static constexpr int kNWarps = kNWarps_; +// static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; +// static constexpr int kNThreadsNonWS = 8 * cutlass::NumThreadsPerWarp; +// // static constexpr int kNThreadsdQ = cutlass::NumThreadsPerWarpGroup; +// static constexpr int kNThreadsdQ = 2 * cutlass::NumThreadsPerWarpGroup; + +// static_assert(kNWarps_ == 8 || kNWarps_ == 12); + +// static constexpr bool Is_WS = kNWarps_ >= 12; + +// static constexpr int kBlockM = kBlockM_; +// static constexpr int kBlockN = kBlockN_; +// static constexpr int kHeadDim = kHeadDim_; +// static_assert(kHeadDim % 32 == 0); +// using TileShape_MNK = Shape, Int, Int>; + +// static constexpr int kClusterN = kClusterN_; +// using ClusterShape_MNK = Shape<_1, Int, _1>; + +// static constexpr int kStages = 2; + +// static constexpr bool SdP_swapAB = SdP_swapAB_; +// static constexpr bool dKV_swapAB = dKV_swapAB_; +// static constexpr bool dQ_swapAB = dQ_swapAB_; +// static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV + +// static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS + +// using TileShapeAtomSdP = std::conditional_t< +// !SdP_swapAB, +// Shape, Int, Int>, +// Shape, Int, Int> +// >; +// using AtomLayoutSdP = std::conditional_t< +// !SdP_swapAB, +// Layout, Int<2 / AtomLayoutMSdP>, _1>>, +// Layout, Int, _1>> +// >; +// using TiledMmaSdP = decltype(cute::make_tiled_mma( +// cute::GMMA::ss_op_selector(), +// AtomLayoutSdP{})); + +// using TileShapeAtomdKV = std::conditional_t< +// !dKV_swapAB, +// Shape, Int, Int>, +// Shape, Int, Int> +// >; +// using AtomLayoutdKV = std::conditional_t< +// !dKV_swapAB, +// Layout, Int<2 / AtomLayoutNdKV>, _1>>, +// Layout, Int, _1>> +// >; +// using TiledMmadKV = decltype(cute::make_tiled_mma( +// std::conditional_t< +// !SdP_swapAB, +// decltype(cute::GMMA::ss_op_selector()), +// decltype(cute::GMMA::rs_op_selector()) +// >{}, +// AtomLayoutdKV{})); + +// using TileShapeAtomdQ = std::conditional_t< +// !dQ_swapAB, +// Shape, Int, Int>, +// Shape, Int, Int> +// // Shape, Int, Int>, +// // Shape, Int, Int> +// >; +// using AtomLayoutdQ = std::conditional_t< +// !dQ_swapAB, +// Layout, Int<2 / AtomLayoutMdQ>, _1>>, +// Layout, Int, _1>> +// // Layout, Int<1>, _1>>, +// // Layout, Int<1>, _1>> +// >; +// static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN; +// static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K; +// using TiledMmadQ = decltype(cute::make_tiled_mma( +// std::conditional_t< +// !dQ_swapAB, +// std::conditional_t< +// Mma_dQ_is_RS, +// decltype(cute::GMMA::rs_op_selector()), +// decltype(cute::GMMA::ss_op_selector()) +// >, +// decltype(cute::GMMA::ss_op_selector()) +// >{}, +// AtomLayoutdQ{})); + +// using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); +// using GmemTiledCopyKV = cute::SM90_TMA_LOAD; +// using GmemTiledCopydKV = cute::SM90_TMA_STORE; + +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// static constexpr bool Has_cp_async = true; +// #else +// static constexpr bool Has_cp_async = false; +// #endif +// // For the dot_do_o preprocessing kernel +// using Gmem_copy_struct = std::conditional_t< +// Has_cp_async, +// SM80_CP_ASYNC_CACHEGLOBAL, +// DefaultCopy +// >; +// static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; +// static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); +// static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); +// // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem +// // to affect speed in practice. +// static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; +// static_assert(kNThreadsNonWS % kGmemThreadsPerRow == 0, "kNThreadsNonWS must be a multiple of kGmemThreadsPerRow"); +// using GmemLayoutAtom = Layout, Int>, +// Stride, _1>>; +// using GmemLayoutAtomdQ = Layout, Int>, +// Stride, _1>>; +// using GmemTiledCopydO = decltype( +// make_tiled_copy(Copy_Atom{}, +// GmemLayoutAtom{}, +// Layout>{})); // Val layout, 8 vals per store +// using GmemTiledCopydQ = decltype( +// make_tiled_copy(Copy_Atom{}, +// GmemLayoutAtomdQ{}, +// Layout>{})); // Val layout, 8 vals per store +// using GmemLayoutAtomdQaccum = std::conditional_t< +// kBlockKSmem == 32, +// Layout, _8>, // Thread layout, 8 threads per row +// Stride< _8, _1>>, +// Layout, _16>, // Thread layout, 16 threads per row +// Stride< _16, _1>> +// >; +// using GmemTiledCopydQaccum = decltype( +// make_tiled_copy(Copy_Atom{}, +// GmemLayoutAtomdQaccum{}, +// Layout>{})); // Val layout, 4 vals per store + +// using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); +// using SmemLayoutQ = +// decltype(tile_to_shape(SmemLayoutAtomQ{}, +// make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); +// using SmemLayoutdO = SmemLayoutQ; + +// using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); +// using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); + +// using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); +// using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{}))); + +// using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); +// using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); +// using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); +// using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{}))); + +// // using SmemLayoutAtomdQacc = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); +// // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{}))); + +// // Note this is the transpose in terms of the view, not in terms of memory. +// using SmemLayoutQt = +// decltype(cute::composition(SmemLayoutQ{}, +// make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), +// make_stride(Int{}, _1{}, Int{})))); +// using SmemLayoutdOt = +// decltype(cute::composition(SmemLayoutdO{}, +// make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), +// make_stride(Int{}, _1{}, Int{})))); +// using SmemLayoutKt = +// decltype(cute::composition(SmemLayoutK{}, +// make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), +// make_stride(Int{}, _1{})))); +// using SmemLayoutPt = +// decltype(cute::composition(SmemLayoutP{}, +// make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), +// make_stride(Int{}, _1{})))); +// using SmemLayoutdSt = +// decltype(cute::composition(SmemLayoutdS{}, +// make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), +// make_stride(Int{}, _1{})))); + +// // using SmemLayoutdQacct = +// // decltype(cute::composition(SmemLayoutdQacc{}, +// // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), +// // make_stride(Int{}, _1{})))); + +// using SmemLayoutdK = SmemLayoutK; +// using SmemLayoutdV = SmemLayoutV; +// using SmemLayoutdKt = SmemLayoutKt; +// using SmemLayoutdVt = SmemLayoutKt; + +// static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; +// using SmemLayoutAtomdQ = decltype( +// // composition(Swizzle{}, +// composition(Swizzle<3, 3, 3>{}, +// Layout, Int<32>>, +// Stride, _1>>{})); +// using SmemLayoutdQ = decltype(tile_to_shape( +// SmemLayoutAtomdQ{}, +// make_shape(Int{}, Int{}))); +// using SmemLayoutdQt = +// decltype(cute::composition(SmemLayoutdQ{}, +// make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), +// make_stride(Int{}, _1{})))); +// static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + +// using SmemLayoutAtomdQaccTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); +// using SmemLayoutdQaccTMA = decltype(tile_to_shape(SmemLayoutAtomdQaccTMA{}, select<0, 2>(TileShape_MNK{}))); +// using SmemLayoutdQacc = SmemLayoutdQ; +// using SmemLayoutdQacct = SmemLayoutdQt; +// using SmemLayoutdQacc2 = decltype(tile_to_shape( +// SmemLayoutAtomdQ{}, +// make_shape(Int{}, Int{}, _2{}))); +// // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{}))); +// // using SmemLayoutdQacct = +// // decltype(cute::composition(SmemLayoutdQacc{}, +// // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), +// // make_stride(Int{}, _1{})))); +// using RmemTiledCopydQacc = decltype( +// make_tiled_copy(Copy_Atom{}, +// GmemLayoutAtomdQaccum{}, +// Layout>{})); // Val layout, 4 vals per store + +// // using SmemCopyAtomQ = Copy_Atom; +// using SmemCopyAtomPdS = Copy_Atom< +// std::conditional_t, +// Element>; +// using SmemCopyAtomdKV = Copy_Atom< +// std::conditional_t, +// Element>; +// using SmemCopyAtomdQ = Copy_Atom< +// std::conditional_t, +// Element>; + +// using SharedStorage = std::conditional_t< +// !Is_WS, +// SharedStorageQKVdOdKV, +// SharedStorageQKVdOdKVWS +// // SmemLayoutK, SmemLayoutV, SmemLayoutdS, SmemLayoutdQacc2, SmemLayoutdK, SmemLayoutdV> +// >; + +// // using MainloopPipeline = typename cutlass::PipelineTmaAsync; +// // using PipelineState = typename cutlass::PipelineState; +// using MainloopPipeline = typename cutlass::PipelineTmaAsync; + +// }; + +// //////////////////////////////////////////////////////////////////////////////////////////////////// + +// template +// struct Flash_bwd_seqqpar_kernel_traits { +// using Element = elem_type; +// using ElementAccum = float; +// using index_t = int64_t; + +// // The number of threads. +// static constexpr int kNWarps = kNWarps_; +// static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; + +// static_assert(kNWarps_ == 8); + +// static constexpr int kBlockM = kBlockM_; +// static constexpr int kBlockN = kBlockN_; +// static constexpr int kHeadDim = kHeadDim_; +// static_assert(kHeadDim % 32 == 0); +// using TileShape_MNK = Shape, Int, Int>; + +// static constexpr int kClusterN = kClusterN_; +// using ClusterShape_MNK = Shape<_1, Int, _1>; + +// static constexpr int kStages = 2; + +// static constexpr bool SdP_swapAB = SdP_swapAB_; +// static constexpr bool dKV_swapAB = dKV_swapAB_; +// static constexpr bool dQ_swapAB = dQ_swapAB_; +// static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV + +// static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS + +// using TileShapeAtomSdP = std::conditional_t< +// !SdP_swapAB, +// Shape, Int, Int>, +// Shape, Int, Int> +// >; +// using AtomLayoutSdP = std::conditional_t< +// !SdP_swapAB, +// Layout, Int<2 / AtomLayoutMSdP>, _1>>, +// Layout, Int, _1>> +// >; +// using TiledMmaSdP = decltype(cute::make_tiled_mma( +// cute::GMMA::ss_op_selector(), +// AtomLayoutSdP{})); + +// using TileShapeAtomdKV = std::conditional_t< +// !dKV_swapAB, +// Shape, Int, Int>, +// Shape, Int, Int> +// >; +// using AtomLayoutdKV = std::conditional_t< +// !dKV_swapAB, +// Layout, Int<2 / AtomLayoutNdKV>, _1>>, +// Layout, Int, _1>> +// >; +// using TiledMmadKV = decltype(cute::make_tiled_mma( +// std::conditional_t< +// !SdP_swapAB, +// decltype(cute::GMMA::ss_op_selector()), +// decltype(cute::GMMA::rs_op_selector()) +// >{}, +// AtomLayoutdKV{})); + +// using TileShapeAtomdQ = std::conditional_t< +// !dQ_swapAB, +// Shape, Int, Int>, +// Shape, Int, Int> +// >; +// using AtomLayoutdQ = std::conditional_t< +// !dQ_swapAB, +// Layout, Int<2 / AtomLayoutMdQ>, _1>>, +// Layout, Int, _1>> +// >; +// static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN; +// static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K; +// using TiledMmadQ = decltype(cute::make_tiled_mma( +// std::conditional_t< +// !dQ_swapAB, +// std::conditional_t< +// Mma_dQ_is_RS, +// decltype(cute::GMMA::rs_op_selector()), +// decltype(cute::GMMA::ss_op_selector()) +// >, +// decltype(cute::GMMA::ss_op_selector()) +// >{}, +// AtomLayoutdQ{})); + +// using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); +// using GmemTiledCopyKV = cute::SM90_TMA_LOAD; +// using GmemTiledCopydKV = cute::SM90_TMA_STORE; + +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// static constexpr bool Has_cp_async = true; +// #else +// static constexpr bool Has_cp_async = false; +// #endif +// // For the dot_do_o preprocessing kernel +// using Gmem_copy_struct = std::conditional_t< +// Has_cp_async, +// SM80_CP_ASYNC_CACHEGLOBAL, +// DefaultCopy +// >; +// static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; +// static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); +// static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); +// // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem +// // to affect speed in practice. +// static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; +// static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); +// using GmemLayoutAtom = Layout, Int>, +// Stride, _1>>; +// using GmemTiledCopydO = decltype( +// make_tiled_copy(Copy_Atom{}, +// GmemLayoutAtom{}, +// Layout>{})); // Val layout, 8 vals per store +// using GmemTiledCopydQ = decltype( +// make_tiled_copy(Copy_Atom{}, +// GmemLayoutAtom{}, +// Layout>{})); // Val layout, 8 vals per store +// using GmemLayoutAtomdQaccum = std::conditional_t< +// kBlockKSmem == 32, +// Layout, // Thread layout, 8 threads per row +// Stride< _8, _1>>, +// Layout, // Thread layout, 16 threads per row +// Stride< _16, _1>> +// >; +// using GmemTiledCopydQaccum = decltype( +// make_tiled_copy(Copy_Atom{}, +// GmemLayoutAtomdQaccum{}, +// Layout>{})); // Val layout, 4 vals per store + +// using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); +// using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); +// using SmemLayoutdO = SmemLayoutQ; + +// using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); +// using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, +// make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + +// using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); +// using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, +// make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + +// using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); +// using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); +// using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); +// using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{}))); + +// // Note this is the transpose in terms of the view, not in terms of memory. +// using SmemLayoutQt = +// decltype(cute::composition(SmemLayoutQ{}, +// make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), +// make_stride(Int{}, _1{})))); +// using SmemLayoutdOt = +// decltype(cute::composition(SmemLayoutdO{}, +// make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), +// make_stride(Int{}, _1{})))); +// using SmemLayoutKt = +// decltype(cute::composition(SmemLayoutK{}, +// make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), +// make_stride(Int{}, _1{}, Int{})))); +// using SmemLayoutPt = +// decltype(cute::composition(SmemLayoutP{}, +// make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), +// make_stride(Int{}, _1{})))); +// using SmemLayoutdSt = +// decltype(cute::composition(SmemLayoutdS{}, +// make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), +// make_stride(Int{}, _1{})))); + +// using SmemLayoutdK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); +// using SmemLayoutdV = SmemLayoutdK; +// using SmemLayoutdKt = SmemLayoutKt; +// using SmemLayoutdVt = SmemLayoutKt; +// using SmemLayoutdQTMA = decltype(tile_to_shape(SmemLayoutAtomK{}, select<0, 2>(TileShape_MNK{}))); + +// static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; +// using SmemLayoutAtomdQ = decltype( +// composition(Swizzle{}, +// Layout>, +// Stride, _1>>{})); +// using SmemLayoutdQ = decltype(tile_to_shape( +// SmemLayoutAtomdQ{}, +// make_shape(Int{}, Int{}))); +// using SmemLayoutdQt = +// decltype(cute::composition(SmemLayoutdQ{}, +// make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), +// make_stride(Int{}, _1{})))); +// static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + +// using SmemLayoutAtomdKV = decltype( +// composition(Swizzle{}, +// Layout>, +// Stride, _1>>{})); +// using SmemLayoutdKV = decltype(tile_to_shape( +// SmemLayoutAtomdKV{}, +// make_shape(Int{}, Int{}))); +// using SmemLayoutdKVt = +// decltype(cute::composition(SmemLayoutdKV{}, +// make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), +// make_stride(Int{}, _1{})))); +// static constexpr int kSmemdKVSize = size(SmemLayoutdKV{}) * sizeof(Element) * 2; + +// // using SmemCopyAtomQ = Copy_Atom; +// using SmemCopyAtomPdS = Copy_Atom< +// std::conditional_t, +// Element>; +// using SmemCopyAtomdKV = Copy_Atom< +// std::conditional_t, +// Element>; +// using SmemCopyAtomdQ = Copy_Atom< +// std::conditional_t, +// Element>; + +// using SharedStorage = SharedStorageQKVdOdKVSeqqPar; + +// // using MainloopPipeline = typename cutlass::PipelineTmaAsync; +// // using PipelineState = typename cutlass::PipelineState; +// using MainloopPipeline = typename cutlass::PipelineTmaAsync; + +// }; + +// //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-flash-attn-v3/hkernel/mainloop_fwd_sm90_tma_gmma_ws.hpp b/candle-flash-attn-v3/hkernel/mainloop_fwd_sm90_tma_gmma_ws.hpp new file mode 100644 index 0000000000..27db336b5c --- /dev/null +++ b/candle-flash-attn-v3/hkernel/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -0,0 +1,1145 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "named_barrier.hpp" +#include "utils.h" +#include "copy_paged_sm90_tma.hpp" + +namespace flash { + +using namespace cute; + +// 4 warps +struct SmemTransposeFp8_64x64 { + + using Element = cutlass::float_e4m3_t; + + using ldsm_thread_shape = Shape<_4, _1, _8, _4>; + using ldsm_value_shape = Shape<_2, _8, _2, _1>; + using ldsm_value_stride = Stride<_2, _4, _1, _0>; + using TiledCopyLDSM = decltype(make_tiled_copy( + Copy_Atom{}, Layout{}, + Layout{})); + TiledCopyLDSM tiled_copy_ldsm; + + using stsm_thread_shape = Shape<_4, _1, _8, _4>; + // using stsm_thread_stride = Stride<_1, _0, _4, _32>; +#ifndef NO_FP8_COLUMN_PERMUTE + using stsm_value_shape = Shape<_4, _4, _1, _2>; + using stsm_value_stride = Stride<_1, _8, _0, _4>; +#else + using stsm_value_shape = Shape<_4, _4, _2, _1>; + using stsm_value_stride = Stride<_1, _8, _4, _0>; +#endif + + using TiledCopySTSM = + decltype(make_tiled_copy(Copy_Atom{}, + Layout{}, + Layout{})); + TiledCopySTSM tiled_copy_stsm; + + template + CUTLASS_DEVICE void operator()(SmemTensor &&s_in, SmemTensorOut &&s_out) { + using namespace cute; + + auto tid = threadIdx.x; + auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid); + auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid); + + auto tXsX = thr_copy_ldsm.partition_S(s_in); + auto tXrX = make_tensor(shape(tXsX)); + auto tXsX_out = thr_copy_stsm.partition_D(s_out); + + cute::copy(tiled_copy_ldsm, tXsX, tXrX); + + auto data = tXrX.data(); + // size(tXrX) == 32 + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size(tXrX); n += 8) { + uint32_t *data_32bit = reinterpret_cast(&data[n]); + auto upper = data_32bit[0]; + auto lower = data_32bit[1]; + data_32bit[0] = __byte_perm(upper, lower, 0x6420); + data_32bit[1] = __byte_perm(upper, lower, 0x7531); + } + + cute::copy(tiled_copy_stsm, tXrX, tXsX_out); + } +}; + +template +struct CollectiveMainloopFwd { + + using Element = typename Ktraits::Element; + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; + + static constexpr int kStages = Ktraits::kStages; + static constexpr int kHeadDim = Ktraits::kHeadDim; + // static constexpr int kBlockM = Ktraits::kBlockM; + // static constexpr int kBlockN = Ktraits::kBlockN; + // static constexpr int kBlockH = Ktraits::kBlockH; + static constexpr bool Is_split = Ktraits::Is_split; + static constexpr bool No_smem_O = Ktraits::No_smem_O; + + using GmemTiledCopyQ = cute::SM90_TMA_LOAD; + using GmemTiledCopyKVNopage = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); + + // use SM90_TMA_LOAD_MULTICAST_PAGED if we would use SM90_TMA_LOAD_MULTICAST in unpaged scenario, otherwise use SM90_TMA_LOAD_PAGED + using GmemTiledCopyKV = typename std::conditional< + std::is_same::value, + SM90_TMA_LOAD_MULTICAST_PAGED, + SM90_TMA_LOAD_PAGED>::type; + + using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutQCopy = typename Ktraits::SmemLayoutQCopy; + using TileShapeQCopy = typename Ktraits::TileShapeQCopy; + using SmemLayoutK = typename Ktraits::SmemLayoutK; + using SmemLayoutV = typename Ktraits::SmemLayoutV; + using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + + using TMA_Q = decltype(make_tma_copy( + GmemTiledCopyQ{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + repeat_like(typename Seqlen_traits_Q::StrideT{}, int32_t(0)), + typename Seqlen_traits_Q::StrideT{} + ), + SmemLayoutQCopy{}, + TileShapeQCopy{}, + _1{})); // no mcast for Q + + using TMA_K = decltype(make_virtualized_tma_copy( + GmemTiledCopyKV{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)), + typename Seqlen_traits::StrideT{} + ), + typename Seqlen_traits::ShapeT{}, + take<0, 2>(SmemLayoutK{}), + select<1, 2>(TileShape_MNK{}), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + // TMA_V may differ from TMA_K for fp8 kernel (e.g. swizzling mode) + using TMA_V = decltype(make_virtualized_tma_copy( + GmemTiledCopyKV{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)), + typename Seqlen_traits::StrideT{} + ), + typename Seqlen_traits::ShapeT{}, + take<0, 2>(SmemLayoutV{}), + select<1, 2>(TileShape_MNK{}), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using MainloopPipelineNoTMA = typename Ktraits::MainloopPipelineNoTMA; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); + + // static constexpr bool UseSchedulerBarrier = kHeadDim <= 128; + static constexpr bool UseSchedulerBarrier = Ktraits::kNWarps >= 12 && + (cutlass::sizeof_bits_v == 8 ? kHeadDim >= 128 : kHeadDim <= 128); + + // Host side kernel arguments + struct Arguments { + Element const* ptr_Q; + typename Seqlen_traits_Q::LayoutT layout_Q; + Element const* ptr_K; + typename Seqlen_traits::LayoutT layout_K; + Element const* ptr_V; + typename Seqlen_traits::LayoutT layout_V; + typename Seqlen_traits::ShapeT shape_KV; + float const softmax_scale_log2; + float const* descale_q_ptr; + float const* descale_k_ptr; + float const* descale_v_ptr; + int window_size_left; + int window_size_right; + int const qhead_per_khead; + int const* cache_batch_idx; + int const num_splits; + // Paged Attention block table data + int * block_table; // may be nullptr if not paged + int64_t block_table_batch_stride; + int page_block_size; + int num_blocks; + }; + + // Device side kernel params + struct Params { + typename Seqlen_traits_Q::LayoutT layout_Q; + typename Seqlen_traits::LayoutT layout_K; + typename Seqlen_traits::LayoutT layout_V; + typename Seqlen_traits::ShapeT shape_KV; + cutlass::FastDivmod qhead_per_khead_divmod; + TMA_Q tma_load_Q; + TMA_K tma_load_K; + TMA_V tma_load_V; + float const softmax_scale_log2; + float const* descale_q_ptr; + float const* descale_k_ptr; + float const* descale_v_ptr; + int window_size_left; + int window_size_right; + int const* cache_batch_idx; + cutlass::FastDivmod num_splits_divmod; + // Paged Attention block table data + const PagedCopyArgs paged_copy_args; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q); + TMA_Q tma_load_Q = make_tma_copy( + GmemTiledCopyQ{}, + mQ, + SmemLayoutQCopy{}, + TileShapeQCopy{}, + _1{}); // no mcast for Q + Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K); + TMA_K tma_load_K = make_virtualized_tma_copy( + GmemTiledCopyKV{}, + mK, + args.shape_KV, + SmemLayoutK{}(_, _, _0{}), + select<1, 2>(TileShape_MNK{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V); + TMA_V tma_load_V = make_virtualized_tma_copy( + GmemTiledCopyKV{}, + mV, + args.shape_KV, + SmemLayoutV{}(_, _, _0{}), + select<1, 2>(TileShape_MNK{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return {args.layout_Q, args.layout_K, args.layout_V, args.shape_KV, + cutlass::FastDivmod(args.qhead_per_khead), + + tma_load_Q, tma_load_K, tma_load_V, + args.softmax_scale_log2, + args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr, + args.window_size_left, args.window_size_right, + args.cache_batch_idx, + cutlass::FastDivmod(args.num_splits), + {args.block_table_batch_stride, args.page_block_size, args.block_table }}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor()); + } + + CUTLASS_DEVICE + void get_n_block_min_max( + Params const& mainloop_params, + int m_block, + int n_split_idx, + const Seqlen_traits_Q& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k, + int& n_block_min, + int& n_block_max + ) { + // static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{})/Ktraits::kBlockH; + int const seqlen_q = seqlen_traits_q.actual_seq_len; + int const seqlen_k = seqlen_traits_k.actual_seq_len; + n_block_max = cute::ceil_div(seqlen_k, kBlockN); + + if constexpr(Is_split) { + int const n_blocks_per_split + = mainloop_params.num_splits_divmod.divide(n_block_max + int(mainloop_params.num_splits_divmod) - 1); + n_block_min = n_split_idx * n_blocks_per_split; + n_block_max = std::min(n_block_max, (n_split_idx + 1) * n_blocks_per_split); + } + + if constexpr (Is_causal) { + n_block_max = std::min( + n_block_max, + cute::ceil_div((m_block + 1) * kBlockM_div_H + seqlen_k - seqlen_q, kBlockN)); + } else if constexpr (Is_local) { + n_block_max = std::min( + n_block_max, + cute::ceil_div((m_block + 1) * kBlockM_div_H + seqlen_k - seqlen_q + mainloop_params.window_size_right, kBlockN)); + n_block_min = std::max( + n_block_min, + (m_block * kBlockM_div_H + seqlen_k - seqlen_q - mainloop_params.window_size_left) / kBlockN); + } + } + + CUTLASS_DEVICE + void get_n_block_max( + Params const& mainloop_params, + int m_block, + const Seqlen_traits_Q& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k, + int& n_block_max + ) { + // static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{})/Ktraits::kBlockH; + int const seqlen_q = seqlen_traits_q.actual_seq_len; + int const seqlen_k = seqlen_traits_k.actual_seq_len; + n_block_max = cute::ceil_div(seqlen_k, kBlockN); + if constexpr (Is_causal) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM_div_H + seqlen_k - seqlen_q, kBlockN)); + } + } + + template + CUTLASS_DEVICE void + load(Params const& mainloop_params, + MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, + PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v, + SharedStorage &shared_storage, + Scheduler& scheduler, + typename Scheduler::Params const& scheduler_params, + typename Scheduler::WorkTileInfo& work_tile_info, + cute::tuple block_coord, + int work_idx, + const Seqlen_traits_Q& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k, + int n_block_min, + int n_block_max + ) { + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQCopy{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); + + Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); + Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_KV); + Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_KV); + + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + const int bidb_cache = mainloop_params.cache_batch_idx == nullptr ? bidb : mainloop_params.cache_batch_idx[bidb]; + const int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + Tensor gQ = [&] { + // Need this inside lambda to capture structured binding + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + if constexpr(Seqlen_traits_Q::UseGQAPacking) { + return seqlen_traits_q.get_local_tile_tensor( + mQ, TileShapeQCopy{}, bidh_kv, bidb) + (_, _, _, m_block, bidh % int(mainloop_params.qhead_per_khead_divmod)); // (M/H, H, K) + } else { + return seqlen_traits_q.get_local_tile_tensor( + mQ, TileShapeQCopy{}, bidh, bidb)(_, _, m_block); // (M, K) + } + }(); + Tensor gK = seqlen_traits_k.get_local_tile_tensor( + mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _) + Tensor gV = seqlen_traits_k.get_local_tile_tensor( + mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _) + + Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); + Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); + auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{}, + group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA) + auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout{}, + group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE) + auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout{}, + group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE) + + uint16_t mcast_mask_kv = 0; + if constexpr (cute::is_same_v || cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); + } + } + + int n_block = n_block_max - 1; + + int lane_predicate = cute::elect_one_sync(); + if (lane_predicate) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv, mainloop_params.paged_copy_args), + tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + } + + // Wait for the MMA warpgroups to say that smem_q is ready + cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + + if (lane_predicate) { + shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(mainloop_params.tma_load_Q.with(reinterpret_cast(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ); + } + + // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem + // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the + // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O. + if constexpr (!No_smem_O) { shared_storage.barrier_O.wait((work_idx + 1) % 2); } + if (lane_predicate) { + // CUTLASS_PRAGMA_NO_UNROLL + #pragma unroll 2 + for (; n_block > n_block_min; --n_block) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv, mainloop_params.paged_copy_args), + tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv, mainloop_params.paged_copy_args), + tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + } + + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + if (lane_predicate) { + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv, mainloop_params.paged_copy_args), + tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + scheduler.broadcast_next_work(work_tile_info); + + } + + template + CUTLASS_DEVICE void + load_fp8(Params const& mainloop_params, + MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, + MainloopPipelineNoTMA pipeline_vt, + PipelineState& smem_pipe_write, + PipelineState& smem_pipe_read, + SharedStorage &shared_storage, + Scheduler& scheduler, + typename Scheduler::Params const& scheduler_params, + typename Scheduler::WorkTileInfo& work_tile_info, + cute::tuple block_coord, + int work_idx, + const Seqlen_traits_Q& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k, + int n_block_min, + int n_block_max + ) { + + using SmemLayoutTransposeV = typename Ktraits::SmemLayoutTransposeV; + using SmemLayoutTransposeVt = typename Ktraits::SmemLayoutTransposeVt; + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQCopy{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); + + Tensor sV_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutTransposeV{})); + Tensor sVt_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutTransposeVt{})); + + auto smem_transpose_V = SmemTransposeFp8_64x64(); + auto do_transpose_V = [&](int stage) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < shape<2>(SmemLayoutTransposeV{}); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < shape<1>(SmemLayoutTransposeV{}); ++i) { + smem_transpose_V(flatten(sV_divide(_, i, j, stage)), + flatten(sVt_divide(_, i, j, stage))); + } + } + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::ProducerWG) /*id*/); + }; + + Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); + Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_KV); + Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_KV); + + auto [m_block, split_idx, bidh, bidb] = block_coord; + const int bidb_cache = mainloop_params.cache_batch_idx == nullptr ? bidb : mainloop_params.cache_batch_idx[bidb]; + const int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + Tensor gQ = [&] { + // Need this inside lambda to capture structured binding + auto [m_block, n_split_idx, bidh, bidb] = block_coord; + if constexpr(Seqlen_traits_Q::UseGQAPacking) { + return seqlen_traits_q.get_local_tile_tensor( + mQ, TileShapeQCopy{}, bidh_kv, bidb) + (_, _, _, m_block, bidh % int(mainloop_params.qhead_per_khead_divmod)); // (M/H, H, K) + } else { + return seqlen_traits_q.get_local_tile_tensor( + mQ, TileShapeQCopy{}, bidh, bidb)(_, _, m_block); // (M, K) + } + }(); + Tensor gK = seqlen_traits_k.get_local_tile_tensor( + mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _) + Tensor gV = seqlen_traits_k.get_local_tile_tensor( + mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb_cache); // (N, K, _) + + Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); + Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); + auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{}, + group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA) + auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout{}, + group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE) + auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout{}, + group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE) + + uint16_t mcast_mask_kv = 0; + if constexpr (cute::is_same_v || cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); + } + } + + int n_block = n_block_max - 1; + + int lane_predicate = cute::elect_one_sync(); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + if (warp_idx_in_warpgroup == 0 && lane_predicate) { + pipeline_k.producer_acquire(smem_pipe_write); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, mainloop_params.paged_copy_args), + tKgK(_, n_block), tKsK(_, smem_pipe_write.index())); + } + + // Wait for the MMA warpgroups to say that smem_q is ready + // for fp8, change from NumThreadsPerWarp to NumThreadsPerWarpGroup + cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + + if (warp_idx_in_warpgroup == 0 && lane_predicate) { + shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(mainloop_params.tma_load_Q.with(reinterpret_cast(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ); + if constexpr(!Ktraits::VO_union_all) { + pipeline_v.producer_acquire(smem_pipe_write); + copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv, mainloop_params.paged_copy_args), + tVgV(_, n_block), tVsV(_, smem_pipe_write.index())); + } + + } + // With fp8 kernel, smem_o is in union with smem_v_out, + // except for split kernel + hdim 256, + // so could use NamedBarrier instead of ClusterBarrier. + // But, this doesn't appear to have any benefit. + if constexpr (!No_smem_O) { shared_storage.barrier_O.wait((work_idx + 1) % 2); } + + if constexpr(Ktraits::VO_union_all) { + if (warp_idx_in_warpgroup == 0 && lane_predicate) { + pipeline_v.producer_acquire(smem_pipe_write); + copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv, mainloop_params.paged_copy_args), + tVgV(_, n_block), tVsV(_, smem_pipe_write.index())); + } + } + + #pragma unroll 2 + for (; n_block > n_block_min; --n_block) { + pipeline_v.consumer_wait(smem_pipe_read); + pipeline_vt.producer_acquire(smem_pipe_write); + do_transpose_V(smem_pipe_read.index()); + pipeline_vt.producer_commit(smem_pipe_write); + pipeline_v.consumer_release(smem_pipe_read); + + ++smem_pipe_write; + ++smem_pipe_read; + + if (warp_idx_in_warpgroup == 0 && lane_predicate) { + pipeline_k.producer_acquire(smem_pipe_write); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, mainloop_params.paged_copy_args), + tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index())); + pipeline_v.producer_acquire(smem_pipe_write); + copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv, mainloop_params.paged_copy_args), + tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index())); + } + } + + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + scheduler.broadcast_next_work(work_tile_info); + + pipeline_v.consumer_wait(smem_pipe_read); + pipeline_vt.producer_acquire(smem_pipe_write); + do_transpose_V(smem_pipe_read.index()); + pipeline_vt.producer_commit(smem_pipe_write); + pipeline_v.consumer_release(smem_pipe_read); + + ++smem_pipe_write; + ++smem_pipe_read; + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, + PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) { + int lane_predicate = cute::elect_one_sync(); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + // Issue the epilogue waits + if (warp_idx_in_warpgroup == 0 && lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was still inverted from make_producer_start_state + */ + pipeline_k.producer_tail(smem_pipe_write_k); + pipeline_v.producer_tail(smem_pipe_write_v); + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail_one_write(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, + PipelineState& smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + // Issue the epilogue waits + if (warp_idx_in_warpgroup == 0 && lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was still inverted from make_producer_start_state + */ + pipeline_k.producer_tail(smem_pipe_write); + pipeline_v.producer_tail(smem_pipe_write); + } + } + + CUTLASS_DEVICE void + warp_scheduler_barrier_sync() { + if constexpr (UseSchedulerBarrier) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/); + } + } + + CUTLASS_DEVICE void + warp_scheduler_barrier_arrive() { + if constexpr (!UseSchedulerBarrier) { + return; + } else { + static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); + if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/); + } else { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/); + } + } + } + + CUTLASS_DEVICE void + mma_init() { + // Tell producer (warp 0) that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + Ktraits::NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + if constexpr (!UseSchedulerBarrier) { + return; + } else { + static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); + if (cutlass::canonical_warp_group_idx() > 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/); + } + if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) { + if (cutlass::canonical_warp_group_idx() > 2) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/); + } + } + } + } + + template + CUTLASS_DEVICE void + mma(Params const& mainloop_params, + MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, + PipelineState& smem_pipe_read_k, + PipelineState& smem_pipe_read_v, + FrgTensorO& tOrO, + Softmax& softmax, + int n_block_min, + int n_block_max, + int thread_idx, + int work_idx, + int m_block, + SharedStorage& shared_storage, + const Seqlen_traits_Q& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kBlockH = Ktraits::kBlockH; + static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{}) / kBlockH; + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{}); + + typename Ktraits::TiledMma0 tiled_mma0; + typename Ktraits::TiledMma1 tiled_mma1; + auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx); + auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx); + + // Allocate "fragments/descriptors" for first matmul. + Tensor tSrQ = threadMma0.partition_fragment_A(sQ); + Tensor tSrK = threadMma0.partition_fragment_B(sK); + // Allocate "fragments/descriptors" for second matmul. + // Note: S becomes P. + Tensor tOrV = threadMma1.partition_fragment_B(sVt); + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; + int const seqlen_q = seqlen_traits_q.actual_seq_len; + int const seqlen_k = seqlen_traits_k.actual_seq_len; + int n_block = n_block_max - 1; + + cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); + if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); } + + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + + consumer_wait(pipeline_k, smem_pipe_read_k); + warp_scheduler_barrier_sync(); + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); + warp_scheduler_barrier_arrive(); + if constexpr (!No_smem_O) { + if (work_idx != 0) { + int lane_predicate = cute::elect_one_sync(); + if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) { + tma_store_wait<0>(); + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.barrier_O.arrive(cta_id, lane_predicate); + } + } + } + } + warpgroup_wait<0>(); + pipeline_k.consumer_release(smem_pipe_read_k); + ++smem_pipe_read_k; + + auto col_limit_right = [&](int row, int n_block) { + int col_limit_base = row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H; + if constexpr(Is_local) + return col_limit_base + mainloop_params.window_size_right; + else + return col_limit_base; + }; + auto col_limit_left = [&](int row, int n_block) { + return std::max( + 0, + row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H - mainloop_params.window_size_left + ); + }; + { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if constexpr (!Is_causal && !Is_local) { // Just masking based on col + if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } + } else { // mask based on both row and col + // using std::min is faster than doing col >= limit0 or col >= limit1 + // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the + // right hand side can be negative and might be converted to a very large unsigned integer. + int row = int(get<0>(tScS(i))) / kBlockH; + if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, col_limit_right(row, n_block))) { + tSrS(i) = -INFINITY; + } else if constexpr(Is_local) { + if (int(get<1>(tScS(i))) < col_limit_left(row, n_block)) { + tSrS(i) = -INFINITY; + } + } + } + } + } + + softmax.template online_softmax(tSrS); + + Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())); + Tensor scores_scale = make_fragment_like(softmax.row_max); + clear(scores_scale); + + constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM_div_H, kBlockN) + 1; + // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > n_block_min; ++masking_step, --n_block) { + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + warp_scheduler_barrier_sync(); + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); + if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); } + consumer_wait(pipeline_v, smem_pipe_read_v); + flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + warp_scheduler_barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read_k); // release K + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int row = int(get<0>(tScS(i))) / kBlockH; + if (int(get<1>(tScS(i))) >= col_limit_right(row, n_block - 1)) { + tSrS(i) = -INFINITY; + } + } + cute::copy(softmax.template max(tSrS), scores_scale); + softmax.template online_softmax(tSrS); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + ++smem_pipe_read_k; + ++smem_pipe_read_v; + cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); + } + + #pragma unroll 1 + for (; n_block > n_block_min; --n_block) { + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + warp_scheduler_barrier_sync(); + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); + softmax.rescale_o(tOrO, scores_scale); + consumer_wait(pipeline_v, smem_pipe_read_v); + flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + warp_scheduler_barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read_k); // release K + + if constexpr(Is_local) { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int row = int(get<0>(tScS(i))) / kBlockH; + if ( + int(get<1>(tScS(i))) >= col_limit_right(row, n_block - 1) || + int(get<1>(tScS(i))) < col_limit_left(row, n_block - 1) + ) { + tSrS(i) = -INFINITY; + } + } + } + // auto scores_scale = softmax.template max(tSrS); + cute::copy(softmax.template max(tSrS), scores_scale); + softmax.template online_softmax(tSrS); + + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + ++smem_pipe_read_k; + ++smem_pipe_read_v; + // softmax.rescale_o(tOrO, scores_scale); + cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); + } + // Tell warp 0 that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + softmax.rescale_o(tOrO, scores_scale); + consumer_wait(pipeline_v, smem_pipe_read_v); + flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + cute::copy(softmax.template finalize(tSrS), scores_scale); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang + ++smem_pipe_read_v; + softmax.rescale_o(tOrO, scores_scale); + return; + } + + template + CUTLASS_DEVICE void + mma_fp8(Params const& mainloop_params, + MainloopPipeline pipeline_k, + MainloopPipelineNoTMA pipeline_vt, + PipelineState& smem_pipe_read, + PipelineState& smem_pipe_release, + FrgTensorO& tOrO, + Softmax& softmax, + int n_block_min, + int n_block_max, + int thread_idx, + int work_idx, + int m_block, + SharedStorage& shared_storage, + const Seqlen_traits_Q& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + + // static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kBlockH = Ktraits::kBlockH; + static constexpr int kBlockM_div_H = get<0>(TileShape_MNK{}) / kBlockH; + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutVt{}); + + typename Ktraits::TiledMma0 tiled_mma0; + typename Ktraits::TiledMma1 tiled_mma1; + auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx); + auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx); + + // Allocate "fragments/descriptors" for first matmul. + Tensor tSrQ = threadMma0.partition_fragment_A(sQ); + Tensor tSrK = threadMma0.partition_fragment_B(sK); + // Allocate "fragments/descriptors" for second matmul. + Tensor tOrV = threadMma1.partition_fragment_B(sVt); + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; + int const seqlen_q = seqlen_traits_q.actual_seq_len; + int const seqlen_k = seqlen_traits_k.actual_seq_len; + int n_block = n_block_max - 1; + + cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); + if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); } + + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + + consumer_wait(pipeline_k, smem_pipe_read); + warp_scheduler_barrier_sync(); + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + if constexpr (!No_smem_O) { + if (work_idx != 0) { + int lane_predicate = cute::elect_one_sync(); + if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) { + tma_store_wait<0>(); + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.barrier_O.arrive(cta_id, lane_predicate); + } + } + } + } + warpgroup_wait<0>(); + warp_scheduler_barrier_arrive(); + pipeline_k.consumer_release(smem_pipe_read); + + auto col_limit_right = [&](int row, int n_block) { + int col_limit_base = row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H; + if constexpr(Is_local) + return col_limit_base + mainloop_params.window_size_right; + else + return col_limit_base; + }; + auto col_limit_left = [&](int row, int n_block) { + return std::max( + 0, + row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM_div_H - mainloop_params.window_size_left + ); + }; + { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if constexpr (!Is_causal && !Is_local) { // Just masking based on col + if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } + } else { // mask based on both row and col + int row = int(get<0>(tScS(i))) / kBlockH; + if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, col_limit_right(row, n_block))) { + tSrS(i) = -INFINITY; + } else if constexpr(Is_local) { + if (int(get<1>(tScS(i))) < col_limit_left(row, n_block)) { + tSrS(i) = -INFINITY; + } + } + } + } + } + + softmax.template online_softmax(tSrS); + Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout())); + permute_regs_A_to_C(tOrP); + + Tensor scores_scale = make_fragment_like(softmax.row_max); + clear(scores_scale); + + consumer_wait(pipeline_vt, smem_pipe_read); + flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); } + + ++smem_pipe_read; + --n_block; + constexpr int extra_iterations = !Is_causal ? kStages - 1 : cute::ceil_div(kBlockM_div_H, kBlockN); + + if constexpr(Is_causal) { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < extra_iterations && n_block >= n_block_min; ++iter, --n_block) { + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read); + warp_scheduler_barrier_sync(); + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int row = int(get<0>(tScS(i))) / kBlockH; + if (int(get<1>(tScS(i))) >= col_limit_right(row, n_block)) { + tSrS(i) = -INFINITY; + } + } + + warp_scheduler_barrier_arrive(); + pipeline_k.consumer_release(smem_pipe_read); + if constexpr(Delay_V_release) { + pipeline_vt.consumer_release(smem_pipe_release); + ++smem_pipe_release; + } + consumer_wait(pipeline_vt, smem_pipe_read); + + cute::copy(softmax.template max(tSrS), scores_scale); + softmax.rescale_o(tOrO, scores_scale); + softmax.template online_softmax(tSrS); + Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout())); + permute_regs_A_to_C(tOrP); + + flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); } + ++smem_pipe_read; + } + } else if constexpr(!Is_local) { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < extra_iterations && n_block >= n_block_min; ++iter, --n_block) { + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read); + if constexpr(Delay_V_release) { + pipeline_vt.consumer_release(smem_pipe_release); + ++smem_pipe_release; + } + warp_scheduler_barrier_sync(); + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + warp_scheduler_barrier_arrive(); + if constexpr(!Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); } + else { consumer_wait(pipeline_vt, smem_pipe_read); } + + cute::copy(softmax.template max(tSrS), scores_scale); + softmax.rescale_o(tOrO, scores_scale); + softmax.template online_softmax(tSrS); + Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout())); + permute_regs_A_to_C(tOrP); + + if constexpr (Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); } + else { consumer_wait(pipeline_vt, smem_pipe_read); } + flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); } + ++smem_pipe_read; + } + } + + if constexpr(Delay_V_release) { + warp_scheduler_barrier_sync(); + CUTLASS_PRAGMA_NO_UNROLL + for (; n_block >= n_block_min; --n_block) { + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read); + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + + if constexpr(Is_local) { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int row = int(get<0>(tScS(i))) / kBlockH; + if ( + int(get<1>(tScS(i))) >= col_limit_right(row, n_block) || + int(get<1>(tScS(i))) < col_limit_left(row, n_block) + ) { + tSrS(i) = -INFINITY; + } + } + } + + warp_scheduler_barrier_arrive(); + pipeline_k.consumer_release(smem_pipe_read); + pipeline_vt.consumer_release(smem_pipe_release); + + cute::copy(softmax.template max(tSrS), scores_scale); + softmax.rescale_o(tOrO, scores_scale); + softmax.template online_softmax(tSrS); + Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout())); + permute_regs_A_to_C(tOrP); + + consumer_wait(pipeline_vt, smem_pipe_read); + flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + warp_scheduler_barrier_sync(); + ++smem_pipe_read; + ++smem_pipe_release; + } + warp_scheduler_barrier_arrive(); + pipeline_vt.consumer_release(smem_pipe_release); + ++smem_pipe_release; + } else { + if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); } + CUTLASS_PRAGMA_NO_UNROLL + for (; n_block >= n_block_min; --n_block) { + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read); + if constexpr (kHeadDim == 256) { warp_scheduler_barrier_sync(); } + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + + if constexpr(Is_local) { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int row = int(get<0>(tScS(i))) / kBlockH; + if ( + int(get<1>(tScS(i))) >= col_limit_right(row, n_block) || + int(get<1>(tScS(i))) < col_limit_left(row, n_block) + ) { + tSrS(i) = -INFINITY; + } + } + } + + warp_scheduler_barrier_arrive(); + pipeline_k.consumer_release(smem_pipe_read); + + cute::copy(softmax.template max(tSrS), scores_scale); + softmax.rescale_o(tOrO, scores_scale); + softmax.template online_softmax(tSrS); + Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout())); + permute_regs_A_to_C(tOrP); + + consumer_wait(pipeline_vt, smem_pipe_read); + if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); } + flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + pipeline_vt.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } + if constexpr (kHeadDim == 128) { warp_scheduler_barrier_arrive(); } + } + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cute::copy(softmax.template finalize(tSrS, shared_storage.descale_v), scores_scale); + softmax.rescale_o(tOrO, scores_scale); + return; + } + +}; + +} // namespace flash diff --git a/candle-flash-attn-v3/hkernel/named_barrier.hpp b/candle-flash-attn-v3/hkernel/named_barrier.hpp new file mode 100644 index 0000000000..efdd0fafdc --- /dev/null +++ b/candle-flash-attn-v3/hkernel/named_barrier.hpp @@ -0,0 +1,41 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cutlass/arch/barrier.h" + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts + +enum class FwdNamedBarriers { + QueryEmpty = 0, + ValueEmpty = 1, + TileCountSmemEmpty = 2, + TileCountSmemFull = 3, + WarpSchedulerWG1 = 4, + WarpSchedulerWG2 = 5, + WarpSchedulerWG3 = 6, + ProducerWG = 7 +}; + +// enum class BwdNamedBarriers { +// QueryEmpty = 0, +// KVEmpty = 1, +// TileCountSmemEmpty = 2, +// TileCountSmemFull = 3, +// // WarpSchedulerWG1 = 4, +// // WarpSchedulerWG2 = 5, +// dQEmptyWG1 = 4, +// dQEmptyWG2 = 5, +// dSFull = 6, +// // dSEmptyWG1 = 7, +// // dSEmptyWG2 = 8, +// dQEmpty = 7, +// dQFull = 8, +// }; + +} // namespace flash diff --git a/candle-flash-attn-v3/hkernel/seq_len.h b/candle-flash-attn-v3/hkernel/seq_len.h new file mode 100644 index 0000000000..5085fa16e2 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/seq_len.h @@ -0,0 +1,451 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#include +#include + +namespace flash { + +static constexpr int kMaxTileSize = 128; + +template class SeqLenTraits { +public: + static_assert((!UsePagedKV_) || (UseVarSeqLen_ && UsePagedKV_), "PagedKV is only supported for VarSeqLen."); + static_assert(!(UseVarSeqLen_ && UseGQAPacking_), + "Variable sequence length with GQA parallelization not implemented yet."); + + // Total number of queries / keys. Unpadded. + int sum_s = 0; + // seq len offsets. + int *cu_seq_len = nullptr; + // actual seq len array. + int *seq_used = nullptr; + // seq len of the current batch. + int actual_seq_len = -1; + + // Whether this is for fixed-seq-len or var-seq-len. + static constexpr bool UseVarSeqLen = UseVarSeqLen_; + static constexpr bool UseGQAPacking = UseGQAPacking_; + static constexpr bool UsePagedKV = UsePagedKV_; + + using ShapeT = std::conditional_t< + UseVarSeqLen, + std::conditional_t< + !UsePagedKV, + cute::Shape, + cute::Shape>, + std::conditional_t< + UseGQAPacking, + cute::Shape, + cute::Shape + > + >; + using VirtualShapeT = std::conditional_t< + UsePagedKV, + cute::Shape, + ShapeT + >; + + using StrideT = std::conditional_t< + UseVarSeqLen, + std::conditional_t< + !UsePagedKV, + cute::Shape, + cute::Shape>, + std::conditional_t< + UseGQAPacking, + cute::Shape, + cute::Shape + > + >; + using LayoutT = cute::Layout; + + using ShapeLseT = std::conditional_t< + UseVarSeqLen, + cute::Shape, + cute::Shape + >; + using StrideLseT = std::conditional_t< + UseVarSeqLen, + cute::Shape, + cute::Shape + >; + using LayoutLseT = cute::Layout; + + // Not used for varseqlen + using ShapeOAccumT = std::conditional_t< + UseGQAPacking, + cute::Shape, + cute::Shape + >; + using StrideOAccumT = std::conditional_t< + UseGQAPacking, + cute::Shape, + cute::Shape + >; + using LayoutOAccumT = cute::Layout; + + using ShapeLseAccumT = cute::Shape; + using StrideLseAccumT = cute::Shape; + using LayoutLseAccumT = cute::Layout; + + CUTLASS_HOST SeqLenTraits() {} + + CUTLASS_HOST SeqLenTraits( + int sum_s, int max_seq_len, int *cu_seq_len = nullptr, int *seq_used = nullptr): + sum_s(sum_s), cu_seq_len(cu_seq_len), seq_used(seq_used), actual_seq_len(max_seq_len) {} + + CUTLASS_DEVICE void init(int bidb) { + // TODO: add leftpad, seqlen_new for kv cache support + if (seq_used) { + actual_seq_len = seq_used[bidb]; + } + } + + CUTLASS_DEVICE void init_no_guard(int bidb) { + actual_seq_len = seq_used[bidb]; + } + + // Returns the layout of a tensor in MKHB format in global memory. + // padded: only useful for var-seq-len for dq_accum and softmax_d. + CUTLASS_HOST_DEVICE auto get_gmem_layout( + int m, int k, int h, int b, + int64_t m_stride, int64_t h_stride, int64_t b_stride, + int page_block_size, int num_blocks, + bool padded = false) const { + static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen."); + // static_assert(!UseGQAPacking, "Specialize default implementation for UseGQAPacking."); + return make_layout(make_shape(m, k, h, b), + make_stride(m_stride, cute::_1{}, h_stride, b_stride)); + } + + + // Returns the layout of a tensor in MKHB format in virtual memory space + // that is mapped to the global memory via the block table when paged attention is used + CUTLASS_HOST_DEVICE VirtualShapeT get_virtual_shape( + int m, int k, int h_k, int b, int h_h_k_ratio, bool padded) const { + return make_shape(m, k, h_k, b); + } + + // Returns the layout of a tensor in MKHB format in global memory. + // padded: only useful for var-seq-len for dq_accum and softmax_d. + // Overload that separates h into h_k and h/h_k. + CUTLASS_HOST_DEVICE auto get_gmem_layout( + int m, int k, int h_k, int b, int h_h_k_ratio, + int64_t m_stride, int64_t h_stride, int64_t b_stride, + bool padded = false) const { + static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen."); + static_assert(!UseGQAPacking, "Specialize default implementation for UseGQAPacking."); + return make_layout(make_shape(m, k, h_k * h_h_k_ratio, b), + make_stride(m_stride, cute::_1{}, h_stride, b_stride)); + } + + // Returns the layout of a tensor in MKHBT format in global memory, + // where T is number of splits. + CUTLASS_HOST_DEVICE auto get_oaccum_gmem_layout( + int m, int k, int h, int b, int num_splits, + int64_t m_stride, int64_t h_stride, int64_t b_stride, int64_t split_stride, + bool padded = false) const { + return make_layout(make_shape(m, k, h, b, num_splits), + make_stride(m_stride, cute::_1{}, h_stride, b_stride, split_stride)); + } + + // Returns the layout of a tensor in MKHBT format in global memory, + // where T is number of splits. + // Overload that separates h into h_k and h/h_k. + CUTLASS_HOST_DEVICE auto get_oaccum_gmem_layout( + int m, int k, int h_k, int b, int h_h_k_ratio, int num_splits, + int64_t m_stride, int64_t h_stride, int64_t b_stride, int64_t split_stride, + bool padded = false) const { + return make_layout(make_shape(m, k, h_k * h_h_k_ratio, b, num_splits), + make_stride(m_stride, cute::_1{}, h_stride, b_stride, split_stride)); + } + + // Returns the layout of lse tensor in BHM format in global memory. + // padded: only useful for var-seq-len for dq_accum and softmax_d. + CUTLASS_HOST_DEVICE auto get_lse_gmem_layout( + int m, int h, int b, bool padded = false) const { + static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen."); + return make_layout(make_shape(b, h, m), + make_stride(int64_t(h * m), int64_t(m), cute::_1())); + } + + // Returns the layout of lse tensor in TBHM format in global memory, + // where T is number of splits. + CUTLASS_HOST_DEVICE auto get_lseaccum_gmem_layout( + int m, int h, int b, int num_splits, bool padded = false) const { + return make_layout(make_shape(num_splits, b, h, m), + make_stride(int64_t(b * h * m), int64_t(h * m), int64_t(m), cute::_1())); + } + + template + CUTLASS_DEVICE auto get_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, bool padded = false) const { + auto g_tensor = local_tile( + m_tensor(_, _, bidh, bidb), tile_shape, make_coord(_, _0{})); + return g_tensor; + } + + template + CUTLASS_DEVICE auto get_lse_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, int n_split_idx, bool padded = false) const { + // m_tensor has shape (B, H, M) or (splits, B, H, M) + // Expect tile shape (bM) + // Returns g_tensor of shape = (bM, ceil_div(M,bM)) + if constexpr(!Is_split) { + auto g_tensor = local_tile(m_tensor(bidb, bidh, _), tile_shape, make_coord(_)); + return g_tensor; + } else { + auto g_tensor = local_tile(m_tensor(n_split_idx, bidb, bidh, _), tile_shape, make_coord(_)); + return g_tensor; + } + } + + template + CUTLASS_DEVICE auto get_o_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, int split_idx, bool padded = false) const { + // static_assert(!UseVarSeqLen, "Don't use get_o_local_tile_tensor with VarSeqLen."); + // m_tensor has shape (M, K, H, B) or (M, K, H, B, splits) + // Expect tile shape (bM, K) + // Returns g_tensor of shape = (bM, K, ceil_div(M,bM)) + if constexpr(!Is_split) { + auto g_tensor = local_tile( + m_tensor(_, _, bidh, bidb), tile_shape, make_coord(_, _0{})); + return g_tensor; + } else { + auto g_tensor = local_tile( + m_tensor(_, _, bidh, bidb, split_idx), tile_shape, make_coord(_, _0{})); + return g_tensor; + } + } + +}; + +using FixedSeqLenTraits = SeqLenTraits; +using VarSeqLenTraits = SeqLenTraits; +using PagedSeqLenTraits = SeqLenTraits; +using FixedGQASeqLenTraits = SeqLenTraits; + +template <> +CUTLASS_DEVICE void VarSeqLenTraits::init(int bidb) { + actual_seq_len = + seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]); +} + +template <> +CUTLASS_DEVICE void FixedGQASeqLenTraits::init(int bidb) { + // no op +} + +// Returns the static layout of a var-seq-len tensor in global memory based on +// max_seq_len and max_batch_size. +// padded: only useful for var-seq-len for dq_accum and softmax_d. +// When padded is True, use B_M + kMaxTileSize * B as the total B_M. +template <> +CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout( + int m, int k, int h, int b, + int64_t m_stride, int64_t h_stride, int64_t b_stride, + int page_block_size, int num_blocks, + bool padded) const { + return make_layout( + make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h), + make_stride(m_stride, cute::_1{}, h_stride)); +} + +template <> +CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout( + int m, int k, int h_k, int b, int h_h_k_ratio, + int64_t m_stride, int64_t h_stride, int64_t b_stride, + bool padded) const { + return make_layout( + make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h_k * h_h_k_ratio), + make_stride(m_stride, cute::_1{}, h_stride)); +} + + +template <> + CUTLASS_HOST_DEVICE VarSeqLenTraits::VirtualShapeT VarSeqLenTraits::get_virtual_shape( + int m, int k, int h, int b, int h_h_k_ratio, + bool padded) const { + return make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h); + } + + +// padded: only useful for var-seq-len for dq_accum and softmax_d. +// When padded is True, use B_M + kMaxTileSize * B as the total B_M. +//template <> +template <> +CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_lse_gmem_layout( + int m, int h, int b, bool padded) const { + return make_layout( + make_shape(h, sum_s + (padded ? kMaxTileSize * b : 0)), + make_stride(int64_t(sum_s + (padded ? kMaxTileSize * b : 0)), cute::_1())); +} + +template <> +template +CUTLASS_DEVICE auto VarSeqLenTraits::get_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, bool padded) const { + auto g_offset = local_tile( + m_tensor(_, _, bidh), + cute::make_shape(1, get<1>(tile_shape)), + make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0), _0{})); + auto g_sequence = make_tensor( + g_offset.data(), + make_layout( + cute::make_shape(actual_seq_len, get<1>(tile_shape)), + g_offset.stride() + )); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{})); + return g_tensor; +} + +// TODO: restructure to not duplicate code +template <> +template +CUTLASS_DEVICE auto VarSeqLenTraits::get_o_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, int n_split_idx, bool padded) const { + static_assert(!Is_split, "Don't currently support split kv kernel with VarSeqLenTraits"); + auto g_offset = local_tile( + m_tensor(_, _, bidh), + cute::make_shape(1, get<1>(tile_shape)), + make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0), _0{})); + auto g_sequence = make_tensor( + g_offset.data(), + make_layout( + cute::make_shape(actual_seq_len, get<1>(tile_shape)), + g_offset.stride() + )); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{})); + return g_tensor; +} + + +template <> +template +CUTLASS_DEVICE auto VarSeqLenTraits::get_lse_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, int n_split_idx, bool padded) const { + static_assert(!Is_split, "Don't currently support split kv kernel with VarSeqLenTraits"); + auto g_offset = local_tile( + m_tensor(bidh, _), cute::make_shape(_1{}), + make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0))); + auto g_sequence = make_tensor( + g_offset.data(), + make_layout(cute::make_shape(actual_seq_len), cute::make_shape(_1{}))); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_)); + return g_tensor; +} + +// Returns layout of QO tensor in (M,H/HK,K,HK,B) format in global memory. +template <> +CUTLASS_HOST_DEVICE auto FixedGQASeqLenTraits::get_gmem_layout( + int m, int k, int h_k, int b, int h_h_k_ratio, + int64_t m_stride, int64_t h_stride, int64_t b_stride, bool padded) const { + return make_layout(make_shape(m, h_h_k_ratio, k, h_k, b), + make_stride(m_stride, h_stride, cute::_1{}, + h_stride * h_h_k_ratio, b_stride)); +} + +template <> + CUTLASS_HOST_DEVICE FixedGQASeqLenTraits::VirtualShapeT FixedGQASeqLenTraits::get_virtual_shape( + int m, int k, int h_k, int b, int h_h_k_ratio, + bool padded) const { + return make_shape(m, h_h_k_ratio, k, h_k, b); + } + + +// Returns layout of Oaccum tensor in (M,H/HK,K,HK,B,T) format in global memory. +template <> +CUTLASS_HOST_DEVICE auto FixedGQASeqLenTraits::get_oaccum_gmem_layout( + int m, int k, int h_k, int b, int h_h_k_ratio, int num_splits, + int64_t m_stride, int64_t h_stride, int64_t b_stride, int64_t split_stride, + bool padded) const { + return make_layout(make_shape(m, h_h_k_ratio, k, h_k, b, num_splits), + make_stride(m_stride, h_stride, cute::_1{}, + h_stride * h_h_k_ratio, b_stride, + split_stride)); +} + +template <> +template +CUTLASS_DEVICE auto FixedGQASeqLenTraits::get_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh_kv, int bidb, bool padded) const { + // m_tensor has shape (M, H/H_K, K, H_K, B) + // Expect tile_shape (bM/bH, bH, K) + // Returns g_tensor of shape (bM/bH, bH, K, ceil_div(M,bM/bH), ceil_div(H/H_K,bH)) + auto g_tensor = local_tile( + m_tensor(_, _, _, bidh_kv, bidb), tile_shape, make_coord(_, _, _0{})); + return g_tensor; +} + +template <> +template +CUTLASS_DEVICE auto FixedGQASeqLenTraits::get_o_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh_kv, int bidb, int split_idx, bool padded) const { + // m_tensor has shape (M, H/H_K, K, H_K, B) or (M, H/H_K, K, H_K, B, splits) + // Expect tile_shape (bM/bH, bH, K) + // Returns g_tensor of shape (bM/bH, bH, K, ceil_div(M,bM/bH), ceil_div(H/H_K,bH)) + if constexpr(!Is_split) { + auto g_tensor = local_tile( + m_tensor(_, _, _, bidh_kv, bidb), tile_shape, make_coord(_, _, _0{})); + return g_tensor; + } else { + auto g_tensor = local_tile( + m_tensor(_, _, _, bidh_kv, bidb, split_idx), tile_shape, make_coord(_, _, _0{})); + return g_tensor; + } +} + +/////////////// PagedSeqLenTraits ///////////////// + + // Returns the layout of a tensor in MKHB format in global memory. + // padded: only useful for var-seq-len for dq_accum and softmax_d. +template<> +CUTLASS_HOST_DEVICE auto PagedSeqLenTraits::get_gmem_layout( + int m, int k, int h, int b, + int64_t m_stride, int64_t h_stride, int64_t b_stride, + int page_block_size, int num_blocks, + bool padded) const { + return static_cast(make_layout(make_shape((int)page_block_size, k, h, (int)num_blocks), + make_stride(m_stride, cute::_1{}, h_stride, b_stride))); +} + +template <> +CUTLASS_DEVICE void PagedSeqLenTraits::init(int bidb) { + actual_seq_len = + seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]); +} + +template <> +template +CUTLASS_DEVICE auto PagedSeqLenTraits::get_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, bool padded) const { + + auto g_slice = m_tensor(_, _, bidh, bidb); // = m_tensor[:,:, head_idx, batch_idx] + auto g_seq_slice = make_tensor( // m_tensor[:actual_seq_len,:, head_idx, batch_idx] + g_slice.data(), + make_layout(cute::make_shape(actual_seq_len, get<1>(g_slice.layout().shape())), g_slice.layout().stride())); + // slice up into tiles + auto g_tensor = local_tile( + g_seq_slice, tile_shape, make_coord(_, _0{})); + return g_tensor; + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/candle-flash-attn-v3/hkernel/softmax.h b/candle-flash-attn-v3/hkernel/softmax.h new file mode 100644 index 0000000000..1125cb33b0 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/softmax.h @@ -0,0 +1,235 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include + +#include "utils.h" + +#include "cutlass/fast_math.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); + if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); } +} + +__forceinline__ __device__ __half2 half_exp(__half2 x) { + uint32_t tmp_out, tmp_in; + tmp_in = reinterpret_cast(x); + asm ("ex2.approx.f16x2 %0, %1;\n" + : "=r"(tmp_out) + : "r"(tmp_in)); + __half2 out = reinterpret_cast<__half2&>(tmp_out); + return out; +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + constexpr static float max_offset = Use_max_offset ? 8.0f : 0.0f; + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = Check_inf + ? (max(mi) == -INFINITY ? 0.f : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset) + : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + constexpr static bool Use_max_offset = Use_max_offset_; + // constexpr static float max_offset = Use_max_offset ? 8.0f : 0.0f; + // constexpr static float max_offset_E = max_offset * float(M_LN2); + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + const float softmax_scale_log2; + + CUTLASS_DEVICE Softmax(float scale_ = 1.f) : softmax_scale_log2(scale_) {}; + + template + __forceinline__ __device__ TensorT max(Tensor0 &acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + TensorT scores_scale; + if constexpr (Is_first) { + flash::template reduce_max(scores, row_max); + cute::fill(scores_scale, 1.f); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale(mi); + } + } + return scores_scale; + }; + + template + __forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + TensorT scores_scale; + if constexpr (Is_first) { + flash::template reduce_max(scores, row_max); + flash::template scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + cute::fill(scores_scale, 1.f); + // if (cute::thread0()) { print_tensor(scores); printf("\n scale = %f\n", softmax_scale_log2); print_tensor(row_sum); } + } else { + // Tensor scores_max_prev = make_fragment_like(row_max); + // cute::copy(row_max, scores_max_prev); + // flash::template reduce_max(scores, row_max); + // // if (cute::thread0()) { print_tensor(scores); printf("\n"); print_tensor(row_max); printf("\n"); } + // #pragma unroll + // for (int mi = 0; mi < size(row_max); ++mi) { + // float scores_max_cur = !Check_inf + // ? row_max(mi) + // : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + // scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + // row_sum(mi) *= scores_scale(mi); + // } + flash::template scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + } + return scores_scale; + }; + + template + __forceinline__ __device__ TensorT finalize(Tensor0 &acc_s, float descale_v = 1.f, float rp_dropout=1.f) { + constexpr static float max_offset_E = Use_max_offset ? 8.f * float(M_LN2) : 0.f; + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT scores_scale; + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 0.f : descale_v / sum; + row_sum(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : (row_max(mi) * softmax_scale_log2) * float(M_LN2) - max_offset_E + __logf(sum); + scores_scale(mi) = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + } + return scores_scale; + }; + + template + __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) { + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); } + } + }; + +}; + +} // namespace flash diff --git a/candle-flash-attn-v3/hkernel/static_switch.h b/candle-flash-attn-v3/hkernel/static_switch.h new file mode 100644 index 0000000000..e85758e62c --- /dev/null +++ b/candle-flash-attn-v3/hkernel/static_switch.h @@ -0,0 +1,168 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +// + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +// if (PRECTYPE == 3) { +// using NAME = cutlass::float_e4m3_t; +// return __VA_ARGS__(); +// } else // removed this for dropped fp8 support +#define PREC_SWITCH(PRECTYPE, NAME, ...) \ + [&] { \ + if (PRECTYPE == 2) { \ + using NAME = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } else { \ + using NAME = cutlass::half_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define HEADDIM_SWITCH(HEADDIM, CONST_NAME, ...) \ + [&] { \ + if (HEADDIM == 64) { \ + constexpr static int CONST_NAME = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 128) { \ + constexpr static int CONST_NAME = 128; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int CONST_NAME = 256; \ + return __VA_ARGS__(); \ + } \ + }() + +#define SEQLEN_SWITCH(PARAMS, NAME, NAME_Q, ...) \ + [&] { \ + const bool useSeqLen = PARAMS.cu_seqlens_q; \ + const bool usePagedKV = PARAMS.page_block_size>0; \ + if (useSeqLen) { \ + if (usePagedKV) { \ + using NAME = flash::PagedSeqLenTraits; \ + using NAME_Q = flash::VarSeqLenTraits; \ + return __VA_ARGS__(); \ + } else { \ + using NAME = flash::VarSeqLenTraits; \ + using NAME_Q = flash::VarSeqLenTraits; \ + return __VA_ARGS__(); \ + } \ + } else { \ + using NAME = flash::FixedSeqLenTraits; \ + using NAME_Q = flash::FixedSeqLenTraits; \ + return __VA_ARGS__(); \ + } \ + }() + +#define SEQLEN_SWITCH_FWD(VAR_SEQ_LEN_Q, SEQ_USED_K, NAME_Q, NAME_K, ...) \ + [&] { \ + bool useVarSeqLenQ = VAR_SEQ_LEN_Q; \ + bool useSeqUsedK = SEQ_USED_K; \ + if (useVarSeqLenQ) { \ + using NAME_Q = flash::VarSeqLenTraits; \ + using NAME_K = flash::VarSeqLenTraits; \ + return __VA_ARGS__(); \ + } else if (useSeqUsedK) { \ + using NAME_Q = flash::FixedSeqLenTraits; \ + using NAME_K = flash::FixedSeqLenTraitsDynamic; \ + return __VA_ARGS__(); \ + } else { \ + using NAME_Q = flash::FixedSeqLenTraits; \ + using NAME_K = flash::FixedSeqLenTraits; \ + return __VA_ARGS__(); \ + } \ + }() + +#define QUERYHEAD_SWITCH(QUERYHEADS, CONST_NAME, ...) \ + [&] { \ + if (QUERYHEADS <= 2) { \ + constexpr static int CONST_NAME = 2; \ + return __VA_ARGS__(); \ + } else if (QUERYHEADS <= 4) { \ + constexpr static int CONST_NAME = 4; \ + return __VA_ARGS__(); \ + } else if (QUERYHEADS <= 8) { \ + constexpr static int CONST_NAME = 8; \ + return __VA_ARGS__(); \ + } else if (QUERYHEADS <= 16) { \ + constexpr static int CONST_NAME = 16; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int CONST_NAME = 32; \ + return __VA_ARGS__(); \ + } \ + }() + +#define MMA_3WG_SWITCH(QLEN, CONST_NAME, ...) \ + [&] { \ + if (QLEN <= 64) { \ + constexpr static int CONST_NAME = 1; \ + return __VA_ARGS__(); \ + } else if (QLEN <= 128) { \ + constexpr static int CONST_NAME = 2; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int CONST_NAME = 3; \ + return __VA_ARGS__(); \ + } \ + }() + +#define MMA_2WG_SWITCH(QLEN, CONST_NAME, ...) \ + [&] { \ + if (QLEN <= 64) { \ + constexpr static int CONST_NAME = 1; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int CONST_NAME = 2; \ + return __VA_ARGS__(); \ + } \ + }() + +#define NUM_SPLITS_SWITCH(NUM_SPLITS, LOG_MAX_SPLITS, ...) \ + [&] { \ + if (NUM_SPLITS <= 2) { \ + constexpr static int LOG_MAX_SPLITS = 1; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 4) { \ + constexpr static int LOG_MAX_SPLITS = 2; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 8) { \ + constexpr static int LOG_MAX_SPLITS = 3; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 16) { \ + constexpr static int LOG_MAX_SPLITS = 4; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 32) { \ + constexpr static int LOG_MAX_SPLITS = 5; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 64) { \ + constexpr static int LOG_MAX_SPLITS = 6; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int LOG_MAX_SPLITS = 7; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/candle-flash-attn-v3/hkernel/tile_scheduler.hpp b/candle-flash-attn-v3/hkernel/tile_scheduler.hpp new file mode 100644 index 0000000000..9375aa1e41 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/tile_scheduler.hpp @@ -0,0 +1,301 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cutlass/fast_math.h" +#include "cutlass/arch/barrier.h" + +#include "named_barrier.hpp" + +namespace flash { + +/////////////////////////////////////////////////////////////////////////////// + +struct SingleTileScheduler { + +public: + + // Host side kernel arguments + struct Arguments { + int const num_blocks_m, num_splits, num_head, num_batch; + int* const tile_count_semaphore = nullptr; + }; + + // Device side kernel params + struct Params {}; + + static Params + to_underlying_arguments(Arguments const& args) { + return {}; + } + + static dim3 + get_grid_dim(Arguments const& args, int num_sm) { + return {uint32_t(args.num_blocks_m), uint32_t(args.num_head), uint32_t(args.num_batch)}; + } + + struct WorkTileInfo { + int M_idx = 0; + int H_idx = 0; + int B_idx = 0; + bool is_valid_tile = false; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return is_valid_tile; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + return {M_idx, 1, H_idx, B_idx}; + } + + }; + + CUTLASS_DEVICE + SingleTileScheduler(int* tile_count_smem_) { } + + CUTLASS_DEVICE + WorkTileInfo + get_initial_work() const { + return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true}; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + CUTLASS_DEVICE + void + broadcast_next_work(WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {-1, -1, -1, false}; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class StaticPersistentTileScheduler { + +public: + + // Host side kernel arguments + struct Arguments { + int const num_blocks_m, num_splits, num_head, num_batch; + int* const tile_count_semaphore = nullptr; + }; + + // Device side kernel params + struct Params { + int const total_blocks; + cutlass::FastDivmod const m_block_divmod, split_divmod, head_divmod; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + // return {args.num_blocks_m * args.num_head * args.num_batch, + // cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head)}; + return {args.num_blocks_m * args.num_splits * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks_m), + cutlass::FastDivmod(args.num_splits), + cutlass::FastDivmod(args.num_head)}; + } + + static dim3 + get_grid_dim(Arguments const& args, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int m_block, split_idx, bidh, bidb; + if constexpr(!Is_split) { + bidb = params.head_divmod.divmod(bidh, + params.m_block_divmod.divmod(m_block, tile_idx)); + return {m_block, 1, bidh, bidb}; + } else { + bidb = params.head_divmod.divmod(bidh, + params.split_divmod.divmod(split_idx, + params.m_block_divmod.divmod(m_block, tile_idx))); + return {m_block, split_idx, bidh, bidb}; + } + } + + }; + + CUTLASS_DEVICE + StaticPersistentTileScheduler(int* tile_count_smem_) {}; + + CUTLASS_DEVICE + WorkTileInfo + get_initial_work() const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + CUTLASS_DEVICE + void + broadcast_next_work(WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {current_work.tile_idx + int(gridDim.x)}; + } + +}; + +template +class DynamicPersistentTileScheduler { + +protected: + int* const tile_count_smem; + +public: + + // Host side kernel arguments + struct Arguments { + int const num_blocks_m, num_splits, num_head, num_batch; + int* const tile_count_semaphore; + }; + + // Device side kernel params + struct Params { + int const total_blocks; + cutlass::FastDivmod const m_block_divmod, split_divmod, head_divmod; + int* const tile_count_semaphore; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + // return {args.num_blocks_m * args.num_head * args.num_batch, + // cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head), + // args.tile_count_semaphore}; + return {args.num_blocks_m * args.num_splits * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks_m), + cutlass::FastDivmod(args.num_splits), + cutlass::FastDivmod(args.num_head), + args.tile_count_semaphore}; + } + + static dim3 + get_grid_dim(Arguments const& args, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int m_block, split_idx, bidh, bidb; + if constexpr(!Is_split) { + bidb = params.head_divmod.divmod(bidh, + params.m_block_divmod.divmod(m_block, tile_idx)); + return {m_block, 1, bidh, bidb}; + } else { + bidb = params.head_divmod.divmod(bidh, + params.split_divmod.divmod(split_idx, + params.m_block_divmod.divmod(m_block, tile_idx))); + return {m_block, split_idx, bidh, bidb}; + } + } + + }; + + CUTLASS_DEVICE + DynamicPersistentTileScheduler(int* tile_count_smem_) : tile_count_smem(tile_count_smem_) {}; + + CUTLASS_DEVICE + WorkTileInfo + get_initial_work() const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + } + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + if (threadIdx.x % NumProducerThreads == 0) { + current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); + } + } + + CUTLASS_DEVICE + void + broadcast_next_work(WorkTileInfo& current_work) const { + cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + if (threadIdx.x % NumProducerThreads == 0) { + *tile_count_smem = current_work.tile_idx; + } + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + if constexpr (IsProducer && NumProducerThreads == cutlass::NumThreadsPerWarp) { + // thread 0 already has the right tile_idx, just need to broadcast to the rest of the producer threads (warp 0) + return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)}; + } else if constexpr (IsProducer && NumProducerThreads == cutlass::NumThreadsPerWarpGroup) { + // TODO: investigate optimal synchronize + int tile_idx = *tile_count_smem; + return {tile_idx}; + } else { + cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + int tile_idx = *tile_count_smem; + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + return {tile_idx}; + } + } + +}; + +} // namespace flash diff --git a/candle-flash-attn-v3/hkernel/utils.h b/candle-flash-attn-v3/hkernel/utils.h new file mode 100644 index 0000000000..c27524c056 --- /dev/null +++ b/candle-flash-attn-v3/hkernel/utils.h @@ -0,0 +1,448 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include // For cute::elect_one_sync() + +#include +#include +#include +#include + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) + + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM90, convert acc_layout from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_transposed_rowcol(Layout acc_layout) { + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +// For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + auto l = logical_divide(get<0>(acc_layout), Shape{}); // (2, 2, (2, N / 16))) + return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout))); + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } + } +}; + +// Convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_Aregs_fp8(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + auto l = logical_divide(get<0>(acc_layout), Shape{}); // (2, 2, (2, N / 32))) + return make_layout(make_layout(Shape<_4, _2, _2>{}), + get<1>(acc_layout), + make_layout(get<2, 1>(l), get<2>(acc_layout))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Byte permute for fp8 kernel +template +CUTLASS_DEVICE void permute_regs_A_to_C(Fragment &accum) { + + auto data = accum.data(); + + #pragma unroll + for (int n = 0; n < size(accum); n += 8) { + uint32_t *data_32bit = reinterpret_cast(&data[n]); + auto upper = data_32bit[0]; + auto lower = data_32bit[1]; + data_32bit[0] = __byte_perm(upper, lower, 0x5410); + data_32bit[1] = __byte_perm(upper, lower, 0x7632); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + // Tensor out = make_tensor_like(tensor); + // cute::copy(make_tensor(make_rmem_ptr(&frag), tensor.layout()), out); + // return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void write_tma( + ElemO* O, const TMACopyO& tma_store_O, + const LayoutO& layout_O, const TileShapeO& tile_shape_O, + const SMemO& sO, int m_block, int bidh, int bidb, int n_split_idx, + const SeqLenTraits& seqlen_traits_o, int write_warp_idx) { + Tensor mO = tma_store_O.get_tma_tensor(layout_O.shape()); + Tensor gO = seqlen_traits_o.get_o_local_tile_tensor( + mO, tile_shape_O, bidh, bidb, n_split_idx + )(_, _, m_block); // (M, K) + auto block_tma_O = tma_store_O.get_slice(_0{}); + Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) + Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == write_warp_idx && lane_predicate) { + cute::copy(tma_store_O, tOsO, tOgO); + tma_store_arrive(); + } + // Note: no wait here. + // tma_store_wait<0>(); +} + +// Epilogue that copies RMEM -> GMEM directly for GQA enabled. +// Reports as uncoalesced stores by the profiler +template +__forceinline__ __device__ void write_rmem_to_gmem( + TensorO &tOrO, OutputType *O, const LayoutO& layout_O, TileShapeO tile_shape_O, + int m_block, int h_block, int bidh, int bidh_kv, int bidb, int n_split_idx, + TiledMma& tiled_mma, const SeqLenTraits& seqlen_traits_o, int thread_idx) { + static_assert(is_same_v, "rmem dtype must be float"); + Tensor mO = make_tensor(make_gmem_ptr(O), layout_O); + Tensor gO = [&] { + if constexpr(Use_gqa_layout) { + return seqlen_traits_o.get_o_local_tile_tensor( + mO, tile_shape_O, bidh_kv, bidb, n_split_idx + )(_, _, _, m_block, h_block); // (bM/bH, bH, K) + } else { + return seqlen_traits_o.get_o_local_tile_tensor( + mO, tile_shape_O, bidh, bidb, n_split_idx + )(_, _, m_block); // (bM, bK) + } + }(); + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + auto tile_shape_mnk = cute::tile_shape(tiled_mma); + Tensor cO = cute::make_identity_tensor(select<0, 1>(tile_shape_mnk)); + Tensor tOcO = thread_mma.partition_C(cO); + // tOcO has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. + Tensor tOcO_row = tOcO(make_coord(_0{}, _, _0{}), _, _0{}); + // reshape from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); + const int m_bound = seqlen_traits_o.actual_seq_len - m_block * size<0>(gO); + // hardcoded col_idx to circumvent reg spilling with counting tensor + const int col_start_idx = !Column_permute_fp8 ? 2 * (thread_idx % 4) : 4 * (thread_idx % 4); + + if constexpr (Use_gqa_layout) { + static constexpr int kBlockH = size<1>(gO); + const int h_bound = shape<1>(layout_O) - h_block * kBlockH; + #pragma unroll + for(int nrow = 0; nrow < size<0>(tOrO_rowcol); ++nrow) { + const int row = int(get<0>(tOcO_row(nrow))); + const int h_local = row % kBlockH; + const int m_local = row / kBlockH; + if(h_local < h_bound && m_local < m_bound) { + if constexpr(!Column_permute_fp8) { + Tensor tOrO_nrow_float2 = recast(tOrO_rowcol(nrow, _)); + #pragma unroll + for (int ncol = 0; ncol < size<1>(tOrO_rowcol)/2; ++ncol) { + *reinterpret_cast(&(gO(m_local, h_local, col_start_idx + 8 * ncol))) = + tOrO_nrow_float2(ncol); + } + } else { + Tensor tOrO_nrow = tOrO_rowcol(nrow, _); + #pragma unroll + for (int ncol = 0; ncol < size<1>(tOrO_rowcol); ncol += 4) { + gO(m_local, h_local, col_start_idx + 4 * ncol) = tOrO_nrow(ncol); + gO(m_local, h_local, col_start_idx + 4 * ncol + 2) = tOrO_nrow(ncol + 1); + gO(m_local, h_local, col_start_idx + 4 * ncol + 1) = tOrO_nrow(ncol + 2); + gO(m_local, h_local, col_start_idx + 4 * ncol + 3) = tOrO_nrow(ncol + 3); + } + } + } + } + } else { + #pragma unroll + for(int nrow = 0; nrow < size<0>(tOrO_rowcol); ++nrow) { + const int row = int(get<0>(tOcO_row(nrow))); + if(row < m_bound) { + if constexpr(!Column_permute_fp8) { + Tensor tOrO_nrow_float2 = recast(tOrO_rowcol(nrow, _)); + #pragma unroll + for (int ncol = 0; ncol < size<1>(tOrO_rowcol)/2; ++ncol) { + *reinterpret_cast(&(gO(row, col_start_idx + 8 * ncol))) = + tOrO_nrow_float2(ncol); + } + } else { + Tensor tOrO_nrow = tOrO_rowcol(nrow, _); + #pragma unroll + for (int ncol = 0; ncol < size<1>(tOrO_rowcol); ncol += 4) { + gO(row, col_start_idx + 4 * ncol) = tOrO_nrow(ncol); + gO(row, col_start_idx + 4 * ncol + 2) = tOrO_nrow(ncol + 1); + gO(row, col_start_idx + 4 * ncol + 1) = tOrO_nrow(ncol + 2); + gO(row, col_start_idx + 4 * ncol + 3) = tOrO_nrow(ncol + 3); + } + } + } + } + } +} + +template +__forceinline__ __device__ void write_tiled( + ElemO* O, const TiledCopyO& tiled_copy_O, + const LayoutO& layout_O, const TileShapeO& tile_shape_O, + const SMemO& sO, int m_block, int bidh, int bidb, + const SeqLenTraits& seqlen_traits_o) { + Tensor mO = make_tensor(make_gmem_ptr(O), layout_O); + Tensor gO = seqlen_traits_o.get_local_tile_tensor( + mO, tile_shape_O, bidh, bidb + )(_, _, m_block); // (M, K) + + ThrCopy thr_copy_O = tiled_copy_O.get_slice(threadIdx.x - NumCopyThreads); + Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K,k) + Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) + + // Prepare for TiledCopy. + // Grouping is needed because cute::copy_if() does group_modes<1, R> for src and dst. + // After grouping, the first dim is number of elements to read together. + Tensor tOsOFlatten = cute::flatten(tOsO); + Tensor tOsOGroup = cute::group_modes<1, rank(tOsOFlatten)>(tOsOFlatten); + Tensor tOgOFlatten = cute::flatten(tOgO); + Tensor tOgOGroup = cute::group_modes<1, rank(tOgOFlatten)>(tOgOFlatten); + + // Get thread coords to global index mapping. + Tensor gOCounting = cute::make_identity_tensor(gO.shape()); + Tensor tSgOCounting = thr_copy_O.partition_D(gOCounting); + Tensor tSgOCountingFlatten = cute::flatten(tSgOCounting); + Tensor tSgOCountingGrouped = + cute::group_modes<1, rank(tSgOCountingFlatten)>(tSgOCountingFlatten); + + // Write out to GMEM. + const int kNumMsPerTile = get<0>(tile_shape_O); + int cta_m = std::min( + seqlen_traits_o.actual_seq_len - m_block * kNumMsPerTile, kNumMsPerTile + ); + if (cta_m == kNumMsPerTile) { + copy(tiled_copy_O, tOsOGroup, tOgOGroup); + } else { + auto predicate_fn = [&](auto coords) { + auto s_coords = tSgOCountingGrouped(_0{}, coords); + return elem_less(get<0>(s_coords), cta_m); + }; + copy_if(tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup); + } +} + +template +__forceinline__ __device__ void write_O( + ElemO* O, const TMACopyO& tma_copy_O, const TiledCopyO& tiled_copy_O, + const LayoutO& layout_O, const TileShapeO& tile_shape_O, + const SMemO& sO, int m_block, int bidh, int bidb, int n_split_idx, + const SeqLenTraits& seqlen_traits_o, int write_warp_idx, TiledMma & tiledMma1, TensorO & tOrO) { + + if constexpr (IsRegToGmem) { + static_assert(Is_split, "use write_rmem_to_gmem with split kv kernel only"); + write_rmem_to_gmem(tOrO, O, layout_O, tile_shape_O, m_block, bidh, bidb, n_split_idx, + tiledMma1, seqlen_traits_o, threadIdx.x - NumCopyThreads); + } else if constexpr (IsTMACopy) { + write_tma(O, tma_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, + n_split_idx, seqlen_traits_o, write_warp_idx); + } else { + static_assert(!Is_split, "Don't use write_tiled with split kv kernel"); + write_tiled(O, tiled_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/candle-flash-attn-v3/src/ffi.rs b/candle-flash-attn-v3/src/ffi.rs new file mode 100644 index 0000000000..02bf43f697 --- /dev/null +++ b/candle-flash-attn-v3/src/ffi.rs @@ -0,0 +1,55 @@ +use core::ffi::{c_int, c_void}; + +extern "C" { + pub(crate) fn run_mha( + q_ptr: *const c_void, + k_ptr: *const c_void, + v_ptr: *const c_void, + o_ptr: *const c_void, + softmax_lse_ptr: *const c_void, + alibi_slopes_ptr: *const c_void, + + cu_seqlens_q_ptr: *const i32, + cu_seqlens_k_ptr: *const i32, + + q_batch_stride: u32, + k_batch_stride: u32, + v_batch_stride: u32, + o_batch_stride: u32, + alibi_slopes_batch_stride: u32, + + q_row_stride: u32, + k_row_stride: u32, + v_row_stride: u32, + o_row_stride: u32, + + q_head_stride: u32, + k_head_stride: u32, + v_head_stride: u32, + o_head_stride: u32, + + b: u32, + h: u32, + h_k: u32, + d: u32, + d_rounded: u32, + softmax_scale: f32, + + seqlen_q: u32, + seqlen_k: u32, + seqlen_q_rounded: u32, + seqlen_k_rounded: u32, + + is_bf16: c_int, + is_causal: c_int, + unpadded_lse: c_int, + use_gqa_packing: c_int, + + window_size_left: c_int, + window_size_right: c_int, + + total_q: u32, + total_k: u32, + ); + +} diff --git a/candle-flash-attn-v3/src/lib.rs b/candle-flash-attn-v3/src/lib.rs new file mode 100644 index 0000000000..91b7c63f30 --- /dev/null +++ b/candle-flash-attn-v3/src/lib.rs @@ -0,0 +1,918 @@ +mod ffi; + +use candle::backend::BackendStorage; +use candle::cuda_backend::cudarc::driver::DevicePtr; +use candle::cuda_backend::WrapErr; +use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; +use half::{bf16, f16}; + +fn round_multiple(x: usize, m: usize) -> usize { + (x + m - 1) / m * m +} + +pub struct FlashAttn { + pub softmax_scale: f32, + pub alibi_slopes: Option, + pub window_size_left: Option, + pub window_size_right: Option, + pub use_gqa_packing: bool, +} + +impl FlashAttn { + fn cuda_fwd_t< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + is_bf16: bool, + ) -> Result<(candle::CudaStorage, Shape)> { + // https://github.com/Dao-AILab/flash-attention/blob/0dfb28174333d9eefb7c1dd4292690a8458d1e89/hopper/flash_api.cpp + let dev = q.device(); + let out_shape = q_l.shape().clone(); + let out_l = Layout::contiguous(&out_shape); + + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let v = v.as_cuda_slice::()?; + let q = q.slice(q_l.start_offset()..); + let k = k.slice(k_l.start_offset()..); + let v = v.slice(v_l.start_offset()..); + + let q_stride = q_l.stride(); + let k_stride = k_l.stride(); + let v_stride = v_l.stride(); + let o_stride = out_l.stride(); + + let q_rank = q_stride.len(); + let k_rank = k_stride.len(); + let v_rank = v_stride.len(); + let o_rank = o_stride.len(); + + if q_rank != 4 || k_rank != 4 || v_rank != 4 { + candle::bail!( + "flash-attn expects input tensors of rank 4 (q: {q_rank}, k: {k_rank}, v: {v_rank}" + ) + } + if q_stride[q_rank - 1] != 1 { + candle::bail!("the last dim of q must be contiguous {q_stride:?}") + } + if k_stride[k_rank - 1] != 1 { + candle::bail!("the last dim of k must be contiguous {k_stride:?}") + } + if v_stride[v_rank - 1] != 1 { + candle::bail!("the last dim of v must be contiguous {v_stride:?}") + } + + let (b_sz, seqlen_q, num_heads, head_size_og) = q_l.shape().dims4()?; + let (_b_sz, seqlen_k, num_heads_k, _head_size_og) = k_l.shape().dims4()?; + let expected_kv = (b_sz, seqlen_k, num_heads_k, head_size_og); + if expected_kv != k_l.shape().dims4()? { + candle::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape()) + } + if expected_kv != v_l.shape().dims4()? { + candle::bail!("shape mismatch q {:?} and v {:?}", q_l.shape(), v_l.shape()) + } + if head_size_og > 256 { + candle::bail!("only supports head dimension at most 256 (got {head_size_og})") + } + if !(head_size_og == 256 || head_size_og == 128 || head_size_og == 64) { + candle::bail!("only supports head dimension 64, 128 and 256 (got {head_size_og})") + } + if head_size_og % 8 != 0 { + // TODO: Handle head sizes that are not a multiple of 8 via some padding. + candle::bail!("only supports head sizes that are a multiple of 8 (got {head_size_og})") + } + if num_heads % num_heads_k != 0 { + candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") + } + let use_gqa_packing = match num_heads_k / num_heads { + 2 | 4 | 8 | 16 | 32 => self.use_gqa_packing as i32, + _ => 0, + }; + + let stream = dev.cuda_stream(); + let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { + if alibi_slopes.dtype() != DType::F32 { + candle::bail!( + "DType mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes.dtype(), + DType::F32 + ); + } + + let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout(); + + if num_heads != alibi_slopes_layout.shape().dims1()? { + candle::bail!( + "shape mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes_layout.shape(), + (num_heads) + ); + } + + let alibi_slopes = match &*alibi_slopes { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("alibi_slopes must be a cuda tensor"), + }; + + let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); + + // Dropping the guard here doesn't seem very safe. + let (ptr, _guard) = alibi_slopes.device_ptr(&stream); + ptr as *const core::ffi::c_void + } else { + std::ptr::null() + }; + + // if window_size_left > self.max_seqlen_k or None => -1 + let mut window_size_left = self + .window_size_left + .filter(|v| v <= &seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + // if window_size_right > self.max_seqlen_k or None => -1 + let mut window_size_right = self + .window_size_right + .filter(|v| v <= &seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + let head_size = round_multiple(head_size_og, 8); + let head_size_rounded = round_multiple(head_size, 32); + let seqlen_q_rounded = round_multiple(seqlen_q, 128); + let seqlen_k_rounded = round_multiple(seqlen_k, 128); + + let elem_count = out_shape.elem_count(); + let dst = unsafe { dev.alloc::(elem_count) }?; + let softmax_lse = dev + .alloc_zeros::(b_sz * 128 * num_heads * seqlen_q)?; + + let is_bf16 = if is_bf16 { 1 } else { 0 }; + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + let is_causal = if window_size_left < 0 && window_size_right == 0 { + 1 + } else { + 0 + }; + if window_size_left < 0 && window_size_right >= 0 { + window_size_left = seqlen_k as i32; + } + if window_size_left >= 0 && window_size_right < 0 { + window_size_right = seqlen_k as i32; + } + + unsafe { + let (q_ptr, _guard) = q.device_ptr(&stream); + let (k_ptr, _guard) = k.device_ptr(&stream); + let (v_ptr, _guard) = v.device_ptr(&stream); + let (dst_ptr, _guard) = dst.device_ptr(&stream); + let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream); + ffi::run_mha( + q_ptr as *const core::ffi::c_void, + k_ptr as *const core::ffi::c_void, + v_ptr as *const core::ffi::c_void, + dst_ptr as *const core::ffi::c_void, + softmax_lse_ptr as *const core::ffi::c_void, + /* alibi_slopes_ptr */ alibi_slopes_ptr, + /* cu_seqlens_q_ptr */ std::ptr::null(), + /* cu_seqlens_k_ptr */ std::ptr::null(), + /* q_batch_stride */ q_stride[0] as u32, + /* k_batch_stride */ k_stride[0] as u32, + /* v_batch_stride */ v_stride[0] as u32, + /* o_batch_stride */ o_stride[0] as u32, + /* alibi_slopes_batch_stride */ 0, + /* q_row_stride */ q_stride[q_rank - 3] as u32, + /* k_row_stride */ k_stride[k_rank - 3] as u32, + /* v_row_stride */ v_stride[v_rank - 3] as u32, + /* o_row_stride */ o_stride[o_rank - 3] as u32, + /* q_head_stride */ q_stride[q_rank - 2] as u32, + /* k_head_stride */ k_stride[k_rank - 2] as u32, + /* v_head_stride */ v_stride[v_rank - 2] as u32, + /* o_head_stride */ o_stride[o_rank - 2] as u32, + /* b */ b_sz as u32, + /* h */ num_heads as u32, + /* h_k */ num_heads_k as u32, + /* d */ head_size as u32, + /* d_rounded */ head_size_rounded as u32, + /* softmax_scale*/ self.softmax_scale, + /* seqlen_q */ seqlen_q as u32, + /* seqlen_k */ seqlen_k as u32, + /* seqlen_q_rounded */ seqlen_q_rounded as u32, + /* seqlen_k_rounded */ seqlen_k_rounded as u32, + /* is_bf16 */ is_bf16, + /* is_causal */ is_causal, + /* unpadded_lse */ 0, + /* use_gqa_packing */ use_gqa_packing, + /* window_size_left */ window_size_left, + /* window_size_right */ window_size_right, + /* total_q, dummy */ 0u32, + /* total_k, dummy */ 0u32, + ) + } + + let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev.clone()); + Ok((dst, out_shape)) + } +} + +impl candle::CustomOp3 for FlashAttn { + fn name(&self) -> &'static str { + "flash-attn" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match q.dtype() { + candle::DType::F16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, false), + candle::DType::BF16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, true), + dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), + } + } +} + +/// Flash-attention v3 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `use_gqa_packing` - enables dedicated kernels for GQA packing if head sizes are compatible. + +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, + use_gqa_packing: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttn { + softmax_scale, + alibi_slopes: None, + window_size_left, + window_size_right, + use_gqa_packing, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v3 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// * `use_gqa_packing` - enables dedicated kernels for GQA packing if head sizes are compatible. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, + use_gqa_packing: bool, +) -> Result { + let op = FlashAttn { + softmax_scale, + alibi_slopes: None, + window_size_left, + window_size_right, + use_gqa_packing, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v3 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `use_gqa_packing` - enables dedicated kernels for GQA packing if head sizes are compatible. + +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + causal: bool, + use_gqa_packing: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttn { + softmax_scale, + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + use_gqa_packing, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v3 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// * `use_gqa_packing` - enables dedicated kernels for GQA packing if head sizes are compatible. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, + use_gqa_packing: bool, +) -> Result { + let op = FlashAttn { + softmax_scale, + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + use_gqa_packing, + }; + q.apply_op3(k, v, op) +} + +struct FlashAttnVarLen { + pub softmax_scale: f32, + pub max_seqlen_q: usize, + pub max_seqlen_k: usize, + pub seqlens_q: Tensor, + pub seqlens_k: Tensor, + pub alibi_slopes: Option, + pub window_size_left: Option, + pub window_size_right: Option, + pub use_gqa_packing: bool, +} + +impl FlashAttnVarLen { + fn cuda_fwd_t< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + is_bf16: bool, + ) -> Result<(candle::CudaStorage, Shape)> { + // https://github.com/Dao-AILab/flash-attention/blob/0dfb28174333d9eefb7c1dd4292690a8458d1e89/hopper/flash_api.cpp + let dev = q.device(); + let out_shape = q_l.shape().clone(); + let out_l = Layout::contiguous(&out_shape); + + let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout(); + let seqlens_q = match &*seqlens_q { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, // Should be i32! + _ => candle::bail!("seqlens_q must be a cuda tensor"), + }; + let seqlens_q = match seqlens_q_layout.contiguous_offsets() { + Some((o1, o2)) => seqlens_q.slice(o1..o2), + None => candle::bail!("seqlens_q has to be contiguous"), + }; + + let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout(); + let seqlens_k = match &*seqlens_k { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, // Should be i32! + _ => candle::bail!("seqlens_k must be a cuda tensor"), + }; + let seqlens_k = match seqlens_k_layout.contiguous_offsets() { + Some((o1, o2)) => seqlens_k.slice(o1..o2), + None => candle::bail!("seqlens_k has to be contiguous"), + }; + + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let v = v.as_cuda_slice::()?; + let q = q.slice(q_l.start_offset()..); + let k = k.slice(k_l.start_offset()..); + let v = v.slice(v_l.start_offset()..); + + let q_stride = q_l.stride(); + let k_stride = k_l.stride(); + let v_stride = v_l.stride(); + let o_stride = out_l.stride(); + + let q_rank = q_stride.len(); + let k_rank = k_stride.len(); + let v_rank = v_stride.len(); + let o_rank = o_stride.len(); + + if q_rank != 3 || k_rank != 3 || v_rank != 3 { + candle::bail!( + "flash-attn-varlen expects input tensors of rank 3 (q: {q_rank}, k: {k_rank}, v: {v_rank}" + ) + } + if q_stride[q_rank - 1] != 1 { + candle::bail!("the last dim of q must be contiguous {q_stride:?}") + } + if k_stride[k_rank - 1] != 1 { + candle::bail!("the last dim of k must be contiguous {k_stride:?}") + } + if v_stride[v_rank - 1] != 1 { + candle::bail!("the last dim of v must be contiguous {v_stride:?}") + } + + let (total_q, num_heads, head_size_og) = q_l.shape().dims3()?; + let (total_k, num_heads_k, _head_size_og) = k_l.shape().dims3()?; + let expected_kv = (total_k, num_heads_k, head_size_og); + if expected_kv != k_l.shape().dims3()? { + candle::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape()) + } + if expected_kv != v_l.shape().dims3()? { + candle::bail!("shape mismatch q {:?} and v {:?}", q_l.shape(), v_l.shape()) + } + if head_size_og > 256 { + candle::bail!("only supports head dimension at most 256 (got {head_size_og})") + } + if !(head_size_og == 256 || head_size_og == 128 || head_size_og == 64) { + candle::bail!("only supports head dimension 64, 128 and 256 (got {head_size_og})") + } + if head_size_og % 8 != 0 { + // TODO: Handle head sizes that are not a multiple of 8 via some padding. + candle::bail!("only supports head sizes that are a multiple of 8 (got {head_size_og})") + } + if num_heads % num_heads_k != 0 { + candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") + } + let use_gqa_packing = match num_heads_k / num_heads { + 2 | 4 | 8 | 16 | 32 => self.use_gqa_packing as i32, + _ => 0, + }; + + let nseqlens_q = seqlens_q_layout.shape().dims1()?; + if nseqlens_q < 2 { + candle::bail!("seqlens_q should have a len >= 2 {nseqlens_q}") + } + let nseqlens_k = seqlens_k_layout.shape().dims1()?; + if nseqlens_k != nseqlens_q { + candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}") + } + + let batch_size = nseqlens_q - 1; + + let stream = dev.cuda_stream(); + let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { + if alibi_slopes.dtype() != DType::F32 { + candle::bail!( + "DType mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes.dtype(), + DType::F32 + ); + } + + let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout(); + + if num_heads != alibi_slopes_layout.shape().dims1()? { + candle::bail!( + "shape mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes_layout.shape(), + (num_heads) + ); + } + + let alibi_slopes = match &*alibi_slopes { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("alibi_slopes must be a cuda tensor"), + }; + + let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); + + // Dropping the guard here doesn't seem very safe. + let (ptr, _guard) = alibi_slopes.device_ptr(&stream); + ptr as *const core::ffi::c_void + } else { + std::ptr::null() + }; + + // if window_size_left > self.max_seqlen_k or None => -1 + let mut window_size_left = self + .window_size_left + .filter(|v| v <= &self.max_seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + if window_size_left < self.max_seqlen_k as i32 { + window_size_left = self.max_seqlen_k.clone() as i32; + } + + // if window_size_right > self.max_seqlen_k or None => -1 + let mut window_size_right = self + .window_size_right + .filter(|v| v <= &self.max_seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + if window_size_right < self.max_seqlen_k as i32 { + window_size_right = self.max_seqlen_k.clone() as i32; + } + + let head_size = round_multiple(head_size_og, 8); + let head_size_rounded = round_multiple(head_size, 32); + let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128); + let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128); + + let elem_count = out_shape.elem_count(); + let dst = unsafe { dev.alloc::(elem_count) }?; + let softmax_lse = dev.alloc_zeros::(num_heads * total_q)?; + + let is_bf16 = if is_bf16 { 1 } else { 0 }; + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + let is_causal = if window_size_left < 0 && window_size_right == 0 { + 1 + } else { + 0 + }; + if window_size_left < 0 && window_size_right >= 0 { + window_size_left = self.max_seqlen_k as i32; + } + if window_size_left >= 0 && window_size_right < 0 { + window_size_right = self.max_seqlen_k as i32; + } + unsafe { + let (q_ptr, _guard) = q.device_ptr(&stream); + let (k_ptr, _guard) = k.device_ptr(&stream); + let (v_ptr, _guard) = v.device_ptr(&stream); + let (dst_ptr, _guard) = dst.device_ptr(&stream); + let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream); + let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream); + let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream); + ffi::run_mha( + q_ptr as *const core::ffi::c_void, + k_ptr as *const core::ffi::c_void, + v_ptr as *const core::ffi::c_void, + dst_ptr as *const core::ffi::c_void, + softmax_lse_ptr as *const core::ffi::c_void, + /* alibi_slopes_ptr */ alibi_slopes_ptr, + /* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32, + /* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32, + /* q_batch_stride */ 0, + /* k_batch_stride */ 0, + /* v_batch_stride */ 0, + /* o_batch_stride */ 0, + /* alibi_slopes_batch_stride */ 0, + /* q_row_stride */ q_stride[q_rank - 3] as u32, + /* k_row_stride */ k_stride[k_rank - 3] as u32, + /* v_row_stride */ v_stride[v_rank - 3] as u32, + /* o_row_stride */ o_stride[o_rank - 3] as u32, + /* q_head_stride */ q_stride[q_rank - 2] as u32, + /* k_head_stride */ k_stride[k_rank - 2] as u32, + /* v_head_stride */ v_stride[v_rank - 2] as u32, + /* o_head_stride */ o_stride[o_rank - 2] as u32, + /* b */ batch_size as u32, + /* h */ num_heads as u32, + /* h_k */ num_heads_k as u32, + /* d */ head_size as u32, + /* d_rounded */ head_size_rounded as u32, + /* softmax_scale*/ self.softmax_scale, + /* seqlen_q */ self.max_seqlen_q as u32, + /* seqlen_k */ self.max_seqlen_k as u32, + /* seqlen_q_rounded */ seqlen_q_rounded as u32, + /* seqlen_k_rounded */ seqlen_k_rounded as u32, + /* is_bf16 */ is_bf16, + /* is_causal */ is_causal, + /* unpadded_lse */ 1, + /* use_gqa_packing */ use_gqa_packing, + /* window_size_left */ window_size_left, + /* window_size_right */ window_size_right, + /* total_q */ total_q as u32, + /* total_k */ total_k as u32, + ) + } + + let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev.clone()); + Ok((dst, out_shape)) + } +} + +impl candle::CustomOp3 for FlashAttnVarLen { + fn name(&self) -> &'static str { + "flash-attn-varlen" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match q.dtype() { + candle::DType::F16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, false), + candle::DType::BF16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, true), + dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), + } + } +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v3 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `use_gqa_packing` - enables dedicated kernels for GQA packing if head sizes are compatible. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +pub fn flash_attn_varlen( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + causal: bool, + use_gqa_packing: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: None, + window_size_left, + window_size_right, + use_gqa_packing, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v3 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// * `use_gqa_packing` - enables dedicated kernels for GQA packing if head sizes are compatible. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, + use_gqa_packing: bool, +) -> Result { + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: None, + window_size_left, + window_size_right, + use_gqa_packing, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v3 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `use_gqa_packing` - enables dedicated kernels for GQA packing if head sizes are compatible. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +pub fn flash_attn_varlen_alibi( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + causal: bool, + use_gqa_packing: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + use_gqa_packing, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v3 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// * `use_gqa_packing` - enables dedicated kernels for GQA packing if head sizes are compatible. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_alibi_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, + use_gqa_packing: bool, +) -> Result { + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + use_gqa_packing, + }; + q.apply_op3(k, v, op) +} diff --git a/candle-flash-attn-v3/tests/flash_attn_tests.rs b/candle-flash-attn-v3/tests/flash_attn_tests.rs new file mode 100644 index 0000000000..7366096be8 --- /dev/null +++ b/candle-flash-attn-v3/tests/flash_attn_tests.rs @@ -0,0 +1,394 @@ +use anyhow::Result; +use candle::{DType, Device, IndexOp, Tensor, D}; +use baseten_candle_flash_attn_v3; +use rstest::rstest; + +fn to_vec3_round(t: Tensor, digits: i32) -> Result>>> { + let b = 10f32.powi(digits); + let t = t.to_vec3::()?; + let t = t + .iter() + .map(|t| { + t.iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect() + }) + .collect(); + Ok(t) +} + +fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result { + let in_dtype = q.dtype(); + let q = q.to_dtype(DType::F32)?; + let k = k.to_dtype(DType::F32)?; + let v = v.to_dtype(DType::F32)?; + let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?; + Ok(output) +} + +#[test] +fn flash_attn_acausal() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 3 * 2 * 64, &device)? + .to_dtype(DType::F16)? + .reshape((1, 3, 2, 64))?; + let k = (&q / 400.)?; + let v = (&q / 500.)?; + let q = (&q / 300.)?; + + let ys1 = fa_acausal(&q, &k, &v, 0.5)?; + let ys1 = ys1.i(0)?.to_dtype(DType::F32)?; + let ys2 = { + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + baseten_candle_flash_attn_v3::flash_attn(&q, &k, &v, 0.5, false, false)?.transpose(1, 2)? + }; + let ys2 = ys2.i(0)?.to_dtype(DType::F32)?; + let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?; + + assert_eq!(ys2.dims(), &[3, 2, 64]); + assert_eq!( + to_vec3_round(ys2, 4)?, + &[ + [ + [ + 0.0808, 0.0828, 0.0848, 0.0869, 0.0889, 0.0908, 0.0928, 0.0948, 0.0969, 0.0989, + 0.1008, 0.1028, 0.1049, 0.1069, 0.1088, 0.1108, 0.1129, 0.1149, 0.1168, 0.1188, + 0.1208, 0.1229, 0.1249, 0.1268, 0.1288, 0.1309, 0.1328, 0.1349, 0.1368, 0.1388, + 0.1409, 0.1428, 0.1449, 0.1469, 0.1488, 0.1509, 0.1528, 0.1548, 0.1569, 0.1588, + 0.1609, 0.1628, 0.1648, 0.1669, 0.1688, 0.1709, 0.1729, 0.1748, 0.1769, 0.1788, + 0.1809, 0.1829, 0.1848, 0.1869, 0.1888, 0.1908, 0.1929, 0.1948, 0.1969, 0.1989, + 0.2008, 0.2029, 0.205, 0.2069 + ], + [ + 0.1071, 0.1091, 0.1111, 0.113, 0.1151, 0.1171, 0.1191, 0.1211, 0.123, 0.1251, + 0.1271, 0.129, 0.1311, 0.1331, 0.135, 0.1371, 0.139, 0.1411, 0.1431, 0.145, + 0.1471, 0.149, 0.1511, 0.1531, 0.155, 0.1571, 0.1591, 0.1611, 0.1631, 0.165, + 0.1671, 0.1691, 0.1711, 0.1731, 0.175, 0.1771, 0.1791, 0.181, 0.1831, 0.1851, + 0.1871, 0.1891, 0.191, 0.1931, 0.1951, 0.1971, 0.1991, 0.201, 0.2031, 0.2051, + 0.2072, 0.2091, 0.2111, 0.2131, 0.2151, 0.217, 0.2191, 0.2211, 0.2231, 0.2251, + 0.2271, 0.229, 0.2312, 0.2332 + ] + ], + [ + [ + 0.3765, 0.3784, 0.3804, 0.3823, 0.3843, 0.3862, 0.3884, 0.3904, 0.3923, 0.3943, + 0.3962, 0.3984, 0.4004, 0.4023, 0.4043, 0.4063, 0.4084, 0.4104, 0.4124, 0.4143, + 0.4163, 0.4185, 0.4204, 0.4224, 0.4243, 0.4263, 0.4285, 0.4304, 0.4324, 0.4343, + 0.4363, 0.4385, 0.4404, 0.4424, 0.4443, 0.4463, 0.4485, 0.4504, 0.4524, 0.4543, + 0.4563, 0.4585, 0.4604, 0.4624, 0.4644, 0.4663, 0.4683, 0.4705, 0.4724, 0.4744, + 0.4763, 0.4783, 0.4805, 0.4824, 0.4844, 0.4863, 0.4883, 0.4905, 0.4922, 0.4946, + 0.4966, 0.4985, 0.5005, 0.5024 + ], + [ + 0.3816, 0.3835, 0.3855, 0.3875, 0.3894, 0.3914, 0.3936, 0.3955, 0.3975, 0.3994, + 0.4014, 0.4036, 0.4055, 0.4075, 0.4094, 0.4114, 0.4136, 0.4155, 0.4175, 0.4194, + 0.4214, 0.4236, 0.4255, 0.4275, 0.4294, 0.4314, 0.4336, 0.4355, 0.4375, 0.4395, + 0.4414, 0.4436, 0.4456, 0.4475, 0.4495, 0.4514, 0.4536, 0.4556, 0.4575, 0.4595, + 0.4614, 0.4636, 0.4656, 0.4675, 0.4695, 0.4714, 0.4734, 0.4756, 0.4775, 0.4795, + 0.4814, 0.4834, 0.4856, 0.4875, 0.4895, 0.4915, 0.4934, 0.4956, 0.4973, 0.4998, + 0.5015, 0.5034, 0.5054, 0.5073 + ] + ], + [ + [ + 0.6392, 0.6411, 0.6431, 0.6455, 0.6475, 0.6494, 0.6514, 0.6533, 0.6553, 0.6572, + 0.6592, 0.6611, 0.6631, 0.6655, 0.6675, 0.6694, 0.6714, 0.6733, 0.6753, 0.6772, + 0.6792, 0.6812, 0.6831, 0.6851, 0.6875, 0.6895, 0.6914, 0.6934, 0.6953, 0.6973, + 0.6992, 0.7012, 0.7031, 0.7051, 0.7075, 0.7095, 0.7114, 0.7134, 0.7153, 0.7173, + 0.7192, 0.7212, 0.7231, 0.7251, 0.7275, 0.7295, 0.7314, 0.7334, 0.7354, 0.7373, + 0.7393, 0.7412, 0.7432, 0.7451, 0.7476, 0.7495, 0.7515, 0.7534, 0.7554, 0.7573, + 0.7593, 0.7612, 0.7632, 0.7651 + ], + [ + 0.6396, 0.6416, 0.6436, 0.646, 0.6479, 0.6499, 0.6519, 0.6538, 0.6558, 0.6577, + 0.6597, 0.6616, 0.6636, 0.666, 0.668, 0.6699, 0.6719, 0.6738, 0.6758, 0.6777, + 0.6797, 0.6816, 0.6836, 0.6855, 0.688, 0.6899, 0.6919, 0.6938, 0.6958, 0.6978, + 0.6997, 0.7017, 0.7036, 0.7056, 0.708, 0.71, 0.7119, 0.7139, 0.7158, 0.7178, + 0.7197, 0.7217, 0.7236, 0.7256, 0.728, 0.73, 0.7319, 0.7339, 0.7358, 0.7378, + 0.7397, 0.7417, 0.7437, 0.7456, 0.748, 0.75, 0.752, 0.7539, 0.7559, 0.7578, + 0.7598, 0.7617, 0.7637, 0.7656 + ] + ] + ] + ); + assert!(diff.to_vec0::()?.abs() < 1e-5); + Ok(()) +} + +#[test] +fn flash_attn_acausal_gqa() -> Result<()> { + let device = Device::new_cuda(0)?; + let n_h = 4usize; + let n_h_k = 1usize; + + let q = Tensor::arange(0u32, (n_h * 2 * 64) as u32, &device)? + .to_dtype(DType::F16)? + .reshape((1, n_h, 2, 64))?; + let gqa = q.clone().i((.., ..n_h_k, .., ..))?; + assert_eq!(gqa.dims(), &[1, n_h_k, 2, 64]); + + let q = (q.clone() / 1000.)?; + let k_gqa = (&gqa / 400.)?; + let v_gqa = (&gqa / 500.)?; + + // let gqa_repeat = gqa.repeat((1, (n_h / n_h_k) as usize, 1, 1))?; + // assert_eq!(gqa_repeat.dims(), &[1, n_h, 2, 64]); + // let k = (&gqa_repeat / 400.)?; + // let v = (&gqa_repeat / 500.)?; + + // let ys1 = fa_acausal(&q, &k, &v, 0.5)?; + // let ys1 = ys1.i(0)?.to_dtype(DType::F32)?; + // assert_eq!(ys1.dims(), &[n_h, 2, 64]); + + let ys2 = { + let q = q.transpose(1, 2)?; + let k_gqa = k_gqa.transpose(1, 2)?; + let v_gqa = v_gqa.transpose(1, 2)?; + baseten_candle_flash_attn_v3::flash_attn(&q, &k_gqa, &v_gqa, 0.125, false, true)?.transpose(1, 2)? + }; + let ys2 = ys2.i(0)?.to_dtype(DType::F32)?; + assert_eq!(ys2.dims(), &[n_h, 2, 64]); + + assert_eq!( + to_vec3_round(ys2.clone(), 4)?, + &[ + [ + [ + 0.0653, 0.0673, 0.0693, 0.0713, 0.0734, 0.0753, 0.0773, 0.0793, 0.0813, 0.0834, + 0.0853, 0.0873, 0.0894, 0.0913, 0.0933, 0.0953, 0.0973, 0.0994, 0.1013, 0.1033, + 0.1053, 0.1073, 0.1094, 0.1113, 0.1133, 0.1154, 0.1173, 0.1194, 0.1213, 0.1233, + 0.1254, 0.1273, 0.1294, 0.1313, 0.1333, 0.1354, 0.1373, 0.1393, 0.1414, 0.1433, + 0.1454, 0.1473, 0.1493, 0.1514, 0.1533, 0.1554, 0.1573, 0.1593, 0.1614, 0.1633, + 0.1654, 0.1674, 0.1693, 0.1714, 0.1733, 0.1753, 0.1774, 0.1793, 0.1814, 0.1833, + 0.1853, 0.1874, 0.1895, 0.1914 + ], + [ + 0.0679, 0.0699, 0.072, 0.0739, 0.076, 0.0779, 0.0799, 0.082, 0.0839, 0.086, + 0.088, 0.0899, 0.092, 0.0939, 0.0959, 0.098, 0.0999, 0.102, 0.1039, 0.106, + 0.108, 0.1099, 0.112, 0.114, 0.1159, 0.118, 0.1199, 0.122, 0.124, 0.126, + 0.1279, 0.13, 0.132, 0.134, 0.136, 0.1379, 0.14, 0.142, 0.144, 0.146, 0.1479, + 0.1499, 0.152, 0.1539, 0.1559, 0.158, 0.1599, 0.162, 0.1639, 0.1659, 0.168, + 0.1699, 0.172, 0.174, 0.1759, 0.178, 0.1799, 0.182, 0.184, 0.1859, 0.188, + 0.1899, 0.192, 0.194 + ] + ], + [ + [ + 0.0706, 0.0725, 0.0746, 0.0765, 0.0786, 0.0806, 0.0825, 0.0846, 0.0865, 0.0886, + 0.0906, 0.0925, 0.0946, 0.0966, 0.0985, 0.1006, 0.1025, 0.1046, 0.1066, 0.1085, + 0.1106, 0.1125, 0.1146, 0.1166, 0.1185, 0.1206, 0.1226, 0.1246, 0.1266, 0.1285, + 0.1306, 0.1326, 0.1346, 0.1366, 0.1385, 0.1406, 0.1426, 0.1445, 0.1466, 0.1486, + 0.1506, 0.1526, 0.1545, 0.1566, 0.1586, 0.1606, 0.1626, 0.1646, 0.1666, 0.1686, + 0.1707, 0.1726, 0.1746, 0.1766, 0.1786, 0.1805, 0.1826, 0.1846, 0.1866, 0.1886, + 0.1906, 0.1925, 0.1947, 0.1967 + ], + [ + 0.0731, 0.0751, 0.0771, 0.0791, 0.0812, 0.0831, 0.0851, 0.0872, 0.0891, 0.0912, + 0.0931, 0.0951, 0.0972, 0.0991, 0.1011, 0.1031, 0.1051, 0.1072, 0.1091, 0.1111, + 0.1132, 0.1151, 0.1172, 0.1191, 0.1212, 0.1232, 0.1251, 0.1272, 0.1292, 0.1311, + 0.1332, 0.1351, 0.1372, 0.1392, 0.1411, 0.1432, 0.1451, 0.1471, 0.1492, 0.1511, + 0.1532, 0.1552, 0.1571, 0.1592, 0.1611, 0.1632, 0.1652, 0.1671, 0.1692, 0.1711, + 0.1732, 0.1752, 0.1771, 0.1792, 0.1812, 0.1831, 0.1852, 0.1871, 0.1892, 0.1912, + 0.1931, 0.1951, 0.1973, 0.1992 + ] + ], + [ + [ + 0.0757, 0.0776, 0.0797, 0.0817, 0.0837, 0.0857, 0.0876, 0.0897, 0.0917, 0.0938, + 0.0957, 0.0977, 0.0997, 0.1017, 0.1036, 0.1057, 0.1077, 0.1097, 0.1117, 0.1136, + 0.1157, 0.1177, 0.1198, 0.1217, 0.1237, 0.1257, 0.1277, 0.1298, 0.1317, 0.1337, + 0.1357, 0.1377, 0.1398, 0.1417, 0.1437, 0.1458, 0.1477, 0.1497, 0.1517, 0.1537, + 0.1558, 0.1577, 0.1597, 0.1617, 0.1637, 0.1658, 0.1677, 0.1697, 0.1718, 0.1737, + 0.1758, 0.1777, 0.1797, 0.1818, 0.1837, 0.1857, 0.1877, 0.1897, 0.1918, 0.1937, + 0.1957, 0.1976, 0.1998, 0.2018 + ], + [ + 0.0782, 0.0802, 0.0822, 0.0842, 0.0862, 0.0882, 0.0902, 0.0922, 0.0942, 0.0963, + 0.0982, 0.1002, 0.1022, 0.1042, 0.1062, 0.1082, 0.1102, 0.1122, 0.1142, 0.1162, + 0.1182, 0.1202, 0.1223, 0.1242, 0.1262, 0.1283, 0.1302, 0.1322, 0.1343, 0.1362, + 0.1383, 0.1403, 0.1422, 0.1443, 0.1462, 0.1482, 0.1503, 0.1522, 0.1543, 0.1563, + 0.1582, 0.1603, 0.1622, 0.1643, 0.1663, 0.1682, 0.1703, 0.1722, 0.1743, 0.1763, + 0.1782, 0.1803, 0.1823, 0.1843, 0.1863, 0.1882, 0.1903, 0.1923, 0.1943, 0.1963, + 0.1982, 0.2002, 0.2023, 0.2043 + ] + ], + [ + [ + 0.0807, 0.0826, 0.0847, 0.0867, 0.0887, 0.0907, 0.0927, 0.0947, 0.0967, 0.0987, + 0.1007, 0.1027, 0.1047, 0.1067, 0.1086, 0.1107, 0.1127, 0.1147, 0.1167, 0.1187, + 0.1207, 0.1227, 0.1247, 0.1267, 0.1287, 0.1307, 0.1327, 0.1348, 0.1367, 0.1387, + 0.1407, 0.1427, 0.1448, 0.1467, 0.1487, 0.1508, 0.1527, 0.1547, 0.1567, 0.1587, + 0.1608, 0.1627, 0.1647, 0.1667, 0.1687, 0.1708, 0.1727, 0.1747, 0.1768, 0.1787, + 0.1808, 0.1827, 0.1847, 0.1868, 0.1887, 0.1907, 0.1927, 0.1947, 0.1968, 0.1987, + 0.2007, 0.2026, 0.2048, 0.2068 + ], + [ + 0.0831, 0.0851, 0.0871, 0.0891, 0.0911, 0.0931, 0.0951, 0.0971, 0.0991, 0.1011, + 0.1031, 0.1051, 0.1071, 0.1091, 0.1111, 0.1131, 0.1151, 0.1171, 0.1191, 0.1211, + 0.1231, 0.1251, 0.1271, 0.1292, 0.1311, 0.1332, 0.1351, 0.1371, 0.1392, 0.1411, + 0.1432, 0.1451, 0.1471, 0.1492, 0.1511, 0.1531, 0.1552, 0.1571, 0.1592, 0.1611, + 0.1631, 0.1652, 0.1671, 0.1692, 0.1711, 0.1731, 0.1752, 0.1771, 0.1792, 0.1812, + 0.1831, 0.1852, 0.1871, 0.1891, 0.1912, 0.1931, 0.1952, 0.1971, 0.1991, 0.2012, + 0.2031, 0.2051, 0.2072, 0.2092 + ] + ] + ] + ); + Ok(()) +} + +#[test] +fn flash_attn_varlen() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 3 * 2 * 64, &device)? + .to_dtype(DType::F16)? + .reshape((3, 2, 64))?; + let k = (&q / 400.)?; + let v = (&q / 500.)?; + let q = (&q / 300.)?; + + let seqlens_q = Tensor::new(&[0u32, 2u32], &device)?; + // let seqlens_k: Tensor = Tensor::new(&[0u32, 3u32], &device)?; + + let ys = { + let q = q.transpose(0, 1)?; + let k = k.transpose(0, 1)?; + let v = v.transpose(0, 1)?; + baseten_candle_flash_attn_v3::flash_attn_varlen( + &q, &k, &v, &seqlens_q, &seqlens_q, 2, 2, 0.5, false, false, + )? + .transpose(0, 1)? + }; + let ys = ys.to_dtype(DType::F32)?; + + assert_eq!(ys.dims(), &[3, 2, 64]); + assert_eq!( + to_vec3_round(ys, 4)?, + &[ + [ + [ + 0.0808, 0.0828, 0.0848, 0.0869, 0.0889, 0.0908, 0.0928, 0.0948, 0.0969, 0.0989, + 0.1008, 0.1028, 0.1049, 0.1069, 0.1088, 0.1108, 0.1129, 0.1149, 0.1168, 0.1188, + 0.1208, 0.1229, 0.1249, 0.1268, 0.1288, 0.1309, 0.1328, 0.1349, 0.1368, 0.1388, + 0.1409, 0.1428, 0.1449, 0.1469, 0.1488, 0.1509, 0.1528, 0.1548, 0.1569, 0.1588, + 0.1609, 0.1628, 0.1648, 0.1669, 0.1688, 0.1709, 0.1729, 0.1748, 0.1769, 0.1788, + 0.1809, 0.1829, 0.1848, 0.1869, 0.1888, 0.1908, 0.1929, 0.1948, 0.1969, 0.1989, + 0.2008, 0.2029, 0.205, 0.2069 + ], + [ + 0.1071, 0.1091, 0.1111, 0.113, 0.1151, 0.1171, 0.1191, 0.1211, 0.123, 0.1251, + 0.1271, 0.129, 0.1311, 0.1331, 0.135, 0.1371, 0.139, 0.1411, 0.1431, 0.145, + 0.1471, 0.149, 0.1511, 0.1531, 0.155, 0.1571, 0.1591, 0.1611, 0.1631, 0.165, + 0.1671, 0.1691, 0.1711, 0.1731, 0.175, 0.1771, 0.1791, 0.181, 0.1831, 0.1851, + 0.1871, 0.1891, 0.191, 0.1931, 0.1951, 0.1971, 0.1991, 0.201, 0.2031, 0.2051, + 0.2072, 0.2091, 0.2111, 0.2131, 0.2151, 0.217, 0.2191, 0.2211, 0.2231, 0.2251, + 0.2271, 0.229, 0.2312, 0.2332 + ] + ], + [ + [ + 0.3765, 0.3784, 0.3804, 0.3823, 0.3843, 0.3862, 0.3884, 0.3904, 0.3923, 0.3943, + 0.3962, 0.3984, 0.4004, 0.4023, 0.4043, 0.4063, 0.4084, 0.4104, 0.4124, 0.4143, + 0.4163, 0.4185, 0.4204, 0.4224, 0.4243, 0.4263, 0.4285, 0.4304, 0.4324, 0.4343, + 0.4363, 0.4385, 0.4404, 0.4424, 0.4443, 0.4463, 0.4485, 0.4504, 0.4524, 0.4543, + 0.4563, 0.4585, 0.4604, 0.4624, 0.4644, 0.4663, 0.4683, 0.4705, 0.4724, 0.4744, + 0.4763, 0.4783, 0.4805, 0.4824, 0.4844, 0.4863, 0.4883, 0.4905, 0.4922, 0.4946, + 0.4966, 0.4985, 0.5005, 0.5024 + ], + [ + 0.3816, 0.3835, 0.3855, 0.3875, 0.3894, 0.3914, 0.3936, 0.3955, 0.3975, 0.3994, + 0.4014, 0.4036, 0.4055, 0.4075, 0.4094, 0.4114, 0.4136, 0.4155, 0.4175, 0.4194, + 0.4214, 0.4236, 0.4255, 0.4275, 0.4294, 0.4314, 0.4336, 0.4355, 0.4375, 0.4395, + 0.4414, 0.4436, 0.4456, 0.4475, 0.4495, 0.4514, 0.4536, 0.4556, 0.4575, 0.4595, + 0.4614, 0.4636, 0.4656, 0.4675, 0.4695, 0.4714, 0.4734, 0.4756, 0.4775, 0.4795, + 0.4814, 0.4834, 0.4856, 0.4875, 0.4895, 0.4915, 0.4934, 0.4956, 0.4973, 0.4998, + 0.5015, 0.5034, 0.5054, 0.5073 + ] + ], + [ + [ + 0.6392, 0.6411, 0.6431, 0.6455, 0.6475, 0.6494, 0.6514, 0.6533, 0.6553, 0.6572, + 0.6592, 0.6611, 0.6631, 0.6655, 0.6675, 0.6694, 0.6714, 0.6733, 0.6753, 0.6772, + 0.6792, 0.6812, 0.6831, 0.6851, 0.6875, 0.6895, 0.6914, 0.6934, 0.6953, 0.6973, + 0.6992, 0.7012, 0.7031, 0.7051, 0.7075, 0.7095, 0.7114, 0.7134, 0.7153, 0.7173, + 0.7192, 0.7212, 0.7231, 0.7251, 0.7275, 0.7295, 0.7314, 0.7334, 0.7354, 0.7373, + 0.7393, 0.7412, 0.7432, 0.7451, 0.7476, 0.7495, 0.7515, 0.7534, 0.7554, 0.7573, + 0.7593, 0.7612, 0.7632, 0.7651 + ], + [ + 0.6396, 0.6416, 0.6436, 0.646, 0.6479, 0.6499, 0.6519, 0.6538, 0.6558, 0.6577, + 0.6597, 0.6616, 0.6636, 0.666, 0.668, 0.6699, 0.6719, 0.6738, 0.6758, 0.6777, + 0.6797, 0.6816, 0.6836, 0.6855, 0.688, 0.6899, 0.6919, 0.6938, 0.6958, 0.6978, + 0.6997, 0.7017, 0.7036, 0.7056, 0.708, 0.71, 0.7119, 0.7139, 0.7158, 0.7178, + 0.7197, 0.7217, 0.7236, 0.7256, 0.728, 0.73, 0.7319, 0.7339, 0.7358, 0.7378, + 0.7397, 0.7417, 0.7437, 0.7456, 0.748, 0.75, 0.752, 0.7539, 0.7559, 0.7578, + 0.7598, 0.7617, 0.7637, 0.7656 + ] + ] + ] + ); + Ok(()) +} + +#[rstest( + head_dim => [64, 128, 256], + seq_len => [2, 4, 9], + use_gqa_packing => [false], // true does not make sense, as its reset to falser in the function +)] +fn flash_attn_varlen_param(head_dim: usize, seq_len: usize, use_gqa_packing: bool) -> Result<()> { + let device = Device::new_cuda(0)?; + + // Adjust the shape so it reflects seq_len. + // Here, we make q of shape (3, seq_len, head_dim). + let q = Tensor::arange(0u32, (3 * seq_len * head_dim) as u32, &device)? + .to_dtype(DType::F16)? + .reshape((3, seq_len, head_dim))?; + // divide by max value to have expected magnitude of error. + let k = (&q / ((head_dim * seq_len) as f64 * 4.))?; + let v = (&q / ((head_dim * seq_len) as f64 * 2.))?; + let q = (&q / ((head_dim * seq_len) as f64 * 3.))?; + + // For varlen, we need start/end offsets for each “batch element.” + // In this test, we have only 1 “batch element,” so let's do `[0, seq_len]`. + let seqlens_q = Tensor::new(&[0u32, seq_len as u32], &device)?; + let seqlens_k = Tensor::new(&[0u32, seq_len as u32], &device)?; + + let ys = { + let q = q.transpose(0, 1)?; + let k = k.transpose(0, 1)?; + let v = v.transpose(0, 1)?; + baseten_candle_flash_attn_v3::flash_attn_varlen( + &q, + &k, + &v, + &seqlens_q, + &seqlens_k, + seq_len, // max_seqlen_q + seq_len, // max_seqlen_k + 0.5, // softmax scale + false, // causal + use_gqa_packing, // use_gqa_packing + )? + .transpose(0, 1)? // bring it back to (3, seq_len, head_dim) + }; + let ys = ys.to_dtype(DType::F32)?; + + assert_eq!(ys.dims(), &[3, seq_len, head_dim]); + let ys2 = { + // reference implementation + let q = q.unsqueeze(0)?; + let k = k.unsqueeze(0)?; + let v = v.unsqueeze(0)?; + let y = fa_acausal(&q, &k, &v, 0.5)?; + y.i(0)?.to_dtype(DType::F32)? + }; + + let diff = ys.sub(&ys2)?.abs()?.flatten_all()?.max(0)?; + assert!(diff.to_vec0::()?.abs() < 5e-3); + Ok(()) +} diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 643783b350..3f90ec3a47 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -487,9 +487,9 @@ impl FlashAttnVarLen { None => candle::bail!("seqlens_k has to be contiguous"), }; - let q = q.as_cuda_slice::()?; - let k = k.as_cuda_slice::()?; - let v = v.as_cuda_slice::()?; + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let v = v.as_cuda_slice::()?; let q = q.slice(q_l.start_offset()..); let k = k.slice(k_l.start_offset()..); let v = v.slice(v_l.start_offset()..); @@ -604,7 +604,7 @@ impl FlashAttnVarLen { let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::(elem_count)? }; + let dst = unsafe { dev.alloc::(elem_count)? }; let softmax_lse = dev.alloc_zeros::(num_heads * total_q)?; let is_bf16 = if is_bf16 { 1 } else { 0 }; diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 90f5e7ba48..9259b3e1c7 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -139,3 +139,58 @@ CAST_OP(double, uint32_t, cast_f64_u32) CAST_OP(double, int64_t, cast_f64_i64 ) CAST_OP(double, float, cast_f64_f32) CAST_OP(double, double, cast_f64_f64) + +// I32 casts +CAST_OP(int32_t, uint8_t, cast_i32_u8 ) +CAST_OP(int32_t, uint32_t, cast_i32_u32) +CAST_OP(int32_t, int32_t, cast_i32_i32) +CAST_OP(int32_t, int64_t, cast_i32_i64) +CAST_OP(int32_t, float, cast_i32_f32) +CAST_OP(int32_t, double, cast_i32_f64) +#if __CUDA_ARCH__ >= 530 +CAST_OP(int32_t, __half, cast_i32_f16) +#endif +#if __CUDA_ARCH__ >= 800 +CAST_OP(int32_t, __nv_bfloat16, cast_i32_bf16) +#endif + +CAST_OP(uint8_t, int32_t, cast_u8_i32 ) +CAST_OP(uint32_t, int32_t, cast_u32_i32) +CAST_OP(int64_t, int32_t, cast_i64_i32) +CAST_OP(float, int32_t, cast_f32_i32) +CAST_OP(double, int32_t, cast_f64_i32) +#if __CUDA_ARCH__ >= 530 +CAST_OP(__half, int32_t, cast_f16_i32) +#endif +#if __CUDA_ARCH__ >= 800 +CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32) +#endif + +// I16 casts +CAST_OP(int16_t, uint8_t, cast_i16_u8 ) +CAST_OP(int16_t, uint32_t, cast_i16_u32) +CAST_OP(int16_t, int16_t, cast_i16_i16) +CAST_OP(int16_t, int32_t, cast_i16_i32) +CAST_OP(int16_t, int64_t, cast_i16_i64) +CAST_OP(int16_t, float, cast_i16_f32) +CAST_OP(int16_t, double, cast_i16_f64) +#if __CUDA_ARCH__ >= 530 +CAST_OP(int16_t, __half, cast_i16_f16) +#endif +#if __CUDA_ARCH__ >= 800 +CAST_OP(int16_t, __nv_bfloat16, cast_i16_bf16) +#endif + +CAST_OP(uint8_t, int16_t, cast_u8_i16 ) +CAST_OP(uint32_t, int16_t, cast_u32_i16) +CAST_OP(int32_t, int16_t, cast_i32_i16) +CAST_OP(int64_t, int16_t, cast_i64_i16) +CAST_OP(float, int16_t, cast_f32_i16) +CAST_OP(double, int16_t, cast_f64_i16) +#if __CUDA_ARCH__ >= 530 +CAST_OP(__half, int16_t, cast_f16_i16) +#endif +#if __CUDA_ARCH__ >= 800 +CAST_THROUGH_OP(__nv_bfloat16, int16_t, float, cast_bf16_i16) +#endif + diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 939990da9d..c03b1c1370 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1631,7 +1631,6 @@ impl ConstantValues { f } } - #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum SdpaDType { BF16, @@ -1652,174 +1651,208 @@ pub fn call_sdpa_full( kernels: &Kernels, q_offset: usize, q_shape: &[usize], + q_strides: &[usize], q_buffer: &Buffer, k_offset: usize, + k_shape: &[usize], + k_strides: &[usize], k_buffer: &Buffer, v_offset: usize, v_buffer: &Buffer, + v_strides: &[usize], + mask_type: Option, + mask_buffer: Option<&Buffer>, + m_strides: Option<&[usize]>, output: &Buffer, - alpha: f32, - softcapping: f32, + o_strides: &[usize], + scale: f32, + do_causal: bool, itype: SdpaDType, ) -> Result<(), MetalKernelError> { #[derive(Debug)] #[repr(C)] - struct MLXFastAttentionParams { - m: i32, - n: i32, - k: i32, - - ldq: i32, // ldq == ldo - ldk: i32, - ldv: i32, - lds: i32, - ldo: i32, - - tiles_n: i32, - tiles_m: i32, - - batch_stride_q: i32, - batch_stride_k: i32, - batch_stride_v: i32, - batch_stride_o: i32, - - swizzle_log: i32, - gemm_n_iterations_aligned: i32, - gemm_k_iterations_aligned: i32, - gemm_sv_m_block_iterations: i32, - - batch_ndim: i32, - alpha: f32, - softcapping: f32, + struct AttnParams { + b: i32, + h: i32, + d: i32, + ql: i32, + kl: i32, + gqa_factor: i32, + scale: f32, + nq: i32, + nk: i32, + nq_aligned: i32, + nk_aligned: i32, + ql_rem: i32, + kl_rem: i32, + ql_off: i32, + q_strides: [i64; 3], + k_strides: [i64; 3], + v_strides: [i64; 3], + o_strides: [i64; 3], } - let bk = q_shape.last().unwrap(); + #[derive(Debug)] + #[repr(C)] + struct AttnMaskParams { + m_strides: [i64; 3], + } - const BN: usize = 16; - const BM: usize = 16; - const WM: usize = 2; - const WN: usize = 2; + const WM: usize = 4; + const WN: usize = 1; + + const BQ: usize = 32; + let bd = q_shape[q_shape.len() - 1]; + if ![32, 64, 72, 80, 96, 128, 256].contains(&bd) { + return Err(MetalKernelError::SdpaHeadSizeMismatch { + variation: "full", + got: bd, + expected: vec![32, 64, 72, 80, 96, 128, 256], + }); + }; + let bk = if bd < 128 { 32 } else { 16 }; - let name = match (bk, itype) { - (32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half", - (64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half", - (96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half", - (128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half", - (256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half", - (32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float", - (64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float", - (96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float", - (128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float", - (256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float", - (other, SdpaDType::F16 | SdpaDType::F32) => { - return Err(MetalKernelError::SdpaHeadSizeMismatch { - variation: "full", - got: *other, - expected: vec![32, 64, 96, 128, 256], - }) - } - (_, SdpaDType::BF16) => { - return Err(MetalKernelError::SdpaHeadDTypeMismatch { - variation: "full", - got: SdpaDType::BF16, - }) - } + let b = q_shape[0]; + let h = q_shape[1]; + let d = q_shape[3]; + let gqa_factor = q_shape[1] / k_shape[1]; + + let ql = q_shape[2]; + let kl = k_shape[2]; + + let align_q = (ql % BQ) == 0; + let align_k = (kl % bk) == 0; + let has_mask = mask_buffer.is_some(); + + let itype_repr = match itype { + SdpaDType::BF16 => "bfloat16", + SdpaDType::F16 => "float16", + SdpaDType::F32 => "float32", + }; + let mask_repr = match mask_type { + Some(SdpaDType::BF16) => "bfloat16", + Some(SdpaDType::F16) => "float16", + Some(SdpaDType::F32) => "float32", + None => itype_repr, }; + let name = + format!("steel_attention_{itype_repr}_bq{BQ}_bk{bk}_bd{bd}_wm{WM}_wn{WN}_mask{mask_repr}"); - let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?; + let constants = Some(ConstantValues::new(vec![ + (200, Value::Bool(/* align_Q */ align_q)), + (201, Value::Bool(/* align_K */ align_k)), + (300, Value::Bool(/* has_mask */ has_mask)), + (301, Value::Bool(/* do_causal */ do_causal)), + ])); + + let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name, constants)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); - // q = (bs, qhead, seq, hidden) - // k/v = (bs, kv_head, seq, hidden) - - let qseq = q_shape[q_shape.len() - 2]; - - let m = q_shape[q_shape.len() - 2]; - let n = m; - let k = q_shape[q_shape.len() - 1]; - let bs_out = q_shape[0] * q_shape[1]; - - let batch_shape = [q_shape[0] * q_shape[1]]; - let dk = q_shape[q_shape.len() - 1]; - let ldq = dk; - let ldk = dk; - let ldv = dk; - let lds = BN; - let ldo = dk; - - let tn = 1; - let tm = m.div_ceil(BM); - - let b_stride_q = dk * qseq; - let b_stride_k = dk * qseq; - let b_stride_v = dk * qseq; - let b_stride_o = dk * qseq; - let swizzle_log = 0; - let gemm_n_iterations_aligned = n.div_ceil(BN); - let gemm_k_iterations_aligned = k.div_ceil(*bk); - let gemm_sv_m_block_iterations = m.div_ceil(BM); - let batch_ndim = batch_shape.len(); - - let alpha = if softcapping != 1. { - alpha / softcapping - } else { - alpha + let nq = (ql + BQ - 1) / BQ; + let nk = (kl + bk - 1) / bk; + + let nq_aligned = ql / BQ; + let nk_aligned = kl / bk; + + let params = AttnParams { + b: b as i32, + h: h as i32, + d: d as i32, + ql: ql as i32, + kl: kl as i32, + gqa_factor: gqa_factor as i32, + scale, + nq: nq as i32, + nk: nk as i32, + nq_aligned: nq_aligned as i32, + nk_aligned: nk_aligned as i32, + ql_rem: ql.wrapping_sub(nq_aligned * BQ) as i32, + kl_rem: kl.wrapping_sub(nk_aligned * bk) as i32, + ql_off: kl.wrapping_sub(ql) as i32, + q_strides: [ + q_strides[0] as i64, + q_strides[1] as i64, + q_strides[2] as i64, + ], + k_strides: [ + k_strides[0] as i64, + k_strides[1] as i64, + k_strides[2] as i64, + ], + v_strides: [ + v_strides[0] as i64, + v_strides[1] as i64, + v_strides[2] as i64, + ], + o_strides: [ + o_strides[0] as i64, + o_strides[1] as i64, + o_strides[2] as i64, + ], }; - let params = MLXFastAttentionParams { - m: m as i32, - n: n as i32, - k: k as i32, - ldq: ldq as i32, - ldk: ldk as i32, - ldv: ldv as i32, - lds: lds as i32, - ldo: ldo as i32, - tiles_n: tn, - tiles_m: tm as i32, - batch_stride_q: b_stride_q as i32, - batch_stride_k: b_stride_k as i32, - batch_stride_v: b_stride_v as i32, - batch_stride_o: b_stride_o as i32, - swizzle_log, - gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32, - gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32, - gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32, - batch_ndim: batch_ndim as i32, - alpha, - softcapping, - }; - let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o]; + impl EncoderParam for AttnParams { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of::() as u64, + &data as *const AttnParams as *const c_void, + ); + } + } - impl EncoderParam for MLXFastAttentionParams { + impl EncoderParam for AttnMaskParams { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, - core::mem::size_of::() as u64, - &data as *const MLXFastAttentionParams as *const c_void, + core::mem::size_of::() as u64, + &data as *const AttnMaskParams as *const c_void, ); } } - set_params!( - encoder, - ( - (q_buffer, q_offset), - (k_buffer, k_offset), - (v_buffer, v_offset), - output, - params, - &batch_shape[..], - &batch_strides[..] - ) - ); + if let Some(mask) = mask_buffer { + let mask_strides = m_strides.unwrap(); + let mask_params = AttnMaskParams { + m_strides: [ + mask_strides[0] as i64, + mask_strides[1] as i64, + mask_strides[2] as i64, + ], + }; + encoder.use_resource(mask, metal::MTLResourceUsage::Read); + + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + output, + params, + mask_params, + mask + ) + ); + } else { + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + output, + params + ) + ); + } let grid_dims = MTLSize { - width: 1, - height: tm as u64, - depth: bs_out as u64, + width: nq as u64, + height: h as u64, + depth: b as u64, }; let group_dims = MTLSize { width: 32, @@ -1831,10 +1864,11 @@ pub fn call_sdpa_full( encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_dims, group_dims); + Ok(()) } -/// SDPA full is supported when: +/// SDPA vector is supported when: /// - q head dim == 64, 96, 128 /// - no mask /// - q,k,v are contiguous @@ -1869,16 +1903,22 @@ pub fn call_sdpa_vector( let name = match (bk, itype) { (32, SdpaDType::F16) => "sdpa_vector_float16_t_32", (64, SdpaDType::F16) => "sdpa_vector_float16_t_64", + (72, SdpaDType::F16) => "sdpa_vector_float16_t_72", + (80, SdpaDType::F16) => "sdpa_vector_float16_t_80", (96, SdpaDType::F16) => "sdpa_vector_float16_t_96", (128, SdpaDType::F16) => "sdpa_vector_float16_t_128", (256, SdpaDType::F16) => "sdpa_vector_float16_t_256", (32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32", (64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64", + (72, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_72", + (80, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_80", (96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96", (128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128", (256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256", (32, SdpaDType::F32) => "sdpa_vector_float_32", (64, SdpaDType::F32) => "sdpa_vector_float_64", + (72, SdpaDType::F32) => "sdpa_vector_float_72", + (80, SdpaDType::F32) => "sdpa_vector_float_80", (96, SdpaDType::F32) => "sdpa_vector_float_96", (128, SdpaDType::F32) => "sdpa_vector_float_128", (256, SdpaDType::F32) => "sdpa_vector_float_256", @@ -1886,7 +1926,7 @@ pub fn call_sdpa_vector( return Err(MetalKernelError::SdpaHeadSizeMismatch { variation: "vector", got: *other, - expected: vec![32, 64, 96, 128, 256], + expected: vec![32, 64, 72, 80, 96, 128, 256], }) } }; @@ -1929,7 +1969,7 @@ pub fn call_sdpa_vector( let grid_dims = MTLSize { width: 1, height: b as u64, - depth: 1_u64, + depth: 1 as u64, }; let group_dims = MTLSize { width: 1024, @@ -2341,6 +2381,7 @@ pub enum GgmlDType { Q8K, F16, F32, + BF16, } #[allow(clippy::too_many_arguments)] @@ -2418,7 +2459,7 @@ pub fn call_quantized_matmul_mv_t( let align = 2; (nth0, nth1, align) } - GgmlDType::F16 | GgmlDType::Q8K => { + GgmlDType::F16 | GgmlDType::BF16 | GgmlDType::Q8K => { // Original implem uses rows let nth0 = 32; let nth1 = 1; @@ -2456,6 +2497,7 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32", GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32", GgmlDType::F16 => "kernel_mul_mv_f16_f32", + GgmlDType::BF16 => "kernel_mul_mv_bf16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", }; @@ -2496,6 +2538,114 @@ pub fn call_quantized_matmul_mv_t( Ok(()) } +/// - src0 is usually weight +/// - src1 is usually xs +#[allow(clippy::too_many_arguments)] +pub fn call_quantized_matmul_mm_t( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + dtype: GgmlDType, + src0_shape: &[usize], + src0_stride: &[usize], + src0: &Buffer, + src1_shape: &[usize], + src1_stride: &[usize], + src1: &Buffer, + src1_offset: usize, + dst_shape: &[usize], + dst_offset: usize, + dst: &Buffer, +) -> Result<(), MetalKernelError> { + // Everything is in reverse + let ne00 = src0_shape[src0_shape.len() - 1] as i64; + let ne01 = src0_shape[src0_shape.len() - 2] as i64; + let ne02 = src0_shape[src0_shape.len() - 3] as i64; + let ne03 = src0_shape[src0_shape.len() - 4] as i64; + + let nb01 = src0_stride[src0_stride.len() - 2] as i64; + let nb02 = src0_stride[src0_stride.len() - 3] as i64; + let nb03 = src0_stride[src0_stride.len() - 4] as i64; + + let ne11 = src1_shape[src1_shape.len() - 2] as i64; + let ne12 = src1_shape[src1_shape.len() - 3] as i64; + let ne13 = src1_shape[src1_shape.len() - 4] as i64; + + let nb10 = src1_stride[src1_stride.len() - 1] as i64; + let nb11 = src1_stride[src1_stride.len() - 2] as i64; + let nb12 = src1_stride[src1_stride.len() - 3] as i64; + let nb13 = src1_stride[src1_stride.len() - 4] as i64; + + let ne0 = dst_shape[dst_shape.len() - 1] as i64; + let ne1 = dst_shape[dst_shape.len() - 2] as i64; + let r2 = (ne12 / ne02) as u32; + let r3 = (ne13 / ne03) as u32; + + let thread_groups_count = MTLSize { + width: divide(ne11 as usize, 32), + height: divide(ne01 as usize, 64), + depth: (ne12 * ne13) as u64, + }; + let threads_per_threadgroup = MTLSize { + width: 128, + height: 1, + depth: 1, + }; + let name = match dtype { + GgmlDType::Q4_0 => "kernel_mul_mm_q4_0_f32", + GgmlDType::Q4_1 => "kernel_mul_mm_q4_1_f32", + GgmlDType::Q5_0 => "kernel_mul_mm_q5_0_f32", + GgmlDType::Q5_1 => "kernel_mul_mm_q5_1_f32", + GgmlDType::Q8_0 => "kernel_mul_mm_q8_0_f32", + GgmlDType::Q8_1 => "kernel_mul_mm_q8_1_f32", + GgmlDType::Q2K => "kernel_mul_mm_q2_K_f32", + GgmlDType::Q3K => "kernel_mul_mm_q3_K_f32", + GgmlDType::Q4K => "kernel_mul_mm_q4_K_f32", + GgmlDType::Q5K => "kernel_mul_mm_q5_K_f32", + GgmlDType::Q6K => "kernel_mul_mm_q6_K_f32", + GgmlDType::Q8K => "kernel_mul_mm_q8_K_f32", + GgmlDType::F16 => "kernel_mul_mm_f16_f32", + GgmlDType::BF16 => "kernel_mul_mm_bf16_f32", + GgmlDType::F32 => "kernel_mul_mm_f32_f32", + }; + + let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + src0, + (src1, src1_offset), + (dst, dst_offset), + ne00, + ne02, + nb01, + nb02, + nb03, + ne12, + nb10, + nb11, + nb12, + nb13, + ne0, + ne1, + r2, + r3 + ) + ); + encoder.use_resource(src0, metal::MTLResourceUsage::Read); + encoder.use_resource(src1, metal::MTLResourceUsage::Read); + encoder.use_resource(dst, metal::MTLResourceUsage::Write); + + encoder.set_threadgroup_memory_length(0, 8192); + + encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); + Ok(()) +} + fn divide(m: usize, b: usize) -> NSUInteger { m.div_ceil(b) as NSUInteger } diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal index fef6ac54f8..b463144f39 100644 --- a/candle-metal-kernels/src/quantized.metal +++ b/candle-metal-kernels/src/quantized.metal @@ -1,55 +1,1434 @@ -// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal #include using namespace metal; +#if defined(__HAVE_BFLOAT__) + +typedef bfloat bfloat16_t; + +#else + +///////////////////////////////////////////////////////////////////////////// +// Helpers +///////////////////////////////////////////////////////////////////////////// + +constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { + // Check for nan + if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > + _fp_encoding_traits::inf_mask) { + return uint16_t(as_type(0x7FC0)); + } + // Take bits + uint32_t float_bits = as_type(x); + + // Round to nearest even + float_bits += ((float_bits >> 16) & 1) + as_type(0x7FFF); + + // Take upper 16 bits + return float_bits >> 16; +} + +constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) { + // Upper 16 bits are the data and lower 16 bits are 0s + return as_type((uint32_t)x << 16); +} + +struct _MLX_BFloat16; + +template +static constexpr constant bool can_convert_to_bfloat = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_bfloat = + !is_same_v && is_convertible_v; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat struct +///////////////////////////////////////////////////////////////////////////// + +struct _MLX_BFloat16 { + ///////////////////////////////////////////////////////////////////////////// + // Constructors + uint16_t bits_; + _MLX_BFloat16() thread = default; + _MLX_BFloat16() threadgroup = default; + _MLX_BFloat16() device = default; + _MLX_BFloat16() constant = default; + + struct bits_to_bfloat_struct {}; + static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() { + return bits_to_bfloat_struct(); + } + constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct) + : bits_(bits) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions to bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) thread + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) device + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) constant + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions from bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const thread { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const threadgroup { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const device { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const constant { + return static_cast(bfloat_bits_to_float(bits_)); + } +}; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat operators +///////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////// +// Unary ops +constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { + return -static_cast(x); +} + +///////////////////////////////////////////////////////////////////////////// +// Binary operators +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +///////////////////////////////////////////////////////////////////////////// +// Arithmetic Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +///////////////////////////////////////////////////////////////////////////// +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop +#undef bfloat_binop_base +#undef bfloat_binop_helper +#undef bfloat_binop + +///////////////////////////////////////////////////////////////////////////// +// Inplace Operators +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype& __operator__( \ + addr_space itype& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); + +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ + bfloat_inplace_op_addr_space_helper(/, operator/=, itype); + +bfloat_inplace_op(float); +bfloat_inplace_op(half); +bfloat_inplace_op(int16_t); +bfloat_inplace_op(int32_t); +bfloat_inplace_op(int64_t); +bfloat_inplace_op(uint16_t); +bfloat_inplace_op(uint32_t); +bfloat_inplace_op(uint64_t); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper +#undef bfloat_inplace_op + +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, threadgroup); + +bfloat_inplace_op_addr_space_helper(+, operator+=); +bfloat_inplace_op_addr_space_helper(-, operator-=); +bfloat_inplace_op_addr_space_helper(*, operator*=); +bfloat_inplace_op_addr_space_helper(/, operator/=); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper + +///////////////////////////////////////////////////////////////////////////// +// Bfloat typedef +///////////////////////////////////////////////////////////////////////////// + +typedef struct _MLX_BFloat16 bfloat16_t; + +#endif + #define MAX(x, y) ((x) > (y) ? (x) : (y)) #define MIN(x, y) ((x) < (y) ? (x) : (y)) #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +typedef matrix bfloat4x4; + +// QK = number of values after dequantization +// QK_K = super-block size + +#define QK_K 256 +#define K_SCALE_SIZE 12 + #define QK4_0 32 -#define QR4_0 2 typedef struct { - half d; // delta + half d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(half) + QK4_0 / 2, "wrong q4_0 block size/padding"); #define QK4_1 32 typedef struct { - half d; // delta - half m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants + union { + struct { + half d; // delta + half m; // min + }; + half2 dm; + }; + uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; +static_assert(sizeof(block_q4_1) == 2 * sizeof(half) + QK4_1 / 2, "wrong q4_1 block size/padding"); #define QK5_0 32 typedef struct { - half d; // delta + half d; // delta uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_0 / 2]; // nibbles / quants } block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(half) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); #define QK5_1 32 typedef struct { - half d; // delta - half m; // min - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_1 / 2]; // nibbles / quants + union { + struct { + half d; // delta + half m; // min + }; + half2 dm; + }; + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants } block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); #define QK8_0 32 typedef struct { - half d; // delta + half d; // delta int8_t qs[QK8_0]; // quants } block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(half) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +typedef struct { + union { + struct { + half d; // delta + half s; // d * sum(qs[i]) + }; + half2 ds; + }; + int8_t qs[QK8_1]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(half) + QK8_1, "wrong q8_1 block size/padding"); + +typedef struct { + half d[4]; // deltas for 4 q4_0 blocks + uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks +} block_q4_0x4; +static_assert(sizeof(block_q4_0x4) == 4 * sizeof(half) + QK4_0 * 2, "wrong q4_0x4 block size/padding"); + +typedef struct { + half d[8]; // deltas for 8 q4_0 blocks + uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks +} block_q4_0x8; +static_assert(sizeof(block_q4_0x8) == 8 * sizeof(half) + QK4_0 * 4, "wrong q4_0x8 block size/padding"); + +typedef struct { + half d[4]; // deltas for 4 q8_0 blocks + int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks +} block_q8_0x4; +static_assert(sizeof(block_q8_0x4) == 4 * sizeof(half) + QK8_0 * 4, "wrong q8_0x4 block size/padding"); + +typedef struct { + half d[8]; // deltas for 8 q8_0 blocks + int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks +} block_q8_0x8; +static_assert(sizeof(block_q8_0x8) == 8 * sizeof(half) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); + +// +// Ternary quantization +// + +// 1.6875 bpw +typedef struct { + uint8_t qs[(QK_K - 4 * QK_K / 64) / 5]; // 5 elements per byte (3^5 = 243 < 256) + uint8_t qh[QK_K/64]; // 4 elements per byte + half d; +} block_tq1_0; +static_assert(sizeof(block_tq1_0) == sizeof(half) + QK_K / 64 + (QK_K - 4 * QK_K / 64) / 5, "wrong tq1_0 block size/padding"); + +// 2.0625 bpw +typedef struct { + uint8_t qs[QK_K/4]; // 2 bits per element + half d; +} block_tq2_0; +static_assert(sizeof(block_tq2_0) == sizeof(half) + QK_K / 4, "wrong tq2_0 block size/padding"); + +// +// Super-block quantization structures +// + +// 2-bit quantization +// weight is represented as x = a * q + b +// 16 blocks of 16 elements each +// Effectively 2.625 bits per weight +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + union { + struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + }; + half2 dm; + }; +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +// 3-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 3.4375 bits per weight +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[12]; // scales, quantized with 6 bits + half d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); + +// 4-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 4.5 bits per weight +typedef struct { + union { + struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + }; + half2 dm; + }; + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); + +// 5-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 5.5 bits per weight +typedef struct { + union { + struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + }; + half2 dm; + }; + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); + +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); + +// This is only used for intermediate quantization and dot products +typedef struct { + float d; // delta + int8_t qs[QK_K]; // quants + int16_t bsums[QK_K/16]; // sum of quants in groups of 16 +} block_q8_K; +static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); + +// (Almost) "true" 2-bit quantization. +// Due to the need to use blocks as per ggml design, it ends up using +// 2.0625 bpw because of the 16-bit scale for each block of 256. +typedef struct { + half d; + uint16_t qs[QK_K/8]; +} block_iq2_xxs; +static_assert(sizeof(block_iq2_xxs) == sizeof(half) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding"); + +// 2.3125 bpw quants +typedef struct { + half d; + uint16_t qs[QK_K/8]; + uint8_t scales[QK_K/32]; +} block_iq2_xs; +static_assert(sizeof(block_iq2_xs) == sizeof(half) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding"); + +// 2.5625 bpw quants +typedef struct { + half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t scales[QK_K/32]; +} block_iq2_s; +static_assert(sizeof(block_iq2_s) == sizeof(half) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding"); + +// (Almost) "true" 3-bit quantization. +// Due to the need to use blocks as per ggml design, it ends up using +// 3.0625 bpw because of the 16-bit scale for each block of 256. +typedef struct { + half d; + uint8_t qs[3*QK_K/8]; +} block_iq3_xxs; +static_assert(sizeof(block_iq3_xxs) == sizeof(half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); + +// 3.4375 bpw +#define IQ3S_N_SCALE QK_K/64 +typedef struct { + half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t signs[QK_K/8]; + uint8_t scales[IQ3S_N_SCALE]; +} block_iq3_s; +static_assert(sizeof(block_iq3_s) == sizeof(half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding"); + +// 1.5625 bpw +typedef struct { + half d; + uint8_t qs[QK_K/8]; + uint16_t qh[QK_K/32]; +} block_iq1_s; +static_assert(sizeof(block_iq1_s) == sizeof(half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding"); + +// 1.75 bpw +typedef struct { + uint8_t qs[QK_K/8]; // grid index, low 8 bits + uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8) + uint8_t scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64) +} block_iq1_m; +static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding"); + +// Used by IQ1_M quants +typedef union { + half f16; + uint16_t u16; +} iq1m_scale_t; + +// Non-linear quants +#define QK4_NL 32 +typedef struct { + half d; + uint8_t qs[QK4_NL/2]; +} block_iq4_nl; +static_assert(sizeof(block_iq4_nl) == sizeof(half) + QK4_NL/2, "wrong iq4_nl block size/padding"); + +typedef struct { + half d; + uint16_t scales_h; + uint8_t scales_l[QK_K/64]; + uint8_t qs[QK_K/2]; +} block_iq4_xs; +static_assert(sizeof(block_iq4_xs) == sizeof(half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); + +#define GGML_TABLE_BEGIN(type, name, size) static const constant type name[size] = { +#define GGML_TABLE_END() }; + +GGML_TABLE_BEGIN(uint8_t, kmask_iq2xs, 8) + 1, 2, 4, 8, 16, 32, 64, 128 +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128) + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, + 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, + 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, + 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, + 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, + 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, + 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, +GGML_TABLE_END() + +//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +GGML_TABLE_BEGIN(uint64_t, ksigns64, 128) + 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff, + 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff, + 0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff, + 0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff, + 0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff, + 0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff, + 0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff, + 0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff, + 0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff, + 0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff, + 0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff, + 0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff, + 0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff, + 0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff, + 0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff, + 0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff, + 0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff, + 0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff, + 0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff, + 0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff, + 0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff, + 0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff, + 0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff, + 0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff, + 0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff, + 0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff, + 0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff, + 0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff, + 0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff, + 0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff, + 0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff, + 0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff, +GGML_TABLE_END() +//#endif + + +GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256) + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, + 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, + 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, + 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, + 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, + 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, + 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, + 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, + 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, + 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, + 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, + 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, + 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, + 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, + 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, + 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, + 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, + 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, + 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, + 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, + 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, + 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, + 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, + 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, + 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, + 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, + 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, + 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, + 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, + 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, + 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, + 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, + 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, + 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, + 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, + 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, + 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, + 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, + 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, + 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, + 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, + 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, + 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, + 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, + 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, + 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, + 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, + 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, + 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, + 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, + 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, + 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, + 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, + 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, + 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, + 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, + 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, + 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint64_t, iq2xs_grid, 512) + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, + 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819, + 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819, + 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, + 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b, + 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b, + 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, + 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908, + 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919, + 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, + 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908, + 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, + 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, + 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08, + 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, + 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808, + 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819, + 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908, + 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808, + 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819, + 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808, + 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908, + 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19, + 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b, + 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b, + 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919, + 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, + 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819, + 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819, + 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b, + 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908, + 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808, + 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, + 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808, + 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, + 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808, + 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808, + 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908, + 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908, + 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808, + 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, + 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, + 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908, + 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808, + 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908, + 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919, + 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, + 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19, + 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b, + 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b, + 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808, + 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08, + 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b, + 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908, + 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b, + 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, + 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, + 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808, + 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808, + 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08, + 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819, + 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919, + 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808, + 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808, + 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819, + 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819, + 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908, + 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908, + 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b, + 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908, + 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, + 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908, + 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808, + 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, + 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819, + 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819, + 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808, + 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b, + 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819, + 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819, + 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08, + 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808, + 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19, + 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919, + 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, + 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, + 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b, + 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808, + 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b, + 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b, + 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, + 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808, + 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819, + 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808, + 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808, + 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, + 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19, + 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08, + 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, + 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08, + 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08, + 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908, + 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908, + 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b, + 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, + 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808, + 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b, + 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808, + 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808, + 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19, + 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08, + 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808, + 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b, + 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, + 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b, + 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint64_t, iq2s_grid, 1024) + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b, + 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919, + 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808, + 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908, + 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b, + 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908, + 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08, + 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19, + 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819, + 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919, + 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b, + 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, + 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908, + 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908, + 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b, + 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919, + 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b, + 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, + 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908, + 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b, + 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b, + 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08, + 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, + 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819, + 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808, + 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908, + 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b, + 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908, + 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08, + 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808, + 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08, + 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819, + 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908, + 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919, + 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b, + 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919, + 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808, + 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819, + 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919, + 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919, + 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808, + 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819, + 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b, + 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908, + 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, + 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, + 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919, + 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b, + 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919, + 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b, + 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819, + 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919, + 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908, + 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b, + 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908, + 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b, + 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908, + 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08, + 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908, + 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819, + 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819, + 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808, + 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08, + 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19, + 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819, + 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808, + 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819, + 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919, + 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808, + 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19, + 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08, + 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b, + 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908, + 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808, + 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819, + 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908, + 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808, + 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808, + 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819, + 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908, + 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08, + 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819, + 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b, + 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08, + 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19, + 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819, + 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919, + 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908, + 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808, + 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808, + 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908, + 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808, + 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08, + 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08, + 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908, + 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919, + 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808, + 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819, + 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908, + 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08, + 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819, + 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808, + 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808, + 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819, + 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808, + 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908, + 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b, + 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, + 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, + 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b, + 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808, + 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b, + 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19, + 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819, + 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08, + 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b, + 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908, + 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b, + 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b, + 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919, + 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808, + 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819, + 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908, + 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08, + 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08, + 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819, + 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919, + 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908, + 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b, + 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908, + 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b, + 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908, + 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08, + 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819, + 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808, + 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819, + 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919, + 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808, + 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808, + 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08, + 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819, + 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919, + 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808, + 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819, + 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919, + 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808, + 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b, + 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908, + 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808, + 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908, + 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b, + 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908, + 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b, + 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908, + 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b, + 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908, + 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08, + 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908, + 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b, + 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908, + 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08, + 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819, + 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919, + 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808, + 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19, + 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b, + 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919, + 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808, + 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819, + 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908, + 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919, + 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808, + 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808, + 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b, + 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919, + 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808, + 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b, + 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808, + 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919, + 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b, + 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08, + 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919, + 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808, + 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b, + 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908, + 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808, + 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808, + 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808, + 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908, + 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808, + 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808, + 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b, + 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908, + 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808, + 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808, + 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819, + 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b, + 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808, + 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819, + 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b, + 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908, + 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08, + 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908, + 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919, + 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819, + 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908, + 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808, + 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819, + 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908, + 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919, + 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808, + 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808, + 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808, + 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919, + 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908, + 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908, + 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08, + 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819, + 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b, + 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808, + 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819, + 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908, + 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819, + 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808, + 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808, + 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b, + 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908, + 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808, + 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908, + 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819, + 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819, + 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808, + 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b, + 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b, + 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819, + 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b, + 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b, + 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b, + 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819, + 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19, + 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819, + 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908, + 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808, + 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint32_t, iq3xxs_grid, 256) + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, +GGML_TABLE_END() + +#define NGRID_IQ1S 2048 +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f +GGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S) + 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, + 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, + 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, + 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, + 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, + 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, + 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, + 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, + 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, + 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, + 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, + 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, + 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, + 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, + 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, + 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, + 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, + 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, + 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, + 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, + 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, + 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, + 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, + 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, + 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, + 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, + 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, + 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, + 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, + 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, + 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, + 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, + 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, + 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, + 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, + 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, + 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, + 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, + 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, + 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, + 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, + 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, + 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, + 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, + 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, + 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, + 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, + 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, + 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, + 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, + 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, + 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, + 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, + 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, + 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, + 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, + 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, + 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, + 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, + 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, + 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, + 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, + 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, + 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, + 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, + 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, + 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, + 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, + 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, + 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, + 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, + 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, + 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, + 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, + 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, + 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, + 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, + 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, + 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, + 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, + 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, + 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, + 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, + 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, + 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, + 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, + 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, + 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, + 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, + 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, + 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, + 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, + 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, + 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, + 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, + 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, + 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, + 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, + 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, + 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, + 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, + 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, + 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, + 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, + 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, + 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, + 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, + 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, + 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, + 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, + 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, + 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, + 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, + 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, + 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, + 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, + 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, + 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, + 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, + 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, + 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, + 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, + 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, + 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, + 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, + 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, + 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, + 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, + 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, + 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, + 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, + 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, + 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, + 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, + 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, + 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, + 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, + 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, + 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, + 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, + 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, + 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, + 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, + 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, + 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, + 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, + 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, + 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, + 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, + 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, + 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, + 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, + 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, + 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, + 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, + 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, + 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, + 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, + 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, + 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, + 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, + 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, + 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, + 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, + 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, + 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, + 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, + 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, + 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, + 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, + 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, + 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, + 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, + 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, + 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, + 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, + 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, + 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, + 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, + 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, + 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, + 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, + 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, + 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, + 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, + 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, + 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, + 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, + 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, + 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, + 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, + 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, + 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, + 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, + 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, + 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, + 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, + 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, + 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, + 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, + 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, + 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, + 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, + 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, + 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, + 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, + 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, + 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, + 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, + 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, + 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, + 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, + 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, + 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, + 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, + 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, + 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, + 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, + 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, + 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, + 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, + 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, + 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, + 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, + 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, + 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, + 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, + 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, + 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, + 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, + 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, + 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, + 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, + 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, + 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, + 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, + 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, + 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, + 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, + 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, + 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, + 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, + 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, + 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, + 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, + 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, + 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, + 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, + 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, + 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, + 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, + 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, + 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, + 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, + 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, + 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, +GGML_TABLE_END() -#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 enum ggml_sort_order { - GGML_SORT_ASC, - GGML_SORT_DESC, + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, }; -// general-purpose kernel for addition, multiplication and division of two tensors +// general-purpose kernel for addition, subtraction, multiplication and division of two tensors // pros: works for non-contiguous tensors, supports broadcast across all dims // cons: not very efficient kernel void kernel_add( @@ -102,6 +1481,56 @@ kernel void kernel_add( } } +kernel void kernel_sub( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int64_t & offs, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10)); + } +} + kernel void kernel_mul( device const char * src0, device const char * src1, @@ -200,6 +1629,53 @@ kernel void kernel_div( } } +template +kernel void kernel_repeat( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3 % ne03; + const int64_t i02 = i2 % ne02; + const int64_t i01 = i1 % ne01; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i00 = i0 % ne00; + *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00)); + } +} + +typedef decltype(kernel_repeat) kernel_repeat_t; + +template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; + // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_add_row( @@ -211,6 +1687,15 @@ kernel void kernel_add_row( dst[tpig] = src0[tpig] + src1[tpig % nb]; } +kernel void kernel_sub_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] - src1[tpig % nb]; +} + kernel void kernel_mul_row( device const float4 * src0, device const float4 * src1, @@ -245,6 +1730,15 @@ kernel void kernel_scale_4( dst[tpig] = src0[tpig] * scale; } +kernel void kernel_clamp( + device const float * src0, + device float * dst, + constant float & min, + constant float & max, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); +} + kernel void kernel_relu( device const float * src0, device float * dst, @@ -252,6 +1746,13 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +kernel void kernel_sigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); +} + kernel void kernel_tanh( device const float * src0, device float * dst, @@ -265,6 +1766,15 @@ constant float GELU_QUICK_COEF = -1.702f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -278,6 +1788,15 @@ kernel void kernel_gelu( } kernel void kernel_gelu_quick( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -287,6 +1806,14 @@ kernel void kernel_gelu_quick( } kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -301,6 +1828,27 @@ kernel void kernel_sqr( dst[tpig] = src0[tpig] * src0[tpig]; } +kernel void kernel_sqrt( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sqrt(src0[tpig]); +} + +kernel void kernel_sin( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sin(src0[tpig]); +} + +kernel void kernel_cos( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = cos(src0[tpig]); +} + kernel void kernel_sum_rows( device const float * src0, device float * dst, @@ -349,15 +1897,20 @@ kernel void kernel_sum_rows( dst_row[0] = row_sum; } +template kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, - threadgroup float * buf [[threadgroup(0)]], + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], @@ -367,15 +1920,27 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } // parallel max float lmax = -INFINITY; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } // find the max value in the block @@ -400,7 +1965,7 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; } @@ -435,15 +2000,20 @@ kernel void kernel_soft_max( } } +template kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, - threadgroup float * buf [[threadgroup(0)]], + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], @@ -453,15 +2023,26 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + + float slope = 1.0f; + + if (max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } // parallel max float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -487,7 +2068,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -524,6 +2105,14 @@ kernel void kernel_soft_max_4( } } +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, @@ -569,12 +2158,133 @@ kernel void kernel_diag_mask_inf_8( } } -kernel void kernel_norm( +// ref: ggml.c:ggml_compute_forward_ssm_conv_f32 +// TODO: optimize +kernel void kernel_ssm_conv_f32( device const void * src0, + device const void * src1, device float * dst, constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, constant uint64_t & nb01, - constant float & eps, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = ne10; + const int64_t ncs = ne00; + const int64_t nr = ne01; + const int64_t n_t = ne1; + const int64_t n_s = ne2; + + device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02); + device const float * c = (device const float *) ((device const char *) src1 + ir*nb11); + device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// TODO: optimize +kernel void kernel_ssm_scan_f32( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device float * dst, + constant int64_t & d_state, + constant int64_t & d_inner, + constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb20, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb30, + constant uint64_t & nb31, + constant uint64_t & nb40, + constant uint64_t & nb41, + constant uint64_t & nb42, + constant uint64_t & nb50, + constant uint64_t & nb51, + constant uint64_t & nb52, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i3 = tgpig.y; + + const int64_t nc = d_state; + const int64_t nr = d_inner; + const int64_t n_t = n_seq_tokens; + const int64_t n_s = n_seqs; + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); + device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); + device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); + device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); + device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); + + if (i2 > 0) { + s0 = s; + } + + // i1 == 0 + float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + float x_dt = x[0] * dt_soft_plus; + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + int64_t i = i0; + float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; + } + + y[0] = sumf; + } +} + +kernel void kernel_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, threadgroup float * sum [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], @@ -863,6 +2573,7 @@ void mul_vec_q_n_f32_impl( int64_t ne1, uint r2, uint r3, + threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; @@ -939,7 +2650,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -965,7 +2676,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -991,7 +2702,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1017,7 +2728,7 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } @@ -1027,18 +2738,19 @@ void kernel_mul_mv_q8_0_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; @@ -1116,36 +2828,36 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } -#define N_F32_F32 4 +#define N_MV_T_T 4 -void kernel_mul_mv_f32_f32_impl( +template +void kernel_mul_mv_impl( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F32_F32; + const int64_t rb = tgpig.y*N_MV_T_T; const int64_t im = tgpig.z; const uint i12 = im%ne12; @@ -1153,20 +2865,20 @@ void kernel_mul_mv_f32_f32_impl( const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - device const float * x = (device const float *) (src0 + offset0); + device const T0 * x = (device const T0 *) (src0 + offset0); if (ne00 < 128) { - for (int row = 0; row < N_F32_F32; ++row) { + for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; if (r1 >= ne11) { break; } - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); float sumf = 0; for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; + sumf += (T0) x[i] * (T1) y[i]; } float all_sum = simd_sum(sumf); @@ -1175,32 +2887,32 @@ void kernel_mul_mv_f32_f32_impl( } } } else { - device const float4 * x4 = (device const float4 *)x; - for (int row = 0; row < N_F32_F32; ++row) { + device const T04 * x4 = (device const T04 *) x; + for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; if (r1 >= ne11) { break; } - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; + device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); + device const T14 * y4 = (device const T14 *) y; float sumf = 0; for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } } } } -[[host_name("kernel_mul_mv_f32_f32")]] -kernel void kernel_mul_mv_f32_f32( +template +kernel void kernel_mul_mv( device const char * src0, device const char * src1, device float * dst, @@ -1222,90 +2934,38 @@ kernel void kernel_mul_mv_f32_f32( constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); + kernel_mul_mv_impl( + src0, + src1, + dst, + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); } -#define N_F16_F16 4 - -kernel void kernel_mul_mv_f16_f16( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F16; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (half) x[i] * (half) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); - device const half4 * y4 = (device const half4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; - } +typedef decltype(kernel_mul_mv) mul_mv_t; - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv; -void kernel_mul_mv_f16_f32_1row_impl( +template +kernel void kernel_mul_mv_1row( device const char * src0, device const char * src1, device float * dst, @@ -1337,7 +2997,7 @@ void kernel_mul_mv_f16_f32_1row_impl( const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - device const half * x = (device const half *) (src0 + offset0); + device const T * x = (device const T *) (src0 + offset0); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); float sumf = 0; @@ -1350,48 +3010,29 @@ void kernel_mul_mv_f16_f32_1row_impl( dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } } else { - device const half4 * x4 = (device const half4 *) x; + device const T4 * x4 = (device const T4 *) x; device const float4 * y4 = (device const float4 *) y; + for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; + for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } } } -[[host_name("kernel_mul_mv_f16_f32_1row")]] -kernel void kernel_mul_mv_f16_f32_1row( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} +typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t; -#define N_F16_F32 4 +template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; -void kernel_mul_mv_f16_f32_impl( +// Assumes row size (ne00) is a multiple of 4 +template +kernel void kernel_mul_mv_l4( device const char * src0, device const char * src1, device float * dst, @@ -1414,8 +3055,8 @@ void kernel_mul_mv_f16_f32_impl( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { + const int nrows = ne11; const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F32; const int64_t im = tgpig.z; const uint i12 = im%ne12; @@ -1423,193 +3064,37 @@ void kernel_mul_mv_f16_f32_impl( const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - device const half * x = (device const half *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } + device const T4 * x4 = (device const T4 *) (src0 + offset0); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + for (int r1 = 0; r1 < nrows; ++r1) { + device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } + } +} - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; +typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } +template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -[[host_name("kernel_mul_mv_f16_f32")]] -kernel void kernel_mul_mv_f16_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} - -// Assumes row size (ne00) is a multiple of 4 -kernel void kernel_mul_mv_f16_f32_l4( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int nrows = ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half4 * x4 = (device const half4 *) (src0 + offset0); - - for (int r1 = 0; r1 < nrows; ++r1) { - device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -kernel void kernel_alibi_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant float & m0, - constant float & m1, - constant int & n_heads_log2_floor, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - const int64_t k = i3*ne3 + i2; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = pow(m0, k + 1); - } else { - m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1; - device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01; - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - const float src_v = *(device float *)(src_row + i00*nb00); - device float * dst_v = (device float *)(dst_row + i00*nb0); - *dst_v = i00 * m_k + src_v; - } -} - -static float rope_yarn_ramp(const float low, const float high, const int i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); } // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static void rope_yarn( float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - thread float * cos_theta, thread float * sin_theta -) { + thread float * cos_theta, thread float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -1626,21 +3111,23 @@ static void rope_yarn( // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); } static void rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] ) { // start and end correction dims - dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); - dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); } -typedef void (rope_t)( +template +kernel void kernel_rope_norm( device const void * src0, device const int32_t * src1, + device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1660,8 +3147,7 @@ typedef void (rope_t)( constant uint64_t & nb3, constant int & n_past, constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, + constant int & n_ctx_orig, constant float & freq_base, constant float & freq_scale, constant float & ext_factor, @@ -1670,12 +3156,55 @@ typedef void (rope_t)( constant float & beta_slow, uint tiitg[[thread_index_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]); + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int64_t i3 = tgpig[2]; + const int64_t i2 = tgpig[1]; + const int64_t i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + device const int32_t * pos = src1; + + const float theta_base = (float) pos[i2]; + const float inv_ndims = -1.f/n_dims; + + float cos_theta; + float sin_theta; + + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + if (i0 < n_dims) { + const int64_t ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} template -kernel void kernel_rope( +kernel void kernel_rope_neox( device const void * src0, device const int32_t * src1, + device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1695,8 +3224,7 @@ kernel void kernel_rope( constant uint64_t & nb3, constant int & n_past, constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, + constant int & n_ctx_orig, constant float & freq_base, constant float & freq_scale, constant float & ext_factor, @@ -1710,75 +3238,77 @@ kernel void kernel_rope( const int64_t i2 = tgpig[1]; const int64_t i1 = tgpig[0]; - const bool is_neox = mode & 2; - float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); device const int32_t * pos = src1; - const int64_t p = pos[i2]; - - const float theta_0 = (float)p; + const float theta_base = (float) pos[i2]; const float inv_ndims = -1.f/n_dims; - if (!is_neox) { - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - - const float theta = theta_0 * pow(freq_base, inv_ndims*i0); - float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const T x0 = src[0]; - const T x1 = src[1]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } else { - for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { - if (ic < n_dims) { - const int64_t ib = 0; + float cos_theta; + float sin_theta; - // simplified from `(ib * n_dims + ic) * inv_ndims` - const float cur_rot = inv_ndims*ic - ib; + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + if (i0 < n_dims) { + const int64_t ic = i0/2; - const float theta = theta_0 * pow(freq_base, cur_rot); - float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float theta = theta_base * pow(freq_base, inv_ndims*i0); - const int64_t i0 = ib*n_dims + ic/2; + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - const float x0 = src[0]; - const float x1 = src[n_dims/2]; + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - const int64_t i0 = ic; + const float x0 = src[0]; + const float x1 = src[n_dims/2]; - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } -template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; -template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; +typedef decltype(kernel_rope_norm) kernel_rope_norm_t; +typedef decltype(kernel_rope_neox) kernel_rope_neox_t; -kernel void kernel_im2col_f16( +template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; +template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; + +template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; + +typedef void (im2col_t)( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col( device const float * x, - device half * dst, + device char * dst, constant int32_t & ofs0, constant int32_t & ofs1, constant int32_t & IW, @@ -1801,14 +3331,98 @@ kernel void kernel_im2col_f16( (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); + device T * pdst = (device T *) (dst); + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; + pdst[offset_dst] = 0.0f; } else { const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; - dst[offset_dst] = x[offset_src + iih * IW + iiw]; + pdst[offset_dst] = x[offset_src + iih * IW + iiw]; + } +} + +template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; +template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; + +typedef void (im2col_ext_t)( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + constant int32_t & N, + constant int32_t & KH, + constant int32_t & KW, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col_ext( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + constant int32_t & N, + constant int32_t & KH, + constant int32_t & KW, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] + const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2] + + const int32_t d = tgpig[0] / CHW; + const int32_t chw = tgpig[0] % CHW; + const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int32_t HW = tgpig[0] % KHW; + + const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + if (tpitg_0 >= N) { + return; + } + + const int32_t tpitg_1 = HW / KW; + const int32_t tpitg_2 = HW % KW; + + const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0; + const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1; + + const int32_t offset_dst = + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + + (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + pdst[offset_dst] = 0.0f; + } else { + const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1; + pdst[offset_dst] = x[offset_src + iih * IW + iiw]; } } +template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; + kernel void kernel_upscale_f32( device const char * src0, device char * dst, @@ -1828,7 +3442,10 @@ kernel void kernel_upscale_f32( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, - constant int32_t & sf, + constant float & sf0, + constant float & sf1, + constant float & sf2, + constant float & sf3, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -1837,15 +3454,17 @@ kernel void kernel_upscale_f32( const int64_t i2 = tgpig.y; const int64_t i1 = tgpig.x; - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1/sf; - - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + const int64_t i03 = i3/sf3; + const int64_t i02 = i2/sf2; + const int64_t i01 = i1/sf1; for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = src0_ptr[i0/sf]; + const int64_t i00 = i0/sf0; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_ptr[0] = src0_ptr[0]; } } @@ -1900,46 +3519,100 @@ kernel void kernel_pad_f32( } } -// bitonic sort implementation following the CUDA kernels as reference -typedef void (argsort_t)( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]); - -template -kernel void kernel_argsort_f32_i32( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]) { - // bitonic sort - int col = tpitg[0]; - int row = tgpig[1]; - - if (col >= ncols) return; +kernel void kernel_arange_f32( + device char * dst, + constant int64_t & ne0, + constant float & start, + constant float & step, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { - device const float * x_row = x + row * ncols; - device int32_t * dst_row = dst + row * ncols; + device float * dst_ptr = (device float *) dst; - // initialize indices - if (col < ncols) { - dst_row[col] = col; + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = start + step * i0; + } +} + +kernel void kernel_timestep_embedding_f32( + device const char * src0, + device char * dst, + constant uint64_t & nb1, + constant int & dim, + constant int & max_period, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + int i = tgpig.x; + device float * embed_data = (device float *)(dst + i*nb1); + + int half_ = dim / 2; + for (int j = tpitg.x; j < half_; j += ntg.x) { + float timestep = ((device float *)src0)[i]; + float freq = (float)exp(-log((float)max_period) * j / half_); + float arg = timestep * freq; + embed_data[j ] = cos(arg); + embed_data[j + half_] = sin(arg); + } + + if (dim % 2 != 0 && tpitg.x == 0) { + embed_data[dim] = 0.f; } +} + +// bitonic sort implementation following the CUDA kernels as reference +typedef void (argsort_t)( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +template +kernel void kernel_argsort_f32_i32( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + // bitonic sort + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= ncols_pad) return; + + device const float * x_row = x + row * ncols; + threadgroup int32_t * dst_row = shared_values; + + // initialize indices + dst_row[col] = col; + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int k = 2; k <= ncols; k *= 2) { + for (int k = 2; k <= ncols_pad; k *= 2) { for (int j = k / 2; j > 0; j /= 2) { int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { SWAP(dst_row[col], dst_row[ixj]); } } else { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { SWAP(dst_row[col], dst_row[ixj]); } } @@ -1947,10 +3620,15 @@ kernel void kernel_argsort_f32_i32( threadgroup_barrier(mem_flags::mem_threadgroup); } } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } } -template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; -template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; kernel void kernel_leaky_relu_f32( device const float * src0, @@ -1960,229 +3638,763 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } -kernel void kernel_cpy_f16_f16( - device const half * src0, - device half * dst, - constant int64_t & ne00, +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant int64_t & ne0, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, constant int64_t & ne1, constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f16_f32( - device const half * src0, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + constant float & logit_softcap, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]); + +// ref: https://arxiv.org/pdf/2307.08691.pdf +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, device float * dst, - constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant int64_t & ne0, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, constant int64_t & ne1, constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + constant float & logit_softcap, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short D8 = D/8; + //const short Q8 = Q/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short TF = T/2; // shared memory size per query in (float) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; + + // load heads from Q to shared memory + for (short j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 + j < ne01) { + sq4[j*T4 + i] = (half4) q4[i]; + } else { + sq4[j*T4 + i] = 0.0h; + } + } + } - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + // zero out lo + for (short i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); + } - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + // zero out shared memory SH + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { + ss[j*TF + i] = 0.0f; + } + } - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} + { + float S[Q] = { [0 ... Q-1] = 0.0h }; + float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; -kernel void kernel_cpy_f32_f16( - device const float * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + // k indices + const short ik2 = iq2/rk2; + const short ik3 = iq3/rk3; - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + // v indices + const short iv2 = iq2/rv2; + const short iv3 = iq3/rv3; - dst_data[i00] = src[0]; - } -} + // load the queries from shared memory into local memory + simdgroup_half8x8 mq[D8]; -kernel void kernel_cpy_f32_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[i], sq + i*8, T); + } - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + // pointer to the mask + device const half * mp = (device const half *) (mask + iq1*nb31); - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + float slope = 1.0f; - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - dst_data[i00] = src[0]; - } -} + slope = pow(base, exph); + } -kernel void kernel_cpy_f32_q8_0( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + // Q*K^T + { + for (short cc = 0; cc < C/8; ++cc) { + simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0; + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } - float amax = 0.0f; // absolute max + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + } + } - for (int j = 0; j < QK8_0; j++) { - const float v = src[j]; - amax = MAX(amax, fabs(v)); - } + // used to detect blocks full of -INF + float smax = -INFINITY; - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; + // online softmax + { + float ms[Q]; - dst_data[i00/QK8_0].d = d; + for (short j = 0; j < Q; ++j) { + const float m = M[j]; - for (int j = 0; j < QK8_0; ++j) { - const float x0 = src[j]*id; + // scale and apply the logitcap / mask + float s = ss[j*TF + tiisg]*scale; - dst_data[i00/QK8_0].qs[j] = round(x0); + if (logit_softcap != 0.0f) { + s = logit_softcap*precise::tanh(s); + } + + if (mask != q) { + // mqk = mqk + mask*slope + s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; + } + + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + + ms[j] = exp(m - M[j]); + const float vs = exp(s - M[j]); + + S[j] = S[j]*ms[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*TF + tiisg] = vs; + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*TF + C + tiisg] = ms[tiisg]; + } + } + + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } + + // O = diag(ms)*O + { + simdgroup_float8x8 mm; + simdgroup_load(mm, ss + C, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_multiply(lo[i], mm, lo[i]); + } + } + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); + + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*cc, TF, 0, false); + + simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); + } + } + } } - } -} + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (short j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*TF + 0] = S[j]; + ss[j*TF + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (short sg = 1; sg < nsg; ++sg) { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + const float S0 = ss[j*TF + 0]; + const float S1 = ss[j*TF + sg*SH + 0]; + + const float M0 = ss[j*TF + 1]; + const float M1 = ss[j*TF + sg*SH + 1]; + + M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*TF + 0] = S; + ss[j*TF + 1] = M; + + ss[j*TF + C + j ] = ms0; + ss[j*TF + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_float8x8 ms0; + simdgroup_float8x8 ms1; + + simdgroup_load(ms0, ss + C, TF, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } + } + } + + // store result to shared memory (reuse sq) + if (sgitg == 0) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + const float S = ss[j*TF + 0]; + + for (short i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + } + } + } +} + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_vec_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + constant float & logit_softcap, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short D4 = D/4; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half4 lo[D4/NW]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + float4 mq[D4]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = (float4) sq4[i]; + } + + // pointer to the mask + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + float4 mqk = { 0.0h }; + + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + float4x4 mk; + mk[0] = (float4) pk4[i + 0*(nb11/8)]; + mk[1] = (float4) pk4[i + 1*(nb11/8)]; + mk[2] = (float4) pk4[i + 2*(nb11/8)]; + mk[3] = (float4) pk4[i + 3*(nb11/8)]; + + mqk += (float4) (mq[i] * mk); + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask*slope + if (tiisg == 0) { + mqk *= scale; + + if (logit_softcap != 0.0f) { + mqk = logit_softcap*precise::tanh(mqk); + } + + mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f; + + ss4[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const float m = M; + const float s = ss[p]; + + M = simd_max(max(M, s)); + + const float ms = exp(m - M); + const float vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; + + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; + + const float M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; +//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; + +template +kernel void kernel_cpy( + device const void * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = (T1) src[0]; + } +} + +typedef decltype(kernel_cpy) kernel_cpy_t; + +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; + +kernel void kernel_cpy_f32_q8_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0; + + device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = src[j]; + amax = MAX(amax, fabs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK8_0].d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = src[j]*id; + + dst_data[i00/QK8_0].qs[j] = round(x0); + } + } +} kernel void kernel_cpy_f32_q4_0( device const float * src0, @@ -2317,13 +4529,249 @@ kernel void kernel_cpy_f32_q4_1( } } -kernel void kernel_concat( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, +kernel void kernel_cpy_f32_q5_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0; + + device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK5_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_0].d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK5_0/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_0].qh[j] = qh8[j]; + } + } +} + +kernel void kernel_cpy_f32_q5_1( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1; + + device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float max = src[0]; + float min = src[0]; + + for (int j = 1; j < QK5_1; j++) { + const float v = src[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_1].d = d; + dst_data[i00/QK5_1].m = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_1].qh[j] = qh8[j]; + } + } +} + +static inline int best_index_int8(int n, constant float * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +constexpr constant static float kvalues_iq4nl_f[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +kernel void kernel_cpy_f32_iq4_nl( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL; + + device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / kvalues_iq4nl_f[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_NL/2 + j]*id; + + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); + + dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); + + const float v0 = kvalues_iq4nl_f[xi0]; + const float v1 = kvalues_iq4nl_f[xi1]; + const float w0 = src[0 + j]*src[0 + j]; + const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; + sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + + } + + dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; + + } +} + +kernel void kernel_concat( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, constant int64_t & ne03, constant uint64_t & nb00, constant uint64_t & nb01, @@ -2345,126 +4793,50 @@ kernel void kernel_concat( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, + constant int32_t & dim, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + device const float * x; for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i02 < ne02) { - ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; - src0_ptr += ntg.x*nb00; + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); } else { - ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; - src1_ptr += ntg.x*nb10; + x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); } - dst_ptr += ntg.x*nb0; - } -} - -//============================================ k-quants ====================================================== - -#ifndef QK_K -#define QK_K 256 -#else -static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); -#endif - -#if QK_K == 256 -#define K_SCALE_SIZE 12 -#else -#define K_SCALE_SIZE 4 -#endif - -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins -} block_q2_K; -// 84 bytes / block - -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits -#if QK_K == 64 - uint8_t scales[2]; -#else - uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits -#endif - half d; // super-block scale -} block_q3_K; - -#if QK_K == 64 -typedef struct { - half d[2]; // super-block scales/mins - uint8_t scales[2]; - uint8_t qs[QK_K/2]; // 4-bit quants -} block_q4_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -#endif - -#if QK_K == 64 -typedef struct { - half d; // super-block scales/mins - int8_t scales[QK_K/16]; // 8-bit block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -// 176 bytes / block -#endif -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - half d; // super-block scale -} block_q6_K; -// 210 bytes / block + device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); -//====================================== dot products ========================= + *y = *x; + } +} void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -2487,7 +4859,6 @@ void kernel_mul_mv_q2_K_f32_impl( const int step = sizeof(block_q2_K) * nb; -#if QK_K == 256 const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 const int iq = it/4; // 0 or 1 @@ -2539,65 +4910,14 @@ void kernel_mul_mv_q2_K_f32_impl( y4 += 4 * QK_K; } -#else - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0...1 - device const float * y4 = y + ix * QK_K + 8 * it; - - for (int ib = ix; ib < nb; ib += 16) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 16 * QK_K; - } -#endif - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} [[host_name("kernel_mul_mv_q2_K_f32")]] kernel void kernel_mul_mv_q2_K_f32( @@ -2624,26 +4944,26 @@ kernel void kernel_mul_mv_q2_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } -#if QK_K == 256 void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; @@ -2785,83 +5105,6 @@ void kernel_mul_mv_q3_K_f32_impl( } } } -#else -void kernel_mul_mv_q3_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const int row = 2 * r0 + sgitg; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - const int ix = tiisg/4; - const int il = 4 * (tiisg%4);// 0, 4, 8, 12 - const int iq = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - float2 sum = {0.f, 0.f}; - - for (int i = ix; i < nb; i += 8) { - - const float d_all = (float)(x[i].d); - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); - device const uint16_t * s = (device const uint16_t *)(x[i].scales); - device const float * y = yy + i * QK_K + il; - - const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); - const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; - const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; - const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; - - for (int l = 0; l < 4; l += 2) { - const uint16_t hm = h[l/2] >> iq; - sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) - + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) - + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) - + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); - sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) - + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) - + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) - + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); - } - - } - const float sumf = sum[0] + sum[1] * 1.f/256.f; - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } - -} -#endif [[host_name("kernel_mul_mv_q3_K_f32")]] kernel void kernel_mul_mv_q3_K_f32( @@ -2888,26 +5131,26 @@ kernel void kernel_mul_mv_q3_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } -#if QK_K == 256 void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -3004,102 +5247,6 @@ void kernel_mul_mv_q4_K_f32_impl( } } } -#else -void kernel_mul_mv_q4_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int ix = tiisg/4; // 0...7 - const int it = tiisg%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[8]; - float yh[8]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 8 * it; - - uint16_t sc16[4]; - - for (int ib = ix; ib < nb; ib += 8) { - - float2 sumy = {0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i] = y4[i+ 0]; sumy[0] += yl[i]; - yh[i] = y4[i+32]; sumy[1] += yh[i]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & 0x000f; - sc16[1] = sc[0] & 0x0f00; - sc16[2] = sc[0] & 0x00f0; - sc16[3] = sc[0] & 0xf000; - - float2 acc1 = {0.f, 0.f}; - float2 acc2 = {0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); - acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); - acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); - acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + - (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - - dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); - - qs += step; - sc += step; - dh += step; - } - - y4 += 8 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#endif [[host_name("kernel_mul_mv_q4_K_f32")]] kernel void kernel_mul_mv_q4_K_f32( @@ -3126,25 +5273,26 @@ kernel void kernel_mul_mv_q4_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q5_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; @@ -3166,8 +5314,6 @@ void kernel_mul_mv_q5_K_f32_impl( const int step = sizeof(block_q5_K) * nb; -#if QK_K == 256 -# float yl[16], yh[16]; const uint16_t kmask1 = 0x3f3f; @@ -3250,54 +5396,6 @@ void kernel_mul_mv_q5_K_f32_impl( y1 += 4 * QK_K; } -#else - float yl[8], yh[8]; - - const int il = 4 * (tiisg/8); // 0, 4, 8, 12 - const int ix = tiisg%8; - const int iq = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - device const float * y = yy + ix*QK_K + il; - - for (int i = ix; i < nb; i += 8) { - - for (int l = 0; l < 4; ++l) { - yl[l+0] = y[l+ 0]; - yl[l+4] = y[l+16]; - yh[l+0] = y[l+32]; - yh[l+4] = y[l+48]; - } - - device const half * dh = &x[i].d; - device const uint8_t * q = x[i].qs + il; - device const uint8_t * h = x[i].qh + in; - device const int8_t * s = x[i].scales; - - for (int row = 0; row < 2; ++row) { - - const float d = dh[0]; - - float2 acc = {0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - const uint8_t hl = h[l] >> iq; - acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) - + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); - acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) - + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); - } - sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); - - q += step; - h += step; - s += step; - dh += step/2; - - } - - y += 8 * QK_K; - } -#endif for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); @@ -3332,27 +5430,28 @@ kernel void kernel_mul_mv_q5_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q6_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const uint8_t kmask1 = 0x03; + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; const uint8_t kmask3 = 0x30; const uint8_t kmask4 = 0xC0; @@ -3375,7 +5474,6 @@ void kernel_mul_mv_q6_K_f32_impl( float sumf = 0; -#if QK_K == 256 const int tid = tiisg/2; const int ix = tiisg%2; const int ip = tid/8; // 0 or 1 @@ -3411,30 +5509,6 @@ void kernel_mul_mv_q6_K_f32_impl( } -#else - const int ix = tiisg/4; - const int il = 4*(tiisg%4); - - for (int i = ix; i < nb; i += 8) { - device const float * y = yy + i * QK_K + il; - device const uint8_t * ql = x[i].ql + il; - device const uint8_t * qh = x[i].qh + il; - device const int8_t * s = x[i].scales; - - const float d = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); - sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); - } - -#endif - const float tot = simd_sum(sumf); if (tiisg == 0) { dst[r1*ne0 + im*ne0*ne1 + row] = tot; @@ -3466,640 +5540,1708 @@ kernel void kernel_mul_mv_q6_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } -//============================= templates and their specializations ============================= +// ======================= "True" 2-bit -// NOTE: this is not dequantizing - we are simply fitting the template -template -void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { - float4x4 temp = *(((device float4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} +void kernel_mul_mv_iq2_xxs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { -template -void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { - half4x4 temp = *(((device half4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; -template -void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float md = -8.h * xb->d; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; - reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; - } -} + const uint i12 = im%ne12; + const uint i13 = im/ne12; -template -void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float m = xb->m; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; - reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; - } -} + device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; -template -void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 3); - const float d = xb->d; - const float md = -16.h * xb->d; - const ushort mask = il ? 0x00F0 : 0x000F; + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; - const uint32_t qh = *((device const uint32_t *)xb->qh); + const int nb32 = nb * (QK_K / 32); - const int x_mv = il ? 4 : 0; + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } - const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; + const int ix = tiisg; - for (int i = 0; i < 8; i++) { - // extract the 5-th bits for x0 and x1 - const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; - const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + device const float * y4 = y + 32 * ix; - // combine the 4-bits from qs with the 5th bit - const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); - const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - reg[i/2][2*(i%2)+0] = d * x0 + md; - reg[i/2][2*(i%2)+1] = d * x1 + md; - } -} + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } -template -void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 4); - const float d = xb->d; - const float m = xb->m; - const ushort mask = il ? 0x00F0 : 0x000F; + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); - const uint32_t qh = *((device const uint32_t *)xb->qh); + device const block_iq2_xxs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const half * dh = &xr->d; - const int x_mv = il ? 4 : 0; + for (int row = 0; row < N_DST; row++) { - const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; + const float db = dh[0]; + device const uint8_t * aux8 = (device const uint8_t *)q2; + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = db * (0.5f + (aux32 >> 28)); - for (int i = 0; i < 8; i++) { - // extract the 5-th bits for x0 and x1 - const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; - const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + float sum = 0; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d * sum; - // combine the 4-bits from qs with the 5th bit - const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); - const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + dh += nb*sizeof(block_iq2_xxs)/2; + q2 += nb*sizeof(block_iq2_xxs)/2; + } - reg[i/2][2*(i%2)+0] = d * x0 + m; - reg[i/2][2*(i%2)+1] = d * x1 + m; + y4 += 32 * 32; } -} -template -void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { - device const int8_t * qs = ((device const int8_t *)xb->qs); - const half d = xb->d; - - for (int i = 0; i < 16; i++) { - reg[i/4][i%4] = (qs[i + 16*il] * d); + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } } } -template -void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { - const float d = xb->d; - const float min = xb->dmin; - device const uint8_t * q = (device const uint8_t *)xb->qs; - float dl, ml; - uint8_t sc = xb->scales[il]; +[[host_name("kernel_mul_mv_iq2_xxs_f32")]] +kernel void kernel_mul_mv_iq2_xxs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { -#if QK_K == 256 - q = q + 32*(il/8) + 16*(il&1); - il = (il/2)%4; -#endif - half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } -template -void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * q = (device const uint8_t *)xb->qs; - device const uint8_t * h = (device const uint8_t *)xb->hmask; - device const int8_t * scales = (device const int8_t *)xb->scales; +void kernel_mul_mv_iq2_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { -#if QK_K == 256 - q = q + 32 * (il/8) + 16 * (il&1); - h = h + 16 * (il&1); - uint8_t m = 1 << (il/2); - uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ - ((il/4)>0 ? 12 : 3); - uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; - uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; - int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) - : (scale_2&kmask2) | ((scale_1&kmask1) << 4); - half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); - const half ml = 4.h * dl; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; - il = (il/2) & 3; - const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl *= coef; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); - } -#else - float kcoef = il&1 ? 1.f/16.f : 1.f; - uint16_t kmask = il&1 ? 0xF0 : 0x0F; - float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); - float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - uint8_t m = 1<<(il*2); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); } -#endif -} -static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { - return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} - : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; -} + const int ix = tiisg; -template -void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { - device const uchar * q = xb->qs; + device const float * y4 = y + 32 * ix; -#if QK_K == 256 - short is = (il/4) * 2; - q = q + (il/4) * 32 + 16 * (il&1); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.h; - const float min = xb->dmin; - const float dl = d * sc[0]; - const float ml = min * sc[1]; -#else - q = q + 16 * (il&1); - device const uint8_t * s = xb->scales; - device const half2 * dh = (device const half2 *)xb->d; - const float2 d = (float2)dh[0]; - const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; - const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4); -#endif - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { -template -void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { - device const uint8_t * q = xb->qs; - device const uint8_t * qh = xb->qh; + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } -#if QK_K == 256 - short is = (il/4) * 2; - q = q + 32 * (il/4) + 16 * (il&1); - qh = qh + 16 * (il&1); - uint8_t ul = 1 << (il/2); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.h; - const float min = xb->dmin; - const float dl = d * sc[0]; - const float ml = min * sc[1]; + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); - const ushort mask = il<2 ? 0x0F : 0xF0; - const float qh_val = il<2 ? 16.f : 256.f; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; - } -#else - q = q + 16 * (il&1); - device const int8_t * s = xb->scales; - const float dl = xb->d * s[il]; - uint8_t m = 1<<(il*2); - const float coef = il<2 ? 1.f : 1.f/16.f; - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); - } -#endif -} + device const block_iq2_xs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const uint8_t * sc = xr->scales + ib; + device const half * dh = &xr->d; -template -void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * ql = (device const uint8_t *)xb->ql; - device const uint8_t * qh = (device const uint8_t *)xb->qh; - device const int8_t * scales = (device const int8_t *)xb->scales; + for (int row = 0; row < N_DST; row++) { -#if QK_K == 256 - ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); - qh = qh + 32*(il/8) + 16*(il&1); - half sc = scales[(il%2) + 2 * ((il/2))]; - il = (il/2) & 3; -#else - ql = ql + 16 * (il&1); - half sc = scales[il]; -#endif - const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const half coef = il>1 ? 1.f/16.h : 1.h; - const half ml = d_all * sc * 32.h; - const half dl = d_all * sc * coef; - for (int i = 0; i < 16; ++i) { - const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) - : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); - reg[i/4][i%4] = dl * q - ml; + const float db = dh[0]; + const uint8_t ls1 = sc[0] & 0xf; + const uint8_t ls2 = sc[0] >> 4; + const float d1 = db * (0.5f + ls1); + const float d2 = db * (0.5f + ls2); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < 2; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + for (int l = 2; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d1 * sum1 + d2 * sum2; + + dh += nb*sizeof(block_iq2_xs)/2; + q2 += nb*sizeof(block_iq2_xs)/2; + sc += nb*sizeof(block_iq2_xs); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } } } -template -kernel void kernel_get_rows( +[[host_name("kernel_mul_mv_iq2_xs_f32")]] +kernel void kernel_mul_mv_iq2_xs_f32( device const void * src0, - device const char * src1, + device const float * src1, device float * dst, constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - //const int64_t i = tgpig; - //const int64_t r = ((device int32_t *) src1)[i]; + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq3_xxs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; - const int64_t i02 = i11; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; - for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { - float4x4 temp; - dequantize_func( - ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); } -} -kernel void kernel_get_rows_f32( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + const int ix = tiisg; - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + device const float * y4 = y + 32 * ix; - const int64_t i02 = i11; + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_xxs * xr = x + ibl; + device const uint8_t * q3 = xr->qs + 8 * ib; + device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float2 sum = {0}; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += nb*sizeof(block_iq3_xxs)/2; + q3 += nb*sizeof(block_iq3_xxs); + gas += nb*sizeof(block_iq3_xxs)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + } } } -kernel void kernel_get_rows_f16( +[[host_name("kernel_mul_mv_iq3_xxs_f32")]] +kernel void kernel_mul_mv_iq3_xxs_f32( device const void * src0, - device const char * src1, + device const float * src1, device float * dst, constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} - const int64_t i02 = i11; +void kernel_mul_mv_iq3_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; -#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A -#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B -#define BLOCK_SIZE_K 32 -#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A -#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B -#define THREAD_PER_BLOCK 128 -#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers -#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers -#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 -#define SG_MAT_ROW 8 + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; -// each block_q contains 16*nl weights -template -void kernel_mul_mm_impl(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + const uint i12 = im%ne12; + const uint i13 = im/ne12; - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; + device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; - // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + const int nb32 = nb * (QK_K / 32); - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); + threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; + threadgroup_barrier(mem_flags::mem_threadgroup); } - short il = (tiitg % THREAD_PER_ROW); + const int ix = tiisg; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + device const float * y4 = y + 32 * ix; - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); - ushort offset1 = il/nl; + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * (r1 * BLOCK_SIZE_N + thread_col) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - // load data and store to threadgroup memory - half4x4 temp_a; - dequantize_func(x, il, temp_a); - threadgroup_barrier(mem_flags::mem_threadgroup); + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); - #pragma unroll(16) - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; - } - - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; - y += BLOCK_SIZE_K; + device const block_iq3_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 8 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + (ib/2); + device const uint8_t * signs = xr->signs + 4 * ib; + device const half * dh = &xr->d; - threadgroup_barrier(mem_flags::mem_threadgroup); + for (int row = 0; row < N_DST; row++) { - // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + const float db = dh[0]; + const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); - #pragma unroll(4) - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - #pragma unroll(4) - for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); - } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) - for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + float2 sum = {0}; + for (int l = 0; l < 4; ++l) { + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); + for (int j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); + sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); + } } + sumf[row] += d * (sum[0] + sum[1]); - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - - #pragma unroll(8) - for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } + dh += nb*sizeof(block_iq3_s)/2; + qs += nb*sizeof(block_iq3_s); + qh += nb*sizeof(block_iq3_s); + sc += nb*sizeof(block_iq3_s); + signs += nb*sizeof(block_iq3_s); } + + y4 += 32 * 32; } - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); - } - } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; } + } +} - threadgroup_barrier(mem_flags::mem_threadgroup); +[[host_name("kernel_mul_mv_iq3_s_f32")]] +kernel void kernel_mul_mv_iq3_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; - if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); - } - } - } - } + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids -template -void kernel_mul_mm_id_impl( - device const uchar * src0, - device const uchar * src1, - thread short * src1ids, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - int64_t ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +void kernel_mul_mv_iq2_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; - if (r1 * BLOCK_SIZE_N >= ne1) return; + const uint i12 = im%ne12; + const uint i13 = im/ne12; - // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); - } + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; - short il = (tiitg % THREAD_PER_ROW); + const int nb32 = nb * (QK_K / 32); - const uint i12 = im%ne12; - const uint i13 = im/ne12; + //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + //{ + // int nval = 32; + // int pos = (32*sgitg + tiisg)*nval; + // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); - ushort offset1 = il/nl; + const int ix = tiisg; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col] - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + device const float * y4 = y + 32 * ix; - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - // load data and store to threadgroup memory - half4x4 temp_a; - dequantize_func(x, il, temp_a); - threadgroup_barrier(mem_flags::mem_threadgroup); + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; } - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; - y += BLOCK_SIZE_K; + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); - threadgroup_barrier(mem_flags::mem_threadgroup); + device const block_iq2_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + ib; + device const uint8_t * signs = qs + QK_K/8; + device const half * dh = &xr->d; - // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + for (int row = 0; row < N_DST; row++) { - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); - } - simdgroup_barrier(mem_flags::mem_none); - for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + const float db = dh[0]; + const float d1 = db * (0.5f + (sc[0] & 0xf)); + const float d2 = db * (0.5f + (sc[0] >> 4)); + + float2 sum = {0}; + for (int l = 0; l < 2; ++l) { + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + for (int j = 0; j < 8; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); + sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); + } } + sumf[row] += d1 * sum[0] + d2 * sum[1]; - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + dh += nb*sizeof(block_iq2_s)/2; + qs += nb*sizeof(block_iq2_s); + qh += nb*sizeof(block_iq2_s); + sc += nb*sizeof(block_iq2_s); + signs += nb*sizeof(block_iq2_s); + } - for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; } } +} - { - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); +[[host_name("kernel_mul_mv_iq2_s_f32")]] +kernel void kernel_mul_mv_iq2_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq1_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float sumy = 0; + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + sumy += yl[i]; } - threadgroup_barrier(mem_flags::mem_threadgroup); + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); - device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0; - if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); - } + device const block_iq1_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint16_t * qh = xr->qh + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); + + float sum = 0; + for (int j = 0; j < 4; ++j) { + sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); } + sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); + + dh += nb*sizeof(block_iq1_s)/2; + qs += nb*sizeof(block_iq1_s); + qh += nb*sizeof(block_iq1_s)/2; } + + y4 += 32 * 32; } -} -template + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq1_m_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + iq1m_scale_t scale; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float4 sumy = {0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_m * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + 2 * ib; + device const uint16_t * sc = (device const uint16_t *)xr->scales; + + for (int row = 0; row < N_DST; row++) { + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); + + float2 sum = {0.f}; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); + sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + + sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); + + sc += nb*sizeof(block_iq1_m)/2; + qs += nb*sizeof(block_iq1_m); + qh += nb*sizeof(block_iq1_m); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq4_nl_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK4_NL; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0 or 1 + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK4_NL + it * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ib = ix; ib < nb; ib += 16) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + + device const block_iq4_nl & xb = x[row*nb + ib]; + device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] | (q4[1] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[2] | (q4[3] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 16 * QK4_NL; + } + + for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq4_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/16; // 0 or 1 + const int it = tiisg%16; // 0...15 + const int ib = it/2; + const int il = it%2; + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_xs & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; + sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 2 * QK_K; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq1_s_f32")]] +kernel void kernel_mul_mv_iq1_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq1_m_f32")]] +kernel void kernel_mul_mv_iq1_m_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq4_nl_f32")]] +kernel void kernel_mul_mv_iq4_nl_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq4_xs_f32")]] +kernel void kernel_mul_mv_iq4_xs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +//============================= templates and their specializations ============================= + +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + float4x4 temp = *(((device float4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + half4x4 temp = *(((device half4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; + } +} + +template +void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; + reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; + } +} + +template +void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + md; + reg[i/2][2*(i%2)+1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + m; + reg[i/2][2*(i%2)+1] = d * x1 + m; + } +} + +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const half d = xb->d; + + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = (qs[i + 16*il] * d); + } +} + +template +void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { + const float d = xb->d; + const float min = xb->dmin; + device const uint8_t * q = (device const uint8_t *)xb->qs; + float dl, ml; + uint8_t sc = xb->scales[il]; + + q = q + 32*(il/8) + 16*(il&1); + il = (il/2)%4; + + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * q = (device const uint8_t *)xb->qs; + device const uint8_t * h = (device const uint8_t *)xb->hmask; + device const int8_t * scales = (device const int8_t *)xb->scales; + + q = q + 32 * (il/8) + 16 * (il&1); + h = h + 16 * (il&1); + uint8_t m = 1 << (il/2); + uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ + ((il/4)>0 ? 12 : 3); + uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; + uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) + : (scale_2&kmask2) | ((scale_1&kmask1) << 4); + float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + const float ml = 4.f * dl; + + il = (il/2) & 3; + const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl *= coef; + + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); + } +} + +static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { + return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} + : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; +} + +template +void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { + device const uchar * q = xb->qs; + + short is = (il/4) * 2; + q = q + (il/4) * 32 + 16 * (il&1); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + device const uint8_t * qh = xb->qh; + + short is = (il/4) * 2; + q = q + 32 * (il/4) + 16 * (il&1); + qh = qh + 16 * (il&1); + uint8_t ul = 1 << (il/2); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.f; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + const float qh_val = il<2 ? 16.f : 256.f; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; + } +} + +template +void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * ql = (device const uint8_t *)xb->ql; + device const uint8_t * qh = (device const uint8_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + + ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); + qh = qh + 32*(il/8) + 16*(il&1); + float sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; + + const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; + const float coef = il>1 ? 1.f/16.f : 1.f; + const float ml = d_all * sc * 32.f; + const float dl = d_all * sc * coef; + for (int i = 0; i < 16; ++i) { + const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) + : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); + reg[i/4][i%4] = dl * q - ml; + } +} + +template +void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. + device const uint16_t * q2 = xb->qs + 4*ib32; + const uint32_t aux32_g = q2[0] | (q2[1] << 16); + const uint32_t aux32_s = q2[2] | (q2[3] << 16); + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; + const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); + uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); + signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint16_t * q2 = xb->qs + 4*ib32; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); + uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); + signs = ksigns_iq2xs[q2[2*il+1] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * q3 = xb->qs + 8*ib32; + device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f; + constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]); + constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]); + uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127]; + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } + grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]); + grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]); + signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127]; + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 8*ib32; + device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); + constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); + constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + } + grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); + grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + } +} + +template +void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * signs = qs + QK_K/8; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); + for (int i = 0; i < 8; ++i) { + reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); + reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); + } +} + +template +void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + const float d = xb->d; + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint16_t * qh = xb->qh; + const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); + const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); + const uint16_t h = qh[ib32] >> 6*il; + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml; + reg[1][i] = dl * (grid1[i] >> 4) + ml; + reg[2][i] = dl * (grid2[i] & 0xf) + ml; + reg[3][i] = dl * (grid2[i] >> 4) + ml; + } +} + +template +void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + device const uint16_t * sc = (device const uint16_t *)xb->scales; + + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = scale.f16; + + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * qh = xb->qh + 2*ib32 + il; + + const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); + const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml1; + reg[1][i] = dl * (grid1[i] >> 4) + ml1; + reg[2][i] = dl * (grid2[i] & 0xf) + ml2; + reg[3][i] = dl * (grid2[i] >> 4) + ml2; + } +} + +template +void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +template +void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; + const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); + const float d = (float)xb->d * (ls - 32); + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +template +kernel void kernel_get_rows_q( + device const void * src0, + device const void * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { + float4x4 temp; + dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} + +template +kernel void kernel_get_rows_f( + device const void * src0, + device const void * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; + } +} + +kernel void kernel_get_rows_i32( + device const void * src0, + device const void * src1, + device int32_t * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; + } +} + + +#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B +#define BLOCK_SIZE_K 32 +#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A +#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B +#define THREAD_PER_BLOCK 128 +#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers +#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers +#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 +#define SG_MAT_ROW 8 + +// each block_q contains 16*nl weights +template kernel void kernel_mul_mm(device const uchar * src0, device const uchar * src1, device float * dst, @@ -4107,10 +7249,12 @@ kernel void kernel_mul_mm(device const uchar * src0, constant int64_t & ne02, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4119,629 +7263,267 @@ kernel void kernel_mul_mm(device const uchar * src0, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mm_impl( - src0, - src1, - dst, - ne00, - ne02, - nb01, - nb02, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - shared_memory, - tgpig, - tiitg, - sgitg); -} -template -kernel void kernel_mul_mm_id( - device const uchar * ids, - device const uchar * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + threadgroup T * sa = (threadgroup T *)(shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + const uint im = tgpig.z; - // expert id - const int32_t id = tgpig.z/(ne12*ne13); + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; - tgpig.z = tgpig.z%(ne12*ne13); + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - // row indices of src1 for expert id - int64_t _ne1 = 0; - short src1ids[512]; + simdgroup_T8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 mc[8]; - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { - src1ids[_ne1++] = i1; - } + for (short i = 0; i < 8; i++){ + mc[i] = make_filled_simdgroup_matrix(0.f); } - kernel_mul_mm_id_impl( - src0s[id], - src1, - src1ids, - dst, - ne00, - ne02, - nb01, - nb02, - ne12, - nb10, - nb11, - nb12, - ne0, - _ne1, - r2, - r3, - shared_memory, - tgpig, - tiitg, - sgitg); -} + short il = (tiitg % THREAD_PER_ROW); -#if QK_K == 256 -#define QK_NL 16 -#else -#define QK_NL 4 -#endif + const uint i12 = im%ne12; + const uint i13 = im/ne12; -// -// get rows -// + uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03; + ushort offset1 = il/nl; -typedef void (get_rows_t)( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3, uint, uint3); - -//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; -//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; + device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb13 * i13 + + nb12 * i12 + + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); -// -// matrix-matrix multiplication -// + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + T4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); -typedef void (mat_mm_t)( - device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar *, - uint3, uint, uint); - -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; + #pragma unroll(16) + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + } -// -// indirect matrix-matrix multiplication -// + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)*8*32 + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y); -typedef void (mat_mm_id_t)( - device const uchar * ids, - device const uchar * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, - threadgroup uchar *, - uint3, uint, uint); - -template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; -// -// matrix-vector multiplication -// + threadgroup_barrier(mem_flags::mem_threadgroup); -[[host_name("kernel_mul_mv_id_f32_f32")]] -kernel void kernel_mul_mv_id_f32_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + // load matrices from threadgroup memory and conduct outer products + threadgroup T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); - const int64_t bid = tgpig.z/(ne12*ne13); + #pragma unroll(4) + for (short ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) + for (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); + } - tgpig.z = tgpig.z%(ne12*ne13); + lsma += BLOCK_SIZE_M/SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N/SG_MAT_ROW * SG_MAT_SIZE; - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + #pragma unroll(8) + for (short i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); + } + } + } - kernel_mul_mv_f32_f32_impl( - src0[id], - src1 + bid*nb11, - dst + bid*ne0, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg); -} + if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { + device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ + + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *) shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M; + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + } -[[host_name("kernel_mul_mv_id_f16_f32")]] -kernel void kernel_mul_mv_id_f16_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + threadgroup_barrier(mem_flags::mem_threadgroup); - const int64_t bid = tgpig.z/(ne12*ne13); + if (sgitg == 0) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0; + device float4 * D4 = (device float4 *) D; - tgpig.z = tgpig.z%(ne12*ne13); + threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + int i = 0; + for (; i < n_rows/4; i++) { + *(D4 + i) = *(C4 + i); + } - kernel_mul_mv_f16_f32_impl( - src0[id], - src1 + bid*nb11, - dst + bid*ne0, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg); + i *= 4; + for (; i < n_rows; i++) { + *(D + i) = *(C + i); + } + } + } + } } -[[host_name("kernel_mul_mv_id_q8_0_f32")]] -kernel void kernel_mul_mv_id_q8_0_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; +// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids +template +void kernel_mul_mm_id_impl( + device const uchar * src0, + device const uchar * src1, + threadgroup ushort2 * rowids, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + int64_t ne1, + int64_t ne0ne1, + threadgroup uchar * shared_memory, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); + threadgroup half * sa = (threadgroup half *)(shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - tgpig.z = tgpig.z%(ne12*ne13); + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + if (r1 * BLOCK_SIZE_N >= ne1) return; - kernel_mul_mv_q8_0_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; -[[host_name("kernel_mul_mv_id_q4_0_f32")]] -kernel void kernel_mul_mv_id_q4_0_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - const int64_t bid = tgpig.z/(ne12*ne13); + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + short il = (tiitg % THREAD_PER_ROW); - tgpig.z = tgpig.z%(ne12*ne13); + ushort offset1 = il/nl; - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * id[1] + + nb11 * (id[0] % ne11) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); -[[host_name("kernel_mul_mv_id_q4_1_f32")]] -kernel void kernel_mul_mv_id_q4_1_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); - const int64_t bid = tgpig.z/(ne12*ne13); + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } - tgpig.z = tgpig.z%(ne12*ne13); + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} + threadgroup_barrier(mem_flags::mem_threadgroup); -[[host_name("kernel_mul_mv_id_q5_0_f32")]] -kernel void kernel_mul_mv_id_q5_0_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + // load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); + } - const int64_t bid = tgpig.z/(ne12*ne13); + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - tgpig.z = tgpig.z%(ne12*ne13); + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + { + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); + threadgroup_barrier(mem_flags::mem_threadgroup); + + device float * C = dst + (BLOCK_SIZE_M * r0); + if (sgitg == 0) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; + int joff = jid[0] * ne0 + jid[1] * ne0ne1; + for (int i = 0; i < n_rows; i++) { + *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } + } } -[[host_name("kernel_mul_mv_id_q5_1_f32")]] -kernel void kernel_mul_mv_id_q5_1_f32( - device const char * ids, - device const char * src1, +template +kernel void kernel_mul_mm_id( + device const uchar * src0s, + device const uchar * src1, device float * dst, + device const uchar * ids, + constant int64_t & nei0, + constant int64_t & nei1, constant uint64_t & nbi1, constant int64_t & ne00, - constant int64_t & ne01, constant int64_t & ne02, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, - constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, @@ -4751,178 +7533,266 @@ kernel void kernel_mul_mv_id_q5_1_f32( constant int64_t & ne0, constant int64_t & ne1, constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, + threadgroup uchar * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); + const int32_t i02 = tgpig.z; + tgpig.z = 0; + + device const uchar * src0 = src0s + i02*nb02; - tgpig.z = tgpig.z%(ne12*ne13); + // row indices + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + + // TODO: parallelize this loop + int64_t _ne1 = 0; + for (ushort ii1 = 0; ii1 < nei1; ii1++) { + for (ushort ii0 = 0; ii0 < nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + if (id == i02) { + //if (tiitg == 0) { + rowids[_ne1] = ushort2(ii0, ii1); + //} + _ne1++; + } + } + } - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + threadgroup_barrier(mem_flags::mem_threadgroup); - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, + kernel_mul_mm_id_impl( + src0, + src1, + rowids, + dst, ne00, - ne01, ne02, - ne10, + nb01, + nb02, + ne11, ne12, + nb10, + nb11, + nb12, ne0, - ne1, - r2, - r3, + _ne1, + ne0*ne1, + shared_memory, tgpig, - tiisg, + tiitg, sgitg); } -[[host_name("kernel_mul_mv_id_q2_K_f32")]] -kernel void kernel_mul_mv_id_q2_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; +#define QK_NL 16 - const int64_t bid = tgpig.z/(ne12*ne13); +// +// get rows +// - tgpig.z = tgpig.z%(ne12*ne13); +typedef decltype(kernel_get_rows_f) get_rows_f_t; + +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; + +typedef decltype(kernel_get_rows_q) get_rows_q_t; + +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q; - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; +// +// matrix-matrix multiplication +// - kernel_mul_mv_q2_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} +typedef decltype(kernel_mul_mm) mat_mm_t; + +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -[[host_name("kernel_mul_mv_id_q3_K_f32")]] -kernel void kernel_mul_mv_id_q3_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; +// +// indirect matrix-matrix multiplication +// - const int64_t bid = tgpig.z/(ne12*ne13); +typedef decltype(kernel_mul_mm_id) mat_mm_id_t; + +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; - tgpig.z = tgpig.z%(ne12*ne13); +// +// matrix-vector multiplication +// - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; +typedef void (kernel_mul_mv_impl_t)( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg); - kernel_mul_mv_q3_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); +typedef void (kernel_mul_mv2_impl_t)( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg); + +template +void mmv_fn( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); } -[[host_name("kernel_mul_mv_id_q4_K_f32")]] -kernel void kernel_mul_mv_id_q4_K_f32( - device const char * ids, +template +void mmv_fn( + device const char * src0, device const char * src1, device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); +} + +typedef decltype(mmv_fn>) mul_mv_impl_fn_t; + +template +kernel void kernel_mul_mv_id( + device const char * src0s, + device const char * src1, + device float * dst, + device const char * ids, + constant int64_t & nei0, + constant int64_t & nei1, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -4940,169 +7810,176 @@ kernel void kernel_mul_mv_id_q4_K_f32( constant int64_t & ne0, constant int64_t & ne1, constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q4_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, + const int iid1 = tgpig.z/nei0; + const int idx = tgpig.z%nei0; + + tgpig.z = 0; + + const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; + + const int64_t i11 = idx % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = idx; + const int64_t i2 = i12; + + device const char * src0_cur = src0s + i02*nb02; + device const char * src1_cur = src1 + i11*nb11 + i12*nb12; + device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + + impl_fn( + /* src0 */ src0_cur, + /* src1 */ src1_cur, + /* dst */ dst_cur, + /* ne00 */ ne00, + /* ne01 */ ne01, + /* ne02 */ 1,//ne02, + /* nb00 */ nb00, + /* nb01 */ nb01, + /* nb02 */ nb02, + /* ne10 */ ne10, + /* ne11 */ 1,//ne11, + /* ne12 */ 1,//ne12, + /* ne13 */ 1,//ne13, + /* nb10 */ nb10, + /* nb11 */ nb11, + /* nb12 */ nb12, + /* ne0 */ ne0, + /* ne1 */ 1,//ne1, + /* nb1 */ nb1, + /* r2 */ 1, + /* r3 */ 1, + shared_values, tgpig, + tiitg, tiisg, sgitg); } -[[host_name("kernel_mul_mv_id_q5_K_f32")]] -kernel void kernel_mul_mv_id_q5_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; + +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; + +kernel void kernel_pool_2d_max_f32( + device const float * src0, + device float * dst, + constant int32_t & k0, + constant int32_t & k1, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int64_t & IH, + constant int64_t & IW, + constant int64_t & OH, + constant int64_t & OW, + constant int64_t & parallel_elements, + uint gid[[thread_position_in_grid]]) { + + if (gid >= parallel_elements) { + return; + } - const int64_t bid = tgpig.z/(ne12*ne13); + const int idx = gid; + const int I_HW = IH * IW; + const int O_HW = OH * OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / OW; + const int cur_ow = idx % O_HW % OW; - tgpig.z = tgpig.z%(ne12*ne13); + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + const int start_h = cur_oh * s1 - p1; + const int bh = MAX(0, start_h); + const int eh = MIN(IH, start_h + k1); + const int start_w = cur_ow * s0 - p0; + const int bw = MAX(0, start_w); + const int ew = MIN(IW, start_w + k0); - kernel_mul_mv_q5_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} + float res = -INFINITY; -[[host_name("kernel_mul_mv_id_q6_K_f32")]] -kernel void kernel_mul_mv_id_q6_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + res = MAX(res, i_ptr[i * IW + j]); + } + } - const int64_t bid = tgpig.z/(ne12*ne13); + o_ptr[cur_oh * OW + cur_ow] = res; +} - tgpig.z = tgpig.z%(ne12*ne13); +kernel void kernel_pool_2d_avg_f32( + device const float * src0, + device float * dst, + constant int32_t & k0, + constant int32_t & k1, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int64_t & IH, + constant int64_t & IW, + constant int64_t & OH, + constant int64_t & OW, + constant int64_t & parallel_elements, + uint gid[[thread_position_in_grid]]) { + + if (gid >= parallel_elements) { + return; + } - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + const int idx = gid; + const int I_HW = IH * IW; + const int O_HW = OH * OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / OW; + const int cur_ow = idx % O_HW % OW; + + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * s1 - p1; + const int bh = MAX(0, start_h); + const int eh = MIN(IH, start_h + k1); + const int start_w = cur_ow * s0 - p0; + const int bw = MAX(0, start_w); + const int ew = MIN(IW, start_w + k0); + // const float scale = 1. / ((eh - bh) * (ew - bw)); + const float scale = 1. / (k0 * k1); + + float res = 0; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + float cur = i_ptr[i * IW + j]; + res += cur * scale; + } + } - kernel_mul_mv_q6_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); + o_ptr[cur_oh * OW + cur_ow] = res; } diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal index ab129d13a1..eb3d2d7326 100644 --- a/candle-metal-kernels/src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -5,6 +5,262 @@ using namespace metal; +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +#if defined(__HAVE_BFLOAT__) + +typedef bfloat bfloat16_t; +typedef half float16_t; + +#else + +///////////////////////////////////////////////////////////////////////////// +// Helpers +///////////////////////////////////////////////////////////////////////////// + +constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { + // Check for nan + if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > + _fp_encoding_traits::inf_mask) { + return uint16_t(as_type(0x7FC0)); + } + // Take bits + uint32_t float_bits = as_type(x); + + // Round to nearest even + float_bits += ((float_bits >> 16) & 1) + as_type(0x7FFF); + + // Take upper 16 bits + return float_bits >> 16; +} + +constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) { + // Upper 16 bits are the data and lower 16 bits are 0s + return as_type((uint32_t)x << 16); +} + +struct _MLX_BFloat16; + +template +static constexpr constant bool can_convert_to_bfloat = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_bfloat = + !is_same_v && is_convertible_v; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat struct +///////////////////////////////////////////////////////////////////////////// + +struct _MLX_BFloat16 { + ///////////////////////////////////////////////////////////////////////////// + // Constructors + uint16_t bits_; + _MLX_BFloat16() thread = default; + _MLX_BFloat16() threadgroup = default; + _MLX_BFloat16() device = default; + _MLX_BFloat16() constant = default; + + struct bits_to_bfloat_struct {}; + static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() { + return bits_to_bfloat_struct(); + } + constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct) + : bits_(bits) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions to bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) thread + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) device + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) constant + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions from bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const thread { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const threadgroup { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const device { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const constant { + return static_cast(bfloat_bits_to_float(bits_)); + } +}; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat operators +///////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////// +// Unary ops +constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { + return -static_cast(x); +} + +///////////////////////////////////////////////////////////////////////////// +// Binary operators +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +///////////////////////////////////////////////////////////////////////////// +// Arithmetic Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +///////////////////////////////////////////////////////////////////////////// +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop +#undef bfloat_binop_base +#undef bfloat_binop_helper +#undef bfloat_binop + +///////////////////////////////////////////////////////////////////////////// +// Inplace Operators +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype& __operator__( \ + addr_space itype& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); + +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ + bfloat_inplace_op_addr_space_helper(/, operator/=, itype); + +bfloat_inplace_op(float); +bfloat_inplace_op(half); +bfloat_inplace_op(int16_t); +bfloat_inplace_op(int32_t); +bfloat_inplace_op(int64_t); +bfloat_inplace_op(uint16_t); +bfloat_inplace_op(uint32_t); +bfloat_inplace_op(uint64_t); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper +#undef bfloat_inplace_op + +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, threadgroup); + +bfloat_inplace_op_addr_space_helper(+, operator+=); +bfloat_inplace_op_addr_space_helper(-, operator-=); +bfloat_inplace_op_addr_space_helper(*, operator*=); +bfloat_inplace_op_addr_space_helper(/, operator/=); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper + +///////////////////////////////////////////////////////////////////////////// +// Bfloat typedef +///////////////////////////////////////////////////////////////////////////// + +typedef struct _MLX_BFloat16 bfloat16_t; + +#endif + // ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" struct MLXFastAttentionParams { @@ -140,6 +396,9 @@ template // Move the pointers to the next kv keys += stride; values += stride; + if (sdpa_vector_has_mask) { + mask += BN * mask_seq_stride; + } } // Each thread has a partial part of the output so we need to combine them. @@ -275,6 +534,43 @@ template mask += BN * blocks * mask_seq_stride; } } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; + sum_exp_score = simd_sum(sum_exp_score * factor); + + // Write the sum and new max + if (simd_gid == 0) { + sums[0] = sum_exp_score; + maxs[0] = new_max; + } + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BN + simd_gid] = + o[i] * fast::exp(max_scores[simd_gid] - new_max); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // And write the output + if (simd_gid == 0) { + U output = outputs[simd_lid * BN]; + for (int j = 1; j < BN; j++) { + output += outputs[simd_lid * BN + j]; + } + out[i] = static_cast(output); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } } template @@ -329,114 +625,55 @@ template } } -// ============ "mlx/backend/metal/kernels/steel/defines.h" - -#define STEEL_CONST static constant constexpr const -#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") - -// ============ "mlx/backend/metal/kernels/steel/gemm/transforms.h" - -template -struct TransformNone { - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - static METAL_FUNC OutT apply(InT x, OutT) { - return static_cast(x); - } -}; - -template -struct TransformAdd { - TransformAdd(const float, const float) {} - - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - static METAL_FUNC OutT apply(InT x, OutT c) { - return static_cast(x) + c; - } -}; - -template -struct TransformAxpby { - const float alpha; - const float beta; - - TransformAxpby(const float alpha_, const float beta_) - : alpha(alpha_), beta(beta_) {} - - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - METAL_FUNC OutT apply(InT x, OutT c) const { - return static_cast(x * alpha + (beta * c)); - } -}; - -template -struct AccumHelper { - typedef float accum_type; -}; +// ============ "mlx/backend/metal/kernels/utils.h" -struct BlockSwizzle { - static METAL_FUNC int2 - swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { - const int tid_x = (tid.x) >> swizzle_log; - const int tid_y = - ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); - return int2(tid_x, tid_y); - } +template +struct Limits { + static const constant U max = metal::numeric_limits::max(); + static const constant U min = metal::numeric_limits::min(); + static const constant U finite_max = metal::numeric_limits::max(); + static const constant U finite_min = metal::numeric_limits::min(); }; -// ============ "mlx/backend/metal/kernels/utils.h" +#define instantiate_default_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = metal::numeric_limits::max(); \ + static constexpr constant type min = metal::numeric_limits::min(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + metal::numeric_limits::min(); \ + }; -#if defined(__HAVE_BFLOAT__) -typedef bfloat bfloat16_t; -#endif -typedef half float16_t; +instantiate_default_limit(uint8_t); +instantiate_default_limit(uint16_t); +instantiate_default_limit(uint32_t); +instantiate_default_limit(uint64_t); +instantiate_default_limit(int8_t); +instantiate_default_limit(int16_t); +instantiate_default_limit(int32_t); +instantiate_default_limit(int64_t); + +#define instantiate_float_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = \ + metal::numeric_limits::infinity(); \ + static constexpr constant type min = \ + -metal::numeric_limits::infinity(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + -metal::numeric_limits::max(); \ + }; -METAL_FUNC ulong2 elem_to_loc_broadcast( - uint elem, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - int ndim) { - ulong loc_a{0}; - ulong loc_b{0}; - for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - int pos_in_dim = (elem % shape[i]); - elem /= shape[i]; - loc_a += pos_in_dim * a_strides[i]; - loc_b += pos_in_dim * b_strides[i]; - } - return ulong2(loc_a, loc_b); -} +instantiate_float_limit(half); +instantiate_float_limit(float); +instantiate_float_limit(bfloat16_t); -METAL_FUNC ulong3 elem_to_loc_broadcast( - uint elem, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - constant const size_t* c_strides, - int ndim) { - ulong loc_a{0}; - ulong loc_b{0}; - ulong loc_c{0}; - for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - int pos_in_dim = (elem % shape[i]); - elem /= shape[i]; - loc_a += pos_in_dim * a_strides[i]; - loc_b += pos_in_dim * b_strides[i]; - loc_c += pos_in_dim * c_strides[i]; - } - return ulong3(loc_a, loc_b, loc_c); -} -// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.metal" +// ============ "mlx/backend/metal/kernels/steel/attn/loader.h" template < typename T, @@ -449,7 +686,7 @@ template < short n_reads = (BCOLS * BROWS) / (tgp_size), short TCOLS = BCOLS / n_reads, short TROWS = tgp_size / TCOLS> -struct BlockLoaderFA { +struct BlockLoader { STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; STEEL_CONST short vec_size = n_reads; @@ -471,7 +708,7 @@ struct BlockLoaderFA { }; /* Constructor */ - METAL_FUNC BlockLoaderFA( + METAL_FUNC BlockLoader( const device T* src_, const int src_ld_, threadgroup T* dst_, @@ -485,6 +722,18 @@ struct BlockLoaderFA { dst(dst_ + bi * dst_ld + bj), src(src_ + bi * src_ld + bj) {} + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { STEEL_PRAGMA_UNROLL @@ -546,242 +795,925 @@ struct BlockLoaderFA { METAL_FUNC void next() { src += tile_stride; } - METAL_FUNC void next(short n) { - src += n * tile_stride; - } }; -template -struct LoopAlignment {}; +template +struct CShape { + STEEL_CONST int kRows = R; + STEEL_CONST int kCols = C; +}; template < typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - short lda_tgp, - short ldb_tgp, - typename AccumType = float, - typename Epilogue = TransformNone> -struct BlockMMAFA { - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TM_stride = 8 * WM; - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TN_stride = 8 * WN; - - // Warp tile size along M - STEEL_CONST short TM = BM / TM_stride; - // Warp tile size along N - STEEL_CONST short TN = BN / TN_stride; - - // Strides of A, B along reduction axis - STEEL_CONST short simd_stride_a = { - transpose_a ? TM_stride : TM_stride * lda_tgp}; - STEEL_CONST short simd_stride_b = { - transpose_b ? TN_stride * ldb_tgp : TN_stride}; - - // Jump between elements - STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; - STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; - - STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; - STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; - - // Simdgroup matrices - simdgroup_matrix Asimd[TM]; - simdgroup_matrix Bsimd[TN]; - simdgroup_matrix results[TM * TN] = { - simdgroup_matrix(0)}; - - // Offsets within threadgroup - const short tm; - const short tn; + short BROWS, + short BCOLS, + short kDstStrRow, + short kDstStrCol, + short reduction_dim, + short tgp_size, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoaderT { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; - short sm; - short sn; + // Leading dimension for src + const int src_ld; + const int tile_stride; - ushort sid; - ushort slid; + // Thread location indices + const short thread_idx; + const short bi; + const short bj; - short As_offset; - short Bs_offset; + // threadgroup and device memory + threadgroup T* dst; + const device T* src; /* Constructor */ - METAL_FUNC BlockMMAFA( + METAL_FUNC BlockLoaderT( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) - : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { - // Determine thread position in simdgroup matrix - short qid = simd_lane_id / 4; - slid = simd_lane_id; - sid = simd_group_id; - - sm = (qid & 4) + (simd_lane_id / 2) % 4; - sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - - // Determine thread and simdgroup offset - As_offset = - transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); - Bs_offset = - transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); - } - - /* (BM, BK) X (BK, BN) multiply accumulate function */ - METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { - // Adjust for simdgroup and thread location - As += As_offset; - Bs += Bs_offset; + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * kDstStrRow + bj * kDstStrCol), + src(src_ + bi * src_ld + bj) {} - // Iterate over BK in blocks of 8 + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < BK; kk += 8) { - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup A as simdgroup matrices + for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = - static_cast(As[i * simd_stride_a + 0]); - Asimd[i].thread_elements()[1] = - static_cast(As[i * simd_stride_a + jump_a]); + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = + op.apply(dst[i * kDstStrRow + j * kDstStrCol]); } + } + } - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup B as simdgroup matrices + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = - static_cast(Bs[j * simd_stride_b + 0]); - Bsimd[j].thread_elements()[1] = - static_cast(Bs[j * simd_stride_b + jump_b]); + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j]; } + } + } - simdgroup_barrier(mem_flags::mem_none); + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); - // Multiply and accumulate into result simdgroup matrices + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { + for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - short j_serp = (i % 2) ? (TN - 1 - j) : j; - - simdgroup_multiply_accumulate( - results[i * TN + j_serp], - Asimd[i], - Bsimd[j_serp], - results[i * TN + j_serp]); + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = T(0); } } - - // Progress to next simdgroup tile - As += tile_stride_a; - Bs += tile_stride_b; + return; } - } - METAL_FUNC void rescale_output(const threadgroup float* Corrections) { - // Loop over all simdgroup tiles + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - short row = sm + tm + i * TM_stride; - float scale_value = Corrections[row]; + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +// ============ "mlx/backend/metal/kernels/steel/utils/type_traits.h" + +template +struct make_void { + typedef void type; +}; + +template +using void_t = typename make_void::type; + +template +struct pointer_element {}; + +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; + +template +using pointer_element_t = typename pointer_element>::type; + +// ============ "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +/////////////////////////////////////////////////////////////////////////////// +// Integral constant with casting +/////////////////////////////////////////////////////////////////////////////// + +template +using Int = integral_constant; + +/////////////////////////////////////////////////////////////////////////////// +// Binary Operators on Integral constants +/////////////////////////////////////////////////////////////////////////////// + +#define integral_const_binop(__op__, __operator__) \ + template \ + METAL_FUNC constexpr auto __operator__( \ + integral_constant, integral_constant) { \ + constexpr auto res = tv __op__ uv; \ + return integral_constant{}; \ + } + +integral_const_binop(+, operator+); +integral_const_binop(-, operator-); +integral_const_binop(*, operator*); +integral_const_binop(/, operator/); + +integral_const_binop(==, operator==); +integral_const_binop(!=, operator!=); +integral_const_binop(<, operator<); +integral_const_binop(>, operator>); +integral_const_binop(<=, operator<=); +integral_const_binop(>=, operator>=); + +integral_const_binop(&&, operator&&); +integral_const_binop(||, operator||); + +#undef integral_const_binop + +/////////////////////////////////////////////////////////////////////////////// +// Reduction operators +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC constexpr T sum(T x) { + return x; +} + +template +METAL_FUNC constexpr auto sum(T x, Us... us) { + return x + sum(us...); +} + +// ============ "mlx/backend/metal/kernels/steel/gemm/transforms.h" + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +// ============ "mlx/backend/metal/kernels/steel/attn/mma.h" + +template +struct Shape2D { + RInt r; + CInt c; + + Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {} +}; + +template +struct Layout2D { + Shape shape; + Layout layout; +}; + +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; + +template +struct BaseMMAFrag { + STEEL_CONST int kFragRows = 8; + STEEL_CONST int kFragCols = 8; + + STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST int kElemRows = 1; + STEEL_CONST int kElemCols = 2; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + typedef metal::vec row_frag_type; + typedef metal::vec col_frag_type; + + template + using dtype_mat_t = typename metal::simdgroup_matrix; + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id + [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x.value + j * str_y.value]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[(off_x + i) * str_x + (off_y + j) * str_y.value]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y.value] = static_cast(src[i * kElemCols + j]); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y.value] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread dtype_frag_t& A, + thread dtype_frag_t& B, + thread dtype_frag_t& C) { + mat_type D_mat; + dtype_mat_t A_mat; + dtype_mat_t B_mat; + dtype_mat_t C_mat; + + reinterpret_cast&>(A_mat.thread_elements()) = A; + reinterpret_cast&>(B_mat.thread_elements()) = B; + reinterpret_cast&>(C_mat.thread_elements()) = C; + + mma(D_mat, A_mat, B_mat, C_mat); + + D = reinterpret_cast(D_mat.thread_elements()); + } + + template + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread dtype_mat_t& A, + thread dtype_mat_t& B, + thread dtype_mat_t& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const frag_type& inp_vals, + thread T* reduced_vals) { + T thr_reduce = Op::apply(inp_vals.x, inp_vals.y); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce); + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread frag_type& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; + STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; + STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + + STEEL_CONST int kTileRows = kTileRows_; + STEEL_CONST int kTileCols = kTileCols_; + + STEEL_CONST int kRows = kTileRows * kFragRows; + STEEL_CONST int kCols = kTileCols * kFragCols; + + STEEL_CONST int kNumFrags = kTileRows * kTileCols; + STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols; + + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + + frag_type val_frags[kNumFrags]; // = {frag_type(0)}; + + METAL_FUNC MMATile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_reduce( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_bin_op( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; + +template < + typename Dtype, + typename Atype, + typename Btype, + typename Ctype, + int M, + int N, + int K, + class MMAFragD, + class MMAFragA, + class MMAFragB, + class MMAFragC> +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { + STEEL_PRAGMA_UNROLL + for (short m = 0; m < M; ++m) { + STEEL_PRAGMA_UNROLL + for (short n = 0; n < N; ++n) { + short m_serp = m; //(n % 2) ? (M - 1 - m) : m; + short n_serp = (m % 2) ? (N - 1 - n) : n; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < K; ++k) { + MMAFragD::mma( + D.frag_at(m_serp, n_serp), + A.frag_at(m_serp, k), + B.frag_at(k, n_serp), + C.frag_at(m_serp, n_serp)); + } + } + } +} + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // Simdgroup matrices + MMATile Atile; + MMATile Btile; + MMATile Ctile; + + // Offsets within threadgroup + short sm; + short sn; + + short As_offset; + short Bs_offset; - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread auto& accum = results[i * TN + j].thread_elements(); - // int offset = (i * TM_stride) * ldc + (j * TN_stride); - accum[0] *= scale_value; - accum[1] *= scale_value; - } + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N + + sm += tm; + sn += tn; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + Atile.template load(As); + + simdgroup_barrier(mem_flags::mem_none); + + Btile.template load(Bs); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Ctile, Atile, Btile, Ctile); + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; } } /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device U* C, const int ldc) const { + METAL_FUNC void store_result(device U* D, const int ldd) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + tn + sn; + D += sm * ldd + sn; - // Loop over all simdgroup tiles + Ctile.template store(D, ldd); + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + // Apply epilogue STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } - // Apply epilogue - U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); - // Write out C - C[offset] = outs[0]; - C[offset + 1] = outs[1]; - } + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Ctile.template store_safe(D, ldd, dst_tile_dims); + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); } } - METAL_FUNC void store_result_to_tgp_memory( - threadgroup U* C, + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, const int ldc, - short2 dst_tile_dims) const { + const int fdc, + thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - C[offset] = Epilogue::apply(accum[0]); - } + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - if (j * TN_stride + 1 < dst_tile_dims.x) { - C[offset + 1] = Epilogue::apply(accum[1]); - } + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); } } } } - METAL_FUNC void - store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - C[offset] = Epilogue::apply(accum[0]); - } + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Read C + U c_elems[kelems] = {0}; - if (j * TN_stride + 1 < dst_tile_dims.x) { - C[offset + 1] = Epilogue::apply(accum[1]); + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; } } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); + } } } } @@ -795,8 +1727,10 @@ struct BlockMMAFA { const int fdc, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL @@ -804,18 +1738,15 @@ struct BlockMMAFA { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); // Apply epilogue - U outs[2] = { - epilogue_op.apply(accum[0], C[offset_c]), - epilogue_op.apply(accum[1], C[offset_c + fdc])}; - - // Write out D - D[offset_d] = outs[0]; - D[offset_d + 1] = outs[1]; + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } @@ -829,9 +1760,14 @@ struct BlockMMAFA { short2 dst_tile_dims, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; STEEL_PRAGMA_UNROLL for (int i = 0; i < TM; i++) { @@ -839,556 +1775,551 @@ struct BlockMMAFA { STEEL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } } } +}; - METAL_FUNC void clear_results() { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - results[i * TN + j] = simdgroup_matrix(0); - } - } +// ============ "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h" + +struct AttnParams { + int B; ///< Batch Size + int H; ///< Heads + int D; ///< Head Dim + + int qL; ///< Query Sequence Length + int kL; ///< Key Sequence Length + + int gqa_factor; ///< Group Query factor + float scale; ///< Attention scale + + int NQ; ///< Number of query blocks + int NK; ///< Number of key/value blocks + + int NQ_aligned; ///< Number of full query blocks + int NK_aligned; ///< Number of full key/value blocks + + int qL_rem; ///< Remainder in last query block + int kL_rem; ///< Remainder in last key/value block + int qL_off; ///< Offset in query sequence start + + int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) + int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) + int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) + int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) +}; + +struct AttnMaskParams { + int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1) +}; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool align_Q [[function_constant(200)]]; +constant bool align_K [[function_constant(201)]]; + +constant bool has_mask [[function_constant(300)]]; +constant bool do_causal [[function_constant(301)]]; + +template +struct TransformScale { + T scale; + METAL_FUNC TransformScale(T scale_) : scale(scale_) {} + + METAL_FUNC T apply(T x) const { + return scale * x; } }; +struct MaxOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return metal::max(x, y); + } +}; + +struct SumOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x + y; + } +}; + +struct MulOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x * y; + } +}; + +struct SubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x - y; + } +}; + +struct ExpSubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return fast::exp2(x - y); + } +}; + +struct DivOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x / y; + } +}; + +// clang-format off template < typename T, - typename U, - int BM, - int BN, + int BQ, int BK, + int BD, int WM, int WN, - bool transpose_q, - bool transpose_k, - bool transpose_v, - bool MN_aligned, - bool K_aligned, - typename AccumType = typename AccumHelper::accum_type, - typename Epilogue = TransformNone> -struct FastAttentionKernel { - STEEL_CONST short tgp_padding = 16 / sizeof(T); - STEEL_CONST short float_padding = 16 / sizeof(float); - STEEL_CONST short tgp_mem_size_q = - transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_k = - transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_v = - transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding); - - // maxes, rowsums, rescale - STEEL_CONST short tgp_mem_size_corrections = - 4 * (BM * sizeof(float) + float_padding); - - STEEL_CONST bool share_kv_smem = transpose_k != transpose_v; - - STEEL_CONST short tgp_mem_size = share_kv_smem - ? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + - tgp_mem_size_corrections - : tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + - tgp_mem_size_corrections + tgp_mem_size_v; - - STEEL_CONST short tgp_size = WM * WN * 32; - - static_assert(transpose_q == false, "Expected Q not transposed."); - static_assert(transpose_k == true, "Expected K transposed."); - static_assert(transpose_v == false, "Expected V not transposed."); - static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested."); - - using loader_q_t = BlockLoaderFA< - T, - transpose_q ? BK : BM, - transpose_q ? BM : BK, - transpose_q ? BM + tgp_padding : BK + tgp_padding, - !transpose_q, - tgp_size>; - - using loader_k_t = BlockLoaderFA< - T, - transpose_k ? BN : BK, - transpose_k ? BK : BN, - transpose_k ? BK + tgp_padding : BN + tgp_padding, - transpose_k, - tgp_size>; - - using loader_v_t = BlockLoaderFA< - T, - transpose_v ? BK : BN, - transpose_v ? BN : BK, - transpose_v ? BN + tgp_padding : BK + tgp_padding, - transpose_v, - tgp_size>; - - using mma_qk_t = BlockMMAFA< - T, - U, - BM, - BN, - BK, - WM, - WN, - transpose_q, - transpose_k, - transpose_q ? BM + tgp_padding : BK + tgp_padding, - transpose_k ? BK + tgp_padding : BN + tgp_padding, - AccumType, - Epilogue>; - - using mma_sv_t = BlockMMAFA< - T, - U, - BM, - BK, - BN, - WM, - WN, - false, - transpose_v, - BN + tgp_padding, - BK + tgp_padding, - AccumType, - Epilogue>; - - /* Main kernel function */ - template - static METAL_FUNC void gemm_loop( - threadgroup T* As [[threadgroup(0)]], - threadgroup T* Bs [[threadgroup(1)]], - const int gemm_k_iterations, - thread loader_k_t& loader_b, - thread mma_qk_t& mma_op, - thread const short& tgp_bm, - thread const short& tgp_bn, - LoopAlignment l = {}) { - // Appease the compiler - (void)l; - (void)tgp_bm; - - short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - // not valid for gemm_k_iterations > 1 (so, BK == d_k) - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (N_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe(tile_dims_B); - } + typename MaskType = float, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant AttnParams* params [[buffer(4)]], + const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], + const device MaskType* mask [[buffer(6), function_constant(has_mask)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on - threadgroup_barrier(mem_flags::mem_threadgroup); + // Pacifying compiler + (void)lid; - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - } + // Move to correct block + ulong3 tidl{tid.x, tid.y, tid.z}; + + Q += tidl.z * params->Q_strides[0] + // Batch + tidl.y * params->Q_strides[1] + // Head + tidl.x * BQ * params->Q_strides[2]; // Seqeunce + + ulong kv_head_idx = int(tid.y) / params->gqa_factor; + K += tidl.z * params->K_strides[0] + // Batch + kv_head_idx * params->K_strides[1]; // Head + + V += tidl.z * params->V_strides[0] + // Batch + kv_head_idx * params->V_strides[1]; // Head + + O += tidl.z * params->O_strides[0] + // Batch + tidl.y * params->O_strides[1] + // Head + tidl.x * BQ * params->O_strides[2]; // Seqeunce + + if (has_mask) { + mask += tidl.z * mask_params->M_strides[0] + // Batch + tidl.y * mask_params->M_strides[1]; // Head + } + + // Prepare threadgroup memory + constexpr short padQ = 16 / sizeof(T); + constexpr short padK = 16 / sizeof(T); + constexpr short padV = 16 / sizeof(T); + + constexpr short LDQ_tgp = BD + padQ; + constexpr short LDK_tgp = BK + padK; + constexpr short LDV_tgp = BD + padV; + + constexpr short tgp_mem_0 = (BK + padK) * (BD); + constexpr short tgp_mem_1 = BK * (BD + padV); + constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1; + + threadgroup T Q_smem[BQ * (BD + padQ)]; + threadgroup T KV_smem[tgp_mem_s]; + + threadgroup T* Qs = Q_smem; + threadgroup T* Ks = KV_smem; + threadgroup T* Vs = KV_smem; + + // Prepare block loaders + using QBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BQ, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDQ_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 1, + /* short tgp_size = */ WM * WN * 32>; + + // K is loaded in transposed + using KBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ 1, + /* short kDstStrCol = */ LDK_tgp, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + using VBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDV_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + QBlockLoader loader_q( + Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id); + KBlockLoader loader_k( + K, params->K_strides[2], Ks, simd_group_id, simd_lane_id); + VBlockLoader loader_v( + V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); + + TransformScale ts(static_cast(params->scale * 1.44269504089)); + + // Prepare MMA tiles + constexpr short kFragSize = 8; // MMAFrag size + using MMAFrag_acc_t = BaseMMAFrag; + + constexpr int kNWarps = WM * WN; + static_assert( + BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0, + "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); + + // Q seq frags per warp + constexpr int TQ = BQ / (kNWarps * kFragSize); + // KV sequence frags (all warps load the same frags) + constexpr int TK = BK / kFragSize; + // HeadDim frags (all warps load the same frags) + constexpr int TD = BD / kFragSize; + + static_assert(TQ == 1, "Check TQ"); + + MMATile Qtile; + MMATile Ktile; + MMATile Stile; + MMATile Vtile; + MMATile Otile; + + Otile.clear(); + + // Prepare mma tile offsets + const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = kFragSize * TQ * simd_group_id; + + const short Qs_offset = (tm + sm) * LDQ_tgp + sn; + const short Ks_offset = sm * LDK_tgp + sn; + const short Vs_offset = sm * LDV_tgp + sn; + + constexpr short Qs_tile_stride = kFragSize; + constexpr short Ks_tile_stride = kFragSize * LDK_tgp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load Q blocks apply scale + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + loader_q.load_safe(short2(BD, params->qL_rem)); + } else { + loader_q.load_unsafe(); + } + loader_q.apply_inplace_op(ts); + + // Init row reduction variables + constexpr short kRowsPT = decltype(Stile)::kRowsPerThread; + + AccumType max_score[kRowsPT]; + AccumType sum_score[kRowsPT] = {0}; + + // Init to -Inf + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = Limits::min; + } + + int kb_lim = params->NK; + + if (do_causal) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; } - static METAL_FUNC void initialize_corrections( - threadgroup float* C, - uint simd_lane_id, - uint simd_group_id) { - if (simd_group_id == 0) { - threadgroup float* maxes = C; - threadgroup float* sums = C + (BM + float_padding); - threadgroup float* o_rescale = sums + (BM + float_padding); - threadgroup float* output_rescale = o_rescale + (BM + float_padding); - - if (simd_lane_id < BM) { - maxes[simd_lane_id] = -INFINITY; // m_i - sums[simd_lane_id] = 0.f; // l_i - o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new) - output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i + // Loop over KV seq length + for (int kb = 0; kb < kb_lim; kb++) { + // Load K block and apply scale + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!align_K && kb == (params->NK_aligned)) { + loader_k.load_safe(short2(BD, params->kL_rem)); + } else { + loader_k.load_unsafe(); + } + + // Do S = Q @ K.T + Stile.clear(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL + for (short dd = 0; dd < TD; dd++) { + simdgroup_barrier(mem_flags::mem_none); + + Qtile.template load( + &Qs[Qs_offset + dd * Qs_tile_stride]); + Ktile.template load( + &Ks[Ks_offset + dd * Ks_tile_stride]); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Stile, Qtile, Ktile, Stile); + } + + // Mask out length sequence + if (!align_K && kb == (params->NK_aligned)) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + short col_pos = sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if ((col_pos + jj) >= params->kL_rem) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } } } - } - static METAL_FUNC void rescale_ss( - threadgroup T* Ss, - threadgroup float* Corrections, - uint simd_group_id, - uint simd_lane_id, - short2 local_blocks, - float alpha, - float softcapping) { - if (simd_group_id == 0) { - short row_offset = BM + float_padding; - threadgroup float* maxes = Corrections; - threadgroup float* sums = Corrections + row_offset; - threadgroup float* o_rescale = sums + row_offset; - threadgroup float* output_scales = o_rescale + row_offset; - - if (simd_lane_id < uint(local_blocks.y)) { - float m_i_old = maxes[simd_lane_id]; - float l_i_old = sums[simd_lane_id]; - - float m_i_new = m_i_old; - float l_i_new = l_i_old; - - short offset = simd_lane_id * (BN + tgp_padding); - - float m_ij = -INFINITY; - - for (short j = 0; j < local_blocks.x; j++) { - float val = alpha * float(Ss[offset + j]); - if (softcapping != 1.) { - val = precise::tanh(val); - val = val * softcapping; + // Mask out if causal + if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = + tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if (row_pos < (col_pos + jj)) { + Stile.frag_at(i, j)[jj] = neg_inf; + } } - m_ij = max(m_ij, val); } + } + } + + // Other masking as needed + if (has_mask) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); - m_i_new = max(m_ij, m_i_new); + constexpr bool is_bool = is_same_v; + using melem_t = typename metal::conditional_t; - float rowsum = 0.f; // lij + using MMAFrag_mask_t = BaseMMAFrag; + using frag_t = typename MMAFrag_mask_t::frag_type; - for (short j = 0; j < local_blocks.x; j++) { - float val = alpha * float(Ss[offset + j]); - if (softcapping != 1.) { - val = precise::tanh(val); - val = val * softcapping; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + + frag_t mfrag; + + MMAFrag_mask_t::load_safe( + mfrag, + mask, + int(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) { + if constexpr (is_bool) { + Stile.frag_at(i, j)[jj] = + mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; + } else { + Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]); + } } - float P_i_j = exp(val - m_ij); - rowsum += P_i_j; - P_i_j = P_i_j * exp(m_ij - m_i_new); - Ss[offset + j] = T(P_i_j); } - - l_i_new = - exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum; - maxes[simd_lane_id] = m_i_new; - sums[simd_lane_id] = l_i_new; - float rescale = l_i_old * exp(m_i_old - m_i_new); - o_rescale[simd_lane_id] = rescale; - output_scales[simd_lane_id] = 1.0 / l_i_new; } } - } - /* Main kernel function */ - static METAL_FUNC void run( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device U* O [[buffer(3)]], - const constant MLXFastAttentionParams* params [[buffer(4)]], - threadgroup T* Qs [[threadgroup(0)]], - threadgroup T* Ks [[threadgroup(1)]], - threadgroup T* Ss [[threadgroup(2)]], - threadgroup T* Vs [[threadgroup(3)]], - threadgroup float* Corrections [[threadgroup(4)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Pacifying compiler - (void)lid; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load V blocks + if (!align_K && kb == (params->NK_aligned)) { + loader_v.load_safe(short2(BD, params->kL_rem)); + } else { + loader_v.load_unsafe(); } - threadgroup_barrier(mem_flags::mem_none); - - // Find block in Q, O; and head in K, V. - const int c_row = tid_y * BM; - - Q += transpose_q ? c_row : c_row * params->ldq; - thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id); - - short tgp_bm = min(BM, params->M - c_row); - short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - - loader_q.load_safe(tile_dims_Q); - - initialize_corrections(Corrections, simd_lane_id, simd_group_id); - - O += c_row * params->ldo; - - // Prepare threadgroup mma operation - thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id); - thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id); - thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id); - thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id); - - for (short n_block = 0; n_block < params->gemm_n_iterations_aligned; - n_block++) { - short c_col = BN; - - // Prepare threadgroup loading operations - short gemm_k_iterations = params->gemm_k_iterations_aligned; - short tgp_bn_qk = min(BN, params->N - c_col * n_block); - threadgroup_barrier(mem_flags::mem_none); - - /////////////////////////////////////////////////////////////////////////////// - { // Loop over K - unaligned case - - if (tgp_bm == BM && tgp_bn_qk == BN) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - } else if (tgp_bn_qk == BN) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - - } else if (tgp_bm == BM) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); + // Do softmax - } else { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - } - } + // Temp variables + AccumType new_max[kRowsPT]; + AccumType factor[kRowsPT]; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + new_max[i] = max_score[i]; + } - mma_qk_op.store_result_to_tgp_memory( - Ss, BN + tgp_padding, short2(BN, BM)); + // Row max + Stile.template row_reduce(new_max); - threadgroup_barrier(mem_flags::mem_threadgroup); + // exp(Si - rowmax(Si)) + Stile.template row_bin_op(new_max); - rescale_ss( - Ss, - Corrections, - simd_group_id, - simd_lane_id, - short2(tgp_bn_qk, tgp_bm), - params->alpha, - params->softcapping); + // Factor exp(rowmax(Si) - rowmax(Si-1)) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + factor[i] = fast::exp2(max_score[i] - new_max[i]); + } + + // Save max for next iteration + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = new_max[i]; + } + + // Row Sum + AccumType sum_score_tmp[kRowsPT] = {0}; + Stile.template row_reduce(sum_score_tmp); - loader_v.load_safe(short2(BK, tgp_bn_qk)); + // Update norm + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i]; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + // Update O + Otile.template row_bin_op(factor); - threadgroup float* o_scales = Corrections + 2 * (BM + float_padding); - mma_softmax_sv_op.rescale_output(o_scales); + // Load V into registers + threadgroup_barrier(mem_flags::mem_threadgroup); - mma_softmax_sv_op.mma(Ss, Vs); + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + if constexpr (BD == 128) { + simdgroup_barrier(mem_flags::mem_none); + } - threadgroup float* final_output_scales = - Corrections + 3 * (BM + float_padding); + const short kk = ik * kFragSize; + const short dd = id * kFragSize; - mma_softmax_sv_op.rescale_output(final_output_scales); + Vtile.template load( + &Vs[Vs_offset + kk * LDV_tgp + dd]); - loader_v.next(); - loader_k.next(BN); + if constexpr (BD == 128) { + simdgroup_barrier(mem_flags::mem_none); + } - mma_qk_op.clear_results(); + MMAFrag_acc_t::mma( + Otile.frag_at(iq, id), + Stile.frag_at(iq, ik), + Vtile.frag_at(0, 0), + Otile.frag_at(iq, id)); + } + } } - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm)); + // Prepare for next iteration + loader_k.next(); + loader_v.next(); } -}; -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_q, - bool transpose_k, - bool transpose_v, - bool MN_aligned, - bool K_aligned> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device T* O [[buffer(3)]], - const constant MLXFastAttentionParams* params [[buffer(4)]], - const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using attention_kernel = FastAttentionKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_q, - transpose_k, - transpose_v, - MN_aligned, - K_aligned>; - - // Adjust for batch - if (params->batch_ndim > 1) { - const constant size_t* Q_bstrides = batch_strides; - const constant size_t* KV_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim); - - Q += batch_offsets.x; - K += batch_offsets.y; - V += batch_offsets.y; + // Normalize output + Otile.template row_bin_op(sum_score); + threadgroup_barrier(mem_flags::mem_none); + // Store results + O += (tm + sm) * params->O_strides[2] + sn; + + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm)); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Otile.template store_safe(O, params->O_strides[2], dst_tile_dims); } else { - Q += params->batch_stride_q * tid.z; - K += params->batch_stride_k * tid.z; - V += params->batch_stride_v * tid.z; - } - - // same shape as input - O += params->batch_stride_o * tid.z; - threadgroup T Qs[attention_kernel::tgp_mem_size_q]; - threadgroup T Ss[attention_kernel::tgp_mem_size_s]; - threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections]; - - if (attention_kernel::share_kv_smem) { - threadgroup T Ks[attention_kernel::tgp_mem_size_k]; - threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v]; - attention_kernel::run( - Q, - K, - V, - O, - params, - Qs, - Ks, - Ss, - Vs, - Corrections, - simd_lane_id, - simd_group_id, - tid, - lid); - } else { - threadgroup T Ks[attention_kernel::tgp_mem_size_k]; - threadgroup T Vs[attention_kernel::tgp_mem_size_v]; - attention_kernel::run( - Q, - K, - V, - O, - params, - Qs, - Ks, - Ss, - Vs, - Corrections, - simd_lane_id, - simd_group_id, - tid, - lid); + Otile.template store(O, params->O_strides[2]); } } // clang-format off // SDPA full instantiations -#define instantiate_fast_inference_self_attention_kernel( \ - itype, otype, bm, bn, bk, wm, wn) \ - template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \ - "_itype_" #itype)]] [[kernel]] void \ - attention( \ - const device itype* Q [[buffer(0)]], \ - const device itype* K [[buffer(1)]], \ - const device itype* V [[buffer(2)]], \ - device otype* O [[buffer(3)]], \ - const constant MLXFastAttentionParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(5)]], \ - const constant size_t* batch_strides [[buffer(6)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); - -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 32, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 64, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 96, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 128, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 256, - 2, - 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 32, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 96, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); + +// Instantiate a templated kernel. +// Extra args are used as template parameters: +// e.g. instantiate_kernel(binary_int, binary, a, b) -> +// [[host_name(binary_int)]] [kernel] binary +#define instantiate_kernel(name, func, ...) \ + template [[host_name( \ + name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; + +#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ + instantiate_kernel( \ + "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ + "_wm" #wm "_wn" #wn "_mask" #mname, \ + attention, dtype, bq, bk, bd, wm, wn, mtype, float) + +#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 256, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 96, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 72, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 32, 4, 1, mname, mtype) + +#define instantiate_attn_mask_helper(iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, bool_, bool) + +instantiate_attn_mask_helper(float16, half); +instantiate_attn_mask_helper(bfloat16, bfloat16_t); +instantiate_attn_mask_helper(float32, float); // SDPA vector instantiations #define instantiate_sdpa_vector(type, head_dim) \ @@ -1443,13 +2374,13 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); #define instantiate_sdpa_vector_heads(type) \ instantiate_sdpa_vector(type, 32) \ instantiate_sdpa_vector(type, 64) \ + instantiate_sdpa_vector(type, 72) \ + instantiate_sdpa_vector(type, 80) \ instantiate_sdpa_vector(type, 96) \ instantiate_sdpa_vector(type, 128) \ instantiate_sdpa_vector(type, 256) instantiate_sdpa_vector_heads(float) -#if defined(__HAVE_BFLOAT__) instantiate_sdpa_vector_heads(bfloat16_t) -#endif instantiate_sdpa_vector_heads(float16_t) // clang-format on diff --git a/candle-nn/benches/benchmarks/mod.rs b/candle-nn/benches/benchmarks/mod.rs index a34d888439..bfe04456df 100644 --- a/candle-nn/benches/benchmarks/mod.rs +++ b/candle-nn/benches/benchmarks/mod.rs @@ -2,6 +2,8 @@ pub(crate) mod conv; pub(crate) mod layer_norm; pub(crate) mod softmax; +#[cfg(feature = "cuda")] +use candle::backend::BackendDevice; use candle::{Device, Result}; pub(crate) trait BenchDevice { @@ -18,13 +20,13 @@ impl BenchDevice for Device { #[cfg(feature = "cuda")] return Ok(device.synchronize()?); #[cfg(not(feature = "cuda"))] - panic!("Cuda device without cuda feature enabled: {:?}", device) + panic!("Cuda device without cuda feature enabled: {device:?}") } Device::Metal(device) => { #[cfg(feature = "metal")] return Ok(device.wait_until_completed()?); #[cfg(not(feature = "metal"))] - panic!("Metal device without metal feature enabled: {:?}", device) + panic!("Metal device without metal feature enabled: {device:?}") } } } diff --git a/candle-nn/benches/benchmarks/softmax.rs b/candle-nn/benches/benchmarks/softmax.rs index 2a1ea2d547..3c0c43bc44 100644 --- a/candle-nn/benches/benchmarks/softmax.rs +++ b/candle-nn/benches/benchmarks/softmax.rs @@ -6,7 +6,7 @@ use criterion::{black_box, criterion_group, Criterion}; use std::time::Instant; fn run(input: &Tensor) { - let _ = softmax_last_dim(&input).unwrap(); + let _ = softmax_last_dim(input).unwrap(); } const B: usize = 1; @@ -16,7 +16,7 @@ const K: usize = 1024; fn run_softmax_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { let elements = B * M * K; - let input = Tensor::rand(-1000.0f32, 1000.0f32, (B, M, K), &device) + let input = Tensor::rand(-1000.0f32, 1000.0f32, (B, M, K), device) .unwrap() .to_dtype(dtype) .unwrap(); diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 79affdae40..3f2dd27217 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -974,6 +974,8 @@ impl Module for Identity { struct Sdpa { scale: f32, softcapping: f32, + mask: Option, + do_causal: bool, } impl candle::CustomOp3 for Sdpa { @@ -1010,6 +1012,8 @@ impl candle::CustomOp3 for Sdpa { let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?]; let elem_count: usize = out_dims.iter().product(); + let out_shape = Shape::from_dims(&out_dims); + let out_layout = Layout::contiguous(out_shape.clone()); let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?; @@ -1031,16 +1035,20 @@ impl candle::CustomOp3 for Sdpa { let k_head = k_l.dim(D::Minus1)?; let q_head = q_l.dim(D::Minus1)?; let q_seq = q_l.dim(2)?; + let k_seq = k_l.dim(2)?; let mut implementation_supports_use_case = q_head == k_head; - let supported_head_dim = - q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256; - - const SDPA_FULL_THRESHOLD: usize = 2; - - let supports_sdpa_full = - q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head; - let supports_sdpa_vector = q_seq == 1 && supported_head_dim; + let supported_head_dim = q_head == 32 + || q_head == 64 + || q_head == 72 + || q_head == 80 + || q_head == 96 + || q_head == 128 + || q_head == 256; + + let supports_sdpa_full_mask = !self.mask.is_some() || q_seq <= k_seq; + let supports_sdpa_full = q_seq > 8 && supported_head_dim && supports_sdpa_full_mask; + let supports_sdpa_vector = q_seq <= 8 && supported_head_dim && q_seq <= k_seq; implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector; @@ -1079,7 +1087,7 @@ impl candle::CustomOp3 for Sdpa { // Route to the 2 pass fused attention if the k seqlen is large. // https://github.com/ml-explore/mlx/pull/1597 const TWO_PASS_K_THRESHOLD: usize = 1024; - if k_l.dim(2)? >= TWO_PASS_K_THRESHOLD { + if k_seq >= TWO_PASS_K_THRESHOLD { let mut intermediate_shape = [ &out_dims[0..out_dims.len() - 2], &[candle_metal_kernels::SDPA_2PASS_BLOCKS], @@ -1151,27 +1159,70 @@ impl candle::CustomOp3 for Sdpa { .map_err(candle::Error::wrap)?; } } else if supports_sdpa_full { - if q_l.dim(2)? != k_l.dim(2)? { - candle::bail!( - "query and key sequence length must be equal if using full metal sdpa" - ) + command_buffer.set_label("full_attention"); + if self.softcapping != 1. { + candle::bail!("SDPA full requires softcapping to be disabled (1.0)"); } - command_buffer.set_label("full_attention"); + let mask_s_l = self.mask.as_ref().map(|m| m.storage_and_layout()); + + let (mask_type, mask_buffer, mask_strides) = if let Some(mask) = &self.mask { + let (mask_s, mask_l) = mask_s_l.as_ref().unwrap(); + + let mask_buffer = match &**mask_s { + candle::Storage::Metal(m) => m.buffer(), + _ => candle::bail!("Expected metal device for mask"), + }; + + let mask_type = match mask.dtype() { + DType::BF16 => SdpaDType::BF16, + DType::F16 => SdpaDType::F16, + DType::F32 => SdpaDType::F32, + other => candle::bail!("unsupported sdpa type {other:?}"), + }; + if mask_type != itype { + candle::bail!("Mask type {mask_type:?} must match q type {itype:?}"); + } + + if mask_l.dims() != [q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, k_seq] { + candle::bail!( + "Mask shape must be {:?} (bs, qheads, qseq, kseq), got {:?}", + [q_l.dim(0)?, q_head, q_l.dim(2)?, k_seq], + mask_l.dims() + ); + } + + ( + Some(mask_type), + Some(mask_buffer), + Some(mask_l.stride().to_vec()), + ) + } else { + (None, None, None) + }; + candle_metal_kernels::call_sdpa_full( q.device().device(), &command_buffer, q.device().kernels(), q_l.start_offset(), q_l.dims(), + q_l.stride(), q.buffer(), k_l.start_offset(), + k_l.dims(), + k_l.stride(), k.buffer(), v_l.start_offset(), v.buffer(), + v_l.stride(), + mask_type, + mask_buffer, + mask_strides.as_deref(), &output, + out_layout.stride(), self.scale, - self.softcapping, + self.do_causal, itype, ) .map_err(candle::Error::wrap)?; @@ -1180,7 +1231,7 @@ impl candle::CustomOp3 for Sdpa { } let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype()); - Ok((newstorage, Shape::from_dims(&out_dims))) + Ok((newstorage, out_shape)) } } @@ -1192,13 +1243,15 @@ impl candle::CustomOp3 for Sdpa { /// - `q`: (bs, qhead, seq, hidden) /// - `k`: (bs, kv_head, kv_seq, hidden) /// - `k`: (bs, kv_head, kv_seq, v_hidden) +/// - `mask`: (bs, qhead, seq, kv_seq) +/// - `do_causal`: Apply causal masking. If this is true, the mask does not need to be provided. /// - `scale` is applied before softmax. /// - If `softcapping` != 1.0: /// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v /// /// **Output shape:** (bs, qhead, seq, v_hidden) /// -/// **Supported head dims:** 32, 64, 96, 128, 256. +/// Note: For Grouped Query Attention and Multi-Query Attention, the k and v inputs should not be pre-tiled to match q. /// /// ## On Metal: /// - If `seq` == 1: @@ -1206,9 +1259,27 @@ impl candle::CustomOp3 for Sdpa { /// - Supports `seq` != `kv_seq` (cross attn. support) /// - Supports GQA when `qhead` is a multiple of `kv_head` /// - Otherwise: -/// - Use an alternate kernel -/// - Requires `seq` == `kv_seq` -/// - GQA is not supported (requires `qhead` == `kv_head`) -pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result { - q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping }) +/// - Masking is supported +/// - Supports `seq` != `kv_seq` (cross attn. support) +/// - Supports GQA when `qhead` is a multiple of `kv_head` +/// - Softcapping is not supported. +pub fn sdpa( + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + do_causal: bool, + scale: f32, + softcapping: f32, +) -> Result { + q.apply_op3_no_bwd( + k, + v, + &Sdpa { + scale, + softcapping, + mask: mask.cloned(), + do_causal, + }, + ) } diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index cce6050806..eb07ee3962 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,5 +1,3 @@ -//! A `VarBuilder` for variable retrieval from models -//! //! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come //! from a pre-trained checkpoint, e.g. using `VarBuilder::from_mmaped_safetensors`, or initialized //! for training, e.g. using `VarBuilder::from_varmap`. @@ -36,7 +34,8 @@ impl Clone for VarBuilderArgs<'_, B> { pub type VarBuilder<'a> = VarBuilderArgs<'a, Box>; struct TensorData { - backend: B, + backend: Arc, + pub dtype: DType, pub device: Device, } @@ -59,6 +58,9 @@ pub trait Backend: Send + Sync { dev: &Device, ) -> Result; + /// Retrieve a tensor based on the name. + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result; + fn contains_tensor(&self, name: &str) -> bool; } @@ -73,6 +75,9 @@ pub trait SimpleBackend: Send + Sync { dev: &Device, ) -> Result; + /// Retrieve a tensor based on the name. + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result; + fn contains_tensor(&self, name: &str) -> bool; } @@ -89,6 +94,10 @@ impl Backend for Box { self.as_ref().get(s, name, h, dtype, dev) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + self.as_ref().get_unchecked(name, dtype, dev) + } + fn contains_tensor(&self, name: &str) -> bool { self.as_ref().contains_tensor(name) } @@ -97,7 +106,8 @@ impl Backend for Box { impl VarBuilderArgs<'_, B> { pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self { let data = TensorData { - backend, + backend: Arc::new(backend), + dtype, device: dev.clone(), }; Self { @@ -194,7 +204,7 @@ impl VarBuilderArgs<'_, B> { name: &str, hints: B::Hints, ) -> Result { - self.get_with_hints_dtype(s, name, hints, self.dtype) + self.get_with_hints_dtype(s, name, hints, self.data.dtype) } /// Retrieve the tensor associated with the given name at the current path. @@ -202,6 +212,19 @@ impl VarBuilderArgs<'_, B> { self.get_with_hints(s, name, Default::default()) } + /// Retrieve the tensor associated with the given name at the current path. + pub fn get_unchecked(&self, name: &str) -> Result { + self.get_unchecked_dtype(name, self.data.dtype) + } + + /// Retrieve the tensor associated with the given name & dtype at the current path. + pub fn get_unchecked_dtype(&self, name: &str, dtype: DType) -> Result { + let name = self.path(name); + self.data + .backend + .get_unchecked(&name, dtype, &self.data.device) + } + /// Retrieve the tensor associated with the given name & dtype at the current path. pub fn get_with_hints_dtype>( &self, @@ -215,6 +238,31 @@ impl VarBuilderArgs<'_, B> { .backend .get(s.into(), &path, hints, dtype, &self.data.device) } + + /// Set the device of the VarBuilder. + pub fn set_device(self, device: Device) -> Self { + Self { + data: Arc::new(TensorData { + backend: self.data.backend.clone(), + dtype: self.data.dtype, + device, + }), + ..self + } + } + + /// Set the dtype of the VarBuilder. + pub fn set_dtype(self, dtype: DType) -> Self { + Self { + data: Arc::new(TensorData { + backend: self.data.backend.clone(), + dtype, + device: self.data.device.clone(), + }), + dtype, + ..self + } + } } struct Zeros; @@ -224,6 +272,12 @@ impl SimpleBackend for Zeros { Tensor::zeros(s, dtype, dev) } + fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result { + candle::bail!( + "`Zeros` requires a shape for tensor retrieval, use `get` instead of `get_unchecked`" + ) + } + fn contains_tensor(&self, _name: &str) -> bool { true } @@ -238,6 +292,19 @@ impl SimpleBackend for HashMap { dtype: DType, dev: &Device, ) -> Result { + let tensor = self.get_unchecked(name, dtype, dev)?; + if tensor.shape() != &s { + Err(candle::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { let tensor = self .get(name) .ok_or_else(|| { @@ -247,14 +314,6 @@ impl SimpleBackend for HashMap { .bt() })? .clone(); - if tensor.shape() != &s { - Err(candle::Error::UnexpectedShape { - msg: format!("shape mismatch for {name}"), - expected: s, - got: tensor.shape().clone(), - } - .bt())? - } tensor.to_device(dev)?.to_dtype(dtype) } @@ -275,6 +334,10 @@ impl SimpleBackend for VarMap { VarMap::get(self, s, name, h, dtype, dev) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + VarMap::get_unchecked(self, name, dtype, dev) + } + fn contains_tensor(&self, name: &str) -> bool { self.data().lock().unwrap().contains_key(name) } @@ -290,11 +353,24 @@ impl SimpleBackend for SafeTensorWithRouting<'_> { fn get( &self, s: Shape, - path: &str, + name: &str, _: crate::Init, dtype: DType, dev: &Device, ) -> Result { + let tensor = self.get_unchecked(name, dtype, dev)?; + if tensor.shape() != &s { + Err(candle::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn get_unchecked(&self, path: &str, dtype: DType, dev: &Device) -> Result { let index = self.routing.get(path).ok_or_else(|| { Error::CannotFindTensor { path: path.to_string(), @@ -305,14 +381,6 @@ impl SimpleBackend for SafeTensorWithRouting<'_> { .tensor(path)? .load(dev)? .to_dtype(dtype)?; - if tensor.shape() != &s { - Err(candle::Error::UnexpectedShape { - msg: format!("shape mismatch for {path}"), - expected: s, - got: tensor.shape().clone(), - } - .bt())? - } Ok(tensor) } @@ -325,22 +393,15 @@ impl SimpleBackend for candle::npy::NpzTensors { fn get( &self, s: Shape, - path: &str, + name: &str, _: crate::Init, dtype: DType, dev: &Device, ) -> Result { - let tensor = match self.get(path)? { - None => Err(Error::CannotFindTensor { - path: path.to_string(), - } - .bt())?, - Some(tensor) => tensor, - }; - let tensor = tensor.to_device(dev)?.to_dtype(dtype)?; + let tensor = self.get_unchecked(name, dtype, dev)?; if tensor.shape() != &s { Err(candle::Error::UnexpectedShape { - msg: format!("shape mismatch for {path}"), + msg: format!("shape mismatch for {name}"), expected: s, got: tensor.shape().clone(), } @@ -349,6 +410,18 @@ impl SimpleBackend for candle::npy::NpzTensors { Ok(tensor) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + let tensor = match self.get(name)? { + None => Err(Error::CannotFindTensor { + path: name.to_string(), + } + .bt())?, + Some(tensor) => tensor, + }; + let tensor = tensor.to_device(dev)?.to_dtype(dtype)?; + Ok(tensor) + } + fn contains_tensor(&self, name: &str) -> bool { self.get(name).is_ok_and(|v| v.is_some()) } @@ -358,22 +431,15 @@ impl SimpleBackend for candle::pickle::PthTensors { fn get( &self, s: Shape, - path: &str, + name: &str, _: crate::Init, dtype: DType, dev: &Device, ) -> Result { - let tensor = match self.get(path)? { - None => Err(Error::CannotFindTensor { - path: path.to_string(), - } - .bt())?, - Some(tensor) => tensor, - }; - let tensor = tensor.to_device(dev)?.to_dtype(dtype)?; + let tensor = self.get_unchecked(name, dtype, dev)?; if tensor.shape() != &s { Err(candle::Error::UnexpectedShape { - msg: format!("shape mismatch for {path}"), + msg: format!("shape mismatch for {name}"), expected: s, got: tensor.shape().clone(), } @@ -382,6 +448,18 @@ impl SimpleBackend for candle::pickle::PthTensors { Ok(tensor) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + let tensor = match self.get(name)? { + None => Err(Error::CannotFindTensor { + path: name.to_string(), + } + .bt())?, + Some(tensor) => tensor, + }; + let tensor = tensor.to_device(dev)?.to_dtype(dtype)?; + Ok(tensor) + } + fn contains_tensor(&self, name: &str) -> bool { self.get(name).is_ok_and(|v| v.is_some()) } @@ -396,7 +474,7 @@ impl SimpleBackend for candle::safetensors::MmapedSafetensors { dtype: DType, dev: &Device, ) -> Result { - let tensor = self.load(name, dev)?.to_dtype(dtype)?; + let tensor = self.get_unchecked(name, dtype, dev)?; if tensor.shape() != &s { Err(candle::Error::UnexpectedShape { msg: format!("shape mismatch for {name}"), @@ -408,6 +486,10 @@ impl SimpleBackend for candle::safetensors::MmapedSafetensors { Ok(tensor) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + self.load(name, dev)?.to_dtype(dtype) + } + fn contains_tensor(&self, name: &str) -> bool { self.get(name).is_ok() } @@ -422,7 +504,7 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors { dtype: DType, dev: &Device, ) -> Result { - let tensor = self.load(name, dev)?.to_dtype(dtype)?; + let tensor = self.get_unchecked(name, dtype, dev)?; if tensor.shape() != &s { Err(candle::Error::UnexpectedShape { msg: format!("shape mismatch for {name}"), @@ -434,6 +516,10 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors { Ok(tensor) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + self.load(name, dev)?.to_dtype(dtype) + } + fn contains_tensor(&self, name: &str) -> bool { self.get(name).is_ok() } @@ -448,7 +534,7 @@ impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> { dtype: DType, dev: &Device, ) -> Result { - let tensor = self.load(name, dev)?.to_dtype(dtype)?; + let tensor = self.get_unchecked(name, dtype, dev)?; if tensor.shape() != &s { Err(candle::Error::UnexpectedShape { msg: format!("shape mismatch for {name}"), @@ -460,6 +546,10 @@ impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> { Ok(tensor) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + self.load(name, dev)?.to_dtype(dtype) + } + fn contains_tensor(&self, name: &str) -> bool { self.get(name).is_ok() } @@ -476,7 +566,11 @@ impl<'a> VarBuilder<'a> { dtype: DType, device: Device, ) -> Self { - let data = TensorData { backend, device }; + let data = TensorData { + backend: Arc::new(backend), + dtype, + device, + }; Self { data: Arc::new(data), path: vec![], @@ -544,17 +638,7 @@ impl<'a> VarBuilder<'a> { let pth = candle::pickle::PthTensors::new(p, None)?; Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) } - /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. - /// similar to [`from_pth`] but requires a `state_key`. - pub fn from_pth_with_state>( - p: P, - dtype: DType, - state_key: &str, - dev: &Device, - ) -> Result { - let pth = candle::pickle::PthTensors::new(p, Some(state_key))?; - Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) - } + /// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before /// passing the new names to the inner VarBuilder. /// @@ -590,7 +674,11 @@ impl<'a> VarBuilder<'a> { let path = self.path.clone(); let backend = Rename::new(self, renamer); let backend: Box = Box::new(backend); - let data = TensorData { backend, device }; + let data = TensorData { + backend: Arc::new(backend), + dtype, + device, + }; Self { data: Arc::new(data), dtype, @@ -714,6 +802,10 @@ impl Backend for ShardedSafeTensors { Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype) } + fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result { + candle::bail!("`get_unchecked` does not make sense for `ShardedSafeTensors`, use `get`."); + } + fn contains_tensor(&self, name: &str) -> bool { self.0.get(name).is_ok() } @@ -747,6 +839,11 @@ impl SimpleBackend for Rename<'_, R> { .to_device(dev) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + let name = self.renamer.rename(name); + self.inner.get_unchecked_dtype(&name, dtype)?.to_device(dev) + } + fn contains_tensor(&self, name: &str) -> bool { let name = self.renamer.rename(name); self.inner.contains_tensor(&name) diff --git a/candle-nn/src/var_map.rs b/candle-nn/src/var_map.rs index ba020746b5..8c7aa53d79 100644 --- a/candle-nn/src/var_map.rs +++ b/candle-nn/src/var_map.rs @@ -1,5 +1,3 @@ -//! A `VarMap` is a store that holds named variables. -//! use candle::{DType, Device, Result, Shape, Tensor, Var}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -32,7 +30,7 @@ impl VarMap { pub fn save>(&self, path: P) -> Result<()> { let tensor_data = self.data.lock().unwrap(); let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor())); - safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; + safetensors::tensor::serialize_to_file(data, None, path.as_ref())?; Ok(()) } @@ -115,6 +113,11 @@ impl VarMap { Ok(tensor) } + /// Retrieve or add a new variable. + pub fn get_unchecked(&self, _path: &str, _dtype: DType, _device: &Device) -> Result { + candle::bail!("`get_unchecked` does not make sense for `VarMap`, use `get`."); + } + pub fn data(&self) -> &Mutex> { &self.data } diff --git a/candle-nn/tests/sdpa.rs b/candle-nn/tests/sdpa.rs index f63d1f05e4..d4d218cfef 100644 --- a/candle-nn/tests/sdpa.rs +++ b/candle-nn/tests/sdpa.rs @@ -1,98 +1,178 @@ #[cfg(feature = "metal")] mod metal_sdpa_tests { - use candle::{DType, Device, Result, Shape, Tensor}; - use rand::SeedableRng; - use rand_distr::Distribution; - use std::ops::{Div, Mul}; - - fn randn>( - rng: &mut rand::rngs::StdRng, - shape: S, - dev: &Device, - ) -> Result { - let shape = shape.into(); - let elem_count = shape.elem_count(); - let normal = rand_distr::Normal::new(0.0, 1.0).unwrap(); - let vs: Vec = (0..elem_count).map(|_| normal.sample(rng)).collect(); - Tensor::from_vec(vs, &shape, dev) - } - #[test] - fn sdpa_full() -> Result<()> { + fn sdpa_full() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + // Force seqlen = 100 const BS: usize = 4; const R: usize = 4; const L: usize = 4; const DK: usize = 64; const H: usize = 3; - let scale: f64 = f64::from(DK as u32).sqrt().recip(); + let device = Device::new_metal(0)?; - let mut rng = rand::rngs::StdRng::seed_from_u64(42); - let q = randn(&mut rng, (BS, H, R, DK), &device)?; - let k = randn(&mut rng, (BS, H, L, DK), &device)?; - let v = randn(&mut rng, (BS, H, L, DK), &device)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let ground_truth = { let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? .to_scalar()?; - assert!(error <= 0.0004, "{}", error); + + assert!(error <= 0.002, "{}", error); + Ok(()) } #[test] - fn sdpa_vector() -> Result<()> { + fn sdpa_vector() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + // Allow vectorized, seqlen = 1 const BS: usize = 4; const R: usize = 1; const L: usize = 1; const DK: usize = 64; const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0001, "{}", error); + + Ok(()) + } + #[test] + fn sdpa_vector_2pass() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + + // Allow vectorized, seqlen = 1 but kseqlen is long (long context) + const BS: usize = 4; + const R: usize = 1; + const L: usize = 2048; + const DK: usize = 64; + const H: usize = 3; let scale: f64 = f64::from(DK as u32).sqrt().recip(); + let device = Device::new_metal(0)?; - let mut rng = rand::rngs::StdRng::seed_from_u64(4242); - let q = randn(&mut rng, (BS, H, R, DK), &device)?; - let k = randn(&mut rng, (BS, H, L, DK), &device)?; - let v = randn(&mut rng, (BS, H, L, DK), &device)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let ground_truth = { let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? .to_scalar()?; - assert!(error <= 0.000, "{}", error); + + assert!(error <= 0.002, "{}", error); + Ok(()) } #[test] - fn sdpa_full_softcapping() -> Result<()> { + fn sdpa_full_masked() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + // Allow vectorized, seqlen = 1 const BS: usize = 4; const R: usize = 4; const L: usize = 4; const DK: usize = 64; const H: usize = 3; - const SOFTCAP: f64 = 50.; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let mask = Tensor::randn(0f32, 1f32, (BS, H, R, L), &device)?; + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim(&(att.to_dtype(DType::F32)? + &mask)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, Some(&mask), false, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.006, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector_softcapping() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + use std::ops::{Div, Mul}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 1; + const DK: usize = 64; + const H: usize = 3; + const SOFTCAP: f64 = 50.; let scale: f64 = f64::from(DK as u32).sqrt().recip(); + let device = Device::new_metal(0)?; - let mut rng = rand::rngs::StdRng::seed_from_u64(424242); - let q = randn(&mut rng, (BS, H, R, DK), &device)?; - let k = randn(&mut rng, (BS, H, L, DK), &device)?; - let v = randn(&mut rng, (BS, H, L, DK), &device)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let ground_truth = { let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; let att = candle_nn::ops::softmax_last_dim( @@ -104,31 +184,41 @@ mod metal_sdpa_tests { .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + + let sdpa_output = + candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, SOFTCAP as f32)?; + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? .to_scalar()?; - assert!(error <= 0.0005, "{}", error); + + assert!(error <= 0.0001, "{}", error); + Ok(()) } #[test] - fn sdpa_vector_softcapping() -> Result<()> { - // Allow vectorized, seqlen = 1 + fn sdpa_vector_2pass_softcapping() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + use std::ops::{Div, Mul}; + + // Allow vectorized, seqlen = 1 but kseqlen is long (long context) const BS: usize = 4; const R: usize = 1; - const L: usize = 1; + const L: usize = 2048; const DK: usize = 64; const H: usize = 3; const SOFTCAP: f64 = 50.; - let scale: f64 = f64::from(DK as u32).sqrt().recip(); + let device = Device::new_metal(0)?; - let mut rng = rand::rngs::StdRng::seed_from_u64(42424242); - let q = randn(&mut rng, (BS, H, R, DK), &device)?; - let k = randn(&mut rng, (BS, H, L, DK), &device)?; - let v = randn(&mut rng, (BS, H, L, DK), &device)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let ground_truth = { let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; let att = candle_nn::ops::softmax_last_dim( @@ -140,42 +230,112 @@ mod metal_sdpa_tests { .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + + let sdpa_output = + candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, SOFTCAP as f32)?; + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? .to_scalar()?; - assert!(error <= 0.0001, "{}", error); + + assert!(error <= 0.0021, "{}", error); + Ok(()) } #[test] - fn sdpa_vector_cross() -> Result<()> { + fn sdpa_vector_cross() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + // Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1 const BS: usize = 4; const R: usize = 1; const L: usize = 24; const DK: usize = 64; const H: usize = 3; - let scale: f64 = f64::from(DK as u32).sqrt().recip(); + let device = Device::new_metal(0)?; - let mut rng = rand::rngs::StdRng::seed_from_u64(4242424242); - let q = randn(&mut rng, (BS, H, R, DK), &device)?; - let k = randn(&mut rng, (BS, H, L, DK), &device)?; - let v = randn(&mut rng, (BS, H, L, DK), &device)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let ground_truth = { let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0017, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector_gqa_2pass_no_mask() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + // GQA && Increase seq_len to 1024 in order to cover 2-pass code branch + + /// Repeats a key or value tensor for grouped query attention + /// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`, + fn repeat_kv(xs: Tensor, n_rep: usize) -> candle::Result { + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?; + // Using cat is faster than a broadcast as it avoids going through a potentially + // strided copy. + // https://github.com/huggingface/candle/pull/2043 + Tensor::cat(&vec![&xs; n_rep], 2)?.reshape(( + b_sz, + n_kv_head * n_rep, + seq_len, + head_dim, + )) + } + } + + const BS: usize = 4; + const R: usize = 1; + const L: usize = 1024; + const DK: usize = 128; + const HQ: usize = 28; + const HKV: usize = 4; + + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + let device = Device::new_metal(0)?; + let q = Tensor::randn(0f32, 1f32, (BS, HQ, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, HKV, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, HKV, L, DK), &device)?; + + let k_aligned = repeat_kv(k.copy().unwrap(), HQ / HKV)?; + let v_aligned = repeat_kv(v.copy().unwrap(), HQ / HKV)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k_aligned.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v_aligned.clone())? + }; + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()? .to_scalar()?; - assert!(error <= 0.0013, "{}", error); + println!("{error}"); + assert!(error <= 0.06, "{}", error); Ok(()) } } diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 3f981c99d9..d618ee88d6 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -205,6 +205,21 @@ trait MapDType { DType::F16 => self.f::(t), DType::F32 => self.f::(t), DType::F64 => self.f::(t), + DType::I16 => Err(PyErr::new::( + "i16 dtype is not supported in Python interface", + )), + DType::I32 => Err(PyErr::new::( + "i32 dtype is not supported in Python interface", + )), + DType::F8E4M3 => Err(PyErr::new::( + "f8e4m3 dtype is not supported in Python interface", + )), + DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { + Err(PyErr::new::(format!( + "Dummy dtype {:?} is not supported", + t.dtype() + ))) + } } } } @@ -518,9 +533,7 @@ impl PyTensor { // Check that the index is in range if actual_index < 0 || actual_index >= dims[current_dim] as isize { return Err(PyValueError::new_err(format!( - "index out of range for dimension '{i}' with indexer '{value}'", - i = current_dim, - value = index + "index out of range for dimension '{current_dim}' with indexer '{index}'" ))); } Ok(actual_index as usize) @@ -580,8 +593,7 @@ impl PyTensor { Ok((Indexer::Expand, current_dim)) } else { Err(PyTypeError::new_err(format!( - "unsupported indexer {}", - py_indexer + "unsupported indexer {py_indexer}" ))) } } @@ -1423,8 +1435,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) gguf_file::Value::Array(x) } else { return Err(PyErr::new::(format!( - "unsupported type {:?}", - v + "unsupported type {v:?}" ))); }; Ok(v) diff --git a/candle-pyo3/src/shape.rs b/candle-pyo3/src/shape.rs index b9bc67899d..4218d86186 100644 --- a/candle-pyo3/src/shape.rs +++ b/candle-pyo3/src/shape.rs @@ -56,8 +56,7 @@ impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole { let any_invalid_dimensions = dims.iter().any(|&x| x < -1 || x == 0); if negative_ones > 1 || any_invalid_dimensions { return Err(PyErr::new::(format!( - "Invalid dimension in shape: {:?}", - dims + "Invalid dimension in shape: {dims:?}" ))); } @@ -89,8 +88,7 @@ impl PyShapeWithHole { new_dims.push(elements); } else { return Err(PyErr::new::(format!( - "Invalid dimension in shape: {}", - dim + "Invalid dimension in shape: {dim}" ))); } } diff --git a/candle-transformers/src/models/chinese_clip/mod.rs b/candle-transformers/src/models/chinese_clip/mod.rs index 1edc903179..ad8f380a24 100644 --- a/candle-transformers/src/models/chinese_clip/mod.rs +++ b/candle-transformers/src/models/chinese_clip/mod.rs @@ -30,7 +30,7 @@ impl From for Activation { "gelu" => Activation::Gelu, "gelu_new" => Activation::GeluNew, "relu" => Activation::Relu, - _ => panic!("Invalid activation function: {}", value), + _ => panic!("Invalid activation function: {value}"), } } } diff --git a/candle-transformers/src/models/deepseek2.rs b/candle-transformers/src/models/deepseek2.rs index 6a418b4326..073b69075b 100644 --- a/candle-transformers/src/models/deepseek2.rs +++ b/candle-transformers/src/models/deepseek2.rs @@ -45,11 +45,33 @@ impl CustomOp1 for NonZero { let result = match storage { candle::CpuStorage::U8(vs) => self.nonzero(vs, layout), candle::CpuStorage::U32(vs) => self.nonzero(vs, layout), + candle::CpuStorage::I16(vs) => self.nonzero(vs, layout), + candle::CpuStorage::I32(vs) => self.nonzero(vs, layout), candle::CpuStorage::I64(vs) => self.nonzero(vs, layout), candle::CpuStorage::BF16(vs) => self.nonzero(vs, layout), candle::CpuStorage::F16(vs) => self.nonzero(vs, layout), candle::CpuStorage::F32(vs) => self.nonzero(vs, layout), candle::CpuStorage::F64(vs) => self.nonzero(vs, layout), + candle::CpuStorage::F8E4M3(vs) => self.nonzero(vs, layout), + // Dummy types don't support nonzero operation + candle::CpuStorage::F6E2M3(_) => { + return Err( + candle::Error::UnsupportedDTypeForOp(candle::DType::F6E2M3, "nonzero").bt(), + ) + } + candle::CpuStorage::F6E3M2(_) => { + return Err( + candle::Error::UnsupportedDTypeForOp(candle::DType::F6E3M2, "nonzero").bt(), + ) + } + candle::CpuStorage::F4(_) => { + return Err(candle::Error::UnsupportedDTypeForOp(candle::DType::F4, "nonzero").bt()) + } + candle::CpuStorage::F8E8M0(_) => { + return Err( + candle::Error::UnsupportedDTypeForOp(candle::DType::F8E8M0, "nonzero").bt(), + ) + } }; let index_len = layout.dims().len(); let result_len = result.len() / index_len; diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs index 21897aa356..2cf0dc9232 100644 --- a/candle-transformers/src/models/mmdit/model.rs +++ b/candle-transformers/src/models/mmdit/model.rs @@ -181,9 +181,9 @@ impl MMDiTCore { ) -> Result { let mut joint_blocks = Vec::with_capacity(depth - 1); for i in 0..depth - 1 { - let joint_block_vb_pp = format!("joint_blocks.{}", i); + let joint_block_vb_pp = format!("joint_blocks.{i}"); let joint_block: Box = - if vb.contains_tensor(&format!("{}.x_block.attn2.qkv.weight", joint_block_vb_pp)) { + if vb.contains_tensor(&format!("{joint_block_vb_pp}.x_block.attn2.qkv.weight")) { Box::new(MMDiTXJointBlock::new( hidden_size, num_heads, diff --git a/candle-transformers/src/models/moondream.rs b/candle-transformers/src/models/moondream.rs index a9dc9b7dc2..4c0b30503e 100644 --- a/candle-transformers/src/models/moondream.rs +++ b/candle-transformers/src/models/moondream.rs @@ -204,7 +204,7 @@ impl VisionTransformer { let blocks = (0..cfg.num_blocks) .map(|i| { VitBlock::new( - vb.pp(format!("blocks.{}", i)), + vb.pp(format!("blocks.{i}")), cfg.embed_dim, cfg.num_heads, cfg, diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index e171b54fd8..1c416b12f2 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -225,7 +225,15 @@ impl LayerWeights { let y = if q.device().is_metal() && seq_len == 1 { // SDPA will do MQA for us - candle_nn::ops::sdpa(&q, &k, &v, 1. / (self.head_dim as f32).sqrt(), 1.)? + candle_nn::ops::sdpa( + &q, + &k, + &v, + None, + false, + 1. / (self.head_dim as f32).sqrt(), + 1., + )? } else { // Support for MQA, useful for 70B models and mistral. let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; diff --git a/candle-transformers/src/models/quantized_moondream.rs b/candle-transformers/src/models/quantized_moondream.rs index c1daffafe4..9a49598bcd 100644 --- a/candle-transformers/src/models/quantized_moondream.rs +++ b/candle-transformers/src/models/quantized_moondream.rs @@ -134,7 +134,7 @@ impl VisionTransformer { let blocks = (0..cfg.num_blocks) .map(|i| { VitBlock::new( - vb.pp(format!("blocks.{}", i)), + vb.pp(format!("blocks.{i}")), cfg.embed_dim, cfg.num_heads, cfg, diff --git a/candle-transformers/src/models/segformer.rs b/candle-transformers/src/models/segformer.rs index 6d750df224..10bdb7fba8 100644 --- a/candle-transformers/src/models/segformer.rs +++ b/candle-transformers/src/models/segformer.rs @@ -420,7 +420,7 @@ impl SegformerEncoder { stride, num_channels, hidden_size, - vb.pp(format!("patch_embeddings.{}", i)), + vb.pp(format!("patch_embeddings.{i}")), )?); let mut layers = Vec::with_capacity(config.depths[i]); for j in 0..config.depths[i] { @@ -433,14 +433,14 @@ impl SegformerEncoder { num_attention_heads, sequence_reduction_ratio, mlp_ratio, - vb.pp(format!("block.{}.{}", i, j)), + vb.pp(format!("block.{i}.{j}")), )?); } blocks.push(layers); layer_norms.push(layer_norm( hidden_size, config.layer_norm_eps, - vb.pp(format!("layer_norm.{}", i)), + vb.pp(format!("layer_norm.{i}")), )?); } Ok(Self { @@ -523,7 +523,7 @@ impl SegformerDecodeHead { linear_c.push(SegformerMLP::new( config, hidden_size, - vb.pp(format!("linear_c.{}", i)), + vb.pp(format!("linear_c.{i}")), )?); } let linear_fuse = conv2d_no_bias( diff --git a/candle-transformers/src/models/xlm_roberta.rs b/candle-transformers/src/models/xlm_roberta.rs index 6fb1268ae4..9b1cdcd5a3 100644 --- a/candle-transformers/src/models/xlm_roberta.rs +++ b/candle-transformers/src/models/xlm_roberta.rs @@ -336,7 +336,7 @@ struct XLMRobertaEncoder { impl XLMRobertaEncoder { fn new(cfg: &Config, vb: VarBuilder) -> Result { let layers = (0..cfg.num_hidden_layers) - .map(|i| XLMRobertaLayer::new(cfg, vb.pp(format!("layer.{}", i)))) + .map(|i| XLMRobertaLayer::new(cfg, vb.pp(format!("layer.{i}")))) .collect::>>()?; Ok(Self { layers }) } diff --git a/candle-wasm-examples/moondream/src/bin/m.rs b/candle-wasm-examples/moondream/src/bin/m.rs index 27cda1e788..0a924c5b0e 100644 --- a/candle-wasm-examples/moondream/src/bin/m.rs +++ b/candle-wasm-examples/moondream/src/bin/m.rs @@ -120,7 +120,7 @@ impl Model { } = serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; let device = Device::Cpu; - let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", prompt); + let prompt = format!("\n\nQuestion: {prompt}\n\nAnswer:"); match &mut self.model { SelectedModel::Moondream(m) => m.text_model.clear_kv_cache(), SelectedModel::Quantized(m) => m.text_model.clear_kv_cache(), diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs index 38e9fe3b6e..5164bb9ab1 100644 --- a/candle-wasm-examples/segment-anything/src/bin/m.rs +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -81,14 +81,12 @@ impl Model { for &(x, y, _bool) in &transformed_points { if !(0.0..=1.0).contains(&x) { return Err(JsError::new(&format!( - "x has to be between 0 and 1, got {}", - x + "x has to be between 0 and 1, got {x}" ))); } if !(0.0..=1.0).contains(&y) { return Err(JsError::new(&format!( - "y has to be between 0 and 1, got {}", - y + "y has to be between 0 and 1, got {y}" ))); } } diff --git a/candle-wasm-examples/whisper/src/app.rs b/candle-wasm-examples/whisper/src/app.rs index a2c0ddabcb..03eae9382d 100644 --- a/candle-wasm-examples/whisper/src/app.rs +++ b/candle-wasm-examples/whisper/src/app.rs @@ -184,7 +184,7 @@ impl Component for App { Ok(WorkerOutput::Decoded(segments)) => { self.status = match dt { None => "decoding succeeded!".to_string(), - Some(dt) => format!("decoding succeeded in {:.2}s", dt), + Some(dt) => format!("decoding succeeded in {dt:.2}s"), }; self.segments = segments; } diff --git a/candle-wasm-examples/yolo/src/app.rs b/candle-wasm-examples/yolo/src/app.rs index 61253fb5a8..40445da696 100644 --- a/candle-wasm-examples/yolo/src/app.rs +++ b/candle-wasm-examples/yolo/src/app.rs @@ -204,7 +204,7 @@ impl Component for App { }); self.status = match dt { None => "processing succeeded!".to_string(), - Some(dt) => format!("processing succeeded in {:.2}s", dt,), + Some(dt) => format!("processing succeeded in {dt:.2}s",), }; self.current_decode = None; if let Err(err) = draw_bboxes(bboxes) { diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index 0bda36d524..00af187057 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -352,7 +352,7 @@ fn run_ls( tensor_info.dtype, ); if verbose { - println!(" {:?}", tensor_info); + println!(" {tensor_info:?}"); } } }