Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
gussmith23 committed Oct 17, 2019
1 parent 79bf21f commit e412aa7
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion tests/python/unittest/test_custom_datatypes_change_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
from tvm import relay
from tvm.relay.testing.inception_v3 import get_workload as get_inception
from tvm.relay.testing.resnet import get_workload as get_resnet

tgt = "llvm"

Expand Down Expand Up @@ -92,8 +93,29 @@ def test_change_dtype_simple():
result_converted = convert_ndarray('float32', result, ex)
print(result_converted)

def test_change_dtype_inception_v3():
def test_change_dtype_resnet():
expr, params = get_resnet()

ex = relay.create_executor("graph")

src_dtype = 'float32'
dst_dtype = 'custom[bfloat]16' # Change me to posit.
expr, params = change_dtype(src_dtype, dst_dtype, expr, params)

# Convert the input into the correct format.
input = tvm.nd.array(np.random.rand(3, 299, 299).astype(src_dtype))
input = convert_ndarray(dst_dtype, input, ex)

def print_info(node):
if not isinstance(node, relay.op.op.Op):
if ("custom[bfloat]32" not in str(node.checked_type())):
print(node.checked_type())
relay.ir_pass.post_order_visit(expr, print_info)

# Execute the model in the new datatype.
result = ex.evaluate(expr)(input, **params)

def test_change_dtype_inception_v3():
expr, params = get_inception()

ex = relay.create_executor("graph")
Expand All @@ -120,3 +142,4 @@ def print_info(node):
setup()
# test_change_dtype_inception_v3()
test_change_dtype_simple()
test_change_dtype_resnet()

0 comments on commit e412aa7

Please sign in to comment.