From 80d9549190b60a9d36ef93e6d842de26515d6718 Mon Sep 17 00:00:00 2001 From: Yuanjing Shi Date: Fri, 27 May 2022 16:41:54 -0700 Subject: [PATCH] [Meta Schedule] Fix testing issues for models with more than one inputs (#11298) --- .../testing/custom_builder_runner.py | 8 +-- .../testing/tune_relay_auto_scheduler.py | 30 +++++++----- .../testing/tune_relay_meta_schedule.py | 30 +++++++----- .../unittest/test_meta_schedule_tune_tir.py | 49 +++++++++++++++++++ 4 files changed, 92 insertions(+), 25 deletions(-) diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 83bb4aab516b..3ba007d9a4d3 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -145,7 +145,7 @@ def run_module_via_rpc( rpc_config: "RPCConfig", lib: "Module", dev_type: str, - args: List["np.ndarray"], + args: Dict[str, "np.ndarray"], continuation: Callable, ): """Execute a tvm.runtime.Module on RPC remote""" @@ -166,5 +166,7 @@ def run_module_via_rpc( _, filename = os.path.split(filename) rt_mod = session.load_module(filename) dev = session.device(dev_type=dev_type, dev_id=0) - args = [ndarray.array(arg, dev) for arg in args] - return continuation(rt_mod, dev, *args) + nd_args = {} + for arg_key, arg_value in args.items(): + nd_args[arg_key] = ndarray.array(arg_value, dev) + return continuation(rt_mod, dev, nd_args) diff --git a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py index 2a2c20868bb7..abac49c50c6e 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py +++ b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py @@ -134,10 +134,13 @@ def main(): ARGS.input_shape, cache_dir=ARGS.cache_dir, ) + input_info = {input_name: input_shape} + input_data = {} print(f"Workload: {ARGS.workload}") - print(f" input_name: {input_name}") - print(f" input_shape: {input_shape}") - print(f" input_dtype: {input_dtype}") + for input_name, input_shape in input_info.items(): + print(f" input_name: {input_name}") + print(f" input_shape: {input_shape}") + print(f" input_dtype: {input_dtype}") tasks, task_weights = auto_scheduler.extract_tasks( mod["main"], params, @@ -170,10 +173,13 @@ def main(): params=params, ) graph, rt_mod, params = lib.graph_json, lib.lib, lib.params - if input_dtype.startswith("float"): - input_data = np.random.uniform(size=input_shape).astype(input_dtype) - else: - input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype) + for input_name, input_shape in input_info.items(): + if input_dtype.startswith("float"): + input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype) + else: + input_data[input_name] = np.random.randint( + low=0, high=10000, size=input_shape, dtype=input_dtype + ) def f_timer(rt_mod, dev, input_data): # pylint: disable=import-outside-toplevel @@ -182,7 +188,8 @@ def f_timer(rt_mod, dev, input_data): # pylint: enable=import-outside-toplevel mod = GraphModule(rt_mod["default"](dev)) - mod.set_input(input_name, input_data) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) ftimer = mod.module.time_evaluator( "run", dev, @@ -196,7 +203,7 @@ def f_timer(rt_mod, dev, input_data): rpc_config=ARGS.rpc_config, lib=lib, dev_type=ARGS.target.kind.name, - args=[input_data], + args=input_data, continuation=f_timer, ) @@ -206,7 +213,8 @@ def f_per_layer(rt_mod, dev, input_data): # pylint: enable=import-outside-toplevel mod = create(graph, rt_mod, dev) - mod.set_input(input_name, input_data) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]] graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000) print("|graph_nodes| = ", len(graph_nodes)) @@ -219,7 +227,7 @@ def f_per_layer(rt_mod, dev, input_data): rpc_config=ARGS.rpc_config, lib=rt_mod, dev_type=ARGS.target.kind.name, - args=[input_data], + args=input_data, continuation=f_per_layer, ) diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py index 88de0c336073..bd858e0f2d36 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py @@ -103,10 +103,13 @@ def main(): ARGS.input_shape, cache_dir=ARGS.cache_dir, ) + input_info = {input_name: input_shape} + input_data = {} print(f"Workload: {ARGS.workload}") - print(f" input_name: {input_name}") - print(f" input_shape: {input_shape}") - print(f" input_dtype: {input_dtype}") + for input_name, input_shape in input_info.items(): + print(f" input_name: {input_name}") + print(f" input_shape: {input_shape}") + print(f" input_dtype: {input_dtype}") alloc_repeat = 1 runner = ms.runner.RPCRunner( rpc_config=ARGS.rpc_config, @@ -133,10 +136,13 @@ def main(): params=params, ) graph, rt_mod, params = lib.graph_json, lib.lib, lib.params - if input_dtype.startswith("float"): - input_data = np.random.uniform(size=input_shape).astype(input_dtype) - else: - input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype) + for input_name, input_shape in input_info.items(): + if input_dtype.startswith("float"): + input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype) + else: + input_data[input_name] = np.random.randint( + low=0, high=10000, size=input_shape, dtype=input_dtype + ) def f_timer(rt_mod, dev, input_data): # pylint: disable=import-outside-toplevel @@ -145,7 +151,8 @@ def f_timer(rt_mod, dev, input_data): # pylint: enable=import-outside-toplevel mod = GraphModule(rt_mod["default"](dev)) - mod.set_input(input_name, input_data) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) ftimer = mod.module.time_evaluator( "run", dev, @@ -159,7 +166,7 @@ def f_timer(rt_mod, dev, input_data): rpc_config=ARGS.rpc_config, lib=lib, dev_type=ARGS.target.kind.name, - args=[input_data], + args=input_data, continuation=f_timer, ) @@ -169,7 +176,8 @@ def f_per_layer(rt_mod, dev, input_data): # pylint: enable=import-outside-toplevel mod = create(graph, rt_mod, dev) - mod.set_input(input_name, input_data) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]] graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000) print("|graph_nodes| = ", len(graph_nodes)) @@ -182,7 +190,7 @@ def f_per_layer(rt_mod, dev, input_data): rpc_config=ARGS.rpc_config, lib=rt_mod, dev_type=ARGS.target.kind.name, - args=[input_data], + args=input_data, continuation=f_per_layer, ) diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index a7806ebda28a..0e8c205230e6 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -17,9 +17,15 @@ # pylint: disable=missing-docstring import logging import tempfile +import numpy as np import pytest +import tvm + +from tvm import meta_schedule as ms from tvm.meta_schedule import TuneConfig, tune_tir +from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc +from tvm.meta_schedule.testing.local_rpc import LocalRPC from tvm.script import tir as T from tvm.target import Target from tvm.tir import Schedule @@ -89,6 +95,49 @@ def test_tune_matmul_cuda(): print(sch.trace) +def test_tune_run_module_via_rpc(): + target = tvm.target.Target("llvm") + rt_mod = tvm.build(matmul, target) + + # construct the input + input_data = {} + input_shape = (128, 128) + input_dtype = "float32" + a_np = np.random.uniform(size=input_shape).astype(input_dtype) + b_np = np.random.uniform(size=input_shape).astype(input_dtype) + c_np = np.zeros(input_shape).astype(input_dtype) + for i in range(128): + for j in range(128): + for k in range(128): + c_np[i, j] = c_np[i, j] + a_np[i, k] * b_np[j, k] + input_data["a"] = a_np + input_data["b"] = b_np + input_data["c"] = np.zeros(input_shape).astype(input_dtype) + + with LocalRPC() as rpc: + rpc_config = ms.runner.RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + + def f_timer(rt_mod, dev, input_data): + rt_mod(input_data["a"], input_data["b"], input_data["c"]) + return input_data["c"] + + result = run_module_via_rpc( + rpc_config=rpc_config, + lib=rt_mod, + dev_type=target.kind.name, + args=input_data, + continuation=f_timer, + ) + tvm.testing.assert_allclose(result.numpy(), c_np, rtol=1e-3) + + if __name__ == """__main__""": test_tune_matmul_cpu() test_tune_matmul_cuda() + test_tune_run_module_via_rpc()