Skip to content

Commit

Permalink
ApplyPass -> ApplyPasses; Refactored infer pass; (#43)
Browse files Browse the repository at this point in the history
* ApplyPass -> ApplyPasses; Refactored infer pass;

* lint fix
  • Loading branch information
jermainewang authored and tqchen committed May 29, 2018
1 parent 8b2d68f commit 2e938a1
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 49 deletions.
10 changes: 5 additions & 5 deletions nnvm/include/nnvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,16 @@ NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
const char* key,
SymbolHandle list);
/*!
* \brief Apply pass on the src graph.
* \brief Apply passes on the src graph.
* \param src The source graph handle.
* \param num_pass The number of pass to be applied.
* \param pass_names The names of the pass.
* \param dst The result graph.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphApplyPass(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst);
NNVM_DLL int NNGraphApplyPasses(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst);

#endif // NNVM_C_API_H_
6 changes: 3 additions & 3 deletions nnvm/include/nnvm/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,11 @@ class IndexedGraph {
* \param other The source graph.
*/
explicit IndexedGraph(const Graph& other);
// node pointers in CSR structure.
// Node pointers in CSR structure.
std::vector<Node> nodes_;
// index all to input nodes
// Index to all input nodes.
std::vector<uint32_t> input_nodes_;
// index to mutable input nodes
// Index to all mutable input nodes.
std::unordered_set<uint32_t> mutable_input_nodes_;
// space to store the outputs entries
std::vector<NodeEntry> outputs_;
Expand Down
10 changes: 5 additions & 5 deletions nnvm/include/nnvm/graph_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace nnvm {
* \note Stored under ret.attrs["json"], provided by Pass "SaveJSON"
* \code
* Graph ret = ApplyPass(src_graph, {"SaveJSON"});
* Graph ret = ApplyPass(src_graph, "SaveJSON");
* const JSONString& json = ret.GetAttr<JSONString>("shape");
* \endcode
*/
Expand All @@ -29,7 +29,7 @@ using JSONString = std::string;
* \note Stored under graph.attrs["shape"], provided by Pass "InferShape"
*
* \code
* Graph g = ApplyPass(src_graph, {"InferShape"});
* Graph g = ApplyPass(src_graph, "InferShape");
* const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape");
* // get shape by entry id
* TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)];
Expand All @@ -44,7 +44,7 @@ using ShapeVector = std::vector<TShape>;
* \note Stored under graph.attrs["dtype"], provided by Pass "InferType"
*
* \code
* Graph g = ApplyPass(src_graph, {"InferType"});
* Graph g = ApplyPass(src_graph, "InferType");
* const DTypeVector& types = g.GetAttr<DTypeVector>("dtype");
* // get shape by entry id
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
Expand All @@ -59,7 +59,7 @@ using DTypeVector = std::vector<int>;
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice"
*
* \code
* Graph g = ApplyPass(src_graph, {"PlaceDevice"});
* Graph g = ApplyPass(src_graph, "PlaceDevice");
* const &device = g.GetAttr<DeviceVector>("device");
* // get device by node_id
* int device_type = device[g.indexed_graph().node_id(my_node)];
Expand All @@ -83,7 +83,7 @@ using DeviceAssignMap = std::unordered_map<std::string, int>;
* If the storage id is -1 then the storage is not assigned.
*
* \code
* Graph g = ApplyPass(src_graph, {"PlanMemory"});
* Graph g = ApplyPass(src_graph, "PlanMemory");
* const &storage = g.GetAttr<StorageVector>("storage");
* // get storage id by entry
* int storage_id = storage[g.indexed_graph().entry_id(my_entry)];
Expand Down
15 changes: 13 additions & 2 deletions nnvm/include/nnvm/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,22 @@ typedef std::function<Graph (Graph src)> PassFunction;
/*!
* \brief Apply a series of pass transformations on the input graph.
* \param src The graph to be transformed.
* \param passes A list of pass names to be applied.
* \return The transformed graph
*/
Graph ApplyPasses(Graph src,
const std::vector<std::string>& passes);

