Skip to content

Commit

Permalink
[AoT]Add get_input_name function to AoT Module (apache#14071)
Browse files Browse the repository at this point in the history
* Add get_input_name to C AOT

* add get_input_name to AOT C++

* lint

* fix bug in AotExecutor
  • Loading branch information
mehrdadh authored and yongwww committed Feb 27, 2023
1 parent cb66e6c commit 279f39b
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 21 deletions.
10 changes: 10 additions & 0 deletions include/tvm/runtime/crt/aot_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ int TVMAotExecutor_GetNumOutputs(TVMAotExecutor* executor);
*/
int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name);

/*!
* \brief Return a pointer to name of input with the specified input index
*
* \param executor Pointer to executor instance, created by TVMAotExecutor_Create().
* \param index Input index for retrieving name.
* \param name Output for retrieving name.
* \return Pointer to input name in `name`.
*/
int TVMAotExecutor_GetInputName(TVMAotExecutor* executor, int index, char** name);

/*!
* \brief Run the generated program.
*
Expand Down
1 change: 1 addition & 0 deletions python/tvm/micro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .session import (
create_local_graph_executor,
create_local_debug_executor,
create_local_aot_executor,
Session,
SessionTerminatedError,
)
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import pathlib
import shutil
from typing import Union

from tvm.runtime.executor.aot_executor import AotModule
from ..error import register_error
from .._ffi import get_global_func, register_func
from ..contrib import graph_executor
Expand Down Expand Up @@ -259,6 +261,22 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None):
)


def create_local_aot_executor(session: Session):
"""Create a local AoT executor driving execution on the remote CPU device given.
Parameters
----------
session : Session
A microTVM device session.
Returns
-------
tvm.runtime.executor.aot_executor.AotModule :
A local AoT executor instance that executes on the remote device.
"""
return AotModule(session.create_aot_executor())


@register_func("tvm.micro.compile_and_create_micro_session")
def compile_and_create_micro_session(
mod_src_bytes: bytes,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def _make_executor(self, expr=None):
ret_type = self.mod["main"].checked_type.ret_type
if _ty.is_dynamic(ret_type):
raise ValueError("AOT Executor only supports static graphs, got output type", ret_type)
mod = build(self.mod, target=self.target)
mod = build(self.mod, target=self.target, executor=Executor("aot"))

# NOTE: Given AOT requires use of the "c" backend, must export/import to compile the
# generated code.
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/runtime/executor/aot_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, module):
self._get_num_outputs = module["get_num_outputs"]
self._get_input_index = module["get_input_index"]
self._get_num_inputs = module["get_num_inputs"]
self._get_input_name = module["get_input_name"]

def set_input(self, key=None, value=None, **params):
"""Set inputs to the module via kwargs
Expand Down Expand Up @@ -180,3 +181,21 @@ def get_output(self, index, out=None):
return out

return self._get_output(index)

def get_input_name(self, index: int) -> str:
"""Return the name of input with index `index`"""
return self._get_input_name(index)

def get_input_info(self):
"""Return the 'shape' and 'dtype' dictionaries of the module."""
self.get_input_name(0)

shape_dict = dict()
dtype_dict = dict()
for ind in range(0, self.get_num_inputs()):
input_name = self.get_input_name(ind)
input_tensor = self.get_input(ind)
shape_dict[input_name] = input_tensor.shape
dtype_dict[input_name] = input_tensor.dtype

return shape_dict, dtype_dict
8 changes: 8 additions & 0 deletions src/runtime/aot_executor/aot_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
*rv = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
});
} else if (name == "get_input_name") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetInputName(args[0]); });
} else {
return PackedFunc();
}
Expand Down Expand Up @@ -191,6 +194,11 @@ int AotExecutor::GetInputIndex(const std::string& name) {
return -1;
}

std::string AotExecutor::GetInputName(int index) {
auto inputs = metadata_->inputs();
return inputs[index]->name();
}

int AotExecutor::GetOutputIndex(const std::string& name) {
auto outputs = metadata_->outputs();
for (unsigned int i = 0; i < outputs.size(); i++) {
Expand Down
7 changes: 7 additions & 0 deletions src/runtime/aot_executor/aot_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ class TVM_DLL AotExecutor : public ModuleNode {
*/
int GetInputIndex(const std::string& name);

/*!
* \brief Get the input name given the index of input.
* \param index The index of the input.
* \return The name of input.
*/
std::string GetInputName(int index);

