forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.h
242 lines (197 loc) · 9.24 KB
/
utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/graph/basic_types.h"
#include "core/framework/allocator.h"
#include "core/framework/data_types.h"
#include "core/framework/framework_common.h"
#include "core/framework/iexecutor.h"
#include "core/framework/session_state.h"
#include "core/framework/session_options.h"
namespace onnxruntime {
class ExecutionProviders;
struct FeedsFetchesInfo;
class FeedsFetchesManager;
struct MLValueCopyInfo;
class Graph;
class KernelDef;
class KernelRegistryManager;
class IExecutionProvider;
class Node;
class Tensor;
struct KernelCreateInfo;
#ifdef ENABLE_TRAINING
struct PartialGraphExecutionState;
typedef InlinedHashMap<std::string, OrtValue> OrtValueCache;
typedef std::shared_ptr<OrtValueCache> OrtValueCachePtr;
#endif
namespace logging {
class Logger;
}
namespace utils {
void* DefaultAlloc(size_t size);
void DefaultFree(void* p);
/// <summary>
// Do the placement new for strings on pre-allocated buffer
// `elements` times.
/// </summary>
/// <param name="p_data"></param>
/// <param name="elements"></param>
void ConstructStrings(void* p_data, int64_t elements);
/// <summary>
/// Destroy std::string objects in the contiquous chunk of memory
/// by explicitely invoking ~string();
/// </summary>
/// <param name="p_data"></param>
/// <param name="elements"></param>
void DestroyStrings(void* p_data, int64_t elements);
const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info);
// EP used for internal testing. We define it here as it's used in ProviderIsCpuBased, but we don't want
// it to be in the public header include/onnxruntime/core/graph/constants.h as it's purely internal.
constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider";
// return true if the execution provider is CPU based (meaning no copies to device are required)
bool ProviderIsCpuBased(const std::string& provider_type);
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);
// Searches the allocation plan from the session_state to find the OrtMemoryInfo for the value 'name'.
const OrtDevice& FindDeviceForValue(const SessionState& session_state, std::string_view name);
// Initialize the feed and fetch copy info using session_state.
// Determines the device that each graph input that will be fed will be consumed on,
// and the device that each graph output that will be fetched will be created on.
common::Status InitializeFeedFetchCopyInfo(const SessionState& session_state,
FeedsFetchesManager& feeds_fetches_manager);
// Finalize the feed and fetch copy info using session_state and the device and location information from the feeds
// and fetches that will be used in graph execution.
void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtDevice> feed_locations,
gsl::span<const OrtDevice* const> fetch_alloc_info);
// Execute the main graph. The feed_fetches_manager will be finalized based on the provided feeds and fetches.
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
bool only_execute_path_to_fetches = false,
Stream* parent_stream = nullptr);
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const RunOptions& run_options,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
const logging::Logger& logger);
#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache,
const bool& terminate_flag,
int32_t partial_graph_index,
Stream* parent_stream);
#endif
// Execute a subgraph. The feeds_fetches_manager should have been finalized prior to calling this function.
// See IControlFlowNode::SetupSubgraphExecutionInfo usage in the control flow kernels.
common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger,
Stream* parent_stream,
/*when this is enabled, we will sync the parent stream to make sure the subgraph fetches
is complete. this is mainly used when the parent kernel depends on the CPU value of the
subgraph fetches, i.e. the loop condition*/
bool sync_subgraph_fetches = false);
bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index);
bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index);
template <typename T>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<bool>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<std::string>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<float>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<double>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<MLFloat16>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<BFloat16>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int8_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint8_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int16_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint16_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int32_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint32_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int64_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint64_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
}
#if !defined(DISABLE_FLOAT8_TYPES)
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<Float8E4M3FN>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<Float8E4M3FNUZ>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<Float8E5M2>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<Float8E5M2FNUZ>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
}
#endif
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<Int4x2>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<UInt4x2>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4;
}
int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType);
#ifdef ENABLE_TRAINING
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context);
#endif
} // namespace utils
} // namespace onnxruntime