diff --git a/nnvm/include/nnvm/c_api.h b/nnvm/include/nnvm/c_api.h index 3122e26b7038..7f12561e9cdd 100644 --- a/nnvm/include/nnvm/c_api.h +++ b/nnvm/include/nnvm/c_api.h @@ -329,16 +329,16 @@ NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, const char* key, SymbolHandle list); /*! - * \brief Apply pass on the src graph. + * \brief Apply passes on the src graph. * \param src The source graph handle. * \param num_pass The number of pass to be applied. * \param pass_names The names of the pass. * \param dst The result graph. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphApplyPass(GraphHandle src, - nn_uint num_pass, - const char** pass_names, - GraphHandle *dst); +NNVM_DLL int NNGraphApplyPasses(GraphHandle src, + nn_uint num_pass, + const char** pass_names, + GraphHandle *dst); #endif // NNVM_C_API_H_ diff --git a/nnvm/include/nnvm/graph.h b/nnvm/include/nnvm/graph.h index 9c08467218e9..19abfa406c92 100644 --- a/nnvm/include/nnvm/graph.h +++ b/nnvm/include/nnvm/graph.h @@ -179,11 +179,11 @@ class IndexedGraph { * \param other The source graph. */ explicit IndexedGraph(const Graph& other); - // node pointers in CSR structure. + // Node pointers in CSR structure. std::vector nodes_; - // index all to input nodes + // Index to all input nodes. std::vector input_nodes_; - // index to mutable input nodes + // Index to all mutable input nodes. std::unordered_set mutable_input_nodes_; // space to store the outputs entries std::vector outputs_; diff --git a/nnvm/include/nnvm/graph_attr_types.h b/nnvm/include/nnvm/graph_attr_types.h index 6ac12e72d0fa..64894ec58de7 100644 --- a/nnvm/include/nnvm/graph_attr_types.h +++ b/nnvm/include/nnvm/graph_attr_types.h @@ -18,7 +18,7 @@ namespace nnvm { * \note Stored under ret.attrs["json"], provided by Pass "SaveJSON" * \code - * Graph ret = ApplyPass(src_graph, {"SaveJSON"}); + * Graph ret = ApplyPass(src_graph, "SaveJSON"); * const JSONString& json = ret.GetAttr("shape"); * \endcode */ @@ -29,7 +29,7 @@ using JSONString = std::string; * \note Stored under graph.attrs["shape"], provided by Pass "InferShape" * * \code - * Graph g = ApplyPass(src_graph, {"InferShape"}); + * Graph g = ApplyPass(src_graph, "InferShape"); * const ShapeVector& shapes = g.GetAttr("shape"); * // get shape by entry id * TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)]; @@ -44,7 +44,7 @@ using ShapeVector = std::vector; * \note Stored under graph.attrs["dtype"], provided by Pass "InferType" * * \code - * Graph g = ApplyPass(src_graph, {"InferType"}); + * Graph g = ApplyPass(src_graph, "InferType"); * const DTypeVector& types = g.GetAttr("dtype"); * // get shape by entry id * int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)]; @@ -59,7 +59,7 @@ using DTypeVector = std::vector; * \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice" * * \code - * Graph g = ApplyPass(src_graph, {"PlaceDevice"}); + * Graph g = ApplyPass(src_graph, "PlaceDevice"); * const &device = g.GetAttr("device"); * // get device by node_id * int device_type = device[g.indexed_graph().node_id(my_node)]; @@ -83,7 +83,7 @@ using DeviceAssignMap = std::unordered_map; * If the storage id is -1 then the storage is not assigned. * * \code - * Graph g = ApplyPass(src_graph, {"PlanMemory"}); + * Graph g = ApplyPass(src_graph, "PlanMemory"); * const &storage = g.GetAttr("storage"); * // get storage id by entry * int storage_id = storage[g.indexed_graph().entry_id(my_entry)]; diff --git a/nnvm/include/nnvm/pass.h b/nnvm/include/nnvm/pass.h index b743d2d5eba0..8b731bbd0f9f 100644 --- a/nnvm/include/nnvm/pass.h +++ b/nnvm/include/nnvm/pass.h @@ -29,11 +29,22 @@ typedef std::function PassFunction; /*! * \brief Apply a series of pass transformations on the input graph. * \param src The graph to be transformed. + * \param passes A list of pass names to be applied. + * \return The transformed graph + */ +Graph ApplyPasses(Graph src, + const std::vector& passes); + +/*! + * \brief Apply one pass to the graph. + * \param src The graph to be transformed. * \param pass The name of pass to be applied. * \return The transformed graph. */ -Graph ApplyPass(Graph src, - const std::vector& pass); +inline Graph ApplyPass(Graph src, const std::string& pass) { + return ApplyPasses(src, {pass}); +} + /*! * \brief Registry entry for DataIterator factory functions. diff --git a/nnvm/include/nnvm/pass_functions.h b/nnvm/include/nnvm/pass_functions.h index b94528496e51..68125ebf24bd 100644 --- a/nnvm/include/nnvm/pass_functions.h +++ b/nnvm/include/nnvm/pass_functions.h @@ -28,7 +28,7 @@ namespace pass { inline Graph LoadJSON(const std::string& json_str) { Graph ret; ret.attrs["json"] = std::make_shared(json_str); - return ApplyPass(ret, {"LoadJSON"}); + return ApplyPass(ret, "LoadJSON"); } /*! @@ -37,7 +37,7 @@ inline Graph LoadJSON(const std::string& json_str) { * \return The json string. */ inline std::string SaveJSON(Graph graph) { - Graph ret = ApplyPass(std::move(graph), {"SaveJSON"}); + Graph ret = ApplyPass(std::move(graph), "SaveJSON"); return ret.GetAttr("json"); } @@ -52,7 +52,7 @@ inline std::string SaveJSON(Graph graph) { * \return A graph with proper control flow dependencies added. */ inline Graph OrderMutation(Graph src) { - return ApplyPass(std::move(src), {"OrderMutation"}); + return ApplyPass(std::move(src), "OrderMutation"); } /*! @@ -73,7 +73,7 @@ inline Graph InferShape(Graph graph, if (shape_attr_key.length() != 0) { graph.attrs["shape_attr_key"] = std::make_shared(std::move(shape_attr_key)); } - return ApplyPass(std::move(graph), {"InferShape"}); + return ApplyPass(std::move(graph), "InferShape"); } /*! @@ -94,7 +94,7 @@ inline Graph InferType(Graph graph, if (dtype_attr_key.length() != 0) { graph.attrs["dtype_attr_key"] = std::make_shared(std::move(dtype_attr_key)); } - return ApplyPass(std::move(graph), {"InferType"}); + return ApplyPass(std::move(graph), "InferType"); } /*! @@ -118,7 +118,7 @@ inline Graph PlaceDevice(Graph graph, graph.attrs["device_group_attr_key"] = std::make_shared(std::move(device_group_attr_key)); graph.attrs["device_assign_map"] = std::make_shared(std::move(device_assign_map)); graph.attrs["device_copy_op"] = std::make_shared(std::move(device_copy_op)); - return ApplyPass(std::move(graph), {"PlaceDevice"}); + return ApplyPass(std::move(graph), "PlaceDevice"); } /*! @@ -149,7 +149,7 @@ inline Graph Gradient( graph.attrs["grad_mirror_fun"] = std::make_shared(mirror_fun); } - return ApplyPass(std::move(graph), {"Gradient"}); + return ApplyPass(std::move(graph), "Gradient"); } } // namespace pass diff --git a/nnvm/python/nnvm/graph.py b/nnvm/python/nnvm/graph.py index e3e857eecbf3..0f2997d4bf3f 100644 --- a/nnvm/python/nnvm/graph.py +++ b/nnvm/python/nnvm/graph.py @@ -113,7 +113,7 @@ def apply(self, passes): cpass = c_array(ctypes.c_char_p, [c_str(key) for key in passes]) ghandle = GraphHandle() npass = nn_uint(len(passes)) - check_call(_LIB.NNGraphApplyPass(self.handle, npass, cpass, ctypes.byref(ghandle))) + check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle))) return Graph(ghandle) diff --git a/nnvm/src/c_api/c_api_graph.cc b/nnvm/src/c_api/c_api_graph.cc index d3dd1d3e49aa..831aaec33e8c 100644 --- a/nnvm/src/c_api/c_api_graph.cc +++ b/nnvm/src/c_api/c_api_graph.cc @@ -82,17 +82,17 @@ int NNGraphGetJSONAttr(GraphHandle handle, API_END(); } -int NNGraphApplyPass(GraphHandle src, - nn_uint num_pass, - const char** pass_names, - GraphHandle *dst) { +int NNGraphApplyPasses(GraphHandle src, + nn_uint num_pass, + const char** pass_names, + GraphHandle *dst) { Graph* g = new Graph(); API_BEGIN(); std::vector vpass; for (nn_uint i = 0; i < num_pass; ++i) { vpass.emplace_back(std::string(pass_names[i])); } - *g = ApplyPass(*static_cast(src), vpass); + *g = ApplyPasses(*static_cast(src), vpass); *dst = g; API_END_HANDLE_ERROR(delete g); } diff --git a/nnvm/src/core/pass.cc b/nnvm/src/core/pass.cc index 5c4aeb2e0232..d72d4af00e65 100644 --- a/nnvm/src/core/pass.cc +++ b/nnvm/src/core/pass.cc @@ -22,8 +22,8 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) { return nullptr; } -Graph ApplyPass(Graph g, - const std::vector& pass) { +Graph ApplyPasses(Graph g, + const std::vector& pass) { std::vector fpass; for (auto& name : pass) { auto* reg = dmlc::Registry::Find(name); diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index bb50e98b5ede..3f9eca0a8c4b 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -13,7 +13,7 @@ namespace { template Graph InferAttr(Graph &&ret, - const AttrType def_value, + const AttrType default_val, const char* infer_name, const char* input_name, const char* attr_key_name, @@ -23,16 +23,16 @@ Graph InferAttr(Graph &&ret, using AttrVector = std::vector; const IndexedGraph& idx = ret.indexed_graph(); static auto& finfer_shape = - Op::GetAttr >(infer_name); + Op::GetAttr>(infer_name); static auto& backward_map = Op::GetAttr("FBackwardOutToInIndex"); // reshape shape vector - AttrVector rshape(idx.num_node_entries(), def_value); + AttrVector rshape(idx.num_node_entries(), default_val); if (ret.attrs.count(input_name) != 0) { const AttrVector& shape_args = ret.GetAttr(input_name); CHECK_LE(shape_args.size(), idx.input_nodes().size()) - << "shape args is more than number of arguments"; + << "More provided shapes than number of arguments."; for (size_t i = 0; i < shape_args.size(); ++i) { rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i]; } @@ -46,36 +46,41 @@ Graph InferAttr(Graph &&ret, ret.attrs.erase(attr_key_name); } - // temp space for shape inference. + // Temp space for shape inference. std::vector ishape, oshape; // number of completed nodes size_t num_unknown = 0; for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; - uint32_t num_inputs = inode.inputs.size(); - uint32_t num_outputs = inode.source->num_outputs(); + const uint32_t num_inputs = inode.inputs.size(); + const uint32_t num_outputs = inode.source->num_outputs(); if (inode.source->is_variable()) { - if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) { + // Variable node. No operator. Only one output entry. + CHECK(inode.source->op() == nullptr); + CHECK_EQ(num_outputs, 1); + const uint32_t out_ent_id = idx.entry_id(nid, 0); + if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) { auto it = inode.source->attrs.dict.find(shape_attr_key); if (it != inode.source->attrs.dict.end()) { - CHECK_EQ(num_outputs, 1); std::istringstream is(it->second); - CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid attribute"; + CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; } } - continue; - } - if (finfer_shape.count(inode.source->op())) { - ishape.resize(num_inputs, def_value); + } else if (finfer_shape.count(inode.source->op())) { + // Forward operator inference. + ishape.resize(num_inputs, default_val); for (uint32_t i = 0; i < ishape.size(); ++i) { ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; } - oshape.resize(num_outputs, def_value); + oshape.resize(num_outputs, default_val); for (uint32_t i = 0; i < oshape.size(); ++i) { oshape[i] = rshape[idx.entry_id(nid, i)]; } - num_unknown += - !(finfer_shape[inode.source->op()](inode.source->attrs, &ishape, &oshape)); + // Call inference function of the operator. + bool forward_known = finfer_shape[inode.source->op()]( + inode.source->attrs, &ishape, &oshape); + num_unknown += !forward_known; + // Save to the result map. for (uint32_t i = 0; i < num_inputs; ++i) { rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; } @@ -83,10 +88,12 @@ Graph InferAttr(Graph &&ret, rshape[idx.entry_id(nid, i)] = oshape[i]; } } else if (backward_map.count(inode.source->op())) { - // backward operator inference. + // Backward operator inference. CHECK_GE(inode.control_deps.size(), 1) << "BackwardOp need to have control_deps to its forward op"; - const auto& fnode = idx[inode.control_deps[0]]; + const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; + // Inference the outputs of backward operator (equal to the inputs + // of its corresponding forward operator). std::vector out_map = backward_map[inode.source->op()](inode.source->attrs); bool known = true;