Skip to content

Commit

Permalink
[OpFusion] Make the max number of fused ops configurable (#6327)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored Aug 24, 2020
1 parent 37cbbd7 commit 6b5176d
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 17 deletions.
52 changes: 42 additions & 10 deletions src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ constexpr uint32_t kMaxFusedOps = 256;

static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion");

TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer);

/*!
* \brief Indexed data flow graph in forward direction.
* This is a temporary data structure used for operator fusion analysis.
Expand Down Expand Up @@ -496,8 +498,8 @@ DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForward
*/
class GraphPartitioner {
public:
explicit GraphPartitioner(support::Arena* arena, int opt_level)
: arena_(arena), opt_level_(opt_level) {}
explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth)
: arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {}
/*!
* \brief Group as a union find data structure.
*/
Expand Down Expand Up @@ -549,6 +551,8 @@ class GraphPartitioner {
support::Arena* arena_;
/*! \brief optimization level for fuse operation. */
int opt_level_;
/*! \brief The maximum number of operations in one fused function */
size_t max_fuse_depth_;
/*! \brief The internal groups. */
std::vector<Group*> groups_;
/*! \brief internal field used for deduplication */
Expand Down Expand Up @@ -604,11 +608,11 @@ class GraphPartitioner {
* \param parent The parent group.
*/
void MergeFromTo(Group* child, Group* parent) {
// update the number of nodes of the parent group
parent->num_nodes += child->num_nodes;
child = child->FindRoot();
parent = parent->FindRoot();
if (child == parent) return;
// update the number of nodes of the parent group
parent->num_nodes += child->num_nodes;
child->parent = parent;
// update master ref and pattern
if (child->master_ref != nullptr) {
Expand Down Expand Up @@ -643,6 +647,32 @@ class GraphPartitioner {
CommitFuse_(src, sink, target);
}

size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
if (src == sink || visited_.count(src)) return 0;
visited_.insert(src);
Group* gnode = groups_[src->index];
CHECK(gnode != nullptr);
auto sum = gnode->num_nodes;
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
sum += CountNodesUptoSink_(link->value.node, sink);
}
return sum;
}

// Count the number of nodes in a fused subgraph if child is additionaly fused.
// dom_parent is already known to be a part of the subgraph.
// For a diamond structure, there can be multiple paths connecting child and dom_parent.
// All intermediate nodes between child and dom_parent are taken into account.
// Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot()
// is important for correct calculation.
size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
IndexedForwardGraph::Node* dom_parent) {
Group* target = groups_[dom_parent->index];
visited_.clear();
CHECK(child != dom_parent);
return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
}

// Initialize the groups.
void InitGroups(const IndexedForwardGraph& graph) {
groups_.resize(graph.post_dfs_order.size());
Expand Down Expand Up @@ -675,7 +705,8 @@ class GraphPartitioner {
size_t dom_parent_gindex = dom_node->parent->gnode->index;

// refuse the fusion if too many ops are going to be fused together
if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps) continue;
if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
continue;

if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
Expand Down Expand Up @@ -769,10 +800,10 @@ std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
class FuseMutator : private ExprMutator {
public:
// Run the transform
Expr Transform(const Expr& body, int fuse_opt_level) {
Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth) {
// setup the group map.
auto graph = IndexedForwardGraph::Create(&arena_, body);
auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(graph);
auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph);
for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
CHECK(graph.post_dfs_order[nid]->ref != nullptr);
gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
Expand Down Expand Up @@ -926,8 +957,8 @@ class FuseMutator : private ExprMutator {
}
};

Expr FuseOps(const Expr& expr, int fuse_opt_level, const IRModule& module) {
return FuseMutator().Transform(expr, fuse_opt_level);
Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, const IRModule& module) {
return FuseMutator().Transform(expr, fuse_opt_level, max_fuse_depth);
}

namespace transform {
Expand All @@ -936,7 +967,8 @@ Pass FuseOps(int fuse_opt_level) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
return Downcast<Function>(FuseOps(f, opt_level, m));
auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps));
return Downcast<Function>(FuseOps(f, opt_level, max_fuse_depth.value(), m));
};
return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
}
Expand Down
66 changes: 59 additions & 7 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,17 +587,14 @@ def test_split():

def test_fuse_max():
"""Test the constraint of number of nodes in op fusion."""
max_fused_ops = 256
# n is the number of nodes to be fused, should be less than 2*max_fused_ops
n = 300
def before():
def before(n):
x = relay.var("x", shape=(10, 20))
y = x
for i in range(n):
y = relay.exp(y)
return relay.Function([x], y)

def expected():
def expected(n, max_fused_ops):
x = relay.var("p", shape=(10, 20))
y = x
for i in range(max_fused_ops):
Expand All @@ -608,17 +605,30 @@ def expected():
z = relay.Call(f1, [x])
xx = relay.var("pp", shape=(10, 20))
yy = xx
# it is assumed that there are two fused functions
for i in range(n-max_fused_ops):
yy = relay.exp(yy)
f2 = relay.Function([xx], yy)
f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
zz = relay.Call(f2, [z])
return relay.Function([x], zz)

z = before()
max_fused_ops = 256
n = 300
z = before(n)
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(), transform.InferType())
after = run_opt_pass(expected(n, max_fused_ops), transform.InferType())
assert tvm.ir.structural_equal(zz, after)

max_fused_ops = 10
n = 20
z = before(n)
after = run_opt_pass(expected(n, max_fused_ops), transform.InferType())

with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}):
zz = run_opt_pass(z, transform.FuseOps())

assert tvm.ir.structural_equal(zz, after)


Expand Down Expand Up @@ -722,6 +732,47 @@ def expected():
assert tvm.ir.structural_equal(m["main"], after)


def test_fuse_max_diamond():
def create_diamond(x, branch_len):
x1 = x
x2 = x
for _ in range(branch_len):
x1 = relay.exp(x1)
x2 = relay.exp(x2)
return relay.add(x1, x2)

def before(branch_len, num_diamond):
x = relay.var("x", shape=(10, 20))
out = x
for _ in range(num_diamond):
out = create_diamond(out, branch_len)
return relay.Function([x], out)

def after(branch_len, num_diamond):
def create_diamond_func(inp):
inp_var = relay.var("p", shape=(10, 20))
d = create_diamond(inp_var, branch_len)
f = relay.Function([inp_var], d)
f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
return relay.Call(f, [inp])

inp = relay.var("x", shape=(10, 20))
out = inp
for _ in range(num_diamond):
out = create_diamond_func(out)
return relay.Function([inp], out)

branch_len = 5
max_fused_ops = branch_len * 2 + 1 # the number of ops in one diamond
num_diamond = 3

with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}):
fused = run_opt_pass(before(branch_len, num_diamond), transform.FuseOps())

expected = run_opt_pass(after(branch_len, num_diamond), transform.InferType())
assert tvm.ir.structural_equal(fused, expected)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -741,3 +792,4 @@ def expected():
test_fuse_take()
test_fuse_gather_nd()
test_fuse_bcast_reduce_scalar()
test_fuse_max_diamond()

0 comments on commit 6b5176d

Please sign in to comment.