Skip to content

Commit

Permalink
Avoid incorrect simplifications when updating bounds in the constraint (
Browse files Browse the repository at this point in the history
#16410)

When combining an old and a new bound, we use `Type#&`/`Type#|` which
perform simplifications. This is usually fine, but if the new bounds
refer to the parameter currently being updated, we can run into cyclic
reasoning issues which make the simplifications invalid after the
update.

We already have logic for handling self-references in parameter bounds:
`updateEntry` calls `ensureNonCyclic` which sanitizes the type, but at
this point the simplifications have already occured. This commit simply
moves the logic out of `updateEntry` so that we can sanitize the new
bounds before simplification.

More precisely, we rename `ensureNonCyclic` to `validBoundsFor` which
calls `validBoundFor` (singular). Both are used to sanitize bounds where
needed in `addOneBound` and `unify`.

Since all calls to `updateEntry` now have sanitized bounds, we no longer
need to sanitize them in `updateEntry` itself, we document this change
by adding a pre-condition to `updateEntry`.

For the record, here's how `ConstraintsTest#validBoundsInit` used to
fail. It defines a method:

    def foo[S >: T <: T | Int, T <: String]: Any

Before this commit, when `foo` was added to the current constraints, the
constraint `S <: T | Int` was propagated to the lower bound `T` of `S`.
The updated upper bound of `T` was thus set to:

    String & (T | Int)

But because `Type#&` performs simplifications, this became

    T | (String & Int)

by relying on the fact that at this point, `T <: String`. But in fact
this simplified bound no longer ensures that `T <: String`! The
self-reference was then replaced by `Any` in
`OrderingConstraint#ensureNonCyclic`. After this commit, the problematic
simplification no longer occurs since the new `T | Int` is sanitized to
`Any` before being intersected with the old bound.
  • Loading branch information
smarter authored Dec 1, 2022
2 parents 845105a + 50eb0e9 commit e842810
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 35 deletions.
19 changes: 19 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ abstract class Constraint extends Showable {
* - Another type, indicating a solution for the parameter
*
* @pre `this contains param`.
* @pre `tp` does not contain top-level references to `param`
* (see `validBoundsFor`)
*/
def updateEntry(param: TypeParamRef, tp: Type)(using Context): This

Expand Down Expand Up @@ -172,6 +174,23 @@ abstract class Constraint extends Showable {
*/
def occursAtToplevel(param: TypeParamRef, tp: Type)(using Context): Boolean

/** Sanitize `bound` to make it either a valid upper or lower bound for
* `param` depending on `isUpper`.
*
* Toplevel references to `param`, are replaced by `Any` if `isUpper` is true
* and `Nothing` otherwise.
*
* @see `occursAtTopLevel` for a definition of "toplevel"
* @see `validBoundsFor` to sanitize both the lower and upper bound at once.
*/
def validBoundFor(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Type

/** Sanitize `bounds` to make them valid constraints for `param`.
*
* @see `validBoundFor` for details.
*/
def validBoundsFor(param: TypeParamRef, bounds: TypeBounds)(using Context): Type

/** A string that shows the reverse dependencies maintained by this constraint
* (coDeps and contraDeps for OrderingConstraints).
*/
Expand Down
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ trait ConstraintHandling {
end LevelAvoidMap

/** Approximate `rawBound` if needed to make it a legal bound of `param` by
* avoiding wildcards and types with a level strictly greater than its
* avoiding cycles, wildcards and types with a level strictly greater than its
* `nestingLevel`.
*
* Note that level-checking must be performed here and cannot be delayed
Expand All @@ -283,7 +283,7 @@ trait ConstraintHandling {
// This is necessary for i8900-unflip.scala to typecheck.
val v = if necessaryConstraintsOnly then -this.variance else this.variance
atVariance(v)(super.legalVar(tp))
approx(rawBound)
constraint.validBoundFor(param, approx(rawBound), isUpper)
end legalBound

protected def addOneBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Boolean =
Expand Down Expand Up @@ -413,8 +413,10 @@ trait ConstraintHandling {

constraint = constraint.addLess(p2, p1, direction = if pKept eq p1 then KeepParam2 else KeepParam1)

val boundKept = constraint.nonParamBounds(pKept).substParam(pRemoved, pKept)
var boundRemoved = constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept)
val boundKept = constraint.validBoundsFor(pKept,
constraint.nonParamBounds( pKept).substParam(pRemoved, pKept).bounds)
var boundRemoved = constraint.validBoundsFor(pKept,
constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept).bounds)

if level1 != level2 then
boundRemoved = LevelAvoidMap(-1, math.min(level1, level2))(boundRemoved)
Expand Down
53 changes: 22 additions & 31 deletions compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -525,20 +525,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds,

// ---------- Updates ------------------------------------------------------------

/** If `inst` is a TypeBounds, make sure it does not contain toplevel
* references to `param` (see `Constraint#occursAtToplevel` for a definition
* of "toplevel").
* Any such references are replaced by `Nothing` in the lower bound and `Any`
* in the upper bound.
* References can be direct or indirect through instantiations of other
* parameters in the constraint.
*/
private def ensureNonCyclic(param: TypeParamRef, inst: Type)(using Context): Type =

def recur(tp: Type, fromBelow: Boolean): Type = tp match
def validBoundFor(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Type =
def recur(tp: Type): Type = tp match
case tp: AndOrType =>
val r1 = recur(tp.tp1, fromBelow)
val r2 = recur(tp.tp2, fromBelow)
val r1 = recur(tp.tp1)
val r2 = recur(tp.tp2)
if (r1 eq tp.tp1) && (r2 eq tp.tp2) then tp
else tp.match
case tp: OrType =>
Expand All @@ -547,35 +538,34 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
r1 & r2
case tp: TypeParamRef =>
if tp eq param then
if fromBelow then defn.NothingType else defn.AnyType
if isUpper then defn.AnyType else defn.NothingType
else entry(tp) match
case NoType => tp
case TypeBounds(lo, hi) => if lo eq hi then recur(lo, fromBelow) else tp
case inst => recur(inst, fromBelow)
case TypeBounds(lo, hi) => if lo eq hi then recur(lo) else tp
case inst => recur(inst)
case tp: TypeVar =>
val underlying1 = recur(tp.underlying, fromBelow)
val underlying1 = recur(tp.underlying)
if underlying1 ne tp.underlying then underlying1 else tp
case CapturingType(parent, refs) =>
val parent1 = recur(parent, fromBelow)
val parent1 = recur(parent)
if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp
case tp: AnnotatedType =>
val parent1 = recur(tp.parent, fromBelow)
val parent1 = recur(tp.parent)
if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp
case _ =>
val tp1 = tp.dealiasKeepAnnots
if tp1 ne tp then
val tp2 = recur(tp1, fromBelow)
val tp2 = recur(tp1)
if tp2 ne tp1 then tp2 else tp
else tp

inst match
case bounds: TypeBounds =>
bounds.derivedTypeBounds(
recur(bounds.lo, fromBelow = true),
recur(bounds.hi, fromBelow = false))
case _ =>
inst
end ensureNonCyclic
recur(bound)
end validBoundFor

def validBoundsFor(param: TypeParamRef, bounds: TypeBounds)(using Context): Type =
bounds.derivedTypeBounds(
validBoundFor(param, bounds.lo, isUpper = false),
validBoundFor(param, bounds.hi, isUpper = true))

/** Add the fact `param1 <: param2` to the constraint `current` and propagate
* `<:<` relationships between parameters ("edges") but not bounds.
Expand Down Expand Up @@ -658,9 +648,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
current1
}

/** The public version of `updateEntry`. Guarantees that there are no cycles */
def updateEntry(param: TypeParamRef, tp: Type)(using Context): This =
updateEntry(this, param, ensureNonCyclic(param, tp)).checkWellFormed()
updateEntry(this, param, tp).checkWellFormed()

def addLess(param1: TypeParamRef, param2: TypeParamRef, direction: UnificationDirection)(using Context): This =
order(this, param1, param2, direction).checkWellFormed()
Expand Down Expand Up @@ -703,7 +692,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds,

def replaceParamIn(other: TypeParamRef) =
val oldEntry = current.entry(other)
val newEntry = current.ensureNonCyclic(other, oldEntry.substParam(param, replacement))
val newEntry = oldEntry.substParam(param, replacement) match
case tp: TypeBounds => validBoundsFor(other, tp)
case tp => tp
current = boundsLens.update(this, current, other, newEntry)
var oldDepEntry = oldEntry
var newDepEntry = newEntry
Expand Down
24 changes: 24 additions & 0 deletions compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,27 @@ class ConstraintsTest:
i"Merging constraints `?S <: ?T` and `Int <: ?S` should result in `Int <:< ?T`: ${ctx.typerState.constraint}")
}
end mergeBoundsTransitivity

@Test def validBoundsInit: Unit = inCompilerContext(
TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes

val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked
assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}")
assert(hi =:= defn.StringType, i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}") // used to be Any
}

@Test def validBoundsUnify: Unit = inCompilerContext(
TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String | Int]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes

s <:< t

val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked
assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}")
assert(hi =:= (defn.StringType | defn.IntType), i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}")
}

0 comments on commit e842810

Please sign in to comment.