Skip to content

Commit

Permalink
Simplify Matchless.Let around recursion (#1254)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek authored Nov 13, 2024
1 parent be453d5 commit 30b3af9
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 106 deletions.
42 changes: 27 additions & 15 deletions core/src/main/scala/org/bykn/bosatsu/Matchless.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ object Matchless {
// these hold bindings either in the code, or temporary
// local ones, note CheapExpr never trigger a side effect
sealed trait CheapExpr extends Expr
sealed abstract class FnExpr extends Expr
sealed abstract class FnExpr extends Expr {
def captures: List[Expr]
// this is set if the function is recursive
def recursiveName: Option[Bindable]
}

sealed abstract class StrPart
object StrPart {
Expand Down Expand Up @@ -78,7 +82,7 @@ object Matchless {
// name is set for recursive (but not tail recursive) methods
case class Lambda(
captures: List[Expr],
name: Option[Bindable],
recursiveName: Option[Bindable],
args: NonEmptyList[Bindable],
expr: Expr
) extends FnExpr
Expand All @@ -92,7 +96,9 @@ object Matchless {
name: Bindable,
arg: NonEmptyList[Bindable],
body: Expr
) extends FnExpr
) extends FnExpr {
val recursiveName: Option[Bindable] = Some(name)
}

case class Global(pack: PackageName, name: Bindable) extends CheapExpr

Expand All @@ -108,7 +114,7 @@ object Matchless {
// note fn is never an App
case class App(fn: Expr, arg: NonEmptyList[Expr]) extends Expr
case class Let(
arg: Either[LocalAnon, (Bindable, RecursionKind)],
arg: Either[LocalAnon, Bindable],
expr: Expr,
in: Expr
) extends Expr
Expand Down Expand Up @@ -324,12 +330,19 @@ object Matchless {
e: TypedExpr[A],
rec: RecursionKind,
slots: LambdaState
): F[Expr] = {
lazy val e0 = loop(e, if (rec.isRecursive) slots.inLet(name) else slots)
): F[Expr] =
rec match {
case RecursionKind.Recursive =>
def letrec(e: Expr): Expr =
Let(Right((name, RecursionKind.Recursive)), e, Local(name))
lazy val e0 = loop(e, slots.inLet(name))
def letrec(expr: Expr): Expr =
expr match {
case fn: FnExpr if fn.recursiveName == Some(name) => fn
case fn: FnExpr =>
// loops always have a function name
sys.error(s"expected ${fn.recursiveName} == Some($name) in ${e.repr.render(80)} which compiled to $fn")
case _ =>
sys.error(s"expected ${e.repr.render(80)} to compile to a function, but got: $expr")
}

// this could be tail recursive
if (SelfCallKind(name, e) == SelfCallKind.TailCall) {
Expand All @@ -345,20 +358,20 @@ object Matchless {
val (slots1, caps) = slots.inLet(name).lambdaFrees(frees)
loop(body, slots1)
.map { v =>
letrec(LoopFn(caps, name, args, v))
LoopFn(caps, name, args, v)
}
// $COVERAGE-OFF$
case _ =>
// TODO: I don't think this case should ever happen in real code
// but it definitely does in fuzz tests
e0.map(letrec)
// $COVERAGE-ON$
}
} else {
// otherwise let rec x = fn in x
e0.map(letrec)
}
case RecursionKind.NonRecursive => e0
case RecursionKind.NonRecursive => loop(e, slots)
}
}

def loop(te: TypedExpr[A], slots: LambdaState): F[Expr] =
te match {
Expand Down Expand Up @@ -396,7 +409,7 @@ object Matchless {
.mapN(App(_, _))
case TypedExpr.Let(a, e, in, r, _) =>
(loopLetVal(a, e, r, slots.unname), loop(in, slots))
.mapN(Let(Right((a, r)), _, _))
.mapN(Let(Right(a), _, _))
case TypedExpr.Literal(lit, _, _) => Monad[F].pure(Literal(lit))
case TypedExpr.Match(arg, branches, _) =>
(
Expand Down Expand Up @@ -766,8 +779,7 @@ object Matchless {

def lets(binds: List[(Bindable, Expr)], in: Expr): Expr =
binds.foldRight(in) { case ((b, e), r) =>
val arg = Right((b, RecursionKind.NonRecursive))
Let(arg, e, r)
Let(Right(b), e, r)
}

def checkLets(binds: List[LocalAnonMut], in: Expr): Expr =
Expand Down
89 changes: 41 additions & 48 deletions core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,9 @@ object MatchlessToValue {
()
}

def capture(it: Vector[Value], name: Option[Bindable]): Scope =
def capture(it: Vector[Value]): Scope =
Scope(
name match {
case None => Map.empty
case Some(n) => Map((n, locals(n)))
},
Map.empty,
LongMap.empty,
MLongMap(),
it
Expand Down Expand Up @@ -353,7 +350,7 @@ object MatchlessToValue {
// It may make things go faster
// if the caps are really small
// or if we can GC things sooner.
val scope1 = scope.capture(caps.map(s => s(scope)), Some(fnName))
val scope1 = scope.capture(caps.map(s => s(scope)))

FnValue { allArgs =>
var registers: NonEmptyList[Value] = allArgs
Expand Down Expand Up @@ -393,30 +390,45 @@ object MatchlessToValue {
// the locals can be recusive, so we box into Eval for laziness
def loop(me: Expr): Scoped[Value] =
me match {
case Lambda(caps, name, args, res) =>
case Lambda(Nil, None, args, res) =>
val resFn = loop(res)

if (caps.isEmpty && name.isEmpty) {
// we can allocate once if there is no closure
val scope1 = Scope.empty()
val fn = FnValue { argV =>
// we can allocate once if there is no closure
val scope1 = Scope.empty()
val fn = FnValue { argV =>
val scope2 = scope1.letAll(args, argV)
resFn(scope2)
}
Static(fn)
case Lambda(caps, None, args, res) =>
val resFn = loop(res)
val capScoped = caps.map(loop).toVector
Dynamic { scope =>
val scope1 = scope
.capture(capScoped.map(scoped => scoped(scope)))

// hopefully optimization/normalization has lifted anything
// that doesn't depend on argV above this lambda
FnValue { argV =>
val scope2 = scope1.letAll(args, argV)
resFn(scope2)
}
Static(fn)
} else {
val capScoped = caps.map(loop).toVector
Dynamic { scope =>
val scope1 = scope
.capture(capScoped.map(scoped => scoped(scope)), name)

// hopefully optimization/normalization has lifted anything
// that doesn't depend on argV above this lambda
FnValue { argV =>
val scope2 = scope1.letAll(args, argV)
resFn(scope2)
}
}
case Lambda(caps, Some(name), args, res) =>
val resFn = loop(res)
val capScoped = caps.map(loop).toVector
Dynamic { scope =>
lazy val scope1: Scope = scope
.capture(capScoped.map(scoped => scoped(scope)))
.let(name, Eval.later(fn))

// hopefully optimization/normalization has lifted anything
// that doesn't depend on argV above this lambda
lazy val fn = FnValue { argV =>
val scope2 = scope1.letAll(args, argV)
resFn(scope2)
}

fn
}
case LoopFn(caps, thisName, args, body) =>
val bodyFn = loop(body)
Expand Down Expand Up @@ -444,34 +456,15 @@ object MatchlessToValue {
Applicative[Scoped].map2(exprFn, argsFn) { (fn, args) =>
fn.applyAll(args)
}
case Let(Right((n1, r)), loopFn @ LoopFn(_, n2, _, _), Local(n3))
if (n1 === n3) && (n1 === n2) && r.isRecursive =>
// LoopFn already correctly handles recursion
loop(loopFn)
case Let(localOrBind, value, in) =>
val valueF = loop(value)
val inF = loop(in)

localOrBind match {
case Right((b, rec)) =>
if (rec.isRecursive) {

inF.withScope { scope =>
// this is the only one that should
// use lazy/Eval.later
// we use it to tie the recursive knot
lazy val scope1: Scope =
scope.let(b, vv)

lazy val vv = Eval.later(valueF(scope1))

scope1
}
} else {
inF.withScope { (scope: Scope) =>
val vv = Eval.now(valueF(scope))
scope.let(b, vv)
}
case Right(b) =>
inF.withScope { (scope: Scope) =>
val vv = Eval.now(valueF(scope))
scope.let(b, vv)
}
case Left(LocalAnon(l)) =>
inF.withScope { (scope: Scope) =>
Expand Down
37 changes: 10 additions & 27 deletions core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import org.bykn.bosatsu.{
Matchless,
Par,
Parser,
RecursionKind
}
import org.bykn.bosatsu.codegen.Idents
import org.bykn.bosatsu.rankn.Type
Expand Down Expand Up @@ -506,7 +505,7 @@ object PythonGen {

// if we have a top level let rec with the same name, handle it more cleanly
me match {
case Let(Right((n1, RecursionKind.NonRecursive)), inner, Local(n2))
case Let(Right(n1), inner, Local(n2))
if ((n1 === name) && (n2 === name)) =>
// we can just bind now at the top level
for {
Expand All @@ -516,12 +515,6 @@ object PythonGen {
case _ => ops.loop(inner, None).map(nm := _)
}
} yield res
case Let(Right((n1, RecursionKind.Recursive)), fn: FnExpr, Local(n2))
if (n1 === name) && (n2 === name) =>
for {
nm <- Env.topLevelName(name)
res <- ops.topFn(nm, fn, None)
} yield res
case fn: FnExpr =>
for {
nm <- Env.topLevelName(name)
Expand Down Expand Up @@ -1723,7 +1716,7 @@ object PythonGen {
val inF = loop(in, slotName)

localOrBind match {
case Right((b, _)) =>
case Right(b) =>
// for fn, bosatsu doesn't allow bind name
// shadowing, so the bind order of the name
// doesn't matter
Expand All @@ -1747,24 +1740,14 @@ object PythonGen {
val inF = loop(in, slotName)

localOrBind match {
case Right((b, rec)) =>
if (rec.isRecursive) {
// value b is in scope first
for {
bi <- Env.bind(b)
v <- loop(notFn, slotName)
ine <- inF
_ <- Env.unbind(b)
} yield ((bi := v).withValue(ine))
} else {
// value b is in scope after ve
for {
ve <- loop(notFn, slotName)
bi <- Env.bind(b)
ine <- inF
_ <- Env.unbind(b)
} yield ((bi := ve).withValue(ine))
}
case Right(b) =>
// value b is in scope after ve
for {
ve <- loop(notFn, slotName)
bi <- Env.bind(b)
ine <- inF
_ <- Env.unbind(b)
} yield ((bi := ve).withValue(ine))
case Left(LocalAnon(l)) =>
// anonymous names never shadow
(Env.nameForAnon(l), loop(notFn, slotName))
Expand Down
27 changes: 11 additions & 16 deletions core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{
import Identifier.{Bindable, Constructor}
import rankn.DataRepr

import cats.implicits._
import org.scalatest.funsuite.AnyFunSuite
import scala.util.Try

class MatchlessTest extends AnyFunSuite {
implicit val generatorDrivenConfig: PropertyCheckConfiguration =
Expand Down Expand Up @@ -52,28 +52,23 @@ class MatchlessTest extends AnyFunSuite {

test("matchless.fromLet is pure: f(x) == f(x)") {
forAll(genInputs) { case (b, r, t, fn) =>
def run(): Matchless.Expr =
Matchless.fromLet(b, r, t)(fn)
def run(): Option[Matchless.Expr] =
// ill-formed inputs can fail
Try(Matchless.fromLet(b, r, t)(fn)).toOption

assert(run() == run())
}
}

val genMatchlessExpr: Gen[Matchless.Expr] =
lazy val genMatchlessExpr: Gen[Matchless.Expr] =
genInputs.map { case (b, r, t, fn) =>
Matchless.fromLet(b, r, t)(fn)
// ill-formed inputs can fail
Try(Matchless.fromLet(b, r, t)(fn)).toOption
}
.flatMap {
case Some(e) => Gen.const(e)
case None => genMatchlessExpr
}

test("regressions") {
// this is illegal code, but it shouldn't throw a match error:
val name = Identifier.Name("foo")
val te = TypedExpr.Local(name, rankn.Type.IntType, ())
// this should not throw
val me = Matchless.fromLet(name, RecursionKind.Recursive, te)(
fnFromTypeEnv(rankn.TypeEnv.empty)
)
assert(me != null)
}

def genNE[A](max: Int, ga: Gen[A]): Gen[NonEmptyList[A]] =
for {
Expand Down

0 comments on commit 30b3af9

Please sign in to comment.