Skip to content

Commit

Permalink
Support polymorphic recursion (#1041)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek authored Sep 16, 2023
1 parent ec9414f commit 493c92c
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 25 deletions.
5 changes: 5 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ object Expr {
case None => expr
case Some(nel) =>
expr match {
case Annotation(expr, tpe, tag) =>
val tpeFrees = Type.freeBoundTyVars(tpe :: Nil).toSet
// these are the frees that are also in tpeArgs
val freeArgs = tpeArgs.filter { case (n, _) => tpeFrees(n) }
Annotation(forAll(tpeArgs, expr), Type.forAll(freeArgs, tpe), tag)
case Generic(typeVars, in) =>
Generic(nel ::: typeVars, in)
case notAnn => Generic(nel, notAnn)
Expand Down
12 changes: 6 additions & 6 deletions core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ final class SourceConverter(
case Some(k) => k
})
}
val gen = Expr.Generic(bs, lambda)
val gen = Expr.forAll(bs.toList, lambda)
val freeVarsList = Expr.freeBoundTyVars(lambda)
val freeVars = freeVarsList.toSet
val notFreeDecl = bs.exists { case (a, _) => !freeVars(a) }
Expand Down Expand Up @@ -172,10 +172,10 @@ final class SourceConverter(
}
}

private def apply(decl: Declaration, bound: Set[Bindable], topBound: Set[Bindable]): Result[Expr[Declaration]] = {
private def fromDecl(decl: Declaration, bound: Set[Bindable], topBound: Set[Bindable]): Result[Expr[Declaration]] = {
implicit val parAp = SourceConverter.parallelIor
def loop(decl: Declaration) = apply(decl, bound, topBound)
def withBound(decl: Declaration, newB: Iterable[Bindable]) = apply(decl, bound ++ newB, topBound)
def loop(decl: Declaration) = fromDecl(decl, bound, topBound)
def withBound(decl: Declaration, newB: Iterable[Bindable]) = fromDecl(decl, bound ++ newB, topBound)

decl match {
case Annotation(term, tpe) =>
Expand Down Expand Up @@ -1168,7 +1168,7 @@ final class SourceConverter(
stmt match {
case Right(Right((nm, decl))) =>

val r = apply(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil)
val r = fromDecl(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil)
// make sure all the free types are Generic
// we have to do this at the top level because in Declaration => Expr
// we allow closing over type variables defined at a higher level
Expand All @@ -1188,7 +1188,7 @@ final class SourceConverter(
d.region,
success(defstmt.result.get))(
{ (res: OptIndent[Declaration]) =>
apply(res.get, argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, topBound1)
fromDecl(res.get, argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, topBound1)
})

val r = lam.map { (l: Expr[Declaration]) =>
Expand Down
54 changes: 35 additions & 19 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -846,24 +846,35 @@ object Infer {
// After we typecheck we see if this is truly recursive so
// compilers/evaluation can possibly optimize non-recursive
// cases differently
newMetaType(Kind.Type) // the kind of a let value is a Type
.flatMap { rhsTpe =>
extendEnv(name, rhsTpe) {
for {
// the type variable needs to be unified with varT
// note, varT could be a sigma type, it is not a Tau or Rho
typedRhs <- inferSigmaMeta(rhs, Some((name, rhsTpe, region(rhs))))
varT = typedRhs.getType
// we need to overwrite the metavariable now with the full type
typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect))
// TODO: a more efficient algorithm would do this top down
// for each top level TypedExpr and build it bottom up.
// we could do this after all typechecking is done
frees = TypedExpr.freeVars(typedRhs :: Nil)
isRecursive = RecursionKind.recursive(frees.contains(name))
} yield TypedExpr.Let(name, typedRhs, typedBody, isRecursive, tag)
}
val rhsBody = rhs match {
case Annotation(expr, tpe, tag) =>
extendEnv(name, tpe) {
checkSigma(expr, tpe).product(typeCheckRho(body, expect))
}
case _ =>
newMetaType(Kind.Type) // the kind of a let value is a Type
.flatMap { rhsTpe =>
extendEnv(name, rhsTpe) {
for {
// the type variable needs to be unified with varT
// note, varT could be a sigma type, it is not a Tau or Rho
typedRhs <- inferSigmaMeta(rhs, Some((name, rhsTpe, region(rhs))))
varT = typedRhs.getType
// we need to overwrite the metavariable now with the full type
typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect))
} yield (typedRhs, typedBody)
}
}
}

rhsBody.map { case (rhs, body) =>
// TODO: a more efficient algorithm would do this top down
// for each top level TypedExpr and build it bottom up.
// we could do this after all typechecking is done
val frees = TypedExpr.freeVars(rhs :: Nil)
val isRecursive = RecursionKind.recursive(frees.contains(name))
TypedExpr.Let(name, rhs, body, isRecursive, tag)
}
}
else {
// In this branch, we typecheck the rhs *without* name in the environment
Expand Down Expand Up @@ -1335,8 +1346,13 @@ object Infer {

private def recursiveTypeCheck[A: HasRegion](name: Bindable, expr: Expr[A]): Infer[TypedExpr[A]] =
// values are of kind Type
newMetaType(Kind.Type).flatMap { tpe =>
extendEnv(name, tpe)(typeCheckMeta(expr, Some((name, tpe, region(expr)))))
expr match {
case Expr.Annotation(e, tpe, _) =>
extendEnv(name, tpe)(checkSigma(e, tpe))
case _ =>
newMetaType(Kind.Type).flatMap { tpe =>
extendEnv(name, tpe)(typeCheckMeta(expr, Some((name, tpe, region(expr)))))
}
}


Expand Down
43 changes: 43 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2973,4 +2973,47 @@ def last[a](nel: NEList[a]) -> a:
test = Assertion(last(One(True)), "")
"""), "Generic", 1)
}

test("support polymorphic recursion") {
runBosatsuTest(
List("""
package PolyRec
enum Nat: NZero, NSucc(n: Nat)
def poly_rec(count: Nat, a: a) -> a:
recur count:
case NZero: a
case NSucc(prev):
# make a call with a different type
(_, b) = poly_rec(prev, ("foo", a))
b
test = Assertion(True, "")
""")
, "PolyRec", 1)

runBosatsuTest(
List("""
package PolyRec
enum Nat: NZero, NSucc(n: Nat)
def call(a):
# TODO it's weird that removing the [a] breaks this
# if a type isn't mentioned in an outer scope, we should assume it's local
def poly_rec[a](count: Nat, a: a) -> a:
recur count:
case NZero: a
case NSucc(prev):
# make a call with a different type
(_, b) = poly_rec(prev, ("foo", a))
b
# call a polymorphic recursion internally to exercise different code paths
poly_rec(NZero, a)
test = Assertion(True, "")
""")
, "PolyRec", 1)
}
}

0 comments on commit 493c92c

Please sign in to comment.