Skip to content

Commit

Permalink
[tvmc] command line driver 'compile' (part 2/4) (apache#6302)
Browse files Browse the repository at this point in the history
* [tvmc] command line driver 'compile' (part 2/4)

 * Add 'compile' subcommand into tvmc (tvm.driver.tvmc)
 * Add frontends: Keras, ONNX, TensorFlow, tflite, PyTorch
 * Add tests for the 'compile' subcommand
 * Enable command line driver tests as part of integration tests
 * Skip tests if the cross-compilation toolchain is not installed


Co-authored-by: Marcus Shawcroft <[email protected]>
Co-authored-by: Matthew Barrett <[email protected]>
Co-authored-by: Dmitriy Smirnov <[email protected]>
Co-authored-by: Luke Hutton <[email protected]>
Co-authored-by: Giuseppe Rossini <[email protected]>
Co-authored-by: Matthew Barrett <[email protected]>
Co-authored-by: Elen Kalda <[email protected]>
Co-authored-by: Ramana Radhakrishnan <[email protected]>
Co-authored-by: Jeremy Johnson <[email protected]>
Co-authored-by: Ina Dobreva <[email protected]>

* tvmc: adjust TODOs

* tvmc: fix linting errors

* Address code-review comments

* Adjust pytest fixture to not break when there is no tensorflow

* Fix frontend tests, to cope with different frameworks in different images

* Apply suggestions from code review

Co-authored-by: Cody Yu <[email protected]>

* Fix lint and code-review issues

* Re-format with black.

* tvmc: Move dependencies to extras_requires

Co-authored-by: Marcus Shawcroft <[email protected]>
Co-authored-by: Matthew Barrett <[email protected]>
Co-authored-by: Dmitriy Smirnov <[email protected]>
Co-authored-by: Luke Hutton <[email protected]>
Co-authored-by: Giuseppe Rossini <[email protected]>
Co-authored-by: Elen Kalda <[email protected]>
Co-authored-by: Ramana Radhakrishnan <[email protected]>
Co-authored-by: Jeremy Johnson <[email protected]>
Co-authored-by: Ina Dobreva <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
  • Loading branch information
11 people authored and Tushar Dey committed Oct 15, 2020
1 parent 9c26a1b commit 52028fd
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 171 deletions.
2 changes: 0 additions & 2 deletions python/tvm/driver/tvmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,4 @@
TVMC - TVM driver command-line interface
"""

from . import autotuner
from . import compiler
from . import runner
74 changes: 0 additions & 74 deletions python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,10 @@
"""
Common utility functions shared by TVMC modules.
"""
import logging
import os.path

from urllib.parse import urlparse

import tvm

from tvm import relay
from tvm import transform


# pylint: disable=invalid-name
logger = logging.getLogger("TVMC")


class TVMCException(Exception):
"""TVMC Exception"""

Expand Down Expand Up @@ -74,66 +63,3 @@ def convert_graph_layout(mod, desired_layout):
raise TVMCException(
"Error converting layout to {0}: {1}".format(desired_layout, str(err))
)


# TODO In a separate PR, eliminate the duplicated code here and in compiler.py (@leandron)
def target_from_cli(target):
"""
Create a tvm.target.Target instance from a
command line interface (CLI) string.
Parameters
----------
target : str
compilation target as plain string,
inline JSON or path to a JSON file
Returns
-------
tvm.target.Target
an instance of target device information
"""

if os.path.exists(target):
with open(target) as target_file:
logger.info("using target input from file: %s", target)
target = "".join(target_file.readlines())

# TODO(@leandron) We don't have an API to collect a list of supported
# targets yet
logger.debug("creating target from input: %s", target)

return tvm.target.Target(target)


def tracker_host_port_from_cli(rpc_tracker_str):
"""Extract hostname and (optional) port from strings
like "1.2.3.4:9090" or "4.3.2.1".
Used as a helper function to cover --rpc-tracker
command line argument, in different subcommands.
Parameters
----------
rpc_tracker_str : str
hostname (or IP address) and port of the RPC tracker,
in the format 'hostname[:port]'.
Returns
-------
rpc_hostname : str or None
hostname or IP address, extracted from input.
rpc_port : int or None
port number extracted from input (9090 default).
"""

rpc_hostname = rpc_port = None

if rpc_tracker_str:
parsed_url = urlparse("//%s" % rpc_tracker_str)
rpc_hostname = parsed_url.hostname
rpc_port = parsed_url.port or 9090
logger.info("RPC tracker hostname: %s", rpc_hostname)
logger.info("RPC tracker port: %s", rpc_port)

return rpc_hostname, rpc_port
30 changes: 13 additions & 17 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@
from .main import register_parser


# pylint: disable=invalid-name
logger = logging.getLogger("TVMC")


@register_parser
def add_compile_parser(subparsers):
""" Include parser for 'compile' subcommand """
Expand Down Expand Up @@ -111,7 +107,7 @@ def drive_compile(args):
None,
args.model_format,
args.tuning_records,
args.desired_layout,
args.tensor_layout,
)

if dumps:
Expand Down Expand Up @@ -180,24 +176,26 @@ def compile_model(
# Handle the case in which target is a path to a JSON file.
if os.path.exists(target):
with open(target) as target_file:
logger.info("using target input from file: %s", target)
logging.info("using target input from file: %s", target)
target = "".join(target_file.readlines())

# TODO(@leandron) We don't have an API to collect a list of supported
# targets yet
logger.debug("creating target from input: %s", target)
logging.debug("creating target from input: %s", target)
tvm_target = tvm.target.Target(target)
target_host = target_host or ""

if tuning_records and os.path.exists(tuning_records):
logger.debug("tuning records file provided: %s", tuning_records)
# TODO (@leandron) a new PR will introduce the 'tune' subcommand
# the is used to generate the tuning records file
logging.debug("tuning records file provided: %s", tuning_records)
with autotvm.apply_history_best(tuning_records):
with tvm.transform.PassContext(opt_level=3):
logger.debug("building relay graph with tuning records")
logging.debug("building relay graph with tuning records")
graph_module = relay.build(mod, tvm_target, params=params, target_host=tvm_target)
else:
with tvm.transform.PassContext(opt_level=3):
logger.debug("building relay graph (no tuning records provided)")
logging.debug("building relay graph (no tuning records provided)")
graph_module = relay.build(mod, tvm_target, params=params, target_host=tvm_target)

# Generate output dump files with sources
Expand All @@ -210,8 +208,6 @@ def compile_model(
source = str(mod) if source_type == "relay" else lib.get_source(source_type)
dumps[source_type] = source

# TODO we need to update this return to use the updated graph module APIs
# as these getter functions will be deprecated in the next release (@leandron)
return graph_module.get_json(), graph_module.get_lib(), graph_module.get_params(), dumps


Expand Down Expand Up @@ -241,21 +237,21 @@ def save_module(module_path, graph, lib, params, cross=None):
temp = util.tempdir()
path_lib = temp.relpath(lib_name)
if not cross:
logger.debug("exporting library to %s", path_lib)
logging.debug("exporting library to %s", path_lib)
lib.export_library(path_lib)
else:
logger.debug("exporting library to %s , using cross compiler %s", path_lib, cross)
logging.debug("exporting library to %s , using cross compiler %s", path_lib, cross)
lib.export_library(path_lib, cc.cross_compiler(cross))

with open(temp.relpath(graph_name), "w") as graph_file:
logger.debug("writing graph to file to %s", graph_file.name)
logging.debug("writing graph to file to %s", graph_file.name)
graph_file.write(graph)

with open(temp.relpath(param_name), "wb") as params_file:
logger.debug("writing params to file to %s", params_file.name)
logging.debug("writing params to file to %s", params_file.name)
params_file.write(relay.save_param_dict(params))

logger.debug("saving module as tar file to %s", module_path)
logging.debug("saving module as tar file to %s", module_path)
with tarfile.open(module_path, "w") as tar:
tar.add(path_lib, lib_name)
tar.add(temp.relpath(graph_name), graph_name)
Expand Down
17 changes: 6 additions & 11 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@
from tvm.driver.tvmc.common import TVMCException


# pylint: disable=invalid-name
logger = logging.getLogger("TVMC")


class Frontend(ABC):
"""Abstract class for command line driver frontend.
Expand Down Expand Up @@ -158,7 +154,6 @@ def load(self, path):
# pylint: disable=C0415
import onnx

# pylint: disable=E1101
model = onnx.load(path)

# pylint: disable=E1101
Expand Down Expand Up @@ -196,7 +191,7 @@ def load(self, path):
graph_def.ParseFromString(content)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

logger.debug("parse TensorFlow model and convert into Relay computation graph")
logging.debug("relay.frontend.from_tensorflow")
return relay.frontend.from_tensorflow(graph_def)


Expand Down Expand Up @@ -239,17 +234,18 @@ def load(self, path):

try:
version = tflite_model.Version()
logger.debug("tflite version %s", version)
logging.debug("tflite version %s", version)
except Exception:
raise TVMCException("input file not tflite")

if version != 3:
raise TVMCException("input file not tflite version 3")

logger.debug("tflite_input_type")
logging.debug("tflite_input_type")
shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model)

