Skip to content

Commit

Permalink
Support signature polymorphic methods (MethodHandle and VarHandle) (
Browse files Browse the repository at this point in the history
#16225)

fixes #11332
  • Loading branch information
dwijnand authored Nov 22, 2022
2 parents d7e4f94 + 1b0b830 commit 439a17d
Show file tree
Hide file tree
Showing 14 changed files with 193 additions and 8 deletions.
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/backend/jvm/CoreBTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ class CoreBTypes[BTFS <: BTypesFromSymbols[_ <: DottyBackendInterface]](val bTyp

private lazy val jliCallSiteRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.CallSite])
private lazy val jliLambdaMetafactoryRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.LambdaMetafactory])
private lazy val jliMethodHandleRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.MethodHandle])
private lazy val jliMethodHandlesLookupRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.MethodHandles.Lookup])
private lazy val jliMethodHandleRef : ClassBType = classBTypeFromSymbol(defn.MethodHandleClass)
private lazy val jliMethodHandlesLookupRef : ClassBType = classBTypeFromSymbol(defn.MethodHandlesLookupClass)
private lazy val jliMethodTypeRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.MethodType])
private lazy val jliStringConcatFactoryRef : ClassBType = classBTypeFromSymbol(requiredClass("java.lang.invoke.StringConcatFactory")) // since JDK 9
private lazy val srLambdaDeserialize : ClassBType = classBTypeFromSymbol(requiredClass[scala.runtime.LambdaDeserialize])
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,10 @@ class Definitions {
}
def JavaEnumType = JavaEnumClass.typeRef

@tu lazy val MethodHandleClass: ClassSymbol = requiredClass("java.lang.invoke.MethodHandle")
@tu lazy val MethodHandlesLookupClass: ClassSymbol = requiredClass("java.lang.invoke.MethodHandles.Lookup")
@tu lazy val VarHandleClass: ClassSymbol = requiredClass("java.lang.invoke.VarHandle")

