From 646d88502dac30ab331a431d2b6effc45899d810 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 25 Apr 2019 19:14:53 +0900 Subject: [PATCH 01/11] add post process analysis for tuple fusion --- include/tvm/relay/op_attr_types.h | 9 ++++-- src/relay/pass/fuse_ops.cc | 50 +++++++++++++++++-------------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 464bc1cc0b64..46f5721c2f73 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -43,12 +43,15 @@ enum OpPatternKind { kBroadcast = 1, // Injective operator, can always injectively map output axis to a single input axis. // All injective operator can still be safely fused to injective and reduction. - kInjective = 2, + kTuple = 2, + kInjective = 3, // Communicative reduction operator. - kCommReduce = 3, + kCommReduce = 4, // Complex operation, can still fuse elemwise operations into its output. // but cannot chain another complex op - kOutEWiseFusable = 4, + kOutEWiseFusable = 5, + kTupleField = 6, + // Opaque operation, cannot fuse anything. kOpaque = 8 }; diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 12e3174dcade..ca1dccff287e 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -267,10 +267,10 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const TupleNode* op) final { CHECK(graph_.node_map.count(op)); Node* tuple_node = graph_.node_map.at(op); - tuple_node->pattern = kInjective; + tuple_node->pattern = kTuple; for (const Expr& field : op->fields) { if (field->checked_type().as()) { - this->Update(field, tuple_node, kInjective); + this->Update(field, tuple_node, kTupleField); } else { this->Update(field, nullptr, kOpaque); } @@ -493,6 +493,8 @@ class GraphPartitioner { OpPatternKind pattern; /*! \brief reference to the root node. */ const tvm::Node* root_ref{nullptr}; + + std::vector inputs; /*! * \brief Reference to the master node, * this field is not nullptr only if pattern is kOutEWiseFusable. @@ -700,6 +702,8 @@ class GraphPartitioner { }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { CommitFuse(graph_node, dom_node->parent->gnode); + } else { + groups_[dom_node->parent->gnode->index]->inputs.push_back(groups_[graph_node->index]); } } } else if (group_node->pattern == kInjective) { @@ -712,6 +716,8 @@ class GraphPartitioner { }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { CommitFuse(graph_node, dom_node->parent->gnode); + } else { + groups_[dom_node->parent->gnode->index]->inputs.push_back(groups_[graph_node->index]); } } else { // do nothing. @@ -731,6 +737,24 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) { for (int phase = 0; phase < 2; ++phase) { this->RunFuse(graph, post_dom_tree, phase); } + // Fuse intermediate tuples, if any + std::unordered_set visited; + for (size_t nid = groups_.size() - 1; nid >= 0; --nid) { + Group* group = groups_[nid]; + if (visited.count(group)) continue; + visited.insert(group); + if (group->pattern > kInjective) continue; + const auto& input_groups = group->inputs; + bool fusible = std::all_of(input_groups.begin(), input_groups.end(), [](const Group* g) { + return g->pattern <= kInjective || g->pattern == kTupleField; + }); + if (fusible) { + for (Group* child_group : input_groups) { + MergeFromTo(child_group, group); + visited.insert(child_group); + } + } + } return std::move(groups_); } @@ -821,29 +845,11 @@ class FuseMutator : private ExprMutator { Expr VisitExpr_(const TupleNode* tuple) { auto* ret_group = gmap_.at(tuple)->FindRoot(); - Array new_fields = GetNewArguments(tuple->fields, ret_group); if (ret_group == gmap_.at(tuple)) { - // This tuple is the root of its group. Check if all fields come from other groups. - bool isolated = new_fields.size() == ginfo_[ret_group].params.size(); - for (size_t i = 0; i < new_fields.size() && isolated; ++i) { - isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i])); - } - if (isolated) { - // Do not put a isolated tuple into a function - return ExprMutator::VisitExpr_(tuple); - } - // This tuple has been fused with other ops before it - for (size_t i = 0; i < new_fields.size(); i++) { - // Copy function arguments to tuple field of the output because currently graph memory - // planer doesn't support inplace operations - if (new_fields[i].as()) { - auto copy = Copy(new_fields[i]); - new_fields.Set(i, copy); - } - } - return MakeNewFunction(ret_group, tuple->checked_type(), TupleNode::make(new_fields)); + return ExprMutator::VisitExpr_(tuple); } // This tuple is an intermediate node in the group + Array new_fields = GetNewArguments(tuple->fields, ret_group); return TupleNode::make(new_fields); } From 87bda6cedeaabffbe0db4c5de1d3c80fe9b49950 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 25 Apr 2019 23:45:08 +0900 Subject: [PATCH 02/11] make it work --- python/tvm/relay/op/op.py | 8 +-- src/relay/pass/fuse_ops.cc | 37 +++++++------- .../relay/test_backend_compile_engine.py | 10 +++- tests/python/relay/test_pass_fuse_ops.py | 51 +++---------------- 4 files changed, 41 insertions(+), 65 deletions(-) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 6312f023df0d..7eea999828fb 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -106,12 +106,14 @@ class OpPattern(object): ELEMWISE = 0 # Broadcast operator BROADCAST = 1 + kTuple = 2 # Injective mapping - INJECTIVE = 2 + INJECTIVE = 3 # Communication - COMM_REDUCE = 3 + COMM_REDUCE = 4 # Complex op, can still fuse ewise into it - OUT_ELEMWISE_FUSABLE = 4 + OUT_ELEMWISE_FUSABLE = 5 + kTupleFiled = 6 # Not fusable opaque op OPAQUE = 8 diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index ca1dccff287e..3aa7a479725a 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -494,7 +494,7 @@ class GraphPartitioner { /*! \brief reference to the root node. */ const tvm::Node* root_ref{nullptr}; - std::vector inputs; + std::set inputs; /*! * \brief Reference to the master node, * this field is not nullptr only if pattern is kOutEWiseFusable. @@ -595,6 +595,9 @@ class GraphPartitioner { parent = parent->FindRoot(); if (child == parent) return; child->parent = parent; + for (Group* g : child->inputs) { + parent->inputs.insert(g); + } // update master ref and pattern if (child->master_ref != nullptr) { CHECK(parent->master_ref == nullptr); @@ -662,9 +665,13 @@ class GraphPartitioner { if (group_node->pattern == kOpaque) continue; // no actions needed if the current node have no dominator if (dom_node->parent == nullptr) continue; + size_t dom_parent_gindex = dom_node->parent->gnode->index; + if (dom_node->pattern == kTupleField) { + groups_[dom_parent_gindex]->inputs.insert(group_node); + continue; + } CHECK(!graph_node->extern_ref); // Skip if current node is already fused to the parent. - size_t dom_parent_gindex = dom_node->parent->gnode->index; if (groups_[dom_parent_gindex] != nullptr && group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { continue; @@ -702,11 +709,9 @@ class GraphPartitioner { }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { CommitFuse(graph_node, dom_node->parent->gnode); - } else { - groups_[dom_node->parent->gnode->index]->inputs.push_back(groups_[graph_node->index]); } } - } else if (group_node->pattern == kInjective) { + } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { // defer injective fusion to second phase. // so conv2d always finishes fusing. if (phase != 1) continue; @@ -716,8 +721,6 @@ class GraphPartitioner { }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { CommitFuse(graph_node, dom_node->parent->gnode); - } else { - groups_[dom_node->parent->gnode->index]->inputs.push_back(groups_[graph_node->index]); } } else { // do nothing. @@ -739,19 +742,19 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) { } // Fuse intermediate tuples, if any std::unordered_set visited; - for (size_t nid = groups_.size() - 1; nid >= 0; --nid) { + for (size_t i = groups_.size(); i != 0; --i) { + size_t nid = i - 1; Group* group = groups_[nid]; if (visited.count(group)) continue; visited.insert(group); - if (group->pattern > kInjective) continue; - const auto& input_groups = group->inputs; - bool fusible = std::all_of(input_groups.begin(), input_groups.end(), [](const Group* g) { - return g->pattern <= kInjective || g->pattern == kTupleField; - }); - if (fusible) { - for (Group* child_group : input_groups) { - MergeFromTo(child_group, group); - visited.insert(child_group); + Group* root_group = group->FindRoot(); + if (root_group->pattern == kTuple) continue; + if (group->pattern == kTuple && root_group->pattern <= kInjective) { + for (Group* child_group : root_group->inputs) { + if (child_group->FindRoot()->pattern <= kInjective) { + MergeFromTo(child_group, group); + visited.insert(child_group); + } } } } diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index 3b479b847619..ca4619c97886 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -69,8 +69,16 @@ def test_compile_injective_with_tuple(): relay.build(func, 'llvm') +def test_compile_tuple_dup(): + x = relay.var("data", shape=(16, 16)) + log = relay.log(x) + output = relay.Tuple([log, log]) + f = relay.Function([x], output) + relay.build(f, 'llvm') + + if __name__ == "__main__": test_compile_engine() test_compile_placeholder_bypass() test_compile_injective_with_tuple() - + test_compile_tuple_dup() diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index baafbeebd560..2d95d37e99e3 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -151,9 +151,10 @@ def expected(dshape): dshape = (1, 16, 64, 64) z = before(dshape) z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=0) - assert not relay.ir_pass.free_vars(zz) + # zz = relay.ir_pass.fuse_ops(z, opt_level=0) + # assert not relay.ir_pass.free_vars(zz) zz = relay.ir_pass.fuse_ops(z, opt_level=2) + print(zz) zz = relay.ir_pass.infer_type(zz) assert not relay.ir_pass.free_vars(zz) after = relay.ir_pass.infer_type(expected(dshape)) @@ -176,16 +177,14 @@ def expected(dshape): f0 = relay.Function([x], pooled) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) - p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3])) - p1_copy = relay.copy(p1) upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") - out = relay.Tuple((upsampled, p1_copy)) - f1 = relay.Function([p0, p1], out) + f1 = relay.Function([p0], upsampled) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) - z = relay.Call(f1, [y, x]) - return relay.Function([x], z) + z = relay.Call(f1, [y]) + tup = relay.Tuple((z, x)) + return relay.Function([x], tup) dshape = (1, 16, 64, 64) z = before(dshape) @@ -199,41 +198,6 @@ def expected(dshape): assert relay.ir_pass.alpha_equal(zz, after) -def test_tuple_strided_slice(): - """ - Test fusion case where the number of fields of tuple and - the number of parameters to the function containing the tuple are different - """ - - def before(dshape): - x = relay.var("x", shape=dshape) - slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1]) - slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1]) - out = relay.Tuple((slice1, slice2)) - return relay.Function([x], out) - - def expected(dshape): - x = relay.var("x", shape=dshape) - slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1]) - slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1]) - out = relay.Tuple((slice1, slice2)) - f0 = relay.Function([x], out) - - x = relay.var("x", shape=dshape) - y = relay.Call(f0, [x]) - return relay.Function([x], y) - - dshape = (64, 64) - z = before(dshape) - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=0) - assert not relay.ir_pass.free_vars(zz) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - zz = relay.ir_pass.infer_type(zz) - assert not relay.ir_pass.free_vars(zz) - after = relay.ir_pass.infer_type(expected(dshape)) - assert relay.ir_pass.alpha_equal(zz, after) - def test_stop_fusion(): def before(dshape): @@ -382,7 +346,6 @@ def expected(dim): test_conv2d_fuse() test_concatenate() test_tuple_root() - test_tuple_strided_slice() test_stop_fusion() test_fuse_myia_regression() test_fuse_tuple_get_elemwise() From 911d6b664c1a106f0d21d30c99eabdae5b01bf82 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 26 Apr 2019 00:02:54 +0900 Subject: [PATCH 03/11] add more test --- tests/python/relay/test_pass_fuse_ops.py | 46 ++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 2d95d37e99e3..b707e8a78b63 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -151,10 +151,9 @@ def expected(dshape): dshape = (1, 16, 64, 64) z = before(dshape) z = relay.ir_pass.infer_type(z) - # zz = relay.ir_pass.fuse_ops(z, opt_level=0) - # assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) zz = relay.ir_pass.fuse_ops(z, opt_level=2) - print(zz) zz = relay.ir_pass.infer_type(zz) assert not relay.ir_pass.free_vars(zz) after = relay.ir_pass.infer_type(expected(dshape)) @@ -341,6 +340,46 @@ def expected(dim): assert relay.ir_pass.alpha_equal(zz, after) +def test_tuple_intermediate(): + def before(dshape): + x = relay.var("x", shape=dshape) + inj = relay.squeeze(x) + y1 = relay.add(inj, relay.const(1, "float32")) + y2 = relay.squeeze(inj) + y3 = relay.add(inj, relay.const(1, "float32")) + concat = relay.concatenate((y1, y2, y3), axis=1) + out_inj = relay.squeeze(concat) + out = relay.add(out_inj, relay.const(1, "float32")) + return relay.Function(relay.ir_pass.free_vars(out), out) + + def expected(dshape): + p0 = relay.var("p0", shape=dshape) + inj = relay.squeeze(p0) + y1 = relay.add(inj, relay.const(1, "float32")) + y2 = relay.squeeze(inj) + y3 = relay.add(inj, relay.const(1, "float32")) + concat = relay.concatenate((y1, y2, y3), axis=1) + out_inj = relay.squeeze(concat) + out = relay.add(out_inj, relay.const(1, "float32")) + f0 = relay.Function([p0], out) + + x = relay.var("x", shape=dshape) + y = relay.Call(f0, [x]) + return relay.Function([x], y) + + dshape = (1, 16, 64, 64) + z = before(dshape) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + relay.build(zz, 'llvm') + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(dshape)) + assert relay.ir_pass.alpha_equal(zz, after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -350,3 +389,4 @@ def expected(dim): test_fuse_myia_regression() test_fuse_tuple_get_elemwise() test_tuple_get_root() + test_tuple_intermediate() From 1b73bd0b4f8ac787f0c92deb6be327f341b1f81f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 26 Apr 2019 01:32:29 +0900 Subject: [PATCH 04/11] add more test --- src/relay/pass/fuse_ops.cc | 4 -- tests/python/relay/test_pass_fuse_ops.py | 79 ++++++++++++++++++++---- 2 files changed, 66 insertions(+), 17 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 3aa7a479725a..c29fa2fd5813 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -741,19 +741,15 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) { this->RunFuse(graph, post_dom_tree, phase); } // Fuse intermediate tuples, if any - std::unordered_set visited; for (size_t i = groups_.size(); i != 0; --i) { size_t nid = i - 1; Group* group = groups_[nid]; - if (visited.count(group)) continue; - visited.insert(group); Group* root_group = group->FindRoot(); if (root_group->pattern == kTuple) continue; if (group->pattern == kTuple && root_group->pattern <= kInjective) { for (Group* child_group : root_group->inputs) { if (child_group->FindRoot()->pattern <= kInjective) { MergeFromTo(child_group, group); - visited.insert(child_group); } } } diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index b707e8a78b63..d6d7bc9d479a 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -341,34 +341,86 @@ def expected(dim): def test_tuple_intermediate(): - def before(dshape): - x = relay.var("x", shape=dshape) + def before(x): inj = relay.squeeze(x) y1 = relay.add(inj, relay.const(1, "float32")) - y2 = relay.squeeze(inj) + tmp = relay.squeeze(inj) + tmp = relay.add(tmp, relay.const(1, "float32")) + y2 = relay.add(tmp, relay.const(1, "float32")) y3 = relay.add(inj, relay.const(1, "float32")) concat = relay.concatenate((y1, y2, y3), axis=1) out_inj = relay.squeeze(concat) out = relay.add(out_inj, relay.const(1, "float32")) return relay.Function(relay.ir_pass.free_vars(out), out) + def expected(p0): + f0 = before(p0) + x = relay.var("x", shape=dshape) + y = relay.Call(f0, [x]) + return relay.Function([x], y) + + dshape = (1, 16, 64, 64) + x = relay.var("x", shape=dshape) + z = before(x) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + relay.build(zz, 'llvm') + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(x)) + assert relay.ir_pass.alpha_equal(zz, after) + + +def test_tuple_consecutive(): + def gen_intermediate_tuple(x): + y1 = relay.add(x, relay.const(1, "float32")) + y2 = relay.add(x, relay.const(1, "float32")) + y3 = relay.add(x, relay.const(1, "float32")) + concat = relay.concatenate((y1, y2, y3), axis=1) + out = relay.add(concat, relay.const(1, "float32")) + return out + + def gen_consecutive_tuple(x): + y1 = gen_intermediate_tuple(x) + y2 = gen_intermediate_tuple(x) + y3 = gen_intermediate_tuple(x) + concat = relay.concatenate((y1, y2, y3), axis=1) + return concat + + def before(x): + concat = gen_consecutive_tuple(x) + pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + out = relay.add(pooled, relay.const(1, "float32")) + out2 = relay.add(out, relay.const(1, "float32")) + out_tup = relay.Tuple((out, out2)) + return relay.Function(relay.ir_pass.free_vars(out_tup), out_tup) + def expected(dshape): p0 = relay.var("p0", shape=dshape) - inj = relay.squeeze(p0) - y1 = relay.add(inj, relay.const(1, "float32")) - y2 = relay.squeeze(inj) - y3 = relay.add(inj, relay.const(1, "float32")) - concat = relay.concatenate((y1, y2, y3), axis=1) - out_inj = relay.squeeze(concat) - out = relay.add(out_inj, relay.const(1, "float32")) - f0 = relay.Function([p0], out) + concat = gen_consecutive_tuple(p0) + f0 = relay.Function([p0], concat) + + p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3])) + pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + out = relay.add(pooled, relay.const(1, "float32")) + f1 = relay.Function([p01], out) + + p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2)) + out = relay.add(p02, relay.const(1, "float32")) + f2 = relay.Function([p02], out) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) - return relay.Function([x], y) + z = relay.Call(f1, [y]) + z2 = relay.Call(f2, [z]) + + return relay.Function([x], relay.Tuple((z, z2))) dshape = (1, 16, 64, 64) - z = before(dshape) + x = relay.var("x", shape=dshape) + z = before(x) z = relay.ir_pass.infer_type(z) zz = relay.ir_pass.fuse_ops(z, opt_level=0) assert not relay.ir_pass.free_vars(zz) @@ -390,3 +442,4 @@ def expected(dshape): test_fuse_tuple_get_elemwise() test_tuple_get_root() test_tuple_intermediate() + test_tuple_consecutive() From 0e6aa6a88a7e20d9035d0890a67633bcb6207645 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 26 Apr 2019 02:12:33 +0900 Subject: [PATCH 05/11] add comment --- include/tvm/relay/op_attr_types.h | 5 ++++- python/tvm/relay/op/op.py | 2 ++ src/relay/pass/fuse_ops.cc | 13 +++++++++++-- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 46f5721c2f73..d3c1d5fb09ee 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -41,15 +41,18 @@ enum OpPatternKind { // for example :code:`out[i, ax1, j, ax2] = input[i, j]`. // Note that the axis need to be in order so transpose is not a bcast operator. kBroadcast = 1, + // The pattern for tuple nodes. Can fuse into subsequent injective ops. + kTuple = 2, // Injective operator, can always injectively map output axis to a single input axis. // All injective operator can still be safely fused to injective and reduction. - kTuple = 2, kInjective = 3, // Communicative reduction operator. kCommReduce = 4, // Complex operation, can still fuse elemwise operations into its output. // but cannot chain another complex op kOutEWiseFusable = 5, + // The edge pattern between tuple and its fields. Tuple fields can fuse into + // a tuple if the tuple is intermediate node in its fusion group. kTupleField = 6, // Opaque operation, cannot fuse anything. diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 7eea999828fb..8cbb55fd4ef6 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -106,6 +106,7 @@ class OpPattern(object): ELEMWISE = 0 # Broadcast operator BROADCAST = 1 + # Represents tuple node kTuple = 2 # Injective mapping INJECTIVE = 3 @@ -113,6 +114,7 @@ class OpPattern(object): COMM_REDUCE = 4 # Complex op, can still fuse ewise into it OUT_ELEMWISE_FUSABLE = 5 + # Used to represent edge pattern between tuple and its fields kTupleFiled = 6 # Not fusable opaque op OPAQUE = 8 diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index c29fa2fd5813..c2b2c8024fef 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -493,7 +493,11 @@ class GraphPartitioner { OpPatternKind pattern; /*! \brief reference to the root node. */ const tvm::Node* root_ref{nullptr}; - + /*! + * \brief The input nodes to this group + * The group and its inputs are disjoint in the union find forest. + * This field is used to keep track of data flow between fusion groups. + */ std::set inputs; /*! * \brief Reference to the master node, @@ -666,6 +670,9 @@ class GraphPartitioner { // no actions needed if the current node have no dominator if (dom_node->parent == nullptr) continue; size_t dom_parent_gindex = dom_node->parent->gnode->index; + // If the edge pattern is kTupleField, then dom_node below is tuple + // Do not let tuple fields fuse into the tuple here, but record the data flow between + // groups so that we can fuse them later if (dom_node->pattern == kTupleField) { groups_[dom_parent_gindex]->inputs.insert(group_node); continue; @@ -715,7 +722,7 @@ class GraphPartitioner { // defer injective fusion to second phase. // so conv2d always finishes fusing. if (phase != 1) continue; - // Check if all path are injective. + // Check if all path are injective. tuple nodes can be fused if its dom_node is injective. auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; @@ -747,6 +754,8 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) { Group* root_group = group->FindRoot(); if (root_group->pattern == kTuple) continue; if (group->pattern == kTuple && root_group->pattern <= kInjective) { + // Here, we found a tuple node that had been fused into later injective ops. + // Complete the fusion by fusing tuple fields into it. for (Group* child_group : root_group->inputs) { if (child_group->FindRoot()->pattern <= kInjective) { MergeFromTo(child_group, group); From 4d87216feae3783ad21e4eab3f3443faed9e3e06 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 26 Apr 2019 09:17:13 +0900 Subject: [PATCH 06/11] remove group inputs and kTupleFields --- include/tvm/relay/op_attr_types.h | 3 -- python/tvm/relay/op/op.py | 2 -- src/relay/pass/fuse_ops.cc | 48 ++++++++++++------------------- 3 files changed, 18 insertions(+), 35 deletions(-) diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index d3c1d5fb09ee..45749ac894f0 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -51,9 +51,6 @@ enum OpPatternKind { // Complex operation, can still fuse elemwise operations into its output. // but cannot chain another complex op kOutEWiseFusable = 5, - // The edge pattern between tuple and its fields. Tuple fields can fuse into - // a tuple if the tuple is intermediate node in its fusion group. - kTupleField = 6, // Opaque operation, cannot fuse anything. kOpaque = 8 diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 8cbb55fd4ef6..23138580297e 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -114,8 +114,6 @@ class OpPattern(object): COMM_REDUCE = 4 # Complex op, can still fuse ewise into it OUT_ELEMWISE_FUSABLE = 5 - # Used to represent edge pattern between tuple and its fields - kTupleFiled = 6 # Not fusable opaque op OPAQUE = 8 diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index c2b2c8024fef..59e82369067f 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -270,7 +270,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { tuple_node->pattern = kTuple; for (const Expr& field : op->fields) { if (field->checked_type().as()) { - this->Update(field, tuple_node, kTupleField); + this->Update(field, tuple_node, kInjective); } else { this->Update(field, nullptr, kOpaque); } @@ -493,12 +493,6 @@ class GraphPartitioner { OpPatternKind pattern; /*! \brief reference to the root node. */ const tvm::Node* root_ref{nullptr}; - /*! - * \brief The input nodes to this group - * The group and its inputs are disjoint in the union find forest. - * This field is used to keep track of data flow between fusion groups. - */ - std::set inputs; /*! * \brief Reference to the master node, * this field is not nullptr only if pattern is kOutEWiseFusable. @@ -599,9 +593,6 @@ class GraphPartitioner { parent = parent->FindRoot(); if (child == parent) return; child->parent = parent; - for (Group* g : child->inputs) { - parent->inputs.insert(g); - } // update master ref and pattern if (child->master_ref != nullptr) { CHECK(parent->master_ref == nullptr); @@ -669,20 +660,16 @@ class GraphPartitioner { if (group_node->pattern == kOpaque) continue; // no actions needed if the current node have no dominator if (dom_node->parent == nullptr) continue; - size_t dom_parent_gindex = dom_node->parent->gnode->index; - // If the edge pattern is kTupleField, then dom_node below is tuple - // Do not let tuple fields fuse into the tuple here, but record the data flow between - // groups so that we can fuse them later - if (dom_node->pattern == kTupleField) { - groups_[dom_parent_gindex]->inputs.insert(group_node); - continue; - } CHECK(!graph_node->extern_ref); // Skip if current node is already fused to the parent. + size_t dom_parent_gindex = dom_node->parent->gnode->index; if (groups_[dom_parent_gindex] != nullptr && group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { continue; } + // Do not fuse into tuple for now + if (groups_[dom_parent_gindex]->pattern == kTuple) continue; + // Try to fuse current node to its post-dominator. if (group_node->pattern == kOutEWiseFusable) { if (phase != 0) continue; @@ -722,7 +709,7 @@ class GraphPartitioner { // defer injective fusion to second phase. // so conv2d always finishes fusing. if (phase != 1) continue; - // Check if all path are injective. tuple nodes can be fused if its dom_node is injective. + // Check if all path are injective. auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; @@ -748,18 +735,19 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) { this->RunFuse(graph, post_dom_tree, phase); } // Fuse intermediate tuples, if any - for (size_t i = groups_.size(); i != 0; --i) { - size_t nid = i - 1; + for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { + auto* dom_node = post_dom_tree.nodes[nid]; Group* group = groups_[nid]; - Group* root_group = group->FindRoot(); - if (root_group->pattern == kTuple) continue; - if (group->pattern == kTuple && root_group->pattern <= kInjective) { - // Here, we found a tuple node that had been fused into later injective ops. - // Complete the fusion by fusing tuple fields into it. - for (Group* child_group : root_group->inputs) { - if (child_group->FindRoot()->pattern <= kInjective) { - MergeFromTo(child_group, group); - } + if (group->pattern == kOpaque) continue; + if (dom_node->parent == nullptr) continue; + Group* dom_parent_group = groups_[dom_node->parent->gnode->index]; + Group* dom_root_group = dom_parent_group->FindRoot(); + // If dom node group has a tuple as its root, we do not fuse tuple fields into it + if (dom_root_group->pattern == kTuple) continue; + if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { + // Now we know the tuple has been fused into subsequent injective ops + if (group->FindRoot()->pattern <= kInjective) { + MergeFromTo(group, dom_root_group); } } } From 55e1c0ce9de05d77f0df1e31883e9b7d1d742bf5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 26 Apr 2019 15:46:35 +0900 Subject: [PATCH 07/11] move logic to RunFuse --- python/tvm/relay/op/op.py | 2 +- src/relay/pass/fuse_ops.cc | 38 ++++++++++++++++++-------------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 23138580297e..e0d597525ebd 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -107,7 +107,7 @@ class OpPattern(object): # Broadcast operator BROADCAST = 1 # Represents tuple node - kTuple = 2 + TUPLE = 2 # Injective mapping INJECTIVE = 3 # Communication diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 59e82369067f..43484630589b 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -661,15 +661,30 @@ class GraphPartitioner { // no actions needed if the current node have no dominator if (dom_node->parent == nullptr) continue; CHECK(!graph_node->extern_ref); - // Skip if current node is already fused to the parent. size_t dom_parent_gindex = dom_node->parent->gnode->index; + + if (phase == 2) { + // Fuse intermediate tuples, if any + Group* dom_parent_group = groups_[dom_parent_gindex]; + Group* dom_root_group = dom_parent_group->FindRoot(); + // If dom node group has a tuple as its root, we do not fuse tuple fields into it + if (dom_root_group->pattern == kTuple) continue; + if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { + // Now we know the tuple has been fused into subsequent injective ops + if (group_node->FindRoot()->pattern <= kInjective) { + MergeFromTo(group_node, dom_root_group); + } + } + continue; + } + + // Skip if current node is already fused to the parent. if (groups_[dom_parent_gindex] != nullptr && group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { continue; } // Do not fuse into tuple for now if (groups_[dom_parent_gindex]->pattern == kTuple) continue; - // Try to fuse current node to its post-dominator. if (group_node->pattern == kOutEWiseFusable) { if (phase != 0) continue; @@ -731,26 +746,9 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) { // get post dominator tree auto post_dom_tree = DominatorTree::PostDom(arena_, graph); // run fusion algorithm. - for (int phase = 0; phase < 2; ++phase) { + for (int phase = 0; phase < 3; ++phase) { this->RunFuse(graph, post_dom_tree, phase); } - // Fuse intermediate tuples, if any - for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { - auto* dom_node = post_dom_tree.nodes[nid]; - Group* group = groups_[nid]; - if (group->pattern == kOpaque) continue; - if (dom_node->parent == nullptr) continue; - Group* dom_parent_group = groups_[dom_node->parent->gnode->index]; - Group* dom_root_group = dom_parent_group->FindRoot(); - // If dom node group has a tuple as its root, we do not fuse tuple fields into it - if (dom_root_group->pattern == kTuple) continue; - if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { - // Now we know the tuple has been fused into subsequent injective ops - if (group->FindRoot()->pattern <= kInjective) { - MergeFromTo(group, dom_root_group); - } - } - } return std::move(groups_); } From 538a2d3bb8c21894f203f8f1755942c6494da18a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 26 Apr 2019 17:49:25 +0900 Subject: [PATCH 08/11] fix inception case --- src/relay/pass/fuse_ops.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 43484630589b..00c38ca5c059 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -664,15 +664,19 @@ class GraphPartitioner { size_t dom_parent_gindex = dom_node->parent->gnode->index; if (phase == 2) { - // Fuse intermediate tuples, if any + // Fuse injective ops into intermediate tuples, if any + if (group_node->pattern > kInjective) continue; Group* dom_parent_group = groups_[dom_parent_gindex]; Group* dom_root_group = dom_parent_group->FindRoot(); // If dom node group has a tuple as its root, we do not fuse tuple fields into it if (dom_root_group->pattern == kTuple) continue; if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { // Now we know the tuple has been fused into subsequent injective ops - if (group_node->FindRoot()->pattern <= kInjective) { - MergeFromTo(group_node, dom_root_group); + auto fcond = [](OpPatternKind kind, bool is_sink) { + return kind <= kInjective; + }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); } } continue; From 6535dcfc53a91dc29234f4b270f8a3af0f358038 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 26 Apr 2019 18:33:15 +0900 Subject: [PATCH 09/11] add inception test case --- src/relay/pass/fuse_ops.cc | 2 + tests/python/relay/test_pass_fuse_ops.py | 72 ++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 00c38ca5c059..55d609872929 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -675,6 +675,8 @@ class GraphPartitioner { auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; + // dom_root_group can also be tuple, as in inception layers + // CheckPath is needed to avoid fusing two intermediate tuples if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { CommitFuse(graph_node, dom_node->parent->gnode); } diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index d6d7bc9d479a..bdffdf7c129f 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -432,6 +432,77 @@ def expected(dshape): assert relay.ir_pass.alpha_equal(zz, after) +def test_inception_like(): + def conv(data): + y = relay.nn.conv2d(data, relay.var("w"), + kernel_size=(3, 3), + padding=(1, 1), + channels=16) + return relay.nn.relu(data=y) + + def inception_like(data): + c0 = conv(data) + c1 = conv(data) + return relay.concatenate((c0, c1), axis=1) + + def before(dshape): + x = relay.var("x", shape=dshape) + in1 = inception_like(x) + in2 = inception_like(in1) + return relay.Function(relay.ir_pass.free_vars(in2), in2) + + def expected(dshape): + p0 = relay.var("p0", shape=dshape) + c = conv(p0) + f0 = relay.Function(relay.ir_pass.free_vars(c), c) + + p01 = relay.var("p01", shape=dshape) + c = conv(p01) + f1 = relay.Function(relay.ir_pass.free_vars(c), c) + + p02 = relay.var("p02", shape=dshape) + p12 = relay.var("p12", shape=dshape) + concat1 = relay.concatenate((p02, p12), axis=1) + f_concat1 = relay.Function([p02, p12], concat1) + + dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3]) + + p03 = relay.var("p03", shape=dshape2) + c = conv(p03) + f2 = relay.Function(relay.ir_pass.free_vars(c), c) + + p04 = relay.var("p04", shape=dshape2) + c = conv(p04) + f3 = relay.Function(relay.ir_pass.free_vars(c), c) + + p05 = relay.var("p05", shape=dshape) + p15 = relay.var("p15", shape=dshape) + concat2 = relay.concatenate((p05, p15), axis=1) + f_concat2 = relay.Function([p05, p15], concat2) + + x = relay.var("x", shape=dshape) + c1 = relay.Call(f0, [x, relay.var("w1")]) + c2 = relay.Call(f1, [x, relay.var("w2")]) + concat = relay.Call(f_concat1, [c1, c2]) + c3 = relay.Call(f2, [concat, relay.var("w3")]) + c4 = relay.Call(f3, [concat, relay.var("w4")]) + out = relay.Call(f_concat2, [c3, c4]) + + return relay.Function(relay.ir_pass.free_vars(out), out) + + dshape = (1, 16, 64, 64) + z = before(dshape) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + relay.build(zz, 'llvm') + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(dshape)) + assert relay.ir_pass.alpha_equal(zz, after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -443,3 +514,4 @@ def expected(dshape): test_tuple_get_root() test_tuple_intermediate() test_tuple_consecutive() + test_inception_like() From 7cb17242152418fe1e89ad27639e158a34e72b74 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 28 Apr 2019 08:15:39 +0900 Subject: [PATCH 10/11] make kTuple 7 --- include/tvm/relay/op_attr_types.h | 12 ++++++------ python/tvm/relay/op/op.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 45749ac894f0..ca7f6e5d3908 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -41,17 +41,17 @@ enum OpPatternKind { // for example :code:`out[i, ax1, j, ax2] = input[i, j]`. // Note that the axis need to be in order so transpose is not a bcast operator. kBroadcast = 1, - // The pattern for tuple nodes. Can fuse into subsequent injective ops. - kTuple = 2, // Injective operator, can always injectively map output axis to a single input axis. // All injective operator can still be safely fused to injective and reduction. - kInjective = 3, + kInjective = 2, // Communicative reduction operator. - kCommReduce = 4, + kCommReduce = 3, // Complex operation, can still fuse elemwise operations into its output. // but cannot chain another complex op - kOutEWiseFusable = 5, - + kOutEWiseFusable = 4, + // The pattern for tuple nodes. Can fuse into subsequent injective ops, + // but treated specially + kTuple = 7, // Opaque operation, cannot fuse anything. kOpaque = 8 }; diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index e0d597525ebd..9099c8128de6 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -106,14 +106,14 @@ class OpPattern(object): ELEMWISE = 0 # Broadcast operator BROADCAST = 1 - # Represents tuple node - TUPLE = 2 # Injective mapping - INJECTIVE = 3 + INJECTIVE = 2 # Communication - COMM_REDUCE = 4 + COMM_REDUCE = 3 # Complex op, can still fuse ewise into it - OUT_ELEMWISE_FUSABLE = 5 + OUT_ELEMWISE_FUSABLE = 4 + # Represents tuple node + TUPLE = 2 # Not fusable opaque op OPAQUE = 8 From b278a749a1483b60ed42c6caef9f81b9eca26ffb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 28 Apr 2019 08:16:23 +0900 Subject: [PATCH 11/11] fix typo --- python/tvm/relay/op/op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 9099c8128de6..6ba207934d1b 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -113,7 +113,7 @@ class OpPattern(object): # Complex op, can still fuse ewise into it OUT_ELEMWISE_FUSABLE = 4 # Represents tuple node - TUPLE = 2 + TUPLE = 7 # Not fusable opaque op OPAQUE = 8