Skip to content

Commit

Permalink
[Relay] Use LowerTEPass in VM (apache#9483)
Browse files Browse the repository at this point in the history
We replace use of the TECompiler::{Lower,LowerShapeFunc} methods from the VM's
compiler.cc with LowerTEPass. This clears the way for performing post-lowering
IRModule->IRModule transformations which combine Relay and TIR analysis. In particular,
it will allow us to use the PlanDevices pass to propagate memory scope constraints
across PrimFuncs.

We run LowerTEPass fairly early in the pipeline, which required quite a few passes
to become 'post-lowering friendly'. In particular, ManifestAlloc is now run after
rather than before lowering, and so must now work in a mixed Function/PrimFunc world.

The "vm.shape_func" operator has been removed since a) lowering has already generated
the necessary dynamic shape function, and b) the call to that function can be
represented by an 'ordinary' vm.invoke_tvm_op call.

We worked our way through the following glitches:
 - Dynamic shape functions are now given their true type (rather than the type of
   the primitive function they are paired with).
 - Lowering was choosing definitional GlobalVars which were not pointer-equal to the
   referential GlobalVars left behind in the rewritten Calls. We fixed that in
   te_compiler.cc, though better would be to push GlobalVars deeper into the
   lowering machinery.
 - device_copy was rewritten to a call to @__copy without any definition. Though we
   tried adding it as a global this (obviously in retrospect...) won't typecheck if
   there are multiple device_copies in the program. Instead leave device_copy unchanged
   during lowering and update each executor codegen to look for them specially.
 - Calls to already-compiled BYOC functions were indistinguishable from calls
   to (non-primitive) Relay functions. We move them into the call_lowered calling
   convention, and leave behind a Function tagged with "ExternalSymbol". Better would
   be a first-class representatn for externals in the IRModule but one step at a time.
 - Functions with dynamic shapes tagged for BYOC compilation were not tracking their
   connection to their dynamic shape function. We now use exactly the same attributes
   as for non-BYOC primitives.
 - VerilatorRuntime can legitimately be deleted before initialized.
 - IRModule attributes must be preserved. In particular, since LowerTEPass can
   be invoked more than once we need to be careful to preserve any existing external
   modules and other attributes gatherd from an earlier LowerTEPass.
 - GetUniqueName accounts for existing definitions in the module, but is not used
   for external functions since their intended names are communicated to the codegen
   toolchain via the already fixed "global_symbol" attribute.
  • Loading branch information
mbs-octoml authored and ylc committed Jan 7, 2022
1 parent 0ce3f0f commit 35efe66
Show file tree
Hide file tree
Showing 50 changed files with 1,150 additions and 1,193 deletions.
14 changes: 14 additions & 0 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,20 @@ Function WithFields(Function function, Optional<Array<Var>> opt_params = Optiona
Optional<DictAttrs> opt_attrs = Optional<DictAttrs>(),
Optional<Span> opt_span = Optional<Span>());

/*
* \brief Returns the Relay FunctionNode represented by base_func if it should be optimized,
* otherwise returns nullptr.
*
* This means returns nullptr:
* - For PrimFuncs, since not Relay Functions.
* - For Functions marked for external compilation (with "Compiler").
* - For Functions marked as already having an external definition (with "ExternalSymbol").
* - For Functions marked as not to be optimized (with "SkipOptimization").
*
* TODO(mbs): Audit all enumerations of IRModule::functions to use this or some family of such.
*/
const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func);

