Skip to content

Commit

Permalink
Fix coerce in Infer.subsCheck (#1060)
Browse files Browse the repository at this point in the history
* Fix coerce in Infer.subsCheck

* remove an unreachable branch

* Add an ill-typed test

* try to increase coverage
  • Loading branch information
johnynek authored Oct 22, 2023
1 parent a746a82 commit b324264
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 19 deletions.
39 changes: 20 additions & 19 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ object Infer {
} yield TypedExpr.coerceFn(a1s, r2, coarg, cores, ks)

/*
* If t <:< rho then coerce to rho
* invariant: second argument is in weak prenex form, which means that all
* the covariant positions have lifted the ForAlls out, e.g.
* forall a. a -> (forall b. b -> b)
Expand All @@ -507,6 +508,7 @@ object Infer {
subsCheckRho2(rhot, rho, left, right)
}

// if t <:< rho, then coerce to rho
def subsCheckRho2(t: Type.Rho, rho: Type.Rho, left: Region, right: Region): Infer[TypedExpr.Coerce] =
// get the kinds to make sure they are well kinded
kindOf(t, left).product(kindOf(rho, right)) *>
Expand Down Expand Up @@ -551,7 +553,7 @@ object Infer {
// should we coerce to t2? Seems like... but copying previous code
_ <- subsCheck(l1, l2, left, right)
ks <- checkedKinds
} yield TypedExpr.coerceRho(rho1, ks)
} yield TypedExpr.coerceRho(ta, ks)
case (ta@Type.TyApply(l1, r1), rho2) =>
for {
kl <- kindOf(l1, left)
Expand All @@ -571,8 +573,7 @@ object Infer {
}
_ <- subsCheck(l1, l2, left, right)
ks <- checkedKinds
// should we coerce to t2? Seems like... but copying previous code
} yield TypedExpr.coerceRho(ta, ks)
} yield TypedExpr.coerceRho(rho2, ks)
case (t1, t2) =>
// rule: MONO
for {
Expand Down Expand Up @@ -865,14 +866,14 @@ object Infer {
extendEnv(name, tpe) {
checkSigma(rhs, tpe).parProduct(typeCheckRho(body, expect))
}
case _ =>
case notAnnotated =>
newMetaType(Kind.Type) // the kind of a let value is a Type
.flatMap { rhsTpe =>
extendEnv(name, rhsTpe) {
for {
// the type variable needs to be unified with varT
// note, varT could be a sigma type, it is not a Tau or Rho
typedRhs <- inferSigmaMeta(rhs, Some((name, rhsTpe, region(rhs))))
typedRhs <- inferSigmaMeta(notAnnotated, Some((name, rhsTpe, region(notAnnotated))))
varT = typedRhs.getType
// we need to overwrite the metavariable now with the full type
typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect))
Expand Down Expand Up @@ -949,42 +950,42 @@ object Infer {
tbranches <- branches.parTraverse { case (p, r) =>
inferBranch(p, check, r)
}
(rho, regRho, resBranches) <- narrowBranches(tbranches)
(rho, regRho, resBranches) <- widenBranches(tbranches)
_ <- infer.set((rho, regRho))
} yield TypedExpr.Match(tsigma, resBranches, tag)
}
}
}
}

def narrowBranches[A: HasRegion](branches: NonEmptyList[(Pattern, (TypedExpr.Rho[A], Type.Rho))]): Infer[(Type.Rho, Region, NonEmptyList[(Pattern, TypedExpr.Rho[A])])] = {
def widenBranches[A: HasRegion](branches: NonEmptyList[(Pattern, (TypedExpr.Rho[A], Type.Rho))]): Infer[(Type.Rho, Region, NonEmptyList[(Pattern, TypedExpr.Rho[A])])] = {

def minBy[M[_]: Monad, B](head: B, tail: List[B])(lteq: (B, B) => M[Boolean]): M[B] =
def maxBy[M[_]: Monad, B](head: B, tail: List[B])(gteq: (B, B) => M[Boolean]): M[B] =
tail match {
case Nil => Monad[M].pure(head)
case h :: tail =>
lteq(head, h)
gteq(head, h)
.flatMap { keep =>
val next = if (keep) head else h
minBy(next, tail)(lteq)
maxBy(next, tail)(gteq)
}
}

def ltEq[K](left: (TypedExpr[A], K), right: (TypedExpr[A], K)): Infer[Boolean] = {
def gtEq[K](left: (TypedExpr[A], K), right: (TypedExpr[A], K)): Infer[Boolean] = {
val leftTE = left._1
val rightTE = right._1
val lt = leftTE.getType
val lr = region(leftTE)
val rt = rightTE.getType
val rr = region(rightTE)
// right <= left if left subsumes right
subsCheck(lt, rt, lr, rr)
// left >= right if right subsumes left
subsCheck(rt, lt, rr, lr)
.peek
.flatMap {
case Right(_) => pure(true)
case Left(_) =>
// maybe the other way around
subsCheck(rt, lt, rr, lr)
subsCheck(lt, rt, lr, rr)
.peek
.flatMap {
case Right(_) =>
Expand All @@ -1000,7 +1001,7 @@ object Infer {
val withIdx = branches.zipWithIndex.map { case ((p, (te, tpe)), idx) => (te, (p, tpe, idx)) }

for {
(minRes, (minPat, resTRho, minIdx)) <- minBy(withIdx.head, withIdx.tail)((a, b) => ltEq(a, b))
(minRes, (minPat, resTRho, minIdx)) <- maxBy(withIdx.head, withIdx.tail)((a, b) => gtEq(a, b))
resRegion = region(minRes)
resBranches <- withIdx.parTraverse { case (te, (p, tpe, idx)) =>
if (idx != minIdx) {
Expand Down Expand Up @@ -1283,6 +1284,7 @@ object Infer {
def inferSigma[A: HasRegion](e: Expr[A]): Infer[TypedExpr[A]] =
inferSigmaMeta(e, None)

// invariant: if meta.isDefined then e is not Expr.Annotated
def inferSigmaMeta[A: HasRegion](e: Expr[A], meta: Option[(Identifier, Type.TyMeta, Region)]): Infer[TypedExpr[A]] = {
def unifySelf(tpe: Type.Rho): Infer[Map[Name, Type]] =
meta match {
Expand All @@ -1303,8 +1305,6 @@ object Infer {
case Some((_, tpe, rtpe)) =>
def maybeUnified(e: Expr[A]): Infer[Unit] =
e match {
case Expr.Annotation(e1, t, _) =>
unifyType(tpe, t, rtpe, region(e)) *> maybeUnified(e1)
case Expr.Lambda(args, res, _) =>
unifyFn(args.length, tpe, rtpe, region(e) - region(res)).void
case _ =>
Expand Down Expand Up @@ -1377,9 +1377,9 @@ object Infer {
expr match {
case Expr.Annotated(tpe) =>
extendEnv(name, tpe)(checkSigma(expr, tpe))
case _ =>
case notAnnotated =>
newMetaType(Kind.Type).flatMap { tpe =>
extendEnv(name, tpe)(typeCheckMeta(expr, Some((name, tpe, region(expr)))))
extendEnv(name, tpe)(typeCheckMeta(notAnnotated, Some((name, tpe, region(notAnnotated)))))
}
}

Expand All @@ -1399,6 +1399,7 @@ object Infer {
}
}

// Invariant: if optMeta.isDefined then t is not Expr.Annotated
private def typeCheckMeta[A: HasRegion](t: Expr[A], optMeta: Option[(Identifier, Type.TyMeta, Region)]): Infer[TypedExpr[A]] = {
def run(t: Expr[A]) = inferSigmaMeta(t, optMeta).flatMap(zonkTypedExpr _)

Expand Down
15 changes: 15 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3167,4 +3167,19 @@ test3 = Assertion(last("foo") matches Some(.'o'), "last test")
all = TestSuite("chars", [test1, test2, test3])
"""), "Foo", 3)
}

test("test universal quantified list match") {
runBosatsuTest(List("""
package Foo
empty: (forall a. List[a]) = []
res = match empty:
case []: 0
case [_, *_]: 1
test = Assertion(res matches 0, "one")
"""), "Foo", 1)

}
}
39 changes: 39 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ class RankNInferTest extends AnyFunSuite {
assertTypesDisjoint("Int -> Unit", "String")
assertTypesDisjoint("Int -> Unit", "String -> a")
assertTypesUnify("forall a. Int", "Int")

// Test unbound vars
assertTypesDisjoint("a", "Int")
assertTypesDisjoint("Int", "a")
}

test("Basic inferences") {
Expand Down Expand Up @@ -1138,4 +1142,39 @@ foo = (
""", "Foo")
}

test("widening inside a match") {
parseProgram("""#
enum B: True, False
def not(b):
match b:
case True: False
case False: True
def branch(x):
match x:
case True: (x -> x): forall a. a -> a
case False: i -> not(i)
res = branch(True)(True)
""", "B")

parseProgramIllTyped("""#
enum B: True, False
def not(b):
match b:
case True: False
case False: True
def branch[a](x: B) -> (a -> a):
match x:
case True: (x -> x): forall a. a -> a
case False: i -> not(i)
res = branch(True)(True)
""")
}

}

0 comments on commit b324264

Please sign in to comment.