From 012d4568e95ddfd865899de0db8c0c40dacd0824 Mon Sep 17 00:00:00 2001 From: Martin Kucera <3159068+KuceraMartin@users.noreply.github.com> Date: Tue, 28 Mar 2023 19:08:58 +0200 Subject: [PATCH] List(...) optimization to avoid intermediate array (closes https://github.com/lampepfl/dotty/issues/17035) --- .../dotty/tools/dotc/core/Definitions.scala | 20 +++++++------- .../tools/dotc/transform/ArrayApply.scala | 27 +++++++++++++++++-- .../tools/backend/jvm/ArrayApplyOptTest.scala | 25 +++++++++++++++++ tests/run/list-apply-eval.scala | 21 +++++++++++++++ 4 files changed, 82 insertions(+), 11 deletions(-) create mode 100644 tests/run/list-apply-eval.scala diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 148b314220a8..861decbda54d 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -513,14 +513,15 @@ class Definitions { methodNames.map(getWrapVarargsArrayModule.requiredMethod(_)) }) - @tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List") - def ListType: TypeRef = ListClass.typeRef - @tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List") - @tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil") - def NilType: TermRef = NilModule.termRef - @tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::") - def ConsType: TypeRef = ConsClass.typeRef - @tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory") + @tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List") + def ListType: TypeRef = ListClass.typeRef + @tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List") + @tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply) + @tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil") + def NilType: TermRef = NilModule.termRef + @tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::") + def ConsType: TypeRef = ConsClass.typeRef + @tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory") @tu lazy val SingletonClass: ClassSymbol = // needed as a synthetic class because Scala 2.x refers to it in classfiles @@ -539,7 +540,8 @@ class Definitions { @tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType)) @tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length) @tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq) - @tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq") + @tu lazy val SeqModule : Symbol = requiredModule("scala.collection.immutable.Seq") + @tu lazy val SeqModule_apply : Symbol = SeqModule.requiredMethod(nme.apply) @tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps") diff --git a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala index 872c7cc897de..5aa23b233bbb 100644 --- a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala +++ b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala @@ -22,9 +22,18 @@ class ArrayApply extends MiniPhase { override def description: String = ArrayApply.description + private var transformListApplyLimit = 8 + + private def reducingTransformListApply[A](depth: Int)(body: => A): A = { + val saved = transformListApplyLimit + transformListApplyLimit -= depth + try body + finally transformListApplyLimit = saved + } + override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree = if isArrayModuleApply(tree.symbol) then - tree.args match { + tree.args match case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) => seqLit @@ -35,7 +44,18 @@ class ArrayApply extends MiniPhase { case _ => tree - } + + else if isListOrSeqModuleApply(tree.symbol) then + tree.args match + // (a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types + case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: tpd.JavaSeqLiteral)))) :: Nil + if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) && + rest.elems.lengthIs < transformListApplyLimit => + rest.elems.foldRight(tpd.ref(defn.NilModule)): (elem, acc) => + tpd.New(defn.ConsType, List(elem, acc)) + + case _ => + tree else tree @@ -43,6 +63,9 @@ class ArrayApply extends MiniPhase { sym.name == nme.apply && (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension))) + private def isListOrSeqModuleApply(sym: Symbol)(using Context): Boolean = + sym == defn.ListModule_apply || sym == defn.SeqModule_apply + /** Only optimize when classtag if it is one of * - `ClassTag.apply(classOf[XYZ])` * - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ`` diff --git a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala index e7cd20ba98b2..a2d37b8399e5 100644 --- a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala +++ b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala @@ -160,4 +160,29 @@ class ArrayApplyOptTest extends DottyBytecodeTest { } } + @Test def testListApplyAvoidsIntermediateArray = { + val source = + """ + |class Foo { + | def meth1: List[String] = List("1", "2", "3") + | def meth2: List[String] = + | new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]] + |} + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val meth1 = getMethod(clsNode, "meth1") + val meth2 = getMethod(clsNode, "meth2") + + val instructions1 = instructionsFromMethod(meth1) + val instructions2 = instructionsFromMethod(meth2) + + assert(instructions1 == instructions2, + "the List.apply method " + + diffInstructions(instructions1, instructions2)) + } + } + } diff --git a/tests/run/list-apply-eval.scala b/tests/run/list-apply-eval.scala new file mode 100644 index 000000000000..4e25444689cc --- /dev/null +++ b/tests/run/list-apply-eval.scala @@ -0,0 +1,21 @@ +object Test: + + var counter = 0 + + def next = + counter += 1 + counter.toString + + def main(args: Array[String]): Unit = + //List.apply is subject to an optimisation in cleanup + //ensure that the arguments are evaluated in the currect order + // Rewritten to: + // val myList: List = new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), scala.collection.immutable.Nil))); + val myList = List(next, next, next) + assert(myList == List("1", "2", "3"), myList) + + val mySeq = Seq(next, next, next) + assert(mySeq == Seq("4", "5", "6"), mySeq) + + val emptyList = List[Int]() + assert(emptyList == Nil)