-
Notifications
You must be signed in to change notification settings - Fork 246
Implement 1-op models for EPs #1895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements efficient Cast operator support for the WebGPU backend by dynamically generating minimal ONNX models and caching inference sessions. The implementation enables type conversion operations to be performed on WebGPU devices without requiring external ONNX library dependencies.
- Adds manual protobuf-based ONNX model generation for Cast operations
- Implements thread-safe session caching to avoid redundant model creation
- Provides WebGPU-specific Cast method using ONNX Runtime's IOBinding
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| src/webgpu/cast_model_builder.h | Declares the function to create ONNX Cast model bytes from input/output types |
| src/webgpu/cast_model_builder.cpp | Implements manual protobuf encoding to generate minimal ONNX Cast operator models without ONNX library dependency |
| src/webgpu/interface.cpp | Adds CastSessionCache for thread-safe session reuse and implements Cast method with element size helper for tensor creation |
|
@kunal-vaishnavi @fs-eire @guschmue In the latest commit, I move the cached cast_sessions_ from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 8 comments.
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "one_op_model_executor.h" | ||
| #include "one_op_model_builder.h" | ||
| #include "../generators.h" | ||
| #include <functional> | ||
| #include <mutex> | ||
| #include <unordered_map> | ||
|
|
||
| namespace Generators { | ||
|
|
||
| // Global cache for 1-op model sessions | ||
| // Stored in OrtGlobals to ensure proper cleanup before OrtEnv destruction | ||
| struct OneOpSessionCache { | ||
| std::unordered_map<uint64_t, std::unique_ptr<OrtSession>> sessions_; | ||
| std::mutex mutex_; | ||
| }; | ||
|
|
||
| // Get the global session cache (stored in OrtGlobals) | ||
| static OneOpSessionCache& GetOneOpSessionCache() { | ||
| static OneOpSessionCache cache; | ||
| return cache; | ||
| } | ||
|
|
||
| // Generate a cache key from the model configuration and EP name | ||
| uint64_t OneOpModelExecutor::GenerateCacheKey(const OneOpModelConfig& config, const std::string& ep_name) { | ||
| // Simple hash combining op_type, input/output types, and EP name | ||
| // For more complex operators with attributes, we'd need a more sophisticated hash | ||
| std::hash<std::string> hasher; | ||
| uint64_t key = hasher(config.op_type); | ||
|
|
||
| // Hash EP name | ||
| key ^= hasher(ep_name) + 0x9e3779b9 + (key << 6) + (key >> 2); | ||
|
|
||
| // Hash input types | ||
| for (const auto& input : config.inputs) { | ||
| key ^= static_cast<uint64_t>(input.elem_type) + 0x9e3779b9 + (key << 6) + (key >> 2); | ||
| } | ||
|
|
||
| // Hash output types | ||
| for (const auto& output : config.outputs) { | ||
| key ^= static_cast<uint64_t>(output.elem_type) + 0x9e3779b9 + (key << 6) + (key >> 2); | ||
| } | ||
|
|
||
| // Hash attributes | ||
| for (const auto& attr : config.attributes) { | ||
| key ^= hasher(attr.name) + 0x9e3779b9 + (key << 6) + (key >> 2); | ||
|
|
||
| switch (attr.type) { | ||
| case AttributeType::INT: | ||
| key ^= static_cast<uint64_t>(attr.int_value) + 0x9e3779b9 + (key << 6) + (key >> 2); | ||
| break; | ||
| case AttributeType::FLOAT: { | ||
| uint32_t float_bits; | ||
| std::memcpy(&float_bits, &attr.float_value, sizeof(float)); | ||
| key ^= static_cast<uint64_t>(float_bits) + 0x9e3779b9 + (key << 6) + (key >> 2); | ||
| break; | ||
| } | ||
| case AttributeType::STRING: | ||
| key ^= hasher(attr.string_value) + 0x9e3779b9 + (key << 6) + (key >> 2); | ||
| break; | ||
| case AttributeType::INTS: | ||
| for (auto val : attr.ints_value) { | ||
| key ^= static_cast<uint64_t>(val) + 0x9e3779b9 + (key << 6) + (key >> 2); | ||
| } | ||
| break; | ||
| case AttributeType::FLOATS: | ||
| for (auto val : attr.floats_value) { | ||
| uint32_t float_bits; | ||
| std::memcpy(&float_bits, &val, sizeof(float)); | ||
| key ^= static_cast<uint64_t>(float_bits) + 0x9e3779b9 + (key << 6) + (key >> 2); | ||
| } | ||
| break; | ||
| case AttributeType::STRINGS: | ||
| for (const auto& val : attr.strings_value) { | ||
| key ^= hasher(val) + 0x9e3779b9 + (key << 6) + (key >> 2); | ||
| } | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| return key; | ||
| } | ||
|
|
||
| // Create a new session for the given model and EP | ||
| std::unique_ptr<OrtSession> OneOpModelExecutor::CreateSession( | ||
| const std::vector<uint8_t>& model_bytes, | ||
| const std::string& ep_name, | ||
| const std::vector<const char*>& session_config_keys, | ||
| const std::vector<const char*>& session_config_values) { | ||
| auto session_options = OrtSessionOptions::Create(); | ||
| session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); | ||
|
|
||
| // Apply session configuration entries | ||
| for (size_t i = 0; i < session_config_keys.size(); i++) { | ||
| session_options->AddConfigEntry(session_config_keys[i], session_config_values[i]); | ||
| } | ||
|
|
||
| // Append execution provider | ||
| if (!ep_name.empty()) { | ||
| session_options->AppendExecutionProvider(ep_name.c_str(), nullptr, nullptr, 0); | ||
| } | ||
|
|
||
| return OrtSession::Create(GetOrtEnv(), model_bytes.data(), model_bytes.size(), session_options.get()); | ||
| } | ||
|
|
||
| // Get or create a cached session | ||
| OrtSession* OneOpModelExecutor::GetOrCreateSession( | ||
| const OneOpModelConfig& config, | ||
| const std::string& ep_name, | ||
| const std::vector<const char*>& session_config_keys, | ||
| const std::vector<const char*>& session_config_values) { | ||
| auto& cache = GetOneOpSessionCache(); | ||
| uint64_t key = GenerateCacheKey(config, ep_name); | ||
|
|
||
| std::lock_guard<std::mutex> lock(cache.mutex_); | ||
|
|
||
| auto it = cache.sessions_.find(key); | ||
| if (it != cache.sessions_.end()) { | ||
| return it->second.get(); | ||
| } | ||
|
|
||
| // Create new session | ||
| auto model_bytes = OneOpModelBuilder::Build(config); | ||
| auto session = CreateSession(model_bytes, ep_name, session_config_keys, session_config_values); | ||
|
|
||
| OrtSession* session_ptr = session.get(); | ||
| cache.sessions_[key] = std::move(session); | ||
|
|
||
| return session_ptr; | ||
| } | ||
|
|
||
| // Execute a 1-op model | ||
| bool OneOpModelExecutor::Execute( | ||
| const OneOpModelConfig& model_config, | ||
| const OneOpExecutionParams& exec_params) { | ||
| try { | ||
| // Get or create session | ||
| OrtSession* session = GetOrCreateSession( | ||
| model_config, | ||
| exec_params.execution_provider_name, | ||
| exec_params.session_config_keys, | ||
| exec_params.session_config_values); | ||
|
|
||
| // Create IOBinding for efficient execution | ||
| auto io_binding = OrtIoBinding::Create(*session); | ||
|
|
||
| // Bind inputs | ||
| for (size_t i = 0; i < exec_params.inputs.size(); i++) { | ||
| const auto& input_spec = exec_params.inputs[i]; | ||
| const auto& input_config = model_config.inputs[i]; | ||
|
|
||
| auto input_tensor = OrtValue::CreateTensor( | ||
| *exec_params.memory_info, | ||
| input_spec.data, | ||
| input_spec.size_in_bytes, | ||
| input_spec.shape, | ||
| input_spec.elem_type); | ||
|
|
||
| io_binding->BindInput(input_config.name.c_str(), *input_tensor); | ||
| } | ||
|
|
||
| // Bind outputs | ||
| for (size_t i = 0; i < exec_params.outputs.size(); i++) { | ||
| const auto& output_spec = exec_params.outputs[i]; | ||
| const auto& output_config = model_config.outputs[i]; | ||
|
|
||
| auto output_tensor = OrtValue::CreateTensor( | ||
| *exec_params.memory_info, | ||
| output_spec.data, | ||
| output_spec.size_in_bytes, | ||
| output_spec.shape, | ||
| output_spec.elem_type); | ||
|
|
||
| io_binding->BindOutput(output_config.name.c_str(), *output_tensor); | ||
| } | ||
|
|
||
| // Run inference | ||
| session->Run(nullptr, *io_binding); | ||
|
|
||
| return true; | ||
| } catch (const std::exception& e) { | ||
| // Log error or handle as needed | ||
| std::cerr << "OneOpModelExecutor::Execute failed: " << e.what() << std::endl; | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| // Clear all cached sessions | ||
| void OneOpModelExecutor::ClearCache() { | ||
| auto& cache = GetOneOpSessionCache(); | ||
| std::lock_guard<std::mutex> lock(cache.mutex_); | ||
| cache.sessions_.clear(); | ||
| } | ||
|
|
||
| // Helper function for Cast operation | ||
| bool ExecuteCastOp( | ||
| void* input_data, | ||
| void* output_data, | ||
| ONNXTensorElementDataType input_type, | ||
| ONNXTensorElementDataType output_type, | ||
| size_t element_count, | ||
| const std::string& execution_provider_name, | ||
| const OrtMemoryInfo* memory_info, | ||
| const std::vector<const char*>& session_config_keys, | ||
| const std::vector<const char*>& session_config_values) { | ||
| // Build Cast model configuration with dynamic shape (-1) to support any element count | ||
| OneOpModelConfig config("Cast"); | ||
| config.inputs.push_back(TensorConfig("input", input_type, {-1})); | ||
| config.outputs.push_back(TensorConfig("output", output_type, {-1})); | ||
| config.attributes.push_back(AttributeValue::Int("to", static_cast<int64_t>(output_type))); | ||
|
|
||
| // Build execution parameters | ||
| OneOpExecutionParams params(execution_provider_name, memory_info); | ||
| params.inputs.push_back(OneOpTensorSpec( | ||
| input_data, | ||
| input_type, | ||
| {static_cast<int64_t>(element_count)}, | ||
| element_count * Ort::SizeOf(input_type))); | ||
| params.outputs.push_back(OneOpTensorSpec( | ||
| output_data, | ||
| output_type, | ||
| {static_cast<int64_t>(element_count)}, | ||
| element_count * Ort::SizeOf(output_type))); | ||
|
|
||
| // Apply session config entries if provided | ||
| params.session_config_keys = session_config_keys; | ||
| params.session_config_values = session_config_values; | ||
|
|
||
| return OneOpModelExecutor::Execute(config, params); | ||
| } | ||
|
|
||
| } // namespace Generators |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new OneOpModelExecutor and OneOpModelBuilder classes lack test coverage. Given that this is core infrastructure for WebGPU Cast operations (and potentially other operations in the future), unit tests should be added to verify correct ONNX model generation, cache behavior, and session execution across different data types and operators.
| bool Cast(void* input, void* output, ONNXTensorElementDataType input_type, ONNXTensorElementDataType output_type, size_t element_count) override { | ||
| if (!ort_allocator_) { | ||
| throw std::runtime_error("WebGPU allocator not initialized"); | ||
| } | ||
|
|
||
| // Get WebGPU allocator's memory info | ||
| const OrtMemoryInfo* webgpu_mem_info = nullptr; | ||
| Ort::ThrowOnError(Ort::api->AllocatorGetInfo(ort_allocator_, &webgpu_mem_info)); | ||
|
|
||
| // WebGPU-specific session configuration | ||
| static const char* webgpu_config_key = "ep.webgpuexecutionprovider.registerInt64Ops"; | ||
| static const char* webgpu_config_value = "1"; | ||
| std::vector<const char*> session_config_keys = {webgpu_config_key}; | ||
| std::vector<const char*> session_config_values = {webgpu_config_value}; | ||
|
|
||
| // Use the generalized ExecuteCastOp helper with WebGPU session config | ||
| return ExecuteCastOp( | ||
| input, | ||
| output, | ||
| input_type, | ||
| output_type, | ||
| element_count, | ||
| "WebGPU", | ||
| webgpu_mem_info, | ||
| session_config_keys, | ||
| session_config_values); | ||
| } |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new Cast method implementation for WebGPU lacks test coverage. Consider adding tests to verify that the Cast operation works correctly for various type conversions (e.g., float to float16, int32 to int64) on the WebGPU execution provider.
| // Default is 17 which is widely supported and has been validated with this infrastructure. | ||
| // Can be overridden if a specific opset is required, but ensure the ONNX Runtime build | ||
| // supports it and the operator exists in that opset version. | ||
| int opset_version{17}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use opset 21 to match the opset used in the model builder?
This pull request introduces a new utility for building, executing, and caching single-operator (1-op) ONNX models, which can be leveraged by different execution providers (EPs) for efficient operator execution. The changes add a complete infrastructure for dynamic 1-op model creation, session management, and execution, along with proper resource cleanup. The most important changes are grouped below.
1-op Model Infrastructure
one_op_model_builder.h/cppandone_op_model_executor.h/cppimplementing utilities for constructing ONNX protobuf models for a single operator, managing model configuration, encoding, and execution, including session caching and helpers for common ops like Cast. [1] [2] [3] [4]Integration and Resource Management
one_op_model_executor.hingenerators.cppand adding an explicit destructor toOrtGlobalsto ensure cached sessions are cleared before ONNX environment destruction. [1] [2] [3]With this change, phi4 with graph capture becomes 130.4 tps from 116.2 tps on NV 5080.