diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h new file mode 100644 index 0000000000000..c0fcc7629c517 --- /dev/null +++ b/include/tvm/relay/adt.h @@ -0,0 +1,228 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/adt.h + * \brief Algebraic data types for Relay + */ +#ifndef TVM_RELAY_ADT_H_ +#define TVM_RELAY_ADT_H_ + +#include +#include +#include +#include "./base.h" +#include "./type.h" +#include "./expr.h" + +namespace tvm { +namespace relay { + +/*! \brief Base type for declaring relay pattern. */ +class PatternNode : public RelayNode { + public: + static constexpr const char* _type_key = "relay.Pattern"; + TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node); +}; + +/*! + * \brief Pattern is the base type for an ADT match pattern in Relay. + * + * Given an ADT value, a pattern might accept it and bind the pattern variable to some value + * (typically a subnode of the input or the input). Otherwise, the pattern rejects the value. + * + * ADT pattern matching thus takes a list of values and bings to the first that accepts the value. + */ +class Pattern : public NodeRef { + public: + Pattern() {} + explicit Pattern(NodePtr p) : NodeRef(p) {} + + using ContainerType = PatternNode; +}; + +/*! \brief A wildcard pattern: Accepts all input and binds nothing. */ +class PatternWildcard; +/*! \brief PatternWildcard container node */ +class PatternWildcardNode : public PatternNode { + public: + PatternWildcardNode() {} + + TVM_DLL static PatternWildcard make(); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.PatternWildcard"; + TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode); +}; + +RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern); + +/*! \brief A var pattern. Accept all input and bind to a var. */ +class PatternVar; +/*! \brief PatternVar container node */ +class PatternVarNode : public PatternNode { + public: + PatternVarNode() {} + + tvm::relay::Var var; + + TVM_DLL static PatternVar make(tvm::relay::Var var); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.PatternVar"; + TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode); +}; + +RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern); + +/*! + * \brief ADT constructor. + * Constructors compare by pointer equality. + */ +class Constructor; +/*! \brief Constructor container node. */ +class ConstructorNode : public ExprNode { + public: + /*! \brief The name (only a hint) */ + std::string name_hint; + /*! \brief Input to the constructor. */ + tvm::Array inp; + /*! \brief The datatype the constructor will construct. */ + GlobalTypeVar belong_to; + mutable int tag = -1; + + ConstructorNode() {} + + TVM_DLL static Constructor make(std::string name_hint, + tvm::Array inp, + GlobalTypeVar belong_to); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name_hint", &name_hint); + v->Visit("inp", &inp); + v->Visit("belong_to", &belong_to); + v->Visit("tag", &tag); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + static constexpr const char* _type_key = "relay.Constructor"; + TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr); + +/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ +class PatternConstructor; +/*! \brief PatternVar container node */ +class PatternConstructorNode : public PatternNode { + public: + Constructor con; + tvm::Array pat; + + PatternConstructorNode() {} + + TVM_DLL static PatternConstructor make(Constructor con, tvm::Array var); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("con", &con); + v->Visit("pat", &pat); + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.PatternConstructor"; + TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode); +}; + +RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern); + +/*! + * \brief Stores all data for an Algebraic Data Type (ADT). + */ +class TypeData; +/*! \brief TypeData container node */ +class TypeDataNode : public TypeNode { + public: + /*! + * \brief The header is simply the name of the ADT. + * We adopt nominal typing for ADT definitions; + * that is, differently-named ADT definitions with same constructors + * have different types. + */ + GlobalTypeVar header; + /*! \brief The type variables (to allow for polymorphism). */ + tvm::Array tv; + /*! \brief The constructors. */ + tvm::Array constructors; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("header", &header); + v->Visit("tv", &tv); + v->Visit("constructors", &constructors); + v->Visit("span", &span); + } + + TVM_DLL static TypeData make(GlobalTypeVar header, + tvm::Array tv, + tvm::Array constructors); + + static constexpr const char* _type_key = "relay.TypeData"; + TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type); + +/*! \brief A clause in a match expression. */ +class Clause; +/*! \brief Clause container node. */ +class ClauseNode : public Node { + public: + Pattern lhs; + Expr rhs; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("lhs", &lhs); + v->Visit("rhs", &rhs); + } + + TVM_DLL static Clause make(Pattern lhs, Expr rhs); + + static constexpr const char* _type_key = "relay.Clause"; + TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node); +}; + +RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef); + +/*! \brief ADT pattern matching exression. */ +class Match; +/*! \brief Match container node. */ +class MatchNode : public ExprNode { + public: + Expr data; + + tvm::Array pattern; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("data", &data); + v->Visit("pattern", &pattern); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static Match make(Expr data, tvm::Array pattern); + + static constexpr const char* _type_key = "relay.Match"; + TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ADT_H_ diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 60b18218a3131..bf7f04e1453d7 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -10,6 +10,7 @@ #include #include #include "./expr.h" +#include "./adt.h" #include "./op.h" #include "./error.h" @@ -89,6 +90,8 @@ class ExprFunctor { virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw Error(std::string("Do not have a default for ") + op->type_key()); } @@ -108,6 +111,8 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); + RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); + RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); return vtable; } }; @@ -133,7 +138,11 @@ class ExprVisitor void VisitExpr_(const IfNode* op) override; void VisitExpr_(const OpNode* op) override; void VisitExpr_(const TupleGetItemNode* op) override; + void VisitExpr_(const ConstructorNode* op) override; + void VisitExpr_(const MatchNode* op) override; virtual void VisitType(const Type& t); + virtual void VisitClause(const Clause& c); + virtual void VisitPattern(const Pattern& c); protected: // Internal visiting counter @@ -168,6 +177,9 @@ class ExprMutator Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const IfNode* op) override; Expr VisitExpr_(const TupleGetItemNode* op) override; + Expr VisitExpr_(const ConstructorNode* op) override; + Expr VisitExpr_(const MatchNode* op) override; + /*! * \brief Used to visit the types inside of expressions. * @@ -176,6 +188,8 @@ class ExprMutator * visitor for types which transform them appropriately. */ virtual Type VisitType(const Type& t); + virtual Clause VisitClause(const Clause& c); + virtual Pattern VisitPattern(const Pattern& c); protected: /*! \brief Internal map used for memoization. */ diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 1099ef0f3cfdd..bcfaad86385ea 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -140,6 +140,28 @@ struct TensorValueNode : ValueNode { RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value); +/*! \brief An ADT constructor value. */ +class ConValue; + +struct ConValueNode : ValueNode { + Constructor con; + + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("con", &con); + v->Visit("fields", &fields); + } + + TVM_DLL static ConValue make(Constructor con, + tvm::Array fields); + + static constexpr const char* _type_key = "relay.ConValue"; + TVM_DECLARE_NODE_TYPE_INFO(ConValueNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(ConValue, ConValueNode, Value); + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 8d302c09d959c..b9bd8782f70e1 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -35,32 +36,43 @@ struct Module; * * The functional style allows users to construct custom * environments easily, for example each thread can store - * an Module while auto-tuning. + * a Module while auto-tuning. * */ class ModuleNode : public RelayNode { public: /*! \brief A map from ids to all global functions. */ tvm::Map functions; + tvm::Map type_definitions; ModuleNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("functions", &functions); + v->Visit("type_definitions", &type_definitions); v->Visit("global_var_map_", &global_var_map_); + v->Visit("global_type_var_map_", &global_type_var_map_); } - TVM_DLL static Module make(tvm::Map global_funcs); + TVM_DLL static Module make(tvm::Map global_funcs, + tvm::Map global_type_defs); /*! * \brief Add a function to the global environment. - * \param var The name of the global function. + * \param var The var of the global function. * \param func The function. * \param update Controls whether you can replace a definition in the * environment. */ void Add(const GlobalVar& var, const Function& func, bool update = false); + /*! + * \brief Add a type-level definition to the global environment. + * \param var The var of the global type definition. + * \param type The type definition. + */ + void AddDef(const GlobalTypeVar& var, const TypeData& type); + /*! * \brief Add a function to the global environment. * \param var The name of the global function. @@ -90,6 +102,13 @@ class ModuleNode : public RelayNode { */ GlobalVar GetGlobalVar(const std::string& str); + /*! + * \brief Look up a global function by its name. + * \param str The unique string specifying the global variable. + * \returns The global variable. + */ + GlobalTypeVar GetGlobalTypeVar(const std::string& str); + /*! * \brief Lookup a global function by its variable. * \param var The global var to lookup. @@ -104,6 +123,20 @@ class ModuleNode : public RelayNode { */ Function Lookup(const std::string& name); + /*! + * \brief Lookup a global type definition by its variable. + * \param var The var of the global type definition. + * \return The type definition. + */ + TypeData LookupDef(const GlobalTypeVar& var); + + /*! + * \brief Lookup a global type definition by its name. + * \param var The name of the global type definition. + * \return The type definition. + */ + TypeData LookupDef(const std::string& var); + /*! * \brief Update the functions inside this environment by * functions in another environment. @@ -119,6 +152,7 @@ class ModuleNode : public RelayNode { * ensures global uniqueness. */ tvm::Map global_var_map_; + tvm::Map global_type_var_map_; }; struct Module : public NodeRef { diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index d3c5edd314615..c2f93bcb52c40 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -422,7 +422,7 @@ inline OpRegistry& OpRegistry::add_type_rel( std::string input_name_prefix = "in"; for (int i = 0; i < get()->num_inputs; i++) { auto name = input_name_prefix + std::to_string(i); - auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType); + auto param = TypeVarNode::make(name, Kind::kType); type_params.push_back(param); arg_types.push_back(param); } @@ -430,7 +430,7 @@ inline OpRegistry& OpRegistry::add_type_rel( Array ty_call_args = arg_types; // Add output type. - auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType); + auto out_param = TypeVarNode::make("out", Kind::kType); type_params.push_back(out_param); // this will trigger copy on write. ty_call_args.push_back(out_param); diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h new file mode 100644 index 0000000000000..f9833201ea0a4 --- /dev/null +++ b/include/tvm/relay/pattern_functor.h @@ -0,0 +1,143 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/pattern_functor.h + * \brief A more powerful visitor on ADT patterns that enables defining + * arbitrary function signatures with type-based dispatch on first argument. + */ +#ifndef TVM_RELAY_PATTERN_FUNCTOR_H_ +#define TVM_RELAY_PATTERN_FUNCTOR_H_ + +#include +#include +#include "./expr.h" +#include "./op.h" +#include "./error.h" +#include "./adt.h" + +namespace tvm { +namespace relay { + +/*! + * \brief A dynamical functor on ADT patterns that dispatches on its first argument. + * You can use this as a more powerful visitor, since it allows you to + * define the types of further arguments to VisitPattern. + * + * \sa tvm/ir_functor.h + * + * \tparam FType function signiture + * This type is only defined for FType with function signature R(const Pattern&, + * Args...) + */ +template +class PatternFunctor; + +// functions to be overriden. +#define PATTERN_FUNCTOR_DEFAULT \ + { return VisitPatternDefault_(op, std::forward(args)...); } + +#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitPattern_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); + +template +class PatternFunctor { + private: + using TSelf = PatternFunctor; + using FType = tvm::IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~PatternFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Pattern& n, Args... args) { + return VisitPattern(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitPattern(const Pattern& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitPattern_(const PatternWildcardNode* op, + Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternVarNode* op, + Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternConstructorNode* op, + Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPatternDefault_(const Node* op, Args...) { + throw Error(std::string("Do not have a default for ") + op->type_key()); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode); + RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode); + RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode); + return vtable; + } +}; + +/*! \brief A simple visitor wrapper around PatternFunctor. + * + * Exposes two visitors with default traversal strategies, one + * which doesn't compute a result but can mutate internal state, + * and another which functionally builds a new pattern. + */ +class PatternVisitor : public ::tvm::relay::PatternFunctor { + public: + void VisitPattern_(const PatternWildcardNode* op) override; + void VisitPattern_(const PatternVarNode* op) override; + void VisitPattern_(const PatternConstructorNode* op) override; + virtual void VisitType(const Type& t); + virtual void VisitVar(const Var& v); + virtual void VisitConstructor(const Constructor& c); +}; + +/*! \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator uses memoization and self return in order to amortize + * the cost of using functional updates. + */ +class PatternMutator + : public ::tvm::relay::PatternFunctor { + public: + Pattern Mutate(const Pattern& pat); + Pattern VisitPattern_(const PatternWildcardNode* op) override; + Pattern VisitPattern_(const PatternVarNode* op) override; + Pattern VisitPattern_(const PatternConstructorNode* op) override; + /*! \brief Used to visit the types inside of patterns. + * + * Can be overloaded to transform the types in arbitrary + * ways, one way would be to define a sub-class of type + * visitor for types which transform them appropriately. + */ + virtual Type VisitType(const Type& t); + /*! \brief Used to visit the vars inside of patterns. */ + virtual Var VisitVar(const Var& v); + /*! \brief Used to visit the vars inside of patterns. */ + virtual Constructor VisitConstructor(const Constructor& c); + private: + std::unordered_map var_map_; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PATTERN_FUNCTOR_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 69a8a4fb0bd7d..a00356e1f222d 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -98,6 +98,15 @@ class TensorTypeNode : public BaseTensorTypeNode { RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); +/*! \brief possible kinds of Type */ +enum Kind : int { + /*! \brief template variable in shape expression */ + kType = 0, + kShapeVar = 1, + kBaseType = 2, + kShape = 3 +}; + /*! * \brief Type parameter in the function. * This can be viewed as template parameter in c++ template function. @@ -119,14 +128,6 @@ class TypeVar; /*! \brief TypeVar container node */ class TypeVarNode : public TypeNode { public: - /*! \brief possible kinds of TypeVar */ - enum Kind : int { - /*! \brief template variable in shape expression */ - kType = 0, - kShapeVar = 1, - kBaseType = 2, - kShape = 3 - }; /*! * \brief The variable itself is only meaningful when * kind is ShapeVar, otherwise, we only use the name. @@ -149,6 +150,63 @@ class TypeVarNode : public TypeNode { RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type); +/*! + * \brief A global type variable that is used for defining new types or type aliases. + */ +class GlobalTypeVar; +/*! \brief GlobalTypeVar container node */ +class GlobalTypeVarNode : public TypeNode { + public: + /*! + * \brief The variable itself is only meaningful when + * kind is ShapeVar; otherwise, we only use the name. + */ + tvm::Var var; + /*! \brief The kind of type parameter */ + Kind kind; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("kind", &kind); + v->Visit("span", &span); + } + + TVM_DLL static GlobalTypeVar make(std::string name, Kind kind); + + static constexpr const char* _type_key = "relay.GlobalTypeVar"; + TVM_DECLARE_NODE_TYPE_INFO(GlobalTypeVarNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(GlobalTypeVar, GlobalTypeVarNode, Type); + +/*! + * \brief Type application. + */ +class TypeCall; +/*! \brief TypeCall container node */ +class TypeCallNode : public TypeNode { + public: + /*! + * \brief The type-level function. + */ + Type func; + /*! \brief The arguments. */ + tvm::Array args; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("func", &func); + v->Visit("args", &args); + v->Visit("span", &span); + } + + TVM_DLL static TypeCall make(Type func, tvm::Array args); + + static constexpr const char* _type_key = "relay.TypeCall"; + TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); + /*! * \brief IncompleteType. * This is intermediate values that is used during type inference. @@ -162,14 +220,14 @@ class IncompleteType; /*! \brief IncompleteType container node */ class IncompleteTypeNode : public TypeNode { public: - TypeVarNode::Kind kind; + Kind kind; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("kind", &kind); v->Visit("span", &span); } - TVM_DLL static IncompleteType make(TypeVarNode::Kind kind); + TVM_DLL static IncompleteType make(Kind kind); static constexpr const char* _type_key = "relay.IncompleteType"; TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 44d819ea78a36..f1592d18df827 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -7,8 +7,10 @@ from . import expr from . import expr_functor from . import module +from . import adt from . import ir_pass from .build_module import build, build_config, create_executor +from . import prelude from . import parser from . import debug @@ -43,6 +45,8 @@ TypeRelation = ty.TypeRelation IncompleteType = ty.IncompleteType scalar_type = ty.scalar_type +GlobalTypeVar = ty.GlobalTypeVar +TypeCall = ty.TypeCall # Expr Expr = expr.Expr @@ -60,6 +64,15 @@ ExprFunctor = expr_functor.ExprFunctor ExprMutator = expr_functor.ExprMutator +# ADT +PatternWildcard = adt.PatternWildcard +PatternVar = adt.PatternVar +PatternConstructor = adt.PatternConstructor +Constructor = adt.Constructor +TypeData = adt.TypeData +Clause = adt.Clause +Match = adt.Match + # helper functions var = expr.var const = expr.const diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/adt.py new file mode 100644 index 0000000000000..7e25ec16fb522 --- /dev/null +++ b/python/tvm/relay/adt.py @@ -0,0 +1,166 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""Algebraic data types in Relay.""" +from .base import RelayNode, register_relay_node, NodeBase +from . import _make +from .ty import Type +from .expr import Expr + + +class Pattern(RelayNode): + """Base type for pattern matching constructs.""" + pass + +@register_relay_node +class PatternWildcard(Pattern): + """Wildcard pattern in Relay: Matches any ADT and binds nothing.""" + + def __init__(self): + """Constructs a wildcard pattern. + + Parameters + ---------- + None + + Returns + ------- + wildcard: PatternWildcard + a wildcard pattern. + """ + self.__init_handle_by_constructor__(_make.PatternWildcard) + + +@register_relay_node +class PatternVar(Pattern): + """Variable pattern in Relay: Matches anything and binds it to the variable.""" + + def __init__(self, var): + """Construct a variable pattern. + + Parameters + ---------- + var: tvm.relay.Var + + Returns + ------- + pv: PatternVar + A variable pattern. + """ + self.__init_handle_by_constructor__(_make.PatternVar, var) + + +@register_relay_node +class PatternConstructor(Pattern): + """Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively.""" + + def __init__(self, con, pat=None): + """Construct a constructor pattern. + + Parameters + ---------- + con: Constructor + The constructor. + pat: Optional[List[Pattern]] + Optional subpatterns: for each field of the constructor, + match to the given subpattern (treated as a variable pattern by default). + + Returns + ------- + wildcard: PatternWildcard + a wildcard pattern. + """ + if pat is None: + pat = [] + self.__init_handle_by_constructor__(_make.PatternConstructor, con, pat) + + +@register_relay_node +class Constructor(Expr): + """Relay ADT constructor.""" + + def __init__(self, name_hint, inp, belong_to): + """Defines an ADT constructor. + + Parameters + ---------- + name_hint : str + Name of constructor (only a hint). + inp : List[Type] + Input types. + belong_to : tvm.relay.GlobalTypeVar + Denotes which ADT the constructor belongs to. + + Returns + ------- + con: Constructor + A constructor. + """ + self.__init_handle_by_constructor__(_make.Constructor, name_hint, inp, belong_to) + + +@register_relay_node +class TypeData(Type): + """Stores the definition for an Algebraic Data Type (ADT) in Relay.""" + + def __init__(self, header, tv, constructors): + """Defines a TypeData object. + + Parameters + ---------- + header: tvm.relay.GlobalTypeVar + The name of the ADT. + ADTs with the same constructors but different names are + treated as different types. + tv: List[TypeVar] + Type variables that appear in constructors. + constructors: List[tvm.relay.Constructor] + The constructors for the ADT. + + Returns + ------- + type_data: TypeData + The adt declaration. + """ + self.__init_handle_by_constructor__(_make.TypeData, header, tv, constructors) + + +@register_relay_node +class Clause(NodeBase): + """Clause for pattern matching in Relay.""" + + def __init__(self, lhs, rhs): + """Construct a clause. + + Parameters + ---------- + lhs: tvm.relay.Pattern + Left-hand side of match clause. + rhs: tvm.relay.Expr + Right-hand side of match clause. + + Returns + ------- + clause: Clause + The Clause. + """ + self.__init_handle_by_constructor__(_make.Clause, lhs, rhs) + + +@register_relay_node +class Match(Expr): + """Pattern matching expression in Relay.""" + + def __init__(self, data, pattern): + """Construct a Match. + + Parameters + ---------- + data: tvm.relay.Expr + The value being deconstructed and matched. + pattern: [tvm.relay.Clause] + The pattern match clauses. + Returns + ------- + match: tvm.relay.Expr + The match expression. + """ + self.__init_handle_by_constructor__(_make.Match, data, pattern) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 4a5ddcd8270c5..f1d9f7797deca 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -51,6 +51,13 @@ class Closure(Value): pass +@register_relay_node +class ConValue(Value): + def __init__(self, con, fields, types): + self.__init_handle_by_constructor__( + _make.ConValue, con, fields, types) + + @register_relay_node class TensorValue(Value): """A Tensor value produced by the interpreter.""" diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index eafe5f09309ff..1e8cd0814df3b 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -2,6 +2,7 @@ """The expression functor of Relay.""" from .expr import Function, Call, Let, Var, GlobalVar, If, Tuple, TupleGetItem, Constant +from .adt import Constructor, Match, Clause from .op import Op class ExprFunctor: @@ -41,6 +42,10 @@ def visit(self, expr): res = self.visit_constant(expr) elif isinstance(expr, Op): res = self.visit_op(expr) + elif isinstance(expr, Constructor): + res = self.visit_constructor(expr) + elif isinstance(expr, Match): + res = self.visit_match(expr) else: raise Exception("warning unhandled case: {0}".format(type(expr))) @@ -81,6 +86,12 @@ def visit_op(self, _): def visit_constant(self, _): raise NotImplementedError() + def visit_constructor(self, _): + raise NotImplementedError() + + def visit_match(self, _): + raise NotImplementedError() + class ExprMutator(ExprFunctor): """ diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 024c6baf70126..5812d9a2481b0 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -6,6 +6,7 @@ from . import _module from . import expr as _expr +from . import ty as _ty @register_relay_node class Module(RelayNode): @@ -20,7 +21,7 @@ class Module(RelayNode): functions : dict, optional. Map of global var to Function """ - def __init__(self, functions=None): + def __init__(self, functions=None, type_definitions=None): if functions is None: functions = {} elif isinstance(functions, dict): @@ -32,28 +33,45 @@ def __init__(self, functions=None): raise TypeError("Expect functions to be Dict[GlobalVar, Function]") mapped_funcs[k] = v functions = mapped_funcs - self.__init_handle_by_constructor__(_make.Module, functions) - - def __setitem__(self, var, func): - """Add a function to the module. + if type_definitions is None: + type_definitions = {} + elif isinstance(type_definitions, dict): + mapped_type_defs = {} + for k, v in type_definitions.items(): + if isinstance(k, _base.string_types): + k = _ty.GlobalTypeVar(k) + if not isinstance(k, _ty.GlobalTypeVar): + raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") + type_definitions = mapped_type_defs + self.__init_handle_by_constructor__(_make.Module, functions, type_definitions) + + + def __setitem__(self, var, val): + """Add a mapping to the module. Parameters --------- var: GlobalVar - The global variable which names the function. + The global variable. - func: Function - The function. + val: Union[Function, Type] + The value. """ - return self._add(var, func) + return self._add(var, val) - def _add(self, var, func, update=False): - if isinstance(var, _base.string_types): - var = _expr.GlobalVar(var) - return _module.Module_Add(self, var, func, update) + def _add(self, var, val, update=False): + if isinstance(val, _expr.Function): + if isinstance(var, _base.string_types): + var = _expr.GlobalVar(var) + _make.Module_Add(self, var, val, update) + else: + assert isinstance(val, _ty.Type) + if isinstance(var, _base.string_types): + var = _ty.GlobalTypeVar(var) + _module.Module_AddDef(self, var, val) def __getitem__(self, var): - """Lookup a global function by name or by variable. + """Lookup a global definition by name or by variable. Parameters ---------- @@ -62,13 +80,15 @@ def __getitem__(self, var): Returns ------- - func: Function - The function referenced by :code:`var`. + val: Union[Function, Type] + The definition referenced by :code:`var` (either a function or type). """ if isinstance(var, _base.string_types): return _module.Module_Lookup_str(self, var) - else: + elif isinstance(var, _expr.GlobalVar): return _module.Module_Lookup(self, var) + else: + return _module.Module_LookupDef(self, var) def update(self, other): """Insert functions in another Module to current one. @@ -100,3 +120,22 @@ def get_global_var(self, name): tvm.TVMError if we cannot find corresponding global var. """ return _module.Module_GetGlobalVar(self, name) + + def get_global_type_var(self, name): + """Get a global type variable in the function by name. + + Parameters + ---------- + name: str + The name of the global type variable. + + Returns + ------- + global_type_var: GlobalTypeVar + The global variable mapped to :code:`name`. + + Raises + ------ + tvm.TVMError if we cannot find corresponding global type var. + """ + return _module.Module_GetGlobalTypeVar(self, name) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py new file mode 100644 index 0000000000000..29fcc40b93613 --- /dev/null +++ b/python/tvm/relay/prelude.py @@ -0,0 +1,114 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""Include some preloaded term/type definitions.""" +from .ty import GlobalTypeVar, TypeVar, FuncType +from .expr import Var, Function, GlobalVar +from .adt import Constructor, TypeData, Clause, Match +from .adt import PatternConstructor, PatternVar, PatternWildcard + +class Prelude: + """Contain standard definitions.""" + def __init__(self, mod): + self.mod = mod + self.nat = GlobalTypeVar("nat") + self.z = Constructor("z", [], self.nat) + self.s = Constructor("s", [self.nat()], self.nat) + mod[self.nat] = TypeData(self.nat, [], [self.z, self.s]) + + self.double = GlobalVar("double") + x = Var("x", self.nat()) + y = Var("y") + z_case = Clause(PatternConstructor(self.z), self.z()) + s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.s(self.s(self.double(y)))) + mod[self.double] = Function([x], Match(x, [z_case, s_case])) + + self.add = GlobalVar("add") + x = Var("x", self.nat()) + y = Var("y", self.nat()) + a = Var("a") + z_case = Clause(PatternConstructor(self.z), y) + s_case = Clause(PatternConstructor(self.s, [PatternVar(a)]), self.s(self.add(a, y))) + mod[self.add] = Function([x, y], Match(x, [z_case, s_case])) + + self.l = GlobalTypeVar("list") + a = TypeVar("a") + self.nil = Constructor("nil", [], self.l) + self.cons = Constructor("cons", [a, self.l(a)], self.l) + mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons]) + + self.length = GlobalVar("length") + a = TypeVar("a") + x = Var("x", self.l(a)) + y = Var("y") + nil_case = Clause(PatternConstructor(self.nil), self.z()) + cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]), + self.s(self.length(y))) + mod[self.length] = Function([x], Match(x, [nil_case, cons_case]), None, [a]) + + self.map = GlobalVar("map") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a], b)) + x = Var("x", self.l(a)) + y = Var("y") + z = Var("z") + nil_case = Clause(PatternConstructor(self.nil), self.nil()) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), + self.cons(f(y), self.map(f, z))) + mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), None, [a, b]) + + self.foldl = GlobalVar("foldl") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a, b], a)) + av = Var("av", a) + bv = Var("bv", self.l(b)) + y = Var("y") + z = Var("z") + nil_case = Clause(PatternConstructor(self.nil), av) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), + self.foldl(f, f(av, y), z)) + mod[self.foldl] = Function([f, av, bv], Match(bv, [nil_case, cons_case]), None, [a, b]) + + self.tree = GlobalTypeVar("tree") + a = TypeVar("a") + self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree) + mod[self.tree] = TypeData(self.tree, [a], [self.rose]) + + self.foldr = GlobalVar("foldr") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a, b], b)) + av = Var("av", self.l(a)) + bv = Var("bv", b) + y = Var("y") + z = Var("z") + nil_case = Clause(PatternConstructor(self.nil), bv) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), + f(y, self.foldr(f, bv, z))) + mod[self.foldr] = Function([f, bv, av], Match(av, [nil_case, cons_case]), None, [a, b]) + + self.sum = GlobalVar("sum") + a = Var("a", self.l(self.nat())) + mod[self.sum] = Function([a], self.foldl(self.add, self.z(), a)) + + self.tmap = GlobalVar("tmap") + a = TypeVar("a") + b = TypeVar("b") + t = Var("t", self.tree(a)) + f = Var("f", FuncType([a], b)) + x = Var("x", self.tree(a)) + y = Var("y") + z = Var("z") + rose_case = Clause(PatternConstructor(self.rose, [PatternVar(y), PatternVar(z)]), + self.rose(f(y), self.map(Function([x], self.tmap(f, x)), z))) + mod[self.tmap] = Function([f, t], Match(t, [rose_case]), self.tree(b), [a, b]) + + self.size = GlobalVar("size") + a = TypeVar("a") + t = Var("t", self.tree(a)) + x = Var("x", self.tree(a)) + z = Var("z") + rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]), + self.s(self.sum(self.map(Function([x], self.size(x)), z)))) + mod[self.size] = Function([t], Match(t, [rose_case]), self.nat(), [a]) + # cannot infer return type here diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 96dde5acb4dfe..b7650d01826ab 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -21,6 +21,19 @@ def same_as(self, other): """Compares two Relay types by referential equality.""" return super().__eq__(other) + def __call__(self, *args): + """Create a type call from this type. + + Parameters + ---------- + args: List[relay.Type] + The arguments to the type call. + + Returns + ------- + call: relay.TypeCall + """ + return TypeCall(self, args) @register_relay_node class TensorType(Type): @@ -106,6 +119,51 @@ def __init__(self, var, kind=Kind.Type): self.__init_handle_by_constructor__(_make.TypeVar, var, kind) +@register_relay_node +class GlobalTypeVar(Type): + """A global type variable in Relay. + GlobalTypeVar is used to refer to the global type-level definitions + stored in the environment. + """ + + def __init__(self, var, kind=Kind.Type): + """Construct a GlobalTypeVar. + + Parameters + ---------- + var: tvm.Var + The tvm.Var which backs the type parameter. + kind: Kind, optional + The kind of the type parameter, Kind.Type by default. + + Returns + ------- + type_var: GlobalTypeVar + The global type variable. + """ + self.__init_handle_by_constructor__(_make.GlobalTypeVar, var, kind) + + +@register_relay_node +class TypeCall(Type): + """Type-level function application in Relay.""" + + def __init__(self, func, args): + """Construct a TypeCall. + Parameters + ---------- + func: tvm.relay.Type + The function. + args: List[tvm.expr.Type] + The arguments. + Returns + ------- + type_call: TypeCall + The type function application. + """ + self.__init_handle_by_constructor__(_make.TypeCall, func, args) + + @register_relay_node class TypeConstraint(Type): """Abstract class representing a type constraint.""" diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 734180c537596..b577f2d9d131a 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -75,6 +76,26 @@ TVM_REGISTER_API("relay._make.TensorValue") *ret = TensorValueNode::make(data); }); +ConValue ConValueNode::make(Constructor con, + tvm::Array fields) { + NodePtr n = make_node(); + n->con = con; + n->fields = fields; + return ConValue(n); +} + +TVM_REGISTER_API("relay._make.ConValue") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ConValueNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const ConValueNode* node, + tvm::IRPrinter* p) { + p->stream << "ConValueNode(" << node->con + << node->fields << ")"; +}); + /*! * \brief A stack frame in the Relay interpreter. * @@ -168,7 +189,8 @@ InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) { // // Conversion to ANF is recommended before running the interpretation. class Interpreter : - public ExprFunctor { + public ExprFunctor, + PatternFunctor { public: Interpreter(Module mod, DLContext context, @@ -192,7 +214,7 @@ class Interpreter : } Value Eval(const Expr& expr) { - return (*this)(expr); + return VisitExpr(expr); } Value VisitExpr(const Expr& expr) final { @@ -384,6 +406,9 @@ class Interpreter : << "; operators should be removed by future passes; try " "fusing and lowering"; } + if (auto con = call->op.as()) { + return ConValueNode::make(GetRef(con), args); + } // Now we just evaluate and expect to find a closure. Value fn_val = Eval(call->op); if (const ClosureNode* closure_node = fn_val.as()) { @@ -432,7 +457,45 @@ class Interpreter : } } - InterpreterState get_state(Expr e = Expr()) const { + Value VisitExpr_(const MatchNode* op) final { + Value v = Eval(op->data); + for (const Clause& c : op->pattern) { + if (VisitPattern(c->lhs, v)) { + return VisitExpr(c->rhs); + } + } + LOG(FATAL) << "did not find any match"; + return Value(); + } + + bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final { + const ConValueNode* cvn = v.as(); + CHECK(cvn) << "need to be a constructor for match"; + CHECK_NE(op->con->tag, -1); + CHECK_NE(cvn->con->tag, -1); + if (op->con->tag == cvn->con->tag) { + // todo(M.K.): should use ptr equality but it is broken + CHECK(op->pat.size() == cvn->fields.size()); + for (size_t i = 0; i < op->pat.size(); ++i) { + if (!VisitPattern(op->pat[i], cvn->fields[i])) { + return false; + } + } + return true; + } + return false; + } + + bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final { + return true; + } + + bool VisitPattern_(const PatternVarNode* op, const Value& v) final { + extend(op->var, v); + return true; + } + + InterpreterState get_state(Expr e = Expr()) const { InterpreterStateNode::Stack stack; for (auto fr : this->stack_.frames) { InterpreterStateNode::Frame frame = fr.locals; @@ -443,14 +506,14 @@ class Interpreter : } private: - // module + // Module Module mod_; // For simplicity we only run the interpreter on a single context. // Context to run the interpreter on. DLContext context_; // Target parameter being used by the interpreter. Target target_; - // value stack. + // Value stack. Stack stack_; // Backend compile engine. CompileEngine engine_; diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc new file mode 100644 index 0000000000000..4bcbc5f27e4bd --- /dev/null +++ b/src/relay/ir/adt.cc @@ -0,0 +1,161 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/ir/adt.cc + * \brief AST nodes for Relay algebraic data types (ADTs). + */ +#include +#include + +namespace tvm { +namespace relay { + +PatternWildcard PatternWildcardNode::make() { + NodePtr n = make_node(); + return PatternWildcard(n); +} + +TVM_REGISTER_NODE_TYPE(PatternWildcardNode); + +TVM_REGISTER_API("relay._make.PatternWildcard") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = PatternWildcardNode::make(); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const PatternWildcardNode* node, + tvm::IRPrinter* p) { + p->stream << "PatternWildcardNode()"; +}); + +PatternVar PatternVarNode::make(tvm::relay::Var var) { + NodePtr n = make_node(); + n->var = std::move(var); + return PatternVar(n); +} + +TVM_REGISTER_NODE_TYPE(PatternVarNode); + +TVM_REGISTER_API("relay._make.PatternVar") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = PatternVarNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const PatternVarNode* node, + tvm::IRPrinter* p) { + p->stream << "PatternVarNode(" << node->var << ")"; +}); + +PatternConstructor PatternConstructorNode::make(Constructor con, tvm::Array pat) { + NodePtr n = make_node(); + n->con = std::move(con); + n->pat = std::move(pat); + return PatternConstructor(n); +} + +TVM_REGISTER_NODE_TYPE(PatternConstructorNode); + +TVM_REGISTER_API("relay._make.PatternConstructor") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = PatternConstructorNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const PatternConstructorNode* node, + tvm::IRPrinter* p) { + p->stream << "PatternConstructorNode(" << node->con + << ", " << node->pat << ")"; +}); + +Constructor ConstructorNode::make(std::string name_hint, + tvm::Array inp, + GlobalTypeVar belong_to) { + NodePtr n = make_node(); + n->name_hint = std::move(name_hint); + n->inp = std::move(inp); + n->belong_to = std::move(belong_to); + return Constructor(n); +} + +TVM_REGISTER_NODE_TYPE(ConstructorNode); + +TVM_REGISTER_API("relay._make.Constructor") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ConstructorNode::make(args[0], args[1], args[2]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const ConstructorNode* node, + tvm::IRPrinter* p) { + p->stream << "ConstructorNode(" << node->name_hint << ", " + << node->inp << ", " << node->belong_to << ")"; +}); + +TypeData TypeDataNode::make(GlobalTypeVar header, + tvm::Array tv, + tvm::Array constructors) { + NodePtr n = make_node(); + n->header = std::move(header); + n->tv = std::move(tv); + n->constructors = std::move(constructors); + return TypeData(n); +} + +TVM_REGISTER_NODE_TYPE(TypeDataNode); + +TVM_REGISTER_API("relay._make.TypeData") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TypeDataNode::make(args[0], args[1], args[2]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TypeDataNode* node, + tvm::IRPrinter* p) { + p->stream << "TypeDataNode(" << node->header << ", " << node->tv << ", " + << node->constructors << ")"; +}); + +Clause ClauseNode::make(Pattern lhs, Expr rhs) { + NodePtr n = make_node(); + n->lhs = std::move(lhs); + n->rhs = std::move(rhs); + return Clause(n); +} + +TVM_REGISTER_NODE_TYPE(ClauseNode); + +TVM_REGISTER_API("relay._make.Clause") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ClauseNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const ClauseNode* node, + tvm::IRPrinter* p) { + p->stream << "ClauseNode(" << node->lhs << ", " + << node->rhs << ")"; + }); + +Match MatchNode::make(Expr data, tvm::Array pattern) { + NodePtr n = make_node(); + n->data = std::move(data); + n->pattern = std::move(pattern); + return Match(n); +} + +TVM_REGISTER_NODE_TYPE(MatchNode); + +TVM_REGISTER_API("relay._make.Match") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = MatchNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const MatchNode* node, + tvm::IRPrinter* p) { + p->stream << "MatchNode(" << node->data << ", " + << node->pattern << ")"; +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 064343c834ea0..c7fc520c1b8ec 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include #include #include "type_functor.h" @@ -17,7 +18,8 @@ namespace relay { class AlphaEqualHandler: public AttrsEqualHandler, public TypeFunctor, - public ExprFunctor { + public ExprFunctor, + public PatternFunctor { public: explicit AlphaEqualHandler(bool map_free_var) : map_free_var_(map_free_var) {} @@ -160,7 +162,7 @@ class AlphaEqualHandler: } equal_map_[lhs->type_params[i]] = rhs->type_params[i]; // set up type parameter equal - if (lhs->type_params[i]->kind == TypeVarNode::Kind::kShapeVar) { + if (lhs->type_params[i]->kind == Kind::kShapeVar) { // map variable equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var; } @@ -207,6 +209,27 @@ class AlphaEqualHandler: return false; } } + + bool VisitType_(const GlobalTypeVarNode* op, const Type& t2) final { + return GetRef(op) == t2; + } + + bool VisitType_(const TypeCallNode* op, const Type& t2) final { + const TypeCallNode* pt = t2.as(); + if (pt == nullptr + || op->args.size() != pt->args.size() + || !TypeEqual(op->func, pt->func)) { + return false; + } + + for (size_t i = 0; i < op->args.size(); ++i) { + if (!TypeEqual(op->args[i], pt->args[i])) { + return false; + } + } + return true; + } + // Expr equal checking. bool NDArrayEqual(const runtime::NDArray& lhs, const runtime::NDArray& rhs) { @@ -253,11 +276,9 @@ class AlphaEqualHandler: bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final { if (const GlobalVarNode* rhs = other.as()) { // use name equality for global var for now. - if (lhs->name_hint != rhs->name_hint) return false; - return true; - } else { - return false; + return lhs->name_hint == rhs->name_hint; } + return false; } bool VisitExpr_(const TupleNode* lhs, const Expr& other) final { @@ -361,6 +382,62 @@ class AlphaEqualHandler: } } + bool VisitExpr_(const ConstructorNode* op, const Expr& e2) final { + return GetRef(op) == e2; + } + + bool ClauseEqual(const Clause& l, const Clause& r) { + return PatternEqual(l->lhs, r->lhs) && ExprEqual(l->rhs, r->rhs); + } + + bool PatternEqual(const Pattern& l, const Pattern& r) { + return VisitPattern(l, r); + } + + bool VisitPattern_(const PatternWildcardNode* op, const Pattern& r) { + return r.as(); + } + + bool VisitPattern_(const PatternVarNode* op, const Pattern& e2) { + if (const auto* r = e2.as()) { + return ExprEqual(op->var, r->var); + } + return false; + } + + bool VisitPattern_(const PatternConstructorNode* op, const Pattern& e2) { + const auto* r = e2.as(); + if (r == nullptr + || !ExprEqual(op->con, r->con) + || op->pat.size() != r->pat.size()) { + return false; + } + + for (size_t i = 0; i < op->pat.size(); i++) { + if (!PatternEqual(op->pat[i], r->pat[i])) { + return false; + } + } + return true; + } + + bool VisitExpr_(const MatchNode* op, const Expr& e2) final { + const MatchNode* r = e2.as(); + + if (r == nullptr + || !ExprEqual(op->data, r->data) + || op->pattern.size() != r->pattern.size()) { + return false; + } + + for (size_t i = 0; i < op->pattern.size(); ++i) { + if (!ClauseEqual(op->pattern[i], r->pattern[i])) { + return false; + } + } + return true; + } + private: // whether to map open terms. bool map_free_var_{false}; diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index c1719e81a6c6c..010e84da8533d 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -157,6 +157,24 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { } } +Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { + return GetRef(c); +} + +Expr ExprMutator::VisitExpr_(const MatchNode* m) { + std::vector pattern; + for (const Clause& p : m->pattern) { + pattern.push_back(VisitClause(p)); + } + return MatchNode::make(VisitExpr(m->data), pattern); +} + +Clause ExprMutator::VisitClause(const Clause& c) { + return ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs)); +} + +Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } + Type ExprMutator::VisitType(const Type& t) { return t; } void ExprVisitor::VisitExpr(const Expr& expr) { @@ -226,6 +244,27 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); } +void ExprVisitor::VisitExpr_(const ConstructorNode* op) { + for (const Type& t : op->inp) { + this->VisitType(t); + } + this->VisitType(op->belong_to); +} + +void ExprVisitor::VisitExpr_(const MatchNode* op) { + this->VisitExpr(op->data); + for (const Clause& c : op->pattern) { + this->VisitClause(c); + } +} + +void ExprVisitor::VisitClause(const Clause& op) { + this->VisitPattern(op->lhs); + this->VisitExpr(op->rhs); +} + +void ExprVisitor::VisitPattern(const Pattern& p) { return; } + void ExprVisitor::VisitType(const Type& t) { return; } diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index d7a8df98fa3fb..479a7b8d4552e 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include #include #include @@ -16,9 +17,10 @@ namespace relay { // Hash handler for Relay. class RelayHashHandler: - public AttrsHashHandler, - public TypeFunctor, - public ExprFunctor { + public AttrsHashHandler, + public TypeFunctor, + public ExprFunctor, + public PatternFunctor { public: explicit RelayHashHandler() {} @@ -195,7 +197,7 @@ class RelayHashHandler: hash_map_[var] = hash; const auto* ty_param = var.as(); - if (ty_param && ty_param->kind == TypeVarNode::Kind::kShapeVar) { + if (ty_param && ty_param->kind == Kind::kShapeVar) { hash_map_[ty_param->var] = hash; } return hash; @@ -230,7 +232,7 @@ class RelayHashHandler: } hash = Combine(hash, TypeHash(func->ret_type)); - hash = Combine(hash, ExprHash(func->body)); + hash = Combine(hash, ExprHash(func->body)); return hash; } @@ -243,6 +245,10 @@ class RelayHashHandler: hash = Combine(hash, ExprHash(arg)); } + for (auto t : call->type_args) { + hash = Combine(hash, TypeHash(t)); + } + hash = Combine(hash, AttrHash(call->attrs)); return hash; @@ -280,6 +286,71 @@ class RelayHashHandler: return hash; } + size_t VisitExpr_(const MatchNode* mn) final { + size_t hash = std::hash()(MatchNode::_type_key); + hash = Combine(hash, ExprHash(mn->data)); + for (const auto& c : mn->pattern) { + hash = Combine(hash, PatternHash(c->lhs)); + hash = Combine(hash, ExprHash(c->rhs)); + } + return hash; + } + + size_t VisitExpr_(const ConstructorNode* cn) final { + size_t hash = std::hash()(ConstructorNode::_type_key); + hash = Combine(hash, std::hash()(cn->name_hint)); + return hash; + } + + size_t VisitType_(const TypeCallNode* tcn) final { + size_t hash = std::hash()(TypeCallNode::_type_key); + hash = Combine(hash, TypeHash(tcn->func)); + for (const auto& t : tcn->args) { + hash = Combine(hash, TypeHash(t)); + } + return hash; + } + + size_t VisitType_(const TypeDataNode* tdn) final { + size_t hash = std::hash()(TypeDataNode::_type_key); + hash = Combine(hash, TypeHash(tdn->header)); + for (const auto& tv : tdn->tv) { + hash = Combine(hash, TypeHash(tv)); + } + for (const auto& cn : tdn->constructors) { + hash = Combine(hash, ExprHash(cn)); + } + return hash; + } + + size_t VisitType_(const GlobalTypeVarNode* tvn) final { + return BindVar(GetRef(tvn)); + } + + size_t PatternHash(const Pattern& p) { + return VisitPattern(p); + } + + size_t VisitPattern_(const PatternConstructorNode* pcn) final { + size_t hash = std::hash()(PatternConstructorNode::_type_key); + hash = Combine(hash, ExprHash(pcn->con)); + for (const auto& p : pcn->pat) { + hash = Combine(hash, PatternHash(p)); + } + return hash; + } + + size_t VisitPattern_(const PatternVarNode* pvn) final { + size_t hash = std::hash()(PatternVarNode::_type_key); + hash = Combine(hash, ExprHash(pvn->var)); + return hash; + } + + size_t VisitPattern_(const PatternWildcardNode* pwn) final { + size_t hash = std::hash()(PatternWildcardNode::_type_key); + return hash; + } + private: // renaming of NodeRef to indicate two nodes equals to each other std::unordered_map hash_map_; diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index cbb0b77680043..7cfec8c48715c 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -13,16 +13,24 @@ namespace relay { using tvm::IRPrinter; using namespace runtime; -Module ModuleNode::make(tvm::Map global_funcs) { +Module ModuleNode::make(tvm::Map global_funcs, + tvm::Map global_type_defs) { auto n = make_node(); n->functions = std::move(global_funcs); + n->type_definitions = std::move(global_type_defs); for (const auto& kv : n->functions) { - // set gloval var map + // set global var map CHECK(!n->global_var_map_.count(kv.first->name_hint)) - << "Duplicate global function name " << kv.first->name_hint; + << "Duplicate global function name " << kv.first->name_hint; n->global_var_map_.Set(kv.first->name_hint, kv.first); } + for (const auto& kv : n->type_definitions) { + // set global typevar map + CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint)) + << "Duplicate global type definition name " << kv.first->var->name_hint; + n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first); + } return Module(n); } @@ -49,6 +57,13 @@ void ModuleNode::AddUnchecked(const GlobalVar& var, global_var_map_.Set(var->name_hint, var); } +GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) { + auto it = global_type_var_map_.find(name); + CHECK(it != global_type_var_map_.end()) + << "Cannot find global type var " << name << " in the Module"; + return (*it).second; +} + void ModuleNode::Add(const GlobalVar& var, const Function& func, bool update) { @@ -67,6 +82,19 @@ void ModuleNode::Add(const GlobalVar& var, AddUnchecked(var, checked_func); } +void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { + // kind checker is broken, not checking them rn. + // TODO(slyubomirsky, MarisaKirisame): fix the kind checker. + this->type_definitions.Set(var, type); + // set global type var map + CHECK(!global_type_var_map_.count(var->var->name_hint)) + << "Duplicate global type definition name " << var->var->name_hint; + global_type_var_map_.Set(var->var->name_hint, var); + for (size_t i = 0; i < type->constructors.size(); ++i) { + type->constructors[i]->tag = i; + } + } + void ModuleNode::Update(const GlobalVar& var, const Function& func) { this->Add(var, func, true); } @@ -90,6 +118,18 @@ Function ModuleNode::Lookup(const std::string& name) { return this->Lookup(id); } +TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) { + auto it = type_definitions.find(var); + CHECK(it != type_definitions.end()) + << "There is no definition of " << var->var->name_hint; + return (*it).second; +} + +TypeData ModuleNode::LookupDef(const std::string& name) { + GlobalTypeVar id = this->GetGlobalTypeVar(name); + return this->LookupDef(id); +} + void ModuleNode::Update(const Module& mod) { for (auto pair : mod->functions) { this->Update(pair.first, pair.second); @@ -100,21 +140,33 @@ TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_API("relay._make.Module") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = ModuleNode::make(args[0]); + *ret = ModuleNode::make(args[0], args[1]); }); -TVM_REGISTER_API("relay._module.Module_Add") +TVM_REGISTER_API("relay._make.Module_Add") .set_body([](TVMArgs args, TVMRetValue *ret) { Module mod = args[0]; mod->Add(args[1], args[2], args[3]); }); +TVM_REGISTER_API("relay._module.Module_AddDef") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Module mod = args[0]; + mod->AddDef(args[1], args[2]); + }); + TVM_REGISTER_API("relay._module.Module_GetGlobalVar") .set_body([](TVMArgs args, TVMRetValue *ret) { Module mod = args[0]; *ret = mod->GetGlobalVar(args[1]); }); +TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Module mod = args[0]; + *ret = mod->GetGlobalTypeVar(args[1]); + }); + TVM_REGISTER_API("relay._module.Module_Lookup") .set_body([](TVMArgs args, TVMRetValue *ret) { Module mod = args[0]; @@ -126,8 +178,21 @@ TVM_REGISTER_API("relay._module.Module_Lookup_str") .set_body([](TVMArgs args, TVMRetValue *ret) { Module mod = args[0]; std::string var_name = args[1]; - auto var = mod->GetGlobalVar(var_name); - *ret = mod->Lookup(var); + *ret = mod->Lookup(var_name); + }); + +TVM_REGISTER_API("relay._module.Module_LookupDef") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Module mod = args[0]; + GlobalTypeVar var = args[1]; + *ret = mod->LookupDef(var); + }); + +TVM_REGISTER_API("relay._module.Module_LookupDef_str") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Module mod = args[0]; + std::string var_name = args[1]; + *ret = mod->LookupDef(var_name); }); TVM_REGISTER_API("relay._module.Module_Update") diff --git a/src/relay/ir/pattern_functor.cc b/src/relay/ir/pattern_functor.cc new file mode 100644 index 0000000000000..71002058fe499 --- /dev/null +++ b/src/relay/ir/pattern_functor.cc @@ -0,0 +1,75 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/pattern_functor.cc + * \brief Implementations of visitors and mutators for ADT patterns. + */ + +#include + +namespace tvm { +namespace relay { + +Pattern PatternMutator::Mutate(const Pattern& pat) { + return (*this)(pat); +} + +Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { + return GetRef(op); +} + +Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) { + return PatternVarNode::make(VisitVar(op->var)); +} + +Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) { + std::vector pat; + for (const auto& p : op->pat) { + pat.push_back(VisitPattern(p)); + } + return PatternConstructorNode::make(VisitConstructor(op->con), pat); +} + +Type PatternMutator::VisitType(const Type& t) { + return t; +} + +Var PatternMutator::VisitVar(const Var& v) { + if (var_map_.count(v) == 0) { + var_map_.insert(std::pair(v, + VarNode::make(v->name_hint(), + VisitType(v->type_annotation)))); + } + return var_map_.at(v); +} + +Constructor PatternMutator::VisitConstructor(const Constructor& v) { + return v; +} + +void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) { } + +void PatternVisitor::VisitPattern_(const PatternVarNode* op) { + VisitVar(op->var); +} + +void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) { + VisitConstructor(op->con); + for (const auto& p : op->pat) { + VisitPattern(p); + } +} + +void PatternVisitor::VisitType(const Type& t) { } + +void PatternVisitor::VisitVar(const Var& v) { + VisitType(v->type_annotation); +} + +void PatternVisitor::VisitConstructor(const Constructor& c) { + for (const auto& inp : c->inp) { + VisitType(inp); + } +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 46b0d25b3d7de..31db0d8c493fa 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include #include "type_functor.h" #include "../../lang/attr_functor.h" @@ -23,6 +24,12 @@ struct TextValue { TextValue() {} // constructor explicit TextValue(std::string name) : name(name) {} + TextValue operator+(const TextValue& rhs) const { + return TextValue(name + rhs.name); + } + TextValue operator+(const std::string& str) const { + return TextValue(name + str); + } }; // operator overloading @@ -127,6 +134,7 @@ class TextMetaDataContext { class TextPrinter : public ExprFunctor, + public PatternFunctor, public TypeFunctor, // NOLINT(*) public AttrFunctor { // NOLINT(*) public: @@ -212,6 +220,9 @@ class TextPrinter : memo_[expr] = val; return val; } + TextValue GetValue(const Pattern& p) { + return this->VisitPattern(p); + } //------------------------------------ // Overload of Expr printing functions //------------------------------------ @@ -362,6 +373,36 @@ class TextPrinter : return id; } + TextValue VisitExpr_(const MatchNode* op) final { + TextValue data = GetValue(op->data); + this->PrintIndent(); + TextValue id = this->AllocTempVar(); + stream_ << id << " = " << "Match " << data << " with"; + this->PrintEndInst("\n"); + for (const auto& c : op->pattern) { + this->PrintIndent(); + stream_ << GetValue(c->lhs) << " to " << GetValue(c->rhs); + this->PrintEndInst("\n"); + } + return id; + } + + TextValue VisitPattern_(const PatternConstructorNode* p) final { + TextValue ret(p->con->name_hint + "("); + for (const Pattern& pat : p->pat) { + ret = ret + " " + GetValue(pat); + } + return ret + ")"; + } + + TextValue VisitPattern_(const PatternVarNode* pv) final { + return GetValue(pv->var); + } + + TextValue VisitExpr_(const ConstructorNode* n) final { + return TextValue(n->name_hint); + } + /*! * \brief Print the type to os * \param type The type to be printed. @@ -404,6 +445,18 @@ class TextPrinter : os << "]"; } + void VisitType_(const TypeCallNode* node, std::ostream& os) final { + os << node->func << "(" << node->args << ")"; + } + + void VisitType_(const GlobalTypeVarNode* node, std::ostream& os) final { + VisitTypeDefault_(node, os); + } + + void VisitType_(const TypeDataNode* node, std::ostream& os) final { + VisitTypeDefault_(node, os); + } + void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*) // by default always print as meta-data os << meta_.GetMetaNode(GetRef(node)); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index bbe6472609dfd..bde9e53dc3297 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -48,7 +48,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; }); -TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) { +TypeVar TypeVarNode::make(std::string name, Kind kind) { NodePtr n = make_node(); n->var = tvm::Var(name); n->kind = std::move(kind); @@ -61,7 +61,7 @@ TVM_REGISTER_API("relay._make.TypeVar") .set_body([](TVMArgs args, TVMRetValue* ret) { int kind = args[1]; *ret = - TypeVarNode::make(args[0], static_cast(kind)); + TypeVarNode::make(args[0], static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -71,7 +71,50 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->kind << ")"; }); -IncompleteType IncompleteTypeNode::make(TypeVarNode::Kind kind) { +GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) { + NodePtr n = make_node(); + n->var = tvm::Var(name); + n->kind = std::move(kind); + return GlobalTypeVar(n); +} + +TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); + +TVM_REGISTER_API("relay._make.GlobalTypeVar") +.set_body([](TVMArgs args, TVMRetValue* ret) { + int kind = args[1]; + *ret = GlobalTypeVarNode::make(args[0], static_cast(kind)); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const GlobalTypeVarNode *node, + tvm::IRPrinter *p) { + p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", " + << node->kind << ")"; +}); + +TypeCall TypeCallNode::make(Type func, tvm::Array args) { + NodePtr n = make_node(); + n->func = std::move(func); + n->args = std::move(args); + return TypeCall(n); +} + +TVM_REGISTER_NODE_TYPE(TypeCallNode); + +TVM_REGISTER_API("relay._make.TypeCall") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TypeCallNode::make(args[0], args[1]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TypeCallNode* node, + tvm::IRPrinter* p) { + p->stream << "TypeCallNode(" << node->func << ", " + << node->args << ")"; +}); + +IncompleteType IncompleteTypeNode::make(Kind kind) { auto n = make_node(); n->kind = std::move(kind); return IncompleteType(n); @@ -82,7 +125,7 @@ TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); TVM_REGISTER_API("relay._make.IncompleteType") .set_body([](TVMArgs args, TVMRetValue* ret) { int kind = args[0]; - *ret = IncompleteTypeNode::make(static_cast(kind)); + *ret = IncompleteTypeNode::make(static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 0ef1743cbbc4a..2267f3f1604ef 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -44,6 +44,23 @@ void TypeVisitor::VisitType_(const TypeRelationNode* op) { } } +void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) { +} + +void TypeVisitor::VisitType_(const TypeCallNode* op) { + this->VisitType(op->func); + for (const Type& t : op->args) { + this->VisitType(t); + } +} + +void TypeVisitor::VisitType_(const TypeDataNode* op) { + this->VisitType(op->header); + for (const auto& v : op->tv) { + this->VisitType(v); + } + // TODO(slyubomirsky, MarisaKirisame): visit constructors +} // Type Mutator. Array TypeMutator::MutateArray(Array arr) { @@ -131,6 +148,22 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { } } +Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { + return GetRef(op); +} + +Type TypeMutator::VisitType_(const TypeCallNode* op) { + std::vector args; + for (const auto& a : op->args) { + args.push_back(VisitType(a)); + } + return TypeCallNode::make(VisitType(op->func), args); +} + +Type TypeMutator::VisitType_(const TypeDataNode* op) { + return GetRef(op); +} + // Implements bind. class TypeBinder : public TypeMutator { public: diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index e8dfd2b7cd7cd..9a6f18a09cc7f 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -68,6 +69,9 @@ class TypeFunctor { virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Node* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->type_key(); @@ -86,6 +90,9 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode); return vtable; } }; @@ -101,6 +108,9 @@ class TypeVisitor : public TypeFunctor { void VisitType_(const FuncTypeNode* op) override; void VisitType_(const TupleTypeNode* op) override; void VisitType_(const TypeRelationNode* op) override; + void VisitType_(const GlobalTypeVarNode* op) override; + void VisitType_(const TypeCallNode* op) override; + void VisitType_(const TypeDataNode* op) override; }; // Mutator that transform a type to another one. @@ -112,6 +122,9 @@ class TypeMutator : public TypeFunctor { Type VisitType_(const FuncTypeNode* op) override; Type VisitType_(const TupleTypeNode* op) override; Type VisitType_(const TypeRelationNode* type_rel) override; + Type VisitType_(const GlobalTypeVarNode* op) override; + Type VisitType_(const TypeCallNode* op) override; + Type VisitType_(const TypeDataNode* op) override; private: Array MutateArray(Array arr); diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index fad4fb781b5a8..46c9fece321bc 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -275,6 +275,15 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ExprVisitor::VisitExpr_(op); this->AddNode(op); } + + void VisitExpr_(const MatchNode* op) final { + this->Update(op->data, nullptr, kOpaque); + for (const Clause& c : op->pattern) { + this->Update(c->rhs, nullptr, kOpaque); + } + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } }; IndexedForwardGraph IndexedForwardGraph::Create( diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 7253a600dabfb..5b0599d96d6e3 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -20,7 +20,6 @@ namespace tvm { namespace relay { using namespace tvm::runtime; -using Kind = TypeVarNode::Kind; struct KindChecker : TypeVisitor { bool valid; @@ -107,7 +106,7 @@ bool KindCheck(const Type& t, const Module& mod) { TVM_REGISTER_API("relay._ir_pass.check_kind") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { - *ret = KindCheck(args[0], ModuleNode::make({})); + *ret = KindCheck(args[0], ModuleNode::make({}, {})); } else { *ret = KindCheck(args[0], args[1]); } diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 904ceab36c3d4..9e913ec842816 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -61,7 +61,7 @@ class LetList { * \return a Var that hold the inserted expr. */ Var Push(Expr expr) { - return Push(IncompleteTypeNode::make(TypeVarNode::kType), expr); + return Push(IncompleteTypeNode::make(Kind::kType), expr); } /*! diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index ee1b5ab101482..3e600165fb241 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -22,7 +22,9 @@ #include #include +#include #include +#include "./pass_util.h" #include "type_solver.h" #include "../ir/type_functor.h" @@ -99,7 +101,8 @@ struct ResolvedTypeInfo { // - Solve the constraints (solver_.Solve) // - Recreate expression with the resolved checked_type (Resolver.VisitExpr) // -class TypeInferencer : private ExprFunctor { +class TypeInferencer : private ExprFunctor, + private PatternFunctor { public: // constructors TypeInferencer() { @@ -158,7 +161,7 @@ class TypeInferencer : private ExprFunctor { if (op->type_annotation.defined()) { return op->type_annotation; } else { - return IncompleteTypeNode::make(TypeVarNode::kType); + return IncompleteTypeNode::make(Kind::kType); } } @@ -183,7 +186,7 @@ class TypeInferencer : private ExprFunctor { for (Expr field : op->fields) { types.push_back(GetType(field)); } - Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type rtype = IncompleteTypeNode::make(Kind::kType); types.push_back(rtype); solver_.AddConstraint(TypeRelationNode::make( make_tuple_rel_, types, op->fields.size(), Attrs())); @@ -196,7 +199,7 @@ class TypeInferencer : private ExprFunctor { EnvFunc::Get("tvm.relay.type_relation.TupleGetItem").node_); } Type tuple_type = GetType(op->tuple); - Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type rtype = IncompleteTypeNode::make(Kind::kType); auto attrs = make_node(); attrs->index = op->index; solver_.AddConstraint(TypeRelationNode::make( @@ -204,6 +207,45 @@ class TypeInferencer : private ExprFunctor { return rtype; } + void VisitPattern_(const PatternConstructorNode* con, const Type& t) { + CHECK(mod_.defined()) + << "Cannot do type inference without a environment:" + << con->con->name_hint; + TypeData td = mod_->type_definitions.at(con->con->belong_to); + auto* tc = t.as(); + CHECK(tc) << "must be type call"; + CHECK_EQ(td->header, tc->func); + CHECK(td->tv.size() == tc->args.size()) << "both side must be equal"; + std::unordered_map type_var_map_; + for (size_t i = 0; i < td->tv.size(); ++i) { + type_var_map_[td->tv[i]] = tc->args[i]; + } + CHECK(con->con->inp.size() == con->pat.size()) << "not enough pattern"; + for (size_t i = 0; i < con->con->inp.size(); ++i) { + VisitPattern(con->pat[i], Bind(con->con->inp[i], type_var_map_)); + } + } + + void VisitPattern_(const PatternVarNode* pv, const Type& t) { + type_map_[pv->var] = ResolvedTypeInfo(t, {}); + } + + void VisitPattern_(const PatternWildcardNode* wc, const Type& t) { } + + Type VisitExpr_(const MatchNode* op) final { + Type dtype = GetType(op->data); + for (const auto& c : op->pattern) { + VisitPattern(c->lhs, dtype); + } + Type rtype = IncompleteTypeNode::make(Kind::kType); + for (const auto& c : op->pattern) { + rtype = this->Unify(rtype, + GetType(c->rhs), + op->span); + } + return rtype; + } + Type VisitExpr_(const OpNode* op) final { return op->op_type; } @@ -244,7 +286,7 @@ class TypeInferencer : private ExprFunctor { for (size_t i = 0; i < op->type_params.size(); ++i) { if (!op->type_params[i].same_as(rel->args[i])) return Type(); } - Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type rtype = IncompleteTypeNode::make(Kind::kType); arg_types.push_back(rtype); // we can do simple replacement here solver_.AddConstraint(TypeRelationNode::make( @@ -272,7 +314,7 @@ class TypeInferencer : private ExprFunctor { // This is a temporary work around to check recursive functions whose // return type is not yet known. if (!ret_type.defined()) { - ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + ret_type = IncompleteTypeNode::make(Kind::kType); } Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, @@ -374,9 +416,21 @@ class TypeInferencer : private ExprFunctor { // do not support constraint lifting for now. return FuncTypeNode::make(arg_types, rtype, f->type_params, {}); } + + Type VisitExpr_(const ConstructorNode* c) final { + CHECK(mod_.defined()) + << "Cannot do type inference without a environment:" + << c->name_hint; + TypeData td = mod_->type_definitions.at(c->belong_to); + std::vector types; + for (const auto & t : td->tv) { + types.push_back(t); + } + return FuncTypeNode::make(c->inp, TypeCallNode::make(c->belong_to, types), td->tv, {}); + } }; -class TypeInferencer::Resolver : public ExprMutator { +class TypeInferencer::Resolver : public ExprMutator, PatternMutator { public: Resolver(const std::unordered_map& tmap, TypeSolver* solver) @@ -384,7 +438,7 @@ class TypeInferencer::Resolver : public ExprMutator { } Expr VisitExpr_(const VarNode* op) final { - return AttachCheckedType(op); + return VisitVar(GetRef(op)); } Expr VisitExpr_(const ConstantNode* op) final { @@ -423,6 +477,25 @@ class TypeInferencer::Resolver : public ExprMutator { return AttachCheckedType(op); } + Expr VisitExpr_(const ConstructorNode* op) final { + return AttachCheckedType(op); + } + + Expr VisitExpr_(const MatchNode* op) final { + return AttachCheckedType(op); + } + + Pattern VisitPattern(const Pattern& p) final { + return PatternMutator::VisitPattern(p); + } + + Var VisitVar(const Var& v) final { + if (vmap_.count(v) == 0) { + vmap_[v] = GetRef(AttachCheckedType(v.as()).as()); + } + return vmap_.at(v); + } + // attach checked type to the mutated node. template Expr AttachCheckedType(const T* op) { @@ -512,6 +585,7 @@ class TypeInferencer::Resolver : public ExprMutator { } private: + std::unordered_map vmap_; const std::unordered_map& tmap_; TypeSolver* solver_; // whether attach the checked type as type_annotation @@ -531,10 +605,23 @@ Expr TypeInferencer::Infer(Expr expr) { return resolved_expr; } +struct AllCheckTypePopulated : ExprVisitor { + void VisitExpr(const Expr& e) { + if (e.as()) { return; } + if (e.as()) { return; } + CHECK(e->checked_type_.defined()) << "Expression: " << e; + return ExprVisitor::VisitExpr(e); + } +}; + +void EnsureCheckedType(const Expr& e) { + AllCheckTypePopulated().VisitExpr(e); +} Expr InferType(const Expr& expr, const Module& mod) { auto e = TypeInferencer(mod).Infer(expr); CHECK(WellFormed(e)); + EnsureCheckedType(e); return e; } diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index b99d975135bea..75f97e1a53a76 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -7,6 +7,7 @@ */ #include #include +#include #include "../ir/type_functor.h" namespace tvm { diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 385bde9740149..ff62847c35489 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -6,12 +6,14 @@ TEST(Relay, SelfReference) { using namespace tvm; - auto type_a = relay::TypeVarNode::make("a", relay::TypeVarNode::kType); - auto type_b = relay::TypeVarNode::make("b", relay::TypeVarNode::kType); + auto type_a = relay::TypeVarNode::make("a", relay::kType); + auto type_b = relay::TypeVarNode::make("b", relay::kType); auto x = relay::VarNode::make("x", type_a); auto f = relay::FunctionNode::make(tvm::Array{ x }, x, type_b, Array{}); auto fx = relay::CallNode::make(f, Array{ x }); - auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map{})); + Map func_def = {}; + Map type_def = {}; + auto type_fx = relay::InferType(fx, relay::ModuleNode::make(func_def, type_def)); CHECK_EQ(type_fx->checked_type(), type_a); } diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py new file mode 100644 index 0000000000000..2f2d1f98906e5 --- /dev/null +++ b/tests/python/relay/test_adt.py @@ -0,0 +1,138 @@ +import tvm +from tvm import relay +from tvm.relay.ir_pass import infer_type +from tvm.relay.backend.interpreter import Value, TupleValue, ConValue +from tvm.relay import testing, create_executor +from tvm.relay.prelude import Prelude + +mod = relay.Module() +p = Prelude(mod) +ctx = tvm.context("llvm", 0) +intrp = create_executor(mod=mod, ctx=ctx, target="llvm") + +z = p.z +s = p.s +nat = p.nat +double = p.double +add = p.add + +nil = p.nil +cons = p.cons +l = p.l +length = p.length +map = p.map +foldl = p.foldl +foldr = p.foldr +sum = p.sum + +tree = p.tree +rose = p.rose +tmap = p.tmap +size = p.size + +# this is an example of using the adt value in python side +def count(n): + assert isinstance(n, ConValue) + if n.con.name_hint == 's': + return 1 + count(n.fields[0]) + else: + assert n.con.name_hint == 'z' + return 0 + +# this is an example of creating the adt value in python side +def make_nat(n): + if n != 0: + return ConValue(s, [make_nat(n - 1)], []) + else: + return ConValue(z, [], []) + + +def test_nat_value(): + assert count(make_nat(10)) == 10 + + +def test_nat_constructor(): + assert relay.ir_pass.infer_type(z(), mod).checked_type == nat() + assert relay.ir_pass.infer_type(s, mod).checked_type == relay.FuncType([nat()], nat()) + assert relay.ir_pass.infer_type(s(z()), mod).checked_type == nat() + + +def test_double(): + assert mod[double].checked_type == relay.FuncType([nat()], nat()) + + +def test_add(): + assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) + res = intrp.evaluate(add(s(z()), s(z()))) + assert count(res) == 2 + + +def test_list_constructor(): + a = relay.TypeVar("a") + assert relay.ir_pass.infer_type(nil, mod).checked_type == relay.FuncType([], l(a), [a]) + assert relay.ir_pass.infer_type(cons(z(), nil()), mod).checked_type == l(nat()) + + +def test_length(): + a = relay.TypeVar("a") + assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a]) + + +def test_map(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + lhs = mod[map].checked_type + rhs = relay.FuncType([relay.FuncType([a], b), l(a)], l(b), [a, b]) + assert lhs == rhs + + +def test_foldl(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + lhs = mod[foldl].checked_type + rhs = relay.FuncType([relay.FuncType([a, b], a), a, l(b)], a, [a, b]) + assert lhs == rhs + + +def test_foldr(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + lhs = mod[foldr].checked_type + rhs = relay.FuncType([relay.FuncType([a, b], b), b, l(a)], b, [a, b]) + assert lhs == rhs + + +def test_sum(): + assert mod[sum].checked_type == relay.FuncType([l(nat())], nat()) + + +def test_tmap(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + # cannot infer return type of tmap! + lhs = mod[tmap].checked_type + rhs = relay.FuncType([relay.FuncType([a], b), tree(a)], tree(b), [a, b]) + # print(lhs) + # print(rhs) + # assert lhs == rhs + # this is broken, need some way to add type annotation + +def test_size(): + a = relay.TypeVar("a") + lhs = mod[size].checked_type + rhs = relay.FuncType([tree(a)], nat(), [a]) + assert lhs == rhs + + +if __name__ == "__main__": + test_nat_constructor() + test_double() + test_add() + test_list_constructor() + test_length() + test_map() + test_foldl() + test_foldr() + test_sum() + test_tmap() + test_size() diff --git a/tests/python/relay/test_typecall.py b/tests/python/relay/test_typecall.py new file mode 100644 index 0000000000000..8f556c0af67e7 --- /dev/null +++ b/tests/python/relay/test_typecall.py @@ -0,0 +1,25 @@ +from tvm import relay +from tvm.relay.ir_pass import infer_type + +def test_dup_type(): + a = relay.TypeVar("a") + av = relay.Var("av", a) + make_id = relay.Function([av], relay.Tuple([av, av]), None, [a]) + t = relay.scalar_type("float32") + b = relay.Var("b", t) + assert relay.ir_pass.infer_type(make_id(b)).checked_type == relay.TupleType([t, t]) + + +def test_id_type(): + mod = relay.Module() + id_type = relay.TypeVar("id") + a = relay.TypeVar("a") + make_id = relay.Var("make_id", relay.FuncType([a], id_type(a), [a])) + t = relay.scalar_type("float32") + b = relay.Var("b", t) + assert relay.ir_pass.infer_type(make_id(b), mod).checked_type == id_type(t) + + +if __name__ == "__main__": + test_dup_type() + test_id_type()