Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ jobs:
if: ${{ !contains(github.event.head_commit.message, 'ci skip') && !contains(github.event.head_commit.message, 'test skip') }}
strategy:
fail-fast: false
max-parallel: 3
matrix:
os:
- 'ubuntu-latest'
- 'windows-latest'
- 'macos-latest'
python-version:
- '3.8'
- '3.9'
- '3.10'
# - '3.9'
# - '3.10'
- '3.11'
- '3.12'
# - '3.12'
- '3.13'
install:
- 'full'
Expand Down
117 changes: 102 additions & 15 deletions imgutils/utils/onnxruntime.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""
Overview:
Management of onnx models.
Management of ONNX models with automatic runtime detection and provider selection.

This module provides utilities for loading and managing ONNX models with support for
different execution providers (CPU, CUDA, TensorRT). It automatically handles the
installation of onnxruntime based on the system configuration and provides a
convenient interface for model inference.
"""
import logging
import os
import shutil
import warnings
from typing import Optional

from hbutils.system import pip_install
Expand All @@ -15,6 +21,14 @@


def _ensure_onnxruntime():
"""
Ensure that onnxruntime is installed on the system.

This function automatically detects if NVIDIA GPU is available and installs
the appropriate version of onnxruntime (GPU or CPU version).

:raises ImportError: If installation fails
"""
try:
import onnxruntime
except (ImportError, ModuleNotFoundError):
Expand All @@ -39,13 +53,35 @@ def _ensure_onnxruntime():

def get_onnx_provider(provider: Optional[str] = None):
"""
Overview:
Get onnx provider.
Get the appropriate ONNX execution provider based on system capabilities and user preference.

This function automatically detects available execution providers and returns the most
suitable one. It supports aliases for common providers and falls back to CPU execution
if GPU providers are not available.

:param provider: The provider for ONNX runtime. ``None`` by default and will automatically detect
if the ``CUDAExecutionProvider`` is available. If it is available, it will be used,
otherwise the default ``CPUExecutionProvider`` will be used.
:return: String of the provider.
otherwise the default ``CPUExecutionProvider`` will be used. Supported aliases include
'gpu' for CUDAExecutionProvider and 'trt' for TensorrtExecutionProvider.
:type provider: Optional[str]

:return: String name of the selected execution provider.
:rtype: str

:raises ValueError: If the specified provider is not supported or available.

Example::
>>> # Auto-detect provider
>>> provider = get_onnx_provider()
>>> print(provider) # 'CUDAExecutionProvider' or 'CPUExecutionProvider'

>>> # Explicitly request GPU provider
>>> provider = get_onnx_provider('gpu')
>>> print(provider) # 'CUDAExecutionProvider'

>>> # Request CPU provider
>>> provider = get_onnx_provider('cpu')
>>> print(provider) # 'CPUExecutionProvider'
"""
if not provider:
if "CUDAExecutionProvider" in get_available_providers():
Expand All @@ -63,34 +99,85 @@ def get_onnx_provider(provider: Optional[str] = None):
f'but unsupported provider {provider!r} found.')


def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True) -> InferenceSession:
def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True,
cuda_device_id: Optional[int] = None) -> InferenceSession:
"""
Internal function to create and configure an ONNX inference session.

This function handles the low-level configuration of the ONNX runtime session,
including optimization settings and provider-specific configurations.

:param ckpt: Path to the ONNX model file.
:type ckpt: str
:param provider: Name of the execution provider to use.
:type provider: str
:param use_cpu: Whether to include CPU provider as fallback. Defaults to True.
:type use_cpu: bool
:param cuda_device_id: Specific CUDA device ID to use for GPU inference.
:type cuda_device_id: Optional[int]

:return: Configured ONNX inference session.
:rtype: InferenceSession
"""
options = SessionOptions()
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
if provider == "CPUExecutionProvider":
options.intra_op_num_threads = os.cpu_count()

providers = [provider]
if provider == 'CUDAExecutionProvider' and cuda_device_id is not None:
providers = [
('CUDAExecutionProvider', {'device_id': cuda_device_id}),
]
else:
if provider != 'CUDAExecutionProvider' and cuda_device_id is not None:
warnings.warn(UserWarning(
'CUDA device ID specified but provider is not CUDAExecutionProvider. The device ID will be ignored.'))
providers = [provider]
if use_cpu and "CPUExecutionProvider" not in providers:
providers.append("CPUExecutionProvider")

logging.info(f'Model {ckpt!r} loaded with provider {provider!r}')
return InferenceSession(ckpt, options, providers=providers)


def open_onnx_model(ckpt: str, mode: str = None) -> InferenceSession:
def open_onnx_model(ckpt: str, mode: str = None, cuda_device_id: Optional[int] = None) -> InferenceSession:
"""
Overview:
Open an ONNX model and load its ONNX runtime.
Open an ONNX model and create a configured inference session.

This function provides a high-level interface for loading ONNX models with
automatic provider selection and optimization. It supports environment variable
configuration for runtime provider selection.

:param ckpt: ONNX model file.
:param mode: Provider of the ONNX. Default is ``None`` which means the provider will be auto-detected,
see :func:`get_onnx_provider` for more details.
:return: A loaded ONNX runtime object.
:param ckpt: Path to the ONNX model file to load.
:type ckpt: str
:param mode: Provider of the ONNX runtime. Default is ``None`` which means the provider will be auto-detected,
see :func:`get_onnx_provider` for more details. Can also be controlled via ONNX_MODE environment variable.
:type mode: Optional[str]
:param cuda_device_id: Specific CUDA device ID to use for GPU inference. Only effective when using CUDA provider.
:type cuda_device_id: Optional[int]

:return: A loaded and configured ONNX inference session ready for prediction.
:rtype: InferenceSession

.. note::
When ``mode`` is set to ``None``, it will attempt to detect the environment variable ``ONNX_MODE``.
This means you can decide which ONNX runtime to use by setting the environment variable. For example,
on Linux, executing ``export ONNX_MODE=cpu`` will ignore any existing CUDA and force the model inference
to run on CPU.

Example::
>>> # Load model with auto-detected provider
>>> session = open_onnx_model('model.onnx')

>>> # Force CPU execution
>>> session = open_onnx_model('model.onnx', mode='cpu')

>>> # Use specific CUDA device
>>> session = open_onnx_model('model.onnx', mode='gpu', cuda_device_id=1)
"""
return _open_onnx_model(ckpt, get_onnx_provider(mode or os.environ.get('ONNX_MODE', None)))
return _open_onnx_model(
ckpt=ckpt,
provider=get_onnx_provider(mode or os.environ.get('ONNX_MODE', None)),
use_cpu=True,
cuda_device_id=cuda_device_id,
)
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def text_aligner():
return TextAligner().multiple_lines()


@pytest.fixture(autouse=True, scope='module')
@pytest.fixture(autouse=True, scope='class')
def clean_hf_cache():
try:
yield
Expand Down
Loading