diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index 8a292a8f5..f1bd07de1 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -492,6 +492,7 @@ object Infer { } yield TypedExpr.coerceFn(a1s, r2, coarg, cores, ks) /* + * If t <:< rho then coerce to rho * invariant: second argument is in weak prenex form, which means that all * the covariant positions have lifted the ForAlls out, e.g. * forall a. a -> (forall b. b -> b) @@ -507,6 +508,7 @@ object Infer { subsCheckRho2(rhot, rho, left, right) } + // if t <:< rho, then coerce to rho def subsCheckRho2(t: Type.Rho, rho: Type.Rho, left: Region, right: Region): Infer[TypedExpr.Coerce] = // get the kinds to make sure they are well kinded kindOf(t, left).product(kindOf(rho, right)) *> @@ -551,7 +553,7 @@ object Infer { // should we coerce to t2? Seems like... but copying previous code _ <- subsCheck(l1, l2, left, right) ks <- checkedKinds - } yield TypedExpr.coerceRho(rho1, ks) + } yield TypedExpr.coerceRho(ta, ks) case (ta@Type.TyApply(l1, r1), rho2) => for { kl <- kindOf(l1, left) @@ -571,8 +573,7 @@ object Infer { } _ <- subsCheck(l1, l2, left, right) ks <- checkedKinds - // should we coerce to t2? Seems like... but copying previous code - } yield TypedExpr.coerceRho(ta, ks) + } yield TypedExpr.coerceRho(rho2, ks) case (t1, t2) => // rule: MONO for { @@ -865,14 +866,14 @@ object Infer { extendEnv(name, tpe) { checkSigma(rhs, tpe).parProduct(typeCheckRho(body, expect)) } - case _ => + case notAnnotated => newMetaType(Kind.Type) // the kind of a let value is a Type .flatMap { rhsTpe => extendEnv(name, rhsTpe) { for { // the type variable needs to be unified with varT // note, varT could be a sigma type, it is not a Tau or Rho - typedRhs <- inferSigmaMeta(rhs, Some((name, rhsTpe, region(rhs)))) + typedRhs <- inferSigmaMeta(notAnnotated, Some((name, rhsTpe, region(notAnnotated)))) varT = typedRhs.getType // we need to overwrite the metavariable now with the full type typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect)) @@ -949,7 +950,7 @@ object Infer { tbranches <- branches.parTraverse { case (p, r) => inferBranch(p, check, r) } - (rho, regRho, resBranches) <- narrowBranches(tbranches) + (rho, regRho, resBranches) <- widenBranches(tbranches) _ <- infer.set((rho, regRho)) } yield TypedExpr.Match(tsigma, resBranches, tag) } @@ -957,34 +958,34 @@ object Infer { } } - def narrowBranches[A: HasRegion](branches: NonEmptyList[(Pattern, (TypedExpr.Rho[A], Type.Rho))]): Infer[(Type.Rho, Region, NonEmptyList[(Pattern, TypedExpr.Rho[A])])] = { + def widenBranches[A: HasRegion](branches: NonEmptyList[(Pattern, (TypedExpr.Rho[A], Type.Rho))]): Infer[(Type.Rho, Region, NonEmptyList[(Pattern, TypedExpr.Rho[A])])] = { - def minBy[M[_]: Monad, B](head: B, tail: List[B])(lteq: (B, B) => M[Boolean]): M[B] = + def maxBy[M[_]: Monad, B](head: B, tail: List[B])(gteq: (B, B) => M[Boolean]): M[B] = tail match { case Nil => Monad[M].pure(head) case h :: tail => - lteq(head, h) + gteq(head, h) .flatMap { keep => val next = if (keep) head else h - minBy(next, tail)(lteq) + maxBy(next, tail)(gteq) } } - def ltEq[K](left: (TypedExpr[A], K), right: (TypedExpr[A], K)): Infer[Boolean] = { + def gtEq[K](left: (TypedExpr[A], K), right: (TypedExpr[A], K)): Infer[Boolean] = { val leftTE = left._1 val rightTE = right._1 val lt = leftTE.getType val lr = region(leftTE) val rt = rightTE.getType val rr = region(rightTE) - // right <= left if left subsumes right - subsCheck(lt, rt, lr, rr) + // left >= right if right subsumes left + subsCheck(rt, lt, rr, lr) .peek .flatMap { case Right(_) => pure(true) case Left(_) => // maybe the other way around - subsCheck(rt, lt, rr, lr) + subsCheck(lt, rt, lr, rr) .peek .flatMap { case Right(_) => @@ -1000,7 +1001,7 @@ object Infer { val withIdx = branches.zipWithIndex.map { case ((p, (te, tpe)), idx) => (te, (p, tpe, idx)) } for { - (minRes, (minPat, resTRho, minIdx)) <- minBy(withIdx.head, withIdx.tail)((a, b) => ltEq(a, b)) + (minRes, (minPat, resTRho, minIdx)) <- maxBy(withIdx.head, withIdx.tail)((a, b) => gtEq(a, b)) resRegion = region(minRes) resBranches <- withIdx.parTraverse { case (te, (p, tpe, idx)) => if (idx != minIdx) { @@ -1283,6 +1284,7 @@ object Infer { def inferSigma[A: HasRegion](e: Expr[A]): Infer[TypedExpr[A]] = inferSigmaMeta(e, None) + // invariant: if meta.isDefined then e is not Expr.Annotated def inferSigmaMeta[A: HasRegion](e: Expr[A], meta: Option[(Identifier, Type.TyMeta, Region)]): Infer[TypedExpr[A]] = { def unifySelf(tpe: Type.Rho): Infer[Map[Name, Type]] = meta match { @@ -1303,8 +1305,6 @@ object Infer { case Some((_, tpe, rtpe)) => def maybeUnified(e: Expr[A]): Infer[Unit] = e match { - case Expr.Annotation(e1, t, _) => - unifyType(tpe, t, rtpe, region(e)) *> maybeUnified(e1) case Expr.Lambda(args, res, _) => unifyFn(args.length, tpe, rtpe, region(e) - region(res)).void case _ => @@ -1377,9 +1377,9 @@ object Infer { expr match { case Expr.Annotated(tpe) => extendEnv(name, tpe)(checkSigma(expr, tpe)) - case _ => + case notAnnotated => newMetaType(Kind.Type).flatMap { tpe => - extendEnv(name, tpe)(typeCheckMeta(expr, Some((name, tpe, region(expr))))) + extendEnv(name, tpe)(typeCheckMeta(notAnnotated, Some((name, tpe, region(notAnnotated))))) } } @@ -1399,6 +1399,7 @@ object Infer { } } + // Invariant: if optMeta.isDefined then t is not Expr.Annotated private def typeCheckMeta[A: HasRegion](t: Expr[A], optMeta: Option[(Identifier, Type.TyMeta, Region)]): Infer[TypedExpr[A]] = { def run(t: Expr[A]) = inferSigmaMeta(t, optMeta).flatMap(zonkTypedExpr _) diff --git a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala index ad35d7db5..23c98a635 100644 --- a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala @@ -3167,4 +3167,19 @@ test3 = Assertion(last("foo") matches Some(.'o'), "last test") all = TestSuite("chars", [test1, test2, test3]) """), "Foo", 3) } + + test("test universal quantified list match") { + runBosatsuTest(List(""" +package Foo + +empty: (forall a. List[a]) = [] + +res = match empty: + case []: 0 + case [_, *_]: 1 + +test = Assertion(res matches 0, "one") +"""), "Foo", 1) + + } } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala index 6f02a5671..062bf1237 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala @@ -216,6 +216,10 @@ class RankNInferTest extends AnyFunSuite { assertTypesDisjoint("Int -> Unit", "String") assertTypesDisjoint("Int -> Unit", "String -> a") assertTypesUnify("forall a. Int", "Int") + + // Test unbound vars + assertTypesDisjoint("a", "Int") + assertTypesDisjoint("Int", "a") } test("Basic inferences") { @@ -1138,4 +1142,39 @@ foo = ( """, "Foo") } + + test("widening inside a match") { + parseProgram("""# +enum B: True, False + +def not(b): + match b: + case True: False + case False: True + +def branch(x): + match x: + case True: (x -> x): forall a. a -> a + case False: i -> not(i) + +res = branch(True)(True) +""", "B") + + parseProgramIllTyped("""# +enum B: True, False + +def not(b): + match b: + case True: False + case False: True + +def branch[a](x: B) -> (a -> a): + match x: + case True: (x -> x): forall a. a -> a + case False: i -> not(i) + +res = branch(True)(True) +""") + } + }