|
1 | 1 | //! CoreML Execution Provider |
2 | 2 | //! |
3 | 3 | //! Execution provider for Apple CoreML/Metal acceleration on macOS. |
| 4 | +//! Supports CPU, GPU, and Neural Engine (ANE) on Apple Silicon. |
4 | 5 |
|
5 | | -use ort::execution_providers::{CoreMLExecutionProvider, ExecutionProviderDispatch}; |
| 6 | +use ort::ep::coreml::{ComputeUnits as OrtComputeUnits, CoreML, ModelFormat}; |
| 7 | +use ort::execution_providers::ExecutionProviderDispatch; |
6 | 8 |
|
7 | 9 | /// Compute units for CoreML execution |
8 | | -#[derive(Debug, Clone, Copy, Default)] |
| 10 | +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] |
9 | 11 | pub enum ComputeUnits { |
10 | 12 | /// Use all available compute units (CPU, GPU, Neural Engine) |
11 | 13 | #[default] |
12 | 14 | All, |
13 | | - /// Use CPU and GPU only |
| 15 | + /// Use CPU and Neural Engine (ANE) - optimal for most models on Apple Silicon |
| 16 | + CpuAndNeuralEngine, |
| 17 | + /// Use CPU and GPU only (no ANE) |
14 | 18 | CpuAndGpu, |
15 | 19 | /// Use CPU only |
16 | 20 | CpuOnly, |
17 | 21 | } |
18 | 22 |
|
| 23 | +impl ComputeUnits { |
| 24 | + /// Convert to ort's ComputeUnits enum |
| 25 | + fn to_ort(self) -> OrtComputeUnits { |
| 26 | + match self { |
| 27 | + ComputeUnits::All => OrtComputeUnits::All, |
| 28 | + ComputeUnits::CpuAndNeuralEngine => OrtComputeUnits::CPUAndNeuralEngine, |
| 29 | + ComputeUnits::CpuAndGpu => OrtComputeUnits::CPUAndGPU, |
| 30 | + ComputeUnits::CpuOnly => OrtComputeUnits::CPUOnly, |
| 31 | + } |
| 32 | + } |
| 33 | +} |
| 34 | + |
19 | 35 | /// CoreML execution provider configuration |
20 | 36 | #[derive(Debug, Clone)] |
21 | 37 | pub struct CoreMLConfig { |
22 | 38 | /// Which compute units to use |
23 | 39 | pub compute_units: ComputeUnits, |
| 40 | + /// Enable subgraph execution (for models with control flow) |
| 41 | + pub enable_subgraphs: bool, |
24 | 42 | /// Require static input shapes |
25 | 43 | pub require_static_shapes: bool, |
26 | | - /// Enable model caching |
27 | | - pub enable_cache: bool, |
| 44 | + /// Model format (NeuralNetwork or MLProgram) |
| 45 | + pub model_format: Option<CoreMLModelFormat>, |
| 46 | + /// Cache directory for compiled models |
| 47 | + pub cache_dir: Option<String>, |
| 48 | +} |
| 49 | + |
| 50 | +/// CoreML model format |
| 51 | +#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| 52 | +pub enum CoreMLModelFormat { |
| 53 | + /// NeuralNetwork format - better compatibility with older macOS/iOS |
| 54 | + NeuralNetwork, |
| 55 | + /// MLProgram format - supports more operators, potentially better performance |
| 56 | + MLProgram, |
28 | 57 | } |
29 | 58 |
|
30 | 59 | impl Default for CoreMLConfig { |
31 | 60 | fn default() -> Self { |
32 | 61 | Self { |
33 | 62 | compute_units: ComputeUnits::All, |
| 63 | + enable_subgraphs: false, |
34 | 64 | require_static_shapes: false, |
35 | | - enable_cache: true, |
| 65 | + model_format: None, |
| 66 | + cache_dir: None, |
36 | 67 | } |
37 | 68 | } |
38 | 69 | } |
@@ -60,36 +91,77 @@ impl CoreMLProvider { |
60 | 91 | self |
61 | 92 | } |
62 | 93 |
|
63 | | - /// Set neural engine only mode |
| 94 | + /// Set Neural Engine only mode (CPU + ANE, no GPU) |
| 95 | + /// This is optimal for most inference tasks on Apple Silicon |
64 | 96 | pub fn neural_engine_only(self) -> Self { |
65 | | - self.with_compute_units(ComputeUnits::All) |
| 97 | + self.with_compute_units(ComputeUnits::CpuAndNeuralEngine) |
66 | 98 | } |
67 | 99 |
|
68 | | - /// Set GPU only mode (no ANE) |
| 100 | + /// Set GPU only mode (CPU + GPU, no ANE) |
69 | 101 | pub fn gpu_only(self) -> Self { |
70 | 102 | self.with_compute_units(ComputeUnits::CpuAndGpu) |
71 | 103 | } |
72 | 104 |
|
| 105 | + /// Set CPU only mode |
| 106 | + pub fn cpu_only(self) -> Self { |
| 107 | + self.with_compute_units(ComputeUnits::CpuOnly) |
| 108 | + } |
| 109 | + |
| 110 | + /// Enable subgraph execution for models with control flow operators |
| 111 | + pub fn with_subgraphs(mut self, enable: bool) -> Self { |
| 112 | + self.config.enable_subgraphs = enable; |
| 113 | + self |
| 114 | + } |
| 115 | + |
| 116 | + /// Require static input shapes |
| 117 | + pub fn with_static_shapes(mut self, require: bool) -> Self { |
| 118 | + self.config.require_static_shapes = require; |
| 119 | + self |
| 120 | + } |
| 121 | + |
| 122 | + /// Set model format |
| 123 | + pub fn with_model_format(mut self, format: CoreMLModelFormat) -> Self { |
| 124 | + self.config.model_format = Some(format); |
| 125 | + self |
| 126 | + } |
| 127 | + |
| 128 | + /// Set cache directory for compiled models |
| 129 | + pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self { |
| 130 | + self.config.cache_dir = Some(dir.into()); |
| 131 | + self |
| 132 | + } |
| 133 | + |
73 | 134 | /// Convert to ORT execution provider dispatch |
74 | 135 | pub fn into_dispatch(self) -> ExecutionProviderDispatch { |
75 | | - let mut provider = CoreMLExecutionProvider::default(); |
76 | | - |
77 | | - match self.config.compute_units { |
78 | | - ComputeUnits::All => { |
79 | | - // Default - uses all available compute units |
80 | | - } |
81 | | - ComputeUnits::CpuAndGpu => { |
82 | | - provider = provider.with_ane_only(); |
83 | | - } |
84 | | - ComputeUnits::CpuOnly => { |
85 | | - provider = provider.with_cpu_only(); |
86 | | - } |
| 136 | + let mut provider = CoreML::default(); |
| 137 | + |
| 138 | + // Set compute units |
| 139 | + provider = provider.with_compute_units(self.config.compute_units.to_ort()); |
| 140 | + |
| 141 | + // Enable subgraphs if requested |
| 142 | + if self.config.enable_subgraphs { |
| 143 | + provider = provider.with_subgraphs(true); |
87 | 144 | } |
88 | 145 |
|
| 146 | + // Require static shapes if requested |
89 | 147 | if self.config.require_static_shapes { |
90 | | - provider = provider.with_subgraphs(); |
| 148 | + provider = provider.with_static_input_shapes(true); |
| 149 | + } |
| 150 | + |
| 151 | + // Set model format if specified |
| 152 | + if let Some(format) = self.config.model_format { |
| 153 | + let ort_format = match format { |
| 154 | + CoreMLModelFormat::NeuralNetwork => ModelFormat::NeuralNetwork, |
| 155 | + CoreMLModelFormat::MLProgram => ModelFormat::MLProgram, |
| 156 | + }; |
| 157 | + provider = provider.with_model_format(ort_format); |
| 158 | + } |
| 159 | + |
| 160 | + // Set cache directory if specified |
| 161 | + if let Some(dir) = &self.config.cache_dir { |
| 162 | + provider = provider.with_model_cache_dir(dir); |
91 | 163 | } |
92 | 164 |
|
93 | | - provider.build().into() |
| 165 | + provider.build() |
94 | 166 | } |
95 | 167 | } |
0 commit comments