From 2e27335d78df07595727f34b3627bd6a788052e3 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Sat, 2 Nov 2019 15:15:39 -0700 Subject: [PATCH] Tail recursion --- src/relay/backend/vm/compiler.cc | 78 ++++++++++++++++++++++++++++++++ src/relay/backend/vm/compiler.h | 2 + tests/python/relay/test_vm.py | 40 ++++++++++++++++ 3 files changed, 120 insertions(+) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 3cfea5c2e0db..c7f8f89135a6 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -235,6 +235,70 @@ std::vector ToAllocTensorShape32(NDArray shape) { return raw_shape; } +// Find tail recursive calls inside a function +struct TailRecursion : ExprVisitor { + // Target recursive function + const GlobalVar& var; + + // Whether the function is tail recursive + bool is_tail_recursive{true}; + + // All tail recursive call sites + std::unordered_set tail_calls; + + TailRecursion(const GlobalVar& var): + var(var), tail_calls() {} + + void VisitExpr_(const LetNode* let_node) { + if (auto body = let_node->body.as()) { + auto body_var = GetRef(body); + // If body is var we check it's the same as bind var + // Example: let x = expr in x, then we recurse inside expr + // to find tail call sites. + if (let_node->var.same_as(body_var)) { + // find call sites in value + this->VisitExpr(let_node->value); + } else { + // the body var and bind var are not the same, so definitely not tail recursive + // example: let a = expr in b, this is definitely not tail recursive. But I don't + // think such form is possible to appear. + CHECK(false) << "body and bind does not match, no tail recursive\n"; + } + } else { + // If it's not a var we know in the body there are other computations to compute return values + // Example: let x = expr1 in expr2 + this->VisitExpr(let_node->body); + } + } + + void VisitExpr_(const CallNode* call_node) { + Expr op = call_node->op; + auto call = GetRef(call_node); + + if (auto callee_node = op.as()) { + auto callee = GetRef(callee_node); + if (!callee.same_as(var)) { + is_tail_recursive = false; + } else { + auto call = GetRef(call_node); + this->tail_calls.insert(call); + } + } else { + is_tail_recursive = false; + } + } +}; + +std::unordered_set FindTailRecursion(const GlobalVar& var, const Function& func) { + auto tail_recursion = TailRecursion(var); + tail_recursion.VisitExpr(func); + if (!tail_recursion.is_tail_recursive) { + // Clear other tail callls sine the function is not fully tail recursive + tail_recursion.tail_calls.clear(); + } + return tail_recursion.tail_calls; +} + class VMFunctionCompiler : ExprFunctor { public: VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) @@ -617,6 +681,18 @@ class VMFunctionCompiler : ExprFunctor { auto arity = func->params.size(); Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister())); } else { + auto call = GetRef(call_node); + auto find_it = this->context_->tail_recursions.find(call); + if (find_it != this->context_->tail_recursions.end()) { + // Handle tail recursion + // Override parameters + for (size_t i = 0; i < func->params.size(); ++i) { + Emit(Instruction::Move(args_registers[i], var_register_map_[func->params[i]])); + } + + Emit(Instruction::Goto(-(instructions_.size()))); + return; + } Emit(Instruction::Invoke(it->second, args_registers, NewRegister())); } } else if (auto constructor_node = op.as()) { @@ -836,6 +912,8 @@ void VMCompiler::Compile(Module mod, for (auto named_func : context_.module->functions) { auto gvar = named_func.first; auto func = named_func.second; + context_.tail_recursions = FindTailRecursion(gvar, func); + VMFunctionCompiler func_compiler(&context_, targets_, target_host_); auto vm_func = func_compiler.Compile(gvar, func); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 215cc12c4cdb..c2d4b494dd7a 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -78,6 +78,8 @@ struct VMCompilerContext { std::vector cached_funcs; // The functions that have been lowered. std::unordered_map seen_funcs; + // Tail recursive calls + std::unordered_set tail_recursions; }; diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index a3b251c38e00..074828ef537e 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -575,6 +575,46 @@ def test_add_op_broadcast(): mod["main"] = func check_result([x_data, y_data], x_data + y_data, mod=mod) +def create_exec(f, target="llvm", params=None): + from tvm.relay import vm as _vm + if isinstance(f, relay.Expr): + mod = relay.Module() + mod["main"] = f + executable = _vm.compile(mod, target=target, params=params) + return executable + else: + assert isinstance(f, relay.Module), "expected mod as relay.Module" + executable = _vm.compile(f, target=target, params=params) + return executable + +def test_tail_recursion(): + loop = relay.GlobalVar('loop') + sb = relay.ScopeBuilder() + + x = relay.var('x', shape=()) + y = relay.var('y', shape=()) + with sb.if_scope(relay.op.less(x, relay.const(10.0))): + x1 = x + relay.const(1.0) + y1 = y * relay.const(2.0) + x2 = relay.Call(loop, [x1, y1]) + sb.ret(x2) + with sb.else_scope(): + sb.ret(y) + + body = sb.get() + f = relay.Function([x, y], body) + + # module definition + mod = relay.Module() + mod[loop] = f + + print("Mod {}".format(mod)) + one = relay.const(1.0) + mod['main'] = relay.Function([], relay.Call(loop, [one, one])) + res = create_exec(mod) + print("Opcode: \n{}".format(res.bytecode)) + res = veval(mod) + tvm.testing.assert_allclose(res.asnumpy(), 512.0) if __name__ == "__main__": test_id()