Skip to content

Commit

Permalink
zero copy for all data entries
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored and zhiics committed Jun 25, 2020
1 parent e04227d commit db5db74
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 84 deletions.
8 changes: 1 addition & 7 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
CHECK_EQ(consts.size(), const_idx_.size())
<< "The number of input constants must match the number of required.";

// Pre-allocate buffers on CPU for input and output entries.
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(kDLCPU);
ctx.device_id = 0;
AllocateInputOutputBuffer(ctx);

// Setup constants entries for weights.
SetupConstants(consts);
}
Expand All @@ -71,7 +65,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// Fill in the input buffers.
for (size_t i = 0; i < input_nodes_.size(); ++i) {
auto eid = EntryID(input_nodes_[i], 0);
// TODO(@comanic): Support other data lengths.
// TODO(@comaniac): Support other data lengths.
size_t offset_in_bytes = entry_out_mem_[eid].second * 4;
size_t buffer_size = GetDataSize(*data_entry_[eid]);
write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size,
Expand Down
93 changes: 16 additions & 77 deletions src/runtime/contrib/json/json_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,10 @@ class JSONRuntimeBase : public ModuleNode {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(this->initialized_) << "The module has not been initialized";

// Set inputs.
this->SetInputs(args);
// Bind argument tensors to data entries.
this->SetInputOutputBuffers(args);
// Execute the subgraph.
this->Run();
// Copy result to output buffer.
this->GetOutput(args);
});
} else if ("__init_" + this->symbol_name_ == name) {
// The function to initialize constant tensors.
Expand Down Expand Up @@ -134,16 +132,18 @@ class JSONRuntimeBase : public ModuleNode {

protected:
/*!
* \brief Set up the inputs for inference.
* \brief Set up the input and output buffers by binding their DLTensor pointers to the
* corresponding data entry.
*
* \param args The packed args.
*/
void SetInputs(const TVMArgs& args) {
void SetInputOutputBuffers(const TVMArgs& args) {
CHECK_EQ(args.size(), input_var_idx_.size() + outputs_.size())
<< "Found mismatch in the number of provided data entryies and required.";

for (size_t i = 0; i < input_var_idx_.size(); i++) {
auto eid = EntryID(input_var_idx_[i], 0);
for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
auto eid = i < input_var_idx_.size() ? EntryID(input_var_idx_[i], 0)
: EntryID(outputs_[i - input_var_idx_.size()]);
CHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle)
<< "Expect NDArray or DLTensor as inputs";

Expand All @@ -155,70 +155,9 @@ class JSONRuntimeBase : public ModuleNode {
arg = args[i].operator DLTensor*();
}

size_t from_size = GetDataSize(*arg);
size_t to_size = GetDataSize(*data_entry_[eid]);
CHECK_EQ(from_size, to_size);

if (data_entry_[eid]->ctx.device_type == arg->ctx.device_type) {
// Zero copy for input because the tensor is managed by the host.
data_entry_[eid]->data = arg->data;
} else {
NDArray::CopyFromTo(arg, data_entry_[eid]);
}
}
}

/*!
* \brief Return the results through packed args.
*
* \param args The packed args.
*/
void GetOutput(const TVMArgs& args) {
// Copy result to output buffer.
size_t arg_idx = input_var_idx_.size();
CHECK_EQ(args.size(), arg_idx + outputs_.size())
<< "Found mismatch in the number of provided data entryies and required.";

for (size_t i = 0; i < outputs_.size(); i++, arg_idx++) {
auto eid = EntryID(outputs_[i]);

if (args[arg_idx].type_code() == kTVMDLTensorHandle) {
DLTensor* arg = args[arg_idx];
NDArray::CopyFromTo(data_entry_[eid], arg);
} else {
CHECK(args[arg_idx].IsObjectRef<NDArray>());
NDArray arg = args[arg_idx];
arg.CopyFrom(data_entry_[eid]);
}
}
}

/*!
* \brief Pre-allocate empty buffers for input and output entries.
*
* \param ctx The context for the pre-allocated buffer.
*/
void AllocateInputOutputBuffer(const DLContext& ctx) {
for (size_t i = 0; i < input_nodes_.size(); ++i) {
auto nid = input_nodes_[i];
auto shape = nodes_[nid].GetOpShape()[0];
auto dtype = nodes_[nid].GetOpDataType()[0];
DLTensor* tensor;
int ret = TVMArrayAlloc(shape.data(), shape.size(), dtype.code, dtype.bits, dtype.lanes,
ctx.device_type, ctx.device_id, &tensor);
CHECK_EQ(ret, 0) << TVMGetLastError();
data_entry_[EntryID(nid, 0)] = tensor;
}

for (size_t i = 0; i < outputs_.size(); ++i) {
auto entry = outputs_[i];
auto shape = nodes_[entry.id_].GetOpShape()[entry.index_];
auto dtype = nodes_[entry.id_].GetOpDataType()[entry.index_];
DLTensor* tensor;
int ret = TVMArrayAlloc(shape.data(), shape.size(), dtype.code, dtype.bits, dtype.lanes,
ctx.device_type, ctx.device_id, &tensor);
CHECK_EQ(ret, 0) << TVMGetLastError();
data_entry_[EntryID(entry)] = tensor;
// Assign input/output the NDArray pointers to data entry so that we can directly
// read/write host buffers.
data_entry_[eid] = arg;
}
}

Expand Down Expand Up @@ -258,14 +197,14 @@ class JSONRuntimeBase : public ModuleNode {
}

/*!
* \brief Set up the constants/weights for inference.
* \brief Set up the constants/weights for inference by binding their DLTensor pointer to
* the corresponding data entry.
*
* \param consts The constant to be filled.
* \param consts A list of constant NDArray to be used.
*/
void SetupConstants(const Array<NDArray>& consts) {
// Initialize consts
for (size_t i = 0; i < consts.size(); ++i) {
consts[i].CopyTo(data_entry_[const_idx_[i]]);
data_entry_[const_idx_[i]] = consts[i].operator->();
}
}

Expand Down Expand Up @@ -313,7 +252,7 @@ class JSONRuntimeBase : public ModuleNode {
/*! \brief Output entries. */
std::vector<JSONGraphNodeEntry> outputs_;
/*! \brief Data of that entry. */
std::vector<DLTensor*> data_entry_;
std::vector<const DLTensor*> data_entry_;
/*! \brief Map the input name to index. */
std::vector<uint32_t> input_var_idx_;
/*! \brief input const index. */
Expand Down

0 comments on commit db5db74

Please sign in to comment.