Skip to content

Commit

Permalink
Check method arguments with parametricity when static
Browse files Browse the repository at this point in the history
When a global static is called, allow for a cold argument if the corresponding parameter is not `Matchable`.
  • Loading branch information
Xavientois committed May 16, 2022
1 parent e35b6ff commit fcf823d
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 44 deletions.
45 changes: 35 additions & 10 deletions compiler/src/dotty/tools/dotc/transform/init/Semantic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,8 @@ object Semantic {
def select(f: Symbol, source: Tree): Contextual[Result] =
value.select(f, source) ++ errors

def call(meth: Symbol, args: List[ArgInfo], superType: Type, source: Tree): Contextual[Result] =
value.call(meth, args, superType, source) ++ errors
def call(meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, source: Tree): Contextual[Result] =
value.call(meth, args, receiver, superType, source) ++ errors

def callConstructor(ctor: Symbol, args: List[ArgInfo], source: Tree): Contextual[Result] =
value.callConstructor(ctor, args, source) ++ errors
Expand Down Expand Up @@ -587,7 +587,7 @@ object Semantic {
}
}

def call(meth: Symbol, args: List[ArgInfo], superType: Type, source: Tree, needResolve: Boolean = true): Contextual[Result] = log("call " + meth.show + ", args = " + args, printer, (_: Result).show) {
def call(meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, source: Tree, needResolve: Boolean = true): Contextual[Result] = log("call " + meth.show + ", args = " + args, printer, (_: Result).show) {
def checkArgs = args.flatMap(_.promote)

def isSyntheticApply(meth: Symbol) =
Expand All @@ -600,6 +600,24 @@ object Semantic {
|| (meth eq defn.Object_ne)
|| (meth eq defn.Any_isInstanceOf)

def checkArgsWithParametricity() =
val methodType = atPhaseBeforeTransforms { meth.info.stripPoly }
var allArgsPromote = true
var allParamTypes = methodType.paramInfoss.flatten.map(_.repeatedToSingle)
val errors = allParamTypes.zip(args).flatMap { (info, arg) =>
val isTypeParam = info.isInstanceOf[TypeParamRef]
val isWithinBounds = info.bounds.lo <:< defn.NothingType && defn.AnyType <:< info.bounds.hi
def otherParamContains = allParamTypes.exists { param => param != info && param.typeSymbol != defn.ClassTagClass && info.occursIn(param) }
val errors = arg.promote
allArgsPromote = allArgsPromote && errors.isEmpty
// A non-hot method argument is allowed if the corresponding parameter type is a
// type parameter T with Any as its upper bound and Nothing as its lower bound.
// the other arguments should either correspond to a parameter type that is T
// or that does not contain T as a component.
if isTypeParam && !isWithinBounds && !otherParamContains then Nil else errors
}
(errors, allArgsPromote)

// fast track if the current object is already initialized
if promoted.isCurrentObjectPromoted then Result(Hot, Nil)
else if isAlwaysSafe(meth) then Result(Hot, Nil)
Expand All @@ -610,7 +628,14 @@ object Semantic {
val klass = meth.owner.companionClass.asClass
instantiate(klass, klass.primaryConstructor, args, source)
else
Result(Hot, checkArgs)
if receiver.typeSymbol.isStaticOwner then
val (errors, allArgsPromote) = checkArgsWithParametricity()
if allArgsPromote || errors.nonEmpty then
Result(Hot, errors)
else
Result(Cold, errors)
else
Result(Hot, checkArgs)

case Cold =>
val error = CallCold(meth, source, trace.toVector)
Expand Down Expand Up @@ -666,7 +691,7 @@ object Semantic {
}

case RefSet(refs) =>
val resList = refs.map(_.call(meth, args, superType, source))
val resList = refs.map(_.call(meth, args, receiver, superType, source))
val value2 = resList.map(_.value).join
val errors = resList.flatMap(_.errors)
Result(value2, errors)
Expand Down Expand Up @@ -946,7 +971,7 @@ object Semantic {
locally {
given Trace = trace2
val args = member.info.paramInfoss.flatten.map(_ => ArgInfo(Hot, EmptyTree))
val res = warm.call(member, args, superType = NoType, source = member.defTree)
val res = warm.call(member, args, receiver = NoType, superType = NoType, source = member.defTree)
buffer ++= res.ensureHot(msg, source).errors
}
else
Expand Down Expand Up @@ -1126,14 +1151,14 @@ object Semantic {
case Select(supert: Super, _) =>
val SuperType(thisTp, superTp) = supert.tpe
val thisValue2 = resolveThis(thisTp.classSymbol.asClass, thisV, klass, ref)
Result(thisValue2, errors).call(ref.symbol, args, superTp, expr)
Result(thisValue2, errors).call(ref.symbol, args, thisTp, superTp, expr)

case Select(qual, _) =>
val res = eval(qual, thisV, klass) ++ errors
if ref.symbol.isConstructor then
res.callConstructor(ref.symbol, args, source = expr)
else
res.call(ref.symbol, args, superType = NoType, source = expr)
res.call(ref.symbol, args, receiver = qual.tpe, superType = NoType, source = expr)

case id: Ident =>
id.tpe match
Expand All @@ -1142,13 +1167,13 @@ object Semantic {
val enclosingClass = id.symbol.owner.enclosingClass.asClass
val thisValue2 = resolveThis(enclosingClass, thisV, klass, id)
// local methods are not a member, but we can reuse the method `call`
thisValue2.call(id.symbol, args, superType = NoType, expr, needResolve = false)
thisValue2.call(id.symbol, args, receiver = NoType, superType = NoType, expr, needResolve = false)
case TermRef(prefix, _) =>
val res = cases(prefix, thisV, klass, id) ++ errors
if id.symbol.isConstructor then
res.callConstructor(id.symbol, args, source = expr)
else
res.call(id.symbol, args, superType = NoType, source = expr)
res.call(id.symbol, args, receiver = prefix, superType = NoType, source = expr)

case Select(qualifier, name) =>
val qualRes = eval(qualifier, thisV, klass)
Expand Down
4 changes: 2 additions & 2 deletions tests/init/neg/early-promote.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Y {

val n = 10
val x = new X
List(x.b) // unsafe promotion
println(x.b) // unsafe promotion

}

Expand All @@ -24,7 +24,7 @@ class A { // checking A
def c = new C
}
val b = new B()
List(b) // error: the checker simply issue warnings for objects that contain inner classes
println(b) // error: the checker simply issue warnings for objects that contain inner classes
val af = 42
}

Expand Down
11 changes: 0 additions & 11 deletions tests/init/neg/enum-desugared.check

This file was deleted.

9 changes: 0 additions & 9 deletions tests/init/neg/enum.check

This file was deleted.

4 changes: 2 additions & 2 deletions tests/init/neg/inner-case.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ class Foo {
}

val a = Inner(5) // ok
println(a) // error
println(a) // error

var count = 0
println(a) // ok
println(a) // ok
}
6 changes: 3 additions & 3 deletions tests/init/neg/inner-new.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ class Foo {
def f() = count + 1
}

val a = new Inner // ok
println(a) // error
val a = new Inner // ok
println(a) // error

var count = 0
println(a) // ok
println(a) // ok
}
13 changes: 9 additions & 4 deletions tests/init/neg/leak-warm.check
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
-- Error: tests/init/neg/leak-warm.scala:18:26 -------------------------------------------------------------------------
18 | val l: List[A] = List(c, d) // error
| ^^^^
| Cannot prove that the value is fully initialized. Only initialized values may be used as arguments.
-- Error: tests/init/neg/leak-warm.scala:19:18 -------------------------------------------------------------------------
19 | val l2 = l.map(_.m()) // error
| ^^^^^^^^^^^^
| Call method leakWarm.l.map[leakWarm.A#B](
| {
| def $anonfun(_$1: leakWarm.A): leakWarm.A#B = _$1.m()
| closure($anonfun)
| }
| ) on a value with an unknown initialization.
4 changes: 2 additions & 2 deletions tests/init/neg/leak-warm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ object leakWarm {
}
val c = new C(1, 2)
val d = new D(3, 4)
val l: List[A] = List(c, d) // error
val l2 = l.map(_.m())
val l: List[A] = List(c, d)
val l2 = l.map(_.m()) // error
}
2 changes: 2 additions & 0 deletions tests/init/neg/some-this.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class X:
val some = Some(this) // error
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ object ErrorMessageID {
final val NoExplanationID = $new(1, "NoExplanationID")

private[this] val $values: Array[ErrorMessageID] =
Array(this.LazyErrorId, this.NoExplanationID) // error
Array(this.LazyErrorId, this.NoExplanationID)

def values: Array[ErrorMessageID] = $values.clone()

Expand Down
File renamed without changes.
3 changes: 3 additions & 0 deletions tests/init/pos/inner-enum-multi-variant.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class Outer:
enum Color:
case Red, Blue
4 changes: 4 additions & 0 deletions tests/init/pos/inner-enum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class Outer:
enum MyEnum {
case Case
}

0 comments on commit fcf823d

Please sign in to comment.