@tu lazy val StringBuilderClass: ClassSymbol = requiredClass("scala.collection.mutable.StringBuilder")
@tu lazy val MatchErrorClass : ClassSymbol = requiredClass("scala.MatchError")
@tu lazy val ConversionClass : ClassSymbol = requiredClass("scala.Conversion").typeRef.symbol.asClass
Expand Down
20 changes: 20 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,26 @@ object SymDenotations {

def isSkolem: Boolean = name == nme.SKOLEM

// Java language spec: https://docs.oracle.com/javase/specs/jls/se11/html/jls-15.html#jls-15.12.3
// Scala 2 spec: https://scala-lang.org/files/archive/spec/2.13/06-expressions.html#signature-polymorphic-methods
def isSignaturePolymorphic(using Context): Boolean =
containsSignaturePolymorphic
&& is(JavaDefined)
&& hasAnnotation(defn.NativeAnnot)
&& atPhase(typerPhase)(symbol.denot).paramSymss.match
case List(List(p)) => p.info.isRepeatedParam
case _ => false

def containsSignaturePolymorphic(using Context): Boolean =
maybeOwner == defn.MethodHandleClass
|| maybeOwner == defn.VarHandleClass

def originalSignaturePolymorphic(using Context): Denotation =
if containsSignaturePolymorphic && !isSignaturePolymorphic then
val d = owner.info.member(name)
if d.symbol.isSignaturePolymorphic then d else NoDenotation
else NoDenotation

def isInlineMethod(using Context): Boolean =
isAllOf(InlineMethod, butNot = Accessor)

Expand Down
7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,13 @@ class TreePickler(pickler: TastyPickler) {
writeByte(THROW)
pickleTree(args.head)
}
else if fun.symbol.originalSignaturePolymorphic.exists then
writeByte(APPLYsigpoly)
withLength {
pickleTree(fun)
pickleType(fun.tpe.widenTermRefExpr, richTypes = true) // this widens to a MethodType, so need richTypes
args.foreach(pickleTree)
}
else {
writeByte(APPLY)
withLength {
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,12 @@ class TreeUnpickler(reader: TastyReader,
else tpd.Apply(fn, args)
case TYPEAPPLY =>
tpd.TypeApply(readTerm(), until(end)(readTpt()))
case APPLYsigpoly =>
val fn = readTerm()
val methType = readType()
val args = until(end)(readTerm())
val fun2 = typer.Applications.retypeSignaturePolymorphicFn(fn, methType)
tpd.Apply(fun2, args)
case TYPED =>
val expr = readTerm()
val tpt = readTpt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ class SemanticSymbolBuilder:
def addOwner(owner: Symbol): Unit =
if !owner.isRoot then addSymName(b, owner)

def addOverloadIdx(sym: Symbol): Unit =
def addOverloadIdx(initSym: Symbol): Unit =
// revert from the compiler-generated overload of the signature polymorphic method
val sym = initSym.originalSignaturePolymorphic.symbol.orElse(initSym)
val decls =
val decls0 = sym.owner.info.decls.lookupAll(sym.name)
if sym.owner.isAllOf(JavaModule) then
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/Recheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ abstract class Recheck extends Phase, SymTransformer:
mt.instantiate(argTypes)

def recheckApply(tree: Apply, pt: Type)(using Context): Type =
val funtpe = recheck(tree.fun)
val funTp = recheck(tree.fun)
// reuse the tree's type on signature polymorphic methods, instead of using the (wrong) rechecked one
val funtpe = if tree.fun.symbol.originalSignaturePolymorphic.exists then tree.fun.tpe else funTp
funtpe.widen match
case fntpe: MethodType =>
assert(fntpe.paramInfos.hasSameLengthAs(tree.args))
Expand Down
23 changes: 22 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import Inferencing._
import reporting._
import transform.TypeUtils._
import transform.SymUtils._
import Nullables._
import Nullables._, NullOpsDecorator.*
import config.Feature

import collection.mutable
Expand Down Expand Up @@ -340,6 +340,12 @@ object Applications {
val getter = findDefaultGetter(fn, n, testOnly)
if getter.isEmpty then getter
else spliceMeth(getter.withSpan(fn.span), fn)

def retypeSignaturePolymorphicFn(fun: Tree, methType: Type)(using Context): Tree =
val sym1 = fun.symbol
val flags2 = sym1.flags | NonMember // ensures Select typing doesn't let TermRef#withPrefix revert the type
val sym2 = sym1.copy(info = methType, flags = flags2) // symbol not entered, to avoid overload resolution problems
fun.withType(sym2.termRef)
}

trait Applications extends Compatibility {
Expand Down Expand Up @@ -936,6 +942,21 @@ trait Applications extends Compatibility {
/** Type application where arguments come from prototype, and no implicits are inserted */
def simpleApply(fun1: Tree, proto: FunProto)(using Context): Tree =
methPart(fun1).tpe match {
case funRef: TermRef if funRef.symbol.isSignaturePolymorphic =>
// synthesize a method type based on the types at the call site.
// one can imagine the original signature-polymorphic method as
// being infinitely overloaded, with each individual overload only
// being brought into existence as needed
val originalResultType = funRef.symbol.info.resultType.stripNull
val resultType =
if !originalResultType.isRef(defn.ObjectClass) then originalResultType
else AvoidWildcardsMap()(proto.resultType.deepenProtoTrans) match
case SelectionProto(nme.asInstanceOf_, PolyProto(_, resTp), _, _) => resTp
case resTp if isFullyDefined(resTp, ForceDegree.all) => resTp
case _ => defn.ObjectType
val methType = MethodType(proto.typedArgs().map(_.tpe.widen), resultType)
val fun2 = Applications.retypeSignaturePolymorphicFn(fun1, methType)
simpleApply(fun2, proto)
case funRef: TermRef =>
val app = ApplyTo(tree, fun1, funRef, proto, pt)
convertNewGenericArray(
Expand Down
7 changes: 4 additions & 3 deletions project/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1823,9 +1823,10 @@ object Build {
settings(disableDocSetting).
settings(
versionScheme := Some("semver-spec"),
if (mode == Bootstrapped) {
commonMiMaSettings
} else {
if (mode == Bootstrapped) Def.settings(
commonMiMaSettings,
mimaBinaryIssueFilters ++= MiMaFilters.TastyCore,
) else {
Nil
}
)
Expand Down
3 changes: 3 additions & 0 deletions project/MiMaFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ object MiMaFilters {
ProblemFilters.exclude[MissingClassProblem]("scala.caps$Pure"),
ProblemFilters.exclude[MissingClassProblem]("scala.caps$unsafe$"),
)
val TastyCore: Seq[ProblemFilter] = Seq(
ProblemFilters.exclude[MissingMethodProblem]("dotty.tools.tasty.TastyFormat.APPLYsigpoly"),
)
}
3 changes: 3 additions & 0 deletions tasty/src/dotty/tools/tasty/TastyFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Standard-Section: "ASTs" TopLevelStat*
THROW throwableExpr_Term -- throw throwableExpr
NAMEDARG paramName_NameRef arg_Term -- paramName = arg
APPLY Length fn_Term arg_Term* -- fn(args)
APPLYsigpoly Length fn_Term meth_Type arg_Term* -- The application of a signature-polymorphic method
TYPEAPPLY Length fn_Term arg_Type* -- fn[args]
SUPER Length this_Term mixinTypeIdent_Tree? -- super[mixin]
TYPED Length expr_Term ascriptionType_Term -- expr: ascription
Expand Down Expand Up @@ -578,6 +579,7 @@ object TastyFormat {
// final val ??? = 178
// final val ??? = 179
final val METHODtype = 180
final val APPLYsigpoly = 181

final val MATCHtype = 190
final val MATCHtpt = 191
Expand Down Expand Up @@ -744,6 +746,7 @@ object TastyFormat {
case BOUNDED => "BOUNDED"
case APPLY => "APPLY"
case TYPEAPPLY => "TYPEAPPLY"
case APPLYsigpoly => "APPLYsigpoly"
case NEW => "NEW"
case THROW => "THROW"
case TYPED => "TYPED"
Expand Down
22 changes: 22 additions & 0 deletions tests/explicit-nulls/run/i11332.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// scalajs: --skip
import scala.language.unsafeNulls

import java.lang.invoke._, MethodType.methodType

// A copy of tests/run/i11332.scala
// to test the bootstrap minimisation which failed
// (because bootstrap runs under explicit nulls)
class Foo:
def neg(x: Int): Int = -x

object Test:
def main(args: Array[String]): Unit =
val l = MethodHandles.lookup()
val self = new Foo()

val res4 = {
l // explicit chain method call - previously derivedSelect broke the type
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
.invokeExact(self, 4): Int
}
assert(-4 == res4)
72 changes: 72 additions & 0 deletions tests/run/i11332.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// scalajs: --skip
import scala.language.unsafeNulls

import java.lang.invoke._, MethodType.methodType

class Foo:
def neg(x: Int): Int = -x
def rev(s: String): String = s.reverse
def over(l: Long): String = "long"
def over(i: Int): String = "int"
def unit(s: String): Unit = ()
def obj(s: String): Object = s

object Test:
def main(args: Array[String]): Unit =
val l = MethodHandles.lookup()
val self = new Foo()
val mhNeg = l.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
val mhRev = l.findVirtual(classOf[Foo], "rev", methodType(classOf[String], classOf[String]))
val mhOverL = l.findVirtual(classOf[Foo], "over", methodType(classOf[String], classOf[Long]))
val mhOverI = l.findVirtual(classOf[Foo], "over", methodType(classOf[String], classOf[Int]))
val mhUnit = l.findVirtual(classOf[Foo], "unit", methodType(classOf[Unit], classOf[String]))
val mhObj = l.findVirtual(classOf[Foo], "obj", methodType(classOf[Any], classOf[String]))
val mhCL = l.findStatic(classOf[ClassLoader], "getPlatformClassLoader", methodType(classOf[ClassLoader]))

assert(-42 == (mhNeg.invokeExact(self, 42): Int))
assert(-33 == (mhNeg.invokeExact(self, 33): Int))

assert("oof" == (mhRev.invokeExact(self, "foo"): String))
assert("rab" == (mhRev.invokeExact(self, "bar"): String))

assert("long" == (mhOverL.invokeExact(self, 1L): String))
assert("int" == (mhOverI.invokeExact(self, 1): String))

assert(-3 == (id(mhNeg.invokeExact(self, 3)): Int))
expectWrongMethod(mhNeg.invokeExact(self, 4))

{ mhUnit.invokeExact(self, "hi"): Unit; () } // explicit block
val hi2: Unit = mhUnit.invokeExact(self, "hi2")
assert((()) == hi2)
def hi3: Unit = mhUnit.invokeExact(self, "hi3")
assert((()) == hi3)

{ mhObj.invokeExact(self, "any"); () } // explicit block
val any2 = mhObj.invokeExact(self, "any2")
assert("any2" == any2)
def any3 = mhObj.invokeExact(self, "any3")
assert("any3" == any3)

assert(null != (mhCL.invoke(): ClassLoader))
assert(null != (mhCL.invoke().asInstanceOf[ClassLoader]: ClassLoader))
assert(null != (mhCL.invokeExact(): ClassLoader))
assert(null != (mhCL.invokeExact().asInstanceOf[ClassLoader]: ClassLoader))

expectWrongMethod {
l // explicit chain method call
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
.invokeExact(self, 3)
}
val res4 = {
l // explicit chain method call
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
.invokeExact(self, 4): Int
}
assert(-4 == res4)

def id[T](x: T): T = x

def expectWrongMethod(op: => Any) = try {
op
throw new AssertionError("expected operation to fail but it didn't")
} catch case expected: WrongMethodTypeException => ()
22 changes: 22 additions & 0 deletions tests/run/t12348.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// test: -jvm 11+
// scalajs: --skip
import java.lang.invoke._
import scala.runtime.IntRef

object Test {
def main(args: Array[String]): Unit = {
val ref = new scala.runtime.IntRef(0)
val varHandle = MethodHandles.lookup()
.in(classOf[IntRef])
.findVarHandle(classOf[IntRef], "elem", classOf[Int])
assert(0 == (varHandle.getAndSet(ref, 1): Int))
assert(1 == (varHandle.getAndSet(ref, 2): Int))
assert(2 == ref.elem)

assert((()) == (varHandle.set(ref, 3): Any))
assert(3 == (varHandle.get(ref): Int))

assert(true == (varHandle.compareAndSet(ref, 3, 4): Any))
assert(4 == (varHandle.get(ref): Int))
}
}

0 comments on commit 439a17d

Please sign in to comment.