From 2db0d3ae81c90773ff086fc8af86a17bbf2d68ff Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 24 Aug 2016 18:40:06 -0700 Subject: [PATCH] updates (#25) * [FIX] Remove extra move * [MEMORY] Add inplace index --- nnvm/src/c_api/c_api_symbolic.cc | 8 ++++---- nnvm/src/pass/order_mutation.cc | 5 +---- nnvm/src/pass/plan_memory.cc | 8 ++++++-- nnvm/src/pass/saveload_json.cc | 2 +- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/nnvm/src/c_api/c_api_symbolic.cc b/nnvm/src/c_api/c_api_symbolic.cc index aabfca4795d2..3dbb816d1729 100644 --- a/nnvm/src/c_api/c_api_symbolic.cc +++ b/nnvm/src/c_api/c_api_symbolic.cc @@ -160,7 +160,7 @@ int NNSymbolListAttrs(SymbolHandle symbol, NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); std::unordered_map attr = - std::move(s->ListAttrs(static_cast(option))); // NOLINT(*) + s->ListAttrs(static_cast(option)); // NOLINT(*) std::vector& attr_list = ret->ret_vec_str; attr_list.clear(); @@ -184,8 +184,8 @@ int NNSymbolListInputNames(SymbolHandle symbol, Symbol *s = static_cast(symbol); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); - ret->ret_vec_str = std::move( - s->ListInputNames(Symbol::ListInputOption(option))); + ret->ret_vec_str = + s->ListInputNames(Symbol::ListInputOption(option)); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); @@ -201,7 +201,7 @@ int NNSymbolListOutputNames(SymbolHandle symbol, Symbol *s = static_cast(symbol); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); - ret->ret_vec_str = std::move(s->ListOutputNames()); + ret->ret_vec_str = s->ListOutputNames(); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); diff --git a/nnvm/src/pass/order_mutation.cc b/nnvm/src/pass/order_mutation.cc index c554946c6245..3bcfd9922d53 100644 --- a/nnvm/src/pass/order_mutation.cc +++ b/nnvm/src/pass/order_mutation.cc @@ -22,10 +22,7 @@ inline T get_with_default(const std::unordered_map &map, } inline bool IsMutate(const std::vector& mutate_inputs, uint32_t i) { - if (mutate_inputs.size() == 0) return false; - auto it = std::lower_bound( - mutate_inputs.begin(), mutate_inputs.end(), i); - return (it != mutate_inputs.end()) && (*it == i); + return std::binary_search(mutate_inputs.begin(), mutate_inputs.end(), i); } Graph OrderMutation(const Graph& src) { diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index 2d57b5c78f6b..14a88d217de8 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -150,6 +150,7 @@ Graph PlanMemory(Graph ret) { } // step 2: allocate memory. StorageVector storage(idx.num_node_entries(), -1); + std::vector storage_inplace_index(idx.num_node_entries(), -1); const ShapeVector& shape_vec = ret.GetAttr("shape"); const DTypeVector& dtype_vec = ret.GetAttr("dtype"); const DeviceVector* device_vec = nullptr; @@ -173,8 +174,10 @@ Graph PlanMemory(Graph ret) { uint32_t eid_out = idx.entry_id(nid, kv.second); uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]); if (ref_count[eid_in] == 1 && storage[eid_in] != GraphAllocator::kBadStorageID) { + // inplace optimization storage[eid_out] = storage[eid_in]; ref_count[eid_in] = 0; + storage_inplace_index[eid_out] = kv.first; } } } @@ -209,8 +212,8 @@ Graph PlanMemory(Graph ret) { } } } - ret.attrs["storage_id"] = std::make_shared(std::move(storage)); + ret.attrs["storage_inplace_index"] = std::make_shared(std::move(storage_inplace_index)); ret.attrs["storage_allocated_bytes"] = std::make_shared(allocator.TotalAllocBytes()); ret.attrs["storage_num_not_allocated"] = std::make_shared(num_not_allocated); return ret; @@ -222,7 +225,8 @@ NNVM_REGISTER_PASS(PlanMemory) .set_change_graph(false) .depend_graph_attr("dtype") .depend_graph_attr("shape") -.provide_graph_attr("storage_id"); +.provide_graph_attr("storage_id") +.provide_graph_attr("storage_inplace_index"); } // namespace } // namespace pass diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index bd22d807e1fe..984a2c9905c4 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -89,7 +89,7 @@ struct JSONNode { } void Load(dmlc::JSONReader *reader) { - node = std::move(Node::Create()); + node = Node::Create(); control_deps.clear(); dmlc::JSONObjectReadHelper helper; std::string op_type_str;