From 921d084c21c2abc908f7308b92cad80579000100 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 15 Feb 2019 22:37:49 -0800 Subject: [PATCH] [Relay] Algebraic data types (#2442) * First pass on ADTs * Add doc string for tag field * Visit constructors in TypeVisitor for TypeData * Add to description of type call * Add type call to type solving and unification * Make type mutator for typecall consistent with others (only create new node if there's a change) * Ensure kindchecking can handle type calls and typedata * Fix bad nesting in module constructor * Correctly construct call in typecall test * Add call override for ordinary vars (do we want this?) * Remove generalization hack from type inference because it was breaking ADT constructors * Check that there are no free type vars in exprs after inferring type * Free var checks need module because of ADT constructors * Typecall test can't have unbound type var, make it global * Uncomment tmap test and remove comments about failing to infer ret type; those work now * Put in dummy visits for ADTs in graph runtime codegen to placate pylint * Fix Relay type infer test module constructor * Mark override for TypeCallNode in type solver * Ensure free vars check treats patern vars as bound * Run interpreter in more ADT test cases * Refactor kind check to return the kind, like typechecking * Fix invalid typecall in test * Add kind check to type inference, do not use nulls in func_type_annotation()! * Redundant whitespace * Make TypeData a separate kind * Make ADT handles a separate kind too, document calling convention better * Remove nats and tree from prelude, move to test, document prelude * Restore and document nat and tree to prelude, add more tree tests * Add alpha equality tests for match cases, fix variable binding bug * Add more kind check tests for ADTs * Add more tests for finding free or bound vars in match exprs * Add unification tests for type call * Update main() for alpha equality tests * Add simple type inference test cases for match exprs and ADT constructors * Add more ADT interpreter tests * Allow incomplete types when typechecking match cases * Type inference for pattern vars should use the type annotation if it's there * Two more specific test cases for ADT matching * Add option ADT to prelude * Fix broken reference to kind enum * Fix rebase snags * Do not attach checked types to constructors * More docstrings for module fields * Use proper wrapper for indexing into module type data * checked_type for constructors is not populated * Expand type call docstring * Rename PatternConstructor con field * Use error reporter for pattern constructor case * Condense error reporting in kind check, use error reporter * Expand docstrings and rename ADT fields * Rename 'option' ADT to 'optional' for consistency with Python * Add various list iterators and utility functions to prelude * Add smoke tests for new iterators in prelude * Add concat to prelude * Add smoke test for concat * Correct docstrings in prelude * Ensure that type defs are written in module initialization * Various requested renamings * Correct rebase snags * Add kind check tests for ref types * Update the main() for kind checking tests --- include/tvm/relay/adt.h | 244 +++++++ include/tvm/relay/expr_functor.h | 14 + include/tvm/relay/interpreter.h | 22 + include/tvm/relay/module.h | 45 +- include/tvm/relay/op.h | 4 +- include/tvm/relay/pass.h | 22 +- include/tvm/relay/pattern_functor.h | 143 +++++ include/tvm/relay/type.h | 81 ++- python/tvm/relay/__init__.py | 13 + python/tvm/relay/adt.py | 187 ++++++ .../relay/backend/graph_runtime_codegen.py | 6 + python/tvm/relay/backend/interpreter.py | 7 + python/tvm/relay/expr.py | 14 + python/tvm/relay/expr_functor.py | 12 + python/tvm/relay/ir_pass.py | 32 +- python/tvm/relay/module.py | 72 ++- python/tvm/relay/prelude.py | 379 +++++++++++ python/tvm/relay/ty.py | 63 ++ src/relay/backend/interpreter.cc | 71 ++- src/relay/ir/adt.cc | 162 +++++ src/relay/ir/alpha_equal.cc | 89 ++- src/relay/ir/expr.cc | 9 +- src/relay/ir/expr_functor.cc | 39 ++ src/relay/ir/hash.cc | 78 ++- src/relay/ir/module.cc | 86 ++- src/relay/ir/pattern_functor.cc | 75 +++ src/relay/ir/text_printer.cc | 53 ++ src/relay/ir/type.cc | 51 +- src/relay/ir/type_functor.cc | 41 ++ src/relay/ir/type_functor.h | 14 + src/relay/pass/fuse_ops.cc | 9 + src/relay/pass/kind_check.cc | 154 +++-- src/relay/pass/let_list.h | 2 +- src/relay/pass/to_anf.cc | 2 +- src/relay/pass/type_infer.cc | 181 ++++-- src/relay/pass/type_solver.cc | 26 +- src/relay/pass/util.cc | 65 +- tests/cpp/relay_pass_type_infer_test.cc | 5 +- tests/python/relay/test_adt.py | 600 ++++++++++++++++++ tests/python/relay/test_pass_alpha_equal.py | 99 +++ tests/python/relay/test_pass_check_kind.py | 129 +++- tests/python/relay/test_pass_vars.py | 36 +- tests/python/relay/test_type_infer.py | 76 +++ tests/python/relay/test_type_solver.py | 65 ++ tests/python/relay/test_typecall.py | 28 + 45 files changed, 3398 insertions(+), 207 deletions(-) create mode 100644 include/tvm/relay/adt.h create mode 100644 include/tvm/relay/pattern_functor.h create mode 100644 python/tvm/relay/adt.py create mode 100644 python/tvm/relay/prelude.py create mode 100644 src/relay/ir/adt.cc create mode 100644 src/relay/ir/pattern_functor.cc create mode 100644 tests/python/relay/test_adt.py create mode 100644 tests/python/relay/test_typecall.py diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h new file mode 100644 index 000000000000..07c05e89aa86 --- /dev/null +++ b/include/tvm/relay/adt.h @@ -0,0 +1,244 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/adt.h + * \brief Algebraic data types for Relay + */ +#ifndef TVM_RELAY_ADT_H_ +#define TVM_RELAY_ADT_H_ + +#include +#include +#include +#include "./base.h" +#include "./type.h" +#include "./expr.h" + +namespace tvm { +namespace relay { + +/*! \brief Base type for declaring relay pattern. */ +class PatternNode : public RelayNode { + public: + static constexpr const char* _type_key = "relay.Pattern"; + TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node); +}; + +/*! + * \brief Pattern is the base type for an ADT match pattern in Relay. + * + * Given an ADT value, a pattern might accept it and bind the pattern variable to some value + * (typically a subnode of the input or the input). Otherwise, the pattern rejects the value. + * + * ADT pattern matching thus takes a list of values and binds to the first that accepts the value. + */ +class Pattern : public NodeRef { + public: + Pattern() {} + explicit Pattern(NodePtr p) : NodeRef(p) {} + + using ContainerType = PatternNode; +}; + +/*! \brief A wildcard pattern: Accepts all input and binds nothing. */ +class PatternWildcard; +/*! \brief PatternWildcard container node */ +class PatternWildcardNode : public PatternNode { + public: + PatternWildcardNode() {} + + TVM_DLL static PatternWildcard make(); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.PatternWildcard"; + TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode); +}; + +RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern); + +/*! \brief A var pattern. Accept all input and bind to a var. */ +class PatternVar; +/*! \brief PatternVar container node */ +class PatternVarNode : public PatternNode { + public: + PatternVarNode() {} + + /*! \brief Variable that stores the matched value. */ + tvm::relay::Var var; + + TVM_DLL static PatternVar make(tvm::relay::Var var); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.PatternVar"; + TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode); +}; + +RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern); + +/*! + * \brief ADT constructor. + * Constructors compare by pointer equality. + */ +class Constructor; +/*! \brief Constructor container node. */ +class ConstructorNode : public ExprNode { + public: + /*! \brief The name (only a hint) */ + std::string name_hint; + /*! \brief Input to the constructor. */ + tvm::Array inputs; + /*! \brief The datatype the constructor will construct. */ + GlobalTypeVar belong_to; + /*! \brief Index in the table of constructors (set when the type is registered). */ + mutable int tag = -1; + + ConstructorNode() {} + + TVM_DLL static Constructor make(std::string name_hint, + tvm::Array inputs, + GlobalTypeVar belong_to); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name_hint", &name_hint); + v->Visit("inputs", &inputs); + v->Visit("belong_to", &belong_to); + v->Visit("tag", &tag); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + static constexpr const char* _type_key = "relay.Constructor"; + TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr); + +/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ +class PatternConstructor; +/*! \brief PatternVar container node */ +class PatternConstructorNode : public PatternNode { + public: + /*! Constructor matched by the pattern. */ + Constructor constructor; + /*! Sub-patterns to match against each input to the constructor. */ + tvm::Array patterns; + + PatternConstructorNode() {} + + TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array var); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("constructor", &constructor); + v->Visit("patterns", &patterns); + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.PatternConstructor"; + TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode); +}; + +RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern); + +/*! + * \brief Stores all data for an Algebraic Data Type (ADT). + * + * In particular, it stores the handle (global type var) for an ADT + * and the constructors used to build it and is kept in the module. Note + * that type parameters are also indicated in the type data: this means that + * for any instance of an ADT, the type parameters must be indicated. That is, + * an ADT definition is treated as a type-level function, so an ADT handle + * must be wrapped in a TypeCall node that instantiates the type-level arguments. + * The kind checker enforces this. + */ +class TypeData; +/*! \brief TypeData container node */ +class TypeDataNode : public TypeNode { + public: + /*! + * \brief The header is simply the name of the ADT. + * We adopt nominal typing for ADT definitions; + * that is, differently-named ADT definitions with same constructors + * have different types. + */ + GlobalTypeVar header; + /*! \brief The type variables (to allow for polymorphism). */ + tvm::Array type_vars; + /*! \brief The constructors. */ + tvm::Array constructors; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("header", &header); + v->Visit("type_vars", &type_vars); + v->Visit("constructors", &constructors); + v->Visit("span", &span); + } + + TVM_DLL static TypeData make(GlobalTypeVar header, + tvm::Array type_vars, + tvm::Array constructors); + + static constexpr const char* _type_key = "relay.TypeData"; + TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type); + +/*! \brief A clause in a match expression. */ +class Clause; +/*! \brief Clause container node. */ +class ClauseNode : public Node { + public: + /*! \brief The pattern the clause matches. */ + Pattern lhs; + /*! \brief The resulting value. */ + Expr rhs; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("lhs", &lhs); + v->Visit("rhs", &rhs); + } + + TVM_DLL static Clause make(Pattern lhs, Expr rhs); + + static constexpr const char* _type_key = "relay.Clause"; + TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node); +}; + +RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef); + +/*! \brief ADT pattern matching exression. */ +class Match; +/*! \brief Match container node. */ +class MatchNode : public ExprNode { + public: + /*! \brief The input being deconstructed. */ + Expr data; + + /*! \brief The match node clauses. */ + tvm::Array clauses; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("data", &data); + v->Visit("clause", &clauses); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static Match make(Expr data, tvm::Array pattern); + + static constexpr const char* _type_key = "relay.Match"; + TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ADT_H_ diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index e7b66bc1bbde..fd68139495b4 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -10,6 +10,7 @@ #include #include #include "./expr.h" +#include "./adt.h" #include "./op.h" #include "./error.h" @@ -92,6 +93,8 @@ class ExprFunctor { virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw Error(std::string("Do not have a default for ") + op->type_key()); } @@ -114,6 +117,8 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); + RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); + RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); return vtable; } }; @@ -142,7 +147,11 @@ class ExprVisitor void VisitExpr_(const RefCreateNode* op) override; void VisitExpr_(const RefReadNode* op) override; void VisitExpr_(const RefWriteNode* op) override; + void VisitExpr_(const ConstructorNode* op) override; + void VisitExpr_(const MatchNode* op) override; virtual void VisitType(const Type& t); + virtual void VisitClause(const Clause& c); + virtual void VisitPattern(const Pattern& c); protected: // Internal visiting counter @@ -180,6 +189,9 @@ class ExprMutator Expr VisitExpr_(const RefCreateNode* op) override; Expr VisitExpr_(const RefReadNode* op) override; Expr VisitExpr_(const RefWriteNode* op) override; + Expr VisitExpr_(const ConstructorNode* op) override; + Expr VisitExpr_(const MatchNode* op) override; + /*! * \brief Used to visit the types inside of expressions. * @@ -188,6 +200,8 @@ class ExprMutator * visitor for types which transform them appropriately. */ virtual Type VisitType(const Type& t); + virtual Clause VisitClause(const Clause& c); + virtual Pattern VisitPattern(const Pattern& c); protected: /*! \brief Internal map used for memoization. */ diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 08aeef1827b6..42f0d4e9b0a5 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -160,6 +160,28 @@ struct RefValueNode : ValueNode { RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); +/*! \brief An ADT constructor value. */ +class ConstructorValue; + +struct ConstructorValueNode : ValueNode { + Constructor constructor; + + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("constructor", &constructor); + v->Visit("fields", &fields); + } + + TVM_DLL static ConstructorValue make(Constructor constructor, + tvm::Array fields); + + static constexpr const char* _type_key = "relay.ConstructorValue"; + TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(ConstructorValue, ConstructorValueNode, Value); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_INTERPRETER_H_ diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 45ccfe3a8089..6de3b22f6566 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -35,13 +36,15 @@ struct Module; * * The functional style allows users to construct custom * environments easily, for example each thread can store - * an Module while auto-tuning. + * a Module while auto-tuning. * */ class ModuleNode : public RelayNode { public: /*! \brief A map from ids to all global functions. */ tvm::Map functions; + /*! \brief A map from global type vars to ADT type data. */ + tvm::Map type_definitions; /*! \brief The entry function (i.e. "main"). */ GlobalVar entry_func; @@ -50,21 +53,31 @@ class ModuleNode : public RelayNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("functions", &functions); + v->Visit("type_definitions", &type_definitions); v->Visit("global_var_map_", &global_var_map_); v->Visit("entry_func", &entry_func); + v->Visit("global_type_var_map_", &global_type_var_map_); } - TVM_DLL static Module make(tvm::Map global_funcs); + TVM_DLL static Module make(tvm::Map global_funcs, + tvm::Map global_type_defs); /*! * \brief Add a function to the global environment. - * \param var The name of the global function. + * \param var The var of the global function. * \param func The function. * \param update Controls whether you can replace a definition in the * environment. */ void Add(const GlobalVar& var, const Function& func, bool update = false); + /*! + * \brief Add a type-level definition to the global environment. + * \param var The var of the global type definition. + * \param type The type definition. + */ + void AddDef(const GlobalTypeVar& var, const TypeData& type); + /*! * \brief Add a function to the global environment. * \param var The name of the global function. @@ -94,6 +107,13 @@ class ModuleNode : public RelayNode { */ GlobalVar GetGlobalVar(const std::string& str); + /*! + * \brief Look up a global function by its name. + * \param str The unique string specifying the global variable. + * \returns The global variable. + */ + GlobalTypeVar GetGlobalTypeVar(const std::string& str); + /*! * \brief Lookup a global function by its variable. * \param var The global var to lookup. @@ -108,6 +128,20 @@ class ModuleNode : public RelayNode { */ Function Lookup(const std::string& name); + /*! + * \brief Lookup a global type definition by its variable. + * \param var The var of the global type definition. + * \return The type definition. + */ + TypeData LookupDef(const GlobalTypeVar& var); + + /*! + * \brief Lookup a global type definition by its name. + * \param var The name of the global type definition. + * \return The type definition. + */ + TypeData LookupDef(const std::string& var); + /*! * \brief Update the functions inside this environment by * functions in another environment. @@ -137,6 +171,11 @@ class ModuleNode : public RelayNode { * ensures global uniqueness. */ tvm::Map global_var_map_; + + /*! \brief A map from string names to global type variables (ADT names) + * that ensures global uniqueness. + */ + tvm::Map global_type_var_map_; }; struct Module : public NodeRef { diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 0fd54ff5b8fa..583491ca2613 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -422,7 +422,7 @@ inline OpRegistry& OpRegistry::add_type_rel( std::string input_name_prefix = "in"; for (int i = 0; i < get()->num_inputs; i++) { auto name = input_name_prefix + std::to_string(i); - auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType); + auto param = TypeVarNode::make(name, Kind::kType); type_params.push_back(param); arg_types.push_back(param); } @@ -430,7 +430,7 @@ inline OpRegistry& OpRegistry::add_type_rel( Array ty_call_args = arg_types; // Add output type. - auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType); + auto out_param = TypeVarNode::make("out", Kind::kType); type_params.push_back(out_param); // this will trigger copy on write. ty_call_args.push_back(out_param); diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 1558e65a6b36..b87f9319a3d3 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -56,9 +56,9 @@ TVM_DLL Function InferType(const Function& f, const Module& mod, * \param t The type to check. * \param mod The global module. * - * \return true if the rules are satisified otherwise false + * \return The kind of the passed type. */ -TVM_DLL bool KindCheck(const Type& t, const Module& mod); +TVM_DLL Kind KindCheck(const Type& t, const Module& mod); /*! \brief Compare two expressions for structural equivalence. * @@ -144,10 +144,11 @@ TVM_DLL tvm::Array AllVars(const Expr& expr); * type in the context. * * \param expr the expression. + * \param mod the module. * * \return List of free vars, in the PostDFS order visited by expr. */ -TVM_DLL tvm::Array FreeTypeVars(const Expr& expr); +TVM_DLL tvm::Array FreeTypeVars(const Expr& expr, const Module& mod); /*! \brief Get free TypeVars from type t. * @@ -155,10 +156,11 @@ TVM_DLL tvm::Array FreeTypeVars(const Expr& expr); * type in the context. * * \param t the type. + * \param mod the module. * * \return List of free type vars, in the PostDFS order visited by type. */ -TVM_DLL tvm::Array FreeTypeVars(const Type& t); +TVM_DLL tvm::Array FreeTypeVars(const Type& t, const Module& mod); /*! \brief Get all bound type variables from expression expr. * @@ -166,10 +168,11 @@ TVM_DLL tvm::Array FreeTypeVars(const Type& t); * They only have meaning inside that expr, and can only be used in it. * * \param expr the expression. + * \param mod the module. * * \return List of bound type vars, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array BoundTypeVars(const Expr& expr); +TVM_DLL tvm::Array BoundTypeVars(const Expr& expr, const Module& mod); /*! \brief Get all bound type variables from type t. * @@ -177,26 +180,29 @@ TVM_DLL tvm::Array BoundTypeVars(const Expr& expr); * They only have meaning inside that type, and can only be used in it. * * \param t the type + * \param mod the module. * * \return List of bound type vars, in the PostDFS order visited by type. */ -TVM_DLL tvm::Array BoundTypeVars(const Type& t); +TVM_DLL tvm::Array BoundTypeVars(const Type& t, const Module& mod); /*! \brief Get all type variables in expression expr. * * \param expr the expression. + * \param mod the module. * * \return List of type vars, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array AllTypeVars(const Expr& expr); +TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const Module& mod); /*! \brief Get all type variables in type t. * * \param t the type. + * \param mod the module. * * \return List of type vars, in the PostDFS order visited by type. */ -TVM_DLL tvm::Array AllTypeVars(const Type& t); +TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); /*! \brief Remove expressions which does not effect the program result. * diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h new file mode 100644 index 000000000000..f9833201ea0a --- /dev/null +++ b/include/tvm/relay/pattern_functor.h @@ -0,0 +1,143 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/pattern_functor.h + * \brief A more powerful visitor on ADT patterns that enables defining + * arbitrary function signatures with type-based dispatch on first argument. + */ +#ifndef TVM_RELAY_PATTERN_FUNCTOR_H_ +#define TVM_RELAY_PATTERN_FUNCTOR_H_ + +#include +#include +#include "./expr.h" +#include "./op.h" +#include "./error.h" +#include "./adt.h" + +namespace tvm { +namespace relay { + +/*! + * \brief A dynamical functor on ADT patterns that dispatches on its first argument. + * You can use this as a more powerful visitor, since it allows you to + * define the types of further arguments to VisitPattern. + * + * \sa tvm/ir_functor.h + * + * \tparam FType function signiture + * This type is only defined for FType with function signature R(const Pattern&, + * Args...) + */ +template +class PatternFunctor; + +// functions to be overriden. +#define PATTERN_FUNCTOR_DEFAULT \ + { return VisitPatternDefault_(op, std::forward(args)...); } + +#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitPattern_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); + +template +class PatternFunctor { + private: + using TSelf = PatternFunctor; + using FType = tvm::IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~PatternFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Pattern& n, Args... args) { + return VisitPattern(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitPattern(const Pattern& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitPattern_(const PatternWildcardNode* op, + Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternVarNode* op, + Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternConstructorNode* op, + Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPatternDefault_(const Node* op, Args...) { + throw Error(std::string("Do not have a default for ") + op->type_key()); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode); + RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode); + RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode); + return vtable; + } +}; + +/*! \brief A simple visitor wrapper around PatternFunctor. + * + * Exposes two visitors with default traversal strategies, one + * which doesn't compute a result but can mutate internal state, + * and another which functionally builds a new pattern. + */ +class PatternVisitor : public ::tvm::relay::PatternFunctor { + public: + void VisitPattern_(const PatternWildcardNode* op) override; + void VisitPattern_(const PatternVarNode* op) override; + void VisitPattern_(const PatternConstructorNode* op) override; + virtual void VisitType(const Type& t); + virtual void VisitVar(const Var& v); + virtual void VisitConstructor(const Constructor& c); +}; + +/*! \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator uses memoization and self return in order to amortize + * the cost of using functional updates. + */ +class PatternMutator + : public ::tvm::relay::PatternFunctor { + public: + Pattern Mutate(const Pattern& pat); + Pattern VisitPattern_(const PatternWildcardNode* op) override; + Pattern VisitPattern_(const PatternVarNode* op) override; + Pattern VisitPattern_(const PatternConstructorNode* op) override; + /*! \brief Used to visit the types inside of patterns. + * + * Can be overloaded to transform the types in arbitrary + * ways, one way would be to define a sub-class of type + * visitor for types which transform them appropriately. + */ + virtual Type VisitType(const Type& t); + /*! \brief Used to visit the vars inside of patterns. */ + virtual Var VisitVar(const Var& v); + /*! \brief Used to visit the vars inside of patterns. */ + virtual Constructor VisitConstructor(const Constructor& c); + private: + std::unordered_map var_map_; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PATTERN_FUNCTOR_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 0ee265e5f3b0..6c164ab6bcea 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -98,6 +98,18 @@ class TensorTypeNode : public BaseTensorTypeNode { RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); +/*! \brief possible kinds of Type */ +enum Kind : int { + /*! \brief template variable in shape expression */ + kType = 0, + kShapeVar = 1, + kBaseType = 2, + kShape = 3, + kConstraint = 4, + kAdtHandle = 5, + kTypeData = 6 +}; + /*! * \brief Type parameter in the function. * This can be viewed as template parameter in c++ template function. @@ -119,14 +131,6 @@ class TypeVar; /*! \brief TypeVar container node */ class TypeVarNode : public TypeNode { public: - /*! \brief possible kinds of TypeVar */ - enum Kind : int { - /*! \brief template variable in shape expression */ - kType = 0, - kShapeVar = 1, - kBaseType = 2, - kShape = 3 - }; /*! * \brief The variable itself is only meaningful when * kind is ShapeVar, otherwise, we only use the name. @@ -149,6 +153,63 @@ class TypeVarNode : public TypeNode { RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type); +/*! + * \brief A global type variable that is used for defining new types or type aliases. + */ +class GlobalTypeVar; +/*! \brief GlobalTypeVar container node */ +class GlobalTypeVarNode : public TypeNode { + public: + /*! + * \brief The variable itself is only meaningful when + * kind is ShapeVar; otherwise, we only use the name. + */ + tvm::Var var; + /*! \brief The kind of type parameter */ + Kind kind; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("kind", &kind); + v->Visit("span", &span); + } + + TVM_DLL static GlobalTypeVar make(std::string name, Kind kind); + + static constexpr const char* _type_key = "relay.GlobalTypeVar"; + TVM_DECLARE_NODE_TYPE_INFO(GlobalTypeVarNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(GlobalTypeVar, GlobalTypeVarNode, Type); + +/*! + * \brief Type application. + */ +class TypeCall; +/*! \brief TypeCall container node */ +class TypeCallNode : public TypeNode { + public: + /*! + * \brief The type-level function (ADT that takes type params). + */ + Type func; + /*! \brief The arguments. */ + tvm::Array args; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("func", &func); + v->Visit("args", &args); + v->Visit("span", &span); + } + + TVM_DLL static TypeCall make(Type func, tvm::Array args); + + static constexpr const char* _type_key = "relay.TypeCall"; + TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); + /*! * \brief IncompleteType. * This is intermediate values that is used during type inference. @@ -162,14 +223,14 @@ class IncompleteType; /*! \brief IncompleteType container node */ class IncompleteTypeNode : public TypeNode { public: - TypeVarNode::Kind kind; + Kind kind; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("kind", &kind); v->Visit("span", &span); } - TVM_DLL static IncompleteType make(TypeVarNode::Kind kind); + TVM_DLL static IncompleteType make(Kind kind); static constexpr const char* _type_key = "relay.IncompleteType"; TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 0af164bc7a73..fe00877c0fb0 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -7,8 +7,10 @@ from . import expr from . import expr_functor from . import module +from . import adt from . import ir_pass from .build_module import build, build_config, create_executor, optimize +from . import prelude from . import parser from . import debug @@ -45,6 +47,8 @@ IncompleteType = ty.IncompleteType scalar_type = ty.scalar_type RefType = ty.RefType +GlobalTypeVar = ty.GlobalTypeVar +TypeCall = ty.TypeCall # Expr Expr = expr.Expr @@ -61,6 +65,15 @@ RefRead = expr.RefRead RefWrite = expr.RefWrite +# ADT +PatternWildcard = adt.PatternWildcard +PatternVar = adt.PatternVar +PatternConstructor = adt.PatternConstructor +Constructor = adt.Constructor +TypeData = adt.TypeData +Clause = adt.Clause +Match = adt.Match + # helper functions var = expr.var const = expr.const diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/adt.py new file mode 100644 index 000000000000..bc516a8f3ddb --- /dev/null +++ b/python/tvm/relay/adt.py @@ -0,0 +1,187 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""Algebraic data types in Relay.""" +from .base import RelayNode, register_relay_node, NodeBase +from . import _make +from .ty import Type +from .expr import Expr, Call + + +class Pattern(RelayNode): + """Base type for pattern matching constructs.""" + pass + +@register_relay_node +class PatternWildcard(Pattern): + """Wildcard pattern in Relay: Matches any ADT and binds nothing.""" + + def __init__(self): + """Constructs a wildcard pattern. + + Parameters + ---------- + None + + Returns + ------- + wildcard: PatternWildcard + a wildcard pattern. + """ + self.__init_handle_by_constructor__(_make.PatternWildcard) + + +@register_relay_node +class PatternVar(Pattern): + """Variable pattern in Relay: Matches anything and binds it to the variable.""" + + def __init__(self, var): + """Construct a variable pattern. + + Parameters + ---------- + var: tvm.relay.Var + + Returns + ------- + pv: PatternVar + A variable pattern. + """ + self.__init_handle_by_constructor__(_make.PatternVar, var) + + +@register_relay_node +class PatternConstructor(Pattern): + """Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively.""" + + def __init__(self, constructor, patterns=None): + """Construct a constructor pattern. + + Parameters + ---------- + constructor: Constructor + The constructor. + patterns: Optional[List[Pattern]] + Optional subpatterns: for each field of the constructor, + match to the given subpattern (treated as a variable pattern by default). + + Returns + ------- + wildcard: PatternWildcard + a wildcard pattern. + """ + if patterns is None: + patterns = [] + self.__init_handle_by_constructor__(_make.PatternConstructor, constructor, patterns) + + +@register_relay_node +class Constructor(Expr): + """Relay ADT constructor.""" + + def __init__(self, name_hint, inputs, belong_to): + """Defines an ADT constructor. + + Parameters + ---------- + name_hint : str + Name of constructor (only a hint). + inputs : List[Type] + Input types. + belong_to : tvm.relay.GlobalTypeVar + Denotes which ADT the constructor belongs to. + + Returns + ------- + con: Constructor + A constructor. + """ + self.__init_handle_by_constructor__(_make.Constructor, name_hint, inputs, belong_to) + + def __call__(self, *args): + """Call the constructor. + + Parameters + ---------- + args: List[relay.Expr] + The arguments to the constructor. + + Returns + ------- + call: relay.Call + A call to the constructor. + """ + return Call(self, args) + + +@register_relay_node +class TypeData(Type): + """Stores the definition for an Algebraic Data Type (ADT) in Relay. + + Note that ADT definitions are treated as type-level functions because + the type parameters need to be given for an instance of the ADT. Thus, + any global type var that is an ADT header needs to be wrapped in a + type call that passes in the type params. + """ + + def __init__(self, header, type_vars, constructors): + """Defines a TypeData object. + + Parameters + ---------- + header: tvm.relay.GlobalTypeVar + The name of the ADT. + ADTs with the same constructors but different names are + treated as different types. + type_vars: List[TypeVar] + Type variables that appear in constructors. + constructors: List[tvm.relay.Constructor] + The constructors for the ADT. + + Returns + ------- + type_data: TypeData + The adt declaration. + """ + self.__init_handle_by_constructor__(_make.TypeData, header, type_vars, constructors) + + +@register_relay_node +class Clause(NodeBase): + """Clause for pattern matching in Relay.""" + + def __init__(self, lhs, rhs): + """Construct a clause. + + Parameters + ---------- + lhs: tvm.relay.Pattern + Left-hand side of match clause. + rhs: tvm.relay.Expr + Right-hand side of match clause. + + Returns + ------- + clause: Clause + The Clause. + """ + self.__init_handle_by_constructor__(_make.Clause, lhs, rhs) + + +@register_relay_node +class Match(Expr): + """Pattern matching expression in Relay.""" + + def __init__(self, data, clauses): + """Construct a Match. + + Parameters + ---------- + data: tvm.relay.Expr + The value being deconstructed and matched. + clauses: List[tvm.relay.Clause] + The pattern match clauses. + Returns + ------- + match: tvm.relay.Expr + The match expression. + """ + self.__init_handle_by_constructor__(_make.Match, data, clauses) diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index cc510b2290cf..fba4d11aaf72 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -292,6 +292,12 @@ def visit_ref_read(self, _): def visit_ref_write(self, _): raise RuntimeError("reference not supported") + def visit_constructor(self, _): + raise Exception("ADT constructor case not yet implemented") + + def visit_match(self, _): + raise Exception("match case not yet implemented") + def _get_json(self): """ Convert the sequence of nodes stored by the compiler into the diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index b21eab185c28..1d50a571a460 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -52,6 +52,13 @@ class Closure(Value): pass +@register_relay_node +class ConstructorValue(Value): + def __init__(self, constructor, fields, types): + self.__init_handle_by_constructor__( + _make.ConstructorValue, constructor, fields, types) + + @register_relay_node class TensorValue(Value): """A Tensor value produced by the interpreter.""" diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 71b89d0b4777..9257bad7dd58 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -172,6 +172,20 @@ def name_hint(self): name = self.vid.name_hint return name + def __call__(self, *args): + """Call the variable (if it represents a function). + + Parameters + ---------- + args: List[relay.Expr] + The arguments to the call. + + Returns + ------- + call: Call + A call taking the variable as a function. + """ + return Call(self, args) @register_relay_node class GlobalVar(Expr): diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index b22a4e7562e2..199d66baa45a 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -2,6 +2,7 @@ """The expression functor of Relay.""" from .expr import Function, Call, Let, Var, GlobalVar, If, Tuple, TupleGetItem, Constant +from .adt import Constructor, Match, Clause from .op import Op class ExprFunctor: @@ -47,6 +48,10 @@ def visit(self, expr): res = self.visit_ref_read(expr) elif isinstance(expr, RefWrite): res = self.visit_ref_write(expr) + elif isinstance(expr, Constructor): + res = self.visit_constructor(expr) + elif isinstance(expr, Match): + res = self.visit_match(expr) else: raise Exception("warning unhandled case: {0}".format(type(expr))) @@ -96,6 +101,13 @@ def visit_ref_write(self, _): def visit_ref_read(self, _): raise NotImplementedError() + def visit_constructor(self, _): + raise NotImplementedError() + + def visit_match(self, _): + raise NotImplementedError() + + class ExprMutator(ExprFunctor): """ A functional visitor over Expr. diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index b27f030e459a..90d038ebc784 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -9,6 +9,7 @@ from . import _make from .expr import Expr from .ty import Type +from .module import Module def post_order_visit(expr, fvisit): """Recursively visit the ir in post DFS order node, @@ -107,7 +108,7 @@ def well_formed(expr): def check_kind(t, mod=None): - """Check that the type is well kinded. + """Check that the type is well kinded and return the kind. For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes. Parameters @@ -120,15 +121,15 @@ def check_kind(t, mod=None): Returns ------- - well_kinded : bool - whether the input type is well kinded. + kind : Kind + the kind of t Examples -------- .. code:: python - assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) - assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) + assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) == Shape + assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type """ if mod is not None: return _ir_pass.check_kind(t, mod) @@ -190,52 +191,61 @@ def all_vars(expr): return _ir_pass.all_vars(expr) -def free_type_vars(expr): +def free_type_vars(expr, mod=None): """Get free type variables from expression/type e Parameters ---------- expr: Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type + mod: tvm.relay.Module, optional + The global module Returns ------- free : List[tvm.relay.TypeVar] The list of free type variables in post-DFS order """ - return _ir_pass.free_type_vars(expr) + use_mod = mod if mod is not None else Module() + return _ir_pass.free_type_vars(expr, use_mod) -def bound_type_vars(expr): +def bound_type_vars(expr, mod=None): """Get bound type variables from expression/type e Parameters ---------- expr: Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type + mod: tvm.relay.Module, optional + The global module Returns ------- free : List[tvm.relay.TypeVar] The list of bound type variables in post-DFS order """ - return _ir_pass.bound_type_vars(expr) + use_mod = mod if mod is not None else Module() + return _ir_pass.bound_type_vars(expr, use_mod) -def all_type_vars(expr): +def all_type_vars(expr, mod=None): """Get all type variables from expression/type e Parameters ---------- expr: Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type + mod: tvm.relay.Module, optional + The global module Returns ------- free : List[tvm.relay.TypeVar] The list of all type variables in post-DFS order """ - return _ir_pass.all_type_vars(expr) + use_mod = mod if mod is not None else Module() + return _ir_pass.all_type_vars(expr, use_mod) def simplify_inference(expr): diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 024c6baf7012..ef496333d828 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -6,6 +6,7 @@ from . import _module from . import expr as _expr +from . import ty as _ty @register_relay_node class Module(RelayNode): @@ -20,7 +21,7 @@ class Module(RelayNode): functions : dict, optional. Map of global var to Function """ - def __init__(self, functions=None): + def __init__(self, functions=None, type_definitions=None): if functions is None: functions = {} elif isinstance(functions, dict): @@ -32,28 +33,46 @@ def __init__(self, functions=None): raise TypeError("Expect functions to be Dict[GlobalVar, Function]") mapped_funcs[k] = v functions = mapped_funcs - self.__init_handle_by_constructor__(_make.Module, functions) + if type_definitions is None: + type_definitions = {} + elif isinstance(type_definitions, dict): + mapped_type_defs = {} + for k, v in type_definitions.items(): + if isinstance(k, _base.string_types): + k = _ty.GlobalTypeVar(k) + if not isinstance(k, _ty.GlobalTypeVar): + raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") + mapped_type_defs[k] = v + type_definitions = mapped_type_defs + self.__init_handle_by_constructor__(_make.Module, functions, type_definitions) + - def __setitem__(self, var, func): - """Add a function to the module. + def __setitem__(self, var, val): + """Add a mapping to the module. Parameters --------- var: GlobalVar - The global variable which names the function. + The global variable. - func: Function - The function. + val: Union[Function, Type] + The value. """ - return self._add(var, func) + return self._add(var, val) - def _add(self, var, func, update=False): - if isinstance(var, _base.string_types): - var = _expr.GlobalVar(var) - return _module.Module_Add(self, var, func, update) + def _add(self, var, val, update=False): + if isinstance(val, _expr.Function): + if isinstance(var, _base.string_types): + var = _expr.GlobalVar(var) + _make.Module_Add(self, var, val, update) + else: + assert isinstance(val, _ty.Type) + if isinstance(var, _base.string_types): + var = _ty.GlobalTypeVar(var) + _module.Module_AddDef(self, var, val) def __getitem__(self, var): - """Lookup a global function by name or by variable. + """Lookup a global definition by name or by variable. Parameters ---------- @@ -62,13 +81,15 @@ def __getitem__(self, var): Returns ------- - func: Function - The function referenced by :code:`var`. + val: Union[Function, Type] + The definition referenced by :code:`var` (either a function or type). """ if isinstance(var, _base.string_types): return _module.Module_Lookup_str(self, var) - else: + elif isinstance(var, _expr.GlobalVar): return _module.Module_Lookup(self, var) + else: + return _module.Module_LookupDef(self, var) def update(self, other): """Insert functions in another Module to current one. @@ -100,3 +121,22 @@ def get_global_var(self, name): tvm.TVMError if we cannot find corresponding global var. """ return _module.Module_GetGlobalVar(self, name) + + def get_global_type_var(self, name): + """Get a global type variable in the function by name. + + Parameters + ---------- + name: str + The name of the global type variable. + + Returns + ------- + global_type_var: GlobalTypeVar + The global variable mapped to :code:`name`. + + Raises + ------ + tvm.TVMError if we cannot find corresponding global type var. + """ + return _module.Module_GetGlobalTypeVar(self, name) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py new file mode 100644 index 000000000000..99b6c8d1c766 --- /dev/null +++ b/python/tvm/relay/prelude.py @@ -0,0 +1,379 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""Adds certain standard global functions and ADT definitions to the module.""" +from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type +from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem +from .adt import Constructor, TypeData, Clause, Match +from .adt import PatternConstructor, PatternVar, PatternWildcard + +class Prelude: + """Contains standard definitions.""" + + def define_list_adt(self): + """Defines a LISP-style list ADT. An empty list is + represented by nil(). A member x can be appended to the + front of a list l via the constructor cons(x, l).""" + self.l = GlobalTypeVar("list") + a = TypeVar("a") + self.nil = Constructor("nil", [], self.l) + self.cons = Constructor("cons", [a, self.l(a)], self.l) + self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons]) + + def define_list_map(self): + """Defines a function for mapping a function over a list's + elements. That is, map(f, l) returns a new list where + the ith member is f applied to the ith member of l. + + map(f, l) : fn(fn(a) -> b, list[a]) -> list[b] + """ + self.map = GlobalVar("map") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a], b)) + x = Var("x", self.l(a)) + y = Var("y") + z = Var("z") + nil_case = Clause(PatternConstructor(self.nil), self.nil()) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), + self.cons(f(y), self.map(f, z))) + self.mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), self.l(b), [a, b]) + + def define_list_foldl(self): + """Defines a left-way fold over a list. + + foldl(f, z, l) : fn(fn(a, b) -> a, a, list[b]) -> a + + foldl(f, z, cons(a1, cons(a2, cons(a3, cons(..., nil))))) + evaluates to f(...f(f(f(z, a1), a2), a3)...) + """ + self.foldl = GlobalVar("foldl") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a, b], a)) + av = Var("av", a) + bv = Var("bv", self.l(b)) + y = Var("y") + z = Var("z") + nil_case = Clause(PatternConstructor(self.nil), av) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), + self.foldl(f, f(av, y), z)) + self.mod[self.foldl] = Function([f, av, bv], + Match(bv, [nil_case, cons_case]), a, [a, b]) + + def define_list_foldr(self): + """Defines a right-way fold over a list. + + foldr(f, l, z) : fn(fn(a, b) -> b, list[a], b) -> b + + foldr(f, cons(a1, cons(a2, cons(..., cons(an, nil)))), z) + evalutes to f(a1, f(a2, f(..., f(an, z)))...) + """ + self.foldr = GlobalVar("foldr") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a, b], b)) + av = Var("av", self.l(a)) + bv = Var("bv", b) + y = Var("y") + z = Var("z") + nil_case = Clause(PatternConstructor(self.nil), bv) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), + f(y, self.foldr(f, bv, z))) + self.mod[self.foldr] = Function([f, bv, av], + Match(av, [nil_case, cons_case]), b, [a, b]) + + def define_list_concat(self): + """Defines a function that concatenates two lists. + + concat(l1, l2) : fn(list[a], list[a]) -> list[a]""" + self.concat = GlobalVar("concat") + a = TypeVar("a") + l1 = Var("l1", self.l(a)) + l2 = Var("l2", self.l(a)) + h = Var("h") + t = Var("t") + updater = Function([h, t], self.cons(h, t)) + self.mod[self.concat] = Function([l1, l2], + self.foldr(updater, l2, l1), + self.l(a), [a]) + + def define_list_filter(self): + """Defines a function that filters a list. + + filter(f, l) : fn(fn(a) -> Tensor[(), bool], list[a]) -> list[a] + + It returns the sublist of l consisting of the elements for which f returns true. + """ + self.filter = GlobalVar("filter") + a = TypeVar("a") + f = Var("f", FuncType([a], scalar_type("bool"))) + l = Var("l", self.l(a)) + h = Var("h") + t = Var("t") + nil_case = Clause(PatternConstructor(self.nil), self.nil()) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h), PatternVar(t)]), + If(f(h), self.cons(h, self.filter(f, t)), self.filter(f, t))) + self.mod[self.filter] = Function([f, l], Match(l, [nil_case, cons_case]), self.l(a), [a]) + + def define_list_zip(self): + """Defines a function that combines two lists into a list of tuples of their elements. + + zip(l, m) : fn(list[a], list[b]) -> list[(a, b)] + + The zipped list will be the length of the shorter list. + """ + self.zip = GlobalVar("zip") + a = TypeVar("a") + b = TypeVar("b") + nil_case = Clause(PatternConstructor(self.nil), self.nil()) + l1 = Var("l1") + l2 = Var("l2") + h1 = Var("h1") + h2 = Var("h2") + t1 = Var("t1") + t2 = Var("t2") + inner_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h2), PatternVar(t2)]), + self.cons(Tuple([h1, h2]), self.zip(t1, t2))) + outer_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h1), PatternVar(t1)]), + Match(l2, [nil_case, inner_cons_case])) + self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]), + self.l(TupleType([a, b])), [a, b]) + + def define_list_rev(self): + """Defines a function that reverses a list. + + rev(l) : fn(list[a]) -> list[a] + """ + self.rev = GlobalVar("rev") + a = TypeVar("a") + l = Var("l", self.l(a)) + x = Var("x") + y = Var("y") + updater = Function([y, x], self.cons(x, y)) + self.mod[self.rev] = Function([l], + self.foldl(updater, self.nil(), l), + self.l(a), [a]) + + def define_list_map_accumr(self): + """Defines an accumulative map, which is a fold that simulataneously updates + an accumulator value and a list of results. + + map_accumr(f, s, l) : fn(fn(a, b) -> (a, c), a, list[b]) -> (a, list[c]) + + This map proceeds through l from right to left. + """ + self.map_accumr = GlobalVar("map_accumr") + a = TypeVar("a") + b = TypeVar("b") + c = TypeVar("c") + f = Var("f", FuncType([a, b], TupleType([a, c]))) + acc = Var("acc", a) + l = Var("l", self.l(b)) + v = Var("v", b) + p = Var("p", TupleType([a, self.l(c)])) + f_out = Var("f_out", TupleType([a, c])) + updater = Function([v, p], + Let(f_out, f(TupleGetItem(p, 0), v), + Tuple([TupleGetItem(f_out, 0), + self.cons(TupleGetItem(f_out, 1), + TupleGetItem(p, 1))])), + TupleType([a, self.l(c)])) + self.mod[self.map_accumr] = Function([f, acc, l], + self.foldr(updater, Tuple([acc, self.nil()]), l), + TupleType([a, self.l(c)]), + [a, b, c]) + + def define_list_map_accuml(self): + """Defines an accumulative map, which is a fold that simulataneously updates + an accumulator value and a list of results. + + map_accuml(f, s, l) : fn(fn(a, b) -> (a, c), a, list[b]) -> (a, list[c]) + + This map proceeds through l from left to right. + """ + self.map_accuml = GlobalVar("map_accuml") + a = TypeVar("a") + b = TypeVar("b") + c = TypeVar("c") + f = Var("f", FuncType([a, b], TupleType([a, c]))) + acc = Var("acc", a) + l = Var("l", self.l(b)) + v = Var("v", b) + p = Var("p", TupleType([a, self.l(c)])) + f_out = Var("f_out", TupleType([a, c])) + updater = Function([p, v], + Let(f_out, f(TupleGetItem(p, 0), v), + Tuple([TupleGetItem(f_out, 0), + self.cons(TupleGetItem(f_out, 1), + TupleGetItem(p, 1))])), + TupleType([a, self.l(c)])) + self.mod[self.map_accuml] = Function([f, acc, l], + self.foldl(updater, Tuple([acc, self.nil()]), l), + TupleType([a, self.l(c)]), + [a, b, c]) + + + def define_optional_adt(self): + """Defines an optional ADT, which can either contain some other + type or nothing at all.""" + self.optional = GlobalTypeVar("optional") + a = TypeVar("a") + self.some = Constructor("some", [a], self.optional) + self.none = Constructor("none", [], self.optional) + self.mod[self.optional] = TypeData(self.optional, [a], [self.some, self.none]) + + def define_list_unfoldr(self): + """Defines a function that builds up a list starting from a seed value. + + unfoldr(f, s) : fn(fn(a) -> Optional[(a, b)], a) -> list[b] + + f returns an option containing a new seed and an output value. f will + continue to be called on the new seeds until it returns None. All the + output values will be combined into a list, right to left. + """ + self.unfoldr = GlobalVar("unfoldr") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a], self.optional(TupleType([a, b])))) + s = Var("s", a) + p = Var("p", TupleType([a, b])) + none_case = Clause(PatternConstructor(self.none), self.nil()) + some_case = Clause(PatternConstructor(self.some, [PatternVar(p)]), + self.cons(TupleGetItem(p, 1), + self.unfoldr(f, TupleGetItem(p, 0)))) + self.mod[self.unfoldr] = Function([f, s], Match(f(s), [none_case, some_case]), + self.l(b), [a, b]) + + def define_list_unfoldl(self): + """Defines a function that builds up a list starting from a seed value. + + unfoldl(f, s) : fn(fn(a) -> Optional[(a, b)], a) -> list[b] + + f returns an option containing a new seed and an output value. f will + continue to be called on the new seeds until it returns None. All the + output values will be combined into a list, left to right. + """ + self.unfoldl = GlobalVar("unfoldl") + a = TypeVar("a") + b = TypeVar("b") + f = Var("f", FuncType([a], self.optional(TupleType([a, b])))) + s = Var("s", a) + # easiest way to implement is to do a right unfold and reverse + self.mod[self.unfoldl] = Function([f, s], + self.rev(self.unfoldr(f, s)), + self.l(b), [a, b]) + + def define_nat_adt(self): + """Defines a Peano (unary) natural number ADT. + Zero is represented by z(). s(n) adds 1 to a nat n.""" + self.nat = GlobalTypeVar("nat") + self.z = Constructor("z", [], self.nat) + self.s = Constructor("s", [self.nat()], self.nat) + self.mod[self.nat] = TypeData(self.nat, [], [self.z, self.s]) + + def define_nat_double(self): + """Defines a function that doubles a nat.""" + self.double = GlobalVar("double") + x = Var("x", self.nat()) + y = Var("y") + z_case = Clause(PatternConstructor(self.z), self.z()) + s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), + self.s(self.s(self.double(y)))) + self.mod[self.double] = Function([x], Match(x, [z_case, s_case])) + + def define_nat_add(self): + """Defines a function that adds two nats.""" + self.add = GlobalVar("add") + x = Var("x", self.nat()) + y = Var("y", self.nat()) + a = Var("a") + z_case = Clause(PatternConstructor(self.z), y) + s_case = Clause(PatternConstructor(self.s, [PatternVar(a)]), + self.s(self.add(a, y))) + self.mod[self.add] = Function([x, y], Match(x, [z_case, s_case])) + + def define_list_sum(self): + """Defines a function that computes the sum of a list of nats.""" + self.sum = GlobalVar("sum") + a = Var("a", self.l(self.nat())) + self.mod[self.sum] = Function([a], self.foldl(self.add, self.z(), a)) + + def define_list_length(self): + """Defines a function that returns the length of a list as a nat""" + self.length = GlobalVar("length") + a = TypeVar("a") + x = Var("x", self.l(a)) + y = Var("y") + nil_case = Clause(PatternConstructor(self.nil), self.z()) + cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]), + self.s(self.length(y))) + self.mod[self.length] = Function([x], + Match(x, [nil_case, cons_case]), None, [a]) + + def define_tree_adt(self): + """Defines a tree ADT. A tree can contain any type. + It has only one constructor, rose(x, l), where x is the content + of that point of the tree and l is a list of more trees of the + same type. A leaf is thus rose(x, nil()). + """ + self.tree = GlobalTypeVar("tree") + a = TypeVar("a") + self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree) + self.mod[self.tree] = TypeData(self.tree, [a], [self.rose]) + + def define_tree_map(self): + """Defines a function that maps over a tree. The function + is applied to each subtree's contents. + + Signature: fn(f : fn(a) -> b, t : tree[a]) -> tree[b] + """ + self.tmap = GlobalVar("tmap") + a = TypeVar("a") + b = TypeVar("b") + t = Var("t", self.tree(a)) + f = Var("f", FuncType([a], b)) + x = Var("x", self.tree(a)) + y = Var("y") + z = Var("z") + rose_case = Clause(PatternConstructor(self.rose, [PatternVar(y), PatternVar(z)]), + self.rose(f(y), self.map(Function([x], self.tmap(f, x)), z))) + self.mod[self.tmap] = Function([f, t], + Match(t, [rose_case]), self.tree(b), [a, b]) + + def define_tree_size(self): + """Defines a function that computes the size of a tree as a nat.""" + self.size = GlobalVar("size") + a = TypeVar("a") + t = Var("t", self.tree(a)) + x = Var("x", self.tree(a)) + z = Var("z") + rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]), + self.s(self.sum(self.map(Function([x], self.size(x)), z)))) + self.mod[self.size] = Function([t], + Match(t, [rose_case]), self.nat(), [a]) + + def __init__(self, mod): + self.mod = mod + self.define_list_adt() + self.define_list_map() + self.define_list_foldl() + self.define_list_foldr() + self.define_list_concat() + self.define_list_filter() + self.define_list_zip() + self.define_list_rev() + self.define_list_map_accumr() + self.define_list_map_accuml() + + self.define_optional_adt() + self.define_list_unfoldr() + self.define_list_unfoldl() + + self.define_nat_adt() + self.define_nat_double() + self.define_nat_add() + self.define_list_length() + self.define_list_sum() + + self.define_tree_adt() + self.define_tree_map() + self.define_tree_size() diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index bed293d1e3ca..1cfa96aa7213 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -21,6 +21,19 @@ def same_as(self, other): """Compares two Relay types by referential equality.""" return super().__eq__(other) + def __call__(self, *args): + """Create a type call from this type. + + Parameters + ---------- + args: List[relay.Type] + The arguments to the type call. + + Returns + ------- + call: relay.TypeCall + """ + return TypeCall(self, args) @register_relay_node class TensorType(Type): @@ -75,6 +88,9 @@ class Kind(IntEnum): ShapeVar = 1 BaseType = 2 Shape = 3 + Constraint = 4 + AdtHandle = 5 + TypeData = 6 @register_relay_node class TypeVar(Type): @@ -106,6 +122,53 @@ def __init__(self, var, kind=Kind.Type): self.__init_handle_by_constructor__(_make.TypeVar, var, kind) +@register_relay_node +class GlobalTypeVar(Type): + """A global type variable in Relay. + GlobalTypeVar is used to refer to the global type-level definitions + stored in the environment. + """ + + def __init__(self, var, kind=Kind.AdtHandle): + """Construct a GlobalTypeVar. + + Parameters + ---------- + var: tvm.Var + The tvm.Var which backs the type parameter. + kind: Kind, optional + The kind of the type parameter, Kind.AdtHandle by default. + + Returns + ------- + type_var: GlobalTypeVar + The global type variable. + """ + self.__init_handle_by_constructor__(_make.GlobalTypeVar, var, kind) + + +@register_relay_node +class TypeCall(Type): + """Type-level function application in Relay. + A type call applies argument types to a constructor (type-level function). + """ + + def __init__(self, func, args): + """Construct a TypeCall. + Parameters + ---------- + func: tvm.relay.Type + The function. + args: List[tvm.expr.Type] + The arguments. + Returns + ------- + type_call: TypeCall + The type function application. + """ + self.__init_handle_by_constructor__(_make.TypeCall, func, args) + + @register_relay_node class TypeConstraint(Type): """Abstract class representing a type constraint.""" diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 893e66b41b42..4ef893f463e9 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -92,6 +93,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "RefValueNode(" << node->value << ")"; }); +ConstructorValue ConstructorValueNode::make(Constructor constructor, + tvm::Array fields) { + NodePtr n = make_node(); + n->constructor = constructor; + n->fields = fields; + return ConstructorValue(n); +} + +TVM_REGISTER_API("relay._make.ConstructorValue") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ConstructorValueNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const ConstructorValueNode* node, + tvm::IRPrinter* p) { + p->stream << "ConstructorValueNode(" << node->constructor + << node->fields << ")"; +}); + /*! * \brief A stack frame in the Relay interpreter. * @@ -185,7 +206,8 @@ InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) { // // Conversion to ANF is recommended before running the interpretation. class Interpreter : - public ExprFunctor { + public ExprFunctor, + PatternFunctor { public: Interpreter(Module mod, DLContext context, @@ -209,7 +231,7 @@ class Interpreter : } Value Eval(const Expr& expr) { - return (*this)(expr); + return VisitExpr(expr); } Value VisitExpr(const Expr& expr) final { @@ -401,6 +423,9 @@ class Interpreter : << "; operators should be removed by future passes; try " "fusing and lowering"; } + if (auto con = call->op.as()) { + return ConstructorValueNode::make(GetRef(con), args); + } // Now we just evaluate and expect to find a closure. Value fn_val = Eval(call->op); if (const ClosureNode* closure_node = fn_val.as()) { @@ -474,6 +499,44 @@ class Interpreter : } } + Value VisitExpr_(const MatchNode* op) final { + Value v = Eval(op->data); + for (const Clause& c : op->clauses) { + if (VisitPattern(c->lhs, v)) { + return VisitExpr(c->rhs); + } + } + LOG(FATAL) << "did not find any match"; + return Value(); + } + + bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final { + const ConstructorValueNode* cvn = v.as(); + CHECK(cvn) << "need to be a constructor for match"; + CHECK_NE(op->constructor->tag, -1); + CHECK_NE(cvn->constructor->tag, -1); + if (op->constructor->tag == cvn->constructor->tag) { + // todo(M.K.): should use ptr equality but it is broken + CHECK(op->patterns.size() == cvn->fields.size()); + for (size_t i = 0; i < op->patterns.size(); ++i) { + if (!VisitPattern(op->patterns[i], cvn->fields[i])) { + return false; + } + } + return true; + } + return false; + } + + bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final { + return true; + } + + bool VisitPattern_(const PatternVarNode* op, const Value& v) final { + extend(op->var, v); + return true; + } + InterpreterState get_state(Expr e = Expr()) const { InterpreterStateNode::Stack stack; for (auto fr : this->stack_.frames) { @@ -485,14 +548,14 @@ class Interpreter : } private: - // module + // Module Module mod_; // For simplicity we only run the interpreter on a single context. // Context to run the interpreter on. DLContext context_; // Target parameter being used by the interpreter. Target target_; - // value stack. + // Value stack. Stack stack_; // Backend compile engine. CompileEngine engine_; diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc new file mode 100644 index 000000000000..21d98036fb0d --- /dev/null +++ b/src/relay/ir/adt.cc @@ -0,0 +1,162 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/ir/adt.cc + * \brief AST nodes for Relay algebraic data types (ADTs). + */ +#include +#include + +namespace tvm { +namespace relay { + +PatternWildcard PatternWildcardNode::make() { + NodePtr n = make_node(); + return PatternWildcard(n); +} + +TVM_REGISTER_NODE_TYPE(PatternWildcardNode); + +TVM_REGISTER_API("relay._make.PatternWildcard") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = PatternWildcardNode::make(); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const PatternWildcardNode* node, + tvm::IRPrinter* p) { + p->stream << "PatternWildcardNode()"; +}); + +PatternVar PatternVarNode::make(tvm::relay::Var var) { + NodePtr n = make_node(); + n->var = std::move(var); + return PatternVar(n); +} + +TVM_REGISTER_NODE_TYPE(PatternVarNode); + +TVM_REGISTER_API("relay._make.PatternVar") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = PatternVarNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const PatternVarNode* node, + tvm::IRPrinter* p) { + p->stream << "PatternVarNode(" << node->var << ")"; +}); + +PatternConstructor PatternConstructorNode::make(Constructor constructor, + tvm::Array patterns) { + NodePtr n = make_node(); + n->constructor = std::move(constructor); + n->patterns = std::move(patterns); + return PatternConstructor(n); +} + +TVM_REGISTER_NODE_TYPE(PatternConstructorNode); + +TVM_REGISTER_API("relay._make.PatternConstructor") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = PatternConstructorNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const PatternConstructorNode* node, + tvm::IRPrinter* p) { + p->stream << "PatternConstructorNode(" << node->constructor + << ", " << node->patterns << ")"; +}); + +Constructor ConstructorNode::make(std::string name_hint, + tvm::Array inputs, + GlobalTypeVar belong_to) { + NodePtr n = make_node(); + n->name_hint = std::move(name_hint); + n->inputs = std::move(inputs); + n->belong_to = std::move(belong_to); + return Constructor(n); +} + +TVM_REGISTER_NODE_TYPE(ConstructorNode); + +TVM_REGISTER_API("relay._make.Constructor") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ConstructorNode::make(args[0], args[1], args[2]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const ConstructorNode* node, + tvm::IRPrinter* p) { + p->stream << "ConstructorNode(" << node->name_hint << ", " + << node->inputs << ", " << node->belong_to << ")"; +}); + +TypeData TypeDataNode::make(GlobalTypeVar header, + tvm::Array type_vars, + tvm::Array constructors) { + NodePtr n = make_node(); + n->header = std::move(header); + n->type_vars = std::move(type_vars); + n->constructors = std::move(constructors); + return TypeData(n); +} + +TVM_REGISTER_NODE_TYPE(TypeDataNode); + +TVM_REGISTER_API("relay._make.TypeData") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TypeDataNode::make(args[0], args[1], args[2]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TypeDataNode* node, + tvm::IRPrinter* p) { + p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " + << node->constructors << ")"; +}); + +Clause ClauseNode::make(Pattern lhs, Expr rhs) { + NodePtr n = make_node(); + n->lhs = std::move(lhs); + n->rhs = std::move(rhs); + return Clause(n); +} + +TVM_REGISTER_NODE_TYPE(ClauseNode); + +TVM_REGISTER_API("relay._make.Clause") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ClauseNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const ClauseNode* node, + tvm::IRPrinter* p) { + p->stream << "ClauseNode(" << node->lhs << ", " + << node->rhs << ")"; + }); + +Match MatchNode::make(Expr data, tvm::Array clauses) { + NodePtr n = make_node(); + n->data = std::move(data); + n->clauses = std::move(clauses); + return Match(n); +} + +TVM_REGISTER_NODE_TYPE(MatchNode); + +TVM_REGISTER_API("relay._make.Match") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = MatchNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const MatchNode* node, + tvm::IRPrinter* p) { + p->stream << "MatchNode(" << node->data << ", " + << node->clauses << ")"; +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index d0cc004994d4..96517f8dd445 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include #include #include "type_functor.h" @@ -17,7 +18,8 @@ namespace relay { class AlphaEqualHandler: public AttrsEqualHandler, public TypeFunctor, - public ExprFunctor { + public ExprFunctor, + public PatternFunctor { public: explicit AlphaEqualHandler(bool map_free_var) : map_free_var_(map_free_var) {} @@ -160,7 +162,7 @@ class AlphaEqualHandler: } equal_map_[lhs->type_params[i]] = rhs->type_params[i]; // set up type parameter equal - if (lhs->type_params[i]->kind == TypeVarNode::Kind::kShapeVar) { + if (lhs->type_params[i]->kind == Kind::kShapeVar) { // map variable equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var; } @@ -215,6 +217,26 @@ class AlphaEqualHandler: return false; } + bool VisitType_(const GlobalTypeVarNode* op, const Type& t2) final { + return GetRef(op) == t2; + } + + bool VisitType_(const TypeCallNode* op, const Type& t2) final { + const TypeCallNode* pt = t2.as(); + if (pt == nullptr + || op->args.size() != pt->args.size() + || !TypeEqual(op->func, pt->func)) { + return false; + } + + for (size_t i = 0; i < op->args.size(); ++i) { + if (!TypeEqual(op->args[i], pt->args[i])) { + return false; + } + } + return true; + } + // Expr equal checking. bool NDArrayEqual(const runtime::NDArray& lhs, const runtime::NDArray& rhs) { @@ -261,11 +283,9 @@ class AlphaEqualHandler: bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final { if (const GlobalVarNode* rhs = other.as()) { // use name equality for global var for now. - if (lhs->name_hint != rhs->name_hint) return false; - return true; - } else { - return false; + return lhs->name_hint == rhs->name_hint; } + return false; } bool VisitExpr_(const TupleNode* lhs, const Expr& other) final { @@ -392,6 +412,63 @@ class AlphaEqualHandler: return false; } } + + bool VisitExpr_(const ConstructorNode* op, const Expr& e2) final { + return GetRef(op) == e2; + } + + bool ClauseEqual(const Clause& l, const Clause& r) { + return PatternEqual(l->lhs, r->lhs) && ExprEqual(l->rhs, r->rhs); + } + + bool PatternEqual(const Pattern& l, const Pattern& r) { + return VisitPattern(l, r); + } + + bool VisitPattern_(const PatternWildcardNode* op, const Pattern& r) final { + return r.as(); + } + + bool VisitPattern_(const PatternVarNode* op, const Pattern& e2) final { + if (const auto* r = e2.as()) { + return MergeVarDecl(op->var, r->var); + } + return false; + } + + bool VisitPattern_(const PatternConstructorNode* op, const Pattern& e2) final { + const auto* r = e2.as(); + if (r == nullptr + || !ExprEqual(op->constructor, r->constructor) + || op->patterns.size() != r->patterns.size()) { + return false; + } + + for (size_t i = 0; i < op->patterns.size(); i++) { + if (!PatternEqual(op->patterns[i], r->patterns[i])) { + return false; + } + } + return true; + } + + bool VisitExpr_(const MatchNode* op, const Expr& e2) final { + const MatchNode* r = e2.as(); + + if (r == nullptr + || !ExprEqual(op->data, r->data) + || op->clauses.size() != r->clauses.size()) { + return false; + } + + for (size_t i = 0; i < op->clauses.size(); ++i) { + if (!ClauseEqual(op->clauses[i], r->clauses[i])) { + return false; + } + } + return true; + } + private: // whether to map open terms. bool map_free_var_{false}; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index bc6eee3ebc03..29fe98ba78f5 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -130,9 +130,14 @@ Function FunctionNode::make(tvm::Array params, FuncType FunctionNode::func_type_annotation() const { Array param_types; for (auto param : this->params) { - param_types.push_back(param->type_annotation); + Type param_type = (param->type_annotation.defined()) ? param->type_annotation + : IncompleteTypeNode::make(Kind::kType); + param_types.push_back(param_type); } - return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); + + Type ret_type = (this->ret_type.defined()) ? this->ret_type + : IncompleteTypeNode::make(Kind::kType); + return FuncTypeNode::make(param_types, ret_type, this->type_params, {}); } bool FunctionNode::IsPrimitive() const { diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 9bdfa00ce298..6265873d8310 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -185,6 +185,24 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { } } +Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { + return GetRef(c); +} + +Expr ExprMutator::VisitExpr_(const MatchNode* m) { + std::vector clauses; + for (const Clause& p : m->clauses) { + clauses.push_back(VisitClause(p)); + } + return MatchNode::make(VisitExpr(m->data), clauses); +} + +Clause ExprMutator::VisitClause(const Clause& c) { + return ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs)); +} + +Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } + Type ExprMutator::VisitType(const Type& t) { return t; } void ExprVisitor::VisitExpr(const Expr& expr) { @@ -267,6 +285,27 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) { this->VisitExpr(op->value); } +void ExprVisitor::VisitExpr_(const ConstructorNode* op) { + for (const Type& t : op->inputs) { + this->VisitType(t); + } + this->VisitType(op->belong_to); +} + +void ExprVisitor::VisitExpr_(const MatchNode* op) { + this->VisitExpr(op->data); + for (const Clause& c : op->clauses) { + this->VisitClause(c); + } +} + +void ExprVisitor::VisitClause(const Clause& op) { + this->VisitPattern(op->lhs); + this->VisitExpr(op->rhs); +} + +void ExprVisitor::VisitPattern(const Pattern& p) { return; } + void ExprVisitor::VisitType(const Type& t) { return; } // visitor to implement apply diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index d984bb051e43..5e10906bec84 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include #include #include @@ -18,7 +19,8 @@ namespace relay { class RelayHashHandler: public AttrsHashHandler, public TypeFunctor, - public ExprFunctor { + public ExprFunctor, + public PatternFunctor { public: explicit RelayHashHandler() {} @@ -201,7 +203,7 @@ class RelayHashHandler: hash_map_[var] = hash; const auto* ty_param = var.as(); - if (ty_param && ty_param->kind == TypeVarNode::Kind::kShapeVar) { + if (ty_param && ty_param->kind == Kind::kShapeVar) { hash_map_[ty_param->var] = hash; } return hash; @@ -236,7 +238,7 @@ class RelayHashHandler: } hash = Combine(hash, TypeHash(func->ret_type)); - hash = Combine(hash, ExprHash(func->body)); + hash = Combine(hash, ExprHash(func->body)); return hash; } @@ -249,6 +251,10 @@ class RelayHashHandler: hash = Combine(hash, ExprHash(arg)); } + for (auto t : call->type_args) { + hash = Combine(hash, TypeHash(t)); + } + hash = Combine(hash, AttrHash(call->attrs)); return hash; @@ -304,6 +310,72 @@ class RelayHashHandler: hash = Combine(hash, ExprHash(rn->value)); return hash; } + + size_t VisitExpr_(const MatchNode* mn) final { + size_t hash = std::hash()(MatchNode::_type_key); + hash = Combine(hash, ExprHash(mn->data)); + for (const auto& c : mn->clauses) { + hash = Combine(hash, PatternHash(c->lhs)); + hash = Combine(hash, ExprHash(c->rhs)); + } + return hash; + } + + size_t VisitExpr_(const ConstructorNode* cn) final { + size_t hash = std::hash()(ConstructorNode::_type_key); + hash = Combine(hash, std::hash()(cn->name_hint)); + return hash; + } + + size_t VisitType_(const TypeCallNode* tcn) final { + size_t hash = std::hash()(TypeCallNode::_type_key); + hash = Combine(hash, TypeHash(tcn->func)); + for (const auto& t : tcn->args) { + hash = Combine(hash, TypeHash(t)); + } + return hash; + } + + size_t VisitType_(const TypeDataNode* tdn) final { + size_t hash = std::hash()(TypeDataNode::_type_key); + hash = Combine(hash, TypeHash(tdn->header)); + for (const auto& tv : tdn->type_vars) { + hash = Combine(hash, TypeHash(tv)); + } + for (const auto& cn : tdn->constructors) { + hash = Combine(hash, ExprHash(cn)); + } + return hash; + } + + size_t VisitType_(const GlobalTypeVarNode* tvn) final { + return BindVar(GetRef(tvn)); + } + + size_t PatternHash(const Pattern& p) { + return VisitPattern(p); + } + + size_t VisitPattern_(const PatternConstructorNode* pcn) final { + size_t hash = std::hash()(PatternConstructorNode::_type_key); + hash = Combine(hash, ExprHash(pcn->constructor)); + for (const auto& p : pcn->patterns) { + hash = Combine(hash, PatternHash(p)); + } + return hash; + } + + size_t VisitPattern_(const PatternVarNode* pvn) final { + size_t hash = std::hash()(PatternVarNode::_type_key); + hash = Combine(hash, BindVar(pvn->var)); + return hash; + } + + size_t VisitPattern_(const PatternWildcardNode* pwn) final { + size_t hash = std::hash()(PatternWildcardNode::_type_key); + return hash; + } + private: // renaming of NodeRef to indicate two nodes equals to each other std::unordered_map hash_map_; diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 9ba5efecec80..da273265ae33 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -13,18 +13,28 @@ namespace relay { using tvm::IRPrinter; using namespace runtime; -Module ModuleNode::make(tvm::Map global_funcs) { +Module ModuleNode::make(tvm::Map global_funcs, + tvm::Map global_type_defs) { auto n = make_node(); n->functions = std::move(global_funcs); + n->type_definitions = std::move(global_type_defs); for (const auto& kv : n->functions) { - // set gloval var map + // set global var map CHECK(!n->global_var_map_.count(kv.first->name_hint)) - << "Duplicate global function name " << kv.first->name_hint; + << "Duplicate global function name " << kv.first->name_hint; n->global_var_map_.Set(kv.first->name_hint, kv.first); } n->entry_func = GlobalVarNode::make("main"); + + for (const auto& kv : n->type_definitions) { + // set global typevar map + CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint)) + << "Duplicate global type definition name " << kv.first->var->name_hint; + n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first); + } + return Module(n); } @@ -51,6 +61,13 @@ void ModuleNode::AddUnchecked(const GlobalVar& var, global_var_map_.Set(var->name_hint, var); } +GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) { + auto it = global_type_var_map_.find(name); + CHECK(it != global_type_var_map_.end()) + << "Cannot find global type var " << name << " in the Module"; + return (*it).second; +} + void ModuleNode::Add(const GlobalVar& var, const Function& func, bool update) { @@ -69,6 +86,22 @@ void ModuleNode::Add(const GlobalVar& var, AddUnchecked(var, checked_func); } +void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { + this->type_definitions.Set(var, type); + // set global type var map + CHECK(!global_type_var_map_.count(var->var->name_hint)) + << "Duplicate global type definition name " << var->var->name_hint; + global_type_var_map_.Set(var->var->name_hint, var); + for (size_t i = 0; i < type->constructors.size(); ++i) { + type->constructors[i]->tag = i; + } + + // need to kind check at the end because the check can look up + // a definition potentially + CHECK(KindCheck(type, GetRef(this)) == Kind::kTypeData) + << "Invalid or malformed typedata given to module: " << type; +} + void ModuleNode::Update(const GlobalVar& var, const Function& func) { this->Add(var, func, true); } @@ -92,6 +125,18 @@ Function ModuleNode::Lookup(const std::string& name) { return this->Lookup(id); } +TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) { + auto it = type_definitions.find(var); + CHECK(it != type_definitions.end()) + << "There is no definition of " << var->var->name_hint; + return (*it).second; +} + +TypeData ModuleNode::LookupDef(const std::string& name) { + GlobalTypeVar id = this->GetGlobalTypeVar(name); + return this->LookupDef(id); +} + void ModuleNode::Update(const Module& mod) { for (auto pair : mod->functions) { this->Update(pair.first, pair.second); @@ -101,7 +146,7 @@ void ModuleNode::Update(const Module& mod) { Module ModuleNode::FromExpr( const Expr& expr, const tvm::Map& global_funcs) { - auto mod = ModuleNode::make(global_funcs); + auto mod = ModuleNode::make(global_funcs, {}); auto func_node = expr.as(); Function func; if (func_node) { @@ -117,21 +162,33 @@ TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_API("relay._make.Module") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = ModuleNode::make(args[0]); + *ret = ModuleNode::make(args[0], args[1]); }); -TVM_REGISTER_API("relay._module.Module_Add") +TVM_REGISTER_API("relay._make.Module_Add") .set_body([](TVMArgs args, TVMRetValue *ret) { Module mod = args[0]; mod->Add(args[1], args[2], args[3]); }); +TVM_REGISTER_API("relay._module.Module_AddDef") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Module mod = args[0]; + mod->AddDef(args[1], args[2]); + }); + TVM_REGISTER_API("relay._module.Module_GetGlobalVar") .set_body([](TVMArgs args, TVMRetValue *ret) { Module mod = args[0]; *ret = mod->GetGlobalVar(args[1]); }); +TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Module mod = args[0]; + *ret = mod->GetGlobalTypeVar(args[1]); + }); + TVM_REGISTER_API("relay._module.Module_Lookup") .set_body([](TVMArgs args, TVMRetValue *ret) { Module mod = args[0]; @@ -143,8 +200,21 @@ TVM_REGISTER_API("relay._module.Module_Lookup_str") .set_body([](TVMArgs args, TVMRetValue *ret) { Module mod = args[0]; std::string var_name = args[1]; - auto var = mod->GetGlobalVar(var_name); - *ret = mod->Lookup(var); + *ret = mod->Lookup(var_name); + }); + +TVM_REGISTER_API("relay._module.Module_LookupDef") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Module mod = args[0]; + GlobalTypeVar var = args[1]; + *ret = mod->LookupDef(var); + }); + +TVM_REGISTER_API("relay._module.Module_LookupDef_str") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Module mod = args[0]; + std::string var_name = args[1]; + *ret = mod->LookupDef(var_name); }); TVM_REGISTER_API("relay._module.Module_Update") diff --git a/src/relay/ir/pattern_functor.cc b/src/relay/ir/pattern_functor.cc new file mode 100644 index 000000000000..6d2e9d296164 --- /dev/null +++ b/src/relay/ir/pattern_functor.cc @@ -0,0 +1,75 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/pattern_functor.cc + * \brief Implementations of visitors and mutators for ADT patterns. + */ + +#include + +namespace tvm { +namespace relay { + +Pattern PatternMutator::Mutate(const Pattern& pat) { + return (*this)(pat); +} + +Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { + return GetRef(op); +} + +Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) { + return PatternVarNode::make(VisitVar(op->var)); +} + +Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) { + std::vector pat; + for (const auto& p : op->patterns) { + pat.push_back(VisitPattern(p)); + } + return PatternConstructorNode::make(VisitConstructor(op->constructor), pat); +} + +Type PatternMutator::VisitType(const Type& t) { + return t; +} + +Var PatternMutator::VisitVar(const Var& v) { + if (var_map_.count(v) == 0) { + var_map_.insert(std::pair(v, + VarNode::make(v->name_hint(), + VisitType(v->type_annotation)))); + } + return var_map_.at(v); +} + +Constructor PatternMutator::VisitConstructor(const Constructor& v) { + return v; +} + +void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) { } + +void PatternVisitor::VisitPattern_(const PatternVarNode* op) { + VisitVar(op->var); +} + +void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) { + VisitConstructor(op->constructor); + for (const auto& p : op->patterns) { + VisitPattern(p); + } +} + +void PatternVisitor::VisitType(const Type& t) { } + +void PatternVisitor::VisitVar(const Var& v) { + VisitType(v->type_annotation); +} + +void PatternVisitor::VisitConstructor(const Constructor& c) { + for (const auto& inp : c->inputs) { + VisitType(inp); + } +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 05179d584d84..932856a2055d 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include #include "type_functor.h" #include "../../lang/attr_functor.h" @@ -23,6 +24,12 @@ struct TextValue { TextValue() {} // constructor explicit TextValue(std::string name) : name(name) {} + TextValue operator+(const TextValue& rhs) const { + return TextValue(name + rhs.name); + } + TextValue operator+(const std::string& str) const { + return TextValue(name + str); + } }; // operator overloading @@ -128,6 +135,7 @@ class TextMetaDataContext { class TextPrinter : public ExprFunctor, + public PatternFunctor, public TypeFunctor, // NOLINT(*) public AttrFunctor { // NOLINT(*) public: @@ -213,6 +221,9 @@ class TextPrinter : memo_[expr] = val; return val; } + TextValue GetValue(const Pattern& p) { + return this->VisitPattern(p); + } //------------------------------------ // Overload of Expr printing functions //------------------------------------ @@ -391,6 +402,36 @@ class TextPrinter : return id; } + TextValue VisitExpr_(const MatchNode* op) final { + TextValue data = GetValue(op->data); + this->PrintIndent(); + TextValue id = this->AllocTempVar(); + stream_ << id << " = " << "Match " << data << " with"; + this->PrintEndInst("\n"); + for (const auto& c : op->clauses) { + this->PrintIndent(); + stream_ << GetValue(c->lhs) << " to " << GetValue(c->rhs); + this->PrintEndInst("\n"); + } + return id; + } + + TextValue VisitPattern_(const PatternConstructorNode* p) final { + TextValue ret(p->constructor->name_hint + "("); + for (const Pattern& pat : p->patterns) { + ret = ret + " " + GetValue(pat); + } + return ret + ")"; + } + + TextValue VisitPattern_(const PatternVarNode* pv) final { + return GetValue(pv->var); + } + + TextValue VisitExpr_(const ConstructorNode* n) final { + return TextValue(n->name_hint); + } + /*! * \brief Print the type to os * \param type The type to be printed. @@ -437,6 +478,18 @@ class TextPrinter : VisitTypeDefault_(node, os); } + void VisitType_(const TypeCallNode* node, std::ostream& os) final { + os << node->func << "(" << node->args << ")"; + } + + void VisitType_(const GlobalTypeVarNode* node, std::ostream& os) final { + VisitTypeDefault_(node, os); + } + + void VisitType_(const TypeDataNode* node, std::ostream& os) final { + VisitTypeDefault_(node, os); + } + void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*) // by default always print as meta-data os << meta_.GetMetaNode(GetRef(node)); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index e829d8abd63c..25b7beb5356a 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -48,7 +48,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; }); -TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) { +TypeVar TypeVarNode::make(std::string name, Kind kind) { NodePtr n = make_node(); n->var = tvm::Var(name); n->kind = std::move(kind); @@ -61,7 +61,7 @@ TVM_REGISTER_API("relay._make.TypeVar") .set_body([](TVMArgs args, TVMRetValue* ret) { int kind = args[1]; *ret = - TypeVarNode::make(args[0], static_cast(kind)); + TypeVarNode::make(args[0], static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -71,7 +71,50 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->kind << ")"; }); -IncompleteType IncompleteTypeNode::make(TypeVarNode::Kind kind) { +GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) { + NodePtr n = make_node(); + n->var = tvm::Var(name); + n->kind = std::move(kind); + return GlobalTypeVar(n); +} + +TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); + +TVM_REGISTER_API("relay._make.GlobalTypeVar") +.set_body([](TVMArgs args, TVMRetValue* ret) { + int kind = args[1]; + *ret = GlobalTypeVarNode::make(args[0], static_cast(kind)); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const GlobalTypeVarNode *node, + tvm::IRPrinter *p) { + p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", " + << node->kind << ")"; +}); + +TypeCall TypeCallNode::make(Type func, tvm::Array args) { + NodePtr n = make_node(); + n->func = std::move(func); + n->args = std::move(args); + return TypeCall(n); +} + +TVM_REGISTER_NODE_TYPE(TypeCallNode); + +TVM_REGISTER_API("relay._make.TypeCall") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TypeCallNode::make(args[0], args[1]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TypeCallNode* node, + tvm::IRPrinter* p) { + p->stream << "TypeCallNode(" << node->func << ", " + << node->args << ")"; +}); + +IncompleteType IncompleteTypeNode::make(Kind kind) { auto n = make_node(); n->kind = std::move(kind); return IncompleteType(n); @@ -82,7 +125,7 @@ TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); TVM_REGISTER_API("relay._make.IncompleteType") .set_body([](TVMArgs args, TVMRetValue* ret) { int kind = args[0]; - *ret = IncompleteTypeNode::make(static_cast(kind)); + *ret = IncompleteTypeNode::make(static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 100c633a2997..b88d0ee0e3ab 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -48,6 +48,29 @@ void TypeVisitor::VisitType_(const TypeRelationNode* op) { } } +void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) { +} + +void TypeVisitor::VisitType_(const TypeCallNode* op) { + this->VisitType(op->func); + for (const Type& t : op->args) { + this->VisitType(t); + } +} + +void TypeVisitor::VisitType_(const TypeDataNode* op) { + this->VisitType(op->header); + for (const auto& v : op->type_vars) { + this->VisitType(v); + } + + for (const auto& c : op->constructors) { + this->VisitType(c->belong_to); + for (const auto& t : c->inputs) { + this->VisitType(t); + } + } +} // Type Mutator. Array TypeMutator::MutateArray(Array arr) { @@ -139,6 +162,24 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { } } +Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { + return GetRef(op); +} + +Type TypeMutator::VisitType_(const TypeCallNode* op) { + Type new_func = VisitType(op->func); + Array new_args = MutateArray(op->args); + if (new_args.same_as(op->args) && new_func.same_as(op->func)) { + return GetRef(op); + } else { + return TypeCallNode::make(new_func, new_args); + } +} + +Type TypeMutator::VisitType_(const TypeDataNode* op) { + return GetRef(op); +} + // Implements bind. class TypeBinder : public TypeMutator { public: diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index 1be55e78eee6..36f77967c253 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -69,6 +70,10 @@ class TypeFunctor { virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitTypeDefault_(const Node* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->type_key(); throw; // unreachable, written to stop compiler warning @@ -87,6 +92,9 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode); return vtable; } }; @@ -103,6 +111,9 @@ class TypeVisitor : public TypeFunctor { void VisitType_(const TupleTypeNode* op) override; void VisitType_(const TypeRelationNode* op) override; void VisitType_(const RefTypeNode* op) override; + void VisitType_(const GlobalTypeVarNode* op) override; + void VisitType_(const TypeCallNode* op) override; + void VisitType_(const TypeDataNode* op) override; }; // Mutator that transform a type to another one. @@ -115,6 +126,9 @@ class TypeMutator : public TypeFunctor { Type VisitType_(const TupleTypeNode* op) override; Type VisitType_(const TypeRelationNode* type_rel) override; Type VisitType_(const RefTypeNode* op) override; + Type VisitType_(const GlobalTypeVarNode* op) override; + Type VisitType_(const TypeCallNode* op) override; + Type VisitType_(const TypeDataNode* op) override; private: Array MutateArray(Array arr); diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index a6298ba448f3..169aef3b6a4a 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -296,6 +296,15 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ExprVisitor::VisitExpr_(op); this->AddNode(op); } + + void VisitExpr_(const MatchNode* op) final { + this->Update(op->data, nullptr, kOpaque); + for (const Clause& c : op->clauses) { + this->Update(c->rhs, nullptr, kOpaque); + } + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } }; IndexedForwardGraph IndexedForwardGraph::Create( diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 200f5385a37a..f1e539d71d48 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -14,106 +14,160 @@ * contains a data type such as `int`, `float`, `uint`. */ #include +#include #include "../ir/type_functor.h" namespace tvm { namespace relay { using namespace tvm::runtime; -using Kind = TypeVarNode::Kind; -struct KindChecker : TypeVisitor { - bool valid; +struct KindChecker : TypeFunctor { + const Module& mod; + ErrorReporter err_reporter; - KindChecker() : valid(true) {} + explicit KindChecker(const Module& mod) : mod(mod), err_reporter() {} - // checks if t is an incomplete node of kind k or a type param of kind k - bool MatchKind(const Type& t, Kind k) { - if (const IncompleteTypeNode* tv = t.as()) { - return tv->kind == k; - } + void ReportFatalError(const Error& err) { + this->err_reporter.Report(err); + this->err_reporter.RenderErrors(mod); + } - if (const TypeVarNode* tp = t.as()) { - return tp->kind == k; + void CheckKindMatches(const Type& t, const Type& outer, + Kind expected, const std::string& description) { + Kind k = this->VisitType(t); + if (k != expected) { + ReportFatalError(RELAY_ERROR("Incorrect kind for a " << description + << ". Type " << t << " inside " << outer + << " is of kind " << k + << " but was expected to be " + << expected)); } + } - return false; + Kind VisitType_(const IncompleteTypeNode* op) override { + return op->kind; } - bool IsTypeKind(const Type& t) { - if (MatchKind(t, Kind::kType)) { - return true; - } + Kind VisitType_(const TypeVarNode* op) override { + return op->kind; + } + + Kind VisitType_(const GlobalTypeVarNode* op) override { + return op->kind; + } - return t.as_derived() || t.as() || t.as(); + Kind VisitType_(const TensorTypeNode* op) override { + return Kind::kType; } - void VisitType_(const TupleTypeNode* op) override { + Kind VisitType_(const TupleTypeNode* op) override { // tuples should only contain normal types for (const Type& t : op->fields) { - this->VisitType(t); - valid = valid && IsTypeKind(t); - if (!valid) { - return; - } + CheckKindMatches(t, GetRef(op), Kind::kType, + "tuple member"); } + return Kind::kType; } - void VisitType_(const FuncTypeNode* op) override { + Kind VisitType_(const FuncTypeNode* op) override { // Func types should only take normal types for arguments // and only return a normal type. They should also have // well-formed constraints + FuncType ft = GetRef(op); for (const Type& t : op->arg_types) { - this->VisitType(t); - valid = valid && IsTypeKind(t); - if (!valid) { - return; - } + CheckKindMatches(t, ft, Kind::kType, "function type parameter"); } + CheckKindMatches(ft->ret_type, ft, Kind::kType, "function return type"); + for (const TypeConstraint& tc : op->type_constraints) { - this->VisitType(tc); - if (!valid) { - return; - } + CheckKindMatches(tc, ft, Kind::kConstraint, "function type constraint"); } - this->VisitType(op->ret_type); - valid = valid && IsTypeKind(op->ret_type); + return Kind::kType; } - void VisitType_(const RefTypeNode* op) override { - // tuples should only contain normal types - this->VisitType(op->value); - valid = valid && IsTypeKind(op->value); + Kind VisitType_(const RefTypeNode* op) override { + // ref types should only contain normal types + RefType rt = GetRef(op); + CheckKindMatches(op->value, rt, Kind::kType, "ref contents"); + return Kind::kType; } - void VisitType_(const TypeRelationNode* op) override { + Kind VisitType_(const TypeRelationNode* op) override { // arguments to type relation should be normal types for (const Type& t : op->args) { - this->VisitType(t); - valid = valid && IsTypeKind(t); - if (!valid) { - return; + CheckKindMatches(t, GetRef(op), Kind::kType, + "argument to type relation"); + } + return Kind::kConstraint; + } + + Kind VisitType_(const TypeCallNode* op) override { + // type call func should be a global type var, args should be type + TypeCall tc = GetRef(op); + const auto* gtv = op->func.as(); + if (gtv == nullptr) { + ReportFatalError(RELAY_ERROR("The callee in " << tc + << " is not a global type var, but is " << op->func)); + } + + CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function"); + + for (const Type& t : op->args) { + CheckKindMatches(t, tc, Kind::kType, "type call argument"); + } + + // finally we need to check the module to check the number of type params + auto var = GetRef(gtv); + auto data = mod->LookupDef(var); + if (data->type_vars.size() != op->args.size()) { + ReportFatalError(RELAY_ERROR("Expected " << data->type_vars.size() << "arguments for " << tc + << "; got " << op->args.size())); + } + return Kind::kType; + } + + Kind VisitType_(const TypeDataNode* op) override { + // Constructors can reference the header var, but no other GlobalTypeVars. + // In theory, a TypeData could be nested, so the header scope + // should be tracked recursively, but it is unclear that we need + // to support it. + TypeData td = GetRef(op); + CheckKindMatches(op->header, td, Kind::kAdtHandle, "type data header"); + + for (const auto& var : op->type_vars) { + CheckKindMatches(var, td, Kind::kType, "ADT type var"); + } + + for (const auto& con : op->constructors) { + if (!con->belong_to.same_as(op->header)) { + ReportFatalError(RELAY_ERROR(con << " has header " << con->belong_to + << " but " << op << "has header " << op->header)); + } + + for (const Type& t : con->inputs) { + CheckKindMatches(t, td, Kind::kType, "ADT constructor input"); } } + return Kind::kTypeData; } - bool Check(const Type& t) { - this->VisitType(t); - return valid; + Kind Check(const Type& t) { + return this->VisitType(t); } }; -bool KindCheck(const Type& t, const Module& mod) { - KindChecker kc; +Kind KindCheck(const Type& t, const Module& mod) { + KindChecker kc(mod); return kc.Check(t); } TVM_REGISTER_API("relay._ir_pass.check_kind") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { - *ret = KindCheck(args[0], ModuleNode::make({})); + *ret = KindCheck(args[0], ModuleNode::make({}, {})); } else { *ret = KindCheck(args[0], args[1]); } diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 2fecc8ba3727..3afbcba96ae6 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -62,7 +62,7 @@ class LetList { * \return a Var that hold the inserted expr. */ Var Push(Expr expr) { - return Push(IncompleteTypeNode::make(TypeVarNode::kType), expr); + return Push(IncompleteTypeNode::make(Kind::kType), expr); } /*! diff --git a/src/relay/pass/to_anf.cc b/src/relay/pass/to_anf.cc index a724d5f2e855..6d65fe449fb0 100644 --- a/src/relay/pass/to_anf.cc +++ b/src/relay/pass/to_anf.cc @@ -274,7 +274,7 @@ class Fill : ExprFunctor { } Expr VisitExpr(const Expr& e) { - Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(TypeVarNode::kType)); + Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); return this->VisitExpr(e, v); } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 10ba3b127bbf..fa3cea610c68 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -22,7 +22,9 @@ #include #include +#include #include +#include "./pass_util.h" #include "type_solver.h" #include "../ir/type_functor.h" @@ -79,7 +81,8 @@ struct ResolvedTypeInfo { // - Solve the constraints (solver_.Solve) // - Recreate expression with the resolved checked_type (Resolver.VisitExpr) // -class TypeInferencer : private ExprFunctor { +class TypeInferencer : private ExprFunctor, + private PatternFunctor { public: // constructors @@ -107,10 +110,6 @@ class TypeInferencer : private ExprFunctor { // type inferencer will populate it up std::unordered_map type_map_; - // used to ensure we don't have free type vars hanging around - // (a temporary measure until we have proper generalization implemented) - Map instantiation_map_; - // The solver used by the inferencer. TypeSolver solver_; // relation function @@ -119,7 +118,7 @@ class TypeInferencer : private ExprFunctor { // Perform unification on two types and report the error at the expression // or the span of the expression. - Type Unify(const Type& t1, const Type& t2, const Expr& expr) { + Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) { // TODO(tqchen, jroesch): propagate span to solver try { return solver_.Unify(t1, t2, expr); @@ -135,31 +134,6 @@ class TypeInferencer : private ExprFunctor { } } - // Substitutes every type var in t with a corresponding incomplete type. - // This is a temporary measure to ensure type vars behave until - // generalization is properly implemented. - Type Instantiate(const Type &t) { - if (!t.defined()) { - return t; - } - auto* ft = t.as(); - if (ft == nullptr) { - return Bind(t, instantiation_map_); - } - - for (auto type_param : ft->type_params) { - instantiation_map_.Set(type_param, IncompleteTypeNode::make(TypeVarNode::Kind::kType)); - } - - Type ret_type = ft->ret_type; - if (!ret_type.defined()) { - ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); - } - - auto strip_tvs = FuncTypeNode::make(ft->arg_types, ret_type, {}, ft->type_constraints); - return Bind(strip_tvs, instantiation_map_); - } - // Lazily get type for expr // expression, we will populate it now, and return the result. Type GetType(const Expr &expr) { @@ -167,13 +141,14 @@ class TypeInferencer : private ExprFunctor { if (it != type_map_.end() && it->second.checked_type.defined()) { return it->second.checked_type; } - Type ret = Instantiate(this->VisitExpr(expr)); + Type ret = this->VisitExpr(expr); + KindCheck(ret, mod_); ResolvedTypeInfo& rti = type_map_[expr]; rti.checked_type = ret; return ret; } - void ReportFatalError(const Expr& expr, const Error& err) { + void ReportFatalError(const NodeRef& expr, const Error& err) { CHECK(this->current_func_.defined()); this->err_reporter.ReportAt(this->current_func_, expr, err); this->err_reporter.RenderErrors(this->mod_); @@ -184,7 +159,7 @@ class TypeInferencer : private ExprFunctor { if (op->type_annotation.defined()) { return op->type_annotation; } else { - return IncompleteTypeNode::make(TypeVarNode::kType); + return IncompleteTypeNode::make(Kind::kType); } } @@ -219,7 +194,7 @@ class TypeInferencer : private ExprFunctor { EnvFunc::Get("tvm.relay.type_relation.TupleGetItem").node_); } Type tuple_type = GetType(op->tuple); - Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type rtype = IncompleteTypeNode::make(Kind::kType); auto attrs = make_node(); attrs->index = op->index; solver_.AddConstraint(TypeRelationNode::make( @@ -227,6 +202,70 @@ class TypeInferencer : private ExprFunctor { return rtype; } + void VisitPattern_(const PatternConstructorNode* con, const Type& t) { + CHECK(mod_.defined()) + << "Cannot do type inference without a environment:" + << con->constructor->name_hint; + TypeData td = mod_->type_definitions.at(con->constructor->belong_to); + auto pc = GetRef(con); + + // we can expect a certain number of arguments + Array unknown_args; + for (size_t i = 0; i < td->type_vars.size(); i++) { + unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); + } + Type expected = TypeCallNode::make(con->constructor->belong_to, unknown_args); + Type unified = Unify(t, expected, GetRef(con)); + + auto* tc = unified.as(); + if (!tc) { + this->ReportFatalError(pc, RELAY_ERROR("Expected a type call, got " << unified)); + } + if (td->header != tc->func) { + this->ReportFatalError(pc, RELAY_ERROR("ADT headers must match, but we have " + << td->header << " and " << tc->func)); + } + if (td->type_vars.size() != tc->args.size()) { + this->ReportFatalError(pc, RELAY_ERROR("The number of type args must match" + << "the number of type vars in the type data: " + << td->type_vars.size() << " != " << tc->args.size())); + } + std::unordered_map type_var_map_; + for (size_t i = 0; i < td->type_vars.size(); ++i) { + type_var_map_[td->type_vars[i]] = tc->args[i]; + } + CHECK(con->constructor->inputs.size() == con->patterns.size()) << "not enough pattern"; + if (con->constructor->inputs.size() != con->patterns.size()) { + this->ReportFatalError(pc, RELAY_ERROR("Not enough inputs for the constructor; " + << "expected " << con->constructor->inputs.size() + << ", got " << con->patterns.size())); + } + for (size_t i = 0; i < con->constructor->inputs.size(); ++i) { + VisitPattern(con->patterns[i], Bind(con->constructor->inputs[i], type_var_map_)); + } + } + + void VisitPattern_(const PatternVarNode* pv, const Type& t) { + Type vt = GetType(pv->var); + Unify(vt, t, pv->span); + } + + void VisitPattern_(const PatternWildcardNode* wc, const Type& t) { } + + Type VisitExpr_(const MatchNode* op) final { + Type dtype = GetType(op->data); + for (const auto& c : op->clauses) { + VisitPattern(c->lhs, dtype); + } + Type rtype = IncompleteTypeNode::make(Kind::kType); + for (const auto& c : op->clauses) { + rtype = this->Unify(rtype, + GetType(c->rhs), + op->span); + } + return rtype; + } + Type VisitExpr_(const OpNode* op) final { return op->op_type; } @@ -235,7 +274,7 @@ class TypeInferencer : private ExprFunctor { // if the definition is a function literal, permit recursion bool is_functional_literal = let->value.as() != nullptr; if (is_functional_literal) { - type_map_[let->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + type_map_[let->var].checked_type = IncompleteTypeNode::make(Kind::kType); } Type vtype = GetType(let->value); @@ -276,7 +315,7 @@ class TypeInferencer : private ExprFunctor { for (size_t i = 0; i < op->type_params.size(); ++i) { if (!op->type_params[i].same_as(rel->args[i])) return Type(); } - Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type rtype = IncompleteTypeNode::make(Kind::kType); arg_types.push_back(rtype); // we can do simple replacement here solver_.AddConstraint(TypeRelationNode::make( @@ -302,7 +341,7 @@ class TypeInferencer : private ExprFunctor { // This is a temporary work around to check recursive functions whose // return type is not yet known. if (!ret_type.defined()) { - ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + ret_type = IncompleteTypeNode::make(Kind::kType); } Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, @@ -338,7 +377,7 @@ class TypeInferencer : private ExprFunctor { // incomplete type => it must be a function taking the arg types // with an unknown return type if (inc_ty_node != nullptr) { - Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type ret_type = IncompleteTypeNode::make(Kind::kType); Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); Type unified = this->Unify(ftype, func_type, GetRef(call)); fn_ty_node = unified.as(); @@ -347,7 +386,7 @@ class TypeInferencer : private ExprFunctor { Array type_args = call->type_args; if (type_args.size() == 0) { for (size_t i = 0; i < fn_ty_node->type_params.size(); i++) { - type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + type_args.push_back(IncompleteTypeNode::make(Kind::kType)); } } @@ -428,6 +467,7 @@ class TypeInferencer : private ExprFunctor { if (f->ret_type.defined()) { rtype = this->Unify(f->ret_type, rtype, GetRef(f)); } + CHECK(rtype.defined()); auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); return solver_.Resolve(ret); } @@ -437,20 +477,33 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const RefReadNode* op) final { - Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type it = IncompleteTypeNode::make(Kind::kType); this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef(op)); return it; } Type VisitExpr_(const RefWriteNode* op) final { - Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type it = IncompleteTypeNode::make(Kind::kType); this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef(op)); this->Unify(GetType(op->value), it, GetRef(op)); return TupleTypeNode::make({}); } + + Type VisitExpr_(const ConstructorNode* c) final { + CHECK(mod_.defined()) + << "Cannot do type inference without a environment:" + << c->name_hint; + TypeData td = mod_->LookupDef(c->belong_to); + std::vector types; + for (const auto & t : td->type_vars) { + types.push_back(t); + } + return FuncTypeNode::make(c->inputs, TypeCallNode::make(c->belong_to, types), + td->type_vars, {}); + } }; -class TypeInferencer::Resolver : public ExprMutator { +class TypeInferencer::Resolver : public ExprMutator, PatternMutator { public: Resolver(const std::unordered_map& tmap, TypeSolver* solver) @@ -458,7 +511,7 @@ class TypeInferencer::Resolver : public ExprMutator { } Expr VisitExpr_(const VarNode* op) final { - return AttachCheckedType(op); + return VisitVar(GetRef(op)); } Expr VisitExpr_(const ConstantNode* op) final { @@ -509,6 +562,25 @@ class TypeInferencer::Resolver : public ExprMutator { return AttachCheckedType(op); } + Expr VisitExpr_(const ConstructorNode* op) final { + return GetRef(op); + } + + Expr VisitExpr_(const MatchNode* op) final { + return AttachCheckedType(op); + } + + Pattern VisitPattern(const Pattern& p) final { + return PatternMutator::VisitPattern(p); + } + + Var VisitVar(const Var& v) final { + if (vmap_.count(v) == 0) { + vmap_[v] = GetRef(AttachCheckedType(v.as()).as()); + } + return vmap_.at(v); + } + // attach checked type to the mutated node. template Expr AttachCheckedType(const T* op) { @@ -601,6 +673,7 @@ class TypeInferencer::Resolver : public ExprMutator { } private: + std::unordered_map vmap_; const std::unordered_map& tmap_; TypeSolver* solver_; // whether attach the checked type as type_annotation @@ -625,6 +698,19 @@ Expr TypeInferencer::Infer(Expr expr) { return resolved_expr; } +struct AllCheckTypePopulated : ExprVisitor { + void VisitExpr(const Expr& e) { + if (e.as()) { return; } + if (e.as()) { return; } + if (e.as()) { return; } + CHECK(e->checked_type_.defined()) << "Expression: " << e; + return ExprVisitor::VisitExpr(e); + } +}; + +void EnsureCheckedType(const Expr& e) { + AllCheckTypePopulated().VisitExpr(e); +} Expr InferType(const Expr& expr, const Module& mod_ref) { if (!mod_ref.defined()) { @@ -645,6 +731,10 @@ Expr InferType(const Expr& expr, const Module& mod_ref) { } else { auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr); CHECK(WellFormed(e)); + auto free_tvars = FreeTypeVars(e, mod_ref); + CHECK(free_tvars.size() == 0) + << "Found unbound type variables in " << e << ": " << free_tvars; + EnsureCheckedType(e); return e; } } @@ -658,6 +748,9 @@ Function InferType(const Function& func, Expr func_ret = TypeInferencer(mod, var).Infer(func_copy); mod->Remove(var); CHECK(WellFormed(func_ret)); + auto free_tvars = FreeTypeVars(func_ret, mod); + CHECK(free_tvars.size() == 0) + << "Found unbound type variables in " << func << ": " << free_tvars; return Downcast(func_ret); } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index fcd39e791339..fd15c91e79f7 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -189,6 +189,20 @@ class TypeSolver::Unifier : public TypeFunctor { return RefTypeNode::make(Unify(op->value, rtn->value)); } + Type VisitType_(const TypeCallNode* op, const Type& tn) override { + const auto* tcn = tn.as(); + if (!tcn || tcn->args.size() != op->args.size()) { + return Type(); + } + + Type func = Unify(op->func, tcn->func); + tvm::Array args; + for (size_t i = 0; i < op->args.size(); i++) { + args.push_back(Unify(op->args[i], tcn->args[i])); + } + return TypeCallNode::make(func, args); + } + private: TypeSolver* solver_; }; @@ -266,6 +280,16 @@ class TypeSolver::Propagator : public TypeFunctor { } } + void VisitType_(const TypeCallNode* op) override { + TypeCall tc = GetRef(op); + UpdateRelSet(tc); + + Propagate(tc->func); + for (auto arg : tc->args) { + Propagate(arg); + } + } + private: TypeSolver* solver_; const std::unordered_set* rels_; @@ -494,7 +518,7 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") } else if (name == "AddConstraint") { return TypedPackedFunc([solver](TypeConstraint c) { Expr e = VarNode::make("dummy_var", - IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + IncompleteTypeNode::make(Kind::kType)); return solver->AddConstraint(c, e); }); } else { diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 403863c1d757..76fc0aa1a45e 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -7,6 +7,7 @@ */ #include #include +#include #include "../ir/type_functor.h" namespace tvm { @@ -29,7 +30,7 @@ class TypeVarTVisitor : public TypeVisitor { TypeVarTVisitor( InsertionSet* type_vars, InsertionSet* bound_type_vars) - : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { } + : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { } void VisitType_(const TypeVarNode* tp) final { TypeVar var = GetRef(tp); @@ -51,6 +52,8 @@ class TypeVarTVisitor : public TypeVisitor { class TypeVarEVisitor : private ExprVisitor { public: + explicit TypeVarEVisitor(const Module& mod) : mod_(mod) {} + Array CollectFree() { Array ret; for (const auto& v : type_vars_.data) { @@ -115,6 +118,16 @@ class TypeVarEVisitor : private ExprVisitor { ExprVisitor::VisitExpr_(f); } + void VisitExpr_(const ConstructorNode* cn) final { + // for constructors, type vars will be bound in the module + auto data = mod_->LookupDef(cn->belong_to); + for (const auto& tv : data->type_vars) { + type_vars_.Insert(tv); + bound_type_vars_.Insert(tv); + } + ExprVisitor::VisitExpr_(cn); + } + void VisitType(const Type& t) final { TypeVarTVisitor(&type_vars_, &bound_type_vars_) .VisitType(t); @@ -123,9 +136,10 @@ class TypeVarEVisitor : private ExprVisitor { private: InsertionSet type_vars_; InsertionSet bound_type_vars_; + const Module& mod_; }; -class VarVisitor : protected ExprVisitor { +class VarVisitor : protected ExprVisitor, protected PatternVisitor { public: Array Free(const Expr& expr) { this->VisitExpr(expr); @@ -178,33 +192,41 @@ class VarVisitor : protected ExprVisitor { VisitExpr(op->body); } + void VisitPattern(const Pattern& p) final { + PatternVisitor::VisitPattern(p); + } + + void VisitPattern_(const PatternVarNode* op) final { + MarkBounded(op->var); + } + private: InsertionSet vars_; InsertionSet bound_vars_; }; -tvm::Array FreeTypeVars(const Expr& expr) { - return TypeVarEVisitor().Free(expr); +tvm::Array FreeTypeVars(const Expr& expr, const Module& mod) { + return TypeVarEVisitor(mod).Free(expr); } -tvm::Array FreeTypeVars(const Type& type) { - return TypeVarEVisitor().Free(type); +tvm::Array FreeTypeVars(const Type& type, const Module& mod) { + return TypeVarEVisitor(mod).Free(type); } -tvm::Array BoundTypeVars(const Expr& expr) { - return TypeVarEVisitor().Bound(expr); +tvm::Array BoundTypeVars(const Expr& expr, const Module& mod) { + return TypeVarEVisitor(mod).Bound(expr); } -tvm::Array BoundTypeVars(const Type& type) { - return TypeVarEVisitor().Bound(type); +tvm::Array BoundTypeVars(const Type& type, const Module& mod) { + return TypeVarEVisitor(mod).Bound(type); } -tvm::Array AllTypeVars(const Expr& expr) { - return TypeVarEVisitor().All(expr); +tvm::Array AllTypeVars(const Expr& expr, const Module& mod) { + return TypeVarEVisitor(mod).All(expr); } -tvm::Array AllTypeVars(const Type& type) { - return TypeVarEVisitor().All(type); +tvm::Array AllTypeVars(const Type& type, const Module& mod) { + return TypeVarEVisitor(mod).All(type); } tvm::Array FreeVars(const Expr& expr) { @@ -237,30 +259,33 @@ TVM_REGISTER_API("relay._ir_pass.all_vars") TVM_REGISTER_API("relay._ir_pass.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; + Module mod = args[1]; if (x.as_derived()) { - *ret = FreeTypeVars(Downcast(x)); + *ret = FreeTypeVars(Downcast(x), mod); } else { - *ret = FreeTypeVars(Downcast(x)); + *ret = FreeTypeVars(Downcast(x), mod); } }); TVM_REGISTER_API("relay._ir_pass.bound_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; + Module mod = args[1]; if (x.as_derived()) { - *ret = BoundTypeVars(Downcast(x)); + *ret = BoundTypeVars(Downcast(x), mod); } else { - *ret = BoundTypeVars(Downcast(x)); + *ret = BoundTypeVars(Downcast(x), mod); } }); TVM_REGISTER_API("relay._ir_pass.all_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; + Module mod = args[1]; if (x.as_derived()) { - *ret = AllTypeVars(Downcast(x)); + *ret = AllTypeVars(Downcast(x), mod); } else { - *ret = AllTypeVars(Downcast(x)); + *ret = AllTypeVars(Downcast(x), mod); } }); diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 50aed4c57338..62f1d914b510 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -13,7 +13,10 @@ TEST(Relay, SelfReference) { auto y = relay::VarNode::make("y", tensor_type); auto call = relay::CallNode::make(f, Array{ y }); auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); - auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map{})); + auto empty_module = + relay::ModuleNode::make(Map{}, + Map{}); + auto type_fx = relay::InferType(fx, empty_module); auto expected = relay::FuncTypeNode::make(tvm::Array{ tensor_type }, tensor_type, {}, {}); CHECK(AlphaEqual(type_fx->checked_type(), expected)); diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py new file mode 100644 index 000000000000..5acae6c70295 --- /dev/null +++ b/tests/python/relay/test_adt.py @@ -0,0 +1,600 @@ +import tvm +from tvm import relay +from tvm.relay.ir_pass import infer_type +from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue +from tvm.relay import testing, create_executor +from tvm.relay.prelude import Prelude + +mod = relay.Module() +p = Prelude(mod) +ctx = tvm.context("llvm", 0) +intrp = create_executor(mod=mod, ctx=ctx, target="llvm") + +z = p.z +s = p.s +nat = p.nat +double = p.double +add = p.add + +optional = p.optional +some = p.some +none = p.none + +nil = p.nil +cons = p.cons +l = p.l +length = p.length +map = p.map +foldl = p.foldl +foldr = p.foldr +sum = p.sum + +concat = p.concat +filter = p.filter +zip = p.zip +rev = p.rev +unfoldl = p.unfoldl +unfoldr = p.unfoldr +map_accumr = p.map_accumr +map_accuml = p.map_accuml + +tree = p.tree +rose = p.rose +tmap = p.tmap +size = p.size + +# this is an example of using the adt value in python side +def count(n): + assert isinstance(n, ConstructorValue) + if n.constructor.name_hint == 's': + return 1 + count(n.fields[0]) + else: + assert n.constructor.name_hint == 'z' + return 0 + +# this is an example of creating the adt value in python side +def make_nat(n): + if n != 0: + return ConstructorValue(s, [make_nat(n - 1)], []) + else: + return ConstructorValue(z, [], []) + +def build_nat(n): + assert n >= 0 + ret = z() + while n > 0: + ret = s(ret) + n = n - 1 + return ret + +def to_list(l): + assert isinstance(l, ConstructorValue) + val = l + ret = [] + while True: + if val.constructor.name_hint == 'cons': + ret.append(val.fields[0]) + val = val.fields[1] + else: + assert val.constructor.name_hint == 'nil' + break + return ret + +def tree_to_dict(t): + assert isinstance(t, ConstructorValue) + ret = {} + assert t.constructor.name_hint == 'rose' + ret['member'] = t.fields[0] + ret['children'] = [] + for subtree in to_list(t.fields[1]): + l = tree_to_dict(subtree) + ret['children'].append(l) + return ret + +def test_nat_value(): + assert count(make_nat(10)) == 10 + + +def test_nat_constructor(): + assert relay.ir_pass.infer_type(z(), mod).checked_type == nat() + assert relay.ir_pass.infer_type(s(z()), mod).checked_type == nat() + + +def test_double(): + assert mod[double].checked_type == relay.FuncType([nat()], nat()) + res = intrp.evaluate(double(s(z()))) + assert count(res) == 2 + + +def test_add(): + assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) + res = intrp.evaluate(add(s(z()), s(z()))) + assert count(res) == 2 + + +def test_list_constructor(): + a = relay.TypeVar("a") + assert relay.ir_pass.infer_type(cons(z(), nil()), mod).checked_type == l(nat()) + + +def test_length(): + a = relay.TypeVar("a") + assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a]) + res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil()))))) + assert count(res) == 3 + + +def test_map(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + lhs = mod[map].checked_type + rhs = relay.FuncType([relay.FuncType([a], b), l(a)], l(b), [a, b]) + assert lhs == rhs + + x = relay.Var("x") + add_one = relay.Function([x], s(x)) + res = intrp.evaluate(map(add_one, cons(z(), cons(z(), nil())))) + ones = to_list(res) + assert len(ones) == 2 + assert count(ones[0]) == 1 and count(ones[1]) == 1 + + +def test_foldl(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + lhs = mod[foldl].checked_type + rhs = relay.FuncType([relay.FuncType([a, b], a), a, l(b)], a, [a, b]) + assert lhs == rhs + + x = relay.Var("x") + y = relay.Var("y") + rev_dup = relay.Function([y, x], cons(x, cons(x, y))) + res = intrp.evaluate(foldl(rev_dup, nil(), + cons(build_nat(1), + cons(build_nat(2), + cons(build_nat(3), nil()))))) + reversed = to_list(res) + assert len(reversed) == 6 + assert count(reversed[0]) == 3 and count(reversed[1]) == 3 + assert count(reversed[2]) == 2 and count(reversed[3]) == 2 + assert count(reversed[4]) == 1 and count(reversed[5]) == 1 + + +def test_foldr(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + lhs = mod[foldr].checked_type + rhs = relay.FuncType([relay.FuncType([a, b], b), b, l(a)], b, [a, b]) + assert lhs == rhs + + x = relay.Var("x") + y = relay.Var("y") + identity = relay.Function([x, y], cons(x, y)) + res = intrp.evaluate(foldr(identity, nil(), + cons(build_nat(1), + cons(build_nat(2), + cons(build_nat(3), nil()))))) + same = to_list(res) + assert len(same) == 3 + assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3 + + +def test_sum(): + assert mod[sum].checked_type == relay.FuncType([l(nat())], nat()) + res = intrp.evaluate(sum(cons(build_nat(1), cons(build_nat(2), nil())))) + assert count(res) == 3 + + +def test_concat(): + a = relay.TypeVar("a") + assert mod[concat].checked_type == relay.FuncType([l(a), l(a)], l(a), [a]) + + l1 = cons(build_nat(1), cons(build_nat(2), nil())) + l2 = cons(build_nat(3), cons(build_nat(4), nil())) + res = intrp.evaluate(concat(l1, l2)) + + catted = to_list(res) + assert len(catted) == 4 + assert count(catted[0]) == 1 + assert count(catted[1]) == 2 + assert count(catted[2]) == 3 + assert count(catted[3]) == 4 + + +def test_filter(): + a = relay.TypeVar("a") + expected_type = relay.FuncType([ + relay.FuncType([a], relay.scalar_type("bool")), l(a) + ], l(a), [a]) + assert mod[filter].checked_type == expected_type + + x = relay.Var("x", nat()) + greater_than_one = relay.Function( + [x], + relay.Match(x, [ + relay.Clause( + relay.PatternConstructor(s, [ + relay.PatternConstructor( + s, [relay.PatternWildcard()]) + ]), + relay.const(True)), + relay.Clause(relay.PatternWildcard(), relay.const(False)) + ])) + res = intrp.evaluate( + filter(greater_than_one, + cons(build_nat(1), + cons(build_nat(1), + cons(build_nat(3), + cons(build_nat(1), + cons(build_nat(5), + cons(build_nat(1), + nil())))))))) + filtered = to_list(res) + assert len(filtered) == 2 + assert count(filtered[0]) == 3 + assert count(filtered[1]) == 5 + + +def test_zip(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + expected_type = relay.FuncType([l(a), l(b)], + l(relay.TupleType([a, b])), [a, b]) + assert mod[zip].checked_type == expected_type + + l1 = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) + l2 = cons(nil(), + cons(cons(nil(), nil()), + cons(cons(nil(), cons(nil(), nil())), + nil()))) + + res = intrp.evaluate(zip(l1, l2)) + zipped = to_list(res) + assert len(zipped) == 3 + assert count(zipped[0][0]) == 1 + assert len(to_list(zipped[0][1])) == 0 + assert count(zipped[1][0]) == 2 + assert len(to_list(zipped[1][1])) == 1 + assert count(zipped[2][0]) == 3 + assert len(to_list(zipped[2][1])) == 2 + + # test truncation + l3 = cons(build_nat(4), cons(build_nat(5), nil())) + shorter_res = intrp.evaluate(zip(l3, l2)) + truncated = to_list(shorter_res) + assert len(truncated) == 2 + assert count(truncated[0][0]) == 4 + assert len(to_list(truncated[0][1])) == 0 + assert count(truncated[1][0]) == 5 + assert len(to_list(truncated[1][1])) == 1 + + l4 = cons(nil(), nil()) + shortest_res = intrp.evaluate(zip(l3, l4)) + singleton = to_list(shortest_res) + assert len(singleton) == 1 + assert count(singleton[0][0]) == 4 + assert len(to_list(singleton[0][1])) == 0 + + +def test_rev(): + a = relay.TypeVar("a") + assert mod[rev].checked_type == relay.FuncType([l(a)], l(a), [a]) + + res = intrp.evaluate(rev(cons(build_nat(1), + cons(build_nat(2), + cons(build_nat(3), nil()))))) + reversed = to_list(res) + + assert len(reversed) == 3 + assert count(reversed[0]) == 3 + assert count(reversed[1]) == 2 + assert count(reversed[2]) == 1 + + +def test_unfoldr(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + expected_type = relay.FuncType([ + relay.FuncType([a], optional(relay.TupleType([a, b]))), a], + l(b), [a, b]) + + x = relay.Var("x", nat()) + n = relay.Var("n", nat()) + count_down = relay.Function( + [x], + relay.Match(x, [ + relay.Clause(relay.PatternConstructor( + s, [relay.PatternVar(n)]), + some(relay.Tuple([n, x]))), + relay.Clause(relay.PatternConstructor(z, []), none()) + ])) + + res = intrp.evaluate(unfoldr(count_down, build_nat(3))) + unfolded = to_list(res) + + assert len(unfolded) == 3 + assert count(unfolded[0]) == 3 + assert count(unfolded[1]) == 2 + assert count(unfolded[2]) == 1 + + +def test_unfoldl(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + expected_type = relay.FuncType([ + relay.FuncType([a], optional(relay.TupleType([a, b]))), a], + l(b), [a, b]) + + x = relay.Var("x", nat()) + n = relay.Var("n", nat()) + count_down = relay.Function( + [x], + relay.Match(x, [ + relay.Clause(relay.PatternConstructor( + s, [relay.PatternVar(n)]), + some(relay.Tuple([n, x]))), + relay.Clause(relay.PatternConstructor(z, []), none()) + ])) + + res = intrp.evaluate(unfoldl(count_down, build_nat(3))) + unfolded = to_list(res) + + assert len(unfolded) == 3 + assert count(unfolded[0]) == 1 + assert count(unfolded[1]) == 2 + assert count(unfolded[2]) == 3 + + +def test_map_accumr(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + c = relay.TypeVar("c") + expected_type = relay.FuncType([ + relay.FuncType([a, b], relay.TupleType([a, c])), + a, l(b) + ], relay.TupleType([a, l(c)]), [a, b, c]) + assert mod[map_accumr].checked_type == expected_type + + acc = relay.Var("acc", nat()) + x = relay.Var("x", nat()) + add_acc_to_each = relay.Function([acc, x], + relay.Tuple([add(x, acc), + add(x, acc)])) + + vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) + res = intrp.evaluate(map_accumr(add_acc_to_each, z(), vals)) + + sum = count(res[0]) + new_vals = to_list(res[1]) + + assert sum == 6 + assert len(new_vals) == 3 + assert count(new_vals[0]) == 6 + assert count(new_vals[1]) == 5 + assert count(new_vals[2]) == 3 + + +def test_map_accuml(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + c = relay.TypeVar("c") + expected_type = relay.FuncType([ + relay.FuncType([a, b], relay.TupleType([a, c])), + a, l(b) + ], relay.TupleType([a, l(c)]), [a, b, c]) + assert mod[map_accuml].checked_type == expected_type + + acc = relay.Var("acc", nat()) + x = relay.Var("x", nat()) + add_to_acc = relay.Function([acc, x], + relay.Tuple([add(x, acc), x])) + + vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) + res = intrp.evaluate(map_accuml(add_to_acc, z(), vals)) + + sum = count(res[0]) + new_vals = to_list(res[1]) + + assert sum == 6 + assert len(new_vals) == 3 + assert count(new_vals[0]) == 3 + assert count(new_vals[1]) == 2 + assert count(new_vals[2]) == 1 + + +def test_optional_matching(): + x = relay.Var('x') + y = relay.Var('y') + v = relay.Var('v') + condense = relay.Function( + [x, y], + relay.Match(x, [ + relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(v)]), cons(v, y)), + relay.Clause(relay.PatternConstructor(none), y) + ])) + + res = intrp.evaluate(foldr(condense, nil(), cons( + some(build_nat(3)), + cons(none(), cons(some(build_nat(1)), nil()))))) + + reduced = to_list(res) + assert len(reduced) == 2 + assert count(reduced[0]) == 3 + assert count(reduced[1]) == 1 + + +def test_tmap(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + lhs = mod[tmap].checked_type + rhs = relay.FuncType([relay.FuncType([a], b), tree(a)], tree(b), [a, b]) + assert lhs == rhs + + x = relay.Var("x") + add_one = relay.Function([x], s(x)) + res = intrp.evaluate(tmap(add_one, + rose(z(), + cons(rose(z(), nil()), + cons(rose(z(), nil()), + nil()))))) + + tree_dict = tree_to_dict(res) + assert count(tree_dict['member']) == 1 + assert len(tree_dict['children']) == 2 + for subtree in tree_dict['children']: + assert count(subtree['member']) == 1 + assert len(subtree['children']) == 0 + + +def test_size(): + a = relay.TypeVar("a") + lhs = mod[size].checked_type + rhs = relay.FuncType([tree(a)], nat(), [a]) + assert lhs == rhs + + root = rose(z(), cons(rose(z(), nil()), + cons(rose(z(), nil()), + nil()))) + t = rose(z(), cons(root, cons(root, cons(root, nil())))) + res = intrp.evaluate(size(t)) + assert count(res) == 10 + + +def test_wildcard_match_solo(): + x = relay.Var('x', nat()) + copy = relay.Function([x], + relay.Match(x, [relay.Clause(relay.PatternWildcard(), x)]), + nat()) + + res = intrp.evaluate(copy(s(s(s(z()))))) + assert count(res) == 3 + + +def test_wildcard_match_order(): + x = relay.Var('x', l(nat())) + y = relay.Var('y') + a = relay.Var('a') + return_zero = relay.Function( + [x], + relay.Match(x, [ + relay.Clause(relay.PatternWildcard(), z()), + relay.Clause( + relay.PatternConstructor( + cons, [relay.PatternVar(y), relay.PatternVar(a)]), + y), + relay.Clause(relay.PatternConstructor(nil), s(z())) + ]), + nat()) + + res = intrp.evaluate(return_zero(cons(s(z()), nil()))) + # wildcard pattern is evaluated first + assert count(res) == 0 + + +def test_nested_matches(): + a = relay.TypeVar('a') + x = relay.Var('x') + y = relay.Var('y') + w = relay.Var('w') + h = relay.Var('h') + t = relay.Var('t') + flatten = relay.GlobalVar('flatten') + + # flatten could be written using a fold, but this way has nested matches + inner_match = relay.Match( + y, [ + relay.Clause(relay.PatternConstructor(nil), flatten(w)), + relay.Clause(relay.PatternConstructor( + cons, [relay.PatternVar(h), relay.PatternVar(t)]), + cons(h, flatten(cons(t, w)))) + ]) + + mod[flatten] = relay.Function( + [x], + relay.Match(x, [ + relay.Clause(relay.PatternConstructor(nil), nil()), + relay.Clause(relay.PatternConstructor( + cons, [relay.PatternVar(y), relay.PatternVar(w)]), + inner_match) + ]), l(a), [a]) + + first_list = cons(build_nat(1), cons(build_nat(2), + cons(build_nat(3), nil()))) + second_list = cons(build_nat(4), cons(build_nat(5), + cons(build_nat(6), nil()))) + final_list = cons(first_list, cons(second_list, nil())) + + res = intrp.evaluate(flatten(final_list)) + + flat = to_list(res) + assert len(flat) == 6 + for i in range(6): + assert count(flat[i]) == i + 1 + + +def test_match_full_var(): + x = relay.Var('x') + v = relay.Var('v') + id_func = relay.Function([x], + relay.Match(x, + [relay.Clause(relay.PatternVar(v), + v)])) + + res1 = intrp.evaluate(id_func(nil())) + res2 = intrp.evaluate(id_func(cons(z(), cons(z(), nil())))) + + empty = to_list(res1) + assert len(empty) == 0 + + zeroes = to_list(res2) + assert len(zeroes) == 2 + assert count(zeroes[0]) == 0 + assert count(zeroes[1]) == 0 + + +def test_nested_pattern_match(): + x = relay.Var('x', l(nat())) + h1 = relay.Var('h1') + h2 = relay.Var('h2') + t = relay.Var('t') + match = relay.Match( + x, + [relay.Clause( + relay.PatternConstructor( + cons, + [relay.PatternVar(h1), + relay.PatternConstructor( + cons, + [relay.PatternVar(h2), relay.PatternVar(t)])]), + h2), + relay.Clause(relay.PatternWildcard(), z()) + ]) + get_second = relay.Function([x], match) + + res = intrp.evaluate(get_second(cons(s(z()), + cons(s(s(z())), + nil())))) + + assert count(res) == 2 + + +if __name__ == "__main__": + test_nat_constructor() + test_double() + test_add() + test_list_constructor() + test_length() + test_map() + test_foldl() + test_foldr() + test_concat() + test_filter() + test_zip() + test_rev() + test_unfoldl() + test_unfoldr() + test_map_accumr() + test_map_accuml() + test_sum() + test_tmap() + test_size() diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 5158d5c7cc9c..ca86aaa3313e 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -171,6 +171,29 @@ def test_type_relation_alpha_equal(): assert bigger != diff_num_inputs +def test_type_call_alpha_equal(): + h1 = relay.GlobalTypeVar("h1") + h2 = relay.GlobalTypeVar("h2") + t1 = relay.TensorType((1, 2), "float32") + t2 = relay.TensorType((1, 2, 3), "float32") + t3 = relay.TensorType((1, 2, 3, 4), "float32") + t4 = relay.TensorType((), "float32") + + tc = relay.TypeCall(h1, [t1, t2, t3]) + same = relay.TypeCall(h1, [t1, t2, t3]) + + different_func = relay.TypeCall(h2, [t1, t2, t3]) + different_arg = relay.TypeCall(h1, [t1, t2, t4]) + fewer_args = relay.TypeCall(h1, [t1, t2]) + more_args = relay.TypeCall(h1, [t1, t2, t3, t4]) + different_order_args = relay.TypeCall(h1, [t3, t2, t1]) + + assert tc == same + assert tc != different_func + assert tc != fewer_args + assert tc != more_args + assert tc != different_order_args + def test_constant_alpha_equal(): x = relay.const(1) @@ -453,6 +476,79 @@ def test_if_alpha_equal(): assert not alpha_equal(if_sample, different_false) +def test_constructor_alpha_equal(): + # smoke test: it should be pointer equality + mod = relay.Module() + p = relay.prelude.Prelude(mod) + + assert alpha_equal(p.nil, p.nil) + assert alpha_equal(p.cons, p.cons) + assert not alpha_equal(p.nil, p.cons) + + +def test_match_alpha_equal(): + mod = relay.Module() + p = relay.prelude.Prelude(mod) + + x = relay.Var('x') + y = relay.Var('y') + nil_case = relay.Clause(relay.PatternConstructor(p.nil), p.nil()) + cons_case = relay.Clause(relay.PatternConstructor(p.cons, + [relay.PatternVar(x), + relay.PatternVar(y)]), + p.cons(x, y)) + + z = relay.Var('z') + a = relay.Var('a') + equivalent_cons = relay.Clause(relay.PatternConstructor(p.cons, + [relay.PatternVar(z), + relay.PatternVar(a)]), + p.cons(z, a)) + + data = p.cons(p.z(), p.cons(p.z(), p.nil())) + + match = relay.Match(data, [nil_case, cons_case]) + equivalent = relay.Match(data, [nil_case, equivalent_cons]) + empty = relay.Match(data, []) + no_cons = relay.Match(data, [nil_case]) + no_nil = relay.Match(data, [cons_case]) + different_data = relay.Match(p.nil(), [nil_case, cons_case]) + different_order = relay.Match(data, [cons_case, nil_case]) + different_nil = relay.Match(data, [ + relay.Clause(relay.PatternConstructor(p.nil), p.cons(p.nil(), p.nil())), + cons_case + ]) + different_cons = relay.Match(data, [ + nil_case, + relay.Clause(relay.PatternConstructor(p.cons, + [relay.PatternWildcard(), + relay.PatternWildcard()]), + p.nil()) + ]) + another_case = relay.Match(data, [ + nil_case, + cons_case, + relay.Clause(relay.PatternWildcard(), p.nil()) + ]) + wrong_constructors = relay.Match(data, [ + relay.Clause(relay.PatternConstructor(p.z), p.nil()), + relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), + p.cons(x, p.nil())) + ]) + + assert alpha_equal(match, match) + assert alpha_equal(match, equivalent) + assert not alpha_equal(match, no_cons) + assert not alpha_equal(match, no_nil) + assert not alpha_equal(match, empty) + assert not alpha_equal(match, different_data) + assert not alpha_equal(match, different_order) + assert not alpha_equal(match, different_nil) + assert not alpha_equal(match, different_cons) + assert not alpha_equal(match, another_case) + assert not alpha_equal(match, wrong_constructors) + + def test_op_alpha_equal(): # only checks names op1 = relay.op.get("add") @@ -491,6 +587,7 @@ def test_graph_equal(): test_func_type_alpha_equal() test_tuple_type_alpha_equal() test_type_relation_alpha_equal() + test_type_call_alpha_equal() test_constant_alpha_equal() test_global_var_alpha_equal() test_tuple_alpha_equal() @@ -499,6 +596,8 @@ def test_graph_equal(): test_call_alpha_equal() test_let_alpha_equal() test_if_alpha_equal() + test_constructor_alpha_equal() + test_match_alpha_equal() test_op_alpha_equal() test_var_alpha_equal() test_graph_equal() diff --git a/tests/python/relay/test_pass_check_kind.py b/tests/python/relay/test_pass_check_kind.py index 5ead501157c5..4eab59a6edd0 100644 --- a/tests/python/relay/test_pass_check_kind.py +++ b/tests/python/relay/test_pass_check_kind.py @@ -1,6 +1,19 @@ import tvm from tvm import relay from tvm.relay.ir_pass import check_kind +from nose.tools import raises + + +def test_typevar_kind(): + # returns the same kind + tp1 = relay.TypeVar('tp1', relay.Kind.Type) + tp2 = relay.TypeVar('tp2', relay.Kind.Shape) + tp3 = relay.TypeVar('tp3', relay.Kind.Constraint) + + assert check_kind(tp1) == relay.Kind.Type + assert check_kind(tp2) == relay.Kind.Shape + assert check_kind(tp3) == relay.Kind.Constraint + def test_tuple_kind(): # only contain type kinds @@ -10,7 +23,7 @@ def test_tuple_kind(): fields = tvm.convert([tp, tf, tt]) tup_ty = relay.TupleType(fields) - assert check_kind(tup_ty) + assert check_kind(tup_ty) == relay.Kind.Type def test_func_kind(): @@ -30,7 +43,20 @@ def test_func_kind(): ret_type = relay.TupleType(tvm.convert([tp2, tensor_type])) tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) - assert check_kind(tf) + assert check_kind(tf) == relay.Kind.Type + + +def test_ref_kind(): + # only contain type kinds + tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') + ft = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) + + rt1 = relay.RefType(tt) + assert check_kind(rt1) == relay.Kind.Type + rt2 = relay.RefType(ft) + assert check_kind(rt2) == relay.Kind.Type + rt3 = relay.RefType(relay.TupleType([rt1, rt2])) + assert check_kind(rt3) == relay.Kind.Type def test_relation_kind(): @@ -41,9 +67,35 @@ def test_relation_kind(): args = tvm.convert([tf, tt, tp]) tr = relay.TypeRelation(None, args, 2, None) - assert check_kind(tr) + assert check_kind(tr) == relay.Kind.Constraint + + +def test_global_typevar_kind(): + v1 = relay.GlobalTypeVar('gtv1', relay.Kind.AdtHandle) + v2 = relay.GlobalTypeVar('gtv2', relay.Kind.Type) + + assert check_kind(v1) == relay.Kind.AdtHandle + assert check_kind(v2) == relay.Kind.Type + +def test_typecall_kind(): + gtv = relay.GlobalTypeVar('gtv') + mod = relay.Module() + data = relay.TypeData(gtv, [], []) + mod[gtv] = data + empty_call = relay.TypeCall(gtv, []) + assert check_kind(empty_call, mod) == relay.Kind.Type + + new_mod = relay.Module() + tv = relay.TypeVar('tv') + new_data = relay.TypeData(gtv, [tv], []) + new_mod[gtv] = new_data + call = relay.TypeCall(gtv, [relay.TupleType([])]) + assert check_kind(call, new_mod) == relay.Kind.Type + + +@raises(tvm._ffi.base.TVMError) def test_invalid_tuple_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) @@ -51,9 +103,10 @@ def test_invalid_tuple_kind(): fields = tvm.convert([tp1, tp2, tp3]) tup_ty = relay.TupleType(fields) - assert not check_kind(tup_ty) + check_kind(tup_ty) +@raises(tvm._ffi.base.TVMError) def test_invalid_func_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) @@ -65,51 +118,98 @@ def test_invalid_func_kind(): ret_type = tp3 tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) - assert not check_kind(tf) + check_kind(tf) + +@raises(tvm._ffi.base.TVMError) +def test_invalid_ref_kind(): + tp = relay.TypeVar('tp', relay.Kind.Shape) + rt = relay.RefType(tp) + check_kind(rt) + +@raises(tvm._ffi.base.TVMError) def test_invalid_relation_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) args = tvm.convert([tp1, tp2, tp3]) - tr = relay.TypeRelation(None, args, 2, None) - assert not check_kind(tr) + func = tvm.get_env_func("tvm.relay.type_relation.Broadcast") + tr = relay.TypeRelation(func, args, 2, None) + check_kind(tr) + + +@raises(tvm._ffi.base.TVMError) +def test_typecall_invalid_callee(): + # global type var must be an ADT handle + gtv = relay.GlobalTypeVar('v1', relay.Kind.Type) + check_kind(relay.TypeCall(gtv, [])) +@raises(tvm._ffi.base.TVMError) +def test_typecall_invalid_args(): + # args must all be type kind + mod = relay.Module() + gtv = relay.GlobalTypeVar('v1') + data = relay.TypeData(gtv, [], []) + mod[gtv] = data + + check_kind(relay.TypeCall(gtv, [data])) + + +@raises(tvm._ffi.base.TVMError) +def test_typecall_invalid_num_args(): + mod = relay.Module() + gtv = relay.GlobalTypeVar('v1') + tv = relay.TypeVar('tv') + data = relay.TypeData(gtv, [tv], []) + mod[gtv] = data + check_kind(relay.TypeCall(gtv, [])) + + +@raises(tvm._ffi.base.TVMError) def test_func_with_invalid_ret_type(): tp1 = relay.TypeVar('tp1', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Shape) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) + check_kind(tf) + +@raises(tvm._ffi.base.TVMError) def test_func_with_invalid_arg_types(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.Type) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) + check_kind(tf) + +@raises(tvm._ffi.base.TVMError) def test_func_with_invalid_tuple(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1])) tf = relay.FuncType(tvm.convert([]), ret_type, tvm.convert([tp1]), tvm.convert([])) - assert not check_kind(tf) + check_kind(tf) +@raises(tvm._ffi.base.TVMError) def test_func_with_invalid_relation(): tp1 = relay.TypeVar('tp1', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Shape) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) - tr = relay.TypeRelation(None, tvm.convert([tp2, tp3]), 1, None) + func = tvm.get_env_func("tvm.relay.type_relation.Identity") + tr = relay.TypeRelation(func, tvm.convert([tp2, tp3]), 1, None) tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr])) - assert not check_kind(tf) + check_kind(tf) +@raises(tvm._ffi.base.TVMError) def test_tuple_with_invalid_func(): tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') @@ -117,16 +217,23 @@ def test_tuple_with_invalid_func(): tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([])) tup_ty = relay.TupleType(tvm.convert([tensor_type, tf])) - assert not check_kind(tup_ty) + check_kind(tup_ty) if __name__ == "__main__": test_tuple_kind() test_func_kind() + test_ref_kind() test_relation_kind() + test_global_typevar_kind() + test_typecall_kind() test_invalid_tuple_kind() test_invalid_func_kind() + test_invalid_ref_kind() test_invalid_relation_kind() + test_typecall_invalid_callee() + test_typecall_invalid_args() + test_typecall_invalid_num_args() test_func_with_invalid_ret_type() test_func_with_invalid_arg_types() test_func_with_invalid_tuple() diff --git a/tests/python/relay/test_pass_vars.py b/tests/python/relay/test_pass_vars.py index c8d3d6d14992..afdaddca922a 100644 --- a/tests/python/relay/test_pass_vars.py +++ b/tests/python/relay/test_pass_vars.py @@ -65,6 +65,40 @@ def test_bound_vars(): assert_vars_match(bound_vars(f2), [x, y]) +def test_match_vars(): + mod = relay.Module() + p = relay.prelude.Prelude(mod) + + x = relay.Var('x') + y = relay.Var('y') + z = relay.Var('z') + + match1 = relay.Match(p.nil(), [ + relay.Clause(relay.PatternConstructor(p.nil), z), + relay.Clause(relay.PatternConstructor(p.cons, + [relay.PatternVar(x), + relay.PatternVar(y)]), + p.cons(x, y)) + ]) + + match2 = relay.Match(p.nil(), [ + relay.Clause(relay.PatternConstructor(p.cons, [ + relay.PatternWildcard(), + relay.PatternVar(x) + ]), + y), + relay.Clause(relay.PatternWildcard(), z) + ]) + + assert_vars_match(bound_vars(match1), [x, y]) + assert_vars_match(free_vars(match1), [z]) + assert_vars_match(all_vars(match1), [z, x, y]) + + assert_vars_match(bound_vars(match2), [x]) + assert_vars_match(free_vars(match2), [y, z]) + assert_vars_match(all_vars(match2), [x, y, z]) + + def test_bound_type_vars(): a = relay.TypeVar("a") b = relay.TypeVar("b") @@ -127,7 +161,7 @@ def test_all_type_vars(): x = relay.Var("x", a) y = relay.Var("y", b) z = relay.Var("z", c) - + f1 = relay.Function([x], y, b, [a]) assert_vars_match(all_type_vars(f1), [a, b]) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index eeefbc6c3051..05f8b8fd22f9 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -17,6 +17,16 @@ def assert_has_type(expr, typ, mod=relay.module.Module({})): checked_type, typ)) +# initializes simple ADT for tests +def initialize_box_adt(mod): + box = relay.GlobalTypeVar('box') + tv = relay.TypeVar('tv') + constructor = relay.Constructor('constructor', [tv], box) + data = relay.TypeData(box, [tv], [constructor]) + mod[box] = data + return (box, constructor) + + def test_monomorphic_let(): "Program: let x = 1; return x" sb = relay.ScopeBuilder() @@ -190,6 +200,69 @@ def test_equal(): assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool')) +def test_constructor_type(): + mod = relay.Module() + box, constructor = initialize_box_adt(mod) + + a = relay.TypeVar('a') + x = relay.Var('x', a) + ct = relay.ir_pass.infer_type( + relay.Function([x], constructor(x), box(a), [a]), mod) + expected = relay.FuncType([a], box(a), [a]) + assert ct.checked_type == expected + + +def test_constructor_call(): + mod = relay.Module() + box, constructor = initialize_box_adt(mod) + + box_unit = constructor(relay.Tuple([])) + box_constant = constructor(relay.const(0, 'float32')) + + ut = relay.ir_pass.infer_type(box_unit, mod) + ct = relay.ir_pass.infer_type(box_constant, mod) + assert ut.checked_type == box(relay.TupleType([])) + assert ct.checked_type == box(relay.TensorType((), 'float32')) + + +def test_adt_match(): + mod = relay.Module() + box, constructor = initialize_box_adt(mod) + + v = relay.Var('v', relay.TensorType((), 'float32')) + match = relay.Match(constructor(relay.const(0, 'float32')), + [relay.Clause( + relay.PatternConstructor(constructor, + [relay.PatternVar(v)]), + relay.Tuple([])), + # redundant but shouldn't matter to typechecking + relay.Clause(relay.PatternWildcard(), + relay.Tuple([]))]) + + mt = relay.ir_pass.infer_type(match, mod) + assert mt.checked_type == relay.TupleType([]) + + +def test_adt_match_type_annotations(): + mod = relay.Module() + box, constructor = initialize_box_adt(mod) + + # the only type annotation is inside the match pattern var + # but that should be enough info + tt = relay.TensorType((2, 2), 'float32') + x = relay.Var('x') + mv = relay.Var('mv', tt) + match = relay.Match(constructor(x), + [relay.Clause( + relay.PatternConstructor(constructor, + [relay.PatternVar(mv)]), + relay.Tuple([]))]) + + func = relay.Function([x], match) + ft = relay.ir_pass.infer_type(func, mod) + assert ft.checked_type == relay.FuncType([tt], relay.TupleType([])) + + if __name__ == "__main__": test_free_expr() test_dual_op() @@ -205,3 +278,6 @@ def test_equal(): test_global_var_recursion() test_equal() test_ref() + test_constructor_type() + test_constructor_call() + test_adt_match() diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 1e2fed0af1f8..8bcd912f841e 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -62,6 +62,30 @@ def test_unify_tuple(): assert unified == tup2 +def test_unify_global_type_var(): + # should only be able to unify if they're the same + solver = make_solver() + gtv = relay.GlobalTypeVar('gtv') + unified = solver.Unify(gtv, gtv) + assert unified == gtv + + +def test_unify_typecall(): + solver = make_solver() + gtv = relay.GlobalTypeVar('gtv') + + # yeah, typecalls are shaped like tuples so the same + # tests work out + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.TensorType((10, 20), "float32") + + tc1 = relay.ty.TypeCall(gtv, [t1, t2]) + tc2 = relay.ty.TypeCall(gtv, [t3, t3]) + unified = solver.Unify(tc1, tc2) + assert unified == tc2 + + def test_unify_functype(): solver = make_solver() t1 = relay.ty.IncompleteType() @@ -205,10 +229,49 @@ def test_bad_recursive_unification(): solver.Unify(t1, relay.ty.TupleType([t1, t1])) +@raises(tvm._ffi.base.TVMError) +def test_unify_invalid_global_typevars(): + solver = make_solver() + gtv1 = relay.GlobalTypeVar('gtv1') + gtv2 = relay.GlobalTypeVar('gtv2') + solver.Unify(gtv1, gtv2) + + +@raises(tvm._ffi.base.TVMError) +def test_incompatible_typecall_var_unification(): + solver = make_solver() + gtv1 = relay.GlobalTypeVar('gtv1') + gtv2 = relay.GlobalTypeVar('gtv2') + + t1 = relay.IncompleteType() + t2 = relay.IncompleteType() + + tc1 = relay.TypeCall(gtv1, [t1]) + tc2 = relay.TypeCall(gtv2, [t2]) + solver.Unify(tc1, tc2) + + +@raises(tvm._ffi.base.TVMError) +def test_incompatible_typecall_args_unification(): + solver = make_solver() + gtv = relay.GlobalTypeVar('gtv1') + t1 = relay.IncompleteType() + t2 = relay.IncompleteType() + + tensor1 = relay.TensorType((1, 2, 3), "float32") + tensor2 = relay.TensorType((2, 3), "float32") + tensor3 = relay.TensorType((3,), "float32") + + tc1 = relay.TypeCall(gtv, [relay.TupleType([t1, t1]), t2]) + tc2 = relay.TypeCall(gtv, [relay.TupleType([tensor1, tensor2]), tensor3]) + solver.Unify(tc1, tc2) + + if __name__ == "__main__": test_bcast() test_backward_solving() test_unify_tuple() + test_unify_typecall() test_unify_functype() test_recursive_unify() test_unify_vars_under_tuples() @@ -216,3 +279,5 @@ def test_bad_recursive_unification(): test_backward_solving_after_child_update() test_incompatible_tuple_unification() test_bad_recursive_unification() + test_incompatible_typecall_var_unification() + test_incompatible_typecall_args_unification() diff --git a/tests/python/relay/test_typecall.py b/tests/python/relay/test_typecall.py new file mode 100644 index 000000000000..6f002c438a52 --- /dev/null +++ b/tests/python/relay/test_typecall.py @@ -0,0 +1,28 @@ +from tvm import relay +from tvm.relay.ir_pass import infer_type + +def test_dup_type(): + a = relay.TypeVar("a") + av = relay.Var("av", a) + make_id = relay.Function([av], relay.Tuple([av, av]), None, [a]) + t = relay.scalar_type("float32") + b = relay.Var("b", t) + assert relay.ir_pass.infer_type(make_id(b)).checked_type == relay.TupleType([t, t]) + + +def test_id_type(): + mod = relay.Module() + id_type = relay.GlobalTypeVar("id") + a = relay.TypeVar("a") + mod[id_type] = relay.TypeData(id_type, [a], []) + + b = relay.TypeVar("b") + make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b])) + t = relay.scalar_type("float32") + b = relay.Var("b", t) + assert relay.ir_pass.infer_type(make_id(b), mod).checked_type == id_type(t) + + +if __name__ == "__main__": + test_dup_type() + test_id_type()