diff --git a/include/tvm/relay/vm/vm.h b/include/tvm/relay/vm/vm.h index 2421564cc42e3..ccfb025a67b31 100644 --- a/include/tvm/relay/vm/vm.h +++ b/include/tvm/relay/vm/vm.h @@ -80,6 +80,7 @@ enum struct Opcode { Invoke, InvokePacked, AllocTensor, + GetField, If, LoadConst, Goto, @@ -113,6 +114,10 @@ struct Instruction { struct { size_t pc_offset; }; + struct { + size_t object_offset; + size_t field_index; + }; }; Instruction(); @@ -124,6 +129,7 @@ Instruction Push(size_t stack_index); Instruction Ret(); Instruction InvokePacked(size_t stack_index); Instruction AllocTensor(std::vector shape, std::string dtype); +Instruction GetField(size_t object_offset, size_t field_index); struct VMFunction { diff --git a/src/relay/vm/vm.cc b/src/relay/vm/vm.cc index 4b889886f3e2f..9d7f4b0f3b305 100644 --- a/src/relay/vm/vm.cc +++ b/src/relay/vm/vm.cc @@ -47,6 +47,10 @@ Instruction::Instruction(const Instruction& instr) { case Opcode::LoadConst: this->const_index = instr.const_index; return; + case Opcode::GetField: + this->object_offset = instr.object_offset; + this->field_index = instr.field_index; + return; case Opcode::Goto: this->pc_offset = instr.pc_offset; return; @@ -90,6 +94,14 @@ Instruction AllocTensor(const std::vector shape, DLDataType dtype) { return instr; } +Instruction GetField(size_t object_offset, size_t field_index) { + Instruction instr; + instr.op = Opcode::GetField; + instr.object_offset = object_offset; + instr.field_index = field_index; + return instr; +} + Instruction If(size_t true_branch, size_t false_branch) { Instruction instr; instr.op = Opcode::If; @@ -162,6 +174,12 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { << instr.const_index; break; } + case Opcode::GetField: { + os << "get_field " + << instr.object_offset << " " + << instr.field_index; + break; + } case Opcode::Goto: { os << "goto " << instr.pc_offset; @@ -259,6 +277,12 @@ struct VMCompiler : ExprFunctor { this->VisitExpr(let_node->body); } + void VisitExpr_(const TupleGetItemNode* get_node) { + auto get = GetRef(get_node); + this->VisitExpr(get->tuple); + Emit(GetField(this->stack_index++, get->index)); + } + void VisitExpr_(const GlobalVarNode* gvar) { auto global = GetRef(gvar); auto it = this->context->global_map.find(global); @@ -558,6 +582,15 @@ void VirtualMachine::Run() { pc++; goto main_loop; } + case Opcode::GetField: { + auto object = stack[bp + instr.object_offset]; + CHECK(object->tag == VMObjectTag::kDatatype) << "Object is not data type object"; + auto tuple = std::dynamic_pointer_cast(object); + auto field = tuple->fields[instr.field_index]; + stack.push_back(field); + pc++; + goto main_loop; + } case Opcode::Goto: { pc += instr.pc_offset + 1; goto main_loop; @@ -656,6 +689,10 @@ Value ConvertVMToValue(VMObject obj) { case VMObjectTag::kTensor: { return TensorValueNode::make(ToNDArray(obj)); } + case VMObjectTag::kDatatype: { + LOG(FATAL) << "unsupported return value: data type"; + return Value(); + } default: LOG(FATAL) << "unsupported return value"; return Value(); diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index ec00de05d6114..cbb5de2f9e007 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -115,11 +115,20 @@ def test_tuple_fst(): ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))]) tup = relay.var('tup', type_annotation=ttype) f = relay.Function([tup], relay.TupleGetItem(tup, 0)) - i_data = np.random.rand(1).astype('float32') + i_data = np.random.rand(41).astype('float32') j_data = np.random.rand(10).astype('float32') result = eval_vm(f, tvm.cpu(), (i_data, j_data)) tvm.testing.assert_allclose(result.asnumpy(), i_data) +def test_tuple_second(): + ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))]) + tup = relay.var('tup', type_annotation=ttype) + f = relay.Function([tup], relay.TupleGetItem(tup, 1)) + i_data = np.random.rand(41).astype('float32') + j_data = np.random.rand(10).astype('float32') + result = eval_vm(f, tvm.cpu(), (i_data, j_data)) + tvm.testing.assert_allclose(result.asnumpy(), j_data) + def test_let_tensor(): sb = relay.ScopeBuilder() shape = (1,)