From f5258a83ee421ee7c20605a5f97555f1a36dfbc5 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Tue, 22 Jan 2019 14:30:53 -0800 Subject: [PATCH] move fix test fix lint fix test add more code fix lint --- include/tvm/relay/expr.h | 68 ++++++++++++++++++- include/tvm/relay/expr_functor.h | 12 ++++ include/tvm/relay/interpreter.h | 19 ++++++ include/tvm/relay/type.h | 27 ++++++++ python/tvm/relay/__init__.py | 12 ++-- .../relay/backend/graph_runtime_codegen.py | 9 +++ python/tvm/relay/backend/interpreter.py | 8 +++ python/tvm/relay/expr.py | 40 +++++++++++ python/tvm/relay/expr_functor.py | 14 ++++ python/tvm/relay/ty.py | 13 ++++ src/relay/backend/interpreter.cc | 42 ++++++++++++ src/relay/ir/alpha_equal.cc | 31 +++++++++ src/relay/ir/expr.cc | 51 +++++++++++++- src/relay/ir/expr_functor.cc | 41 +++++++++++ src/relay/ir/hash.cc | 30 +++++++- src/relay/ir/text_printer.cc | 32 +++++++++ src/relay/ir/type.cc | 19 ++++++ src/relay/ir/type_functor.cc | 8 +++ src/relay/ir/type_functor.h | 5 +- src/relay/pass/fuse_ops.cc | 36 +++++++--- src/relay/pass/kind_check.cc | 6 ++ src/relay/pass/type_infer.cc | 39 +++++++++++ src/relay/pass/type_solver.cc | 14 +++- .../python/relay/test_backend_interpreter.py | 19 ++++++ tests/python/relay/test_type_infer.py | 17 +++-- 25 files changed, 585 insertions(+), 27 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 14b3cd91701c6..3eb1bba8667fe 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -428,12 +428,78 @@ class TupleGetItemNode : public ExprNode { TVM_DLL static TupleGetItem make(Expr tuple, int index); - static constexpr const char * _type_key = "relay.TupleGetItem"; + static constexpr const char* _type_key = "relay.TupleGetItem"; TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode); }; RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); +/*! \brief Create a new Reference out of initial value. */ +class RefNew; +class RefNewNode : public ExprNode { + public: + /*! \brief The initial value of the Reference. */ + Expr value; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("value", &value); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static RefNew make(Expr value); + + static constexpr const char* _type_key = "relay.RefNew"; + TVM_DECLARE_NODE_TYPE_INFO(RefNewNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(RefNew, RefNewNode, Expr); + +/*! \brief Get value out of Reference. */ +class RefRead; +class RefReadNode : public ExprNode { + public: + /*! \brief The Reference Expression. */ + Expr ref; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("ref", &ref); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static RefRead make(Expr ref); + + static constexpr const char* _type_key = "relay.RefRead"; + TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr); + +/*! \brief Set value of Reference. The whole expression evaluate to an Empty Tuple. */ +class RefWrite; +class RefWriteNode : public ExprNode { + public: + /*! \brief The Reference Expression. */ + Expr ref; + /*! \brief The value to write into. */ + Expr value; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("ref", &ref); + v->Visit("value", &value); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static RefWrite make(Expr ref, Expr value); + + static constexpr const char* _type_key = "relay.RefWrite"; + TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr); + /*! * \brief Base class of the temporary expression. * diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 60b18218a3131..661ba3245e643 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -89,6 +89,9 @@ class ExprFunctor { virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const RefNewNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw Error(std::string("Do not have a default for ") + op->type_key()); } @@ -108,6 +111,9 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); + RELAY_EXPR_FUNCTOR_DISPATCH(RefNewNode); + RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode); + RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); return vtable; } }; @@ -133,6 +139,9 @@ class ExprVisitor void VisitExpr_(const IfNode* op) override; void VisitExpr_(const OpNode* op) override; void VisitExpr_(const TupleGetItemNode* op) override; + void VisitExpr_(const RefNewNode* op) override; + void VisitExpr_(const RefReadNode* op) override; + void VisitExpr_(const RefWriteNode* op) override; virtual void VisitType(const Type& t); protected: @@ -168,6 +177,9 @@ class ExprMutator Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const IfNode* op) override; Expr VisitExpr_(const TupleGetItemNode* op) override; + Expr VisitExpr_(const RefNewNode* op) override; + Expr VisitExpr_(const RefReadNode* op) override; + Expr VisitExpr_(const RefWriteNode* op) override; /*! * \brief Used to visit the types inside of expressions. * diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 1099ef0f3cfdd..08aeef1827b68 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -140,6 +140,25 @@ struct TensorValueNode : ValueNode { RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value); +/*! \brief A reference value. */ +class RefValue; + +struct RefValueNode : ValueNode { + mutable Value value; + + RefValueNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("value", &value); + } + + TVM_DLL static RefValue make(Value val); + + static constexpr const char* _type_key = "relay.RefValue"; + TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 69a8a4fb0bd7d..211e46beaa4f5 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -262,6 +262,33 @@ class TupleTypeNode : public TypeNode { RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type); +/*! + * \brief The type of reference values. + */ +class RefType; +/*! + * \brief Reference Type in relay. + */ +class RefTypeNode : public TypeNode { + public: + /*! \brief The type of value in the Reference. */ + Type value; + + RefTypeNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("value", &value); + v->Visit("span", &span); + } + + TVM_DLL static RefType make(Type value); + + static constexpr const char* _type_key = "relay.RefType"; + TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type); + class TypeReporter; /*! diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 44d819ea78a36..0981767d2fa63 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -43,6 +43,7 @@ TypeRelation = ty.TypeRelation IncompleteType = ty.IncompleteType scalar_type = ty.scalar_type +RefType = ty.RefType # Expr Expr = expr.Expr @@ -55,15 +56,18 @@ Let = expr.Let If = expr.If TupleGetItem = expr.TupleGetItem - -# ExprFunctor -ExprFunctor = expr_functor.ExprFunctor -ExprMutator = expr_functor.ExprMutator +RefNew = expr.RefNew +RefRead = expr.RefRead +RefWrite = expr.RefWrite # helper functions var = expr.var const = expr.const bind = expr.bind +# ExprFunctor +ExprFunctor = expr_functor.ExprFunctor +ExprMutator = expr_functor.ExprMutator + # Parser fromtext = parser.fromtext diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index 15e0a81226cb0..893cbcb361635 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -283,6 +283,15 @@ def visit_call(self, call): def visit_op(self, _): raise Exception("can not compile op in non-eta expanded form") + def visit_ref_new(self, _): + raise RuntimeError("reference not supported") + + def visit_ref_read(self, _): + raise RuntimeError("reference not supported") + + def visit_ref_write(self, _): + raise RuntimeError("reference not supported") + def _get_json(self): """ Convert the sequence of nodes stored by the compiler into the diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 4a5ddcd8270c5..b21eab185c28c 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -45,6 +45,7 @@ def __repr__(self): def __iter__(self): return iter(self.fields) + @register_relay_node class Closure(Value): """A closure produced by the interpreter.""" @@ -79,6 +80,13 @@ def __str__(self): return str(self.data) +@register_relay_node +class RefValue(Value): + def __init__(self, value): + self.__init_handle_by_constructor__( + _make.RefValue, value) + + def _arg_to_ast(arg): if isinstance(arg, TensorValue): return Constant(arg.data.copyto(_nd.cpu(0))) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index f510d6195127e..a428f3a5a06a0 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -327,6 +327,46 @@ def __init__(self, tuple_value, index): _make.TupleGetItem, tuple_value, index) +@register_relay_node +class RefNew(Expr): + """Create a new reference from initial value. + Parameters + ---------- + value: tvm.relay.Expr + The initial value. + """ + def __init__(self, value): + self.__init_handle_by_constructor__(_make.RefNew, value) + + +@register_relay_node +class RefRead(Expr): + """Get the value inside the reference. + Parameters + ---------- + ref: tvm.relay.Expr + The reference. + """ + def __init__(self, ref): + self.__init_handle_by_constructor__(_make.RefRead, ref) + + +@register_relay_node +class RefWrite(Expr): + """ + Update the value inside the reference. + The whole expression will evaluate to an empty tuple. + Parameters + ---------- + ref: tvm.relay.Expr + The reference. + value: tvm.relay.Expr + The new value. + """ + def __init__(self, ref, value): + self.__init_handle_by_constructor__(_make.RefWrite, ref, value) + + class TempExpr(Expr): """Baseclass of all TempExpr. diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index eafe5f09309ff..2d7b84610ea4e 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -41,6 +41,12 @@ def visit(self, expr): res = self.visit_constant(expr) elif isinstance(expr, Op): res = self.visit_op(expr) + elif isinstance(expr, RefNew): + res = self.visit_ref_new(expr) + elif isinstance(expr, RefRead): + res = self.visit_ref_read(expr) + elif isinstance(expr, RefWrite): + res = self.visit_ref_write(expr) else: raise Exception("warning unhandled case: {0}".format(type(expr))) @@ -81,6 +87,14 @@ def visit_op(self, _): def visit_constant(self, _): raise NotImplementedError() + def visit_ref_new(self, _): + raise NotImplementedError() + + def visit_ref_write(self, _): + raise NotImplementedError() + + def visit_ref_read(self, _): + raise NotImplementedError() class ExprMutator(ExprFunctor): """ diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 96dde5acb4dfe..bed293d1e3caf 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -210,6 +210,19 @@ def __init__(self, func, args, num_inputs, attrs): func, args, num_inputs, attrs) +@register_relay_node +class RefType(Type): + """Reference Type in relay. + + Parameters + ---------- + value: Type + The value type. + """ + def __init__(self, value): + self.__init_handle_by_constructor__(_make.RefType, value) + + def scalar_type(dtype): """Creates a scalar type. diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 734180c537596..23243bf0b6580 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -75,6 +75,23 @@ TVM_REGISTER_API("relay._make.TensorValue") *ret = TensorValueNode::make(data); }); +RefValue RefValueNode::make(Value value) { + NodePtr n = make_node(); + n->value = value; + return RefValue(n); +} + +TVM_REGISTER_API("relay._make.RefValue") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = RefValueNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const RefValueNode* node, + tvm::IRPrinter* p) { + p->stream << "RefValueNode(" << node->value << ")"; + }); + /*! * \brief A stack frame in the Relay interpreter. * @@ -432,6 +449,31 @@ class Interpreter : } } + Value VisitExpr_(const RefWriteNode* op) final { + Value r = Eval(op->ref); + if (const RefValueNode* rv = r.as()) { + rv->value = Eval(op->value); + return TupleValueNode::make({}); + } else { + LOG(FATAL) << "type error, type system should have caught this"; + return Value(); + } + } + + Value VisitExpr_(const RefNewNode* op) final { + return RefValueNode::make(Eval(op->value)); + } + + Value VisitExpr_(const RefReadNode* op) final { + Value r = Eval(op->ref); + if (const RefValueNode* rv = r.as()) { + return rv->value; + } else { + LOG(FATAL) << "type error, type system should have caught this"; + return Value(); + } + } + InterpreterState get_state(Expr e = Expr()) const { InterpreterStateNode::Stack stack; for (auto fr : this->stack_.frames) { diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 064343c834ea0..e059ad32254f0 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -207,6 +207,14 @@ class AlphaEqualHandler: return false; } } + + bool VisitType_(const RefTypeNode* lhs, const Type& other) final { + if (const RefTypeNode* rhs = other.as()) { + return TypeEqual(lhs->value, rhs->value); + } + return false; + } + // Expr equal checking. bool NDArrayEqual(const runtime::NDArray& lhs, const runtime::NDArray& rhs) { @@ -361,6 +369,29 @@ class AlphaEqualHandler: } } + bool VisitExpr_(const RefNewNode* op, const Expr& e2) final { + if (const RefNewNode* nr = e2.as()) { + return ExprEqual(op->value, nr->value); + } else { + return false; + } + } + + bool VisitExpr_(const RefReadNode* op, const Expr& e2) final { + if (const RefReadNode* r = e2.as()) { + return ExprEqual(op->ref, r->ref); + } else { + return false; + } + } + + bool VisitExpr_(const RefWriteNode* op, const Expr& e2) final { + if (const RefWriteNode* r = e2.as()) { + return ExprEqual(op->ref, r->ref) && ExprEqual(op->value, r->value); + } else { + return false; + } + } private: // whether to map open terms. bool map_free_var_{false}; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index cdb2a32a0009b..8a5115db6cb75 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -271,13 +271,58 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; }); +RefNew RefNewNode::make(Expr value) { + NodePtr n = make_node(); + n->value = std::move(value); + return RefNew(n); +} -TVM_REGISTER_API("relay._expr.TempExprRealize") +TVM_REGISTER_API("relay._make.RefNew").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = RefNewNode::make(args[0]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const RefNewNode* node, tvm::IRPrinter* p) { + p->stream << "RefNewNode(" << node->value << ")"; +}); + +RefRead RefReadNode::make(Expr ref) { + NodePtr n = make_node(); + n->ref = std::move(ref); + return RefRead(n); +} + +TVM_REGISTER_API("relay._make.RefRead") .set_body([](TVMArgs args, TVMRetValue* ret) { - TempExpr temp = args[0]; - *ret = temp->Realize(); + *ret = RefReadNode::make(args[0]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const RefReadNode* node, tvm::IRPrinter* p) { + p->stream << "RefReadNode(" << node->ref << ")"; +}); + +RefWrite RefWriteNode::make(Expr ref, Expr value) { + NodePtr n = make_node(); + n->ref = std::move(ref); + n->value = std::move(value); + return RefWrite(n); +} + +TVM_REGISTER_API("relay._make.RefWrite").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = RefWriteNode::make(args[0], args[1]); }); +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const RefWriteNode* node, tvm::IRPrinter* p) { + p->stream << "RefWriteNode(" << node->value << ")"; +}); + +TVM_REGISTER_API("relay._expr.TempExprRealize") +.set_body([](TVMArgs args, TVMRetValue* ret) { + TempExpr temp = args[0]; + *ret = temp->Realize(); +}); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index c1719e81a6c6c..83eb882b5847d 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -157,6 +157,34 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { } } +Expr ExprMutator::VisitExpr_(const RefNewNode* op) { + Expr value = this->Mutate(op->value); + if (value.same_as(op->value)) { + return GetRef(op); + } else { + return RefNewNode::make(value); + } +} + +Expr ExprMutator::VisitExpr_(const RefReadNode* op) { + Expr ref = this->Mutate(op->ref); + if (ref.same_as(op->ref)) { + return GetRef(op); + } else { + return RefReadNode::make(ref); + } +} + +Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { + Expr ref = this->Mutate(op->ref); + Expr value = this->Mutate(op->value); + if (ref.same_as(op->ref) && value.same_as(op->value)) { + return GetRef(op); + } else { + return RefWriteNode::make(ref, value); + } +} + Type ExprMutator::VisitType(const Type& t) { return t; } void ExprVisitor::VisitExpr(const Expr& expr) { @@ -226,6 +254,19 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); } +void ExprVisitor::ExprVisitor::VisitExpr_(const RefNewNode* op) { + this->VisitExpr(op->value); +} + +void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) { + this->VisitExpr(op->ref); +} + +void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) { + this->VisitExpr(op->ref); + this->VisitExpr(op->value); +} + void ExprVisitor::VisitType(const Type& t) { return; } diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index d7a8df98fa3fb..9fb94b8cc5dc9 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -16,9 +16,9 @@ namespace relay { // Hash handler for Relay. class RelayHashHandler: - public AttrsHashHandler, - public TypeFunctor, - public ExprFunctor { + public AttrsHashHandler, + public TypeFunctor, + public ExprFunctor { public: explicit RelayHashHandler() {} @@ -175,6 +175,12 @@ class RelayHashHandler: return hash; } + size_t VisitType_(const RefTypeNode* rtn) final { + size_t hash = std::hash()(RefTypeNode::_type_key); + hash = Combine(hash, TypeHash(rtn->value)); + return hash; + } + // Expr hashing. size_t NDArrayHash(const runtime::NDArray& array) { size_t hash = std::hash()(array->dtype.code); @@ -280,6 +286,24 @@ class RelayHashHandler: return hash; } + size_t VisitExpr_(const RefNewNode* rn) final { + size_t hash = std::hash()(RefNewNode::_type_key); + hash = Combine(hash, ExprHash(rn->value)); + return hash; + } + + size_t VisitExpr_(const RefReadNode* rn) final { + size_t hash = std::hash()(RefReadNode::_type_key); + hash = Combine(hash, ExprHash(rn->ref)); + return hash; + } + + size_t VisitExpr_(const RefWriteNode* rn) final { + size_t hash = std::hash()(RefWriteNode::_type_key); + hash = Combine(hash, ExprHash(rn->ref)); + hash = Combine(hash, ExprHash(rn->value)); + return hash; + } private: // renaming of NodeRef to indicate two nodes equals to each other std::unordered_map hash_map_; diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 46b0d25b3d7de..fed86a51f4967 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -362,6 +362,34 @@ class TextPrinter : return id; } + TextValue VisitExpr_(const RefNewNode* op) final { + TextValue value = GetValue(op->value); + TextValue id = this->AllocTempVar(); + this->PrintIndent(); + stream_ << id << " = " << "RefNew(" << op->value << ")"; + this->PrintEndInst("\n"); + return id; + } + + TextValue VisitExpr_(const RefReadNode* op) final { + TextValue ref = GetValue(op->ref); + TextValue id = this->AllocTempVar(); + this->PrintIndent(); + stream_ << id << " = " << "RefRead(" << ref << ")"; + this->PrintEndInst("\n"); + return id; + } + + TextValue VisitExpr_(const RefWriteNode* op) final { + TextValue ref = GetValue(op->ref); + TextValue value = GetValue(op->value); + TextValue id = this->AllocTempVar(); + this->PrintIndent(); + stream_ << id << " = " << "RefWrite(" << ref << ", " << value << ")"; + this->PrintEndInst("\n"); + return id; + } + /*! * \brief Print the type to os * \param type The type to be printed. @@ -404,6 +432,10 @@ class TextPrinter : os << "]"; } + void VisitType_(const RefTypeNode* node, std::ostream& os) final { + VisitTypeDefault_(node, os); + } + void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*) // by default always print as meta-data os << meta_.GetMetaNode(GetRef(node)); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index bbe6472609dfd..e829d8abd63c2 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -164,5 +164,24 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TupleTypeNode(" << node->fields << ")"; }); +RefType RefTypeNode::make(Type value) { + NodePtr n = make_node(); + n->value = std::move(value); + return RefType(n); +} + +TVM_REGISTER_API("relay._make.RefType") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = RefTypeNode::make(args[0]); +}); + +TVM_REGISTER_NODE_TYPE(RefTypeNode); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const RefTypeNode* node, + tvm::IRPrinter* p) { + p->stream << "RefTypeNode(" << node->value << ")"; +}); + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 0ef1743cbbc4a..100c633a2997b 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -38,6 +38,10 @@ void TypeVisitor::VisitType_(const TupleTypeNode* op) { } } +void TypeVisitor::VisitType_(const RefTypeNode* op) { + this->VisitType(op->value); +} + void TypeVisitor::VisitType_(const TypeRelationNode* op) { for (const Type& t : op->args) { this->VisitType(t); @@ -119,6 +123,10 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) { } } +Type TypeMutator::VisitType_(const RefTypeNode* op) { + return RefTypeNode::make(this->VisitType(op->value)); +} + Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { Array new_args = MutateArray(type_rel->args); if (new_args.same_as(type_rel->args)) { diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index e8dfd2b7cd7cd..1be55e78eee64 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -68,7 +68,7 @@ class TypeFunctor { virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - + virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Node* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->type_key(); throw; // unreachable, written to stop compiler warning @@ -86,6 +86,7 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode); return vtable; } }; @@ -101,6 +102,7 @@ class TypeVisitor : public TypeFunctor { void VisitType_(const FuncTypeNode* op) override; void VisitType_(const TupleTypeNode* op) override; void VisitType_(const TypeRelationNode* op) override; + void VisitType_(const RefTypeNode* op) override; }; // Mutator that transform a type to another one. @@ -112,6 +114,7 @@ class TypeMutator : public TypeFunctor { Type VisitType_(const FuncTypeNode* op) override; Type VisitType_(const TupleTypeNode* op) override; Type VisitType_(const TypeRelationNode* type_rel) override; + Type VisitType_(const RefTypeNode* op) override; private: Array MutateArray(Array arr); diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index fad4fb781b5a8..2f96c4b52b428 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -161,6 +161,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { current->extern_ref = true; } } + void AddNode(const tvm::Node* key) { auto it = graph_.node_map.find(key); CHECK(it != graph_.node_map.end()) @@ -173,7 +174,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } // Post order tree - void VisitExpr_(const FunctionNode* op) { + void VisitExpr_(const FunctionNode* op) final { for (auto param : op->params) { this->Update(param, nullptr, kOpaque); } @@ -181,7 +182,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ExprVisitor::VisitExpr_(op); } - void VisitExpr_(const ConstantNode* op) { + void VisitExpr_(const ConstantNode* op) final { this->AddNode(op); Node* node = graph_.node_map.at(op); DataType dtype = TVMType2Type(op->data->dtype); @@ -201,7 +202,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } } - void VisitExpr_(const CallNode* call) { + void VisitExpr_(const CallNode* call) final { CHECK(graph_.node_map.count(call)); Node* node = graph_.node_map.at(call); static auto fpattern = @@ -231,7 +232,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(call); } - void VisitExpr_(const TupleNode* op) { + void VisitExpr_(const TupleNode* op) final { CHECK(graph_.node_map.count(op)); Node* tuple_node = graph_.node_map.at(op); tuple_node->pattern = kInjective; @@ -246,7 +247,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const TupleGetItemNode* op) { + void VisitExpr_(const TupleGetItemNode* op) final { CHECK(graph_.node_map.count(op)); Node* node = graph_.node_map.at(op); this->Update(op->tuple, node, kOpaque); @@ -254,11 +255,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const VarNode* op) { + void VisitExpr_(const VarNode* op) final { this->AddNode(op); } - void VisitExpr_(const LetNode* op) { + void VisitExpr_(const LetNode* op) final { // do not fuse through let. this->Update(op->var, nullptr, kOpaque); this->Update(op->value, nullptr, kOpaque); @@ -267,7 +268,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const IfNode* op) { + void VisitExpr_(const IfNode* op) final { // do not fuse through if. this->Update(op->cond, nullptr, kOpaque); this->Update(op->true_branch, nullptr, kOpaque); @@ -275,6 +276,25 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ExprVisitor::VisitExpr_(op); this->AddNode(op); } + + void VisitExpr_(const RefNewNode* op) final { + this->Update(op->value, nullptr, kOpaque); + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } + + void VisitExpr_(const RefReadNode* op) final { + this->Update(op->ref, nullptr, kOpaque); + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } + + void VisitExpr_(const RefWriteNode* op) final { + this->Update(op->ref, nullptr, kOpaque); + this->Update(op->value, nullptr, kOpaque); + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } }; IndexedForwardGraph IndexedForwardGraph::Create( diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 7253a600dabfb..200f5385a37a7 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -82,6 +82,12 @@ struct KindChecker : TypeVisitor { valid = valid && IsTypeKind(op->ret_type); } + void VisitType_(const RefTypeNode* op) override { + // tuples should only contain normal types + this->VisitType(op->value); + valid = valid && IsTypeKind(op->value); + } + void VisitType_(const TypeRelationNode* op) override { // arguments to type relation should be normal types for (const Type& t : op->args) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index af4cc6607a44a..0edfa9a11add2 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -390,6 +390,33 @@ class TypeInferencer : private ExprFunctor { auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); return solver_.Resolve(ret); } + + Type VisitExpr_(const RefNewNode* op) final { + return RefTypeNode::make(GetType(op->value)); + } + + Type VisitExpr_(const RefReadNode* op) final { + // TODO(M.K.) + // handle case where ref type is not known + Type ref_type = GetType(op->ref); + auto ref_ty_node = ref_type.as(); + if (!ref_ty_node) { + LOG(FATAL) << "only expressions with ref types is accepted" << GetRef(op); + } + return ref_ty_node->value; + } + + Type VisitExpr_(const RefWriteNode* op) final { + // TODO(M.K.) + // handle case where ref type is not known + Type ref_type = GetType(op->ref); + auto ref_ty_node = ref_type.as(); + if (!ref_ty_node) { + LOG(FATAL) << "only expressions with ref types is accepted" << GetRef(op); + } + this->Unify(ref_ty_node->value, GetType(op->value), op->span); + return TupleTypeNode::make({}); + } }; class TypeInferencer::Resolver : public ExprMutator { @@ -439,6 +466,18 @@ class TypeInferencer::Resolver : public ExprMutator { return AttachCheckedType(op); } + Expr VisitExpr_(const RefNewNode* op) final { + return AttachCheckedType(op); + } + + Expr VisitExpr_(const RefReadNode* op) final { + return AttachCheckedType(op); + } + + Expr VisitExpr_(const RefWriteNode* op) final { + return AttachCheckedType(op); + } + // attach checked type to the mutated node. template Expr AttachCheckedType(const T* op) { diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index caea3755b8f9e..54e0ddf9f03a5 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -109,7 +109,7 @@ class TypeSolver::Unifier : public TypeFunctor { } // default: unify only if alpha-equal - Type VisitTypeDefault_(const Node* op, const Type& tn) override { + Type VisitTypeDefault_(const Node* op, const Type& tn) final { NodeRef nr = GetRef(op); Type t1 = GetRef(nr.as_derived()); if (!AlphaEqual(t1, tn)) { @@ -118,7 +118,7 @@ class TypeSolver::Unifier : public TypeFunctor { return t1; } - Type VisitType_(const TupleTypeNode* op, const Type& tn) override { + Type VisitType_(const TupleTypeNode* op, const Type& tn) final { const auto* ttn = tn.as(); if (!ttn || op->fields.size() != ttn->fields.size()) { return Type(nullptr); @@ -135,7 +135,7 @@ class TypeSolver::Unifier : public TypeFunctor { return TupleTypeNode::make(new_fields); } - Type VisitType_(const FuncTypeNode* op, const Type& tn) override { + Type VisitType_(const FuncTypeNode* op, const Type& tn) final { const auto* ftn = tn.as(); if (!ftn || op->arg_types.size() != ftn->arg_types.size() @@ -174,6 +174,14 @@ class TypeSolver::Unifier : public TypeFunctor { return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints); } + Type VisitType_(const RefTypeNode* op, const Type& tn) final { + const auto* rtn = tn.as(); + if (!rtn) { + return Type(nullptr); + } + return RefTypeNode::make(Unify(op->value, rtn->value)); + } + private: TypeSolver* solver_; }; diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 38c340db424df..37d78a372d058 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -110,6 +110,22 @@ def test_loop(): check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod) +def test_ref(): + mod = relay.Module() + three_with_ref = relay.GlobalVar('three_with_ref') + i = relay.Var('i') + iv = relay.Var('iv') + u = relay.Var('u') + uv = relay.Var('uv') + body = relay.add(iv, uv) + body = relay.Let(uv, relay.RefRead(i), body) + body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body) + body = relay.Let(iv, relay.RefRead(i), body) + body = relay.Let(i, relay.RefNew(relay.const(1)), body) + mod[three_with_ref] = relay.Function([], body) + check_eval(three_with_ref, [], 3, mod=mod) + + def test_binds(): x = relay.var("x") y = relay.add(x, x) @@ -118,6 +134,7 @@ def test_binds(): res = intrp.evaluate(y, binds={x: xx}).asnumpy() tvm.testing.assert_allclose(xx + xx, res) + def test_kwargs_params(): x = relay.var("x", shape=(1, 10)) y = relay.var("y", shape=(1, 10)) @@ -131,6 +148,7 @@ def test_kwargs_params(): res = intrp.evaluate(f)(x_data, **params).data tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data) + if __name__ == "__main__": test_id() test_add_const() @@ -140,3 +158,4 @@ def test_kwargs_params(): test_loop() test_binds() test_kwargs_params() + test_ref() diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index ac4eb1b404dbc..d5c0d978f424a 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -131,6 +131,18 @@ def test_tuple(): relay.TupleType([tp, tp])) +def test_ref(): + x = relay.var("x", "float32") + y = relay.var("y", "float32") + r = relay.RefNew(x) + st = relay.scalar_type("float32") + assert relay.ir_pass.infer_type(r).checked_type == relay.RefType(st) + g = relay.RefRead(r) + assert relay.ir_pass.infer_type(g).checked_type == st + w = relay.RefWrite(r, y) + assert relay.ir_pass.infer_type(w).checked_type == relay.TupleType([]) + + def test_free_expr(): x = relay.var("x", "float32") y = relay.add(x, x) @@ -187,12 +199,9 @@ def test_equal(): test_decl() test_recursion() test_tuple() - test_generalized_tuple() test_incomplete_call() - test_generalized_call() - test_call_with_type_args() test_free_expr() test_type_args() - test_self_reference() test_global_var_recursion() test_equal() + test_ref()