Skip to content

Commit

Permalink
Avoid incorrect simplifications when updating bounds in the constraint
Browse files Browse the repository at this point in the history
When combining an old and a new bound, we use `Type#&`/`Type#|` which perform
simplifications. This is usually fine, but if the new bounds refer to the
parameter currently being updated, we can run into cyclic reasoning issues which
make the simplifications invalid after the update.

We already have logic for handling self-references in parameter bounds:
`updateEntry` calls `ensureNonCyclic` which sanitizes the type, but at this point
the simplifications have already occured. This commit simply moves the logic out
of `updateEntry` so that we can sanitize the new bounds before simplification.

More precisely, we rename `ensureNonCyclic` to `validBoundsFor` which calls
`validBoundFor`. Both are used to sanitize bounds where needed in `addOneBound`
and `unify`.

Since all calls to `updateEntry` now have sanitized bounds, we no longer need to
sanitize them in `updateEntry` itself, we document this change by adding a
pre-condition to `updateEntry`.

For the record, here's how `ConstraintsTest#validBoundsInit` used to fail.
It defines a method:

    def foo[S >: T <: T | Int, T <: String]: Any

Before this commit, when `foo` was added to the current constraints, the
constraint `S <: T | Int` was propagated to the lower bound `T` of `S`. The
updated upper bound of `T` was thus set to:

    String & (T | Int)

But because `Type#&` performs simplifications, this became

    T | (String & Int)

by relying on the fact that at this point, `T <: String`. But in fact this
simplified bound no longer ensures that `T <: String`! The self-reference was
then replaced by `Any` in `OrderingConstraint#ensureNonCyclic`.
After this commit, the problematic simplification no longer occurs since the
new `T | Int` is sanitized to `Any` before being intersected with the old bound.
  • Loading branch information
