diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index cc37a8591016..26e1d842ed05 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -23,6 +23,7 @@ */ #include "graph_runtime.h" +#include #include #include #include @@ -38,6 +39,13 @@ namespace tvm { namespace runtime { +namespace details { +inline size_t GetDataAlignment(const DLTensor& arr) { + size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes; + if (align < kAllocAlignment) return kAllocAlignment; + return align; +} +} // namespace details /*! * \brief Run all the operations one by one. @@ -96,6 +104,39 @@ void GraphRuntime::SetInput(int index, DLTensor* data_in) { uint32_t eid = this->entry_id(input_nodes_[index], 0); data_entry_[eid].CopyFrom(data_in); } +/*! + * \brief set index-th input to the graph without copying the data. + * \param index The input index. + * \param data_ref The input data that is referred. + */ +void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) { + CHECK_LT(static_cast(index), input_nodes_.size()); + uint32_t eid = this->entry_id(input_nodes_[index], 0); + const DLTensor* old_t = data_entry_[eid].operator->(); + + // check the consistency of input + CHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*data_ref)); + CHECK_EQ(reinterpret_cast(data_ref->data) % kAllocAlignment, 0); + CHECK_EQ(old_t->ndim, static_cast(data_ref->ndim)); + CHECK_EQ(old_t->ctx.device_type, data_ref->ctx.device_type); + CHECK_EQ(old_t->ctx.device_id, data_ref->ctx.device_id); + for (auto i = 0; i < data_ref->ndim; ++i) { + CHECK_EQ(old_t->shape[i], data_ref->shape[i]); + } + + // Update the data pointer for each argument of each op + for (auto& op_arg : op_args_) { + if (op_arg) { + const auto it = op_arg->input_entry_ids.find(eid); + if (it != op_arg->input_entry_ids.end()) { + for (const auto i : it->second) { + DLTensor* t = static_cast(op_arg->arg_values[i].v_handle); + t->data = data_ref->data; + } + } + } + } +} /*! * \brief Get the number of outputs * @@ -184,7 +225,7 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { } } - void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { +void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { uint64_t header, reserved; CHECK(strm->Read(&header)) << "Invalid parameters file format"; @@ -206,6 +247,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { CHECK_EQ(data_entry_[eid].use_count(), 1); data_entry_[eid] = other.GetInput(GetInputIndex(names[i])); CHECK_GT(data_entry_[eid].use_count(), 1); + const DLTensor* tmp = data_entry_[eid].operator->(); + data_alignment_[eid] = details::GetDataAlignment(*tmp); } this->SetupOpExecs(); } @@ -268,23 +311,30 @@ void GraphRuntime::SetupStorage() { // memory assignment for each node entry. The allocated memory on each device // is mapped to this pool. data_entry_.resize(num_node_entries()); + data_alignment_.resize(num_node_entries()); for (size_t i = 0; i < data_entry_.size(); ++i) { int storage_id = attrs_.storage_id[i]; CHECK_LT(static_cast(storage_id), storage_pool_.size()); data_entry_[i] = storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); + const DLTensor* tmp = data_entry_[i].operator->(); + data_alignment_[i] = details::GetDataAlignment(*tmp); } } void GraphRuntime::SetupOpExecs() { op_execs_.resize(this->GetNumOfNodes()); + op_args_.resize(this->GetNumOfNodes()); // setup the array and requirements. for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) { const auto& inode = nodes_[nid]; if (inode.op_type == "null") continue; std::vector args; + std::vector input_entry_ids; for (const auto& e : inode.inputs) { - args.push_back(*(data_entry_[this->entry_id(e)].operator->())); + uint32_t eid = this->entry_id(e); + args.push_back(*(data_entry_[eid].operator->())); + input_entry_ids.push_back(eid); } for (uint32_t index = 0; index < inode.param.num_outputs; ++index) { uint32_t eid = this->entry_id(nid, index); @@ -292,29 +342,34 @@ void GraphRuntime::SetupOpExecs() { } CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op"; - op_execs_[nid] = CreateTVMOp(inode.param, args, inode.inputs.size()); + std::tie(op_execs_[nid], op_args_[nid]) = + CreateTVMOp(inode.param, args, inode.inputs.size()); + auto& entry_to_input_pos = op_args_[nid]->input_entry_ids; + for (uint32_t i = 0; i < input_entry_ids.size(); ++i) { + const auto eid = input_entry_ids[i]; + auto it = entry_to_input_pos.find(eid); + if (it == entry_to_input_pos.end()) { + entry_to_input_pos.emplace(eid, std::vector{i}); + } else { + it->second.push_back(i); + } + } } } -std::function GraphRuntime::CreateTVMOp( +std::pair, std::shared_ptr > GraphRuntime::CreateTVMOp( const TVMOpParam& param, const std::vector& args, size_t num_inputs) { - struct OpArgs { - std::vector args; - std::vector arg_values; - std::vector arg_tcodes; - std::vector shape_data; - }; - std::shared_ptr arg_ptr = std::make_shared(); + std::shared_ptr arg_ptr = std::make_shared(); // setup address. - arg_ptr->args = std::move(args); + arg_ptr->args = args; if (param.flatten_data) { arg_ptr->shape_data.resize(arg_ptr->args.size()); } for (size_t i = 0; i < arg_ptr->args.size(); ++i) { TVMValue v; - DLTensor* t = &(arg_ptr->args[i]); + DLTensor* t = &arg_ptr->args[i]; v.v_handle = t; arg_ptr->arg_values.push_back(v); arg_ptr->arg_tcodes.push_back(kArrayHandle); @@ -327,7 +382,7 @@ std::function GraphRuntime::CreateTVMOp( } if (param.func_name == "__nop") { - return [](){}; + return {[](){}, arg_ptr}; } else if (param.func_name == "__copy") { // Perform cross device data copy. // Directly copy data from the input to the output. @@ -336,7 +391,7 @@ std::function GraphRuntime::CreateTVMOp( DLTensor* to = static_cast(arg_ptr->arg_values[1].v_handle); TVM_CCALL(TVMArrayCopyFromTo(from, to, nullptr)); }; - return fexec; + return {fexec, arg_ptr}; } // Get compiled function from the module that contains both host and device @@ -351,7 +406,7 @@ std::function GraphRuntime::CreateTVMOp( static_cast(arg_ptr->arg_values.size())); pf.CallPacked(targs, &rv); }; - return fexec; + return {fexec, arg_ptr}; } PackedFunc GraphRuntime::GetFunction( @@ -367,14 +422,23 @@ PackedFunc GraphRuntime::GetFunction( this->SetInput(args[0], args[1]); } }); + } else if (name == "set_input_zero_copy") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (args[0].type_code() == kStr) { + int in_idx = this->GetInputIndex(args[0]); + if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]); + } else { + this->SetInputZeroCopy(args[0], args[1]); + } + }); } else if (name == "get_output") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (args.num_args == 2) { - this->CopyOutputTo(args[0], args[1]); - } else { - *rv = this->GetOutput(args[0]); - } - }); + if (args.num_args == 2) { + this->CopyOutputTo(args[0], args[1]); + } else { + *rv = this->GetOutput(args[0]); + } + }); } else if (name == "get_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { int in_idx = 0; diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index e3f5815950f7..ddfd4a587148 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -34,6 +34,7 @@ #include #include +#include #include #include #include @@ -67,6 +68,14 @@ struct TVMOpParam { * TVM runtime PackedFunc API. */ class GraphRuntime : public ModuleNode { + struct OpArgs { + std::vector args; + std::unordered_map > input_entry_ids; + std::vector arg_values; + std::vector arg_tcodes; + std::vector shape_data; + }; + public: /*! * \brief Get member function to front-end @@ -111,6 +120,12 @@ class GraphRuntime : public ModuleNode { * \param data_in The input data. */ void SetInput(int index, DLTensor* data_in); + /*! + * \brief set index-th input to the graph without copying the data + * \param index The input index. + * \param data_ref The input data that is referred. + */ + void SetInputZeroCopy(int index, DLTensor* data_ref); /*! * \brief Get the number of outputs * @@ -365,9 +380,9 @@ class GraphRuntime : public ModuleNode { * \param num_inputs Number of inputs. * \return The created executor. */ - std::function CreateTVMOp(const TVMOpParam& attrs, - const std::vector& args, - size_t num_inputs); + std::pair, std::shared_ptr > CreateTVMOp( + const TVMOpParam& attrs, const std::vector& args, + size_t num_inputs); // Get node entry index. uint32_t entry_id(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; @@ -398,8 +413,12 @@ class GraphRuntime : public ModuleNode { std::vector storage_pool_; /*! \brief Data entry of each node. */ std::vector data_entry_; + /*! \brief Data alignment of each node. */ + std::vector data_alignment_; /*! \brief Operator on each node. */ std::vector > op_execs_; + /*! \brief Arg info of TVM ops */ + std::vector > op_args_; }; std::vector GetAllContext(const TVMArgs& args); diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 39c17b8b3a81..95da100c4a7d 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 3f46eed9f10e..98dd753803d7 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -84,18 +84,41 @@ TEST(Relay, BuildModule) { auto ctx = A->ctx; auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create"); tvm::runtime::Module run_mod = (*pfr)(json, mod, (int)ctx.device_type, (int)ctx.device_id); - auto set_input_f = run_mod.GetFunction("set_input", false); + auto set_input_f = run_mod.GetFunction("set_input_zero_copy", false); auto run_f = run_mod.GetFunction("run", false); auto get_output_f = run_mod.GetFunction("get_output", false); - set_input_f("a", A); - set_input_f("b", B); - set_input_f("c", C); + set_input_f("a", &A.ToDLPack()->dl_tensor); + set_input_f("b", &B.ToDLPack()->dl_tensor); + set_input_f("c", &C.ToDLPack()->dl_tensor); run_f(); tvm::runtime::NDArray Y = get_output_f(0); auto pY = (float*)Y.ToDLPack()->dl_tensor.data; for (int i = 0; i < 6; ++i) { CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); } + // mutate the input a bit and run it again + for (int i = 0; i < 6; ++i) { + pB[i] = i + 3; + } + run_f(); + tvm::runtime::NDArray Y2 = get_output_f(0); + auto pY2 = (float*)Y2.ToDLPack()->dl_tensor.data; + for (int i = 0; i < 6; ++i) { + CHECK_LT(fabs(pY2[i] - (i + (i + 3) + (i + 2))), 1e-4); + } + // attach a different input and run it again + auto C2 = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto pC2 = (float*)C2.ToDLPack()->dl_tensor.data; + for (int i = 0; i < 6; ++i) { + pC2[i] = i + 4; + } + set_input_f("c", &C2.ToDLPack()->dl_tensor); + run_f(); + tvm::runtime::NDArray Y3 = get_output_f(0); + auto pY3 = (float*)Y3.ToDLPack()->dl_tensor.data; + for (int i = 0; i < 6; ++i) { + CHECK_LT(fabs(pY3[i] - (i + (i + 3) + (i + 4))), 1e-4); + } } int main(int argc, char ** argv) {