Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions candle-core/examples/metal_pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use candle_core::{DType, Device, Result, Tensor};

Check warning on line 1 in candle-core/examples/metal_pool.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

unused imports: `DType`, `Device`, and `Tensor`

#[cfg(feature = "metal")]
fn run() -> Result<()> {
if !candle_core::utils::metal_is_available() {
eprintln!("Metal is not available on this system.");
return Ok(());
}

let device = Device::new_metal(0)?;

let input = Tensor::arange(0f32, 16f32, &Device::Cpu)?
.to_dtype(DType::F32)?
.to_device(&device)?;

// Allocate a 4 MB pool for intermediate activations.
let pooled = input.start_pool(4 * 1024 * 1024)?;
println!("{input}");
println!("{pooled}");
println!("{}", pooled.sin()?);

let logits = pooled.sin()?.mul(&pooled.cos()?)?;
let final_tensor = logits.tanh()?.leave_pool()?;

println!("final tensor: {:?}", final_tensor.to_vec1::<f32>()?);

Ok(())
}

#[cfg(not(feature = "metal"))]
fn run() -> Result<()> {
eprintln!("Rebuild candle-core with the `metal` feature to run this example.");
Ok(())
}

fn main() -> Result<()> {
run()
}
21 changes: 16 additions & 5 deletions candle-core/src/metal_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};

use super::MetalError;
use super::{buffer_offset_for_output, current_pool, register_pool_allocation, MetalError};

// iOS and macOS have different storage modes for shared buffers.
// due to the GPU/CPU management differences.
Expand Down Expand Up @@ -242,8 +242,16 @@ impl MetalDevice {
dtype: DType,
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
let size_bytes = element_count * dtype.size_in_bytes();
if let Some(pool) = current_pool() {
let allocation = pool.allocate(size_bytes, dtype.size_in_bytes())?;
let buffer = Arc::clone(allocation.buffer());
register_pool_allocation(&buffer, allocation);
Ok(buffer)
} else {
let size = size_bytes as NSUInteger;
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
}
}

pub fn new_buffer_private(
Expand All @@ -252,8 +260,7 @@ impl MetalDevice {
dtype: DType,
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, metal::MTLResourceOptions::StorageModePrivate, name)
self.new_buffer(element_count, dtype, name)
}

/// Creates a new buffer (not necessarily zeroed).
Expand Down Expand Up @@ -288,6 +295,10 @@ impl MetalDevice {
Ok(new_buffer)
}

pub fn buffer_offset<'a>(&self, buffer: &'a Arc<Buffer>) -> super::BufferOffset<'a> {
buffer_offset_for_output(buffer)
}

pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
Expand Down
Loading
Loading