smarter committed Nov 30, 2022
1 parent 81235b7 commit 50eb0e9
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 35 deletions.
19 changes: 19 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ abstract class Constraint extends Showable {
* - Another type, indicating a solution for the parameter
*
* @pre `this contains param`.
* @pre `tp` does not contain top-level references to `param`
* (see `validBoundsFor`)
*/
def updateEntry(param: TypeParamRef, tp: Type)(using Context): This

Expand Down Expand Up @@ -172,6 +174,23 @@ abstract class Constraint extends Showable {
*/
def occursAtToplevel(param: TypeParamRef, tp: Type)(using Context): Boolean

/** Sanitize `bound` to make it either a valid upper or lower bound for
* `param` depending on `isUpper`.
*
* Toplevel references to `param`, are replaced by `Any` if `isUpper` is true
* and `Nothing` otherwise.
*
* @see `occursAtTopLevel` for a definition of "toplevel"
* @see `validBoundsFor` to sanitize both the lower and upper bound at once.
*/
def validBoundFor(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Type

/** Sanitize `bounds` to make them valid constraints for `param`.
*
* @see `validBoundFor` for details.
*/
def validBoundsFor(param: TypeParamRef, bounds: TypeBounds)(using Context): Type

/** A string that shows the reverse dependencies maintained by this constraint
* (coDeps and contraDeps for OrderingConstraints).
*/
Expand Down
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ trait ConstraintHandling {
end LevelAvoidMap

/** Approximate `rawBound` if needed to make it a legal bound of `param` by
* avoiding wildcards and types with a level strictly greater than its
* avoiding cycles, wildcards and types with a level strictly greater than its
* `nestingLevel`.
*
* Note that level-checking must be performed here and cannot be delayed
Expand All @@ -283,7 +283,7 @@ trait ConstraintHandling {
// This is necessary for i8900-unflip.scala to typecheck.
val v = if necessaryConstraintsOnly then -this.variance else this.variance
atVariance(v)(super.legalVar(tp))
approx(rawBound)
constraint.validBoundFor(param, approx(rawBound), isUpper)
end legalBound

protected def addOneBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Boolean =
Expand Down Expand Up @@ -413,8 +413,10 @@ trait ConstraintHandling {

constraint = constraint.addLess(p2, p1, direction = if pKept eq p1 then KeepParam2 else KeepParam1)

val boundKept = constraint.nonParamBounds(pKept).substParam(pRemoved, pKept)
var boundRemoved = constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept)
val boundKept = constraint.validBoundsFor(pKept,
constraint.nonParamBounds( pKept).substParam(pRemoved, pKept).bounds)
var boundRemoved = constraint.validBoundsFor(pKept,
constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept).bounds)

if level1 != level2 then
boundRemoved = LevelAvoidMap(-1, math.min(level1, level2))(boundRemoved)
Expand Down
53 changes: 22 additions & 31 deletions compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -525,20 +525,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds,

// ---------- Updates ------------------------------------------------------------

/** If `inst` is a TypeBounds, make sure it does not contain toplevel
* references to `param` (see `Constraint#occursAtToplevel` for a definition
* of "toplevel").
* Any such references are replaced by `Nothing` in the lower bound and `Any`
* in the upper bound.
* References can be direct or indirect through instantiations of other
* parameters in the constraint.
*/
private def ensureNonCyclic(param: TypeParamRef, inst: Type)(using Context): Type =

def recur(tp: Type, fromBelow: Boolean): Type = tp match
def validBoundFor(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Type =
def recur(tp: Type): Type = tp match
case tp: AndOrType =>
val r1 = recur(tp.tp1, fromBelow)
val r2 = recur(tp.tp2, fromBelow)
val r1 = recur(tp.tp1)
val r2 = recur(tp.tp2)
if (r1 eq tp.tp1) && (r2 eq tp.tp2) then tp
else tp.match
case tp: OrType =>
Expand All @@ -547,35 +538,34 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
r1 & r2
case tp: TypeParamRef =>
if tp eq param then
if fromBelow then defn.NothingType else defn.AnyType
if isUpper then defn.AnyType else defn.NothingType
else entry(tp) match
case NoType => tp
case TypeBounds(lo, hi) => if lo eq hi then recur(lo, fromBelow) else tp
case inst => recur(inst, fromBelow)
case TypeBounds(lo, hi) => if lo eq hi then recur(lo) else tp
case inst => recur(inst)
case tp: TypeVar =>
val underlying1 = recur(tp.underlying, fromBelow)
val underlying1 = recur(tp.underlying)
if underlying1 ne tp.underlying then underlying1 else tp
case CapturingType(parent, refs) =>
val parent1 = recur(parent, fromBelow)
val parent1 = recur(parent)
if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp
case tp: AnnotatedType =>
val parent1 = recur(tp.parent, fromBelow)
val parent1 = recur(tp.parent)
if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp
case _ =>
val tp1 = tp.dealiasKeepAnnots
if tp1 ne tp then
val tp2 = recur(tp1, fromBelow)
val tp2 = recur(tp1)
if tp2 ne tp1 then tp2 else tp
else tp

inst match
case bounds: TypeBounds =>
bounds.derivedTypeBounds(
recur(bounds.lo, fromBelow = true),
recur(bounds.hi, fromBelow = false))
case _ =>
inst
end ensureNonCyclic
recur(bound)
end validBoundFor

def validBoundsFor(param: TypeParamRef, bounds: TypeBounds)(using Context): Type =
bounds.derivedTypeBounds(
validBoundFor(param, bounds.lo, isUpper = false),
validBoundFor(param, bounds.hi, isUpper = true))

/** Add the fact `param1 <: param2` to the constraint `current` and propagate
* `<:<` relationships between parameters ("edges") but not bounds.
Expand Down Expand Up @@ -658,9 +648,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
current1
}

/** The public version of `updateEntry`. Guarantees that there are no cycles */
def updateEntry(param: TypeParamRef, tp: Type)(using Context): This =
updateEntry(this, param, ensureNonCyclic(param, tp)).checkWellFormed()
updateEntry(this, param, tp).checkWellFormed()

def addLess(param1: TypeParamRef, param2: TypeParamRef, direction: UnificationDirection)(using Context): This =
order(this, param1, param2, direction).checkWellFormed()
Expand Down Expand Up @@ -703,7 +692,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds,

def replaceParamIn(other: TypeParamRef) =
val oldEntry = current.entry(other)
val newEntry = current.ensureNonCyclic(other, oldEntry.substParam(param, replacement))
val newEntry = oldEntry.substParam(param, replacement) match
case tp: TypeBounds => validBoundsFor(other, tp)
case tp => tp
current = boundsLens.update(this, current, other, newEntry)
var oldDepEntry = oldEntry
var newDepEntry = newEntry
Expand Down
24 changes: 24 additions & 0 deletions compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,27 @@ class ConstraintsTest:
i"Merging constraints `?S <: ?T` and `Int <: ?S` should result in `Int <:< ?T`: ${ctx.typerState.constraint}")
}
end mergeBoundsTransitivity

@Test def validBoundsInit: Unit = inCompilerContext(
TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes

val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked
assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}")
assert(hi =:= defn.StringType, i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}") // used to be Any
}

@Test def validBoundsUnify: Unit = inCompilerContext(
TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String | Int]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes

s <:< t

val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked
assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}")
assert(hi =:= (defn.StringType | defn.IntType), i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}")
}

0 comments on commit 50eb0e9

Please sign in to comment.