Skip to content

Commit

Permalink
Improve error messages in graph tuner, graph runtime, and module load…
Browse files Browse the repository at this point in the history
…er. (#6148)

* Raise error if no operators are found in GraphTuner

* Raise error if key cannot be found in graph runtime inputs

* Detailed error message when module loader is not found
  • Loading branch information
tkonolige authored Jul 29, 2020
1 parent 44ff1f3 commit 2e93aef
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 5 deletions.
3 changes: 3 additions & 0 deletions python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def __init__(self, graph, input_shapes, records, target_ops,

self._graph = graph
self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys())
if len(self._in_nodes_dict) == 0:
raise RuntimeError("Could not find any input nodes with whose "
"operator is one of %s" % self._target_ops)
self._out_nodes_dict = get_out_nodes(self._in_nodes_dict)
self._fetch_cfg()
self._opt_out_op = OPT_OUT_OP
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def set_input(self, key=None, value=None, **params):
Additional arguments
"""
if key is not None:
self._get_input(key).copyfrom(value)
v = self._get_input(key)
if v is None:
raise RuntimeError("Could not find '%s' in graph's inputs" % key)
v.copyfrom(value)

if params:
# upload big arrays first to avoid memory issue in rpc mode
Expand Down
19 changes: 17 additions & 2 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,24 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
CHECK(stream->Read(&import_tree_row_ptr));
CHECK(stream->Read(&import_tree_child_indices));
} else {
std::string fkey = "runtime.module.loadbinary_" + tkey;
std::string loadkey = "runtime.module.loadbinary_";
std::string fkey = loadkey + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented.";
if (f == nullptr) {
std::string loaders = "";
for (auto name : Registry::ListNames()) {
if (name.rfind(loadkey, 0) == 0) {
if (loaders.size() > 0) {
loaders += ", ";
}
loaders += name.substr(loadkey.size());
}
}
CHECK(f != nullptr)
<< "Binary was created using " << tkey
<< " but a loader of that name is not registered. Available loaders are " << loaders
<< ". Perhaps you need to recompile with this runtime enabled.";
}
Module m = (*f)(static_cast<void*>(stream));
modules.emplace_back(m);
}
Expand Down
19 changes: 17 additions & 2 deletions src/runtime/stackvm/stackvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,24 @@ class StackVMModuleNode : public runtime::ModuleNode {
for (uint64_t i = 0; i < num_imports; ++i) {
std::string tkey;
CHECK(strm->Read(&tkey));
std::string fkey = "runtime.module.loadbinary_" + tkey;
std::string loadkey = "runtime.module.loadbinary_";
std::string fkey = loadkey + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented.";
if (f == nullptr) {
std::string loaders = "";
for (auto name : Registry::ListNames()) {
if (name.rfind(loadkey, 0) == 0) {
if (loaders.size() > 0) {
loaders += ", ";
}
loaders += name.substr(loadkey.size());
}
}
CHECK(f != nullptr)
<< "Binary was created using " << tkey
<< " but a loader of that name is not registered. Available loaders are " << loaders
<< ". Perhaps you need to recompile with this runtime enabled.";
}
Module m = (*f)(static_cast<void*>(strm));
n->imports_.emplace_back(std::move(m));
}
Expand Down

0 comments on commit 2e93aef

Please sign in to comment.