From 81eabefce8783b724ac02a30e3bc80ab3e77c2b1 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Tue, 24 May 2022 12:23:04 -0700 Subject: [PATCH 1/6] Add multiple module support to MLF --- python/tvm/driver/tvmc/model.py | 17 +- python/tvm/micro/model_library_format.py | 286 +++++++++++------- python/tvm/micro/testing/utils.py | 31 +- tests/micro/zephyr/test_zephyr.py | 1 - tests/python/relay/aot/test_crt_aot.py | 37 +-- .../test_micro_model_library_format.py | 263 +++++++++++++--- 6 files changed, 440 insertions(+), 195 deletions(-) diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 04946ec9c6d0..5f40d2122312 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -391,9 +391,20 @@ def import_package(self, package_path: str): with open(temp.relpath("metadata.json")) as metadata_json: metadata = json.load(metadata_json) - has_graph_executor = "graph" in metadata["executors"] - graph = temp.relpath("executor-config/graph/graph.json") if has_graph_executor else None - params = temp.relpath(f'parameters/{metadata["model_name"]}.params') + all_module_names = [] + for name in metadata["modules"].keys(): + all_module_names.append(name) + assert len(all_module_names) == 1, "Multiple modules in MLF is not supported." + + module_name = all_module_names[0] + module_metdata = metadata["modules"][module_name] + has_graph_executor = "graph" in module_metdata["executors"] + graph = ( + temp.relpath(f"executor-config/graph/{module_name}.graph") + if has_graph_executor + else None + ) + params = temp.relpath(f"parameters/{module_name}.params") self.type = "mlf" else: diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 1dd63b319dbd..fece8183c1ef 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -24,6 +24,7 @@ import re import tarfile import typing +from typing import Union import tvm from tvm.ir.type import TupleType @@ -39,6 +40,7 @@ # This should be kept identical to runtime::symbol::tvm_module_main MAIN_FUNC_NAME_STR = "__tvm_main__" STANDALONE_CRT_URL = "./runtime" +METADATA_FILE = "metadata.json" class UnsupportedInModelLibraryFormatError(Exception): @@ -67,56 +69,75 @@ def generate_c_interface_header( EPHEMERAL_MODULE_TYPE_KEYS = ("metadata_module",) -def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None): +def _populate_codegen_dir( + mods: Union[ + typing.List[executor_factory.ExecutorFactoryModule], typing.List[tvm.runtime.Module] + ], + codegen_dir: str, +): """Populate the codegen sub-directory as part of a Model Library Format export. Parameters ---------- - mod : tvm.runtime.Module - Module which should be written to codegen_dir. + mods : List[tvm.relay.backend.executor_factory.ExecutorFactoryModule], List[tvm.runtime.Module] + A list of the return value of tvm.relay.build, which + will be exported into Model Library Format. codegen_dir : str Path to the codegen directory on disk. module_name: Optional[str] Name used to prefix the generated source files """ - dso_modules = mod._collect_dso_modules() - non_dso_modules = mod._collect_from_import_tree(lambda m: m not in dso_modules) - - # Filter ephemeral modules which cannot be exported. - dso_modules = [m for m in dso_modules if m.type_key not in EPHEMERAL_MODULE_TYPE_KEYS] - non_dso_modules = [m for m in non_dso_modules if m.type_key not in EPHEMERAL_MODULE_TYPE_KEYS] + dso_modules = [] + for mod in mods: + if isinstance(mod, executor_factory.ExecutorFactoryModule): + lib = mod.lib + elif isinstance(mod, tvm.runtime.Module): + lib = mod + + dso_modules = lib._collect_dso_modules() + non_dso_modules = lib._collect_from_import_tree(lambda m: m not in dso_modules) + + # Filter ephemeral modules which cannot be exported. + dso_modules = [m for m in dso_modules if m.type_key not in EPHEMERAL_MODULE_TYPE_KEYS] + non_dso_modules = [ + m for m in non_dso_modules if m.type_key not in EPHEMERAL_MODULE_TYPE_KEYS + ] + + if non_dso_modules: + raise UnsupportedInModelLibraryFormatError( + f"Don't know how to export non-c or non-llvm modules; found: {non_dso_modules!r}" + ) - if non_dso_modules: - raise UnsupportedInModelLibraryFormatError( - f"Don't know how to export non-c or non-llvm modules; found: {non_dso_modules!r}" + mod_indices = {"lib": 0, "src": 0} + host_codegen_dir = os.path.join(codegen_dir, "host") + lib_name = ( + f"{mod.libmod_name}_lib" + if isinstance(mod, executor_factory.ExecutorFactoryModule) + else "lib" ) - mod_indices = {"lib": 0, "src": 0} - host_codegen_dir = os.path.join(codegen_dir, "host") - lib_name = f"{module_name}_lib" if module_name else "lib" - - for dso_mod in dso_modules: - if dso_mod.type_key == "c": - assert dso_mod.format in ["c", "cc", "cpp"] - ext = dso_mod.format - index = mod_indices["src"] - mod_indices["src"] += 1 - parent_dir = os.path.join(host_codegen_dir, "src") - file_name = os.path.join(parent_dir, f"{lib_name}{index}.{ext}") - elif dso_mod.type_key == "llvm": - index = mod_indices["lib"] - mod_indices["lib"] += 1 - parent_dir = os.path.join(host_codegen_dir, "lib") - file_name = os.path.join(parent_dir, f"{lib_name}{index}.o") - else: - assert ( - False - ), f"do not expect module with type_key={mod.type_key} from _collect_dso_modules" - - if not os.path.exists(parent_dir): - os.makedirs(parent_dir) - dso_mod.save(file_name) + for dso_mod in dso_modules: + if dso_mod.type_key == "c": + assert dso_mod.format in ["c", "cc", "cpp"] + ext = dso_mod.format + index = mod_indices["src"] + mod_indices["src"] += 1 + parent_dir = os.path.join(host_codegen_dir, "src") + file_name = os.path.join(parent_dir, f"{lib_name}{index}.{ext}") + elif dso_mod.type_key == "llvm": + index = mod_indices["lib"] + mod_indices["lib"] += 1 + parent_dir = os.path.join(host_codegen_dir, "lib") + file_name = os.path.join(parent_dir, f"{lib_name}{index}.o") + else: + assert ( + False + ), f"do not expect module with type_key={lib.type_key} from _collect_dso_modules" + + if not os.path.exists(parent_dir): + os.makedirs(parent_dir) + dso_mod.save(file_name) def _build_memory_map(mod): @@ -312,86 +333,119 @@ def reset(tarinfo): tar_f.add(get_standalone_crt_dir(), arcname=STANDALONE_CRT_URL) -_GENERATED_VERSION = 6 +_GENERATED_VERSION = 7 + + +def _is_module_names_unique(mods: typing.List[executor_factory.ExecutorFactoryModule]): + """Check if built modules have unique names. + + Parameters + ---------- + mods : List[tvm.relay.backend.executor_factory.ExecutorFactoryModule] + A list of the return value of tvm.relay.build, + which will be exported into Model Library Format. + """ + all_names = [] + for mod in mods: + all_names.append(mod.libmod_name) + + return len(set(all_names)) == len(all_names) def _export_graph_model_library_format( - mod: executor_factory.ExecutorFactoryModule, tempdir: pathlib.Path + mods: typing.List[executor_factory.ExecutorFactoryModule], tempdir: pathlib.Path ): """Export a tvm.relay.build artifact in Model Library Format. Parameters ---------- - mod : tvm.relay.backend.executor_factory.ExecutorFactoryModule - The return value of tvm.relay.build, which will be exported into Model Library Format. + mods : List[tvm.relay.backend.executor_factory.ExecutorFactoryModule] + A list of the return value of tvm.relay.build, + which will be exported into Model Library Format. tempdir : pathlib.Path Temporary directory to populate with Model Library Format contents. """ - is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) - executor = ["aot"] if is_aot else ["graph"] + + assert _is_module_names_unique(mods), "Multiple modules should have unique names." metadata = { "version": _GENERATED_VERSION, - "model_name": mod.libmod_name, - "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), - "memory": _build_memory_map(mod), - "target": [str(t) for t in mod.target], - "executors": executor, - "style": "full-model", } - - if is_aot and (str(mod.runtime) == "crt"): - standalone_crt = { - "short_name": "tvm_standalone_crt", - "url": f"{STANDALONE_CRT_URL}", - "url_type": "mlf_path", - "version_spec": f"{tvm.__version__}", + metadata["modules"] = {} + for mod in mods: + is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) + executor = ["aot"] if is_aot else ["graph"] + module_name = mod.libmod_name + metadata["modules"][module_name] = { + "model_name": module_name, + "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), + "memory": _build_memory_map(mod), + "target": [str(t) for t in mod.target], + "executors": executor, + "style": "full-model", } - external_dependencies = [standalone_crt] - metadata["external_dependencies"] = external_dependencies - with open(tempdir / "metadata.json", "w") as json_f: + if is_aot and (str(mod.runtime) == "crt"): + standalone_crt = { + "short_name": "tvm_standalone_crt", + "url": f"{STANDALONE_CRT_URL}", + "url_type": "mlf_path", + "version_spec": f"{tvm.__version__}", + } + external_dependencies = [standalone_crt] + metadata["modules"][module_name]["external_dependencies"] = external_dependencies + + with open(tempdir / METADATA_FILE, "w") as json_f: json.dump(metadata, json_f, indent=2, sort_keys=True) codegen_dir = tempdir / "codegen" codegen_dir.mkdir() - _populate_codegen_dir(mod.lib, codegen_dir, mod.libmod_name) - - if _should_generate_interface_header(mod): - include_path = codegen_dir / "host" / "include" - include_path.mkdir() - inputs, outputs = _get_inputs_and_outputs_from_module(mod) - devices = mod.get_devices() - pools = _get_pools_from_module(mod) - io_pool_allocations = _get_io_pool_allocation_from_module(mod) - workspace_size = int(metadata["memory"]["functions"]["main"][0]["workspace_size_bytes"]) - generate_c_interface_header( - mod.libmod_name, - inputs, - outputs, - pools, - io_pool_allocations, - devices, - workspace_size, - include_path, - ) + _populate_codegen_dir(mods, codegen_dir) parameters_dir = tempdir / "parameters" parameters_dir.mkdir() - param_filename = parameters_dir / f"{mod.libmod_name}.params" - with open(param_filename, "wb") as f: - f.write(param_dict.save_param_dict(mod.params)) - src_dir = tempdir / "src" src_dir.mkdir() - with open(src_dir / "relay.txt", "w") as f: - f.write(str(mod.ir_mod)) + graph_config_dir = tempdir / "executor-config" / "graph" + for mod in mods: + if _should_generate_interface_header(mod): + include_path = codegen_dir / "host" / "include" + if not include_path.exists(): + include_path.mkdir() + + inputs, outputs = _get_inputs_and_outputs_from_module(mod) + devices = mod.get_devices() + pools = _get_pools_from_module(mod) + io_pool_allocations = _get_io_pool_allocation_from_module(mod) + workspace_size = int( + metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0][ + "workspace_size_bytes" + ] + ) + generate_c_interface_header( + mod.libmod_name, + inputs, + outputs, + pools, + io_pool_allocations, + devices, + workspace_size, + include_path, + ) + + is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) + param_filename = parameters_dir / f"{mod.libmod_name}.params" + with open(param_filename, "wb") as f: + f.write(param_dict.save_param_dict(mod.params)) + + with open(src_dir / f"{mod.libmod_name}.relay", "w") as f: + f.write(str(mod.ir_mod)) - if not is_aot: - graph_config_dir = tempdir / "executor-config" / "graph" - graph_config_dir.mkdir(parents=True) - with open(graph_config_dir / "graph.json", "w") as f: - f.write(mod.get_executor_config()) + if not is_aot: + if not graph_config_dir.exists(): + graph_config_dir.mkdir(parents=True) + with open(graph_config_dir / f"{mod.libmod_name}.graph", "w") as f: + f.write(mod.get_executor_config()) class NonStaticShapeError(Exception): @@ -451,14 +505,11 @@ def _eval_shape(param_name, buffer_shape): def _export_operator_model_library_format(mod: build_module.OperatorModule, tempdir): """Export the result of tvm.build() in Model Library Format. - Parameters ---------- mod : runtime.Module The Module returned from tvm.build(). - args : list of Buffer or Tensor or Var, optional - The args supplied to tvm.build(). - file_name : str + tempdir : str Path to the .tar archive to generate. """ targets = [] @@ -484,12 +535,12 @@ def _export_operator_model_library_format(mod: build_module.OperatorModule, temp "executors": [], "style": "operator", } - with open(tempdir / "metadata.json", "w") as metadata_f: + with open(tempdir / METADATA_FILE, "w") as metadata_f: json.dump(metadata, metadata_f) codegen_dir = tempdir / "codegen" codegen_dir.mkdir() - _populate_codegen_dir(mod, codegen_dir) + _populate_codegen_dir(list([mod]), codegen_dir) ExportableModule = typing.Union[ @@ -499,7 +550,10 @@ def _export_operator_model_library_format(mod: build_module.OperatorModule, temp ] -def export_model_library_format(mod: ExportableModule, file_name: typing.Union[str, pathlib.Path]): +def export_model_library_format( + mods: Union[ExportableModule, typing.List[ExportableModule]], + file_name: typing.Union[str, pathlib.Path], +): """Export the build artifact in Model Library Format. This function creates a .tar archive containing the build artifacts in a standardized @@ -508,7 +562,7 @@ def export_model_library_format(mod: ExportableModule, file_name: typing.Union[s Parameters ---------- - mod : ExportableModule + mod : ExportableModule, List[ExportableModule] The return value of tvm.build or tvm.relay.build. file_name : str Path to the .tar archive to generate. @@ -518,20 +572,36 @@ def export_model_library_format(mod: ExportableModule, file_name: typing.Union[s file_name : str The path to the generated .tar archive. """ - file_name = pathlib.Path(file_name) + modules = mods + if not isinstance(mods, list): + modules = list([mods]) + + operator_module_type = all(isinstance(mod, build_module.OperatorModule) for mod in modules) + graph_module_type = all( + isinstance( + mod, + ( + executor_factory.AOTExecutorFactoryModule, + executor_factory.GraphExecutorFactoryModule, + ), + ) + for mod in modules + ) + file_name = pathlib.Path(file_name) tempdir = utils.tempdir() - if isinstance(mod, build_module.OperatorModule): - _export_operator_model_library_format(mod, tempdir.path) - elif isinstance( - mod, - (executor_factory.AOTExecutorFactoryModule, executor_factory.GraphExecutorFactoryModule), - ): - _export_graph_model_library_format(mod, tempdir.path) + if operator_module_type: + if len(modules) != 1: + raise RuntimeError("Multiple operator is not supported.") + _export_operator_model_library_format(modules[0], tempdir.path) + elif graph_module_type: + _export_graph_model_library_format(modules, tempdir.path) else: - raise NotImplementedError(f"Don't know how to export module of type {mod.__class__!r}") + raise NotImplementedError( + f"Don't know how to export module of type {modules[0].__class__!r}" + ) - _make_tar(tempdir.path, file_name, mod) + _make_tar(tempdir.path, file_name, modules) return file_name diff --git a/python/tvm/micro/testing/utils.py b/python/tvm/micro/testing/utils.py index 81e29a92a86a..a48c8dc3230f 100644 --- a/python/tvm/micro/testing/utils.py +++ b/python/tvm/micro/testing/utils.py @@ -24,6 +24,8 @@ import time from typing import Union +import tvm +from tvm import relay from tvm.micro.project_api.server import IoTimeoutError # Timeout in seconds for AOT transport. @@ -77,9 +79,36 @@ def _read_line(transport, timeout_sec: int) -> str: def mlf_extract_workspace_size_bytes(mlf_tar_path: Union[pathlib.Path, str]) -> int: """Extract an MLF archive file and read workspace size from metadata file.""" + workspace_size = 0 with tarfile.open(mlf_tar_path, "r:*") as tar_file: tar_members = [ti.name for ti in tar_file.getmembers()] assert "./metadata.json" in tar_members with tar_file.extractfile("./metadata.json") as f: metadata = json.load(f) - return metadata["memory"]["functions"]["main"][0]["workspace_size_bytes"] + for mod_name in metadata["modules"].keys(): + workspace_size += metadata["modules"][mod_name]["memory"]["functions"]["main"][0][ + "workspace_size_bytes" + ] + return workspace_size + + +def get_conv2d_relay_module(): + """Generate a conv2d Relay module for testing.""" + data_shape = (1, 3, 64, 64) + weight_shape = (8, 3, 5, 5) + data = relay.var("data", relay.TensorType(data_shape, "int8")) + weight = relay.var("weight", relay.TensorType(weight_shape, "int8")) + y = relay.nn.conv2d( + data, + weight, + padding=(2, 2), + channels=8, + kernel_size=(5, 5), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + f = relay.Function([data, weight], y) + mod = tvm.IRModule.from_expr(f) + mod = relay.transform.InferType()(mod) + return mod diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index 2651435434b1..05c8daa20c21 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -17,7 +17,6 @@ import logging import os import pathlib -import sys import logging import pytest diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 1a4f23ad467a..a46e7925fe6f 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -44,6 +44,7 @@ create_relay_module_and_inputs_from_tflite_file, ) from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER, parametrize_aot_options +from tvm.micro.testing.utils import get_conv2d_relay_module def test_error_c_interface_with_packed_api(): @@ -75,23 +76,7 @@ def test_error_c_interface_with_packed_api(): @parametrize_aot_options def test_conv_with_params(interface_api, use_unpacked_api, test_runner): - """Tests compilation of convolution with parameters""" - relay_model = """ -#[version = "0.0.5"] -def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), int8]) { - %1 = nn.conv2d( - %data, - %weight, - padding=[2, 2], - channels=8, - kernel_size=[5, 5], - data_layout="NCHW", - kernel_layout="OIHW", - out_dtype="int32"); - %1 -} -""" - mod = tvm.parser.fromtext(relay_model) + mod = get_conv2d_relay_module() main_func = mod["main"] shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} @@ -576,23 +561,7 @@ def test_multiple_models(interface_api, use_unpacked_api, test_runner): params1 = None # Convolution model - relay_model = """ - #[version = "0.0.5"] - def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), int8]) { - %1 = nn.conv2d( - %data, - %weight, - padding=[2, 2], - channels=8, - kernel_size=[5, 5], - data_layout="NCHW", - kernel_layout="OIHW", - out_dtype="int32"); - %1 - } - """ - - mod2 = tvm.parser.fromtext(relay_model) + mod2 = get_conv2d_relay_module() main_func = mod2["main"] shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index d707e6b4646b..3a07459ad028 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -15,15 +15,19 @@ # specific language governing permissions and limitations # under the License. +import pathlib +import sys import datetime import json import os import tarfile -import numpy +import numpy as np import pytest import platform +pytest.importorskip("tvm.micro") + import tvm import tvm.relay from tvm.relay.backend import Executor, Runtime @@ -31,12 +35,14 @@ import tvm.runtime.module import tvm.testing from tvm.contrib import utils +import tvm.micro as micro +from tvm.micro.testing.utils import get_conv2d_relay_module +import tvm.micro.model_library_format as model_library_format +from tvm.micro.model_library_format import _GENERATED_VERSION @tvm.testing.requires_micro def test_export_operator_model_library_format(): - import tvm.micro as micro - target = tvm.target.target.micro("host") with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): A = tvm.te.placeholder((2,), dtype="int8") @@ -63,7 +69,7 @@ def test_export_operator_model_library_format(): with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 6 + assert metadata["version"] == _GENERATED_VERSION assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" @@ -95,8 +101,35 @@ def test_export_operator_model_library_format(): assert tir_f.read() == str(ir_mod) +@tvm.testing.requires_micro +def test_export_multiple_operator_model_library_format(): + target = tvm.target.target.micro("host") + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + A = tvm.te.placeholder((2,), dtype="int8") + B = tvm.te.placeholder((1,), dtype="int8") + C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C") + sched = tvm.te.create_schedule(C.op) + mod = tvm.build( + sched, + [A, B, C], + tvm.target.Target(target, target), + runtime=Runtime("crt", {"system-lib": True}), + name="add", + ) + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + + with pytest.raises(RuntimeError) as exc: + micro.export_model_library_format([mod, mod], mlf_tar_path) + + assert str(exc.exception) == ("Multiple operator is not supported.") + + def validate_graph_json(extract_dir, factory): - with open(os.path.join(extract_dir, "executor-config", "graph", "graph.json")) as graph_f: + with open( + os.path.join(extract_dir, "executor-config", "graph", f"{factory.libmod_name}.graph") + ) as graph_f: graph_json = graph_f.read() assert graph_json == factory.graph_json @@ -141,12 +174,11 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ executor=executor, runtime=runtime, mod_name="add", - params={"c": numpy.array([[2.0, 4.0]], dtype="float32")}, + params={"c": np.array([[2.0, 4.0]], dtype="float32")}, ) temp_dir = utils.tempdir() mlf_tar_path = temp_dir.relpath("lib.tar") - import tvm.micro as micro micro.export_model_library_format(factory, mlf_tar_path) tf = tarfile.open(mlf_tar_path) @@ -157,21 +189,22 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 6 - assert metadata["model_name"] == "add" + module_name = factory.libmod_name + assert metadata["version"] == _GENERATED_VERSION + assert metadata["modules"][module_name]["model_name"] == "add" export_datetime = datetime.datetime.strptime( - metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" + metadata["modules"][module_name]["export_datetime"], "%Y-%m-%d %H:%M:%SZ" ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) - assert metadata["target"] == [str(target)] + assert metadata["modules"][module_name]["target"] == [str(target)] if str(executor) == "graph": - assert metadata["memory"]["sids"] == [ + assert metadata["modules"][module_name]["memory"]["sids"] == [ {"storage_id": 0, "size_bytes": 2, "input_binding": "a"}, {"storage_id": 1, "size_bytes": 8, "input_binding": "b"}, {"storage_id": 2, "size_bytes": 8, "input_binding": "p0"}, {"storage_id": 3, "size_bytes": 8}, ] - assert metadata["memory"]["functions"]["main"] == [ + assert metadata["modules"][module_name]["memory"]["functions"]["main"] == [ { "constants_size_bytes": json_constants_size_bytes, "device": 1, @@ -179,12 +212,14 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ "workspace_size_bytes": 0, } ] - assert metadata["memory"]["functions"]["operator_functions"][0]["workspace"] == [ - {"device": 1, "workspace_size_bytes": 0} - ] + assert metadata["modules"][module_name]["memory"]["functions"]["operator_functions"][0][ + "workspace" + ] == [{"device": 1, "workspace_size_bytes": 0}] assert ( "fused_cast_multiply_add" - in metadata["memory"]["functions"]["operator_functions"][0]["function_name"] + in metadata["modules"][module_name]["memory"]["functions"]["operator_functions"][0][ + "function_name" + ] ) assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "add_lib0.c")) @@ -196,7 +231,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ if str(executor) == "graph": validate_graph_json(extract_dir, factory) - with open(os.path.join(extract_dir, "src", "relay.txt")) as relay_f: + with open(os.path.join(extract_dir, "src", f"{module_name}.relay")) as relay_f: assert relay_f.read() == str(relay_mod) with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f: @@ -227,12 +262,11 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ target, runtime=Runtime("crt", {"system-lib": True}), mod_name="add", - params={"c": numpy.array([[2.0, 4.0]], dtype="float32")}, + params={"c": np.array([[2.0, 4.0]], dtype="float32")}, ) temp_dir = utils.tempdir() mlf_tar_path = temp_dir.relpath("lib.tar") - import tvm.micro as micro micro.export_model_library_format(factory, mlf_tar_path) tf = tarfile.open(mlf_tar_path) @@ -243,20 +277,21 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 6 - assert metadata["model_name"] == "add" + module_name = factory.libmod_name + assert metadata["version"] == _GENERATED_VERSION + assert metadata["modules"][module_name]["model_name"] == "add" export_datetime = datetime.datetime.strptime( - metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" + metadata["modules"][module_name]["export_datetime"], "%Y-%m-%d %H:%M:%SZ" ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) - assert metadata["target"] == [str(target)] - assert metadata["memory"]["sids"] == [ + assert metadata["modules"][module_name]["target"] == [str(target)] + assert metadata["modules"][module_name]["memory"]["sids"] == [ {"storage_id": 0, "size_bytes": 2, "input_binding": "a"}, {"storage_id": 1, "size_bytes": 8, "input_binding": "b"}, {"storage_id": 2, "size_bytes": 8, "input_binding": "p0"}, {"storage_id": 3, "size_bytes": 8}, ] - assert metadata["memory"]["functions"]["main"] == [ + assert metadata["modules"][module_name]["memory"]["functions"]["main"] == [ { "constants_size_bytes": 8, "device": 1, @@ -264,19 +299,21 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ "workspace_size_bytes": 0, } ] - assert metadata["memory"]["functions"]["operator_functions"][0]["workspace"] == [ - {"device": 1, "workspace_size_bytes": 0} - ] + assert metadata["modules"][module_name]["memory"]["functions"]["operator_functions"][0][ + "workspace" + ] == [{"device": 1, "workspace_size_bytes": 0}] assert ( "fused_cast_multiply_add" - in metadata["memory"]["functions"]["operator_functions"][0]["function_name"] + in metadata["modules"][module_name]["memory"]["functions"]["operator_functions"][0][ + "function_name" + ] ) assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "lib", "add_lib0.o")) validate_graph_json(extract_dir, factory) - with open(os.path.join(extract_dir, "src", "relay.txt")) as relay_f: + with open(os.path.join(extract_dir, "src", f"{module_name}.relay")) as relay_f: assert relay_f.read() == str(relay_mod) with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f: @@ -314,7 +351,6 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 temp_dir = utils.tempdir() mlf_tar_path = temp_dir.relpath("lib.tar") - import tvm.micro as micro micro.export_model_library_format(factory, mlf_tar_path) tf = tarfile.open(mlf_tar_path) @@ -325,14 +361,15 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 6 - assert metadata["model_name"] == "qnn_conv2d" + module_name = factory.libmod_name + assert metadata["version"] == _GENERATED_VERSION + assert metadata["modules"][module_name]["model_name"] == "qnn_conv2d" export_datetime = datetime.datetime.strptime( - metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" + metadata["modules"][module_name]["export_datetime"], "%Y-%m-%d %H:%M:%SZ" ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) - assert metadata["target"] == [str(target)] - assert metadata["memory"]["functions"]["main"] == [ + assert metadata["modules"][module_name]["target"] == [str(target)] + assert metadata["modules"][module_name]["memory"]["functions"]["main"] == [ { "constants_size_bytes": 0, "device": 1, @@ -340,12 +377,14 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 "workspace_size_bytes": 2466816, } ] - assert metadata["memory"]["functions"]["operator_functions"][0]["workspace"] == [ - {"device": 1, "workspace_size_bytes": 2466816} - ] + assert metadata["modules"][module_name]["memory"]["functions"]["operator_functions"][0][ + "workspace" + ] == [{"device": 1, "workspace_size_bytes": 2466816}] assert ( "fused_nn_conv2d_add_fixed_point_multiply_clip_cast" - in metadata["memory"]["functions"]["operator_functions"][0]["function_name"] + in metadata["modules"][module_name]["memory"]["functions"]["operator_functions"][0][ + "function_name" + ] ) @@ -354,11 +393,9 @@ def test_export_non_dso_exportable(): module = tvm.support.FrontendTestModule() temp_dir = utils.tempdir() - import tvm.micro as micro - import tvm.micro.model_library_format as model_library_format with pytest.raises(micro.UnsupportedInModelLibraryFormatError) as exc: - model_library_format._populate_codegen_dir(module, temp_dir.relpath("codegen")) + model_library_format._populate_codegen_dir([module], temp_dir.relpath("codegen")) assert str(exc.exception) == ( "Don't know how to export non-c or non-llvm modules; found: ffi_testing" @@ -408,8 +445,6 @@ def test_export_byoc_c_module(): temp_dir = utils.tempdir() mlf_tar_path = temp_dir.relpath("lib.tar") - from tvm import micro - micro.export_model_library_format(factory, mlf_tar_path) with tarfile.open(mlf_tar_path, "r:*") as tf: @@ -418,7 +453,7 @@ def test_export_byoc_c_module(): assert "./metadata.json" in tar_members with tf.extractfile("./metadata.json") as f: metadata = json.load(f) - main_md = metadata["memory"]["functions"]["main"] + main_md = metadata["modules"][factory.libmod_name]["memory"]["functions"]["main"] if platform.architecture()[0] == "64bit": assert main_md == [ { @@ -439,5 +474,137 @@ def test_export_byoc_c_module(): ] +@tvm.testing.requires_micro +def test_multiple_relay_modules_same_module_name(): + mod = get_conv2d_relay_module() + + executor = Executor("graph") + runtime = Runtime("crt") + target = tvm.target.target.micro("host") + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + factory1 = tvm.relay.build(mod, target, runtime=runtime, executor=executor, mod_name="mod") + factory2 = tvm.relay.build(mod, target, runtime=runtime, executor=executor, mod_name="mod") + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + + with pytest.raises(AssertionError, match="Multiple modules should have unique names"): + micro.export_model_library_format([factory1, factory2], mlf_tar_path) + + +@tvm.testing.requires_micro +def test_multiple_relay_modules_graph(): + mod = get_conv2d_relay_module() + + executor = Executor("graph") + runtime = Runtime("crt") + target = tvm.target.target.micro("host") + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + factory1 = tvm.relay.build(mod, target, runtime=runtime, executor=executor, mod_name="mod1") + factory2 = tvm.relay.build(mod, target, runtime=runtime, executor=executor, mod_name="mod2") + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + micro.export_model_library_format([factory1, factory2], mlf_tar_path) + + with tarfile.open(mlf_tar_path, "r:*") as tf: + tar_members = [ti.name for ti in tf.getmembers()] + print("tar members", tar_members) + assert "./metadata.json" in tar_members + assert "./codegen/host/src/mod1_lib0.c" in tar_members + assert "./codegen/host/src/mod2_lib0.c" in tar_members + + with tf.extractfile("./metadata.json") as f: + metadata = json.load(f) + mod2_main_md = metadata["modules"]["mod2"]["memory"]["functions"]["main"] + assert mod2_main_md == [ + { + "constants_size_bytes": 0, + "device": 1, + "io_size_bytes": 143960, + "workspace_size_bytes": 158088, + } + ] + assert metadata["modules"]["mod1"]["model_name"] == "mod1" + assert metadata["modules"]["mod2"]["model_name"] == "mod2" + + +@tvm.testing.requires_micro +def test_multiple_relay_modules_c(): + mod = get_conv2d_relay_module() + + executor = Executor("aot", {"unpacked-api": True, "interface-api": "c"}) + runtime = Runtime("crt") + target = tvm.target.target.micro("host") + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + factory1 = tvm.relay.build(mod, target, runtime=runtime, executor=executor, mod_name="mod1") + factory2 = tvm.relay.build(mod, target, runtime=runtime, executor=executor, mod_name="mod2") + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + + micro.export_model_library_format([factory1, factory2], mlf_tar_path) + + tf = tarfile.open(mlf_tar_path) + + extract_dir = temp_dir.relpath("extract") + os.mkdir(extract_dir) + tf.extractall(extract_dir) + + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib0.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib1.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod2_lib0.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod2_lib1.c")) + + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "include", "tvmgen_mod1.h")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "include", "tvmgen_mod2.h")) + + +@tvm.testing.requires_micro +def test_multiple_relay_modules_aot_graph(): + mod = get_conv2d_relay_module() + + executor1 = Executor("graph") + executor2 = Executor("aot", {"unpacked-api": True, "interface-api": "c"}) + runtime = Runtime("crt") + target = tvm.target.target.micro("host") + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + factory1 = tvm.relay.build( + mod, target, runtime=runtime, executor=executor1, mod_name="mod1" + ) + factory2 = tvm.relay.build( + mod, target, runtime=runtime, executor=executor2, mod_name="mod2" + ) + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + + micro.export_model_library_format([factory1, factory2], mlf_tar_path) + + tf = tarfile.open(mlf_tar_path) + extract_dir = temp_dir.relpath("extract") + os.mkdir(extract_dir) + tf.extractall(extract_dir) + + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib0.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib1.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib2.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod2_lib0.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod2_lib1.c")) + + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "include", "tvmgen_mod2.h")) + + with open(os.path.join(extract_dir, "metadata.json")) as f: + metadata = json.load(f) + + assert metadata["modules"]["mod1"]["executors"] == ["graph"] + assert metadata["modules"]["mod2"]["executors"] == ["aot"] + assert metadata["version"] == _GENERATED_VERSION + + if __name__ == "__main__": - tvm.testing.main() + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 2cbfe5e63dddc89c91cce5846549d311e90d475b Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Thu, 2 Jun 2022 09:25:33 -0700 Subject: [PATCH 2/6] Fix pytest refactor --- tests/python/relay/strategy/arm_cpu/test_avg_pool.py | 4 +--- tests/python/relay/strategy/arm_cpu/test_conv1d_ncw.py | 4 +--- tests/python/relay/strategy/arm_cpu/test_conv1d_nwc.py | 4 +--- tests/python/relay/strategy/arm_cpu/test_conv2d_NCHWc.py | 4 +--- tests/python/relay/strategy/arm_cpu/test_dense_dsp.py | 4 +--- tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py | 2 -- .../relay/strategy/arm_cpu/test_depthwise_conv2d_NCHWc.py | 4 +--- tests/python/relay/strategy/arm_cpu/test_group_conv2d.py | 2 -- tests/python/relay/strategy/arm_cpu/test_max_pool.py | 5 +---- 9 files changed, 7 insertions(+), 26 deletions(-) diff --git a/tests/python/relay/strategy/arm_cpu/test_avg_pool.py b/tests/python/relay/strategy/arm_cpu/test_avg_pool.py index 31a812b38eed..3d6690a1a16f 100644 --- a/tests/python/relay/strategy/arm_cpu/test_avg_pool.py +++ b/tests/python/relay/strategy/arm_cpu/test_avg_pool.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys import numpy as np -import pytest import tvm import tvm.testing from tvm import relay @@ -165,4 +163,4 @@ class TestAvgPool3d(BasicPoolTests): if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + tvm.testing.main() diff --git a/tests/python/relay/strategy/arm_cpu/test_conv1d_ncw.py b/tests/python/relay/strategy/arm_cpu/test_conv1d_ncw.py index 0f0507cfe7d3..b1dda10c4294 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv1d_ncw.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv1d_ncw.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys import numpy as np -import pytest import tvm import tvm.testing from tvm import relay @@ -114,4 +112,4 @@ class TestConv1d_ncw(BasicConv1dTests): if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + tvm.testing.main() diff --git a/tests/python/relay/strategy/arm_cpu/test_conv1d_nwc.py b/tests/python/relay/strategy/arm_cpu/test_conv1d_nwc.py index e430ade2fac1..3daed6221f68 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv1d_nwc.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv1d_nwc.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys import numpy as np -import pytest import tvm import tvm.testing from tvm import relay @@ -142,4 +140,4 @@ class TestConv1d_nwc(BasicConv1dTests): if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + tvm.testing.main() diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d_NCHWc.py b/tests/python/relay/strategy/arm_cpu/test_conv2d_NCHWc.py index 3b43d37c9075..8ca132ffba75 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv2d_NCHWc.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d_NCHWc.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys import numpy as np -import pytest import tvm import tvm.testing from tvm import relay @@ -135,4 +133,4 @@ class TestConv2d_NCHWc(BasicConv2dTests): if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + tvm.testing.main() diff --git a/tests/python/relay/strategy/arm_cpu/test_dense_dsp.py b/tests/python/relay/strategy/arm_cpu/test_dense_dsp.py index 3edffba8acaa..a69ea6c09e79 100644 --- a/tests/python/relay/strategy/arm_cpu/test_dense_dsp.py +++ b/tests/python/relay/strategy/arm_cpu/test_dense_dsp.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys import numpy as np -import pytest import tvm import tvm.testing from tvm import relay @@ -87,4 +85,4 @@ class TestDense(BasicDenseTests): if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + tvm.testing.main() diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py index 96628a6371d0..ee0d51c321f7 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys import numpy as np -import pytest import tvm import tvm.testing from tvm import relay diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d_NCHWc.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d_NCHWc.py index 69e9ab09e4c9..178b44edbd40 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d_NCHWc.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d_NCHWc.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys import numpy as np -import pytest import tvm import tvm.testing from tvm import relay @@ -118,4 +116,4 @@ class TestDepthWiseConv2d_NCHWc(BasicConv2dTests): if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + tvm.testing.main() diff --git a/tests/python/relay/strategy/arm_cpu/test_group_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_group_conv2d.py index b24c651de988..47fe6d9f74c2 100644 --- a/tests/python/relay/strategy/arm_cpu/test_group_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_group_conv2d.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys import numpy as np -import pytest import tvm import tvm.testing from tvm import relay diff --git a/tests/python/relay/strategy/arm_cpu/test_max_pool.py b/tests/python/relay/strategy/arm_cpu/test_max_pool.py index f58a041ecb74..ee890261d1b4 100644 --- a/tests/python/relay/strategy/arm_cpu/test_max_pool.py +++ b/tests/python/relay/strategy/arm_cpu/test_max_pool.py @@ -14,10 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from pickle import FALSE -import sys import numpy as np -import pytest import tvm import tvm.testing from tvm import relay @@ -132,4 +129,4 @@ class TestMaxPool3d(BasicPoolTests): if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + tvm.testing.main() From 4c02084d9494664cec2089fc01e7b3b8302e44df Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Thu, 2 Jun 2022 09:57:40 -0700 Subject: [PATCH 3/6] fix errors --- .../template_project/microtvm_api_server.py | 17 ++++++++++++----- python/tvm/micro/contrib/stm32/emitter.py | 12 +++++++++++- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index 131f92a20829..0e922f06cb51 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -214,14 +214,21 @@ def _template_model_header(self, source_dir, metadata): with open(source_dir / "model.h", "r") as f: model_h_template = Template(f.read()) - assert ( - metadata["style"] == "full-model" + all_module_names = [] + for name in metadata["modules"].keys(): + all_module_names.append(name) + + assert all( + metadata["modules"][mod_name]["style"] == "full-model" for mod_name in all_module_names ), "when generating AOT, expect only full-model Model Library Format" - template_values = { - "workspace_size_bytes": metadata["memory"]["functions"]["main"][0][ + workspace_size_bytes = 0 + for mod_name in all_module_names: + workspace_size_bytes += metadata["modules"][mod_name]["memory"]["functions"]["main"][0][ "workspace_size_bytes" - ], + ] + template_values = { + "workspace_size_bytes": workspace_size_bytes, } with open(source_dir / "model.h", "w") as f: diff --git a/python/tvm/micro/contrib/stm32/emitter.py b/python/tvm/micro/contrib/stm32/emitter.py index aec5912871fd..b8c41eb402eb 100644 --- a/python/tvm/micro/contrib/stm32/emitter.py +++ b/python/tvm/micro/contrib/stm32/emitter.py @@ -482,8 +482,18 @@ def parse_library_format(self, model_library_format_path, quantization=None): with tarfile.TarFile(model_library_format_path) as f: f.extractall(extract_path) + with open(os.path.join(extract_path, "metadata.json")) as metadata_f: + metadata = json.load(metadata_f) + + all_module_names = [] + for name in metadata["modules"].keys(): + all_module_names.append(name) + assert len(all_module_names) == 1, "Multiple modules is not supported." + # Extract informations from the Model Library Format - graph_file = os.path.join(extract_path, "executor-config", "graph", "graph.json") + graph_file = os.path.join( + extract_path, "executor-config", "graph", f"{all_module_names[0]}.graph" + ) with open(graph_file, "r") as f: # returns JSON object as a dictionary graph_dict = json.load(f) From ef53c11991fde3396a88a30fb51272d940b5fbb6 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Thu, 2 Jun 2022 16:21:52 -0700 Subject: [PATCH 4/6] fix runtime directory --- python/tvm/micro/model_library_format.py | 11 +++++++---- .../unittest/test_micro_model_library_format.py | 3 +++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index fece8183c1ef..9fa5b0ba01fb 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -318,7 +318,7 @@ def _should_generate_interface_header(mod): return "interface-api" in mod.executor and mod.executor["interface-api"] == "c" -def _make_tar(source_dir, tar_file_path, mod): +def _make_tar(source_dir, tar_file_path, modules): """Build a tar file from source_dir.""" with tarfile.open(tar_file_path, "w") as tar_f: @@ -328,9 +328,12 @@ def reset(tarinfo): return tarinfo tar_f.add(str(source_dir), arcname=".", filter=reset) - is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) - if is_aot and str(mod.runtime) == "crt": - tar_f.add(get_standalone_crt_dir(), arcname=STANDALONE_CRT_URL) + + for mod in modules: + is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) + if is_aot and str(mod.runtime) == "crt": + tar_f.add(get_standalone_crt_dir(), arcname=STANDALONE_CRT_URL) + break _GENERATED_VERSION = 7 diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 3a07459ad028..0caae1cdd9d4 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -562,6 +562,9 @@ def test_multiple_relay_modules_c(): assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "include", "tvmgen_mod1.h")) assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "include", "tvmgen_mod2.h")) + # check CRT runtime directory + assert os.path.exists(os.path.join(extract_dir, "runtime")) + @tvm.testing.requires_micro def test_multiple_relay_modules_aot_graph(): From a7560f1efadcb591324f6801108631b345ccb655 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Tue, 14 Jun 2022 11:24:40 -0700 Subject: [PATCH 5/6] address comments --- python/tvm/micro/contrib/stm32/emitter.py | 2 +- python/tvm/micro/model_library_format.py | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/tvm/micro/contrib/stm32/emitter.py b/python/tvm/micro/contrib/stm32/emitter.py index b8c41eb402eb..814f98f1b788 100644 --- a/python/tvm/micro/contrib/stm32/emitter.py +++ b/python/tvm/micro/contrib/stm32/emitter.py @@ -488,7 +488,7 @@ def parse_library_format(self, model_library_format_path, quantization=None): all_module_names = [] for name in metadata["modules"].keys(): all_module_names.append(name) - assert len(all_module_names) == 1, "Multiple modules is not supported." + assert len(metadata["modules"]) == 1, "Multiple modules is not supported." # Extract informations from the Model Library Format graph_file = os.path.join( diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 9fa5b0ba01fb..d13947527e81 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -24,7 +24,6 @@ import re import tarfile import typing -from typing import Union import tvm from tvm.ir.type import TupleType @@ -70,8 +69,9 @@ def generate_c_interface_header( def _populate_codegen_dir( - mods: Union[ - typing.List[executor_factory.ExecutorFactoryModule], typing.List[tvm.runtime.Module] + mods: typing.Union[ + typing.List[executor_factory.ExecutorFactoryModule], + typing.List[build_module.OperatorModule], ], codegen_dir: str, ): @@ -92,8 +92,10 @@ def _populate_codegen_dir( for mod in mods: if isinstance(mod, executor_factory.ExecutorFactoryModule): lib = mod.lib - elif isinstance(mod, tvm.runtime.Module): + elif isinstance(mod, build_module.OperatorModule): lib = mod + else: + raise RuntimeError(f"Not supported module type: {type(mod)}") dso_modules = lib._collect_dso_modules() non_dso_modules = lib._collect_from_import_tree(lambda m: m not in dso_modules) @@ -554,7 +556,7 @@ def _export_operator_model_library_format(mod: build_module.OperatorModule, temp def export_model_library_format( - mods: Union[ExportableModule, typing.List[ExportableModule]], + mods: typing.Union[ExportableModule, typing.List[ExportableModule]], file_name: typing.Union[str, pathlib.Path], ): """Export the build artifact in Model Library Format. From 36e38f9db7181e5736c7ede7ac88be8c23e0301d Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Tue, 14 Jun 2022 14:56:56 -0700 Subject: [PATCH 6/6] revert to tvm.runtime.Module --- python/tvm/micro/model_library_format.py | 4 ++-- tests/python/relay/aot/test_crt_aot.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index d13947527e81..e220fa1ca543 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -71,7 +71,7 @@ def generate_c_interface_header( def _populate_codegen_dir( mods: typing.Union[ typing.List[executor_factory.ExecutorFactoryModule], - typing.List[build_module.OperatorModule], + typing.List[tvm.runtime.Module], ], codegen_dir: str, ): @@ -92,7 +92,7 @@ def _populate_codegen_dir( for mod in mods: if isinstance(mod, executor_factory.ExecutorFactoryModule): lib = mod.lib - elif isinstance(mod, build_module.OperatorModule): + elif isinstance(mod, tvm.runtime.Module): lib = mod else: raise RuntimeError(f"Not supported module type: {type(mod)}") diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index a46e7925fe6f..987d425aa63d 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -76,6 +76,7 @@ def test_error_c_interface_with_packed_api(): @parametrize_aot_options def test_conv_with_params(interface_api, use_unpacked_api, test_runner): + """Tests compilation of convolution with parameters""" mod = get_conv2d_relay_module() main_func = mod["main"] shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params}