Skip to content

Commit

Permalink
[relay][vm] move vm opt passes to pass manager (#3323)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and jroesch committed Jun 11, 2019
1 parent 8f219b9 commit 70041c4
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 113 deletions.
52 changes: 33 additions & 19 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,45 @@
Implements a Python interface to compiling and executing on the Relay VM.
"""
import numpy as np

import tvm
from tvm._ffi.function import Object
import numpy as np
from .. import ir_pass
from .. import transform
from ..backend.interpreter import Executor
from ..expr import GlobalVar, Function, Expr
from ..expr import GlobalVar, Expr
from . import _vm

Object = Object

def optimize(expr, mod=None):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=mod)
simplified_expr = ir_pass.simplify_inference(ck_expr)
simplified_expr = ir_pass.infer_type(simplified_expr, mod=mod)
fused_expr = ir_pass.fuse_ops(simplified_expr, mod=mod)
ck_fused = ir_pass.infer_type(fused_expr, mod=mod)
return ck_fused
def optimize(mod):
"""Perform several optimizations on a module before executing it in the
Relay virtual machine.
Parameters
----------
mod : tvm.relay.Module
The module to optimize.
Returns
-------
ret : tvm.relay.Module
The optimized module.
"""
main_func = mod[mod.entry_func]

opt_passes = []
if not main_func.params and isinstance(main_func.body, GlobalVar):
opt_passes.append(transform.EtaExpand())

opt_passes = opt_passes + [
transform.SimplifyInference(),
transform.FuseOps(),
transform.InferType()
]

seq = transform.Sequential(opt_passes)
return seq(mod)

def _convert(arg, cargs):
if isinstance(arg, np.ndarray):
Expand Down Expand Up @@ -76,15 +97,8 @@ def _eval_vm(mod, ctx, *args):
args: List[tvm.NDArray, np.ndarray]
The arguments to evaluate.
"""
main_func = mod[mod.entry_func]

if not main_func.params and isinstance(main_func.body, GlobalVar):
main_func = ir_pass.eta_expand(main_func.body, mod)

assert isinstance(main_func, Function)
main_func = optimize(mod[mod.entry_func], mod)
mod[mod.entry_func] = main_func

mod = optimize(mod)
args = list(args)
assert isinstance(args, list)
cargs = convert(args)
Expand Down
24 changes: 17 additions & 7 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <unordered_map>
Expand All @@ -38,15 +38,22 @@

namespace tvm {
namespace relay {

namespace transform {

Pass LambdaLift();
Pass InlinePrimitives();

} // namespace transform

namespace vm {

using namespace tvm::runtime;
using namespace tvm::runtime::vm;
using namespace relay::transform;

// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);
Module LambdaLift(const Module& module);
Module InlinePrimitives(const Module& module);

template <typename T, typename U>
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
Expand Down Expand Up @@ -560,10 +567,13 @@ VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const F
}

Module OptimizeModule(const Module& mod) {
ToANormalForm(mod->entry_func, mod);
InlinePrimitives(mod);
LambdaLift(mod);
return InlinePrimitives(mod);
transform::Sequential seq({transform::ToANormalForm(),
transform::InlinePrimitives(),
transform::LambdaLift(),
transform::InlinePrimitives()});
auto pass_ctx = transform::PassContext::Create();
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
return seq(mod);
}

void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) {
Expand Down
92 changes: 49 additions & 43 deletions src/relay/backend/vm/inline_primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>
Expand All @@ -37,6 +37,21 @@ namespace tvm {
namespace relay {
namespace vm {

// TODO(@jroesch): write verifier

/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
struct PrimitiveInliner : ExprMutator {
Module module_;
std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map;
Expand Down Expand Up @@ -92,55 +107,46 @@ struct PrimitiveInliner : ExprMutator {
}
}

Function Inline(const Function& func) {
DLOG(INFO) << "Before inlining primitives: " << std::endl
<< "func= " << AsText(func, false) << std::endl;

auto inlined = FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
func->type_params, func->attrs);

inlined = Downcast<Function>(DeadCodeElimination(inlined));

DLOG(INFO) << "After inlining primitives" << std::endl
<< "after_func= " << AsText(inlined, false) << std::endl;
return inlined;
Module Inline() {
auto gvar_funcs = module_->functions;
for (auto pair : gvar_funcs) {
auto global = pair.first;
auto func = pair.second;
DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);

func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(global, func, true);

DLOG(INFO) << "After inlining primitives: " << global
<< std::endl << AsText(func, false);
}
return module_;
}
};

// TODO(@jroesch): write verifier

/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
Module InlinePrimitives(const Module& module) {
PrimitiveInliner inliner(module);
} // namespace vm

tvm::Map<GlobalVar, Function> updates;
namespace transform {

// There is an ordering bug here.
for (auto pair : module->functions) {
auto global = pair.first;
auto func = pair.second;
updates.Set(global, inliner.Inline(func));
}
Pass InlinePrimitives() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return relay::vm::PrimitiveInliner(m).Inline();
};
auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {});
// Eliminate dead code for each function after inlining.
return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives");
}

for (auto pair : updates) {
module->Add(pair.first, pair.second, true);
}
TVM_REGISTER_API("relay._transform.InlinePrimitives")
.set_body_typed(InlinePrimitives);

return module;
}
} // namespace transform

} // namespace vm
} // namespace relay
} // namespace tvm
80 changes: 41 additions & 39 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>
Expand Down Expand Up @@ -54,9 +55,14 @@ Function MarkClosure(const Function& func) {
return FunctionSetAttr(func, kIsClosure, tvm::Integer(1));
}

