Skip to content

Commit

Permalink
Pass in executor to change_dtype()
Browse files Browse the repository at this point in the history
  • Loading branch information
gussmith23 committed Oct 17, 2019
1 parent e412aa7 commit ca5c607
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/python/unittest/test_custom_datatypes_change_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def convert_ndarray(dst_dtype, array, executor):
cast = relay.Function([x], x.astype(dst_dtype))
return executor.evaluate(cast)(array)

def change_dtype(src, dst, expr, params):
def change_dtype(src, dst, expr, params, executor):
cdtype = relay.frontend.ChangeDatatype(src, dst)
expr = cdtype.visit(expr)
expr = relay.ir_pass.infer_type(expr)
#raise "pause"
params = dict(
(p, convert_ndarray(dst, params[p])) for p in params)
(p, convert_ndarray(dst, params[p], executor)) for p in params)
return expr, params

def test_change_dtype_simple():
Expand All @@ -85,7 +85,7 @@ def test_change_dtype_simple():
# Execute the model in the new datatype.
result = ex.evaluate(func)(A, B)

func_changed, _ = change_dtype('float32', 'custom[bfloat]16', func, [])
func_changed, _ = change_dtype('float32', 'custom[bfloat]16', func, [], ex)
A_converted = convert_ndarray('custom[bfloat]16', A, ex)
B_converted = convert_ndarray('custom[bfloat]16', B, ex)
result = ex.evaluate(func_changed)(A_converted, B_converted)
Expand All @@ -100,7 +100,7 @@ def test_change_dtype_resnet():

src_dtype = 'float32'
dst_dtype = 'custom[bfloat]16' # Change me to posit.
expr, params = change_dtype(src_dtype, dst_dtype, expr, params)
expr, params = change_dtype(src_dtype, dst_dtype, expr, params, ex)

# Convert the input into the correct format.
input = tvm.nd.array(np.random.rand(3, 299, 299).astype(src_dtype))
Expand All @@ -122,7 +122,7 @@ def test_change_dtype_inception_v3():

src_dtype = 'float32'
dst_dtype = 'custom[bfloat]16' # Change me to posit.
expr, params = change_dtype(src_dtype, dst_dtype, expr, params)
expr, params = change_dtype(src_dtype, dst_dtype, expr, params, ex)

# Convert the input into the correct format.
input = tvm.nd.array(np.random.rand(3, 299, 299).astype(src_dtype))
Expand Down

0 comments on commit ca5c607

Please sign in to comment.