Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions candle-core/src/cuda_backend/device.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use crate::backend::BackendDevice;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
use crate::dtype::AllocatedBuffers;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape, WithDType};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
use cudarc::driver::{
result, CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig,
};
use float8::F8E4M3;
use half::{bf16, f16};
use std::sync::{Arc, Mutex, RwLock};
Expand Down Expand Up @@ -32,6 +35,7 @@ pub struct CudaDevice {
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
seed_value: Arc<RwLock<u64>>,
buffers: Arc<RwLock<AllocatedBuffers>>,
}

impl std::fmt::Debug for CudaDevice {
Expand All @@ -49,6 +53,29 @@ impl std::ops::Deref for CudaDevice {
}

impl CudaDevice {
/// Allocates device memory and increments the reference counter of [CudaDevice].
///
/// # Safety
/// This is unsafe because the device memory is unset after this call.
pub unsafe fn alloc<T: DeviceRepr + WithDType>(
&self,
len: usize,
) -> std::result::Result<CudaSlice<T>, result::DriverError> {
let mut buffers = self.buffers.write().unwrap();

let bufs = T::get_buffers(&*buffers);
for buf in bufs {
if buf.len() == len {
return Ok(buf.clone());
}
}

// Default to plain
let buffer = self.device.alloc::<T>(len)?;
T::cache_buffer(&mut buffers, buffer.clone());
Ok(buffer)
}

pub fn cublas_handle(&self) -> &cudarc::cublas::CudaBlas {
&*self.blas
}
Expand Down Expand Up @@ -185,6 +212,7 @@ impl CudaDevice {
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
seed_value: Arc::new(RwLock::new(299792458)),
buffers: Arc::new(RwLock::new(AllocatedBuffers::default())),
})
}
}
Expand All @@ -202,6 +230,7 @@ impl BackendDevice for CudaDevice {
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
seed_value: Arc::new(RwLock::new(299792458)),
buffers: Arc::new(RwLock::new(AllocatedBuffers::default())),
})
}

Expand Down
2 changes: 1 addition & 1 deletion candle-core/src/cuda_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,7 @@ pub struct CudaStorage {
pub device: CudaDevice,
}

pub trait CudaDType: Sized {
pub trait CudaDType: Sized + WithDType {
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
}
Expand Down
187 changes: 187 additions & 0 deletions candle-core/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,192 @@
}
}

// Real impl
#[cfg(feature = "cuda")]
mod allocated_buffers {
use cudarc::driver::{CudaSlice, DeviceRepr};
use float8::F8E4M3;
use half::{bf16, f16};

#[derive(Default)]
pub struct AllocatedBuffers {
buf_f8e4m3: Vec<CudaSlice<F8E4M3>>,
buf_u8: Vec<CudaSlice<u8>>,
buf_u32: Vec<CudaSlice<u32>>,
buf_i16: Vec<CudaSlice<i16>>,
buf_i32: Vec<CudaSlice<i32>>,
buf_i64: Vec<CudaSlice<i64>>,
buf_bf16: Vec<CudaSlice<bf16>>,
buf_f16: Vec<CudaSlice<f16>>,
buf_f32: Vec<CudaSlice<f32>>,
buf_f64: Vec<CudaSlice<f64>>,
}

// Trait to map from T to the correct buffer.
pub trait CachedBufferHelper: DeviceRepr {
fn get_buffers(buffers: &AllocatedBuffers) -> &Vec<CudaSlice<Self>>
where
Self: Sized;
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice<Self>)
where
Self: Sized;
}

impl CachedBufferHelper for F8E4M3 {
fn get_buffers(buffers: &AllocatedBuffers) -> &Vec<CudaSlice<Self>> {
&buffers.buf_f8e4m3
}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice<Self>) {
buffers.buf_f8e4m3.push(buf);
}
}

impl CachedBufferHelper for u8 {
fn get_buffers(buffers: &AllocatedBuffers) -> &Vec<CudaSlice<Self>> {
&buffers.buf_u8
}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice<Self>) {
buffers.buf_u8.push(buf);
}
}

impl CachedBufferHelper for u32 {
fn get_buffers(buffers: &AllocatedBuffers) -> &Vec<CudaSlice<Self>> {
&buffers.buf_u32
}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice<Self>) {
buffers.buf_u32.push(buf);
}
}

impl CachedBufferHelper for i16 {
fn get_buffers(buffers: &AllocatedBuffers) -> &Vec<CudaSlice<Self>> {
&buffers.buf_i16
}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice<Self>) {
buffers.buf_i16.push(buf);
}
}

