Skip to content

Commit

Permalink
Simplify and unify free type variable handling (#1040)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek authored Sep 10, 2023
1 parent ec40f58 commit dd86eea
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 46 deletions.
25 changes: 17 additions & 8 deletions core/src/main/scala/org/bykn/bosatsu/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ object Expr {
case class Match[T](arg: Expr[T], branches: NonEmptyList[(Pattern[(PackageName, Constructor), Type], Expr[T])], tag: T) extends Expr[T]


def forAll[A](tpeArgs: List[(Type.Var.Bound, Kind)], expr: Expr[A]): Expr[A] =
NonEmptyList.fromList(tpeArgs) match {
case None => expr
case Some(nel) =>
expr match {
case Generic(typeVars, in) =>
Generic(nel ::: typeVars, in)
case notAnn => Generic(nel, notAnn)
}
}

def quantifyFrees[A](expr: Expr[A]): Expr[A] =
forAll(freeBoundTyVars(expr).map((_, Kind.Type)), expr)

/**
* Report all the Bindable names refered to in the given Expr.
* this can be used to allocate names that can never shadow
Expand Down Expand Up @@ -133,13 +147,16 @@ object Expr {
}
}

// Returns a distinct list of free bound type variables
// in the order they were encountered in traversal
def freeBoundTyVars[A](expr: Expr[A]): List[Type.Var.Bound] = {
val w = traverseType(expr, Set.empty) { (t, bound) =>
val frees = Chain.fromSeq(Type.freeBoundTyVars(t :: Nil))
Writer(frees.filterNot(bound), t)
}
w.written.iterator.toList.distinct
}

/**
* Here we substitute any free bound variables with skolem variables
*
Expand All @@ -157,14 +174,6 @@ object Expr {
* running inference, then quantifying over that skolem
* variable.
*/
def skolemizeFreeVars[F[_]: Applicative, A](expr: Expr[A])(newSkolemTyVar: (Type.Var.Bound, Kind) => F[Type.Var.Skolem]): Option[F[(NonEmptyList[Type.Var.Skolem], Expr[A])]] = {
val frees = freeBoundTyVars(expr)
NonEmptyList.fromList(frees)
.map { tvs =>
skolemizeVars[F, A](tvs.map { b => (b, Kind.Type) }, expr)(newSkolemTyVar)
}
}

def skolemizeVars[F[_]: Applicative, A](vs: NonEmptyList[(Type.Var.Bound, Kind)], expr: Expr[A])(newSkolemTyVar: (Type.Var.Bound, Kind) => F[Type.Var.Skolem]): F[(NonEmptyList[Type.Var.Skolem], Expr[A])] = {
vs.traverse { case (b, k) => newSkolemTyVar(b, k) }
.map { skVs =>
Expand Down
13 changes: 10 additions & 3 deletions core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ final class SourceConverter(
case (_, Padding(_, in)) => withBound(in, defstmt.name :: Nil)
}
val newBindings = defstmt.name :: defstmt.args.toList.flatMap(_.patternNames)
// TODO
val lambda = toLambdaExpr(defstmt, decl.region, success(decl))({ res => withBound(res._1.get, newBindings) })

(inExpr, lambda).parMapN { (in, lam) =>
Expand Down Expand Up @@ -1170,7 +1169,11 @@ final class SourceConverter(
case Right(Right((nm, decl))) =>

val r = apply(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil)
(topBound + nm, r)
// make sure all the free types are Generic
// we have to do this at the top level because in Declaration => Expr
// we allow closing over type variables defined at a higher level
val r1 = r.map { exs => exs.map { case (n, r, e) => (n, r, Expr.quantifyFrees(e)) } }
(topBound + nm, r1)

case Right(Left(d @ Def(defstmt@DefStatement(_, _, argGroups, _, _)))) =>
// using body for the outer here is a bummer, but not really a good outer otherwise
Expand All @@ -1193,7 +1196,11 @@ final class SourceConverter(
val rec =
if (UnusedLetCheck.freeBound(l).contains(boundName)) RecursionKind.Recursive
else RecursionKind.NonRecursive
(boundName, rec, l) :: Nil
// make sure all the free types are Generic
// we have to do this at the top level because in Declaration => Expr
// we allow closing over type variables defined at a higher level
val l1 = Expr.quantifyFrees(l)
(boundName, rec, l1) :: Nil
}
(topBound1, r)
case Left(ExternalDef(n, _, _)) =>
Expand Down
38 changes: 3 additions & 35 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1357,43 +1357,11 @@ object Infer {

private def typeCheckMeta[A: HasRegion](t: Expr[A], optMeta: Option[(Identifier, Type.TyMeta, Region)]): Infer[TypedExpr[A]] = {
def run(t: Expr[A]) = inferSigmaMeta(t, optMeta).flatMap(zonkTypedExpr _)
/*
* This is a deviation from the paper.
* We are allowing a syntax like:
*
* def identity(x: a) -> a:
* x
*
* or:
*
* def foo(x: a): x
*
* We handle this by converting a to a skolem variable,
* running inference, then quantifying over that skolem
* variable.
*
* TODO Kind we need to know the kinds of these skolems
*/

val optSkols = t match {
case Expr.Generic(vs, e) =>
Some(for {
skolsE1 <- Expr.skolemizeVars(vs, e)(newSkolemTyVar(_, _))
(skols, e1) = skolsE1
/*
* This is a bit weird, but for top-level defs, the type parameters
* only need to be a superset of free variables. You don't need
* to declare them all. On inner variables of a def we assume
* it is free in the top level def unless you declare it generic.
* Maybe worth revisiting and require ALL free variables declared
* or none of them...
*/
optMore = Expr.skolemizeFreeVars(e1)(newSkolemTyVar(_, _))
res <- optMore.fold(pure(skolsE1)) { restSkols =>
restSkols.map { case (sM, eM) => (skols ::: sM, eM) }
}
} yield res)
case notGeneric =>
Expr.skolemizeFreeVars(notGeneric)(newSkolemTyVar(_, _))
Some(Expr.skolemizeVars(vs, e)(newSkolemTyVar(_, _)))
case _ => None
}

optSkols match {
Expand Down

0 comments on commit dd86eea

Please sign in to comment.