-
Notifications
You must be signed in to change notification settings - Fork 313
Expand file tree
/
Copy pathApplication.cpp
More file actions
148 lines (124 loc) · 4.45 KB
/
Application.cpp
File metadata and controls
148 lines (124 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE.md in the repo root for license information.
//
// Application Implementation
//
// Implementation of main application logic for the Windows ML ABI sample.
//
#include "Application.h"
#include "ConsoleUI.h"
#include <roapi.h>
#include <iostream>
#include <stdexcept>
namespace CppAbiEPEnumerationSample
{
void Application::Run()
{
// Initialize Windows Runtime outside try/catch for proper cleanup
HRESULT hr = RoInitialize(RO_INIT_MULTITHREADED);
if (FAILED(hr))
{
std::wcout << L"Failed to initialize Windows Runtime. HRESULT: 0x"
<< std::hex << hr << std::endl;
return;
}
try
{
RunInteractiveSession();
}
catch (const std::exception& e)
{
std::wcout << L"Error: " << std::wstring(e.what(), e.what() + strlen(e.what())) << std::endl;
}
// Uninitialize Windows Runtime after all COM objects are destroyed
RoUninitialize();
}
void Application::RunInteractiveSession()
{
ConsoleUI::PrintHeader();
// Create catalog and find providers
ExecutionProviderCatalog catalog;
auto providers = catalog.FindAllProviders();
if (providers.empty())
{
std::wcout << L"No execution providers found." << std::endl;
return;
}
while (true)
{
ConsoleUI::PrintProviders(providers);
int selection = ConsoleUI::GetUserSelection(static_cast<int>(providers.size()));
if (selection == 0)
{
break;
}
if (selection < 1 || selection > static_cast<int>(providers.size()))
{
std::wcout << L"Invalid selection." << std::endl;
continue;
}
const auto& selectedProvider = providers[selection - 1];
ProcessProviderSelection(selectedProvider);
ConsoleUI::WaitForKeypress();
}
}
void Application::ProcessProviderSelection(const ExecutionProvider& provider)
{
std::wcout << L"Selected: " << provider.GetName() << std::endl;
auto currentState = provider.GetReadyState();
if (currentState == ABI::Microsoft::Windows::AI::MachineLearning::ExecutionProviderReadyState_Ready)
{
std::wcout << L"Provider is already ready!" << std::endl;
TryRegisterProvider(provider);
}
else
{
std::wcout << L"Provider current state: " << provider.GetReadyStateString(currentState) << std::endl;
if (currentState == ABI::Microsoft::Windows::AI::MachineLearning::ExecutionProviderReadyState_NotPresent ||
currentState == ABI::Microsoft::Windows::AI::MachineLearning::ExecutionProviderReadyState_NotReady)
{
PrepareProvider(provider);
}
else
{
// Try to register for other states
TryRegisterProvider(provider);
}
}
}
void Application::PrepareProvider(const ExecutionProvider& provider)
{
std::wcout << L"Calling EnsureReadyAsync to prepare the provider..." << std::endl;
bool ensureReadySuccess = provider.EnsureReadyAsync();
if (ensureReadySuccess)
{
// Check the state again after EnsureReady
auto newState = provider.GetReadyState();
std::wcout << L"Provider state after EnsureReady: "
<< provider.GetReadyStateString(newState) << std::endl;
if (newState == ABI::Microsoft::Windows::AI::MachineLearning::ExecutionProviderReadyState_Ready)
{
TryRegisterProvider(provider);
}
else
{
std::wcout << L"Provider is not ready even after EnsureReadyAsync." << std::endl;
}
}
else
{
std::wcout << L"EnsureReadyAsync failed." << std::endl;
}
}
void Application::TryRegisterProvider(const ExecutionProvider& provider)
{
if (provider.TryRegister())
{
std::wcout << L"Provider registered successfully!" << std::endl;
}
else
{
std::wcout << L"Provider registration failed." << std::endl;
}
}
} // namespace CppAbiEPEnumerationSample