diff --git a/python/setup.py b/python/setup.py index 402d993820ba..fff7a0ed3bb1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -166,7 +166,21 @@ def get_package_data_files(): ], extras_require={ "test": ["pillow<7", "matplotlib"], - "extra_feature": ["tornado", "psutil", "xgboost>=1.1.0", "mypy", "orderedset"], + "extra_feature": [ + "tornado", + "psutil", + "xgboost>=1.1.0", + "mypy", + "orderedset", + ], + "tvmc": [ + "tensorflow>=2.1.0", + "tflite>=2.1.0", + "onnx>=1.7.0", + "onnxruntime>=1.0.0", + "torch>=1.4.0", + "torchvision>=0.5.0", + ], }, packages=find_packages(), package_dir={"tvm": "tvm"}, diff --git a/python/tvm/driver/tvmc/__init__.py b/python/tvm/driver/tvmc/__init__.py index 13a83393a912..cf35f189d2ba 100644 --- a/python/tvm/driver/tvmc/__init__.py +++ b/python/tvm/driver/tvmc/__init__.py @@ -14,3 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +TVMC - TVM driver command-line interface +""" + +from . import compiler diff --git a/python/tvm/driver/tvmc/__main__.py b/python/tvm/driver/tvmc/__main__.py index f72e9f4df3ba..55235a6adfdd 100644 --- a/python/tvm/driver/tvmc/__main__.py +++ b/python/tvm/driver/tvmc/__main__.py @@ -18,7 +18,7 @@ TVMC - TVM driver command-line interface """ -from .main import main +from tvm.driver import tvmc if __name__ == "__main__": - main() + tvmc.main.main() diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index aa53ce7134bb..f389c81f0337 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -17,7 +17,49 @@ """ Common utility functions shared by TVMC modules. """ +from tvm import relay +from tvm import transform class TVMCException(Exception): """TVMC Exception""" + + +def convert_graph_layout(mod, desired_layout): + """Alter the layout of the input graph. + + Parameters + ---------- + mod : tvm.relay.Module + The relay module to convert. + desired_layout : str + The layout to convert to. + + Returns + ------- + mod : tvm.relay.Module + The converted module. + """ + + # Assume for the time being that graphs only have + # conv2d as heavily-sensitive operators. + desired_layouts = { + "nn.conv2d": [desired_layout, "default"], + "qnn.conv2d": [desired_layout, "default"], + } + + # Convert the layout of the graph where possible. + seq = transform.Sequential( + [ + relay.transform.RemoveUnusedFunctions(), + relay.transform.ConvertLayout(desired_layouts), + ] + ) + + with transform.PassContext(opt_level=3): + try: + return seq(mod) + except Exception as err: + raise TVMCException( + "Error converting layout to {0}: {1}".format(desired_layout, str(err)) + ) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py new file mode 100644 index 000000000000..77703b2d06e1 --- /dev/null +++ b/python/tvm/driver/tvmc/compiler.py @@ -0,0 +1,280 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Provides support to compile networks both AOT and JIT. +""" +import logging +import os.path +import tarfile +from pathlib import Path + +import tvm +from tvm import autotvm +from tvm import relay +from tvm.contrib import cc +from tvm.contrib import util + +from . import common, frontends +from .main import register_parser + + +@register_parser +def add_compile_parser(subparsers): + """ Include parser for 'compile' subcommand """ + + parser = subparsers.add_parser("compile", help="compile a model") + parser.set_defaults(func=drive_compile) + parser.add_argument( + "--cross-compiler", + default="", + help="the cross compiler to generate target libraries, e.g. 'aarch64-linux-gnu-gcc'", + ) + parser.add_argument( + "--desired-layout", + choices=["NCHW", "NHWC"], + default=None, + help="change the data layout of the whole graph", + ) + parser.add_argument( + "--dump-code", + metavar="FORMAT", + default="", + help="comma separarated list of formats to export, e.g. 'asm,ll,relay' ", + ) + parser.add_argument( + "--model-format", + choices=frontends.get_frontend_names(), + help="specify input model format", + ) + parser.add_argument( + "-o", + "--output", + default="module.tar", + help="output the compiled module to an archive", + ) + parser.add_argument( + "--target", + help="compilation target as plain string, inline JSON or path to a JSON file", + required=True, + ) + parser.add_argument( + "--tuning-records", + metavar="PATH", + default="", + help="path to an auto-tuning log file by AutoTVM. If not presented, " + "the fallback/tophub configs will be used", + ) + parser.add_argument("-v", "--verbose", action="count", default=0, help="increase verbosity") + # TODO (@leandron) This is a path to a physical file, but + # can be improved in future to add integration with a modelzoo + # or URL, for example. + parser.add_argument("FILE", help="path to the input model file") + + +def drive_compile(args): + """Invoke tvmc.compiler module with command line arguments + + Parameters + ---------- + args: argparse.Namespace + Arguments from command line parser. + + Returns + -------- + int + Zero if successfully completed + + """ + + graph, lib, params, dumps = compile_model( + args.FILE, + args.target, + args.dump_code, + None, + args.model_format, + args.tuning_records, + args.tensor_layout, + ) + + if dumps: + save_dumps(args.output, dumps) + + save_module(args.output, graph, lib, params, args.cross_compiler) + return 0 + + +def compile_model( + path, + target, + dump_code=None, + target_host=None, + model_format=None, + tuning_records=None, + alter_layout=None, +): + """Compile a model from a supported framework into a TVM module. + + This function takes a union of the arguments of both frontends.load_model + and compiler.compile_relay. The resulting TVM module can be executed using + the graph runtime. + + Parameters + ---------- + path: str + Path to a file + target : str + The target for which to compile. Can be a plain string or + a path. + dump_code : list, optional + Dump the generated code for the specified source types, on + the requested target. + target_host : str, optional + The target of the host machine if host-side code + needs to be generated. + model_format: str, optional + A string representing a name of a frontend to be used + tuning_records: str, optional + Path to the file produced by the tuning to be used during + compilation. + alter_layout: str, optional + The layout to convert the graph to. Note, the convert layout + pass doesn't currently guarantee the whole of the graph will + be converted to the chosen layout. + + Returns + ------- + graph : str + A JSON-serialized TVM execution graph. + lib : tvm.module.Module + A TVM module containing the compiled functions. + params : dict + The parameters (weights) for the TVM module. + dumps : dict + Dictionary containing the dumps specified. + + """ + dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None + mod, params = frontends.load_model(path, model_format) + + if alter_layout: + mod = common.convert_graph_layout(mod, alter_layout) + + # Handle the case in which target is a path to a JSON file. + if os.path.exists(target): + with open(target) as target_file: + logging.info("using target input from file: %s", target) + target = "".join(target_file.readlines()) + + # TODO(@leandron) We don't have an API to collect a list of supported + # targets yet + logging.debug("creating target from input: %s", target) + tvm_target = tvm.target.Target(target) + target_host = target_host or "" + + if tuning_records and os.path.exists(tuning_records): + # TODO (@leandron) a new PR will introduce the 'tune' subcommand + # the is used to generate the tuning records file + logging.debug("tuning records file provided: %s", tuning_records) + with autotvm.apply_history_best(tuning_records): + with tvm.transform.PassContext(opt_level=3): + logging.debug("building relay graph with tuning records") + graph_module = relay.build(mod, tvm_target, params=params, target_host=tvm_target) + else: + with tvm.transform.PassContext(opt_level=3): + logging.debug("building relay graph (no tuning records provided)") + graph_module = relay.build(mod, tvm_target, params=params, target_host=tvm_target) + + # Generate output dump files with sources + dump_code = dump_code or [] + dumps = {} + for source_type in dump_code: + lib = graph_module.get_lib() + # TODO lib.get_source call have inconsistent behavior for unsupported + # formats (@leandron). + source = str(mod) if source_type == "relay" else lib.get_source(source_type) + dumps[source_type] = source + + return graph_module.get_json(), graph_module.get_lib(), graph_module.get_params(), dumps + + +def save_module(module_path, graph, lib, params, cross=None): + """ + Create a tarball containing the generated TVM graph, + exported library and parameters + + Parameters + ---------- + module_path : str + path to the target tar.gz file to be created, + including the file name + graph : str + A JSON-serialized TVM execution graph. + lib : tvm.module.Module + A TVM module containing the compiled functions. + params : dict + The parameters (weights) for the TVM module. + cross : str or callable object, optional + Function that performs the actual compilation + + """ + lib_name = "mod.so" + graph_name = "mod.json" + param_name = "mod.params" + temp = util.tempdir() + path_lib = temp.relpath(lib_name) + if not cross: + logging.debug("exporting library to %s", path_lib) + lib.export_library(path_lib) + else: + logging.debug("exporting library to %s , using cross compiler %s", path_lib, cross) + lib.export_library(path_lib, cc.cross_compiler(cross)) + + with open(temp.relpath(graph_name), "w") as graph_file: + logging.debug("writing graph to file to %s", graph_file.name) + graph_file.write(graph) + + with open(temp.relpath(param_name), "wb") as params_file: + logging.debug("writing params to file to %s", params_file.name) + params_file.write(relay.save_param_dict(params)) + + logging.debug("saving module as tar file to %s", module_path) + with tarfile.open(module_path, "w") as tar: + tar.add(path_lib, lib_name) + tar.add(temp.relpath(graph_name), graph_name) + tar.add(temp.relpath(param_name), param_name) + + +def save_dumps(module_name, dumps, dump_root="."): + """ + Serialize dump files to the disk. + + Parameters + ---------- + module_name : str + File name, referring to the module that generated + the dump contents + dumps : dict + The output contents to be saved into the files + dump_root : str, optional + Path in which dump files will be created + + """ + + for dump_format in dumps: + dump_name = module_name + "." + dump_format + with open(Path(dump_root, dump_name), "w") as f: + f.write(dumps[dump_format]) diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py new file mode 100644 index 000000000000..6275f779f778 --- /dev/null +++ b/python/tvm/driver/tvmc/frontends.py @@ -0,0 +1,413 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Provides support to parse models from different frameworks into Relay networks. + +Frontend classes do lazy-loading of modules on purpose, to reduce time spent on +loading the tool. +""" +import logging +import os +import sys +from abc import ABC +from abc import abstractmethod +from pathlib import Path + +import numpy as np + +from tvm import relay +from tvm.driver.tvmc.common import TVMCException + + +class Frontend(ABC): + """Abstract class for command line driver frontend. + + Provide a unified way to import models (as files), and deal + with any required preprocessing to create a TVM module from it.""" + + @staticmethod + @abstractmethod + def name(): + """Frontend name""" + + @staticmethod + @abstractmethod + def suffixes(): + """File suffixes (extensions) used by this frontend""" + + @abstractmethod + def load(self, path): + """Load a model from a given path. + + Parameters + ---------- + path: str + Path to a file + + Returns + ------- + mod : tvm.relay.Module + The produced relay module. + params : dict + The parameters (weights) for the relay module. + + """ + + +def import_keras(): + """ Lazy import function for Keras""" + # Keras writes the message "Using TensorFlow backend." to stderr + # Redirect stderr during the import to disable this + stderr = sys.stderr + sys.stderr = open(os.devnull, "w") + try: + # pylint: disable=C0415 + import tensorflow as tf + from tensorflow import keras + + return tf, keras + finally: + sys.stderr = stderr + + +class KerasFrontend(Frontend): + """ Keras frontend for TVMC """ + + @staticmethod + def name(): + return "keras" + + @staticmethod + def suffixes(): + return ["h5"] + + def load(self, path): + # pylint: disable=C0103 + tf, keras = import_keras() + + # tvm build currently imports keras directly instead of tensorflow.keras + try: + model = keras.models.load_model(path) + except ValueError as err: + raise TVMCException(str(err)) + + # There are two flavours of keras model, sequential and + # functional, TVM expects a functional model, so convert + # if required: + if self.is_sequential_p(model): + model = self.sequential_to_functional(model) + + in_shapes = [] + for layer in model._input_layers: + if tf.executing_eagerly(): + in_shapes.append(tuple(dim if dim is not None else 1 for dim in layer.input.shape)) + else: + in_shapes.append( + tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape) + ) + + inputs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes] + shape_dict = {name: x.shape for (name, x) in zip(model.input_names, inputs)} + return relay.frontend.from_keras(model, shape_dict, layout="NHWC") + + def is_sequential_p(self, model): + _, keras = import_keras() + return isinstance(model, keras.models.Sequential) + + def sequential_to_functional(self, model): + _, keras = import_keras() + assert self.is_sequential_p(model) + input_layer = keras.layers.Input(batch_shape=model.layers[0].input_shape) + prev_layer = input_layer + for layer in model.layers: + prev_layer = layer(prev_layer) + model = keras.models.Model([input_layer], [prev_layer]) + return model + + +class OnnxFrontend(Frontend): + """ ONNX frontend for TVMC """ + + @staticmethod + def name(): + return "onnx" + + @staticmethod + def suffixes(): + return ["onnx"] + + def load(self, path): + # pylint: disable=C0415 + import onnx + + model = onnx.load(path) + + # pylint: disable=E1101 + name = model.graph.input[0].name + + # pylint: disable=E1101 + proto_shape = model.graph.input[0].type.tensor_type.shape.dim + shape = [d.dim_value for d in proto_shape] + + shape_dict = {name: shape} + + return relay.frontend.from_onnx(model, shape_dict) + + +class TensorflowFrontend(Frontend): + """ TensorFlow frontend for TVMC """ + + @staticmethod + def name(): + return "pb" + + @staticmethod + def suffixes(): + return ["pb"] + + def load(self, path): + # pylint: disable=C0415 + import tensorflow as tf + import tvm.relay.testing.tf as tf_testing + + with tf.io.gfile.GFile(path, "rb") as tf_graph: + content = tf_graph.read() + + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(content) + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + + logging.debug("relay.frontend.from_tensorflow") + return relay.frontend.from_tensorflow(graph_def) + + +class TFLiteFrontend(Frontend): + """ TFLite frontend for TVMC """ + + _tflite_m = { + 0: "float32", + 1: "float16", + 2: "int32", + 3: "uint8", + 4: "int64", + 5: "string", + 6: "bool", + 7: "int16", + 8: "complex64", + 9: "int8", + } + + @staticmethod + def name(): + return "tflite" + + @staticmethod + def suffixes(): + return ["tflite"] + + def load(self, path): + # pylint: disable=C0415 + import tflite.Model as model + + with open(path, "rb") as tf_graph: + content = tf_graph.read() + + # tflite.Model.Model is tflite.Model in 1.14 and 2.1.0 + try: + tflite_model = model.Model.GetRootAsModel(content, 0) + except AttributeError: + tflite_model = model.GetRootAsModel(content, 0) + + try: + version = tflite_model.Version() + logging.debug("tflite version %s", version) + except Exception: + raise TVMCException("input file not tflite") + + if version != 3: + raise TVMCException("input file not tflite version 3") + + logging.debug("tflite_input_type") + shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model) + + # parse TFLite model and convert into Relay computation graph + logging.debug("relay.frontend.from_tflite") + mod, params = relay.frontend.from_tflite( + tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict + ) + return mod, params + + @staticmethod + def _decode_type(n): + return TFLiteFrontend._tflite_m[n] + + @staticmethod + def _input_type(model): + subgraph_count = model.SubgraphsLength() + assert subgraph_count > 0 + shape_dict = {} + dtype_dict = {} + for subgraph_index in range(subgraph_count): + subgraph = model.Subgraphs(subgraph_index) + inputs_count = subgraph.InputsLength() + assert inputs_count >= 1 + for input_index in range(inputs_count): + input_ = subgraph.Inputs(input_index) + assert subgraph.TensorsLength() > input_ + tensor = subgraph.Tensors(input_) + input_shape = tuple(tensor.ShapeAsNumpy()) + tensor_type = tensor.Type() + input_name = tensor.Name().decode("utf8") + shape_dict[input_name] = input_shape + dtype_dict[input_name] = TFLiteFrontend._decode_type(tensor_type) + + return shape_dict, dtype_dict + + +class PyTorchFrontend(Frontend): + """ PyTorch frontend for TVMC """ + + @staticmethod + def name(): + return "pytorch" + + @staticmethod + def suffixes(): + # Torch Script is a zip file, but can be named pth + return ["pth", "zip"] + + def load(self, path): + # pylint: disable=C0415 + import torch + + traced_model = torch.jit.load(path) + + inputs = list(traced_model.graph.inputs())[1:] + input_shapes = [inp.type().sizes() for inp in inputs] + + traced_model.eval() # Switch to inference mode + input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(shapes)] + logging.debug("relay.frontend.from_pytorch") + return relay.frontend.from_pytorch(traced_model, input_shapes) + + +ALL_FRONTENDS = [ + KerasFrontend, + OnnxFrontend, + TensorflowFrontend, + TFLiteFrontend, + PyTorchFrontend, +] + + +def get_frontend_names(): + """Return the names of all supported frontends + + Returns + ------- + list : list of str + A list of frontend names as strings + + """ + return [frontend.name() for frontend in ALL_FRONTENDS] + + +def get_frontend_by_name(name): + """ + This function will try to get a frontend instance, based + on the name provided. + + Parameters + ---------- + name : str + the name of a given frontend + + Returns + ------- + frontend : tvm.driver.tvmc.Frontend + An instance of the frontend that matches with + the file extension provided in `path`. + + """ + + for frontend in ALL_FRONTENDS: + if name == frontend.name(): + return frontend() + + raise TVMCException( + "unrecognized frontend '{0}'. Choose from: {1}".format(name, get_frontend_names()) + ) + + +def guess_frontend(path): + """ + This function will try to imply which framework is being used, + based on the extension of the file provided in the path parameter. + + Parameters + ---------- + path : str + The path to the model file. + + Returns + ------- + frontend : tvm.driver.tvmc.Frontend + An instance of the frontend that matches with + the file extension provided in `path`. + + """ + + suffix = Path(path).suffix.lower() + if suffix.startswith("."): + suffix = suffix[1:] + + for frontend in ALL_FRONTENDS: + if suffix in frontend.suffixes(): + return frontend() + + raise TVMCException("failed to infer the model format. Please specify --model-format") + + +def load_model(path, model_format=None): + """Load a model from a supported framework and convert it + into an equivalent relay representation. + + Parameters + ---------- + path : str + The path to the model file. + model_format : str, optional + The underlying framework used to create the model. + If not specified, this will be inferred from the file type. + + Returns + ------- + mod : tvm.relay.Module + The produced relay module. + params : dict + The parameters (weights) for the relay module. + + """ + + if model_format is not None: + frontend = get_frontend_by_name(model_format) + else: + frontend = guess_frontend(path) + + mod, params = frontend.load(path) + + return mod, params diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py new file mode 100644 index 000000000000..ee67cc904aac --- /dev/null +++ b/tests/python/driver/tvmc/conftest.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import pytest +import tarfile + +import tvm.driver.tvmc.compiler + +from tvm.contrib.download import download_testdata + +from tvm.driver.tvmc.common import convert_graph_layout + +# Support functions + + +def download_and_untar(model_url, model_sub_path, temp_dir): + model_tar_name = os.path.basename(model_url) + model_path = download_testdata(model_url, model_tar_name, module=["tvmc"]) + + if model_path.endswith("tgz") or model_path.endswith("gz"): + tar = tarfile.open(model_path) + tar.extractall(path=temp_dir) + tar.close() + + return os.path.join(temp_dir, model_sub_path) + + +def get_sample_compiled_module(target_dir): + """Support function that retuns a TFLite compiled module""" + base_url = "https://storage.googleapis.com/download.tensorflow.org/models" + model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" + model_file = download_and_untar( + "{}/{}".format(base_url, model_url), + "mobilenet_v1_1.0_224_quant.tflite", + temp_dir=target_dir, + ) + + return tvmc.compiler.compile_model(model_file, targets=["llvm"]) + + +# PyTest fixtures + + +@pytest.fixture(scope="session") +def tflite_mobilenet_v1_1_quant(tmpdir_factory): + base_url = "https://storage.googleapis.com/download.tensorflow.org/models" + model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" + model_file = download_and_untar( + "{}/{}".format(base_url, model_url), + "mobilenet_v1_1.0_224_quant.tflite", + temp_dir=tmpdir_factory.mktemp("data"), + ) + + return model_file + + +@pytest.fixture(scope="session") +def pb_mobilenet_v1_1_quant(tmpdir_factory): + base_url = "https://storage.googleapis.com/download.tensorflow.org/models" + model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz" + model_file = download_and_untar( + "{}/{}".format(base_url, model_url), + "mobilenet_v1_1.0_224_frozen.pb", + temp_dir=tmpdir_factory.mktemp("data"), + ) + + return model_file + + +@pytest.fixture(scope="session") +def keras_resnet50(tmpdir_factory): + try: + from tensorflow.keras.applications.resnet50 import ResNet50 + except ImportError: + # not all environments provide TensorFlow, so skip this fixture + # if that is that case. + return "" + + model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "resnet50.h5") + model = ResNet50(include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000) + model.save(model_file_name) + + return model_file_name + + +@pytest.fixture(scope="session") +def onnx_resnet50(): + base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model" + file_to_download = "resnet50-v2-7.onnx" + model_file = download_testdata( + "{}/{}".format(base_url, file_to_download), file_to_download, module=["tvmc"] + ) + + return model_file + + +@pytest.fixture(scope="session") +def tflite_compiled_module_as_tarfile(tmpdir_factory): + target_dir = tmpdir_factory.mktemp("data") + graph, lib, params, _ = get_sample_compiled_module(target_dir) + + module_file = os.path.join(target_dir, "mock.tar") + tvmc.compiler.save_module(module_file, graph, lib, params) + + return module_file diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py new file mode 100644 index 000000000000..a9a62c5ef874 --- /dev/null +++ b/tests/python/driver/tvmc/test_common.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import argparse +import os +from os import path + +import pytest + +import tvm +from tvm.driver import tvmc + + +def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + before, _ = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + + expected_layout = "NCHW" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NHWC" + and node.attrs.dst_layout == "NCHW" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert any(layout_transform_calls), "Expected 'layout_transform NHWC->NCHW' not found" + + +def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50): + # some CI environments wont offer ONNX, so skip in case it is not present + pytest.importorskip("onnx") + + before, _ = tvmc.frontends.load_model(onnx_resnet50) + + expected_layout = "NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NHWC" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" + + +def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + before, _ = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + + expected_layout = "NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NHWC" + and node.attrs.dst_layout == "NHWC" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" + + +def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50): + # some CI environments wont offer ONNX, so skip in case it is not present + pytest.importorskip("onnx") + + before, _ = tvmc.frontends.load_model(onnx_resnet50) + + expected_layout = "NCHW" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NCHW" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py new file mode 100644 index 000000000000..28a60b19b28e --- /dev/null +++ b/tests/python/driver/tvmc/test_compiler.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import argparse +import os +import shutil +from os import path + +import pytest + +import tvm + +from tvm.driver import tvmc + + +def test_save_dumps(tmpdir_factory): + tmpdir = tmpdir_factory.mktemp("data") + dump_formats = {"relay": "fake relay", "ll": "fake llvm", "asm": "fake asm"} + tvmc.compiler.save_dumps("fake_module", dump_formats, dump_root=tmpdir) + + assert path.exists("{}/{}".format(tmpdir, "fake_module.ll")) + assert path.exists("{}/{}".format(tmpdir, "fake_module.asm")) + assert path.exists("{}/{}".format(tmpdir, "fake_module.relay")) + + +# End to end tests for compilation + + +def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): + pytest.importorskip("tflite") + + graph, lib, params, dumps = tvmc.compiler.compile_model( + tflite_mobilenet_v1_1_quant, + target="llvm", + dump_code="ll", + alter_layout="NCHW", + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict + + +# This test will be skipped if the AArch64 cross-compilation toolchain is not installed. +@pytest.mark.skipif( + not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" +) +def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant): + pytest.importorskip("tflite") + + graph, lib, params, dumps = tvmc.compiler.compile_model( + tflite_mobilenet_v1_1_quant, + target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + dump_code="asm", + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict + + +def test_compile_keras__save_module(keras_resnet50, tmpdir_factory): + # some CI environments wont offer tensorflow/Keras, so skip in case it is not present + pytest.importorskip("tensorflow") + + graph, lib, params, dumps = tvmc.compiler.compile_model( + keras_resnet50, target="llvm", dump_code="ll" + ) + + expected_temp_dir = tmpdir_factory.mktemp("saved_output") + expected_file_name = "saved.tar" + module_file = os.path.join(expected_temp_dir, expected_file_name) + tvmc.compiler.save_module(module_file, graph, lib, params) + + assert os.path.exists(module_file), "output file {0} should exist".format(module_file) + + +# This test will be skipped if the AArch64 cross-compilation toolchain is not installed. +@pytest.mark.skipif( + not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" +) +def test_cross_compile_aarch64_keras_module(keras_resnet50): + # some CI environments wont offer tensorflow/Keras, so skip in case it is not present + pytest.importorskip("tensorflow") + + graph, lib, params, dumps = tvmc.compiler.compile_model( + keras_resnet50, + target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + dump_code="asm", + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict + assert "asm" in dumps.keys() + + +def test_compile_onnx_module(onnx_resnet50): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip("onnx") + + graph, lib, params, dumps = tvmc.compiler.compile_model( + onnx_resnet50, target="llvm", dump_code="ll" + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict + assert "ll" in dumps.keys() + + +# This test will be skipped if the AArch64 cross-compilation toolchain is not installed. +@pytest.mark.skipif( + not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" +) +def test_cross_compile_aarch64_onnx_module(onnx_resnet50): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip("onnx") + + graph, lib, params, dumps = tvmc.compiler.compile_model( + onnx_resnet50, + target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + dump_code="asm", + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict + assert "asm" in dumps.keys() diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py new file mode 100644 index 000000000000..d77a17addabf --- /dev/null +++ b/tests/python/driver/tvmc/test_frontends.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import tarfile + +import pytest + +from tvm.ir.module import IRModule + +from tvm.driver import tvmc +from tvm.driver.tvmc.common import TVMCException + + +def test_get_frontends_contains_only_strings(): + sut = tvmc.frontends.get_frontend_names() + assert all([type(x) is str for x in sut]) is True + + +def test_get_frontend_by_name_valid(): + # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present + pytest.importorskip("tensorflow") + + sut = tvmc.frontends.get_frontend_by_name("keras") + assert type(sut) is tvmc.frontends.KerasFrontend + + +def test_get_frontend_by_name_invalid(): + with pytest.raises(TVMCException): + tvmc.frontends.get_frontend_by_name("unsupported_thing") + + +def test_guess_frontend_tflite(): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + sut = tvmc.frontends.guess_frontend("a_model.tflite") + assert type(sut) is tvmc.frontends.TFLiteFrontend + + +def test_guess_frontend_onnx(): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip("onnx") + + sut = tvmc.frontends.guess_frontend("a_model.onnx") + assert type(sut) is tvmc.frontends.OnnxFrontend + + +def test_guess_frontend_pytorch(): + # some CI environments wont offer pytorch, so skip in case it is not present + pytest.importorskip("torch") + + sut = tvmc.frontends.guess_frontend("a_model.pth") + assert type(sut) is tvmc.frontends.PyTorchFrontend + + +def test_guess_frontend_keras(): + # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present + pytest.importorskip("tensorflow") + + sut = tvmc.frontends.guess_frontend("a_model.h5") + assert type(sut) is tvmc.frontends.KerasFrontend + + +def test_guess_frontend_tensorflow(): + # some CI environments wont offer TensorFlow, so skip in case it is not present + pytest.importorskip("tensorflow") + + sut = tvmc.frontends.guess_frontend("a_model.pb") + assert type(sut) is tvmc.frontends.TensorflowFrontend + + +def test_guess_frontend_invalid(): + with pytest.raises(TVMCException): + tvmc.frontends.guess_frontend("not/a/file.txt") + + +def test_load_model__invalid_path__no_language(): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + with pytest.raises(FileNotFoundError): + tvmc.frontends.load_model("not/a/file.tflite") + + +def test_load_model__invalid_path__with_language(): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip("onnx") + + with pytest.raises(FileNotFoundError): + tvmc.frontends.load_model("not/a/file.txt", model_format="onnx") + + +def test_load_model__tflite(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + mod, params = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + assert type(mod) is IRModule + assert type(params) is dict + # check whether one known value is part of the params dict + assert "_param_1" in params.keys() + + +def test_load_model__keras(keras_resnet50): + # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present + pytest.importorskip("tensorflow") + + mod, params = tvmc.frontends.load_model(keras_resnet50) + assert type(mod) is IRModule + assert type(params) is dict + ## check whether one known value is part of the params dict + assert "_param_1" in params.keys() + + +def test_load_model__onnx(onnx_resnet50): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip("onnx") + + mod, params = tvmc.frontends.load_model(onnx_resnet50) + assert type(mod) is IRModule + assert type(params) is dict + ## check whether one known value is part of the params dict + assert "resnetv24_batchnorm0_gamma" in params.keys() + + +def test_load_model__pb(pb_mobilenet_v1_1_quant): + # some CI environments wont offer TensorFlow, so skip in case it is not present + pytest.importorskip("tensorflow") + + mod, params = tvmc.frontends.load_model(pb_mobilenet_v1_1_quant) + assert type(mod) is IRModule + assert type(params) is dict + # check whether one known value is part of the params dict + assert "MobilenetV1/Conv2d_0/weights" in params.keys() + + +def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present + pytest.importorskip("tensorflow") + + with pytest.raises(OSError): + tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="keras") + + +def test_load_model___wrong_language__to_tflite(keras_resnet50): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + with pytest.raises(TVMCException): + tvmc.frontends.load_model(keras_resnet50, model_format="tflite") + + +def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip("onnx") + + from google.protobuf.message import DecodeError + + with pytest.raises(DecodeError): + tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="onnx") + + +def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer pytorch, so skip in case it is not present + pytest.importorskip("torch") + + with pytest.raises(RuntimeError) as e: + tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="pytorch") diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 35a81e508643..ef86d6917424 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -59,6 +59,9 @@ TVM_FFI=ctypes python3 -m pytest tests/python/contrib TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm;cuda}" TVM_FFI=ctypes python3 -m pytest tests/python/relay +# Command line driver test +TVM_FFI=ctypes python3 -m pytest tests/python/driver + # Do not enable OpenGL # TVM_FFI=cython python -m pytest tests/webgl # TVM_FFI=ctypes python3 -m pytest tests/webgl