-
Notifications
You must be signed in to change notification settings - Fork 301
/
main.cpp
247 lines (207 loc) · 9.36 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "pch.h"
#include <dxcore_interface.h>
#include <dxcore.h>
#include "onnxruntime_cxx_api.h"
#include "dml_provider_factory.h"
#include "TensorHelper.h"
using Microsoft::WRL::ComPtr;
bool TryGetProperty(IDXCoreAdapter* adapter, DXCoreAdapterProperty prop, std::string& outputValue)
{
if (adapter->IsPropertySupported(prop))
{
size_t propSize;
THROW_IF_FAILED(adapter->GetPropertySize(prop, &propSize));
outputValue.resize(propSize);
THROW_IF_FAILED(adapter->GetProperty(prop, propSize, outputValue.data()));
// Trim any trailing nul characters.
while (!outputValue.empty() && outputValue.back() == '\0')
{
outputValue.pop_back();
}
return true;
}
return false;
}
// Returns nullptr if not found.
void GetNonGraphicsAdapter(IDXCoreAdapterList* adapterList, IDXCoreAdapter** outAdapter)
{
for (uint32_t i = 0, adapterCount = adapterList->GetAdapterCount(); i < adapterCount; i++)
{
ComPtr<IDXCoreAdapter> possibleAdapter;
THROW_IF_FAILED(adapterList->GetAdapter(i, IID_PPV_ARGS(&possibleAdapter)));
if (!possibleAdapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS))
{
*outAdapter = possibleAdapter.Detach();
return;
}
}
*outAdapter = nullptr;
}
void InitializeDirectML(ID3D12Device1** d3dDeviceOut, ID3D12CommandQueue** commandQueueOut, IDMLDevice** dmlDeviceOut)
{
// Create Adapter Factory
ComPtr<IDXCoreAdapterFactory> factory;
// Note: this module is not currently properly freed. Outside of sample usage, this module should freed e.g. with an explicit free or through wil::unique_hmodule.
HMODULE dxCoreModule = LoadLibraryW(L"DXCore.dll");
if (dxCoreModule)
{
auto dxcoreCreateAdapterFactory = reinterpret_cast<HRESULT(WINAPI*)(REFIID, void**)>(
GetProcAddress(dxCoreModule, "DXCoreCreateAdapterFactory")
);
if (dxcoreCreateAdapterFactory)
{
dxcoreCreateAdapterFactory(IID_PPV_ARGS(&factory));
}
}
// Create the DXCore Adapter, for the purposes of selecting NPU we look for (!GRAPHICS && (GENERIC_ML || CORE_COMPUTE))
ComPtr<IDXCoreAdapter> adapter;
ComPtr<IDXCoreAdapterList> adapterList;
D3D_FEATURE_LEVEL featureLevel = D3D_FEATURE_LEVEL_1_0_GENERIC;
if (factory)
{
THROW_IF_FAILED(factory->CreateAdapterList(1, &DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML, IID_PPV_ARGS(&adapterList)));
if (adapterList->GetAdapterCount() > 0)
{
GetNonGraphicsAdapter(adapterList.Get(), adapter.GetAddressOf());
}
if (!adapter)
{
featureLevel = D3D_FEATURE_LEVEL_1_0_CORE;
THROW_IF_FAILED(factory->CreateAdapterList(1, &DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE, IID_PPV_ARGS(&adapterList)));
GetNonGraphicsAdapter(adapterList.Get(), adapter.GetAddressOf());
}
}
if (adapter)
{
std::string adapterName;
if (TryGetProperty(adapter.Get(), DXCoreAdapterProperty::DriverDescription, adapterName))
{
printf("Successfully found adapter %s\n", adapterName.c_str());
}
else
{
printf("Failed to get adapter description.\n");
}
}
// Create the D3D12 Device
ComPtr<ID3D12Device1> d3dDevice;
if (adapter)
{
// Note: this module is not currently properly freed. Outside of sample usage, this module should freed e.g. with an explicit free or through wil::unique_hmodule.
HMODULE d3d12Module = LoadLibraryW(L"d3d12.dll");
if (d3d12Module)
{
auto d3d12CreateDevice = reinterpret_cast<HRESULT(WINAPI*)(IUnknown*, D3D_FEATURE_LEVEL, REFIID, void*)>(
GetProcAddress(d3d12Module, "D3D12CreateDevice")
);
if (d3d12CreateDevice)
{
// The GENERIC feature level minimum allows for the creation of both compute only and generic ML devices.
THROW_IF_FAILED(d3d12CreateDevice(adapter.Get(), featureLevel, IID_PPV_ARGS(&d3dDevice)));
}
}
}
// Create the DML Device and D3D12 Command Queue
ComPtr<IDMLDevice> dmlDevice;
ComPtr<ID3D12CommandQueue> commandQueue;
if (d3dDevice)
{
D3D12_COMMAND_QUEUE_DESC queueDesc = {};
queueDesc.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE;
THROW_IF_FAILED(d3dDevice->CreateCommandQueue(
&queueDesc,
IID_PPV_ARGS(commandQueue.ReleaseAndGetAddressOf())));
// Note: this module is not currently properly freed. Outside of sample usage, this module should freed e.g. with an explicit free or through wil::unique_hmodule.
HMODULE dmlModule = LoadLibraryW(L"DirectML.dll");
if (dmlModule)
{
auto dmlCreateDevice = reinterpret_cast<HRESULT(WINAPI*)(ID3D12Device*, DML_CREATE_DEVICE_FLAGS, DML_FEATURE_LEVEL, REFIID, void*)>(
GetProcAddress(dmlModule, "DMLCreateDevice1")
);
if (dmlCreateDevice)
{
THROW_IF_FAILED(dmlCreateDevice(d3dDevice.Get(), DML_CREATE_DEVICE_FLAG_NONE, DML_FEATURE_LEVEL_5_0, IID_PPV_ARGS(dmlDevice.ReleaseAndGetAddressOf())));
}
}
}
d3dDevice.CopyTo(d3dDeviceOut);
commandQueue.CopyTo(commandQueueOut);
dmlDevice.CopyTo(dmlDeviceOut);
}
void main()
{
ComPtr<ID3D12Device1> d3dDevice;
ComPtr<IDMLDevice> dmlDevice;
ComPtr<ID3D12CommandQueue> commandQueue;
InitializeDirectML(d3dDevice.GetAddressOf(), commandQueue.GetAddressOf(), dmlDevice.GetAddressOf());
// Add the DML execution provider to ORT using the DML Device and D3D12 Command Queue created above.
if (!dmlDevice)
{
printf("No NPU device found\n");
return;
}
const OrtApi& ortApi = Ort::GetApi();
static Ort::Env s_OrtEnv{ nullptr };
s_OrtEnv = Ort::Env(Ort::ThreadingOptions{});
s_OrtEnv.DisableTelemetryEvents();
auto sessionOptions = Ort::SessionOptions{};
sessionOptions.DisableMemPattern();
sessionOptions.DisablePerSessionThreads();
sessionOptions.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
const OrtDmlApi* ortDmlApi = nullptr;
Ort::ThrowOnError(ortApi.GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ortDmlApi)));
Ort::ThrowOnError(ortDmlApi->SessionOptionsAppendExecutionProvider_DML1(sessionOptions, dmlDevice.Get(), commandQueue.Get()));
// Create the session
auto session = Ort::Session(s_OrtEnv, L"mobilenetv2-7-fp16.onnx", sessionOptions);
const char* inputName = "input";
const char* outputName = "output";
// Create input tensor
Ort::TypeInfo type_info = session.GetInputTypeInfo(0);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
auto input = CreateDmlValue(tensor_info, commandQueue.Get());
auto inputTensor = std::move(input.first);
const auto memoryInfo = inputTensor.GetTensorMemoryInfo();
Ort::Allocator allocator(session, memoryInfo);
// Get the inputResource and populate!
ComPtr<ID3D12Resource> inputResource;
Ort::ThrowOnError(ortDmlApi->GetD3D12ResourceFromAllocation(allocator, inputTensor.GetTensorMutableData<void*>(), &inputResource));
// Create output tensor
type_info = session.GetOutputTypeInfo(0);
tensor_info = type_info.GetTensorTypeAndShapeInfo();
auto output = CreateDmlValue(tensor_info, commandQueue.Get());
auto outputTensor = std::move(output.first);
// Run warmup
session.Run(Ort::RunOptions{ nullptr }, &inputName, &inputTensor, 1, &outputName, &outputTensor, 1);
// Queue fence, and wait for completion
ComPtr<ID3D12Fence> fence;
THROW_IF_FAILED(d3dDevice->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(fence.GetAddressOf())));
THROW_IF_FAILED(commandQueue->Signal(fence.Get(), 1));
wil::unique_handle fenceEvent(CreateEvent(nullptr, FALSE, FALSE, nullptr));
THROW_IF_FAILED(fence->SetEventOnCompletion(1, fenceEvent.get()));
THROW_HR_IF(E_FAIL, WaitForSingleObject(fenceEvent.get(), INFINITE) != WAIT_OBJECT_0);
// Record start
auto start = std::chrono::high_resolution_clock::now();
// Run performance test
constexpr int fenceValueStart = 2;
constexpr int numIterations = 100;
for (int i = fenceValueStart; i < (numIterations + fenceValueStart); i++)
{
session.Run(Ort::RunOptions{ nullptr }, &inputName, &inputTensor, 1, &outputName, &outputTensor, 1);
{
// Synchronize with CPU before queuing more inference runs
THROW_IF_FAILED(commandQueue->Signal(fence.Get(), i));
THROW_HR_IF(E_FAIL, ResetEvent(fenceEvent.get()) == 0);
THROW_IF_FAILED(fence->SetEventOnCompletion(i, fenceEvent.get()));
THROW_HR_IF(E_FAIL, WaitForSingleObject(fenceEvent.get(), INFINITE) != WAIT_OBJECT_0);
}
}
// Record end and calculate duration
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double, std::micro> duration = end - start;
printf("Evaluate Took: %fus\n", float(duration.count())/100);
// Read results
ComPtr<ID3D12Resource> outputResource;
Ort::ThrowOnError(ortDmlApi->GetD3D12ResourceFromAllocation(allocator, outputTensor.GetTensorMutableData<void*>(), &outputResource));
}