diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index dafcaf56015a..786fe90c7b71 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -335,6 +335,295 @@ class TypeSolver::Merger : public TypeFunctor { TypeNode* dst_; }; +class TypeSolver::OccursChecker : public TypeVisitor { + public: + explicit OccursChecker(TypeSolver* solver, TypeNode* var) + : solver_(solver), var_(var), found_(false) {} + + bool Check(const Type& t) { + VisitType(t); + return found_; + } + + void VisitType_(const IncompleteTypeNode* op) override { + IncompleteType t = GetRef(op); + TypeNode* node = solver_->GetTypeNode(t); + found_ = found_ || (var_->FindRoot() == node->FindRoot()); + } + + private: + TypeSolver* solver_; + TypeNode* var_; + bool found_; +}; + +class TypeSolver::Unifier : public TypeFunctor { + public: + explicit Unifier(TypeSolver* solver) : solver_(solver) {} + + Type Unify(const Type& src, const Type& dst) { + // Known limitation + // - handle shape pattern matching + TypeNode* lhs = solver_->GetTypeNode(dst); + TypeNode* rhs = solver_->GetTypeNode(src); + + // do occur check so we don't create self-referencing structure + if (lhs->FindRoot() == rhs->FindRoot()) { + return lhs->resolved_type; + } + if (lhs->resolved_type.as()) { + CHECK(!CheckOccurs(lhs, rhs->resolved_type)) + << "Incomplete type " << lhs->resolved_type << " occurs in " + << rhs->resolved_type << ", cannot unify"; + solver_->MergeFromTo(lhs, rhs); + return rhs->resolved_type; + } else if (rhs->resolved_type.as()) { + CHECK(!CheckOccurs(rhs, lhs->resolved_type)) + << "Incomplete type " << rhs->resolved_type << " occurs in " + << lhs->resolved_type << ", cannot unify"; + solver_->MergeFromTo(rhs, lhs); + return lhs->resolved_type; + } else { + Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); + CHECK(resolved.defined()) + << "Unable to unify parent types: " + << lhs->resolved_type << " and " << rhs->resolved_type; + TypeNode* top = solver_->GetTypeNode(resolved); + solver_->MergeFromTo(lhs, top); + solver_->MergeFromTo(rhs, top); + return resolved; + } + } + + // Checks whether lhs (taken to be a type var) occurs in t, meaning + // there is a recursive equality constraint, which should be rejected. + // N.b.: A tautology like ?a = ?a is okay and should be checked for + // *before* calling this method + bool CheckOccurs(TypeNode* lhs, const Type& t) { + OccursChecker rc(solver_, lhs); + return rc.Check(t); + } + + // default: unify only if alpha-equal + Type VisitTypeDefault_(const Node* op, const Type& tn) override { + NodeRef nr = GetRef(op); + Type t1 = GetRef(nr.as_derived()); + if (!AlphaEqual(t1, tn)) { + return Type(nullptr); + } + return t1; + } + + Type VisitType_(const TupleTypeNode* op, const Type& tn) override { + const auto* ttn = tn.as(); + if (!ttn || op->fields.size() != ttn->fields.size()) { + return Type(nullptr); + } + + TupleType tt1 = GetRef(op); + TupleType tt2 = GetRef(ttn); + + std::vector new_fields; + for (size_t i = 0; i < tt1->fields.size(); i++) { + Type field = Unify(tt1->fields[i], tt2->fields[i]); + new_fields.push_back(field); + } + return TupleTypeNode::make(new_fields); + } + + Type VisitType_(const FuncTypeNode* op, const Type& tn) override { + const auto* ftn = tn.as(); + if (!ftn + || op->arg_types.size() != ftn->arg_types.size() + || op->type_params.size() != ftn->type_params.size() + || op->type_constraints.size() != ftn->type_constraints.size()) { + return Type(nullptr); + } + + // remap type vars so they match + Map subst_map; + for (size_t i = 0; i < op->type_params.size(); i++) { + subst_map.Set(ftn->type_params[i], op->type_params[i]); + } + + auto ft1 = GetRef(op); + auto ft2 = Downcast(Bind(GetRef(ftn), subst_map)); + + Type ret_type = Unify(ft1->ret_type, ft2->ret_type); + + std::vector arg_types; + for (size_t i = 0; i < ft1->arg_types.size(); i++) { + Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]); + arg_types.push_back(arg_type); + } + + std::vector type_constraints; + for (size_t i = 0; i < ft1->type_constraints.size(); i++) { + Type unified_constraint = Unify(ft1->type_constraints[i], + ft2->type_constraints[i]); + const auto* tcn = unified_constraint.as(); + CHECK(tcn) << "Two type constraints unified into a non-constraint?" + << ft1->type_constraints[i] << " and " << ft2->type_constraints[i]; + type_constraints.push_back(GetRef(tcn)); + } + + return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints); + } + + private: + TypeSolver* solver_; +}; + +class TypeSolver::Resolver : public TypeMutator { + public: + explicit Resolver(TypeSolver* solver) : solver_(solver) {} + + Type Resolve(const Type& t) { + if (!t.defined()) { + return t; + } + return VisitType(t); + } + + Type VisitType_(const IncompleteTypeNode* op) override { + auto* node = solver_->GetTypeNode(GetRef(op)); + return node->resolved_type; + } + + private: + TypeSolver* solver_; +}; + +// It ends up being more compact to simply have TypeFunctor { + public: + explicit Propagator(TypeSolver* solver, const std::unordered_set* rels) + : solver_(solver), rels_(rels) {} + + // adds the relation node to t and all child types of t + void Propagate(const Type& t) { + VisitType(t); + } + + void UpdateRelSet(const Type& t) { + TypeNode* tnode = solver_->GetTypeNode(t); + for (auto* rel : *rels_) { + tnode->rel_set.insert(rel); + } + } + + void VisitTypeDefault_(const Node* op) override { + NodeRef nr = GetRef(op); + Type t = GetRef(nr.as_derived()); + UpdateRelSet(t); + } + + void VisitType_(const TupleTypeNode* op) override { + TupleType tt = GetRef(op); + UpdateRelSet(tt); + + for (const Type& t : tt->fields) { + Propagate(t); + } + } + + void VisitType_(const FuncTypeNode* op) override { + FuncType ft = GetRef(op); + UpdateRelSet(ft); + + Propagate(ft->ret_type); + for (auto arg_type : ft->arg_types) { + Propagate(arg_type); + } + + for (auto type_param : ft->type_params) { + Propagate(type_param); + } + + for (auto type_cs : ft->type_constraints) { + Propagate(type_cs); + } + } + + private: + TypeSolver* solver_; + const std::unordered_set* rels_; +}; + +// similarly, we use TypeFunctor so we can use +// the default visitor case to avoid more overrides +class TypeSolver::Merger : public TypeFunctor { + public: + explicit Merger(TypeSolver* solver) : solver_(solver) {} + + // Merges src node to dst, ensures *all* type relations of all + // child nodes of src are transferred to dst. + void Merge(TypeNode* src, TypeNode* dst) { + if (src == dst) return; + dst_ = dst; + VisitType(src->resolved_type); + // set parent at the end so later calls to GetTypeNode go back to src + src->parent = dst; + + // now propagate relations to child nodes, since change to + // a child node should update parent too + Propagator prop(solver_, &dst->rel_set); + prop.Propagate(dst->resolved_type); + } + + // Transfers any relations linked to t to the stored dst. + // Any unresolved relations are added back to the queue, since + // there is now new information + void TransferLinks(const Type& t) { + TypeNode* src = solver_->GetTypeNode(t); + if (src == dst_) return; + for (auto* rel : src->rel_set) { + // if the relation is not yet resolved, add to queue + if (!rel->resolved) { + solver_->AddToQueue(rel); + dst_->rel_set.insert(rel); + } + } + } + + void VisitTypeDefault_(const Node* op) override { + NodeRef nr = GetRef(op); + Type t = GetRef(nr.as_derived()); + TransferLinks(t); + } + + void VisitType_(const TupleTypeNode* ttn) override { + auto tup = GetRef(ttn); + TransferLinks(tup); + + for (auto field : tup->fields) { + VisitType(field); + } + } + + void VisitType_(const FuncTypeNode* ftn) override { + auto func = GetRef(ftn); + TransferLinks(func); + + VisitType(func->ret_type); + for (auto arg : func->arg_types) { + VisitType(arg); + } + for (auto param : func->type_params) { + VisitType(param); + } + for (auto constraint : func->type_constraints) { + VisitType(constraint); + } + } + + private: + TypeSolver* solver_; + TypeNode* dst_; +}; + // constructor TypeSolver::TypeSolver(const GlobalVar ¤t_func, ErrorReporter* err_reporter) : reporter_(make_node(this)),