diff --git a/compiler/src/dotty/tools/dotc/core/Constraint.scala b/compiler/src/dotty/tools/dotc/core/Constraint.scala index fb87aed77c41..b849c7aa7093 100644 --- a/compiler/src/dotty/tools/dotc/core/Constraint.scala +++ b/compiler/src/dotty/tools/dotc/core/Constraint.scala @@ -88,6 +88,8 @@ abstract class Constraint extends Showable { * - Another type, indicating a solution for the parameter * * @pre `this contains param`. + * @pre `tp` does not contain top-level references to `param` + * (see `validBoundsFor`) */ def updateEntry(param: TypeParamRef, tp: Type)(using Context): This @@ -172,6 +174,23 @@ abstract class Constraint extends Showable { */ def occursAtToplevel(param: TypeParamRef, tp: Type)(using Context): Boolean + /** Sanitize `bound` to make it either a valid upper or lower bound for + * `param` depending on `isUpper`. + * + * Toplevel references to `param`, are replaced by `Any` if `isUpper` is true + * and `Nothing` otherwise. + * + * @see `occursAtTopLevel` for a definition of "toplevel" + * @see `validBoundsFor` to sanitize both the lower and upper bound at once. + */ + def validBoundFor(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Type + + /** Sanitize `bounds` to make them valid constraints for `param`. + * + * @see `validBoundFor` for details. + */ + def validBoundsFor(param: TypeParamRef, bounds: TypeBounds)(using Context): Type + /** A string that shows the reverse dependencies maintained by this constraint * (coDeps and contraDeps for OrderingConstraints). */ diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index c7c005b5220f..6207e0a3d728 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -257,7 +257,7 @@ trait ConstraintHandling { end LevelAvoidMap /** Approximate `rawBound` if needed to make it a legal bound of `param` by - * avoiding wildcards and types with a level strictly greater than its + * avoiding cycles, wildcards and types with a level strictly greater than its * `nestingLevel`. * * Note that level-checking must be performed here and cannot be delayed @@ -283,7 +283,7 @@ trait ConstraintHandling { // This is necessary for i8900-unflip.scala to typecheck. val v = if necessaryConstraintsOnly then -this.variance else this.variance atVariance(v)(super.legalVar(tp)) - approx(rawBound) + constraint.validBoundFor(param, approx(rawBound), isUpper) end legalBound protected def addOneBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Boolean = @@ -413,8 +413,10 @@ trait ConstraintHandling { constraint = constraint.addLess(p2, p1, direction = if pKept eq p1 then KeepParam2 else KeepParam1) - val boundKept = constraint.nonParamBounds(pKept).substParam(pRemoved, pKept) - var boundRemoved = constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept) + val boundKept = constraint.validBoundsFor(pKept, + constraint.nonParamBounds( pKept).substParam(pRemoved, pKept).bounds) + var boundRemoved = constraint.validBoundsFor(pKept, + constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept).bounds) if level1 != level2 then boundRemoved = LevelAvoidMap(-1, math.min(level1, level2))(boundRemoved) diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 1f65fa324147..603d7a3cb0e3 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -525,20 +525,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds, // ---------- Updates ------------------------------------------------------------ - /** If `inst` is a TypeBounds, make sure it does not contain toplevel - * references to `param` (see `Constraint#occursAtToplevel` for a definition - * of "toplevel"). - * Any such references are replaced by `Nothing` in the lower bound and `Any` - * in the upper bound. - * References can be direct or indirect through instantiations of other - * parameters in the constraint. - */ - private def ensureNonCyclic(param: TypeParamRef, inst: Type)(using Context): Type = - - def recur(tp: Type, fromBelow: Boolean): Type = tp match + def validBoundFor(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Type = + def recur(tp: Type): Type = tp match case tp: AndOrType => - val r1 = recur(tp.tp1, fromBelow) - val r2 = recur(tp.tp2, fromBelow) + val r1 = recur(tp.tp1) + val r2 = recur(tp.tp2) if (r1 eq tp.tp1) && (r2 eq tp.tp2) then tp else tp.match case tp: OrType => @@ -547,35 +538,34 @@ class OrderingConstraint(private val boundsMap: ParamBounds, r1 & r2 case tp: TypeParamRef => if tp eq param then - if fromBelow then defn.NothingType else defn.AnyType + if isUpper then defn.AnyType else defn.NothingType else entry(tp) match case NoType => tp - case TypeBounds(lo, hi) => if lo eq hi then recur(lo, fromBelow) else tp - case inst => recur(inst, fromBelow) + case TypeBounds(lo, hi) => if lo eq hi then recur(lo) else tp + case inst => recur(inst) case tp: TypeVar => - val underlying1 = recur(tp.underlying, fromBelow) + val underlying1 = recur(tp.underlying) if underlying1 ne tp.underlying then underlying1 else tp case CapturingType(parent, refs) => - val parent1 = recur(parent, fromBelow) + val parent1 = recur(parent) if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp case tp: AnnotatedType => - val parent1 = recur(tp.parent, fromBelow) + val parent1 = recur(tp.parent) if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp case _ => val tp1 = tp.dealiasKeepAnnots if tp1 ne tp then - val tp2 = recur(tp1, fromBelow) + val tp2 = recur(tp1) if tp2 ne tp1 then tp2 else tp else tp - inst match - case bounds: TypeBounds => - bounds.derivedTypeBounds( - recur(bounds.lo, fromBelow = true), - recur(bounds.hi, fromBelow = false)) - case _ => - inst - end ensureNonCyclic + recur(bound) + end validBoundFor + + def validBoundsFor(param: TypeParamRef, bounds: TypeBounds)(using Context): Type = + bounds.derivedTypeBounds( + validBoundFor(param, bounds.lo, isUpper = false), + validBoundFor(param, bounds.hi, isUpper = true)) /** Add the fact `param1 <: param2` to the constraint `current` and propagate * `<:<` relationships between parameters ("edges") but not bounds. @@ -658,9 +648,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds, current1 } - /** The public version of `updateEntry`. Guarantees that there are no cycles */ def updateEntry(param: TypeParamRef, tp: Type)(using Context): This = - updateEntry(this, param, ensureNonCyclic(param, tp)).checkWellFormed() + updateEntry(this, param, tp).checkWellFormed() def addLess(param1: TypeParamRef, param2: TypeParamRef, direction: UnificationDirection)(using Context): This = order(this, param1, param2, direction).checkWellFormed() @@ -703,7 +692,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds, def replaceParamIn(other: TypeParamRef) = val oldEntry = current.entry(other) - val newEntry = current.ensureNonCyclic(other, oldEntry.substParam(param, replacement)) + val newEntry = oldEntry.substParam(param, replacement) match + case tp: TypeBounds => validBoundsFor(other, tp) + case tp => tp current = boundsLens.update(this, current, other, newEntry) var oldDepEntry = oldEntry var newDepEntry = newEntry diff --git a/compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala b/compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala index 5ab162b9f05c..ad8578fa3e61 100644 --- a/compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala +++ b/compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala @@ -53,3 +53,27 @@ class ConstraintsTest: i"Merging constraints `?S <: ?T` and `Int <: ?S` should result in `Int <:< ?T`: ${ctx.typerState.constraint}") } end mergeBoundsTransitivity + + @Test def validBoundsInit: Unit = inCompilerContext( + TestConfiguration.basicClasspath, + scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String]: Any }") { + val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2 + val List(s, t) = tvars.tpes + + val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked + assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}") + assert(hi =:= defn.StringType, i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}") // used to be Any + } + + @Test def validBoundsUnify: Unit = inCompilerContext( + TestConfiguration.basicClasspath, + scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String | Int]: Any }") { + val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2 + val List(s, t) = tvars.tpes + + s <:< t + + val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked + assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}") + assert(hi =:= (defn.StringType | defn.IntType), i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}") + }