From dfbb78e7a31f75efca320d76e233be2ceda8eb1e Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 1 Jul 2020 13:49:56 -0700 Subject: [PATCH] add dyn tile to dynamic_to_static --- src/relay/transforms/dynamic_to_static.cc | 11 +++++++++- .../relay/test_pass_dynamic_to_static.py | 22 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 7b3f1957811b1..ede5a47315e5b 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -32,7 +32,7 @@ namespace relay { class DynamicToStaticMutator : public MixedModeMutator { public: - DynamicToStaticMutator() : dyn_reshape_op_(Op::Get("dyn.reshape")) {} + DynamicToStaticMutator() : dyn_reshape_op_(Op::Get("dyn.reshape")), dyn_tile_op_(Op::Get("dyn.tile")) {} private: Expr Rewrite_(const CallNode* pre, const Expr& post) override { @@ -46,6 +46,14 @@ class DynamicToStaticMutator : public MixedModeMutator { static const Op& reshape = Op::Get("reshape"); return Call(reshape, {call_node->args[0]}, Attrs(attrs), {}); } + } else if (call_node->op == dyn_tile_op_) { + if (const ConstantNode* reps = call_node->args[1].as()) { + auto attrs = make_object(); + CHECK_EQ(reps->data->ndim, 1); + attrs->reps = ToVector(reps->data); + static const Op& op = Op::Get("tile"); + return Call(op, {call_node->args[0]}, Attrs(attrs), {}); + } } return post; } @@ -58,6 +66,7 @@ class DynamicToStaticMutator : public MixedModeMutator { } const Op& dyn_reshape_op_; + const Op& dyn_tile_op_; }; Expr DynamicToStatic(Function f, IRModule m) { diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 052d95cef0a76..3415ce01d5fd2 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -108,8 +108,30 @@ def verify_reshape(shape, newshape): verify_reshape((2, 3, 4), (8, 3)) verify_reshape((4, 7), (2, 7, 2)) +def test_dynamic_to_static_tile(): + def verify_tile(shape, reps, oshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("y", relay.TensorType(reps, "float32")) + z = relay.tile(x, relay.shape_of(y)) + func = run_infer_type(relay.Function([x, y], z)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("tile") + assert zz.checked_type == relay.ty.TensorType(oshape, "float32") + + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + y_data = np.random.uniform(low=-1, high=1, size=reps).astype("float32") + ref_res = np.tile(x_data, reps) + verify_func(func2, [x_data, y_data], ref_res) + + verify_tile((2, 3, 4), (2, 1, 5), (4, 3, 20)) + verify_tile((4, 7), (4, 2), (16, 14)) + if __name__=="__main__": test_dynamic_to_static_reshape() test_dynamic_to_static_double_reshape() test_dynamic_to_static_quad_reshape() + test_dynamic_to_static_tile()