Skip to content

Commit

Permalink
get tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Nov 19, 2023
1 parent 346238a commit 469fcf9
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 37 deletions.
15 changes: 8 additions & 7 deletions core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -416,20 +416,18 @@ object TypedExpr {
val forAlls = typeArgs.collect { case (nk, false) => nk }
val exists = typeArgs.collect { case (nk, true) => nk }
val q = quantVars(forallList = forAlls, existList = exists, r)
println(s"forAlls = $forAlls exists = $exists ${r.repr} quantVars = ${q.repr}")
//println(s"forAlls = $forAlls exists = $exists ${r.repr} quantVars = ${q.repr}")
q
}
}

type Name = (Option[PackageName], Identifier)

def quantifyMetas(envList: => List[Type], metas: SortedSet[Type.Meta], te: TypedExpr[A]): F[TypedExpr[A]] =
if (metas.isEmpty) Applicative[F].pure(te)
else {
for {
envTypeVars <- getMetaTyVars(envList)
localMetas = metas -- envTypeVars
_ = println(s"localMetas = $localMetas in ${te.repr}")
localMetas = metas.diff(envTypeVars)
//_ = println(s"localMetas = $localMetas in ${te.repr}")
q <- quantify0(localMetas.toList, te)
} yield q
}
Expand Down Expand Up @@ -469,10 +467,12 @@ object TypedExpr {

getMetaTyVars(te1.allTypes.toList)
.flatMap(quantifyMetas(envList, _, te1))
/*
.map { res =>
println(s"quantifyFree, teSkols=${teSkols} ${te.repr} => ${te1.repr} => ${res.repr}")
res
}
*/
}

