diff --git a/cli/src/test/scala/org/bykn/bosatsu/PathModuleTest.scala b/cli/src/test/scala/org/bykn/bosatsu/PathModuleTest.scala index 307411f9c..8cbd4a1b8 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/PathModuleTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/PathModuleTest.scala @@ -292,5 +292,4 @@ class PathModuleTest extends AnyFunSuite { case other => fail(s"unexpeced: $other") } } - } diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala index cdb2495a4..0c1e3dbb0 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -45,7 +45,7 @@ class ClangGenTest extends munit.FunSuite { To inspect the code, change the hash, and it will print the code out */ testFilesCompilesToHash("test_workspace/Ackermann.bosatsu")( - "e7bf005ddf95ce8e62e45c6985c09e51" + "f881ad7b7bda633aae1ad1c061f1d3bd" ) } } \ No newline at end of file diff --git a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala index 0cadc0a79..d83ad5419 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala @@ -33,18 +33,9 @@ object Matchless { body: Expr ) extends FnExpr - // this is a tail recursive function that should be compiled into a loop - // when a call to name is done inside body, that should restart the loop - // the type of this Expr a function with the arity of args that returns - // the type of body - case class LoopFn( - captures: List[Expr], - name: Bindable, - args: NonEmptyList[Bindable], - body: Expr - ) extends FnExpr { - val recursiveName: Option[Bindable] = Some(name) - } + // This is a while loop, the result of which is result and the body is evaluated + // while cond is true + case class WhileExpr(cond: BoolExpr, effectExpr: Expr, result: LocalAnonMut) extends Expr case class Global(pack: PackageName, name: Bindable) extends CheapExpr @@ -77,7 +68,18 @@ object Matchless { else Let(Left(arg), expr, in) } - case class LetMut(name: LocalAnonMut, span: Expr) extends Expr + case class LetMut(name: LocalAnonMut, span: Expr) extends Expr { + // often we have several LetMut at once, return all them + def flatten: (NonEmptyList[LocalAnonMut], Expr) = { + span match { + case next @ LetMut(_, _) => + val (anons, expr) = next.flatten + (name :: anons, expr) + case notLetMut => + (NonEmptyList.one(name), notLetMut) + } + } + } case class Literal(lit: Lit) extends CheapExpr // these result in Int values which are also used as booleans @@ -154,6 +156,21 @@ object Matchless { } } case class Always(cond: BoolExpr, thenExpr: Expr) extends Expr + object Always { + object SetChain { + // a common pattern is Always(SetMut(m, e), r) + def unapply(expr: Expr): Option[(NonEmptyList[(LocalAnonMut, Expr)], Expr)] = + expr match { + case Always(SetMut(mut, v), res) => + val pair = (mut, v) + unapply(res) match { + case None => Some((NonEmptyList.one(pair), res)) + case Some((muts, res)) => Some((pair :: muts, res)) + } + case _ => None + } + } + } def always(cond: BoolExpr, thenExpr: Expr): Expr = if (hasSideEffect(cond)) Always(cond, thenExpr) else thenExpr @@ -172,6 +189,15 @@ object Matchless { // we need to compile calls to constructors into these case class MakeEnum(variant: Int, arity: Int, famArities: List[Int]) extends ConsExpr + + private val boolFamArities = 0 :: 0 :: Nil + val FalseExpr: Expr = MakeEnum(0, 0, boolFamArities) + val TrueExpr: Expr = MakeEnum(1, 0, boolFamArities) + val UnitExpr: Expr = MakeStruct(0) + + def isTrueExpr(e: CheapExpr): BoolExpr = + CheckVariant(e, 1, 0, boolFamArities) + case class MakeStruct(arity: Int) extends ConsExpr case object ZeroNat extends ConsExpr { def arity = 0 @@ -285,6 +311,171 @@ object Matchless { def inLet(b: Bindable): LambdaState = copy(name = Some(b)) } + def translateLocalsBool(m: Map[Bindable, LocalAnonMut], e: BoolExpr): BoolExpr = + e match { + case SetMut(mut, e) => SetMut(mut, translateLocals(m, e)) + case And(b1, b2) => + And(translateLocalsBool(m, b1), translateLocalsBool(m, b2)) + case EqualsLit(x, l) => + EqualsLit(translateLocalsCheap(m, x), l) + case EqualsNat(x, n) => + EqualsNat(translateLocalsCheap(m, x), n) + case TrueConst => TrueConst + case CheckVariant(expr, expect, sz, fam) => + CheckVariant(translateLocalsCheap(m, expr), expect, sz, fam) + case ms: MatchString => + ms.copy(arg = translateLocalsCheap(m, ms.arg)) + case sl: SearchList => + sl.copy( + init = translateLocalsCheap(m, sl.init), + check = translateLocalsBool(m, sl.check) + ) + } + + def translateLocals(m: Map[Bindable, LocalAnonMut], e: Expr): Expr = + e match { + case App(fn, appArgs) => + App(translateLocals(m, fn), appArgs.map(translateLocals(m, _))) + case If(c, tcase, fcase) => + If(translateLocalsBool(m, c), translateLocals(m, tcase), translateLocals(m, fcase)) + case Always(c, e) => + Always(translateLocalsBool(m, c), translateLocals(m, e)) + case LetMut(mut, e) => + LetMut(mut, translateLocals(m, e)) + case Let(n, v, in) => + val m1 = n match { + case Right(b) => m - b + case _ => m + } + Let(n, translateLocals(m, v), translateLocals(m1, in)) + // the rest cannot have a call in tail position + case Local(n) => + m.get(n) match { + case Some(mut) => mut + case None => e + } + case PrevNat(n) => PrevNat(translateLocals(m, n)) + case ge: GetEnumElement => + ge.copy(arg = translateLocalsCheap(m, ge.arg)) + case gs: GetStructElement => + gs.copy(arg = translateLocalsCheap(m, gs.arg)) + case Lambda(c, r, as, b) => + val m1 = m -- as.toList + val b1 = translateLocals(m1, b) + Lambda(c, r, as, b1) + case WhileExpr(c, ef, r) => + WhileExpr(translateLocalsBool(m, c), translateLocals(m, ef), r) + case ClosureSlot(_) | Global(_, _) | LocalAnon(_) | LocalAnonMut(_) | + MakeEnum(_, _, _) | MakeStruct(_) | SuccNat | Literal(_) | ZeroNat => e + } + def translateLocalsCheap(m: Map[Bindable, LocalAnonMut], e: CheapExpr): CheapExpr = + translateLocals(m, e) match { + case ch: CheapExpr => ch + case notCheap => sys.error(s"invariant violation: translation didn't maintain cheap: $e => $notCheap") + } + + def loopFn( + captures: List[Expr], + name: Bindable, + args: NonEmptyList[Bindable], + body: Expr): F[Expr] = { + + def setAll(ls: List[(LocalAnonMut, Expr)], ret: Expr): Expr = + ls.foldRight(ret) { case ((l, e), r) => + Always(SetMut(l, e), r) + } + // assign any results to result and set the condition to false + // and replace any tail calls to nm(args) with assigning args to those values + case class ArgRecord(name: Bindable, tmp: LocalAnon, loopVar: LocalAnonMut) + def toWhileBody(args: NonEmptyList[ArgRecord], cond: LocalAnonMut, result: LocalAnonMut): Expr = { + + val nameExpr = Local(name) + + def returnValue(v: Expr): Expr = + setAll((cond, FalseExpr) :: (result, v) :: Nil, UnitExpr) + + // return Some(e) if this expression can be rewritten into a tail call to name, + // in instead of the call, do a bunch of SetMut on the args, and set cond to false + // else None + def loop(expr: Expr): Option[Expr] = + expr match { + case App(fn, appArgs) if fn == nameExpr => + // this is a tail call + // we know the length of appArgs must match args or the code wouldn't have compiled + // we have to first assign to the temp variables, and then assign the temp variables + // to the results to make sure we don't have any data dependency issues with the values; + val tmpAssigns = appArgs + .iterator + .zip(args.iterator) + .flatMap { case (appArg, argRecord) => + if (appArg != argRecord.loopVar) + // don't create self assignments + Iterator.single(((argRecord.tmp, appArg), (argRecord.loopVar, argRecord.tmp))) + else + Iterator.empty + } + .toList + + // there must be at least one assignment + Some(letAnons( + tmpAssigns.map(_._1), + setAll(tmpAssigns.map(_._2), UnitExpr) + )) + case If(c, tcase, fcase) => + // this can possible have tail calls inside the branches + (loop(tcase), loop(fcase)) match { + case (Some(t), Some(f)) => + Some(If(c, t, f)) + case (None, Some(f)) => + Some(If(c, returnValue(tcase), f)) + case (Some(t), None) => + Some(If(c, t, returnValue(fcase))) + case (None, None) => None + } + case Always(c, e) => + loop(e).map(Always(c, _)) + case LetMut(m, e) => + loop(e).map(LetMut(m, _)) + case Let(b, v, in) => + // in is in tail position + loop(in).map(Let(b, v, _)) + // the rest cannot have a call in tail position + case App(_, _) | ClosureSlot(_) | GetEnumElement(_, _, _, _) | GetStructElement(_, _, _) | + Global(_, _) | Lambda(_, _, _, _) | Literal(_) | Local(_) | LocalAnon(_) | LocalAnonMut(_) | + MakeEnum(_, _, _) | MakeStruct(_) | PrevNat(_) | SuccNat | WhileExpr(_, _, _) | ZeroNat => None + } + + val bodyTrans = translateLocals( + args.toList.map(a => (a.name, a.loopVar)).toMap, + body) + + loop(bodyTrans) match { + case Some(expr) => expr + case None => + sys.error("invariant violation: could not find tail calls in:" + + s"toWhileBody(name = $name, body = $body)") + } + } + val mut = makeAnon.map(LocalAnonMut(_)) + val anon = makeAnon.map(LocalAnon(_)) + for { + cond <- mut + result <- mut + args1 <- args.traverse { b => (anon, mut).mapN(ArgRecord(b, _, _)) } + whileLoop = toWhileBody(args1, cond, result) + allMuts = cond :: result :: args1.toList.map(_.loopVar) + // we don't need to set the name on the lambda because this is no longer recursive + } yield Lambda(captures, None, args, + letMutAll(allMuts, + setAll( + args1.toList.map(arg => (arg.loopVar, Local(arg.name))), + Always(SetMut(cond, TrueExpr), + WhileExpr(isTrueExpr(cond), whileLoop, result)) + ) + ) + ) + } + def loopLetVal( name: Bindable, e: TypedExpr[A], @@ -317,8 +508,8 @@ object Matchless { val frees = TypedExpr.freeVars(e :: Nil) val (slots1, caps) = slots.inLet(name).lambdaFrees(frees) loop(body, slots1) - .map { v => - LoopFn(caps, name, args, v) + .flatMap { v => + loopFn(caps, name, args, v) } // $COVERAGE-OFF$ case _ => @@ -763,8 +954,13 @@ object Matchless { Let(b, e, r) } - def checkLets(binds: List[LocalAnonMut], in: Expr): Expr = - binds.foldLeft(in) { case (rest, anon) => + def letAnons(binds: List[(LocalAnon, Expr)], in: Expr): Expr = + binds.foldRight(in) { case ((b, e), r) => + Let(Left(b), e, r) + } + + def letMutAll(binds: List[LocalAnonMut], in: Expr): Expr = + binds.foldRight(in) { case (anon, rest) => // TODO: sometimes we generate code like // LetMut(x, Always(SetMut(x, y), f)) // with no side effects in y or f @@ -797,7 +993,7 @@ object Matchless { case NonEmptyList((b0, TrueConst, binds), _) => // this is a total match, no fall through val right = lets(binds, r1) - Monad[F].pure(checkLets(b0, right)) + Monad[F].pure(letMutAll(b0, right)) case NonEmptyList((b0, cond, binds), others) => val thisBranch = lets(binds, r1) val res = others match { @@ -819,7 +1015,7 @@ object Matchless { } } - res.map(checkLets(b0, _)) + res.map(letMutAll(b0, _)) } doesMatch(arg, p1, branches.tail.isEmpty).flatMap(loop) diff --git a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala index 5c5b40078..4d8a56c83 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala @@ -6,7 +6,6 @@ import cats.evidence.Is import java.math.BigInteger import org.bykn.bosatsu.pattern.StrPart import scala.collection.immutable.LongMap -import scala.collection.mutable.{LongMap => MLongMap} import Identifier.Bindable import Value._ @@ -67,16 +66,48 @@ object MatchlessToValue { case object Uninitialized val uninit: Value = ExternalValue(Uninitialized) + class DebugStr(prefix: String = "") { + private var message: String = "" + def set(msg: String): Unit = { + message = msg; + } + + def append(msg: String): Unit = { + message = message + " :: " + msg + } + + override def toString = prefix + message + + def scope(outer: String): DebugStr = { + new DebugStr(prefix + "/" + outer) + } + } + + class Cell { + private var value = uninit + def set(v: Value): Unit = { + value = v + } + + def get(): Value = value + } + final case class Scope( locals: Map[Bindable, Eval[Value]], anon: LongMap[Value], - muts: MLongMap[Value], - slots: Vector[Value] + muts: LongMap[Cell], + slots: Vector[Value], + extra: DebugStr ) { def let(b: Bindable, v: Eval[Value]): Scope = copy(locals = locals.updated(b, v)) + def letMuts(idxs: Iterator[Long]): Scope = { + val mut1 = muts ++ idxs.map(l => (l, new Cell)) + copy(muts = mut1) + } + def letAll(bs: NonEmptyList[Bindable], vs: NonEmptyList[Value]): Scope = { val b = bs.iterator val v = vs.iterator @@ -87,24 +118,31 @@ object MatchlessToValue { copy(locals = local1) } + def debugString: String = + s"local keys: ${locals.keySet}, anon keys: ${anon.keySet}, anonMut keys: ${muts.keySet}\nextra=$extra" + def updateMut(mutIdx: Long, v: Value): Unit = { - assert(muts.contains(mutIdx)) - muts.put(mutIdx, v) + if (!muts.contains(mutIdx)) { + sys.error(s"updateMut($mutIdx, _) but $mutIdx is empty: $debugString") + } + muts(mutIdx).set(v) () } - def capture(it: Vector[Value]): Scope = - Scope( - Map.empty, - LongMap.empty, - MLongMap(), - it - ) } object Scope { def empty(): Scope = - Scope(Map.empty, LongMap.empty, MLongMap(), Vector.empty) + Scope(Map.empty, LongMap.empty, LongMap.empty, Vector.empty, new DebugStr) + + def capture(it: Vector[Value], dbg: DebugStr = new DebugStr): Scope = + Scope( + Map.empty, + LongMap.empty, + LongMap.empty, + it, + dbg.scope("capture") + ) } sealed abstract class Scoped[A] { @@ -296,96 +334,6 @@ object MatchlessToValue { } } - def buildLoop( - caps: Vector[Scoped[Value]], - fnName: Bindable, - args: NonEmptyList[Bindable], - body: Scoped[Value] - ): Scoped[Value] = { - val argCount = args.length - val argNames: Array[Bindable] = args.toList.toArray - if (caps.isEmpty) { - // We only capture ourself and we put that in below - val scope1 = Scope.empty() - val fn = FnValue { allArgs => - var registers: NonEmptyList[Value] = allArgs - - // the registers are set up - // when we recur, that is a continue on the loop, - // we just update the registers and return null - val continueFn = FnValue { continueArgs => - registers = continueArgs - null - } - - val scope2 = scope1.let(fnName, Eval.now(continueFn)) - - var res: Value = null - - while (res eq null) { - // read the registers into the environment - var idx = 0 - var reg: List[Value] = registers.toList - var s: Scope = scope2 - while (idx < argCount) { - val b = argNames(idx) - val v = reg.head - reg = reg.tail - s = s.let(b, Eval.now(v)) - idx = idx + 1 - } - res = body(s) - } - - res - } - - Static(fn) - } else { - Dynamic { scope => - // TODO this maybe isn't helpful - // it doesn't matter if the scope - // is too broad for correctness. - // It may make things go faster - // if the caps are really small - // or if we can GC things sooner. - val scope1 = scope.capture(caps.map(s => s(scope))) - - FnValue { allArgs => - var registers: NonEmptyList[Value] = allArgs - - // the registers are set up - // when we recur, that is a continue on the loop, - // we just update the registers and return null - val continueFn = FnValue { continueArgs => - registers = continueArgs - null - } - - val scope2 = scope1.let(fnName, Eval.now(continueFn)) - - var res: Value = null - - while (res eq null) { - // read the registers into the environment - var idx = 0 - var reg: List[Value] = registers.toList - var s: Scope = scope2 - while (idx < argCount) { - val b = argNames(idx) - val v = reg.head - reg = reg.tail - s = s.let(b, Eval.now(v)) - idx = idx + 1 - } - res = body(s) - } - - res - } - } - } - } // the locals can be recusive, so we box into Eval for laziness def loop(me: Expr): Scoped[Value] = me match { @@ -402,8 +350,9 @@ object MatchlessToValue { val resFn = loop(res) val capScoped = caps.map(loop).toVector Dynamic { scope => - val scope1 = scope - .capture(capScoped.map(scoped => scoped(scope))) + val valuesInScope = capScoped.map(scoped => scoped(scope)) + // now we ignore the scope after reading from it + val scope1 = Scope.capture(valuesInScope, scope.extra) // hopefully optimization/normalization has lifted anything // that doesn't depend on argV above this lambda @@ -416,8 +365,10 @@ object MatchlessToValue { val resFn = loop(res) val capScoped = caps.map(loop).toVector Dynamic { scope => - lazy val scope1: Scope = scope - .capture(capScoped.map(scoped => scoped(scope))) + val valuesInScope = capScoped.map(scoped => scoped(scope)) + + lazy val scope1: Scope = Scope + .capture(valuesInScope) .let(name, Eval.later(fn)) // hopefully optimization/normalization has lifted anything @@ -429,10 +380,21 @@ object MatchlessToValue { fn } - case LoopFn(caps, thisName, args, body) => - val bodyFn = loop(body) + case WhileExpr(cond, effect, result) => + val condF = boolExpr(cond) + val effectF = loop(effect) - buildLoop(caps.map(loop).toVector, thisName, args, bodyFn) + // conditions are never static + // or a previous optimization/normalization + // has failed + Dynamic { (scope: Scope) => + var c = condF(scope) + while(c) { + effectF(scope) + c = condF(scope) + } + scope.muts(result.ident).get() + } case Global(p, n) => val res = resolve(p, n) @@ -441,10 +403,15 @@ object MatchlessToValue { Dynamic((_: Scope) => res.value) case Local(b) => Dynamic(_.locals(b).value) case LocalAnon(a) => Dynamic(_.anon(a)) - case LocalAnonMut(m) => Dynamic(_.muts(m)) + case LocalAnonMut(m) => Dynamic { s => + s.muts.get(m) match { + case Some(v) => v.get() + case None => sys.error(s"could not get: $m. ${s.debugString}") + } + } case ClosureSlot(idx) => Dynamic(_.slots(idx)) case App(expr, args) => - // TODO: App(LoopFn(.. + // TODO: App(lambda(while // can be optimized into a while // loop, but there isn't any prior optimization // that would do this.... maybe it should @@ -471,23 +438,17 @@ object MatchlessToValue { scope.copy(anon = scope.anon.updated(l, vv)) } } - case LetMut(LocalAnonMut(l), in) => - loop(in) match { - case s @ Static(_) => s - case Dynamic(inF) => - Dynamic { (scope: Scope) => - // we make sure there is - // a value that will show up - // strange in tests, - // for an optimization we could - // avoid this - scope.muts.put(l, uninit) - val res = inF(scope) - // now we can remove this from mutable scope - // we should be able to remove this - scope.muts.remove(l) - res - } + case lm @ LetMut(_, _) => + val (anonMuts, in) = lm.flatten + val inF = loop(in) + Dynamic { (scope: Scope) => + // we make sure there is + // a value that will show up + // strange in tests, + // for an optimization we could + // avoid this + val scope1 = scope.letMuts(anonMuts.iterator.map(_.ident)) + inF(scope1) } case Literal(lit) => Static(Value.fromLit(lit)) @@ -503,6 +464,18 @@ object MatchlessToValue { if (condF(scope)) thenF(scope) else elseF(scope) } + case Always.SetChain(muts, expr) => + val values = muts.map { case (m, e) => (m, loop(e)) } + val exprF = loop(expr) + + Dynamic { scope => + values.iterator.foreach { case (m, e) => + val ev = e(scope) + scope.updateMut(m.ident, ev) + } + + exprF(scope) + } case Always(cond, expr) => val condF = boolExpr(cond) val exprF = loop(expr) diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala index ede3802d4..9b1bd9e8f 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala @@ -244,69 +244,6 @@ object ClangGen { // This name has to be impossible to give out for any other purpose val slotsArgName: Code.Ident = Code.Ident("__bstsi_slot") - // assign any results to result and set the condition to false - // and replace any tail calls to nm(args) with assigning args to those values - def toWhileBody(fnName: Code.Ident, argTemp: NonEmptyList[(Code.Param, Code.Ident)], isClosure: Boolean, cond: Code.Ident, result: Code.Ident, body: Code.ValueLike): Code.Block = { - - import Code._ - - def returnValue(vl: ValueLike): Statement = - (cond := FalseLit) + - (result := vl) - - def loop(vl: ValueLike): Option[Statement] = - vl match { - case Apply(fn, appArgs) if fn == fnName => - // this is a tail call - val newArgsList = - if (isClosure) appArgs.tail - else appArgs - // we know the length of appArgs must match args or the code wouldn't have compiled - // we have to first assign to the temp variables, and then assign the temp variables - // to the results to make sure we don't have any data dependency issues with the values; - val tmpAssigns = argTemp - .iterator - .zip(newArgsList.iterator) - .flatMap { case ((Param(_, name), tmp), value) => - if (name != value) - // don't create self assignments - Iterator.single((Assignment(tmp, value), Assignment(name, tmp))) - else - Iterator.empty - } - .toList - - // there must be at least one assignment - val assignNEL = NonEmptyList.fromListUnsafe( - tmpAssigns.map(_._1) ::: tmpAssigns.map(_._2) - ) - Some(Statements(assignNEL)) - case IfElseValue(c, t, f) => - // this can possible have tail calls inside the branches - (loop(t), loop(f)) match { - case (Some(t), Some(f)) => - Some(ifThenElse(c, t, f)) - case (None, Some(f)) => - Some(ifThenElse(c, returnValue(t), f)) - case (Some(t), None) => - Some(ifThenElse(c, t, returnValue(f))) - case (None, None) => None - } - case Ternary(c, t, f) => loop(IfElseValue(c, t, f)) - case WithValue(s, vl) => loop(vl).map(s + _) - case Apply(_, _) | Cast(_, _) | BinExpr(_, _, _) | Bracket(_, _) | Ident(_) | - IntLiteral(_) | PostfixExpr(_, _) | PrefixExpr(_, _) | Select(_, _) | - StrLiteral(_) => None - } - - loop(body) match { - case Some(stmt) => block(stmt) - case None => - sys.error("invariant violation: could not find tail calls in:" + - s"toWhileBody(fnName = $fnName, body = $body)") - } - } - def bindAll[A](nel: NonEmptyList[Bindable])(in: T[A]): T[A] = bind(nel.head) { NonEmptyList.fromList(nel.tail) match { @@ -1070,6 +1007,17 @@ object ClangGen { .flatMapN { (c, thenC, elseC) => Code.ValueLike.ifThenElseV(c, thenC, elseC)(newLocalName) } + case Always.SetChain(setmuts, result) => + ( + setmuts.traverse { case (LocalAnonMut(mut), v) => + for { + name <- getAnon(mut) + vl <- innerToValue(v) + } yield (name := vl) + }, innerToValue(result) + ).mapN { (assigns, result) => + Code.Statements(assigns) +: result + } case Always(cond, thenExpr) => boolToValue(cond).flatMap { bv => bv.discardValue match { @@ -1133,6 +1081,24 @@ object ClangGen { NonEmptyList.one(argVL) )(newLocalName) } + case WhileExpr(cond, effect, res) => + (boolToValue(cond), innerToValue(effect), innerToValue(res), newLocalName("cond")) + .mapN { (cond, effect, res, condVar) => + Code.Statements( + Code.DeclareVar(Nil, Code.TypeIdent.Bool, condVar, None), + condVar := cond, + Code.While(condVar, + Code.Block( + NonEmptyList.one( + condVar := cond, + ) + .prependList( + effect.discardValue.toList + ) + ) + ) + ) +: res + } } def fnStatement(fnName: Code.Ident, fn: FnExpr): T[Code.Statement] = @@ -1157,38 +1123,6 @@ object ClangGen { } } yield Code.DeclareFn(Nil, Code.TypeIdent.BValue, fnName, allArgs.toList, Some(Code.block(fnBody))) } - case LoopFn(captures, nm, args, body) => - recursiveName(fnName, nm, isClosure = captures.nonEmpty, arity = fn.arity) { - bindAll(args) { - for { - cond <- newLocalName("cond") - res <- newLocalName("res") - bodyVL <- innerToValue(body) - argParamsTemps <- args.traverse { b => - (getBinding(b), newLocalName("loop_temp")).mapN { (i, t) => (Code.Param(Code.TypeIdent.BValue, i), t) } - } - whileBody = toWhileBody(fnName, argParamsTemps, isClosure = captures.nonEmpty, cond = cond, result = res, body = bodyVL) - declTmps = Code.Statements( - argParamsTemps.map { case (_, tmp) => - Code.DeclareVar(Nil, Code.TypeIdent.BValue, tmp, None) - } - ) - fnBody = Code.block( - declTmps, - Code.DeclareVar(Nil, Code.TypeIdent.Bool, cond, Some(Code.TrueLit)), - Code.DeclareVar(Nil, Code.TypeIdent.BValue, res, None), - Code.While(cond, whileBody), - Code.Return(Some(res)) - ) - argParams = argParamsTemps.map(_._1) - allArgs = - if (captures.isEmpty) argParams - else { - Code.Param(Code.TypeIdent.BValue.ptr, slotsArgName) :: argParams - } - } yield Code.DeclareFn(Nil, Code.TypeIdent.BValue, fnName, allArgs.toList, Some(fnBody)) - } - } } def renderTop(p: PackageName, b: Bindable, expr: Expr): T[Unit] = diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala index 988e9464f..61944d560 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala @@ -498,7 +498,7 @@ object Code { // These are the highest priority, so safe to not use a parens def unapply(e: Expression): Option[Expression] = e match { - case noPar @ (Ident(_) | Apply(_, _) | Select(_, _) | Bracket(_, _)) => Some(noPar) + case noPar @ (Ident(_) | Apply(_, _) | Select(_, _) | Bracket(_, _) | IntLiteral(_)) => Some(noPar) case _ => None } } diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala index 94a3b232b..126899c34 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala @@ -421,61 +421,6 @@ object PythonGen { loop(initBody) } - - // these are always recursive so we can use def to define them - def buildLoop( - selfName: Ident, - fnMutArgs: NonEmptyList[(Ident, Ident)], - body: ValueLike - ): Env[Statement] = { - /* - * bodyUpdate = body except App(foo, args) is replaced with - * reseting the inputs, and setting cont to True and having - * the value () - * - * def foo(a)(b)(c): - * cont = True - * res = () - * while cont: - * cont = False - * res = bodyUpdate - * return res - */ - val fnArgs = fnMutArgs.map(_._1) - val mutArgs = fnMutArgs.map(_._2) - - def assignMut(cont: Code.Ident)(args: List[Expression]): Statement = { - // do the replacement in one atomic go. otherwise - // we could mutate a variable a later expression depends on - // some times we generate code that does x = x, remove those cases - val (left, right) = - mutArgs.toList.zip(args).filter { case (x, y) => x != y }.unzip - - Code.block( - cont := Const.True, - if (left.isEmpty) Pass - else if (left.lengthCompare(1) == 0) { - left.head := right.head - } else { - (MakeTuple(left) := MakeTuple(right)) - } - ) - } - - for { - cont <- Env.newAssignableVar - ac = assignMut(cont)(fnArgs.toList) - res <- Env.newAssignableVar - ar = (res := Code.Const.Unit) - body1 <- replaceTailCallWithAssign(selfName, mutArgs.length, body)( - assignMut(cont) - ) - setRes = res := body1 - loop = While(cont, (cont := false) +: setRes) - newBody = (ac +: ar +: loop).withValue(res) - } yield makeDef(selfName, fnArgs, newBody) - } - } // we escape by prefixing by three underscores, ___ and n (for name) @@ -1708,34 +1653,13 @@ object PythonGen { .withValue(res) } - // if expr is a LoopFn or Lambda handle it + // if expr is a Lambda handle it def topFn( name: Code.Ident, expr: FnExpr, slotName: Option[Code.Ident] ): Env[Statement] = expr match { - case LoopFn(captures, _, args, b) => - // note, name is already bound - // args can use topFn - val boundA = args.traverse(Env.topLevelName) - val subsA = args.traverse { a => - for { - mut <- Env.newAssignableVar - _ <- Env.subs(a, mut) - } yield (a, mut) - } - - for { - as <- boundA - subs <- subsA - subs1 = as.zipWith(subs) { case (b, (_, m)) => (b, m) } - (binds, body) <- makeSlots(captures, slotName)(loop(b, _)) - loopRes <- Env.buildLoop(name, subs1, body) - // we have bound this name twice, once for the top and once for substitution - _ <- subs.traverse_ { case (a, _) => Env.unbind(a) } - } yield Code.blockFromList(binds.toList ::: loopRes :: Nil) - case Lambda(captures, _, args, body) => // we can ignore name because python already allows recursion // we can use topLevelName on makeDefs since they are already @@ -1744,13 +1668,13 @@ object PythonGen { args.traverse(Env.topLevelName(_)), makeSlots(captures, slotName)(loop(body, _)) ) - .mapN { case (as, (slots, body)) => - Code.blockFromList( - slots.toList ::: - Env.makeDef(name, as, body) :: - Nil - ) - } + .mapN { case (as, (slots, body)) => + Code.blockFromList( + slots.toList ::: + Env.makeDef(name, as, body) :: + Nil + ) + } } def makeSlots[A](captures: List[Expr], slotName: Option[Code.Ident])( @@ -1788,33 +1712,19 @@ object PythonGen { block = Code.blockFromList(prefix.toList ::: defn :: Nil) } yield block.withValue(defName) } - case LoopFn(captures, thisName, args, body) => - // note, thisName is already bound because LoopFn - // is a lambda, not a def - - // we can use topLeft for arg names - val boundA = args.traverse(Env.topLevelName) - val subsA = args.traverse { a => - for { - mut <- Env.newAssignableVar - _ <- Env.subs(a, mut) - } yield (a, mut) - } - - for { - nameI <- Env.bind(thisName) - as <- boundA - subs <- subsA - (prefix, body) <- makeSlots(captures, slotName)(loop(body, _)) - subs1 = as.zipWith(subs) { case (b, (_, m)) => (b, m) } - loopRes <- Env.buildLoop(nameI, subs1, body) - // we have bound the args twice: once as args, once as interal muts - _ <- subs.traverse_ { case (a, _) => Env.unbind(a) } - _ <- Env.unbind(thisName) - } yield Code - .blockFromList(prefix.toList :+ loopRes) - .withValue(nameI) - + case WhileExpr(cond, effect, res) => + (boolExpr(cond, slotName), loop(effect, slotName), loop(res, slotName), Env.newAssignableVar) + .mapN { (cond, effect, res, c) => + Code.block( + c := cond, + Code.While(c, + Code.block( + Code.always(effect), + c := cond + ) + ) + ).withValue(res) + } case PredefExternal((fn, arity)) => // make a lambda PredefExternal.makeLambda(arity)(fn) @@ -1926,6 +1836,19 @@ object PythonGen { Env.ifElse(ifs, elseV) }.flatten + case Always.SetChain(setmuts, result) => + ( + setmuts.traverse { case (LocalAnonMut(mut), v) => + Env.nameForAnon(mut).product(loop(v, slotName)) + }, loop(result, slotName) + ).mapN { (assigns, result) => + Code.blockFromList( + assigns.toList.map { case (mut, v) => + mut := v + } + ) + .withValue(result) + } case Always(cond, expr) => (boolExpr(cond, slotName).map(Code.always), loop(expr, slotName)) .mapN(_.withValue(_)) diff --git a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala index 85ffd1a10..bf87ccba7 100644 --- a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -63,7 +63,7 @@ BValue ___bsts_g_Bosatsu_l_Predef_l_build__List(BValue __bsts_b_fn0) { assertPredefFns("foldr_List")("""#include "bosatsu_runtime.h" BValue __bsts_t_closure__loop0(BValue* __bstsi_slot, BValue __bsts_b_list1) { - if (get_variant(__bsts_b_list1) == (0)) { + if (get_variant(__bsts_b_list1) == 0) { return __bstsi_slot[0]; } else { @@ -89,43 +89,49 @@ BValue ___bsts_g_Bosatsu_l_Predef_l_foldr__List(BValue __bsts_b_list0, test("check foldLeft and reverse_concat") { assertPredefFns("foldLeft", "reverse_concat")("""#include "bosatsu_runtime.h" -BValue __bsts_t_closure__loop0(BValue* __bstsi_slot, +BValue __bsts_t_closure0(BValue* __bstsi_slot, BValue __bsts_b_lst1, BValue __bsts_b_item1) { - BValue __bsts_l_loop__temp3; - BValue __bsts_l_loop__temp4; - _Bool __bsts_l_cond1 = 1; - BValue __bsts_l_res2; + BValue __bsts_a_0; + BValue __bsts_a_1; + BValue __bsts_a_3; + BValue __bsts_a_5; + __bsts_a_3 = __bsts_b_lst1; + __bsts_a_5 = __bsts_b_item1; + __bsts_a_0 = alloc_enum0(1); + _Bool __bsts_l_cond1; + __bsts_l_cond1 = get_variant(__bsts_a_0) == 1; while (__bsts_l_cond1) { - if (get_variant(__bsts_b_lst1) == (0)) { - __bsts_l_cond1 = 0; - __bsts_l_res2 = __bsts_b_item1; + if (get_variant(__bsts_a_3) == 0) { + __bsts_a_0 = alloc_enum0(0); + __bsts_a_1 = __bsts_a_5; } else { - BValue __bsts_b_head0 = get_enum_index(__bsts_b_lst1, 0); - BValue __bsts_b_tail0 = get_enum_index(__bsts_b_lst1, 1); - __bsts_l_loop__temp3 = __bsts_b_tail0; - __bsts_l_loop__temp4 = call_fn2(__bstsi_slot[0], - __bsts_b_item1, + BValue __bsts_b_head0 = get_enum_index(__bsts_a_3, 0); + BValue __bsts_b_tail0 = get_enum_index(__bsts_a_3, 1); + BValue __bsts_a_2 = __bsts_b_tail0; + BValue __bsts_a_4 = call_fn2(__bstsi_slot[0], + __bsts_a_5, __bsts_b_head0); - __bsts_b_lst1 = __bsts_l_loop__temp3; - __bsts_b_item1 = __bsts_l_loop__temp4; + __bsts_a_3 = __bsts_a_2; + __bsts_a_5 = __bsts_a_4; } + __bsts_l_cond1 = get_variant(__bsts_a_0) == 1; } - return __bsts_l_res2; + return __bsts_a_1; } BValue ___bsts_g_Bosatsu_l_Predef_l_foldLeft(BValue __bsts_b_lst0, BValue __bsts_b_item0, BValue __bsts_b_fn0) { - BValue __bsts_l_captures5[1] = { __bsts_b_fn0 }; + BValue __bsts_l_captures2[1] = { __bsts_b_fn0 }; BValue __bsts_b_loop0 = alloc_closure2(1, - __bsts_l_captures5, - __bsts_t_closure__loop0); + __bsts_l_captures2, + __bsts_t_closure0); return call_fn2(__bsts_b_loop0, __bsts_b_lst0, __bsts_b_item0); } -BValue __bsts_t_lambda6(BValue __bsts_b_tail0, BValue __bsts_b_h0) { +BValue __bsts_t_lambda3(BValue __bsts_b_tail0, BValue __bsts_b_h0) { return alloc_enum2(1, __bsts_b_h0, __bsts_b_tail0); } @@ -133,7 +139,7 @@ BValue ___bsts_g_Bosatsu_l_Predef_l_reverse__concat(BValue __bsts_b_front0, BValue __bsts_b_back0) { return ___bsts_g_Bosatsu_l_Predef_l_foldLeft(__bsts_b_front0, __bsts_b_back0, - alloc_boxed_pure_fn2(__bsts_t_lambda6)); + alloc_boxed_pure_fn2(__bsts_t_lambda3)); }""") }