diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 697e52bf6373..bf31f1fb42af 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -31,11 +31,12 @@ #include #include +#include #include -#include #include -#include +#include #include +#include namespace tvm { namespace runtime { @@ -78,6 +79,11 @@ void GraphRuntime::Init(const std::string& graph_json, ctxs_ = ctxs; this->SetupStorage(); this->SetupOpExecs(); + for (size_t i = 0; i < input_nodes_.size(); i++) { + const uint32_t nid = input_nodes_[i]; + std::string& name = nodes_[nid].name; + input_map_[name] = i; + } } /*! * \brief Get the input index given the name of input. @@ -85,11 +91,9 @@ void GraphRuntime::Init(const std::string& graph_json, * \return The index of input. */ int GraphRuntime::GetInputIndex(const std::string& name) { - for (size_t i = 0; i< input_nodes_.size(); ++i) { - uint32_t nid = input_nodes_[i]; - if (nodes_[nid].name == name) { - return static_cast(i); - } + auto it = input_map_.find(name); + if (it != input_map_.end()) { + return it->second; } LOG(WARNING) << "Warning: cannot find \"" << name << "\" among input"; return -1; @@ -152,16 +156,8 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) { } // Update the data pointer for each argument of each op - for (auto& op_arg : op_args_) { - if (op_arg) { - const auto it = op_arg->input_entry_ids.find(eid); - if (it != op_arg->input_entry_ids.end()) { - for (const auto i : it->second) { - DLTensor* t = static_cast(op_arg->arg_values[i].v_handle); - t->data = data_ref->data; - } - } - } + for (DLTensor* t : input_dltensors_[eid]) { + t->data = data_ref->data; } } /*! @@ -350,17 +346,21 @@ void GraphRuntime::SetupStorage() { void GraphRuntime::SetupOpExecs() { op_execs_.resize(this->GetNumOfNodes()); - op_args_.resize(this->GetNumOfNodes()); + input_dltensors_.resize(num_node_entries()); + std::unordered_set input_node_eids; + for (size_t i = 0; i < input_nodes_.size(); i++) { + uint32_t nid = input_nodes_[i]; + input_node_eids.insert(entry_id(nid, 0)); + } + // setup the array and requirements. for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) { const auto& inode = nodes_[nid]; if (inode.op_type == "null") continue; std::vector args; - std::vector input_entry_ids; for (const auto& e : inode.inputs) { uint32_t eid = this->entry_id(e); args.push_back(*(data_entry_[eid].operator->())); - input_entry_ids.push_back(eid); } for (uint32_t index = 0; index < inode.param.num_outputs; ++index) { uint32_t eid = this->entry_id(nid, index); @@ -368,16 +368,16 @@ void GraphRuntime::SetupOpExecs() { } if (inode.op_type == "tvm_op") { - std::tie(op_execs_[nid], op_args_[nid]) = + std::shared_ptr op_args = nullptr; + std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args, inode.inputs.size()); - auto& entry_to_input_pos = op_args_[nid]->input_entry_ids; - for (uint32_t i = 0; i < input_entry_ids.size(); ++i) { - const auto eid = input_entry_ids[i]; - auto it = entry_to_input_pos.find(eid); - if (it == entry_to_input_pos.end()) { - entry_to_input_pos.emplace(eid, std::vector{i}); - } else { - it->second.push_back(i); + + for (size_t i = 0; i < inode.inputs.size(); i++) { + uint32_t eid = this->entry_id(inode.inputs[i]); + // check if op input is model input + if (input_node_eids.count(eid) > 0) { + input_dltensors_[eid].push_back( + static_cast(op_args->arg_values[i].v_handle)); } } } else if (inode.op_type == "_tensorrt_subgraph_op") { diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index cf0f1b0f2c8c..91908f2241c0 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -75,7 +75,6 @@ struct TVMOpParam { class GraphRuntime : public ModuleNode { struct OpArgs { std::vector args; - std::unordered_map > input_entry_ids; std::vector arg_values; std::vector arg_tcodes; std::vector shape_data; @@ -442,6 +441,10 @@ class GraphRuntime : public ModuleNode { std::vector nodes_; /*! \brief The argument nodes. */ std::vector input_nodes_; + /*! \brief Map of input names to input indices. */ + std::unordered_map input_map_; + /*! \brief Used for quick node input DLTensor* lookup given an input eid. */ + std::vector> input_dltensors_; /*! \brief Used for quick entry indexing. */ std::vector node_row_ptr_; /*! \brief Output entries. */