Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions plugin_execution_providers/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ endif()
add_definitions(-DONNX_NAMESPACE=onnx)
add_definitions(-DONNX_ML)
add_definitions(-DNOMINMAX)
file(GLOB tensorrt_src "./*.cc" "./utils/*.cc" "./cuda/unary_elementwise_ops_impl.cu" "./*.h")
file(GLOB tensorrt_src "./src/*.cc" "./src/utils/*.cc" "./src/cuda/unary_elementwise_ops_impl.cu" "./src/*.h")
add_library(TensorRTEp SHARED ${tensorrt_src})

if (NOT ORT_HOME)
Expand Down Expand Up @@ -111,7 +111,7 @@ if (WIN32) # Windows
"${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib")

set(TRT_EP_LIB_LINK_FLAG
"-DEF:${CMAKE_SOURCE_DIR}/tensorrt_execution_provider.def")
"-DEF:${CMAKE_SOURCE_DIR}/src/tensorrt_execution_provider.def")

else() # Linux
set(ORT_LIB "${ORT_HOME}/lib/libonnxruntime.so")
Expand Down Expand Up @@ -142,7 +142,7 @@ set_property(TARGET TensorRTEp APPEND_STRING PROPERTY LINK_FLAGS
${TRT_EP_LIB_LINK_FLAG})

target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include"
"./utils"
"./src/utils"
"/usr/local/cuda/include"
"${TENSORRT_HOME}/include"
"${DEPS_PATH}/flatbuffers-src/include"
Expand Down
56 changes: 56 additions & 0 deletions plugin_execution_providers/tensorrt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Plugin TensorRT EP

This repo contains:
- The plugin TRT EP implementation
- How to build plugin TRT EP
- How to build python wheel for plugin TRT EP
- How to run inference with plugin TRT EP using python API

Plugin TRT EP is migrated from the original TRT EP and provides the implementations of `OrtEpFactory`, `OrtEp`, `OrtNodeComputeInfo`, `OrtDataTransferImpl` ... that are required for a plugin EP to be able to interact with ONNX Runtime via the EP ABI (introduced in ORT 1.23.0).

## How to build (on Windows) ##
````bash
mkdir build;cd build
````
````bash
cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DTENSORRT_HOME=C:/folder/to/trt -DORT_HOME=C:/folder/to/ort
````
````bash
cmake --build ./ --config Debug
`````

If the build succeeds, you will see the TRT EP DLL being generated at:
```
C:\repos\onnxruntime-inference-examples\plugin_execution_providers\tensorrt\build> ls .\Debug

TensorRTEp.dll
```


Note: The `ORT_HOME` should contain the `include` and `lib` folder as below
```
C:\folder\to\ort
|
| ----- lib
| | ----- onnxruntime.dll
| | ----- onnxruntime.lib
| | ----- onnxruntime.pdb
| ...
|
| ---- include
| | ----- onnxruntime_c_api.h
| | ----- onnxruntime_ep_c_api.h
| | ----- onnxruntime_cxx_api.h
| | ----- onnxruntime_cxx_inline_api.h
| ...
```
## How to build python wheel (on Windows) ##
```
setup.py bdist_wheel
```
Once it's done, you will see the wheel file at:
```
C:\repos\onnxruntime-inference-examples\plugin_execution_providers\tensorrt> ls .\dist

