-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLF] Add support for multiple modules in Model Library Format #11464
Conversation
@@ -449,47 +498,53 @@ def _eval_shape(param_name, buffer_shape): | |||
return memory_map | |||
|
|||
|
|||
def _export_operator_model_library_format(mod: build_module.OperatorModule, tempdir): | |||
def _export_operator_model_library_format( | |||
mods: typing.List[build_module.OperatorModule], tempdir: pathlib.Path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wonder if for operator MLF, we should just not allow multiple module export? if they don't have a mod_name, then we can't really identify them or namespace them properly. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah makes sense.
6a94c64
to
8665d73
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @mehrdadh !
all_module_names = [] | ||
for name in metadata["modules"].keys(): | ||
all_module_names.append(name) | ||
assert len(all_module_names) == 1, "Multiple modules is not supported." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think you could simplify to just len(metadata["modules"])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -24,6 +24,7 @@ | |||
import re | |||
import tarfile | |||
import typing | |||
from typing import Union |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we should unify on style here (either also import List, etc) or just typing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None): | ||
def _populate_codegen_dir( | ||
mods: Union[ | ||
typing.List[executor_factory.ExecutorFactoryModule], typing.List[tvm.runtime.Module] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how come tvm.runtime.Module is allowed? also, do we want List[Union[ExecutorFactoryModule, tvm.runtime.Module]]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to build_module.OperatorModule
I think Union of List is correct because the input could be Union of two different list, but not a list of two different type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually since we are passing modules like build_module.OperatorModule
and tvm.support.FrontendTestModule
I changed it back to tvm.runtime.Module
with pytest.raises(RuntimeError) as exc: | ||
micro.export_model_library_format([mod, mod], mlf_tar_path) | ||
|
||
assert str(exc.exception) == ("Multiple operator is not supported.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems like in this case the error is that you've passed duplicate mod
, right? maybe we should be checking for that or de-duping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I wanted to simply check the case where we pass multiple operator modules. I can add another module with different name to make it more obvious. thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah can you add another? it seems like passing duplicate modules is a separate error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, will do a follow up PR. thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@areusch thanks for the review! PTAL.
all_module_names = [] | ||
for name in metadata["modules"].keys(): | ||
all_module_names.append(name) | ||
assert len(all_module_names) == 1, "Multiple modules is not supported." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -24,6 +24,7 @@ | |||
import re | |||
import tarfile | |||
import typing | |||
from typing import Union |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None): | ||
def _populate_codegen_dir( | ||
mods: Union[ | ||
typing.List[executor_factory.ExecutorFactoryModule], typing.List[tvm.runtime.Module] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to build_module.OperatorModule
I think Union of List is correct because the input could be Union of two different list, but not a list of two different type
with pytest.raises(RuntimeError) as exc: | ||
micro.export_model_library_format([mod, mod], mlf_tar_path) | ||
|
||
assert str(exc.exception) == ("Multiple operator is not supported.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I wanted to simply check the case where we pass multiple operator modules. I can add another module with different name to make it more obvious. thoughts?
@mkatanbaf could you try this out and verify it's working for you? |
@areusch sure, working on it now! |
f284b60
to
06c9b51
Compare
06c9b51
to
36e38f9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @mehrdadh ! we can address the comment in a follow-up PR.
for the record @mkatanbaf says this has unblocked him in porting multi-model Corstone-300 tests to Project API.
with pytest.raises(RuntimeError) as exc: | ||
micro.export_model_library_format([mod, mod], mlf_tar_path) | ||
|
||
assert str(exc.exception) == ("Multiple operator is not supported.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah can you add another? it seems like passing duplicate modules is a separate error
This PR implements this RFC: apache/tvm-rfcs#76
cc @areusch @gromero