Skip to content

Commit

Permalink
[PASS] SimplifyBatchNorm->SimplifyInference, remove dropout (apache#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent 215693d commit 40bc10f
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 11 deletions.
11 changes: 4 additions & 7 deletions nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .. import runtime

OPT_PASS_LEVEL = {
"SimplifyBatchNormInference": 2,
"SimplifyInference": 2,
"PrecomputePrune": 2,
"OpFusion": 1
}
Expand Down Expand Up @@ -115,12 +115,9 @@ def optimize(graph, shape, dtype="float32"):
"""
# pylint: disable=unused-argument
cfg = BuildConfig.current
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
if graph.json_attr("shape_num_unknown_nodes"):
raise ValueError("InferShape fails..")
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
graph = graph.apply("SimplifyBatchNormInference")
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyInference"]:
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "SimplifyInference"])
return graph


Expand Down
5 changes: 5 additions & 0 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def _compute(attrs, x, _):

_fschedule_broadcast = tvm.convert(_schedule_broadcast)

# copy
reg.register_compute("copy", _compute_unary(topi.identity))
reg.register_pattern("copy", OpPattern.ELEM_WISE)
reg.register_schedule("copy", _fschedule_broadcast)

# exp
reg.register_compute("exp", _compute_unary(topi.exp))
reg.register_pattern("exp", OpPattern.ELEM_WISE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
nnvm::NodeEntry moving_mean,
nnvm::NodeEntry moving_var,
TShape dshape) {
CHECK_NE(dshape.ndim(), 0);
CHECK(attrs.op);
static const Op* bn_op = Op::Get("batch_norm");
CHECK(attrs.op == bn_op);
Expand Down Expand Up @@ -76,13 +77,14 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
return {out, undef, undef};
}

Graph SimplifyBatchNormInference(nnvm::Graph src) {
Graph SimplifyInference(nnvm::Graph src) {
// Get attributes from the graph
const IndexedGraph& idx = src.indexed_graph();
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
auto transform = [&](uint32_t nid, const Node* n, std::vector<NodeEntry>* ret) {
if (n->is_variable()) return false;
static const Op* bn_op = Op::Get("batch_norm");
static const Op* dropout_op = Op::Get("dropout");
if (n->op() == bn_op) {
*ret = BatchNormToInferUnpack(
n->attrs,
Expand All @@ -93,15 +95,19 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) {
n->inputs[4],
shape_vec[idx.entry_id(nid, 0)]);
return true;
} else if (n->op() == dropout_op) {
NodeEntry undef = MakeNode("__undef__", "undef", {});
*ret = {n->inputs[0], undef};
return true;
} else {
return false;
}
};
return GraphTransform(src, transform);
}

NNVM_REGISTER_PASS(SimplifyBatchNormInference)
.set_body(SimplifyBatchNormInference);
NNVM_REGISTER_PASS(SimplifyInference)
.set_body(SimplifyInference);

} // namespace compiler
} // namespace nnvm
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def check(dim, axis, nstep):
for i in range(nstep):
y1 = sym.batch_norm(
y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
y1 = sym.dropout(y1)
y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var,
epsilon=eps, axis=axis, shape=ishape["x"])
g = nnvm.graph.create(y1)
g2 = nnvm.graph.create(y2)
graph_attr.set_shape_inputs(g, ishape)
g1 = g.apply("InferShape").apply("SimplifyBatchNormInference")
g1 = g.apply("InferShape").apply("SimplifyInference")
# Some prints for debug
# print(g1.ir())
# assert graph equals as expected
Expand Down

0 comments on commit 40bc10f

Please sign in to comment.