Skip to content

Commit

Permalink
[AOT][DeviceAPI] Add Hooks for Activate/Deactivate/Open/Close
Browse files Browse the repository at this point in the history
This adds the relevant hooks into their starting places in the code
generation. As per the [C Device API
RFC](https://github.com/apache/tvm-rfcs/blob/main/rfcs/0031-devices-api.md)
  • Loading branch information
Mousius committed Nov 12, 2021
1 parent f63a0c8 commit fb75858
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 15 deletions.
9 changes: 8 additions & 1 deletion apps/microtvm/ethosu/include/tvm_ethosu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
#include <stddef.h>
#include <stdint.h>

int32_t TVMEthosULaunch(struct ethosu_driver* resource_handle, void* cms_data, size_t cms_data_size,
typedef void tvm_device_ethos_u_t;

int32_t TVMEthosULaunch(tvm_device_ethos_u_t* resource_handle, void* cms_data, size_t cms_data_size,
uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors);

int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context);
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context);
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context);
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context);

#endif // TVM_RUNTIME_CONTRIB_ETHOSU_ETHOSU_RUNTIME_H_
8 changes: 7 additions & 1 deletion apps/microtvm/ethosu/src/tvm_ethosu_runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@

#include <ethosu_driver.h>

int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size,
int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cms_data_size,
uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors) {
struct ethosu_driver* driver = (struct ethosu_driver*)context;
int32_t result =
ethosu_invoke(driver, cms_data, cms_data_size, base_addrs, base_addrs_size, num_tensors);

Expand All @@ -32,3 +33,8 @@ int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms
}
return 0;
}

int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {}
7 changes: 6 additions & 1 deletion python/tvm/relay/backend/executor_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
----------
ir_mod : :py:class:`~tvm.IRModule`
The IR module to build.
built_ir_mods : dict[Target, IRModule]
The IR modules built per Target.
target : tvm.Target
The Target used to build this module.
libmod : tvm.Module
Expand All @@ -89,8 +91,11 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
List of devices used in the module
"""

def __init__(self, ir_mod, target, libmod, libmod_name, params, function_metadata, devices):
def __init__(
self, ir_mod, built_ir_mods, target, libmod, libmod_name, params, function_metadata, devices
):
self.ir_mod = ir_mod
self.built_ir_mods = built_ir_mods
self.target = target
self.lib = libmod
self.libmod_name = libmod_name
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self):
self._get_params_func = self.mod["get_params"]
self._get_function_metadata = self.mod["get_function_metadata"]
self._get_devices = self.mod["get_devices"]
self._get_irmodule = self.mod["get_irmodule"]

def build(
self, mod, target=None, target_host=None, params=None, executor="graph", mod_name=None
Expand Down Expand Up @@ -244,6 +245,10 @@ def get_params(self):
ret[key] = value.data
return ret

def get_irmodule(self):
"""Returns the Target IRModule's from code generation"""
return self._get_irmodule()


@register_func("tvm.relay.module_export_library")
def _module_export(module, file_name): # fcompile, addons, kwargs?
Expand Down Expand Up @@ -364,10 +369,11 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
)
func_metadata = bld_mod.get_function_metadata()
devices = bld_mod.get_devices()
final_ir_mods = bld_mod.get_irmodule()

if executor == "aot":
executor_factory = _executor_factory.AOTExecutorFactoryModule(
ir_mod, target, runtime_mod, mod_name, params, func_metadata, devices
ir_mod, final_ir_mods, target, runtime_mod, mod_name, params, func_metadata, devices
)
elif executor == "graph":
executor_factory = _executor_factory.GraphExecutorFactoryModule(
Expand Down
73 changes: 64 additions & 9 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,19 @@ class AOTExecutorCodegen : public MixedModeVisitor {
GlobalVar global_var = call_lowered_props.lowered_func;
bool has_c_device_api_context = device_contexts_.count(global_var) != 0;
if (has_c_device_api_context) {
tir::Var context = device_contexts_.Get(global_var).value();
args.push_back(device_contexts_[global_var]);
}

tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(func_call);
tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(tir::SeqStmt({
GenerateDeviceHook(context, "Open"),
func_call,
GenerateDeviceHook(context, "Close"),
}));
} else {
tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(func_call);
}

tir::Stmt body = tir::SeqStmt(create_func_call_stmts);
stmts_.push_back(body);
Expand Down Expand Up @@ -416,7 +424,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
device_context_var = (*pair).second;
} else {
main_signature_.push_back(device_context_var);
devices_.push_back(context_name);
devices_.Set(context_name, device_context_var);
target_contexts.Set(target_kind.value(), device_context_var);
}

Expand All @@ -425,6 +433,44 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
}

/**
* \brief Generates a call to a given hook for all Devices found for C Device API
* \param Name of hook to generate statements for
* \return Statement with function calls for each device
*/
tir::Stmt GenerateAllDeviceHook(const String& hook) {
std::vector<tir::Stmt> device_activations;
for (const auto& it : devices_) {
const String& device_name = it.first;
const tir::Var& context = it.second;
Array<String> sections = {"Device", device_name, hook};
String device_activation = ToCFunctionStyle(PrefixName(sections));

tir::Evaluate device_hook(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
{tvm::tir::StringImm(device_activation), context}));
device_activations.push_back(device_hook);
}
return tir::SeqStmt(device_activations);
}

