Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix][UMA] Protect target registration (#13624) #14057

Merged
merged 1 commit into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gallery/tutorial/uma.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
#

######################################################################
# .. image:: https://raw.githubusercontent.com/apache/tvm-site/main/images/tutorial/uma_vanilla_block_diagram.png
# .. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/tutorial/uma_vanilla_block_diagram.png
# :width: 100%
# :alt: A block diagram of Vanilla
#
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relay/backend/contrib/uma/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,12 @@ def register(self) -> None:
"""
registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget")

for name, attr in self._target_attrs:
for name, attr in self._target_attrs.items():
if attr is None:
raise ValueError("Target attribute None is not supported.")

if registration_func(self.target_name, self._target_attrs):
# skip if target is already registered
if self.target_name not in tvm.target.Target.list_kinds():
registration_func(self.target_name, self._target_attrs)
self._relay_to_relay.register()
self._relay_to_tir.register()
self._tir_to_runtime.register()
Expand Down
24 changes: 15 additions & 9 deletions src/relay/backend/contrib/uma/targets.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,23 @@ namespace tvm {
namespace relay {
namespace contrib {
namespace uma {
tvm::transform::Pass RelayToTIR(String target_name);
transform::Pass RelayToTIR(String target_name);
runtime::Module TIRToRuntime(IRModule mod, Target target);
} // namespace uma
} // namespace contrib
} // namespace relay

TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget")
.set_body_typed([](String target_name, Map<String, ObjectRef> attr_options) -> bool {
// @todo(cgerum): We probably should get rid of target.register rather sooner than later
// And use a proper registry for uma backends
for (const String registered_target_name : ::tvm::TargetKindRegEntry::ListTargetKinds()) {
// create only new target and init only once
for (const String registered_target_name : TargetKindRegEntry::ListTargetKinds()) {
if (registered_target_name == target_name) {
return false;
LOG(FATAL) << "TVM UMA Error: Target is already registered: " << target_name;
}
}

auto target_kind =
::tvm::TargetKindRegEntry::RegisterOrGet(target_name)
TargetKindRegEntry::RegisterOrGet(target_name)
.set_name()
.set_default_device_type(kDLCPU)
.add_attr_option<Array<String>>("keys")
Expand All @@ -58,20 +57,27 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget")
.add_attr_option<Array<String>>("libs")
.add_attr_option<Target>("host")
.add_attr_option<Integer>("from_device")
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR,
.set_attr<FTVMRelayToTIR>(attr::kRelayToTIR,
relay::contrib::uma::RelayToTIR(target_name))
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", relay::contrib::uma::TIRToRuntime);

// target kind attrs inventory
auto kind = TargetKind::Get(target_name).value();
auto list_attrs = TargetKindRegEntry::ListTargetKindOptions(kind);

for (auto& attr_option : attr_options) {
auto option_name = attr_option.first;
auto default_value = attr_option.second;
if (list_attrs.find(option_name) != list_attrs.end()) {
LOG(FATAL) << "TVM UMA Error: Attribute is already registered: " << option_name;
}
if (default_value->IsInstance<StringObj>()) {
target_kind.add_attr_option<String>(option_name, Downcast<String>(default_value));
} else if (default_value->IsInstance<IntImmNode>()) {
target_kind.add_attr_option<Integer>(option_name, Downcast<Integer>(default_value));
} else {
LOG(FATAL) << "Only String, Integer, or Bool are supported. Given attribute option type: "
<< attr_option.second->GetTypeKey();
LOG(FATAL) << "TypeError: Only String, Integer, or Bool are supported. "
<< "Given attribute option type: " << attr_option.second->GetTypeKey();
}
}
return true;
Expand Down
25 changes: 22 additions & 3 deletions tests/python/contrib/test_uma/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,42 @@ def test_uma_target(target_name, target_attrs, target_args):
[
("float_attr", 3.14),
("none_attr", None),
("model", "my_model"),
],
)
def test_invalid_attr_option(attr_name: str, target_attr: Union[str, int, bool, float, None]):
registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget")
if target_attr is None:
# None cannot be caught as TVMError, as it causes a SIGKILL, therefore it must be prevented to be
# entered into relay.backend.contrib.uma.RegisterTarget at Python level.
with pytest.raises(ValueError):
with pytest.raises(ValueError, match=r"Target attribute None is not supported."):
uma_backend = VanillaAcceleratorBackend()
uma_backend._target_attrs = {attr_name: target_attr}
uma_backend.register()
elif "model" in attr_name:
target_name = f"{attr_name}_{target_attr}"
target_attr = {attr_name: target_attr}
with pytest.raises(tvm.TVMError, match=r"Attribute is already registered: .*"):
registration_func(target_name, target_attr)
else:
registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget")
target_name = f"{attr_name}_{target_attr}"
target_attr = {attr_name: target_attr}
with pytest.raises(tvm.TVMError, match=r"Only String, Integer, or Bool are supported. .*"):
with pytest.raises(TypeError, match=r"Only String, Integer, or Bool are supported. .*"):
registration_func(target_name, target_attr)


@pytest.mark.parametrize(
"target_name",
[
"llvm",
"c",
],
)
def test_target_duplication(target_name: str):
with pytest.raises(tvm.TVMError, match=r"TVM UMA Error: Target is already registered: .*"):
registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget")
registration_func(target_name, {})


if __name__ == "__main__":
tvm.testing.main()