Skip to content

Commit

Permalink
remove cuda tests until VM supports dynamic shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Jun 18, 2020
1 parent 8aa3193 commit 9453400
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 30 deletions.
25 changes: 13 additions & 12 deletions tests/python/relay/dynamic/test_dynamic_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@
from tvm.relay import create_executor, transform
from tvm.relay.testing import ctx_list, check_grad, run_infer_type

def verify_func(func, data, ref_res):
assert isinstance(data, list)
for target, ctx in ctx_list():
#TODO(mbrookhart): enable Cuda tests onces the VM supports dynamic shapes
if "llvm" not in target: continue
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(*data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
relay.backend.compile_engine.get().clear()

def test_dynamic_reshape():
def verify_reshape(shape, newshape, oshape):
Expand All @@ -34,12 +45,7 @@ def verify_reshape(shape, newshape, oshape):
func = relay.Function([x, y], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
for target, ctx in ctx_list():
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_data, np.array(newshape))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_func(func, [x_data, np.array(newshape).astype("int64")], ref_res)
verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
verify_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2))
Expand All @@ -60,12 +66,7 @@ def verify_reshape(shape, newshape, oshape):
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32")
ref_res = np.reshape(x_data, oshape)
for target, ctx in ctx_list():
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_func(func, [x_data, y_data], ref_res)
verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))

Expand Down
31 changes: 13 additions & 18 deletions tests/python/relay/test_pass_dynamic_to_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ def run_opt_pass(expr, opt_pass):
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body

def verify_func(func, data, ref_res):
assert isinstance(data, list)
for target, ctx in ctx_list():
#TODO(mbrookhart): enable Cuda tests onces the VM supports dynamic shapes
if "llvm" not in target: continue
for kind in ["graph", "vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(*data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

def test_dynamic_to_static_reshape():
def verify_reshape(shape, newshape, oshape):
Expand All @@ -49,12 +59,7 @@ def verify_reshape(shape, newshape, oshape):
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32")
ref_res = np.reshape(x_data, oshape)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
mod = tvm.ir.IRModule.from_expr(func2)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_func(func2, [x_data, y_data], ref_res)

verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
Expand All @@ -76,12 +81,7 @@ def verify_reshape(shape, newshape):

x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32")
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
mod = tvm.ir.IRModule.from_expr(func2)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), x_data, rtol=1e-5)
verify_func(func2, [x_data, y_data], x_data)

verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))
Expand All @@ -105,12 +105,7 @@ def verify_reshape(shape, newshape):

x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32")
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
mod = tvm.ir.IRModule.from_expr(func2)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), x_data, rtol=1e-5)
verify_func(func2, [x_data, y_data], x_data)

verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))
Expand Down

0 comments on commit 9453400

Please sign in to comment.