Skip to content

Commit

Permalink
Refactor the compile engine into a cleaner interface.
Browse files Browse the repository at this point in the history
Duplicate the CompileEngine interface.

Refactor the graph_runtime_codegen to invoke the new LowerTE pass

More changes

Things appear to be working

Some tracing to get Relay code to flow through too.

Disable some assertions as exp.

Tweak printing for now

Fix a few bugs: (#13)

1. Don't add relay main function to list of lowered TIR functions
2. Don't skip visiting call to relay function in graph runtime codegen

Remove debug prints.

Start refactoring

Split out shared data structures

Fix implicit duplicate decl of IsDynamic

Clean up handling of name + global prim fn

Clean up the code and debug issue introduced by previous hack

Clean up the debugging

Do C++ lint clean up

Update src/relay/backend/graph_executor_codegen.cc

Co-authored-by: Chris Sullivan <[email protected]>

Clean up handling of external functions

Add more error messages

More clean up

Update src/runtime/graph_executor/graph_executor.cc

Co-authored-by: Chris Sullivan <[email protected]>

Update src/runtime/graph_executor/graph_executor.cc

Co-authored-by: Chris Sullivan <[email protected]>

Update src/relay/backend/te_compiler.h

Co-authored-by: Haichen Shen <[email protected]>

Update src/relay/backend/te_compiler.h

Co-authored-by: Haichen Shen <[email protected]>

Fix

CR

More CR

Format

Fix lowering path for C++

Fix tests

Remove uncessary change

Clean up a few more things

CI fix

Fix the default context

Fix

Fix broken test cases

Update

Fix

WIP

Clean up storage data structures

WIP

WIP

Fix build errors

Remove TVMLower

Fix lint

Lint again

fix black

Move UpdateMainWorkspaceSize into te_compiler.cc

Fix link errors

Formatting

Change UpdateMainWorkspaceSize to return Map<String, FunctionInfo>

Workaround for GCC 5 error caused by enums in maps (GCC 5 is on i386 CI)

Testing how functions should be named

Lint

Change how function metadata is updated

Attempt to update aot_executor_codegen to use new StaticMemoryPlan instead of storage_device_map

Pass memory plan through LowerTE into UpdateMainWorkspaceSize so that we don't need to run GraphPlanMemory an extra time

Fix return in UpdateMainWorkspaceSize

Lint

Try to fix UpdateMainWorkspaceSize

Fix construction of static memory plan

Clean up code while debugging

Adding UpdateWorkspaceSize back

Add closure + call to UpdateFunctionMetadata (WIP)

UpdateFunctionMetadata builds; weird error with device ctx map though. Not sure if it came from this change or something else

Add some debugging of UpdateMainWorkspaceSize

Starting to move UpdateFunctionMetadata call to use process_fn infra

UWhat target should be passed to UpdateFunctionMetadata?

UpdateFunctionMetadata is not workinggg

Added some comments about UpdateFunctionMetadata for Jared

Fix the creation of function metadata

Try another stab at cleaning up the information

Fix

Port StorageInfo and StaticMemoryPlan data structure (apache#8297)

Restoring reshape opt

Fix tests

Caught a nasty typo from Lily, Map::Set does not mutate

Format

Disable stupid Google style warning
  • Loading branch information
jroesch committed Jun 29, 2021
1 parent 61a6ea1 commit f82e60d
Show file tree
Hide file tree
Showing 28 changed files with 2,289 additions and 1,161 deletions.
12 changes: 12 additions & 0 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
}
};

/*!
* \brief Options for the operators used to annotate a compiler.
*/
struct TIRCallAttrs : public tvm::AttrsNode<TIRCallAttrs> {
/*! \brief A 3rd party compiler for code generation. */
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
1 change: 1 addition & 0 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def add_workload_input_names(self, workload_key, input_names):
@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
def enter_layout_rewrite():
"""Enter layout rewrite tracing environment"""
# import pdb; pdb.set_trace()
env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE)
env.__enter__()