logger.debug("parse TFLite model and convert into Relay computation graph")
# parse TFLite model and convert into Relay computation graph
logging.debug("relay.frontend.from_tflite")
mod, params = relay.frontend.from_tflite(
tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
)
Expand Down Expand Up @@ -305,8 +301,7 @@ def load(self, path):

traced_model.eval() # Switch to inference mode
input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(shapes)]

logger.debug("parse Torch model and convert into Relay computation graph")
logging.debug("relay.frontend.from_pytorch")
return relay.frontend.from_pytorch(traced_model, input_shapes)


Expand Down
41 changes: 5 additions & 36 deletions tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
import pytest
import tarfile

import numpy as np

from PIL import Image

from tvm.driver import tvmc
import tvm.driver.tvmc.compiler

from tvm.contrib.download import download_testdata

from tvm.driver.tvmc.common import convert_graph_layout

# Support functions


Expand All @@ -42,7 +40,7 @@ def download_and_untar(model_url, model_sub_path, temp_dir):


def get_sample_compiled_module(target_dir):
"""Support function that returns a TFLite compiled module"""
"""Support function that retuns a TFLite compiled module"""
base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
model_file = download_and_untar(
Expand All @@ -51,7 +49,7 @@ def get_sample_compiled_module(target_dir):
temp_dir=target_dir,
)

