From e0ac0441961e7fade621606b542755ec4e380ce0 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Mon, 21 Nov 2022 17:21:44 +0100 Subject: [PATCH] Reuse beta-reduction logic from BetaReduce Fixes a bug when beta-reducing inlined code. In some situations the beta-reduction did not bind mutable variables. --- community-build/community-projects/spire | 2 +- .../tools/dotc/inlines/InlineReducer.scala | 22 ++++-------------- .../tools/dotc/transform/BetaReduce.scala | 20 ++++++++++------ .../backend/jvm/InlineBytecodeTests.scala | 10 +------- tests/run/i16390.scala | 23 +++++++++++++++++++ 5 files changed, 43 insertions(+), 34 deletions(-) create mode 100644 tests/run/i16390.scala diff --git a/community-build/community-projects/spire b/community-build/community-projects/spire index 7f630c0209e3..99ea909d086e 160000 --- a/community-build/community-projects/spire +++ b/community-build/community-projects/spire @@ -1 +1 @@ -Subproject commit 7f630c0209e327bdc782ade2210d8e4b916fddcc +Subproject commit 99ea909d086e28e85dbdf8aa78ef0a83bf873405 diff --git a/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala b/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala index b85454b8ba35..42e86b71eff8 100644 --- a/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala +++ b/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala @@ -12,6 +12,8 @@ import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName} import config.Printers.inlining import util.SimpleIdentityMap +import dotty.tools.dotc.transform.BetaReduce + import collection.mutable /** A utility class offering methods for rewriting inlined code */ @@ -163,26 +165,12 @@ class InlineReducer(inliner: Inliner)(using Context): */ def betaReduce(tree: Tree)(using Context): Tree = tree match { case Apply(Select(cl, nme.apply), args) if defn.isFunctionType(cl.tpe) => - val bindingsBuf = new DefBuffer + val bindingsBuf = new mutable.ListBuffer[ValDef] def recur(cl: Tree): Option[Tree] = cl match case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol => ddef.tpe.widen match case mt: MethodType if ddef.paramss.head.length == args.length => - val argSyms = mt.paramNames.lazyZip(mt.paramInfos).lazyZip(args).map { (name, paramtp, arg) => - arg.tpe.dealias match { - case ref @ TermRef(NoPrefix, _) => ref.symbol - case _ => - paramBindingDef(name, paramtp, arg, bindingsBuf)( - using ctx.withSource(cl.source) - ).symbol - } - } - val expander = new TreeTypeMap( - oldOwners = ddef.symbol :: Nil, - newOwners = ctx.owner :: Nil, - substFrom = ddef.paramss.head.map(_.symbol), - substTo = argSyms) - Some(expander.transform(ddef.rhs)) + Some(BetaReduce.reduceApplication(ddef, args, bindingsBuf)) case _ => None case Block(stats, expr) if stats.forall(isPureBinding) => recur(expr).map(cpy.Block(cl)(stats, _)) @@ -193,7 +181,7 @@ class InlineReducer(inliner: Inliner)(using Context): case _ => None recur(cl) match case Some(reduced) => - Block(bindingsBuf.toList, reduced).withSpan(tree.span) + seq(bindingsBuf.result(), reduced).withSpan(tree.span) case None => tree case _ => diff --git a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala index 90c0207ebb6d..7ac3dc972ad1 100644 --- a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala +++ b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala @@ -9,6 +9,8 @@ import Symbols._, Contexts._, Types._, Decorators._ import StdNames.nme import ast.TreeTypeMap +import scala.collection.mutable.ListBuffer + /** Rewrite an application * * (((x1, ..., xn) => b): T)(y1, ..., yn) @@ -70,9 +72,15 @@ object BetaReduce: original end apply - /** Beta-reduces a call to `ddef` with arguments `argSyms` */ + /** Beta-reduces a call to `ddef` with arguments `args` */ def apply(ddef: DefDef, args: List[Tree])(using Context) = - val bindings = List.newBuilder[ValDef] + val bindings = new ListBuffer[ValDef]() + val expansion1 = reduceApplication(ddef, args, bindings) + val bindings1 = bindings.result() + seq(bindings1, expansion1) + + /** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */ + def reduceApplication(ddef: DefDef, args: List[Tree], bindings: ListBuffer[ValDef])(using Context): Tree = val vparams = ddef.termParamss.iterator.flatten.toList assert(args.hasSameLengthAs(vparams)) val argSyms = @@ -84,7 +92,8 @@ object BetaReduce: val flags = Synthetic | (param.symbol.flags & Erased) val tpe = if arg.tpe.dealias.isInstanceOf[ConstantType] then arg.tpe.dealias else arg.tpe.widen val binding = ValDef(newSymbol(ctx.owner, param.name, flags, tpe, coord = arg.span), arg).withSpan(arg.span) - bindings += binding + if !(tpe.isInstanceOf[ConstantType] && isPureExpr(arg)) then + bindings += binding binding.symbol val expansion = TreeTypeMap( @@ -99,8 +108,5 @@ object BetaReduce: case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const) case _ => super.transform(tree) }.transform(expansion) - val bindings1 = - bindings.result().filterNot(vdef => vdef.tpt.tpe.isInstanceOf[ConstantType] && isPureExpr(vdef.rhs)) - seq(bindings1, expansion1) - end apply + expansion1 diff --git a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala index a492e8785afc..33e898718b33 100644 --- a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala @@ -600,15 +600,7 @@ class InlineBytecodeTests extends DottyBytecodeTest { val instructions = instructionsFromMethod(fun) val expected = // TODO room for constant folding List( - Op(ICONST_2), - VarOp(ISTORE, 1), - Op(ICONST_1), - VarOp(ISTORE, 2), - Op(ICONST_2), - VarOp(ILOAD, 2), - Op(IADD), - Op(ICONST_3), - Op(IADD), + IntOp(BIPUSH, 6), Op(IRETURN), ) assert(instructions == expected, diff --git a/tests/run/i16390.scala b/tests/run/i16390.scala new file mode 100644 index 000000000000..067603fea00a --- /dev/null +++ b/tests/run/i16390.scala @@ -0,0 +1,23 @@ +inline def cfor(inline body: Int => Unit): Unit = + var index = 0 + while index < 3 do + body(index) + index = index + 1 + +@main def Test = + assert(test1() == test2(), (test1(), test2())) + +def test1() = + val b = collection.mutable.ArrayBuffer.empty[() => Int] + cfor { x => + b += (() => x) + } + b.map(_.apply()).toList + +def test2() = + val b = collection.mutable.ArrayBuffer.empty[() => Int] + var index = 0 + while index < 3 do + ((x: Int) => b += (() => x)).apply(index) + index = index + 1 + b.map(_.apply()).toList \ No newline at end of file