Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/3][AOT][DeviceAPI] Add Hooks for Activate/Deactivate/Open/Close #9500

Merged
merged 2 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion apps/microtvm/ethosu/include/tvm_ethosu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
#include <stddef.h>
#include <stdint.h>

int32_t TVMEthosULaunch(struct ethosu_driver* resource_handle, void* cms_data, size_t cms_data_size,
typedef void tvm_device_ethos_u_t;

int32_t TVMEthosULaunch(tvm_device_ethos_u_t* resource_handle, void* cms_data, size_t cms_data_size,
uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors);

int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context);
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context);
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context);
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context);

#endif // TVM_RUNTIME_CONTRIB_ETHOSU_ETHOSU_RUNTIME_H_
8 changes: 7 additions & 1 deletion apps/microtvm/ethosu/src/tvm_ethosu_runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@

#include <ethosu_driver.h>

int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size,
int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cms_data_size,
uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors) {
struct ethosu_driver* driver = (struct ethosu_driver*)context;
int32_t result =
ethosu_invoke(driver, cms_data, cms_data_size, base_addrs, base_addrs_size, num_tensors);

Expand All @@ -32,3 +33,8 @@ int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms
}
return 0;
}

int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {}
24 changes: 22 additions & 2 deletions python/tvm/relay/backend/executor_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
----------
ir_mod : :py:class:`~tvm.IRModule`
The IR module to build.
lowered_ir_mods : dict[Target, IRModule]
The IR modules lowered per Target.
target : tvm.Target
The Target used to build this module.
libmod : tvm.Module
Expand All @@ -89,8 +91,19 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
List of devices used in the module
"""

def __init__(self, ir_mod, target, libmod, libmod_name, params, function_metadata, devices):
def __init__(
self,
ir_mod,
lowered_ir_mods,
target,
libmod,
libmod_name,
params,
function_metadata,
devices,
):
self.ir_mod = ir_mod
self.lowered_ir_mods = lowered_ir_mods
self.target = target
self.lib = libmod
self.libmod_name = libmod_name
Expand Down Expand Up @@ -136,7 +149,14 @@ class GraphExecutorFactoryModule(ExecutorFactoryModule):
"""

def __init__(
self, ir_mod, target, graph_json_str, libmod, libmod_name, params, function_metadata
self,
ir_mod,
target,
graph_json_str,
libmod,
libmod_name,
params,
function_metadata,
):
assert isinstance(graph_json_str, string_types)
fcreate = get_global_func("tvm.graph_executor_factory.create")
Expand Down
15 changes: 14 additions & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self):
self._get_params_func = self.mod["get_params"]
self._get_function_metadata = self.mod["get_function_metadata"]
self._get_devices = self.mod["get_devices"]
self._get_irmodule = self.mod["get_irmodule"]
Mousius marked this conversation as resolved.
Show resolved Hide resolved

def build(
self, mod, target=None, target_host=None, params=None, executor="graph", mod_name=None
Expand Down Expand Up @@ -249,6 +250,10 @@ def get_params(self):
ret[key] = value.data
return ret

def get_irmodule(self):
"""Returns the Target IRModule's post-lowering"""
return self._get_irmodule()


@register_func("tvm.relay.module_export_library")
def _module_export(module, file_name): # fcompile, addons, kwargs?
Expand Down Expand Up @@ -376,10 +381,18 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
)
func_metadata = bld_mod.get_function_metadata()
devices = bld_mod.get_devices()
lowered_ir_mods = bld_mod.get_irmodule()

