Skip to content

Commit

Permalink
fix some existential type unification cases (#1236)
Browse files Browse the repository at this point in the history
* Failing existential test case

* checkpoint after skolemize change

* change skolemize to only use strict covariant paths

* cleanup for PR

* add comment to skolemize

* use the real type, not just a dummy Int
  • Loading branch information
johnynek authored Oct 31, 2024
1 parent cecebd8 commit c9602f3
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 16 deletions.
36 changes: 27 additions & 9 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ object Infer {
val variances: Map[Type.Const.Defined, Kind]
) {

override def toString() = s"Env($uniq, $vars, $typeCons, $variances)"

def addVars(vt: NonEmptyList[(Name, Type)]): Env =
new Env(uniq, vars = (vars + vt.head) ++ vt.tail, typeCons, variances)

Expand Down Expand Up @@ -467,10 +469,24 @@ object Infer {
* covariant, then C[forall x. D[x]] == forall x. C[D[x]]
*
* this is always true for existential quantification I think, but for
* universal, we need that C is covariant which roughtly means C[x] either
* universal, we need that C is covariant which roughly means C[x] either
* has x in a return position of a function, or not at all, which then
* gives us that (forall x. (A(x) u B(x))) == (forall x A(x)) u (forall x
* B(x)) where A(x) and B(x) represent the union branches of the type C
*
* Here, we only float quantification above completely covariant paths,
* which includes function results returning functions, etc.
* I don't really know why this works, but it is skolemizing in a smaller set of
* cases compared to *any* covariant path (including contra * contra), yet still
* keeping functions in weak-prenex form (which isn't type checked, but is
* asserted in the inference and typechecking process).
*
* So, I guess one argument is we want to skolemize the least we can to make
* inference work, especially with function return types, and this is sufficient
* to do it. It also passes all the current tests.
*
* The paper we base the type inference on doesn't have a type system as complex
* as bosatsu, so we have to generalize it.
*/
private def skolemize(
t: Type,
Expand All @@ -480,11 +496,11 @@ object Infer {
// Invariant: if t is Rho, then result._3 is Rho
def loop(
t: Type,
path: Variance
allCo: Boolean
): Infer[(List[Type.Var.Skolem], List[Type.TyMeta], Type)] =
t match {
case q: Type.Quantified =>
if (path == Variance.co) {
if (allCo) {
val univ = q.forallList
val exists = q.existList
val ty = q.in
Expand All @@ -500,20 +516,20 @@ object Infer {
(exists.map(_._1).iterator.zip(ms) ++
univ.map(_._1).iterator.zip(sksT.iterator)).toMap
)
(sks2, ms2, ty) <- loop(ty1, path)
(sks2, ms2, ty) <- loop(ty1, allCo)
} yield (sks1 ::: sks2, ms ::: ms2, ty)
} else pure((Nil, Nil, t))

case ta @ Type.TyApply(left, right) =>
// Rule PRFUN
// we know the kind of left is k -> x, and right has kind k
// since left: Rho, we know loop(left, path)._3 is Rho
(varianceOfCons(ta, region), loop(left, path))
(varianceOfCons(ta, region), loop(left, allCo))
.flatMapN { case (consVar, (sksl, el, ltpe0)) =>
// due to loop invariant
val ltpe: Type.Rho = ltpe0.asInstanceOf[Type.Rho]
val rightPath = consVar * path
loop(right, rightPath)
val allCoRight = allCo && (consVar == Variance.co)
loop(right, allCoRight)
.map { case (sksr, er, rtpe) =>
(sksl ::: sksr, el ::: er, Type.TyApply(ltpe, rtpe))
}
Expand All @@ -523,7 +539,7 @@ object Infer {
pure((Nil, Nil, other))
}

loop(t, Variance.co).map {
loop(t, true).map {
case (skols, metas, rho: Type.Rho) =>
(skols, metas, rho)
// $COVERAGE-OFF$ this should be unreachable
Expand Down Expand Up @@ -1688,7 +1704,7 @@ object Infer {
// the length of args and varsT must be the same because of unifyFnRho
zipped = args.zip(varsT)
namesVarsT = zipped.map { case ((n, _), t) => (n, t) }
typedBody <- extendEnvList(namesVarsT.toList) {
typedBody <- extendEnvNonEmptyList(namesVarsT) {
// TODO we are ignoring the result of subsCheck here
// should we be coercing a var?
//
Expand All @@ -1697,6 +1713,8 @@ object Infer {
// indicates the testing coverage is incomplete
zipped.parTraverse_ {
case ((_, Some(tpe)), varT) =>
// since a -> b <:< c -> d means, b <:< d and c <:< a
// we check that the varT <:< tpe
subsCheck(varT, tpe, region(term), rr)
case ((_, None), _) => unit
} &>
Expand Down
6 changes: 3 additions & 3 deletions core/src/test/scala/org/bykn/bosatsu/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ object TestUtils {

val testPackage: PackageName = PackageName.parts("Test")

def checkLast(
def checkLast[A](
statement: String
)(fn: TypedExpr[Declaration] => Assertion): Assertion = {
)(fn: TypedExpr[Declaration] => A): A = {
val stmts = Parser.unsafeParse(Statement.parser, statement)
Package.inferBody(testPackage, Nil, stmts).strictToValidated match {
case Validated.Invalid(errs) =>
Expand All @@ -91,7 +91,7 @@ object TestUtils {
err.message(packMap, LocationMap.Colorize.None)
}
.mkString("", "\n==========\n", "\n")
fail("inference failure: " + msg)
sys.error("inference failure: " + msg)
case Validated.Valid(program) =>
// make sure all the TypedExpr are valid
program.lets.foreach { case (_, _, te) => assertValid(te) }
Expand Down
31 changes: 27 additions & 4 deletions core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ import TestUtils.{checkLast, testPackage}

import Identifier.Constructor

import org.scalatest.funsuite.AnyFunSuite

import cats.syntax.all._

class RankNInferTest extends AnyFunSuite {
class RankNInferTest extends munit.FunSuite {

val emptyRegion: Region = Region(0, 0)

Expand Down Expand Up @@ -254,7 +252,9 @@ class RankNInferTest extends AnyFunSuite {
assert_:<:("(exists a. a) -> Int", "Int -> Int")
assertTypesUnify("(exists a. a) -> Int", "forall a. a -> Int")
assertTypesUnify("exists a. List[a]", "List[exists a. a]")
assertTypesUnify("Int -> (exists a. a)", "exists a. (Int -> a)")
assertTypesUnify("exists a. (Int -> a)", "Int -> (exists a. a)")
assertTypesUnify("(exists a. a) -> Int", "(exists a. a) -> Int")
assert_:<:("forall a. a -> a", "(exists a. a) -> (exists a. a)")

assert_:<:("forall a. a -> Int", "(forall a. a) -> Int")
assertTypesUnify(
Expand Down Expand Up @@ -1929,4 +1929,27 @@ struct Foo
f = Foo
""")
}

test("identity function with existential") {
parseProgram(
"""
struct Prog[a: -*, e: +*, b: +*]
def pass_thru(f: Prog[exists a. a, e, b]) -> Prog[exists a. a, e, b]:
f
""",
"forall e, b. Prog[exists a. a, e, b] -> Prog[exists a. a, e, b]"
)

parseProgram(
"""
struct Foo
struct Prog[a: -*, e: +*, b: +*]
def pass_thru(f: Prog[exists a. a, Foo, Foo]) -> Prog[exists a. a, Foo, Foo]:
f
""",
"Prog[exists a. a, Foo, Foo] -> Prog[exists a. a, Foo, Foo]"
)
}
}

0 comments on commit c9602f3

Please sign in to comment.