/*!
* \brief Get the output index given the name of output.
* \param name The name of the output.
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/crt/aot_executor/aot_executor.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name) {
return rv;
}

int TVMAotExecutor_GetInputName(TVMAotExecutor* executor, int index, char** name) {
const TVMMetadata* md = executor->metadata;
*name = md->inputs[index].name;
return 0;
}

int TVMAotExecutor_Run(TVMAotExecutor* executor) {
const char* tvm_main_suffix = "_run";
char tvm_main_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME];
Expand Down
24 changes: 22 additions & 2 deletions src/runtime/crt/aot_executor_module/aot_executor_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,24 @@ int32_t TVMAotExecutorModule_GetInputIndex(TVMValue* args, int* tcodes, int narg
return 0;
}

int32_t TVMAotExecutorModule_GetInputName(TVMValue* args, int* tcodes, int nargs,
TVMValue* ret_values, int* ret_tcodes,
void* resource_handle) {
if (nargs != 1) {
return kTvmErrorFunctionCallNumArguments;
}

char* name;
int ret = TVMAotExecutor_GetInputName(aot_executor.executor, args[0].v_int64, &name);
if (ret < 0) {
return kTvmErrorExecutorModuleNoSuchInput;
}

ret_values[0].v_str = name;
ret_tcodes[0] = kTVMStr;
return 0;
}

int32_t TVMAotExecutorModule_GetNumInputs(TVMValue* args, int* tcodes, int nargs,
TVMValue* ret_values, int* ret_tcodes,
void* resource_handle) {
Expand Down Expand Up @@ -191,10 +209,11 @@ static const TVMBackendPackedCFunc aot_executor_registry_funcs[] = {
&TVMAotExecutorModule_Run, // run
&TVMAotExecutorModule_NotImplemented, // set_input (implemented via python wrapper)
&TVMAotExecutorModule_NotImplemented, // share_params (do not implement)
&TVMAotExecutorModule_GetInputName, // get_input_name
};

static const TVMFuncRegistry aot_executor_registry = {
"\x0a\0get_input\0"
"\x0b\0get_input\0"
"get_input_index\0"
"get_input_info\0"
"get_num_inputs\0"
Expand All @@ -203,7 +222,8 @@ static const TVMFuncRegistry aot_executor_registry = {
"load_params\0"
"run\0"
"set_input\0"
"share_params\0",
"share_params\0"
"get_input_name\0",
aot_executor_registry_funcs};

tvm_crt_error_t TVMAotExecutorModule_Register() {
Expand Down
6 changes: 6 additions & 0 deletions tests/python/relay/aot/test_cpp_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5),
loaded_mod = tvm.runtime.load_module(test_so_path)
runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0)))
runner.set_input(**inputs)

assert runner.get_input_name(0) == "data"
shape_dict, dtype_dict = runner.get_input_info()
assert shape_dict == {"data": (1, 3, 64, 64)}
assert dtype_dict == {"data": "uint8"}

runner.run()
assert (runner.get_output(0).numpy() == list(ref_outputs.values())[0]).all()

Expand Down
31 changes: 13 additions & 18 deletions tests/python/unittest/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,18 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) {
factory = tvm.relay.build(relay_mod, target=TARGET, runtime=runtime, executor=executor)

def do_test():
aot_executor = tvm.runtime.executor.aot_executor.AotModule(
sess._rpc.get_function("tvm.aot_executor.create")(
sess.get_system_lib(), sess.device, "default"
)
)
aot_executor = tvm.micro.create_local_aot_executor(sess)

assert aot_executor.get_input_index("a") == 0
assert aot_executor.get_input_index("b") == 1

assert aot_executor.get_input_name(0) == "a"
assert aot_executor.get_input_name(1) == "b"

shape_dict, dtype_dict = aot_executor.get_input_info()
assert shape_dict == {"a": (1, 2), "b": (1, 2)}
assert dtype_dict == {"a": "uint8", "b": "uint8"}

assert aot_executor.get_num_inputs() == 2
assert aot_executor.get_num_outputs() == 1

Expand Down Expand Up @@ -246,11 +249,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8], %c : Tensor[(1

def do_test():
try:
aot_executor = tvm.runtime.executor.aot_executor.AotModule(
sess._rpc.get_function("tvm.aot_executor.create")(
sess.get_system_lib(), sess.device, "default"
)
)
aot_executor = tvm.micro.create_local_aot_executor(sess)
except tvm._ffi.base.TVMError as excpt:
raise excpt

Expand Down Expand Up @@ -408,11 +407,9 @@ def test_autotune():
lowered = tvm.relay.build(mod, target=TARGET, runtime=runtime, params=params)

temp_dir = tvm.contrib.utils.tempdir()
project = tvm.micro.generate_project(template_project_dir, lowered, temp_dir / "project")
project.build()
with tvm.micro.Session(project.transport()) as session:
with _make_session(temp_dir, lowered) as sess:
graph_mod = tvm.micro.create_local_graph_executor(
lowered.get_graph_json(), session.get_system_lib(), session.device
lowered.get_graph_json(), sess.get_system_lib(), sess.device
)
graph_mod.set_input(**lowered.get_params())
graph_mod.run(**inputs)
Expand All @@ -425,11 +422,9 @@ def test_autotune():
lowered_tuned = tvm.relay.build(mod, target=target, runtime=runtime, params=params)

temp_dir = tvm.contrib.utils.tempdir()
project = tvm.micro.generate_project(template_project_dir, lowered_tuned, temp_dir / "project")
project.build()
with tvm.micro.Session(project.transport()) as session:
with _make_session(temp_dir, lowered_tuned) as sess:
graph_mod = tvm.micro.create_local_graph_executor(
lowered_tuned.get_graph_json(), session.get_system_lib(), session.device
lowered_tuned.get_graph_json(), sess.get_system_lib(), sess.device
)
graph_mod.set_input(**lowered_tuned.get_params())
graph_mod.run(**inputs)
Expand Down

0 comments on commit 279f39b

Please sign in to comment.