Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AoT]Add get_input_name function to AoT Module #14071

Merged
merged 4 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/runtime/crt/aot_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ int TVMAotExecutor_GetNumOutputs(TVMAotExecutor* executor);
*/
int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name);

/*!
* \brief Return a pointer to name of input with the specified input index
*
* \param executor Pointer to executor instance, created by TVMAotExecutor_Create().
* \param index Input index for retrieving name.
* \param name Output for retrieving name.
* \return Pointer to input name in `name`.
*/
int TVMAotExecutor_GetInputName(TVMAotExecutor* executor, int index, char** name);

/*!
* \brief Run the generated program.
*
Expand Down
1 change: 1 addition & 0 deletions python/tvm/micro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .session import (
create_local_graph_executor,
create_local_debug_executor,
create_local_aot_executor,
Session,
SessionTerminatedError,
)
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import pathlib
import shutil
from typing import Union

from tvm.runtime.executor.aot_executor import AotModule
from ..error import register_error
from .._ffi import get_global_func, register_func
from ..contrib import graph_executor
Expand Down Expand Up @@ -259,6 +261,22 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None):
)


def create_local_aot_executor(session: Session):
"""Create a local AoT executor driving execution on the remote CPU device given.

Parameters
----------
session : Session
A microTVM device session.

Returns
-------
tvm.runtime.executor.aot_executor.AotModule :
A local AoT executor instance that executes on the remote device.
"""
return AotModule(session.create_aot_executor())


@register_func("tvm.micro.compile_and_create_micro_session")
def compile_and_create_micro_session(
mod_src_bytes: bytes,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def _make_executor(self, expr=None):
ret_type = self.mod["main"].checked_type.ret_type
if _ty.is_dynamic(ret_type):
raise ValueError("AOT Executor only supports static graphs, got output type", ret_type)
mod = build(self.mod, target=self.target)
mod = build(self.mod, target=self.target, executor=Executor("aot"))

# NOTE: Given AOT requires use of the "c" backend, must export/import to compile the
# generated code.
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/runtime/executor/aot_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, module):
self._get_num_outputs = module["get_num_outputs"]
self._get_input_index = module["get_input_index"]
self._get_num_inputs = module["get_num_inputs"]
self._get_input_name = module["get_input_name"]

def set_input(self, key=None, value=None, **params):
"""Set inputs to the module via kwargs
Expand Down Expand Up @@ -180,3 +181,21 @@ def get_output(self, index, out=None):
return out

return self._get_output(index)

def get_input_name(self, index: int) -> str:
"""Return the name of input with index `index`"""
return self._get_input_name(index)

def get_input_info(self):
"""Return the 'shape' and 'dtype' dictionaries of the module."""
self.get_input_name(0)

shape_dict = dict()
dtype_dict = dict()
for ind in range(0, self.get_num_inputs()):
input_name = self.get_input_name(ind)
input_tensor = self.get_input(ind)
shape_dict[input_name] = input_tensor.shape
dtype_dict[input_name] = input_tensor.dtype

return shape_dict, dtype_dict
8 changes: 8 additions & 0 deletions src/runtime/aot_executor/aot_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
*rv = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
});
} else if (name == "get_input_name") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetInputName(args[0]); });
} else {
return PackedFunc();
}
Expand Down Expand Up @@ -191,6 +194,11 @@ int AotExecutor::GetInputIndex(const std::string& name) {
return -1;
}

std::string AotExecutor::GetInputName(int index) {
auto inputs = metadata_->inputs();
return inputs[index]->name();
}

int AotExecutor::GetOutputIndex(const std::string& name) {
auto outputs = metadata_->outputs();
for (unsigned int i = 0; i < outputs.size(); i++) {
Expand Down
7 changes: 7 additions & 0 deletions src/runtime/aot_executor/aot_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ class TVM_DLL AotExecutor : public ModuleNode {
*/
int GetInputIndex(const std::string& name);

/*!
* \brief Get the input name given the index of input.
* \param index The index of the input.
* \return The name of input.
*/
std::string GetInputName(int index);

