From ca5c60753aeeb84419d3ae7f638365c84e8314e3 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Thu, 16 May 2019 17:27:28 -0700 Subject: [PATCH] Pass in executor to change_dtype() --- .../unittest/test_custom_datatypes_change_dtype.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_custom_datatypes_change_dtype.py b/tests/python/unittest/test_custom_datatypes_change_dtype.py index 1615f927171e..ea2551f9e368 100644 --- a/tests/python/unittest/test_custom_datatypes_change_dtype.py +++ b/tests/python/unittest/test_custom_datatypes_change_dtype.py @@ -62,13 +62,13 @@ def convert_ndarray(dst_dtype, array, executor): cast = relay.Function([x], x.astype(dst_dtype)) return executor.evaluate(cast)(array) -def change_dtype(src, dst, expr, params): +def change_dtype(src, dst, expr, params, executor): cdtype = relay.frontend.ChangeDatatype(src, dst) expr = cdtype.visit(expr) expr = relay.ir_pass.infer_type(expr) #raise "pause" params = dict( - (p, convert_ndarray(dst, params[p])) for p in params) + (p, convert_ndarray(dst, params[p], executor)) for p in params) return expr, params def test_change_dtype_simple(): @@ -85,7 +85,7 @@ def test_change_dtype_simple(): # Execute the model in the new datatype. result = ex.evaluate(func)(A, B) - func_changed, _ = change_dtype('float32', 'custom[bfloat]16', func, []) + func_changed, _ = change_dtype('float32', 'custom[bfloat]16', func, [], ex) A_converted = convert_ndarray('custom[bfloat]16', A, ex) B_converted = convert_ndarray('custom[bfloat]16', B, ex) result = ex.evaluate(func_changed)(A_converted, B_converted) @@ -100,7 +100,7 @@ def test_change_dtype_resnet(): src_dtype = 'float32' dst_dtype = 'custom[bfloat]16' # Change me to posit. - expr, params = change_dtype(src_dtype, dst_dtype, expr, params) + expr, params = change_dtype(src_dtype, dst_dtype, expr, params, ex) # Convert the input into the correct format. input = tvm.nd.array(np.random.rand(3, 299, 299).astype(src_dtype)) @@ -122,7 +122,7 @@ def test_change_dtype_inception_v3(): src_dtype = 'float32' dst_dtype = 'custom[bfloat]16' # Change me to posit. - expr, params = change_dtype(src_dtype, dst_dtype, expr, params) + expr, params = change_dtype(src_dtype, dst_dtype, expr, params, ex) # Convert the input into the correct format. input = tvm.nd.array(np.random.rand(3, 299, 299).astype(src_dtype))