Skip to content

Commit

Permalink
Merge pull request apache#2 from wweic/vm-exp
Browse files Browse the repository at this point in the history
Tail recursion
  • Loading branch information
wweic authored Nov 3, 2019
2 parents c381ec3 + 2e27335 commit 7f0d343
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 0 deletions.
78 changes: 78 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,70 @@ std::vector<int64_t> ToAllocTensorShape32(NDArray shape) {
return raw_shape;
}

// Find tail recursive calls inside a function
struct TailRecursion : ExprVisitor {
// Target recursive function
const GlobalVar& var;

// Whether the function is tail recursive
bool is_tail_recursive{true};

// All tail recursive call sites
std::unordered_set<Call, NodeHash> tail_calls;

TailRecursion(const GlobalVar& var):
var(var), tail_calls() {}

void VisitExpr_(const LetNode* let_node) {
if (auto body = let_node->body.as<VarNode>()) {
auto body_var = GetRef<Var>(body);
// If body is var we check it's the same as bind var
// Example: let x = expr in x, then we recurse inside expr
// to find tail call sites.
if (let_node->var.same_as(body_var)) {
// find call sites in value
this->VisitExpr(let_node->value);
} else {
// the body var and bind var are not the same, so definitely not tail recursive
// example: let a = expr in b, this is definitely not tail recursive. But I don't
// think such form is possible to appear.
CHECK(false) << "body and bind does not match, no tail recursive\n";
}
} else {
// If it's not a var we know in the body there are other computations to compute return values
// Example: let x = expr1 in expr2
this->VisitExpr(let_node->body);
}
}

void VisitExpr_(const CallNode* call_node) {
Expr op = call_node->op;
auto call = GetRef<Call>(call_node);

if (auto callee_node = op.as<GlobalVarNode>()) {
auto callee = GetRef<GlobalVar>(callee_node);
if (!callee.same_as(var)) {
is_tail_recursive = false;
} else {
auto call = GetRef<Call>(call_node);
this->tail_calls.insert(call);
}
} else {
is_tail_recursive = false;
}
}
};

std::unordered_set<Call, NodeHash> FindTailRecursion(const GlobalVar& var, const Function& func) {
auto tail_recursion = TailRecursion(var);
tail_recursion.VisitExpr(func);
if (!tail_recursion.is_tail_recursive) {
// Clear other tail callls sine the function is not fully tail recursive
tail_recursion.tail_calls.clear();
}
return tail_recursion.tail_calls;
}

class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
public:
VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
Expand Down Expand Up @@ -617,6 +681,18 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto arity = func->params.size();
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
} else {
auto call = GetRef<Call>(call_node);
auto find_it = this->context_->tail_recursions.find(call);
if (find_it != this->context_->tail_recursions.end()) {
// Handle tail recursion
// Override parameters
for (size_t i = 0; i < func->params.size(); ++i) {
Emit(Instruction::Move(args_registers[i], var_register_map_[func->params[i]]));
}

Emit(Instruction::Goto(-(instructions_.size())));
return;
}
Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
}
} else if (auto constructor_node = op.as<ConstructorNode>()) {
Expand Down Expand Up @@ -836,6 +912,8 @@ void VMCompiler::Compile(Module mod,
for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
auto func = named_func.second;
context_.tail_recursions = FindTailRecursion(gvar, func);

VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
auto vm_func = func_compiler.Compile(gvar, func);

Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ struct VMCompilerContext {
std::vector<CachedFunc> cached_funcs;
// The functions that have been lowered.
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
// Tail recursive calls
std::unordered_set<Call, NodeHash> tail_recursions;
};


Expand Down
40 changes: 40 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,46 @@ def test_add_op_broadcast():
mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod)

def create_exec(f, target="llvm", params=None):
from tvm.relay import vm as _vm
if isinstance(f, relay.Expr):
mod = relay.Module()
mod["main"] = f
executable = _vm.compile(mod, target=target, params=params)
return executable
else:
assert isinstance(f, relay.Module), "expected mod as relay.Module"
executable = _vm.compile(f, target=target, params=params)
return executable

def test_tail_recursion():
loop = relay.GlobalVar('loop')
sb = relay.ScopeBuilder()

x = relay.var('x', shape=())
y = relay.var('y', shape=())
with sb.if_scope(relay.op.less(x, relay.const(10.0))):
x1 = x + relay.const(1.0)
y1 = y * relay.const(2.0)
x2 = relay.Call(loop, [x1, y1])
sb.ret(x2)
with sb.else_scope():
sb.ret(y)

body = sb.get()
f = relay.Function([x, y], body)

# module definition
mod = relay.Module()
mod[loop] = f

print("Mod {}".format(mod))
one = relay.const(1.0)
mod['main'] = relay.Function([], relay.Call(loop, [one, one]))
res = create_exec(mod)
print("Opcode: \n{}".format(res.bytecode))
res = veval(mod)
tvm.testing.assert_allclose(res.asnumpy(), 512.0)

if __name__ == "__main__":
test_id()
Expand Down

0 comments on commit 7f0d343

Please sign in to comment.