Expand Down
1 change: 1 addition & 0 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _build_function_memory_map(function_metadata):
2.) A global memory requirement if all functions are executed sequentially
"""
device_max_workspace = dict()
print("TOTAL FUNCTION METADATA: ", function_metadata)
main_func_metadata = function_metadata[MAIN_FUNC_NAME_STR]
num_targets = len(main_func_metadata.workspace_sizes.items())
func_entries = []
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def dump(self):
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint)
res += "----relay function----\n"
res += k.source_func.astext() + "\n"
res += "----tir function----- \n"
Expand All @@ -444,7 +444,7 @@ def dump(self):
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint)
res += "----relay function----\n"
res += k.source_func.astext() + "\n"
res += "----tir function----- \n"
Expand Down
17 changes: 16 additions & 1 deletion python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import tvm._ffi
from tvm._ffi import base as _base
from tvm.runtime import NDArray, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar
from tvm.ir import RelayExpr, GlobalVar, Node

from .base import RelayNode
from . import _ffi_api
Expand Down Expand Up @@ -538,3 +538,18 @@ def bind(expr, binds):
The expression or function after binding.
"""
return _ffi_api.Bind(expr, binds)


@tvm._ffi.register_object("relay.StorageInfo")
class StorageInfo(Node):
@property
def storage_ids(self):
return _ffi_api.StorageInfoStorageIds(self)

@property
def device_types(self):
return _ffi_api.StorageInfoDeviceTypes(self)

@property
def storage_sizes(self):
return _ffi_api.StorageInfoStorageSizes(self)
10 changes: 7 additions & 3 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,18 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
}

if (target->kind->device_type == kDLCPU && target_host == target) {
ICHECK(mdevice->functions.empty()) << "No device code should be generated when target "
<< "and host_target are both llvm target."
<< "\n";
// TODO(@jroesch): This check is no longer true we need to figure out if we care about this.
// We need to relax this check for just TIR functions.
// ICHECK(mdevice->functions.empty()) << "No device code should be generated when target "
// << "and host_target are both llvm target."
// << "\n";
}

return {mhost, mdevice};
}

