diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 87eb74490dc35..c2747d1458579 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -57,12 +57,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { CHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; - // Pre-allocate buffers on CPU for input and output entries. - DLContext ctx; - ctx.device_type = static_cast(kDLCPU); - ctx.device_id = 0; - AllocateInputOutputBuffer(ctx); - // Setup constants entries for weights. SetupConstants(consts); } @@ -71,7 +65,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // Fill in the input buffers. for (size_t i = 0; i < input_nodes_.size(); ++i) { auto eid = EntryID(input_nodes_[i], 0); - // TODO(@comanic): Support other data lengths. + // TODO(@comaniac): Support other data lengths. size_t offset_in_bytes = entry_out_mem_[eid].second * 4; size_t buffer_size = GetDataSize(*data_entry_[eid]); write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index d716929478f76..da7d8952d1503 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -80,12 +80,10 @@ class JSONRuntimeBase : public ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK(this->initialized_) << "The module has not been initialized"; - // Set inputs. - this->SetInputs(args); + // Bind argument tensors to data entries. + this->SetInputOutputBuffers(args); // Execute the subgraph. this->Run(); - // Copy result to output buffer. - this->GetOutput(args); }); } else if ("__init_" + this->symbol_name_ == name) { // The function to initialize constant tensors. @@ -134,16 +132,18 @@ class JSONRuntimeBase : public ModuleNode { protected: /*! - * \brief Set up the inputs for inference. + * \brief Set up the input and output buffers by binding their DLTensor pointers to the + * corresponding data entry. * * \param args The packed args. */ - void SetInputs(const TVMArgs& args) { + void SetInputOutputBuffers(const TVMArgs& args) { CHECK_EQ(args.size(), input_var_idx_.size() + outputs_.size()) << "Found mismatch in the number of provided data entryies and required."; - for (size_t i = 0; i < input_var_idx_.size(); i++) { - auto eid = EntryID(input_var_idx_[i], 0); + for (size_t i = 0; i < static_cast(args.size()); i++) { + auto eid = i < input_var_idx_.size() ? EntryID(input_var_idx_[i], 0) + : EntryID(outputs_[i - input_var_idx_.size()]); CHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle) << "Expect NDArray or DLTensor as inputs"; @@ -155,70 +155,9 @@ class JSONRuntimeBase : public ModuleNode { arg = args[i].operator DLTensor*(); } - size_t from_size = GetDataSize(*arg); - size_t to_size = GetDataSize(*data_entry_[eid]); - CHECK_EQ(from_size, to_size); - - if (data_entry_[eid]->ctx.device_type == arg->ctx.device_type) { - // Zero copy for input because the tensor is managed by the host. - data_entry_[eid]->data = arg->data; - } else { - NDArray::CopyFromTo(arg, data_entry_[eid]); - } - } - } - - /*! - * \brief Return the results through packed args. - * - * \param args The packed args. - */ - void GetOutput(const TVMArgs& args) { - // Copy result to output buffer. - size_t arg_idx = input_var_idx_.size(); - CHECK_EQ(args.size(), arg_idx + outputs_.size()) - << "Found mismatch in the number of provided data entryies and required."; - - for (size_t i = 0; i < outputs_.size(); i++, arg_idx++) { - auto eid = EntryID(outputs_[i]); - - if (args[arg_idx].type_code() == kTVMDLTensorHandle) { - DLTensor* arg = args[arg_idx]; - NDArray::CopyFromTo(data_entry_[eid], arg); - } else { - CHECK(args[arg_idx].IsObjectRef()); - NDArray arg = args[arg_idx]; - arg.CopyFrom(data_entry_[eid]); - } - } - } - - /*! - * \brief Pre-allocate empty buffers for input and output entries. - * - * \param ctx The context for the pre-allocated buffer. - */ - void AllocateInputOutputBuffer(const DLContext& ctx) { - for (size_t i = 0; i < input_nodes_.size(); ++i) { - auto nid = input_nodes_[i]; - auto shape = nodes_[nid].GetOpShape()[0]; - auto dtype = nodes_[nid].GetOpDataType()[0]; - DLTensor* tensor; - int ret = TVMArrayAlloc(shape.data(), shape.size(), dtype.code, dtype.bits, dtype.lanes, - ctx.device_type, ctx.device_id, &tensor); - CHECK_EQ(ret, 0) << TVMGetLastError(); - data_entry_[EntryID(nid, 0)] = tensor; - } - - for (size_t i = 0; i < outputs_.size(); ++i) { - auto entry = outputs_[i]; - auto shape = nodes_[entry.id_].GetOpShape()[entry.index_]; - auto dtype = nodes_[entry.id_].GetOpDataType()[entry.index_]; - DLTensor* tensor; - int ret = TVMArrayAlloc(shape.data(), shape.size(), dtype.code, dtype.bits, dtype.lanes, - ctx.device_type, ctx.device_id, &tensor); - CHECK_EQ(ret, 0) << TVMGetLastError(); - data_entry_[EntryID(entry)] = tensor; + // Assign input/output the NDArray pointers to data entry so that we can directly + // read/write host buffers. + data_entry_[eid] = arg; } } @@ -258,14 +197,14 @@ class JSONRuntimeBase : public ModuleNode { } /*! - * \brief Set up the constants/weights for inference. + * \brief Set up the constants/weights for inference by binding their DLTensor pointer to + * the corresponding data entry. * - * \param consts The constant to be filled. + * \param consts A list of constant NDArray to be used. */ void SetupConstants(const Array& consts) { - // Initialize consts for (size_t i = 0; i < consts.size(); ++i) { - consts[i].CopyTo(data_entry_[const_idx_[i]]); + data_entry_[const_idx_[i]] = consts[i].operator->(); } } @@ -313,7 +252,7 @@ class JSONRuntimeBase : public ModuleNode { /*! \brief Output entries. */ std::vector outputs_; /*! \brief Data of that entry. */ - std::vector data_entry_; + std::vector data_entry_; /*! \brief Map the input name to index. */ std::vector input_var_idx_; /*! \brief input const index. */