Skip to content

Commit

Permalink
Make SigPoly symbols resistant to withPrefix changes
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Nov 5, 2022
1 parent f3946bc commit 3ad35a8
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 20 deletions.
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1238,8 +1238,7 @@ class TreeUnpickler(reader: TastyReader,
val fn = readTerm()
val methType = readType()
val args = until(end)(readTerm())
val sym2 = fn.symbol.copy(info = methType) // symbol not entered (same as in simpleApply)
val fun2 = fn.withType(sym2.termRef)
val fun2 = typer.Applications.retypeSignaturePolymorphicFn(fn, methType)
tpd.Apply(fun2, args)
case TYPED =>
val expr = readTerm()
Expand Down
9 changes: 7 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,12 @@ object Applications {
val getter = findDefaultGetter(fn, n, testOnly)
if getter.isEmpty then getter
else spliceMeth(getter.withSpan(fn.span), fn)

def retypeSignaturePolymorphicFn(fun: Tree, methType: Type)(using Context): Tree =
val sym1 = fun.symbol
val flags2 = sym1.flags | NonMember // ensures Select typing doesn't let TermRef#withPrefix revert the type
val sym2 = sym1.copy(info = methType, flags = flags2) // symbol not entered, to avoid overload resolution problems
fun.withType(sym2.termRef)
}

trait Applications extends Compatibility {
Expand Down Expand Up @@ -950,8 +956,7 @@ trait Applications extends Compatibility {
case resTp if isFullyDefined(resTp, ForceDegree.all) => resTp
case _ => defn.ObjectType
val methType = MethodType(proto.typedArgs().map(_.tpe.widen), resultType)
val sym2 = funRef.symbol.copy(info = methType) // symbol not entered, to avoid overload resolution problems
val fun2 = fun1.withType(sym2.termRef)
val fun2 = Applications.retypeSignaturePolymorphicFn(fun1, methType)
simpleApply(fun2, proto)
case funRef: TermRef =>
val app = ApplyTo(tree, fun1, funRef, proto, pt)
Expand Down
18 changes: 18 additions & 0 deletions tests/explicit-nulls/pos/i11332.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// scalajs: --skip
import scala.language.unsafeNulls

import java.lang.invoke._, MethodType.methodType

// A copy of tests/run/i11332.scala
// to test the bootstrap minimisation which failed
// (because bootstrap runs under explicit nulls)
class Foo:
def neg(x: Int): Int = -x

val l = MethodHandles.lookup()
val self = new Foo()

val test = // testing as a expression tree - previously derivedSelect broke the type
l
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
.invokeExact(self, 4): Int
33 changes: 17 additions & 16 deletions tests/run/i11332.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Foo {
def id[T](x: T): T = x

val l = MethodHandles.lookup()
val self = new Foo()
val mhNeg = l.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
val mhRev = l.findVirtual(classOf[Foo], "rev", methodType(classOf[String], classOf[String]))
val mhOverL = l.findVirtual(classOf[Foo], "over", methodType(classOf[String], classOf[Long]))
Expand All @@ -24,25 +25,25 @@ class Foo {
val mhObj = l.findVirtual(classOf[Foo], "obj", methodType(classOf[Any], classOf[String]))
val mhCL = l.findStatic(classOf[ClassLoader], "getPlatformClassLoader", methodType(classOf[ClassLoader]))

val testNeg1 = assert(-42 == (mhNeg.invokeExact(this, 42): Int))
val testNeg2 = assert(-33 == (mhNeg.invokeExact(this, 33): Int))
val testNeg1 = assert(-42 == (mhNeg.invokeExact(self, 42): Int))
val testNeg2 = assert(-33 == (mhNeg.invokeExact(self, 33): Int))

val testRev1 = assert("oof" == (mhRev.invokeExact(this, "foo"): String))
val testRev2 = assert("rab" == (mhRev.invokeExact(this, "bar"): String))
val testRev1 = assert("oof" == (mhRev.invokeExact(self, "foo"): String))
val testRev2 = assert("rab" == (mhRev.invokeExact(self, "bar"): String))

val testOverL = assert("long" == (mhOverL.invokeExact(this, 1L): String))
val testOVerI = assert("int" == (mhOverI.invokeExact(this, 1): String))
val testOverL = assert("long" == (mhOverL.invokeExact(self, 1L): String))
val testOVerI = assert("int" == (mhOverI.invokeExact(self, 1): String))

val testNeg_tvar = assert(-3 == (id(mhNeg.invokeExact(this, 3)): Int))
val testNeg_obj = expectWrongMethod(mhNeg.invokeExact(this, 4))
val testNeg_tvar = assert(-3 == (id(mhNeg.invokeExact(self, 3)): Int))
val testNeg_obj = expectWrongMethod(mhNeg.invokeExact(self, 4))

val testUnit_exp = { mhUnit.invokeExact(this, "hi"): Unit; () }
val testUnit_val = { val hi2: Unit = mhUnit.invokeExact(this, "hi2"); assert((()) == hi2) }
val testUnit_def = { def hi3: Unit = mhUnit.invokeExact(this, "hi3"); assert((()) == hi3) }
val testUnit_exp = { mhUnit.invokeExact(self, "hi"): Unit; () }
val testUnit_val = { val hi2: Unit = mhUnit.invokeExact(self, "hi2"); assert((()) == hi2) }
val testUnit_def = { def hi3: Unit = mhUnit.invokeExact(self, "hi3"); assert((()) == hi3) }

val testObj_exp = { mhObj.invokeExact(this, "any"); () }
val testObj_val = { val any2 = mhObj.invokeExact(this, "any2"); assert("any2" == any2) }
val testObj_def = { def any3 = mhObj.invokeExact(this, "any3"); assert("any3" == any3) }
val testObj_exp = { mhObj.invokeExact(self, "any"); () }
val testObj_val = { val any2 = mhObj.invokeExact(self, "any2"); assert("any2" == any2) }
val testObj_def = { def any3 = mhObj.invokeExact(self, "any3"); assert("any3" == any3) }

val testCl_pass = assert(null != (mhCL.invoke(): ClassLoader))
val testCl_cast = assert(null != (mhCL.invoke().asInstanceOf[ClassLoader]: ClassLoader))
Expand All @@ -51,10 +52,10 @@ class Foo {

val testNeg_inline_obj = expectWrongMethod(l
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
.invokeExact(this, 3))
.invokeExact(self, 3))
val testNeg_inline_pass = assert(-4 == (l
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
.invokeExact(this, 4): Int))
.invokeExact(self, 4): Int))

def expectWrongMethod(op: => Any) = try {
op
Expand Down

0 comments on commit 3ad35a8

Please sign in to comment.