impl CachedBufferHelper for i32 {
fn get_buffers(buffers: &AllocatedBuffers) -> &Vec<CudaSlice<Self>> {
&buffers.buf_i32
}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice<Self>) {
buffers.buf_i32.push(buf);
}
}

impl CachedBufferHelper for i64 {
fn get_buffers(buffers: &AllocatedBuffers) -> &Vec<CudaSlice<Self>> {
&buffers.buf_i64
}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice<Self>) {
buffers.buf_i64.push(buf);
}
}

impl CachedBufferHelper for bf16 {
fn get_buffers(buffers: &AllocatedBuffers) -> &Vec<CudaSlice<Self>> {
&buffers.buf_bf16
}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice<Self>) {
buffers.buf_bf16.push(buf);
}
}

impl CachedBufferHelper for f16 {
fn get_buffers(buffers: &AllocatedBuffers) -> &Vec<CudaSlice<Self>> {
&buffers.buf_f16
}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice<Self>) {
buffers.buf_f16.push(buf);
}
}

impl CachedBufferHelper for f32 {
fn get_buffers(buffers: &AllocatedBuffers) -> &Vec<CudaSlice<Self>> {
&buffers.buf_f32
}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice<Self>) {
buffers.buf_f32.push(buf);
}
}

impl CachedBufferHelper for f64 {
fn get_buffers(buffers: &AllocatedBuffers) -> &Vec<CudaSlice<Self>> {
&buffers.buf_f64
}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: CudaSlice<Self>) {
buffers.buf_f64.push(buf);
}
}
}

// Dummy impl
#[cfg(not(feature = "cuda"))]
mod allocated_buffers {
#[derive(Default)]
pub struct AllocatedBuffers;

pub trait CachedBufferHelper: DeviceRepr {

Check failure on line 238 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Clippy

cannot find trait `DeviceRepr` in this scope

Check failure on line 238 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

cannot find trait `DeviceRepr` in this scope

Check failure on line 238 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

cannot find trait `DeviceRepr` in this scope

Check failure on line 238 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

cannot find trait `DeviceRepr` in this scope
fn get_buffers(buffers: &AllocatedBuffers);
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ());
}

impl CachedBufferHelper for F8E4M3 {

Check failure on line 243 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Clippy

cannot find type `F8E4M3` in this scope

Check failure on line 243 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

cannot find type `F8E4M3` in this scope

Check failure on line 243 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

cannot find type `F8E4M3` in this scope

Check failure on line 243 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

cannot find type `F8E4M3` in this scope
fn get_buffers(buffers: &AllocatedBuffers) {}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {}
}

impl CachedBufferHelper for u8 {
fn get_buffers(buffers: &AllocatedBuffers) {}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {}
}

impl CachedBufferHelper for u32 {
fn get_buffers(buffers: &AllocatedBuffers) {}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {}
}

impl CachedBufferHelper for i16 {
fn get_buffers(buffers: &AllocatedBuffers) {}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {}
}

impl CachedBufferHelper for i32 {
fn get_buffers(buffers: &AllocatedBuffers) {}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {}
}

impl CachedBufferHelper for i64 {
fn get_buffers(buffers: &AllocatedBuffers) {}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {}
}

impl CachedBufferHelper for bf16 {

Check failure on line 273 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Clippy

cannot find type `bf16` in this scope

Check failure on line 273 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

cannot find type `bf16` in this scope

Check failure on line 273 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

cannot find type `bf16` in this scope

Check failure on line 273 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

cannot find type `bf16` in this scope
fn get_buffers(buffers: &AllocatedBuffers) {}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {}
}

impl CachedBufferHelper for f16 {

Check failure on line 278 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Clippy

the type `f16` is unstable

Check failure on line 278 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

the type `f16` is unstable

Check failure on line 278 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

the type `f16` is unstable

Check failure on line 278 in candle-core/src/dtype.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

the type `f16` is unstable
fn get_buffers(buffers: &AllocatedBuffers) {}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {}
}

impl CachedBufferHelper for f32 {
fn get_buffers(buffers: &AllocatedBuffers) {}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {}
}

impl CachedBufferHelper for f64 {
fn get_buffers(buffers: &AllocatedBuffers) {}
fn cache_buffer(buffers: &mut AllocatedBuffers, buf: ()) {}
}
}

pub use allocated_buffers::*;

pub trait WithDType:
Sized
+ Copy
Expand All @@ -118,6 +304,7 @@
+ Sync
+ std::any::Any
+ crate::cpu::kernels::VecOps
+ CachedBufferHelper
{
const DTYPE: DType;

Expand Down
Loading