From 7890b7b3edaaeff4dcc6f4003825a878b136a67b Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Sun, 17 Sep 2023 11:42:39 -1000 Subject: [PATCH] Support recursion in continuations (#1045) * Support recursion in continuations * support eta expansion checking * Add repro of shadow bug in DefRecursionCheck * fix rebinding loophole * minor cleanups * ignore coverage of unreachable error code * minor code improvement --- .../org/bykn/bosatsu/DefRecursionCheck.scala | 147 ++++++++++++++---- .../bykn/bosatsu/DefRecursionCheckTest.scala | 115 ++++++++++++++ .../org/bykn/bosatsu/EvaluationTest.scala | 67 ++++++++ 3 files changed, 299 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala index 62847032d..e27f0a639 100644 --- a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala @@ -66,7 +66,7 @@ object DefRecursionCheck { } case class RecursionNotSubstructural(fnname: Bindable, recurPat: Pattern.Parsed, arg: Declaration.Var) extends RecursionError { def region = arg.region - def message = s"recursion is ${fnname.sourceCodeRepr} not substructual" + def message = s"recursion in ${fnname.sourceCodeRepr} not substructual" } case class RecursiveDefNoRecur(defstmt: DefStatement[Pattern.Parsed, Declaration], recur: Declaration.Match) extends RecursionError { def region = recur.region @@ -107,33 +107,37 @@ object DefRecursionCheck { * 3. we are checking the branches of the recur match */ sealed abstract class State { - final def outerDefNames: List[Bindable] = + final def outerDefNames: Set[Bindable] = this match { - case TopLevel => Nil - case InDef(outer, n, _, _) => n :: outer.outerDefNames - case InDefRecurred(id, _, _, _, _) => id.outerDefNames - case InRecurBranch(ir, _) => ir.outerDefNames + case TopLevel => Set.empty + case ids: InDefState => + val InDef(outer, n, _, _) = ids.inDef + outer.outerDefNames + n } + @annotation.tailrec final def defNamesContain(n: Bindable): Boolean = this match { case TopLevel => false - case InDef(outer, dn, _, _) => (dn == n) || outer.defNamesContain(n) - case InDefRecurred(id, _, _, _, _) => id.defNamesContain(n) - case InRecurBranch(ir, _) => ir.defNamesContain(n) + case ids: InDefState => + val InDef(outer, dn, _, _) = ids.inDef + (dn == n) || outer.defNamesContain(n) } def inDef(fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]]): InDef = InDef(this, fnname, args, Set.empty) } sealed abstract class InDefState extends State { - final def defname: Bindable = + final def inDef: InDef = this match { - case InDef(_, defname, _, _) => defname - case InDefRecurred(ir, _, _, _, _) => ir.defname - case InRecurBranch(InDefRecurred(ir, _, _, _, _), _) => ir.defname + case id @ InDef(_, _, _, _) => id + case InDefRecurred(ir, _, _, _, _) => ir.inDef + case InRecurBranch(InDefRecurred(ir, _, _, _, _), _, _) => ir.inDef } + + final def defname: Bindable = inDef.fnname } + case object TopLevel extends State case class InDef(outer: State, fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]], localScope: Set[Bindable]) extends InDefState { @@ -142,11 +146,45 @@ object DefRecursionCheck { def setRecur(index: (Int, Int), m: Declaration.Match): InDefRecurred = InDefRecurred(this, index._1, index._2, m, 0) + + // This is eta-expansion of the function name as a lambda so we can check using the lambda rule + def asLambda(region: Region): Declaration.Lambda = { + val allNames = Iterator.iterate(0)(_ + 1).map { idx => Identifier.Name(s"a$idx") }.filterNot(_ == fnname) + + val func = cats.Functor[NonEmptyList].compose[NonEmptyList] + // we allocate the names first. There is only one name inside: fnname + val argsB = func.map(args)(_ => allNames.next()) + + val argsV: NonEmptyList[NonEmptyList[Declaration.NonBinding]] = + func.map(argsB)( + n => Declaration.Var(n)(region) + ) + + val argsP: NonEmptyList[NonEmptyList[Pattern.Parsed]] = + func.map(argsB)( + n => Pattern.Var(n) + ) + + // fn == (x, y) -> z -> f(x, y)(z) + val body = argsV.toList.foldLeft(Declaration.Var(fnname)(region): Declaration.NonBinding) { (called, group) => + Declaration.Apply(called, group, Declaration.ApplyKind.Parens)(region) + } + + def lambdify(args: NonEmptyList[NonEmptyList[Pattern.Parsed]], body: Declaration): Declaration.Lambda = { + val body1 = args.tail match { + case Nil => body + case h :: tail => lambdify(NonEmptyList(h, tail), body) + } + Declaration.Lambda(args.head, body1)(region) + } + + lambdify(argsP, body) + } } case class InDefRecurred(inRec: InDef, group: Int, index: Int, recur: Declaration.Match, recCount: Int) extends InDefState { def incRecCount: InDefRecurred = copy(recCount = recCount + 1) } - case class InRecurBranch(inRec: InDefRecurred, branch: Pattern.Parsed) extends InDefState { + case class InRecurBranch(inRec: InDefRecurred, branch: Pattern.Parsed, allowedNames: Set[Bindable]) extends InDefState { def incRecCount: InRecurBranch = copy(inRec = inRec.incRecCount) } @@ -184,10 +222,10 @@ object DefRecursionCheck { * Check that decl is a strict substructure of pat. We do this by making sure decl is a Var * and that var is one of the strict substrutures of the pattern. */ - def strictSubstructure(fnname: Bindable, pat: Pattern.Parsed, decl: Declaration): Res = + def allowedRecursion(fnname: Bindable, pat: Pattern.Parsed, names: Set[Bindable], decl: Declaration): Res = decl match { - case v@Declaration.Var(nm) => - if (pat.substructures.contains(nm)) unitValid + case v@Declaration.Var(nm: Bindable) => + if (names.contains(nm)) unitValid else Validated.invalidNel(RecursionNotSubstructural(fnname, pat, v)) case _ => // we can only recur with vars @@ -204,15 +242,16 @@ object DefRecursionCheck { def checkForIllegalBinds[A]( state: State, bs: Iterable[Bindable], - decl: Declaration)(next: ValidatedNel[RecursionError, A]): ValidatedNel[RecursionError, A] = - state.outerDefNames match { - case Nil=> next - case nonEmpty => - NonEmptyList.fromList(bs.filter(nonEmpty.toSet).toList.sorted) match { + decl: Declaration)(next: ValidatedNel[RecursionError, A]): ValidatedNel[RecursionError, A] = { + val outerSet = state.outerDefNames + if (outerSet.isEmpty) next + else { + NonEmptyList.fromList(bs.iterator.filter(outerSet).toList.sorted) match { case Some(nel) => Validated.invalid(nel.map(IllegalShadow(_, decl))) case None => next + } } } @@ -256,12 +295,44 @@ object DefRecursionCheck { argsOnDefName(fn1, args :: groups) case _ => None } + + private def unionNames[A](newNames: Iterable[Bindable])(in: St[A]): St[A] = + getSt.flatMap { + case start @ InRecurBranch(inrec, branch, names) => + (setSt(InRecurBranch(inrec, branch, names ++ newNames)) *> in, getSt) + .flatMapN { + case (a, InRecurBranch(ir1, b1, _)) => + setSt(InRecurBranch(ir1, b1, names)).as(a) + // $COVERAGE-OFF$ this should be unreachable + case (_, unexpected) => + sys.error(s"invariant violation expected InRecurBranch: start = $start, end = $unexpected") + } + case notRecur => + sys.error(s"called setNames on $notRecur with names: $newNames") + // $COVERAGE-ON$ this should be unreachable + } + + private def filterNames[A](newNames: Iterable[Bindable])(in: St[A]): St[A] = + getSt.flatMap { + case start @ InRecurBranch(inrec, branch, names) => + (setSt(InRecurBranch(inrec, branch, names -- newNames)) *> in, getSt) + .flatMapN { + case (a, InRecurBranch(ir1, b1, _)) => + setSt(InRecurBranch(ir1, b1, names)).as(a) + // $COVERAGE-OFF$ this should be unreachable + case (_, unexpected) => + sys.error(s"invariant violation expected InRecurBranch: start = $start, end = $unexpected") + // $COVERAGE-ON$ this should be unreachable + } + case _ => in + } + def checkApply(fn: Declaration, args: NonEmptyList[Declaration], region: Region): St[Unit] = getSt.flatMap { case TopLevel => // without any recursion, normal typechecking will detect bad states: checkDecl(fn) *> args.traverse_(checkDecl) - case irb@InRecurBranch(inrec, branch) => + case irb@InRecurBranch(inrec, branch, names) => argsOnDefName(fn, NonEmptyList.one(args)) match { case Some((nm, groups)) => @@ -273,14 +344,28 @@ object DefRecursionCheck { // not enough args to check recursion failSt(InvalidRecursion(nm, region)) case Some(arg) => - toSt(strictSubstructure(irb.defname, branch, arg)) *> + toSt(allowedRecursion(irb.defname, branch, names, arg)) *> setSt(irb.incRecCount) // we have recurred again } } else if (irb.defNamesContain(nm)) { failSt(InvalidRecursion(nm, region)) } + else if (names.contains(nm)) { + // we are calling a reachable function. Any lambda args are new names: + args.traverse_[St, Unit] { + case Declaration.Lambda(args, body) => + val names1 = args.toList.flatMap(_.names) + unionNames(names1)(checkDecl(body)) + case v@Declaration.Var(fn: Bindable) if irb.defname == fn => + val Declaration.Lambda(args, body) = irb.inDef.asLambda(v.region) + val names1 = args.toList.flatMap(_.names) + unionNames(names1)(checkDecl(body)) + case notLambda => checkDecl(notLambda) + } + } else { + // traverse converting Var(name) to the lambda version to use the above check // not a recursive call args.traverse_(checkDecl) } @@ -311,7 +396,7 @@ object DefRecursionCheck { case Binding(BindingStatement(pat, thisDecl, next)) => checkForIllegalBindsSt(pat.names, decl) *> checkDecl(thisDecl) *> - checkDecl(next.padded) + filterNames(pat.names)(checkDecl(next.padded)) case Comment(cs) => checkDecl(cs.on.padded) case CommentNB(cs) => @@ -335,7 +420,9 @@ object DefRecursionCheck { checkDecl(t) *> checkDecl(c) *> checkDecl(f) case Lambda(args, body) => // these args create new bindings: - checkForIllegalBindsSt(args.patternNames, decl) *> checkDecl(body) + val newBinds = args.patternNames + checkForIllegalBindsSt(newBinds, decl) *> + filterNames(newBinds)(checkDecl(body)) case Literal(_) => unitSt case Match(RecursionKind.NonRecursive, arg, cases) => @@ -343,13 +430,13 @@ object DefRecursionCheck { val argRes = checkDecl(arg) val optRes = cases.get.traverse_ { case (pat, next) => checkForIllegalBindsSt(pat.names, decl) *> - checkDecl(next.get) + filterNames(pat.names)(checkDecl(next.get)) } argRes *> optRes case recur@Match(RecursionKind.Recursive, _, cases) => // this is a state change getSt.flatMap { - case TopLevel | InRecurBranch(_, _) | InDefRecurred(_, _, _, _, _) => + case TopLevel | InRecurBranch(_, _, _) | InDefRecurred(_, _, _, _, _) => failSt(UnexpectedRecur(recur)) case InDef(_, defname, args, locals) => toSt(getRecurIndex(defname, args, recur, locals)).flatMap { idx => @@ -361,7 +448,7 @@ object DefRecursionCheck { val rec = ir.setRecur(idx, recur) setSt(rec) *> beginBranch(pat) case irr@InDefRecurred(_, _, _, _, _) => - setSt(InRecurBranch(irr, pat)) + setSt(InRecurBranch(irr, pat, pat.substructures.toSet)) case illegal => // $COVERAGE-OFF$ this should be unreachable sys.error(s"unreachable: $pat -> $illegal") @@ -370,7 +457,7 @@ object DefRecursionCheck { val endBranch: St[Unit] = getSt.flatMap { - case InRecurBranch(irr, _) => setSt(irr) + case InRecurBranch(irr, _, _) => setSt(irr) case illegal => // $COVERAGE-OFF$ this should be unreachable sys.error(s"unreachable end state: $illegal") diff --git a/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala b/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala index 29192f870..e6c4e6cba 100644 --- a/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala @@ -407,6 +407,121 @@ def len(lst): recur lst: case []: 0 case [_, *t]: id(len(t)) +""") + } + + test("tree example") { + allowed("""# +struct Tree(x: Int, items: List[Tree]) + +def sum_all(t): + recur t: + case []: 0 + case [Tree(x, children), *tail]: x + sum_all(children) + sum_all(tail) +""") + } + + test("we can recur on cont") { + allowed("""# +enum Cont: + Item(a: Int) + Next(use: (Cont -> Int) -> Int) + +def loop(box: Cont) -> Int: + recur box: + case Item(a): a + case Next(cont_fn): + cont_fn(cont -> loop(cont)) + """) + + allowed("""# +enum Cont: + Item(a: Int) + Next(use: (Cont -> Int) -> Int) + +def loop(box: Cont) -> Int: + recur box: + case Item(a): a + case Next(cont_fn): + cont_fn(loop) + """) + } + + test("we can't trick the checker with a let shadow") { + disallowed("""# +struct Box(a) + +def anything0[a, b](box: Box[a]) -> b: + recur box: + case Box(b): + # shadow to trick + b = Box(b) + anything0(b) + +bottom: forall a. a = anything0(Box(1)) + +""") + } + + test("we can't trick the checker with a match shadow") { + disallowed("""# +struct Box(a) + +def anything0[a, b](box: Box[a]) -> b: + recur box: + case Box(b): + # shadow to trick + match Box(b): + b: anything0(b) + +bottom: forall a. a = anything0(Box(1)) + +""") + + disallowed("""# +struct Box(a) + +def anything0[a, b](box: Box[a]) -> b: + recur box: + case Box(b): + # shadow to trick + recur Box(b): + b: anything0(b) + +bottom: forall a. a = anything0(Box(1)) + +""") + } + + test("we can't trick the checker with a lambda-let shadow") { + disallowed("""# +struct Box(a) + +def anything0[a, b](box: Box[a]) -> b: + recur box: + case Box(b): + # shadow to trick + (b -> anything0(b))(Box(b)) + +bottom: forall a. a = anything0(Box(1)) + +""") + } + + test("we can't trick the checker with a def-let shadow") { + disallowed("""# +struct Box(a) + +def anything0[a, b](box: Box[a]) -> b: + recur box: + case Box(b): + # shadow to trick + def trick(b): anything0(b) + + trick(Box(b)) + +bottom: forall a. a = anything0(Box(1)) + """) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala index f0f55754d..a08e9156c 100644 --- a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala @@ -3016,4 +3016,71 @@ test = Assertion(True, "") """) , "PolyRec", 1) } + + test("recursion on continuations") { + evalTest( + List(""" +package A +enum Cont: + Item(a: Int) + Next(use: (Cont -> Int) -> Int) + +def map(ca: Cont, fn: Int -> Int) -> Cont: + Next(cont -> fn(cont(ca))) + +b = Item(1).map(x -> x.add(1)) + +def loop(box: Cont) -> Int: + recur box: + case Item(a): a + case Next(cont_fn): + cont_fn(cont -> loop(cont)) + +v = loop(b) +main = v +"""), "A", VInt(2)) + + // Generic version + evalTest( + List(""" +package A +enum Cont[a: *]: + Item(a: a) + Next(use: (Cont[a] -> a) -> a) + +def map[a](ca: Cont[a], fn: a -> a) -> Cont[a]: + Next(cont -> fn(cont(ca))) + +def loop[a](box: Cont[a]) -> a: + recur box: + case Item(a): a + case Next(cont_fn): cont_fn(loop) + +loopgen: forall a. Cont[a] -> a = loop +b: Cont[Int] = Item(1).map(x -> x.add(1)) +main: Int = loop(b) +"""), "A", VInt(2)) + + // this example also exercises polymorphic recursion + evalTest( + List(""" +package A +enum Box[a: +*]: + Item(a: a) + Next(fn: forall res. (forall b. (Box[b], b -> a) -> res) -> res) + +def map[a, b](box: Box[a], fn: a -> b) -> Box[b]: + Next(cont -> cont(box, fn)) + +b = Item(1) + +def loop[a](box: Box[a]) -> a: + recur box: + case Item(a): a + case Next(cont): cont((box, fn) -> fn(loop(box))) + +v = loop(b) +main = v +"""), "A", VInt(1)) + } }