diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index a8b60e7806fe..d7b3b394c5cd 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -81,13 +81,13 @@ class FeatureSet { return ret; } /*! \brief A set that contain all the Feature. */ - static FeatureSet AllFeature() { + static FeatureSet All() { FeatureSet fs; fs.bs_.flip(); return fs; } /*! \brief The empty set. Contain no Feature. */ - static FeatureSet NoFeature() { + static FeatureSet No() { FeatureSet fs; return fs; } diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 491720de5cda..440052f64cbd 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -280,6 +280,7 @@ def optimize(self): """ seq = transform.Sequential([transform.SimplifyInference(), transform.FuseOps(0), + transform.ToANormalForm(), transform.InferType()]) return seq(self.mod) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index e77d6a894af7..86a4ebb4ebd2 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "compile_engine.h" namespace tvm { @@ -761,6 +762,8 @@ CreateInterpreter( Target target) { auto intrp = std::make_shared(mod, context, target); auto packed = [intrp](Expr expr) { + auto f = DetectFeature(expr); + CHECK(f.is_subset_of(FeatureSet::All() - fGraph)); return intrp->Eval(expr); }; return TypedPackedFunc(packed); diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index f83e5882a473..878795d0b9f2 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -120,7 +120,7 @@ class AlphaEqualHandler: * \return the comparison result. */ bool TypeEqual(const Type& lhs, const Type& rhs) { - auto compute = [&](){ + auto compute = [&]() { if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; return this->VisitType(lhs, rhs); diff --git a/src/relay/pass/feature.cc b/src/relay/pass/feature.cc index df3a5d7ecec5..2c5e7ab3b984 100644 --- a/src/relay/pass/feature.cc +++ b/src/relay/pass/feature.cc @@ -34,13 +34,15 @@ namespace relay { FeatureSet DetectFeature(const Expr& expr) { if (!expr.defined()) { - return FeatureSet::NoFeature(); + return FeatureSet::No(); } struct FeatureDetector : ExprVisitor { std::unordered_set visited_; - FeatureSet fs = FeatureSet::NoFeature(); + FeatureSet fs = FeatureSet::No(); + void VisitExpr(const Expr& expr) final { if (visited_.count(expr) == 0) { + visited_.insert(expr); ExprVisitor::VisitExpr(expr); } else { if (!IsAtomic(expr)) { @@ -52,15 +54,20 @@ FeatureSet DetectFeature(const Expr& expr) { void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \ STMT \ fs += f##CONSTRUCT_NAME; \ - ExprVisitor::VisitExpr_(op); \ } -#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, {}) +#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, { \ + ExprVisitor::VisitExpr_(op); \ + }) DETECT_DEFAULT_CONSTRUCT(Var) DETECT_DEFAULT_CONSTRUCT(GlobalVar) DETECT_DEFAULT_CONSTRUCT(Constant) DETECT_DEFAULT_CONSTRUCT(Tuple) DETECT_DEFAULT_CONSTRUCT(TupleGetItem) - DETECT_DEFAULT_CONSTRUCT(Function) + DETECT_CONSTRUCT(Function, { + if (!op->IsPrimitive()) { + ExprVisitor::VisitExpr_(op); + } + }) DETECT_DEFAULT_CONSTRUCT(Op) DETECT_DEFAULT_CONSTRUCT(Call) DETECT_CONSTRUCT(Let, { @@ -69,6 +76,7 @@ FeatureSet DetectFeature(const Expr& expr) { fs += fLetRec; } } + ExprVisitor::VisitExpr_(op); }) DETECT_DEFAULT_CONSTRUCT(If) DETECT_DEFAULT_CONSTRUCT(RefCreate) @@ -83,7 +91,7 @@ FeatureSet DetectFeature(const Expr& expr) { } FeatureSet DetectFeature(const Module& mod) { - FeatureSet fs = FeatureSet::NoFeature(); + FeatureSet fs = FeatureSet::No(); if (mod.defined()) { for (const auto& f : mod->functions) { fs += DetectFeature(f.second); diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index f7de2a927c66..e8bdc090c947 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -139,19 +139,8 @@ class TypeInferencer : private ExprFunctor, // Perform unification on two types and report the error at the expression // or the span of the expression. Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) { - // TODO(tqchen, jroesch): propagate span to solver try { - // instantiate higher-order func types when unifying because - // we only allow polymorphism at the top level - Type first = t1; - Type second = t2; - if (auto* ft1 = t1.as()) { - first = InstantiateFuncType(ft1); - } - if (auto* ft2 = t2.as()) { - second = InstantiateFuncType(ft2); - } - return solver_.Unify(first, second, expr); + return solver_.Unify(t1, t2, expr); } catch (const dmlc::Error &e) { this->ReportFatalError( expr, diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 38870762d840..743a4c7774b8 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -289,30 +289,44 @@ class TypeSolver::Unifier : public TypeFunctor { const auto* ftn = tn.as(); if (!ftn || op->arg_types.size() != ftn->arg_types.size() - || op->type_params.size() != ftn->type_params.size() || op->type_constraints.size() != ftn->type_constraints.size()) { return Type(nullptr); } + // without loss of generality, suppose op->type_params.size() >= ftn->type_params.size(). + if (op->type_params.size() < ftn->type_params.size()) { + return VisitType_(ftn, GetRef(op)); + } + // remap type vars so they match Map subst_map; - for (size_t i = 0; i < op->type_params.size(); i++) { - subst_map.Set(ftn->type_params[i], op->type_params[i]); + tvm::Array ft_type_params; + for (size_t i = 0; i < ftn->type_params.size(); ++i) { + subst_map.Set(op->type_params[i], ftn->type_params[i]); + ft_type_params.push_back(op->type_params[i]); + } + + for (size_t i = ftn->type_params.size(); i < op->type_params.size(); ++i) { + subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType)); } - auto ft1 = GetRef(op); - auto ft2 = Downcast(Bind(GetRef(ftn), subst_map)); + FuncType ft = FuncTypeNode::make(op->arg_types, + op->ret_type, + ft_type_params, + op->type_constraints); + auto ft1 = Downcast(Bind(ft, subst_map)); + auto ft2 = GetRef(ftn); Type ret_type = Unify(ft1->ret_type, ft2->ret_type); std::vector arg_types; - for (size_t i = 0; i < ft1->arg_types.size(); i++) { + for (size_t i = 0; i < ft2->arg_types.size(); ++i) { Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]); arg_types.push_back(arg_type); } std::vector type_constraints; - for (size_t i = 0; i < ft1->type_constraints.size(); i++) { + for (size_t i = 0; i < ft1->type_constraints.size(); ++i) { Type unified_constraint = Unify(ft1->type_constraints[i], ft2->type_constraints[i]); const auto* tcn = unified_constraint.as(); @@ -321,7 +335,7 @@ class TypeSolver::Unifier : public TypeFunctor { type_constraints.push_back(GetRef(tcn)); } - return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints); + return FuncTypeNode::make(arg_types, ret_type, ft2->type_params, type_constraints); } Type VisitType_(const RefTypeNode* op, const Type& tn) final { diff --git a/tests/python/relay/test_feature.py b/tests/python/relay/test_feature.py index 3e9e6a365e77..8f0e90de0315 100644 --- a/tests/python/relay/test_feature.py +++ b/tests/python/relay/test_feature.py @@ -63,7 +63,8 @@ def test_ad(): Feature.fLet, Feature.fRefCreate, Feature.fRefRead, - Feature.fRefWrite + Feature.fRefWrite, + Feature.fGraph ]) diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index 7bed13fc76ec..045c92c929c4 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -30,6 +30,20 @@ def rand(dtype='float32', *shape): return tvm.nd.array(np.random.rand(*shape).astype(dtype)) +def test_id(): + x = relay.var("x", shape=[]) + id = run_infer_type(relay.Function([x], x)) + id_cps = run_infer_type(to_cps(id)) + + +def test_double(): + t = relay.TypeVar("t") + x = relay.var("x", t) + f = relay.var("f", relay.FuncType([t], t)) + double = run_infer_type(relay.Function([f, x], f(f(x)), t, [t])) + double_cps = run_infer_type(to_cps(double)) + + # make sure cps work for recursion. def test_recursion(): mod = relay.Module() diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index e8dff7aa9981..3f6b0d2eb895 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -19,6 +19,7 @@ """ from tvm import relay from tvm.relay import op, transform, analysis +from tvm.relay.analysis import assert_alpha_equal def run_infer_type(expr, mod=None): @@ -349,6 +350,17 @@ def test_adt_match_type_annotations(): assert ft.checked_type == relay.FuncType([tt], relay.TupleType([])) +def test_let_polymorphism(): + id = relay.Var("id") + xt = relay.TypeVar("xt") + x = relay.Var("x", xt) + body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))]) + body = relay.Let(id, relay.Function([x], x, xt, [xt]), body) + body = run_infer_type(body) + int32 = relay.TensorType((), "int32") + assert_alpha_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) + + if __name__ == "__main__": test_free_expr() test_dual_op() @@ -366,3 +378,4 @@ def test_adt_match_type_annotations(): test_constructor_type() test_constructor_call() test_adt_match() + test_let_polymorphism()