Skip to content

Commit

Permalink
[Meta Schedule] Fix testing issues for models with more than one inpu…
Browse files Browse the repository at this point in the history
…ts (#11298)
  • Loading branch information
Yuanjing Shi authored May 27, 2022
1 parent 903f785 commit 80d9549
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 25 deletions.
8 changes: 5 additions & 3 deletions python/tvm/meta_schedule/testing/custom_builder_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def run_module_via_rpc(
rpc_config: "RPCConfig",
lib: "Module",
dev_type: str,
args: List["np.ndarray"],
args: Dict[str, "np.ndarray"],
continuation: Callable,
):
"""Execute a tvm.runtime.Module on RPC remote"""
Expand All @@ -166,5 +166,7 @@ def run_module_via_rpc(
_, filename = os.path.split(filename)
rt_mod = session.load_module(filename)
dev = session.device(dev_type=dev_type, dev_id=0)
args = [ndarray.array(arg, dev) for arg in args]
return continuation(rt_mod, dev, *args)
nd_args = {}
for arg_key, arg_value in args.items():
nd_args[arg_key] = ndarray.array(arg_value, dev)
return continuation(rt_mod, dev, nd_args)
30 changes: 19 additions & 11 deletions python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,13 @@ def main():
ARGS.input_shape,
cache_dir=ARGS.cache_dir,
)
input_info = {input_name: input_shape}
input_data = {}
print(f"Workload: {ARGS.workload}")
print(f" input_name: {input_name}")
print(f" input_shape: {input_shape}")
print(f" input_dtype: {input_dtype}")
for input_name, input_shape in input_info.items():
print(f" input_name: {input_name}")
print(f" input_shape: {input_shape}")
print(f" input_dtype: {input_dtype}")
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"],
params,
Expand Down Expand Up @@ -170,10 +173,13 @@ def main():
params=params,
)
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
if input_dtype.startswith("float"):
input_data = np.random.uniform(size=input_shape).astype(input_dtype)
else:
input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype)
for input_name, input_shape in input_info.items():
if input_dtype.startswith("float"):
input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype)
else:
input_data[input_name] = np.random.randint(
low=0, high=10000, size=input_shape, dtype=input_dtype
)

def f_timer(rt_mod, dev, input_data):
# pylint: disable=import-outside-toplevel
Expand All @@ -182,7 +188,8 @@ def f_timer(rt_mod, dev, input_data):
# pylint: enable=import-outside-toplevel

mod = GraphModule(rt_mod["default"](dev))
mod.set_input(input_name, input_data)
for input_name, input_value in input_data.items():
mod.set_input(input_name, input_value)
ftimer = mod.module.time_evaluator(
"run",
dev,
Expand All @@ -196,7 +203,7 @@ def f_timer(rt_mod, dev, input_data):
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=[input_data],
args=input_data,
continuation=f_timer,
)

Expand All @@ -206,7 +213,8 @@ def f_per_layer(rt_mod, dev, input_data):

# pylint: enable=import-outside-toplevel
mod = create(graph, rt_mod, dev)
mod.set_input(input_name, input_data)
for input_name, input_value in input_data.items():
mod.set_input(input_name, input_value)
graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000)
print("|graph_nodes| = ", len(graph_nodes))
Expand All @@ -219,7 +227,7 @@ def f_per_layer(rt_mod, dev, input_data):
rpc_config=ARGS.rpc_config,
lib=rt_mod,
dev_type=ARGS.target.kind.name,
args=[input_data],
args=input_data,
continuation=f_per_layer,
)

