Skip to content

Commit

Permalink
[MetaSchedule] Refactor testing workloads (#10497)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Mar 6, 2022
1 parent 7cfaa88 commit 085d36c
Show file tree
Hide file tree
Showing 15 changed files with 683 additions and 439 deletions.
12 changes: 9 additions & 3 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,9 @@ def auto_schedule_topi(func_name, outs):
"""

# pylint: disable=import-outside-toplevel
from tvm.auto_scheduler.measure import (
from tvm.auto_scheduler.measure import ( # lazily import to avoid recursive dependency
prepare_input_map,
) # lazily import to avoid recursive dependency
)

io_tensors, has_layout_free, has_complex_op = traverse_to_get_io_tensors(outs)
if not io_tensors: # The compute includes dynamic shapes which are not supported yet.
Expand Down Expand Up @@ -482,4 +482,10 @@ def is_auto_scheduler_enabled():
enabled: bool
Whether the auto-scheduler is enabled
"""
return PassContext.current().config.get("relay.backend.use_auto_scheduler", False)
return PassContext.current().config.get(
"relay.backend.use_auto_scheduler",
False,
) or PassContext.current().config.get(
"relay.backend.use_meta_schedule",
False,
)
15 changes: 9 additions & 6 deletions python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,8 @@ def extract_task_from_relay(
params: Optional[Dict[str, NDArray]] = None,
*,
opt_level: int = 3,
pass_config: Dict[str, Any] = {
"relay.backend.use_meta_schedule": True,
},
disabled_pass: List[str] = [],
pass_config: Optional[Dict[str, Any]] = None,
disabled_pass: Optional[List[str]] = None,
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relay program.
Expand All @@ -221,9 +219,9 @@ def extract_task_from_relay(
The associated parameters of the program
opt_level : int
The optimization level of the compiler
pass_config : Dict[str, Any]
pass_config : Optional[Dict[str, Any]]
The pass config of the compiler
disabled_pass : List[str]
disabled_pass : Optional[List[str]]
The list of disabled passes of the compiler
Returns
Expand All @@ -250,6 +248,11 @@ def _thread_run(func: Callable[[], None]) -> None:
thread.start()
thread.join()

if disabled_pass is None:
disabled_pass = []
if pass_config is None:
pass_config = {"relay.backend.use_meta_schedule": True}

env = TaskExtraction()
if isinstance(mod, RelayFunc):
mod = IRModule.from_expr(mod)
Expand Down
3 changes: 0 additions & 3 deletions python/tvm/meta_schedule/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,3 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities in meta schedule"""
from .byoc_trt import relay_build_with_tensorrt
from .local_rpc import LocalRPC
from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model
53 changes: 0 additions & 53 deletions python/tvm/meta_schedule/testing/byoc_trt.py

This file was deleted.

2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def conv2d_winograd_cpu(
eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap(
"SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]
)
T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cpu"})
T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.llvm"})
T.reads(
[
data_pack[eps_1, nu_1, p_1, ci_1],
Expand Down
140 changes: 140 additions & 0 deletions python/tvm/meta_schedule/testing/custom_builder_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Customized builder and runner methods"""
# pylint: disable=import-outside-toplevel

from typing import TYPE_CHECKING, Dict, List

if TYPE_CHECKING:
from tvm.ir import IRModule
from tvm.meta_schedule.runner import EvaluatorConfig
from tvm.runtime import Device, Module, NDArray
from tvm.target import Target


def build_relay(
mod: "IRModule",
target: "Target",
params: Dict[str, "NDArray"],
) -> "Module":
"""Build a Relay IRModule
Parameters
----------
mod : IRModule
The Relay IRModule to build.
target : Target
The target to build the module for.
params : Dict[str, NDArray]
The parameter dict to build the module with.
Returns
-------
mod : runtime.Module
The built module.
"""
from tvm.relay.build_module import _build_module_no_factory as relay_build
from tvm.runtime import Module

result = relay_build(mod, target=target, target_host=None, params=params)
assert isinstance(result, Module)
return result


def build_relay_with_tensorrt(
mod: "IRModule",
target: "Target",
params: Dict[str, "NDArray"],
) -> "Module":
"""Build a Relay IRModule with TensorRT BYOC
Parameters
----------
mod : IRModule
The Relay IRModule to build.
target : Target
The target to build the module for.
params : Dict[str, NDArray]
The parameter dict to build the module with.
Returns
-------
mod : runtime.Module
The built module.
"""
from tvm.ir.transform import PassContext
from tvm.relay.build_module import _build_module_no_factory as relay_build
from tvm.relay.op.contrib import tensorrt
from tvm.runtime import Module

mod, config = tensorrt.partition_for_tensorrt(mod, params)
with PassContext(
opt_level=3,
config={"relay.ext.tensorrt.options": config},
):
result = relay_build(mod, target=target, target_host=None, params=params)
assert isinstance(result, Module)
return result


def run_with_graph_executor(
rt_mod: "Module",
device: "Device",
evaluator_config: "EvaluatorConfig",
repeated_args: List["NDArray"],
) -> List[float]:
"""Run a Relay module with GraphExecutor
Parameters
----------
rt_mod : Module
The Relay module to run.
device : Device
The device to run the module on.
evaluator_config : EvaluatorConfig
The evaluator configuration to run the module with.
repeated_args : List[NDArray]
The list of repeated arguments to run the module with.
Returns
-------
results : List[float]
The list of results.
"""
import itertools

from tvm.contrib.graph_executor import GraphModule

graph_mod = GraphModule(rt_mod["default"](device))
evaluator = graph_mod.module.time_evaluator(
func_name="run",
dev=device,
number=evaluator_config.number,
repeat=evaluator_config.repeat,
min_repeat_ms=evaluator_config.min_repeat_ms,
f_preproc="cache_flush_cpu_non_first_arg"
if evaluator_config.enable_cpu_cache_flush
else "",
)
repeated_costs = []
for args in repeated_args:
profile_result = evaluator(*args)
repeated_costs.append(profile_result.results)
costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
return costs
Loading

0 comments on commit 085d36c

Please sign in to comment.