-
Notifications
You must be signed in to change notification settings - Fork 313
Expand file tree
/
Copy pathExecutionProviderCatalog.cpp
More file actions
77 lines (63 loc) · 2.53 KB
/
ExecutionProviderCatalog.cpp
File metadata and controls
77 lines (63 loc) · 2.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE.md in the repo root for license information.
//
// Execution Provider Catalog Implementation
//
// Implementation of Windows ML execution provider catalog ABI wrapper.
//
#include "ExecutionProviderCatalog.h"
#include <windows.h>
#include <roapi.h>
#include <combaseapi.h>
#include <wrl/wrappers/corewrappers.h>
#include <stdexcept>
#include <sstream>
#include <iomanip>
namespace CppAbiEPEnumerationSample
{
ExecutionProviderCatalog::ExecutionProviderCatalog()
{
// Get the catalog factory
Microsoft::WRL::ComPtr<ABI::Microsoft::Windows::AI::MachineLearning::IExecutionProviderCatalogStatics> factory;
HSTRING className;
HRESULT hr = WindowsCreateString(RuntimeClass_Microsoft_Windows_AI_MachineLearning_ExecutionProviderCatalog,
static_cast<UINT32>(wcslen(RuntimeClass_Microsoft_Windows_AI_MachineLearning_ExecutionProviderCatalog)),
&className);
if (SUCCEEDED(hr))
{
hr = RoGetActivationFactory(className, IID_PPV_ARGS(&factory));
WindowsDeleteString(className);
}
if (SUCCEEDED(hr))
{
hr = factory->GetDefault(&m_catalog);
}
if (FAILED(hr))
{
std::ostringstream oss;
oss << "Failed to create ExecutionProviderCatalog (HRESULT: 0x"
<< std::hex << std::uppercase << static_cast<unsigned int>(hr) << ")";
throw std::runtime_error(oss.str());
}
}
std::vector<ExecutionProvider> ExecutionProviderCatalog::FindAllProviders() const
{
std::vector<ExecutionProvider> providers;
// The method returns a C-style array, not a vector
UINT32 resultLength = 0;
ABI::Microsoft::Windows::AI::MachineLearning::IExecutionProvider** result = nullptr;
HRESULT hr = m_catalog->FindAllProviders(&resultLength, &result);
if (SUCCEEDED(hr) && result)
{
for (UINT32 i = 0; i < resultLength; i++)
{
Microsoft::WRL::ComPtr<ABI::Microsoft::Windows::AI::MachineLearning::IExecutionProvider> provider;
provider.Attach(result[i]); // Take ownership without AddRef (already ref counted from array)
providers.emplace_back(ExecutionProvider(provider));
}
// The array itself needs to be freed
CoTaskMemFree(result);
}
return providers;
}
} // namespace CppAbiEPEnumerationSample