diff --git a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala index 62847032d..c09c9d0d4 100644 --- a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala @@ -6,6 +6,7 @@ import org.typelevel.paiges.Doc import cats.implicits._ import Identifier.Bindable +import org.bykn.bosatsu.Declaration.Lambda /** * Recursion in bosatsu is only allowed on a substructural match @@ -112,7 +113,7 @@ object DefRecursionCheck { case TopLevel => Nil case InDef(outer, n, _, _) => n :: outer.outerDefNames case InDefRecurred(id, _, _, _, _) => id.outerDefNames - case InRecurBranch(ir, _) => ir.outerDefNames + case InRecurBranch(ir, _, _) => ir.outerDefNames } final def defNamesContain(n: Bindable): Boolean = @@ -120,7 +121,7 @@ object DefRecursionCheck { case TopLevel => false case InDef(outer, dn, _, _) => (dn == n) || outer.defNamesContain(n) case InDefRecurred(id, _, _, _, _) => id.defNamesContain(n) - case InRecurBranch(ir, _) => ir.defNamesContain(n) + case InRecurBranch(ir, _, _) => ir.defNamesContain(n) } def inDef(fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]]): InDef = @@ -131,7 +132,7 @@ object DefRecursionCheck { this match { case InDef(_, defname, _, _) => defname case InDefRecurred(ir, _, _, _, _) => ir.defname - case InRecurBranch(InDefRecurred(ir, _, _, _, _), _) => ir.defname + case InRecurBranch(InDefRecurred(ir, _, _, _, _), _, _) => ir.defname } } case object TopLevel extends State @@ -146,7 +147,7 @@ object DefRecursionCheck { case class InDefRecurred(inRec: InDef, group: Int, index: Int, recur: Declaration.Match, recCount: Int) extends InDefState { def incRecCount: InDefRecurred = copy(recCount = recCount + 1) } - case class InRecurBranch(inRec: InDefRecurred, branch: Pattern.Parsed) extends InDefState { + case class InRecurBranch(inRec: InDefRecurred, branch: Pattern.Parsed, allowedNames: Set[Bindable]) extends InDefState { def incRecCount: InRecurBranch = copy(inRec = inRec.incRecCount) } @@ -184,10 +185,10 @@ object DefRecursionCheck { * Check that decl is a strict substructure of pat. We do this by making sure decl is a Var * and that var is one of the strict substrutures of the pattern. */ - def strictSubstructure(fnname: Bindable, pat: Pattern.Parsed, decl: Declaration): Res = + def allowedRecursion(fnname: Bindable, pat: Pattern.Parsed, names: Set[Bindable], decl: Declaration): Res = decl match { - case v@Declaration.Var(nm) => - if (pat.substructures.contains(nm)) unitValid + case v@Declaration.Var(nm: Bindable) => + if (names.contains(nm)) unitValid else Validated.invalidNel(RecursionNotSubstructural(fnname, pat, v)) case _ => // we can only recur with vars @@ -256,12 +257,27 @@ object DefRecursionCheck { argsOnDefName(fn1, args :: groups) case _ => None } + + private def setNames[A](newNames: Set[Bindable])(in: St[A]): St[A] = + getSt.flatMap { + case start @ InRecurBranch(inrec, branch, names) => + (setSt(InRecurBranch(inrec, branch, newNames)) *> in, getSt) + .flatMapN { + case (a, InRecurBranch(ir1, b1, _)) => + setSt(InRecurBranch(ir1, b1, names)).as(a) + case (_, unexpected) => + sys.error(s"invariant violation expected InRecurBranch: start = $start, end = $unexpected") + } + case notRecur => + sys.error(s"called setNames on $notRecur with names: $newNames") + } + def checkApply(fn: Declaration, args: NonEmptyList[Declaration], region: Region): St[Unit] = getSt.flatMap { case TopLevel => // without any recursion, normal typechecking will detect bad states: checkDecl(fn) *> args.traverse_(checkDecl) - case irb@InRecurBranch(inrec, branch) => + case irb@InRecurBranch(inrec, branch, names) => argsOnDefName(fn, NonEmptyList.one(args)) match { case Some((nm, groups)) => @@ -273,13 +289,22 @@ object DefRecursionCheck { // not enough args to check recursion failSt(InvalidRecursion(nm, region)) case Some(arg) => - toSt(strictSubstructure(irb.defname, branch, arg)) *> + toSt(allowedRecursion(irb.defname, branch, names, arg)) *> setSt(irb.incRecCount) // we have recurred again } } else if (irb.defNamesContain(nm)) { failSt(InvalidRecursion(nm, region)) } + else if (names.contains(nm)) { + // we are calling a reachable function. Any lambda args are new names: + args.traverse_[St, Unit] { + case Lambda(args, body) => + val names1 = names ++ args.toList.iterator.flatMap(_.names) + setNames(names1)(checkDecl(body)) + case notLambda => checkDecl(notLambda) + } + } else { // not a recursive call args.traverse_(checkDecl) @@ -349,7 +374,7 @@ object DefRecursionCheck { case recur@Match(RecursionKind.Recursive, _, cases) => // this is a state change getSt.flatMap { - case TopLevel | InRecurBranch(_, _) | InDefRecurred(_, _, _, _, _) => + case TopLevel | InRecurBranch(_, _, _) | InDefRecurred(_, _, _, _, _) => failSt(UnexpectedRecur(recur)) case InDef(_, defname, args, locals) => toSt(getRecurIndex(defname, args, recur, locals)).flatMap { idx => @@ -361,7 +386,7 @@ object DefRecursionCheck { val rec = ir.setRecur(idx, recur) setSt(rec) *> beginBranch(pat) case irr@InDefRecurred(_, _, _, _, _) => - setSt(InRecurBranch(irr, pat)) + setSt(InRecurBranch(irr, pat, pat.substructures.toSet)) case illegal => // $COVERAGE-OFF$ this should be unreachable sys.error(s"unreachable: $pat -> $illegal") @@ -370,7 +395,7 @@ object DefRecursionCheck { val endBranch: St[Unit] = getSt.flatMap { - case InRecurBranch(irr, _) => setSt(irr) + case InRecurBranch(irr, _, _) => setSt(irr) case illegal => // $COVERAGE-OFF$ this should be unreachable sys.error(s"unreachable end state: $illegal") diff --git a/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala b/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala index 29192f870..09eeef1df 100644 --- a/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala @@ -409,4 +409,29 @@ def len(lst): case [_, *t]: id(len(t)) """) } + + test("tree example") { + allowed("""# +struct Tree(x: Int, items: List[Tree]) + +def sum_all(t): + recur t: + case []: 0 + case [Tree(x, children), *tail]: x + sum_all(children) + sum_all(tail) +""") + } + + test("we can recur on cont") { + allowed("""# +enum Cont: + Item(a: Int) + Next(use: (Cont -> Int) -> Int) + +def loop(box: Cont) -> Int: + recur box: + case Item(a): a + case Next(cont_fn): + cont_fn(cont -> loop(cont)) + """) + } } diff --git a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala index f0f55754d..6eb32fbdb 100644 --- a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala @@ -3016,4 +3016,72 @@ test = Assertion(True, "") """) , "PolyRec", 1) } + + test("recursion on continuations") { + evalTest( + List(""" +package A +enum Cont: + Item(a: Int) + Next(use: (Cont -> Int) -> Int) + +def map(ca: Cont, fn: Int -> Int) -> Cont: + Next(cont -> fn(cont(ca))) + +b = Item(1).map(x -> x.add(1)) + +def loop(box: Cont) -> Int: + recur box: + case Item(a): a + case Next(cont_fn): + cont_fn(cont -> loop(cont)) + +v = loop(b) +main = v +"""), "A", VInt(2)) + + // Generic version + evalTest( + List(""" +package A +enum Cont[a: *]: + Item(a: a) + Next(use: (Cont[a] -> a) -> a) + +def map[a](ca: Cont[a], fn: a -> a) -> Cont[a]: + Next(cont -> fn(cont(ca))) + +def loop[a](box: Cont[a]) -> a: + recur box: + case Item(a): a + case Next(cont_fn): + cont_fn(cont -> loop(cont)) + +loopgen: forall a. Cont[a] -> a = loop +b: Cont[Int] = Item(1).map(x -> x.add(1)) +main: Int = loop(b) +"""), "A", VInt(2)) + + // this example also exercises polymorphic recursion + evalTest( + List(""" +package A +enum Box[a: +*]: + Item(a: a) + Next(fn: forall res. (forall b. (Box[b], b -> a) -> res) -> res) + +def map[a, b](box: Box[a], fn: a -> b) -> Box[b]: + Next(cont -> cont(box, fn)) + +b = Item(1) + +def loop[a](box: Box[a]) -> a: + recur box: + case Item(a): a + case Next(cont): cont((box, fn) -> fn(loop(box))) + +v = loop(b) +main = v +"""), "A", VInt(1)) + } }