Skip to content

Commit

Permalink
[relay][tensor_array] test tensor_array in vm (apache#4608)
Browse files Browse the repository at this point in the history
* [relay] test tensor_array in vm

* add tensor_array scatter test
  • Loading branch information
zhiics committed Mar 2, 2020
1 parent 860d1ba commit 4ac827d
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 103 deletions.
11 changes: 9 additions & 2 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm import autotvm
from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from . import _vm
from . import vmobj as _obj
from .interpreter import Executor
Expand All @@ -34,7 +35,9 @@
ADT = _obj.ADT

def _convert(arg, cargs):
if isinstance(arg, _obj.Object):
if isinstance(arg, _expr.Constant):
cargs.append(_obj.Tensor(arg.data))
elif isinstance(arg, _obj.Object):
cargs.append(arg)
elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.Tensor(arg))
Expand All @@ -43,8 +46,12 @@ def _convert(arg, cargs):
for field in arg:
_convert(field, field_args)
cargs.append(_obj.tuple_object(field_args))
elif isinstance(arg, (_base.numeric_types, bool)):
dtype = "int32" if isinstance(arg, (int, bool)) else "float32"
value = _obj.Tensor(np.array(arg, dtype=dtype))
cargs.append(value)
else:
raise "Unsupported type: %s" % (type(arg))
raise TypeError("Unsupported type: %s" % (type(arg)))


def convert(args):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def __init__(self, prelude, dtype):
self.dtype = dtype

def get_name(self, canonical):
"""Get name corresponding to the caninical name"""
"""Get name corresponding to the canonical name"""
return self.prelude.get_name(canonical, self.dtype)

def get_var(self, canonical):
"""Get var corresponding to the caninical name"""
"""Get var corresponding to the canonical name"""
return self.prelude.get_var(canonical, self.dtype)

def define_tensor_adt(self):
Expand Down
36 changes: 12 additions & 24 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,13 @@
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/attrs/memory.h>
#include <topi/tags.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../../../runtime/vm/naive_allocator.h"
#include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h"
#include "../../op/op_common.h"
Expand Down Expand Up @@ -73,8 +69,6 @@ using namespace relay::transform;
// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);

void InstructionPrint(std::ostream& os, const Instruction& instr);

// Represent a runtime object that's going to be matched by pattern match expressions
struct MatchValue {
virtual ~MatchValue() {}
Expand Down Expand Up @@ -156,12 +150,10 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data,
if (pattern.as<PatternWildcardNode>()) {
// We ignore wildcard binding since it's not producing new vars
return then_branch;
} else if (pattern.as<PatternVarNode>()) {
auto pat = pattern.as<PatternVarNode>();
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
} else if (const auto* pvn = pattern.as<PatternVarNode>()) {
auto cond = std::make_shared<VarBinding>(pvn->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else if (auto pcn = pattern.as<PatternConstructorNode>()) {
} else if (const auto* pcn = pattern.as<PatternConstructorNode>()) {
auto tag = pcn->constructor->tag;

size_t field_index = 0;
Expand All @@ -173,13 +165,12 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data,
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << pattern;
const auto* pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << AsText(pattern, false);
size_t field_index = 0;
for (auto& p : pt->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
auto d = std::make_shared<AccessField>(data, field_index++);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
return then_branch;
}
Expand Down Expand Up @@ -633,7 +624,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
// and emit a call to allocate the data structure.
auto constructor = GetRef<Constructor>(constructor_node);
Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
NewRegister()));
NewRegister()));
} else if (auto var_node = op.as<VarNode>()) {
// If we are calling a variable, it must be the case that it is a closure so we
// emit invoke closure here.
Expand Down Expand Up @@ -675,16 +666,13 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}

void CompileTreeNode(TreeObjectPtr tree) {
if (std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree);
if (auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
VisitExpr(node->body);
} else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
Emit(Instruction::Fatal());
} else if (std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree);
if (std::dynamic_pointer_cast<TagCompare>(node->cond)) {
} else if (auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
if (auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond)) {
// For Tag compariton, generate branches
auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond);
auto r = CompileMatchValue(cond->obj);
Emit(Instruction::GetTag(r, NewRegister()));
auto operand1 = last_register_;
Expand All @@ -707,8 +695,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1;
} else {
// For other non-branch conditions, move to then_branch directly
auto cond = std::dynamic_pointer_cast<VarBinding>(node->cond);
var_register_map_[cond->var] = CompileMatchValue(cond->val);
auto var_bind = std::dynamic_pointer_cast<VarBinding>(node->cond);
var_register_map_[var_bind->var] = CompileMatchValue(var_bind->val);
CompileTreeNode(node->then_branch);
}
}
Expand Down
18 changes: 11 additions & 7 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -583,9 +583,9 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
break;
}
case Opcode::AllocStorage: {
os << "alloc_storage " <<
instr.dst << " " <<
instr.alloc_storage.allocation_size << " " <<
os << "alloc_storage $" <<
instr.dst << " $" <<
instr.alloc_storage.allocation_size << " $" <<
instr.alloc_storage.alignment << " " <<
TVMType2String(instr.alloc_storage.dtype_hint);
break;
Expand Down Expand Up @@ -771,12 +771,14 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
for (size_t fi = 0; fi < dt_cell->size; ++fi) {
auto obj = (*dt_cell)[fi];
const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< obj->GetTypeKey();
setter(idx++, tensor->data);
}
} else {
const auto* tensor = args[i].as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< args[i]->GetTypeKey();
setter(idx++, tensor->data);
}
}
Expand Down Expand Up @@ -823,7 +825,8 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
int32_t result;
const auto& obj = ReadRegister(r);
const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< obj->GetTypeKey();
NDArray array = tensor->data.CopyTo({kDLCPU, 0});

if (array->dtype.bits <= 8) {
Expand Down Expand Up @@ -984,7 +987,8 @@ void VirtualMachine::RunLoop() {
cpu_ctx.device_id = 0;
auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
const auto* tensor = shape_tensor_obj.as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< shape_tensor_obj->GetTypeKey();
NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx);
const DLTensor* dl_tensor = shape_tensor.operator->();
CHECK_EQ(dl_tensor->dtype.code, 0u);
Expand Down
Loading

0 comments on commit 4ac827d

Please sign in to comment.