/*!
* \brief Apply one pass to the graph.
* \param src The graph to be transformed.
* \param pass The name of pass to be applied.
* \return The transformed graph.
*/
Graph ApplyPass(Graph src,
const std::vector<std::string>& pass);
inline Graph ApplyPass(Graph src, const std::string& pass) {
return ApplyPasses(src, {pass});
}


/*!
* \brief Registry entry for DataIterator factory functions.
Expand Down
14 changes: 7 additions & 7 deletions nnvm/include/nnvm/pass_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace pass {
inline Graph LoadJSON(const std::string& json_str) {
Graph ret;
ret.attrs["json"] = std::make_shared<any>(json_str);
return ApplyPass(ret, {"LoadJSON"});
return ApplyPass(ret, "LoadJSON");
}

/*!
Expand All @@ -37,7 +37,7 @@ inline Graph LoadJSON(const std::string& json_str) {
* \return The json string.
*/
inline std::string SaveJSON(Graph graph) {
Graph ret = ApplyPass(std::move(graph), {"SaveJSON"});
Graph ret = ApplyPass(std::move(graph), "SaveJSON");
return ret.GetAttr<std::string>("json");
}

Expand All @@ -52,7 +52,7 @@ inline std::string SaveJSON(Graph graph) {
* \return A graph with proper control flow dependencies added.
*/
inline Graph OrderMutation(Graph src) {
return ApplyPass(std::move(src), {"OrderMutation"});
return ApplyPass(std::move(src), "OrderMutation");
}

/*!
Expand All @@ -73,7 +73,7 @@ inline Graph InferShape(Graph graph,
if (shape_attr_key.length() != 0) {
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
}
return ApplyPass(std::move(graph), {"InferShape"});
return ApplyPass(std::move(graph), "InferShape");
}

/*!
Expand All @@ -94,7 +94,7 @@ inline Graph InferType(Graph graph,
if (dtype_attr_key.length() != 0) {
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(dtype_attr_key));
}
return ApplyPass(std::move(graph), {"InferType"});
return ApplyPass(std::move(graph), "InferType");
}

/*!
Expand All @@ -118,7 +118,7 @@ inline Graph PlaceDevice(Graph graph,
graph.attrs["device_group_attr_key"] = std::make_shared<any>(std::move(device_group_attr_key));
graph.attrs["device_assign_map"] = std::make_shared<any>(std::move(device_assign_map));
graph.attrs["device_copy_op"] = std::make_shared<any>(std::move(device_copy_op));
return ApplyPass(std::move(graph), {"PlaceDevice"});
return ApplyPass(std::move(graph), "PlaceDevice");
}

/*!
Expand Down Expand Up @@ -149,7 +149,7 @@ inline Graph Gradient(
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
}

return ApplyPass(std::move(graph), {"Gradient"});
return ApplyPass(std::move(graph), "Gradient");
}

} // namespace pass
Expand Down
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def apply(self, passes):
cpass = c_array(ctypes.c_char_p, [c_str(key) for key in passes])
ghandle = GraphHandle()
npass = nn_uint(len(passes))
check_call(_LIB.NNGraphApplyPass(self.handle, npass, cpass, ctypes.byref(ghandle)))
check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
return Graph(ghandle)


Expand Down
10 changes: 5 additions & 5 deletions nnvm/src/c_api/c_api_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,17 @@ int NNGraphGetJSONAttr(GraphHandle handle,
API_END();
}

int NNGraphApplyPass(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst) {
int NNGraphApplyPasses(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst) {
Graph* g = new Graph();
API_BEGIN();
std::vector<std::string> vpass;
for (nn_uint i = 0; i < num_pass; ++i) {
vpass.emplace_back(std::string(pass_names[i]));
}
*g = ApplyPass(*static_cast<Graph*>(src), vpass);
*g = ApplyPasses(*static_cast<Graph*>(src), vpass);
*dst = g;
API_END_HANDLE_ERROR(delete g);
}
4 changes: 2 additions & 2 deletions nnvm/src/core/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) {
return nullptr;
}

Graph ApplyPass(Graph g,
const std::vector<std::string>& pass) {
Graph ApplyPasses(Graph g,
const std::vector<std::string>& pass) {
std::vector<const PassFunctionReg*> fpass;
for (auto& name : pass) {
auto* reg = dmlc::Registry<PassFunctionReg>::Find(name);
Expand Down
45 changes: 26 additions & 19 deletions nnvm/src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace {

template<typename AttrType, typename IsNone>
Graph InferAttr(Graph &&ret,
const AttrType def_value,
const AttrType default_val,
const char* infer_name,
const char* input_name,
const char* attr_key_name,
Expand All @@ -23,16 +23,16 @@ Graph InferAttr(Graph &&ret,
using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
Op::GetAttr<FInferNodeEntryAttr<AttrType>>(infer_name);
static auto& backward_map =
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
// reshape shape vector
AttrVector rshape(idx.num_node_entries(), def_value);
AttrVector rshape(idx.num_node_entries(), default_val);

if (ret.attrs.count(input_name) != 0) {
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
CHECK_LE(shape_args.size(), idx.input_nodes().size())
<< "shape args is more than number of arguments";
<< "More provided shapes than number of arguments.";
for (size_t i = 0; i < shape_args.size(); ++i) {
rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
}
Expand All @@ -46,47 +46,54 @@ Graph InferAttr(Graph &&ret,
ret.attrs.erase(attr_key_name);
}

// temp space for shape inference.
// Temp space for shape inference.
std::vector<AttrType> ishape, oshape;
// number of completed nodes
size_t num_unknown = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
uint32_t num_inputs = inode.inputs.size();
uint32_t num_outputs = inode.source->num_outputs();
const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs();
if (inode.source->is_variable()) {
if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) {
// Variable node. No operator. Only one output entry.
CHECK(inode.source->op() == nullptr);
CHECK_EQ(num_outputs, 1);
const uint32_t out_ent_id = idx.entry_id(nid, 0);
if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
auto it = inode.source->attrs.dict.find(shape_attr_key);
if (it != inode.source->attrs.dict.end()) {
CHECK_EQ(num_outputs, 1);
std::istringstream is(it->second);
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid attribute";
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
}
}
continue;
}
if (finfer_shape.count(inode.source->op())) {
ishape.resize(num_inputs, def_value);
} else if (finfer_shape.count(inode.source->op())) {
// Forward operator inference.
ishape.resize(num_inputs, default_val);
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
}
oshape.resize(num_outputs, def_value);
oshape.resize(num_outputs, default_val);
for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = rshape[idx.entry_id(nid, i)];
}
num_unknown +=
!(finfer_shape[inode.source->op()](inode.source->attrs, &ishape, &oshape));
// Call inference function of the operator.
bool forward_known = finfer_shape[inode.source->op()](
inode.source->attrs, &ishape, &oshape);
num_unknown += !forward_known;
// Save to the result map.
for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
}
for (uint32_t i = 0; i < num_outputs; ++i) {
rshape[idx.entry_id(nid, i)] = oshape[i];
}
} else if (backward_map.count(inode.source->op())) {
// backward operator inference.
// Backward operator inference.
CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op";
const auto& fnode = idx[inode.control_deps[0]];
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
// Inference the outputs of backward operator (equal to the inputs
// of its corresponding forward operator).
std::vector<uint32_t> out_map =
backward_map[inode.source->op()](inode.source->attrs);
bool known = true;
Expand Down

0 comments on commit 2e938a1

Please sign in to comment.