Expand Down
30 changes: 19 additions & 11 deletions python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,13 @@ def main():
ARGS.input_shape,
cache_dir=ARGS.cache_dir,
)
input_info = {input_name: input_shape}
input_data = {}
print(f"Workload: {ARGS.workload}")
print(f" input_name: {input_name}")
print(f" input_shape: {input_shape}")
print(f" input_dtype: {input_dtype}")
for input_name, input_shape in input_info.items():
print(f" input_name: {input_name}")
print(f" input_shape: {input_shape}")
print(f" input_dtype: {input_dtype}")
alloc_repeat = 1
runner = ms.runner.RPCRunner(
rpc_config=ARGS.rpc_config,
Expand All @@ -133,10 +136,13 @@ def main():
params=params,
)
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
if input_dtype.startswith("float"):
input_data = np.random.uniform(size=input_shape).astype(input_dtype)
else:
input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype)
for input_name, input_shape in input_info.items():
if input_dtype.startswith("float"):
input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype)
else:
input_data[input_name] = np.random.randint(
low=0, high=10000, size=input_shape, dtype=input_dtype
)

def f_timer(rt_mod, dev, input_data):
# pylint: disable=import-outside-toplevel
Expand All @@ -145,7 +151,8 @@ def f_timer(rt_mod, dev, input_data):
# pylint: enable=import-outside-toplevel

mod = GraphModule(rt_mod["default"](dev))
mod.set_input(input_name, input_data)
for input_name, input_value in input_data.items():
mod.set_input(input_name, input_value)
ftimer = mod.module.time_evaluator(
"run",
dev,
Expand All @@ -159,7 +166,7 @@ def f_timer(rt_mod, dev, input_data):
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=[input_data],
args=input_data,
continuation=f_timer,
)

Expand All @@ -169,7 +176,8 @@ def f_per_layer(rt_mod, dev, input_data):

# pylint: enable=import-outside-toplevel
mod = create(graph, rt_mod, dev)
mod.set_input(input_name, input_data)
for input_name, input_value in input_data.items():
mod.set_input(input_name, input_value)
graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000)
print("|graph_nodes| = ", len(graph_nodes))
Expand All @@ -182,7 +190,7 @@ def f_per_layer(rt_mod, dev, input_data):
rpc_config=ARGS.rpc_config,
lib=rt_mod,
dev_type=ARGS.target.kind.name,
args=[input_data],
args=input_data,
continuation=f_per_layer,
)

Expand Down
49 changes: 49 additions & 0 deletions tests/python/unittest/test_meta_schedule_tune_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
# pylint: disable=missing-docstring
import logging
import tempfile
import numpy as np

import pytest
import tvm

from tvm import meta_schedule as ms
from tvm.meta_schedule import TuneConfig, tune_tir
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.local_rpc import LocalRPC
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir import Schedule
Expand Down Expand Up @@ -89,6 +95,49 @@ def test_tune_matmul_cuda():
print(sch.trace)


def test_tune_run_module_via_rpc():
target = tvm.target.Target("llvm")
rt_mod = tvm.build(matmul, target)

# construct the input
input_data = {}
input_shape = (128, 128)
input_dtype = "float32"
a_np = np.random.uniform(size=input_shape).astype(input_dtype)
b_np = np.random.uniform(size=input_shape).astype(input_dtype)
c_np = np.zeros(input_shape).astype(input_dtype)
for i in range(128):
for j in range(128):
for k in range(128):
c_np[i, j] = c_np[i, j] + a_np[i, k] * b_np[j, k]
input_data["a"] = a_np
input_data["b"] = b_np
input_data["c"] = np.zeros(input_shape).astype(input_dtype)

with LocalRPC() as rpc:
rpc_config = ms.runner.RPCConfig(
tracker_host=rpc.tracker_host,
tracker_port=rpc.tracker_port,
tracker_key=rpc.tracker_key,
session_priority=1,
session_timeout_sec=100,
)

def f_timer(rt_mod, dev, input_data):
rt_mod(input_data["a"], input_data["b"], input_data["c"])
return input_data["c"]

result = run_module_via_rpc(
rpc_config=rpc_config,
lib=rt_mod,
dev_type=target.kind.name,
args=input_data,
continuation=f_timer,
)
tvm.testing.assert_allclose(result.numpy(), c_np, rtol=1e-3)


if __name__ == """__main__""":
test_tune_matmul_cpu()
test_tune_matmul_cuda()
test_tune_run_module_via_rpc()

0 comments on commit 80d9549

Please sign in to comment.