Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check method arguments with parametricity when static #14916

Merged
merged 1 commit into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions compiler/src/dotty/tools/dotc/transform/init/Semantic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,8 @@ object Semantic {
def select(f: Symbol, source: Tree): Contextual[Result] =
value.select(f, source) ++ errors

def call(meth: Symbol, args: List[ArgInfo], superType: Type, source: Tree): Contextual[Result] =
value.call(meth, args, superType, source) ++ errors
def call(meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, source: Tree): Contextual[Result] =
value.call(meth, args, receiver, superType, source) ++ errors

def callConstructor(ctor: Symbol, args: List[ArgInfo], source: Tree): Contextual[Result] =
value.callConstructor(ctor, args, source) ++ errors
Expand Down Expand Up @@ -587,7 +587,7 @@ object Semantic {
}
}

def call(meth: Symbol, args: List[ArgInfo], superType: Type, source: Tree, needResolve: Boolean = true): Contextual[Result] = log("call " + meth.show + ", args = " + args, printer, (_: Result).show) {
def call(meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, source: Tree, needResolve: Boolean = true): Contextual[Result] = log("call " + meth.show + ", args = " + args, printer, (_: Result).show) {
def checkArgs = args.flatMap(_.promote)

def isSyntheticApply(meth: Symbol) =
Expand All @@ -600,6 +600,27 @@ object Semantic {
|| (meth eq defn.Object_ne)
|| (meth eq defn.Any_isInstanceOf)

def checkArgsWithParametricity() =
val methodType = atPhaseBeforeTransforms { meth.info.stripPoly }
var allArgsPromote = true
val allParamTypes = methodType.paramInfoss.flatten.map(_.repeatedToSingle)
val errors = allParamTypes.zip(args).flatMap { (info, arg) =>
val errors = arg.promote
allArgsPromote = allArgsPromote && errors.isEmpty
info match
case typeParamRef: TypeParamRef =>
val bounds = typeParamRef.underlying.bounds
val isWithinBounds = bounds.lo <:< defn.NothingType && defn.AnyType <:< bounds.hi
def otherParamContains = allParamTypes.exists { param => param != info && param.typeSymbol != defn.ClassTagClass && info.occursIn(param) }
// A non-hot method argument is allowed if the corresponding parameter type is a
// type parameter T with Any as its upper bound and Nothing as its lower bound.
// the other arguments should either correspond to a parameter type that is T
// or that does not contain T as a component.
if isWithinBounds && !otherParamContains then Nil else errors
case _ => errors
}
(errors, allArgsPromote)

// fast track if the current object is already initialized
if promoted.isCurrentObjectPromoted then Result(Hot, Nil)
else if isAlwaysSafe(meth) then Result(Hot, Nil)
Expand All @@ -610,7 +631,14 @@ object Semantic {
val klass = meth.owner.companionClass.asClass
instantiate(klass, klass.primaryConstructor, args, source)
else
Result(Hot, checkArgs)
if receiver.typeSymbol.isStaticOwner then
val (errors, allArgsPromote) = checkArgsWithParametricity()
if allArgsPromote || errors.nonEmpty then
Result(Hot, errors)
else
Result(Cold, errors)
else
Result(Hot, checkArgs)

case Cold =>
val error = CallCold(meth, source, trace.toVector)
Expand Down Expand Up @@ -666,7 +694,7 @@ object Semantic {
}

case RefSet(refs) =>
val resList = refs.map(_.call(meth, args, superType, source))
val resList = refs.map(_.call(meth, args, receiver, superType, source))
val value2 = resList.map(_.value).join
val errors = resList.flatMap(_.errors)
Result(value2, errors)
Expand Down Expand Up @@ -946,7 +974,7 @@ object Semantic {
locally {
given Trace = trace2
val args = member.info.paramInfoss.flatten.map(_ => ArgInfo(Hot, EmptyTree))
val res = warm.call(member, args, superType = NoType, source = member.defTree)
val res = warm.call(member, args, receiver = NoType, superType = NoType, source = member.defTree)
buffer ++= res.ensureHot(msg, source).errors
}
else
Expand Down Expand Up @@ -1126,14 +1154,14 @@ object Semantic {
case Select(supert: Super, _) =>
val SuperType(thisTp, superTp) = supert.tpe
val thisValue2 = resolveThis(thisTp.classSymbol.asClass, thisV, klass, ref)
Result(thisValue2, errors).call(ref.symbol, args, superTp, expr)
Result(thisValue2, errors).call(ref.symbol, args, thisTp, superTp, expr)

case Select(qual, _) =>
val res = eval(qual, thisV, klass) ++ errors
if ref.symbol.isConstructor then
res.callConstructor(ref.symbol, args, source = expr)
else
res.call(ref.symbol, args, superType = NoType, source = expr)
res.call(ref.symbol, args, receiver = qual.tpe, superType = NoType, source = expr)

case id: Ident =>
id.tpe match
Expand All @@ -1142,13 +1170,13 @@ object Semantic {
val enclosingClass = id.symbol.owner.enclosingClass.asClass
val thisValue2 = resolveThis(enclosingClass, thisV, klass, id)
// local methods are not a member, but we can reuse the method `call`
thisValue2.call(id.symbol, args, superType = NoType, expr, needResolve = false)
thisValue2.call(id.symbol, args, receiver = NoType, superType = NoType, expr, needResolve = false)
case TermRef(prefix, _) =>
val res = cases(prefix, thisV, klass, id) ++ errors
if id.symbol.isConstructor then
res.callConstructor(id.symbol, args, source = expr)
else
res.call(id.symbol, args, superType = NoType, source = expr)
res.call(id.symbol, args, receiver = prefix, superType = NoType, source = expr)

case Select(qualifier, name) =>
val qualRes = eval(qualifier, thisV, klass)
Expand Down
4 changes: 2 additions & 2 deletions tests/init/neg/early-promote.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Y {

val n = 10
val x = new X
List(x.b) // unsafe promotion
println(x.b) // unsafe promotion

}

Expand All @@ -24,7 +24,7 @@ class A { // checking A
def c = new C
}
val b = new B()
List(b) // error: the checker simply issue warnings for objects that contain inner classes
println(b) // error: the checker simply issue warnings for objects that contain inner classes
val af = 42
}

Expand Down
11 changes: 0 additions & 11 deletions tests/init/neg/enum-desugared.check

This file was deleted.

9 changes: 0 additions & 9 deletions tests/init/neg/enum.check

This file was deleted.

4 changes: 2 additions & 2 deletions tests/init/neg/inner-case.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ class Foo {
}

val a = Inner(5) // ok
println(a) // error
println(a) // error

var count = 0
println(a) // ok
println(a) // ok
}
6 changes: 3 additions & 3 deletions tests/init/neg/inner-new.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ class Foo {
def f() = count + 1
}

val a = new Inner // ok
println(a) // error
val a = new Inner // ok
println(a) // error

var count = 0
println(a) // ok
println(a) // ok
}
8 changes: 8 additions & 0 deletions tests/init/neg/insert-cold-subtype-to-array.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
object A:
def foo[T, S <: T](x: S, array: Array[T]): Unit = array(0) = x

class B:
var a = new Array[B](2)
A.foo(this, a) // error
println(a(0).i)
val i = 99
13 changes: 9 additions & 4 deletions tests/init/neg/leak-warm.check
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
-- Error: tests/init/neg/leak-warm.scala:18:26 -------------------------------------------------------------------------
18 | val l: List[A] = List(c, d) // error
| ^^^^
| Cannot prove that the value is fully initialized. Only initialized values may be used as arguments.
-- Error: tests/init/neg/leak-warm.scala:19:18 -------------------------------------------------------------------------
19 | val l2 = l.map(_.m()) // error
| ^^^^^^^^^^^^
| Call method leakWarm.l.map[leakWarm.A#B](
| {
| def $anonfun(_$1: leakWarm.A): leakWarm.A#B = _$1.m()
| closure($anonfun)
| }
| ) on a value with an unknown initialization.
4 changes: 2 additions & 2 deletions tests/init/neg/leak-warm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ object leakWarm {
}
val c = new C(1, 2)
val d = new D(3, 4)
val l: List[A] = List(c, d) // error
val l2 = l.map(_.m())
val l: List[A] = List(c, d)
val l2 = l.map(_.m()) // error
}
2 changes: 2 additions & 0 deletions tests/init/neg/some-this.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class X:
val some = Some(this) // error
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ object ErrorMessageID {
final val NoExplanationID = $new(1, "NoExplanationID")

private[this] val $values: Array[ErrorMessageID] =
Array(this.LazyErrorId, this.NoExplanationID) // error
Array(this.LazyErrorId, this.NoExplanationID)

def values: Array[ErrorMessageID] = $values.clone()

Expand Down
File renamed without changes.
3 changes: 3 additions & 0 deletions tests/init/pos/inner-enum-multi-variant.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class Outer:
enum Color:
case Red, Blue
4 changes: 4 additions & 0 deletions tests/init/pos/inner-enum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class Outer:
enum MyEnum {
case Case
}