From a0892e04828dd05c5380596db6fa50828869522e Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 25 Aug 2020 05:15:23 +0900 Subject: [PATCH] [OpFusion] Make the max number of fused ops configurable (#6327) --- src/relay/transforms/fuse_ops.cc | 52 +++++++++++++++---- tests/python/relay/test_pass_fuse_ops.py | 66 +++++++++++++++++++++--- 2 files changed, 101 insertions(+), 17 deletions(-) diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 01f1eeea30b36..85b74cc420631 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -83,6 +83,8 @@ constexpr uint32_t kMaxFusedOps = 256; static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion"); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer); + /*! * \brief Indexed data flow graph in forward direction. * This is a temporary data structure used for operator fusion analysis. @@ -496,8 +498,8 @@ DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForward */ class GraphPartitioner { public: - explicit GraphPartitioner(support::Arena* arena, int opt_level) - : arena_(arena), opt_level_(opt_level) {} + explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth) + : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {} /*! * \brief Group as a union find data structure. */ @@ -549,6 +551,8 @@ class GraphPartitioner { support::Arena* arena_; /*! \brief optimization level for fuse operation. */ int opt_level_; + /*! \brief The maximum number of operations in one fused function */ + size_t max_fuse_depth_; /*! \brief The internal groups. */ std::vector groups_; /*! \brief internal field used for deduplication */ @@ -604,11 +608,11 @@ class GraphPartitioner { * \param parent The parent group. */ void MergeFromTo(Group* child, Group* parent) { - // update the number of nodes of the parent group - parent->num_nodes += child->num_nodes; child = child->FindRoot(); parent = parent->FindRoot(); if (child == parent) return; + // update the number of nodes of the parent group + parent->num_nodes += child->num_nodes; child->parent = parent; // update master ref and pattern if (child->master_ref != nullptr) { @@ -643,6 +647,32 @@ class GraphPartitioner { CommitFuse_(src, sink, target); } + size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { + if (src == sink || visited_.count(src)) return 0; + visited_.insert(src); + Group* gnode = groups_[src->index]; + CHECK(gnode != nullptr); + auto sum = gnode->num_nodes; + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + sum += CountNodesUptoSink_(link->value.node, sink); + } + return sum; + } + + // Count the number of nodes in a fused subgraph if child is additionaly fused. + // dom_parent is already known to be a part of the subgraph. + // For a diamond structure, there can be multiple paths connecting child and dom_parent. + // All intermediate nodes between child and dom_parent are taken into account. + // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot() + // is important for correct calculation. + size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, + IndexedForwardGraph::Node* dom_parent) { + Group* target = groups_[dom_parent->index]; + visited_.clear(); + CHECK(child != dom_parent); + return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); + } + // Initialize the groups. void InitGroups(const IndexedForwardGraph& graph) { groups_.resize(graph.post_dfs_order.size()); @@ -675,7 +705,8 @@ class GraphPartitioner { size_t dom_parent_gindex = dom_node->parent->gnode->index; // refuse the fusion if too many ops are going to be fused together - if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps) continue; + if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) + continue; if (phase == 2) { // Fuse injective ops into intermediate tuples, if any @@ -769,10 +800,10 @@ std::vector GraphPartitioner::Partition( class FuseMutator : private ExprMutator { public: // Run the transform - Expr Transform(const Expr& body, int fuse_opt_level) { + Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth) { // setup the group map. auto graph = IndexedForwardGraph::Create(&arena_, body); - auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(graph); + auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph); for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { CHECK(graph.post_dfs_order[nid]->ref != nullptr); gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; @@ -926,8 +957,8 @@ class FuseMutator : private ExprMutator { } }; -Expr FuseOps(const Expr& expr, int fuse_opt_level, const IRModule& module) { - return FuseMutator().Transform(expr, fuse_opt_level); +Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, const IRModule& module) { + return FuseMutator().Transform(expr, fuse_opt_level, max_fuse_depth); } namespace transform { @@ -936,7 +967,8 @@ Pass FuseOps(int fuse_opt_level) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; - return Downcast(FuseOps(f, opt_level, m)); + auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps)); + return Downcast(FuseOps(f, opt_level, max_fuse_depth.value(), m)); }; return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"}); } diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 1727429e74de8..90e80d8e673f3 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -587,17 +587,14 @@ def test_split(): def test_fuse_max(): """Test the constraint of number of nodes in op fusion.""" - max_fused_ops = 256 - # n is the number of nodes to be fused, should be less than 2*max_fused_ops - n = 300 - def before(): + def before(n): x = relay.var("x", shape=(10, 20)) y = x for i in range(n): y = relay.exp(y) return relay.Function([x], y) - def expected(): + def expected(n, max_fused_ops): x = relay.var("p", shape=(10, 20)) y = x for i in range(max_fused_ops): @@ -608,6 +605,7 @@ def expected(): z = relay.Call(f1, [x]) xx = relay.var("pp", shape=(10, 20)) yy = xx + # it is assumed that there are two fused functions for i in range(n-max_fused_ops): yy = relay.exp(yy) f2 = relay.Function([xx], yy) @@ -615,10 +613,22 @@ def expected(): zz = relay.Call(f2, [z]) return relay.Function([x], zz) - z = before() + max_fused_ops = 256 + n = 300 + z = before(n) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps()) - after = run_opt_pass(expected(), transform.InferType()) + after = run_opt_pass(expected(n, max_fused_ops), transform.InferType()) + assert tvm.ir.structural_equal(zz, after) + + max_fused_ops = 10 + n = 20 + z = before(n) + after = run_opt_pass(expected(n, max_fused_ops), transform.InferType()) + + with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): + zz = run_opt_pass(z, transform.FuseOps()) + assert tvm.ir.structural_equal(zz, after) @@ -722,6 +732,47 @@ def expected(): assert tvm.ir.structural_equal(m["main"], after) +def test_fuse_max_diamond(): + def create_diamond(x, branch_len): + x1 = x + x2 = x + for _ in range(branch_len): + x1 = relay.exp(x1) + x2 = relay.exp(x2) + return relay.add(x1, x2) + + def before(branch_len, num_diamond): + x = relay.var("x", shape=(10, 20)) + out = x + for _ in range(num_diamond): + out = create_diamond(out, branch_len) + return relay.Function([x], out) + + def after(branch_len, num_diamond): + def create_diamond_func(inp): + inp_var = relay.var("p", shape=(10, 20)) + d = create_diamond(inp_var, branch_len) + f = relay.Function([inp_var], d) + f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + return relay.Call(f, [inp]) + + inp = relay.var("x", shape=(10, 20)) + out = inp + for _ in range(num_diamond): + out = create_diamond_func(out) + return relay.Function([inp], out) + + branch_len = 5 + max_fused_ops = branch_len * 2 + 1 # the number of ops in one diamond + num_diamond = 3 + + with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): + fused = run_opt_pass(before(branch_len, num_diamond), transform.FuseOps()) + + expected = run_opt_pass(after(branch_len, num_diamond), transform.InferType()) + assert tvm.ir.structural_equal(fused, expected) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -741,3 +792,4 @@ def expected(): test_fuse_take() test_fuse_gather_nd() test_fuse_bcast_reduce_scalar() + test_fuse_max_diamond()