From ec8825c72bf317c5ecadbf09f094c7460a226eec Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Tue, 21 Jan 2020 22:30:02 -0800 Subject: [PATCH] Remove run_infer_type duplicates (#4766) --- tests/python/relay/test_op_level1.py | 7 +------ tests/python/relay/test_op_level10.py | 7 +------ tests/python/relay/test_op_level3.py | 7 +------ tests/python/relay/test_op_level4.py | 7 +------ tests/python/relay/test_op_level5.py | 7 +------ tests/python/relay/test_pass_manager.py | 9 +-------- tests/python/relay/test_pass_to_cps.py | 6 +----- 7 files changed, 7 insertions(+), 43 deletions(-) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index f73826e2d6c8..194b09564288 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -20,15 +20,10 @@ import scipy from tvm import relay from tvm.relay import transform -from tvm.relay.testing import ctx_list +from tvm.relay.testing import ctx_list, run_infer_type import topi.testing from tvm.contrib.nvcc import have_fp16 -def run_infer_type(expr): - mod = relay.Module.from_expr(expr) - mod = transform.InferType()(mod) - entry = mod["main"] - return entry if isinstance(expr, relay.Function) else entry.body def sigmoid(x): one = np.ones_like(x) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index bb1d346ac6e0..6a6f21d9241f 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -21,15 +21,10 @@ import topi.testing from tvm import relay from tvm.relay import transform -from tvm.relay.testing import ctx_list +from tvm.relay.testing import ctx_list, run_infer_type import topi import topi.testing -def run_infer_type(expr): - mod = relay.Module.from_expr(expr) - mod = transform.InferType()(mod) - entry = mod["main"] - return entry if isinstance(expr, relay.Function) else entry.body def test_checkpoint(): dtype = "float32" diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 13f17ca6713b..9c5dfacf62a2 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -21,13 +21,8 @@ import tvm from tvm import relay from tvm.relay import create_executor, transform -from tvm.relay.testing import ctx_list, check_grad +from tvm.relay.testing import ctx_list, check_grad, run_infer_type -def run_infer_type(expr): - mod = relay.Module.from_expr(expr) - mod = transform.InferType()(mod) - entry = mod["main"] - return entry if isinstance(expr, relay.Function) else entry.body def test_zeros_ones(): for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]: diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 2b25d6a90af6..0243adc59319 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -18,14 +18,9 @@ import numpy as np from tvm import relay from tvm.relay import transform -from tvm.relay.testing import ctx_list +from tvm.relay.testing import ctx_list, run_infer_type import topi.testing -def run_infer_type(expr): - mod = relay.Module.from_expr(expr) - mod = transform.InferType()(mod) - entry = mod["main"] - return entry if isinstance(expr, relay.Function) else entry.body def test_binary_op(): def check_binary_op(opfunc, ref): diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index d4abf3d82ced..eb21f338ef07 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -21,14 +21,9 @@ import tvm from tvm import relay from tvm.relay import transform -from tvm.relay.testing import ctx_list +from tvm.relay.testing import ctx_list, run_infer_type import topi.testing -def run_infer_type(expr): - mod = relay.Module.from_expr(expr) - mod = transform.InferType()(mod) - entry = mod["main"] - return entry if isinstance(expr, relay.Function) else entry.body def test_resize_infer_type(): n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w") diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 3fb85088d9b1..e02e917dbb62 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -24,14 +24,7 @@ from tvm.relay import Function, Call from tvm.relay import analysis from tvm.relay import transform as _transform -from tvm.relay.testing import ctx_list - - -def run_infer_type(expr): - mod = relay.Module.from_expr(expr) - mod = _transform.InferType()(mod) - entry = mod["main"] - return entry if isinstance(expr, relay.Function) else entry.body +from tvm.relay.testing import ctx_list, run_infer_type def get_var_func(): diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index 045c92c929c4..1d09c0d67f5b 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -21,15 +21,11 @@ from tvm.relay.transform import to_cps, un_cps from tvm.relay.feature import Feature from tvm.relay.prelude import Prelude -from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, run_opt_pass +from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass from tvm.relay import create_executor from tvm.relay import Function, transform -def rand(dtype='float32', *shape): - return tvm.nd.array(np.random.rand(*shape).astype(dtype)) - - def test_id(): x = relay.var("x", shape=[]) id = run_infer_type(relay.Function([x], x))