diff --git a/tests/python/unittest/test_custom_datatypes_change_dtype.py b/tests/python/unittest/test_custom_datatypes_change_dtype.py index a13f912141da..1615f927171e 100644 --- a/tests/python/unittest/test_custom_datatypes_change_dtype.py +++ b/tests/python/unittest/test_custom_datatypes_change_dtype.py @@ -20,6 +20,7 @@ import numpy as np from tvm import relay from tvm.relay.testing.inception_v3 import get_workload as get_inception +from tvm.relay.testing.resnet import get_workload as get_resnet tgt = "llvm" @@ -92,8 +93,29 @@ def test_change_dtype_simple(): result_converted = convert_ndarray('float32', result, ex) print(result_converted) -def test_change_dtype_inception_v3(): +def test_change_dtype_resnet(): + expr, params = get_resnet() + + ex = relay.create_executor("graph") + + src_dtype = 'float32' + dst_dtype = 'custom[bfloat]16' # Change me to posit. + expr, params = change_dtype(src_dtype, dst_dtype, expr, params) + + # Convert the input into the correct format. + input = tvm.nd.array(np.random.rand(3, 299, 299).astype(src_dtype)) + input = convert_ndarray(dst_dtype, input, ex) + def print_info(node): + if not isinstance(node, relay.op.op.Op): + if ("custom[bfloat]32" not in str(node.checked_type())): + print(node.checked_type()) + relay.ir_pass.post_order_visit(expr, print_info) + + # Execute the model in the new datatype. + result = ex.evaluate(expr)(input, **params) + +def test_change_dtype_inception_v3(): expr, params = get_inception() ex = relay.create_executor("graph") @@ -120,3 +142,4 @@ def print_info(node): setup() # test_change_dtype_inception_v3() test_change_dtype_simple() + test_change_dtype_resnet()