Skip to content

Commit

Permalink
add nested dynamic shape test
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Jun 16, 2020
1 parent 273a47d commit 8aa3193
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/python/relay/test_pass_dynamic_to_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,38 @@ def verify_reshape(shape, newshape):

verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))

def test_dynamic_to_static_quad_reshape():
def verify_reshape(shape, newshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("y", relay.TensorType(newshape, "float32"))
z1 = relay.dynamic.reshape(x, relay.shape_of(y))
z2 = relay.dynamic.reshape(z1, relay.shape_of(x))
z3 = relay.dynamic.reshape(z2, relay.shape_of(z1))
z4 = relay.dynamic.reshape(z3, relay.shape_of(z2))
func = run_infer_type(relay.Function([x, y], z4))
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())

zz = func2.body
assert isinstance(zz, relay.Call)
assert zz.op == relay.op.get("reshape")
assert "newshape=" in zz.astext()
assert zz.checked_type == relay.ty.TensorType(shape, "float32")

x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32")
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
mod = tvm.ir.IRModule.from_expr(func2)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), x_data, rtol=1e-5)

verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))

if __name__=="__main__":
test_dynamic_to_static_reshape()
test_dynamic_to_static_double_reshape()
test_dynamic_to_static_quad_reshape()

0 comments on commit 8aa3193

Please sign in to comment.