diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index ac6ed57221..9840ec5089 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -1,8 +1,11 @@ use crate::backend::BackendDevice; -use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; +use crate::dtype::AllocatedBuffers; +use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape, WithDType}; pub use candle_kernels as kernels; pub use cudarc; -use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; +use cudarc::driver::{ + result, CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, +}; use float8::F8E4M3; use half::{bf16, f16}; use std::sync::{Arc, Mutex, RwLock}; @@ -32,6 +35,7 @@ pub struct CudaDevice { pub(crate) blas: Arc, curand: Arc>, seed_value: Arc>, + buffers: Arc>, } impl std::fmt::Debug for CudaDevice { @@ -49,6 +53,29 @@ impl std::ops::Deref for CudaDevice { } impl CudaDevice { + /// Allocates device memory and increments the reference counter of [CudaDevice]. + /// + /// # Safety + /// This is unsafe because the device memory is unset after this call. + pub unsafe fn alloc( + &self, + len: usize, + ) -> std::result::Result, result::DriverError> { + let mut buffers = self.buffers.write().unwrap(); + + let bufs = T::get_buffers(&*buffers); + for buf in bufs { + if buf.len() == len { + return Ok(buf.clone()); + } + } + + // Default to plain + let buffer = self.device.alloc::(len)?; + T::cache_buffer(&mut buffers, buffer.clone()); + Ok(buffer) + } + pub fn cublas_handle(&self) -> &cudarc::cublas::CudaBlas { &*self.blas } @@ -185,6 +212,7 @@ impl CudaDevice { blas: Arc::new(blas), curand: Arc::new(Mutex::new(CudaRng(curand))), seed_value: Arc::new(RwLock::new(299792458)), + buffers: Arc::new(RwLock::new(AllocatedBuffers::default())), }) } } @@ -202,6 +230,7 @@ impl BackendDevice for CudaDevice { blas: Arc::new(blas), curand: Arc::new(Mutex::new(CudaRng(curand))), seed_value: Arc::new(RwLock::new(299792458)), + buffers: Arc::new(RwLock::new(AllocatedBuffers::default())), }) } diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 51dbd1b356..16c704357d 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1025,7 +1025,7 @@ pub struct CudaStorage { pub device: CudaDevice, } -pub trait CudaDType: Sized { +pub trait CudaDType: Sized + WithDType { fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice>; fn wrap_cuda_slice(s: CudaSlice, dev: CudaDevice) -> CudaStorage; } diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index f40ec3f7e1..872c45c86f 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -107,6 +107,192 @@ impl DType { } } +// Real impl +#[cfg(feature = "cuda")] +mod allocated_buffers { + use cudarc::driver::{CudaSlice, DeviceRepr}; + use float8::F8E4M3; + use half::{bf16, f16}; + + #[derive(Default)] + pub struct AllocatedBuffers { + buf_f8e4m3: Vec>, + buf_u8: Vec>, + buf_u32: Vec>, + buf_i16: Vec>, + buf_i32: Vec>, + buf_i64: Vec>, + buf_bf16: Vec>, + buf_f16: Vec>, + buf_f32: Vec>, + buf_f64: Vec>, + } + + // Trait to map from T to the correct buffer. + pub trait CachedBufferHelper: DeviceRepr { + fn get_buffers(buffers: &AllocatedBuffers) -> &Vec> + where + Self: Sized; + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice) + where + Self: Sized; + } + + impl CachedBufferHelper for F8E4M3 { + fn get_buffers(buffers: &AllocatedBuffers) -> &Vec> { + &buffers.buf_f8e4m3 + } + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice) { + buffers.buf_f8e4m3.push(buf); + } + } + + impl CachedBufferHelper for u8 { + fn get_buffers(buffers: &AllocatedBuffers) -> &Vec> { + &buffers.buf_u8 + } + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice) { + buffers.buf_u8.push(buf); + } + } + + impl CachedBufferHelper for u32 { + fn get_buffers(buffers: &AllocatedBuffers) -> &Vec> { + &buffers.buf_u32 + } + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice) { + buffers.buf_u32.push(buf); + } + } + + impl CachedBufferHelper for i16 { + fn get_buffers(buffers: &AllocatedBuffers) -> &Vec> { + &buffers.buf_i16 + } + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice) { + buffers.buf_i16.push(buf); + } + } + + impl CachedBufferHelper for i32 { + fn get_buffers(buffers: &AllocatedBuffers) -> &Vec> { + &buffers.buf_i32 + } + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice) { + buffers.buf_i32.push(buf); + } + } + + impl CachedBufferHelper for i64 { + fn get_buffers(buffers: &AllocatedBuffers) -> &Vec> { + &buffers.buf_i64 + } + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice) { + buffers.buf_i64.push(buf); + } + } + + impl CachedBufferHelper for bf16 { + fn get_buffers(buffers: &AllocatedBuffers) -> &Vec> { + &buffers.buf_bf16 + } + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice) { + buffers.buf_bf16.push(buf); + } + } + + impl CachedBufferHelper for f16 { + fn get_buffers(buffers: &AllocatedBuffers) -> &Vec> { + &buffers.buf_f16 + } + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice) { + buffers.buf_f16.push(buf); + } + } + + impl CachedBufferHelper for f32 { + fn get_buffers(buffers: &AllocatedBuffers) -> &Vec> { + &buffers.buf_f32 + } + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice) { + buffers.buf_f32.push(buf); + } + } + + impl CachedBufferHelper for f64 { + fn get_buffers(buffers: &AllocatedBuffers) -> &Vec> { + &buffers.buf_f64 + } + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice) { + buffers.buf_f64.push(buf); + } + } +} + +// Dummy impl +#[cfg(not(feature = "cuda"))] +mod allocated_buffers { + #[derive(Default)] + pub struct AllocatedBuffers; + + pub trait CachedBufferHelper: DeviceRepr { + fn get_buffers(buffers: &AllocatedBuffers); + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()); + } + + impl CachedBufferHelper for F8E4M3 { + fn get_buffers(buffers: &AllocatedBuffers) {} + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {} + } + + impl CachedBufferHelper for u8 { + fn get_buffers(buffers: &AllocatedBuffers) {} + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {} + } + + impl CachedBufferHelper for u32 { + fn get_buffers(buffers: &AllocatedBuffers) {} + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {} + } + + impl CachedBufferHelper for i16 { + fn get_buffers(buffers: &AllocatedBuffers) {} + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {} + } + + impl CachedBufferHelper for i32 { + fn get_buffers(buffers: &AllocatedBuffers) {} + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {} + } + + impl CachedBufferHelper for i64 { + fn get_buffers(buffers: &AllocatedBuffers) {} + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {} + } + + impl CachedBufferHelper for bf16 { + fn get_buffers(buffers: &AllocatedBuffers) {} + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {} + } + + impl CachedBufferHelper for f16 { + fn get_buffers(buffers: &AllocatedBuffers) {} + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {} + } + + impl CachedBufferHelper for f32 { + fn get_buffers(buffers: &AllocatedBuffers) {} + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {} + } + + impl CachedBufferHelper for f64 { + fn get_buffers(buffers: &AllocatedBuffers) {} + fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {} + } +} + +pub use allocated_buffers::*; + pub trait WithDType: Sized + Copy @@ -118,6 +304,7 @@ pub trait WithDType: + Sync + std::any::Any + crate::cpu::kernels::VecOps + + CachedBufferHelper { const DTYPE: DType;