if executor == "aot":
executor_factory = _executor_factory.AOTExecutorFactoryModule(
ir_mod, target, runtime_mod, mod_name, params, func_metadata, devices
ir_mod,
lowered_ir_mods,
target,
runtime_mod,
mod_name,
params,
func_metadata,
devices,
)
elif executor == "graph":
executor_factory = _executor_factory.GraphExecutorFactoryModule(
Expand Down
73 changes: 64 additions & 9 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,19 @@ class AOTExecutorCodegen : public MixedModeVisitor {
GlobalVar global_var = call_lowered_props.lowered_func;
bool has_c_device_api_context = device_contexts_.count(global_var) != 0;
if (has_c_device_api_context) {
tir::Var context = device_contexts_.Get(global_var).value();
args.push_back(device_contexts_[global_var]);
}

tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(func_call);
tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(tir::SeqStmt({
GenerateDeviceHook(context, "Open"),
func_call,
GenerateDeviceHook(context, "Close"),
}));
} else {
tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(func_call);
}

tir::Stmt body = tir::SeqStmt(create_func_call_stmts);
stmts_.push_back(body);
Expand Down Expand Up @@ -417,7 +425,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
device_context_var = (*pair).second;
} else {
main_signature_.push_back(device_context_var);
devices_.push_back(context_name);
devices_.Set(context_name, device_context_var);
target_contexts.Set(target_kind.value(), device_context_var);
}

Expand All @@ -426,6 +434,44 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
}

/**
* \brief Generates a call to a given hook for all Devices found for C Device API
* \param Name of hook to generate statements for
* \return Statement with function calls for each device
*/
tir::Stmt GenerateAllDeviceHook(const String& hook) {
std::vector<tir::Stmt> device_hooks;
for (const auto& it : devices_) {
const String& device_name = it.first;
const tir::Var& context = it.second;
Array<String> sections = {"Device", device_name, hook};
String device_hook_name = ToCFunctionStyle(PrefixName(sections));

tir::Evaluate device_hook(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
{tvm::tir::StringImm(device_hook_name), context}));
device_hooks.push_back(device_hook);
}
return tir::SeqStmt(device_hooks);
}

/**
* \brief Generates a call to a given hook for a single Device function
* \param Var Device context to call hook on
* \param Name of hook to generate statements for
* \return Statement with function call to Device API
*/
tir::Stmt GenerateDeviceHook(const tir::Var& context, const String& hook) {
const auto& it = std::find_if(std::begin(devices_), std::end(devices_), [&](const auto& it) {
return it.second->name_hint == context->name_hint;
});
const String& device_name = (*it).first;
Array<String> sections = {"Device", device_name, hook};
String device_hook = ToCFunctionStyle(PrefixName(sections));

return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
{tvm::tir::StringImm(device_hook), context}));
}

/*!
* Utility function to string together different arguments
*/
Expand Down Expand Up @@ -587,8 +633,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));

tir::Stmt device_activations = GenerateAllDeviceHook("Activate");
tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate");
tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations});
Mousius marked this conversation as resolved.
Show resolved Hide resolved

// Make the PrimFunc
return tir::PrimFunc(main_signature_, body, VoidType(), Map<tir::Var, tir::Buffer>(),
return tir::PrimFunc(main_signature_, final_body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));
}

Expand All @@ -597,8 +647,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
runtime::Module* mod_;
/*! \brief list of input expressions (i.e., variable passed by the user) */
std::vector<Var> input_vars_;
/*! \brief list of device contexts used */
std::vector<String> devices_;
/*! \brief map of device contexts variables */
Map<String, tir::Var> devices_;
Mousius marked this conversation as resolved.
Show resolved Hide resolved
/*! \brief map of GlobalVars to C Device API contexts */
Map<GlobalVar, tir::Var> device_contexts_;
/*! \brief input and output variables belonging to the main function signature */
Expand Down Expand Up @@ -779,7 +829,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(),
[](Var input_var) -> String { return input_var->name_hint(); });

ret.metadata = runtime::Metadata(input_var_names, devices_, return_sid_.size(),
ret.metadata = runtime::Metadata(input_var_names, ListDevices(), return_sid_.size(),
runtime::kTvmExecutorAot, mod_name);
return ret;
}
Expand All @@ -788,7 +838,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
* \brief Get list of devices found
* \return List of devices
*/
Array<String> ListDevices() { return devices_; }
Array<String> ListDevices() {
std::vector<String> device_names(devices_.size());
std::transform(devices_.begin(), devices_.end(), device_names.begin(),
[](const auto& it) -> String { return it.first; });
return device_names;
}
}; // namespace backend

