diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 8d4f4addaca9..73a6450c16ec 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -637,6 +637,9 @@ class VMFunctionCompiler : ExprFunctor { // emit invoke closure here. VisitExpr(GetRef(var_node)); Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister())); + } else if (auto inner_call_node = op.as()) { + VisitExpr(GetRef(inner_call_node)); + Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister())); } else { // Finally if there are any other cases this is a bug. LOG(FATAL) << "internal error: unreachable code," diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index c4cd616cdec0..8cac656ee5a1 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -23,6 +23,7 @@ from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.testing.config import ctx_list from tvm.relay.prelude import Prelude +from tvm.relay.loops import while_loop from tvm.relay import testing def check_result(args, expected_result, mod=None): @@ -576,5 +577,31 @@ def test_vm_optimize(): comp = relay.vm.VMCompiler() opt_mod, _ = comp.optimize(mod, "llvm", params) +def test_loop_free_var(): + x = relay.var('x', shape=(), dtype='int32') + i = relay.var('i', shape=(), dtype='int32') + s = relay.var('s', shape=(), dtype='int32') + + def cond(i, _): + return i < relay.const(10, dtype='int32') + + def body_no_free_var(i, acc): + incr = relay.const(1, "int32") + return i + incr, acc + i + + def body_with_free_var(i, acc): + incr = relay.const(1, "int32") + return i + incr, acc + x + + for args, body, expected in zip([[], [1]], + [body_no_free_var, body_with_free_var], + [45, 10]): + loop = while_loop(cond, [i, s], body) + tup = loop(relay.const(0, dtype='int32'), relay.zeros(shape=(), dtype='int32')) + ret = relay.TupleGetItem(tup, 1) + mod = tvm.IRModule() + mod["main"] = relay.Function(relay.analysis.free_vars(ret), ret) + check_result(args, expected, mod=mod) + if __name__ == "__main__": pytest.main([__file__])