Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions candle-core/examples/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use anyhow::Result;
use candle_core::{Device, Tensor};

fn main() -> Result<()> {
// This requires the code to be run with MTL_CAPTURE_ENABLED=1
let device = Device::new_metal(0)?;

let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?;
let x1 = x.add(&x)?;
println!("{x1}");
Ok(())
}
48 changes: 41 additions & 7 deletions candle-core/src/metal_backend/device.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::metal_backend::pool::MetalTensorPool;
use crate::{DType, Result};
use candle_metal_kernels::Kernels;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};
use std::sync::{Arc, LazyLock, Mutex, RwLock};

use super::MetalError;

Expand All @@ -12,7 +13,12 @@ use super::MetalError;
#[cfg(target_os = "ios")]
pub const SHARED_BUFFER_STORAGE_MODE: MTLResourceOptions = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
pub const SHARED_BUFFER_STORAGE_MODE: MTLResourceOptions = MTLResourceOptions::StorageModeManaged;
pub const SHARED_BUFFER_STORAGE_MODE: MTLResourceOptions = MTLResourceOptions::StorageModeShared;

// Pooling should be per-device, not a single global optional pool. Use a global
// map keyed by DeviceId to avoid cross-device contention and accidental sharing.
pub(crate) static POOLS: LazyLock<RwLock<HashMap<DeviceId, Arc<MetalTensorPool>>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));

/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -223,6 +229,27 @@ impl MetalDevice {
commands.wait_until_completed()
}

/// Ensure a tensor pool exists for this device and return it. If one does not exist,
/// it is created with the provided `size_in_bytes` capacity.
///
/// Note: `MetalTensorPool::new` is assumed to take `(metal::Device, usize)`; adjust the
/// constructor call if your actual signature differs.
pub(crate) fn ensure_pool(&self, size_in_bytes: usize) -> Arc<MetalTensorPool> {
// Fast path: try read lock first
if let Ok(g) = POOLS.read() {
if let Some(p) = g.get(&self.id) {
return p.clone();
}
}
// Slow path: upgrade to write lock and insert if still absent
let mut g = POOLS
.write()
.expect("metal tensor pool map poisoned");
g.entry(self.id)
.or_insert_with(|| Arc::new(MetalTensorPool::new(self, size_in_bytes).unwrap()))
.clone()
}

pub fn kernels(&self) -> &Kernels {
&self.kernels
}
Expand All @@ -243,7 +270,7 @@ impl MetalDevice {
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
self.allocate_buffer(size, MTLResourceOptions::StorageModeShared, name)
}

pub fn new_buffer_private(
Expand All @@ -253,7 +280,7 @@ impl MetalDevice {
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, metal::MTLResourceOptions::StorageModePrivate, name)
self.allocate_buffer(size, metal::MTLResourceOptions::StorageModeShared, name)
}

/// Creates a new buffer (not necessarily zeroed).
Expand Down Expand Up @@ -291,7 +318,7 @@ impl MetalDevice {
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModePrivate,
MTLResourceOptions::StorageModeShared,
"allocate_zeros",
)?;
let command_buffer = self.command_buffer()?;
Expand All @@ -314,8 +341,15 @@ impl MetalDevice {
&self,
size: NSUInteger,
option: MTLResourceOptions,
_name: &str,
name: &str,
) -> Result<Arc<Buffer>> {
// println!("{option:?}");
let pool = self.ensure_pool(8*1024*1024*1024);
if option == MTLResourceOptions::StorageModeShared {
// println!("{name}");
return pool.allocate_buffer(size, name, MTLResourceOptions::StorageModeShared);
}

let mut buffers = self.buffers.write().map_err(MetalError::from)?;
if let Some(b) = find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
Expand All @@ -325,7 +359,7 @@ impl MetalDevice {
let size = buf_size(size);
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);

let new_buffer = self.device.new_buffer(size as NSUInteger, option);
let new_buffer = self.device.new_buffer(size as NSUInteger, MTLResourceOptions::StorageModeShared);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());

Expand Down
1 change: 1 addition & 0 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError};

mod device;
pub use device::{DeviceId, MetalDevice, SHARED_BUFFER_STORAGE_MODE};
mod pool;

pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
BufferOffset {
Expand Down
98 changes: 98 additions & 0 deletions candle-core/src/metal_backend/pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};

use crate::{Error, Result};

use super::MetalDevice;
use metal::{Buffer, HeapDescriptor, MTLResourceOptions, MTLStorageMode, NSUInteger};

#[derive(Debug)]
struct MetalPoolInner {
device: MetalDevice,
heap: metal::Heap,
capacity: u64,
id: usize,
}

#[derive(Clone, Debug)]
pub struct MetalTensorPool {
inner: Arc<MetalPoolInner>,
}

impl MetalTensorPool {
pub fn new(device: &MetalDevice, size_in_bytes: usize) -> Result<Self> {
if size_in_bytes == 0 {
crate::bail!("metal pool size must be greater than zero")
}
let descriptor = HeapDescriptor::new();
descriptor.set_size(size_in_bytes as NSUInteger);
descriptor.set_storage_mode(MTLStorageMode::Shared);
dbg!(descriptor.hazard_tracking_mode());
// descriptor.set_heap_type(MTLHeapType::Placement);
// descriptor.set_resource_options(MTLResourceOptions::StorageModePrivate);

let heap = device.device.new_heap(&descriptor);

static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
let id = NEXT_ID.fetch_add(1, Ordering::Relaxed);

Ok(Self {
inner: Arc::new(MetalPoolInner {
device: device.clone(),
heap,
capacity: size_in_bytes as u64,
id,
}),
})
}

pub fn id(&self) -> usize {
self.inner.id
}

pub fn device(&self) -> &MetalDevice {
&self.inner.device
}

pub fn capacity(&self) -> u64 {
self.inner.capacity
}

pub fn allocate_buffer(
&self,
size: NSUInteger,
label: &str,
options: MTLResourceOptions,
) -> Result<Arc<Buffer>> {
if size > self.inner.capacity {
crate::bail!(
"pool allocation of {size} bytes exceeds pool capacity {}",
self.inner.capacity
)
}
let options = MTLResourceOptions::StorageModeShared | MTLResourceOptions::HazardTrackingModeTracked;
let size_align = self
.inner
.device
.device
.heap_buffer_size_and_align(size, options);
let align = std::cmp::max(size_align.align, 1);
let available = self.inner.heap.max_available_size_with_alignment(align);
if size_align.size > available {
crate::bail!(
"pool allocation of {size} bytes exceeds remaining capacity {}",
available
)
}
let buffer = self
.inner
.heap
.new_buffer(size, options)
.ok_or_else(|| Error::msg("metal heap allocation returned null"))?;
buffer.set_label(label);
println!("allocating {size}/{label} with {options:?}");
Ok(Arc::new(buffer))
}
}
26 changes: 26 additions & 0 deletions candle-core/src/tensor_pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#[derive(Clone, Debug)]
pub enum TensorPool {
#[cfg(feature = "metal")]
Metal(crate::metal_backend::MetalTensorPool),
}

impl TensorPool {
#[cfg(feature = "metal")]
pub fn from_metal(pool: crate::metal_backend::MetalTensorPool) -> Self {
Self::Metal(pool)
}

#[cfg(feature = "metal")]
pub fn as_metal(&self) -> Option<&crate::metal_backend::MetalTensorPool> {
match self {
Self::Metal(pool) => Some(pool),
}
}

#[cfg(feature = "metal")]
pub fn into_metal(self) -> crate::metal_backend::MetalTensorPool {
match self {
Self::Metal(pool) => pool,
}
}
}
Loading