class AOTExecutorCodegenModule : public runtime::ModuleNode {
Expand Down
8 changes: 7 additions & 1 deletion src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@

#include <ethosu_driver.h>

int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size,
int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cms_data_size,
uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors) {
struct ethosu_driver* driver = (struct ethosu_driver*)context;
int32_t result =
ethosu_invoke(driver, cms_data, cms_data_size, base_addrs, base_addrs_size, num_tensors);

Expand All @@ -32,3 +33,8 @@ int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms
}
return 0;
}

int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {}
9 changes: 8 additions & 1 deletion src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
#include <stddef.h>
#include <stdint.h>

int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size,
typedef void tvm_device_ethos_u_t;

int32_t TVMEthosULaunch(tvm_device_ethos_u_t* resource_handle, void* cms_data, size_t cms_data_size,
uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors);

int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context);
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context);
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context);
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context);

#endif // TVM_RUNTIME_CONTRIB_ETHOSU_BARE_METAL_TVM_ETHOSU_RUNTIME_H_
104 changes: 104 additions & 0 deletions tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from collections import OrderedDict
import sys
import re

import numpy as np
import pytest
Expand Down Expand Up @@ -693,5 +694,108 @@ def @main(%data: Tensor[(1, 4, 4, 4), float32], %weight: Tensor[(4, 4, 3, 3), fl
assert source.count("TVMBackendAllocWorkspace") == 3


def test_device_api_hooks():
"""Check for Device API hooks"""

# Ideally we should have a sample Target registered here
# but we're going to re-use this for now
pytest.importorskip("ethosu.vela")
import tensorflow as tf
import tflite.Model

from tests.python.contrib.test_ethosu import infra
from tvm.relay.op.contrib.ethosu import partition_for_ethosu

def create_tflite_graph():
tf.config.run_functions_eagerly(True)

class Model(tf.Module):
@tf.function
def tf_function(self, x):
return tf.nn.max_pool(x, [1, 2], [1, 2], "SAME")

def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple([1, 3, 4, 3]))
Copy link
Contributor

@gromero gromero Nov 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Mousius Could the unpack here be directly from the list? Like: ...rand(*[1, 3, 4, 3])?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, nice spot @gromero, by the same token I don't think we even need to unpack the list? This is the same as rand(1, 3, 4,3), I'll clear this up 😸

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Mousius Hi! Yeah, I thought of that too (avoiding the unpack too), however I assumed you would like to keep it as [1, 3, 4, 3] just to be "more explicit" by keeping the dimensions written in a form as you pass later for example to tf.TensorSpec. Either way looks fine to me, just the form with "tuple" seems superfluous :)

yield [data.astype(np.float32)]

model = Model()
concrete_func = model.tf_function.get_concrete_function(
tf.TensorSpec([1, 3, 4, 3], dtype=tf.float32)
)

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
return tflite_model

tflite_graph = create_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

relay_module, params = relay.frontend.from_tflite(
tflite_model,
shape_dict={"x": [1, 3, 4, 3]},
dtype_dict={"x": "int8"},
)
mod = partition_for_ethosu(relay_module, params)

# Generate reference data
input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)

compiled_models = infra.build_source(
mod,
input_data,
output_data,
)
main_ir_module = list(compiled_models[0].executor_factory.lowered_ir_mods.values())[0]
main_func = main_ir_module["run_model"]

# Activate Device
assert (
str(main_func.body[0][0].value)
== "@tir.call_extern("
+ '"TVMDeviceEthosUActivate",'
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
)
# Open Device
assert (
str(main_func.body[1].body.body[0][0][0].value)
== "@tir.call_extern("
+ '"TVMDeviceEthosUOpen",'
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
)
# Device Call
assert (
str(main_func.body[1].body.body[0][0][1].value)
== "@tir.call_extern("
+ '"tvmgen_default_ethos_u_main_0",'
+ " input: handle, output: handle,"
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
)
# Close Device
assert (
str(main_func.body[1].body.body[0][0][2].value)
== "@tir.call_extern("
+ '"TVMDeviceEthosUClose",'
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
)
# Deactivate Device
assert (
str(main_func.body[2][0].value)
== "@tir.call_extern("
+ '"TVMDeviceEthosUDeactivate",'
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))