Skip to content

Commit

Permalink
updates (apache#25)
Browse files Browse the repository at this point in the history
* [FIX] Remove extra move

* [MEMORY] Add inplace index
  • Loading branch information
tqchen committed May 26, 2018
1 parent e83aa14 commit 2db0d3a
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 11 deletions.
8 changes: 4 additions & 4 deletions nnvm/src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ int NNSymbolListAttrs(SymbolHandle symbol,
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
std::unordered_map<std::string, std::string> attr =
std::move(s->ListAttrs(static_cast<Symbol::ListAttrOption>(option))); // NOLINT(*)
s->ListAttrs(static_cast<Symbol::ListAttrOption>(option)); // NOLINT(*)

std::vector<std::string>& attr_list = ret->ret_vec_str;
attr_list.clear();
Expand All @@ -184,8 +184,8 @@ int NNSymbolListInputNames(SymbolHandle symbol,
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str = std::move(
s->ListInputNames(Symbol::ListInputOption(option)));
ret->ret_vec_str =
s->ListInputNames(Symbol::ListInputOption(option));
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
Expand All @@ -201,7 +201,7 @@ int NNSymbolListOutputNames(SymbolHandle symbol,
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str = std::move(s->ListOutputNames());
ret->ret_vec_str = s->ListOutputNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
Expand Down
5 changes: 1 addition & 4 deletions nnvm/src/pass/order_mutation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
}

inline bool IsMutate(const std::vector<uint32_t>& mutate_inputs, uint32_t i) {
if (mutate_inputs.size() == 0) return false;
auto it = std::lower_bound(
mutate_inputs.begin(), mutate_inputs.end(), i);
return (it != mutate_inputs.end()) && (*it == i);
return std::binary_search(mutate_inputs.begin(), mutate_inputs.end(), i);
}

Graph OrderMutation(const Graph& src) {
Expand Down
8 changes: 6 additions & 2 deletions nnvm/src/pass/plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Graph PlanMemory(Graph ret) {
}
// step 2: allocate memory.
StorageVector storage(idx.num_node_entries(), -1);
std::vector<int> storage_inplace_index(idx.num_node_entries(), -1);
const ShapeVector& shape_vec = ret.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = ret.GetAttr<DTypeVector>("dtype");
const DeviceVector* device_vec = nullptr;
Expand All @@ -173,8 +174,10 @@ Graph PlanMemory(Graph ret) {
uint32_t eid_out = idx.entry_id(nid, kv.second);
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);
if (ref_count[eid_in] == 1 && storage[eid_in] != GraphAllocator::kBadStorageID) {
// inplace optimization
storage[eid_out] = storage[eid_in];
ref_count[eid_in] = 0;
storage_inplace_index[eid_out] = kv.first;
}
}
}
Expand Down Expand Up @@ -209,8 +212,8 @@ Graph PlanMemory(Graph ret) {
}
}
}

ret.attrs["storage_id"] = std::make_shared<any>(std::move(storage));
ret.attrs["storage_inplace_index"] = std::make_shared<any>(std::move(storage_inplace_index));
ret.attrs["storage_allocated_bytes"] = std::make_shared<any>(allocator.TotalAllocBytes());
ret.attrs["storage_num_not_allocated"] = std::make_shared<any>(num_not_allocated);
return ret;
Expand All @@ -222,7 +225,8 @@ NNVM_REGISTER_PASS(PlanMemory)
.set_change_graph(false)
.depend_graph_attr("dtype")
.depend_graph_attr("shape")
.provide_graph_attr("storage_id");
.provide_graph_attr("storage_id")
.provide_graph_attr("storage_inplace_index");

} // namespace
} // namespace pass
Expand Down
2 changes: 1 addition & 1 deletion nnvm/src/pass/saveload_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ struct JSONNode {
}

void Load(dmlc::JSONReader *reader) {
node = std::move(Node::Create());
node = Node::Create();
control_deps.clear();
dmlc::JSONObjectReadHelper helper;
std::string op_type_str;
Expand Down

0 comments on commit 2db0d3a

Please sign in to comment.