diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 0204e5bb51462..904e4d7baf28e 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -39,6 +39,7 @@ from . import densenet from . import yolo_detection from . import temp_op_attr +from . import synthetic from .config import ctx_list from .init import create_workload diff --git a/python/tvm/relay/testing/init.py b/python/tvm/relay/testing/init.py index 352230a6150f3..6b8adf35de614 100644 --- a/python/tvm/relay/testing/init.py +++ b/python/tvm/relay/testing/init.py @@ -15,11 +15,13 @@ # specific language governing permissions and limitations # under the License. """Initializer of parameters.""" +from functools import reduce import numpy as np import tvm from tvm import relay + class Initializer(object): """The base class of an initializer.""" def __init__(self, **kwargs): @@ -128,6 +130,14 @@ def _init_weight(self, name, arr): raise ValueError("Unknown random type") +class Constant(Initializer): + """ Constant initialization of weights. Sum of weights in the matrix is 1. + """ + def _init_weight(self, name, arr): + num_elements = reduce(lambda x, y: x*y, arr.shape) + arr[:] = 1./num_elements + + def create_workload(net, initializer=None, seed=0): """Helper function to create benchmark image classification workload. diff --git a/python/tvm/relay/testing/synthetic.py b/python/tvm/relay/testing/synthetic.py new file mode 100644 index 0000000000000..e8bda77a5e8f3 --- /dev/null +++ b/python/tvm/relay/testing/synthetic.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Synthetic networks for testing purposes. Ideally, these networks are similar in +structure to real world networks, but are much smaller in order to make testing +faster. +""" +from __future__ import absolute_import +from tvm import relay +from .init import create_workload, Constant +from . import layers + + +def get_net(input_shape=(1, 3, 24, 12), dtype="float32", wtype=None): + """Get synthetic testing network. + + Parameters + ---------- + image_shape : tuple, optional + The input shape as (batch_size, channels, height, width). + + dtype : str, optional + The data type for the input. + + wtype : str, optional + The data type for weights. Defaults to `dtype`. + + Returns + ------- + net : relay.Function + The dataflow. + """ + if wtype is None: + wtype = dtype + data = relay.var("data", shape=input_shape, dtype=dtype) + dense_shape = [-1, input_shape[3]] + dense = relay.nn.relu( + relay.nn.dense( + relay.reshape(data, dense_shape), + relay.var( + "dense_weight", shape=[input_shape[3], dense_shape[1]], dtype=wtype + ), + ) + ) + dense = relay.reshape_like(dense, data) + conv_shape = [input_shape[1], input_shape[1], 3, 3] + conv = relay.nn.softmax( + relay.nn.conv2d( + data, + relay.var("conv_weight", shape=conv_shape, dtype=wtype), + padding=1, + kernel_size=3, + ) + ) + added = relay.add(dense, conv) + biased = layers.batch_norm_infer( + relay.nn.bias_add(added, relay.var("bias", dtype=wtype)), name="batch_norm" + ) + dense = relay.nn.relu( + relay.nn.dense( + relay.reshape(biased, dense_shape), + relay.var( + "dense2_weight", shape=[input_shape[3], dense_shape[1]], dtype=wtype + ), + ) + ) + dense = relay.reshape_like(dense, data) + conv = relay.nn.softmax( + relay.nn.conv2d( + biased, + relay.var("conv2_weight", shape=conv_shape, dtype=wtype), + padding=1, + kernel_size=3, + ) + ) + added = relay.add(dense, conv) + args = relay.analysis.free_vars(added) + return relay.Function(args, added) + + +def get_workload(input_shape=(1, 3, 24, 12), dtype="float32", wtype=None): + """Get benchmark workload for the synthetic net. + + Parameters + ---------- + image_shape : tuple, optional + The input shape as (batch_size, channels, height, width). + + dtype : str, optional + The data type for the input. + + wtype : str, optional + The data type for weights. Defaults to `dtype`. + + Returns + ------- + mod : tvm.IRModule + The relay module that contains a synthetic network. + + params : dict of str to NDArray + The parameters. + """ + return create_workload( + get_net(input_shape=input_shape, dtype=dtype, wtype=wtype), + initializer=Constant(), + ) diff --git a/tests/micro/test_runtime_micro_on_arm.py b/tests/micro/test_runtime_micro_on_arm.py index ed7d62f284843..4b76efe17581b 100644 --- a/tests/micro/test_runtime_micro_on_arm.py +++ b/tests/micro/test_runtime_micro_on_arm.py @@ -23,7 +23,6 @@ from tvm import relay import tvm.micro as micro from tvm.micro import create_micro_mod -from tvm.relay.testing import resnet # Use real micro device - an STM32F746 discovery board # SETUP: diff --git a/tests/python/relay/test_analysis_extract_fused_functions.py b/tests/python/relay/test_analysis_extract_fused_functions.py index dab481ccd290f..41f6ca060d2b8 100644 --- a/tests/python/relay/test_analysis_extract_fused_functions.py +++ b/tests/python/relay/test_analysis_extract_fused_functions.py @@ -17,7 +17,7 @@ """Test function extraction""" import tvm from tvm import relay -from tvm.relay.testing.resnet import get_workload +from tvm.relay.testing.synthetic import get_workload def get_conv_net(): @@ -106,7 +106,7 @@ def is_conv_add(func): def test_extract_resnet(): mod, _params = get_workload() items = relay.analysis.extract_fused_functions(mod) - assert len(items) == 34 + assert len(items) == 6 if __name__ == '__main__': diff --git a/tests/python/relay/test_change_batch.py b/tests/python/relay/test_change_batch.py index e53887b1c4087..42376412ef72c 100644 --- a/tests/python/relay/test_change_batch.py +++ b/tests/python/relay/test_change_batch.py @@ -17,13 +17,13 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.testing import resnet +from tvm.relay.testing import synthetic from tvm.relay import transform -def test_change_batch_resnet(): - net, params = resnet.get_workload() +def test_change_batch_synthetic(): + net, params = synthetic.get_workload() new_net = transform.ChangeBatch({net["main"].params[0]: 0}, batch_size=123)(net) - assert new_net["main"].checked_type.ret_type == relay.TensorType((123, 1000)) + assert new_net["main"].checked_type.ret_type.shape[0] == 123 if __name__ == "__main__": - test_change_batch_resnet() + test_change_batch_synthetic() diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index f7427974904ce..da5291f5e927a 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -72,18 +72,19 @@ def _check_batch_flatten(node): # check if batch_flatten is quantized relay.analysis.post_order_visit(qmod["main"], _check_batch_flatten) -def get_calibration_dataset(input_name): +def get_calibration_dataset(mod, input_name): dataset = [] + input_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] for i in range(5): - data = np.random.uniform(size=(1, 3, 224, 224)) + data = np.random.uniform(size=input_shape) dataset.append({input_name: data}) return dataset @pytest.mark.parametrize("create_target", [True, False]) def test_calibrate_target(create_target): - mod, params = testing.resnet.get_workload(num_layers=18) - dataset = get_calibration_dataset("data") + mod, params = testing.synthetic.get_workload() + dataset = get_calibration_dataset(mod, "data") with relay.quantize.qconfig(calibrate_mode="kl_divergence"): if create_target: with tvm.target.create("llvm"): @@ -94,8 +95,8 @@ def test_calibrate_target(create_target): def test_calibrate_memory_bound(): - mod, params = testing.resnet.get_workload(num_layers=18) - dataset = get_calibration_dataset("data") + mod, params = testing.synthetic.get_workload() + dataset = get_calibration_dataset(mod, "data") import multiprocessing num_cpu = multiprocessing.cpu_count() with relay.quantize.qconfig(calibrate_mode="kl_divergence", diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index d3bb0841da8f9..e96d36258c678 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -594,7 +594,7 @@ def test_add_op_broadcast(): check_result([x_data, y_data], x_data + y_data, mod=mod) def test_vm_optimize(): - mod, params = testing.resnet.get_workload(batch_size=1, num_layers=18) + mod, params = testing.synthetic.get_workload() comp = relay.vm.VMCompiler() opt_mod, _ = comp.optimize(mod, "llvm", params) diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index d1bcdccc30ccb..2aae431b2a944 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -51,13 +51,14 @@ def get_serialized_output(mod, *data, params=None, target="llvm", def run_network(mod, params, - data_shape=(1, 3, 224, 224), dtype='float32'): def get_vm_output(mod, data, params, target, ctx, dtype='float32'): ex = relay.create_executor('vm', mod=mod, ctx=ctx) result = ex.evaluate()(data, **params) return result.asnumpy().astype(dtype) + print(mod["main"]) + data_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] data = np.random.uniform(size=data_shape).astype(dtype) target = "llvm" ctx = tvm.cpu(0) @@ -272,8 +273,8 @@ def test_closure(): tvm.testing.assert_allclose(res.asnumpy(), 3.0) -def test_resnet(): - mod, params = testing.resnet.get_workload(batch_size=1, num_layers=18) +def test_synthetic(): + mod, params = testing.synthetic.get_workload() run_network(mod, params) @@ -306,6 +307,6 @@ def test_vm_shape_of(): test_adt_list() test_adt_compose() test_closure() - test_resnet() + test_synthetic() test_mobilenet() test_vm_shape_of() diff --git a/tests/python/unittest/test_autotvm_graph_tuner_utils.py b/tests/python/unittest/test_autotvm_graph_tuner_utils.py index bd0ebe0cd3f56..b675df798021c 100644 --- a/tests/python/unittest/test_autotvm_graph_tuner_utils.py +++ b/tests/python/unittest/test_autotvm_graph_tuner_utils.py @@ -24,7 +24,7 @@ from tvm import te from tvm import autotvm, relay -from tvm.relay.testing import resnet +from tvm.relay.testing import synthetic from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \ get_out_nodes, expr2graph, bind_inputs from tvm.autotvm.graph_tuner._base import OPT_OUT_OP @@ -56,7 +56,7 @@ def test_has_multiple_inputs(): def test_expr2graph(): - mod, _ = resnet.get_workload(num_layers=50, batch_size=1) + mod, _ = synthetic.get_workload() node_dict = {} node_list = [] target_ops = [relay.op.get("nn.conv2d")] diff --git a/tests/python/unittest/test_runtime_micro.py b/tests/python/unittest/test_runtime_micro.py index 1983def99cf58..841bffbcec1de 100644 --- a/tests/python/unittest/test_runtime_micro.py +++ b/tests/python/unittest/test_runtime_micro.py @@ -23,7 +23,6 @@ from tvm import relay import tvm.micro as micro from tvm.micro import create_micro_mod -from tvm.relay.testing import resnet # # Use the host emulated micro device. DEV_CONFIG_A = micro.device.host.generate_config() diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 5ab4e829f2ed2..f97a4454f3bee 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -21,11 +21,14 @@ from tvm.contrib import graph_runtime from tvm.contrib.debugger import debug_runtime +def input_shape(mod): + return [int(x) for x in mod["main"].checked_type.arg_types[0].shape] + def verify(data): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) @@ -42,10 +45,10 @@ def test_legacy_compatibility(): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") ctx = tvm.cpu() module = graph_runtime.create(graph, lib, ctx) module.set_input("data", data) @@ -58,10 +61,10 @@ def test_cpu(): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") # raw api ctx = tvm.cpu() gmod = complied_graph_lib['default'](ctx) @@ -84,10 +87,10 @@ def test_gpu(): if not tvm.runtime.enabled("cuda"): print("Skip because cuda is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") ctx = tvm.gpu() # raw api @@ -112,7 +115,7 @@ def verify_cpu_export(obj_format): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) @@ -133,7 +136,7 @@ def verify_cpu_export(obj_format): set_input = gmod["set_input"] run = gmod["run"] get_output = gmod["get_output"] - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") set_input("data", tvm.nd.array(data)) run() out = get_output(0).asnumpy() @@ -150,7 +153,7 @@ def verify_gpu_export(obj_format): if not tvm.runtime.enabled("cuda"): print("Skip because cuda is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) @@ -164,7 +167,7 @@ def verify_gpu_export(obj_format): path_lib = temp.relpath(file_name) complied_graph_lib.export_library(path_lib) loaded_lib = tvm.runtime.load_module(path_lib) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") ctx = tvm.gpu() # raw api @@ -188,7 +191,7 @@ def verify_rpc_cpu_export(obj_format): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) @@ -203,11 +206,11 @@ def verify_rpc_cpu_export(obj_format): complied_graph_lib.export_library(path_lib) from tvm import rpc - server = rpc.Server("localhost", use_popen=True) + server = rpc.Server("localhost", use_popen=True, port=9093) remote = rpc.connect(server.host, server.port) remote.upload(path_lib) loaded_lib = remote.load_module(path_lib) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") ctx = remote.cpu() # raw api @@ -231,7 +234,7 @@ def verify_rpc_gpu_export(obj_format): if not tvm.runtime.enabled("cuda"): print("Skip because cuda is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) @@ -246,11 +249,11 @@ def verify_rpc_gpu_export(obj_format): complied_graph_lib.export_library(path_lib) from tvm import rpc - server = rpc.Server("localhost", use_popen=True) + server = rpc.Server("localhost", use_popen=True, port=9094) remote = rpc.connect(server.host, server.port) remote.upload(path_lib) loaded_lib = remote.load_module(path_lib) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") ctx = remote.gpu() # raw api @@ -281,7 +284,7 @@ def verify_cpu_remove_package_params(obj_format): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) @@ -298,7 +301,7 @@ def verify_cpu_remove_package_params(obj_format): with open(temp.relpath("deploy_param.params"), "wb") as fo: fo.write(relay.save_param_dict(complied_graph_lib.get_params())) loaded_lib = tvm.runtime.load_module(path_lib) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") ctx = tvm.cpu(0) # raw api @@ -327,7 +330,7 @@ def verify_gpu_remove_package_params(obj_format): if not tvm.runtime.enabled("cuda"): print("Skip because cuda is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) @@ -344,7 +347,7 @@ def verify_gpu_remove_package_params(obj_format): with open(temp.relpath("deploy_param.params"), "wb") as fo: fo.write(relay.save_param_dict(complied_graph_lib.get_params())) loaded_lib = tvm.runtime.load_module(path_lib) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") ctx = tvm.gpu(0) # raw api @@ -373,7 +376,7 @@ def verify_rpc_cpu_remove_package_params(obj_format): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) @@ -392,11 +395,11 @@ def verify_rpc_cpu_remove_package_params(obj_format): fo.write(relay.save_param_dict(complied_graph_lib.get_params())) from tvm import rpc - server = rpc.Server("localhost", use_popen=True) + server = rpc.Server("localhost", use_popen=True, port=9095) remote = rpc.connect(server.host, server.port) remote.upload(path_lib) loaded_lib = remote.load_module(path_lib) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") ctx = remote.cpu() # raw api @@ -425,7 +428,7 @@ def verify_rpc_gpu_remove_package_params(obj_format): if not tvm.runtime.enabled("cuda"): print("Skip because cuda is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) @@ -444,11 +447,11 @@ def verify_rpc_gpu_remove_package_params(obj_format): fo.write(relay.save_param_dict(complied_graph_lib.get_params())) from tvm import rpc - server = rpc.Server("localhost", use_popen=True) + server = rpc.Server("localhost", use_popen=True, port=9092) remote = rpc.connect(server.host, server.port) remote.upload(path_lib) loaded_lib = remote.load_module(path_lib) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") ctx = remote.gpu() # raw api @@ -483,10 +486,10 @@ def test_debug_graph_runtime(): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") return - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") # raw api ctx = tvm.cpu() diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py index 8ee197d643acc..9a859da39ae21 100644 --- a/tests/python/unittest/test_runtime_module_export.py +++ b/tests/python/unittest/test_runtime_module_export.py @@ -66,11 +66,11 @@ def verify_gpu_mod_export(obj_format): print("skip because %s is not enabled..." % device) return - resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) - resnet50_mod, resnet50_params = relay.testing.resnet.get_workload(num_layers=50) + synthetic_mod, synthetic_params = relay.testing.synthetic.get_workload() + synthetic_llvm_mod, synthetic_llvm_params = relay.testing.synthetic.get_workload() with tvm.transform.PassContext(opt_level=3): - _, resnet18_gpu_lib, _ = relay.build_module.build(resnet18_mod, "cuda", params=resnet18_params) - _, resnet50_cpu_lib, _ = relay.build_module.build(resnet50_mod, "llvm", params=resnet50_params) + _, synthetic_gpu_lib, _ = relay.build_module.build(synthetic_mod, "cuda", params=synthetic_params) + _, synthetic_llvm_cpu_lib, _ = relay.build_module.build(synthetic_llvm_mod, "llvm", params=synthetic_llvm_params) from tvm.contrib import util temp = util.tempdir() @@ -80,8 +80,8 @@ def verify_gpu_mod_export(obj_format): assert obj_format == ".tar" file_name = "deploy_lib.tar" path_lib = temp.relpath(file_name) - resnet18_gpu_lib.imported_modules[0].import_module(resnet50_cpu_lib) - resnet18_gpu_lib.export_library(path_lib) + synthetic_gpu_lib.imported_modules[0].import_module(synthetic_llvm_cpu_lib) + synthetic_gpu_lib.export_library(path_lib) loaded_lib = tvm.runtime.load_module(path_lib) assert loaded_lib.type_key == "library" assert loaded_lib.imported_modules[0].type_key == "cuda" @@ -93,9 +93,9 @@ def verify_multi_dso_mod_export(obj_format): print("skip because %s is not enabled..." % device) return - resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) + synthetic_mod, synthetic_params = relay.testing.synthetic.get_workload() with tvm.transform.PassContext(opt_level=3): - _, resnet18_cpu_lib, _ = relay.build_module.build(resnet18_mod, "llvm", params=resnet18_params) + _, synthetic_cpu_lib, _ = relay.build_module.build(synthetic_mod, "llvm", params=synthetic_params) A = te.placeholder((1024,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') @@ -109,8 +109,8 @@ def verify_multi_dso_mod_export(obj_format): assert obj_format == ".tar" file_name = "deploy_lib.tar" path_lib = temp.relpath(file_name) - resnet18_cpu_lib.import_module(f) - resnet18_cpu_lib.export_library(path_lib) + synthetic_cpu_lib.import_module(f) + synthetic_cpu_lib.export_library(path_lib) loaded_lib = tvm.runtime.load_module(path_lib) assert loaded_lib.type_key == "library" assert loaded_lib.imported_modules[0].type_key == "library" @@ -177,9 +177,9 @@ def verify_multi_c_mod_export(): print("skip because %s is not enabled..." % device) return - resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) + synthetic_mod, synthetic_params = relay.testing.synthetic.get_workload() with tvm.transform.PassContext(opt_level=3): - _, resnet18_cpu_lib, _ = relay.build_module.build(resnet18_mod, "llvm", params=resnet18_params) + _, synthetic_cpu_lib, _ = relay.build_module.build(synthetic_mod, "llvm", params=synthetic_params) A = te.placeholder((1024,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') @@ -190,10 +190,10 @@ def verify_multi_c_mod_export(): temp = util.tempdir() file_name = "deploy_lib.so" path_lib = temp.relpath(file_name) - resnet18_cpu_lib.import_module(f) - resnet18_cpu_lib.import_module(engine_module) + synthetic_cpu_lib.import_module(f) + synthetic_cpu_lib.import_module(engine_module) kwargs = {"options": ["-O2", "-std=c++14", "-I" + header_file_dir_path.relpath("")]} - resnet18_cpu_lib.export_library(path_lib, fcompile=False, **kwargs) + synthetic_cpu_lib.export_library(path_lib, fcompile=False, **kwargs) loaded_lib = tvm.runtime.load_module(path_lib) assert loaded_lib.type_key == "library" assert loaded_lib.imported_modules[0].type_key == "library" diff --git a/tests/python/unittest/test_target_codegen_blob.py b/tests/python/unittest/test_target_codegen_blob.py index 7cd579397ec82..0059083ebdcc5 100644 --- a/tests/python/unittest/test_target_codegen_blob.py +++ b/tests/python/unittest/test_target_codegen_blob.py @@ -23,14 +23,16 @@ from tvm import te import ctypes -def test_resnet18(): +def test_synthetic(): for device in ["llvm", "cuda"]: if not tvm.runtime.enabled(device): print("skip because %s is not enabled..." % device) return + input_shape = (1, 5, 23, 61) + def verify(data): - mod, params = relay.testing.resnet.get_workload(num_layers=18) + mod, params = relay.testing.synthetic.get_workload(input_shape=input_shape) with tvm.transform.PassContext(opt_level=3): graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) ctx = tvm.cpu() @@ -41,14 +43,14 @@ def verify(data): out = module.get_output(0).asnumpy() return out - resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) + synthetic_mod, synthetic_params = relay.testing.synthetic.get_workload(input_shape=input_shape) with tvm.transform.PassContext(opt_level=3): - graph, resnet18_gpu_lib, graph_params = relay.build_module.build(resnet18_mod, "cuda", params=resnet18_params) + graph, synthetic_gpu_lib, graph_params = relay.build_module.build(synthetic_mod, "cuda", params=synthetic_params) from tvm.contrib import util temp = util.tempdir() path_lib = temp.relpath("deploy_lib.so") - resnet18_gpu_lib.export_library(path_lib) + synthetic_gpu_lib.export_library(path_lib) with open(temp.relpath("deploy_graph.json"), "w") as fo: fo.write(graph) with open(temp.relpath("deploy_param.params"), "wb") as fo: @@ -57,7 +59,7 @@ def verify(data): loaded_lib = tvm.runtime.load_module(path_lib) loaded_json = open(temp.relpath("deploy_graph.json")).read() loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + data = np.random.uniform(-1, 1, size=input_shape).astype("float32") ctx = tvm.gpu() module = graph_runtime.create(loaded_json, loaded_lib, ctx) module.load_params(loaded_params) @@ -96,5 +98,5 @@ def test_cuda_lib(): if __name__ == "__main__": - test_resnet18() + test_synthetic() #test_system_lib()