Skip to content
This repository has been archived by the owner on Feb 1, 2020. It is now read-only.

Commit

Permalink
force infer storage on backward pass (#1) (#113)
Browse files Browse the repository at this point in the history
* force infer storage on backward pass (#1)

* force infer storage type for backward pass

* change none value for storage type to 0(kUndefinedStorage)

* Modify default value for storage type

* update unit test for infer storage
  • Loading branch information
eric-haibin-lin authored and piiswrong committed May 15, 2017
1 parent 319a8dd commit 7c603eb
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
1 change: 0 additions & 1 deletion src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,6 @@ Symbol Symbol::GetInternals() const {
}

Symbol Symbol::GetChildren() const {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol ret;
std::unordered_set<Node*> visited;
for (const auto& p : this->outputs) {
Expand Down
26 changes: 15 additions & 11 deletions src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
namespace nnvm {
namespace pass {
namespace {
// TODO(haibin) change file name to infer_attrs.cc
template<typename AttrType, typename IsNone, typename FDefault>
Graph InferAttr(Graph &&ret,
const AttrType empty_val,
Expand All @@ -20,7 +19,8 @@ Graph InferAttr(Graph &&ret,
const char* attr_name,
const char* unknown_name,
IsNone fis_none,
FDefault fdefault) {
FDefault fdefault,
bool backward_identity_assign) {
using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape =
Expand Down Expand Up @@ -88,7 +88,8 @@ Graph InferAttr(Graph &&ret,
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
}
}
} else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) {
} else if (is_backward.get(inode.source->op(), false) &&
inode.control_deps.size() && backward_identity_assign) {
CHECK_GE(inode.control_deps.size(), 1U)
<< "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
Expand Down Expand Up @@ -208,7 +209,7 @@ NNVM_REGISTER_PASS(InferShape)
"FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
nullptr);
nullptr, true);
})
.set_change_graph(false)
.provide_graph_attr("shape");
Expand Down Expand Up @@ -241,29 +242,32 @@ inline bool SameType(const NodeAttrs& attrs,
}

// assigning default type N to both input and output attrs with value -1
template <int N>
template <int default_val, int none>
inline bool DefaultType(const NodeAttrs& attrs,
std::vector<int> *iattr,
std::vector<int> *oattr) {
// LOG(INFO) << "DefaultType " << N;
for (int& v : *oattr) {
if (v == -1) v = N;
if (v == none) v = default_val;
}
for (int& v : *iattr) {
if (v == -1) v = N;
if (v == none) v = default_val;
}
return true;
}

NNVM_REGISTER_PASS(InferStorageType)
.describe("Infer the storage type of each node entries.")
.set_body([](Graph ret) {
// for storage type, the backward attr is not necessarily the same as it's correspondence
const int none = -1;
const int kDefaultStorage = 0;
return InferAttr<int>(
std::move(ret), -1,
std::move(ret), none,
"FInferStorageType", "storage_type_inputs", "storage_type_attr_key",
"storage_type", "storage_type_num_unknown_nodes",
[](const int t) { return t == -1; },
DefaultType<1>);
[](const int t) { return t == none; },
DefaultType<kDefaultStorage, none>, false);
})
.set_change_graph(false)
.provide_graph_attr("storage_type");
Expand All @@ -276,7 +280,7 @@ NNVM_REGISTER_PASS(InferType)
"FInferType", "dtype_inputs", "dtype_attr_key",
"dtype", "dtype_num_unknown_nodes",
[](const int t) { return t == -1; },
SameType);
SameType, true);
})
.set_change_graph(false)
.provide_graph_attr("dtype");
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_infer_storage_type():
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert g.json_attr('storage_type')[jnode_row_ptr[nindex["add1"]]] == 1
assert g.json_attr('storage_type')[jnode_row_ptr[nindex["add1"]]] == 0

def test_place_device():
x = sym.Variable('x', device_group="stage1")
Expand Down

0 comments on commit 7c603eb

Please sign in to comment.