Skip to content

Commit ac62503

Browse files
committed
feat :: coreml, bench acceleration
1 parent d495087 commit ac62503

4 files changed

Lines changed: 107 additions & 27 deletions

File tree

Lines changed: 95 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,69 @@
11
//! CoreML Execution Provider
22
//!
33
//! Execution provider for Apple CoreML/Metal acceleration on macOS.
4+
//! Supports CPU, GPU, and Neural Engine (ANE) on Apple Silicon.
45
5-
use ort::execution_providers::{CoreMLExecutionProvider, ExecutionProviderDispatch};
6+
use ort::ep::coreml::{ComputeUnits as OrtComputeUnits, CoreML, ModelFormat};
7+
use ort::execution_providers::ExecutionProviderDispatch;
68

79
/// Compute units for CoreML execution
8-
#[derive(Debug, Clone, Copy, Default)]
10+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
911
pub enum ComputeUnits {
1012
/// Use all available compute units (CPU, GPU, Neural Engine)
1113
#[default]
1214
All,
13-
/// Use CPU and GPU only
15+
/// Use CPU and Neural Engine (ANE) - optimal for most models on Apple Silicon
16+
CpuAndNeuralEngine,
17+
/// Use CPU and GPU only (no ANE)
1418
CpuAndGpu,
1519
/// Use CPU only
1620
CpuOnly,
1721
}
1822

23+
impl ComputeUnits {
24+
/// Convert to ort's ComputeUnits enum
25+
fn to_ort(self) -> OrtComputeUnits {
26+
match self {
27+
ComputeUnits::All => OrtComputeUnits::All,
28+
ComputeUnits::CpuAndNeuralEngine => OrtComputeUnits::CPUAndNeuralEngine,
29+
ComputeUnits::CpuAndGpu => OrtComputeUnits::CPUAndGPU,
30+
ComputeUnits::CpuOnly => OrtComputeUnits::CPUOnly,
31+
}
32+
}
33+
}
34+
1935
/// CoreML execution provider configuration
2036
#[derive(Debug, Clone)]
2137
pub struct CoreMLConfig {
2238
/// Which compute units to use
2339
pub compute_units: ComputeUnits,
40+
/// Enable subgraph execution (for models with control flow)
41+
pub enable_subgraphs: bool,
2442
/// Require static input shapes
2543
pub require_static_shapes: bool,
26-
/// Enable model caching
27-
pub enable_cache: bool,
44+
/// Model format (NeuralNetwork or MLProgram)
45+
pub model_format: Option<CoreMLModelFormat>,
46+
/// Cache directory for compiled models
47+
pub cache_dir: Option<String>,
48+
}
49+
50+
/// CoreML model format
51+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52+
pub enum CoreMLModelFormat {
53+
/// NeuralNetwork format - better compatibility with older macOS/iOS
54+
NeuralNetwork,
55+
/// MLProgram format - supports more operators, potentially better performance
56+
MLProgram,
2857
}
2958

3059
impl Default for CoreMLConfig {
3160
fn default() -> Self {
3261
Self {
3362
compute_units: ComputeUnits::All,
63+
enable_subgraphs: false,
3464
require_static_shapes: false,
35-
enable_cache: true,
65+
model_format: None,
66+
cache_dir: None,
3667
}
3768
}
3869
}
@@ -60,36 +91,77 @@ impl CoreMLProvider {
6091
self
6192
}
6293

63-
/// Set neural engine only mode
94+
/// Set Neural Engine only mode (CPU + ANE, no GPU)
95+
/// This is optimal for most inference tasks on Apple Silicon
6496
pub fn neural_engine_only(self) -> Self {
65-
self.with_compute_units(ComputeUnits::All)
97+
self.with_compute_units(ComputeUnits::CpuAndNeuralEngine)
6698
}
6799

68-
/// Set GPU only mode (no ANE)
100+
/// Set GPU only mode (CPU + GPU, no ANE)
69101
pub fn gpu_only(self) -> Self {
70102
self.with_compute_units(ComputeUnits::CpuAndGpu)
71103
}
72104

