From 8aa319398a083cead77d84b7c9155017ef28bdc5 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 16 Jun 2020 13:48:04 -0700 Subject: [PATCH] add nested dynamic shape test --- .../relay/test_pass_dynamic_to_static.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 2ec4e29a69b5..e039439473fe 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -85,3 +85,38 @@ def verify_reshape(shape, newshape): verify_reshape((2, 3, 4), (8, 3)) verify_reshape((4, 7), (2, 7, 2)) + +def test_dynamic_to_static_quad_reshape(): + def verify_reshape(shape, newshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("y", relay.TensorType(newshape, "float32")) + z1 = relay.dynamic.reshape(x, relay.shape_of(y)) + z2 = relay.dynamic.reshape(z1, relay.shape_of(x)) + z3 = relay.dynamic.reshape(z2, relay.shape_of(z1)) + z4 = relay.dynamic.reshape(z3, relay.shape_of(z2)) + func = run_infer_type(relay.Function([x, y], z4)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("reshape") + assert "newshape=" in zz.astext() + assert zz.checked_type == relay.ty.TensorType(shape, "float32") + + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32") + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + mod = tvm.ir.IRModule.from_expr(func2) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_data, y_data) + tvm.testing.assert_allclose(op_res.asnumpy(), x_data, rtol=1e-5) + + verify_reshape((2, 3, 4), (8, 3)) + verify_reshape((4, 7), (2, 7, 2)) + +if __name__=="__main__": + test_dynamic_to_static_reshape() + test_dynamic_to_static_double_reshape() + test_dynamic_to_static_quad_reshape() +