/* The goal of this class is to lift out any nested functions into top-level
* functions.
*
* We will lift a function out into a global which takes the set of the free
* vars and then return the new created function.
*/
struct LambdaLifter : ExprMutator {
Module module_;
std::vector<std::pair<GlobalVar, Function>> lifted_;
explicit LambdaLifter(const Module& module) : module_(module) {}

Expr VisitExpr_(const FunctionNode* func_node) final {
Expand All @@ -71,16 +77,15 @@ struct LambdaLifter : ExprMutator {
auto free_type_vars = FreeTypeVars(func, module_);
auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));

// When performing this optimization there are two
// cases.
// When performing this optimization there are two cases.
//
// The first case in which we have no free variables
// we can just lift the function into the global
// environment without needing to allocate a closure.
//
//
// The second case requires that we generate a special
// function with makes a distinction between allocating
// function which makes a distinction between allocating
// a closure, and then the code for the closure.
//
// We represent a closure allocation by lifting the
Expand All @@ -92,7 +97,7 @@ struct LambdaLifter : ExprMutator {
// function marked as a closure is used to emit allocation
// code for the closure's environment.
//
// The "inner" function is should be used to generate the
// The "inner" function should be used to generate the
// code for the closure.
Function lifted_func;
if (free_vars.size() == 0) {
Expand All @@ -107,16 +112,16 @@ struct LambdaLifter : ExprMutator {
CHECK(lifted_func.defined());

auto name = GenerateName(lifted_func);
auto global = this->module_->GetGlobalVar(name);
auto global = module_->GetGlobalVar(name);

lifted_.push_back({global, lifted_func});
// Add the lifted function to the module.
module_->Add(global, lifted_func);

if (free_vars.size() == 0) {
return std::move(global);
} else {
// If we need to allocate a closure
// we pass the variables in its environment
// here.
// If we need to allocate a closure,
// we pass the variables in its environment here.
Array<Expr> fvs;
for (auto fv : free_vars) {
fvs.push_back(fv);
Expand All @@ -125,42 +130,39 @@ struct LambdaLifter : ExprMutator {
}
}

Function Lift(const Function& func) {
DLOG(INFO) << "Lifting: " << AsText(func, false) << std::endl;
return FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
func->type_params, func->attrs);
Module Lift() {
// There is an ordering bug here.
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
auto func = pair.second;
DLOG(INFO) << "Lifting " << AsText(func, false);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(pair.first, func, true);
}
return module_;
}
};

/* The goal of this pass is to lift out any nested functions into top-level
* functions.
*
* We will lift the functions out into globals which take the set of the free vars
* and then return a function whcih has b
*/
Module LambdaLift(const Module& module) {
LambdaLifter lifter(module);

tvm::Map<GlobalVar, Function> updates;
} // namespace vm

// There is an ordering bug here.
for (auto pair : module->functions) {
auto global = pair.first;
auto func = pair.second;
updates.Set(global, lifter.Lift(func));
}
namespace transform {

for (auto i = lifter.lifted_.begin(); i != lifter.lifted_.end(); i++) {
module->Add(i->first, i->second);
}
Pass LambdaLift() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return relay::vm::LambdaLifter(m).Lift();
};
return CreateModulePass(pass_func, 1, "LambdaLift", {});
}

for (auto pair : updates) {
module->Add(pair.first, pair.second, true);
}
TVM_REGISTER_API("relay._transform.LambdaLift")
.set_body_typed(LambdaLift);

return module;
}
} // namespace transform

} // namespace vm
} // namespace relay
} // namespace tvm
Loading

0 comments on commit 70041c4

Please sign in to comment.