/*!
* \brief Get the output index given the name of output.
* \param name The name of the output.
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/crt/aot_executor/aot_executor.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name) {
return rv;
}

int TVMAotExecutor_GetInputName(TVMAotExecutor* executor, int index, char** name) {
const TVMMetadata* md = executor->metadata;
*name = md->inputs[index].name;
return 0;
}

int TVMAotExecutor_Run(TVMAotExecutor* executor) {
const char* tvm_main_suffix = "_run";
char tvm_main_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME];
Expand Down
24 changes: 22 additions & 2 deletions src/runtime/crt/aot_executor_module/aot_executor_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,24 @@ int32_t TVMAotExecutorModule_GetInputIndex(TVMValue* args, int* tcodes, int narg
return 0;
}

int32_t TVMAotExecutorModule_GetInputName(TVMValue* args, int* tcodes, int nargs,
TVMValue* ret_values, int* ret_tcodes,
void* resource_handle) {
if (nargs != 1) {
return kTvmErrorFunctionCallNumArguments;
}

char* name;
int ret = TVMAotExecutor_GetInputName(aot_executor.executor, args[0].v_int64, &name);
if (ret < 0) {
return kTvmErrorExecutorModuleNoSuchInput;
}

ret_values[0].v_str = name;
ret_tcodes[0] = kTVMStr;
return 0;
}

int32_t TVMAotExecutorModule_GetNumInputs(TVMValue* args, int* tcodes, int nargs,
TVMValue* ret_values, int* ret_tcodes,
void* resource_handle) {
Expand Down Expand Up @@ -191,10 +209,11 @@ static const TVMBackendPackedCFunc aot_executor_registry_funcs[] = {
&TVMAotExecutorModule_Run, // run
&TVMAotExecutorModule_NotImplemented, // set_input (implemented via python wrapper)
&TVMAotExecutorModule_NotImplemented, // share_params (do not implement)
&TVMAotExecutorModule_GetInputName, // get_input_name
};

static const TVMFuncRegistry aot_executor_registry = {
"\x0a\0get_input\0"
"\x0b\0get_input\0"
"get_input_index\0"
"get_input_info\0"
"get_num_inputs\0"
Expand All @@ -203,7 +222,8 @@ static const TVMFuncRegistry aot_executor_registry = {
"load_params\0"
"run\0"
"set_input\0"
"share_params\0",
"share_params\0"
"get_input_name\0",
aot_executor_registry_funcs};

tvm_crt_error_t TVMAotExecutorModule_Register() {
Expand Down
6 changes: 6 additions & 0 deletions tests/python/relay/aot/test_cpp_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5),
loaded_mod = tvm.runtime.load_module(test_so_path)
runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0)))
runner.set_input(**inputs)

assert runner.get_input_name(0) == "data"
shape_dict, dtype_dict = runner.get_input_info()
assert shape_dict == {"data": (1, 3, 64, 64)}
assert dtype_dict == {"data": "uint8"}

runner.run()
assert (runner.get_output(0).numpy() == list(ref_outputs.values())[0]).all()

Expand Down
31 changes: 13 additions & 18 deletions tests/python/unittest/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,18 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) {
factory = tvm.relay.build(relay_mod, target=TARGET, runtime=runtime, executor=executor)

def do_test():
aot_executor = tvm.runtime.executor.aot_executor.AotModule(
sess._rpc.get_function("tvm.aot_executor.create")(
sess.get_system_lib(), sess.device, "default"
)
)
aot_executor = tvm.micro.create_local_aot_executor(sess)

assert aot_executor.get_input_index("a") == 0
assert aot_executor.get_input_index("b") == 1

assert aot_executor.get_input_name(0) == "a"
assert aot_executor.get_input_name(1) == "b"

shape_dict, dtype_dict = aot_executor.get_input_info()
assert shape_dict == {"a": (1, 2), "b": (1, 2)}
assert dtype_dict == {"a": "uint8", "b": "uint8"}

assert aot_executor.get_num_inputs() == 2
assert aot_executor.get_num_outputs() == 1

Expand Down Expand Up @@ -246,11 +249,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8], %c : Tensor[(1

def do_test():
try:
aot_executor = tvm.runtime.executor.aot_executor.AotModule(
sess._rpc.get_function("tvm.aot_executor.create")(
sess.get_system_lib(), sess.device, "default"
)
)
aot_executor = tvm.micro.create_local_aot_executor(sess)
except tvm._ffi.base.TVMError as excpt:
raise excpt

Expand Down Expand Up @@ -408,11 +407,9 @@ def test_autotune():
lowered = tvm.relay.build(mod, target=TARGET, runtime=runtime, params=params)

temp_dir = tvm.contrib.utils.tempdir()
project = tvm.micro.generate_project(template_project_dir, lowered, temp_dir / "project")
project.build()
with tvm.micro.Session(project.transport()) as session:
with _make_session(temp_dir, lowered) as sess:
graph_mod = tvm.micro.create_local_graph_executor(
lowered.get_graph_json(), session.get_system_lib(), session.device
lowered.get_graph_json(), sess.get_system_lib(), sess.device
)
graph_mod.set_input(**lowered.get_params())
graph_mod.run(**inputs)
Expand All @@ -425,11 +422,9 @@ def test_autotune():
lowered_tuned = tvm.relay.build(mod, target=target, runtime=runtime, params=params)

temp_dir = tvm.contrib.utils.tempdir()
project = tvm.micro.generate_project(template_project_dir, lowered_tuned, temp_dir / "project")
project.build()
with tvm.micro.Session(project.transport()) as session:
with _make_session(temp_dir, lowered_tuned) as sess:
graph_mod = tvm.micro.create_local_graph_executor(
lowered_tuned.get_graph_json(), session.get_system_lib(), session.device
lowered_tuned.get_graph_json(), sess.get_system_lib(), sess.device
)
graph_mod.set_input(**lowered_tuned.get_params())
graph_mod.run(**inputs)
Expand Down