diff --git a/include/tvm/runtime/crt/aot_executor.h b/include/tvm/runtime/crt/aot_executor.h index c6a9f022d25e..4783adec8eae 100644 --- a/include/tvm/runtime/crt/aot_executor.h +++ b/include/tvm/runtime/crt/aot_executor.h @@ -92,6 +92,16 @@ int TVMAotExecutor_GetNumOutputs(TVMAotExecutor* executor); */ int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name); +/*! + * \brief Return a pointer to name of input with the specified input index + * + * \param executor Pointer to executor instance, created by TVMAotExecutor_Create(). + * \param index Input index for retrieving name. + * \param name Output for retrieving name. + * \return Pointer to input name in `name`. + */ +int TVMAotExecutor_GetInputName(TVMAotExecutor* executor, int index, char** name); + /*! * \brief Run the generated program. * diff --git a/python/tvm/micro/__init__.py b/python/tvm/micro/__init__.py index 22a34ef69bd1..a2dd66e07730 100644 --- a/python/tvm/micro/__init__.py +++ b/python/tvm/micro/__init__.py @@ -29,6 +29,7 @@ from .session import ( create_local_graph_executor, create_local_debug_executor, + create_local_aot_executor, Session, SessionTerminatedError, ) diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index 7d01baa75289..dacff9aa6d80 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -24,6 +24,8 @@ import pathlib import shutil from typing import Union + +from tvm.runtime.executor.aot_executor import AotModule from ..error import register_error from .._ffi import get_global_func, register_func from ..contrib import graph_executor @@ -259,6 +261,22 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None): ) +def create_local_aot_executor(session: Session): + """Create a local AoT executor driving execution on the remote CPU device given. + + Parameters + ---------- + session : Session + A microTVM device session. + + Returns + ------- + tvm.runtime.executor.aot_executor.AotModule : + A local AoT executor instance that executes on the remote device. + """ + return AotModule(session.create_aot_executor()) + + @register_func("tvm.micro.compile_and_create_micro_session") def compile_and_create_micro_session( mod_src_bytes: bytes, diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 112e5558fef9..f2feed9fd629 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -575,7 +575,7 @@ def _make_executor(self, expr=None): ret_type = self.mod["main"].checked_type.ret_type if _ty.is_dynamic(ret_type): raise ValueError("AOT Executor only supports static graphs, got output type", ret_type) - mod = build(self.mod, target=self.target) + mod = build(self.mod, target=self.target, executor=Executor("aot")) # NOTE: Given AOT requires use of the "c" backend, must export/import to compile the # generated code. diff --git a/python/tvm/runtime/executor/aot_executor.py b/python/tvm/runtime/executor/aot_executor.py index 9ef0d1dee894..f7b911cb1718 100644 --- a/python/tvm/runtime/executor/aot_executor.py +++ b/python/tvm/runtime/executor/aot_executor.py @@ -67,6 +67,7 @@ def __init__(self, module): self._get_num_outputs = module["get_num_outputs"] self._get_input_index = module["get_input_index"] self._get_num_inputs = module["get_num_inputs"] + self._get_input_name = module["get_input_name"] def set_input(self, key=None, value=None, **params): """Set inputs to the module via kwargs @@ -180,3 +181,21 @@ def get_output(self, index, out=None): return out return self._get_output(index) + + def get_input_name(self, index: int) -> str: + """Return the name of input with index `index`""" + return self._get_input_name(index) + + def get_input_info(self): + """Return the 'shape' and 'dtype' dictionaries of the module.""" + self.get_input_name(0) + + shape_dict = dict() + dtype_dict = dict() + for ind in range(0, self.get_num_inputs()): + input_name = self.get_input_name(ind) + input_tensor = self.get_input(ind) + shape_dict[input_name] = input_tensor.shape + dtype_dict[input_name] = input_tensor.dtype + + return shape_dict, dtype_dict diff --git a/src/runtime/aot_executor/aot_executor.cc b/src/runtime/aot_executor/aot_executor.cc index 292fe4fd64ce..39d5570030d6 100644 --- a/src/runtime/aot_executor/aot_executor.cc +++ b/src/runtime/aot_executor/aot_executor.cc @@ -156,6 +156,9 @@ PackedFunc AotExecutor::GetFunction(const std::string& name, CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string"; *rv = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String())); }); + } else if (name == "get_input_name") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetInputName(args[0]); }); } else { return PackedFunc(); } @@ -191,6 +194,11 @@ int AotExecutor::GetInputIndex(const std::string& name) { return -1; } +std::string AotExecutor::GetInputName(int index) { + auto inputs = metadata_->inputs(); + return inputs[index]->name(); +} + int AotExecutor::GetOutputIndex(const std::string& name) { auto outputs = metadata_->outputs(); for (unsigned int i = 0; i < outputs.size(); i++) { diff --git a/src/runtime/aot_executor/aot_executor.h b/src/runtime/aot_executor/aot_executor.h index ccbcf8fdf3d2..cc86381624ce 100644 --- a/src/runtime/aot_executor/aot_executor.h +++ b/src/runtime/aot_executor/aot_executor.h @@ -69,6 +69,13 @@ class TVM_DLL AotExecutor : public ModuleNode { */ int GetInputIndex(const std::string& name); + /*! + * \brief Get the input name given the index of input. + * \param index The index of the input. + * \return The name of input. + */ + std::string GetInputName(int index); + /*! * \brief Get the output index given the name of output. * \param name The name of the output. diff --git a/src/runtime/crt/aot_executor/aot_executor.c b/src/runtime/crt/aot_executor/aot_executor.c index ae007037e6cc..8a47bb008bf0 100644 --- a/src/runtime/crt/aot_executor/aot_executor.c +++ b/src/runtime/crt/aot_executor/aot_executor.c @@ -82,6 +82,12 @@ int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name) { return rv; } +int TVMAotExecutor_GetInputName(TVMAotExecutor* executor, int index, char** name) { + const TVMMetadata* md = executor->metadata; + *name = md->inputs[index].name; + return 0; +} + int TVMAotExecutor_Run(TVMAotExecutor* executor) { const char* tvm_main_suffix = "_run"; char tvm_main_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; diff --git a/src/runtime/crt/aot_executor_module/aot_executor_module.c b/src/runtime/crt/aot_executor_module/aot_executor_module.c index 5dd11c3dbc7e..a5c8105144f7 100644 --- a/src/runtime/crt/aot_executor_module/aot_executor_module.c +++ b/src/runtime/crt/aot_executor_module/aot_executor_module.c @@ -147,6 +147,24 @@ int32_t TVMAotExecutorModule_GetInputIndex(TVMValue* args, int* tcodes, int narg return 0; } +int32_t TVMAotExecutorModule_GetInputName(TVMValue* args, int* tcodes, int nargs, + TVMValue* ret_values, int* ret_tcodes, + void* resource_handle) { + if (nargs != 1) { + return kTvmErrorFunctionCallNumArguments; + } + + char* name; + int ret = TVMAotExecutor_GetInputName(aot_executor.executor, args[0].v_int64, &name); + if (ret < 0) { + return kTvmErrorExecutorModuleNoSuchInput; + } + + ret_values[0].v_str = name; + ret_tcodes[0] = kTVMStr; + return 0; +} + int32_t TVMAotExecutorModule_GetNumInputs(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) { @@ -191,10 +209,11 @@ static const TVMBackendPackedCFunc aot_executor_registry_funcs[] = { &TVMAotExecutorModule_Run, // run &TVMAotExecutorModule_NotImplemented, // set_input (implemented via python wrapper) &TVMAotExecutorModule_NotImplemented, // share_params (do not implement) + &TVMAotExecutorModule_GetInputName, // get_input_name }; static const TVMFuncRegistry aot_executor_registry = { - "\x0a\0get_input\0" + "\x0b\0get_input\0" "get_input_index\0" "get_input_info\0" "get_num_inputs\0" @@ -203,7 +222,8 @@ static const TVMFuncRegistry aot_executor_registry = { "load_params\0" "run\0" "set_input\0" - "share_params\0", + "share_params\0" + "get_input_name\0", aot_executor_registry_funcs}; tvm_crt_error_t TVMAotExecutorModule_Register() { diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 0c5931a55d31..3c7a3bc0ca12 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -112,6 +112,12 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), loaded_mod = tvm.runtime.load_module(test_so_path) runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) runner.set_input(**inputs) + + assert runner.get_input_name(0) == "data" + shape_dict, dtype_dict = runner.get_input_info() + assert shape_dict == {"data": (1, 3, 64, 64)} + assert dtype_dict == {"data": "uint8"} + runner.run() assert (runner.get_output(0).numpy() == list(ref_outputs.values())[0]).all() diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index e51745d08be1..83fab98cf683 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -181,15 +181,18 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) { factory = tvm.relay.build(relay_mod, target=TARGET, runtime=runtime, executor=executor) def do_test(): - aot_executor = tvm.runtime.executor.aot_executor.AotModule( - sess._rpc.get_function("tvm.aot_executor.create")( - sess.get_system_lib(), sess.device, "default" - ) - ) + aot_executor = tvm.micro.create_local_aot_executor(sess) assert aot_executor.get_input_index("a") == 0 assert aot_executor.get_input_index("b") == 1 + assert aot_executor.get_input_name(0) == "a" + assert aot_executor.get_input_name(1) == "b" + + shape_dict, dtype_dict = aot_executor.get_input_info() + assert shape_dict == {"a": (1, 2), "b": (1, 2)} + assert dtype_dict == {"a": "uint8", "b": "uint8"} + assert aot_executor.get_num_inputs() == 2 assert aot_executor.get_num_outputs() == 1 @@ -246,11 +249,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8], %c : Tensor[(1 def do_test(): try: - aot_executor = tvm.runtime.executor.aot_executor.AotModule( - sess._rpc.get_function("tvm.aot_executor.create")( - sess.get_system_lib(), sess.device, "default" - ) - ) + aot_executor = tvm.micro.create_local_aot_executor(sess) except tvm._ffi.base.TVMError as excpt: raise excpt @@ -408,11 +407,9 @@ def test_autotune(): lowered = tvm.relay.build(mod, target=TARGET, runtime=runtime, params=params) temp_dir = tvm.contrib.utils.tempdir() - project = tvm.micro.generate_project(template_project_dir, lowered, temp_dir / "project") - project.build() - with tvm.micro.Session(project.transport()) as session: + with _make_session(temp_dir, lowered) as sess: graph_mod = tvm.micro.create_local_graph_executor( - lowered.get_graph_json(), session.get_system_lib(), session.device + lowered.get_graph_json(), sess.get_system_lib(), sess.device ) graph_mod.set_input(**lowered.get_params()) graph_mod.run(**inputs) @@ -425,11 +422,9 @@ def test_autotune(): lowered_tuned = tvm.relay.build(mod, target=target, runtime=runtime, params=params) temp_dir = tvm.contrib.utils.tempdir() - project = tvm.micro.generate_project(template_project_dir, lowered_tuned, temp_dir / "project") - project.build() - with tvm.micro.Session(project.transport()) as session: + with _make_session(temp_dir, lowered_tuned) as sess: graph_mod = tvm.micro.create_local_graph_executor( - lowered_tuned.get_graph_json(), session.get_system_lib(), session.device + lowered_tuned.get_graph_json(), sess.get_system_lib(), sess.device ) graph_mod.set_input(**lowered_tuned.get_params()) graph_mod.run(**inputs)