Skip to content

Commit

Permalink
[MLF] Add support for multiple modules in Model Library Format (#11464)
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh authored Jun 17, 2022
1 parent c5465d8 commit 648154d
Show file tree
Hide file tree
Showing 17 changed files with 481 additions and 229 deletions.
17 changes: 12 additions & 5 deletions apps/microtvm/arduino/template_project/microtvm_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,21 @@ def _template_model_header(self, source_dir, metadata):
with open(source_dir / "model.h", "r") as f:
model_h_template = Template(f.read())

assert (
metadata["style"] == "full-model"
all_module_names = []
for name in metadata["modules"].keys():
all_module_names.append(name)

assert all(
metadata["modules"][mod_name]["style"] == "full-model" for mod_name in all_module_names
), "when generating AOT, expect only full-model Model Library Format"

template_values = {
"workspace_size_bytes": metadata["memory"]["functions"]["main"][0][
workspace_size_bytes = 0
for mod_name in all_module_names:
workspace_size_bytes += metadata["modules"][mod_name]["memory"]["functions"]["main"][0][
"workspace_size_bytes"
],
]
template_values = {
"workspace_size_bytes": workspace_size_bytes,
}

with open(source_dir / "model.h", "w") as f:
Expand Down
17 changes: 14 additions & 3 deletions python/tvm/driver/tvmc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,20 @@ def import_package(self, package_path: str):
with open(temp.relpath("metadata.json")) as metadata_json:
metadata = json.load(metadata_json)

has_graph_executor = "graph" in metadata["executors"]
graph = temp.relpath("executor-config/graph/graph.json") if has_graph_executor else None
params = temp.relpath(f'parameters/{metadata["model_name"]}.params')
all_module_names = []
for name in metadata["modules"].keys():
all_module_names.append(name)
assert len(all_module_names) == 1, "Multiple modules in MLF is not supported."

module_name = all_module_names[0]
module_metdata = metadata["modules"][module_name]
has_graph_executor = "graph" in module_metdata["executors"]
graph = (
temp.relpath(f"executor-config/graph/{module_name}.graph")
if has_graph_executor
else None
)
params = temp.relpath(f"parameters/{module_name}.params")

self.type = "mlf"
else:
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/micro/contrib/stm32/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,18 @@ def parse_library_format(self, model_library_format_path, quantization=None):
with tarfile.TarFile(model_library_format_path) as f:
f.extractall(extract_path)

with open(os.path.join(extract_path, "metadata.json")) as metadata_f:
metadata = json.load(metadata_f)

all_module_names = []
for name in metadata["modules"].keys():
all_module_names.append(name)
assert len(metadata["modules"]) == 1, "Multiple modules is not supported."

# Extract informations from the Model Library Format
graph_file = os.path.join(extract_path, "executor-config", "graph", "graph.json")
graph_file = os.path.join(
extract_path, "executor-config", "graph", f"{all_module_names[0]}.graph"
)
with open(graph_file, "r") as f:
# returns JSON object as a dictionary
graph_dict = json.load(f)
Expand Down
Loading

0 comments on commit 648154d

Please sign in to comment.