From d1ddbb8441379c14348474bfe9225d801f58b35d Mon Sep 17 00:00:00 2001 From: Patrick Oscar Boykin Date: Sun, 10 Sep 2023 10:50:08 -1000 Subject: [PATCH 1/3] Support polymorphic recursion --- .../main/scala/org/bykn/bosatsu/Expr.scala | 5 +++++ .../org/bykn/bosatsu/SourceConverter.scala | 2 +- .../scala/org/bykn/bosatsu/rankn/Infer.scala | 9 +++++++-- .../org/bykn/bosatsu/EvaluationTest.scala | 20 +++++++++++++++++++ 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/Expr.scala b/core/src/main/scala/org/bykn/bosatsu/Expr.scala index 161f8ba8f..8a2ac6dd6 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Expr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Expr.scala @@ -38,6 +38,11 @@ object Expr { case None => expr case Some(nel) => expr match { + case Annotation(expr, tpe, tag) => + val tpeFrees = Type.freeBoundTyVars(tpe :: Nil).toSet + // these are the frees that are also in tpeArgs + val freeArgs = tpeArgs.filter { case (n, _) => tpeFrees(n) } + Annotation(forAll(tpeArgs, expr), Type.forAll(freeArgs, tpe), tag) case Generic(typeVars, in) => Generic(nel ::: typeVars, in) case notAnn => Generic(nel, notAnn) diff --git a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala index ace8a1fa1..980a722fb 100644 --- a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala @@ -131,7 +131,7 @@ final class SourceConverter( case Some(k) => k }) } - val gen = Expr.Generic(bs, lambda) + val gen = Expr.forAll(bs.toList, lambda) val freeVarsList = Expr.freeBoundTyVars(lambda) val freeVars = freeVarsList.toSet val notFreeDecl = bs.exists { case (a, _) => !freeVars(a) } diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index 61fcd030a..c6ef681bf 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -1335,8 +1335,13 @@ object Infer { private def recursiveTypeCheck[A: HasRegion](name: Bindable, expr: Expr[A]): Infer[TypedExpr[A]] = // values are of kind Type - newMetaType(Kind.Type).flatMap { tpe => - extendEnv(name, tpe)(typeCheckMeta(expr, Some((name, tpe, region(expr))))) + expr match { + case Expr.Annotation(e, tpe, _) => + extendEnv(name, tpe)(checkSigma(e, tpe)) + case _ => + newMetaType(Kind.Type).flatMap { tpe => + extendEnv(name, tpe)(typeCheckMeta(expr, Some((name, tpe, region(expr))))) + } } diff --git a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala index ec3c668cc..90dbf8941 100644 --- a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala @@ -2973,4 +2973,24 @@ def last[a](nel: NEList[a]) -> a: test = Assertion(last(One(True)), "") """), "Generic", 1) } + + test("support polymorphic recursion") { + runBosatsuTest( + List(""" +package PolyRec + +enum Nat: NZero, NSucc(n: Nat) + +def poly_rec(count: Nat, a: a) -> a: + recur count: + case NZero: a + case NSucc(prev): + # make a call with a different type + (_, b) = poly_rec(prev, ("foo", a)) + b + +test = Assertion(True, "") +""") + , "PolyRec", 1) + } } From c259e21b641dbed942b92cd9ee6288592121982c Mon Sep 17 00:00:00 2001 From: Patrick Oscar Boykin Date: Sat, 16 Sep 2023 07:41:59 -1000 Subject: [PATCH 2/3] add a test for local polymorphic recursion --- .../org/bykn/bosatsu/SourceConverter.scala | 10 ++-- .../scala/org/bykn/bosatsu/rankn/Infer.scala | 46 ++++++++++++------- .../org/bykn/bosatsu/EvaluationTest.scala | 23 ++++++++++ 3 files changed, 57 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala index 980a722fb..c6afb9a29 100644 --- a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala @@ -172,10 +172,10 @@ final class SourceConverter( } } - private def apply(decl: Declaration, bound: Set[Bindable], topBound: Set[Bindable]): Result[Expr[Declaration]] = { + private def fromDecl(decl: Declaration, bound: Set[Bindable], topBound: Set[Bindable]): Result[Expr[Declaration]] = { implicit val parAp = SourceConverter.parallelIor - def loop(decl: Declaration) = apply(decl, bound, topBound) - def withBound(decl: Declaration, newB: Iterable[Bindable]) = apply(decl, bound ++ newB, topBound) + def loop(decl: Declaration) = fromDecl(decl, bound, topBound) + def withBound(decl: Declaration, newB: Iterable[Bindable]) = fromDecl(decl, bound ++ newB, topBound) decl match { case Annotation(term, tpe) => @@ -1168,7 +1168,7 @@ final class SourceConverter( stmt match { case Right(Right((nm, decl))) => - val r = apply(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil) + val r = fromDecl(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil) // make sure all the free types are Generic // we have to do this at the top level because in Declaration => Expr // we allow closing over type variables defined at a higher level @@ -1188,7 +1188,7 @@ final class SourceConverter( d.region, success(defstmt.result.get))( { (res: OptIndent[Declaration]) => - apply(res.get, argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, topBound1) + fromDecl(res.get, argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, topBound1) }) val r = lam.map { (l: Expr[Declaration]) => diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index c6ef681bf..653d1e8b3 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -846,24 +846,36 @@ object Infer { // After we typecheck we see if this is truly recursive so // compilers/evaluation can possibly optimize non-recursive // cases differently - newMetaType(Kind.Type) // the kind of a let value is a Type - .flatMap { rhsTpe => - extendEnv(name, rhsTpe) { - for { - // the type variable needs to be unified with varT - // note, varT could be a sigma type, it is not a Tau or Rho - typedRhs <- inferSigmaMeta(rhs, Some((name, rhsTpe, region(rhs)))) - varT = typedRhs.getType - // we need to overwrite the metavariable now with the full type - typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect)) - // TODO: a more efficient algorithm would do this top down - // for each top level TypedExpr and build it bottom up. - // we could do this after all typechecking is done - frees = TypedExpr.freeVars(typedRhs :: Nil) - isRecursive = RecursionKind.recursive(frees.contains(name)) - } yield TypedExpr.Let(name, typedRhs, typedBody, isRecursive, tag) + rhs match { + case Annotation(expr, tpe, tag) => + extendEnv(name, tpe) { + for { + typedRhs <- checkSigma(expr, tpe) + typedBody <- typeCheckRho(body, expect) + frees = TypedExpr.freeVars(typedRhs :: Nil) + isRecursive = RecursionKind.recursive(frees.contains(name)) + } yield TypedExpr.Let(name, typedRhs, typedBody, isRecursive, tag) + } + case _ => + newMetaType(Kind.Type) // the kind of a let value is a Type + .flatMap { rhsTpe => + extendEnv(name, rhsTpe) { + for { + // the type variable needs to be unified with varT + // note, varT could be a sigma type, it is not a Tau or Rho + typedRhs <- inferSigmaMeta(rhs, Some((name, rhsTpe, region(rhs)))) + varT = typedRhs.getType + // we need to overwrite the metavariable now with the full type + typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect)) + // TODO: a more efficient algorithm would do this top down + // for each top level TypedExpr and build it bottom up. + // we could do this after all typechecking is done + frees = TypedExpr.freeVars(typedRhs :: Nil) + isRecursive = RecursionKind.recursive(frees.contains(name)) + } yield TypedExpr.Let(name, typedRhs, typedBody, isRecursive, tag) + } } - } + } } else { // In this branch, we typecheck the rhs *without* name in the environment diff --git a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala index 90dbf8941..f0f55754d 100644 --- a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala @@ -2989,6 +2989,29 @@ def poly_rec(count: Nat, a: a) -> a: (_, b) = poly_rec(prev, ("foo", a)) b +test = Assertion(True, "") +""") + , "PolyRec", 1) + + runBosatsuTest( + List(""" +package PolyRec + +enum Nat: NZero, NSucc(n: Nat) + +def call(a): + # TODO it's weird that removing the [a] breaks this + # if a type isn't mentioned in an outer scope, we should assume it's local + def poly_rec[a](count: Nat, a: a) -> a: + recur count: + case NZero: a + case NSucc(prev): + # make a call with a different type + (_, b) = poly_rec(prev, ("foo", a)) + b + # call a polymorphic recursion internally to exercise different code paths + poly_rec(NZero, a) + test = Assertion(True, "") """) , "PolyRec", 1) From d43c6b73901a71862662da367a342a7954953f79 Mon Sep 17 00:00:00 2001 From: Patrick Oscar Boykin Date: Sat, 16 Sep 2023 08:01:47 -1000 Subject: [PATCH 3/3] simplify code a bit --- .../scala/org/bykn/bosatsu/rankn/Infer.scala | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index 653d1e8b3..e3c55b2e5 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -846,15 +846,10 @@ object Infer { // After we typecheck we see if this is truly recursive so // compilers/evaluation can possibly optimize non-recursive // cases differently - rhs match { + val rhsBody = rhs match { case Annotation(expr, tpe, tag) => extendEnv(name, tpe) { - for { - typedRhs <- checkSigma(expr, tpe) - typedBody <- typeCheckRho(body, expect) - frees = TypedExpr.freeVars(typedRhs :: Nil) - isRecursive = RecursionKind.recursive(frees.contains(name)) - } yield TypedExpr.Let(name, typedRhs, typedBody, isRecursive, tag) + checkSigma(expr, tpe).product(typeCheckRho(body, expect)) } case _ => newMetaType(Kind.Type) // the kind of a let value is a Type @@ -865,16 +860,20 @@ object Infer { // note, varT could be a sigma type, it is not a Tau or Rho typedRhs <- inferSigmaMeta(rhs, Some((name, rhsTpe, region(rhs)))) varT = typedRhs.getType - // we need to overwrite the metavariable now with the full type - typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect)) - // TODO: a more efficient algorithm would do this top down - // for each top level TypedExpr and build it bottom up. - // we could do this after all typechecking is done - frees = TypedExpr.freeVars(typedRhs :: Nil) - isRecursive = RecursionKind.recursive(frees.contains(name)) - } yield TypedExpr.Let(name, typedRhs, typedBody, isRecursive, tag) + // we need to overwrite the metavariable now with the full type + typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect)) + } yield (typedRhs, typedBody) + } } - } + } + + rhsBody.map { case (rhs, body) => + // TODO: a more efficient algorithm would do this top down + // for each top level TypedExpr and build it bottom up. + // we could do this after all typechecking is done + val frees = TypedExpr.freeVars(rhs :: Nil) + val isRecursive = RecursionKind.recursive(frees.contains(name)) + TypedExpr.Let(name, rhs, body, isRecursive, tag) } } else {