Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, OpFusion] Better tuple fusion implementation #3092

Merged
merged 11 commits into from
Apr 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ enum OpPatternKind {
// Complex operation, can still fuse elemwise operations into its output.
// but cannot chain another complex op
kOutEWiseFusable = 4,
// The pattern for tuple nodes. Can fuse into subsequent injective ops,
// but treated specially
kTuple = 7,
// Opaque operation, cannot fuse anything.
kOpaque = 8
};
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class OpPattern(object):
COMM_REDUCE = 3
# Complex op, can still fuse ewise into it
OUT_ELEMWISE_FUSABLE = 4
# Represents tuple node
TUPLE = 7
# Not fusable opaque op
OPAQUE = 8

Expand Down
54 changes: 30 additions & 24 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
void VisitExpr_(const TupleNode* op) final {
CHECK(graph_.node_map.count(op));
Node* tuple_node = graph_.node_map.at(op);
tuple_node->pattern = kInjective;
tuple_node->pattern = kTuple;
for (const Expr& field : op->fields) {
if (field->checked_type().as<TensorTypeNode>()) {
this->Update(field, tuple_node, kInjective);
Expand Down Expand Up @@ -661,12 +661,36 @@ class GraphPartitioner {
// no actions needed if the current node have no dominator
if (dom_node->parent == nullptr) continue;
CHECK(!graph_node->extern_ref);
// Skip if current node is already fused to the parent.
size_t dom_parent_gindex = dom_node->parent->gnode->index;

if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
if (group_node->pattern > kInjective) continue;
Group* dom_parent_group = groups_[dom_parent_gindex];
Group* dom_root_group = dom_parent_group->FindRoot();
// If dom node group has a tuple as its root, we do not fuse tuple fields into it
if (dom_root_group->pattern == kTuple) continue;
if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) {
// Now we know the tuple has been fused into subsequent injective ops
auto fcond = [](OpPatternKind kind, bool is_sink) {
return kind <= kInjective;
};
// dom_root_group can also be tuple, as in inception layers
// CheckPath is needed to avoid fusing two intermediate tuples
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
continue;
}

// Skip if current node is already fused to the parent.
if (groups_[dom_parent_gindex] != nullptr &&
group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
continue;
}
// Do not fuse into tuple for now
if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
// Try to fuse current node to its post-dominator.
if (group_node->pattern == kOutEWiseFusable) {
if (phase != 0) continue;
Expand Down Expand Up @@ -702,7 +726,7 @@ class GraphPartitioner {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
} else if (group_node->pattern == kInjective) {
} else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
// defer injective fusion to second phase.
// so conv2d always finishes fusing.
if (phase != 1) continue;
Expand All @@ -728,7 +752,7 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
// get post dominator tree
auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
// run fusion algorithm.
for (int phase = 0; phase < 2; ++phase) {
for (int phase = 0; phase < 3; ++phase) {
this->RunFuse(graph, post_dom_tree, phase);
}
return std::move(groups_);
Expand Down Expand Up @@ -821,29 +845,11 @@ class FuseMutator : private ExprMutator {

Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
if (ret_group == gmap_.at(tuple)) {
// This tuple is the root of its group. Check if all fields come from other groups.
bool isolated = new_fields.size() == ginfo_[ret_group].params.size();
for (size_t i = 0; i < new_fields.size() && isolated; ++i) {
isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i]));
}
if (isolated) {
// Do not put a isolated tuple into a function
return ExprMutator::VisitExpr_(tuple);
}
// This tuple has been fused with other ops before it
for (size_t i = 0; i < new_fields.size(); i++) {
// Copy function arguments to tuple field of the output because currently graph memory
// planer doesn't support inplace operations
if (new_fields[i].as<VarNode>()) {
auto copy = Copy(new_fields[i]);
new_fields.Set(i, copy);
}
}
return MakeNewFunction(ret_group, tuple->checked_type(), TupleNode::make(new_fields));
return ExprMutator::VisitExpr_(tuple);
}
// This tuple is an intermediate node in the group
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
return TupleNode::make(new_fields);
}

Expand Down
10 changes: 9 additions & 1 deletion tests/python/relay/test_backend_compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,16 @@ def test_compile_injective_with_tuple():
relay.build(func, 'llvm')


def test_compile_tuple_dup():
x = relay.var("data", shape=(16, 16))
log = relay.log(x)
output = relay.Tuple([log, log])
f = relay.Function([x], output)
relay.build(f, 'llvm')


if __name__ == "__main__":
test_compile_engine()
test_compile_placeholder_bypass()
test_compile_injective_with_tuple()

test_compile_tuple_dup()
212 changes: 170 additions & 42 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,14 @@ def expected(dshape):
f0 = relay.Function([x], pooled)

p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
p1_copy = relay.copy(p1)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, p1_copy))
f1 = relay.Function([p0, p1], out)
f1 = relay.Function([p0], upsampled)

x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y, x])
return relay.Function([x], z)
z = relay.Call(f1, [y])
tup = relay.Tuple((z, x))
return relay.Function([x], tup)

dshape = (1, 16, 64, 64)
z = before(dshape)
Expand All @@ -199,41 +197,6 @@ def expected(dshape):
assert relay.ir_pass.alpha_equal(zz, after)


def test_tuple_strided_slice():
"""
Test fusion case where the number of fields of tuple and
the number of parameters to the function containing the tuple are different
"""