/**
* \brief Generates a call to a given hook for a single Device function
* \param Var Device context to call hook on
* \param Name of hook to generate statements for
* \return Statement with function call to Device API
*/
tir::Stmt GenerateDeviceHook(const tir::Var& context, const String& hook) {
const auto& it = std::find_if(std::begin(devices_), std::end(devices_), [&](const auto& it) {
return it.second->name_hint == context->name_hint;
});
const String& device_name = (*it).first;
Array<String> sections = {"Device", device_name, hook};
String device_hook = ToCFunctionStyle(PrefixName(sections));

return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
{tvm::tir::StringImm(device_hook), context}));
}

/*!
* Utility function to string together different arguments
*/
Expand Down Expand Up @@ -586,8 +632,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));

tir::Stmt device_activations = GenerateAllDeviceHook("Activate");
tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate");
tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations});

// Make the PrimFunc
return tir::PrimFunc(main_signature_, body, VoidType(), Map<tir::Var, tir::Buffer>(),
return tir::PrimFunc(main_signature_, final_body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));
}

Expand All @@ -596,8 +646,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
runtime::Module* mod_;
/*! \brief list of input expressions (i.e., variable passed by the user) */
std::vector<Var> input_vars_;
/*! \brief list of device contexts used */
std::vector<String> devices_;
/*! \brief map of device contexts variables */
Map<String, tir::Var> devices_;
/*! \brief map of GlobalVars to C Device API contexts */
Map<GlobalVar, tir::Var> device_contexts_;
/*! \brief input and output variables belonging to the main function signature */
Expand Down Expand Up @@ -778,7 +828,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(),
[](Var input_var) -> String { return input_var->name_hint(); });

ret.metadata = runtime::Metadata(input_var_names, devices_, return_sid_.size(),
ret.metadata = runtime::Metadata(input_var_names, ListDevices(), return_sid_.size(),
runtime::kTvmExecutorAot, mod_name);
return ret;
}
Expand All @@ -787,7 +837,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
* \brief Get list of devices found
* \return List of devices
*/
Array<String> ListDevices() { return devices_; }
Array<String> ListDevices() {
std::vector<String> device_names(devices_.size());
std::transform(devices_.begin(), devices_.end(), device_names.begin(),
[](const auto& it) -> String { return it.first; });
return device_names;
}
}; // namespace backend

class AOTExecutorCodegenModule : public runtime::ModuleNode {
Expand Down
8 changes: 7 additions & 1 deletion src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@

#include <ethosu_driver.h>

int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size,
int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cms_data_size,
uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors) {
struct ethosu_driver* driver = (struct ethosu_driver*)context;
int32_t result =
ethosu_invoke(driver, cms_data, cms_data_size, base_addrs, base_addrs_size, num_tensors);

Expand All @@ -32,3 +33,8 @@ int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms
}
return 0;
}

int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {}
9 changes: 8 additions & 1 deletion src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
#include <stddef.h>
#include <stdint.h>

int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size,
typedef void tvm_device_ethos_u_t;

int32_t TVMEthosULaunch(tvm_device_ethos_u_t* resource_handle, void* cms_data, size_t cms_data_size,
uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors);

int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context);
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context);
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context);
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context);

#endif // TVM_RUNTIME_CONTRIB_ETHOSU_BARE_METAL_TVM_ETHOSU_RUNTIME_H_
104 changes: 104 additions & 0 deletions tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from collections import OrderedDict
import sys
import re

import numpy as np
import pytest
Expand Down Expand Up @@ -693,5 +694,108 @@ def @main(%data: Tensor[(1, 4, 4, 4), float32], %weight: Tensor[(4, 4, 3, 3), fl
assert source.count("TVMBackendAllocWorkspace") == 3


def test_device_api_hooks():
"""Check for Device API hooks"""

# Ideally we should have a sample Target registered here
# but we're going to re-use this for now
pytest.importorskip("ethosu.vela")
import tensorflow as tf
import tflite.Model

from tests.python.contrib.test_ethosu import infra
from tvm.relay.op.contrib.ethosu import partition_for_ethosu

def create_tflite_graph():
tf.config.run_functions_eagerly(True)

class Model(tf.Module):
@tf.function
def tf_function(self, x):
return tf.nn.max_pool(x, [1, 2], [1, 2], "SAME")

def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple([1, 3, 4, 3]))
yield [data.astype(np.float32)]

model = Model()
concrete_func = model.tf_function.get_concrete_function(
tf.TensorSpec([1, 3, 4, 3], dtype=tf.float32)
)

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
return tflite_model

tflite_graph = create_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

relay_module, params = relay.frontend.from_tflite(
tflite_model,
shape_dict={"x": [1, 3, 4, 3]},
dtype_dict={"x": "int8"},
)
mod = partition_for_ethosu(relay_module, params)

# Generate reference data
input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)

compiled_models = infra.build_source(
mod,
input_data,
output_data,
)
main_ir_module = list(compiled_models[0].executor_factory.built_ir_mods.values())[0]
main_func = main_ir_module["run_model"]

# Activate Device
assert (
str(main_func.body[0][0].value)
== "@tir.call_extern("
+ '"TVMDeviceEthosUActivate",'
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
)
# Open Device
assert (
str(main_func.body[1].body.body[0][0][0].value)
== "@tir.call_extern("
+ '"TVMDeviceEthosUOpen",'
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
)
# Device Call
assert (
str(main_func.body[1].body.body[0][0][1].value)
== "@tir.call_extern("
+ '"tvmgen_default_ethos_u_main_0",'
+ " input: handle, output: handle,"
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
)
# Close Device
assert (
str(main_func.body[1].body.body[0][0][2].value)
== "@tir.call_extern("
+ '"TVMDeviceEthosUClose",'
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
)
# Deactivate Device
assert (
str(main_func.body[2][0].value)
== "@tir.call_extern("
+ '"TVMDeviceEthosUDeactivate",'
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit fb75858

Please sign in to comment.