plugin_trt_ep-0.1.0-cp312-cp312-win_amd64.whl
```
52 changes: 52 additions & 0 deletions plugin_execution_providers/tensorrt/example/plugin_ep_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import onnxruntime as onnxrt
import plugin_trt_ep
import numpy as np

# Path to the plugin EP library
ep_lib_path = plugin_trt_ep.get_path()
# Registration name can be anything the application chooses
ep_registration_name = "TensorRTEp"
# EP name should match the name assigned by the EP factory when creating the EP (i.e., in the implementation of OrtEP::CreateEp)
ep_name = ep_registration_name

# Register plugin EP library with ONNX Runtime
onnxrt.register_execution_provider_library(ep_registration_name, ep_lib_path)

#
# Create ORT session with explicit OrtEpDevice(s)
#

# Find the OrtEpDevice for "TensorRTEp"
ep_devices = onnxrt.get_ep_devices()
trt_ep_device = None
for ep_device in ep_devices:
if ep_device.ep_name == ep_name:
trt_ep_device = ep_device

assert trt_ep_device != None

sess_options = onnxrt.SessionOptions()

# Equivalent to the C API's SessionOptionsAppendExecutionProvider_V2 that appends "TensorRTEp" to ORT session option
sess_options.add_provider_for_devices([trt_ep_device], {'trt_engine_cache_enable': '1'})

assert sess_options.has_providers() == True

# Create ORT session with "TensorRTEp" plugin EP
sess = onnxrt.InferenceSession("C:\\models\\mul_1.onnx", sess_options=sess_options)

# Run sample model and check output
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
input_name = sess.get_inputs()[0].name
res = sess.run([], {input_name: x})
output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32)
np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08)

# Unregister the library using the application-specified registration name.
# Must only unregister a library after all sessions that use the library have been released.
onnxrt.unregister_execution_provider_library(ep_registration_name)


# Note:
# The mul_1.onnx can be found here:
# https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/testdata/mul_1.onnx
28 changes: 28 additions & 0 deletions plugin_execution_providers/tensorrt/plugin_trt_ep/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import importlib.resources
import ctypes
import onnxruntime as ort

ort_dir = os.path.dirname(os.path.abspath(ort.__file__))
dll_path = os.path.join(ort_dir, "capi", "onnxruntime.dll")

# When the application calls ort.register_execution_provider_library() with the path to the plugin EP DLL,
# ORT internally uses LoadLibraryExW() to load that DLL. Since the plugin EP depends on onnxruntime.dll,
# the operating system will attempt to locate and load onnxruntime.dll first.
#
# On Windows, LoadLibraryExW() searches the directory containing the plugin EP DLL before searching system directories.
# Because onnxruntime.dll is not located in the plugin EP’s directory, Windows ends up loading the copy from a
# system directory instead — which is not the correct version.
#
# To ensure the plugin EP uses the correct onnxruntime.dll bundled with the ONNX Runtime package,
# we load that DLL explicitly before loading the plugin EP DLL.
ctypes.WinDLL(dll_path)

def get_path(filename: str = "TensorRTEp.dll") -> str:
"""
Returns the absolute filesystem path to a DLL (or any file)
packaged inside plugin_trt_ep/libs.
"""
package = __name__ + ".libs"
with importlib.resources.as_file(importlib.resources.files(package) / filename) as path:
return str(path)
53 changes: 53 additions & 0 deletions plugin_execution_providers/tensorrt/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from setuptools import setup, find_packages
from setuptools.dist import Distribution
import os
import shutil

ep_dll = "TensorRTEp.dll"
src_folder = r".\build\\Debug"
dst_folder = r".\\plugin_trt_ep\\libs"

class BinaryDistribution(Distribution):
# This ensures wheel is marked as "non-pure" (has binary files)
def has_ext_modules(self):
return True

def copy_ep_dll(src_folder: str, dst_folder: str, ep_dll: str = "TensorRTEp.dll"):
"""
Copy EP DLL from src_folder to dst_folder.
Create dst_folder if it doesn't exist.
"""
src_dll_path = os.path.join(src_folder, ep_dll)

# Validate source file
if not os.path.isfile(src_dll_path):
raise FileNotFoundError(f"Source DLL not found: {src_dll_path}")

# Create destination folder if needed
os.makedirs(dst_folder, exist_ok=True)

dst_dll_path = os.path.join(dst_folder, ep_dll)

# Copy file
shutil.copy2(src_dll_path, dst_dll_path)

print(f"Copied {ep_dll} to: {dst_dll_path}")

try:
copy_ep_dll(src_folder, dst_folder, ep_dll)
except Exception as e:
print(f"Error: {e}")

setup(
name="plugin_trt_ep",
version="0.1.0",
packages=["plugin_trt_ep"],
include_package_data=True, # include MANIFEST.in contents
package_data={
"plugin_trt_ep": ["libs/*.dll"], # include DLLs inside the wheel
},
distclass=BinaryDistribution,
description="Example package including DLLs",
author="ORT",
python_requires=">=3.8",
)
Loading