diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 1897809f48b8..566d69cc6b0b 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -108,6 +108,17 @@ bool AlphaEqual(const Type& t1, const Type& t2); */ bool WellFormed(const Expr& expr); +/*! \brief Get all bound variables from expression expr. + * + * Bound variables are all variables that are declared in the expr. + * They only have meaning inside that expr, and can only be used in it. + * + * \param expr the expression. + * + * \return List of bound vars, in the PostDFS order in the expression. + */ +tvm::Array BoundVars(const Expr& expr); + /*! \brief Get free type parameters from expression expr. * * Free variables are variables that are not bound by a @@ -119,6 +130,14 @@ bool WellFormed(const Expr& expr); */ tvm::Array FreeVars(const Expr& expr); +/*! \brief Get all variables from expression expr. + * + * \param expr the expression. + * + * \return List of all vars, in the PostDFS order in the expression. + */ +tvm::Array AllVars(const Expr& expr); + /*! \brief Get free TypeVars from expression expr. * * Free type parameters are type parameters that are not bound by a function @@ -130,6 +149,55 @@ tvm::Array FreeVars(const Expr& expr); */ tvm::Array FreeTypeVars(const Expr& expr); +/*! \brief Get free TypeVars from type t. + * + * Free type parameters are type parameters that are not bound by a function + * type in the context. + * + * \param t the type. + * + * \return List of free type vars, in the PostDFS order visited by type. + */ +tvm::Array FreeTypeVars(const Type& t); + +/*! \brief Get all bound type variables from expression expr. + * + * Bound variables are all type variables that are declared in the expr. + * They only have meaning inside that expr, and can only be used in it. + * + * \param expr the expression. + * + * \return List of bound type vars, in the PostDFS order in the expression. + */ +tvm::Array BoundTypeVars(const Expr& expr); + +/*! \brief Get all bound type variables from type t. + * + * Bound variables are all type variables that are declared in the type. + * They only have meaning inside that type, and can only be used in it. + * + * \param t the type + * + * \return List of bound type vars, in the PostDFS order visited by type. + */ +tvm::Array BoundTypeVars(const Type& t); + +/*! \brief Get all type variables in expression expr. + * + * \param expr the expression. + * + * \return List of type vars, in the PostDFS order in the expression. + */ +tvm::Array AllTypeVars(const Expr& expr); + +/*! \brief Get all type variables in type t. + * + * \param t the type. + * + * \return List of type vars, in the PostDFS order visited by type. + */ +tvm::Array AllTypeVars(const Type& t); + /*! \brief Remove expressions which does not effect the program result. * * It will remove let bindings which are not referenced, and branches that will diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 1bec7ccd72d5..d5d5e9261fc7 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -158,6 +158,38 @@ def free_vars(expr): return _ir_pass.free_vars(expr) +def bound_vars(expr): + """Get bound vars from expression expr in post-DFS order. + + Parameters + ---------- + expr: tvm.relay.Expr + The input expression + + Returns + ------- + free : List[tvm.relay.Var] + The list of bound variables in post-DFS order. + """ + return _ir_pass.bound_vars(expr) + + +def all_vars(expr): + """Get all vars from expression expr in post-DFS order. + + Parameters + ---------- + expr: tvm.relay.Expr + The input expression + + Returns + ------- + free : List[tvm.relay.Var] + The list of all variables in post-DFS order. + """ + return _ir_pass.all_vars(expr) + + def free_type_vars(expr): """Get free type variables from expression/type e @@ -168,12 +200,44 @@ def free_type_vars(expr): Returns ------- - free : List[tvm.relay.TypeParam] - The list of free type variables + free : List[tvm.relay.TypeVar] + The list of free type variables in post-DFS order """ return _ir_pass.free_type_vars(expr) +def bound_type_vars(expr): + """Get bound type variables from expression/type e + + Parameters + ---------- + expr: Union[tvm.relay.Expr,tvm.relay.Type] + The input expression/type + + Returns + ------- + free : List[tvm.relay.TypeVar] + The list of bound type variables in post-DFS order + """ + return _ir_pass.bound_type_vars(expr) + + +def all_type_vars(expr): + """Get all type variables from expression/type e + + Parameters + ---------- + expr: Union[tvm.relay.Expr,tvm.relay.Type] + The input expression/type + + Returns + ------- + free : List[tvm.relay.TypeVar] + The list of all type variables in post-DFS order + """ + return _ir_pass.all_type_vars(expr) + + def simplify_inference(expr): """ Simplify the data-flow graph for inference phase. diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 601a09b35d1a..251d7153e4e6 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -205,14 +205,25 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { }); return Pair(res.foward, grad); }); + + // if type annotations are provided, we will construct a ret type; + // otherwise, leave it to be inferred + Type ret_type = Type(); std::vector vt; + bool missing = !f->ret_type.defined(); for (const auto& p : f->params) { + if (missing || !p->type_annotation.defined()) { + missing = true; + break; + } vt.push_back(p->type_annotation); } - return FunctionNode::make(f->params, - body, - TupleTypeNode::make({f->ret_type, TupleTypeNode::make({})}), - {}); + + if (!missing) { + ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); + } + + return FunctionNode::make(f->params, body, ret_type, {}); } TVM_REGISTER_API("relay._ir_pass.first_order_gradient") diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index ee1b5ab10148..af4cc6607a44 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -56,31 +56,11 @@ bool TupleGetItemRel(const Array& types, return true; } -bool MakeTupleRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(static_cast(num_inputs + 1), types.size()); - for (int i = 0; i < num_inputs; ++i) { - if (types[i].as()) return false; - } - Array fields; - for (int i = 0; i < num_inputs; ++i) { - fields.push_back(types[i]); - } - reporter->Assign(types[num_inputs], TupleTypeNode::make(fields)); - return true; -} - TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") .set_body_typed&, int, const Attrs&, const TypeReporter&)>( TupleGetItemRel); -TVM_REGISTER_API("tvm.relay.type_relation.MakeTuple") -.set_body_typed&, int, const Attrs&, const TypeReporter&)>( - MakeTupleRel); - struct ResolvedTypeInfo { explicit ResolvedTypeInfo(Type checked_type, Array type_args) : checked_type(checked_type), type_args(type_args) {} @@ -120,6 +100,10 @@ class TypeInferencer : private ExprFunctor { // type inferencer will populate it up std::unordered_map type_map_; + // used to ensure we don't have free type vars hanging around + // (a temporary measure until we have proper generalization implemented) + Map instantiation_map_; + // The solver used by the inferencer. TypeSolver solver_; // relation function @@ -140,6 +124,32 @@ class TypeInferencer : private ExprFunctor { return Type(); } } + + // Substitutes every type var in t with a corresponding incomplete type. + // This is a temporary measure to ensure type vars behave until + // generalization is properly implemented. + Type Instantiate(const Type &t) { + if (!t.defined()) { + return t; + } + auto* ft = t.as(); + if (ft == nullptr) { + return Bind(t, instantiation_map_); + } + + for (auto type_param : ft->type_params) { + instantiation_map_.Set(type_param, IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + } + + Type ret_type = ft->ret_type; + if (!ret_type.defined()) { + ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + } + + auto strip_tvs = FuncTypeNode::make(ft->arg_types, ret_type, {}, ft->type_constraints); + return Bind(strip_tvs, instantiation_map_); + } + // Lazily get type for expr // will call visit to deduce it if it is not in the type_map_ Type GetType(const Expr &expr) { @@ -147,7 +157,7 @@ class TypeInferencer : private ExprFunctor { if (it != type_map_.end() && it->second.checked_type.defined()) { return it->second.checked_type; } - Type ret = this->VisitExpr(expr); + Type ret = Instantiate(this->VisitExpr(expr)); ResolvedTypeInfo& rti = type_map_[expr]; rti.checked_type = ret; return ret; @@ -175,19 +185,11 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const TupleNode* op) final { - if (!make_tuple_rel_.defined()) { - make_tuple_rel_ = TypeRelationFn( - EnvFunc::Get("tvm.relay.type_relation.MakeTuple").node_); - } Array types; for (Expr field : op->fields) { types.push_back(GetType(field)); } - Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); - types.push_back(rtype); - solver_.AddConstraint(TypeRelationNode::make( - make_tuple_rel_, types, op->fields.size(), Attrs())); - return rtype; + return TupleTypeNode::make(types); } Type VisitExpr_(const TupleGetItemNode* op) final { @@ -209,11 +211,17 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const LetNode* op) final { + // if the definition is a function literal, permit recursion + bool is_functional_literal = op->value.as() != nullptr; + if (is_functional_literal) { + type_map_[op->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + } + Type vtype = GetType(op->value); if (op->var->type_annotation.defined()) { vtype = Unify(vtype, op->var->type_annotation, op->span); } - CHECK(!type_map_.count(op->var)); + CHECK(is_functional_literal || !type_map_.count(op->var)); // NOTE: no scoping is necessary because var are unique in program type_map_[op->var].checked_type = vtype; return GetType(op->body); @@ -252,16 +260,14 @@ class TypeInferencer : private ExprFunctor { return rtype; } - // instantiate the function type with fresh - FuncType Instantiate(const FuncTypeNode* fn_ty, Array* ty_args) { + // substitute the type args in the function type + FuncType InstantiateFuncType(const FuncTypeNode* fn_ty, const Array& ty_args) { tvm::Map subst_map; // Build a subsitituion map up from the function type and type arguments. // Eventually allow the type vars to be passed in. - for (auto ty_param : fn_ty->type_params) { - IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); - subst_map.Set(ty_param, fresh); - ty_args->push_back(fresh); + for (size_t i = 0; i < fn_ty->type_params.size(); i++) { + subst_map.Set(fn_ty->type_params[i], ty_args[i]); } Type ret_type = fn_ty->ret_type; @@ -296,13 +302,32 @@ class TypeInferencer : private ExprFunctor { Type GeneralCall(const CallNode* call, Array arg_types) { Type ftype = GetType(call->op); auto* fn_ty_node = ftype.as(); + auto* inc_ty_node = ftype.as(); + + CHECK(fn_ty_node != nullptr || inc_ty_node != nullptr) + << "only expressions with function types can be called, found " + << ftype << " at " << call->span; + + // incomplete type => it must be a function taking the arg types + // with an unknown return type + if (inc_ty_node != nullptr) { + Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); + Type unified = this->Unify(ftype, func_type, call->span); + fn_ty_node = unified.as(); + } - CHECK(fn_ty_node != nullptr) - << "only expressions with function types can be called, found " - << ftype << " at " << call->span; - - Array type_args; - FuncType fn_ty = Instantiate(fn_ty_node, &type_args); + Array type_args = call->type_args; + if (type_args.size() == 0) { + for (size_t i = 0; i < fn_ty_node->type_params.size(); i++) { + type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + } + } + CHECK(type_args.size() == fn_ty_node->type_params.size()) + << "Incorrect number of type args in " << call->span << ": " + << "Expected " << fn_ty_node->type_params.size() + << "but got " << type_args.size(); + FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); AddTypeArgs(GetRef(call), type_args); @@ -353,26 +378,17 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const FunctionNode* f) final { + solver_.Solve(); + Array arg_types; for (auto param : f->params) { - GetType(param); + arg_types.push_back(GetType(param)); } Type rtype = GetType(f->body); - // Run solver using the currently known information - solver_.Solve(); - // Trying to resolve - Array arg_types; - for (size_t i = 0; i < f->params.size(); ++i) { - Type atype = solver_.Resolve(GetType(f->params[i])); - CHECK(atype.as() == nullptr) - << "Cannot resolve type of " << i - << "-th parameter of function at" << f->span; - arg_types.push_back(atype); + if (f->ret_type.defined()) { + rtype = this->Unify(f->ret_type, rtype, f->span); } - rtype = solver_.Resolve(rtype); - CHECK(rtype.as() == nullptr) - << "Cannot resolve return type of function at" << f->span; - // do not support constraint lifting for now. - return FuncTypeNode::make(arg_types, rtype, f->type_params, {}); + auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); + return solver_.Resolve(ret); } }; @@ -380,7 +396,7 @@ class TypeInferencer::Resolver : public ExprMutator { public: Resolver(const std::unordered_map& tmap, TypeSolver* solver) - : tmap_(tmap), solver_(solver) { + : tmap_(tmap), solver_(solver) { } Expr VisitExpr_(const VarNode* op) final { @@ -525,6 +541,7 @@ Expr TypeInferencer::Infer(Expr expr) { GetType(expr); // Step 1: Solve the constraints. solver_.Solve(); + // Step 2: Attach resolved types to checked_type field. auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); CHECK(WellFormed(resolved_expr)); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index e1efcbbdd0b9..caea3755b8f9 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -5,6 +5,7 @@ */ #include #include "type_solver.h" +#include "../ir/type_functor.h" namespace tvm { namespace relay { @@ -38,9 +39,298 @@ class TypeSolver::Reporter : public TypeReporterNode { TypeSolver* solver_; }; +class TypeSolver::OccursChecker : public TypeVisitor { + public: + explicit OccursChecker(TypeSolver* solver, TypeNode* var) + : solver_(solver), var_(var), found_(false) {} + + bool Check(const Type& t) { + VisitType(t); + return found_; + } + + void VisitType_(const IncompleteTypeNode* op) override { + IncompleteType t = GetRef(op); + TypeNode* node = solver_->GetTypeNode(t); + found_ = found_ || (var_->FindRoot() == node->FindRoot()); + } + + private: + TypeSolver* solver_; + TypeNode* var_; + bool found_; +}; + +class TypeSolver::Unifier : public TypeFunctor { + public: + explicit Unifier(TypeSolver* solver) : solver_(solver) {} + + Type Unify(const Type& src, const Type& dst) { + // Known limitation + // - handle shape pattern matching + TypeNode* lhs = solver_->GetTypeNode(dst); + TypeNode* rhs = solver_->GetTypeNode(src); + + // do occur check so we don't create self-referencing structure + if (lhs->FindRoot() == rhs->FindRoot()) { + return lhs->resolved_type; + } + if (lhs->resolved_type.as()) { + CHECK(!CheckOccurs(lhs, rhs->resolved_type)) + << "Incomplete type " << lhs->resolved_type << " occurs in " + << rhs->resolved_type << ", cannot unify"; + solver_->MergeFromTo(lhs, rhs); + return rhs->resolved_type; + } else if (rhs->resolved_type.as()) { + CHECK(!CheckOccurs(rhs, lhs->resolved_type)) + << "Incomplete type " << rhs->resolved_type << " occurs in " + << lhs->resolved_type << ", cannot unify"; + solver_->MergeFromTo(rhs, lhs); + return lhs->resolved_type; + } else { + Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); + CHECK(resolved.defined()) + << "Unable to unify parent types: " + << lhs->resolved_type << " and " << rhs->resolved_type; + TypeNode* top = solver_->GetTypeNode(resolved); + solver_->MergeFromTo(lhs, top); + solver_->MergeFromTo(rhs, top); + return resolved; + } + } + + // Checks whether lhs (taken to be a type var) occurs in t, meaning + // there is a recursive equality constraint, which should be rejected. + // N.b.: A tautology like ?a = ?a is okay and should be checked for + // *before* calling this method + bool CheckOccurs(TypeNode* lhs, const Type& t) { + OccursChecker rc(solver_, lhs); + return rc.Check(t); + } + + // default: unify only if alpha-equal + Type VisitTypeDefault_(const Node* op, const Type& tn) override { + NodeRef nr = GetRef(op); + Type t1 = GetRef(nr.as_derived()); + if (!AlphaEqual(t1, tn)) { + return Type(nullptr); + } + return t1; + } + + Type VisitType_(const TupleTypeNode* op, const Type& tn) override { + const auto* ttn = tn.as(); + if (!ttn || op->fields.size() != ttn->fields.size()) { + return Type(nullptr); + } + + TupleType tt1 = GetRef(op); + TupleType tt2 = GetRef(ttn); + + std::vector new_fields; + for (size_t i = 0; i < tt1->fields.size(); i++) { + Type field = Unify(tt1->fields[i], tt2->fields[i]); + new_fields.push_back(field); + } + return TupleTypeNode::make(new_fields); + } + + Type VisitType_(const FuncTypeNode* op, const Type& tn) override { + const auto* ftn = tn.as(); + if (!ftn + || op->arg_types.size() != ftn->arg_types.size() + || op->type_params.size() != ftn->type_params.size() + || op->type_constraints.size() != ftn->type_constraints.size()) { + return Type(nullptr); + } + + // remap type vars so they match + Map subst_map; + for (size_t i = 0; i < op->type_params.size(); i++) { + subst_map.Set(ftn->type_params[i], op->type_params[i]); + } + + auto ft1 = GetRef(op); + auto ft2 = Downcast(Bind(GetRef(ftn), subst_map)); + + Type ret_type = Unify(ft1->ret_type, ft2->ret_type); + + std::vector arg_types; + for (size_t i = 0; i < ft1->arg_types.size(); i++) { + Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]); + arg_types.push_back(arg_type); + } + + std::vector type_constraints; + for (size_t i = 0; i < ft1->type_constraints.size(); i++) { + Type unified_constraint = Unify(ft1->type_constraints[i], + ft2->type_constraints[i]); + const auto* tcn = unified_constraint.as(); + CHECK(tcn) << "Two type constraints unified into a non-constraint?" + << ft1->type_constraints[i] << " and " << ft2->type_constraints[i]; + type_constraints.push_back(GetRef(tcn)); + } + + return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints); + } + + private: + TypeSolver* solver_; +}; + +class TypeSolver::Resolver : public TypeMutator { + public: + explicit Resolver(TypeSolver* solver) : solver_(solver) {} + + Type Resolve(const Type& t) { + if (!t.defined()) { + return t; + } + return VisitType(t); + } + + Type VisitType_(const IncompleteTypeNode* op) override { + auto* node = solver_->GetTypeNode(GetRef(op)); + return node->resolved_type; + } + + private: + TypeSolver* solver_; +}; + +// It ends up being more compact to simply have TypeFunctor { + public: + explicit Propagator(TypeSolver* solver, const std::unordered_set* rels) + : solver_(solver), rels_(rels) {} + + // adds the relation node to t and all child types of t + void Propagate(const Type& t) { + VisitType(t); + } + + void UpdateRelSet(const Type& t) { + TypeNode* tnode = solver_->GetTypeNode(t); + for (auto* rel : *rels_) { + tnode->rel_set.insert(rel); + } + } + + void VisitTypeDefault_(const Node* op) override { + NodeRef nr = GetRef(op); + Type t = GetRef(nr.as_derived()); + UpdateRelSet(t); + } + + void VisitType_(const TupleTypeNode* op) override { + TupleType tt = GetRef(op); + UpdateRelSet(tt); + + for (const Type& t : tt->fields) { + Propagate(t); + } + } + + void VisitType_(const FuncTypeNode* op) override { + FuncType ft = GetRef(op); + UpdateRelSet(ft); + + Propagate(ft->ret_type); + for (auto arg_type : ft->arg_types) { + Propagate(arg_type); + } + + for (auto type_param : ft->type_params) { + Propagate(type_param); + } + + for (auto type_cs : ft->type_constraints) { + Propagate(type_cs); + } + } + + private: + TypeSolver* solver_; + const std::unordered_set* rels_; +}; + +// similarly, we use TypeFunctor so we can use +// the default visitor case to avoid more overrides +class TypeSolver::Merger : public TypeFunctor { + public: + explicit Merger(TypeSolver* solver) : solver_(solver) {} + + // Merges src node to dst, ensures *all* type relations of all + // child nodes of src are transferred to dst. + void Merge(TypeNode* src, TypeNode* dst) { + if (src == dst) return; + dst_ = dst; + VisitType(src->resolved_type); + // set parent at the end so later calls to GetTypeNode go back to src + src->parent = dst; + + // now propagate relations to child nodes, since change to + // a child node should update parent too + Propagator prop(solver_, &dst->rel_set); + prop.Propagate(dst->resolved_type); + } + + // Transfers any relations linked to t to the stored dst. + // Any unresolved relations are added back to the queue, since + // there is now new information + void TransferLinks(const Type& t) { + TypeNode* src = solver_->GetTypeNode(t); + if (src == dst_) return; + for (auto* rel : src->rel_set) { + // if the relation is not yet resolved, add to queue + if (!rel->resolved) { + solver_->AddToQueue(rel); + dst_->rel_set.insert(rel); + } + } + } + + void VisitTypeDefault_(const Node* op) override { + NodeRef nr = GetRef(op); + Type t = GetRef(nr.as_derived()); + TransferLinks(t); + } + + void VisitType_(const TupleTypeNode* ttn) override { + auto tup = GetRef(ttn); + TransferLinks(tup); + + for (auto field : tup->fields) { + VisitType(field); + } + } + + void VisitType_(const FuncTypeNode* ftn) override { + auto func = GetRef(ftn); + TransferLinks(func); + + VisitType(func->ret_type); + for (auto arg : func->arg_types) { + VisitType(arg); + } + for (auto param : func->type_params) { + VisitType(param); + } + for (auto constraint : func->type_constraints) { + VisitType(constraint); + } + } + + private: + TypeSolver* solver_; + TypeNode* dst_; +}; + // constructor TypeSolver::TypeSolver() - : reporter_(make_node(this)) { + : reporter_(make_node(this)) { } // destructor @@ -54,31 +344,16 @@ TypeSolver::~TypeSolver() { } } +// merge src type node to dst +void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) { + Merger merger(this); + merger.Merge(src, dst); +} + // Add equality constraint Type TypeSolver::Unify(const Type& dst, const Type& src) { - // Known limitation - // - handle composite types whose component can be unknown. - // - handle shape pattern matching - TypeNode* lhs = GetTypeNode(dst); - TypeNode* rhs = GetTypeNode(src); - - // do occur check so we don't create self-referencing structure - if (lhs->FindRoot() == rhs->FindRoot()) { - return lhs->resolved_type; - } - if (lhs->resolved_type.as()) { - MergeFromTo(lhs, rhs); - return rhs->resolved_type; - } else if (rhs->resolved_type.as()) { - MergeFromTo(rhs, lhs); - return lhs->resolved_type; - } else { - lhs->parent = rhs; - CHECK(AlphaEqual(lhs->resolved_type, rhs->resolved_type)) - << "Incompatible parent types in UF:" - << lhs->resolved_type << " and " << rhs->resolved_type; - return rhs->resolved_type; - } + Unifier unifier(this); + return unifier.Unify(dst, src); } // Add type constraint to the solver. @@ -96,9 +371,9 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { tlink->value = tnode; rnode->type_list.Push(tlink); // insert type->relation node - LinkNode* rlink = arena_.make >(); - rlink->value = rnode; - tnode->rel_list.Push(rlink); + std::unordered_set singleton { rnode }; + Propagator prop(this, &singleton); + prop.Propagate(tnode->resolved_type); } // add the relation to the working queue. this->AddToQueue(rnode); @@ -110,12 +385,10 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { // Resolve a type in the solver context. Type TypeSolver::Resolve(const Type& type) { + Resolver resolver(this); auto it = tmap_.find(type); - if (it != tmap_.end()) { - return it->second->FindRoot()->resolved_type; - } else { - return type; - } + Type t = (it != tmap_.end()) ? it->second->FindRoot()->resolved_type : type; + return resolver.Resolve(t); } bool TypeSolver::Solve() { @@ -128,7 +401,7 @@ bool TypeSolver::Solve() { // update the relation with given evidence. Array args; for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) { - args.push_back(tlink->value->FindRoot()->resolved_type); + args.push_back(Resolve(tlink->value->FindRoot()->resolved_type)); CHECK_LE(args.size(), rel->args.size()); } // call the function @@ -161,8 +434,8 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") return solver->Solve(); }); } else if (name == "Unify") { - return TypedPackedFunc([solver](Type lhs, Type rhs) { - solver->Unify(lhs, rhs); + return TypedPackedFunc([solver](Type lhs, Type rhs) { + return solver->Unify(lhs, rhs); }); } else if (name == "Resolve") { return TypedPackedFunc([solver](Type t) { diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 2f311c9b9810..b4635fdec331 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -18,6 +18,7 @@ namespace relay { using common::LinkNode; using common::LinkedList; + /*! * \brief Interface of type solver used in type inference. * @@ -65,6 +66,11 @@ class TypeSolver { Type Unify(const Type& lhs, const Type& rhs); private: + class OccursChecker; + class Unifier; + class Resolver; + class Propagator; + class Merger; class Reporter; struct TypeNode; struct RelationNode; @@ -77,15 +83,15 @@ class TypeSolver { * that can unifies the same types to the name resolved_type. * * It also contains collection of links to related Relations, - * which is stored in rel_list. + * which is stored in rel_set. */ struct TypeNode { /*! \brief The final resolved type */ Type resolved_type; /*! \brief type node in the union find algorithm */ TypeNode* parent{nullptr}; - /*! \brief list of relations that is related to this type node */ - LinkedList rel_list; + /*! \brief set of relations that is related to this type node */ + std::unordered_set rel_set; /*! * \brief Find the root type node, perform path compression * \return The root type node. @@ -125,7 +131,7 @@ class TypeSolver { size_t num_resolved_rels_{0}; /*! \brief map from type node to types. */ std::unordered_map tmap_; - /*! \breif Internal queue to update the relation */ + /*! \brief Internal queue to update the relation */ std::queue update_queue_; /*! \brief allocator of all the internal node obhect*/ common::Arena arena_; @@ -163,22 +169,7 @@ class TypeSolver { * \param src The source operand * \param dst The dst operand. */ - void MergeFromTo(TypeNode* src, TypeNode* dst) { - if (src == dst) return; - src->parent = dst; - // move the link to the to dst - for (auto* rlink = src->rel_list.head; rlink != nullptr;) { - // store next pointer first before rlink get moved - auto* next = rlink->next; - // if the relation is not yet resolved - // send the relation to the new - if (!rlink->value->resolved) { - this->AddToQueue(rlink->value); - dst->rel_list.Push(rlink); - } - rlink = next; - } - } + void MergeFromTo(TypeNode* src, TypeNode* dst); }; } // namespace relay diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index b99d975135be..403863c1d757 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -12,105 +12,211 @@ namespace tvm { namespace relay { -// FreeTypeVar -class FreeTypeVarTVisitor : public TypeVisitor { +template +struct InsertionSet { + std::unordered_set set; + std::vector data; + void Insert(const T& t) { + if (set.count(t) == 0) { + set.insert(t); + data.push_back(t); + } + } +}; + +class TypeVarTVisitor : public TypeVisitor { public: - FreeTypeVarTVisitor( - Array* free_vars, - std::unordered_set* bound_vars) - : free_vars_(free_vars), bound_vars_(bound_vars) { } + TypeVarTVisitor( + InsertionSet* type_vars, + InsertionSet* bound_type_vars) + : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { } void VisitType_(const TypeVarNode* tp) final { TypeVar var = GetRef(tp); - if (bound_vars_->count(var) == 0) { - free_vars_->push_back(var); - } + type_vars_->Insert(var); } void VisitType_(const FuncTypeNode* f) final { for (auto type_param : f->type_params) { - bound_vars_->insert(type_param); + type_vars_->Insert(type_param); + bound_type_vars_->Insert(type_param); } TypeVisitor::VisitType_(f); } private: - Array* free_vars_; - std::unordered_set* bound_vars_; + InsertionSet* type_vars_; + InsertionSet* bound_type_vars_; }; -class FreeTypeVarEVisitor : private ExprVisitor { +class TypeVarEVisitor : private ExprVisitor { public: - Array Find(const Expr& expr) { - this->VisitExpr(expr); - return free_vars_; + Array CollectFree() { + Array ret; + for (const auto& v : type_vars_.data) { + if (bound_type_vars_.set.count(v) == 0) { + ret.push_back(v); + } + } + return ret; + } + + Array CollectBound() { + Array ret; + for (const auto& v : bound_type_vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array CollectAll() { + Array ret; + for (const auto& v : type_vars_.data) { + ret.push_back(v); + } + return ret; } - Array Find(const Type& type) { - this->VisitType(type); - return free_vars_; + Array Free(const Expr& expr) { + VisitExpr(expr); + return CollectFree(); + } + + Array Free(const Type& type) { + VisitType(type); + return CollectFree(); + } + + Array Bound(const Expr& expr) { + VisitExpr(expr); + return CollectBound(); + } + + Array Bound(const Type& type) { + VisitType(type); + return CollectBound(); + } + + Array All(const Expr& expr) { + VisitExpr(expr); + return CollectAll(); + } + + Array All(const Type& type) { + VisitType(type); + return CollectAll(); } void VisitExpr_(const FunctionNode* f) final { for (const auto& tp : f->type_params) { - bound_vars_.insert(tp); + type_vars_.Insert(tp); + bound_type_vars_.Insert(tp); } ExprVisitor::VisitExpr_(f); } void VisitType(const Type& t) final { - FreeTypeVarTVisitor(&free_vars_, &bound_vars_) + TypeVarTVisitor(&type_vars_, &bound_type_vars_) .VisitType(t); } private: - // The result list - Array free_vars_; - std::unordered_set bound_vars_; + InsertionSet type_vars_; + InsertionSet bound_type_vars_; }; -class FreeVarVisitor : protected ExprVisitor { +class VarVisitor : protected ExprVisitor { public: - Array Find(const Expr& expr) { + Array Free(const Expr& expr) { this->VisitExpr(expr); - return free_vars_; + Array ret; + for (const auto& v : vars_.data) { + if (bound_vars_.set.count(v) == 0) { + ret.push_back(v); + } + } + return ret; } - void VisitExpr_(const VarNode* var) final { - if (bound_vars_.count(var) == 0) { - free_vars_.push_back(GetRef(var)); + Array Bound(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : bound_vars_.data) { + ret.push_back(v); } + return ret; + } + + Array All(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + ret.push_back(v); + } + return ret; + } + + void MarkBounded(const Var& v) { + bound_vars_.Insert(v); + vars_.Insert(v); + } + + void VisitExpr_(const VarNode* var) final { + vars_.Insert(GetRef(var)); } void VisitExpr_(const FunctionNode* op) final { for (const auto& param : op->params) { - bound_vars_.insert(param.operator->()); + MarkBounded(param); } VisitExpr(op->body); } void VisitExpr_(const LetNode* op) final { - bound_vars_.insert(op->var.operator->()); + MarkBounded(op->var); VisitExpr(op->value); VisitExpr(op->body); } private: - // The result list - Array free_vars_; - std::unordered_set bound_vars_; + InsertionSet vars_; + InsertionSet bound_vars_; }; tvm::Array FreeTypeVars(const Expr& expr) { - return FreeTypeVarEVisitor().Find(expr); + return TypeVarEVisitor().Free(expr); } tvm::Array FreeTypeVars(const Type& type) { - return FreeTypeVarEVisitor().Find(type); + return TypeVarEVisitor().Free(type); +} + +tvm::Array BoundTypeVars(const Expr& expr) { + return TypeVarEVisitor().Bound(expr); +} + +tvm::Array BoundTypeVars(const Type& type) { + return TypeVarEVisitor().Bound(type); +} + +tvm::Array AllTypeVars(const Expr& expr) { + return TypeVarEVisitor().All(expr); +} + +tvm::Array AllTypeVars(const Type& type) { + return TypeVarEVisitor().All(type); } tvm::Array FreeVars(const Expr& expr) { - return FreeVarVisitor().Find(expr); + return VarVisitor().Free(expr); +} + +tvm::Array BoundVars(const Expr& expr) { + return VarVisitor().Bound(expr); +} + +tvm::Array AllVars(const Expr& expr) { + return VarVisitor().All(expr); } TVM_REGISTER_API("relay._ir_pass.free_vars") @@ -118,16 +224,46 @@ TVM_REGISTER_API("relay._ir_pass.free_vars") *ret = FreeVars(args[0]); }); +TVM_REGISTER_API("relay._ir_pass.bound_vars") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = BoundVars(args[0]); + }); + +TVM_REGISTER_API("relay._ir_pass.all_vars") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = AllVars(args[0]); + }); + TVM_REGISTER_API("relay._ir_pass.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; - if (x.as()) { + if (x.as_derived()) { *ret = FreeTypeVars(Downcast(x)); } else { *ret = FreeTypeVars(Downcast(x)); } }); +TVM_REGISTER_API("relay._ir_pass.bound_type_vars") + .set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef x = args[0]; + if (x.as_derived()) { + *ret = BoundTypeVars(Downcast(x)); + } else { + *ret = BoundTypeVars(Downcast(x)); + } + }); + +TVM_REGISTER_API("relay._ir_pass.all_type_vars") + .set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef x = args[0]; + if (x.as_derived()) { + *ret = AllTypeVars(Downcast(x)); + } else { + *ret = AllTypeVars(Downcast(x)); + } + }); + /*! * \brief Get reference counter of each internal ExprNode in body. * \param body The body expression. diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 385bde974014..50aed4c57338 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -6,13 +6,17 @@ TEST(Relay, SelfReference) { using namespace tvm; - auto type_a = relay::TypeVarNode::make("a", relay::TypeVarNode::kType); - auto type_b = relay::TypeVarNode::make("b", relay::TypeVarNode::kType); - auto x = relay::VarNode::make("x", type_a); - auto f = relay::FunctionNode::make(tvm::Array{ x }, x, type_b, Array{}); - auto fx = relay::CallNode::make(f, Array{ x }); + auto tensor_type = relay::TensorTypeNode::make({}, ::tvm::Bool()); + auto x = relay::VarNode::make("x", relay::Type()); + auto f = relay::FunctionNode::make(tvm::Array{ x }, x, relay::Type(), {}); + + auto y = relay::VarNode::make("y", tensor_type); + auto call = relay::CallNode::make(f, Array{ y }); + auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map{})); - CHECK_EQ(type_fx->checked_type(), type_a); + + auto expected = relay::FuncTypeNode::make(tvm::Array{ tensor_type }, tensor_type, {}, {}); + CHECK(AlphaEqual(type_fx->checked_type(), expected)); } int main(int argc, char ** argv) { diff --git a/tests/python/relay/test_pass_free_vars.py b/tests/python/relay/test_pass_free_vars.py deleted file mode 100644 index 151dbe1412bc..000000000000 --- a/tests/python/relay/test_pass_free_vars.py +++ /dev/null @@ -1,41 +0,0 @@ -import tvm -from tvm import relay -from tvm.relay.ir_pass import free_vars, free_type_vars - -def test_free_vars(): - ty = relay.TensorType([], "int32") - x = relay.Var("x", ty) - fvx = free_vars(x) - assert len(fvx) == 1 - assert fvx[0] == x - v = relay.Constant(tvm.nd.array(10)) - - let = relay.Let(x, v, x) - fvx = free_vars(let) - assert len(free_vars(let)) == 0 - f = relay.Function([x], x, ty) - assert len(free_vars(f)) == 0 - - -def test_tuple(): - t = relay.Var('t') - fv = free_vars(relay.Tuple([t, t])) - assert len(fv) == 1 - assert fv[0] == t - fv = free_vars(relay.TupleGetItem(t, 123)) - assert len(fv) == 1 - assert fv[0] == t - - -def test_free_type_vars(): - tp = relay.TypeVar("") - ty = relay.TupleType([tp, relay.TensorType([], "int32")]) - x = relay.Var("x", ty) - y = relay.Var("y") - let = relay.Let(x, y, x) - fvl = free_vars(let) - assert len(fvl) == 1 - assert fvl[0] == y - ftvl = free_type_vars(let) - assert len(ftvl) == 1 - assert ftvl[0] == tp diff --git a/tests/python/relay/test_pass_vars.py b/tests/python/relay/test_pass_vars.py new file mode 100644 index 000000000000..c8d3d6d14992 --- /dev/null +++ b/tests/python/relay/test_pass_vars.py @@ -0,0 +1,144 @@ +import tvm +from tvm import relay +from tvm.relay.ir_pass import (free_vars, free_type_vars, + bound_vars, bound_type_vars, + all_vars, all_type_vars) + +def assert_vars_match(actual, expected): + assert len(actual) == len(expected) + for i in range(len(actual)): + assert actual[i] == expected[i] + + +def test_free_vars(): + ty = relay.TensorType([], "int32") + x = relay.Var("x", ty) + fvx = free_vars(x) + assert len(fvx) == 1 + assert fvx[0] == x + v = relay.Constant(tvm.nd.array(10)) + + let = relay.Let(x, v, x) + fvx = free_vars(let) + assert len(free_vars(let)) == 0 + f = relay.Function([x], x, ty) + assert len(free_vars(f)) == 0 + + +def test_free_vars_tuple(): + t = relay.Var('t') + fv = free_vars(relay.Tuple([t, t])) + assert len(fv) == 1 + assert fv[0] == t + fv = free_vars(relay.TupleGetItem(t, 123)) + assert len(fv) == 1 + assert fv[0] == t + + +def test_free_type_vars(): + tp = relay.TypeVar("") + ty = relay.TupleType([tp, relay.TensorType([], "int32")]) + x = relay.Var("x", ty) + y = relay.Var("y") + let = relay.Let(x, y, x) + fvl = free_vars(let) + assert len(fvl) == 1 + assert fvl[0] == y + ftvl = free_type_vars(let) + assert len(ftvl) == 1 + assert ftvl[0] == tp + + +def test_bound_vars(): + x = relay.Var("x") + y = relay.Var("y") + z = relay.Var("z") + a = relay.Var("a") + + f1 = relay.Function([x, y, z], relay.Let(a, x, relay.Tuple([]))) + assert_vars_match(bound_vars(f1), [x, y, z, a]) + + tup = relay.Tuple([x, y, z, a]) + assert len(bound_vars(tup)) == 0 + + f2 = relay.Function([x, y], relay.Tuple([x, y, z, a])) + assert_vars_match(bound_vars(f2), [x, y]) + + +def test_bound_type_vars(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + c = relay.TypeVar("c") + + ft1 = relay.FuncType([a], b, [a, b]) + bound_ft1 = bound_type_vars(ft1) + assert_vars_match(bound_type_vars(ft1), [a, b]) + + ft2 = relay.FuncType([], c, [a]) + assert_vars_match(bound_type_vars(ft2), [a]) + + tup_ty = relay.TupleType([a, b, c]) + assert len(bound_type_vars(tup_ty)) == 0 + + f1 = relay.Function([], relay.Tuple([]), type_params=[a, b]) + assert_vars_match(bound_type_vars(f1), [a, b]) + + f2 = relay.Function([], relay.Tuple([]), c) + assert len(bound_type_vars(f2)) == 0 + + x = relay.Var("x", a) + let1 = relay.Let(x, relay.Tuple([]), x) + assert len(bound_type_vars(let1)) == 0 + + let2 = relay.Let(x, relay.Function([], relay.Tuple([]), type_params=[b, c]), x) + assert_vars_match(bound_type_vars(let2), [b, c]) + + +def test_all_vars(): + x = relay.Var("x") + y = relay.Var("y") + z = relay.Var("z") + + f1 = relay.Function([x, y], z) + assert_vars_match(all_vars(f1), [x, y, z]) + + f2 = relay.Function([x], relay.Let(y, relay.Tuple([]), z)) + assert_vars_match(all_vars(f2), [x, y, z]) + + f3 = relay.Function([x], relay.Tuple([y, z])) + assert_vars_match(all_vars(f3), [x, y, z]) + + tup = relay.Tuple([x, y, z]) + assert_vars_match(all_vars(tup), [x, y, z]) + + +def test_all_type_vars(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + c = relay.TypeVar("c") + + ft1 = relay.FuncType([b], c, [a]) + assert_vars_match(all_type_vars(ft1), [a, b, c]) + + ft2 = relay.FuncType([], relay.TupleType([a, b, c]), []) + assert_vars_match(all_type_vars(ft2), [a, b, c]) + + w = relay.Var("w") + x = relay.Var("x", a) + y = relay.Var("y", b) + z = relay.Var("z", c) + + f1 = relay.Function([x], y, b, [a]) + assert_vars_match(all_type_vars(f1), [a, b]) + + f2 = relay.Function([x], relay.Let(y, x, z)) + assert_vars_match(all_type_vars(f2), [a, b, c]) + + f3 = relay.Function([], relay.Tuple([x, y, z]), ret_type=relay.TupleType([a, b, c])) + assert_vars_match(all_type_vars(f3), [a, b, c]) + + f4 = relay.Function([w], relay.Tuple([]), type_params=[a, b, c]) + assert_vars_match(all_type_vars(f4), [a, b, c]) + + f5 = relay.Function([w], w) + assert len(all_type_vars(f5)) == 0 diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 06cb19639dcf..ac4eb1b404db 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -23,7 +23,7 @@ def test_monomorphic_let(): x = sb.let('x', relay.const(1.0, "float64")) sb.ret(x) xchecked = relay.ir_pass.infer_type(sb.get()) - assert xchecked.checked_type == relay.scalar_type("float64") + assert xchecked.checked_type == relay.scalar_type("float64" ) def test_single_op(): @@ -41,14 +41,15 @@ def test_add_broadcast_op(): return x + y; } """ - pass - # x = relay.var('x', shape=(10, 4)) - # y = relay.var('y', shape=(5, 10, 1)) - # z = x + y - # func = relay.Function([x, y], z) - # ttype = relay.TensorType((5, 5, 5), 'float32') - # expected_ty = relay.FuncType([ttype, ttype], ttype) - # assert_has_type(func.to_func(), expected_ty) + x = relay.var('x', shape=(10, 4)) + y = relay.var('y', shape=(5, 10, 1)) + z = x + y + func = relay.Function([x, y], z) + t1 = relay.TensorType((10, 4), 'float32') + t2 = relay.TensorType((5, 10, 1), 'float32') + t3 = relay.TensorType((5, 10, 4), 'float32') + expected_ty = relay.FuncType([t1, t2], t3) + assert_has_type(func, expected_ty) def test_dual_op(): @@ -110,24 +111,17 @@ def f(n: i32, data: f32) -> f32 { assert "%3 = @f(%1, %2)" in mod.astext() assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) -# This currently fails and should pass under the type system. -# -# This test is to illustrate problem with our weak form of -# unification. -# - def test_incomplete_call(): - sb = ScopeBuilder() - x = relay.var('x', dtype='int32') + tt = relay.scalar_type('int32') + x = relay.var('x', tt) f = relay.var('f') - func = relay.Function([x, f], relay.Call(f, [x])) + func = relay.Function([x, f], relay.Call(f, [x]), tt) + + ft = relay.ir_pass.infer_type(func) + f_type = relay.FuncType([tt], tt) + assert ft.checked_type == relay.FuncType([tt, f_type], tt) - try: - relay.ir_pass.infer_type(func) - assert False - except tvm.TVMError as e: - assert True def test_tuple(): tp = relay.TensorType((10,)) @@ -136,6 +130,7 @@ def test_tuple(): assert (relay.ir_pass.infer_type(res).checked_type == relay.TupleType([tp, tp])) + def test_free_expr(): x = relay.var("x", "float32") y = relay.add(x, x) @@ -161,38 +156,26 @@ def test_type_args(): assert sh2[1].value == 10 -def test_self_reference(): - """ - Program: - def f(x) { - return x; - } - """ - a = relay.TypeVar("a") - x = relay.var("x", a) - sb = relay.ScopeBuilder() - - f = relay.Function([x], x) - fx = relay.Call(f, [x]) - assert relay.ir_pass.infer_type(x).checked_type == a - assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a) - assert relay.ir_pass.infer_type(fx).checked_type == a - - -def test_global_var_cow_issue(): +def test_global_var_recursion(): mod = relay.Module({}) gv = relay.GlobalVar("foo") x = relay.var('x', shape=[]) - func = relay.Function([x], relay.Call(gv, [x]), - relay.TensorType([], 'float32')) + tt = relay.scalar_type('float32') + + func = relay.Function([x], relay.Call(gv, [x]), tt) mod[gv] = func + ft = relay.ir_pass.infer_type(gv, mod) + assert mod[ft].checked_type == relay.FuncType([tt], tt) + def test_equal(): i = relay.var('i', shape=[], dtype='int32') eq = op.equal(i, relay.const(0, dtype='int32')) - # This should fail .... - func = relay.Function([i], eq, ret_type=relay.TensorType([], 'int32')) + func = relay.Function([i], eq) + ft = relay.ir_pass.infer_type(func) + + assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool')) if __name__ == "__main__": @@ -204,8 +187,12 @@ def test_equal(): test_decl() test_recursion() test_tuple() + test_generalized_tuple() test_incomplete_call() + test_generalized_call() + test_call_with_type_args() test_free_expr() test_type_args() test_self_reference() - test_global_var_cow_issue() + test_global_var_recursion() + test_equal() diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index e8ff67756931..1e2fed0af1f8 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -1,5 +1,6 @@ import tvm from tvm import relay +from nose.tools import raises def make_rel(name, args, num_inputs=None, attrs=None): @@ -48,7 +49,170 @@ def test_backward_solving(): assert solver.Resolve(t3) == relay.ty.TensorType((10, 10, 20), "float32") +def test_unify_tuple(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.TensorType((10, 20), "float32") + + tup1 = relay.ty.TupleType([t1, t2]) + tup2 = relay.ty.TupleType([t3, t3]) + + unified = solver.Unify(tup1, tup2) + assert unified == tup2 + + +def test_unify_functype(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.IncompleteType() + + unit = relay.ty.TupleType([]) + tensor1 = relay.ty.TensorType((10, 20), "float32") + tensor2 = relay.ty.TensorType((10,), "float32") + + ft1 = relay.ty.FuncType([t1, t2], t3) + ft2 = relay.ty.FuncType([tensor1, tensor2], unit) + + unified = solver.Unify(ft1, ft2) + assert unified == ft2 + + +def test_recursive_unify(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.IncompleteType() + + tensor1 = relay.ty.TensorType((10, 10, 20), "float32") + tensor2 = relay.ty.TensorType((10, 20), "float32") + tensor3 = relay.ty.TensorType((10,), "float32") + + tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t2]) + tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor2]) + + ft1 = relay.ty.FuncType([tup1, t3], t3) + ft2 = relay.ty.FuncType([tup2, tensor3], tensor3) + + unified = solver.Unify(ft1, ft2) + assert unified == ft2 + + +def test_unify_vars_under_tuples(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + + tup1 = relay.ty.TupleType([t1, t1]) + unified = solver.Unify(tup1, tup1) + assert unified == tup1 + + t2 = relay.ty.IncompleteType() + tup2 = relay.ty.TupleType([t2, t2]) + + tup3 = relay.ty.TupleType([t1, t2]) + tup4 = relay.ty.TupleType([t2, t1]) + unified = solver.Unify(tup3, tup4) + assert (unified == tup1 or unified == tup2) + + +def test_binding_over_typevars(): + solver = make_solver() + + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + + a = relay.ty.TypeVar('a') + b = relay.ty.TypeVar('b') + c = relay.ty.TypeVar('c') + d = relay.ty.TypeVar('d') + + ft1 = relay.ty.FuncType([t1], t2, [c, d]) + ft2 = relay.ty.FuncType([a], b, [a, b]) + unified = solver.Unify(ft1, ft2) + assert (unified == solver.Resolve(ft1)) + + +def test_recursive_backward_solving(): + solver = make_solver() + + tensor1 = relay.ty.TensorType((10, 20), "float32") + tensor2 = relay.ty.TensorType((10, 1, 1), "float32") + tensor3 = relay.ty.TensorType((10,), "float32") + + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.IncompleteType() + + tup1 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3]) + tup2 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t3]) + solver.gen_type("Identity", [tup1], out=tup2) + + assert solver.Solve() + assert solver.Resolve(tup2) == tup1 + + +def test_backward_solving_after_child_update(): + solver = make_solver() + + tensor1 = relay.ty.TensorType((10, 20), "float32") + tensor2 = relay.ty.TensorType((10, 1, 1), "float32") + + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.IncompleteType() + + tup1 = relay.ty.TupleType([t1, t2]) + tup2 = relay.ty.TupleType([t1, t3]) + + tup_concrete = relay.ty.TupleType([tensor1, tensor2]) + + t4 = solver.gen_type("Identity", [tup1]) + t5 = solver.gen_type("Identity", [tup2]) + + solver.gen_type("Identity", [t4], out=t5) + assert solver.Solve() + assert solver.Resolve(t3) == t3 or solver.Resolve(t3) == t2 + assert solver.Resolve(t4) == tup1 or solver.Resolve(t4) == tup2 + assert solver.Resolve(t5) == tup1 or solver.Resolve(t5) == tup2 + + # updating the variables *inside* tup1 and tup2 should update t4 and t5 + solver.gen_type("Identity", [t1], out=tensor1) + solver.gen_type("Identity", [t2], out=tensor2) + assert solver.Solve() + assert solver.Resolve(t4) == tup_concrete + assert solver.Resolve(t5) == tup_concrete + +@raises(tvm._ffi.base.TVMError) +def test_incompatible_tuple_unification(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + + tensor1 = relay.ty.TensorType((1, 2, 3), "float32") + tensor2 = relay.ty.TensorType((2, 3), "float32") + tensor3 = relay.ty.TensorType((3,), "float32") + + tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t1]), t2]) + tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3]) + solver.Unify(tup1, tup2) + + +@raises(tvm._ffi.base.TVMError) +def test_bad_recursive_unification(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + solver.Unify(t1, relay.ty.TupleType([t1, t1])) + if __name__ == "__main__": test_bcast() test_backward_solving() + test_unify_tuple() + test_unify_functype() + test_recursive_unify() + test_unify_vars_under_tuples() + test_recursive_backward_solving() + test_backward_solving_after_child_update() + test_incompatible_tuple_unification() + test_bad_recursive_unification()