Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Add support for TupleGetItem in op fusion #2914

Merged
merged 7 commits into from
Mar 29, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,23 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}

void VisitExpr_(const TupleGetItemNode* op) final {
CHECK(graph_.node_map.count(op));
Node* node = graph_.node_map.at(op);
this->Update(op->tuple, node, kOpaque);
auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>();
CHECK(tuple_type);
bool has_reference = false;
for (auto ty : tuple_type->fields) {
if (ty.as<RefTypeNode>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if there is a reference inside? does it get recursively handled?

Copy link
Member Author

@masahi masahi Mar 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tuple fields are recursively visited via call to ExprVisitor::VisitExpr_(op) below, even if it is a Reference nodes. Ref nodes are never fused with its parent TupleGetItemNode.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the comment block to explain this so that future developers won't be confused here

has_reference = true;
break;
}
}
if (has_reference) {
this->Update(op->tuple, nullptr, kOpaque);
} else {
CHECK(graph_.node_map.count(op));
Node* node = graph_.node_map.at(op);
node->pattern = kInjective;
this->Update(op->tuple, node, kInjective);
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
Expand Down Expand Up @@ -809,6 +823,23 @@ class FuseMutator : private ExprMutator {
return TupleNode::make(new_fields);
}

Expr VisitExpr_(const TupleGetItemNode* tuple_get) {
auto* ret_group = gmap_.at(tuple_get)->FindRoot();
auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index);
if (ret_group == gmap_.at(tuple_get)) {
if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
// Isolated. This case occurs when tuple is created by an Opaque op
// e.g. multibox_transform_loc
return ExprMutator::VisitExpr_(tuple_get);
}
// A new function whose output is a tuple field access
return MakeNewFunction(ret_group, tuple_get->checked_type(), new_node);
}
// This is an intermediate node in the group
return new_node;
masahi marked this conversation as resolved.
Show resolved Hide resolved
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
Expand Down
39 changes: 39 additions & 0 deletions tests/python/relay/test_backend_graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add
from tvm.relay.module import Module
from tvm.relay.testing.config import ctx_list

# @tq, @jr should we put this in testing ns?
def check_rts(expr, args, expected_result, mod=None):
Expand Down Expand Up @@ -127,9 +128,47 @@ def test_plan_memory():
assert len(device_types) == 1


def test_gru():
def gru(rnn_dim):
X = relay.var("X", shape=(1, rnn_dim))
W = relay.var("y", shape=(3 * rnn_dim, rnn_dim))
matmul = relay.nn.dense(X, W)
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
return relay.Function([X, W], out)

def sigmoid(x):
return 1 / (1 + np.exp(-x))

def gru_numpy(X, W):
prod = np.dot(X, W.transpose())
splits = np.split(prod, indices_or_sections=3, axis=1)
return sigmoid(splits[0]) + np.tanh(splits[1]) * np.exp(splits[2])

dtype = "float32"
rnn_dim = 1000
x = np.random.rand(1, rnn_dim).astype(dtype)
y = np.random.rand(3*rnn_dim, rnn_dim).astype(dtype) * 0.01 - 0.005
out_shape = (1, rnn_dim)
z = gru(rnn_dim)

for target, ctx in ctx_list():
with relay.build_config(opt_level=2):
graph, lib, params = relay.build(z, target)
m = graph_runtime.create(graph, lib, ctx)
m.set_input("X", tvm.nd.array(x.astype(dtype)))
m.set_input("y", tvm.nd.array(y.astype(dtype)))
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
ref = gru_numpy(x, y)
tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
test_plan_memory()
test_with_params()
test_add_op_scalar()
test_add_op_tensor()
test_add_op_broadcast()
test_gru()
78 changes: 77 additions & 1 deletion tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def expected(dshape):
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
print(zz.astext())


def test_stop_fusion():
Expand Down Expand Up @@ -287,6 +286,81 @@ def expected(dshape, dtype):
assert relay.ir_pass.alpha_equal(f, after)


def test_gru():
masahi marked this conversation as resolved.
Show resolved Hide resolved
def before(rnn_dim):
X = relay.var("X", shape=(1, rnn_dim))
W = relay.var("W", shape=(3 * rnn_dim, rnn_dim))
matmul = relay.nn.dense(X, W)
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
return relay.Function([X, W], out)

def expected(rnn_dim):
p0 = relay.var("p0", shape=(1, rnn_dim))
p1 = relay.var("p1", shape=(3 * rnn_dim, rnn_dim))
matmul = relay.nn.dense(p0, p1)
f0 = relay.Function([p0, p1], matmul)

p01 = relay.var("p01", shape=(1, 3 * rnn_dim))
splitted = relay.split(p01, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
f1 = relay.Function([p01], out)

X = relay.var("X", shape=(1, rnn_dim))
W = relay.var("W", shape=(3 * rnn_dim, rnn_dim))
y = relay.Call(f0, [X, W])
z = relay.Call(f1, [y])
return relay.Function([X, W], z)

rnn_dim = 10
z = before(rnn_dim)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(rnn_dim))
assert relay.ir_pass.alpha_equal(zz, after)


def test_tuple_get_root():
def before(dim):
X = relay.var("X", shape=(1, 3 * dim))
W = relay.var("W", shape=(dim, dim))
splitted = relay.split(X, indices_or_sections=3, axis=1)
out = relay.nn.dense(splitted[0], W)
return relay.Function([X, W], out)

def expected(dim):
p0 = relay.var("p0", shape=(1, 3 * dim))
splitted = relay.split(p0, indices_or_sections=3, axis=1)
out = splitted[0]
f0 = relay.Function([p0], out)

p01 = relay.var("p01", shape=(1, dim))
p1 = relay.var("p1", shape=(dim, dim))
out = relay.nn.dense(p01, p1)
f1 = relay.Function([p01, p1], out)

X = relay.var("X", shape=(1, 3 * dim))
W = relay.var("W", shape=(dim, dim))
y = relay.Call(f0, [X])
z = relay.Call(f1, [y, W])
return relay.Function([X, W], z)

dim = 10
z = before(dim)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dim))
assert relay.ir_pass.alpha_equal(zz, after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -295,3 +369,5 @@ def expected(dshape, dtype):
test_tuple_strided_slice()
test_stop_fusion()
test_fuse_myia_regression()
test_gru()
test_tuple_get_root()