Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
014dd66
F8E4M3 dtype
EricLBuehler Jul 19, 2025
0ffd361
Add test
EricLBuehler Jul 19, 2025
5c7e440
Fix
EricLBuehler Jul 19, 2025
b3fa5fb
Add i16/i32
EricLBuehler Jul 19, 2025
6f75b48
Add f6e2m3, f6e3m2, f4, f8e8m0
EricLBuehler Jul 19, 2025
0451ece
Add f6e2m3, f6e3m2, f4, f8e8m0 ST loading
EricLBuehler Jul 19, 2025
0cfb907
Compiles on metal
EricLBuehler Jul 19, 2025
2c6d720
Clippy and format
EricLBuehler Jul 19, 2025
07b0a1f
Fix cuda
EricLBuehler Jul 19, 2025
3841938
Cuda build fixes
EricLBuehler Jul 20, 2025
bf90dbf
Use float8 0.3.0
EricLBuehler Jul 20, 2025
bd73230
Remove
EricLBuehler Jul 20, 2025
7d04b42
Updated quantized api
EricLBuehler Jul 20, 2025
898dd8a
Updated varbuilder api
EricLBuehler Jul 20, 2025
37a237e
Metal updates
EricLBuehler Jul 20, 2025
bfb11d6
Export dummy dtype
EricLBuehler Jul 20, 2025
f3d0caa
Better error type
EricLBuehler Jul 20, 2025
cd06160
Add unfold
EricLBuehler Jul 20, 2025
80d0894
Add new_buffer_private for metal
EricLBuehler Jul 20, 2025
461b4b6
Add empty
EricLBuehler Jul 20, 2025
a68db80
Add get_current_seed
EricLBuehler Jul 20, 2025
eacf1ab
Add get_current_seed
EricLBuehler Jul 20, 2025
67ad023
Add flash attn v3
EricLBuehler Jul 20, 2025
88f8a50
Add flash attn v3
EricLBuehler Jul 20, 2025
8e2e7e1
Update deps
EricLBuehler Jul 20, 2025
1177dd0
Updated v3 FA
EricLBuehler Jul 20, 2025
2a4690f
Fix cuda build
EricLBuehler Jul 20, 2025
d158e8c
Support loading new dtypes
EricLBuehler Jul 20, 2025
3276c80
Fix cuda
EricLBuehler Jul 20, 2025
0bce325
Fix cpu
EricLBuehler Jul 20, 2025
47bc675
Add i32 cuda dtype
EricLBuehler Jul 20, 2025
92070a1
Add i32 cuda dtype
EricLBuehler Jul 20, 2025
8937065
Expose cublas handle
EricLBuehler Jul 20, 2025
bca1bcd
Fix flash attn v2
EricLBuehler Jul 21, 2025
6ff8c85
Add cutlass
EricLBuehler Jul 21, 2025
caa6549
Add cutlass
EricLBuehler Jul 21, 2025
f80cb89
Fix v3 build
EricLBuehler Jul 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "candle-examples/examples/flash-attn/cutlass"]
path = candle-flash-attn/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "candle-flash-attn-v3/cutlass"]
path = candle-flash-attn-v3/cutlass
url = https://github.com/NVIDIA/cutlass.git
11 changes: 0 additions & 11 deletions .vscode/settings.json

This file was deleted.

