-
Notifications
You must be signed in to change notification settings - Fork 313
Expand file tree
/
Copy pathExecutionProvider.cpp
More file actions
154 lines (133 loc) · 5.32 KB
/
ExecutionProvider.cpp
File metadata and controls
154 lines (133 loc) · 5.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE.md in the repo root for license information.
//
// Execution Provider Implementation
//
// Implementation of Windows ML execution provider ABI wrapper.
//
#include "ExecutionProvider.h"
#include "AbiUtils.h"
#include "AsyncHandlers.h"
#include <wrl/module.h>
#include <iostream>
#include <iomanip>
namespace CppAbiEPEnumerationSample
{
ExecutionProvider::ExecutionProvider(Microsoft::WRL::ComPtr<ABI::Microsoft::Windows::AI::MachineLearning::IExecutionProvider> provider)
: m_provider(provider)
{
}
std::wstring ExecutionProvider::GetName() const
{
HStringWrapper name;
HRESULT hr = m_provider->get_Name(name.put());
if (SUCCEEDED(hr))
{
return name.ToString();
}
return L"Unknown";
}
std::wstring ExecutionProvider::GetDescription() const
{
HStringWrapper path;
HRESULT hr = m_provider->get_LibraryPath(path.put());
if (SUCCEEDED(hr))
{
std::wstring libraryPath = path.ToString();
if (!libraryPath.empty())
{
return L"Library: " + libraryPath;
}
}
return L"(please ensure ready)";
}
ABI::Microsoft::Windows::AI::MachineLearning::ExecutionProviderReadyState ExecutionProvider::GetReadyState() const
{
ABI::Microsoft::Windows::AI::MachineLearning::ExecutionProviderReadyState state;
HRESULT hr = m_provider->get_ReadyState(&state);
if (SUCCEEDED(hr))
{
return state;
}
return ABI::Microsoft::Windows::AI::MachineLearning::ExecutionProviderReadyState_NotPresent;
}
std::wstring ExecutionProvider::GetReadyStateString(ABI::Microsoft::Windows::AI::MachineLearning::ExecutionProviderReadyState state) const
{
switch (state)
{
case ABI::Microsoft::Windows::AI::MachineLearning::ExecutionProviderReadyState_NotPresent: return L"Not Present";
case ABI::Microsoft::Windows::AI::MachineLearning::ExecutionProviderReadyState_NotReady: return L"Not Ready";
case ABI::Microsoft::Windows::AI::MachineLearning::ExecutionProviderReadyState_Ready: return L"Ready";
default: return L"Unknown";
}
}
bool ExecutionProvider::TryRegister() const
{
boolean result = false;
HRESULT hr = m_provider->TryRegister(&result);
return SUCCEEDED(hr) && result;
}
bool ExecutionProvider::EnsureReadyAsync() const
{
// Create an event for synchronization
HANDLE completionEvent = CreateEvent(nullptr, TRUE, FALSE, nullptr);
if (completionEvent == nullptr)
{
std::wcout << L"Failed to create synchronization event" << std::endl;
return false;
}
// Get the async operation
Microsoft::WRL::ComPtr<__FIAsyncOperationWithProgress_2_Microsoft__CWindows__CAI__CMachineLearning__CExecutionProviderReadyResult_double> operation;
HRESULT hr = m_provider->EnsureReadyAsync(&operation);
if (FAILED(hr))
{
std::wcout << L"Failed to start EnsureReadyAsync operation. HRESULT: 0x" << std::hex << hr << std::endl;
::CloseHandle(completionEvent);
return false;
}
// Create and attach progress handler
auto progressHandler = Microsoft::WRL::Make<ProgressHandler>([](double progress) {
std::wcout << L"Progress: " << std::fixed << std::setprecision(1) << (progress * 100.0) << L"%" << std::endl;
});
hr = operation->put_Progress(progressHandler.Get());
if (FAILED(hr))
{
std::wcout << L"Warning: Failed to attach progress handler. HRESULT: 0x" << std::hex << hr << std::endl;
}
// Create and attach completion handler
auto completionHandler = Microsoft::WRL::Make<CompletionHandler>(completionEvent);
if (completionHandler == nullptr)
{
std::wcout << L"Failed to create completion handler" << std::endl;
::CloseHandle(completionEvent);
return false;
}
hr = operation->put_Completed(completionHandler.Get());
if (FAILED(hr))
{
std::wcout << L"Failed to attach completion handler. HRESULT: 0x" << std::hex << hr << std::endl;
::CloseHandle(completionEvent);
return false;
}
// Wait for the operation to complete
::WaitForSingleObject(completionEvent, INFINITE);
// Check the result
auto status = completionHandler->GetStatus();
bool success = (status == ABI::Windows::Foundation::AsyncStatus::Completed);
if (!success)
{
if (status == ABI::Windows::Foundation::AsyncStatus::Error)
{
HRESULT errorCode = completionHandler->GetErrorCode();
std::wcout << L"EnsureReadyAsync failed with error: 0x" << std::hex << errorCode << std::endl;
}
else if (status == ABI::Windows::Foundation::AsyncStatus::Canceled)
{
std::wcout << L"EnsureReadyAsync was canceled" << std::endl;
}
}
// Cleanup
::CloseHandle(completionEvent);
return success;
}
} // namespace CppAbiEPEnumerationSample