Skip to content

Commit

Permalink
Add MobileNet
Browse files Browse the repository at this point in the history
  • Loading branch information
gussmith23 committed Oct 17, 2019
1 parent ca5c607 commit d949ed5
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/python/unittest/test_custom_datatypes_change_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tvm import relay
from tvm.relay.testing.inception_v3 import get_workload as get_inception
from tvm.relay.testing.resnet import get_workload as get_resnet
from tvm.relay.testing.mobilenet import get_workload as get_mobilenet

tgt = "llvm"

Expand Down Expand Up @@ -137,9 +138,31 @@ def print_info(node):
# Execute the model in the new datatype.
result = ex.evaluate(expr)(input, **params)

def test_change_dtype_mobilenet():
expr, params = get_mobilenet()

ex = relay.create_executor("graph")

src_dtype = 'float32'
dst_dtype = 'custom[bfloat]16' # Change me to posit.
expr, params = change_dtype(src_dtype, dst_dtype, expr, params, ex)

# Convert the input into the correct format.
input = tvm.nd.array(np.random.rand(3, 224, 224).astype(src_dtype))
input = convert_ndarray(dst_dtype, input, ex)

# def print_info(node):
# if not isinstance(node, relay.op.op.Op):
# if ("custom[bfloat]32" not in str(node.checked_type())):
# print(node.checked_type())
# relay.ir_pass.post_order_visit(expr, print_info)

# Execute the model in the new datatype.
result = ex.evaluate(expr)(input, **params)

if __name__ == "__main__":
setup()
# test_change_dtype_inception_v3()
test_change_dtype_simple()
test_change_dtype_resnet()
test_change_dtype_mobilenet()

0 comments on commit d949ed5

Please sign in to comment.