return tvmc.compiler.compile_model(model_file, target="llvm")
return tvmc.compiler.compile_model(model_file, targets=["llvm"])


# PyTest fixtures
Expand Down Expand Up @@ -112,39 +110,10 @@ def onnx_resnet50():

@pytest.fixture(scope="session")
def tflite_compiled_module_as_tarfile(tmpdir_factory):

# Not all CI environments will have TFLite installed
# so we need to safely skip this fixture that will
# crash the tests that rely on it.
# As this is a pytest.fixture, we cannot take advantage
# of pytest.importorskip. Using the block below instead.
try:
import tflite
except ImportError:
print("Cannot import tflite, which is required by tflite_compiled_module_as_tarfile.")
return ""

target_dir = tmpdir_factory.mktemp("data")
graph, lib, params, _ = get_sample_compiled_module(target_dir)

module_file = os.path.join(target_dir, "mock.tar")
tvmc.compiler.save_module(module_file, graph, lib, params)

return module_file


@pytest.fixture(scope="session")
def imagenet_cat(tmpdir_factory):
tmpdir_name = tmpdir_factory.mktemp("data")
cat_file_name = "imagenet_cat.npz"

cat_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
image_path = download_testdata(cat_url, "inputs", module=["tvmc"])
resized_image = Image.open(image_path).resize((224, 224))
image_data = np.asarray(resized_image).astype("float32")
image_data = np.expand_dims(image_data, axis=0)

cat_file_full_path = os.path.join(tmpdir_name, cat_file_name)
np.savez(cat_file_full_path, input=image_data)

return cat_file_full_path
31 changes: 0 additions & 31 deletions tests/python/driver/tvmc/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,34 +118,3 @@ def _is_layout_transform(node):
tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert not any(layout_transform_calls), "Unexpected 'layout_transform' call"


def test_tracker_host_port_from_cli__hostname_port():
input_str = "1.2.3.4:9090"
expected_host = "1.2.3.4"
expected_port = 9090

actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str)

assert expected_host == actual_host
assert expected_port == actual_port


def test_tracker_host_port_from_cli__hostname_port__empty():
input_str = ""

actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str)

assert actual_host is None
assert actual_port is None


def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090():
input_str = "1.2.3.4"
expected_host = "1.2.3.4"
expected_port = 9090

actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str)

assert expected_host == actual_host
assert expected_port == actual_port

0 comments on commit 52028fd

Please sign in to comment.