105+
/// Set CPU only mode
106+
pub fn cpu_only(self) -> Self {
107+
self.with_compute_units(ComputeUnits::CpuOnly)
108+
}
109+
110+
/// Enable subgraph execution for models with control flow operators
111+
pub fn with_subgraphs(mut self, enable: bool) -> Self {
112+
self.config.enable_subgraphs = enable;
113+
self
114+
}
115+
116+
/// Require static input shapes
117+
pub fn with_static_shapes(mut self, require: bool) -> Self {
118+
self.config.require_static_shapes = require;
119+
self
120+
}
121+
122+
/// Set model format
123+
pub fn with_model_format(mut self, format: CoreMLModelFormat) -> Self {
124+
self.config.model_format = Some(format);
125+
self
126+
}
127+
128+
/// Set cache directory for compiled models
129+
pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self {
130+
self.config.cache_dir = Some(dir.into());
131+
self
132+
}
133+
73134
/// Convert to ORT execution provider dispatch
74135
pub fn into_dispatch(self) -> ExecutionProviderDispatch {
75-
let mut provider = CoreMLExecutionProvider::default();
76-
77-
match self.config.compute_units {
78-
ComputeUnits::All => {
79-
// Default - uses all available compute units
80-
}
81-
ComputeUnits::CpuAndGpu => {
82-
provider = provider.with_ane_only();
83-
}
84-
ComputeUnits::CpuOnly => {
85-
provider = provider.with_cpu_only();
86-
}
136+
let mut provider = CoreML::default();
137+
138+
// Set compute units
139+
provider = provider.with_compute_units(self.config.compute_units.to_ort());
140+
141+
// Enable subgraphs if requested
142+
if self.config.enable_subgraphs {
143+
provider = provider.with_subgraphs(true);
87144
}
88145

146+
// Require static shapes if requested
89147
if self.config.require_static_shapes {
90-
provider = provider.with_subgraphs();
148+
provider = provider.with_static_input_shapes(true);
149+
}
150+
151+
// Set model format if specified
152+
if let Some(format) = self.config.model_format {
153+
let ort_format = match format {
154+
CoreMLModelFormat::NeuralNetwork => ModelFormat::NeuralNetwork,
155+
CoreMLModelFormat::MLProgram => ModelFormat::MLProgram,
156+
};
157+
provider = provider.with_model_format(ort_format);
158+
}
159+
160+
// Set cache directory if specified
161+
if let Some(dir) = &self.config.cache_dir {
162+
provider = provider.with_model_cache_dir(dir);
91163
}
92164

93-
provider.build().into()
165+
provider.build()
94166
}
95167
}

crates/airml-providers/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pub use cpu::CpuProvider;
1111
pub use ort::execution_providers::ExecutionProviderDispatch;
1212

1313
#[cfg(feature = "coreml")]
14-
pub use coreml::{CoreMLConfig, CoreMLProvider};
14+
pub use coreml::{ComputeUnits, CoreMLConfig, CoreMLModelFormat, CoreMLProvider};
1515

1616
/// Available execution providers
1717
#[derive(Debug, Clone, Copy, PartialEq, Eq)]

src/commands/bench.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ pub fn execute(args: &BenchArgs) -> Result<()> {
2222
"cpu" => vec![airml_providers::CpuProvider::default().into_dispatch()],
2323
#[cfg(feature = "coreml")]
2424
"coreml" => vec![airml_providers::CoreMLProvider::default().into_dispatch()],
25+
#[cfg(feature = "coreml")]
26+
"neural-engine" => vec![airml_providers::CoreMLProvider::default()
27+
.neural_engine_only()
28+
.into_dispatch()],
2529
_ => auto_select_providers(),
2630
};
2731

src/commands/run.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,13 @@ fn select_providers(provider_name: &str) -> Result<Vec<airml_providers::Executio
7272
"auto" => Ok(auto_select_providers()),
7373
"cpu" => Ok(vec![airml_providers::CpuProvider::default().into_dispatch()]),
7474
#[cfg(feature = "coreml")]
75-
"coreml" | "neural-engine" => {
76-
Ok(vec![airml_providers::CoreMLProvider::default().into_dispatch()])
77-
}
75+
"coreml" => Ok(vec![airml_providers::CoreMLProvider::default().into_dispatch()]),
76+
#[cfg(feature = "coreml")]
77+
"neural-engine" => Ok(vec![
78+
airml_providers::CoreMLProvider::default()
79+
.neural_engine_only()
80+
.into_dispatch(),
81+
]),
7882
_ => {
7983
println!("Warning: Unknown provider '{}', using auto-selection", provider_name);
8084
Ok(auto_select_providers())

0 commit comments

Comments
 (0)