Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay][vm] move vm opt passes to pass manager #3323

Merged
merged 1 commit into from
Jun 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,45 @@

Implements a Python interface to compiling and executing on the Relay VM.
"""
import numpy as np

import tvm
from tvm._ffi.function import Object
import numpy as np
from .. import ir_pass
from .. import transform
from ..backend.interpreter import Executor
from ..expr import GlobalVar, Function, Expr
from ..expr import GlobalVar, Expr
from . import _vm

Object = Object

def optimize(expr, mod=None):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=mod)
simplified_expr = ir_pass.simplify_inference(ck_expr)
simplified_expr = ir_pass.infer_type(simplified_expr, mod=mod)
fused_expr = ir_pass.fuse_ops(simplified_expr, mod=mod)
ck_fused = ir_pass.infer_type(fused_expr, mod=mod)
return ck_fused
def optimize(mod):
"""Perform several optimizations on a module before executing it in the
Relay virtual machine.

Parameters
----------
mod : tvm.relay.Module
The module to optimize.

Returns
-------
ret : tvm.relay.Module
The optimized module.
"""
main_func = mod[mod.entry_func]

opt_passes = []
if not main_func.params and isinstance(main_func.body, GlobalVar):
opt_passes.append(transform.EtaExpand())

opt_passes = opt_passes + [
transform.SimplifyInference(),
transform.FuseOps(),
transform.InferType()
]

seq = transform.Sequential(opt_passes)
return seq(mod)

def _convert(arg, cargs):
if isinstance(arg, np.ndarray):
Expand Down Expand Up @@ -76,15 +97,8 @@ def _eval_vm(mod, ctx, *args):
args: List[tvm.NDArray, np.ndarray]
The arguments to evaluate.
"""
main_func = mod[mod.entry_func]

if not main_func.params and isinstance(main_func.body, GlobalVar):
main_func = ir_pass.eta_expand(main_func.body, mod)

assert isinstance(main_func, Function)
main_func = optimize(mod[mod.entry_func], mod)
mod[mod.entry_func] = main_func

mod = optimize(mod)
args = list(args)
assert isinstance(args, list)
cargs = convert(args)
Expand Down
24 changes: 17 additions & 7 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <unordered_map>
Expand All @@ -38,15 +38,22 @@

namespace tvm {
namespace relay {

namespace transform {

Pass LambdaLift();
Pass InlinePrimitives();

} // namespace transform

namespace vm {

using namespace tvm::runtime;
using namespace tvm::runtime::vm;
using namespace relay::transform;

// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);
Module LambdaLift(const Module& module);
Module InlinePrimitives(const Module& module);

template <typename T, typename U>
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
Expand Down Expand Up @@ -560,10 +567,13 @@ VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const F
}

Module OptimizeModule(const Module& mod) {
ToANormalForm(mod->entry_func, mod);
InlinePrimitives(mod);
LambdaLift(mod);
return InlinePrimitives(mod);
transform::Sequential seq({transform::ToANormalForm(),
transform::InlinePrimitives(),
transform::LambdaLift(),
transform::InlinePrimitives()});
auto pass_ctx = transform::PassContext::Create();
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
return seq(mod);
}

void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) {
Expand Down
92 changes: 49 additions & 43 deletions src/relay/backend/vm/inline_primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>
Expand All @@ -37,6 +37,21 @@ namespace tvm {
namespace relay {
namespace vm {

// TODO(@jroesch): write verifier

/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
struct PrimitiveInliner : ExprMutator {
Module module_;
std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map;
Expand Down Expand Up @@ -92,55 +107,46 @@ struct PrimitiveInliner : ExprMutator {
}
}

Function Inline(const Function& func) {
DLOG(INFO) << "Before inlining primitives: " << std::endl
<< "func= " << AsText(func, false) << std::endl;

auto inlined = FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
func->type_params, func->attrs);

inlined = Downcast<Function>(DeadCodeElimination(inlined));

DLOG(INFO) << "After inlining primitives" << std::endl
<< "after_func= " << AsText(inlined, false) << std::endl;
return inlined;
Module Inline() {
auto gvar_funcs = module_->functions;
for (auto pair : gvar_funcs) {
auto global = pair.first;
auto func = pair.second;
DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);

func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(global, func, true);

DLOG(INFO) << "After inlining primitives: " << global
<< std::endl << AsText(func, false);
}
return module_;
}
};

// TODO(@jroesch): write verifier

/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
Module InlinePrimitives(const Module& module) {
PrimitiveInliner inliner(module);
} // namespace vm

tvm::Map<GlobalVar, Function> updates;
namespace transform {

// There is an ordering bug here.
for (auto pair : module->functions) {
auto global = pair.first;
auto func = pair.second;
updates.Set(global, inliner.Inline(func));
}
Pass InlinePrimitives() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return relay::vm::PrimitiveInliner(m).Inline();
};
auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {});
// Eliminate dead code for each function after inlining.
return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives");
}

for (auto pair : updates) {
module->Add(pair.first, pair.second, true);
}
TVM_REGISTER_API("relay._transform.InlinePrimitives")
.set_body_typed(InlinePrimitives);

return module;
}
} // namespace transform

} // namespace vm
} // namespace relay
} // namespace tvm
80 changes: 41 additions & 39 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>
Expand Down Expand Up @@ -54,9 +55,14 @@ Function MarkClosure(const Function& func) {
return FunctionSetAttr(func, kIsClosure, tvm::Integer(1));
}

