From 08b1e9abf5ffb5f4c28804dc154a0b2eda9f48e6 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Tue, 20 Feb 2024 09:31:34 +0100 Subject: [PATCH] Lift all non trivial prefixes for default parameters Checking if the prefix is pure is not enough to know if we need to list the prefix. In the case of default parameters, the prefix tree might be used several times to compute the default values. This expression should only be computed once and therefore it should be lifted if there is some computation/allocation involved. Furthermore, if the prefix contains a local definition, it must be lifted to avoid duplicating the definition. A similar situation could happen with dependent default parameters. This currently works as expected. --- .../dotty/tools/dotc/typer/EtaExpansion.scala | 12 ++++- .../backend/jvm/DottyBytecodeTests.scala | 52 +++++++++++++++++++ tests/run/i15315.scala | 5 ++ 3 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 tests/run/i15315.scala diff --git a/compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala b/compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala index 2c441c2f915e..8be346874700 100644 --- a/compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala +++ b/compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala @@ -136,10 +136,18 @@ abstract class Lifter { * val x0 = pre * x0.f(...) * - * unless `pre` is idempotent. + * unless `pre` is idempotent reference, a `this` reference, a literal value, or a or the prefix of an `init` (`New` tree). + * + * Note that default arguments will refer to the prefix, we do not want + * to re-evaluate a complex expression each time we access a getter. */ def liftPrefix(defs: mutable.ListBuffer[Tree], tree: Tree)(using Context): Tree = - if (isIdempotentExpr(tree)) tree else lift(defs, tree) + tree match + case tree: Literal => tree + case tree: This => tree + case tree: New => tree // prefix of call + case tree: RefTree if isIdempotentExpr(tree) => tree + case _ => lift(defs, tree) } /** No lifting at all */ diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala index 94d42952a6eb..51390e35b527 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala @@ -1733,6 +1733,58 @@ class DottyBytecodeTests extends DottyBytecodeTest { assertSameCode(instructions, expected) } } + + @Test def newInPrefixesOfDefaultParam = { + val source = + s"""class A: + | def f(x: Int = 1): Int = x + | + |class Test: + | def meth1() = (new A).f() + | def meth2() = { val a = new A; a.f() } + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val meth1 = getMethod(clsNode, "meth1") + val meth2 = getMethod(clsNode, "meth2") + + val instructions1 = instructionsFromMethod(meth1) + val instructions2 = instructionsFromMethod(meth2) + + assert(instructions1 == instructions2, + "`assert` was not properly inlined in `meth1`\n" + + diffInstructions(instructions1, instructions2)) + } + } + + @Test def newInDependentOfDefaultParam = { + val source = + s"""class A: + | def i: Int = 1 + | + |class Test: + | def f(a: A)(x: Int = a.i): Int = x + | def meth1() = f(new A)() + | def meth2() = { val a = new A; f(a)() } + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val meth1 = getMethod(clsNode, "meth1") + val meth2 = getMethod(clsNode, "meth2") + + val instructions1 = instructionsFromMethod(meth1) + val instructions2 = instructionsFromMethod(meth2) + + assert(instructions1 == instructions2, + "`assert` was not properly inlined in `meth1`\n" + + diffInstructions(instructions1, instructions2)) + } + } + } object invocationReceiversTestCode { diff --git a/tests/run/i15315.scala b/tests/run/i15315.scala new file mode 100644 index 000000000000..d9cab7b87b81 --- /dev/null +++ b/tests/run/i15315.scala @@ -0,0 +1,5 @@ +class A: + def f(x: Int = 1): Int = x + +@main def Test() = + (new A{}).f()