diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index b3526ed7e5..a208472804 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -1,14 +1,14 @@ +use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr}; use crate::backend::{BackendDevice, BackendStorage}; use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; pub use candle_kernels as kernels; pub use cudarc; use cudarc::driver::CudaFunction; +use float8::F8E4M3; use half::{bf16, f16}; use std::collections::HashMap; use std::sync::{Arc, Mutex, RwLock}; -use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr}; - /// Unique identifier for cuda devices. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct DeviceId(usize); @@ -359,7 +359,9 @@ impl BackendDevice for CudaDevice { CudaStorageSlice::F64(data) } DType::F8E4M3 => { - return Err(CudaError::InternalError("F8E4M3 not supported in CUDA backend").into()) + let data = self.alloc_zeros::(elem_count)?; + CudaStorageSlice::F8E4M3(data) + // return Err(CudaError::InternalError("F8E4M3 not supported in CUDA backend").into()) } DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { return Err( @@ -512,7 +514,9 @@ impl BackendDevice for CudaDevice { CudaStorageSlice::F64(data) } DType::F8E4M3 => { - return Err(CudaError::InternalError("F8E4M3 not supported in CUDA backend").into()) + let data = self.alloc::(elem_count)?; + CudaStorageSlice::F8E4M3(data) + // return Err(CudaError::InternalError("F8E4M3 not supported in CUDA backend").into()) } DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { return Err(