diff --git a/tests/python/unittest/test_custom_datatypes_change_dtype.py b/tests/python/unittest/test_custom_datatypes_change_dtype.py index ea2551f9e368..9f665ff06065 100644 --- a/tests/python/unittest/test_custom_datatypes_change_dtype.py +++ b/tests/python/unittest/test_custom_datatypes_change_dtype.py @@ -21,6 +21,7 @@ from tvm import relay from tvm.relay.testing.inception_v3 import get_workload as get_inception from tvm.relay.testing.resnet import get_workload as get_resnet +from tvm.relay.testing.mobilenet import get_workload as get_mobilenet tgt = "llvm" @@ -137,9 +138,31 @@ def print_info(node): # Execute the model in the new datatype. result = ex.evaluate(expr)(input, **params) +def test_change_dtype_mobilenet(): + expr, params = get_mobilenet() + + ex = relay.create_executor("graph") + + src_dtype = 'float32' + dst_dtype = 'custom[bfloat]16' # Change me to posit. + expr, params = change_dtype(src_dtype, dst_dtype, expr, params, ex) + + # Convert the input into the correct format. + input = tvm.nd.array(np.random.rand(3, 224, 224).astype(src_dtype)) + input = convert_ndarray(dst_dtype, input, ex) + + # def print_info(node): + # if not isinstance(node, relay.op.op.Op): + # if ("custom[bfloat]32" not in str(node.checked_type())): + # print(node.checked_type()) + # relay.ir_pass.post_order_visit(expr, print_info) + + # Execute the model in the new datatype. + result = ex.evaluate(expr)(input, **params) if __name__ == "__main__": setup() # test_change_dtype_inception_v3() test_change_dtype_simple() test_change_dtype_resnet() + test_change_dtype_mobilenet()