/*
Expand All @@ -487,7 +487,7 @@ object TypedExpr {
def deepQuantify(env: Set[Type], te: TypedExpr[A]): F[TypedExpr[A]] =
quantifyFree(env, te).flatMap {
case Generic(quant, in) =>
assert(te != in, s"${te.repr} quantifyFree => ${in.repr}")
//assert(te != in, s"${te.repr} quantifyFree => ${in.repr}")
deepQuantify(env + te.getType, in).map { in1 =>
quantVars(quant.forallList, quant.existList, in1)
}
Expand Down Expand Up @@ -1191,7 +1191,8 @@ object TypedExpr {
def forAll[A](params: NonEmptyList[(Type.Var.Bound, Kind)], expr: TypedExpr[A]): TypedExpr[A] =
quantVars(forallList = params.toList, Nil, expr)

def quantVars[A](forallList: List[(Type.Var.Bound, Kind)],
def quantVars[A](
forallList: List[(Type.Var.Bound, Kind)],
existList: List[(Type.Var.Bound, Kind)],
expr: TypedExpr[A]): TypedExpr[A] = {

Expand Down
164 changes: 142 additions & 22 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ sealed abstract class Infer[+A] {
final def mapEither[B](fn: A => Either[Error, B]): Infer[B] =
Infer.Impl.MapEither(this, fn)

// $COVERAGE-OFF$ useful for debugging
final def debugError(fn: Error => Unit): Infer[A] =
peek.mapEither { err =>
err match {
case Left(err) => println(fn(err))
case _ => ()
}

err
}
// $COVERAGE-ON$

final def runVar(
v: Map[Infer.Name, Type],
tpes: Map[(PackageName, Constructor), Infer.Cons],
Expand Down Expand Up @@ -117,7 +129,9 @@ object Infer {


def getKind(t: Type, region: Region): Either[Error, Kind] =
kindCache(t).leftMap(_(region))
kindCache(t).leftMap { err =>
err(region)
}
}

object Env {
Expand Down Expand Up @@ -335,7 +349,10 @@ object Infer {

def kindOf(t: Type, r: Region): Infer[Kind] =
GetEnv.mapEither { env =>
env.getKind(t, r)
env.getKind(t, r).leftMap { err =>
//println(s"kindOf($t, $r) failed => $err")
err
}
}

private val checkedKinds: Infer[Type => Option[Kind]] = {
Expand Down Expand Up @@ -610,7 +627,7 @@ object Infer {
} yield coerce
case (rho1, ta@Type.TyApply(l2, r2)) =>
for {
kl <- kindOf(l2, right)
kl <- kindOf(l2, right)
kr <- kindOf(r2, right)
l1r1 <- unifyTyApp(rho1, kl, kr, left, right)
(l1, r1) = l1r1
Expand Down Expand Up @@ -649,6 +666,23 @@ object Infer {
_ <- subsCheck(l1, l2, left, right)
ks <- checkedKinds
} yield TypedExpr.coerceRho(rho2, ks)
case (
t1 @ Type.TyVar(Type.Var.Skolem(_, k1, true, _)),
t2 @ Type.TyMeta(m2 @ Type.Meta(k2, _, true, _)))
if Kind.leftSubsumesRight(k2, k1) =>
// maybe unify existential skolems, if neither are in the environment
// then allocate a new skolem
for {
env <- getEnv
tyVars = Type.metaTvs(env.values.toList)
_ <- if (tyVars.contains(m2)) {
// this meta already escaped into the environment
// this unify will fail
unify(t1, t2, left, right)
}
else unit
ck <- checkedKinds
} yield TypedExpr.coerceRho(t2, ck)
case (t1, t2) =>
// rule: MONO
for {
Expand Down Expand Up @@ -912,7 +946,7 @@ object Infer {
envTpes: Infer[List[Type]],
a: A,
onErr: NonEmptyList[Type] => Error)(
fn: (A, NonEmptyList[Type.TyMeta]) => A
fn: (A, NonEmptyList[Type.TyMeta]) => Infer[A]
): Infer[A] =
metas match {
case Nil => pure(a)
Expand All @@ -927,7 +961,7 @@ object Infer {

NonEmptyList.fromList(badList) match {
case None =>
pure(fn(a, NonEmptyList(h, t)))
fn(a, NonEmptyList(h, t))
case Some(badTvs) => fail(onErr(badTvs))
}
}
Expand All @@ -937,11 +971,11 @@ object Infer {
declared: Type,
region: Region,
envTpes: Infer[List[Type]])(
fn: Type.Rho => Infer[FunctionK[F, Lambda[x => G[TypedExpr[x]]]]])(
fn: (List[Type.TyMeta], Type.Rho) => Infer[FunctionK[F, Lambda[x => G[TypedExpr[x]]]]])(
onErr: NonEmptyList[Type] => Error): Infer[FunctionK[F, Lambda[x => G[TypedExpr[x]]]]] =
for {
(skols, metas, rho) <- skolemize(declared, region)
coerce <- fn(rho)
coerce <- fn(metas, rho)
// if there are no skolem variables, we can shortcut here, because empty.filter(fn) == empty
resSkols <- checkEscapeSkols(
skols,
Expand All @@ -954,15 +988,19 @@ object Infer {
declared,
envTpes,
resSkols,
// TODO maybe this function should go ahead and quantify
onErr) { (coerce, _) => coerce }
onErr) { (coerce, _) =>
// TODO maybe this function should go ahead and quantify
pure(coerce)
}
} yield res

// DEEP-SKOL rule
// note, this is identical to subsCheckRho when declared is a Rho type
def subsCheck(inferred: Type, declared: Type, left: Region, right: Region): Infer[TypedExpr.Coerce] =
subsUpper[TypedExpr, cats.Id](declared, right, pure(inferred :: Nil)) {
subsCheckRho(inferred, _, left, right)
subsUpper[TypedExpr, cats.Id](declared, right, pure(inferred :: Nil)) { (_, rho) =>
// TODO: we are ignoring the metas, but we can't easily write them
// with the current design since Coerce can't do any Meta writing
subsCheckRho(inferred, rho, left, right)
} {
Error.SubsumptionCheckFailure(inferred, declared, left, right, _)
}
Expand Down Expand Up @@ -1121,7 +1159,19 @@ object Infer {
}
}
case Annotation(term, tpe, tag) =>
(checkSigma(term, tpe), instSigma(tpe, expect, region(tag)))
val inner = term match {
case Match(arg, branches, mtag) =>
// We push the Annotation down to help with
// existential type checking where each branch
// has a different type
Match(arg, branches.map { case (p, r) =>
(p, Annotation(r, tpe, tag))
},
mtag)
case notMatch => notMatch
}

(checkSigma(inner, tpe), instSigma(tpe, expect, region(tag)))
.parFlatMapN { (typedTerm, coerce) =>
zonkTypedExpr(typedTerm).map(coerce(_))
}
Expand All @@ -1141,6 +1191,8 @@ object Infer {
// are missing here.
inferSigma(term)
.flatMap { tsigma =>
assertNoFree(tsigma.getType, s"line 1160 from $term\n\n${tsigma.repr}")

val check = Expected.Check((tsigma.getType, region(term)))

expect match {
Expand Down Expand Up @@ -1221,9 +1273,12 @@ object Infer {
}
} yield (resTRho, resRegion, resBranches)
}

/*
* we require resT in weak prenex form because we call checkRho with it
* TODO: if sigma is an existential, then maybe we should reset any
* existentials after and leave them opaque? Or maybe if we could
* avoid allocating the metas until we get into the branch
*/
def checkBranch[A: HasRegion](p: Pattern, sigma: Expected.Check[(Type, Region)], res: Expr[A], resT: Type.Rho): Infer[(Pattern, TypedExpr.Rho[A])] =
for {
Expand All @@ -1246,7 +1301,8 @@ object Infer {
*
* TODO: Pattern needs to have a region for each part
*/
def typeCheckPattern(pat: Pattern, sigma: Expected.Check[(Type, Region)], reg: Region): Infer[(Pattern, List[(Bindable, Type)])] =
def typeCheckPattern(pat: Pattern, sigma: Expected.Check[(Type, Region)], reg: Region): Infer[(Pattern, List[(Bindable, Type)])] = {
assertNoFree(sigma.value._1, "in typeCheckPattern line 1266")
pat match {
case GenPattern.WildCard => Infer.pure((pat, Nil))
case GenPattern.Literal(lit) =>
Expand Down Expand Up @@ -1309,6 +1365,7 @@ object Infer {
Infer.pure((l, (splice, lst) :: Nil))
case ListPart.Item(p) =>
// This is a non-splice
assertNoFree(inner, s"line 1329 checkItem($inner, $lst, $e)")
checkPat(p, inner, reg).map { case (p, l) => (ListPart.Item(p), l) }
}
val tpeOfList: Infer[Type] =
Expand Down Expand Up @@ -1338,6 +1395,7 @@ object Infer {
// like in the case of an annotation, we check the type, then
// instantiate a sigma type
// checkSigma(term, tpe) *> instSigma(tpe, expect)
assertNoFree(tpe, s"line 1359 $pat $reg")
for {
patBind <- checkPat(p, tpe, reg)
(p1, binds) = patBind
Expand All @@ -1351,7 +1409,10 @@ object Infer {
// if the pattern arity does not match the arity of the constructor
// but we don't want to error type-checking since we want to show
// the maximimum number of errors to the user
envs <- args.zip(params).parTraverse { case (p, t) => checkPat(p, t, reg) }
envs <- args.zip(params).parTraverse { case (p, t) =>
assertNoFree(t, s"pat: $pat line 1374")
checkPat(p, t, reg)
}
pats = envs.map(_._1)
bindings = envs.map(_._2)
} yield (GenPattern.PositionalStruct(nm, pats), bindings.flatten)
Expand All @@ -1364,6 +1425,7 @@ object Infer {
}
.flatten
}
}

// Unions have to have identical bindings in all branches
def identicalBinds(u: Pattern, binds: NonEmptyList[List[(Bindable, Type)]], reg: Region): Infer[Unit] =
Expand Down Expand Up @@ -1416,8 +1478,8 @@ object Infer {
case (_, fa: Type.Quantified) =>
// we have to instantiate a rho type
instantiate(fa)
.flatMap { case (exSkols, faRho) =>
assert(exSkols.isEmpty, s"TODO: skols found in pattern: $consName $sigma")
.flatMap { case (_, faRho) =>
// TODO: it seems like we shouldn't ignore the existential skolems
loop(revArgs, leftKind, faRho)
}
case ((v0, k) :: rest, _) =>
Expand Down Expand Up @@ -1535,28 +1597,75 @@ object Infer {
e <- env
zrho <- zonkTypedExpr(rho)
q <- TypedExpr.quantify(e, zrho, readMeta _, writeMeta _)
_ = assertNoFree(q.getType, s"line 1563 from quantify(${rho.repr}) => ${q.repr}")
} yield q

// allocate this once and reuse
private val envTail = getEnv.map(_.values.toList)

def quantifyMetas(metas: List[Type.TyMeta]): FunctionK[TypedExpr, Lambda[x => Infer[TypedExpr[x]]]] =
NonEmptyList.fromList(metas) match {
case None =>
new FunctionK[TypedExpr, Lambda[x => Infer[TypedExpr[x]]]] {
def apply[A](fa: TypedExpr[A]): Infer[TypedExpr[A]] = pure(fa)
}
case Some(nel) =>
new FunctionK[TypedExpr, Lambda[x => Infer[TypedExpr[x]]]] {
def apply[A](fa: TypedExpr[A]): Infer[TypedExpr[A]] = {
// all these metas can be set to Var
val used: Set[Type.Var.Bound] = Type.tyVarBinders(fa.allTypes.toList)
val aligned = Type.alignBinders(nel, used)
val bound = aligned.toList.traverseFilter { case (m, n) =>
val meta = m.toMeta
if (meta.existential) writeMeta(m.toMeta, Type.TyVar(n)).as(Some((n, meta.kind)))
else pure(None)
}
// we only need to zonk after doing a write:
// it isnot clear that zonkMeta correctly here because the existentials
// here have been realized to Type.Var now, and and meta pointing at them should
// become visible (no longer hidden)
val zFn = Type.zonk(
metas.iterator.map(_.toMeta).filter(_.existential).to(SortedSet),
readMeta,
writeMeta)
(bound, TypedExpr.zonkMeta(fa)(zFn))
.mapN { (typeArgs, r) =>
val q = TypedExpr.quantVars(forallList = Nil, existList = typeArgs, r)
//println(s"quantifyMetas: exists = $typeArgs ${r.repr} quantVars = ${q.repr}")
q
}
}
}
}

def checkSigma[A: HasRegion](t: Expr[A], tpe: Type): Infer[TypedExpr[A]] = {
//println(s"checkSigma($t): $tpe")
val regionT = region(t)
for {
checkRho <- subsUpper[Lambda[x => (Expr[x], HasRegion[x])], Infer](tpe, regionT, envTail) { rho =>
if (rho == tpe) {
checkRho <- subsUpper[Lambda[x => (Expr[x], HasRegion[x])], Infer](tpe, regionT, envTail) { (metas, rho) =>
if ((rho: Type) === tpe) {
// we don't need to zonk here
pure(checkRhoK(rho))
}
else {
// we need to zonk before we unskolemize because some of the metas could be skolems
pure(checkRhoK(rho).andThenFlatMap[TypedExpr](zonkTypeExprK))
pure(checkRhoK(rho)
.andThenFlatMap[TypedExpr](zonkTypeExprK)
.andThenFlatMap[TypedExpr](quantifyMetas(metas))
)
}
} { badTvs =>
Error.NotPolymorphicEnough(tpe, t, badTvs, regionT)
}
te <- checkRho((t, implicitly[HasRegion[A]]))
} yield te
// TODO: seems we should quantify here
/*
quant <- quantify(getEnv, te)
_ = if (quant != te) {
//println(s"checkSigma ${te.repr} => ${quant.repr}")
} else ()
*/
} yield te //quant
}

/**
Expand Down Expand Up @@ -1585,6 +1694,13 @@ object Infer {
case Left(err) => fail(err)
}
} yield (expr, tpe)

def assertNoFree(t: Type, msg: => String): Unit =
Type.freeBoundTyVars(t :: Nil) match {
case Nil => ()
case nel =>
sys.error(s"expected no free vars in $t, found: $nel\n$msg")
}
}

private def recursiveTypeCheck[A: HasRegion](name: Bindable, expr: Expr[A]): Infer[TypedExpr[A]] =
Expand Down Expand Up @@ -1680,7 +1796,11 @@ object Infer {
}
.flatMap { groupChain =>
val glist = groupChain.toList
extendEnvListPack(pack, glist.map { case (b, _, te) => (b, te.getType) }) {
extendEnvListPack(pack, glist.map { case (b, _, te) =>
val t = te.getType
assertNoFree(t, s"line 1722 $b => ${te.repr}")
(b, t)
}) {
run(tail)
}
.map(glist ::: _)
Expand Down
Loading

0 comments on commit 469fcf9

Please sign in to comment.