Skip to content

Commit

Permalink
Remove run_infer_type duplicates (#4766)
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov authored and masahi committed Jan 22, 2020
1 parent 4dbe4d9 commit cf3e786
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 43 deletions.
7 changes: 1 addition & 6 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,10 @@
import scipy
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import ctx_list
from tvm.relay.testing import ctx_list, run_infer_type
import topi.testing
from tvm.contrib.nvcc import have_fp16

def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body

def sigmoid(x):
one = np.ones_like(x)
Expand Down
7 changes: 1 addition & 6 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,10 @@
import topi.testing
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import ctx_list
from tvm.relay.testing import ctx_list, run_infer_type
import topi
import topi.testing

def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body

def test_checkpoint():
dtype = "float32"
Expand Down
7 changes: 1 addition & 6 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,8 @@
import tvm
from tvm import relay
from tvm.relay import create_executor, transform
from tvm.relay.testing import ctx_list, check_grad
from tvm.relay.testing import ctx_list, check_grad, run_infer_type

def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body

def test_zeros_ones():
for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]:
Expand Down
7 changes: 1 addition & 6 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@
import numpy as np
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import ctx_list
from tvm.relay.testing import ctx_list, run_infer_type
import topi.testing

def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body

def test_binary_op():
def check_binary_op(opfunc, ref):
Expand Down
7 changes: 1 addition & 6 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,9 @@
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import ctx_list
from tvm.relay.testing import ctx_list, run_infer_type
import topi.testing

def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body

def test_resize_infer_type():
n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
Expand Down
9 changes: 1 addition & 8 deletions tests/python/relay/test_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,7 @@
from tvm.relay import Function, Call
from tvm.relay import analysis
from tvm.relay import transform as _transform
from tvm.relay.testing import ctx_list


def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = _transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
from tvm.relay.testing import ctx_list, run_infer_type


def get_var_func():
Expand Down
6 changes: 1 addition & 5 deletions tests/python/relay/test_pass_to_cps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,11 @@
from tvm.relay.transform import to_cps, un_cps
from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, run_opt_pass
from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass
from tvm.relay import create_executor
from tvm.relay import Function, transform


def rand(dtype='float32', *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))


def test_id():
x = relay.var("x", shape=[])
id = run_infer_type(relay.Function([x], x))
Expand Down

0 comments on commit cf3e786

Please sign in to comment.