diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index e9bb68ad8e93..7ff1840c9123 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -329,9 +329,9 @@ def auto_schedule_topi(func_name, outs): """ # pylint: disable=import-outside-toplevel - from tvm.auto_scheduler.measure import ( + from tvm.auto_scheduler.measure import ( # lazily import to avoid recursive dependency prepare_input_map, - ) # lazily import to avoid recursive dependency + ) io_tensors, has_layout_free, has_complex_op = traverse_to_get_io_tensors(outs) if not io_tensors: # The compute includes dynamic shapes which are not supported yet. @@ -482,4 +482,10 @@ def is_auto_scheduler_enabled(): enabled: bool Whether the auto-scheduler is enabled """ - return PassContext.current().config.get("relay.backend.use_auto_scheduler", False) + return PassContext.current().config.get( + "relay.backend.use_auto_scheduler", + False, + ) or PassContext.current().config.get( + "relay.backend.use_meta_schedule", + False, + ) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 178239e7def1..26b01444e752 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -204,10 +204,8 @@ def extract_task_from_relay( params: Optional[Dict[str, NDArray]] = None, *, opt_level: int = 3, - pass_config: Dict[str, Any] = { - "relay.backend.use_meta_schedule": True, - }, - disabled_pass: List[str] = [], + pass_config: Optional[Dict[str, Any]] = None, + disabled_pass: Optional[List[str]] = None, ) -> List[ExtractedTask]: """Extract tuning tasks from a relay program. @@ -221,9 +219,9 @@ def extract_task_from_relay( The associated parameters of the program opt_level : int The optimization level of the compiler - pass_config : Dict[str, Any] + pass_config : Optional[Dict[str, Any]] The pass config of the compiler - disabled_pass : List[str] + disabled_pass : Optional[List[str]] The list of disabled passes of the compiler Returns @@ -250,6 +248,11 @@ def _thread_run(func: Callable[[], None]) -> None: thread.start() thread.join() + if disabled_pass is None: + disabled_pass = [] + if pass_config is None: + pass_config = {"relay.backend.use_meta_schedule": True} + env = TaskExtraction() if isinstance(mod, RelayFunc): mod = IRModule.from_expr(mod) diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py index 85b48b35f621..5d6081fa81e4 100644 --- a/python/tvm/meta_schedule/testing/__init__.py +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -15,6 +15,3 @@ # specific language governing permissions and limitations # under the License. """Testing utilities in meta schedule""" -from .byoc_trt import relay_build_with_tensorrt -from .local_rpc import LocalRPC -from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model diff --git a/python/tvm/meta_schedule/testing/byoc_trt.py b/python/tvm/meta_schedule/testing/byoc_trt.py deleted file mode 100644 index d459518cdb23..000000000000 --- a/python/tvm/meta_schedule/testing/byoc_trt.py +++ /dev/null @@ -1,53 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TensorRT-MetaSchedule integration""" -# pylint: disable=import-outside-toplevel - -from typing import List -import tvm -from tvm.runtime import Module -from tvm.meta_schedule.builder import BuilderResult -from tvm.target import Target - - -def relay_build_with_tensorrt( - mod: Module, - target: Target, - params: dict, -) -> List[BuilderResult]: - """Build a Relay IRModule with TensorRT BYOC - Parameters - ---------- - mod : IRModule - The Relay IRModule to build. - target : Target - The target to build the module for. - params : Dict[str, NDArray] - The parameter dict to build the module with. - Returns - ------- - mod : runtime.Module - The built module. - """ - from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt - - assert isinstance(target, Target) - mod, config = partition_for_tensorrt(mod, params) - with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): - result = tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) - assert isinstance(result, Module) - return result diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py index bfd5f4557ce8..261768c4897b 100644 --- a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py +++ b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py @@ -79,7 +79,7 @@ def conv2d_winograd_cpu( eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap( "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5] ) - T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cpu"}) + T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.llvm"}) T.reads( [ data_pack[eps_1, nu_1, p_1, ci_1], diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py new file mode 100644 index 000000000000..87bad5a61caa --- /dev/null +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Customized builder and runner methods""" +# pylint: disable=import-outside-toplevel + +from typing import TYPE_CHECKING, Dict, List + +if TYPE_CHECKING: + from tvm.ir import IRModule + from tvm.meta_schedule.runner import EvaluatorConfig + from tvm.runtime import Device, Module, NDArray + from tvm.target import Target + + +def build_relay( + mod: "IRModule", + target: "Target", + params: Dict[str, "NDArray"], +) -> "Module": + """Build a Relay IRModule + + Parameters + ---------- + mod : IRModule + The Relay IRModule to build. + target : Target + The target to build the module for. + params : Dict[str, NDArray] + The parameter dict to build the module with. + + Returns + ------- + mod : runtime.Module + The built module. + """ + from tvm.relay.build_module import _build_module_no_factory as relay_build + from tvm.runtime import Module + + result = relay_build(mod, target=target, target_host=None, params=params) + assert isinstance(result, Module) + return result + + +def build_relay_with_tensorrt( + mod: "IRModule", + target: "Target", + params: Dict[str, "NDArray"], +) -> "Module": + """Build a Relay IRModule with TensorRT BYOC + + Parameters + ---------- + mod : IRModule + The Relay IRModule to build. + + target : Target + The target to build the module for. + + params : Dict[str, NDArray] + The parameter dict to build the module with. + + Returns + ------- + mod : runtime.Module + The built module. + """ + from tvm.ir.transform import PassContext + from tvm.relay.build_module import _build_module_no_factory as relay_build + from tvm.relay.op.contrib import tensorrt + from tvm.runtime import Module + + mod, config = tensorrt.partition_for_tensorrt(mod, params) + with PassContext( + opt_level=3, + config={"relay.ext.tensorrt.options": config}, + ): + result = relay_build(mod, target=target, target_host=None, params=params) + assert isinstance(result, Module) + return result + + +def run_with_graph_executor( + rt_mod: "Module", + device: "Device", + evaluator_config: "EvaluatorConfig", + repeated_args: List["NDArray"], +) -> List[float]: + """Run a Relay module with GraphExecutor + + Parameters + ---------- + rt_mod : Module + The Relay module to run. + device : Device + The device to run the module on. + evaluator_config : EvaluatorConfig + The evaluator configuration to run the module with. + repeated_args : List[NDArray] + The list of repeated arguments to run the module with. + + Returns + ------- + results : List[float] + The list of results. + """ + import itertools + + from tvm.contrib.graph_executor import GraphModule + + graph_mod = GraphModule(rt_mod["default"](device)) + evaluator = graph_mod.module.time_evaluator( + func_name="run", + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs = [] + for args in repeated_args: + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + return costs diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index 2f1ffdd407fa..29cc70ad3e05 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -15,154 +15,333 @@ # specific language governing permissions and limitations # under the License. """Workloads in Relay IR""" -from enum import Enum -from typing import Dict, Tuple +# pylint: disable=import-outside-toplevel +import multiprocessing +import os +import pickle +from typing import Any, Dict, List, Optional, Tuple -import tvm.relay.testing # pylint: disable=unused-import +import tvm +import tvm.relay.testing from tvm import relay from tvm.ir import IRModule -from tvm.runtime import NDArray - -# Model types supported in Torchvision -class MODEL_TYPE(Enum): # pylint: disable=invalid-name - IMAGE_CLASSIFICATION = (1,) - VIDEO_CLASSIFICATION = (2,) - SEGMENTATION = (3,) - OBJECT_DETECTION = (4,) - TEXT_CLASSIFICATION = (5,) - - -# Specify the type of each model -MODEL_TYPES = { - "resnet18": MODEL_TYPE.IMAGE_CLASSIFICATION, - "mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION, - "bert_base": MODEL_TYPE.TEXT_CLASSIFICATION, -} - - -def get_torch_model( - model_name: str, - input_shape: Tuple[int, ...], - output_shape: Tuple[int, int], # pylint: disable=unused-argument - dtype: str = "float32", -) -> Tuple[IRModule, Dict[str, NDArray]]: - """Load model from torch model zoo - Parameters - ---------- - model_name : str - The name of the model to load - input_shape: Tuple[int, ...] - Tuple for input shape - output_shape: Tuple[int, int] - Tuple for output shape - dtype: str - Tensor data type - """ +from tvm.meta_schedule.integration import ExtractedTask, extract_task_from_relay +from tvm.runtime import NDArray, load_param_dict, save_param_dict +from tvm.target import Target - assert dtype == "float32" - import torch # type: ignore # pylint: disable=import-error,import-outside-toplevel - from torchvision import models # type: ignore # pylint: disable=import-error,import-outside-toplevel - import transformers # type: ignore # pylint: disable=import-error,import-outside-toplevel - import os # type: ignore # pylint: disable=import-error,import-outside-toplevel +def _get_network( + args: Tuple[str, List[int]] +) -> Tuple[IRModule, bytearray, Tuple[str, List[int], str]]: + name: str + input_shape: List[int] + name, input_shape = args - def do_trace(model, inp): - model.eval() - model_trace = torch.jit.trace(model, inp) - model_trace.eval() - return model_trace + mod: IRModule + + if name in [ + "resnet_18", + "resnet_50", + "wide_resnet_50", + "resnext_50", + "mobilenet_v2", + "mobilenet_v3", + "inception_v3", + "densenet_121", + "resnet3d_18", + "vgg_16", + ]: + import torch # type: ignore + from torchvision import models # type: ignore + + if name in ["resnet_18", "resnet_50"]: + model = getattr(models, name.replace("_", ""))(pretrained=False) + elif name == "wide_resnet_50": + model = getattr(models, "wide_resnet50_2")(pretrained=False) + elif name == "resnext_50": + model = getattr(models, "resnext50_32x4d")(pretrained=False) + elif name == "mobilenet_v2": + model = getattr(models, name)(pretrained=False) + elif name == "mobilenet_v3": + model = getattr(models, name + "_large")(pretrained=False) + elif name == "inception_v3": + model = getattr(models, name)(pretrained=False, aux_logits=False) + elif name == "densenet_121": + model = getattr(models, name.replace("_", ""))(pretrained=False) + elif name == "resnet3d_18": + model = models.video.r3d_18(pretrained=False) + elif name == "vgg_16": + model = getattr(models, name.replace("_", ""))(pretrained=False) - # Load model from torchvision - if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION: + dtype = "float32" + input_data = torch.randn(input_shape).type( # pylint: disable=no-member + { + "float32": torch.float32, # pylint: disable=no-member + }[dtype] + ) + scripted_model = torch.jit.trace(model, input_data).eval() + input_name = "input0" + shape_list = [(input_name, input_shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + with tvm.transform.PassContext(opt_level=3): + mod = tvm.transform.Sequential( + [ + relay.transform.RemoveUnusedFunctions(), + relay.transform.ConvertLayout( + { + "nn.conv2d": ["NHWC", "default"], + "nn.conv3d": ["NDHWC", "default"], + "nn.max_pool2d": ["NHWC", "default"], + "nn.avg_pool2d": ["NHWC", "default"], + } + ), + ] + )(mod) + inputs = (input_name, input_shape, dtype) + elif name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]: os.environ["TOKENIZERS_PARALLELISM"] = "false" - model = transformers.BertModel( - transformers.BertConfig( + # pip3 install transformers==3.5 torch==1.7 + import torch # type: ignore + import transformers # type: ignore + + config_dict = { + "bert_tiny": transformers.BertConfig( + num_hidden_layers=6, + hidden_size=512, + intermediate_size=2048, + num_attention_heads=8, + return_dict=False, + ), + "bert_base": transformers.BertConfig( num_hidden_layers=12, hidden_size=768, intermediate_size=3072, num_attention_heads=12, return_dict=False, - ) - ) + ), + "bert_medium": transformers.BertConfig( + num_hidden_layers=12, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + return_dict=False, + ), + "bert_large": transformers.BertConfig( + num_hidden_layers=24, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + return_dict=False, + ), + } + configuration = config_dict[name] + model = transformers.BertModel(configuration) + input_name = "input_ids" + input_dtype = "int64" + a = torch.randint(10000, input_shape) # pylint: disable=no-member model.eval() - input_data = torch.randint(10000, input_shape) - shape_list = [("input_ids", input_shape)] - scripted_model = torch.jit.trace(model, [input_data], strict=False) - elif MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: - model = getattr(models, model_name)() - # Setup input - input_data = torch.randn(input_shape).type(torch.float32) - shape_list = [("input0", input_shape)] - # Get trace. Depending on the model type, wrapper may be necessary. - scripted_model = do_trace(model, input_data) + scripted_model = torch.jit.trace(model, [a], strict=False) + input_name = "input_ids" + shape_list = [(input_name, input_shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + mod = relay.transform.FastMath()(mod) + mod = relay.transform.CombineParallelBatchMatmul()(mod) + inputs = (input_name, input_shape, input_dtype) + elif name == "dcgan": + output_shape = input_shape + batch_size = output_shape[0] + oshape = output_shape[1:] + mod, params = relay.testing.dcgan.get_workload( + batch_size=batch_size, + oshape=oshape, + layout="NHWC", + ) + inputs = ("data", [100], "float32") else: - raise ValueError("Unsupported model in Torch model zoo.") + raise ValueError("Invalid name: " + name) + + params_bytearray: bytearray = save_param_dict(params) + return mod, params_bytearray, inputs + + +def _load_cache(cache_dir: Optional[str], filename: str) -> Optional[List[Any]]: + if cache_dir is None: + return None + path = os.path.join(os.path.expanduser(cache_dir), filename) + if not os.path.exists(path): + return None + print(f"Load from cache: {path}") + with open(path, "rb") as i_f: + return pickle.load(i_f) - # Convert torch model to relay module - mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) - return mod, params + +def _save_cache(cache_dir: Optional[str], filename: str, objects: List[Any]) -> None: + if cache_dir is None: + return + path = os.path.join(os.path.expanduser(cache_dir), filename) + with open(path, "wb") as o_f: + pickle.dump(objects, o_f) def get_network( name: str, - batch_size: int, - layout: str = "NHWC", - dtype: str = "float32", -) -> Tuple[IRModule, Dict[str, NDArray], Tuple[int, int, int, int], Tuple[int, int]]: - """Get the symbol definition and random weight of a network""" - # meta-schedule prefers NHWC layout - if layout == "NHWC": - image_shape = (224, 224, 3) - elif layout == "NCHW": - image_shape = (3, 224, 224) - else: - raise ValueError("Invalid layout: " + layout) + input_shape: List[int], + *, + cache_dir: Optional[str] = None, +) -> Tuple[IRModule, Dict[str, NDArray], Tuple[str, List[int], str]]: + """Get the symbol definition and random weight of a network - input_shape: Tuple[int, int, int, int] = (batch_size,) + image_shape - output_shape: Tuple[int, int] = (batch_size, 1000) + Parameters + ---------- + name : str + The name of the network. + input_shape : List[int] + The shape of the input tensor. + cache_dir : Optional[str], optional + The directory to cache the generated network. + If not specified, the cache will be disabled. - if name.startswith("resnet-"): - n_layer = int(name.split("-")[1]) - mod, params = relay.testing.resnet.get_workload( - num_layers=n_layer, - batch_size=batch_size, - layout=layout, - dtype=dtype, - image_shape=image_shape, - ) - elif name.startswith("resnet3d-"): - n_layer = int(name.split("-")[1]) - mod, params = relay.testing.resnet.get_workload( - num_layers=n_layer, - batch_size=batch_size, - layout=layout, - dtype=dtype, - image_shape=image_shape, - ) - elif name == "mobilenet": - mod, params = relay.testing.mobilenet.get_workload( - batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape - ) - elif name == "squeezenet_v1.1": - assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout" - mod, params = relay.testing.squeezenet.get_workload( - version="1.1", - batch_size=batch_size, - dtype=dtype, - image_shape=image_shape, - ) - elif name == "inception_v3": - input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3) - mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) - elif name == "mxnet": - from mxnet.gluon.model_zoo.vision import get_model # type: ignore # pylint: disable=import-outside-toplevel - - assert layout == "NCHW" - block = get_model("resnet50_v1", pretrained=True) - mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype) - net = mod["main"] - net = relay.Function( - net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs + Returns + ------- + mod : IRModule + The IRModule representing the network. + params : Dict[str, NDArray] + The parameters of the networks. + inputs : Tuple[str, List[int], str] + The name, shape and dtype of the input tensor. + """ + + mod: IRModule + params: Dict[str, NDArray] + inputs: Tuple[str, List[int], str] + params_bytearray: bytearray + + filename = f'relay-{name}-{",".join(str(i) for i in input_shape)}.json' + cached = _load_cache(cache_dir, filename) + if cached is None: + with multiprocessing.Pool(processes=1) as pool: + result = pool.map(_get_network, [(name, input_shape)]) + ((mod, params_bytearray, inputs),) = result + cached = [mod, params_bytearray, inputs] + _save_cache(cache_dir, filename, cached) + mod, params_bytearray, inputs = cached + params = load_param_dict(params_bytearray) + return mod, params, inputs + + +def extract_from_relay( + mod: IRModule, + target: Target, + params: Optional[Dict[str, NDArray]], + name: str, + input_shape: List[int], + *, + cache_dir: Optional[str] = None, + opt_level: int = 3, + pass_config: Optional[Dict[str, Any]] = None, + disabled_pass: Optional[List[str]] = None, +) -> List[ExtractedTask]: + """Extract the tasks from a network. + + Parameters + ---------- + mod : IRModule + The IRModule representing the network. + target : Target + The target that the network will be deployed to. + params : Optional[Dict[str, NDArray]] + The parameters of the networks. + name : str + The name of the network. + input_shape : List[int] + The shape of the input tensor. + cache_dir : Optional[str] + The directory to cache the generated network. + If not specified, the cache will be disabled. + opt_level : int + The optimization level of the compiler. + pass_config : Optional[Dict[str, Any]] + The pass config of the compiler. + disabled_pass : Optional[List[str]] + The disabled pass of the compiler. + + Returns + ------- + extracted_tasks : List[ExtractedTask] + The extracted tasks. + """ + filename = f'tasks-{target.kind.name}-{name}-{",".join(str(i) for i in input_shape)}.json' + extracted_tasks = _load_cache(cache_dir, filename) + if extracted_tasks is None: + extracted_tasks = extract_task_from_relay( + mod=mod, + target=target, + params=params, + opt_level=opt_level, + pass_config=pass_config, + disabled_pass=disabled_pass, ) - mod = IRModule.from_expr(net) - return mod, params, input_shape, output_shape + extracted_tasks = list(extracted_tasks) + _save_cache(cache_dir, filename, extracted_tasks) + return extracted_tasks + + +def _build_dataset() -> List[Tuple[str, List[int]]]: + network_keys = [] + for name in [ + "resnet_18", + "resnet_50", + "mobilenet_v2", + "mobilenet_v3", + "wide_resnet_50", + "resnext_50", + "densenet_121", + "vgg_16", + ]: + for batch_size in [1, 4, 8]: + for image_size in [224, 240, 256]: + network_keys.append((name, [batch_size, 3, image_size, image_size])) + # inception-v3 + for name in ["inception_v3"]: + for batch_size in [1, 2, 4]: + for image_size in [299]: + network_keys.append((name, [batch_size, 3, image_size, image_size])) + # resnet3d + for name in ["resnet3d_18"]: + for batch_size in [1, 2, 4]: + for image_size in [112, 128, 144]: + network_keys.append((name, [batch_size, 3, image_size, image_size, 16])) + # bert + for name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]: + for batch_size in [1, 2, 4]: + for seq_length in [64, 128, 256]: + network_keys.append((name, [batch_size, seq_length])) + # dcgan + for name in ["dcgan"]: + for batch_size in [1, 4, 8]: + for image_size in [64]: + network_keys.append((name, [batch_size, 3, image_size, image_size])) + + return network_keys + + +SUPPORTED = [ + # TorchVision + "resnet_18", + "resnet_50", + "mobilenet_v2", + "mobilenet_v3", + "wide_resnet_50", + "resnext_50", + "resnet3d_18", + "inception_v3", + "densenet_121", + "vgg_16", + # Transformer + "bert_tiny", + "bert_base", + "bert_medium", + "bert_large", + # Relay testing + "dcgan", +] diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 7872091f1a5d..97f7adce63ed 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -19,28 +19,31 @@ from a Relay expression. """ import warnings -import numpy as np +import numpy as np from tvm.ir import IRModule - from tvm.ir.transform import PassContext -from tvm.tir import expr as tvm_expr from tvm.target import Target -from .. import nd as _nd, autotvm, register_func +from tvm.tir import expr as tvm_expr + +from .. import autotvm +from .. import nd as _nd +from .. import register_func +from ..contrib import graph_executor as _graph_executor +from ..contrib import utils as contrib_utils from ..runtime import load_module from ..runtime.executor import aot_executor as _aot_executor from ..target import Target -from ..contrib import graph_executor as _graph_executor -from ..contrib import utils as contrib_utils from . import _build_module -from . import ty as _ty from . import expr as _expr from . import function as _function -from .transform import InferType -from .backend.utils import mangle_module_name -from .backend import executor_factory as _executor_factory, Executor, Runtime +from . import ty as _ty +from .backend import Executor, Runtime +from .backend import executor_factory as _executor_factory from .backend import interpreter as _interpreter +from .backend.utils import mangle_module_name from .backend.vm import VMExecutor +from .transform import InferType def build_target_by_device_type_map(target): @@ -287,13 +290,17 @@ def _module_export(module, file_name): # fcompile, addons, kwargs? @register_func("tvm.relay.build") +def _build_module_no_factory_impl(mod, target, target_host, params, mod_name): + target, target_host = Target.check_and_update_host_consist(target, target_host) + return build(mod, target, params=params, mod_name=mod_name).module + + def _build_module_no_factory(mod, target=None, target_host=None, params=None, mod_name="default"): """A wrapper around build which discards the Python GraphFactoryRuntime. This wrapper is suitable to be used from other programming languages as the runtime::Module can be freely passed between language boundaries. """ - target, target_host = Target.check_and_update_host_consist(target, target_host) - return build(mod, target, params=params, mod_name=mod_name).module + return _build_module_no_factory_impl(mod, target, target_host, params, mod_name) def _reconstruct_from_deprecated_options(deprecated_params_target): diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index b1230c0398c2..dbff3f601655 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -1052,6 +1052,12 @@ def _conv2d_winograd_nhwc_impl( ) # transform data + target = tvm.target.Target.current(allow_none=True) + if target is not None: + target_kind = "meta_schedule.winograd_data_pack." + target.kind.name + else: + target_kind = "None" + r_a = te.reduce_axis((0, alpha), "r_a") r_b = te.reduce_axis((0, alpha), "r_b") data_pack = te.compute( @@ -1062,7 +1068,7 @@ def _conv2d_winograd_nhwc_impl( name="data_pack", attrs={ "auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"], - "schedule_rule": "meta_schedule.winograd_data_pack.cpu", + "schedule_rule": target_kind, }, # the attrs are necessary hints for the auto-scheduler ) diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index 1ecb537d15a9..1ebec19888cd 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -28,17 +28,24 @@ namespace meta_schedule { /**************** Utility functions ****************/ template -bool HasOnlyOneFunction(const IRModule& mod) { +Optional GetOnlyOneFunction(const IRModule& mod) { if (mod->functions.size() != 1) { - return false; + return NullOpt; } for (const auto& kv : mod->functions) { const BaseFunc& func = kv.second; if (!func->IsInstance()) { - return false; + return NullOpt; + } else { + return Downcast(func); } } - return true; + return NullOpt; +} + +template +bool HasOnlyOneFunction(const IRModule& mod) { + return GetOnlyOneFunction(mod).defined(); } /**************** ExtractedTask ****************/ @@ -129,14 +136,17 @@ Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRMod if (database->HasWorkload(prim_mod)) { Array records = database->GetTopK(database->CommitWorkload(prim_mod), 1); if (records.size() == 1) { - LOG(INFO) << "Applied history best for " << task_name << "."; + LOG(INFO) << "Applied history best for: " << task_name; tir::Schedule sch = tir::Schedule::Traced(records[0]->workload->mod, /*seed=*/-1, /*debug_mask=*/0, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); records[0]->trace->ApplyToSchedule(sch, false); - return sch->mod(); + tir::PrimFunc func = GetOnlyOneFunction(sch->mod()).value(); + LOG(INFO) << "\n" << tir::AsTVMScript(func); + return func; } } + LOG(WARNING) << "Cannot find workload: " << task_name << "\n" << tir::AsTVMScript(prim_mod); return NullOpt; } diff --git a/src/meta_schedule/schedule_rule/winograd.cc b/src/meta_schedule/schedule_rule/winograd.cc index 44db6f2f404c..d8aab3a3f757 100644 --- a/src/meta_schedule/schedule_rule/winograd.cc +++ b/src/meta_schedule/schedule_rule/winograd.cc @@ -69,7 +69,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse") return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.cpu") +TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.llvm") .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { BlockRV input_tile = GetOnlyProducer(sch, data_pack); BlockRV data_pad = GetOnlyProducer(sch, input_tile); diff --git a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py index 3b4164c40644..91e2c41b2b3c 100644 --- a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py +++ b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py @@ -15,37 +15,37 @@ # specific language governing permissions and limitations # under the License. """ Test Meta Schedule Builder """ +# pylint: disable=missing-docstring + import sys +from typing import List + import pytest -import itertools import tvm from tvm import relay +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder +from tvm.meta_schedule.runner import EvaluatorConfig, LocalRunner, RunnerInput +from tvm.meta_schedule.testing.custom_builder_runner import ( + build_relay, + build_relay_with_tensorrt, + run_with_graph_executor, +) +from tvm.meta_schedule.testing.relay_workload import get_network from tvm.relay import testing from tvm.relay.op.contrib import tensorrt -import numpy as np -from typing import List -from tvm._ffi import register_func from tvm.target import Target -from tvm.runtime import Module -from tvm.meta_schedule.arg_info import TensorInfo -from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult -from tvm.meta_schedule.runner import ( - EvaluatorConfig, - LocalRunner, - RunnerInput, -) - from tvm.tir import FloatImm -from tvm.meta_schedule.testing import get_network has_tensorrt_codegen = pytest.mark.skipif( - not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" + not tvm.get_global_func("relay.ext.tensorrt", True), + reason="TensorRT codegen not available", ) has_tensorrt_runtime = pytest.mark.skipif( - not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" + not tensorrt.is_tensorrt_runtime_enabled(), + reason="TensorRT runtime not available", ) - # conv2d+relu network def get_conv2d_relu( data_shape, @@ -83,105 +83,52 @@ def get_conv2d_relu( def verify_meta_schedule_with_tensorrt( - mod, params, data_shape, use_meta_sched: bool = True, use_trt: bool = True, mode: str = "vm" + mod, + params, + data_shape, + use_trt: bool = True, ): - if use_meta_sched: - # With meta_schedule - dev = "cuda" - - # Build - if use_trt: - from tvm.meta_schedule.testing import relay_build_with_tensorrt - - builder = LocalBuilder(f_build=relay_build_with_tensorrt) - else: - - def relay_build_without_tensorrt( - mod: Module, - target: Target, - params: dict, - ) -> List[BuilderResult]: - return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) - - builder = LocalBuilder(f_build=relay_build_without_tensorrt) - - builder_input = BuilderInput(mod, Target(dev, host="llvm"), params) - - (builder_result,) = builder.build([builder_input]) - assert builder_result.error_msg is None - assert builder_result.artifact_path is not None - - # Run - evaluator_config = EvaluatorConfig( + # Build + builder = LocalBuilder( + f_build=build_relay_with_tensorrt if use_trt else build_relay, + timeout_sec=1000, + ) + builder_input = BuilderInput(mod, Target("cuda"), params) + builder_result = builder.build([builder_input])[0] + assert builder_result.error_msg is None, builder_result.error_msg + assert builder_result.artifact_path is not None + + # Run + runner_input = RunnerInput( + builder_result.artifact_path, + device_type="cuda", + args_info=[TensorInfo("float32", data_shape)], + ) + runner = LocalRunner( + evaluator_config=EvaluatorConfig( number=5, repeat=2, min_repeat_ms=0, enable_cpu_cache_flush=False, - ) - - runner_input = RunnerInput( - builder_result.artifact_path, "cuda", [TensorInfo("float32", data_shape)] - ) - - def eval_func(rt_mod, device, evaluator_config, repeated_args): - rt_mod = tvm.contrib.graph_executor.GraphModule(rt_mod["default"](device)) - - eval = rt_mod.module.time_evaluator( - func_name="run", - dev=device, - number=evaluator_config.number, - repeat=evaluator_config.repeat, - min_repeat_ms=evaluator_config.min_repeat_ms, - f_preproc="cache_flush_cpu_non_first_arg" - if evaluator_config.enable_cpu_cache_flush - else "", - ) - repeated_costs: List[List[float]] = [] - for args in repeated_args: - profile_result = eval(*args) - repeated_costs.append(profile_result.results) - - costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] - return costs - - runner = LocalRunner( - evaluator_config=evaluator_config, - f_run_evaluator=eval_func, - ) - - # Run the module - (runner_future,) = runner.run([runner_input]) - runner_result = runner_future.result() - assert runner_result is not None - assert runner_result.run_secs is not None - assert runner_result.error_msg is None - - for result in runner_result.run_secs: - if isinstance(result, FloatImm): - result = result.value - assert isinstance(result, float) - assert result >= 0.0 - - else: - # Without meta_schedule - if use_trt: - mod, config = tensorrt.partition_for_tensorrt(mod) - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): - func = relay.create_executor( - mode, mod=mod, device=tvm.cuda(0), target="cuda" - ).evaluate() - else: - with tvm.transform.PassContext(opt_level=3): - func = relay.create_executor( - mode, mod=mod, device=tvm.cuda(0), target="cuda", params=params - ).evaluate() - - -@tvm.testing.requires_cuda + ), + f_run_evaluator=run_with_graph_executor, + ) + + # Run the module + runner_future = runner.run([runner_input])[0] + runner_result = runner_future.result() + assert runner_result is not None + assert runner_result.error_msg is None, runner_result.error_msg + assert runner_result.run_secs is not None + + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + + @has_tensorrt_codegen -@has_tensorrt_runtime def test_conv2d_relu(): data_shape = (1, 1280, 14, 14) out_channels = 256 @@ -206,21 +153,17 @@ def test_conv2d_relu(): verify_meta_schedule_with_tensorrt(mod, params, data_shape) -@tvm.testing.requires_cuda @has_tensorrt_codegen -@has_tensorrt_runtime -@pytest.mark.parametrize( - "model_name", - ["resnet-50", "mobilenet"], -) -@pytest.mark.parametrize("batch_size", [1]) -@pytest.mark.parametrize("use_meta_sched", [True]) +@pytest.mark.parametrize("model_name", ["resnet_50"]) +@pytest.mark.parametrize("input_shape", [[1, 3, 224, 224]]) @pytest.mark.parametrize("use_trt", [True, False]) -def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): - - mod, params, input_shape, output_shape = get_network(name=model_name, batch_size=batch_size) +def test_relay_model(model_name: str, input_shape: List[int], use_trt: bool): + mod, params, _ = get_network(model_name, input_shape) verify_meta_schedule_with_tensorrt( - mod, params, input_shape, use_meta_sched=use_meta_sched, use_trt=use_trt, mode="vm" + mod, + params, + input_shape, + use_trt, ) diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index 3676e3ad94ae..50dc9289780d 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -18,30 +18,29 @@ from typing import List import pytest - import tvm from tvm import meta_schedule as ms from tvm.ir.module import IRModule -from tvm.meta_schedule.utils import derived_object -from tvm.tir import Schedule -from tvm.target import Target -from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord +from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload from tvm.meta_schedule.integration import ( + ApplyHistoryBest, ExtractedTask, MetaScheduleContext, TaskExtraction, - ApplyHistoryBest, ) -from tvm.meta_schedule.testing import get_network +from tvm.meta_schedule.testing.relay_workload import get_network +from tvm.meta_schedule.utils import derived_object from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule -# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking +# pylint: disable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument,missing-docstring,invalid-name @tvm.script.ir_module class MockModule: @T.prim_func - def main(a: T.handle, b: T.handle) -> None: # pylint: disable=no-self-argument + def main(a: T.handle, b: T.handle) -> None: # type: ignore T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") @@ -51,7 +50,17 @@ def main(a: T.handle, b: T.handle) -> None: # pylint: disable=no-self-argument B[vi] = A[vi] -# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking +# pylint: enable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument + + +def _has_torch(): + import importlib.util # pylint: disable=unused-import,import-outside-toplevel + + spec = importlib.util.find_spec("torch") + return spec is not None + + +requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed") def _check_mock_task(tasks: List[ExtractedTask], mod: IRModule): @@ -63,13 +72,9 @@ def _check_mock_task(tasks: List[ExtractedTask], mod: IRModule): tvm.ir.assert_structural_equal(tir_mod, MockModule) +@requires_torch def test_meta_schedule_integration_task_extraction_query(): - mod, _, _, _ = get_network( - name="resnet-18", - batch_size=1, - layout="NHWC", - dtype="float32", - ) + mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) env = TaskExtraction() env.query(task_name="mock-task", mod=mod, target=Target("llvm"), dispatched=[MockModule]) _check_mock_task(env.tasks, mod) @@ -93,13 +98,9 @@ def test_meta_schedule_integration_multiple_current(): ... +@requires_torch def test_meta_schedule_integration_query_inside_with_scope(): - mod, _, _, _ = get_network( - name="resnet-18", - batch_size=1, - layout="NHWC", - dtype="float32", - ) + mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) env = TaskExtraction() with env: MetaScheduleContext.query_inside_with_scope( @@ -111,17 +112,43 @@ def test_meta_schedule_integration_query_inside_with_scope(): _check_mock_task(env.tasks, mod) +@requires_torch def test_meta_schedule_integration_extract_from_resnet(): - mod, params, _, _ = get_network( - name="resnet-18", - batch_size=1, - layout="NHWC", - dtype="float32", - ) + mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params) - assert len(extracted_tasks) == 30 - - + expected_task_names = [ + "vm_mod_fused_" + s + for s in [ + "nn_max_pool2d", + "nn_adaptive_avg_pool2d", + "nn_dense_add", + "nn_conv2d_add", + "nn_conv2d_add_1", + "nn_conv2d_add_2", + "nn_conv2d_add_add_nn_relu", + "nn_conv2d_add_add_nn_relu_1", + "nn_conv2d_add_nn_relu", + "nn_conv2d_add_nn_relu_1", + "nn_conv2d_add_nn_relu_2", + "nn_conv2d_add_nn_relu_3", + "nn_conv2d_add_nn_relu_4", + "nn_conv2d_add_nn_relu_5", + "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu", + "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1", + "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu", + "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1", + # The two tasks below are purely spatial and are ruled out by AutoScheduler + "layout_transform", + "layout_transform_reshape_squeeze", + ] + ] + + assert len(extracted_tasks) == 20 + for t in extracted_tasks: + assert t.task_name in expected_task_names, t.task_name + + +@requires_torch def test_meta_schedule_integration_apply_history_best(): @derived_object class DummyDatabase(PyDatabase): @@ -161,12 +188,7 @@ def __len__(self) -> int: def print_results(self) -> None: print("\n".join([str(r) for r in self.records])) - mod, _, _, _ = get_network( - name="resnet-18", - batch_size=1, - layout="NHWC", - dtype="float32", - ) + mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) database = DummyDatabase() env = ApplyHistoryBest(database) target = Target("llvm") @@ -175,6 +197,7 @@ def print_results(self) -> None: TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, []) ) mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule]) + mod = IRModule({"main": mod}) assert tvm.ir.structural_equal(mod, workload.mod) diff --git a/tests/python/unittest/test_meta_schedule_runner.py b/tests/python/unittest/test_meta_schedule_runner.py index c2fb92ccd880..09e708f32f42 100644 --- a/tests/python/unittest/test_meta_schedule_runner.py +++ b/tests/python/unittest/test_meta_schedule_runner.py @@ -23,8 +23,8 @@ import numpy as np import pytest - import tvm +import tvm.testing from tvm._ffi import register_func from tvm.meta_schedule.arg_info import TensorInfo from tvm.meta_schedule.builder import BuilderInput, LocalBuilder @@ -37,21 +37,25 @@ RunnerFuture, RunnerInput, ) +from tvm.meta_schedule.runner.local_runner import ( + default_alloc_argument as local_default_alloc_argument, +) from tvm.meta_schedule.runner.rpc_runner import ( - default_alloc_argument as rpc_default_alloc_argument, T_ARG_INFO_JSON_OBJ_LIST, T_ARGUMENT_LIST, ) -from tvm.meta_schedule.runner.local_runner import ( - default_alloc_argument as local_default_alloc_argument, +from tvm.meta_schedule.runner.rpc_runner import ( + default_alloc_argument as rpc_default_alloc_argument, +) +from tvm.meta_schedule.testing.local_rpc import LocalRPC +from tvm.meta_schedule.utils import ( + derived_object, + get_global_func_with_default_on_worker, ) -from tvm.meta_schedule.testing import LocalRPC -from tvm.meta_schedule.utils import derived_object, get_global_func_with_default_on_worker from tvm.rpc import RPCSession from tvm.runtime import Device, Module from tvm.script import tir as T from tvm.target import Target -import tvm.testing from tvm.tir import FloatImm MATMUL_N = 16 @@ -886,4 +890,4 @@ def test_run_evaluator( if __name__ == "__main__": - test_meta_schedule_local_single_run() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 144311058bd5..6bf59d269205 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -17,27 +17,20 @@ # pylint: disable=missing-docstring import logging import tempfile -import pytest -import numpy as np -from typing import Tuple, List - -from tvm.meta_schedule.utils import derived_object - -try: - import torch -except ModuleNotFoundError: - pass +from typing import List +import numpy as np +import pytest import tvm from tvm import relay -from tvm.ir import IRModule -from tvm.runtime.ndarray import cpu, cuda -from tvm.target.target import Target from tvm.contrib import graph_executor +from tvm.ir import IRModule from tvm.meta_schedule import ReplayTraceConfig -from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord -from tvm.meta_schedule.testing import MODEL_TYPE, MODEL_TYPES, get_torch_model +from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload +from tvm.meta_schedule.testing.relay_workload import get_network from tvm.meta_schedule.tune import tune_relay +from tvm.meta_schedule.utils import derived_object +from tvm.target.target import Target logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) @@ -83,47 +76,33 @@ def print_results(self) -> None: @pytest.mark.skip("Integration test") -@pytest.mark.parametrize("model_name", ["resnet18", "mobilenet_v2", "bert_base"]) -@pytest.mark.parametrize("batch_size", [1]) -@pytest.mark.parametrize("target", ["llvm --num-cores=16", "nvidia/geforce-rtx-3070"]) -def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str): - if model_name == "inception_v3" and batch_size == 1: - pytest.skip("inception_v3 does not handle batch_size of 1") - - input_shape: Tuple[int, ...] - input_name = "input0" - dev = tvm.cpu() if str(target).startswith("llvm") else cuda() - if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION: - seq_length = 128 - input_name = "input_ids" - input_shape = (batch_size, seq_length) +@pytest.mark.parametrize( + "model_name, input_shape, target", + [ + ("resnet_18", [1, 3, 224, 224], "llvm --num-cores=16"), + ("resnet_18", [1, 3, 224, 224], "nvidia/geforce-rtx-3070"), + ("mobilenet_v2", [1, 3, 224, 224], "llvm --num-cores=16"), + ("mobilenet_v2", [1, 3, 224, 224], "nvidia/geforce-rtx-3070"), + ("bert_base", [1, 64], "llvm --num-cores=16"), + ("bert_base", [1, 64], "nvidia/geforce-rtx-3070"), + ], +) +def test_meta_schedule_tune_relay( + model_name: str, + input_shape: List[int], + target: str, +): + dev = tvm.cpu() if str(target).startswith("llvm") else tvm.cuda() + if model_name.startswith("bert"): data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape), dev) # embedding size else: - if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: - input_shape = (batch_size, 3, 299, 299) - elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: - input_shape = (batch_size, 3, 299, 299) - elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: - input_shape = (1, 3, 300, 300) - elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: - input_shape = (batch_size, 3, 3, 299, 299) - else: - raise ValueError("Unsupported model: " + model_name) data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"), dev) - output_shape: Tuple[int, int] = (batch_size, 1000) - - mod, params = get_torch_model( - model_name=model_name, - input_shape=input_shape, - output_shape=output_shape, - dtype="float32", - ) - + mod, params, (input_name, _, _) = get_network(name=model_name, input_shape=input_shape) + target = Target(target) with tempfile.TemporaryDirectory() as work_dir: - target = Target(target) database = DummyDatabase() - rt_mod: tvm.module = tune_relay( + rt_mod: tvm.runtime.Module = tune_relay( mod=mod, params=params, target=target, @@ -136,7 +115,7 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str) ) # Compile without meta-scheduler for correctness check with tvm.transform.PassContext(opt_level=0): - rt_mod2 = relay.build(mod, target=target, params=params) + rt_mod2 = relay.build(mod, target=Target("llvm"), params=params) def get_output(data, lib): module = graph_executor.GraphModule(lib["default"](dev)) @@ -146,14 +125,14 @@ def get_output(data, lib): # Check correctness actual_output = get_output(data, rt_mod) - expected_output = get_output(data, rt_mod2) + expected_output = get_output(tvm.nd.array(data.numpy(), device=tvm.cpu()), rt_mod2) assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) if __name__ == """__main__""": - test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16") - test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070") - test_meta_schedule_tune_relay("mobilenet_v2", 1, "llvm --num-cores=16") - test_meta_schedule_tune_relay("mobilenet_v2", 1, "nvidia/geforce-rtx-3070") - test_meta_schedule_tune_relay("bert_base", 1, "llvm --num-cores=16") - test_meta_schedule_tune_relay("bert_base", 1, "nvidia/geforce-rtx-3070") + test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], "llvm --num-cores=16") + test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], "nvidia/geforce-rtx-3070") + test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224], "llvm --num-cores=16") + test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224], "nvidia/geforce-rtx-3070") + test_meta_schedule_tune_relay("bert_base", [1, 64], "llvm --num-cores=16") + test_meta_schedule_tune_relay("bert_base", [1, 64], "nvidia/geforce-rtx-3070")