Skip to content

Commit

Permalink
[DML EP] Add dynamic graph compilation (#17876)
Browse files Browse the repository at this point in the history
Historically, DML was only able to fuse partitions when all sizes are
known in advance or when we were overriding them at session creation
time. But in practice, it should be possible to compile partitions at
compute time if the caller knows that the dimensions won't be changed
for every inference (e.g. resizing a webcam window, or padding the input
to powers of 2). This graph will be cached and reused until the sizes
change.

This is an opt-in option gated under the `enable_dynamic_graph_fusion`
option, which means that it will only be enabled when the caller
requests it since they have more context on how their model will be
called between inferences.

This PR also adds the option to disable metacommands from the python
API, which is an option for the C API but was lacking for python.
  • Loading branch information
PatriceVignola authored Oct 26, 2023
1 parent d30d4d3 commit 538e97c
Show file tree
Hide file tree
Showing 26 changed files with 1,127 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct OrtDmlApi {
/**
* SessionOptionsAppendExecutionProvider_DML2
* Creates a DirectML Execution Provider given the supplied device options that contain a performance preference
* (high power, low power, or defult) and a device filter (None, GPU, or NPU).
* (high power, low power, or default) and a device filter (None, GPU, or NPU).
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

#pragma once
interface IMLOperatorRegistry;
interface IDMLDevice;
interface ID3D12CommandQueue;
interface ID3D12Resource;

#include "core/common/status.h"
#include "core/framework/data_transfer.h"
Expand All @@ -28,7 +31,8 @@ namespace Dml
std::unique_ptr<onnxruntime::IExecutionProvider> CreateExecutionProvider(
IDMLDevice* dmlDevice,
ID3D12CommandQueue* commandQueue,
bool enableMetacommands = true);
bool enableMetacommands,
bool enableDynamicGraphFusion);

ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr);
void FlushContext(onnxruntime::IExecutionProvider* provider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
#include <functional>
#include <variant>
#include <optional>
#include <wrl/client.h>

#include "core/framework/op_kernel.h"
#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h"

struct AbstractOperatorDesc;
interface IMLOperatorTensor;
interface IDMLOperator;
struct DML_INPUT_GRAPH_EDGE_DESC;
struct DML_OUTPUT_GRAPH_EDGE_DESC;
struct DML_INTERMEDIATE_GRAPH_EDGE_DESC;
Expand Down Expand Up @@ -92,6 +95,8 @@ namespace Windows::AI::MachineLearning::Adapter
const onnxruntime::Node& node,
MLOperatorTensorGetter& constantInputGetter,
const void* executionHandle,
const EdgeShapes* inputShapesOverrides,
/*out*/ EdgeShapes* outputShapes,
/*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo
)>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,22 +491,24 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
const onnxruntime::Node& node,
MLOperatorTensorGetter& constantInputGetter,
const void* executionHandle,
const EdgeShapes* inputShapesOverrides,
/*out*/ EdgeShapes* outputShapes,
/*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo
)
{
onnxruntime::ProtoHelperNodeContext nodeContext(node);
onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext> protoHelper(&nodeContext);

// Use the same list of required constant inputs for the shape inferrer and the kernel.
EdgeShapes outputShapes;
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes);
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes);

// Create the kernel while allowing input shape and output shape queries according to options
ComPtr<DmlGraphOpKernelInfoWrapper> kernelInfoWrapper = wil::MakeOrThrow<DmlGraphOpKernelInfoWrapper>(
&protoHelper,
executionHandle,
true,
&outputShapes,
inputShapesOverrides,
outputShapes,
&defaultAttributesCapture,
graphNodeCreateInfo,
constantCpuInputCapture,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

namespace Windows::AI::MachineLearning::Adapter
{
// edges and unused edges have an empty array of dimensions.
class EdgeShapes
{
public:
EdgeShapes() = default;

EdgeShapes(size_t count) : m_shapes(count) {}

const std::vector<uint32_t>& GetShape(size_t edgeIndex) const
{
return m_shapes[edgeIndex];
}

std::vector<uint32_t>& GetMutableShape(size_t edgeIndex)
{
return m_shapes[edgeIndex];
}

size_t EdgeCount() const { return m_shapes.size(); }

void Reset(size_t edge_count)
{
m_shapes.clear();
m_shapes.resize(edge_count);
}

bool operator!=(const EdgeShapes& other) const noexcept
{
return (m_shapes != other.m_shapes);
}

private:
std::vector<std::vector<uint32_t>> m_shapes;
};
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "DmlGraphFusionHelper.h"

#include "DmlRuntimeFusedGraphKernel.h"

namespace Dml
{
Expand Down Expand Up @@ -501,5 +501,171 @@ namespace DmlGraphFusionHelper

graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode);
}

void RegisterDynamicKernel(
onnxruntime::Graph& graph,
onnxruntime::KernelRegistry* registryForPartitionKernels,
const ExecutionProviderImpl* providerImpl,
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap,
const std::unordered_set<std::string>& dynamicCpuInputMap,
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>&& isInitializerTransferable)
{
struct NodeInfo
{
std::string name;
std::string opType;
std::string description;
std::string domain;
onnxruntime::NodeAttributes attributes;
std::vector<onnxruntime::NodeArg*> inputDefPointers;
std::vector<onnxruntime::NodeArg*> outputDefPointers;
};

auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap(
graph,
*indexedSubGraph,
std::move(graphNodePropertyMap));

auto modelPath = graph.ModelPath();

const gsl::span<const std::string> subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs;
const gsl::span<const std::string> subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs;

std::vector<NodeInfo> nodesInfo;
nodesInfo.reserve(indexedSubGraph->nodes.size());

std::vector<const onnxruntime::NodeArg*> subgraphInputs;
subgraphInputs.reserve(subGraphInputArgNames.size());

std::vector<const onnxruntime::NodeArg*> subgraphOutputs;
subgraphOutputs.reserve(subGraphOutputArgNames.size());

std::vector<onnxruntime::NodeAttributes> nodeAttributes;
nodeAttributes.reserve(indexedSubGraph->nodes.size());

std::vector<std::shared_ptr<onnxruntime::NodeArg>> intermediateNodeArgs;

for (size_t sortedNodeIndex : indexedSubGraph->nodes)
{
auto node = graph.GetNode(sortedNodeIndex);

nodeAttributes.push_back(node->GetAttributes());

NodeInfo nodeInfo{};
nodeInfo.name = node->Name();
nodeInfo.opType = node->OpType();
nodeInfo.description = node->Description();
nodeInfo.domain = node->Domain();
nodeInfo.attributes = node->GetAttributes();
nodeInfo.inputDefPointers.reserve(node->InputDefs().size());
nodeInfo.outputDefPointers.reserve(node->OutputDefs().size());

for (const onnxruntime::NodeArg* inputDef : node->InputDefs())
{
intermediateNodeArgs.emplace_back(std::make_shared<onnxruntime::NodeArg>(inputDef->Name(), inputDef->TypeAsProto()));
nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get());
}

for (const onnxruntime::NodeArg* outputDef : node->OutputDefs())
{
intermediateNodeArgs.emplace_back(std::make_shared<onnxruntime::NodeArg>(outputDef->Name(), outputDef->TypeAsProto()));
nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get());
}

nodesInfo.push_back(std::move(nodeInfo));
}

for (const std::string& graphInputName : subGraphInputArgNames)
{
subgraphInputs.push_back(graph.GetNodeArg(graphInputName));
}

for (const std::string& graphOutputName : subGraphOutputArgNames)
{
subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName));
}

// We need to keep the initializers alive since they will be freed once the nodes are removed from the graph
std::vector<ONNX_NAMESPACE::TensorProto> ownedInitializers;
ownedInitializers.reserve(isInitializerTransferable.size());

for (auto& kvp : isInitializerTransferable)
{
ONNX_NAMESPACE::TensorProto tensorProto;
tensorProto.set_data_type(kvp.second.first->data_type());
tensorProto.set_raw_data(kvp.second.first->raw_data());
tensorProto.set_name(kvp.second.first->name());

for (int i = 0; i < kvp.second.first->dims_size(); ++i)
{
tensorProto.add_dims(kvp.second.first->dims(i));
}
ownedInitializers.push_back(std::move(tensorProto));
kvp.second.first = &ownedInitializers.back();
}

// lamda captures for the kernel registration
auto fused_kernel_func = [
indexedSubGraph,
&modelPath,
nodesInfo = std::move(nodesInfo),
intermediateNodeArgs = std::move(intermediateNodeArgs),
subgraphInputs = std::move(subgraphInputs),
subgraphOutputs = std::move(subgraphOutputs),
partitionNodePropsMap = std::move(partitionNodePropsMap),
ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr<onnxruntime::OpKernel>& out) mutable ->onnxruntime::Status
{
std::vector<std::shared_ptr<onnxruntime::Node>> subgraphNodes;
subgraphNodes.reserve(nodesInfo.size());

for (const NodeInfo& nodeInfo : nodesInfo)
{
subgraphNodes.emplace_back(std::make_shared<onnxruntime::Node>(
nodeInfo.name,
nodeInfo.opType,
nodeInfo.description,
nodeInfo.inputDefPointers,
nodeInfo.outputDefPointers,
&nodeInfo.attributes,
nodeInfo.domain));
}

out.reset(CreateRuntimeFusedGraphKernel(
info,
indexedSubGraph,
modelPath,
std::move(subgraphNodes),
std::move(subgraphInputs),
std::move(subgraphOutputs),
std::move(intermediateNodeArgs),
std::move(partitionNodePropsMap),
std::move(ownedInitializers)));
return Status::OK();
};

// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
onnxruntime::KernelDefBuilder builder;
builder.SetName(indexedSubGraph->GetMetaDef()->name)
.SetDomain(indexedSubGraph->GetMetaDef()->domain)
.SinceVersion(indexedSubGraph->GetMetaDef()->since_version)
.Provider(onnxruntime::kDmlExecutionProvider);

// Force the CPU inputs to be allocated on the CPU
for (int i = 0; i < subGraphInputArgNames.size(); ++i)
{
if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end())
{
builder.InputMemoryType(OrtMemTypeCPUInput, i);
}
}

ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func));

auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name);
fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider);

graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,14 @@ namespace DmlGraphFusionHelper
std::vector<uint8_t>&& isInputsUploadedByDmlEP,
const GraphDescBuilder::GraphDesc& graphDesc,
Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledExecutionPlanOperator);

void RegisterDynamicKernel(
onnxruntime::Graph& graph,
onnxruntime::KernelRegistry* registryForPartitionKernels,
const ExecutionProviderImpl* providerImpl,
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap,
const std::unordered_set<std::string>& dynamicCpuInputMap,
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>&& isInitializerTransferable);
}
}
Loading

0 comments on commit 538e97c

Please sign in to comment.