Skip to content

Commit

Permalink
Merge pull request #4 from jroesch/relay-rts-wweic-tuple
Browse files Browse the repository at this point in the history
Add tuple support in VM
  • Loading branch information
jroesch authored Feb 4, 2019
2 parents 0d06980 + bc800eb commit 5c958ec
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 1 deletion.
6 changes: 6 additions & 0 deletions include/tvm/relay/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ enum struct Opcode {
Invoke,
InvokePacked,
AllocTensor,
GetField,
If,
LoadConst,
Goto,
Expand Down Expand Up @@ -113,6 +114,10 @@ struct Instruction {
struct {
size_t pc_offset;
};
struct {
size_t object_offset;
size_t field_index;
};
};

Instruction();
Expand All @@ -124,6 +129,7 @@ Instruction Push(size_t stack_index);
Instruction Ret();
Instruction InvokePacked(size_t stack_index);
Instruction AllocTensor(std::vector<int64_t> shape, std::string dtype);
Instruction GetField(size_t object_offset, size_t field_index);


struct VMFunction {
Expand Down
37 changes: 37 additions & 0 deletions src/relay/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ Instruction::Instruction(const Instruction& instr) {
case Opcode::LoadConst:
this->const_index = instr.const_index;
return;
case Opcode::GetField:
this->object_offset = instr.object_offset;
this->field_index = instr.field_index;
return;
case Opcode::Goto:
this->pc_offset = instr.pc_offset;
return;
Expand Down Expand Up @@ -90,6 +94,14 @@ Instruction AllocTensor(const std::vector<int64_t> shape, DLDataType dtype) {
return instr;
}

Instruction GetField(size_t object_offset, size_t field_index) {
Instruction instr;
instr.op = Opcode::GetField;
instr.object_offset = object_offset;
instr.field_index = field_index;
return instr;
}

Instruction If(size_t true_branch, size_t false_branch) {
Instruction instr;
instr.op = Opcode::If;
Expand Down Expand Up @@ -162,6 +174,12 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
<< instr.const_index;
break;
}
case Opcode::GetField: {
os << "get_field "
<< instr.object_offset << " "
<< instr.field_index;
break;
}
case Opcode::Goto: {
os << "goto "
<< instr.pc_offset;
Expand Down Expand Up @@ -259,6 +277,12 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
this->VisitExpr(let_node->body);
}

void VisitExpr_(const TupleGetItemNode* get_node) {
auto get = GetRef<TupleGetItem>(get_node);
this->VisitExpr(get->tuple);
Emit(GetField(this->stack_index++, get->index));
}

void VisitExpr_(const GlobalVarNode* gvar) {
auto global = GetRef<GlobalVar>(gvar);
auto it = this->context->global_map.find(global);
Expand Down Expand Up @@ -558,6 +582,15 @@ void VirtualMachine::Run() {
pc++;
goto main_loop;
}
case Opcode::GetField: {
auto object = stack[bp + instr.object_offset];
CHECK(object->tag == VMObjectTag::kDatatype) << "Object is not data type object";
auto tuple = std::dynamic_pointer_cast<VMDatatypeCell>(object);
auto field = tuple->fields[instr.field_index];
stack.push_back(field);
pc++;
goto main_loop;
}
case Opcode::Goto: {
pc += instr.pc_offset + 1;
goto main_loop;
Expand Down Expand Up @@ -656,6 +689,10 @@ Value ConvertVMToValue(VMObject obj) {
case VMObjectTag::kTensor: {
return TensorValueNode::make(ToNDArray(obj));
}
case VMObjectTag::kDatatype: {
LOG(FATAL) << "unsupported return value: data type";
return Value();
}
default:
LOG(FATAL) << "unsupported return value";
return Value();
Expand Down
11 changes: 10 additions & 1 deletion tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,20 @@ def test_tuple_fst():
ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
tup = relay.var('tup', type_annotation=ttype)
f = relay.Function([tup], relay.TupleGetItem(tup, 0))
i_data = np.random.rand(1).astype('float32')
i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32')
result = eval_vm(f, tvm.cpu(), (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), i_data)

def test_tuple_second():
ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
tup = relay.var('tup', type_annotation=ttype)
f = relay.Function([tup], relay.TupleGetItem(tup, 1))
i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32')
result = eval_vm(f, tvm.cpu(), (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), j_data)

def test_let_tensor():
sb = relay.ScopeBuilder()
shape = (1,)
Expand Down

0 comments on commit 5c958ec

Please sign in to comment.