def before(dshape):
x = relay.var("x", shape=dshape)
slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
out = relay.Tuple((slice1, slice2))
return relay.Function([x], out)

def expected(dshape):
x = relay.var("x", shape=dshape)
slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
out = relay.Tuple((slice1, slice2))
f0 = relay.Function([x], out)

x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
return relay.Function([x], y)

dshape = (64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)


def test_stop_fusion():
def before(dshape):
Expand Down Expand Up @@ -377,13 +340,178 @@ def expected(dim):
assert relay.ir_pass.alpha_equal(zz, after)


def test_tuple_intermediate():
def before(x):
inj = relay.squeeze(x)
y1 = relay.add(inj, relay.const(1, "float32"))
tmp = relay.squeeze(inj)
tmp = relay.add(tmp, relay.const(1, "float32"))
y2 = relay.add(tmp, relay.const(1, "float32"))
y3 = relay.add(inj, relay.const(1, "float32"))
concat = relay.concatenate((y1, y2, y3), axis=1)
out_inj = relay.squeeze(concat)
out = relay.add(out_inj, relay.const(1, "float32"))
return relay.Function(relay.ir_pass.free_vars(out), out)

def expected(p0):
f0 = before(p0)
x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
return relay.Function([x], y)

dshape = (1, 16, 64, 64)
x = relay.var("x", shape=dshape)
z = before(x)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
relay.build(zz, 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(x))
assert relay.ir_pass.alpha_equal(zz, after)


def test_tuple_consecutive():
def gen_intermediate_tuple(x):
y1 = relay.add(x, relay.const(1, "float32"))
y2 = relay.add(x, relay.const(1, "float32"))
y3 = relay.add(x, relay.const(1, "float32"))
concat = relay.concatenate((y1, y2, y3), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
return out

def gen_consecutive_tuple(x):
y1 = gen_intermediate_tuple(x)
y2 = gen_intermediate_tuple(x)
y3 = gen_intermediate_tuple(x)
concat = relay.concatenate((y1, y2, y3), axis=1)
return concat

def before(x):
concat = gen_consecutive_tuple(x)
pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
out = relay.add(pooled, relay.const(1, "float32"))
out2 = relay.add(out, relay.const(1, "float32"))
out_tup = relay.Tuple((out, out2))
return relay.Function(relay.ir_pass.free_vars(out_tup), out_tup)

def expected(dshape):
p0 = relay.var("p0", shape=dshape)
concat = gen_consecutive_tuple(p0)
f0 = relay.Function([p0], concat)

p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3]))
pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
out = relay.add(pooled, relay.const(1, "float32"))
f1 = relay.Function([p01], out)

p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2))
out = relay.add(p02, relay.const(1, "float32"))
f2 = relay.Function([p02], out)

x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y])
z2 = relay.Call(f2, [z])

return relay.Function([x], relay.Tuple((z, z2)))

dshape = (1, 16, 64, 64)
x = relay.var("x", shape=dshape)
z = before(x)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
relay.build(zz, 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)


def test_inception_like():
def conv(data):
y = relay.nn.conv2d(data, relay.var("w"),
kernel_size=(3, 3),
padding=(1, 1),
channels=16)
return relay.nn.relu(data=y)

def inception_like(data):
c0 = conv(data)
c1 = conv(data)
return relay.concatenate((c0, c1), axis=1)

def before(dshape):
x = relay.var("x", shape=dshape)
in1 = inception_like(x)
in2 = inception_like(in1)
return relay.Function(relay.ir_pass.free_vars(in2), in2)

def expected(dshape):
p0 = relay.var("p0", shape=dshape)
c = conv(p0)
f0 = relay.Function(relay.ir_pass.free_vars(c), c)

p01 = relay.var("p01", shape=dshape)
c = conv(p01)
f1 = relay.Function(relay.ir_pass.free_vars(c), c)

p02 = relay.var("p02", shape=dshape)
p12 = relay.var("p12", shape=dshape)
concat1 = relay.concatenate((p02, p12), axis=1)
f_concat1 = relay.Function([p02, p12], concat1)

dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3])

p03 = relay.var("p03", shape=dshape2)
c = conv(p03)
f2 = relay.Function(relay.ir_pass.free_vars(c), c)

p04 = relay.var("p04", shape=dshape2)
c = conv(p04)
f3 = relay.Function(relay.ir_pass.free_vars(c), c)

p05 = relay.var("p05", shape=dshape)
p15 = relay.var("p15", shape=dshape)
concat2 = relay.concatenate((p05, p15), axis=1)
f_concat2 = relay.Function([p05, p15], concat2)

x = relay.var("x", shape=dshape)
c1 = relay.Call(f0, [x, relay.var("w1")])
c2 = relay.Call(f1, [x, relay.var("w2")])
concat = relay.Call(f_concat1, [c1, c2])
c3 = relay.Call(f2, [concat, relay.var("w3")])
c4 = relay.Call(f3, [concat, relay.var("w4")])
out = relay.Call(f_concat2, [c3, c4])

return relay.Function(relay.ir_pass.free_vars(out), out)

dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
relay.build(zz, 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_concatenate()
test_tuple_root()
test_tuple_strided_slice()
test_stop_fusion()
test_fuse_myia_regression()
test_fuse_tuple_get_elemwise()
test_tuple_get_root()
test_tuple_intermediate()
test_tuple_consecutive()
test_inception_like()