Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sound type avoidance (hopefully!) #14026

Merged
merged 11 commits into from
Dec 14, 2021
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
protected def rootContext(using Context): Context = {
ctx.initialize()
ctx.base.setPhasePlan(comp.phases)
val rootScope = new MutableScope
val rootScope = new MutableScope(0)
val bootstrap = ctx.fresh
.setPeriod(Period(comp.nextRunId, FirstPhaseId))
.setScope(rootScope)
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import typer.ProtoTypes
import transform.SymUtils._
import transform.TypeUtils._
import core._
import Scopes.newScope
import util.Spans._, Types._, Contexts._, Constants._, Names._, Flags._, NameOps._
import Symbols._, StdNames._, Annotations._, Trees._, Symbols._
import Decorators._, DenotTransformers._
Expand Down Expand Up @@ -344,7 +345,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
else parents
val cls = newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic | Final, parents1,
coord = fns.map(_.span).reduceLeft(_ union _))
newScope, coord = fns.map(_.span).reduceLeft(_ union _))
val constr = newConstructor(cls, Synthetic, Nil, Nil).entered
def forwarder(fn: TermSymbol, name: TermName) = {
val fwdMeth = fn.copy(cls, name, Synthetic | Method | Final).entered.asTerm
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/ScalaSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ private sealed trait YSettings:
val YprintSyms: Setting[Boolean] = BooleanSetting("-Yprint-syms", "When printing trees print info in symbols instead of corresponding info in trees.")
val YprintDebug: Setting[Boolean] = BooleanSetting("-Yprint-debug", "When printing trees, print some extra information useful for debugging.")
val YprintDebugOwners: Setting[Boolean] = BooleanSetting("-Yprint-debug-owners", "When printing trees, print owners of definitions.")
val YprintLevel: Setting[Boolean] = BooleanSetting("-Yprint-level", "print nesting levels of symbols and type variables.")
val YshowPrintErrors: Setting[Boolean] = BooleanSetting("-Yshow-print-errors", "Don't suppress exceptions thrown during tree printing.")
val YtestPickler: Setting[Boolean] = BooleanSetting("-Ytest-pickler", "Self-test for pickling functionality; should be used with -Ystop-after:pickler.")
val YcheckReentrant: Setting[Boolean] = BooleanSetting("-Ycheck-reentrant", "Check that compiled program does not contain vars that can be accessed from a global root.")
Expand Down
26 changes: 20 additions & 6 deletions compiler/src/dotty/tools/dotc/core/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ abstract class Constraint extends Showable {
/** A constraint that includes the relationship `p1 <: p2`.
* `<:` relationships between parameters ("edges") are propagated, but
* non-parameter bounds are left alone.
*
* @param direction Must be set to `KeepParam1` or `KeepParam2` when
* `p2 <: p1` is already true depending on which parameter
* the caller intends to keep. This will avoid propagating
* bounds that will be redundant after `p1` and `p2` are
* unified.
*/
def addLess(p1: TypeParamRef, p2: TypeParamRef)(using Context): This

/** A constraint resulting from adding p2 = p1 to this constraint, and at the same
* time transferring all bounds of p2 to p1
*/
def unify(p1: TypeParamRef, p2: TypeParamRef)(using Context): This
def addLess(p1: TypeParamRef, p2: TypeParamRef,
direction: UnificationDirection = UnificationDirection.NoUnification)(using Context): This

/** A new constraint which is derived from this constraint by removing
* the type parameter `param` from the domain and replacing all top-level occurrences
Expand Down Expand Up @@ -174,3 +176,15 @@ abstract class Constraint extends Showable {
*/
def checkConsistentVars()(using Context): Unit
}

/** When calling `Constraint#addLess(p1, p2, ...)`, the caller might end up
* unifying one parameter with the other, this enum lets `addLess` know which
* direction the unification will take.
*/
enum UnificationDirection:
/** Neither p1 nor p2 will be instantiated. */
case NoUnification
/** `p2 := p1`, p1 left uninstantiated. */
case KeepParam1
/** `p1 := p2`, p2 left uninstantiated. */
case KeepParam2
218 changes: 198 additions & 20 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import Flags._
import config.Config
import config.Printers.typr
import reporting.trace
import typer.ProtoTypes.newTypeVar
import typer.ProtoTypes.{newTypeVar, representedParamRef}
import StdNames.tpnme
import UnificationDirection.*
import NameKinds.AvoidNameKind

/** Methods for adding constraints and solving them.
*
Expand Down Expand Up @@ -56,20 +58,68 @@ trait ConstraintHandling {
*/
protected var comparedTypeLambdas: Set[TypeLambda] = Set.empty

protected var myNecessaryConstraintsOnly = false
/** When collecting the constraints needed for a particular subtyping
* judgment to be true, we sometimes need to approximate the constraint
* set (see `TypeComparer#either` for example).
*
* Normally, this means adding extra constraints which may not be necessary
* for the subtyping judgment to be true, but if this variable is set to true
* we will instead under-approximate and keep only the constraints that must
* always be present for the subtyping judgment to hold.
*
* This is needed for GADT bounds inference to be sound, but it is also used
* when constraining a method call based on its expected type to avoid adding
* constraints that would later prevent us from typechecking method
* arguments, see or-inf.scala and and-inf.scala for examples.
*/
protected def necessaryConstraintsOnly(using Context): Boolean =
ctx.mode.is(Mode.GadtConstraintInference) || myNecessaryConstraintsOnly

def checkReset() =
assert(addConstraintInvocations == 0)
assert(frozenConstraint == false)
assert(caseLambda == NoType)
assert(homogenizeArgs == false)
assert(comparedTypeLambdas == Set.empty)

def nestingLevel(param: TypeParamRef) = constraint.typeVarOfParam(param) match
case tv: TypeVar => tv.nestingLevel
case _ => Int.MaxValue

/** If `param` is nested deeper than `maxLevel`, try to instantiate it to a
* fresh type variable of level `maxLevel` and return the new variable.
* If this isn't possible, throw a TypeError.
*/
def atLevel(maxLevel: Int, param: TypeParamRef)(using Context): TypeParamRef =
if nestingLevel(param) <= maxLevel then return param
LevelAvoidMap(0, maxLevel)(param) match
case freshVar: TypeVar => freshVar.origin
case _ => throw new TypeError(
i"Could not decrease the nesting level of ${param} from ${nestingLevel(param)} to $maxLevel in $constraint")

def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds = constraint.nonParamBounds(param)

/** The full lower bound of `param` includes both the `nonParamBounds` and the
* params in the constraint known to be `<: param`, except that
* params with a `nestingLevel` higher than `param` will be instantiated
* to a fresh param at a legal level. See the documentation of `TypeVar`
* for details.
*/
def fullLowerBound(param: TypeParamRef)(using Context): Type =
constraint.minLower(param).foldLeft(nonParamBounds(param).lo)(_ | _)
val maxLevel = nestingLevel(param)
var loParams = constraint.minLower(param)
if maxLevel != Int.MaxValue then
loParams = loParams.mapConserve(atLevel(maxLevel, _))
loParams.foldLeft(nonParamBounds(param).lo)(_ | _)

/** The full upper bound of `param`, see the documentation of `fullLowerBounds` above. */
def fullUpperBound(param: TypeParamRef)(using Context): Type =
constraint.minUpper(param).foldLeft(nonParamBounds(param).hi)(_ & _)
val maxLevel = nestingLevel(param)
var hiParams = constraint.minUpper(param)
if maxLevel != Int.MaxValue then
hiParams = hiParams.mapConserve(atLevel(maxLevel, _))
hiParams.foldLeft(nonParamBounds(param).hi)(_ & _)

/** Full bounds of `param`, including other lower/upper params.
*
Expand All @@ -79,10 +129,111 @@ trait ConstraintHandling {
def fullBounds(param: TypeParamRef)(using Context): TypeBounds =
nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))

/** If true, eliminate wildcards in bounds by avoidance, otherwise replace
* them by fresh variables.
/** An approximating map that prevents types nested deeper than maxLevel as
* well as WildcardTypes from leaking into the constraint.
* Note that level-checking is turned off after typer and in uncommitable
* TyperState since these leaks should be safe.
*/
protected def approximateWildcards: Boolean = true
class LevelAvoidMap(topLevelVariance: Int, maxLevel: Int)(using Context) extends TypeOps.AvoidMap:
variance = topLevelVariance

/** Are we allowed to refer to types of the given `level`? */
private def levelOK(level: Int): Boolean =
level <= maxLevel || ctx.isAfterTyper || !ctx.typerState.isCommittable

def toAvoid(tp: NamedType): Boolean =
tp.prefix == NoPrefix && !tp.symbol.isStatic && !levelOK(tp.symbol.nestingLevel)

/** Return a (possibly fresh) type variable of a level no greater than `maxLevel` which is:
* - lower-bounded by `tp` if variance >= 0
* - upper-bounded by `tp` if variance <= 0
* If this isn't possible, return the empty range.
*/
def legalVar(tp: TypeVar): Type =
val oldParam = tp.origin
val nameKind =
if variance > 0 then AvoidNameKind.UpperBound
else if variance < 0 then AvoidNameKind.LowerBound
else AvoidNameKind.BothBounds

/** If it exists, return the first param in the list created in a previous call to `legalVar(tp)`
* with the appropriate level and variance.
*/
def findParam(params: List[TypeParamRef]): Option[TypeParamRef] =
params.find(p =>
nestingLevel(p) <= maxLevel && representedParamRef(p) == oldParam &&
(p.paramName.is(AvoidNameKind.BothBounds) ||
variance != 0 && p.paramName.is(nameKind)))

// First, check if we can reuse an existing parameter, this is more than an optimization
// since it avoids an infinite loop in tests/pos/i8900-cycle.scala
findParam(constraint.lower(oldParam)).orElse(findParam(constraint.upper(oldParam))) match
case Some(param) =>
constraint.typeVarOfParam(param)
case _ =>
// Otherwise, try to return a fresh type variable at `maxLevel` with
// the appropriate constraints.
val name = nameKind(oldParam.paramName.toTermName).toTypeName
val freshVar = newTypeVar(TypeBounds.upper(tp.topType), name,
nestingLevel = maxLevel, represents = oldParam)
val ok =
if variance < 0 then
addLess(freshVar.origin, oldParam)
else if variance > 0 then
addLess(oldParam, freshVar.origin)
else
unify(freshVar.origin, oldParam)
if ok then freshVar else emptyRange
end legalVar

override def apply(tp: Type): Type = tp match
case tp: TypeVar if !tp.isInstantiated && !levelOK(tp.nestingLevel) =>
legalVar(tp)
// TypeParamRef can occur in tl bounds
case tp: TypeParamRef =>
constraint.typeVarOfParam(tp) match
case tvar: TypeVar =>
apply(tvar)
case _ => super.apply(tp)
case _ =>
super.apply(tp)

override def mapWild(t: WildcardType) =
if ctx.mode.is(Mode.TypevarsMissContext) then super.mapWild(t)
else
val tvar = newTypeVar(apply(t.effectiveBounds).toBounds, nestingLevel = maxLevel)
tvar
end LevelAvoidMap

/** Approximate `rawBound` if needed to make it a legal bound of `param` by
* avoiding wildcards and types with a level strictly greater than its
* `nestingLevel`.
*
* Note that level-checking must be performed here and cannot be delayed
* until instantiation because if we allow level-incorrect bounds, then we
* might end up reasoning with bad bounds outside of the scope where they are
* defined. This can lead to level-correct but unsound instantiations as
* demonstrated by tests/neg/i8900.scala.
*/
protected def legalBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Type =
// Over-approximate for soundness.
var variance = if isUpper then -1 else 1
// ...unless we can only infer necessary constraints, in which case we
// flip the variance to under-approximate.
if necessaryConstraintsOnly then variance = -variance

val approx = new LevelAvoidMap(variance, nestingLevel(param)):
override def legalVar(tp: TypeVar): Type =
// `legalVar` will create a type variable whose bounds depend on
// `variance`, but whether the variance is positive or negative,
// we can still infer necessary constraints since just creating a
// type variable doesn't reduce the set of possible solutions.
// Therefore, we can safely "unflip" the variance flipped above.
// This is necessary for i8900-unflip.scala to typecheck.
val v = if necessaryConstraintsOnly then -this.variance else this.variance
atVariance(v)(super.legalVar(tp))
approx(rawBound)
end legalBound

protected def addOneBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Boolean =
if !constraint.contains(param) then true
Expand All @@ -91,12 +242,7 @@ trait ConstraintHandling {
// so we shouldn't allow them as constraints either.
false
else
val dropWildcards = new AvoidWildcardsMap:
if !isUpper then variance = -1
override def mapWild(t: WildcardType) =
if approximateWildcards then super.mapWild(t)
else newTypeVar(apply(t.effectiveBounds).toBounds)
val bound = dropWildcards(rawBound)
val bound = legalBound(param, rawBound, isUpper)
val oldBounds @ TypeBounds(lo, hi) = constraint.nonParamBounds(param)
val equalBounds = (if isUpper then lo else hi) eq bound
if equalBounds && !bound.existsPart(_ eq param, StopAt.Static) then
Expand Down Expand Up @@ -191,19 +337,50 @@ trait ConstraintHandling {

def location(using Context) = "" // i"in ${ctx.typerState.stateChainStr}" // use for debugging

/** Make p2 = p1, transfer all bounds of p2 to p1
* @pre less(p1)(p2)
/** Unify p1 with p2: one parameter will be kept in the constraint, the
* other will be removed and its bounds transferred to the remaining one.
*
* If p1 and p2 have different `nestingLevel`, the parameter with the lowest
* level will be kept and the transferred bounds from the other parameter
* will be adjusted for level-correctness.
*/
private def unify(p1: TypeParamRef, p2: TypeParamRef)(using Context): Boolean = {
constr.println(s"unifying $p1 $p2")
assert(constraint.isLess(p1, p2))
constraint = constraint.addLess(p2, p1)
if !constraint.isLess(p1, p2) then
constraint = constraint.addLess(p1, p2)

val level1 = nestingLevel(p1)
val level2 = nestingLevel(p2)
val pKept = if level1 <= level2 then p1 else p2
val pRemoved = if level1 <= level2 then p2 else p1

constraint = constraint.addLess(p2, p1, direction = if pKept eq p1 then KeepParam2 else KeepParam1)

val boundKept = constraint.nonParamBounds(pKept).substParam(pRemoved, pKept)
var boundRemoved = constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept)

if level1 != level2 then
boundRemoved = LevelAvoidMap(-1, math.min(level1, level2))(boundRemoved)
val TypeBounds(lo, hi) = boundRemoved
// After avoidance, the interval might be empty, e.g. in
// tests/pos/i8900-promote.scala:
// >: x.type <: Singleton
// becomes:
// >: Int <: Singleton
// In that case, we can still get a legal constraint
// by replacing the lower-bound to get:
// >: Int & Singleton <: Singleton
if !isSub(lo, hi) then
boundRemoved = TypeBounds(lo & hi, hi)

val down = constraint.exclusiveLower(p2, p1)
val up = constraint.exclusiveUpper(p1, p2)
constraint = constraint.unify(p1, p2)
val bounds = constraint.nonParamBounds(p1)
val lo = bounds.lo
val hi = bounds.hi

val newBounds = (boundKept & boundRemoved).bounds
constraint = constraint.updateEntry(pKept, newBounds).replace(pRemoved, pKept)

val lo = newBounds.lo
val hi = newBounds.hi
isSub(lo, hi) &&
down.forall(addOneBound(_, hi, isUpper = true)) &&
up.forall(addOneBound(_, lo, isUpper = false))
Expand Down Expand Up @@ -256,6 +433,7 @@ trait ConstraintHandling {
final def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
constraint.entry(param) match
case entry: TypeBounds =>
val maxLevel = nestingLevel(param)
val useLowerBound = fromBelow || param.occursIn(entry.hi)
val inst = if useLowerBound then fullLowerBound(param) else fullUpperBound(param)
typr.println(s"approx ${param.show}, from below = $fromBelow, inst = ${inst.show}")
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,10 @@ object Contexts {
if owner != null && owner.isClass then owner.asClass.unforcedDecls
else scope

def nestingLevel: Int =
val sc = effectiveScope
if sc != null then sc.nestingLevel else 0

/** Sourcefile corresponding to given abstract file, memoized */
def getSource(file: AbstractFile, codec: => Codec = Codec(settings.encoding.value)) = {
util.Stats.record("Context.getSource")
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ class Definitions {
private def newPermanentClassSymbol(owner: Symbol, name: TypeName, flags: FlagSet, infoFn: ClassSymbol => Type) =
newClassSymbol(owner, name, flags | Permanent | NoInits | Open, infoFn)

private def enterCompleteClassSymbol(owner: Symbol, name: TypeName, flags: FlagSet, parents: List[TypeRef], decls: Scope = newScope) =
private def enterCompleteClassSymbol(owner: Symbol, name: TypeName, flags: FlagSet, parents: List[TypeRef]): ClassSymbol =
enterCompleteClassSymbol(owner, name, flags, parents, newScope(owner.nestingLevel + 1))

private def enterCompleteClassSymbol(owner: Symbol, name: TypeName, flags: FlagSet, parents: List[TypeRef], decls: Scope) =
newCompleteClassSymbol(owner, name, flags | Permanent | NoInits | Open, parents, decls).entered

private def enterTypeField(cls: ClassSymbol, name: TypeName, flags: FlagSet, scope: MutableScope) =
Expand Down Expand Up @@ -433,7 +436,7 @@ class Definitions {
Any_toString, Any_##, Any_getClass, Any_isInstanceOf, Any_typeTest, Object_eq, Object_ne)

@tu lazy val AnyKindClass: ClassSymbol = {
val cls = newCompleteClassSymbol(ScalaPackageClass, tpnme.AnyKind, AbstractFinal | Permanent, Nil)
val cls = newCompleteClassSymbol(ScalaPackageClass, tpnme.AnyKind, AbstractFinal | Permanent, Nil, newScope(0))
if (!ctx.settings.YnoKindPolymorphism.value)
// Enable kind-polymorphism by exposing scala.AnyKind
cls.entered
Expand Down
Loading