diff --git a/core/src/main/scala/org/bykn/bosatsu/Expr.scala b/core/src/main/scala/org/bykn/bosatsu/Expr.scala index 161f8ba8f..8a2ac6dd6 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Expr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Expr.scala @@ -38,6 +38,11 @@ object Expr { case None => expr case Some(nel) => expr match { + case Annotation(expr, tpe, tag) => + val tpeFrees = Type.freeBoundTyVars(tpe :: Nil).toSet + // these are the frees that are also in tpeArgs + val freeArgs = tpeArgs.filter { case (n, _) => tpeFrees(n) } + Annotation(forAll(tpeArgs, expr), Type.forAll(freeArgs, tpe), tag) case Generic(typeVars, in) => Generic(nel ::: typeVars, in) case notAnn => Generic(nel, notAnn) diff --git a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala index ace8a1fa1..c6afb9a29 100644 --- a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala @@ -131,7 +131,7 @@ final class SourceConverter( case Some(k) => k }) } - val gen = Expr.Generic(bs, lambda) + val gen = Expr.forAll(bs.toList, lambda) val freeVarsList = Expr.freeBoundTyVars(lambda) val freeVars = freeVarsList.toSet val notFreeDecl = bs.exists { case (a, _) => !freeVars(a) } @@ -172,10 +172,10 @@ final class SourceConverter( } } - private def apply(decl: Declaration, bound: Set[Bindable], topBound: Set[Bindable]): Result[Expr[Declaration]] = { + private def fromDecl(decl: Declaration, bound: Set[Bindable], topBound: Set[Bindable]): Result[Expr[Declaration]] = { implicit val parAp = SourceConverter.parallelIor - def loop(decl: Declaration) = apply(decl, bound, topBound) - def withBound(decl: Declaration, newB: Iterable[Bindable]) = apply(decl, bound ++ newB, topBound) + def loop(decl: Declaration) = fromDecl(decl, bound, topBound) + def withBound(decl: Declaration, newB: Iterable[Bindable]) = fromDecl(decl, bound ++ newB, topBound) decl match { case Annotation(term, tpe) => @@ -1168,7 +1168,7 @@ final class SourceConverter( stmt match { case Right(Right((nm, decl))) => - val r = apply(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil) + val r = fromDecl(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil) // make sure all the free types are Generic // we have to do this at the top level because in Declaration => Expr // we allow closing over type variables defined at a higher level @@ -1188,7 +1188,7 @@ final class SourceConverter( d.region, success(defstmt.result.get))( { (res: OptIndent[Declaration]) => - apply(res.get, argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, topBound1) + fromDecl(res.get, argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, topBound1) }) val r = lam.map { (l: Expr[Declaration]) => diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index 61fcd030a..e3c55b2e5 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -846,24 +846,35 @@ object Infer { // After we typecheck we see if this is truly recursive so // compilers/evaluation can possibly optimize non-recursive // cases differently - newMetaType(Kind.Type) // the kind of a let value is a Type - .flatMap { rhsTpe => - extendEnv(name, rhsTpe) { - for { - // the type variable needs to be unified with varT - // note, varT could be a sigma type, it is not a Tau or Rho - typedRhs <- inferSigmaMeta(rhs, Some((name, rhsTpe, region(rhs)))) - varT = typedRhs.getType - // we need to overwrite the metavariable now with the full type - typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect)) - // TODO: a more efficient algorithm would do this top down - // for each top level TypedExpr and build it bottom up. - // we could do this after all typechecking is done - frees = TypedExpr.freeVars(typedRhs :: Nil) - isRecursive = RecursionKind.recursive(frees.contains(name)) - } yield TypedExpr.Let(name, typedRhs, typedBody, isRecursive, tag) - } + val rhsBody = rhs match { + case Annotation(expr, tpe, tag) => + extendEnv(name, tpe) { + checkSigma(expr, tpe).product(typeCheckRho(body, expect)) + } + case _ => + newMetaType(Kind.Type) // the kind of a let value is a Type + .flatMap { rhsTpe => + extendEnv(name, rhsTpe) { + for { + // the type variable needs to be unified with varT + // note, varT could be a sigma type, it is not a Tau or Rho + typedRhs <- inferSigmaMeta(rhs, Some((name, rhsTpe, region(rhs)))) + varT = typedRhs.getType + // we need to overwrite the metavariable now with the full type + typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect)) + } yield (typedRhs, typedBody) + } + } } + + rhsBody.map { case (rhs, body) => + // TODO: a more efficient algorithm would do this top down + // for each top level TypedExpr and build it bottom up. + // we could do this after all typechecking is done + val frees = TypedExpr.freeVars(rhs :: Nil) + val isRecursive = RecursionKind.recursive(frees.contains(name)) + TypedExpr.Let(name, rhs, body, isRecursive, tag) + } } else { // In this branch, we typecheck the rhs *without* name in the environment @@ -1335,8 +1346,13 @@ object Infer { private def recursiveTypeCheck[A: HasRegion](name: Bindable, expr: Expr[A]): Infer[TypedExpr[A]] = // values are of kind Type - newMetaType(Kind.Type).flatMap { tpe => - extendEnv(name, tpe)(typeCheckMeta(expr, Some((name, tpe, region(expr))))) + expr match { + case Expr.Annotation(e, tpe, _) => + extendEnv(name, tpe)(checkSigma(e, tpe)) + case _ => + newMetaType(Kind.Type).flatMap { tpe => + extendEnv(name, tpe)(typeCheckMeta(expr, Some((name, tpe, region(expr))))) + } } diff --git a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala index ec3c668cc..f0f55754d 100644 --- a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala @@ -2973,4 +2973,47 @@ def last[a](nel: NEList[a]) -> a: test = Assertion(last(One(True)), "") """), "Generic", 1) } + + test("support polymorphic recursion") { + runBosatsuTest( + List(""" +package PolyRec + +enum Nat: NZero, NSucc(n: Nat) + +def poly_rec(count: Nat, a: a) -> a: + recur count: + case NZero: a + case NSucc(prev): + # make a call with a different type + (_, b) = poly_rec(prev, ("foo", a)) + b + +test = Assertion(True, "") +""") + , "PolyRec", 1) + + runBosatsuTest( + List(""" +package PolyRec + +enum Nat: NZero, NSucc(n: Nat) + +def call(a): + # TODO it's weird that removing the [a] breaks this + # if a type isn't mentioned in an outer scope, we should assume it's local + def poly_rec[a](count: Nat, a: a) -> a: + recur count: + case NZero: a + case NSucc(prev): + # make a call with a different type + (_, b) = poly_rec(prev, ("foo", a)) + b + # call a polymorphic recursion internally to exercise different code paths + poly_rec(NZero, a) + +test = Assertion(True, "") +""") + , "PolyRec", 1) + } }