7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ members = [
exclude = [
"candle-book",
"candle-flash-attn",
"candle-flash-attn-v3",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
Expand All @@ -36,18 +37,20 @@ byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.9.1" }
candle-datasets = { path = "./candle-datasets", version = "0.9.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.1" }
candle-flash-attn-v3 = { path = "./candle-flash-attn-v3", version = "0.9.1" }
candle-kernels = { path = "./candle-kernels", version = "0.9.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.1" }
candle-nn = { path = "./candle-nn", version = "0.9.1" }
candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.16.6", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
float8 = { version = "0.3.0", features = ["num-traits"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
Expand All @@ -61,7 +64,7 @@ parquet = { version = "51.0.0" }
rand = "0.9.0"
rand_distr = "0.5.1"
rayon = "1.7.0"
safetensors = "0.4.1"
safetensors = "0.6.0"
serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
Expand Down
4 changes: 3 additions & 1 deletion candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ metal = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
float8 = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
libc = { workspace = true, optional = true }
memmap2 = { workspace = true }
Expand All @@ -43,8 +44,9 @@ criterion = { workspace = true }

[features]
default = []
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda", "float8/cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
nccl = ["cuda", "cudarc/nccl"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]
Expand Down
6 changes: 4 additions & 2 deletions candle-core/benches/benchmarks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pub(crate) mod reduce;
pub(crate) mod unary;
pub(crate) mod where_cond;

#[cfg(feature = "cuda")]
use candle_core::backend::BackendDevice;
use candle_core::{Device, Result};

pub(crate) trait BenchDevice {
Expand All @@ -26,13 +28,13 @@ impl BenchDevice for Device {
.synchronize()
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
#[cfg(not(feature = "cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
panic!("Cuda device without cuda feature enabled: {device:?}")
}
Device::Metal(device) => {
#[cfg(feature = "metal")]
return Ok(device.wait_until_completed()?);
#[cfg(not(feature = "metal"))]
panic!("Metal device without metal feature enabled: {:?}", device)
panic!("Metal device without metal feature enabled: {device:?}")
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion candle-core/benches/benchmarks/qmatmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {

let flops = b * m * n * k;

let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{dtype:?}")));
group.sample_size(200);
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
Expand Down
8 changes: 4 additions & 4 deletions candle-core/benches/benchmarks/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ fn run_reduce<T: candle_core::FloatDType>(
let k = 1024;

let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
Tensor::rand(lo, up, (b, m, k), device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
Tensor::rand(lo, up, (b, m, k), device).unwrap()
};

let flops = b * m * k * T::DTYPE.size_in_bytes();
Expand Down Expand Up @@ -105,12 +105,12 @@ fn run_arg_reduce<T: candle_core::FloatDType>(
let k = 1024;

let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
Tensor::rand(lo, up, (b, m, k), device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
Tensor::rand(lo, up, (b, m, k), device).unwrap()
};

let flops = b * m * k * T::DTYPE.size_in_bytes();
Expand Down
2 changes: 1 addition & 1 deletion candle-core/benches/benchmarks/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in [DType::F32, DType::BF16, DType::F16] {
let name = format!("sqrt_{:?}", dtype);
let name = format!("sqrt_{dtype:?}");
run_unary_benchmark(c, &device, dtype, &name);
}
}
Expand Down
2 changes: 1 addition & 1 deletion candle-core/benches/benchmarks/where_cond.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const M: usize = 1024;
const K: usize = 1024;
const SIZE: usize = B * M * K;

const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
static DATA: [u8; SIZE] = create_cond_arr::<SIZE>();

fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;

fn set_seed(&self, _: u64) -> Result<()>;
fn get_current_seed(&self) -> Result<u64>;

/// Synchronize should block until all the operations on the device are completed.
fn synchronize(&self) -> Result<()>;
Expand Down
21 changes: 21 additions & 0 deletions candle-core/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ from_tensor!(f32);
from_tensor!(f16);
from_tensor!(bf16);
from_tensor!(i64);
from_tensor!(i32);
from_tensor!(i16);
from_tensor!(u32);
from_tensor!(u8);

Expand Down Expand Up @@ -130,6 +132,16 @@ impl Tensor {
f.write_u32::<LittleEndian>(v)?
}
}
DType::I16 => {
for v in vs.to_vec1::<i16>()? {
f.write_i16::<LittleEndian>(v)?
}
}
DType::I32 => {
for v in vs.to_vec1::<i32>()? {
f.write_i32::<LittleEndian>(v)?
}
}
DType::I64 => {
for v in vs.to_vec1::<i64>()? {
f.write_i64::<LittleEndian>(v)?
Expand All @@ -139,6 +151,15 @@ impl Tensor {
let vs = vs.to_vec1::<u8>()?;
f.write_all(&vs)?;
}
DType::F8E4M3 => {
let vs = vs.to_vec1::<float8::F8E4M3>()?;
for v in vs {
f.write_u8(v.to_bits())?
}
}
DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
return Err(crate::Error::UnsupportedDTypeForOp(self.dtype(), "write_bytes").bt())
}
}
Ok(())
}
Expand Down
83 changes: 81 additions & 2 deletions candle-core/src/cpu/avx.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::{Cpu, CpuF16};
use super::{Cpu, CpuBF16, CpuF16};
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

use half::f16;
use half::{bf16, f16};

pub struct CurrentCpu {}

Expand Down Expand Up @@ -146,3 +146,82 @@ impl CpuF16<ARR> for CurrentCpuF16 {
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}

pub struct CurrentCpuBF16 {}
impl CpuBF16<ARR> for CurrentCpuBF16 {
type Unit = __m256;
type Array = [__m256; ARR];

const STEP: usize = STEP;
const EPR: usize = EPR;

fn n() -> usize {
ARR
}

unsafe fn zero() -> Self::Unit {
_mm256_setzero_ps()
}

unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}

unsafe fn from_f32(v: f32) -> Self::Unit {
_mm256_set1_ps(v)
}

#[cfg(target_feature = "f16c")]
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
}

#[cfg(not(target_feature = "f16c"))]
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
let mut tmp = [0.0f32; 8];
for i in 0..8 {
tmp[i] = (*mem_addr.add(i)).to_f32();
}
_mm256_loadu_ps(tmp.as_ptr())
}

unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
_mm256_add_ps(a, b)
}

unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
_mm256_add_ps(_mm256_mul_ps(b, c), a)
}

#[cfg(target_feature = "f16c")]
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
}

#[cfg(not(target_feature = "f16c"))]
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
let mut tmp = [0.0f32; 8];
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
for i in 0..8 {
*mem_addr.add(i) = bf16::from_f32(tmp[i]);
}
}

unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
let mut offset = ARR >> 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
let t1 = _mm_hadd_ps(t0, t0);
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}
41 changes: 41 additions & 0 deletions candle-core/src/cpu/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ impl VecOps for half::bf16 {
fn max(self, other: Self) -> Self {
Self::max(self, other)
}

#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
super::vec_dot_bf16(lhs, rhs, &mut res_f32, len);
*res = half::bf16::from_f32(res_f32);
}
}
impl VecOps for u8 {
#[inline(always)]
Expand All @@ -144,6 +151,28 @@ impl VecOps for u32 {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Expand All @@ -156,6 +185,18 @@ impl VecOps for i64 {
}
}

impl VecOps for float8::F8E4M3 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}

#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
if n_threads == 1 {
Expand Down
Loading
Loading