From 3ad35a8d67fa6242176e2ae539e678520383f9e0 Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Fri, 4 Nov 2022 16:03:45 +0000 Subject: [PATCH] Make SigPoly symbols resistant to withPrefix changes --- .../tools/dotc/core/tasty/TreeUnpickler.scala | 3 +- .../dotty/tools/dotc/typer/Applications.scala | 9 +++-- tests/explicit-nulls/pos/i11332.scala | 18 ++++++++++ tests/run/i11332.scala | 33 ++++++++++--------- 4 files changed, 43 insertions(+), 20 deletions(-) create mode 100644 tests/explicit-nulls/pos/i11332.scala diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala index 119cd6ba47ba..4c2e8bd9c9bb 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala @@ -1238,8 +1238,7 @@ class TreeUnpickler(reader: TastyReader, val fn = readTerm() val methType = readType() val args = until(end)(readTerm()) - val sym2 = fn.symbol.copy(info = methType) // symbol not entered (same as in simpleApply) - val fun2 = fn.withType(sym2.termRef) + val fun2 = typer.Applications.retypeSignaturePolymorphicFn(fn, methType) tpd.Apply(fun2, args) case TYPED => val expr = readTerm() diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 5c42f7c8d9d8..18cc2d4d83b9 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -341,6 +341,12 @@ object Applications { val getter = findDefaultGetter(fn, n, testOnly) if getter.isEmpty then getter else spliceMeth(getter.withSpan(fn.span), fn) + + def retypeSignaturePolymorphicFn(fun: Tree, methType: Type)(using Context): Tree = + val sym1 = fun.symbol + val flags2 = sym1.flags | NonMember // ensures Select typing doesn't let TermRef#withPrefix revert the type + val sym2 = sym1.copy(info = methType, flags = flags2) // symbol not entered, to avoid overload resolution problems + fun.withType(sym2.termRef) } trait Applications extends Compatibility { @@ -950,8 +956,7 @@ trait Applications extends Compatibility { case resTp if isFullyDefined(resTp, ForceDegree.all) => resTp case _ => defn.ObjectType val methType = MethodType(proto.typedArgs().map(_.tpe.widen), resultType) - val sym2 = funRef.symbol.copy(info = methType) // symbol not entered, to avoid overload resolution problems - val fun2 = fun1.withType(sym2.termRef) + val fun2 = Applications.retypeSignaturePolymorphicFn(fun1, methType) simpleApply(fun2, proto) case funRef: TermRef => val app = ApplyTo(tree, fun1, funRef, proto, pt) diff --git a/tests/explicit-nulls/pos/i11332.scala b/tests/explicit-nulls/pos/i11332.scala new file mode 100644 index 000000000000..a773a55e987e --- /dev/null +++ b/tests/explicit-nulls/pos/i11332.scala @@ -0,0 +1,18 @@ +// scalajs: --skip +import scala.language.unsafeNulls + +import java.lang.invoke._, MethodType.methodType + +// A copy of tests/run/i11332.scala +// to test the bootstrap minimisation which failed +// (because bootstrap runs under explicit nulls) +class Foo: + def neg(x: Int): Int = -x + + val l = MethodHandles.lookup() + val self = new Foo() + + val test = // testing as a expression tree - previously derivedSelect broke the type + l + .findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int])) + .invokeExact(self, 4): Int diff --git a/tests/run/i11332.scala b/tests/run/i11332.scala index 20d994bf2d63..18ff61b5f660 100644 --- a/tests/run/i11332.scala +++ b/tests/run/i11332.scala @@ -16,6 +16,7 @@ class Foo { def id[T](x: T): T = x val l = MethodHandles.lookup() + val self = new Foo() val mhNeg = l.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int])) val mhRev = l.findVirtual(classOf[Foo], "rev", methodType(classOf[String], classOf[String])) val mhOverL = l.findVirtual(classOf[Foo], "over", methodType(classOf[String], classOf[Long])) @@ -24,25 +25,25 @@ class Foo { val mhObj = l.findVirtual(classOf[Foo], "obj", methodType(classOf[Any], classOf[String])) val mhCL = l.findStatic(classOf[ClassLoader], "getPlatformClassLoader", methodType(classOf[ClassLoader])) - val testNeg1 = assert(-42 == (mhNeg.invokeExact(this, 42): Int)) - val testNeg2 = assert(-33 == (mhNeg.invokeExact(this, 33): Int)) + val testNeg1 = assert(-42 == (mhNeg.invokeExact(self, 42): Int)) + val testNeg2 = assert(-33 == (mhNeg.invokeExact(self, 33): Int)) - val testRev1 = assert("oof" == (mhRev.invokeExact(this, "foo"): String)) - val testRev2 = assert("rab" == (mhRev.invokeExact(this, "bar"): String)) + val testRev1 = assert("oof" == (mhRev.invokeExact(self, "foo"): String)) + val testRev2 = assert("rab" == (mhRev.invokeExact(self, "bar"): String)) - val testOverL = assert("long" == (mhOverL.invokeExact(this, 1L): String)) - val testOVerI = assert("int" == (mhOverI.invokeExact(this, 1): String)) + val testOverL = assert("long" == (mhOverL.invokeExact(self, 1L): String)) + val testOVerI = assert("int" == (mhOverI.invokeExact(self, 1): String)) - val testNeg_tvar = assert(-3 == (id(mhNeg.invokeExact(this, 3)): Int)) - val testNeg_obj = expectWrongMethod(mhNeg.invokeExact(this, 4)) + val testNeg_tvar = assert(-3 == (id(mhNeg.invokeExact(self, 3)): Int)) + val testNeg_obj = expectWrongMethod(mhNeg.invokeExact(self, 4)) - val testUnit_exp = { mhUnit.invokeExact(this, "hi"): Unit; () } - val testUnit_val = { val hi2: Unit = mhUnit.invokeExact(this, "hi2"); assert((()) == hi2) } - val testUnit_def = { def hi3: Unit = mhUnit.invokeExact(this, "hi3"); assert((()) == hi3) } + val testUnit_exp = { mhUnit.invokeExact(self, "hi"): Unit; () } + val testUnit_val = { val hi2: Unit = mhUnit.invokeExact(self, "hi2"); assert((()) == hi2) } + val testUnit_def = { def hi3: Unit = mhUnit.invokeExact(self, "hi3"); assert((()) == hi3) } - val testObj_exp = { mhObj.invokeExact(this, "any"); () } - val testObj_val = { val any2 = mhObj.invokeExact(this, "any2"); assert("any2" == any2) } - val testObj_def = { def any3 = mhObj.invokeExact(this, "any3"); assert("any3" == any3) } + val testObj_exp = { mhObj.invokeExact(self, "any"); () } + val testObj_val = { val any2 = mhObj.invokeExact(self, "any2"); assert("any2" == any2) } + val testObj_def = { def any3 = mhObj.invokeExact(self, "any3"); assert("any3" == any3) } val testCl_pass = assert(null != (mhCL.invoke(): ClassLoader)) val testCl_cast = assert(null != (mhCL.invoke().asInstanceOf[ClassLoader]: ClassLoader)) @@ -51,10 +52,10 @@ class Foo { val testNeg_inline_obj = expectWrongMethod(l .findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int])) - .invokeExact(this, 3)) + .invokeExact(self, 3)) val testNeg_inline_pass = assert(-4 == (l .findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int])) - .invokeExact(this, 4): Int)) + .invokeExact(self, 4): Int)) def expectWrongMethod(op: => Any) = try { op