/* The goal of this class is to lift out any nested functions into top-level
* functions.
*
* We will lift a function out into a global which takes the set of the free
* vars and then return the new created function.
*/
struct LambdaLifter : ExprMutator {
Module module_;
std::vector<std::pair<GlobalVar, Function>> lifted_;
explicit LambdaLifter(const Module& module) : module_(module) {}

Expr VisitExpr_(const FunctionNode* func_node) final {
Expand All @@ -71,16 +77,15 @@ struct LambdaLifter : ExprMutator {
auto free_type_vars = FreeTypeVars(func, module_);
auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));

// When performing this optimization there are two
// cases.
// When performing this optimization there are two cases.
//
// The first case in which we have no free variables
// we can just lift the function into the global
// environment without needing to allocate a closure.
//
//
// The second case requires that we generate a special
// function with makes a distinction between allocating
// function which makes a distinction between allocating
// a closure, and then the code for the closure.
//
// We represent a closure allocation by lifting the
Expand All @@ -92,7 +97,7 @@ struct LambdaLifter : ExprMutator {
// function marked as a closure is used to emit allocation
// code for the closure's environment.
//
// The "inner" function is should be used to generate the
// The "inner" function should be used to generate the
// code for the closure.
Function lifted_func;
if (free_vars.size() == 0) {
Expand All @@ -107,16 +112,16 @@ struct LambdaLifter : ExprMutator {
CHECK(lifted_func.defined());

auto name = GenerateName(lifted_func);
auto global = this->module_->GetGlobalVar(name);
auto global = module_->GetGlobalVar(name);

lifted_.push_back({global, lifted_func});
// Add the lifted function to the module.
module_->Add(global, lifted_func);

if (free_vars.size() == 0) {
return std::move(global);
} else {
// If we need to allocate a closure
// we pass the variables in its environment
// here.
// If we need to allocate a closure,
// we pass the variables in its environment here.
Array<Expr> fvs;
for (auto fv : free_vars) {
fvs.push_back(fv);
Expand All @@ -125,42 +130,39 @@ struct LambdaLifter : ExprMutator {
}
}

Function Lift(const Function& func) {
DLOG(INFO) << "Lifting: " << AsText(func, false) << std::endl;
return FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
func->type_params, func->attrs);
Module Lift() {
// There is an ordering bug here.
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
auto func = pair.second;
DLOG(INFO) << "Lifting " << AsText(func, false);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(pair.first, func, true);
}
return module_;
}
};

/* The goal of this pass is to lift out any nested functions into top-level
* functions.
*
* We will lift the functions out into globals which take the set of the free vars
* and then return a function whcih has b
*/
Module LambdaLift(const Module& module) {
LambdaLifter lifter(module);

tvm::Map<GlobalVar, Function> updates;
} // namespace vm

// There is an ordering bug here.
for (auto pair : module->functions) {
auto global = pair.first;
auto func = pair.second;
updates.Set(global, lifter.Lift(func));
}
namespace transform {

for (auto i = lifter.lifted_.begin(); i != lifter.lifted_.end(); i++) {
module->Add(i->first, i->second);
}
Pass LambdaLift() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return relay::vm::LambdaLifter(m).Lift();
};
return CreateModulePass(pass_func, 1, "LambdaLift", {});
}

for (auto pair : updates) {
module->Add(pair.first, pair.second, true);
}
TVM_REGISTER_API("relay._transform.LambdaLift")
.set_body_typed(LambdaLift);

return module;
}
} // namespace transform

} // namespace vm
} // namespace relay
} // namespace tvm
Loading