// Can we make this take one annotated IRModule?
//
// Build for heterogeneous execution.
runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target& target_host_arg) {
auto pass_ctx = transform::PassContext::Current();
Expand Down
100 changes: 63 additions & 37 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,28 @@

namespace tvm {
namespace relay {
// TODO(@jroesch, @csullivan): declare directly elsewhere
backend::StaticMemoryPlan GraphPlanMemory(const Function& func);
namespace backend {

using IntegerArray = Array<Integer>;
using TargetsMap = std::unordered_map<int, Target>;

class AotReturnSidVisitor : public ExprVisitor {
public:
explicit AotReturnSidVisitor(Map<Expr, Array<IntegerArray>> storage_device_map)
: storage_device_map_{storage_device_map}, return_sid_{-1} {}
explicit AotReturnSidVisitor(Map<Expr, StorageInfo> storage_info_map)
: storage_info_map_{storage_info_map}, return_sid_{-1} {}

IntegerArray FindReturnSid(Function func) {
std::vector<int64_t> FindReturnSid(Function func) {
VisitExpr(func->body);
return return_sid_;
}

protected:
void AssignReturnSid(Expr e) {
auto iter = storage_device_map_.find(e);
if (iter != storage_device_map_.end()) {
return_sid_ = (*iter).second[0];
auto iter = storage_info_map_.find(e);
if (iter != storage_info_map_.end()) {
return_sid_ = (*iter).second->storage_ids;
}
}

Expand All @@ -88,8 +90,8 @@ class AotReturnSidVisitor : public ExprVisitor {
}

private:
Map<Expr, Array<IntegerArray>> storage_device_map_;
IntegerArray return_sid_;
Map<Expr, StorageInfo> storage_info_map_;
std::vector<int64_t> return_sid_;
};

/*! \brief Code generator for AOT executor */
Expand Down Expand Up @@ -120,15 +122,16 @@ class AOTExecutorCodegen : public ExprVisitor {
* \brief Return a vector of variables that represents the sids for the given Relay Expr
*/
std::vector<tir::Var> PackSid(Expr expr) {
Array<IntegerArray> sids = storage_device_map_[expr];
Map<Expr, StorageInfo> storage_info_map = memory_plan_->expr_to_storage_info;
StorageInfo storage_info = storage_info_map[expr];
// std::vector<int64_t> sids = storage_device_map_[expr];
std::vector<tir::Var> sid_vars;

// Note that an expression can have multiple sids associated with it
// e.g., returning multiple values from a function
for (const auto& sid : sids[0]) {
for (int64_t sid : storage_info->storage_ids) {
// Determine if an sid is an output buffer
int sid_int = static_cast<int>((sid.as<IntImmNode>())->value);
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid_int);
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
sid_vars.push_back(main_signature_[input_vars_.size() + output_index]);
Expand Down Expand Up @@ -346,6 +349,7 @@ class AOTExecutorCodegen : public ExprVisitor {
fi_node->tir_primfuncs.Set(primfunc_target, primfunc);
fi_node->relay_primfuncs.Set(primfunc_target, relay_func);
}
std::cout << "Update function metadata called" << std::endl;
function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node));
}

Expand Down Expand Up @@ -385,13 +389,14 @@ class AOTExecutorCodegen : public ExprVisitor {
UpdateConstants(func, &params_);

// Generate the TIR function call
CreateFuncCall(GetRef<Call>(op), ext_func->func_name);
CreateFuncCall(GetRef<Call>(op), ext_func->prim_fn_var->name_hint);
return;
}

ICHECK_GE(storage_device_map_.count(expr), 0);
auto& device_type = storage_device_map_[expr][1];
auto call_dev_type = device_type[0]->value;
Map<Expr, StorageInfo> storage_info_map = memory_plan_->expr_to_storage_info;
ICHECK_GE(storage_info_map.count(expr), 0);
auto& device_type = storage_info_map[expr]->device_types;
auto call_dev_type = device_type[0]; // TODO(@electriclilies): what is happening here
// Normal Relay Function
if (targets_.size() == 1) {
// homogeneous execution.
Expand Down Expand Up @@ -420,42 +425,49 @@ class AOTExecutorCodegen : public ExprVisitor {
UpdateFunctionMetadata(lowered_func, func, target);

// Generate the TIR function call
CreateFuncCall(GetRef<Call>(op), lowered_func->func_name);
CreateFuncCall(GetRef<Call>(op), lowered_func->prim_fn_var->name_hint);
}

void VisitExpr_(const VarNode* op) override {
Expr expr = GetRef<Expr>(op);

// If the Var node is an output node we need to copy the content of the variable to the output
// It's safe to check the SID here because Var StorageToken are never reallocated
Array<IntegerArray> sids = storage_device_map_[expr];
Map<Expr, StorageInfo> storage_info_map = memory_plan_->expr_to_storage_info;
std::vector<int64_t> sids = storage_info_map[expr]->storage_ids;
std::vector<DLDeviceType> device_types = storage_info_map[expr]->device_types;

auto output_iter = std::find(return_sid_.begin(), return_sid_.end(),
static_cast<int>((sids[0][0].as<IntImmNode>())->value));
// TODO(@electriclilies): Not sure if this is right
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sids[0]);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
auto var_expr = FindExpr(expr);
CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0], sids[2][0]);
CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0],
device_types[0]);
}
}

void VisitExpr_(const ConstantNode* op) override {
Expr expr = GetRef<Expr>(op);
size_t index = params_.size();
std::string name = "p" + std::to_string(index);
Map<Expr, StorageInfo> storage_info_map = memory_plan_->expr_to_storage_info;

param_storage_ids_[name] = storage_device_map_[expr][0][0]->value;
param_storage_ids_[name] = storage_info_map[expr]->storage_ids[0];
params_[name] = op->data;
params_by_expr_.Set(expr, name);

// If the Constant node is an output node we need to copy the content of the parameter to the
// output A Var node can only produce a single output
Array<IntegerArray> sids = storage_device_map_[expr];
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(),
static_cast<int>((sids[0][0].as<IntImmNode>())->value));
std::vector<int64_t> sids = storage_info_map[expr]->storage_ids;
// Array<IntegerArray> sids = storage_device_map_[expr];
// TODO(@electriclilies): this might be wrong, hacked in change for now
std::vector<int64_t> storage_sizes = storage_info_map[expr]->storage_sizes_in_bytes;
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sids[0]);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr), sids[2][0]);
CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr),
storage_sizes[0]);
}
}

