Skip to content

Commit

Permalink
Support recursion in continuations
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Sep 16, 2023
1 parent 493c92c commit 09d31f7
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 12 deletions.
49 changes: 37 additions & 12 deletions core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.typelevel.paiges.Doc
import cats.implicits._

import Identifier.Bindable
import org.bykn.bosatsu.Declaration.Lambda

/**
* Recursion in bosatsu is only allowed on a substructural match
Expand Down Expand Up @@ -112,15 +113,15 @@ object DefRecursionCheck {
case TopLevel => Nil
case InDef(outer, n, _, _) => n :: outer.outerDefNames
case InDefRecurred(id, _, _, _, _) => id.outerDefNames
case InRecurBranch(ir, _) => ir.outerDefNames
case InRecurBranch(ir, _, _) => ir.outerDefNames
}

final def defNamesContain(n: Bindable): Boolean =
this match {
case TopLevel => false
case InDef(outer, dn, _, _) => (dn == n) || outer.defNamesContain(n)
case InDefRecurred(id, _, _, _, _) => id.defNamesContain(n)
case InRecurBranch(ir, _) => ir.defNamesContain(n)
case InRecurBranch(ir, _, _) => ir.defNamesContain(n)
}

def inDef(fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]]): InDef =
Expand All @@ -131,7 +132,7 @@ object DefRecursionCheck {
this match {
case InDef(_, defname, _, _) => defname
case InDefRecurred(ir, _, _, _, _) => ir.defname
case InRecurBranch(InDefRecurred(ir, _, _, _, _), _) => ir.defname
case InRecurBranch(InDefRecurred(ir, _, _, _, _), _, _) => ir.defname
}
}
case object TopLevel extends State
Expand All @@ -146,7 +147,7 @@ object DefRecursionCheck {
case class InDefRecurred(inRec: InDef, group: Int, index: Int, recur: Declaration.Match, recCount: Int) extends InDefState {
def incRecCount: InDefRecurred = copy(recCount = recCount + 1)
}
case class InRecurBranch(inRec: InDefRecurred, branch: Pattern.Parsed) extends InDefState {
case class InRecurBranch(inRec: InDefRecurred, branch: Pattern.Parsed, allowedNames: Set[Bindable]) extends InDefState {
def incRecCount: InRecurBranch = copy(inRec = inRec.incRecCount)
}

Expand Down Expand Up @@ -184,10 +185,10 @@ object DefRecursionCheck {
* Check that decl is a strict substructure of pat. We do this by making sure decl is a Var
* and that var is one of the strict substrutures of the pattern.
*/
def strictSubstructure(fnname: Bindable, pat: Pattern.Parsed, decl: Declaration): Res =
def allowedRecursion(fnname: Bindable, pat: Pattern.Parsed, names: Set[Bindable], decl: Declaration): Res =
decl match {
case v@Declaration.Var(nm) =>
if (pat.substructures.contains(nm)) unitValid
case v@Declaration.Var(nm: Bindable) =>
if (names.contains(nm)) unitValid
else Validated.invalidNel(RecursionNotSubstructural(fnname, pat, v))
case _ =>
// we can only recur with vars
Expand Down Expand Up @@ -256,12 +257,27 @@ object DefRecursionCheck {
argsOnDefName(fn1, args :: groups)
case _ => None
}

private def setNames[A](newNames: Set[Bindable])(in: St[A]): St[A] =
getSt.flatMap {
case start @ InRecurBranch(inrec, branch, names) =>
(setSt(InRecurBranch(inrec, branch, newNames)) *> in, getSt)
.flatMapN {
case (a, InRecurBranch(ir1, b1, _)) =>
setSt(InRecurBranch(ir1, b1, names)).as(a)
case (_, unexpected) =>
sys.error(s"invariant violation expected InRecurBranch: start = $start, end = $unexpected")
}
case notRecur =>
sys.error(s"called setNames on $notRecur with names: $newNames")
}

def checkApply(fn: Declaration, args: NonEmptyList[Declaration], region: Region): St[Unit] =
getSt.flatMap {
case TopLevel =>
// without any recursion, normal typechecking will detect bad states:
checkDecl(fn) *> args.traverse_(checkDecl)
case irb@InRecurBranch(inrec, branch) =>
case irb@InRecurBranch(inrec, branch, names) =>

argsOnDefName(fn, NonEmptyList.one(args)) match {
case Some((nm, groups)) =>
Expand All @@ -273,13 +289,22 @@ object DefRecursionCheck {
// not enough args to check recursion
failSt(InvalidRecursion(nm, region))
case Some(arg) =>
toSt(strictSubstructure(irb.defname, branch, arg)) *>
toSt(allowedRecursion(irb.defname, branch, names, arg)) *>
setSt(irb.incRecCount) // we have recurred again
}
}
else if (irb.defNamesContain(nm)) {
failSt(InvalidRecursion(nm, region))
}
else if (names.contains(nm)) {
// we are calling a reachable function. Any lambda args are new names:
args.traverse_[St, Unit] {
case Lambda(args, body) =>
val names1 = names ++ args.toList.iterator.flatMap(_.names)
setNames(names1)(checkDecl(body))
case notLambda => checkDecl(notLambda)
}
}
else {
// not a recursive call
args.traverse_(checkDecl)
Expand Down Expand Up @@ -349,7 +374,7 @@ object DefRecursionCheck {
case recur@Match(RecursionKind.Recursive, _, cases) =>
// this is a state change
getSt.flatMap {
case TopLevel | InRecurBranch(_, _) | InDefRecurred(_, _, _, _, _) =>
case TopLevel | InRecurBranch(_, _, _) | InDefRecurred(_, _, _, _, _) =>
failSt(UnexpectedRecur(recur))
case InDef(_, defname, args, locals) =>
toSt(getRecurIndex(defname, args, recur, locals)).flatMap { idx =>
Expand All @@ -361,7 +386,7 @@ object DefRecursionCheck {
val rec = ir.setRecur(idx, recur)
setSt(rec) *> beginBranch(pat)
case irr@InDefRecurred(_, _, _, _, _) =>
setSt(InRecurBranch(irr, pat))
setSt(InRecurBranch(irr, pat, pat.substructures.toSet))
case illegal =>
// $COVERAGE-OFF$ this should be unreachable
sys.error(s"unreachable: $pat -> $illegal")
Expand All @@ -370,7 +395,7 @@ object DefRecursionCheck {

val endBranch: St[Unit] =
getSt.flatMap {
case InRecurBranch(irr, _) => setSt(irr)
case InRecurBranch(irr, _, _) => setSt(irr)
case illegal =>
// $COVERAGE-OFF$ this should be unreachable
sys.error(s"unreachable end state: $illegal")
Expand Down
25 changes: 25 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -409,4 +409,29 @@ def len(lst):
case [_, *t]: id(len(t))
""")
}

test("tree example") {
allowed("""#
struct Tree(x: Int, items: List[Tree])
def sum_all(t):
recur t:
case []: 0
case [Tree(x, children), *tail]: x + sum_all(children) + sum_all(tail)
""")
}

test("we can recur on cont") {
allowed("""#
enum Cont:
Item(a: Int)
Next(use: (Cont -> Int) -> Int)
def loop(box: Cont) -> Int:
recur box:
case Item(a): a
case Next(cont_fn):
cont_fn(cont -> loop(cont))
""")
}
}
68 changes: 68 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3016,4 +3016,72 @@ test = Assertion(True, "")
""")
, "PolyRec", 1)
}

test("recursion on continuations") {
evalTest(
List("""
package A
enum Cont:
Item(a: Int)
Next(use: (Cont -> Int) -> Int)
def map(ca: Cont, fn: Int -> Int) -> Cont:
Next(cont -> fn(cont(ca)))
b = Item(1).map(x -> x.add(1))
def loop(box: Cont) -> Int:
recur box:
case Item(a): a
case Next(cont_fn):
cont_fn(cont -> loop(cont))
v = loop(b)
main = v
"""), "A", VInt(2))

// Generic version
evalTest(
List("""
package A
enum Cont[a: *]:
Item(a: a)
Next(use: (Cont[a] -> a) -> a)
def map[a](ca: Cont[a], fn: a -> a) -> Cont[a]:
Next(cont -> fn(cont(ca)))
def loop[a](box: Cont[a]) -> a:
recur box:
case Item(a): a
case Next(cont_fn):
cont_fn(cont -> loop(cont))
loopgen: forall a. Cont[a] -> a = loop
b: Cont[Int] = Item(1).map(x -> x.add(1))
main: Int = loop(b)
"""), "A", VInt(2))

// this example also exercises polymorphic recursion
evalTest(
List("""
package A
enum Box[a: +*]:
Item(a: a)
Next(fn: forall res. (forall b. (Box[b], b -> a) -> res) -> res)
def map[a, b](box: Box[a], fn: a -> b) -> Box[b]:
Next(cont -> cont(box, fn))
b = Item(1)
def loop[a](box: Box[a]) -> a:
recur box:
case Item(a): a
case Next(cont): cont((box, fn) -> fn(loop(box)))
v = loop(b)
main = v
"""), "A", VInt(1))
}
}

0 comments on commit 09d31f7

Please sign in to comment.