diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index fc52a8e939d4..e3c8d12a6e66 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -366,7 +366,9 @@ class VMFunctionCompiler : ExprFunctor { this->Emit(Instruction::If(test_register, target_register, 0, 0)); this->VisitExpr(if_node->true_branch); - size_t true_register = last_register_; + // It saves the result of If-Else expression. + auto merge_register = NewRegister(); + Emit(Instruction::Move(last_register_, merge_register)); Emit(Instruction::Goto(0)); // Finally store how many instructions there are in the @@ -378,7 +380,7 @@ class VMFunctionCompiler : ExprFunctor { size_t false_register = last_register_; // In else-branch, override the then-branch register - Emit(Instruction::Move(false_register, true_register)); + Emit(Instruction::Move(false_register, merge_register)); // Compute the total number of instructions // after generating false. auto after_false = this->instructions_.size(); @@ -397,7 +399,7 @@ class VMFunctionCompiler : ExprFunctor { // Patch the Goto. this->instructions_[after_true - 1].pc_offset = (after_false - after_true) + 1; - this->last_register_ = true_register; + this->last_register_ = merge_register; } void EmitShapeFunc(Function func, Array inputs, Array outputs) { diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 02f1e5b753f8..a8ac27a11c0f 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -142,6 +142,25 @@ def test_simple_if(): # diff check_result([x_data, y_data], y_data, mod=mod) +def test_multiple_ifs(): + mod = tvm.IRModule({}) + b = relay.var('b') + v0 = relay.var('v0') + v1 = relay.var('v1') + v2 = relay.var('v2') + v3 = relay.var('v3') + out = relay.Tuple([v2, v3]) + out = relay.Let(v3, relay.If(b, v1, v0), out) + out = relay.Let(v2, relay.If(b, v0, v1), out) + out = relay.Let(v1, relay.Tuple([relay.const(1)]), out) + out = relay.Let(v0, relay.Tuple([relay.const(0)]), out) + fn = relay.Function([b], out) + mod['main'] = fn + ctx = tvm.runtime.ndarray.context('llvm', 0) + vm = relay.create_executor(ctx=ctx, mod=mod, kind='vm') + res = vmobj_to_list(vm.evaluate()(False)) + assert(res == [1, 0]) + def test_simple_call(): mod = tvm.IRModule({}) sum_up = relay.GlobalVar('sum_up')