diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 028a5ff9d1ad..c8171d2c71f5 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -61,9 +61,11 @@ enum class Opcode { AllocClosure = 8U, GetField = 9U, If = 10U, - Select = 11U, - LoadConst = 12U, - Goto = 13U + LoadConst = 11U, + Goto = 12U, + GetTag = 13U, + LoadConsti = 14U, + Fatal = 15U, }; /*! \brief A single virtual machine instruction. @@ -123,22 +125,16 @@ struct Instruction { /*! \brief The arguments to pass to the packed function. */ RegName* packed_args; }; - struct /* Select Operands */ { - /*! \brief The condition of select. */ - RegName select_cond; - /*! \brief The true branch. */ - RegName select_op1; - /*! \brief The false branch. */ - RegName select_op2; - }; struct /* If Operands */ { - /*! \brief The register containing the condition value. */ - RegName if_cond; + /*! \brief The register containing the test value. */ + RegName test; + /*! \brief The register containing the target value. */ + RegName target; /*! \brief The program counter offset for the true branch. */ Index true_offset; /*! \brief The program counter offset for the false branch. */ Index false_offset; - }; + } if_op; struct /* Invoke Operands */ { /*! \brief The function to call. */ Index func_index; @@ -151,6 +147,10 @@ struct Instruction { /* \brief The index into the constant pool. */ Index const_index; }; + struct /* LoadConsti Operands */ { + /* \brief The index into the constant pool. */ + size_t val; + } load_consti; struct /* Jump Operands */ { /*! \brief The jump offset. */ Index pc_offset; @@ -161,6 +161,10 @@ struct Instruction { /*! \brief The field to read out. */ Index field_index; }; + struct /* GetTag Operands */ { + /*! \brief The register to project from. */ + RegName object; + } get_tag; struct /* AllocDatatype Operands */ { /*! \brief The datatype's constructor tag. */ Index constructor_tag; @@ -179,19 +183,15 @@ struct Instruction { }; }; - /*! \brief Construct a select instruction. - * \param cond The condition register. - * \param op1 The true register. - * \param op2 The false register. - * \param dst The destination register. - * \return The select instruction. - */ - static Instruction Select(RegName cond, RegName op1, RegName op2, RegName dst); /*! \brief Construct a return instruction. * \param return_reg The register containing the return value. * \return The return instruction. * */ static Instruction Ret(RegName return_reg); + /*! \brief Construct a fatal instruction. + * \return The fatal instruction. + * */ + static Instruction Fatal(); /*! \brief Construct a invoke packed instruction. * \param packed_index The index of the packed function. * \param arity The arity of the function. @@ -240,13 +240,20 @@ struct Instruction { * \return The get field instruction. */ static Instruction GetField(RegName object_reg, Index field_index, RegName dst); + /*! \brief Construct a get_tag instruction. + * \param object_reg The register containing the object to project from. + * \param dst The destination register. + * \return The get_tag instruction. + */ + static Instruction GetTag(RegName object_reg, RegName dst); /*! \brief Construct an if instruction. - * \param cond_reg The register containing the condition. + * \param test The register containing the test value. + * \param target The register containing the target value. * \param true_branch The offset to the true branch. * \param false_branch The offset to the false branch. * \return The if instruction. */ - static Instruction If(RegName cond_reg, Index true_branch, Index false_branch); + static Instruction If(RegName test, RegName target, Index true_branch, Index false_branch); /*! \brief Construct a goto instruction. * \param pc_offset The offset from the current pc. * \return The goto instruction. @@ -272,6 +279,12 @@ struct Instruction { * \return The load constant instruction. */ static Instruction LoadConst(Index const_index, RegName dst); + /*! \brief Construct a load_constanti instruction. + * \param val The interger constant value. + * \param dst The destination register. + * \return The load_constanti instruction. + */ + static Instruction LoadConsti(size_t val, RegName dst); /*! \brief Construct a move instruction. * \param src The source register. * \param dst The destination register. @@ -398,6 +411,12 @@ struct VirtualMachine { */ inline Object ReadRegister(RegName reg) const; + /*! \brief Read a VM register and cast it to int32_t + * \param reg The register to read from. + * \return The read scalar. + */ + int32_t LoadScalarInt(RegName reg) const; + /*! \brief Invoke a VM function. * \param func The function. * \param args The arguments to the function. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index dae90aad17d9..6bbfa6fbc653 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -35,6 +35,7 @@ #include #include "../../../runtime/vm/naive_allocator.h" #include "../../backend/compile_engine.h" +#include "../../pass/pass_util.h" namespace tvm { namespace relay { @@ -122,6 +123,49 @@ std::tuple LayoutConstantPool(const Module& modul void InstructionPrint(std::ostream& os, const Instruction& instr); +// Represent a runtime object that's going to be matched by pattern match expressions +struct MatchValue { + virtual ~MatchValue() {} +}; +using MatchValuePtr = std::shared_ptr; + +// A runtime object that resides in a register +struct RegisterValue : MatchValue { + // The register num + RegName rergister_num; + + explicit RegisterValue(RegName reg) : rergister_num(reg) {} + + ~RegisterValue() {} +}; + +// The value is a field of another runtime object +struct AccessField : MatchValue { + MatchValuePtr parent; + // Field index + size_t index; + // Runtime register num after compiling the access field path + RegName reg{-1}; + + AccessField(MatchValuePtr parent, size_t index) + : parent(parent), index(index) {} + + ~AccessField() {} +}; + +struct VMCompiler; + +/*! + * \brief Compile a pattern match expression + * It first converts the pattern match expression into a desicision tree, the condition + * could be object comparison or variable binding. If any of the condition fails in a clause, + * the decision tree switches to check the conditions of next clause and so on. If no clause + * matches the value, a fatal node is inserted. + * + * After the decision tree is built, we convert it into bytecodes using If/Goto. + */ +void CompileMatch(Match match, VMCompiler* compiler); + struct VMCompiler : ExprFunctor { /*! \brief Store the expression a variable points to. */ std::unordered_map expr_map; @@ -159,8 +203,9 @@ struct VMCompiler : ExprFunctor { case Opcode::AllocTensor: case Opcode::AllocTensorReg: case Opcode::GetField: + case Opcode::GetTag: case Opcode::LoadConst: - case Opcode::Select: + case Opcode::LoadConsti: case Opcode::Invoke: case Opcode::AllocClosure: case Opcode::Move: @@ -173,6 +218,7 @@ struct VMCompiler : ExprFunctor { case Opcode::If: case Opcode::Ret: case Opcode::Goto: + case Opcode::Fatal: break; } instructions.push_back(instr); @@ -211,8 +257,9 @@ struct VMCompiler : ExprFunctor { void VisitExpr_(const MatchNode* match_node) { auto match = GetRef(match_node); - LOG(FATAL) << "translation of match nodes to the VM is" - << "currently unsupported"; + + this->VisitExpr(match->data); + CompileMatch(match, this); } void VisitExpr_(const LetNode* let_node) { @@ -242,15 +289,15 @@ struct VMCompiler : ExprFunctor { void VisitExpr_(const IfNode* if_node) { this->VisitExpr(if_node->cond); - size_t cond_register = last_register; + size_t test_register = last_register; + this->Emit(Instruction::LoadConsti(1, NewRegister())); auto after_cond = this->instructions.size(); - - this->Emit(Instruction::If(cond_register, 0, 0)); + auto target_register = this->last_register; + this->Emit(Instruction::If(test_register, target_register, 0, 0)); this->VisitExpr(if_node->true_branch); size_t true_register = last_register; - Emit(Instruction::Goto(0)); // Finally store how many instructions there are in the @@ -261,6 +308,8 @@ struct VMCompiler : ExprFunctor { size_t false_register = last_register; + // In else-branch, override the then-branch register + Emit(Instruction::Move(false_register, true_register)); // Compute the total number of instructions // after generating false. auto after_false = this->instructions.size(); @@ -273,13 +322,13 @@ struct VMCompiler : ExprFunctor { // we patch up the if instruction, and goto. auto true_offset = 1; auto false_offset = after_true - after_cond; - this->instructions[after_cond].true_offset = true_offset; - this->instructions[after_cond].false_offset = false_offset; + this->instructions[after_cond].if_op.true_offset = true_offset; + this->instructions[after_cond].if_op.false_offset = false_offset; // Patch the Goto. this->instructions[after_true - 1].pc_offset = (after_false - after_true) + 1; - Emit(Instruction::Select(cond_register, true_register, false_register, NewRegister())); + this->last_register = true_register; } Instruction AllocTensorFromType(const TensorTypeNode* ttype) { @@ -464,6 +513,160 @@ struct VMCompiler : ExprFunctor { } }; +/*! + * \brief Compile a match value + * Generate byte code that compute the value specificed in val + * + * \return The register number assigned for the final value + */ +RegName CompileMatchValue(MatchValuePtr val, VMCompiler* compiler) { + if (std::dynamic_pointer_cast(val)) { + auto r = std::dynamic_pointer_cast(val); + return r->rergister_num; + } else { + auto path = std::dynamic_pointer_cast(val); + auto p = CompileMatchValue(path->parent, compiler); + compiler->Emit(Instruction::GetField(p, path->index, compiler->NewRegister())); + path->reg = compiler->last_register; + return path->reg; + } +} + +/*! + * \brief Condition in a decision tree + */ +struct ConditionNode { + virtual ~ConditionNode() {} +}; + +using ConditionNodePtr = std::shared_ptr; + +/*! + * \brief A var binding condition + */ +struct VarBinding : ConditionNode { + Var var; + MatchValuePtr val; + + VarBinding(Var var, MatchValuePtr val) + : var(var), val(val) {} + + ~VarBinding() {} +}; + +/*! + * \brief Compare the tag of the object + */ +struct TagCompare : ConditionNode { + /*! \brief The object to be examined */ + MatchValuePtr obj; + + /*! \brief The expected tag */ + int target_tag; + + TagCompare(MatchValuePtr obj, size_t target) + : obj(obj), target_tag(target) { + } + + ~TagCompare() {} +}; + +using TreeNodePtr = typename relay::TreeNode::pointer; +using TreeLeafNode = relay::TreeLeafNode; +using TreeLeafFatalNode = relay::TreeLeafFatalNode; +using TreeBranchNode = relay::TreeBranchNode; + +void CompileTreeNode(TreeNodePtr tree, VMCompiler* compiler) { + if (std::dynamic_pointer_cast(tree)) { + auto node = std::dynamic_pointer_cast(tree); + compiler->VisitExpr(node->body); + } else if (std::dynamic_pointer_cast(tree)) { + compiler->Emit(Instruction::Fatal()); + } else if (std::dynamic_pointer_cast(tree)) { + auto node = std::dynamic_pointer_cast(tree); + if (std::dynamic_pointer_cast(node->cond)) { + // For Tag compariton, generate branches + auto cond = std::dynamic_pointer_cast(node->cond); + auto r = CompileMatchValue(cond->obj, compiler); + compiler->Emit(Instruction::GetTag(r, compiler->NewRegister())); + auto operand1 = compiler->last_register; + compiler->Emit(Instruction::LoadConsti(cond->target_tag, compiler->NewRegister())); + auto operand2 = compiler->last_register; + + compiler->Emit(Instruction::If(operand1, operand2, 1, 0)); + auto cond_offset = compiler->instructions.size() - 1; + CompileTreeNode(node->then_branch, compiler); + auto if_reg = compiler->last_register; + compiler->Emit(Instruction::Goto(1)); + auto goto_offset = compiler->instructions.size() - 1; + CompileTreeNode(node->else_branch, compiler); + auto else_reg = compiler->last_register; + compiler->Emit(Instruction::Move(else_reg, if_reg)); + compiler->last_register = if_reg; + auto else_offset = compiler->instructions.size() - 1; + // Fixing offsets + compiler->instructions[cond_offset].if_op.false_offset = goto_offset - cond_offset + 1; + compiler->instructions[goto_offset].pc_offset = else_offset - goto_offset + 1; + } else { + // For other non-branch conditions, move to then_branch directly + auto cond = std::dynamic_pointer_cast(node->cond); + compiler->var_register_map[cond->var] = CompileMatchValue(cond->val, compiler); + CompileTreeNode(node->then_branch, compiler); + } + } +} + +TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data, + Pattern pattern, + TreeNodePtr then_branch, + TreeNodePtr else_branch) { + if (pattern.as()) { + // We ignore wildcard binding since it's not producing new vars + return then_branch; + } else if (pattern.as()) { + auto pat = pattern.as(); + auto pattern = GetRef(pat); + auto cond = std::make_shared(pattern->var, data); + return TreeBranchNode::Make(cond, then_branch, else_branch); + } else { + auto pat = pattern.as(); + auto pattern = GetRef(pat); + auto tag = pattern->constructor->tag; + + size_t field_index = 0; + for (auto& p : pattern->patterns) { + auto d = std::make_shared(data, field_index); + then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch); + field_index++; + } + auto cond = std::make_shared(data, tag); + return TreeBranchNode::Make(cond, then_branch, else_branch); + } +} + +TreeNodePtr BuildDecisionTreeFromClause(MatchValuePtr data, + Clause clause, + TreeNodePtr else_branch) { + return BuildDecisionTreeFromPattern(data, clause->lhs, + TreeLeafNode::Make(clause->rhs), else_branch); +} + +TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array clauses) { + // When nothing matches, the VM throws fatal error + TreeNodePtr else_branch = TreeLeafFatalNode::Make(); + // Start from the last clause + for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) { + else_branch = BuildDecisionTreeFromClause(data, *it, else_branch); + } + return else_branch; +} + +void CompileMatch(Match match, VMCompiler* compiler) { + auto data = std::make_shared(compiler->last_register); + auto decision_tree = BuildDecisionTreeFromClauses(data, match->clauses); + CompileTreeNode(decision_tree, compiler); +} + void PopulatePackedFuncMap(const std::vector& lowered_funcs, std::vector* packed_funcs) { runtime::Module mod; diff --git a/src/relay/pass/pass_util.h b/src/relay/pass/pass_util.h index 386d1d889ea8..7f7fc8060e85 100644 --- a/src/relay/pass/pass_util.h +++ b/src/relay/pass/pass_util.h @@ -21,7 +21,7 @@ * Copyright (c) 2018 by Contributors. * * \file tvm/relay/pass/pass_util.h - * \brief Utilities for writing + * \brief Utilities for writing passes */ #ifndef TVM_RELAY_PASS_PASS_UTIL_H_ #define TVM_RELAY_PASS_PASS_UTIL_H_ @@ -29,6 +29,7 @@ #include #include #include +#include #include namespace tvm { @@ -108,6 +109,63 @@ inline bool IsAtomic(const Expr& e) { return e.as() || e.as() || e.as() || e.as(); } +template +struct TreeNode { + typedef std::shared_ptr> pointer; + virtual ~TreeNode() {} +}; + +template +struct TreeLeafNode : TreeNode { + using TreeNodePtr = typename TreeNode::pointer; + + Expr body; + + explicit TreeLeafNode(Expr body): body(body) {} + + static TreeNodePtr Make(Expr body) { + return std::make_shared(body); + } + + ~TreeLeafNode() {} +}; + +template +struct TreeLeafFatalNode : TreeNode { + using TreeNodePtr = typename TreeNode::pointer; + + TreeLeafFatalNode() = default; + + static TreeNodePtr Make() { + return std::make_shared(); + } + + ~TreeLeafFatalNode() {} +}; + +template +struct TreeBranchNode : TreeNode { + using TreeNodePtr = typename TreeNode::pointer; + + ConditionNodePtr cond; + TreeNodePtr then_branch; + TreeNodePtr else_branch; + + TreeBranchNode(ConditionNodePtr cond, + TreeNodePtr then_branch, + TreeNodePtr else_branch) + : cond(cond), then_branch(then_branch), else_branch(else_branch) {} + + + static TreeNodePtr Make(ConditionNodePtr cond, + TreeNodePtr then_branch, + TreeNodePtr else_branch) { + return std::make_shared(cond, then_branch, else_branch); + } + + ~TreeBranchNode() {} +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PASS_UTIL_H_ diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 3df4047ddceb..26272d339022 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -58,10 +58,7 @@ Instruction::Instruction(const Instruction& instr) { case Opcode::Move: this->from = instr.from; return; - case Opcode::Select: - this->select_cond = instr.select_cond; - this->select_op1 = instr.select_op1; - this->select_op2 = instr.select_op2; + case Opcode::Fatal: return; case Opcode::Ret: this->result = instr.result; @@ -103,17 +100,21 @@ Instruction::Instruction(const Instruction& instr) { this->invoke_args_registers = Duplicate(instr.invoke_args_registers, instr.num_args); return; case Opcode::If: - this->if_cond = instr.if_cond; - this->true_offset = instr.true_offset; - this->false_offset = instr.false_offset; + this->if_op = instr.if_op; return; case Opcode::LoadConst: this->const_index = instr.const_index; return; + case Opcode::LoadConsti: + this->load_consti = instr.load_consti; + return; case Opcode::GetField: this->object = instr.object; this->field_index = instr.field_index; return; + case Opcode::GetTag: + this->get_tag = instr.get_tag; + return; case Opcode::Goto: this->pc_offset = instr.pc_offset; return; @@ -139,10 +140,10 @@ Instruction& Instruction::operator=(const Instruction& instr) { case Opcode::Move: this->from = instr.from; return *this; - case Opcode::Select: - this->select_cond = instr.select_cond; - this->select_op1 = instr.select_op1; - this->select_op2 = instr.select_op2; + case Opcode::Fatal: + return *this; + case Opcode::LoadConsti: + this->load_consti = instr.load_consti; return *this; case Opcode::Ret: this->result = instr.result; @@ -189,9 +190,7 @@ Instruction& Instruction::operator=(const Instruction& instr) { this->invoke_args_registers = Duplicate(instr.invoke_args_registers, instr.num_args); return *this; case Opcode::If: - this->if_cond = instr.if_cond; - this->true_offset = instr.true_offset; - this->false_offset = instr.false_offset; + this->if_op = instr.if_op; return *this; case Opcode::LoadConst: this->const_index = instr.const_index; @@ -200,6 +199,9 @@ Instruction& Instruction::operator=(const Instruction& instr) { this->object = instr.object; this->field_index = instr.field_index; return *this; + case Opcode::GetTag: + this->get_tag = instr.get_tag; + return *this; case Opcode::Goto: this->pc_offset = instr.pc_offset; return *this; @@ -213,13 +215,15 @@ Instruction& Instruction::operator=(const Instruction& instr) { Instruction::~Instruction() { switch (this->op) { case Opcode::Move: - case Opcode::Select: case Opcode::Ret: case Opcode::AllocTensorReg: case Opcode::If: case Opcode::LoadConst: case Opcode::GetField: + case Opcode::GetTag: case Opcode::Goto: + case Opcode::LoadConsti: + case Opcode::Fatal: return; case Opcode::AllocTensor: delete this->alloc_tensor.shape; @@ -252,6 +256,12 @@ Instruction Instruction::Ret(RegName result) { return instr; } +Instruction Instruction::Fatal() { + Instruction instr; + instr.op = Opcode::Fatal; + return instr; +} + Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index output_size, const std::vector& args) { Instruction instr; @@ -325,22 +335,21 @@ Instruction Instruction::GetField(RegName object, Index field_index, RegName dst return instr; } -Instruction Instruction::If(RegName cond, Index true_branch, Index false_branch) { +Instruction Instruction::GetTag(RegName object, RegName dst) { Instruction instr; - instr.op = Opcode::If; - instr.if_cond = cond; - instr.true_offset = true_branch; - instr.false_offset = false_branch; + instr.op = Opcode::GetTag; + instr.dst = dst; + instr.get_tag.object = object; return instr; } -Instruction Instruction::Select(RegName cond, RegName op1, RegName op2, RegName dst) { +Instruction Instruction::If(RegName test, RegName target, Index true_branch, Index false_branch) { Instruction instr; - instr.op = Opcode::Select; - instr.dst = dst; - instr.select_cond = cond; - instr.select_op1 = op1; - instr.select_op2 = op2; + instr.op = Opcode::If; + instr.if_op.test = test; + instr.if_op.target = target; + instr.if_op.true_offset = true_branch; + instr.if_op.false_offset = false_branch; return instr; } @@ -387,6 +396,14 @@ Instruction Instruction::LoadConst(Index const_index, RegName dst) { return instr; } +Instruction Instruction::LoadConsti(size_t val, RegName dst) { + Instruction instr; + instr.op = Opcode::LoadConsti; + instr.dst = dst; + instr.load_consti.val = val; + return instr; +} + Instruction Instruction::Move(RegName src, RegName dst) { Instruction instr; instr.op = Opcode::Move; @@ -437,6 +454,10 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { os << "ret $" << instr.result; break; } + case Opcode::Fatal: { + os << "fatal"; + break; + } case Opcode::InvokePacked: { os << "invoke_packed PackedFunc[" << instr.packed_index << "](in: $" << StrJoin(instr.packed_args, 0, instr.arity - instr.output_size, ",$") @@ -471,8 +492,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::If: { - os << "if " << "$" << instr.if_cond << " " << instr.true_offset << " " - << instr.false_offset; + os << "if " << "$" << instr.if_op.test << " " << instr.if_op.target << " " + << instr.if_op.true_offset << " " << instr.if_op.false_offset; break; } case Opcode::Invoke: { @@ -491,18 +512,21 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]"; break; } + case Opcode::LoadConsti: { + os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]"; + break; + } case Opcode::GetField: { os << "get_field $" << instr.dst << " $" << instr.object << "[" << instr.field_index << "]"; break; } - case Opcode::Goto: { - os << "goto " << instr.pc_offset; + case Opcode::GetTag: { + os << "get_tag $" << instr.dst << " $" << instr.get_tag.object; break; } - case Opcode::Select: { - os << "select $" << instr.dst << " $" << instr.select_cond << " $" - << instr.select_op1 << " $" << instr.select_op2; + case Opcode::Goto: { + os << "goto " << instr.pc_offset; break; } default: @@ -617,6 +641,21 @@ inline Object VirtualMachine::ReadRegister(Index r) const { return frames.back().register_file[r]; } +inline int32_t VirtualMachine::LoadScalarInt(Index r) const { + int32_t result; + const auto& obj = ReadRegister(r); + NDArray array = ToNDArray(obj).CopyTo({kDLCPU, 0}); + + if (array->dtype.bits <= 8) { + result = reinterpret_cast(array->data)[0]; + } else if (array->dtype.bits <= 16) { + result = reinterpret_cast(array->data)[0]; + } else { + result = reinterpret_cast(array->data)[0]; + } + return result; +} + void VirtualMachine::Run() { CHECK(this->code); this->pc = 0; @@ -632,20 +671,26 @@ void VirtualMachine::Run() { switch (instr.op) { case Opcode::Move: { Object from_obj; - if (instr.from == 0) { - from_obj = return_register; - } else { - from_obj = ReadRegister(instr.from); - } + from_obj = ReadRegister(instr.from); WriteRegister(instr.dst, from_obj); pc++; goto main_loop; } + case Opcode::Fatal: { + throw std::runtime_error("VM encountered fatal error"); + } case Opcode::LoadConst: { WriteRegister(instr.dst, this->constants[instr.const_index]); pc++; goto main_loop; } + case Opcode::LoadConsti: { + auto tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); + reinterpret_cast(tensor->data)[0] = instr.load_consti.val; + WriteRegister(instr.dst, Object::Tensor(tensor)); + pc++; + goto main_loop; + } case Opcode::Invoke: { std::vector args; for (Index i = 0; i < instr.num_args; ++i) { @@ -695,25 +740,34 @@ void VirtualMachine::Run() { pc++; goto main_loop; } + case Opcode::GetTag: { + auto object = ReadRegister(instr.get_tag.object); + CHECK(object->tag == ObjectTag::kDatatype) + << "Object is not data type object, register " + << instr.get_tag.object << ", Object tag " + << static_cast(object->tag); + const auto& data = object.AsDatatype(); + auto tag = data->tag; + auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); + reinterpret_cast(tag_tensor->data)[0] = tag; + WriteRegister(instr.dst, Object::Tensor(tag_tensor)); + pc++; + goto main_loop; + } case Opcode::Goto: { pc += instr.pc_offset; goto main_loop; } case Opcode::If: { - // How do we do this efficiently? - DLContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; + int32_t test_val = LoadScalarInt(instr.if_op.test); + int32_t target_val = LoadScalarInt(instr.if_op.target); - const auto& cond = ReadRegister(instr.if_cond); - NDArray cpu_array = ToNDArray(cond).CopyTo(cpu_ctx); - // CHECK_EQ(cpu_array->dtype, Bool()); - bool branch = reinterpret_cast(cpu_array->data)[0]; - - if (branch) { - pc += instr.true_offset; + if (test_val == target_val) { + CHECK_NE(instr.if_op.true_offset, 0); + pc += instr.if_op.true_offset; } else { - pc += instr.false_offset; + CHECK_NE(instr.if_op.false_offset, 0); + pc += instr.if_op.false_offset; } goto main_loop; @@ -768,26 +822,6 @@ void VirtualMachine::Run() { pc++; goto main_loop; } - case Opcode::Select: { - DLContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - - auto cond = ReadRegister(instr.select_cond); - NDArray cpu_array = ToNDArray(cond).CopyTo(cpu_ctx); - // CHECK_EQ(TVMType2Type(cpu_array->dtype), Bool()); - bool branch = reinterpret_cast(cpu_array->data)[0]; - - if (branch) { - auto op1 = ReadRegister(instr.select_op1); - WriteRegister(instr.dst, op1); - } else { - auto op2 = ReadRegister(instr.select_op2); - WriteRegister(instr.dst, op2); - } - pc++; - goto main_loop; - } case Opcode::Ret: { // If we have hit the point from which we started // running, we should return to the caller breaking diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index f85d21255736..3b525ac58ce4 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import os -from nose.tools import nottest +from nose.tools import nottest, raises import tvm import numpy as np @@ -39,6 +39,15 @@ def veval(f, *args, ctx=tvm.cpu()): else: return ex.evaluate()(*args) +def vmobj_to_list(o): + if isinstance(o, tvm.relay.backend.interpreter.TensorValue): + return [o.data.asnumpy().tolist()] + if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): + result = [] + for f in o.fields: + result.extend(vmobj_to_list(f)) + return result + def test_split(): x = relay.var('x', shape=(12,)) y = relay.split(x, 3, axis=0).astuple() @@ -186,15 +195,6 @@ def test_tuple_second(): tvm.testing.assert_allclose(result.asnumpy(), j_data) def test_list_constructor(): - def to_list(o): - if isinstance(o, tvm.relay.backend.interpreter.TensorValue): - return [o.data.asnumpy().tolist()] - if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): - result = [] - for f in o.fields: - result.extend(to_list(f)) - return result - mod = relay.Module() p = Prelude(mod) @@ -202,11 +202,6 @@ def to_list(o): cons = p.cons l = p.l - # remove all functions to not have pattern match to pass vm compilation - # TODO(wweic): remove the hack and implement pattern match - for v, _ in mod.functions.items(): - mod[v] = relay.const(0) - one2 = cons(relay.const(1), nil()) one3 = cons(relay.const(2), one2) one4 = cons(relay.const(3), one3) @@ -215,7 +210,7 @@ def to_list(o): mod["main"] = f result = veval(mod)() - obj = to_list(result) + obj = vmobj_to_list(result) tvm.testing.assert_allclose(obj, np.array([3,2,1])) def test_let_tensor(): @@ -256,13 +251,6 @@ def test_compose(): compose = p.compose - # remove all functions to not have pattern match to pass vm compilation - # TODO(wweic): remove the hack and implement pattern match - for v, _ in mod.functions.items(): - if v.name_hint == 'compose': - continue - mod[v] = relay.const(0) - # add_one = fun x -> x + 1 sb = relay.ScopeBuilder() x = relay.var('x', 'float32') @@ -291,6 +279,215 @@ def test_compose(): tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0) +def test_list_hd(): + mod = relay.Module() + p = Prelude(mod) + + nil = p.nil + cons = p.cons + l = p.l + hd = p.hd + + one2 = cons(relay.const(1), nil()) + one3 = cons(relay.const(2), one2) + one4 = cons(relay.const(3), one3) + three = hd(one4) + f = relay.Function([], three) + + mod["main"] = f + + result = veval(mod)() + tvm.testing.assert_allclose(result.asnumpy(), 3) + +@raises(Exception) +def test_list_tl_empty_list(): + mod = relay.Module() + p = Prelude(mod) + + nil = p.nil + l = p.l + tl = p.tl + + f = relay.Function([], tl(nil())) + + mod["main"] = f + + result = veval(mod)() + print(result) + +def test_list_tl(): + mod = relay.Module() + p = Prelude(mod) + + nil = p.nil + cons = p.cons + l = p.l + tl = p.tl + + one2 = cons(relay.const(1), nil()) + one3 = cons(relay.const(2), one2) + one4 = cons(relay.const(3), one3) + + f = relay.Function([], tl(one4)) + + mod["main"] = f + + result = veval(mod)() + tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1])) + +def test_list_nth(): + expected = list(range(10)) + + for i in range(len(expected)): + mod = relay.Module() + p = Prelude(mod) + + nil = p.nil + cons = p.cons + nth = p.nth + l = nil() + for i in reversed(expected): + l = cons(relay.const(i), l) + + f = relay.Function([], nth(l, relay.const(i))) + mod["main"] = f + result = veval(mod)() + tvm.testing.assert_allclose(result.asnumpy(), expected[i]) + +def test_list_update(): + expected = list(range(10)) + + mod = relay.Module() + p = Prelude(mod) + + nil = p.nil + cons = p.cons + update = p.update + + l = nil() + # create zero initialized list + for i in range(len(expected)): + l = cons(relay.const(0), l) + + # set value + for i, v in enumerate(expected): + l = update(l, relay.const(i), relay.const(v)) + + f = relay.Function([], l) + mod["main"] = f + result = veval(mod)() + tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected)) + +def test_list_length(): + expected = list(range(10)) + + mod = relay.Module() + p = Prelude(mod) + + nil = p.nil + cons = p.cons + length = p.length + + l = nil() + # create zero initialized list + for i in range(len(expected)): + l = cons(relay.const(0), l) + + l = length(l) + + f = relay.Function([], l) + mod["main"] = f + result = veval(mod)() + tvm.testing.assert_allclose(result.asnumpy(), 10) + +def test_list_map(): + mod = relay.Module() + p = Prelude(mod) + + x = relay.var('x', 'int32') + add_one_func = relay.Function([x], relay.const(1) + x) + + nil = p.nil + cons = p.cons + map = p.map + + l = cons(relay.const(2), cons(relay.const(1), nil())) + + f = relay.Function([], map(add_one_func, l)) + mod["main"] = f + result = veval(mod)() + tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2])) + +def test_list_foldl(): + mod = relay.Module() + p = Prelude(mod) + + nil = p.nil + cons = p.cons + foldl = p.foldl + + x = relay.var("x") + y = relay.var("y") + rev_dup_func = relay.Function([y, x], cons(x, cons(x, y))) + + l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil()))) + f = relay.Function([], foldl(rev_dup_func, nil(), l)) + mod["main"] = f + result = veval(mod)() + tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1])) + +def test_list_foldr(): + mod = relay.Module() + p = Prelude(mod) + + nil = p.nil + cons = p.cons + foldr = p.foldr + + x = relay.var("x") + y = relay.var("y") + identity_func = relay.Function([x, y], cons(x, y)) + + l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil()))) + f = relay.Function([], foldr(identity_func, nil(), l)) + mod["main"] = f + result = veval(mod)() + tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3])) + +def test_list_sum(): + mod = relay.Module() + p = Prelude(mod) + + nil = p.nil + cons = p.cons + sum = p.sum + + l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil()))) + f = relay.Function([], sum(l)) + mod["main"] = f + result = veval(mod)() + tvm.testing.assert_allclose(result.asnumpy(), 6) + +def test_list_filter(): + mod = relay.Module() + p = Prelude(mod) + + nil = p.nil + cons = p.cons + filter = p.filter + + x = relay.var("x", 'int32') + greater_than_one = relay.Function([x], x > relay.const(1)) + l = cons(relay.const(1), + cons(relay.const(3), + cons(relay.const(1), + cons(relay.const(5), + cons(relay.const(1), nil()))))) + f = relay.Function([], filter(greater_than_one, l)) + mod["main"] = f + result = veval(mod)() + tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5])) + def test_closure(): x = relay.var('x', shape=()) y = relay.var('y', shape=()) @@ -315,6 +512,15 @@ def test_closure(): test_let_tensor() test_split() test_split_no_fuse() - # TODO(@jroesch): restore when match is supported - # test_list_constructor() + test_list_constructor() + test_list_tl_empty_list() + test_list_tl() + test_list_nth() + test_list_update() + test_list_length() + test_list_map() + test_list_foldl() + test_list_foldr() + test_list_sum() + test_list_filter() test_closure()