Skip to content

Commit

Permalink
[Relay] Fix VM compiler for while loop with free vars (apache#4889)
Browse files Browse the repository at this point in the history
* add additional switch to handle nested call node

* Fix VM compiler for while loop with free var
  • Loading branch information
masahi authored and alexwong committed Feb 28, 2020
1 parent afe236b commit 4484283
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,9 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
// emit invoke closure here.
VisitExpr(GetRef<Var>(var_node));
Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
} else if (auto inner_call_node = op.as<CallNode>()) {
VisitExpr(GetRef<Call>(inner_call_node));
Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
} else {
// Finally if there are any other cases this is a bug.
LOG(FATAL) << "internal error: unreachable code,"
Expand Down
27 changes: 27 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.testing.config import ctx_list
from tvm.relay.prelude import Prelude
from tvm.relay.loops import while_loop
from tvm.relay import testing

def check_result(args, expected_result, mod=None):
Expand Down Expand Up @@ -576,5 +577,31 @@ def test_vm_optimize():
comp = relay.vm.VMCompiler()
opt_mod, _ = comp.optimize(mod, "llvm", params)

def test_loop_free_var():
x = relay.var('x', shape=(), dtype='int32')
i = relay.var('i', shape=(), dtype='int32')
s = relay.var('s', shape=(), dtype='int32')

def cond(i, _):
return i < relay.const(10, dtype='int32')

def body_no_free_var(i, acc):
incr = relay.const(1, "int32")
return i + incr, acc + i

def body_with_free_var(i, acc):
incr = relay.const(1, "int32")
return i + incr, acc + x

for args, body, expected in zip([[], [1]],
[body_no_free_var, body_with_free_var],
[45, 10]):
loop = while_loop(cond, [i, s], body)
tup = loop(relay.const(0, dtype='int32'), relay.zeros(shape=(), dtype='int32'))
ret = relay.TupleGetItem(tup, 1)
mod = tvm.IRModule()
mod["main"] = relay.Function(relay.analysis.free_vars(ret), ret)
check_result(args, expected, mod=mod)

if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 4484283

Please sign in to comment.