Expand Down Expand Up @@ -502,19 +514,31 @@ class AOTExecutorCodegen : public ExprVisitor {
// Allocate the sids
std::unordered_map<int, bool> allocated;

for (auto kv : storage_device_map_) {
for (auto kv : memory_plan_->expr_to_storage_info) {
// Only allocate sids that are needed
auto expr = kv.first;
auto storage_info = kv.second;
auto sids = storage_info->storage_ids;
auto device_types = storage_info->device_types;
auto storage_sizes_in_bytes = storage_info->storage_sizes_in_bytes;
// sids = kv.second[0]
// devices = kv.second[1]
//

const bool is_input =
(std::find(input_vars_.begin(), input_vars_.end(), kv.first) != input_vars_.end());
const bool is_param = (params_by_expr_.find(kv.first) != params_by_expr_.end());
if (is_input || is_param) {
continue;
}

for (unsigned int i = 0; i < kv.second[0].size(); i++) {
int size = kv.second[2][i];
int sid = static_cast<int>((kv.second[0][i].as<IntImmNode>())->value);
CHECK_EQ(sids.size(), storage_sizes_in_bytes.size())
<< "the mapping between storage ids and sizes is incorrect"
<< "found " << sids.size() << " ids and " << storage_sizes_in_bytes.size() << "sizes";

for (unsigned int i = 0; i < sids.size(); i++) {
int size = storage_sizes_in_bytes.at(i);
int sid = sids.at(i);
if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) {
continue;
}
Expand Down Expand Up @@ -578,7 +602,8 @@ class AOTExecutorCodegen : public ExprVisitor {
std::unordered_map<std::string, int64_t> param_storage_ids_;

/*! \brief plan memory of device result */
Map<Expr, Array<IntegerArray>> storage_device_map_;
StaticMemoryPlan memory_plan_;
// Map<Expr, Array<IntegerArray>> storage_device_map_;
std::unordered_map<int, te::Var> sids_table_;
/*! \brief lowered funcs */
std::unordered_map<std::string, IRModule> lowered_funcs_;
Expand All @@ -589,7 +614,7 @@ class AOTExecutorCodegen : public ExprVisitor {
/*! \brief the set of statements that make the program */
std::vector<tir::Stmt> stmts_;
/*! \brief the list of return sids (note that the function might return more then one output */
IntegerArray return_sid_;
std::vector<int64_t> return_sid_;
/*! \brief the module name we use to mangle the function names */
String mod_name_;

Expand All @@ -603,8 +628,7 @@ class AOTExecutorCodegen : public ExprVisitor {

LoweredOutput Codegen(relay::Function func, String mod_name) {
// Get the module, storage map and token sizes
auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
storage_device_map_ = (*pf)(func);
memory_plan_ = GraphPlanMemory(func);
mod_name_ = mod_name;

for (auto input : func->params) {
Expand All @@ -613,15 +637,17 @@ class AOTExecutorCodegen : public ExprVisitor {
}

// Define the storage allocator ids
for (auto kv : storage_device_map_) {
for (const auto& sid : kv.second[0]) {
for (auto kv : memory_plan_->expr_to_storage_info) {
auto storage_info = kv.second;
auto sids = storage_info->storage_ids;
for (const auto& sid : sids) {
te::Var sid_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8))));
sids_table_[sid] = sid_var;
}
}

// Find the return sid
return_sid_ = AotReturnSidVisitor(storage_device_map_).FindReturnSid(func);
return_sid_ = AotReturnSidVisitor(memory_plan_->expr_to_storage_info).FindReturnSid(func);
for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) {
main_signature_.push_back(tir::Var("output", DataType::Handle()));
}
Expand Down
Loading

0 comments on commit f82e60d

Please sign in to comment.