Skip to content

Commit

Permalink
Support recursion in continuations (#1045)
Browse files Browse the repository at this point in the history
* Support recursion in continuations

* support eta expansion checking

* Add repro of shadow bug in DefRecursionCheck

* fix rebinding loophole

* minor cleanups

* ignore coverage of unreachable error code

* minor code improvement
  • Loading branch information
johnynek authored Sep 17, 2023
1 parent 493c92c commit 7890b7b
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 30 deletions.
147 changes: 117 additions & 30 deletions core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ object DefRecursionCheck {
}
case class RecursionNotSubstructural(fnname: Bindable, recurPat: Pattern.Parsed, arg: Declaration.Var) extends RecursionError {
def region = arg.region
def message = s"recursion is ${fnname.sourceCodeRepr} not substructual"
def message = s"recursion in ${fnname.sourceCodeRepr} not substructual"
}
case class RecursiveDefNoRecur(defstmt: DefStatement[Pattern.Parsed, Declaration], recur: Declaration.Match) extends RecursionError {
def region = recur.region
Expand Down Expand Up @@ -107,33 +107,37 @@ object DefRecursionCheck {
* 3. we are checking the branches of the recur match
*/
sealed abstract class State {
final def outerDefNames: List[Bindable] =
final def outerDefNames: Set[Bindable] =
this match {
case TopLevel => Nil
case InDef(outer, n, _, _) => n :: outer.outerDefNames
case InDefRecurred(id, _, _, _, _) => id.outerDefNames
case InRecurBranch(ir, _) => ir.outerDefNames
case TopLevel => Set.empty
case ids: InDefState =>
val InDef(outer, n, _, _) = ids.inDef
outer.outerDefNames + n
}

@annotation.tailrec
final def defNamesContain(n: Bindable): Boolean =
this match {
case TopLevel => false
case InDef(outer, dn, _, _) => (dn == n) || outer.defNamesContain(n)
case InDefRecurred(id, _, _, _, _) => id.defNamesContain(n)
case InRecurBranch(ir, _) => ir.defNamesContain(n)
case ids: InDefState =>
val InDef(outer, dn, _, _) = ids.inDef
(dn == n) || outer.defNamesContain(n)
}

def inDef(fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]]): InDef =
InDef(this, fnname, args, Set.empty)
}
sealed abstract class InDefState extends State {
final def defname: Bindable =
final def inDef: InDef =
this match {
case InDef(_, defname, _, _) => defname
case InDefRecurred(ir, _, _, _, _) => ir.defname
case InRecurBranch(InDefRecurred(ir, _, _, _, _), _) => ir.defname
case id @ InDef(_, _, _, _) => id
case InDefRecurred(ir, _, _, _, _) => ir.inDef
case InRecurBranch(InDefRecurred(ir, _, _, _, _), _, _) => ir.inDef
}

final def defname: Bindable = inDef.fnname
}

case object TopLevel extends State
case class InDef(outer: State, fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]], localScope: Set[Bindable]) extends InDefState {

Expand All @@ -142,11 +146,45 @@ object DefRecursionCheck {

def setRecur(index: (Int, Int), m: Declaration.Match): InDefRecurred =
InDefRecurred(this, index._1, index._2, m, 0)

// This is eta-expansion of the function name as a lambda so we can check using the lambda rule
def asLambda(region: Region): Declaration.Lambda = {
val allNames = Iterator.iterate(0)(_ + 1).map { idx => Identifier.Name(s"a$idx") }.filterNot(_ == fnname)

val func = cats.Functor[NonEmptyList].compose[NonEmptyList]
// we allocate the names first. There is only one name inside: fnname
val argsB = func.map(args)(_ => allNames.next())

val argsV: NonEmptyList[NonEmptyList[Declaration.NonBinding]] =
func.map(argsB)(
n => Declaration.Var(n)(region)
)

val argsP: NonEmptyList[NonEmptyList[Pattern.Parsed]] =
func.map(argsB)(
n => Pattern.Var(n)
)

// fn == (x, y) -> z -> f(x, y)(z)
val body = argsV.toList.foldLeft(Declaration.Var(fnname)(region): Declaration.NonBinding) { (called, group) =>
Declaration.Apply(called, group, Declaration.ApplyKind.Parens)(region)
}

def lambdify(args: NonEmptyList[NonEmptyList[Pattern.Parsed]], body: Declaration): Declaration.Lambda = {
val body1 = args.tail match {
case Nil => body
case h :: tail => lambdify(NonEmptyList(h, tail), body)
}
Declaration.Lambda(args.head, body1)(region)
}

lambdify(argsP, body)
}
}
case class InDefRecurred(inRec: InDef, group: Int, index: Int, recur: Declaration.Match, recCount: Int) extends InDefState {
def incRecCount: InDefRecurred = copy(recCount = recCount + 1)
}
case class InRecurBranch(inRec: InDefRecurred, branch: Pattern.Parsed) extends InDefState {
case class InRecurBranch(inRec: InDefRecurred, branch: Pattern.Parsed, allowedNames: Set[Bindable]) extends InDefState {
def incRecCount: InRecurBranch = copy(inRec = inRec.incRecCount)
}

Expand Down Expand Up @@ -184,10 +222,10 @@ object DefRecursionCheck {
* Check that decl is a strict substructure of pat. We do this by making sure decl is a Var
* and that var is one of the strict substrutures of the pattern.
*/
def strictSubstructure(fnname: Bindable, pat: Pattern.Parsed, decl: Declaration): Res =
def allowedRecursion(fnname: Bindable, pat: Pattern.Parsed, names: Set[Bindable], decl: Declaration): Res =
decl match {
case v@Declaration.Var(nm) =>
if (pat.substructures.contains(nm)) unitValid
case v@Declaration.Var(nm: Bindable) =>
if (names.contains(nm)) unitValid
else Validated.invalidNel(RecursionNotSubstructural(fnname, pat, v))
case _ =>
// we can only recur with vars
Expand All @@ -204,15 +242,16 @@ object DefRecursionCheck {
def checkForIllegalBinds[A](
state: State,
bs: Iterable[Bindable],
decl: Declaration)(next: ValidatedNel[RecursionError, A]): ValidatedNel[RecursionError, A] =
state.outerDefNames match {
case Nil=> next
case nonEmpty =>
NonEmptyList.fromList(bs.filter(nonEmpty.toSet).toList.sorted) match {
decl: Declaration)(next: ValidatedNel[RecursionError, A]): ValidatedNel[RecursionError, A] = {
val outerSet = state.outerDefNames
if (outerSet.isEmpty) next
else {
NonEmptyList.fromList(bs.iterator.filter(outerSet).toList.sorted) match {
case Some(nel) =>
Validated.invalid(nel.map(IllegalShadow(_, decl)))
case None =>
next
}
}
}

Expand Down Expand Up @@ -256,12 +295,44 @@ object DefRecursionCheck {
argsOnDefName(fn1, args :: groups)
case _ => None
}

private def unionNames[A](newNames: Iterable[Bindable])(in: St[A]): St[A] =
getSt.flatMap {
case start @ InRecurBranch(inrec, branch, names) =>
(setSt(InRecurBranch(inrec, branch, names ++ newNames)) *> in, getSt)
.flatMapN {
case (a, InRecurBranch(ir1, b1, _)) =>
setSt(InRecurBranch(ir1, b1, names)).as(a)
// $COVERAGE-OFF$ this should be unreachable
case (_, unexpected) =>
sys.error(s"invariant violation expected InRecurBranch: start = $start, end = $unexpected")
}
case notRecur =>
sys.error(s"called setNames on $notRecur with names: $newNames")
// $COVERAGE-ON$ this should be unreachable
}

private def filterNames[A](newNames: Iterable[Bindable])(in: St[A]): St[A] =
getSt.flatMap {
case start @ InRecurBranch(inrec, branch, names) =>
(setSt(InRecurBranch(inrec, branch, names -- newNames)) *> in, getSt)
.flatMapN {
case (a, InRecurBranch(ir1, b1, _)) =>
setSt(InRecurBranch(ir1, b1, names)).as(a)
// $COVERAGE-OFF$ this should be unreachable
case (_, unexpected) =>
sys.error(s"invariant violation expected InRecurBranch: start = $start, end = $unexpected")
// $COVERAGE-ON$ this should be unreachable
}
case _ => in
}

def checkApply(fn: Declaration, args: NonEmptyList[Declaration], region: Region): St[Unit] =
getSt.flatMap {
case TopLevel =>
// without any recursion, normal typechecking will detect bad states:
checkDecl(fn) *> args.traverse_(checkDecl)
case irb@InRecurBranch(inrec, branch) =>
case irb@InRecurBranch(inrec, branch, names) =>

argsOnDefName(fn, NonEmptyList.one(args)) match {
case Some((nm, groups)) =>
Expand All @@ -273,14 +344,28 @@ object DefRecursionCheck {
// not enough args to check recursion
failSt(InvalidRecursion(nm, region))
case Some(arg) =>
toSt(strictSubstructure(irb.defname, branch, arg)) *>
toSt(allowedRecursion(irb.defname, branch, names, arg)) *>
setSt(irb.incRecCount) // we have recurred again
}
}
else if (irb.defNamesContain(nm)) {
failSt(InvalidRecursion(nm, region))
}
else if (names.contains(nm)) {
// we are calling a reachable function. Any lambda args are new names:
args.traverse_[St, Unit] {
case Declaration.Lambda(args, body) =>
val names1 = args.toList.flatMap(_.names)
unionNames(names1)(checkDecl(body))
case v@Declaration.Var(fn: Bindable) if irb.defname == fn =>
val Declaration.Lambda(args, body) = irb.inDef.asLambda(v.region)
val names1 = args.toList.flatMap(_.names)
unionNames(names1)(checkDecl(body))
case notLambda => checkDecl(notLambda)
}
}
else {
// traverse converting Var(name) to the lambda version to use the above check
// not a recursive call
args.traverse_(checkDecl)
}
Expand Down Expand Up @@ -311,7 +396,7 @@ object DefRecursionCheck {
case Binding(BindingStatement(pat, thisDecl, next)) =>
checkForIllegalBindsSt(pat.names, decl) *>
checkDecl(thisDecl) *>
checkDecl(next.padded)
filterNames(pat.names)(checkDecl(next.padded))
case Comment(cs) =>
checkDecl(cs.on.padded)
case CommentNB(cs) =>
Expand All @@ -335,21 +420,23 @@ object DefRecursionCheck {
checkDecl(t) *> checkDecl(c) *> checkDecl(f)
case Lambda(args, body) =>
// these args create new bindings:
checkForIllegalBindsSt(args.patternNames, decl) *> checkDecl(body)
val newBinds = args.patternNames
checkForIllegalBindsSt(newBinds, decl) *>
filterNames(newBinds)(checkDecl(body))
case Literal(_) =>
unitSt
case Match(RecursionKind.NonRecursive, arg, cases) =>
// the arg can't use state, but cases introduce new bindings:
val argRes = checkDecl(arg)
val optRes = cases.get.traverse_ { case (pat, next) =>
checkForIllegalBindsSt(pat.names, decl) *>
checkDecl(next.get)
filterNames(pat.names)(checkDecl(next.get))
}
argRes *> optRes
case recur@Match(RecursionKind.Recursive, _, cases) =>
// this is a state change
getSt.flatMap {
case TopLevel | InRecurBranch(_, _) | InDefRecurred(_, _, _, _, _) =>
case TopLevel | InRecurBranch(_, _, _) | InDefRecurred(_, _, _, _, _) =>
failSt(UnexpectedRecur(recur))
case InDef(_, defname, args, locals) =>
toSt(getRecurIndex(defname, args, recur, locals)).flatMap { idx =>
Expand All @@ -361,7 +448,7 @@ object DefRecursionCheck {
val rec = ir.setRecur(idx, recur)
setSt(rec) *> beginBranch(pat)
case irr@InDefRecurred(_, _, _, _, _) =>
setSt(InRecurBranch(irr, pat))
setSt(InRecurBranch(irr, pat, pat.substructures.toSet))
case illegal =>
// $COVERAGE-OFF$ this should be unreachable
sys.error(s"unreachable: $pat -> $illegal")
Expand All @@ -370,7 +457,7 @@ object DefRecursionCheck {

val endBranch: St[Unit] =
getSt.flatMap {
case InRecurBranch(irr, _) => setSt(irr)
case InRecurBranch(irr, _, _) => setSt(irr)
case illegal =>
// $COVERAGE-OFF$ this should be unreachable
sys.error(s"unreachable end state: $illegal")
Expand Down
115 changes: 115 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,121 @@ def len(lst):
recur lst:
case []: 0
case [_, *t]: id(len(t))
""")
}

test("tree example") {
allowed("""#
struct Tree(x: Int, items: List[Tree])
def sum_all(t):
recur t:
case []: 0
case [Tree(x, children), *tail]: x + sum_all(children) + sum_all(tail)
""")
}

test("we can recur on cont") {
allowed("""#
enum Cont:
Item(a: Int)
Next(use: (Cont -> Int) -> Int)
def loop(box: Cont) -> Int:
recur box:
case Item(a): a
case Next(cont_fn):
cont_fn(cont -> loop(cont))
""")

allowed("""#
enum Cont:
Item(a: Int)
Next(use: (Cont -> Int) -> Int)
def loop(box: Cont) -> Int:
recur box:
case Item(a): a
case Next(cont_fn):
cont_fn(loop)
""")
}

test("we can't trick the checker with a let shadow") {
disallowed("""#
struct Box(a)
def anything0[a, b](box: Box[a]) -> b:
recur box:
case Box(b):
# shadow to trick
b = Box(b)
anything0(b)
bottom: forall a. a = anything0(Box(1))
""")
}

test("we can't trick the checker with a match shadow") {
disallowed("""#
struct Box(a)
def anything0[a, b](box: Box[a]) -> b:
recur box:
case Box(b):
# shadow to trick
match Box(b):
b: anything0(b)
bottom: forall a. a = anything0(Box(1))
""")

disallowed("""#
struct Box(a)
def anything0[a, b](box: Box[a]) -> b:
recur box:
case Box(b):
# shadow to trick
recur Box(b):
b: anything0(b)
bottom: forall a. a = anything0(Box(1))
""")
}

test("we can't trick the checker with a lambda-let shadow") {
disallowed("""#
struct Box(a)
def anything0[a, b](box: Box[a]) -> b:
recur box:
case Box(b):
# shadow to trick
(b -> anything0(b))(Box(b))
bottom: forall a. a = anything0(Box(1))
""")
}

test("we can't trick the checker with a def-let shadow") {
disallowed("""#
struct Box(a)
def anything0[a, b](box: Box[a]) -> b:
recur box:
case Box(b):
# shadow to trick
def trick(b): anything0(b)
trick(Box(b))
bottom: forall a. a = anything0(Box(1))
""")
}
}
Loading

0 comments on commit 7890b7b

Please sign in to comment.