diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 9b94c5ab4c4f..7867765cff7b 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -11,6 +11,7 @@ import config.Config import config.Printers.constr import reflect.ClassTag import Constraint.ReverseDeps +import Substituters.SubstParamMap import annotation.tailrec import annotation.internal.sharable import cc.{CapturingType, derivedCapturingType} @@ -37,7 +38,7 @@ object OrderingConstraint { } /** The `current` constraint but with the entry for `param` updated to `entry`. - * `current` is used linearly. If it is different from `prev` it is + * `current` is used linearly. If it is different from `prev` then `current` is * known to be dead after the call. Hence it is OK to update destructively * parts of `current` which are not shared by `prev`. */ @@ -133,6 +134,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds, private val lowerMap : ParamOrdering, private val upperMap : ParamOrdering, private val hardVars : TypeVars) extends Constraint { + thisConstraint => import UnificationDirection.* @@ -243,7 +245,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds, //.showing(i"outer depends on $tv with ${tvdeps.toList}%, % = $result") if co then test(coDeps, upperLens) else test(contraDeps, lowerLens) - private class Adjuster(srcParam: TypeParamRef)(using Context) extends TypeTraverser: + private class Adjuster(srcParam: TypeParamRef)(using Context) + extends TypeTraverser, ConstraintAwareTraversal: + var add: Boolean = compiletime.uninitialized val seen = util.HashSet[LazyRef]() @@ -411,7 +415,6 @@ class OrderingConstraint(private val boundsMap: ParamBounds, tvars.copyToArray(entries1, nparams) newConstraint(boundsMap = this.boundsMap.updated(poly, entries1)) .init(poly) - .adjustDeps(poly, entries1, add = true) } /** Split dependent parameters off the bounds for parameters in `poly`. @@ -433,7 +436,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds, todos.dropInPlace(1) i += 1 } - current.checkWellFormed() + current.adjustDeps(poly, current.boundsMap(poly).nn, add = true) + .checkWellFormed() } // ---------- Updates ------------------------------------------------------------ @@ -591,30 +595,33 @@ class OrderingConstraint(private val boundsMap: ParamBounds, if param == replacement then this.checkWellFormed() else assert(replacement.isValueTypeOrLambda) - var current = - if isRemovable(param.binder) then remove(param.binder) - else updateEntry(this, param, replacement) - def removeParam(ps: List[TypeParamRef]) = ps.filterConserve(param ne _) + val droppedTypeVar = typeVarOfParam(param) - def replaceParam(entry: Type, atPoly: TypeLambda, atIdx: Int): Type = - val pref = atPoly.paramRefs(atIdx) - val newEntry = current.ensureNonCyclic(pref, entry.substParam(param, replacement)) - adjustDeps(newEntry, entry, pref) - newEntry + //println(i"replace $param, $droppedTypeVar with $replacement in $this") + val dropTypeVar = new TypeMap: + override def apply(t: Type): Type = + if t.exists && (t eq droppedTypeVar) then param else mapOver(t) + var current = this + + def removeParam(ps: List[TypeParamRef]) = ps.filterConserve(param ne _) for lo <- lower(param) do current = upperLens.map(this, current, lo, removeParam) for hi <- upper(param) do current = lowerLens.map(this, current, hi, removeParam) current.foreachParam { (p, i) => - current = boundsLens.map(this, current, p, i, - entry => - val newEntry = replaceParam(entry, p, i) - adjustDeps(newEntry, entry, p.paramRefs(i)) - newEntry) + val other = p.paramRefs(i) + if other != param then + val oldEntry = current.entry(other) + val newEntry = current.ensureNonCyclic(other, oldEntry.substParam(param, replacement)) + current = updateEntryNoOrdering(current, other, newEntry, dropTypeVar(oldEntry)) } + + current = + if isRemovable(param.binder) then current.remove(param.binder) + else updateEntry(current, param, replacement) current.dropDeps(param) current.checkWellFormed() end replace diff --git a/compiler/src/dotty/tools/dotc/core/Substituters.scala b/compiler/src/dotty/tools/dotc/core/Substituters.scala index 3e32340b21bd..25cdb5d057f7 100644 --- a/compiler/src/dotty/tools/dotc/core/Substituters.scala +++ b/compiler/src/dotty/tools/dotc/core/Substituters.scala @@ -193,7 +193,7 @@ object Substituters: def apply(tp: Type): Type = substRecThis(tp, from, to, this)(using mapCtx) } - final class SubstParamMap(from: ParamRef, to: Type)(using Context) extends DeepTypeMap, IdempotentCaptRefMap { + class SubstParamMap(from: ParamRef, to: Type)(using Context) extends DeepTypeMap, IdempotentCaptRefMap { def apply(tp: Type): Type = substParam(tp, from, to, this)(using mapCtx) } diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 3243bb242a56..5cc5b6ca3821 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -5497,8 +5497,21 @@ object Types { stop == StopAt.Static && tp.currentSymbol.isStatic && isStaticPrefix(tp.prefix) || stop == StopAt.Package && tp.currentSymbol.is(Package) } + + protected def tyconTypeParams(tp: AppliedType)(using Context): List[ParamInfo] = + tp.tyconTypeParams end VariantTraversal + trait ConstraintAwareTraversal extends VariantTraversal: + override def tyconTypeParams(tp: AppliedType)(using Context): List[ParamInfo] = + tp.tycon match + case tycon: TypeParamRef => + ctx.typerState.constraint.entry(tycon) match + case _: TypeBounds => + case tp1 => if tp1.typeParams.nonEmpty then return tp1.typeParams + case _ => + tp.tyconTypeParams + /** A supertrait for some typemaps that are bijections. Used for capture checking. * BiTypeMaps should map capture references to capture references. */ @@ -5614,7 +5627,7 @@ object Types { derivedSelect(tp, prefix1) case tp: AppliedType => - derivedAppliedType(tp, this(tp.tycon), mapArgs(tp.args, tp.tyconTypeParams)) + derivedAppliedType(tp, this(tp.tycon), mapArgs(tp.args, tyconTypeParams(tp))) case tp: LambdaType => mapOverLambda(tp) @@ -5941,7 +5954,7 @@ object Types { case nil => true } - if (distributeArgs(args, tp.tyconTypeParams)) + if (distributeArgs(args, tyconTypeParams(tp))) range(tp.derivedAppliedType(tycon, loBuf.toList), tp.derivedAppliedType(tycon, hiBuf.toList)) else if tycon.isLambdaSub || args.exists(isRangeOfNonTermTypes) then @@ -6087,7 +6100,7 @@ object Types { } foldArgs(acc, tparams.tail, args.tail) } - foldArgs(this(x, tycon), tp.tyconTypeParams, args) + foldArgs(this(x, tycon), tyconTypeParams(tp), args) case _: BoundType | _: ThisType => x