Conversation
|
signed dtypes are nice 👌 I've been having to pass u32s as i32s in cuda launch code and have been worried that would blow up in my face at some point |
ivarflakstad
left a comment
There was a problem hiding this comment.
This is gonna be a good one! 🙌
| DType::F64 => convert_slice::<f64>(data, shape, device), | ||
| DType::F8E4M3 => convert_slice::<F8E4M3>(data, shape, device), | ||
| DType::F8E4M3 => convert_slice::<float8::F8E4M3>(data, shape, device), | ||
| DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { |
There was a problem hiding this comment.
Doesn't have to be in this PR, but I'd prefer to hoist this out into a helper fn.
Perhaps use convert_slice::<u8>(data, shape, device) and manually change the storage dtype? Might not even need a dedicated fn now that I think about it 🤔
| let shape = view.shape(); | ||
|
|
||
| // Create storage with the appropriate dummy type variant | ||
| let storage = match device { |
| #[test] | ||
| fn load_i8() { | ||
| let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"I8\",\"shape\":[2],\"data_offsets\":[0,2]}} \x01\x03"; | ||
| std::fs::write("test_i8.safetensors", bytes).unwrap(); |
There was a problem hiding this comment.
Not related to this PR, just noting down while I'm here: we should use temp files for these kinds of tests.
|
Addressed most of the review comments; left some as unresolved for posterity. |
| let data = unary_map(storage, layout, |v| v as f64); | ||
| Ok(Self::F64(data)) | ||
| } | ||
| (Self::I32(storage), DType::F8E4M3) => { |
There was a problem hiding this comment.
I have an idea for how to reduce the massive size of this match.
Adding it to the ever growing list of things to do :)
| S::F16(s) => self.f(s, d, l, S::F16)?, | ||
| S::F32(s) => self.f(s, d, l, S::F32)?, | ||
| S::F64(s) => self.f(s, d, l, S::F64)?, | ||
| S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, |
There was a problem hiding this comment.
You resolved this but looks the same to me?
| (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, |
candle-core/src/op.rs
Outdated
| #[inline(always)] | ||
| fn f32(v: f32) -> f32 { | ||
| (crate::cpu::erf::erf_f32(v * std::f32::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v | ||
| Self::f64(v as f64) as f32 |
| } | ||
|
|
||
| fn get_current_seed(&self) -> Result<u64> { | ||
| crate::bail!("cannot get the CPU rng seed with get_current_seed") |
There was a problem hiding this comment.
I'll have a look into this later
| S::F16(s) => self.f(s, d, l, S::F16)?, | ||
| S::F32(s) => self.f(s, d, l, S::F32)?, | ||
| S::F64(s) => self.f(s, d, l, S::F64)?, | ||
| S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, |
There was a problem hiding this comment.
You resolved this but looks the same to me?
| (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, |
ivarflakstad
left a comment
There was a problem hiding this comment.
Lgtm! 🎉
Let's just wait for CI before we merge and see if we haven't missed something (ignore pyo3 though).
735c517 to
75e688c
Compare
Adds support for:
i32
i16
f6e2m3
f6e3m2
f4
f8e8m0
These are "dummy" dtypes: this just means a typed bitbucket essentially.
CPU compiles
CUDA compiles
Metal compiles