diff --git a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala index f98d71e3ff67..edc1d98dc60a 100644 --- a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala +++ b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala @@ -8,6 +8,7 @@ import MegaPhase.* import Symbols.*, Contexts.*, Types.*, Decorators.* import StdNames.nme import ast.TreeTypeMap +import Constants.Constant import scala.collection.mutable.ListBuffer @@ -133,7 +134,7 @@ object BetaReduce: else if arg.tpe.dealias.isInstanceOf[ConstantType] then arg.tpe.dealias else arg.tpe.widen val binding = ValDef(newSymbol(ctx.owner, param.name, flags, tpe, coord = arg.span), arg).withSpan(arg.span) - if !(tpe.isInstanceOf[ConstantType] && isPureExpr(arg)) then + if !((tpe.isInstanceOf[ConstantType] || tpe.derivesFrom(defn.UnitClass)) && isPureExpr(arg)) then bindings += binding binding.symbol @@ -147,6 +148,7 @@ object BetaReduce: val expansion1 = new TreeMap { override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const) + case tpe: TypeRef if tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) => cpy.Literal(tree)(Constant(())) case _ => super.transform(tree) }.transform(expansion) diff --git a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala index 6173842e9ad1..fcbc738f2934 100644 --- a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala @@ -765,4 +765,24 @@ class InlineBytecodeTests extends DottyBytecodeTest { diffInstructions(instructions1, instructions2)) } } + + @Test def beta_reduce_elide_unit_binding = { + val source = """class Test: + | def test = ((u: Unit) => u).apply(()) + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + + val fun = getMethod(clsNode, "test") + val instructions = instructionsFromMethod(fun) + val expected = List(Op(RETURN)) + + assert(instructions == expected, + "`i was not properly beta-reduced in `test`\n" + diffInstructions(instructions, expected)) + + } + } + }