Skip to content

Commit

Permalink
add a test for local polymorphic recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Sep 16, 2023
1 parent d1ddbb8 commit c259e21
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 22 deletions.
10 changes: 5 additions & 5 deletions core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ final class SourceConverter(
}
}

private def apply(decl: Declaration, bound: Set[Bindable], topBound: Set[Bindable]): Result[Expr[Declaration]] = {
private def fromDecl(decl: Declaration, bound: Set[Bindable], topBound: Set[Bindable]): Result[Expr[Declaration]] = {
implicit val parAp = SourceConverter.parallelIor
def loop(decl: Declaration) = apply(decl, bound, topBound)
def withBound(decl: Declaration, newB: Iterable[Bindable]) = apply(decl, bound ++ newB, topBound)
def loop(decl: Declaration) = fromDecl(decl, bound, topBound)
def withBound(decl: Declaration, newB: Iterable[Bindable]) = fromDecl(decl, bound ++ newB, topBound)

decl match {
case Annotation(term, tpe) =>
Expand Down Expand Up @@ -1168,7 +1168,7 @@ final class SourceConverter(
stmt match {
case Right(Right((nm, decl))) =>

val r = apply(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil)
val r = fromDecl(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil)
// make sure all the free types are Generic
// we have to do this at the top level because in Declaration => Expr
// we allow closing over type variables defined at a higher level
Expand All @@ -1188,7 +1188,7 @@ final class SourceConverter(
d.region,
success(defstmt.result.get))(
{ (res: OptIndent[Declaration]) =>
apply(res.get, argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, topBound1)
fromDecl(res.get, argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, topBound1)
})

val r = lam.map { (l: Expr[Declaration]) =>
Expand Down
46 changes: 29 additions & 17 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -846,24 +846,36 @@ object Infer {
// After we typecheck we see if this is truly recursive so
// compilers/evaluation can possibly optimize non-recursive
// cases differently
newMetaType(Kind.Type) // the kind of a let value is a Type
.flatMap { rhsTpe =>
extendEnv(name, rhsTpe) {
for {
// the type variable needs to be unified with varT
// note, varT could be a sigma type, it is not a Tau or Rho
typedRhs <- inferSigmaMeta(rhs, Some((name, rhsTpe, region(rhs))))
varT = typedRhs.getType
// we need to overwrite the metavariable now with the full type
typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect))
// TODO: a more efficient algorithm would do this top down
// for each top level TypedExpr and build it bottom up.
// we could do this after all typechecking is done
frees = TypedExpr.freeVars(typedRhs :: Nil)
isRecursive = RecursionKind.recursive(frees.contains(name))
} yield TypedExpr.Let(name, typedRhs, typedBody, isRecursive, tag)
rhs match {
case Annotation(expr, tpe, tag) =>
extendEnv(name, tpe) {
for {
typedRhs <- checkSigma(expr, tpe)
typedBody <- typeCheckRho(body, expect)
frees = TypedExpr.freeVars(typedRhs :: Nil)
isRecursive = RecursionKind.recursive(frees.contains(name))
} yield TypedExpr.Let(name, typedRhs, typedBody, isRecursive, tag)
}
case _ =>
newMetaType(Kind.Type) // the kind of a let value is a Type
.flatMap { rhsTpe =>
extendEnv(name, rhsTpe) {
for {
// the type variable needs to be unified with varT
// note, varT could be a sigma type, it is not a Tau or Rho
typedRhs <- inferSigmaMeta(rhs, Some((name, rhsTpe, region(rhs))))
varT = typedRhs.getType
// we need to overwrite the metavariable now with the full type
typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect))
// TODO: a more efficient algorithm would do this top down
// for each top level TypedExpr and build it bottom up.
// we could do this after all typechecking is done
frees = TypedExpr.freeVars(typedRhs :: Nil)
isRecursive = RecursionKind.recursive(frees.contains(name))
} yield TypedExpr.Let(name, typedRhs, typedBody, isRecursive, tag)
}
}
}
}
}
else {
// In this branch, we typecheck the rhs *without* name in the environment
Expand Down
23 changes: 23 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2989,6 +2989,29 @@ def poly_rec(count: Nat, a: a) -> a:
(_, b) = poly_rec(prev, ("foo", a))
b
test = Assertion(True, "")
""")
, "PolyRec", 1)

runBosatsuTest(
List("""
package PolyRec
enum Nat: NZero, NSucc(n: Nat)
def call(a):
# TODO it's weird that removing the [a] breaks this
# if a type isn't mentioned in an outer scope, we should assume it's local
def poly_rec[a](count: Nat, a: a) -> a:
recur count:
case NZero: a
case NSucc(prev):
# make a call with a different type
(_, b) = poly_rec(prev, ("foo", a))
b
# call a polymorphic recursion internally to exercise different code paths
poly_rec(NZero, a)
test = Assertion(True, "")
""")
, "PolyRec", 1)
Expand Down

0 comments on commit c259e21

Please sign in to comment.