diff --git a/candle-core/examples/metal_pool.rs b/candle-core/examples/metal_pool.rs new file mode 100644 index 0000000000..1a62c2da17 --- /dev/null +++ b/candle-core/examples/metal_pool.rs @@ -0,0 +1,38 @@ +use candle_core::{DType, Device, Result, Tensor}; + +#[cfg(feature = "metal")] +fn run() -> Result<()> { + if !candle_core::utils::metal_is_available() { + eprintln!("Metal is not available on this system."); + return Ok(()); + } + + let device = Device::new_metal(0)?; + + let input = Tensor::arange(0f32, 16f32, &Device::Cpu)? + .to_dtype(DType::F32)? + .to_device(&device)?; + + // Allocate a 4 MB pool for intermediate activations. + let pooled = input.start_pool(4 * 1024 * 1024)?; + println!("{input}"); + println!("{pooled}"); + println!("{}", pooled.sin()?); + + let logits = pooled.sin()?.mul(&pooled.cos()?)?; + let final_tensor = logits.tanh()?.leave_pool()?; + + println!("final tensor: {:?}", final_tensor.to_vec1::()?); + + Ok(()) +} + +#[cfg(not(feature = "metal"))] +fn run() -> Result<()> { + eprintln!("Rebuild candle-core with the `metal` feature to run this example."); + Ok(()) +} + +fn main() -> Result<()> { + run() +} diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index c5cc96e208..8954f91f90 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use std::path::Path; use std::sync::{Arc, Mutex, RwLock}; -use super::MetalError; +use super::{buffer_offset_for_output, current_pool, register_pool_allocation, MetalError}; // iOS and macOS have different storage modes for shared buffers. // due to the GPU/CPU management differences. @@ -242,8 +242,16 @@ impl MetalDevice { dtype: DType, name: &str, ) -> Result> { - let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) + let size_bytes = element_count * dtype.size_in_bytes(); + if let Some(pool) = current_pool() { + let allocation = pool.allocate(size_bytes, dtype.size_in_bytes())?; + let buffer = Arc::clone(allocation.buffer()); + register_pool_allocation(&buffer, allocation); + Ok(buffer) + } else { + let size = size_bytes as NSUInteger; + self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) + } } pub fn new_buffer_private( @@ -252,8 +260,7 @@ impl MetalDevice { dtype: DType, name: &str, ) -> Result> { - let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self.allocate_buffer(size, metal::MTLResourceOptions::StorageModePrivate, name) + self.new_buffer(element_count, dtype, name) } /// Creates a new buffer (not necessarily zeroed). @@ -288,6 +295,10 @@ impl MetalDevice { Ok(new_buffer) } + pub fn buffer_offset<'a>(&self, buffer: &'a Arc) -> super::BufferOffset<'a> { + buffer_offset_for_output(buffer) + } + pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result> { let buffer = self.allocate_buffer( size_in_bytes as NSUInteger, diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 47c44a8711..5ce847b0b1 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -6,19 +6,85 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, CpuStorageRef, DType, Error, Layout, Result, Shape}; use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels}; use metal::{Buffer, NSUInteger}; +use std::cell::RefCell; use std::collections::HashMap; use std::ffi::c_void; -use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError}; +use std::sync::{Arc, Mutex, OnceLock, PoisonError, RwLock, TryLockError}; mod device; +mod pool; pub use device::{DeviceId, MetalDevice, SHARED_BUFFER_STORAGE_MODE}; +pub use pool::{MetalPoolAllocation, MetalTensorPool}; -pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> { +thread_local! { + static POOL_STACK: RefCell>>> = RefCell::new(Vec::new()); +} + +pub(crate) struct PoolContextGuard; + +impl Drop for PoolContextGuard { + fn drop(&mut self) { + POOL_STACK.with(|stack| { + stack.borrow_mut().pop(); + }); + } +} + +pub(crate) fn push_pool_context(pool: Option>) -> PoolContextGuard { + POOL_STACK.with(|stack| stack.borrow_mut().push(pool)); + PoolContextGuard +} + +fn current_pool() -> Option> { + POOL_STACK.with(|stack| stack.borrow().last().cloned().flatten()) +} + +fn pool_registry() -> &'static Mutex>>> { + static REGISTRY: OnceLock>>>> = + OnceLock::new(); + REGISTRY.get_or_init(|| Mutex::new(HashMap::new())) +} + +fn register_pool_allocation(buffer: &Arc, allocation: Arc) { + let ptr = Arc::as_ptr(buffer) as usize; + if let Ok(mut registry) = pool_registry().lock() { + registry.entry(ptr).or_default().push(allocation); + } +} + +fn take_pool_allocation(buffer: &Arc) -> Option> { + let ptr = Arc::as_ptr(buffer) as usize; + pool_registry().lock().ok().and_then(|mut registry| { + let entry = registry.get_mut(&ptr)?; + let allocation = entry.pop(); + if entry.is_empty() { + registry.remove(&ptr); + } + allocation + }) +} + +fn peek_pool_allocation(buffer: &Arc) -> Option> { + let ptr = Arc::as_ptr(buffer) as usize; + pool_registry() + .lock() + .ok() + .and_then(|registry| registry.get(&ptr).and_then(|allocs| allocs.last().cloned())) +} + +fn buffer_offset_for_output(buffer: &Arc) -> BufferOffset<'_> { + let offset = peek_pool_allocation(buffer) + .map(|alloc| alloc.offset()) + .unwrap_or(0); BufferOffset { buffer, - offset_in_bytes: l.start_offset() * dtype.size_in_bytes(), + offset_in_bytes: offset, } } + +pub fn buffer_o<'a>(storage: &'a MetalStorage, l: &Layout) -> BufferOffset<'a> { + storage.buffer_offset(l) +} /// Simple way to catch lock error without /// depending on T #[derive(thiserror::Error, Debug)] @@ -77,6 +143,8 @@ pub struct MetalStorage { count: usize, /// The dtype is kept since buffers are untyped. dtype: DType, + offset_bytes: usize, + pool_allocation: Option>, } impl BackendStorage for MetalStorage { @@ -113,6 +181,7 @@ impl BackendStorage for MetalStorage { } fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + let _guard = Self::pool_guard(&[self])?; let device = self.device().clone(); let shape = layout.shape(); @@ -121,7 +190,7 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(el, self.dtype, "affine")?; let command_buffer = self.device.command_buffer()?; - let src = buffer_o(&self.buffer, layout, dtype); + let src = buffer_o(self, layout); if layout.is_contiguous() { let name = match self.dtype { DType::F32 => "affine_f32", @@ -138,7 +207,7 @@ impl BackendStorage for MetalStorage { name, el, src, - &buffer, + buffer_offset_for_output(&buffer), mul as f32, add as f32, ) @@ -158,7 +227,7 @@ impl BackendStorage for MetalStorage { layout.dims(), src, layout.stride(), - &buffer, + buffer_offset_for_output(&buffer), mul as f32, add as f32, ) @@ -168,6 +237,7 @@ impl BackendStorage for MetalStorage { } fn powf(&self, layout: &Layout, pow: f64) -> Result { + let _guard = Self::pool_guard(&[self])?; let device = self.device().clone(); let shape = layout.shape(); @@ -176,7 +246,7 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(el, self.dtype, "powf")?; let command_buffer = self.device.command_buffer()?; - let src = buffer_o(&self.buffer, layout, dtype); + let src = buffer_o(self, layout); if layout.is_contiguous() { let name = match self.dtype { DType::F32 => "powf_f32", @@ -191,7 +261,7 @@ impl BackendStorage for MetalStorage { name, el, src, - &buffer, + buffer_offset_for_output(&buffer), pow as f32, ) .map_err(MetalError::from)?; @@ -210,7 +280,7 @@ impl BackendStorage for MetalStorage { layout.dims(), src, layout.stride(), - &buffer, + buffer_offset_for_output(&buffer), pow as f32, ) .map_err(MetalError::from)?; @@ -219,6 +289,7 @@ impl BackendStorage for MetalStorage { } fn elu(&self, layout: &Layout, alpha: f64) -> Result { + let _guard = Self::pool_guard(&[self])?; let device = self.device().clone(); let shape = layout.shape(); @@ -227,7 +298,7 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(el, self.dtype, "elu")?; let command_buffer = self.device.command_buffer()?; - let src = buffer_o(&self.buffer, layout, self.dtype); + let src = buffer_o(self, layout); if layout.is_contiguous() { let name = match self.dtype { DType::F32 => "elu_f32", @@ -242,7 +313,7 @@ impl BackendStorage for MetalStorage { name, el, src, - &buffer, + buffer_offset_for_output(&buffer), alpha as f32, ) .map_err(MetalError::from)?; @@ -261,7 +332,7 @@ impl BackendStorage for MetalStorage { layout.dims(), src, layout.stride(), - &buffer, + buffer_offset_for_output(&buffer), alpha as f32, ) .map_err(MetalError::from)?; @@ -270,6 +341,7 @@ impl BackendStorage for MetalStorage { } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { + let _guard = Self::pool_guard(&[self])?; let device = self.device.clone(); let src_stride = layout.stride(); @@ -335,7 +407,7 @@ impl BackendStorage for MetalStorage { let dtype = if return_index { DType::U32 } else { self.dtype }; let buffer = device.new_buffer(dst_el, dtype, "reduce")?; let command_buffer = self.device.command_buffer()?; - let src = buffer_o(&self.buffer, layout, self.dtype); + let src = buffer_o(self, layout); candle_metal_kernels::call_reduce_contiguous( &device.device, &command_buffer, @@ -344,7 +416,7 @@ impl BackendStorage for MetalStorage { src_dims, dst_el, src, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; @@ -390,7 +462,7 @@ impl BackendStorage for MetalStorage { let dtype = if return_index { DType::U32 } else { self.dtype }; let buffer = device.new_buffer(dst_el, dtype, "reduce")?; let command_buffer = self.device.command_buffer()?; - let src = buffer_o(&self.buffer, layout, self.dtype); + let src = buffer_o(self, layout); candle_metal_kernels::call_reduce_strided( &device.device, &command_buffer, @@ -400,7 +472,7 @@ impl BackendStorage for MetalStorage { &stride, dst_el, src, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; @@ -408,6 +480,7 @@ impl BackendStorage for MetalStorage { } fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result { + let _guard = Self::pool_guard(&[self, rhs])?; let name = match op { CmpOp::Eq => "eq", CmpOp::Ne => "ne", @@ -432,7 +505,7 @@ impl BackendStorage for MetalStorage { let el_count = shape.elem_count(); let command_buffer = device.command_buffer()?; command_buffer.set_label("const-set"); - let dst = buffer_o(&self_.buffer, l, self_.dtype); + let dst = buffer_o(self_, l); match (el_count % 2, dtype, l.is_contiguous()) { (0, DType::BF16 | DType::F16, true) => { @@ -527,12 +600,13 @@ impl BackendStorage for MetalStorage { } fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { + let _guard = Self::pool_guard(&[self])?; let device = self.device(); let shape = layout.shape(); let el_count = shape.elem_count(); let buffer = device.new_buffer(el_count, dtype, "todtype")?; let command_buffer = device.command_buffer()?; - let src = buffer_o(&self.buffer, layout, self.dtype); + let src = buffer_o(self, layout); if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::BF16) => "cast_u32_bf16", @@ -582,7 +656,7 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, src, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; } else { @@ -635,7 +709,7 @@ impl BackendStorage for MetalStorage { layout.dims(), src, layout.stride(), - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; } @@ -644,6 +718,7 @@ impl BackendStorage for MetalStorage { } fn unary_impl(&self, layout: &Layout) -> Result { + let _guard = Self::pool_guard(&[self])?; let device = self.device(); let dtype = self.dtype; let shape = layout.shape(); @@ -651,7 +726,7 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; let command_buffer = device.command_buffer()?; command_buffer.set_label(B::KERNEL); - let src = buffer_o(&self.buffer, layout, self.dtype); + let src = buffer_o(self, layout); match (el_count % 2, dtype, layout.is_contiguous()) { (0, DType::BF16 | DType::F16, true) => { @@ -728,7 +803,7 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, src, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; } @@ -804,7 +879,7 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, src, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; } @@ -869,7 +944,7 @@ impl BackendStorage for MetalStorage { crate::bail!("Metal strided unary {name} {dtype:?} not implemented") } }; - let dst = BufferOffset::zero_offset(&buffer); + let dst = buffer_offset_for_output(&buffer); candle_metal_kernels::call_unary_strided( &device.device, &command_buffer, @@ -904,6 +979,7 @@ impl BackendStorage for MetalStorage { f: &Self, f_l: &Layout, ) -> Result { + let _guard = Self::pool_guard(&[self, t, f])?; let device = self.device.clone(); let shape = t_l.shape(); let dims = shape.dims(); @@ -928,9 +1004,9 @@ impl BackendStorage for MetalStorage { (DType::U8, DType::U8) => "where_u8_u8", (left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"), }; - let src = buffer_o(&self.buffer, layout, self.dtype); - let t = buffer_o(&t.buffer, t_l, t.dtype); - let f = buffer_o(&f.buffer, f_l, f.dtype); + let src = buffer_o(self, layout); + let t = buffer_o(t, t_l); + let f = buffer_o(f, f_l); candle_metal_kernels::call_where_cond_strided( &device.device, &command_buffer, @@ -943,7 +1019,7 @@ impl BackendStorage for MetalStorage { t_l.stride(), f, f_l.stride(), - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; Ok(Self::new(buffer, device, el, dtype)) @@ -956,6 +1032,7 @@ impl BackendStorage for MetalStorage { kernel_l: &Layout, params: &ParamsConv1D, ) -> Result { + let _guard = Self::pool_guard(&[self, kernel])?; let device = self.device().clone(); let shape = layout.shape(); let dims = shape.dims(); @@ -975,7 +1052,7 @@ impl BackendStorage for MetalStorage { DType::F32 => "im2col1d_f32", dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"), }; - let src = buffer_o(&self.buffer, layout, self.dtype); + let src = buffer_o(self, layout); candle_metal_kernels::call_im2col1d_strided( &self.device.device, &command_buffer, @@ -985,15 +1062,10 @@ impl BackendStorage for MetalStorage { strides, (k_size, stride, padding, dilation), src, - &dst, + buffer_offset_for_output(&dst), ) .map_err(MetalError::from)?; - let col = Self { - buffer: dst, - device, - count: dst_el, - dtype: self.dtype, - }; + let col = Self::new(dst, device, dst_el, self.dtype); let l_out = params.l_out(); let b = params.b_size; let n = params.c_out; @@ -1027,6 +1099,7 @@ impl BackendStorage for MetalStorage { k_layout: &Layout, params: &ParamsConvTranspose1D, ) -> Result { + let _guard = Self::pool_guard(&[self, k])?; const USE_COL2IM_CONV1D_TR: bool = true; let can_use_col2im = k_layout.is_contiguous() @@ -1084,8 +1157,8 @@ impl BackendStorage for MetalStorage { &[b_size, l_in, c_out, k_size], params.k_size, params.stride, - BufferOffset::zero_offset(&col.buffer), - &buffer, + col.base_offset(), + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; buffer @@ -1123,7 +1196,7 @@ impl BackendStorage for MetalStorage { layout.start_offset() * self.dtype.size_in_bytes(), &k.buffer, k_layout.start_offset() * k.dtype.size_in_bytes(), - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; buffer @@ -1138,6 +1211,7 @@ impl BackendStorage for MetalStorage { kernel_l: &Layout, params: &ParamsConv2D, ) -> Result { + let _guard = Self::pool_guard(&[self, kernel])?; let device = self.device().clone(); let shape = layout.shape(); let dims = shape.dims(); @@ -1165,7 +1239,7 @@ impl BackendStorage for MetalStorage { DType::U32 => "im2col_u32", dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"), }; - let src = buffer_o(&self.buffer, layout, self.dtype); + let src = buffer_o(self, layout); candle_metal_kernels::call_im2col_strided( &self.device.device, &command_buffer, @@ -1175,15 +1249,10 @@ impl BackendStorage for MetalStorage { layout.stride(), (h_k, w_k, stride, padding, dilation), src, - &dst, + buffer_offset_for_output(&dst), ) .map_err(MetalError::from)?; - let col = Self { - buffer: dst, - device, - count: dst_el, - dtype: self.dtype, - }; + let col = Self::new(dst, device, dst_el, self.dtype); let h_out = params.out_h(); let w_out = params.out_w(); let b = params.b_size; @@ -1220,6 +1289,7 @@ impl BackendStorage for MetalStorage { kernel_l: &Layout, params: &ParamsConvTranspose2D, ) -> Result { + let _guard = Self::pool_guard(&[self, kernel])?; // Kernel shape: (c_in_k, c_out, h_k, w_k) // Input shape: (b_size, c_in, h_in, w_in) let (out_w, out_h) = (params.out_w(), params.out_h()); @@ -1271,7 +1341,7 @@ impl BackendStorage for MetalStorage { }, &self.buffer, &kernel.buffer, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) @@ -1283,6 +1353,7 @@ impl BackendStorage for MetalStorage { (w_k, h_k): (usize, usize), (w_stride, h_stride): (usize, usize), ) -> Result { + let _guard = Self::pool_guard(&[self])?; let shape = inp_l.shape(); let (b_size, channels, width, height) = shape.dims4()?; let strides = inp_l.stride(); @@ -1313,7 +1384,7 @@ impl BackendStorage for MetalStorage { w_stride, h_stride, &self.buffer, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) @@ -1325,6 +1396,7 @@ impl BackendStorage for MetalStorage { (w_k, h_k): (usize, usize), (w_stride, h_stride): (usize, usize), ) -> Result { + let _guard = Self::pool_guard(&[self])?; let shape = inp_l.shape(); let (b_size, channels, width, height) = shape.dims4()?; let strides = inp_l.stride(); @@ -1355,7 +1427,7 @@ impl BackendStorage for MetalStorage { w_stride, h_stride, &self.buffer, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) @@ -1366,6 +1438,7 @@ impl BackendStorage for MetalStorage { } fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result { + let _guard = Self::pool_guard(&[self])?; // let inp = &inp.slice(inp_l.start_offset()..); let shape = inp_l.shape(); let dims = shape.dims(); @@ -1387,7 +1460,7 @@ impl BackendStorage for MetalStorage { .device .new_buffer(dst_el, self.dtype, "upsample_nearest2d")?; let command_buffer = self.device.command_buffer()?; - let src = buffer_o(&self.buffer, inp_l, self.dtype); + let src = buffer_o(self, inp_l); candle_metal_kernels::call_upsample_nearest_2d( &self.device.device, &command_buffer, @@ -1398,13 +1471,14 @@ impl BackendStorage for MetalStorage { out_w, out_h, src, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) } fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { + let _guard = Self::pool_guard(&[self, ids])?; if !ids_l.is_contiguous() { return Err(crate::Error::RequiresContiguous { op: "gather" }.bt()); }; @@ -1427,8 +1501,8 @@ impl BackendStorage for MetalStorage { (left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"), }; let command_buffer = self.device.command_buffer()?; - let src = buffer_o(&self.buffer, src_l, dtype); - let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); + let src = buffer_o(self, src_l); + let ids = buffer_o(ids, ids_l); candle_metal_kernels::call_gather( &device.device, &command_buffer, @@ -1439,7 +1513,7 @@ impl BackendStorage for MetalStorage { dim, src, ids, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; Ok(Self::new(buffer, device.clone(), dst_el, dtype)) @@ -1475,9 +1549,9 @@ impl BackendStorage for MetalStorage { })?, }; let command_buffer = self.device.command_buffer()?; - let dst = buffer_o(&self.buffer, l, self.dtype); - let src = buffer_o(&src.buffer, src_l, src.dtype); - let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); + let dst = buffer_o(self, l); + let src = buffer_o(src, src_l); + let ids = buffer_o(ids, ids_l); candle_metal_kernels::call_scatter( &self.device.device, &command_buffer, @@ -1524,9 +1598,9 @@ impl BackendStorage for MetalStorage { })?, }; let command_buffer = self.device.command_buffer()?; - let dst = buffer_o(&self.buffer, l, self.dtype); - let src = buffer_o(&src.buffer, src_l, src.dtype); - let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); + let dst = buffer_o(self, l); + let src = buffer_o(src, src_l); + let ids = buffer_o(ids, ids_l); candle_metal_kernels::call_scatter( &self.device.device, &command_buffer, @@ -1544,6 +1618,7 @@ impl BackendStorage for MetalStorage { } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { + let _guard = Self::pool_guard(&[self, ids])?; if !ids_l.is_contiguous() { crate::bail!("Metal index_select requires contiguous ids") } @@ -1581,8 +1656,8 @@ impl BackendStorage for MetalStorage { } }; let command_buffer = self.device.command_buffer()?; - let src = buffer_o(&self.buffer, src_l, dtype); - let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); + let src = buffer_o(self, src_l); + let ids = buffer_o(ids, ids_l); candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -1596,7 +1671,7 @@ impl BackendStorage for MetalStorage { src_l.stride(), src, ids, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; Ok(Self::new(buffer, device.clone(), dst_el, dtype)) @@ -1645,8 +1720,8 @@ impl BackendStorage for MetalStorage { })?, }; let command_buffer = self.device.command_buffer()?; - let src = buffer_o(&src.buffer, src_l, src.dtype); - let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); + let src = buffer_o(src, src_l); + let ids = buffer_o(ids, ids_l); candle_metal_kernels::call_index_add( &self.device.device, &command_buffer, @@ -1658,7 +1733,7 @@ impl BackendStorage for MetalStorage { dim, src, ids, - &acc.buffer, + buffer_o(&acc, l), ) .map_err(MetalError::from)?; Ok(acc) @@ -1671,6 +1746,7 @@ impl BackendStorage for MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { + let _guard = Self::pool_guard(&[self, rhs])?; let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; let command_buffer = self.device.command_buffer()?; command_buffer.set_label("matmul"); @@ -1696,7 +1772,7 @@ impl BackendStorage for MetalStorage { rhs_l.stride(), rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &rhs.buffer, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; @@ -1760,8 +1836,8 @@ impl BackendStorage for MetalStorage { d2, src_s, dst_s, - src_o * self.dtype.size_in_bytes(), - dst_o * self.dtype.size_in_bytes(), + self.offset_bytes + src_o * self.dtype.size_in_bytes(), + dst.offset_bytes + dst_o * self.dtype.size_in_bytes(), ) .map_err(MetalError::from)?; command_buffer.set_label("copy2d"); @@ -1775,9 +1851,11 @@ impl BackendStorage for MetalStorage { command_buffer.set_label("copy_contiguous"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("copy_contiguous"); - let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; + let src_offset = (self.offset_bytes + src_l.start_offset() * self.dtype.size_in_bytes()) + as NSUInteger; let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; - let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; + let dst_offset = + (dst.offset_bytes + dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); blit.end_encoding(); } else { @@ -1795,10 +1873,10 @@ impl BackendStorage for MetalStorage { DType::U8 => candle_metal_kernels::unary::strided::copy::U8, dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"), }; - let src = buffer_o(&self.buffer, src_l, self.dtype); + let src = buffer_o(self, src_l); let dst = BufferOffset { buffer: &dst.buffer, - offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(), + offset_in_bytes: dst.offset_bytes + dst_offset * dst.dtype.size_in_bytes(), }; candle_metal_kernels::call_unary_strided( &self.device.device, @@ -1819,14 +1897,61 @@ impl BackendStorage for MetalStorage { impl MetalStorage { pub fn new(buffer: Arc, device: MetalDevice, count: usize, dtype: DType) -> Self { + let pool_allocation = take_pool_allocation(&buffer); + let offset_bytes = pool_allocation + .as_ref() + .map(|alloc| alloc.offset()) + .unwrap_or(0); Self { buffer, device, count, dtype, + offset_bytes, + pool_allocation, } } + pub fn pool(&self) -> Option> { + self.pool_allocation + .as_ref() + .map(|allocation| allocation.pool()) + } + + fn base_offset(&self) -> BufferOffset<'_> { + BufferOffset { + buffer: &self.buffer, + offset_in_bytes: self.offset_bytes, + } + } + + fn buffer_offset<'a>(&'a self, layout: &Layout) -> BufferOffset<'a> { + let mut base = self.base_offset(); + base.offset_in_bytes += layout.start_offset() * self.dtype.size_in_bytes(); + base + } + + fn determine_pool(storages: &[&Self]) -> Result>> { + let mut pool: Option> = None; + for storage in storages { + if let Some(candidate) = storage.pool() { + if let Some(existing) = &pool { + if !Arc::ptr_eq(existing, &candidate) { + crate::bail!("Cannot operate on tensors from different pools"); + } + } else { + pool = Some(candidate); + } + } + } + Ok(pool) + } + + pub(crate) fn pool_guard(storages: &[&Self]) -> Result { + let pool = Self::determine_pool(storages)?; + Ok(push_pool_context(pool)) + } + pub fn buffer(&self) -> &Buffer { &self.buffer } @@ -1838,12 +1963,13 @@ impl MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { + let _guard = Self::pool_guard(&[self, rhs])?; let device = self.device(); let shape = lhs_l.shape(); let el_count = shape.elem_count(); let command_buffer = device.command_buffer()?; - let lhs = buffer_o(&self.buffer, lhs_l, self.dtype); - let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype); + let lhs = buffer_o(self, lhs_l); + let rhs = buffer_o(rhs, rhs_l); let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" { use candle_metal_kernels::binary::contiguous; @@ -1927,7 +2053,7 @@ impl MetalStorage { el_count, lhs, rhs, - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; (buffer, dtype) @@ -2028,7 +2154,7 @@ impl MetalStorage { lhs_l.stride(), rhs, rhs_l.stride(), - &buffer, + buffer_offset_for_output(&buffer), ) .map_err(MetalError::from)?; (buffer, dtype) @@ -2046,7 +2172,13 @@ impl MetalStorage { command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("blit_to_cpu"); - blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, size); + blit.copy_from_buffer( + &self.buffer, + self.offset_bytes as NSUInteger, + &buffer, + 0, + size, + ); blit.end_encoding(); } self.device.wait_until_completed()?; diff --git a/candle-core/src/metal_backend/pool.rs b/candle-core/src/metal_backend/pool.rs new file mode 100644 index 0000000000..24d9f252d8 --- /dev/null +++ b/candle-core/src/metal_backend/pool.rs @@ -0,0 +1,192 @@ +use crate::Result; +use metal::Buffer; +use std::sync::{Arc, Mutex}; + +use super::{MetalDevice, MetalError}; + +const ALIGNMENT: usize = 256; + +fn align_up(value: usize, alignment: usize) -> usize { + if alignment == 0 { + return value; + } + ((value + alignment - 1) / alignment) * alignment +} + +#[derive(Debug)] +struct PoolState { + free: Vec<(usize, usize)>, + used: usize, +} + +impl PoolState { + fn new(capacity: usize) -> Self { + Self { + free: vec![(0, capacity)], + used: 0, + } + } + + fn allocate(&mut self, size: usize, alignment: usize, capacity: usize) -> Option { + let alignment = alignment.max(ALIGNMENT); + for idx in 0..self.free.len() { + let (start, len) = self.free[idx]; + let aligned_start = align_up(start, alignment); + if aligned_start >= start + len { + continue; + } + let padding = aligned_start - start; + let available = len.saturating_sub(padding); + if available < size { + continue; + } + let remaining = available - size; + self.free.remove(idx); + if padding > 0 { + self.free.insert(idx, (start, padding)); + } + if remaining > 0 { + self.free.insert( + idx + (padding > 0) as usize, + (aligned_start + size, remaining), + ); + } + self.used += size; + debug_assert!(self.used <= capacity); + return Some(aligned_start); + } + None + } + + fn free(&mut self, offset: usize, size: usize) { + if size == 0 { + return; + } + self.used = self.used.saturating_sub(size); + let mut insert_pos = 0; + while insert_pos < self.free.len() && self.free[insert_pos].0 < offset { + insert_pos += 1; + } + self.free.insert(insert_pos, (offset, size)); + // Merge with previous block if adjacent. + if insert_pos > 0 { + if let Some(merged) = try_merge(self.free[insert_pos - 1], self.free[insert_pos]) { + self.free[insert_pos - 1] = merged; + self.free.remove(insert_pos); + insert_pos -= 1; + } + } + // Merge with next block if adjacent. + if insert_pos + 1 < self.free.len() { + if let Some(merged) = try_merge(self.free[insert_pos], self.free[insert_pos + 1]) { + self.free[insert_pos] = merged; + self.free.remove(insert_pos + 1); + } + } + } +} + +fn try_merge(lhs: (usize, usize), rhs: (usize, usize)) -> Option<(usize, usize)> { + if lhs.0 + lhs.1 == rhs.0 { + Some((lhs.0, lhs.1 + rhs.1)) + } else { + None + } +} + +#[derive(Debug)] +pub struct MetalTensorPool { + device: MetalDevice, + capacity: usize, + buffer: Arc, + state: Mutex, +} + +impl MetalTensorPool { + pub fn new(device: MetalDevice, capacity: usize) -> Result> { + if capacity == 0 { + crate::bail!("Pool capacity must be greater than 0"); + } + let buffer = Arc::new(device.device().new_buffer( + capacity as u64, + metal::MTLResourceOptions::StorageModePrivate, + )); + Ok(Arc::new(Self { + device, + capacity, + buffer, + state: Mutex::new(PoolState::new(capacity)), + })) + } + + pub fn device(&self) -> &MetalDevice { + &self.device + } + + pub fn buffer(&self) -> &Arc { + &self.buffer + } + + pub fn allocate( + self: &Arc, + size_in_bytes: usize, + alignment: usize, + ) -> Result> { + if size_in_bytes == 0 { + crate::bail!("Cannot allocate zero bytes from pool"); + } + let mut state = self.state.lock().map_err(MetalError::from)?; + if let Some(offset) = state.allocate(size_in_bytes, alignment, self.capacity) { + Ok(Arc::new(MetalPoolAllocation { + pool: Arc::clone(self), + offset, + size: size_in_bytes, + })) + } else { + crate::bail!( + "Metal tensor pool exhausted: requested {size_in_bytes} bytes, capacity {} bytes", + self.capacity + ) + } + } + + fn release(&self, offset: usize, size: usize) { + if size == 0 { + return; + } + if let Ok(mut state) = self.state.lock() { + state.free(offset, size); + } + } +} + +#[derive(Debug)] +pub struct MetalPoolAllocation { + pool: Arc, + offset: usize, + size: usize, +} + +impl MetalPoolAllocation { + pub fn buffer(&self) -> &Arc { + &self.pool.buffer + } + + pub fn offset(&self) -> usize { + self.offset + } + + pub fn size(&self) -> usize { + self.size + } + + pub fn pool(&self) -> Arc { + Arc::clone(&self.pool) + } +} + +impl Drop for MetalPoolAllocation { + fn drop(&mut self) { + self.pool.release(self.offset, self.size); + } +} diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 14ace645da..ae5c8909fb 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -217,7 +217,8 @@ impl crate::CustomOp1 for ArgSort { let el = layout.shape().elem_count(); let ncols = self.last_dim; let nrows = el / ncols; - let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype()); + let _guard = crate::MetalStorage::pool_guard(&[storage])?; + let src = crate::metal_backend::buffer_o(storage, layout); let dst = device.new_buffer(el, DType::U32, "asort")?; let mut ncols_pad = 1; while ncols_pad < ncols { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 791d5db345..7d54d6b18e 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,6 +1,8 @@ //! Tensors are N-dimensional matrixes of elements using a single data type. #![allow(clippy::redundant_closure_call)] use crate::backend::{BackendDevice, BackendStorage}; +#[cfg(feature = "metal")] +use crate::metal_backend::{push_pool_context, MetalTensorPool}; use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp}; use crate::scalar::TensorOrScalar; use crate::shape::{Dim, Dims, ShapeWithOneHole}; @@ -2232,6 +2234,94 @@ impl Tensor { Ok(Tensor(Arc::new(tensor_))) } + pub fn start_pool(&self, size_in_bytes: usize) -> Result { + if size_in_bytes == 0 { + bail!("Pool size must be greater than 0"); + } + #[cfg(not(feature = "metal"))] + { + let _ = size_in_bytes; + bail!("start_pool requires candle to be built with the `metal` feature") + } + #[cfg(feature = "metal")] + { + let device = match &self.device { + Device::Metal(device) => device, + _ => bail!("start_pool is only supported for tensors on Metal devices"), + }; + + let pool = MetalTensorPool::new(device.clone(), size_in_bytes)?; + let guard = push_pool_context(Some(pool.clone())); + let mut pooled_storage = unsafe { device.alloc_uninit(self.shape(), self.dtype())? }; + drop(guard); + + { + let src_storage = self.storage(); + match &*src_storage { + Storage::Metal(src) => { + src.copy_strided_src(&mut pooled_storage, 0, self.layout())?; + } + _ => unreachable!("Metal tensor expected"), + } + } + + let storage = Storage::Metal(pooled_storage); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: Arc::new(RwLock::new(storage)), + layout: self.layout.clone(), + op: BackpropOp::new1(self, Op::Copy), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } + } + + pub fn leave_pool(&self) -> Result { + #[cfg(not(feature = "metal"))] + { + return Ok(self.clone()); + } + #[cfg(feature = "metal")] + { + let is_pooled = + matches!(&*self.storage(), Storage::Metal(storage) if storage.pool().is_some()); + if !is_pooled { + return Ok(self.clone()); + } + + let device = match &self.device { + Device::Metal(device) => device, + _ => bail!("leave_pool is only supported for tensors on Metal devices"), + }; + + let mut new_storage = unsafe { device.alloc_uninit(self.shape(), self.dtype())? }; + { + let src_storage = self.storage(); + match &*src_storage { + Storage::Metal(src) => { + src.copy_strided_src(&mut new_storage, 0, self.layout())?; + } + _ => unreachable!("Metal tensor expected"), + } + } + + let storage = Storage::Metal(new_storage); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: Arc::new(RwLock::new(storage)), + layout: self.layout.clone(), + op: BackpropOp::new1(self, Op::Copy), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } + } + /// Returns a new tensor detached from the current graph, gradient are not propagated through /// this new node. The storage of this tensor is shared with the initial tensor. /// diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c03b1c1370..758f73febb 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -503,7 +503,7 @@ pub fn call_unary_contiguous_tiled( kernel_name: unary::contiguous_tiled::Kernel, length: usize, input: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = ep.encoder(); @@ -513,11 +513,11 @@ pub fn call_unary_contiguous_tiled( encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, &input, output)); + set_params!(encoder, (length, &input, &output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -530,7 +530,7 @@ pub fn call_unary_contiguous( kernel_name: unary::contiguous::Kernel, length: usize, input: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = ep.encoder(); @@ -538,11 +538,11 @@ pub fn call_unary_contiguous( encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, &input, output)); + set_params!(encoder, (length, &input, &output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -583,7 +583,7 @@ pub fn call_binary_contiguous( length: usize, left: BufferOffset, right: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; @@ -591,13 +591,13 @@ pub fn call_binary_contiguous( let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, &left, &right, output)); + set_params!(encoder, (length, &left, &right, &output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -613,7 +613,7 @@ pub fn call_binary_strided( left_strides: &[usize], right_input: BufferOffset, right_strides: &[usize], - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; @@ -635,12 +635,12 @@ pub fn call_binary_strided( right_strides, &left_input, &right_input, - output + &output ) ); encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) @@ -654,7 +654,7 @@ pub fn call_cast_contiguous( kernel_name: &'static str, length: usize, input: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; @@ -662,11 +662,11 @@ pub fn call_cast_contiguous( let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, &input, output)); + set_params!(encoder, (length, &input, &output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -680,7 +680,7 @@ pub fn call_cast_strided( shape: &[usize], input: BufferOffset, input_strides: &[usize], - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; @@ -692,13 +692,13 @@ pub fn call_cast_strided( set_params!( encoder, - (length, shape.len(), shape, input_strides, &input, output) + (length, shape.len(), shape, input_strides, &input, &output) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -712,7 +712,7 @@ pub fn call_reduce_contiguous( shape: &[usize], out_length: usize, input: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let length = shape.iter().product::(); let num_dims = shape.len(); @@ -731,7 +731,7 @@ pub fn call_reduce_contiguous( shape, work_per_threadgroup, &input, - output + &output ) ); @@ -753,7 +753,7 @@ pub fn call_reduce_contiguous( }; encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -768,7 +768,7 @@ pub fn call_reduce_strided( strides: &[usize], out_length: usize, input: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let length: usize = shape.iter().product(); let num_dims = shape.len(); @@ -788,7 +788,7 @@ pub fn call_reduce_strided( strides, work_per_threadgroup, &input, - output + &output ) ); @@ -809,7 +809,7 @@ pub fn call_reduce_strided( depth: 1, }; encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -824,7 +824,7 @@ pub fn call_last_softmax( elements: usize, input: &Buffer, input_offset: usize, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let work_per_threadgroup = elements; @@ -835,7 +835,7 @@ pub fn call_last_softmax( set_params!( encoder, - (length, work_per_threadgroup, (input, input_offset), output) + (length, work_per_threadgroup, (input, input_offset), &output) ); let out_length = length / work_per_threadgroup; @@ -857,7 +857,7 @@ pub fn call_last_softmax( depth: 1, }; encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -875,7 +875,7 @@ pub fn call_rms_norm( input_offset: usize, alpha: &Buffer, alpha_offset: usize, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); @@ -888,7 +888,7 @@ pub fn call_rms_norm( length, elements_to_sum, (input, input_offset), - output, + &output, (alpha, alpha_offset), eps ) @@ -915,7 +915,7 @@ pub fn call_rms_norm( }; encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) @@ -936,7 +936,7 @@ pub fn call_layer_norm( alpha_offset: usize, beta: &Buffer, beta_offset: usize, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); @@ -949,7 +949,7 @@ pub fn call_layer_norm( length, elements_to_sum, (input, input_offset), - output, + &output, (alpha, alpha_offset), (beta, beta_offset), eps @@ -977,7 +977,7 @@ pub fn call_layer_norm( }; encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) @@ -998,7 +998,7 @@ pub fn call_rope_i( cos_offset: usize, sin: &Buffer, sin_offset: usize, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); @@ -1014,14 +1014,14 @@ pub fn call_rope_i( (src, src_offset), (cos, cos_offset), (sin, sin_offset), - output + &output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2); encoder.use_resource(src, metal::MTLResourceUsage::Read); encoder.use_resource(cos, metal::MTLResourceUsage::Read); encoder.use_resource(sin, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1043,7 +1043,7 @@ pub fn call_rope_thd( cos_offset: usize, sin: &Buffer, sin_offset: usize, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); @@ -1061,14 +1061,14 @@ pub fn call_rope_thd( (src, src_offset), (cos, cos_offset), (sin, sin_offset), - output + &output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, (b * t * h * d) / 2); encoder.use_resource(src, metal::MTLResourceUsage::Read); encoder.use_resource(cos, metal::MTLResourceUsage::Read); encoder.use_resource(sin, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1089,7 +1089,7 @@ pub fn call_rope( cos_offset: usize, sin: &Buffer, sin_offset: usize, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); @@ -1106,14 +1106,14 @@ pub fn call_rope( (src, src_offset), (cos, cos_offset), (sin, sin_offset), - output + &output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2); encoder.use_resource(src, metal::MTLResourceUsage::Read); encoder.use_resource(cos, metal::MTLResourceUsage::Read); encoder.use_resource(sin, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1126,7 +1126,7 @@ pub fn call_affine( name: &'static str, size: usize, input: BufferOffset, - output: &Buffer, + output: BufferOffset, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { @@ -1136,11 +1136,11 @@ pub fn call_affine( let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (size, mul, add, &input, output)); + set_params!(encoder, (size, mul, add, &input, &output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1154,7 +1154,7 @@ pub fn call_affine_strided( shape: &[usize], input: BufferOffset, input_stride: &[usize], - output: &Buffer, + output: BufferOffset, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { @@ -1175,13 +1175,13 @@ pub fn call_affine_strided( mul, add, &input, - output + &output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1194,7 +1194,7 @@ pub fn call_powf( name: &'static str, size: usize, input: BufferOffset, - output: &Buffer, + output: BufferOffset, mul: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; @@ -1203,11 +1203,11 @@ pub fn call_powf( let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (size, mul, &input, output)); + set_params!(encoder, (size, mul, &input, &output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1221,7 +1221,7 @@ pub fn call_powf_strided( shape: &[usize], input: BufferOffset, input_stride: &[usize], - output: &Buffer, + output: BufferOffset, mul: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; @@ -1233,12 +1233,12 @@ pub fn call_powf_strided( set_params!( encoder, - (size, shape.len(), shape, input_stride, mul, &input, output) + (size, shape.len(), shape, input_stride, mul, &input, &output) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1251,7 +1251,7 @@ pub fn call_elu( name: &'static str, size: usize, input: BufferOffset, - output: &Buffer, + output: BufferOffset, mul: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; @@ -1260,11 +1260,11 @@ pub fn call_elu( let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (size, mul, &input, output)); + set_params!(encoder, (size, mul, &input, &output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1278,7 +1278,7 @@ pub fn call_elu_strided( shape: &[usize], input: BufferOffset, input_stride: &[usize], - output: &Buffer, + output: BufferOffset, mul: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; @@ -1290,12 +1290,12 @@ pub fn call_elu_strided( set_params!( encoder, - (size, shape.len(), shape, input_stride, mul, &input, output) + (size, shape.len(), shape, input_stride, mul, &input, &output) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1313,7 +1313,7 @@ pub fn call_where_cond_strided( left_stride: &[usize], right: BufferOffset, right_stride: &[usize], - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; @@ -1336,7 +1336,7 @@ pub fn call_where_cond_strided( &cond, &left, &right, - output + &output ) ); @@ -1345,7 +1345,7 @@ pub fn call_where_cond_strided( encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1364,7 +1364,7 @@ pub fn call_index_select( src_strides: &[usize], input: BufferOffset, ids: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); @@ -1391,7 +1391,7 @@ pub fn call_index_select( src_strides, &input, &ids, - output + &output ) ); @@ -1399,7 +1399,7 @@ pub fn call_index_select( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1415,7 +1415,7 @@ pub fn call_gather( dim: usize, input: BufferOffset, ids: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); @@ -1439,7 +1439,7 @@ pub fn call_gather( ids_size, &input, &ids, - output + &output ) ); @@ -1447,7 +1447,7 @@ pub fn call_gather( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1513,7 +1513,7 @@ pub fn call_index_add( dim: usize, input: BufferOffset, ids: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let left_size: usize = src_shape[..dim].iter().product(); let right_size: usize = src_shape[dim + 1..].iter().product(); @@ -1539,7 +1539,7 @@ pub fn call_index_add( ids_dim_size, &input, &ids, - output + &output ) ); @@ -1547,7 +1547,7 @@ pub fn call_index_add( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -1663,7 +1663,7 @@ pub fn call_sdpa_full( mask_type: Option, mask_buffer: Option<&Buffer>, m_strides: Option<&[usize]>, - output: &Buffer, + output: BufferOffset, o_strides: &[usize], scale: f32, do_causal: bool, @@ -1830,7 +1830,7 @@ pub fn call_sdpa_full( (q_buffer, q_offset), (k_buffer, k_offset), (v_buffer, v_offset), - output, + &output, params, mask_params, mask @@ -1843,7 +1843,7 @@ pub fn call_sdpa_full( (q_buffer, q_offset), (k_buffer, k_offset), (v_buffer, v_offset), - output, + &output, params ) ); @@ -1862,7 +1862,7 @@ pub fn call_sdpa_full( encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_dims, group_dims); Ok(()) @@ -1887,7 +1887,7 @@ pub fn call_sdpa_vector( v_offset: usize, v_stride: &[usize], v_buffer: &Buffer, - output: &Buffer, + output: BufferOffset, alpha: f32, softcapping: f32, itype: SdpaDType, @@ -1956,7 +1956,7 @@ pub fn call_sdpa_vector( (q_buffer, q_offset), (k_buffer, k_offset), (v_buffer, v_offset), - output, + &output, gqa_factor, n, kstride, @@ -1979,7 +1979,7 @@ pub fn call_sdpa_vector( encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_dims, group_dims); Ok(()) } @@ -2005,7 +2005,7 @@ pub fn call_sdpa_vector_2pass( v_offset: usize, v_stride: &[usize], v_buffer: &Buffer, - output: &Buffer, + output: BufferOffset, intermediate: &Buffer, sums: &Buffer, maxs: &Buffer, @@ -2143,7 +2143,7 @@ pub fn call_sdpa_vector_2pass( // q = (bs, qhead, seq, hidden) // k/v = (bs, kv_head, kv_seq, hidden) - set_params!(encoder, (intermediate, sums, maxs, output)); + set_params!(encoder, (intermediate, sums, maxs, &output)); let grid_dims = MTLSize { width: 1, @@ -2158,7 +2158,7 @@ pub fn call_sdpa_vector_2pass( encoder.use_resource(intermediate, metal::MTLResourceUsage::Write); encoder.use_resource(sums, metal::MTLResourceUsage::Write); encoder.use_resource(maxs, metal::MTLResourceUsage::Write); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_dims, group_dims); } @@ -2175,7 +2175,7 @@ pub fn call_im2col1d_strided( strides: &[usize], (k_size, stride, padding, dilation): (usize, usize, usize, usize), input: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1; @@ -2187,10 +2187,10 @@ pub fn call_im2col1d_strided( encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, - (dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output) + (dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, &output) ); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -2205,7 +2205,7 @@ pub fn call_col2im1d( k_size: usize, stride: usize, input: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; let l_in = shape[1]; @@ -2219,10 +2219,10 @@ pub fn call_col2im1d( encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, - (dst_el, l_out, l_in, c_out, k_size, stride, &input, output) + (dst_el, l_out, l_in, c_out, k_size, stride, &input, &output) ); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -2237,7 +2237,7 @@ pub fn call_im2col_strided( strides: &[usize], (h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize), input: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; @@ -2256,11 +2256,11 @@ pub fn call_im2col_strided( encoder, ( dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input, - output + &output ) ); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -2276,7 +2276,7 @@ pub fn call_upsample_nearest_2d( out_w: usize, out_h: usize, input: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; let dst_el = out_w * out_h * shape[0] * shape[1]; @@ -2288,10 +2288,10 @@ pub fn call_upsample_nearest_2d( encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, - (out_w, out_h, scale_w, scale_h, shape, strides, &input, output) + (out_w, out_h, scale_w, scale_h, shape, strides, &input, &output) ); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -2665,7 +2665,7 @@ pub fn call_pool2d( w_stride: usize, h_stride: usize, input: &Buffer, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let dst_el = out_w * out_h * shape[0] * shape[1]; let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; @@ -2675,10 +2675,10 @@ pub fn call_pool2d( encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, - (w_k, h_k, w_stride, h_stride, shape, strides, input, output) + (w_k, h_k, w_stride, h_stride, shape, strides, input, &output) ); encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -2704,7 +2704,7 @@ pub fn call_conv_transpose1d( input_offset: usize, kernel: &Buffer, kernel_offset: usize, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let dst_el = c_out * l_out * b_size; let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; @@ -2726,12 +2726,12 @@ pub fn call_conv_transpose1d( kernel_strides, (input, input_offset), (kernel, kernel_offset), - output + &output ) ); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(kernel, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -2762,7 +2762,7 @@ pub fn call_conv_transpose2d( cfg: CallConvTranspose2dCfg, input: &Buffer, kernel: &Buffer, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size; let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; @@ -2785,12 +2785,12 @@ pub fn call_conv_transpose2d( cfg.kernel_stride, (input, cfg.input_offset), (kernel, cfg.kernel_offset), - output + &output ) ); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(kernel, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } @@ -2801,16 +2801,16 @@ pub fn call_const_fill( kernels: &Kernels, name: &'static str, length: usize, - output: &Buffer, + output: BufferOffset, v: impl EncoderParam, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Fill, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (output, v, length)); + set_params!(encoder, (&output, v, length)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } diff --git a/candle-metal-kernels/src/mlx_gemm.rs b/candle-metal-kernels/src/mlx_gemm.rs index ee4292c39d..0bc4ee47f8 100644 --- a/candle-metal-kernels/src/mlx_gemm.rs +++ b/candle-metal-kernels/src/mlx_gemm.rs @@ -1,4 +1,4 @@ -use crate::utils::EncoderProvider; +use crate::utils::{BufferOffset, EncoderProvider}; use crate::{ConstantValues, Kernels, MetalKernelError, Source, Value}; use metal::{Buffer, ComputeCommandEncoderRef, Device, MTLSize, NSUInteger}; use std::ffi::c_void; @@ -23,7 +23,7 @@ pub fn call_mlx_gemm( rhs_stride: &[usize], rhs_offset: usize, rhs_buffer: &Buffer, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { #[derive(Debug)] #[repr(C)] @@ -145,7 +145,7 @@ pub fn call_mlx_gemm( encoder.set_compute_pipeline_state(&pipeline); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); - encoder.set_buffer(3, Some(output), 0); + encoder.set_buffer(3, Some(output.buffer), output.offset_in_bytes as NSUInteger); encoder.set_bytes( 4, std::mem::size_of::() as u64, @@ -174,7 +174,7 @@ pub fn call_mlx_gemm( }; encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_size, group_size); Ok(()) } diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 979effd8ef..06bcc9ac57 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -174,7 +174,7 @@ impl candle::CustomOp1 for Sigmoid { kernel_name, el_count, src, - &buffer, + device.buffer_offset(&buffer), ) .map_err(MetalError::from)?; } @@ -195,7 +195,7 @@ impl candle::CustomOp1 for Sigmoid { kernel_name, el_count, src, - &buffer, + device.buffer_offset(&buffer), ) .map_err(MetalError::from)?; } @@ -209,7 +209,7 @@ impl candle::CustomOp1 for Sigmoid { candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented") } }; - let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer); + let dst = device.buffer_offset(&buffer); candle_metal_kernels::call_unary_strided( device.metal_device(), &command_buffer, @@ -437,7 +437,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { last_dim, storage.buffer(), layout.start_offset() * storage.dtype().size_in_bytes(), - &output, + device.buffer_offset(&output), ) .map_err(candle::Error::wrap)?; let newstorage = @@ -630,7 +630,7 @@ impl candle::CustomOp2 for RmsNorm { l1.start_offset() * s1.dtype().size_in_bytes(), s2.buffer(), l2.start_offset() * s2.dtype().size_in_bytes(), - &output, + device.buffer_offset(&output), ) .map_err(candle::Error::wrap)?; let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype()); @@ -876,7 +876,7 @@ impl candle::CustomOp3 for LayerNorm { l2.start_offset() * s2.dtype().size_in_bytes(), s3.buffer(), l3.start_offset() * s3.dtype().size_in_bytes(), - &output, + device.buffer_offset(&output), ) .map_err(candle::Error::wrap)?; let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype()); @@ -1091,7 +1091,7 @@ impl candle::CustomOp3 for Sdpa { other => candle::bail!("unsupported sdpa type {other:?}"), }; - let command_buffer = q.device().command_buffer()?; + let command_buffer = device.command_buffer()?; if supports_sdpa_vector { // Route to the 2 pass fused attention if the k seqlen is large. // https://github.com/ml-explore/mlx/pull/1597 @@ -1122,9 +1122,9 @@ impl candle::CustomOp3 for Sdpa { command_buffer.set_label("vector_attention"); candle_metal_kernels::call_sdpa_vector_2pass( - q.device().device(), + device.metal_device(), &command_buffer, - q.device().kernels(), + device.kernels(), q_l.start_offset(), q_l.dims(), q.buffer(), @@ -1135,7 +1135,7 @@ impl candle::CustomOp3 for Sdpa { v_l.start_offset(), v_l.stride(), v.buffer(), - &output, + device.buffer_offset(&output), &intermediate, &sums, &maxs, @@ -1147,9 +1147,9 @@ impl candle::CustomOp3 for Sdpa { } else { command_buffer.set_label("vector_attention"); candle_metal_kernels::call_sdpa_vector( - q.device().device(), + device.metal_device(), &command_buffer, - q.device().kernels(), + device.kernels(), q_l.start_offset(), q_l.dims(), q.buffer(), @@ -1160,7 +1160,7 @@ impl candle::CustomOp3 for Sdpa { v_l.start_offset(), v_l.stride(), v.buffer(), - &output, + device.buffer_offset(&output), self.scale, self.softcapping, itype, @@ -1211,9 +1211,9 @@ impl candle::CustomOp3 for Sdpa { }; candle_metal_kernels::call_sdpa_full( - q.device().device(), + device.metal_device(), &command_buffer, - q.device().kernels(), + device.kernels(), q_l.start_offset(), q_l.dims(), q_l.stride(), @@ -1228,7 +1228,7 @@ impl candle::CustomOp3 for Sdpa { mask_type, mask_buffer, mask_strides.as_deref(), - &output, + device.buffer_offset(&output), out_layout.stride(), self.scale, self.do_causal, diff --git a/candle-nn/src/rotary_emb.rs b/candle-nn/src/rotary_emb.rs index bfb541f0c6..c610084c1d 100644 --- a/candle-nn/src/rotary_emb.rs +++ b/candle-nn/src/rotary_emb.rs @@ -216,7 +216,7 @@ impl candle::CustomOp3 for RotaryEmbI { l_cos.start_offset() * cos.dtype().size_in_bytes(), sin.buffer(), l_sin.start_offset() * sin.dtype().size_in_bytes(), - &output, + device.buffer_offset(&output), ) .map_err(candle::Error::wrap)?; let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype()); @@ -499,7 +499,7 @@ impl candle::CustomOp3 for RotaryEmb { l_cos.start_offset() * cos.dtype().size_in_bytes(), sin.buffer(), l_sin.start_offset() * sin.dtype().size_in_bytes(), - &output, + device.buffer_offset(&output), ) .map_err(candle::Error::wrap)?; let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype()); @@ -769,7 +769,7 @@ impl candle::CustomOp3 for RotaryEmbThd { l_cos.start_offset() * cos.dtype().size_in_bytes(), sin.buffer(), l_sin.start_offset() * sin.dtype().size_in_bytes(), - &output, + device.buffer_offset(&output), ) .map_err(candle::Error::wrap)?; let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());