Skip to content

Commit

Permalink
Reuse beta-reduction logic from BetaReduce
Browse files Browse the repository at this point in the history
Fixes a bug with when betareducing inlined code. In some situations
the beta-reduction did not bind mutable variables.
  • Loading branch information
nicolasstucki committed Nov 25, 2022
1 parent c9ace66 commit 2643ea2
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 34 deletions.
2 changes: 1 addition & 1 deletion community-build/community-projects/spire
22 changes: 5 additions & 17 deletions compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName}
import config.Printers.inlining
import util.SimpleIdentityMap

import dotty.tools.dotc.transform.BetaReduce

import collection.mutable

/** A utility class offering methods for rewriting inlined code */
Expand Down Expand Up @@ -163,26 +165,12 @@ class InlineReducer(inliner: Inliner)(using Context):
*/
def betaReduce(tree: Tree)(using Context): Tree = tree match {
case Apply(Select(cl, nme.apply), args) if defn.isFunctionType(cl.tpe) =>
val bindingsBuf = new DefBuffer
val bindingsBuf = new mutable.ListBuffer[ValDef]
def recur(cl: Tree): Option[Tree] = cl match
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
ddef.tpe.widen match
case mt: MethodType if ddef.paramss.head.length == args.length =>
val argSyms = mt.paramNames.lazyZip(mt.paramInfos).lazyZip(args).map { (name, paramtp, arg) =>
arg.tpe.dealias match {
case ref @ TermRef(NoPrefix, _) => ref.symbol
case _ =>
paramBindingDef(name, paramtp, arg, bindingsBuf)(
using ctx.withSource(cl.source)
).symbol
}
}
val expander = new TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
substFrom = ddef.paramss.head.map(_.symbol),
substTo = argSyms)
Some(expander.transform(ddef.rhs))
Some(BetaReduce.reduceApplication(ddef, args, bindingsBuf))
case _ => None
case Block(stats, expr) if stats.forall(isPureBinding) =>
recur(expr).map(cpy.Block(cl)(stats, _))
Expand All @@ -193,7 +181,7 @@ class InlineReducer(inliner: Inliner)(using Context):
case _ => None
recur(cl) match
case Some(reduced) =>
Block(bindingsBuf.toList, reduced).withSpan(tree.span)
seq(bindingsBuf.result(), reduced).withSpan(tree.span)
case None =>
tree
case _ =>
Expand Down
20 changes: 13 additions & 7 deletions compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import Symbols._, Contexts._, Types._, Decorators._
import StdNames.nme
import ast.TreeTypeMap

import scala.collection.mutable.ListBuffer

/** Rewrite an application
*
* (((x1, ..., xn) => b): T)(y1, ..., yn)
Expand Down Expand Up @@ -70,9 +72,15 @@ object BetaReduce:
original
end apply

/** Beta-reduces a call to `ddef` with arguments `argSyms` */
/** Beta-reduces a call to `ddef` with arguments `args` */
def apply(ddef: DefDef, args: List[Tree])(using Context) =
val bindings = List.newBuilder[ValDef]
val bindings = new ListBuffer[ValDef]()
val expansion1 = reduceApplication(ddef, args, bindings)
val bindings1 = bindings.result()
seq(bindings1, expansion1)

/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */
def reduceApplication(ddef: DefDef, args: List[Tree], bindings: ListBuffer[ValDef])(using Context): Tree =
val vparams = ddef.termParamss.iterator.flatten.toList
assert(args.hasSameLengthAs(vparams))
val argSyms =
Expand All @@ -84,7 +92,8 @@ object BetaReduce:
val flags = Synthetic | (param.symbol.flags & Erased)
val tpe = if arg.tpe.dealias.isInstanceOf[ConstantType] then arg.tpe.dealias else arg.tpe.widen
val binding = ValDef(newSymbol(ctx.owner, param.name, flags, tpe, coord = arg.span), arg).withSpan(arg.span)
bindings += binding
if !(tpe.isInstanceOf[ConstantType] && isPureExpr(arg)) then
bindings += binding
binding.symbol

val expansion = TreeTypeMap(
Expand All @@ -99,8 +108,5 @@ object BetaReduce:
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
case _ => super.transform(tree)
}.transform(expansion)
val bindings1 =
bindings.result().filterNot(vdef => vdef.tpt.tpe.isInstanceOf[ConstantType] && isPureExpr(vdef.rhs))

seq(bindings1, expansion1)
end apply
expansion1
10 changes: 1 addition & 9 deletions compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -600,15 +600,7 @@ class InlineBytecodeTests extends DottyBytecodeTest {
val instructions = instructionsFromMethod(fun)
val expected = // TODO room for constant folding
List(
Op(ICONST_2),
VarOp(ISTORE, 1),
Op(ICONST_1),
VarOp(ISTORE, 2),
Op(ICONST_2),
VarOp(ILOAD, 2),
Op(IADD),
Op(ICONST_3),
Op(IADD),
IntOp(BIPUSH, 6),
Op(IRETURN),
)
assert(instructions == expected,
Expand Down
23 changes: 23 additions & 0 deletions tests/run/i16390.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
inline def cfor(inline body: Int => Unit): Unit =
var index = 0
while index < 3 do
body(index)
index = index + 1

@main def Test =
assert(test1() == test2(), (test1(), test2()))

def test1() =
val b = collection.mutable.ArrayBuffer.empty[() => Int]
cfor { x =>
b += (() => x)
}
b.map(_.apply()).toList

def test2() =
val b = collection.mutable.ArrayBuffer.empty[() => Int]
var index = 0
while index < 3 do
((x: Int) => b += (() => x)).apply(index)
index = index + 1
b.map(_.apply()).toList

0 comments on commit 2643ea2

Please sign in to comment.