Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List(...) optimization to avoid intermediate array #17166

Merged
merged 7 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -517,14 +517,16 @@ class Definitions {
methodNames.map(getWrapVarargsArrayModule.requiredMethod(_))
})

@tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List")
def ListType: TypeRef = ListClass.typeRef
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
def NilType: TermRef = NilModule.termRef
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
def ConsType: TypeRef = ConsClass.typeRef
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")
@tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List")
def ListType: TypeRef = ListClass.typeRef
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
@tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply)
def ListModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.List)
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
def NilType: TermRef = NilModule.termRef
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
def ConsType: TypeRef = ConsClass.typeRef
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")

@tu lazy val SingletonClass: ClassSymbol =
// needed as a synthetic class because Scala 2.x refers to it in classfiles
Expand All @@ -534,16 +536,18 @@ class Definitions {
List(AnyType), EmptyScope)
@tu lazy val SingletonType: TypeRef = SingletonClass.typeRef

@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
@tu lazy val SeqModule_apply: Symbol = SeqModule.requiredMethod(nme.apply)
def SeqModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.Seq)
def SeqClass(using Context): ClassSymbol = SeqType.symbol.asClass
@tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply)
@tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head)
@tu lazy val Seq_drop : Symbol = SeqClass.requiredMethod(nme.drop)
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")


@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")
Expand Down
71 changes: 55 additions & 16 deletions compiler/src/dotty/tools/dotc/transform/ArrayApply.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
package dotty.tools.dotc
package dotty.tools
package dotc
package transform

import core.*
import ast.tpd
import core.*, Contexts.*, Decorators.*, Symbols.*, Flags.*, StdNames.*
import reporting.trace
import util.Property
import MegaPhase.*
import Contexts.*
import Symbols.*
import Flags.*
import StdNames.*
import dotty.tools.dotc.ast.tpd



/** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode.
*
Expand All @@ -22,27 +19,69 @@ class ArrayApply extends MiniPhase {

override def description: String = ArrayApply.description

override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
private val TransformListApplyBudgetKey = new Property.Key[Int]
private def transformListApplyBudget(using Context) =
ctx.property(TransformListApplyBudgetKey).getOrElse(8) // default is 8, as originally implemented in nsc

override def prepareForApply(tree: Apply)(using Context): Context = tree match
case SeqApplyArgs(elems) =>
ctx.fresh.setProperty(TransformListApplyBudgetKey, transformListApplyBudget - elems.length)
case _ => ctx

override def transformApply(tree: Apply)(using Context): Tree =
if isArrayModuleApply(tree.symbol) then
tree.args match {
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil
tree.args match
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: ct :: Nil
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) =>
seqLit

case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil
case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: Nil
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) =>
tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)
JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)

case _ =>
tree
}

else tree
else tree match
case SeqApplyArgs(elems) if transformListApplyBudget > 0 || elems.isEmpty =>
val consed = elems.foldRight(ref(defn.NilModule)): (elem, acc) =>
New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc))
consed.cast(tree.tpe)
case _ => tree

private def isArrayModuleApply(sym: Symbol)(using Context): Boolean =
sym.name == nme.apply
&& (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension)))

private def isListApply(tree: Tree)(using Context): Boolean =
(tree.symbol == defn.ListModule_apply || tree.symbol.name == nme.apply) && appliedCore(tree).match
case Select(qual, _) =>
val sym = qual.symbol
sym == defn.ListModule
|| sym == defn.ListModuleAlias
case _ => false

private def isSeqApply(tree: Tree)(using Context): Boolean =
isListApply(tree) || tree.symbol == defn.SeqModule_apply && appliedCore(tree).match
case Select(qual, _) =>
val sym = qual.symbol
sym == defn.SeqModule
|| sym == defn.SeqModuleAlias
|| sym == defn.CollectionSeqType.symbol.companionModule
case _ => false

private object SeqApplyArgs:
def unapply(tree: Apply)(using Context): Option[List[Tree]] =
if isSeqApply(tree) then
tree.args match
// <List or Seq>(a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types
case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil
if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) =>
Some(rest.elems)
case _ => None
else None


/** Only optimize when classtag if it is one of
* - `ClassTag.apply(classOf[XYZ])`
* - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ``
Expand Down
152 changes: 151 additions & 1 deletion compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package dotty.tools.backend.jvm
package dotty.tools
package backend.jvm

import org.junit.Test
import org.junit.Assert._
Expand Down Expand Up @@ -160,4 +161,153 @@ class ArrayApplyOptTest extends DottyBytecodeTest {
}
}

@Test def testListApplyAvoidsIntermediateArray = {
checkApplyAvoidsIntermediateArray("List"):
"""import scala.collection.immutable.{ ::, Nil }
|class Foo {
| def meth1: List[String] = List("1", "2", "3")
| def meth2: List[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray = {
checkApplyAvoidsIntermediateArray("Seq"):
"""import scala.collection.immutable.{ ::, Nil }
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray2 = {
checkApplyAvoidsIntermediateArray("scala.collection.immutable.Seq"):
"""import scala.collection.immutable.{ ::, Seq, Nil }
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray3 = {
checkApplyAvoidsIntermediateArray("scala.collection.Seq"):
"""import scala.collection.immutable.{ ::, Nil }, scala.collection.Seq
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

dwijnand marked this conversation as resolved.
Show resolved Hide resolved
@Test def testListApplyAvoidsIntermediateArray_max1 = {
checkApplyAvoidsIntermediateArray_examples("max1"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", "7")
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::("7", Nil)))))))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_max2 = {
checkApplyAvoidsIntermediateArray_examples("max2"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", List[Object]())
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::(Nil, Nil)))))))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_max3 = {
checkApplyAvoidsIntermediateArray_examples("max3"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", List[Object]("6"))
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::(new ::("6", Nil), Nil))))))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_max4 = {
checkApplyAvoidsIntermediateArray_examples("max4"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", List[Object]("5", "6"))
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::(new ::("5", new ::("6", Nil)), Nil)))))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_over1 = {
checkApplyAvoidsIntermediateArray_examples("over1"):
""" def meth1: List[Object] = List("1", "2", "3", "4", "5", "6", "7", "8")
| def meth2: List[Object] = List(wrapRefArray(Array("1", "2", "3", "4", "5", "6", "7", "8"))*)
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_over2 = {
checkApplyAvoidsIntermediateArray_examples("over2"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", "7", List[Object]())
| def meth2: List[Object] = List(wrapRefArray(Array[Object]("1", "2", "3", "4", "5", "6", "7", Nil))*)
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_over3 = {
checkApplyAvoidsIntermediateArray_examples("over3"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", List[Object]("7"))
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::(List(wrapRefArray(Array[Object]("7"))*), Nil)))))))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_over4 = {
checkApplyAvoidsIntermediateArray_examples("over4"):
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", List[Object]("6", "7"))
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::(List(wrapRefArray(Array[Object]("6", "7"))*), Nil))))))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_max5 = {
checkApplyAvoidsIntermediateArray_examples("max5"):
""" def meth1: List[Object] = List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object]())))))))
| def meth2: List[Object] = new ::(new ::(new ::(new ::(new ::(new ::(new ::(Nil, Nil), Nil), Nil), Nil), Nil), Nil), Nil)
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_over5 = {
checkApplyAvoidsIntermediateArray_examples("over5"):
""" def meth1: List[Object] = List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object]()))))))))
| def meth2: List[Object] = new ::(new ::(new ::(new ::(new ::(new ::(new ::(List[Object](wrapRefArray(Array[Object](Nil))*), Nil), Nil), Nil), Nil), Nil), Nil), Nil)
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_max6 = {
checkApplyAvoidsIntermediateArray_examples("max6"):
""" def meth1: List[Object] = List[Object]("1", "2", List[Object]("3", "4", List[Object](List[Object]())))
| def meth2: List[Object] = new ::("1", new ::("2", new ::(new ::("3", new ::("4", new ::(new ::(Nil, Nil), Nil))), Nil)))
""".stripMargin
}

@Test def testListApplyAvoidsIntermediateArray_over6 = {
checkApplyAvoidsIntermediateArray_examples("over6"):
""" def meth1: List[Object] = List[Object]("1", "2", List[Object]("3", "4", List[Object]("5")))
| def meth2: List[Object] = new ::("1", new ::("2", new ::(new ::("3", new ::("4", new ::(new ::("5", Nil), Nil))), Nil)))
""".stripMargin
}

def checkApplyAvoidsIntermediateArray_examples(name: String)(body: String): Unit = {
checkApplyAvoidsIntermediateArray(s"List_$name"):
s"""import scala.collection.immutable.{ ::, Nil }, scala.runtime.ScalaRunTime.wrapRefArray
|class Foo {
|$body
|}
""".stripMargin
}

def checkApplyAvoidsIntermediateArray(name: String)(source: String): Unit = {
checkBCode(source) { dir =>
val clsIn = dir.lookupName("Foo.class", directory = false).input
val clsNode = loadClassNode(clsIn)
val meth1 = getMethod(clsNode, "meth1")
val meth2 = getMethod(clsNode, "meth2")

val instructions1 = instructionsFromMethod(meth1).filter { case TypeOp(CHECKCAST, _) => false case _ => true }
val instructions2 = instructionsFromMethod(meth2).filter { case TypeOp(CHECKCAST, _) => false case _ => true }

assert(instructions1 == instructions2,
s"the $name.apply method\n" +
diffInstructions(instructions1, instructions2))
}
}

}
88 changes: 88 additions & 0 deletions tests/run/list-apply-eval.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
object Test:

var counter = 0

def next =
counter += 1
counter.toString

def main(args: Array[String]): Unit =
//List.apply is subject to an optimisation in cleanup
//ensure that the arguments are evaluated in the currect order
// Rewritten to:
// val myList: List = new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), scala.collection.immutable.Nil)));
val myList = List(next, next, next)
assert(myList == List("1", "2", "3"), myList)

val mySeq = Seq(next, next, next)
assert(mySeq == Seq("4", "5", "6"), mySeq)

val emptyList = List[Int]()
assert(emptyList == Nil)

// just assert it doesn't throw CCE to List
val queue = scala.collection.mutable.Queue[String]()

// test for the cast instruction described in checkApplyAvoidsIntermediateArray
def lub(b: Boolean): List[(String, String)] =
if b then List(("foo", "bar")) else Nil

// from minimising CI failure in oslib
// again, the lub of :: and Nil is Product, which breaks ++ (which requires IterableOnce)
def lub2(b: Boolean): Unit =
Seq(1) ++ (if (b) Seq(2) else Nil)

// Examples of arity and nesting arity
// to find the thresholds and reproduce the behaviour of nsc
def examples(): Unit =
val max1 = List[Object]("1", "2", "3", "4", "5", "6", "7") // 7 cons w/ 7 string heads + nil
val max2 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]()) // 7 cons w/ 6 string heads + 1 nil head + nil
val max3 = List[Object]("1", "2", "3", "4", "5", List[Object]("6"))
val max4 = List[Object]("1", "2", "3", "4", List[Object]("5", "6"))

val over1 = List[Object]("1", "2", "3", "4", "5", "6", "7", "8") // wrap 8-sized array
val over2 = List[Object]("1", "2", "3", "4", "5", "6", "7", List[Object]()) // wrap 8-sized array
val over3 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]("7")) // wrap 1-sized array with 7
val over4 = List[Object]("1", "2", "3", "4", "5", List[Object]("6", "7")) // wrap 2

val max5 =
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
)))))))) // 7 cons + 1 nil

val over5 =
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object]( List[Object]()
)))))))) // 7 cons + 1-sized array wrapping nil

val max6 =
List[Object]( // ::(
"1", "2", List[Object]( // 1, ::(2, ::(::(
"3", "4", List[Object]( // 3, ::(4, ::(::(
List[Object]() // Nil, Nil
) // ), Nil))
) // ), Nil))
) // )
// 7 cons + 4 string heads + 4 nils for nested lists

val max7 =
List[Object]( // ::(
"1", "2", List[Object]( // 1, ::(2, ::(::(
"3", "4", List[Object]( // 3, ::(4, ::(::(
"5" // 5, Nil
) // ), Nil))
) // ), Nil))
) // )
// 7 cons + 5 string heads + 3 nils for nested lists