diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index f249202ddd..752a9f86f7 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -7,6 +7,13 @@ use std::sync::{Arc, Mutex, RwLock}; use super::MetalError; +// iOS and macOS have different storage modes for shared buffers. +// due to the GPU/CPU management differences. +#[cfg(target_os = "ios")] +pub const SHARED_BUFFER_STORAGE_MODE: MTLResourceOptions = MTLResourceOptions::StorageModeShared; +#[cfg(not(target_os = "ios"))] +pub const SHARED_BUFFER_STORAGE_MODE: MTLResourceOptions = MTLResourceOptions::StorageModeManaged; + /// Unique identifier for cuda devices. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct DeviceId(usize); @@ -255,7 +262,7 @@ impl MetalDevice { /// synchronization when the CPU memory is modified /// Used as a bridge to gather data back from the GPU pub fn new_buffer_managed(&self, size: NSUInteger) -> Result> { - self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") + self.allocate_buffer(size, SHARED_BUFFER_STORAGE_MODE, "managed") } /// Creates a new buffer from data. @@ -268,12 +275,12 @@ impl MetalDevice { let new_buffer = self.device.new_buffer_with_data( data.as_ptr().cast(), size, - MTLResourceOptions::StorageModeManaged, + SHARED_BUFFER_STORAGE_MODE, ); let mut buffers = self.buffers.write().map_err(MetalError::from)?; let subbuffers = buffers - .entry((size, MTLResourceOptions::StorageModeManaged)) + .entry((size, SHARED_BUFFER_STORAGE_MODE)) .or_insert(vec![]); let new_buffer = Arc::new(new_buffer); @@ -347,7 +354,8 @@ impl MetalDevice { } fn buf_size(size: NSUInteger) -> NSUInteger { - size.saturating_sub(1).next_power_of_two() as NSUInteger + // size.saturating_sub(1).next_power_of_two() as NSUInteger + size.next_power_of_two() as NSUInteger } fn find_available_buffer( diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index a77c37168b..6ae14db888 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -11,7 +11,7 @@ use std::ffi::c_void; use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError}; mod device; -pub use device::{DeviceId, MetalDevice}; +pub use device::{DeviceId, MetalDevice, SHARED_BUFFER_STORAGE_MODE}; pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> { BufferOffset { @@ -2064,7 +2064,7 @@ impl BackendDevice for MetalDevice { let seed = Arc::new(Mutex::new(device.new_buffer_with_data( [299792458].as_ptr() as *const c_void, 4, - MTLResourceOptions::StorageModeManaged, + SHARED_BUFFER_STORAGE_MODE, ))); let commands = device::Commands::new(command_queue)?; Ok(Self { diff --git a/candle-metal-kernels/examples/metal_benchmarks.rs b/candle-metal-kernels/examples/metal_benchmarks.rs index f0de21e0c2..864a239f71 100644 --- a/candle-metal-kernels/examples/metal_benchmarks.rs +++ b/candle-metal-kernels/examples/metal_benchmarks.rs @@ -13,6 +13,9 @@ fn run_gemm(f32: bool, n: usize) -> Result<()> { let (b, m, n, k) = (1, n, n, n); let kernels = candle_metal_kernels::Kernels::new(); let command_queue = device.new_command_queue(); + #[cfg(target_os = "ios")] + let options = metal::MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = metal::MTLResourceOptions::StorageModeManaged; let (lhs, rhs) = if f32 { diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c03b1c1370..5c14dd62ad 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -625,6 +625,11 @@ pub fn call_binary_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.set_compute_pipeline_state(&pipeline); + + let dummy = &[0usize]; + let shape = if num_dims == 0 { dummy } else { shape }; + let left_strides = if num_dims == 0 { dummy } else { left_strides }; + let right_strides = if num_dims == 0 { dummy } else { right_strides }; set_params!( encoder, ( diff --git a/candle-metal-kernels/src/sort.rs b/candle-metal-kernels/src/sort.rs index e4140eb38b..1aee8fbd7f 100644 --- a/candle-metal-kernels/src/sort.rs +++ b/candle-metal-kernels/src/sort.rs @@ -92,11 +92,11 @@ pub fn multi_block_sort( &src, &mut dev_vals_0, &mut dev_idxs_0, - /* size_sorted_axis */ ncols as i32, - /* stride_sorted_axis */ 1i32, - /* nc_dim */ 1i32, - /* nc_shape */ nrows as i32, - /* nc_str */ ncols as i32 + /* size_sorted_axis */ ncols as i64, + /* stride_sorted_axis */ 1i64, + /* nc_dim */ 1i64, + /* nc_shape */ nrows as i64, + /* nc_str */ ncols as i64 ) ); let thread_group_count = MTLSize { @@ -243,11 +243,11 @@ pub fn block_sort( ( &src, dst, - ncols as i32, - 1i32, - 1i32, - ncols as i32, - ncols as i32 + ncols as i64, + 1i64, + 1i64, + ncols as i64, + ncols as i64 ) ); let thread_group_count = MTLSize { diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 5934cffb32..1bf210fed9 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -13,6 +13,9 @@ fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { } fn new_buffer(device: &Device, data: &[T]) -> Buffer { + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = MTLResourceOptions::StorageModeManaged; let ptr = data.as_ptr() as *const c_void; let size = std::mem::size_of_val(data) as u64; @@ -69,6 +72,9 @@ fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = MTLResourceOptions::StorageModeManaged; let left = new_buffer(&device, x); let right = new_buffer(&device, y); @@ -311,6 +317,9 @@ fn run_cast(v: &[T], name: &'static str) -> Vec { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = MTLResourceOptions::StorageModeManaged; let size = (v.len() * std::mem::size_of::()) as u64; let output = device.new_buffer(size, options); @@ -516,6 +525,45 @@ fn cast_i64() { assert_eq!(results, v_u8); } +// This test specifically targets the buffer size mismatch for scalar casting. +#[test] +fn cast_scalar() { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let input_data = &[1.0f32]; + let input_buffer = new_buffer(&device, input_data); + let input = BufferOffset::zero_offset(&input_buffer); + + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] + let options = MTLResourceOptions::StorageModeManaged; + + // This is the BUG: The output buffer is allocated with the size of the + // INPUT dtype (f32 = 4 bytes) instead of the OUTPUT dtype (bf16 = 2 bytes). + // The error message shows length=1, but it should be 2. Let's replicate + // the likely buggy allocation size calculation to trigger the validation error. + let buggy_size = (1 * std::mem::size_of::()) as u64; // Incorrectly using f32 size + let output_buffer = device.new_buffer(buggy_size, options); + + // This call should fail the Metal validation. + call_cast_contiguous( + &device, + command_buffer, + &kernels, + "cast_f32_bf16", + 1, // el_count = 1 + input, + &output_buffer, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); +} + fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { let device = device(); let kernels = Kernels::new(); @@ -874,6 +922,9 @@ fn run_reduce( let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = MTLResourceOptions::StorageModeManaged; let output = device.new_buffer((out_length * core::mem::size_of::()) as u64, options); let shape = vec![in_length]; @@ -1188,6 +1239,9 @@ fn run_where_cond( let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = MTLResourceOptions::StorageModeManaged; let length = cond.len(); @@ -1299,6 +1353,9 @@ fn run_mlx_gemm( let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = MTLResourceOptions::StorageModeManaged; let lhs = device.new_buffer_with_data( @@ -1444,6 +1501,9 @@ fn run_random(name: &'static str, seed: u32, length: usize, a: f32, b: let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = MTLResourceOptions::StorageModeManaged; let output = device.new_buffer((length * core::mem::size_of::()) as NSUInteger, options); @@ -1570,6 +1630,9 @@ fn run_scatter_add( let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = MTLResourceOptions::StorageModeManaged; let input_buffer = new_buffer(&device, input); let ids_buffer = new_buffer(&device, ids); diff --git a/candle-metal-kernels/tmp/affine.rs b/candle-metal-kernels/tmp/affine.rs index cd019056c7..a6b2777d0a 100644 --- a/candle-metal-kernels/tmp/affine.rs +++ b/candle-metal-kernels/tmp/affine.rs @@ -30,6 +30,9 @@ fn main() { fn run_affine_bench(device: &Device, kernels: &Kernels, v: &[T]) { let command_queue = device.new_command_queue(); + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = MTLResourceOptions::StorageModeManaged; let iterations = 10000; diff --git a/candle-metal-kernels/tmp/binary.rs b/candle-metal-kernels/tmp/binary.rs index af5a8bdc62..e5b7e11c54 100644 --- a/candle-metal-kernels/tmp/binary.rs +++ b/candle-metal-kernels/tmp/binary.rs @@ -94,6 +94,9 @@ fn run_binary_bench( strided: [binary::strided::Kernel; 4], ) { let command_queue = device.new_command_queue(); + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = MTLResourceOptions::StorageModeManaged; let iterations = 1000; diff --git a/candle-metal-kernels/tmp/cast.rs b/candle-metal-kernels/tmp/cast.rs index 090f510d16..73c0dea3a6 100644 --- a/candle-metal-kernels/tmp/cast.rs +++ b/candle-metal-kernels/tmp/cast.rs @@ -37,6 +37,9 @@ fn run_cast_bench( contiguous: &[&'static str], ) { let command_queue = device.new_command_queue(); + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = MTLResourceOptions::StorageModeManaged; let iterations = 1000; diff --git a/candle-metal-kernels/tmp/unary.rs b/candle-metal-kernels/tmp/unary.rs index 66cf25c0c8..64f597b55a 100644 --- a/candle-metal-kernels/tmp/unary.rs +++ b/candle-metal-kernels/tmp/unary.rs @@ -112,6 +112,9 @@ fn run_unary_bench( strided: [unary::strided::Kernel; 7], ) { let command_queue = device.new_command_queue(); + #[cfg(target_os = "ios")] + let options = MTLResourceOptions::StorageModeShared; + #[cfg(not(target_os = "ios"))] let options = MTLResourceOptions::StorageModeManaged; let iterations = 10000;