From e2b2e9a3ccc44aabd22d6b53d29a2affb3e85da3 Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Wed, 11 Dec 2024 19:12:06 -1000 Subject: [PATCH] fix inlining for small lambdas --- .../bosatsu/codegen/clang/ClangGenTest.scala | 2 +- .../scala/org/bykn/bosatsu/TypedExpr.scala | 1 + .../bykn/bosatsu/TypedExprNormalization.scala | 41 ++++++++----- .../bykn/bosatsu/codegen/clang/ClangGen.scala | 51 +++++++++------- .../bosatsu/codegen/clang/ClangGenTest.scala | 60 +++++++++---------- 5 files changed, 89 insertions(+), 66 deletions(-) diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala index 5498f543b..65b0a38f4 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -45,7 +45,7 @@ class ClangGenTest extends munit.FunSuite { To inspect the code, change the hash, and it will print the code out */ testFilesCompilesToHash("test_workspace/Ackermann.bosatsu")( - "260c81bc79b6232a3f174cb9afc04143" + "01b16c11c1597e46371d356111276af5" ) } } \ No newline at end of file diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala index 06b648973..8ee2bc0b4 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala @@ -220,6 +220,7 @@ object TypedExpr { tag: T ) extends TypedExpr[T] { + // This makes sure the args don't shadow any of the items in freeSet def unshadow(freeSet: Set[Bindable]): AnnotatedLambda[T] = { val clashIdent = if (freeSet.isEmpty) Set.empty[Bindable] diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala index 2435ca8e4..ac842082a 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala @@ -120,8 +120,10 @@ object TypedExprNormalization { if (!tpe.sameAs(expr.getType)) Annotation(expr, tpe) else expr private def appLambda[A](f1: AnnotatedLambda[A], args: NonEmptyList[TypedExpr[A]], tpe: Type, tag: A): TypedExpr[A] = { - val AnnotatedLambda(lamArgs, expr, _) = f1 - // (y -> z)(x) = let y = x in z + val freesInArgs = TypedExpr.freeVarsSet(args.toList) + val AnnotatedLambda(lamArgs, expr, _) = f1.unshadow(freesInArgs) + // Now that we certainly don't shadow we can convert this: + // ((y1, y2, ..., yn) -> z)(x1, x2, ..., xn) = let y1 = x1 in let y2 = x2 in ... z val lets = lamArgs.zip(args).map { case ((n, ltpe), arg) => (n, setType(arg, ltpe)) } @@ -237,18 +239,21 @@ object TypedExprNormalization { // or remove inlining from here unless it can hever hurt and put inlining at a // different phase. val fn1 = AnnotatedLambda(args1, body, ftag) - val e2 = App(fn1, aargs, resT, atag) - if (e1 != e2) { + //val e2 = App(fn1, aargs, resT, atag) + val applied = appLambda[A](fn1, aargs, resT, atag) + //val e3 = normalize1(None, applied, bodyScope, typeEnv).get + //if (e1 != e2) { // in this case we have inlined, vs there already being // a literal lambda being applied // by normalizing this, it will become a let binding - val e3 = normalize1(None, e2, bodyScope, typeEnv).get + //val e3 = normalize1(None, e2, bodyScope, typeEnv).get - if (e3.size <= expr.size) { + //if (e3.size <= expr.size) { + if (true) { // we haven't made the code larger normalize1( namerec, - AnnotatedLambda(lamArgs, e3, tag), + AnnotatedLambda(lamArgs, applied, tag), scope, typeEnv ) @@ -257,10 +262,12 @@ object TypedExprNormalization { if ((e1 eq expr) && (lamArgs === lamArgs0)) None else Some(AnnotatedLambda(lamArgs, e1, tag)) } + /* } else { if ((e1 eq expr) && (lamArgs === lamArgs0)) None else Some(AnnotatedLambda(lamArgs, e1, tag)) } + */ case Let(arg1, ex, in, rec, tag1) if doesntUseArgs(ex) && doesntShadow(arg1) => // x -> @@ -336,19 +343,24 @@ object TypedExprNormalization { lazy val a1 = ListUtil.mapConserveNel(args) { a => normalize1(None, a, scope, typeEnv).get } + val ws = Impl.WithScope(scope, ev.substituteCo[TypeEnv](typeEnv)) f1 match { // TODO: what if f1: Generic(_, AnnotatedLambda(_, _, _)) // we should still be able ton convert this to a let by // instantiating to the right args + case ws.ResolveToLambda(Nil, args1, body, ftag) => + val lam = AnnotatedLambda(args1, body, ftag) + val l = appLambda[A](lam, args, tpe, tag) + normalize1(namerec, l, scope, typeEnv) case lam @ AnnotatedLambda(_, _, _) => - val l = appLambda[A](lam, a1, tpe, tag) + val l = appLambda[A](lam, args, tpe, tag) normalize1(namerec, l, scope, typeEnv) case Let(arg1, ex, in, rec, tag1) if a1.forall(_.notFree(arg1)) => // (app (let x y z) w) == (let x y (app z w)) if w does not have x free normalize1( namerec, - Let(arg1, ex, App(in, a1, tpe, tag), rec, tag1), + Let(arg1, ex, App(in, args, tpe, tag), rec, tag1), scope, typeEnv ) @@ -356,9 +368,6 @@ object TypedExprNormalization { if ((f1 eq fn) && (tpe == tpe0) && (a1 eq args)) None else Some(App(f1, a1, tpe, tag)) } - case Let(arg, ex, Local(arg1, _, _), RecursionKind.NonRecursive, _) if arg1 === arg => - // (let x y x) == y - normalize1(namerec, ex, scope, typeEnv) case Let(arg, ex, in, rec, tag) => // note, Infer has already checked // to make sure rec is accurate @@ -548,6 +557,9 @@ object TypedExprNormalization { } object ResolveToLambda { + // this is a parameter that we can tune to change inlining + val MaxSize = 10 + // TODO: don't we need to worry about the type environment for locals? They // can also capture type references to outer Generics def unapply(te: TypedExpr[A]): Option[ @@ -585,7 +597,8 @@ object TypedExprNormalization { Some((Nil, args, expr, ltag)) case Global(p, n: Bindable, _, _) => scope.getGlobal(p, n).flatMap { - case (RecursionKind.NonRecursive, te, scope1) => + // + case (RecursionKind.NonRecursive, te, scope1) if te.size < MaxSize => val s1 = WithScope(scope1, typeEnv) te match { case s1.ResolveToLambda(frees, args, expr, ltag) => @@ -607,7 +620,7 @@ object TypedExprNormalization { } case Local(nm, _, _) => scope.getLocal(nm).flatMap { - case (RecursionKind.NonRecursive, te, scope1) => + case (RecursionKind.NonRecursive, te, scope1) if te.size < MaxSize => val s1 = WithScope(scope1, typeEnv) te match { case s1.ResolveToLambda(frees, args, expr, ltag) => diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala index 219381a90..77f1e0a49 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala @@ -1026,7 +1026,7 @@ object ClangGen { } def fnStatement(fnName: Code.Ident, fn: FnExpr): T[Code.Statement] = - fn match { + fn match { case Lambda(captures, name, args, expr) => val body = innerToValue(expr).map(Code.returnValue(_)) val body1 = name match { @@ -1119,6 +1119,7 @@ object ClangGen { includes: Chain[Code.Include], stmts: Chain[Code.Statement], currentTop: Option[(PackageName, Bindable)], + bindCount: Int, binds: Map[Bindable, NonEmptyList[Either[((Code.Ident, Boolean, Int), Int), Int]]], counter: Long, identCache: Map[Expr, Code.Ident] @@ -1139,7 +1140,7 @@ object ClangGen { List(Code.Include.quote("bosatsu_runtime.h")) State(allValues, externals, Set.empty ++ defaultIncludes, Chain.fromSeq(defaultIncludes), Chain.empty, - None, Map.empty, 0L, Map.empty + None, 0, Map.empty, 0L, Map.empty ) } } @@ -1168,35 +1169,43 @@ object ClangGen { Eval.now(Right((s, a))) ) + def update[A](fn: State => (State, A)): T[A] = + StateT(s => EitherT[Eval, Error, (State, A)](Eval.now(Right(fn(s))))) + + def tryUpdate[A](fn: State => Either[Error, (State, A)]): T[A] = + StateT(s => EitherT[Eval, Error, (State, A)](Eval.now(fn(s)))) + + def tryRead[A](fn: State => Either[Error, A]): T[A] = + StateT(s => EitherT[Eval, Error, (State, A)](Eval.now(fn(s).map((s, _))))) + def globalIdent(pn: PackageName, bn: Bindable): T[Code.Ident] = - StateT { s => + tryUpdate { s => s.externals(pn, bn) match { case Some((incl, ident, _)) => // TODO: suspect that we are ignoring arity here val withIncl = s.include(incl) - result(withIncl, ident) + Right((withIncl, ident)) case None => val key = (pn, bn) s.allValues.get(key) match { - case Some((_, ident)) => result(s, ident) - case None => errorRes(Error.UnknownValue(pn, bn)) + case Some((_, ident)) => Right((s, ident)) + case None => Left(Error.UnknownValue(pn, bn)) } } } def bind[A](bn: Bindable)(in: T[A]): T[A] = { - val init: T[Unit] = StateT { s => + val init: T[Unit] = update { s => + val cnt = s.bindCount val v = s.binds.get(bn) match { - case None => NonEmptyList.one(Right(0)) - case Some(items @ NonEmptyList(Right(idx), _)) => - Right(idx + 1) :: items - case Some(items @ NonEmptyList(Left((_, idx)), _)) => - Right(idx + 1) :: items + case None => NonEmptyList.one(Right(cnt)) + case Some(items) => + Right(cnt) :: items } - result(s.copy(binds = s.binds.updated(bn, v)), ()) + (s.copy(bindCount = cnt + 1, binds = s.binds.updated(bn, v)), ()) } - val uninit: T[Unit] = StateT { s => + val uninit: T[Unit] = update { s => s.binds.get(bn) match { case Some(NonEmptyList(_, tail)) => val s1 = NonEmptyList.fromList(tail) match { @@ -1205,7 +1214,7 @@ object ClangGen { case Some(prior) => s.copy(binds = s.binds.updated(bn, prior)) } - result(s1, ()) + (s1, ()) case None => sys.error(s"bindable $bn no longer in $s") } } @@ -1217,18 +1226,18 @@ object ClangGen { } yield a } def getBinding(bn: Bindable): T[Code.Ident] = - StateT { s => + tryRead { s => s.binds.get(bn) match { case Some(stack) => stack.head match { case Right(idx) => - result(s, Code.Ident(Idents.escape("__bsts_b_", bn.asString + idx.toString))) + Right(Code.Ident(Idents.escape("__bsts_b_", bn.asString + idx.toString))) case Left(((ident, _, _), _)) => // TODO: suspicious to ignore isClosure and arity here // probably need to conv - result(s, ident) + Right(ident) } - case None => errorRes(Error.Unbound(bn, s.currentTop)) + case None => Left(Error.Unbound(bn, s.currentTop)) } } def bindAnon[A](idx: Long)(in: T[A]): T[A] = @@ -1321,9 +1330,9 @@ object ClangGen { def inTop[A](p: PackageName, bn: Bindable)(ta: T[A]): T[A] = for { - _ <- StateT { (s: State) => result(s.copy(currentTop = Some((p, bn))), ())} + initBindCount <- update { (s: State) => (s.copy(bindCount = 0, currentTop = Some((p, bn))), s.bindCount)} a <- ta - _ <- StateT { (s: State) => result(s.copy(currentTop = None), ()) } + _ <- update { (s: State) => (s.copy(currentTop = None, bindCount = initBindCount), ()) } } yield a val currentTop: T[Option[(PackageName, Bindable)]] = diff --git a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala index 8f5129d21..ce834c042 100644 --- a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -51,8 +51,8 @@ x = 1 #include #include "gc.h" -BValue __bsts_t_lambda0(BValue __bsts_b_a0, BValue __bsts_b_b0) { - return alloc_enum2(1, __bsts_b_a0, __bsts_b_b0); +BValue __bsts_t_lambda0(BValue __bsts_b_a1, BValue __bsts_b_b2) { + return alloc_enum2(1, __bsts_b_a1, __bsts_b_b2); } BValue ___bsts_g_Bosatsu_l_Predef_l_build__List(BValue __bsts_b_fn0) { @@ -75,27 +75,27 @@ int main(int argc, char** argv) { #include #include "gc.h" -BValue __bsts_t_closure__loop0(BValue* __bstsi_slot, BValue __bsts_b_list1) { - if (get_variant(__bsts_b_list1) == 0) { +BValue __bsts_t_closure__loop0(BValue* __bstsi_slot, BValue __bsts_b_list3) { + if (get_variant(__bsts_b_list3) == 0) { return __bstsi_slot[0]; } else { - BValue __bsts_b_h0 = get_enum_index(__bsts_b_list1, 0); - BValue __bsts_b_t0 = get_enum_index(__bsts_b_list1, 1); + BValue __bsts_b_h4 = get_enum_index(__bsts_b_list3, 0); + BValue __bsts_b_t5 = get_enum_index(__bsts_b_list3, 1); return call_fn2(__bstsi_slot[1], - __bsts_b_h0, - __bsts_t_closure__loop0(__bstsi_slot, __bsts_b_t0)); + __bsts_b_h4, + __bsts_t_closure__loop0(__bstsi_slot, __bsts_b_t5)); } } BValue ___bsts_g_Bosatsu_l_Predef_l_foldr__List(BValue __bsts_b_list0, - BValue __bsts_b_fn0, - BValue __bsts_b_acc0) { - BValue __bsts_l_captures1[2] = { __bsts_b_acc0, __bsts_b_fn0 }; - BValue __bsts_b_loop0 = alloc_closure1(2, + BValue __bsts_b_fn1, + BValue __bsts_b_acc2) { + BValue __bsts_l_captures1[2] = { __bsts_b_acc2, __bsts_b_fn1 }; + BValue __bsts_b_loop6 = alloc_closure1(2, __bsts_l_captures1, __bsts_t_closure__loop0); - return call_fn1(__bsts_b_loop0, __bsts_b_list0); + return call_fn1(__bsts_b_loop6, __bsts_b_list0); } int main(int argc, char** argv) { @@ -113,14 +113,14 @@ int main(int argc, char** argv) { #include "gc.h" BValue __bsts_t_closure0(BValue* __bstsi_slot, - BValue __bsts_b_lst1, - BValue __bsts_b_item1) { + BValue __bsts_b_lst3, + BValue __bsts_b_item4) { BValue __bsts_a_0; BValue __bsts_a_1; BValue __bsts_a_3; BValue __bsts_a_5; - __bsts_a_3 = __bsts_b_lst1; - __bsts_a_5 = __bsts_b_item1; + __bsts_a_3 = __bsts_b_lst3; + __bsts_a_5 = __bsts_b_item4; __bsts_a_0 = alloc_enum0(1); _Bool __bsts_l_cond1; __bsts_l_cond1 = get_variant_value(__bsts_a_0) == 1; @@ -130,12 +130,12 @@ BValue __bsts_t_closure0(BValue* __bstsi_slot, __bsts_a_1 = __bsts_a_5; } else { - BValue __bsts_b_head0 = get_enum_index(__bsts_a_3, 0); - BValue __bsts_b_tail0 = get_enum_index(__bsts_a_3, 1); - BValue __bsts_a_2 = __bsts_b_tail0; + BValue __bsts_b_head5 = get_enum_index(__bsts_a_3, 0); + BValue __bsts_b_tail6 = get_enum_index(__bsts_a_3, 1); + BValue __bsts_a_2 = __bsts_b_tail6; BValue __bsts_a_4 = call_fn2(__bstsi_slot[0], __bsts_a_5, - __bsts_b_head0); + __bsts_b_head5); __bsts_a_3 = __bsts_a_2; __bsts_a_5 = __bsts_a_4; } @@ -145,23 +145,23 @@ BValue __bsts_t_closure0(BValue* __bstsi_slot, } BValue ___bsts_g_Bosatsu_l_Predef_l_foldl__List(BValue __bsts_b_lst0, - BValue __bsts_b_item0, - BValue __bsts_b_fn0) { - BValue __bsts_l_captures2[1] = { __bsts_b_fn0 }; - BValue __bsts_b_loop0 = alloc_closure2(1, + BValue __bsts_b_item1, + BValue __bsts_b_fn2) { + BValue __bsts_l_captures2[1] = { __bsts_b_fn2 }; + BValue __bsts_b_loop7 = alloc_closure2(1, __bsts_l_captures2, __bsts_t_closure0); - return call_fn2(__bsts_b_loop0, __bsts_b_lst0, __bsts_b_item0); + return call_fn2(__bsts_b_loop7, __bsts_b_lst0, __bsts_b_item1); } -BValue __bsts_t_lambda3(BValue __bsts_b_tail0, BValue __bsts_b_h0) { - return alloc_enum2(1, __bsts_b_h0, __bsts_b_tail0); +BValue __bsts_t_lambda3(BValue __bsts_b_tail2, BValue __bsts_b_h3) { + return alloc_enum2(1, __bsts_b_h3, __bsts_b_tail2); } BValue ___bsts_g_Bosatsu_l_Predef_l_reverse__concat(BValue __bsts_b_front0, - BValue __bsts_b_back0) { + BValue __bsts_b_back1) { return ___bsts_g_Bosatsu_l_Predef_l_foldl__List(__bsts_b_front0, - __bsts_b_back0, + __bsts_b_back1, alloc_boxed_pure_fn2(__bsts_t_lambda3)); }