-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Runtime] Enable set_input_zero_copy in GraphRuntime #3416
Changes from all commits
cc6d164
4a84628
319d302
08e2694
aa3e1e3
60cae5e
8cca007
8f78fa5
0e92320
655b0fc
aed94b6
6f1e860
34120d6
ab5637c
ff39998
39eea1b
57c7043
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,9 @@ | |
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
|
@@ -23,6 +23,7 @@ | |
*/ | ||
#include "graph_runtime.h" | ||
|
||
#include <tvm/runtime/device_api.h> | ||
#include <tvm/runtime/ndarray.h> | ||
#include <tvm/runtime/packed_func.h> | ||
#include <tvm/runtime/registry.h> | ||
|
@@ -38,6 +39,13 @@ | |
|
||
namespace tvm { | ||
namespace runtime { | ||
namespace details { | ||
inline size_t GetDataAlignment(const DLTensor& arr) { | ||
size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes; | ||
if (align < kAllocAlignment) return kAllocAlignment; | ||
return align; | ||
} | ||
} // namespace details | ||
|
||
/*! | ||
* \brief Run all the operations one by one. | ||
|
@@ -96,6 +104,39 @@ void GraphRuntime::SetInput(int index, DLTensor* data_in) { | |
uint32_t eid = this->entry_id(input_nodes_[index], 0); | ||
data_entry_[eid].CopyFrom(data_in); | ||
} | ||
/*! | ||
* \brief set index-th input to the graph without copying the data. | ||
* \param index The input index. | ||
* \param data_ref The input data that is referred. | ||
*/ | ||
void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) { | ||
CHECK_LT(static_cast<size_t>(index), input_nodes_.size()); | ||
uint32_t eid = this->entry_id(input_nodes_[index], 0); | ||
const DLTensor* old_t = data_entry_[eid].operator->(); | ||
|
||
// check the consistency of input | ||
CHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*data_ref)); | ||
CHECK_EQ(reinterpret_cast<size_t>(data_ref->data) % kAllocAlignment, 0); | ||
CHECK_EQ(old_t->ndim, static_cast<size_t>(data_ref->ndim)); | ||
CHECK_EQ(old_t->ctx.device_type, data_ref->ctx.device_type); | ||
CHECK_EQ(old_t->ctx.device_id, data_ref->ctx.device_id); | ||
for (auto i = 0; i < data_ref->ndim; ++i) { | ||
CHECK_EQ(old_t->shape[i], data_ref->shape[i]); | ||
} | ||
|
||
// Update the data pointer for each argument of each op | ||
for (auto& op_arg : op_args_) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm thinking, maybe we can change the API to accept an array of DLTensors, update the entries in the data_entries_ for all of them, and then call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had an implementation like this but it turned out the be more difficult.
|
||
if (op_arg) { | ||
const auto it = op_arg->input_entry_ids.find(eid); | ||
if (it != op_arg->input_entry_ids.end()) { | ||
for (const auto i : it->second) { | ||
DLTensor* t = static_cast<DLTensor*>(op_arg->arg_values[i].v_handle); | ||
t->data = data_ref->data; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
/*! | ||
* \brief Get the number of outputs | ||
* | ||
|
@@ -184,7 +225,7 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { | |
} | ||
} | ||
|
||
void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { | ||
void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { | ||
uint64_t header, reserved; | ||
CHECK(strm->Read(&header)) | ||
<< "Invalid parameters file format"; | ||
|
@@ -206,6 +247,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { | |
CHECK_EQ(data_entry_[eid].use_count(), 1); | ||
data_entry_[eid] = other.GetInput(GetInputIndex(names[i])); | ||
CHECK_GT(data_entry_[eid].use_count(), 1); | ||
const DLTensor* tmp = data_entry_[eid].operator->(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I met with tests failure if I didn't do this. |
||
data_alignment_[eid] = details::GetDataAlignment(*tmp); | ||
} | ||
this->SetupOpExecs(); | ||
} | ||
|
@@ -268,53 +311,65 @@ void GraphRuntime::SetupStorage() { | |
// memory assignment for each node entry. The allocated memory on each device | ||
// is mapped to this pool. | ||
data_entry_.resize(num_node_entries()); | ||
data_alignment_.resize(num_node_entries()); | ||
for (size_t i = 0; i < data_entry_.size(); ++i) { | ||
int storage_id = attrs_.storage_id[i]; | ||
CHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size()); | ||
data_entry_[i] = | ||
storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); | ||
const DLTensor* tmp = data_entry_[i].operator->(); | ||
data_alignment_[i] = details::GetDataAlignment(*tmp); | ||
} | ||
} | ||
|
||
void GraphRuntime::SetupOpExecs() { | ||
op_execs_.resize(this->GetNumOfNodes()); | ||
op_args_.resize(this->GetNumOfNodes()); | ||
// setup the array and requirements. | ||
for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) { | ||
const auto& inode = nodes_[nid]; | ||
if (inode.op_type == "null") continue; | ||
std::vector<DLTensor> args; | ||
std::vector<uint32_t> input_entry_ids; | ||
for (const auto& e : inode.inputs) { | ||
args.push_back(*(data_entry_[this->entry_id(e)].operator->())); | ||
uint32_t eid = this->entry_id(e); | ||
args.push_back(*(data_entry_[eid].operator->())); | ||
input_entry_ids.push_back(eid); | ||
} | ||
for (uint32_t index = 0; index < inode.param.num_outputs; ++index) { | ||
uint32_t eid = this->entry_id(nid, index); | ||
args.push_back(*(data_entry_[eid].operator->())); | ||
} | ||
CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op"; | ||
|
||
op_execs_[nid] = CreateTVMOp(inode.param, args, inode.inputs.size()); | ||
std::tie(op_execs_[nid], op_args_[nid]) = | ||
CreateTVMOp(inode.param, args, inode.inputs.size()); | ||
auto& entry_to_input_pos = op_args_[nid]->input_entry_ids; | ||
for (uint32_t i = 0; i < input_entry_ids.size(); ++i) { | ||
const auto eid = input_entry_ids[i]; | ||
auto it = entry_to_input_pos.find(eid); | ||
if (it == entry_to_input_pos.end()) { | ||
entry_to_input_pos.emplace(eid, std::vector<uint32_t>{i}); | ||
} else { | ||
it->second.push_back(i); | ||
} | ||
} | ||
} | ||
} | ||
|
||
std::function<void()> GraphRuntime::CreateTVMOp( | ||
std::pair<std::function<void()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRuntime::CreateTVMOp( | ||
const TVMOpParam& param, | ||
const std::vector<DLTensor>& args, | ||
size_t num_inputs) { | ||
struct OpArgs { | ||
std::vector<DLTensor> args; | ||
std::vector<TVMValue> arg_values; | ||
std::vector<int> arg_tcodes; | ||
std::vector<int64_t> shape_data; | ||
}; | ||
std::shared_ptr<OpArgs> arg_ptr = std::make_shared<OpArgs>(); | ||
std::shared_ptr<GraphRuntime::OpArgs> arg_ptr = std::make_shared<GraphRuntime::OpArgs>(); | ||
// setup address. | ||
arg_ptr->args = std::move(args); | ||
arg_ptr->args = args; | ||
if (param.flatten_data) { | ||
arg_ptr->shape_data.resize(arg_ptr->args.size()); | ||
} | ||
for (size_t i = 0; i < arg_ptr->args.size(); ++i) { | ||
TVMValue v; | ||
DLTensor* t = &(arg_ptr->args[i]); | ||
DLTensor* t = &arg_ptr->args[i]; | ||
v.v_handle = t; | ||
arg_ptr->arg_values.push_back(v); | ||
arg_ptr->arg_tcodes.push_back(kArrayHandle); | ||
|
@@ -327,7 +382,7 @@ std::function<void()> GraphRuntime::CreateTVMOp( | |
} | ||
|
||
if (param.func_name == "__nop") { | ||
return [](){}; | ||
return {[](){}, arg_ptr}; | ||
} else if (param.func_name == "__copy") { | ||
// Perform cross device data copy. | ||
// Directly copy data from the input to the output. | ||
|
@@ -336,7 +391,7 @@ std::function<void()> GraphRuntime::CreateTVMOp( | |
DLTensor* to = static_cast<DLTensor*>(arg_ptr->arg_values[1].v_handle); | ||
TVM_CCALL(TVMArrayCopyFromTo(from, to, nullptr)); | ||
}; | ||
return fexec; | ||
return {fexec, arg_ptr}; | ||
} | ||
|
||
// Get compiled function from the module that contains both host and device | ||
|
@@ -351,7 +406,7 @@ std::function<void()> GraphRuntime::CreateTVMOp( | |
static_cast<int>(arg_ptr->arg_values.size())); | ||
pf.CallPacked(targs, &rv); | ||
}; | ||
return fexec; | ||
return {fexec, arg_ptr}; | ||
} | ||
|
||
PackedFunc GraphRuntime::GetFunction( | ||
|
@@ -367,14 +422,23 @@ PackedFunc GraphRuntime::GetFunction( | |
this->SetInput(args[0], args[1]); | ||
} | ||
}); | ||
} else if (name == "set_input_zero_copy") { | ||
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { | ||
if (args[0].type_code() == kStr) { | ||
int in_idx = this->GetInputIndex(args[0]); | ||
if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]); | ||
} else { | ||
this->SetInputZeroCopy(args[0], args[1]); | ||
} | ||
}); | ||
} else if (name == "get_output") { | ||
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { | ||
if (args.num_args == 2) { | ||
this->CopyOutputTo(args[0], args[1]); | ||
} else { | ||
*rv = this->GetOutput(args[0]); | ||
} | ||
}); | ||
if (args.num_args == 2) { | ||
this->CopyOutputTo(args[0], args[1]); | ||
} else { | ||
*rv = this->GetOutput(args[0]); | ||
} | ||
}); | ||
} else if (name == "get_input") { | ||
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { | ||
int in_idx = 0; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check if the
device_type
anddevice_id
match as well, for the heterogenous case.