From 5463ce94a71852b914afee38c2539cc132e64f73 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Tue, 28 Apr 2020 23:55:20 +0000 Subject: [PATCH 1/4] Improve Shape Func handling for Tuple inputs --- src/relay/backend/compile_engine.cc | 7 +++++ src/relay/op/memory/memory.cc | 17 +++++++++++- tests/python/relay/test_any.py | 42 +++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index ce0a314f265b..a3e950472e7b 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -552,6 +552,13 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> return fields; } + Array VisitExpr_(const TupleGetItemNode* op) final { + Array input_shapes = VisitExpr(op->tuple); + Array out; + out.push_back(input_shapes[op->index]); + return out; + } + private: /*! \brief String stream for function name */ std::ostringstream readable_name_stream_; diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index c7ffc95c05d5..494b7812d9a4 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -360,12 +360,27 @@ bool ShapeFuncRel(const Array& types, int num_inputs, const Attrs& attrs, auto tuple = TupleType(func_type->arg_types); auto in_types = FlattenTupleType(tuple); auto out_types = FlattenTupleType(func_type->ret_type); + int num_types = 0; + std::unordered_map is_input_map; + for (size_t i = 0; i < func_type->arg_types.size(); ++i) { + auto const& ft = func_type->arg_types[i]; + if (auto tupletype = ft.as()) { + for (size_t j = num_types; j < num_types + tupletype->fields.size(); ++j) { + is_input_map[j] = shape_func_attrs->is_input[i]; + } + num_types += tupletype->fields.size(); + } + else { + is_input_map[num_types] = shape_func_attrs->is_input[i]; + num_types += 1; + } + } Array shape_func_ins, shape_func_outs; for (size_t i = 0; i < in_types.size(); i++) { auto in_type = in_types[i]; - if (shape_func_attrs->is_input[i]) { + if (is_input_map[i]) { shape_func_ins.push_back(in_type); } else { auto shape = RankShape(in_type->shape); diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 6ce59bbf1c36..df926d408390 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -667,6 +667,45 @@ def _body(i, st): except Exception as e: assert "in particular dimension 0 conflicts 2 does not match 1" in str(e) +def test_tuple_get_item(): + mod = tvm.IRModule() + dtype = "float32" + static_data_shape = (9, 4) + data_shape = (relay.Any(), 4) + indices_or_sections = 2 + axis = 1 + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.split(data, indices_or_sections, axis) + y = relay.expr.TupleGetItem(y.astuple(), 0) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + ref_out_shape = (9, 2) + for kind in ["vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape)) + +def test_mixed_input_type(): + mod = tvm.IRModule() + dtype = "float32" + static_data_shape = (9, 4) + data_shape = (relay.Any(), 4) + tensor_type = relay.TensorType(data_shape, dtype) + data0 = relay.var("d0", type_annotation=relay.TupleType([tensor_type, tensor_type])) + data1 = relay.var("d1", shape=(relay.Any(), 4), dtype=dtype) + data_tuple = relay.expr.TupleWrapper(data0, 2) + y = data_tuple[0] * data_tuple[1] + data1 + mod["main"] = relay.Function([data0, data1], y) + data_np0 = np.random.uniform(size=static_data_shape).astype(dtype) + data_np1 = np.random.uniform(size=static_data_shape).astype(dtype) + ref_out_shape = (9, 4) + for kind in ["vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()([data_np0, data_np0], data_np1) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape)) + if __name__ == "__main__": test_any_full() test_any_broadcast() @@ -695,3 +734,6 @@ def _body(i, st): test_arange_with_dynamic_shape() test_recursive_concat() test_recursive_concat_with_wrong_annotation() + test_tuple_get_item() + test_mixed_input_type() + From f311f8000045b10646bd923a8a4f44478a3822c5 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 29 Apr 2020 00:35:02 +0000 Subject: [PATCH 2/4] Fix lint --- src/relay/op/memory/memory.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 494b7812d9a4..5ceeb9972ef1 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -369,8 +369,7 @@ bool ShapeFuncRel(const Array& types, int num_inputs, const Attrs& attrs, is_input_map[j] = shape_func_attrs->is_input[i]; } num_types += tupletype->fields.size(); - } - else { + } else { is_input_map[num_types] = shape_func_attrs->is_input[i]; num_types += 1; } From 74d70101033fe8b60a001d2cbad6e8135a6ad89b Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Sat, 2 May 2020 11:28:36 -0700 Subject: [PATCH 3/4] Improve --- src/relay/op/memory/memory.cc | 21 +++++++++------------ tests/python/relay/test_any.py | 8 +++++--- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 5ceeb9972ef1..ec4928580628 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -360,18 +360,15 @@ bool ShapeFuncRel(const Array& types, int num_inputs, const Attrs& attrs, auto tuple = TupleType(func_type->arg_types); auto in_types = FlattenTupleType(tuple); auto out_types = FlattenTupleType(func_type->ret_type); - int num_types = 0; - std::unordered_map is_input_map; + Array is_input; for (size_t i = 0; i < func_type->arg_types.size(); ++i) { - auto const& ft = func_type->arg_types[i]; - if (auto tupletype = ft.as()) { - for (size_t j = num_types; j < num_types + tupletype->fields.size(); ++j) { - is_input_map[j] = shape_func_attrs->is_input[i]; - } - num_types += tupletype->fields.size(); - } else { - is_input_map[num_types] = shape_func_attrs->is_input[i]; - num_types += 1; + auto const& aty = func_type->arg_types[i]; + size_t num_types = 1; + if (const auto& tupletype = aty.as()) { + num_types = FlattenTupleType(aty).size(); + } + for (int j = 0; j < num_types; ++j) { + is_input.push_back(shape_func_attrs->is_input[i]); } } @@ -379,7 +376,7 @@ bool ShapeFuncRel(const Array& types, int num_inputs, const Attrs& attrs, for (size_t i = 0; i < in_types.size(); i++) { auto in_type = in_types[i]; - if (is_input_map[i]) { + if (is_input[i]) { shape_func_ins.push_back(in_type); } else { auto shape = RankShape(in_type->shape); diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index df926d408390..24c679eaaef8 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -692,17 +692,19 @@ def test_mixed_input_type(): static_data_shape = (9, 4) data_shape = (relay.Any(), 4) tensor_type = relay.TensorType(data_shape, dtype) - data0 = relay.var("d0", type_annotation=relay.TupleType([tensor_type, tensor_type])) + tuple_type = relay.TupleType([tensor_type, tensor_type]) + data0 = relay.var("d0", type_annotation=relay.TupleType([tuple_type, tensor_type])) data1 = relay.var("d1", shape=(relay.Any(), 4), dtype=dtype) data_tuple = relay.expr.TupleWrapper(data0, 2) - y = data_tuple[0] * data_tuple[1] + data1 + nested_data_tuple = relay.expr.TupleWrapper(data_tuple[0], 2) + y = nested_data_tuple[1] * data_tuple[1] + data1 mod["main"] = relay.Function([data0, data1], y) data_np0 = np.random.uniform(size=static_data_shape).astype(dtype) data_np1 = np.random.uniform(size=static_data_shape).astype(dtype) ref_out_shape = (9, 4) for kind in ["vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()([data_np0, data_np0], data_np1) + result = ex.evaluate()([[data_np0, data_np0], data_np0], data_np1) assert result.asnumpy().shape == ref_out_shape, \ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape)) From c4d331714cc387c20d9f8002c23df044199df295 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Sat, 2 May 2020 11:40:21 -0700 Subject: [PATCH 4/4] Fix build --- src/relay/op/memory/memory.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index ec4928580628..ff669b01e66e 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -364,10 +364,10 @@ bool ShapeFuncRel(const Array& types, int num_inputs, const Attrs& attrs, for (size_t i = 0; i < func_type->arg_types.size(); ++i) { auto const& aty = func_type->arg_types[i]; size_t num_types = 1; - if (const auto& tupletype = aty.as()) { + if (aty.as()) { num_types = FlattenTupleType(aty).size(); } - for (int j = 0; j < num_types; ++j) { + for (size_t j = 0; j < num_types; ++j) { is_input.push_back(shape_func_attrs->is_input[i]); } }