diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 14b3cd91701c..b9a57c5d4618 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -428,12 +428,78 @@ class TupleGetItemNode : public ExprNode { TVM_DLL static TupleGetItem make(Expr tuple, int index); - static constexpr const char * _type_key = "relay.TupleGetItem"; + static constexpr const char* _type_key = "relay.TupleGetItem"; TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode); }; RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); +/*! \brief Create a new Reference out of initial value. */ +class RefCreate; +class RefCreateNode : public ExprNode { + public: + /*! \brief The initial value of the Reference. */ + Expr value; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("value", &value); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static RefCreate make(Expr value); + + static constexpr const char* _type_key = "relay.RefCreate"; + TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr); + +/*! \brief Get value out of Reference. */ +class RefRead; +class RefReadNode : public ExprNode { + public: + /*! \brief The Reference Expression. */ + Expr ref; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("ref", &ref); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static RefRead make(Expr ref); + + static constexpr const char* _type_key = "relay.RefRead"; + TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr); + +/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ +class RefWrite; +class RefWriteNode : public ExprNode { + public: + /*! \brief The Reference Expression. */ + Expr ref; + /*! \brief The value to write into. */ + Expr value; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("ref", &ref); + v->Visit("value", &value); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static RefWrite make(Expr ref, Expr value); + + static constexpr const char* _type_key = "relay.RefWrite"; + TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr); + /*! * \brief Base class of the temporary expression. * diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 60b18218a313..e7b66bc1bbde 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -89,6 +89,9 @@ class ExprFunctor { virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw Error(std::string("Do not have a default for ") + op->type_key()); } @@ -108,6 +111,9 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); + RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode); + RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode); + RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); return vtable; } }; @@ -133,6 +139,9 @@ class ExprVisitor void VisitExpr_(const IfNode* op) override; void VisitExpr_(const OpNode* op) override; void VisitExpr_(const TupleGetItemNode* op) override; + void VisitExpr_(const RefCreateNode* op) override; + void VisitExpr_(const RefReadNode* op) override; + void VisitExpr_(const RefWriteNode* op) override; virtual void VisitType(const Type& t); protected: @@ -168,6 +177,9 @@ class ExprMutator Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const IfNode* op) override; Expr VisitExpr_(const TupleGetItemNode* op) override; + Expr VisitExpr_(const RefCreateNode* op) override; + Expr VisitExpr_(const RefReadNode* op) override; + Expr VisitExpr_(const RefWriteNode* op) override; /*! * \brief Used to visit the types inside of expressions. * diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 1099ef0f3cfd..08aeef1827b6 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -140,6 +140,25 @@ struct TensorValueNode : ValueNode { RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value); +/*! \brief A reference value. */ +class RefValue; + +struct RefValueNode : ValueNode { + mutable Value value; + + RefValueNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("value", &value); + } + + TVM_DLL static RefValue make(Value val); + + static constexpr const char* _type_key = "relay.RefValue"; + TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index f3bcf2c0a1d9..0ee265e5f3b0 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -262,6 +262,33 @@ class TupleTypeNode : public TypeNode { RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type); +/*! + * \brief The type of reference values. + */ +class RefType; +/*! + * \brief Reference Type in relay. + */ +class RefTypeNode : public TypeNode { + public: + /*! \brief The type of value in the Reference. */ + Type value; + + RefTypeNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("value", &value); + v->Visit("span", &span); + } + + TVM_DLL static RefType make(Type value); + + static constexpr const char* _type_key = "relay.RefType"; + TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type); + class TypeReporter; /*! diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index b9d4695b70f8..0af164bc7a73 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -44,6 +44,7 @@ TypeRelation = ty.TypeRelation IncompleteType = ty.IncompleteType scalar_type = ty.scalar_type +RefType = ty.RefType # Expr Expr = expr.Expr @@ -56,15 +57,18 @@ Let = expr.Let If = expr.If TupleGetItem = expr.TupleGetItem - -# ExprFunctor -ExprFunctor = expr_functor.ExprFunctor -ExprMutator = expr_functor.ExprMutator +RefCreate = expr.RefCreate +RefRead = expr.RefRead +RefWrite = expr.RefWrite # helper functions var = expr.var const = expr.const bind = expr.bind +# ExprFunctor +ExprFunctor = expr_functor.ExprFunctor +ExprMutator = expr_functor.ExprMutator + # Parser fromtext = parser.fromtext diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index 15e0a81226cb..cc510b2290cf 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -283,6 +283,15 @@ def visit_call(self, call): def visit_op(self, _): raise Exception("can not compile op in non-eta expanded form") + def visit_ref_create(self, _): + raise RuntimeError("reference not supported") + + def visit_ref_read(self, _): + raise RuntimeError("reference not supported") + + def visit_ref_write(self, _): + raise RuntimeError("reference not supported") + def _get_json(self): """ Convert the sequence of nodes stored by the compiler into the diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 4a5ddcd8270c..b21eab185c28 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -45,6 +45,7 @@ def __repr__(self): def __iter__(self): return iter(self.fields) + @register_relay_node class Closure(Value): """A closure produced by the interpreter.""" @@ -79,6 +80,13 @@ def __str__(self): return str(self.data) +@register_relay_node +class RefValue(Value): + def __init__(self, value): + self.__init_handle_by_constructor__( + _make.RefValue, value) + + def _arg_to_ast(arg): if isinstance(arg, TensorValue): return Constant(arg.data.copyto(_nd.cpu(0))) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 38ab0064e671..71b89d0b4777 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -327,6 +327,46 @@ def __init__(self, tuple_value, index): _make.TupleGetItem, tuple_value, index) +@register_relay_node +class RefCreate(Expr): + """Create a new reference from initial value. + Parameters + ---------- + value: tvm.relay.Expr + The initial value. + """ + def __init__(self, value): + self.__init_handle_by_constructor__(_make.RefCreate, value) + + +@register_relay_node +class RefRead(Expr): + """Get the value inside the reference. + Parameters + ---------- + ref: tvm.relay.Expr + The reference. + """ + def __init__(self, ref): + self.__init_handle_by_constructor__(_make.RefRead, ref) + + +@register_relay_node +class RefWrite(Expr): + """ + Update the value inside the reference. + The whole expression will evaluate to an empty tuple. + Parameters + ---------- + ref: tvm.relay.Expr + The reference. + value: tvm.relay.Expr + The new value. + """ + def __init__(self, ref, value): + self.__init_handle_by_constructor__(_make.RefWrite, ref, value) + + class TempExpr(Expr): """Baseclass of all TempExpr. diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index eafe5f09309f..b22a4e7562e2 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -41,6 +41,12 @@ def visit(self, expr): res = self.visit_constant(expr) elif isinstance(expr, Op): res = self.visit_op(expr) + elif isinstance(expr, RefCreate): + res = self.visit_ref_create(expr) + elif isinstance(expr, RefRead): + res = self.visit_ref_read(expr) + elif isinstance(expr, RefWrite): + res = self.visit_ref_write(expr) else: raise Exception("warning unhandled case: {0}".format(type(expr))) @@ -81,6 +87,14 @@ def visit_op(self, _): def visit_constant(self, _): raise NotImplementedError() + def visit_ref_create(self, _): + raise NotImplementedError() + + def visit_ref_write(self, _): + raise NotImplementedError() + + def visit_ref_read(self, _): + raise NotImplementedError() class ExprMutator(ExprFunctor): """ @@ -145,8 +159,8 @@ def visit_constructor(self, con): def visit_match(self, m): return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern]) - def visit_ref_new(self, r): - return RefNew(self.visit(r.value)) + def visit_ref_create(self, r): + return RefCreate(self.visit(r.value)) def visit_ref_write(self, r): return RefWrite(self.visit(r.ref), self.visit(r.value)) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 96dde5acb4df..bed293d1e3ca 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -210,6 +210,19 @@ def __init__(self, func, args, num_inputs, attrs): func, args, num_inputs, attrs) +@register_relay_node +class RefType(Type): + """Reference Type in relay. + + Parameters + ---------- + value: Type + The value type. + """ + def __init__(self, value): + self.__init_handle_by_constructor__(_make.RefType, value) + + def scalar_type(dtype): """Creates a scalar type. diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 396ff907951d..893e66b41b42 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -75,6 +75,23 @@ TVM_REGISTER_API("relay._make.TensorValue") *ret = TensorValueNode::make(data); }); +RefValue RefValueNode::make(Value value) { + NodePtr n = make_node(); + n->value = value; + return RefValue(n); +} + +TVM_REGISTER_API("relay._make.RefValue") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = RefValueNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const RefValueNode* node, + tvm::IRPrinter* p) { + p->stream << "RefValueNode(" << node->value << ")"; + }); + /*! * \brief A stack frame in the Relay interpreter. * @@ -432,6 +449,31 @@ class Interpreter : } } + Value VisitExpr_(const RefWriteNode* op) final { + Value r = Eval(op->ref); + if (const RefValueNode* rv = r.as()) { + rv->value = Eval(op->value); + return TupleValueNode::make({}); + } else { + LOG(FATAL) << "type error, type system should have caught this"; + return Value(); + } + } + + Value VisitExpr_(const RefCreateNode* op) final { + return RefValueNode::make(Eval(op->value)); + } + + Value VisitExpr_(const RefReadNode* op) final { + Value r = Eval(op->ref); + if (const RefValueNode* rv = r.as()) { + return rv->value; + } else { + LOG(FATAL) << "type error, type system should have caught this"; + return Value(); + } + } + InterpreterState get_state(Expr e = Expr()) const { InterpreterStateNode::Stack stack; for (auto fr : this->stack_.frames) { diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 064343c834ea..d0cc004994d4 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -207,6 +207,14 @@ class AlphaEqualHandler: return false; } } + + bool VisitType_(const RefTypeNode* lhs, const Type& other) final { + if (const RefTypeNode* rhs = other.as()) { + return TypeEqual(lhs->value, rhs->value); + } + return false; + } + // Expr equal checking. bool NDArrayEqual(const runtime::NDArray& lhs, const runtime::NDArray& rhs) { @@ -361,6 +369,29 @@ class AlphaEqualHandler: } } + bool VisitExpr_(const RefCreateNode* op, const Expr& e2) final { + if (const RefCreateNode* nr = e2.as()) { + return ExprEqual(op->value, nr->value); + } else { + return false; + } + } + + bool VisitExpr_(const RefReadNode* op, const Expr& e2) final { + if (const RefReadNode* r = e2.as()) { + return ExprEqual(op->ref, r->ref); + } else { + return false; + } + } + + bool VisitExpr_(const RefWriteNode* op, const Expr& e2) final { + if (const RefWriteNode* r = e2.as()) { + return ExprEqual(op->ref, r->ref) && ExprEqual(op->value, r->value); + } else { + return false; + } + } private: // whether to map open terms. bool map_free_var_{false}; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index cdb2a32a0009..bc6eee3ebc03 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -271,13 +271,59 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; }); +RefCreate RefCreateNode::make(Expr value) { + NodePtr n = make_node(); + n->value = std::move(value); + return RefCreate(n); +} -TVM_REGISTER_API("relay._expr.TempExprRealize") +TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = RefCreateNode::make(args[0]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const RefCreateNode* node, tvm::IRPrinter* p) { + p->stream << "RefCreateNode(" << node->value << ")"; +}); + +RefRead RefReadNode::make(Expr ref) { + NodePtr n = make_node(); + n->ref = std::move(ref); + return RefRead(n); +} + +TVM_REGISTER_API("relay._make.RefRead") .set_body([](TVMArgs args, TVMRetValue* ret) { - TempExpr temp = args[0]; - *ret = temp->Realize(); + *ret = RefReadNode::make(args[0]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const RefReadNode* node, tvm::IRPrinter* p) { + p->stream << "RefReadNode(" << node->ref << ")"; }); +RefWrite RefWriteNode::make(Expr ref, Expr value) { + NodePtr n = make_node(); + n->ref = std::move(ref); + n->value = std::move(value); + return RefWrite(n); +} + +TVM_REGISTER_API("relay._make.RefWrite") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = RefWriteNode::make(args[0], args[1]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const RefWriteNode* node, tvm::IRPrinter* p) { + p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; +}); + +TVM_REGISTER_API("relay._expr.TempExprRealize") +.set_body([](TVMArgs args, TVMRetValue* ret) { + TempExpr temp = args[0]; + *ret = temp->Realize(); +}); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index e7b4a918c984..9bdfa00ce298 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -157,6 +157,34 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { } } +Expr ExprMutator::VisitExpr_(const RefCreateNode* op) { + Expr value = this->Mutate(op->value); + if (value.same_as(op->value)) { + return GetRef(op); + } else { + return RefCreateNode::make(value); + } +} + +Expr ExprMutator::VisitExpr_(const RefReadNode* op) { + Expr ref = this->Mutate(op->ref); + if (ref.same_as(op->ref)) { + return GetRef(op); + } else { + return RefReadNode::make(ref); + } +} + +Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { + Expr ref = this->Mutate(op->ref); + Expr value = this->Mutate(op->value); + if (ref.same_as(op->ref) && value.same_as(op->value)) { + return GetRef(op); + } else { + return RefWriteNode::make(ref, value); + } +} + Type ExprMutator::VisitType(const Type& t) { return t; } void ExprVisitor::VisitExpr(const Expr& expr) { @@ -226,6 +254,19 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); } +void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) { + this->VisitExpr(op->value); +} + +void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) { + this->VisitExpr(op->ref); +} + +void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) { + this->VisitExpr(op->ref); + this->VisitExpr(op->value); +} + void ExprVisitor::VisitType(const Type& t) { return; } // visitor to implement apply diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index d7a8df98fa3f..d984bb051e43 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -16,9 +16,9 @@ namespace relay { // Hash handler for Relay. class RelayHashHandler: - public AttrsHashHandler, - public TypeFunctor, - public ExprFunctor { + public AttrsHashHandler, + public TypeFunctor, + public ExprFunctor { public: explicit RelayHashHandler() {} @@ -175,6 +175,12 @@ class RelayHashHandler: return hash; } + size_t VisitType_(const RefTypeNode* rtn) final { + size_t hash = std::hash()(RefTypeNode::_type_key); + hash = Combine(hash, TypeHash(rtn->value)); + return hash; + } + // Expr hashing. size_t NDArrayHash(const runtime::NDArray& array) { size_t hash = std::hash()(array->dtype.code); @@ -280,6 +286,24 @@ class RelayHashHandler: return hash; } + size_t VisitExpr_(const RefCreateNode* rn) final { + size_t hash = std::hash()(RefCreateNode::_type_key); + hash = Combine(hash, ExprHash(rn->value)); + return hash; + } + + size_t VisitExpr_(const RefReadNode* rn) final { + size_t hash = std::hash()(RefReadNode::_type_key); + hash = Combine(hash, ExprHash(rn->ref)); + return hash; + } + + size_t VisitExpr_(const RefWriteNode* rn) final { + size_t hash = std::hash()(RefWriteNode::_type_key); + hash = Combine(hash, ExprHash(rn->ref)); + hash = Combine(hash, ExprHash(rn->value)); + return hash; + } private: // renaming of NodeRef to indicate two nodes equals to each other std::unordered_map hash_map_; diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 8f6629a14f92..05179d584d84 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -363,6 +363,34 @@ class TextPrinter : return id; } + TextValue VisitExpr_(const RefCreateNode* op) final { + TextValue value = GetValue(op->value); + TextValue id = this->AllocTempVar(); + this->PrintIndent(); + stream_ << id << " = " << "RefCreate(" << op->value << ")"; + this->PrintEndInst("\n"); + return id; + } + + TextValue VisitExpr_(const RefReadNode* op) final { + TextValue ref = GetValue(op->ref); + TextValue id = this->AllocTempVar(); + this->PrintIndent(); + stream_ << id << " = " << "RefRead(" << ref << ")"; + this->PrintEndInst("\n"); + return id; + } + + TextValue VisitExpr_(const RefWriteNode* op) final { + TextValue ref = GetValue(op->ref); + TextValue value = GetValue(op->value); + TextValue id = this->AllocTempVar(); + this->PrintIndent(); + stream_ << id << " = " << "RefWrite(" << ref << ", " << value << ")"; + this->PrintEndInst("\n"); + return id; + } + /*! * \brief Print the type to os * \param type The type to be printed. @@ -405,6 +433,10 @@ class TextPrinter : os << "]"; } + void VisitType_(const RefTypeNode* node, std::ostream& os) final { + VisitTypeDefault_(node, os); + } + void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*) // by default always print as meta-data os << meta_.GetMetaNode(GetRef(node)); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index bbe6472609df..e829d8abd63c 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -164,5 +164,24 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TupleTypeNode(" << node->fields << ")"; }); +RefType RefTypeNode::make(Type value) { + NodePtr n = make_node(); + n->value = std::move(value); + return RefType(n); +} + +TVM_REGISTER_API("relay._make.RefType") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = RefTypeNode::make(args[0]); +}); + +TVM_REGISTER_NODE_TYPE(RefTypeNode); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const RefTypeNode* node, + tvm::IRPrinter* p) { + p->stream << "RefTypeNode(" << node->value << ")"; +}); + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 0ef1743cbbc4..100c633a2997 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -38,6 +38,10 @@ void TypeVisitor::VisitType_(const TupleTypeNode* op) { } } +void TypeVisitor::VisitType_(const RefTypeNode* op) { + this->VisitType(op->value); +} + void TypeVisitor::VisitType_(const TypeRelationNode* op) { for (const Type& t : op->args) { this->VisitType(t); @@ -119,6 +123,10 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) { } } +Type TypeMutator::VisitType_(const RefTypeNode* op) { + return RefTypeNode::make(this->VisitType(op->value)); +} + Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { Array new_args = MutateArray(type_rel->args); if (new_args.same_as(type_rel->args)) { diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index e8dfd2b7cd7c..1be55e78eee6 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -68,7 +68,7 @@ class TypeFunctor { virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - + virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Node* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->type_key(); throw; // unreachable, written to stop compiler warning @@ -86,6 +86,7 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode); return vtable; } }; @@ -101,6 +102,7 @@ class TypeVisitor : public TypeFunctor { void VisitType_(const FuncTypeNode* op) override; void VisitType_(const TupleTypeNode* op) override; void VisitType_(const TypeRelationNode* op) override; + void VisitType_(const RefTypeNode* op) override; }; // Mutator that transform a type to another one. @@ -112,6 +114,7 @@ class TypeMutator : public TypeFunctor { Type VisitType_(const FuncTypeNode* op) override; Type VisitType_(const TupleTypeNode* op) override; Type VisitType_(const TypeRelationNode* type_rel) override; + Type VisitType_(const RefTypeNode* op) override; private: Array MutateArray(Array arr); diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 572c62cfab10..a6298ba448f3 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -162,6 +162,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { current->extern_ref = true; } } + void AddNode(const tvm::Node* key) { auto it = graph_.node_map.find(key); CHECK(it != graph_.node_map.end()) @@ -174,7 +175,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } // Post order tree - void VisitExpr_(const FunctionNode* op) { + void VisitExpr_(const FunctionNode* op) final { for (auto param : op->params) { this->Update(param, nullptr, kOpaque); } @@ -182,7 +183,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ExprVisitor::VisitExpr_(op); } - void VisitExpr_(const ConstantNode* op) { + void VisitExpr_(const ConstantNode* op) final { this->AddNode(op); Node* node = graph_.node_map.at(op); DataType dtype = TVMType2Type(op->data->dtype); @@ -202,7 +203,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } } - void VisitExpr_(const CallNode* call) { + void VisitExpr_(const CallNode* call) final { CHECK(graph_.node_map.count(call)); Node* node = graph_.node_map.at(call); static auto fpattern = @@ -232,7 +233,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(call); } - void VisitExpr_(const TupleNode* op) { + void VisitExpr_(const TupleNode* op) final { CHECK(graph_.node_map.count(op)); Node* tuple_node = graph_.node_map.at(op); tuple_node->pattern = kInjective; @@ -247,7 +248,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const TupleGetItemNode* op) { + void VisitExpr_(const TupleGetItemNode* op) final { CHECK(graph_.node_map.count(op)); Node* node = graph_.node_map.at(op); this->Update(op->tuple, node, kOpaque); @@ -255,11 +256,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const VarNode* op) { + void VisitExpr_(const VarNode* op) final { this->AddNode(op); } - void VisitExpr_(const LetNode* op) { + void VisitExpr_(const LetNode* op) final { // do not fuse through let. this->Update(op->var, nullptr, kOpaque); this->Update(op->value, nullptr, kOpaque); @@ -268,7 +269,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const IfNode* op) { + void VisitExpr_(const IfNode* op) final { // do not fuse through if. this->Update(op->cond, nullptr, kOpaque); this->Update(op->true_branch, nullptr, kOpaque); @@ -276,6 +277,25 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ExprVisitor::VisitExpr_(op); this->AddNode(op); } + + void VisitExpr_(const RefCreateNode* op) final { + this->Update(op->value, nullptr, kOpaque); + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } + + void VisitExpr_(const RefReadNode* op) final { + this->Update(op->ref, nullptr, kOpaque); + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } + + void VisitExpr_(const RefWriteNode* op) final { + this->Update(op->ref, nullptr, kOpaque); + this->Update(op->value, nullptr, kOpaque); + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } }; IndexedForwardGraph IndexedForwardGraph::Create( diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 7253a600dabf..200f5385a37a 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -82,6 +82,12 @@ struct KindChecker : TypeVisitor { valid = valid && IsTypeKind(op->ret_type); } + void VisitType_(const RefTypeNode* op) override { + // tuples should only contain normal types + this->VisitType(op->value); + valid = valid && IsTypeKind(op->value); + } + void VisitType_(const TypeRelationNode* op) override { // arguments to type relation should be normal types for (const Type& t : op->args) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index b17c1c1f0439..10ba3b127bbf 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -431,6 +431,23 @@ class TypeInferencer : private ExprFunctor { auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); return solver_.Resolve(ret); } + + Type VisitExpr_(const RefCreateNode* op) final { + return RefTypeNode::make(GetType(op->value)); + } + + Type VisitExpr_(const RefReadNode* op) final { + Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef(op)); + return it; + } + + Type VisitExpr_(const RefWriteNode* op) final { + Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef(op)); + this->Unify(GetType(op->value), it, GetRef(op)); + return TupleTypeNode::make({}); + } }; class TypeInferencer::Resolver : public ExprMutator { @@ -480,6 +497,18 @@ class TypeInferencer::Resolver : public ExprMutator { return AttachCheckedType(op); } + Expr VisitExpr_(const RefCreateNode* op) final { + return AttachCheckedType(op); + } + + Expr VisitExpr_(const RefReadNode* op) final { + return AttachCheckedType(op); + } + + Expr VisitExpr_(const RefWriteNode* op) final { + return AttachCheckedType(op); + } + // attach checked type to the mutated node. template Expr AttachCheckedType(const T* op) { diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 617aafdc712c..fcd39e791339 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -116,7 +116,7 @@ class TypeSolver::Unifier : public TypeFunctor { } // default: unify only if alpha-equal - Type VisitTypeDefault_(const Node* op, const Type& tn) override { + Type VisitTypeDefault_(const Node* op, const Type& tn) final { NodeRef nr = GetRef(op); Type t1 = GetRef(nr.as_derived()); if (!AlphaEqual(t1, tn)) { @@ -125,7 +125,7 @@ class TypeSolver::Unifier : public TypeFunctor { return t1; } - Type VisitType_(const TupleTypeNode* op, const Type& tn) override { + Type VisitType_(const TupleTypeNode* op, const Type& tn) final { const auto* ttn = tn.as(); if (!ttn || op->fields.size() != ttn->fields.size()) { return Type(nullptr); @@ -142,7 +142,7 @@ class TypeSolver::Unifier : public TypeFunctor { return TupleTypeNode::make(new_fields); } - Type VisitType_(const FuncTypeNode* op, const Type& tn) override { + Type VisitType_(const FuncTypeNode* op, const Type& tn) final { const auto* ftn = tn.as(); if (!ftn || op->arg_types.size() != ftn->arg_types.size() @@ -181,6 +181,14 @@ class TypeSolver::Unifier : public TypeFunctor { return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints); } + Type VisitType_(const RefTypeNode* op, const Type& tn) final { + const auto* rtn = tn.as(); + if (!rtn) { + return Type(nullptr); + } + return RefTypeNode::make(Unify(op->value, rtn->value)); + } + private: TypeSolver* solver_; }; diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 38c340db424d..801b3068eff0 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -110,6 +110,22 @@ def test_loop(): check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod) +def test_ref(): + mod = relay.Module() + three_with_ref = relay.GlobalVar('three_with_ref') + i = relay.Var('i') + iv = relay.Var('iv') + u = relay.Var('u') + uv = relay.Var('uv') + body = relay.add(iv, uv) + body = relay.Let(uv, relay.RefRead(i), body) + body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body) + body = relay.Let(iv, relay.RefRead(i), body) + body = relay.Let(i, relay.RefCreate(relay.const(1)), body) + mod[three_with_ref] = relay.Function([], body) + check_eval(three_with_ref, [], 3, mod=mod) + + def test_binds(): x = relay.var("x") y = relay.add(x, x) @@ -118,6 +134,7 @@ def test_binds(): res = intrp.evaluate(y, binds={x: xx}).asnumpy() tvm.testing.assert_allclose(xx + xx, res) + def test_kwargs_params(): x = relay.var("x", shape=(1, 10)) y = relay.var("y", shape=(1, 10)) @@ -131,6 +148,7 @@ def test_kwargs_params(): res = intrp.evaluate(f)(x_data, **params).data tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data) + if __name__ == "__main__": test_id() test_add_const() @@ -140,3 +158,4 @@ def test_kwargs_params(): test_loop() test_binds() test_kwargs_params() + test_ref() diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index ac4eb1b404db..eeefbc6c3051 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -131,6 +131,18 @@ def test_tuple(): relay.TupleType([tp, tp])) +def test_ref(): + x = relay.var("x", "float32") + y = relay.var("y", "float32") + r = relay.RefCreate(x) + st = relay.scalar_type("float32") + assert relay.ir_pass.infer_type(r).checked_type == relay.RefType(st) + g = relay.RefRead(r) + assert relay.ir_pass.infer_type(g).checked_type == st + w = relay.RefWrite(r, y) + assert relay.ir_pass.infer_type(w).checked_type == relay.TupleType([]) + + def test_free_expr(): x = relay.var("x", "float32") y = relay.add(x, x) @@ -187,12 +199,9 @@ def test_equal(): test_decl() test_recursion() test_tuple() - test_generalized_tuple() test_incomplete_call() - test_generalized_call() - test_call_with_type_args() test_free_expr() test_type_args() - test_self_reference() test_global_var_recursion() test_equal() + test_ref()