Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpFusion] Make the max number of fused ops configurable #6327

Merged
merged 11 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ constexpr uint32_t kMaxFusedOps = 256;

static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion");

TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer);

/*!
* \brief Indexed data flow graph in forward direction.
* This is a temporary data structure used for operator fusion analysis.
Expand Down Expand Up @@ -496,8 +498,8 @@ DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForward
*/
class GraphPartitioner {
public:
explicit GraphPartitioner(support::Arena* arena, int opt_level)
: arena_(arena), opt_level_(opt_level) {}
explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth)
: arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {}
/*!
* \brief Group as a union find data structure.
*/
Expand Down Expand Up @@ -549,6 +551,8 @@ class GraphPartitioner {
support::Arena* arena_;
/*! \brief optimization level for fuse operation. */
int opt_level_;
/*! \brief The maximum number of operations in one fused function */
size_t max_fuse_depth_;
/*! \brief The internal groups. */
std::vector<Group*> groups_;
/*! \brief internal field used for deduplication */
Expand Down Expand Up @@ -604,11 +608,11 @@ class GraphPartitioner {
* \param parent The parent group.
*/
void MergeFromTo(Group* child, Group* parent) {
// update the number of nodes of the parent group
parent->num_nodes += child->num_nodes;
child = child->FindRoot();
parent = parent->FindRoot();
if (child == parent) return;
// update the number of nodes of the parent group
parent->num_nodes += child->num_nodes;
child->parent = parent;
// update master ref and pattern
if (child->master_ref != nullptr) {
Expand Down Expand Up @@ -643,6 +647,32 @@ class GraphPartitioner {
CommitFuse_(src, sink, target);
}

size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
if (src == sink || visited_.count(src)) return 0;
visited_.insert(src);
Group* gnode = groups_[src->index];
CHECK(gnode != nullptr);
auto sum = gnode->num_nodes;
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
sum += CountNodesUptoSink_(link->value.node, sink);
}
return sum;
}

// Count the number of nodes in a fused subgraph if child is additionaly fused.
// dom_parent is already known to be a part of the subgraph.
// For a diamond structure, there can be multiple paths connecting child and dom_parent.
// All intermediate nodes between child and dom_parent are taken into account.
// Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot()
// is important for correct calculation.
size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
IndexedForwardGraph::Node* dom_parent) {
Group* target = groups_[dom_parent->index];
visited_.clear();
CHECK(child != dom_parent);
return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
}

// Initialize the groups.
void InitGroups(const IndexedForwardGraph& graph) {
groups_.resize(graph.post_dfs_order.size());
Expand Down Expand Up @@ -675,7 +705,8 @@ class GraphPartitioner {
size_t dom_parent_gindex = dom_node->parent->gnode->index;

// refuse the fusion if too many ops are going to be fused together
if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps) continue;
if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
continue;

if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
Expand Down Expand Up @@ -769,10 +800,10 @@ std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
class FuseMutator : private ExprMutator {
public:
// Run the transform
Expr Transform(const Expr& body, int fuse_opt_level) {
Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth) {
// setup the group map.
auto graph = IndexedForwardGraph::Create(&arena_, body);
auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(graph);
auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph);
for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
CHECK(graph.post_dfs_order[nid]->ref != nullptr);
gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
Expand Down Expand Up @@ -926,8 +957,8 @@ class FuseMutator : private ExprMutator {
}
};

Expr FuseOps(const Expr& expr, int fuse_opt_level, const IRModule& module) {
return FuseMutator().Transform(expr, fuse_opt_level);
Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, const IRModule& module) {
return FuseMutator().Transform(expr, fuse_opt_level, max_fuse_depth);
}

namespace transform {
Expand All @@ -936,7 +967,8 @@ Pass FuseOps(int fuse_opt_level) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
return Downcast<Function>(FuseOps(f, opt_level, m));
auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps));
return Downcast<Function>(FuseOps(f, opt_level, max_fuse_depth.value(), m));
};
return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
}
Expand Down
66 changes: 59 additions & 7 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,17 +587,14 @@ def test_split():

def test_fuse_max():
"""Test the constraint of number of nodes in op fusion."""
max_fused_ops = 256
# n is the number of nodes to be fused, should be less than 2*max_fused_ops
n = 300
def before():
def before(n):
x = relay.var("x", shape=(10, 20))
y = x
for i in range(n):
y = relay.exp(y)
return relay.Function([x], y)

def expected():
def expected(n, max_fused_ops):
x = relay.var("p", shape=(10, 20))
y = x
for i in range(max_fused_ops):
Expand All @@ -608,17 +605,30 @@ def expected():
z = relay.Call(f1, [x])
xx = relay.var("pp", shape=(10, 20))
yy = xx
# it is assumed that there are two fused functions
for i in range(n-max_fused_ops):
yy = relay.exp(yy)
f2 = relay.Function([xx], yy)
f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
zz = relay.Call(f2, [z])
return relay.Function([x], zz)

z = before()
max_fused_ops = 256
n = 300
z = before(n)
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(), transform.InferType())
after = run_opt_pass(expected(n, max_fused_ops), transform.InferType())
assert tvm.ir.structural_equal(zz, after)

max_fused_ops = 10
n = 20
z = before(n)
after = run_opt_pass(expected(n, max_fused_ops), transform.InferType())

with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}):
zz = run_opt_pass(z, transform.FuseOps())

assert tvm.ir.structural_equal(zz, after)


Expand Down Expand Up @@ -722,6 +732,47 @@ def expected():
assert tvm.ir.structural_equal(m["main"], after)


def test_fuse_max_diamond():
def create_diamond(x, branch_len):
x1 = x
x2 = x
for _ in range(branch_len):
x1 = relay.exp(x1)
x2 = relay.exp(x2)
return relay.add(x1, x2)

def before(branch_len, num_diamond):
x = relay.var("x", shape=(10, 20))
out = x
for _ in range(num_diamond):
out = create_diamond(out, branch_len)
return relay.Function([x], out)

def after(branch_len, num_diamond):
def create_diamond_func(inp):
inp_var = relay.var("p", shape=(10, 20))
d = create_diamond(inp_var, branch_len)
f = relay.Function([inp_var], d)
f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
return relay.Call(f, [inp])

inp = relay.var("x", shape=(10, 20))
out = inp
for _ in range(num_diamond):
out = create_diamond_func(out)
return relay.Function([inp], out)

branch_len = 5
max_fused_ops = branch_len * 2 + 1 # the number of ops in one diamond
num_diamond = 3

with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}):
fused = run_opt_pass(before(branch_len, num_diamond), transform.FuseOps())

expected = run_opt_pass(after(branch_len, num_diamond), transform.InferType())
assert tvm.ir.structural_equal(fused, expected)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -741,3 +792,4 @@ def expected():
test_fuse_take()
test_fuse_gather_nd()
test_fuse_bcast_reduce_scalar()
test_fuse_max_diamond()