From 30b3af9423d9ed8d151b05653120d35f0c86a5a1 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Wed, 13 Nov 2024 10:10:18 -1000 Subject: [PATCH] Simplify Matchless.Let around recursion (#1254) --- .../scala/org/bykn/bosatsu/Matchless.scala | 42 +++++---- .../org/bykn/bosatsu/MatchlessToValue.scala | 89 +++++++++---------- .../bosatsu/codegen/python/PythonGen.scala | 37 +++----- .../org/bykn/bosatsu/MatchlessTests.scala | 27 +++--- 4 files changed, 89 insertions(+), 106 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala index 2e30c55fc..4c06e1ad5 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala @@ -13,7 +13,11 @@ object Matchless { // these hold bindings either in the code, or temporary // local ones, note CheapExpr never trigger a side effect sealed trait CheapExpr extends Expr - sealed abstract class FnExpr extends Expr + sealed abstract class FnExpr extends Expr { + def captures: List[Expr] + // this is set if the function is recursive + def recursiveName: Option[Bindable] + } sealed abstract class StrPart object StrPart { @@ -78,7 +82,7 @@ object Matchless { // name is set for recursive (but not tail recursive) methods case class Lambda( captures: List[Expr], - name: Option[Bindable], + recursiveName: Option[Bindable], args: NonEmptyList[Bindable], expr: Expr ) extends FnExpr @@ -92,7 +96,9 @@ object Matchless { name: Bindable, arg: NonEmptyList[Bindable], body: Expr - ) extends FnExpr + ) extends FnExpr { + val recursiveName: Option[Bindable] = Some(name) + } case class Global(pack: PackageName, name: Bindable) extends CheapExpr @@ -108,7 +114,7 @@ object Matchless { // note fn is never an App case class App(fn: Expr, arg: NonEmptyList[Expr]) extends Expr case class Let( - arg: Either[LocalAnon, (Bindable, RecursionKind)], + arg: Either[LocalAnon, Bindable], expr: Expr, in: Expr ) extends Expr @@ -324,12 +330,19 @@ object Matchless { e: TypedExpr[A], rec: RecursionKind, slots: LambdaState - ): F[Expr] = { - lazy val e0 = loop(e, if (rec.isRecursive) slots.inLet(name) else slots) + ): F[Expr] = rec match { case RecursionKind.Recursive => - def letrec(e: Expr): Expr = - Let(Right((name, RecursionKind.Recursive)), e, Local(name)) + lazy val e0 = loop(e, slots.inLet(name)) + def letrec(expr: Expr): Expr = + expr match { + case fn: FnExpr if fn.recursiveName == Some(name) => fn + case fn: FnExpr => + // loops always have a function name + sys.error(s"expected ${fn.recursiveName} == Some($name) in ${e.repr.render(80)} which compiled to $fn") + case _ => + sys.error(s"expected ${e.repr.render(80)} to compile to a function, but got: $expr") + } // this could be tail recursive if (SelfCallKind(name, e) == SelfCallKind.TailCall) { @@ -345,20 +358,20 @@ object Matchless { val (slots1, caps) = slots.inLet(name).lambdaFrees(frees) loop(body, slots1) .map { v => - letrec(LoopFn(caps, name, args, v)) + LoopFn(caps, name, args, v) } + // $COVERAGE-OFF$ case _ => // TODO: I don't think this case should ever happen in real code // but it definitely does in fuzz tests e0.map(letrec) + // $COVERAGE-ON$ } } else { - // otherwise let rec x = fn in x e0.map(letrec) } - case RecursionKind.NonRecursive => e0 + case RecursionKind.NonRecursive => loop(e, slots) } - } def loop(te: TypedExpr[A], slots: LambdaState): F[Expr] = te match { @@ -396,7 +409,7 @@ object Matchless { .mapN(App(_, _)) case TypedExpr.Let(a, e, in, r, _) => (loopLetVal(a, e, r, slots.unname), loop(in, slots)) - .mapN(Let(Right((a, r)), _, _)) + .mapN(Let(Right(a), _, _)) case TypedExpr.Literal(lit, _, _) => Monad[F].pure(Literal(lit)) case TypedExpr.Match(arg, branches, _) => ( @@ -766,8 +779,7 @@ object Matchless { def lets(binds: List[(Bindable, Expr)], in: Expr): Expr = binds.foldRight(in) { case ((b, e), r) => - val arg = Right((b, RecursionKind.NonRecursive)) - Let(arg, e, r) + Let(Right(b), e, r) } def checkLets(binds: List[LocalAnonMut], in: Expr): Expr = diff --git a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala index 8ed32afc5..f2a93e3ca 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala @@ -94,12 +94,9 @@ object MatchlessToValue { () } - def capture(it: Vector[Value], name: Option[Bindable]): Scope = + def capture(it: Vector[Value]): Scope = Scope( - name match { - case None => Map.empty - case Some(n) => Map((n, locals(n))) - }, + Map.empty, LongMap.empty, MLongMap(), it @@ -353,7 +350,7 @@ object MatchlessToValue { // It may make things go faster // if the caps are really small // or if we can GC things sooner. - val scope1 = scope.capture(caps.map(s => s(scope)), Some(fnName)) + val scope1 = scope.capture(caps.map(s => s(scope))) FnValue { allArgs => var registers: NonEmptyList[Value] = allArgs @@ -393,30 +390,45 @@ object MatchlessToValue { // the locals can be recusive, so we box into Eval for laziness def loop(me: Expr): Scoped[Value] = me match { - case Lambda(caps, name, args, res) => + case Lambda(Nil, None, args, res) => val resFn = loop(res) - - if (caps.isEmpty && name.isEmpty) { - // we can allocate once if there is no closure - val scope1 = Scope.empty() - val fn = FnValue { argV => + // we can allocate once if there is no closure + val scope1 = Scope.empty() + val fn = FnValue { argV => + val scope2 = scope1.letAll(args, argV) + resFn(scope2) + } + Static(fn) + case Lambda(caps, None, args, res) => + val resFn = loop(res) + val capScoped = caps.map(loop).toVector + Dynamic { scope => + val scope1 = scope + .capture(capScoped.map(scoped => scoped(scope))) + + // hopefully optimization/normalization has lifted anything + // that doesn't depend on argV above this lambda + FnValue { argV => val scope2 = scope1.letAll(args, argV) resFn(scope2) } - Static(fn) - } else { - val capScoped = caps.map(loop).toVector - Dynamic { scope => - val scope1 = scope - .capture(capScoped.map(scoped => scoped(scope)), name) - - // hopefully optimization/normalization has lifted anything - // that doesn't depend on argV above this lambda - FnValue { argV => - val scope2 = scope1.letAll(args, argV) - resFn(scope2) - } + } + case Lambda(caps, Some(name), args, res) => + val resFn = loop(res) + val capScoped = caps.map(loop).toVector + Dynamic { scope => + lazy val scope1: Scope = scope + .capture(capScoped.map(scoped => scoped(scope))) + .let(name, Eval.later(fn)) + + // hopefully optimization/normalization has lifted anything + // that doesn't depend on argV above this lambda + lazy val fn = FnValue { argV => + val scope2 = scope1.letAll(args, argV) + resFn(scope2) } + + fn } case LoopFn(caps, thisName, args, body) => val bodyFn = loop(body) @@ -444,34 +456,15 @@ object MatchlessToValue { Applicative[Scoped].map2(exprFn, argsFn) { (fn, args) => fn.applyAll(args) } - case Let(Right((n1, r)), loopFn @ LoopFn(_, n2, _, _), Local(n3)) - if (n1 === n3) && (n1 === n2) && r.isRecursive => - // LoopFn already correctly handles recursion - loop(loopFn) case Let(localOrBind, value, in) => val valueF = loop(value) val inF = loop(in) localOrBind match { - case Right((b, rec)) => - if (rec.isRecursive) { - - inF.withScope { scope => - // this is the only one that should - // use lazy/Eval.later - // we use it to tie the recursive knot - lazy val scope1: Scope = - scope.let(b, vv) - - lazy val vv = Eval.later(valueF(scope1)) - - scope1 - } - } else { - inF.withScope { (scope: Scope) => - val vv = Eval.now(valueF(scope)) - scope.let(b, vv) - } + case Right(b) => + inF.withScope { (scope: Scope) => + val vv = Eval.now(valueF(scope)) + scope.let(b, vv) } case Left(LocalAnon(l)) => inF.withScope { (scope: Scope) => diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala index 459a530e5..f66b54a39 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala @@ -9,7 +9,6 @@ import org.bykn.bosatsu.{ Matchless, Par, Parser, - RecursionKind } import org.bykn.bosatsu.codegen.Idents import org.bykn.bosatsu.rankn.Type @@ -506,7 +505,7 @@ object PythonGen { // if we have a top level let rec with the same name, handle it more cleanly me match { - case Let(Right((n1, RecursionKind.NonRecursive)), inner, Local(n2)) + case Let(Right(n1), inner, Local(n2)) if ((n1 === name) && (n2 === name)) => // we can just bind now at the top level for { @@ -516,12 +515,6 @@ object PythonGen { case _ => ops.loop(inner, None).map(nm := _) } } yield res - case Let(Right((n1, RecursionKind.Recursive)), fn: FnExpr, Local(n2)) - if (n1 === name) && (n2 === name) => - for { - nm <- Env.topLevelName(name) - res <- ops.topFn(nm, fn, None) - } yield res case fn: FnExpr => for { nm <- Env.topLevelName(name) @@ -1723,7 +1716,7 @@ object PythonGen { val inF = loop(in, slotName) localOrBind match { - case Right((b, _)) => + case Right(b) => // for fn, bosatsu doesn't allow bind name // shadowing, so the bind order of the name // doesn't matter @@ -1747,24 +1740,14 @@ object PythonGen { val inF = loop(in, slotName) localOrBind match { - case Right((b, rec)) => - if (rec.isRecursive) { - // value b is in scope first - for { - bi <- Env.bind(b) - v <- loop(notFn, slotName) - ine <- inF - _ <- Env.unbind(b) - } yield ((bi := v).withValue(ine)) - } else { - // value b is in scope after ve - for { - ve <- loop(notFn, slotName) - bi <- Env.bind(b) - ine <- inF - _ <- Env.unbind(b) - } yield ((bi := ve).withValue(ine)) - } + case Right(b) => + // value b is in scope after ve + for { + ve <- loop(notFn, slotName) + bi <- Env.bind(b) + ine <- inF + _ <- Env.unbind(b) + } yield ((bi := ve).withValue(ine)) case Left(LocalAnon(l)) => // anonymous names never shadow (Env.nameForAnon(l), loop(notFn, slotName)) diff --git a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala index deb8a7356..36beb8d2a 100644 --- a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala +++ b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala @@ -10,8 +10,8 @@ import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ import Identifier.{Bindable, Constructor} import rankn.DataRepr -import cats.implicits._ import org.scalatest.funsuite.AnyFunSuite +import scala.util.Try class MatchlessTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = @@ -52,28 +52,23 @@ class MatchlessTest extends AnyFunSuite { test("matchless.fromLet is pure: f(x) == f(x)") { forAll(genInputs) { case (b, r, t, fn) => - def run(): Matchless.Expr = - Matchless.fromLet(b, r, t)(fn) + def run(): Option[Matchless.Expr] = + // ill-formed inputs can fail + Try(Matchless.fromLet(b, r, t)(fn)).toOption assert(run() == run()) } } - val genMatchlessExpr: Gen[Matchless.Expr] = + lazy val genMatchlessExpr: Gen[Matchless.Expr] = genInputs.map { case (b, r, t, fn) => - Matchless.fromLet(b, r, t)(fn) + // ill-formed inputs can fail + Try(Matchless.fromLet(b, r, t)(fn)).toOption + } + .flatMap { + case Some(e) => Gen.const(e) + case None => genMatchlessExpr } - - test("regressions") { - // this is illegal code, but it shouldn't throw a match error: - val name = Identifier.Name("foo") - val te = TypedExpr.Local(name, rankn.Type.IntType, ()) - // this should not throw - val me = Matchless.fromLet(name, RecursionKind.Recursive, te)( - fnFromTypeEnv(rankn.TypeEnv.empty) - ) - assert(me != null) - } def genNE[A](max: Int, ga: Gen[A]): Gen[NonEmptyList[A]] = for {