From 9b2c23ffcf5e0f068793f079eafa3fe2b05a4e1b Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Thu, 4 Apr 2024 10:28:27 +0200 Subject: [PATCH] Elide unit binding when beta-reducing See https://github.com/scala/scala3/discussions/20082#discussioncomment-9006501 --- .../tools/dotc/transform/BetaReduce.scala | 2 ++ .../backend/jvm/InlineBytecodeTests.scala | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala index 653a5e17990f..7e6c1977359b 100644 --- a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala +++ b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala @@ -8,6 +8,7 @@ import MegaPhase.* import Symbols.*, Contexts.*, Types.*, Decorators.* import StdNames.nme import ast.TreeTypeMap +import Constants.Constant import scala.collection.mutable.ListBuffer @@ -131,6 +132,7 @@ object BetaReduce: val tpe = if arg.tpe.isBottomType then param.tpe.widenTermRefExpr else if arg.tpe.dealias.isInstanceOf[ConstantType] then arg.tpe.dealias + else if arg.tpe.dealias =:= defn.UnitType then ConstantType(Constant(())) else arg.tpe.widen val binding = ValDef(newSymbol(ctx.owner, param.name, flags, tpe, coord = arg.span), arg).withSpan(arg.span) if !(tpe.isInstanceOf[ConstantType] && isPureExpr(arg)) then diff --git a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala index 6173842e9ad1..fcbc738f2934 100644 --- a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala @@ -765,4 +765,24 @@ class InlineBytecodeTests extends DottyBytecodeTest { diffInstructions(instructions1, instructions2)) } } + + @Test def beta_reduce_elide_unit_binding = { + val source = """class Test: + | def test = ((u: Unit) => u).apply(()) + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + + val fun = getMethod(clsNode, "test") + val instructions = instructionsFromMethod(fun) + val expected = List(Op(RETURN)) + + assert(instructions == expected, + "`i was not properly beta-reduced in `test`\n" + diffInstructions(instructions, expected)) + + } + } + }