diff --git a/src/core/symbolic.cc b/src/core/symbolic.cc index 51cc1fa9e..0c621b55e 100644 --- a/src/core/symbolic.cc +++ b/src/core/symbolic.cc @@ -436,7 +436,6 @@ Symbol Symbol::GetInternals() const { } Symbol Symbol::GetChildren() const { - static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol ret; std::unordered_set visited; for (const auto& p : this->outputs) { diff --git a/src/pass/infer_shape_type.cc b/src/pass/infer_shape_type.cc index 75d3d150e..cef44f3b0 100644 --- a/src/pass/infer_shape_type.cc +++ b/src/pass/infer_shape_type.cc @@ -10,7 +10,6 @@ namespace nnvm { namespace pass { namespace { -// TODO(haibin) change file name to infer_attrs.cc template Graph InferAttr(Graph &&ret, const AttrType empty_val, @@ -20,7 +19,8 @@ Graph InferAttr(Graph &&ret, const char* attr_name, const char* unknown_name, IsNone fis_none, - FDefault fdefault) { + FDefault fdefault, + bool backward_identity_assign) { using AttrVector = std::vector; const IndexedGraph& idx = ret.indexed_graph(); static auto& finfer_shape = @@ -88,7 +88,8 @@ Graph InferAttr(Graph &&ret, CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; } } - } else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) { + } else if (is_backward.get(inode.source->op(), false) && + inode.control_deps.size() && backward_identity_assign) { CHECK_GE(inode.control_deps.size(), 1U) << "BackwardOp need to have control_deps to its forward op"; const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; @@ -208,7 +209,7 @@ NNVM_REGISTER_PASS(InferShape) "FInferShape", "shape_inputs", "shape_attr_key", "shape", "shape_num_unknown_nodes", [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, - nullptr); + nullptr, true); }) .set_change_graph(false) .provide_graph_attr("shape"); @@ -241,16 +242,16 @@ inline bool SameType(const NodeAttrs& attrs, } // assigning default type N to both input and output attrs with value -1 -template +template inline bool DefaultType(const NodeAttrs& attrs, std::vector *iattr, std::vector *oattr) { // LOG(INFO) << "DefaultType " << N; for (int& v : *oattr) { - if (v == -1) v = N; + if (v == none) v = default_val; } for (int& v : *iattr) { - if (v == -1) v = N; + if (v == none) v = default_val; } return true; } @@ -258,12 +259,15 @@ inline bool DefaultType(const NodeAttrs& attrs, NNVM_REGISTER_PASS(InferStorageType) .describe("Infer the storage type of each node entries.") .set_body([](Graph ret) { + // for storage type, the backward attr is not necessarily the same as it's correspondence + const int none = -1; + const int kDefaultStorage = 0; return InferAttr( - std::move(ret), -1, + std::move(ret), none, "FInferStorageType", "storage_type_inputs", "storage_type_attr_key", "storage_type", "storage_type_num_unknown_nodes", - [](const int t) { return t == -1; }, - DefaultType<1>); + [](const int t) { return t == none; }, + DefaultType, false); }) .set_change_graph(false) .provide_graph_attr("storage_type"); @@ -276,7 +280,7 @@ NNVM_REGISTER_PASS(InferType) "FInferType", "dtype_inputs", "dtype_attr_key", "dtype", "dtype_num_unknown_nodes", [](const int t) { return t == -1; }, - SameType); + SameType, true); }) .set_change_graph(false) .provide_graph_attr("dtype"); diff --git a/tests/python/test_graph.py b/tests/python/test_graph.py index 59cc12fb3..0500ac091 100644 --- a/tests/python/test_graph.py +++ b/tests/python/test_graph.py @@ -114,7 +114,7 @@ def test_infer_storage_type(): jnodes = jgraph['nodes'] jnode_row_ptr = jgraph['node_row_ptr'] nindex = {n['name']: i for i, n in enumerate(jnodes)} - assert g.json_attr('storage_type')[jnode_row_ptr[nindex["add1"]]] == 1 + assert g.json_attr('storage_type')[jnode_row_ptr[nindex["add1"]]] == 0 def test_place_device(): x = sym.Variable('x', device_group="stage1")