diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 2084cd4b04b1..c0435494497f 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -513,14 +513,16 @@ class Definitions { methodNames.map(getWrapVarargsArrayModule.requiredMethod(_)) }) - @tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List") - def ListType: TypeRef = ListClass.typeRef - @tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List") - @tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil") - def NilType: TermRef = NilModule.termRef - @tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::") - def ConsType: TypeRef = ConsClass.typeRef - @tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory") + @tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List") + def ListType: TypeRef = ListClass.typeRef + @tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List") + @tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply) + def ListModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.List) + @tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil") + def NilType: TermRef = NilModule.termRef + @tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::") + def ConsType: TypeRef = ConsClass.typeRef + @tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory") @tu lazy val SingletonClass: ClassSymbol = // needed as a synthetic class because Scala 2.x refers to it in classfiles @@ -530,8 +532,11 @@ class Definitions { List(AnyType), EmptyScope) @tu lazy val SingletonType: TypeRef = SingletonClass.typeRef - @tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq") - @tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq") + @tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq") + @tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq") + @tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq") + @tu lazy val SeqModule_apply: Symbol = SeqModule.requiredMethod(nme.apply) + def SeqModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.Seq) def SeqClass(using Context): ClassSymbol = SeqType.symbol.asClass @tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply) @tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head) @@ -539,7 +544,6 @@ class Definitions { @tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType)) @tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length) @tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq) - @tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq") @tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps") diff --git a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala index 6ece8ad63808..98ca8f2e2b5b 100644 --- a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala +++ b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala @@ -1,15 +1,12 @@ -package dotty.tools.dotc +package dotty.tools +package dotc package transform -import core.* +import ast.tpd +import core.*, Contexts.*, Decorators.*, Symbols.*, Flags.*, StdNames.* +import reporting.trace +import util.Property import MegaPhase.* -import Contexts.* -import Symbols.* -import Flags.* -import StdNames.* -import dotty.tools.dotc.ast.tpd - - /** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode. * @@ -22,27 +19,69 @@ class ArrayApply extends MiniPhase { override def description: String = ArrayApply.description - override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree = + private val TransformListApplyBudgetKey = new Property.Key[Int] + private def transformListApplyBudget(using Context) = + ctx.property(TransformListApplyBudgetKey).getOrElse(8) // default is 8, as originally implemented in nsc + + override def prepareForApply(tree: Apply)(using Context): Context = tree match + case SeqApplyArgs(elems) => + ctx.fresh.setProperty(TransformListApplyBudgetKey, transformListApplyBudget - elems.length) + case _ => ctx + + override def transformApply(tree: Apply)(using Context): Tree = if isArrayModuleApply(tree.symbol) then - tree.args match { - case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil + tree.args match + case StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: ct :: Nil if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) => seqLit - case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil + case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: Nil if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) => - tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt) + JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt) case _ => tree - } - else tree + else tree match + case SeqApplyArgs(elems) if transformListApplyBudget > 0 || elems.isEmpty => + val consed = elems.foldRight(ref(defn.NilModule)): (elem, acc) => + New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc)) + consed.cast(tree.tpe) + case _ => tree private def isArrayModuleApply(sym: Symbol)(using Context): Boolean = sym.name == nme.apply && (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension))) + private def isListApply(tree: Tree)(using Context): Boolean = + (tree.symbol == defn.ListModule_apply || tree.symbol.name == nme.apply) && appliedCore(tree).match + case Select(qual, _) => + val sym = qual.symbol + sym == defn.ListModule + || sym == defn.ListModuleAlias + case _ => false + + private def isSeqApply(tree: Tree)(using Context): Boolean = + isListApply(tree) || tree.symbol == defn.SeqModule_apply && appliedCore(tree).match + case Select(qual, _) => + val sym = qual.symbol + sym == defn.SeqModule + || sym == defn.SeqModuleAlias + || sym == defn.CollectionSeqType.symbol.companionModule + case _ => false + + private object SeqApplyArgs: + def unapply(tree: Apply)(using Context): Option[List[Tree]] = + if isSeqApply(tree) then + tree.args match + // (a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types + case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil + if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) => + Some(rest.elems) + case _ => None + else None + + /** Only optimize when classtag if it is one of * - `ClassTag.apply(classOf[XYZ])` * - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ`` diff --git a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala index e7cd20ba98b2..37e7d5316f9d 100644 --- a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala +++ b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala @@ -1,4 +1,5 @@ -package dotty.tools.backend.jvm +package dotty.tools +package backend.jvm import org.junit.Test import org.junit.Assert._ @@ -160,4 +161,153 @@ class ArrayApplyOptTest extends DottyBytecodeTest { } } + @Test def testListApplyAvoidsIntermediateArray = { + checkApplyAvoidsIntermediateArray("List"): + """import scala.collection.immutable.{ ::, Nil } + |class Foo { + | def meth1: List[String] = List("1", "2", "3") + | def meth2: List[String] = new ::("1", new ::("2", new ::("3", Nil))) + |} + """.stripMargin + } + + @Test def testSeqApplyAvoidsIntermediateArray = { + checkApplyAvoidsIntermediateArray("Seq"): + """import scala.collection.immutable.{ ::, Nil } + |class Foo { + | def meth1: Seq[String] = Seq("1", "2", "3") + | def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil))) + |} + """.stripMargin + } + + @Test def testSeqApplyAvoidsIntermediateArray2 = { + checkApplyAvoidsIntermediateArray("scala.collection.immutable.Seq"): + """import scala.collection.immutable.{ ::, Seq, Nil } + |class Foo { + | def meth1: Seq[String] = Seq("1", "2", "3") + | def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil))) + |} + """.stripMargin + } + + @Test def testSeqApplyAvoidsIntermediateArray3 = { + checkApplyAvoidsIntermediateArray("scala.collection.Seq"): + """import scala.collection.immutable.{ ::, Nil }, scala.collection.Seq + |class Foo { + | def meth1: Seq[String] = Seq("1", "2", "3") + | def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil))) + |} + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_max1 = { + checkApplyAvoidsIntermediateArray_examples("max1"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", "7") + | def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::("7", Nil))))))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_max2 = { + checkApplyAvoidsIntermediateArray_examples("max2"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", List[Object]()) + | def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::(Nil, Nil))))))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_max3 = { + checkApplyAvoidsIntermediateArray_examples("max3"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", List[Object]("6")) + | def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::(new ::("6", Nil), Nil)))))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_max4 = { + checkApplyAvoidsIntermediateArray_examples("max4"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", List[Object]("5", "6")) + | def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::(new ::("5", new ::("6", Nil)), Nil))))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_over1 = { + checkApplyAvoidsIntermediateArray_examples("over1"): + """ def meth1: List[Object] = List("1", "2", "3", "4", "5", "6", "7", "8") + | def meth2: List[Object] = List(wrapRefArray(Array("1", "2", "3", "4", "5", "6", "7", "8"))*) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_over2 = { + checkApplyAvoidsIntermediateArray_examples("over2"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", "7", List[Object]()) + | def meth2: List[Object] = List(wrapRefArray(Array[Object]("1", "2", "3", "4", "5", "6", "7", Nil))*) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_over3 = { + checkApplyAvoidsIntermediateArray_examples("over3"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", List[Object]("7")) + | def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::(List(wrapRefArray(Array[Object]("7"))*), Nil))))))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_over4 = { + checkApplyAvoidsIntermediateArray_examples("over4"): + """ def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", List[Object]("6", "7")) + | def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::(List(wrapRefArray(Array[Object]("6", "7"))*), Nil)))))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_max5 = { + checkApplyAvoidsIntermediateArray_examples("max5"): + """ def meth1: List[Object] = List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object]()))))))) + | def meth2: List[Object] = new ::(new ::(new ::(new ::(new ::(new ::(new ::(Nil, Nil), Nil), Nil), Nil), Nil), Nil), Nil) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_over5 = { + checkApplyAvoidsIntermediateArray_examples("over5"): + """ def meth1: List[Object] = List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object]())))))))) + | def meth2: List[Object] = new ::(new ::(new ::(new ::(new ::(new ::(new ::(List[Object](wrapRefArray(Array[Object](Nil))*), Nil), Nil), Nil), Nil), Nil), Nil), Nil) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_max6 = { + checkApplyAvoidsIntermediateArray_examples("max6"): + """ def meth1: List[Object] = List[Object]("1", "2", List[Object]("3", "4", List[Object](List[Object]()))) + | def meth2: List[Object] = new ::("1", new ::("2", new ::(new ::("3", new ::("4", new ::(new ::(Nil, Nil), Nil))), Nil))) + """.stripMargin + } + + @Test def testListApplyAvoidsIntermediateArray_over6 = { + checkApplyAvoidsIntermediateArray_examples("over6"): + """ def meth1: List[Object] = List[Object]("1", "2", List[Object]("3", "4", List[Object]("5"))) + | def meth2: List[Object] = new ::("1", new ::("2", new ::(new ::("3", new ::("4", new ::(new ::("5", Nil), Nil))), Nil))) + """.stripMargin + } + + def checkApplyAvoidsIntermediateArray_examples(name: String)(body: String): Unit = { + checkApplyAvoidsIntermediateArray(s"List_$name"): + s"""import scala.collection.immutable.{ ::, Nil }, scala.runtime.ScalaRunTime.wrapRefArray + |class Foo { + |$body + |} + """.stripMargin + } + + def checkApplyAvoidsIntermediateArray(name: String)(source: String): Unit = { + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val meth1 = getMethod(clsNode, "meth1") + val meth2 = getMethod(clsNode, "meth2") + + val instructions1 = instructionsFromMethod(meth1).filter { case TypeOp(CHECKCAST, _) => false case _ => true } + val instructions2 = instructionsFromMethod(meth2).filter { case TypeOp(CHECKCAST, _) => false case _ => true } + + assert(instructions1 == instructions2, + s"the $name.apply method\n" + + diffInstructions(instructions1, instructions2)) + } + } + } diff --git a/tests/run/list-apply-eval.scala b/tests/run/list-apply-eval.scala new file mode 100644 index 000000000000..4cbba6d3e6c2 --- /dev/null +++ b/tests/run/list-apply-eval.scala @@ -0,0 +1,88 @@ +object Test: + + var counter = 0 + + def next = + counter += 1 + counter.toString + + def main(args: Array[String]): Unit = + //List.apply is subject to an optimisation in cleanup + //ensure that the arguments are evaluated in the currect order + // Rewritten to: + // val myList: List = new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), scala.collection.immutable.Nil))); + val myList = List(next, next, next) + assert(myList == List("1", "2", "3"), myList) + + val mySeq = Seq(next, next, next) + assert(mySeq == Seq("4", "5", "6"), mySeq) + + val emptyList = List[Int]() + assert(emptyList == Nil) + + // just assert it doesn't throw CCE to List + val queue = scala.collection.mutable.Queue[String]() + + // test for the cast instruction described in checkApplyAvoidsIntermediateArray + def lub(b: Boolean): List[(String, String)] = + if b then List(("foo", "bar")) else Nil + + // from minimising CI failure in oslib + // again, the lub of :: and Nil is Product, which breaks ++ (which requires IterableOnce) + def lub2(b: Boolean): Unit = + Seq(1) ++ (if (b) Seq(2) else Nil) + + // Examples of arity and nesting arity + // to find the thresholds and reproduce the behaviour of nsc + def examples(): Unit = + val max1 = List[Object]("1", "2", "3", "4", "5", "6", "7") // 7 cons w/ 7 string heads + nil + val max2 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]()) // 7 cons w/ 6 string heads + 1 nil head + nil + val max3 = List[Object]("1", "2", "3", "4", "5", List[Object]("6")) + val max4 = List[Object]("1", "2", "3", "4", List[Object]("5", "6")) + + val over1 = List[Object]("1", "2", "3", "4", "5", "6", "7", "8") // wrap 8-sized array + val over2 = List[Object]("1", "2", "3", "4", "5", "6", "7", List[Object]()) // wrap 8-sized array + val over3 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]("7")) // wrap 1-sized array with 7 + val over4 = List[Object]("1", "2", "3", "4", "5", List[Object]("6", "7")) // wrap 2 + + val max5 = + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + )))))))) // 7 cons + 1 nil + + val over5 = + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( + List[Object]( List[Object]() + )))))))) // 7 cons + 1-sized array wrapping nil + + val max6 = + List[Object]( // ::( + "1", "2", List[Object]( // 1, ::(2, ::(::( + "3", "4", List[Object]( // 3, ::(4, ::(::( + List[Object]() // Nil, Nil + ) // ), Nil)) + ) // ), Nil)) + ) // ) + // 7 cons + 4 string heads + 4 nils for nested lists + + val max7 = + List[Object]( // ::( + "1", "2", List[Object]( // 1, ::(2, ::(::( + "3", "4", List[Object]( // 3, ::(4, ::(::( + "5" // 5, Nil + ) // ), Nil)) + ) // ), Nil)) + ) // ) + // 7 cons + 5 string heads + 3 nils for nested lists