Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions candle-core/src/metal_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ use std::sync::{Arc, Mutex, RwLock};

use super::MetalError;

// iOS and macOS have different storage modes for shared buffers.
// due to the GPU/CPU management differences.
#[cfg(target_os = "ios")]
pub const SHARED_BUFFER_STORAGE_MODE: MTLResourceOptions = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
pub const SHARED_BUFFER_STORAGE_MODE: MTLResourceOptions = MTLResourceOptions::StorageModeManaged;

/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
Expand Down Expand Up @@ -255,7 +262,7 @@ impl MetalDevice {
/// synchronization when the CPU memory is modified
/// Used as a bridge to gather data back from the GPU
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
self.allocate_buffer(size, SHARED_BUFFER_STORAGE_MODE, "managed")
}

/// Creates a new buffer from data.
Expand All @@ -268,12 +275,12 @@ impl MetalDevice {
let new_buffer = self.device.new_buffer_with_data(
data.as_ptr().cast(),
size,
MTLResourceOptions::StorageModeManaged,
SHARED_BUFFER_STORAGE_MODE,
);
let mut buffers = self.buffers.write().map_err(MetalError::from)?;

let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.entry((size, SHARED_BUFFER_STORAGE_MODE))
.or_insert(vec![]);

let new_buffer = Arc::new(new_buffer);
Expand Down Expand Up @@ -347,7 +354,8 @@ impl MetalDevice {
}

fn buf_size(size: NSUInteger) -> NSUInteger {
size.saturating_sub(1).next_power_of_two() as NSUInteger
// size.saturating_sub(1).next_power_of_two() as NSUInteger
size.next_power_of_two() as NSUInteger
}

fn find_available_buffer(
Expand Down
4 changes: 2 additions & 2 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::ffi::c_void;
use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError};

mod device;
pub use device::{DeviceId, MetalDevice};
pub use device::{DeviceId, MetalDevice, SHARED_BUFFER_STORAGE_MODE};

pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
BufferOffset {
Expand Down Expand Up @@ -2064,7 +2064,7 @@ impl BackendDevice for MetalDevice {
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
SHARED_BUFFER_STORAGE_MODE,
)));
let commands = device::Commands::new(command_queue)?;
Ok(Self {
Expand Down
3 changes: 3 additions & 0 deletions candle-metal-kernels/examples/metal_benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ fn run_gemm(f32: bool, n: usize) -> Result<()> {
let (b, m, n, k) = (1, n, n, n);
let kernels = candle_metal_kernels::Kernels::new();
let command_queue = device.new_command_queue();
#[cfg(target_os = "ios")]
let options = metal::MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = metal::MTLResourceOptions::StorageModeManaged;

let (lhs, rhs) = if f32 {
Expand Down
5 changes: 5 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,11 @@ pub fn call_binary_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);

encoder.set_compute_pipeline_state(&pipeline);

let dummy = &[0usize];
let shape = if num_dims == 0 { dummy } else { shape };
let left_strides = if num_dims == 0 { dummy } else { left_strides };
let right_strides = if num_dims == 0 { dummy } else { right_strides };
set_params!(
encoder,
(
Expand Down
20 changes: 10 additions & 10 deletions candle-metal-kernels/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ pub fn multi_block_sort(
&src,
&mut dev_vals_0,
&mut dev_idxs_0,
/* size_sorted_axis */ ncols as i32,
/* stride_sorted_axis */ 1i32,
/* nc_dim */ 1i32,
/* nc_shape */ nrows as i32,
/* nc_str */ ncols as i32
/* size_sorted_axis */ ncols as i64,
/* stride_sorted_axis */ 1i64,
/* nc_dim */ 1i64,
/* nc_shape */ nrows as i64,
/* nc_str */ ncols as i64
)
);
let thread_group_count = MTLSize {
Expand Down Expand Up @@ -243,11 +243,11 @@ pub fn block_sort(
(
&src,
dst,
ncols as i32,
1i32,
1i32,
ncols as i32,
ncols as i32
ncols as i64,
1i64,
1i64,
ncols as i64,
ncols as i64
)
);
let thread_group_count = MTLSize {
Expand Down
63 changes: 63 additions & 0 deletions candle-metal-kernels/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
}

fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const c_void;
let size = std::mem::size_of_val(data) as u64;
Expand Down Expand Up @@ -69,6 +72,9 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;
let left = new_buffer(&device, x);
let right = new_buffer(&device, y);
Expand Down Expand Up @@ -311,6 +317,9 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;
let size = (v.len() * std::mem::size_of::<U>()) as u64;
let output = device.new_buffer(size, options);
Expand Down Expand Up @@ -516,6 +525,45 @@ fn cast_i64() {
assert_eq!(results, v_u8);
}

// This test specifically targets the buffer size mismatch for scalar casting.
#[test]
fn cast_scalar() {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

let input_data = &[1.0f32];
let input_buffer = new_buffer(&device, input_data);
let input = BufferOffset::zero_offset(&input_buffer);

#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;

// This is the BUG: The output buffer is allocated with the size of the
// INPUT dtype (f32 = 4 bytes) instead of the OUTPUT dtype (bf16 = 2 bytes).
// The error message shows length=1, but it should be 2. Let's replicate
// the likely buggy allocation size calculation to trigger the validation error.
let buggy_size = (1 * std::mem::size_of::<f32>()) as u64; // Incorrectly using f32 size
let output_buffer = device.new_buffer(buggy_size, options);

// This call should fail the Metal validation.
call_cast_contiguous(
&device,
command_buffer,
&kernels,
"cast_f32_bf16",
1, // el_count = 1
input,
&output_buffer,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
}

fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
Expand Down Expand Up @@ -874,6 +922,9 @@ fn run_reduce<T, U: Clone>(
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);

#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((out_length * core::mem::size_of::<U>()) as u64, options);
let shape = vec![in_length];
Expand Down Expand Up @@ -1188,6 +1239,9 @@ fn run_where_cond<I: Clone, T: Clone>(
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;

let length = cond.len();
Expand Down Expand Up @@ -1299,6 +1353,9 @@ fn run_mlx_gemm<T: Clone>(
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;

let lhs = device.new_buffer_with_data(
Expand Down Expand Up @@ -1444,6 +1501,9 @@ fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b:
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);

Expand Down Expand Up @@ -1570,6 +1630,9 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;
let input_buffer = new_buffer(&device, input);
let ids_buffer = new_buffer(&device, ids);
Expand Down
3 changes: 3 additions & 0 deletions candle-metal-kernels/tmp/affine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ fn main() {

fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
let command_queue = device.new_command_queue();
#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;

let iterations = 10000;
Expand Down
3 changes: 3 additions & 0 deletions candle-metal-kernels/tmp/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ fn run_binary_bench<T: Clone>(
strided: [binary::strided::Kernel; 4],
) {
let command_queue = device.new_command_queue();
#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;

let iterations = 1000;
Expand Down
3 changes: 3 additions & 0 deletions candle-metal-kernels/tmp/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ fn run_cast_bench<T: Clone>(
contiguous: &[&'static str],
) {
let command_queue = device.new_command_queue();
#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;

let iterations = 1000;
Expand Down
3 changes: 3 additions & 0 deletions candle-metal-kernels/tmp/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ fn run_unary_bench<T: Clone>(
strided: [unary::strided::Kernel; 7],
) {
let command_queue = device.new_command_queue();
#[cfg(target_os = "ios")]
let options = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
let options = MTLResourceOptions::StorageModeManaged;

let iterations = 10000;
Expand Down
Loading