/*!
* \brief namespace of the attributes that can be attached to a relay::Function.
*/
Expand Down
11 changes: 5 additions & 6 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,20 @@ class Executable : public ModuleNode {

/*!
* \brief Returns a description of all the constants in the executable in human-readable
* format. Not intended to be machine readable, but rather to help with debugging and
* diffing generated code.
* format. Intended for debugging and diff-testing.
*/
std::string GetConstants() const;

/*!
* \brief Returns a description of all the (virtual) devices in the executable in human-readable
* format.
* format. Intended for debugging and diff-testing.
*/
std::string GetVirtualDevices() const;

/*!
* \brief Returns a description of all the 'primitive' (ie PackedFuncs) in the executable.
* These correspond to eithed PrimFuncs we've compiled locally, or functions compiled by
* a BYOC external codegen.
* \brief Returns a description of all the 'primitive' (ie PackedFuncs) in the executable in
* human-readable format. These correspond either to PrimFuncs we've compiled locally, or
* functions compiled by a BYOC external codegen. Intended for debugging and diff-testing.
*/
std::string GetPrimitives() const;

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/target/se_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ class SEScope : public ObjectRef {

/*! \brief Returns the \p SEScope for \p target. */
static SEScope ForTarget(Target target) {
return SEScope(static_cast<DLDeviceType>(target->kind->device_type), /*virtual_device_id=*/0,
std::move(target));
DLDeviceType device_type = static_cast<DLDeviceType>(target->kind->device_type);
return SEScope(device_type, /*virtual_device_id=*/0, std::move(target));
}

/*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ def virtual_devices(self):
return self._get_virtual_devices()

@property
def primitive(self):
"""Returns a human-readable dencription of all the primitives (ie PackedFuncs) in the
def primitives(self):
"""Returns a human-readable description of all the primitives (ie PackedFuncs) in the
executable"""
return self._get_primitives()

Expand Down
26 changes: 17 additions & 9 deletions src/relay/analysis/call_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "call_graph.h"

#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/object.h>

Expand All @@ -33,6 +34,8 @@
#include <unordered_set>
#include <vector>

#include "../op/call/call.h"

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -64,9 +67,19 @@ void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) {
// post-order visitor will visit each AST node of the current function to
// figure out the dependencies between functions.
PostOrderVisit(func, [&](const Expr& expr) {
if (const GlobalVarNode* gvn = expr.as<GlobalVarNode>()) {
auto callee = GetRef<GlobalVar>(gvn);
cg_node->AddCalledGlobal(LookupGlobalVar(callee));
// TODO(mbs): Cleanup shapes functions.
if (const auto* call_node = expr.as<CallNode>()) {
CallLoweredProps props = GetCallLoweredProps(call_node);
if (props.lowered_func.defined() && props.attrs.metadata.count("prim_shape_fn_var")) {
// We are implicitly calling the shape function *in addition to* the call target.
CallGraphEntry* callee_cg_node =
LookupGlobalVar(Downcast<GlobalVar>(props.attrs.metadata["prim_shape_fn_var"]));
cg_node->AddCalledGlobal(callee_cg_node);
}
} else if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
auto callee = GetRef<GlobalVar>(global_var_node);
CallGraphEntry* callee_cg_node = LookupGlobalVar(callee);
cg_node->AddCalledGlobal(callee_cg_node);
}
});
}
Expand All @@ -88,21 +101,16 @@ CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) {
BaseFunc CallGraphNode::GetGlobalFunction(const GlobalVar& var) const {
ICHECK(module->ContainGlobalVar(var->name_hint))
<< "GlobalVar " << var->name_hint << " not found in the current ir module";
return module->Lookup(var);
return module->Lookup(var->name_hint);
}

// Query the existence of a GlobalVar in the call graph. It creates an entry if
// there is no such node available.
CallGraphEntry* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) {
ICHECK(gv.defined());

// This inserts an element to the call graph if it is not there yet.
auto& call_graph_node = call_graph_[gv];
if (call_graph_node) return call_graph_node.get();

ICHECK(module->ContainGlobalVar(gv->name_hint))
<< "GlobalVar " << gv->name_hint << " not found in the current ir module";

// Create the node for the inserted entry.
call_graph_node = std::unique_ptr<CallGraphEntry>(new CallGraphEntry(gv));
return call_graph_node.get();
Expand Down
66 changes: 37 additions & 29 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

#include "../op/annotation/annotation.h"
#include "../op/call/call.h"
#include "../op/memory/device_copy.h"
#include "../transforms/device_aware_visitors.h"
#include "./name_transforms.h"
#include "./te_compiler.h"
Expand All @@ -52,7 +53,6 @@ namespace tvm {
namespace relay {
namespace backend {

using IntegerArray = Array<Integer>;
using StorageMap =
std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;

Expand All @@ -71,6 +71,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {

StorageMap GetStorageMap() const { return storage_device_map_; }

using ExprVisitor::VisitExpr_;

void VisitExpr_(const ConstantNode* op) final {
CreateStorage(op);
AssignReturnSid(GetRef<Expr>(op));
Expand Down Expand Up @@ -225,7 +227,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
/*!
* \brief Create storage to hold the result of evaluating \p expr in \p se_scope.
*/
void CreateStorage(const Expr& expr, SEScope se_scope) {
void CreateStorage(const Expr& expr, const SEScope& se_scope) {
ICHECK(!se_scope->IsFullyUnconstrained()) << "invalid SEScope for expr:" << std::endl
<< PrettyPrint(expr);
std::vector<int64_t> storage_ids;
Expand Down Expand Up @@ -314,9 +316,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {
/*!
* brief Create a function call
* \param call_lowered_props The lowered function and the arguments to call it with
* \param call The call we got func and args from
* \param result_expr The call we got func and args from (so as to recover the storage
* ids to hold the result).
*/
void CreateFuncCall(CallLoweredProps call_lowered_props, Call call) {
void CreateFuncCall(CallLoweredProps call_lowered_props, const Expr& result_expr) {
std::string func_name = call_lowered_props.lowered_func->name_hint;
tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
std::vector<tir::Stmt> create_func_call_stmts;
Expand All @@ -335,9 +338,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
}

auto ret_expr = Downcast<Expr>(call);
// Pack the return(s) value. A call node can produce multiple outputs
for (const auto& var : PackSid(ret_expr)) {
for (const auto& var : PackSid(result_expr)) {
args.push_back(var);
}

Expand Down Expand Up @@ -507,23 +509,25 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

void VisitExpr_(const CallNode* call_node) override {
// Descend the call tree
CallLoweredProps call_lowered_props;
if (const auto* gvn = call_node->op.as<GlobalVarNode>()) { // Lowered extern function
ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes.";
for (const auto& arg : call_node->args) {
VisitExpr(arg);
}
call_lowered_props = CallLoweredProps{GetRef<GlobalVar>(gvn), call_node->args, {}};
} else {
call_lowered_props = GetCallLoweredProps(call_node);
ICHECK(call_lowered_props.lowered_func.defined())
<< "Operators should be transformed away; Try "
"applying the fuse_ops transformation to the "
"expression.";
for (const auto& arg : call_lowered_props.arguments) {
VisitExpr(arg);
}
DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node);
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);

if (device_copy_props.body.defined()) {
// TODO(mbs): device_copy cleaunp
// Suspect treating as no-op is better since already built into the StorageInfo?
LOG(FATAL) << "The AOT executor does not currently support device_copy";
return;
}

// At this point we should only see calls of the form call_lowered(@callee, (args...)),
// where @callee can be a PrimFunc we've compiled or an external function supplied via
// some other mechanism.
ICHECK(call_lowered_props.lowered_func.defined())
<< "AOT does not support calling Relay functions. Attempting to call:" << std::endl
<< PrettyPrint(GetRef<Call>(call_node));
for (const auto& arg : call_lowered_props.arguments) {
// Evaluate the args
VisitExpr(arg);
}
CreateFuncCall(call_lowered_props, GetRef<Call>(call_node));
}
Expand Down Expand Up @@ -722,18 +726,21 @@ class AOTExecutorCodegen : public MixedModeVisitor {
for (const auto& kv : targets_) {
VLOG(1) << "target: " << kv.second->ToDebugString();
}
if (target_host_.defined()) {
VLOG(1) << "target host: " << target_host_->ToDebugString();
}
ICHECK(target_host_.defined()) << "require a target_host to be given for AOT codegen";
VLOG(1) << "target host: " << target_host_->ToDebugString();

Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
String interface_api = executor_config->GetAttr<String>("interface-api").value_or("packed");
Integer workspace_byte_alignment =
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
use_unpacked_api_ = executor_config->GetAttr<Bool>("unpacked-api").value_or(Bool(false));

IRModule lowered_mod =
tec::LowerTEPass(mod_name, [this, workspace_byte_alignment](BaseFunc func) {
// TODO(mbs): Plumb from compiler config
SEScope host_se_scope = SEScope::ForTarget(target_host_);

IRModule lowered_mod = tec::LowerTEPass(
mod_name,
[this, workspace_byte_alignment](BaseFunc func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -745,7 +752,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment);
})(mod);
},
host_se_scope)(mod);

auto lowered_main = lowered_mod->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/contrib/verilator/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.verilator.options", VerilatorOptions)
* compile it into a Verilator runtime module.
*/
runtime::Module VerilatorBackend(const ObjectRef& ref) {
VLOG(0) << "compiling for verilator runtime";
CHECK(ref->IsInstance<FunctionNode>());
auto func = Downcast<Function>(ref);
auto func_name = GetExtSymbol(func);
Expand Down
Loading

0 comments on commit 35efe66

Please sign in to comment.