From d97a14023ef3df0ab0eb3b328519b8355dc6acb5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Mar 2019 01:46:50 +0900 Subject: [PATCH 1/7] OpFusion: add support for TupleGetItem node --- src/relay/pass/fuse_ops.cc | 18 +++++- tests/python/relay/test_pass_fuse_ops.py | 78 +++++++++++++++++++++++- 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 66ff9caf4ae4..4cefb7930612 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -263,7 +263,12 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const TupleGetItemNode* op) final { CHECK(graph_.node_map.count(op)); Node* node = graph_.node_map.at(op); - this->Update(op->tuple, node, kOpaque); + node->pattern = kInjective; + if (op->tuple->checked_type().as()) { + this->Update(op->tuple, node, kInjective); + } else { + this->Update(op->tuple, nullptr, kOpaque); + } ExprVisitor::VisitExpr_(op); this->AddNode(op); } @@ -809,6 +814,17 @@ class FuseMutator : private ExprMutator { return TupleNode::make(new_fields); } + Expr VisitExpr_(const TupleGetItemNode* tuple_get) { + auto new_node = TupleGetItemNode::make(this->Mutate(tuple_get->tuple), tuple_get->index); + auto* ret_group = gmap_.at(tuple_get)->FindRoot(); + if (ret_group == gmap_.at(tuple_get)) { + // unlike the tuple case above, this node should never be isolated + return MakeNewFunction(ret_group, tuple_get->checked_type(), new_node); + } + // This is an intermediate node in the group + return new_node; + } + Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { const GroupInfo& ginfo = ginfo_[group]; auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 634d69bae823..558314099863 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -217,7 +217,6 @@ def expected(dshape): assert not relay.ir_pass.free_vars(zz) after = relay.ir_pass.infer_type(expected(dshape)) assert relay.ir_pass.alpha_equal(zz, after) - print(zz.astext()) def test_stop_fusion(): @@ -287,6 +286,81 @@ def expected(dshape, dtype): assert relay.ir_pass.alpha_equal(f, after) +def test_gru(): + def before(rnn_dim): + X = relay.var("X", shape=(1, rnn_dim)) + W = relay.var("W", shape=(3 * rnn_dim, rnn_dim)) + matmul = relay.nn.dense(X, W) + splitted = relay.split(matmul, indices_or_sections=3, axis=1) + out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) + return relay.Function([X, W], out) + + def expected(rnn_dim): + p0 = relay.var("p0", shape=(1, rnn_dim)) + p1 = relay.var("p1", shape=(3 * rnn_dim, rnn_dim)) + matmul = relay.nn.dense(p0, p1) + f0 = relay.Function([p0, p1], matmul) + + p01 = relay.var("p01", shape=(1, 3 * rnn_dim)) + splitted = relay.split(p01, indices_or_sections=3, axis=1) + out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) + f1 = relay.Function([p01], out) + + X = relay.var("X", shape=(1, rnn_dim)) + W = relay.var("W", shape=(3 * rnn_dim, rnn_dim)) + y = relay.Call(f0, [X, W]) + z = relay.Call(f1, [y]) + return relay.Function([X, W], z) + + rnn_dim = 10 + z = before(rnn_dim) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(rnn_dim)) + assert relay.ir_pass.alpha_equal(zz, after) + + +def test_tuple_get_root(): + def before(dim): + X = relay.var("X", shape=(1, 3 * dim)) + W = relay.var("W", shape=(dim, dim)) + splitted = relay.split(X, indices_or_sections=3, axis=1) + out = relay.nn.dense(splitted[0], W) + return relay.Function([X, W], out) + + def expected(dim): + p0 = relay.var("p0", shape=(1, 3 * dim)) + splitted = relay.split(p0, indices_or_sections=3, axis=1) + out = splitted[0] + f0 = relay.Function([p0], out) + + p01 = relay.var("p01", shape=(1, dim)) + p1 = relay.var("p1", shape=(dim, dim)) + out = relay.nn.dense(p01, p1) + f1 = relay.Function([p01, p1], out) + + X = relay.var("X", shape=(1, 3 * dim)) + W = relay.var("W", shape=(dim, dim)) + y = relay.Call(f0, [X]) + z = relay.Call(f1, [y, W]) + return relay.Function([X, W], z) + + dim = 10 + z = before(dim) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(dim)) + assert relay.ir_pass.alpha_equal(zz, after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -295,3 +369,5 @@ def expected(dshape, dtype): test_tuple_strided_slice() test_stop_fusion() test_fuse_myia_regression() + test_gru() + test_tuple_get_root() From d17aabbaf55a4912f888bd5a38452adef3a2d236 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 28 Mar 2019 00:20:17 +0900 Subject: [PATCH 2/7] add gru runtime test --- .../relay/test_backend_graph_runtime.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 434b0e6ddfa1..9a68f338edcb 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -7,6 +7,7 @@ from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.op import add from tvm.relay.module import Module +from tvm.relay.testing.config import ctx_list # @tq, @jr should we put this in testing ns? def check_rts(expr, args, expected_result, mod=None): @@ -127,9 +128,47 @@ def test_plan_memory(): assert len(device_types) == 1 +def test_gru(): + def gru(rnn_dim): + X = relay.var("X", shape=(1, rnn_dim)) + W = relay.var("y", shape=(3 * rnn_dim, rnn_dim)) + matmul = relay.nn.dense(X, W) + splitted = relay.split(matmul, indices_or_sections=3, axis=1) + out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) + return relay.Function([X, W], out) + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + def gru_numpy(X, W): + prod = np.dot(X, W.transpose()) + splits = np.split(prod, indices_or_sections=3, axis=1) + return sigmoid(splits[0]) + np.tanh(splits[1]) * np.exp(splits[2]) + + dtype = "float32" + rnn_dim = 1000 + x = np.random.rand(1, rnn_dim).astype(dtype) + y = np.random.rand(3*rnn_dim, rnn_dim).astype(dtype) * 0.01 - 0.005 + out_shape = (1, rnn_dim) + z = gru(rnn_dim) + + for target, ctx in ctx_list(): + with relay.build_config(opt_level=2): + graph, lib, params = relay.build(z, target) + m = graph_runtime.create(graph, lib, ctx) + m.set_input("X", tvm.nd.array(x.astype(dtype))) + m.set_input("y", tvm.nd.array(y.astype(dtype))) + m.set_input(**params) + m.run() + out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy() + ref = gru_numpy(x, y) + tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": test_plan_memory() test_with_params() test_add_op_scalar() test_add_op_tensor() test_add_op_broadcast() + test_gru() From af66383f3e5d8e9725996f327be2565897c03ddd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 28 Mar 2019 00:44:33 +0900 Subject: [PATCH 3/7] fix for reference handling and isolated cases --- src/relay/pass/fuse_ops.cc | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 4cefb7930612..9aef2752a3cd 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -261,13 +261,22 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } void VisitExpr_(const TupleGetItemNode* op) final { - CHECK(graph_.node_map.count(op)); - Node* node = graph_.node_map.at(op); - node->pattern = kInjective; - if (op->tuple->checked_type().as()) { - this->Update(op->tuple, node, kInjective); - } else { + auto tuple_type = op->tuple->checked_type().as(); + CHECK(tuple_type); + bool has_reference = false; + for (auto ty : tuple_type->fields) { + if (auto ref_ty = ty.as()) { + has_reference = true; + break; + } + } + if (has_reference) { this->Update(op->tuple, nullptr, kOpaque); + } else { + CHECK(graph_.node_map.count(op)); + Node* node = graph_.node_map.at(op); + node->pattern = kInjective; + this->Update(op->tuple, node, kInjective); } ExprVisitor::VisitExpr_(op); this->AddNode(op); @@ -815,10 +824,16 @@ class FuseMutator : private ExprMutator { } Expr VisitExpr_(const TupleGetItemNode* tuple_get) { - auto new_node = TupleGetItemNode::make(this->Mutate(tuple_get->tuple), tuple_get->index); auto* ret_group = gmap_.at(tuple_get)->FindRoot(); + auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0]; + auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index); if (ret_group == gmap_.at(tuple_get)) { - // unlike the tuple case above, this node should never be isolated + if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) { + // Isolated. This case occurs when tuple is created by an Opaque op + // e.g. multibox_transform_loc + return ExprMutator::VisitExpr_(tuple_get); + } + // A new function whose output is a tuple field access return MakeNewFunction(ret_group, tuple_get->checked_type(), new_node); } // This is an intermediate node in the group From abef41da9271fdc6bc2c99f9a80517d9bbc304a1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 28 Mar 2019 07:51:16 +0900 Subject: [PATCH 4/7] suppress warning --- src/relay/pass/fuse_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 9aef2752a3cd..da566a18a099 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -265,7 +265,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { CHECK(tuple_type); bool has_reference = false; for (auto ty : tuple_type->fields) { - if (auto ref_ty = ty.as()) { + if (ty.as()) { has_reference = true; break; } From 7bd67aa2d8129ee24f8763e97d7e265ef5fcd003 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 29 Mar 2019 08:10:54 +0900 Subject: [PATCH 5/7] rename --- .../relay/test_backend_graph_runtime.py | 12 ++++---- tests/python/relay/test_pass_fuse_ops.py | 28 +++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 9a68f338edcb..56da263c9b4e 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -128,8 +128,8 @@ def test_plan_memory(): assert len(device_types) == 1 -def test_gru(): - def gru(rnn_dim): +def test_gru_like(): + def unit(rnn_dim): X = relay.var("X", shape=(1, rnn_dim)) W = relay.var("y", shape=(3 * rnn_dim, rnn_dim)) matmul = relay.nn.dense(X, W) @@ -140,7 +140,7 @@ def gru(rnn_dim): def sigmoid(x): return 1 / (1 + np.exp(-x)) - def gru_numpy(X, W): + def unit_numpy(X, W): prod = np.dot(X, W.transpose()) splits = np.split(prod, indices_or_sections=3, axis=1) return sigmoid(splits[0]) + np.tanh(splits[1]) * np.exp(splits[2]) @@ -150,7 +150,7 @@ def gru_numpy(X, W): x = np.random.rand(1, rnn_dim).astype(dtype) y = np.random.rand(3*rnn_dim, rnn_dim).astype(dtype) * 0.01 - 0.005 out_shape = (1, rnn_dim) - z = gru(rnn_dim) + z = unit(rnn_dim) for target, ctx in ctx_list(): with relay.build_config(opt_level=2): @@ -161,7 +161,7 @@ def gru_numpy(X, W): m.set_input(**params) m.run() out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy() - ref = gru_numpy(x, y) + ref = unit_numpy(x, y) tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) @@ -171,4 +171,4 @@ def gru_numpy(X, W): test_add_op_scalar() test_add_op_tensor() test_add_op_broadcast() - test_gru() + test_gru_like() diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 558314099863..5df6ad7d5226 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -286,41 +286,41 @@ def expected(dshape, dtype): assert relay.ir_pass.alpha_equal(f, after) -def test_gru(): - def before(rnn_dim): - X = relay.var("X", shape=(1, rnn_dim)) - W = relay.var("W", shape=(3 * rnn_dim, rnn_dim)) +def test_fuse_tuple_get_elemwise(): + def before(dim): + X = relay.var("X", shape=(1, dim)) + W = relay.var("W", shape=(3 * dim, dim)) matmul = relay.nn.dense(X, W) splitted = relay.split(matmul, indices_or_sections=3, axis=1) out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) return relay.Function([X, W], out) - def expected(rnn_dim): - p0 = relay.var("p0", shape=(1, rnn_dim)) - p1 = relay.var("p1", shape=(3 * rnn_dim, rnn_dim)) + def expected(dim): + p0 = relay.var("p0", shape=(1, dim)) + p1 = relay.var("p1", shape=(3 * dim, dim)) matmul = relay.nn.dense(p0, p1) f0 = relay.Function([p0, p1], matmul) - p01 = relay.var("p01", shape=(1, 3 * rnn_dim)) + p01 = relay.var("p01", shape=(1, 3 * dim)) splitted = relay.split(p01, indices_or_sections=3, axis=1) out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) f1 = relay.Function([p01], out) - X = relay.var("X", shape=(1, rnn_dim)) - W = relay.var("W", shape=(3 * rnn_dim, rnn_dim)) + X = relay.var("X", shape=(1, dim)) + W = relay.var("W", shape=(3 * dim, dim)) y = relay.Call(f0, [X, W]) z = relay.Call(f1, [y]) return relay.Function([X, W], z) - rnn_dim = 10 - z = before(rnn_dim) + dim = 10 + z = before(dim) z = relay.ir_pass.infer_type(z) zz = relay.ir_pass.fuse_ops(z, opt_level=0) assert not relay.ir_pass.free_vars(zz) zz = relay.ir_pass.fuse_ops(z, opt_level=2) zz = relay.ir_pass.infer_type(zz) assert not relay.ir_pass.free_vars(zz) - after = relay.ir_pass.infer_type(expected(rnn_dim)) + after = relay.ir_pass.infer_type(expected(dim)) assert relay.ir_pass.alpha_equal(zz, after) @@ -369,5 +369,5 @@ def expected(dim): test_tuple_strided_slice() test_stop_fusion() test_fuse_myia_regression() - test_gru() + test_fuse_tuple_get_elemwise() test_tuple_get_root() From f64414c722f8e5afe8d09047e12382f164e3fe9d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 29 Mar 2019 12:14:27 +0900 Subject: [PATCH 6/7] add comment on reference handling --- src/relay/pass/fuse_ops.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index da566a18a099..f47fc719f8b4 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -263,6 +263,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const TupleGetItemNode* op) final { auto tuple_type = op->tuple->checked_type().as(); CHECK(tuple_type); + // If this tuple contain a reference type, and we fuse TupleGetItem and + // the reference, a fused function will have a tuple containing a reference + // in its parameters. But when TVM lowers a fused function, it expects all + // arguments to be a Tensor or a tuple containing only Tensors. + // To avoid modifying codegen logic, we do not allow fusing through a reference. bool has_reference = false; for (auto ty : tuple_type->fields) { if (ty.as()) { From e4383f68487a70fb532542ce40e48086d2a8ac8d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 29 Mar 2019 12:18:57 +0900 Subject: [PATCH 7/7] add more comment --- src/relay/pass/fuse_ops.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index f47fc719f8b4..c7b16da9036c 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -268,6 +268,8 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // in its parameters. But when TVM lowers a fused function, it expects all // arguments to be a Tensor or a tuple containing only Tensors. // To avoid modifying codegen logic, we do not allow fusing through a reference. + // The reference itself will be recursively visited via call to ExprVisitor::VisitExpr_(op) + // below and corresponding visitor methods bool has_reference = false; for (auto ty : tuple_type->fields) { if (ty.as()) {