Skip to content

Commit

Permalink
Merge pull request #15632 from dotty-staging/fix-14770
Browse files Browse the repository at this point in the history
Instantiate more type variables to hard unions
  • Loading branch information
odersky authored Sep 1, 2022
2 parents 398b72e + dfcfb6b commit bf03086
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 58 deletions.
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ abstract class Constraint extends Showable {
*/
def subst(from: TypeLambda, to: TypeLambda)(using Context): This

/** Is `tv` marked as hard in the constraint? */
def isHard(tv: TypeVar): Boolean

/** The same as this constraint, but with `tv` marked as hard. */
def withHard(tv: TypeVar)(using Context): This

/** Gives for each instantiated type var that does not yet have its `inst` field
* set, the instance value stored in the constraint. Storing instances in constraints
* is done only in a temporary way for contexts that may be retracted
Expand Down
39 changes: 31 additions & 8 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import typer.ProtoTypes.{newTypeVar, representedParamRef}
import UnificationDirection.*
import NameKinds.AvoidNameKind
import util.SimpleIdentitySet
import NullOpsDecorator.stripNull

/** Methods for adding constraints and solving them.
*
Expand Down Expand Up @@ -613,8 +614,11 @@ trait ConstraintHandling {
* 1. If `inst` is a singleton type, or a union containing some singleton types,
* widen (all) the singleton type(s), provided the result is a subtype of `bound`.
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
* 2. If `inst` is a union type, approximate the union type from above by an intersection
* of all common base types, provided the result is a subtype of `bound`.
* 2a. If `inst` is a union type and `widenUnions` is true, approximate the union type
* from above by an intersection of all common base types, provided the result
* is a subtype of `bound`.
* 2b. If `inst` is a union type and `widenUnions` is false, turn it into a hard
* union type (except for unions | Null, which are kept in the state they were).
* 3. Widen some irreducible applications of higher-kinded types to wildcard arguments
* (see @widenIrreducible).
* 4. Drop transparent traits from intersections (see @dropTransparentTraits).
Expand All @@ -627,10 +631,12 @@ trait ConstraintHandling {
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
* as those could leak the annotation to users (see run/inferred-repeated-result).
*/
def widenInferred(inst: Type, bound: Type)(using Context): Type =
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
def widenOr(tp: Type) =
val tpw = tp.widenUnion
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
if widenUnions then
val tpw = tp.widenUnion
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
else tp.hardenUnions

def widenSingle(tp: Type) =
val tpw = tp.widenSingletons
Expand All @@ -650,6 +656,23 @@ trait ConstraintHandling {
wideInst.dropRepeatedAnnot
end widenInferred

/** Convert all toplevel union types in `tp` to hard unions */
extension (tp: Type) private def hardenUnions(using Context): Type = tp.widen match
case tp: AndType =>
tp.derivedAndType(tp.tp1.hardenUnions, tp.tp2.hardenUnions)
case tp: RefinedType =>
tp.derivedRefinedType(tp.parent.hardenUnions, tp.refinedName, tp.refinedInfo)
case tp: RecType =>
tp.rebind(tp.parent.hardenUnions)
case tp: HKTypeLambda =>
tp.derivedLambdaType(resType = tp.resType.hardenUnions)
case tp: OrType =>
val tp1 = tp.stripNull
if tp1 ne tp then tp.derivedOrType(tp1.hardenUnions, defn.NullType)
else tp.derivedOrType(tp.tp1.hardenUnions, tp.tp2.hardenUnions, soft = false)
case _ =>
tp

/** The instance type of `param` in the current constraint (which contains `param`).
* If `fromBelow` is true, the instance type is the lub of the parameter's
* lower bounds; otherwise it is the glb of its upper bounds. However,
Expand All @@ -658,18 +681,18 @@ trait ConstraintHandling {
* The instance type is not allowed to contain references to types nested deeper
* than `maxLevel`.
*/
def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean, maxLevel: Int)(using Context): Type = {
val approx = approximation(param, fromBelow, maxLevel).simplified
if fromBelow then
val widened = widenInferred(approx, param)
val widened = widenInferred(approx, param, widenUnions)
// Widening can add extra constraints, in particular the widened type might
// be a type variable which is now instantiated to `param`, and therefore
// cannot be used as an instantiation of `param` without creating a loop.
// If that happens, we run `instanceType` again to find a new instantation.
// (we do not check for non-toplevel occurences: those should never occur
// since `addOneBound` disallows recursive lower bounds).
if constraint.occursAtToplevel(param, widened) then
instanceType(param, fromBelow, maxLevel)
instanceType(param, fromBelow, widenUnions, maxLevel)
else
widened
else
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Decorators._
import Contexts._
import Types._
import Symbols._
import util.SimpleIdentityMap
import util.{SimpleIdentitySet, SimpleIdentityMap}
import collection.mutable
import printing._

Expand Down Expand Up @@ -68,7 +68,7 @@ final class ProperGadtConstraint private(
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}

def this() = this(
myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty),
myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentitySet.empty),
mapping = SimpleIdentityMap.empty,
reverseMapping = SimpleIdentityMap.empty,
wasConstrained = false
Expand Down
35 changes: 23 additions & 12 deletions compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package dotc
package core

import Types._, Contexts._, Symbols._, Decorators._, TypeApplications._
import util.SimpleIdentityMap
import util.{SimpleIdentitySet, SimpleIdentityMap}
import collection.mutable
import printing.Printer
import printing.Texts._
Expand All @@ -24,12 +24,14 @@ object OrderingConstraint {
/** The type of `OrderingConstraint#lowerMap`, `OrderingConstraint#upperMap` */
type ParamOrdering = ArrayValuedMap[List[TypeParamRef]]

/** A new constraint with given maps */
private def newConstraint(boundsMap: ParamBounds, lowerMap: ParamOrdering, upperMap: ParamOrdering)(using Context) : OrderingConstraint =
/** A new constraint with given maps and given set of hard typevars */
private def newConstraint(
boundsMap: ParamBounds, lowerMap: ParamOrdering, upperMap: ParamOrdering,
hardVars: TypeVars)(using Context) : OrderingConstraint =
if boundsMap.isEmpty && lowerMap.isEmpty && upperMap.isEmpty then
empty
else
val result = new OrderingConstraint(boundsMap, lowerMap, upperMap)
val result = new OrderingConstraint(boundsMap, lowerMap, upperMap, hardVars)
if ctx.run != null then ctx.run.nn.recordConstraintSize(result, result.boundsMap.size)
result

Expand Down Expand Up @@ -91,28 +93,28 @@ object OrderingConstraint {
def entries(c: OrderingConstraint, poly: TypeLambda): Array[Type] | Null =
c.boundsMap(poly)
def updateEntries(c: OrderingConstraint, poly: TypeLambda, entries: Array[Type])(using Context): OrderingConstraint =
newConstraint(c.boundsMap.updated(poly, entries), c.lowerMap, c.upperMap)
newConstraint(c.boundsMap.updated(poly, entries), c.lowerMap, c.upperMap, c.hardVars)
def initial = NoType
}

val lowerLens: ConstraintLens[List[TypeParamRef]] = new ConstraintLens[List[TypeParamRef]] {
def entries(c: OrderingConstraint, poly: TypeLambda): Array[List[TypeParamRef]] | Null =
c.lowerMap(poly)
def updateEntries(c: OrderingConstraint, poly: TypeLambda, entries: Array[List[TypeParamRef]])(using Context): OrderingConstraint =
newConstraint(c.boundsMap, c.lowerMap.updated(poly, entries), c.upperMap)
newConstraint(c.boundsMap, c.lowerMap.updated(poly, entries), c.upperMap, c.hardVars)
def initial = Nil
}

val upperLens: ConstraintLens[List[TypeParamRef]] = new ConstraintLens[List[TypeParamRef]] {
def entries(c: OrderingConstraint, poly: TypeLambda): Array[List[TypeParamRef]] | Null =
c.upperMap(poly)
def updateEntries(c: OrderingConstraint, poly: TypeLambda, entries: Array[List[TypeParamRef]])(using Context): OrderingConstraint =
newConstraint(c.boundsMap, c.lowerMap, c.upperMap.updated(poly, entries))
newConstraint(c.boundsMap, c.lowerMap, c.upperMap.updated(poly, entries), c.hardVars)
def initial = Nil
}

@sharable
val empty = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty)
val empty = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentitySet.empty)
}

import OrderingConstraint._
Expand All @@ -134,10 +136,13 @@ import OrderingConstraint._
* @param upperMap a map from TypeLambdas to arrays. Each array entry corresponds
* to a parameter P of the type lambda; it contains all constrained parameters
* Q that are known to be greater than P, i.e. P <: Q.
* @param hardVars a set of type variables that are marked as hard and therefore will not
* undergo a `widenUnion` when instantiated to their lower bound.
*/
class OrderingConstraint(private val boundsMap: ParamBounds,
private val lowerMap : ParamOrdering,
private val upperMap : ParamOrdering) extends Constraint {
private val upperMap : ParamOrdering,
private val hardVars : TypeVars) extends Constraint {

import UnificationDirection.*

Expand Down Expand Up @@ -277,7 +282,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
val entries1 = new Array[Type](nparams * 2)
poly.paramInfos.copyToArray(entries1, 0)
tvars.copyToArray(entries1, nparams)
newConstraint(boundsMap.updated(poly, entries1), lowerMap, upperMap).init(poly)
newConstraint(boundsMap.updated(poly, entries1), lowerMap, upperMap, hardVars).init(poly)
}

/** Split dependent parameters off the bounds for parameters in `poly`.
Expand Down Expand Up @@ -478,7 +483,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
}
po.remove(pt).mapValuesNow(removeFromBoundss)
}
newConstraint(boundsMap.remove(pt), removeFromOrdering(lowerMap), removeFromOrdering(upperMap))
val hardVars1 = pt.paramRefs.foldLeft(hardVars)((hvs, param) => hvs - typeVarOfParam(param))
newConstraint(boundsMap.remove(pt), removeFromOrdering(lowerMap), removeFromOrdering(upperMap), hardVars1)
.checkNonCyclic()
}

Expand All @@ -505,7 +511,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
def swapKey[T](m: ArrayValuedMap[T]) =
val info = m(from)
if info == null then m else m.remove(from).updated(to, info)
var current = newConstraint(swapKey(boundsMap), swapKey(lowerMap), swapKey(upperMap))
var current = newConstraint(swapKey(boundsMap), swapKey(lowerMap), swapKey(upperMap), hardVars)
def subst[T <: Type](x: T): T = x.subst(from, to).asInstanceOf[T]
current.foreachParam {(p, i) =>
current = boundsLens.map(this, current, p, i, subst)
Expand All @@ -515,6 +521,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
constr.println(i"renamed $this to $current")
current.checkNonCyclic()

def isHard(tv: TypeVar) = hardVars.contains(tv)

def withHard(tv: TypeVar)(using Context) =
newConstraint(boundsMap, lowerMap, upperMap, hardVars + tv)

def instType(tvar: TypeVar): Type = entry(tvar.origin) match
case _: TypeBounds => NoType
case tp: TypeParamRef => typeVarOfParam(tp).orElse(tp)
Expand Down
69 changes: 39 additions & 30 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -485,33 +485,42 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
false
}

// If LHS is a hard union, constrain any type variables of the RHS with it as lower bound
// before splitting the LHS into its constituents. That way, the RHS variables are
// constraint by the hard union and can be instantiated to it. If we just split and add
// the two parts of the LHS separately to the constraint, the lower bound would become
// a soft union.
def constrainRHSVars(tp2: Type): Boolean = tp2.dealiasKeepRefiningAnnots match
case tp2: TypeParamRef if constraint contains tp2 => compareTypeParamRef(tp2)
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
case _ => true
/** Mark toplevel type vars in `tp2` as hard in the current constraint */
def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match
case tvar: TypeVar if constraint.contains(tvar.origin) =>
constraint = constraint.withHard(tvar)
case tp2: TypeParamRef if constraint.contains(tp2) =>
hardenTypeVars(constraint.typeVarOfParam(tp2))
case tp2: AndOrType =>
hardenTypeVars(tp2.tp1)
hardenTypeVars(tp2.tp2)
case _ =>

widenOK
|| joinOK
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
|| containsAnd(tp1)
&& !joined
&& {
joined = true
try inFrozenGadt(recur(tp1.join, tp2))
finally joined = false
}
// An & on the left side loses information. We compensate by also trying the join.
// This is less ad-hoc than it looks since we produce joins in type inference,
// and then need to check that they are indeed supertypes of the original types
// under -Ycheck. Test case is i7965.scala.
// On the other hand, we could get a combinatorial explosion by applying such joins
// recursively, so we do it only once. See i14870.scala as a test case, which would
// loop for a very long time without the recursion brake.
val res = widenOK || joinOK
|| recur(tp11, tp2) && recur(tp12, tp2)
|| containsAnd(tp1)
&& !joined
&& {
joined = true
try inFrozenGadt(recur(tp1.join, tp2))
finally joined = false
}
// An & on the left side loses information. We compensate by also trying the join.
// This is less ad-hoc than it looks since we produce joins in type inference,
// and then need to check that they are indeed supertypes of the original types
// under -Ycheck. Test case is i7965.scala.
// On the other hand, we could get a combinatorial explosion by applying such joins
// recursively, so we do it only once. See i14870.scala as a test case, which would
// loop for a very long time without the recursion brake.

if res && !tp1.isSoft && state.isCommittable then
// We use a heuristic here where every toplevel type variable on the right hand side
// is marked so that it converts all soft unions in its lower bound to hard unions
// before it is instantiated. The reason is that the variable's instance type will
// be a supertype of (decomposed and reconstituted) `tp1`.
hardenTypeVars(tp2)

res

case CapturingType(parent1, refs1) =>
if subCaptures(refs1, tp2.captureSet, frozenConstraint).isOK && sameBoxed(tp1, tp2, refs1)
Expand Down Expand Up @@ -2960,8 +2969,8 @@ object TypeComparer {
def subtypeCheckInProgress(using Context): Boolean =
comparing(_.subtypeCheckInProgress)

def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
comparing(_.instanceType(param, fromBelow, maxLevel))
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
comparing(_.instanceType(param, fromBelow, widenUnions, maxLevel))

def approximation(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
comparing(_.approximation(param, fromBelow, maxLevel))
Expand All @@ -2981,8 +2990,8 @@ object TypeComparer {
def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(using Context): Boolean =
comparing(_.addToConstraint(tl, tvars))

def widenInferred(inst: Type, bound: Type)(using Context): Type =
comparing(_.widenInferred(inst, bound))
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
comparing(_.widenInferred(inst, bound, widenUnions))

def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type =
comparing(_.dropTransparentTraits(tp, bound))
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,9 @@ object TypeOps:
override def apply(tp: Type): Type = tp match
case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) =>
val lo = TypeComparer.instanceType(
tp.origin, fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound)(using mapCtx)
tp.origin,
fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound,
widenUnions = tp.widenUnions)(using mapCtx)
val lo1 = apply(lo)
if (lo1 ne lo) lo1 else tp
case _ =>
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TyperState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,14 @@ class TyperState() {
constraint.contains(tl) || other.isRemovable(tl) || {
val tvars = tl.paramRefs.map(other.typeVarOfParam(_)).collect { case tv: TypeVar => tv }
if this.isCommittable then
tvars.foreach(tvar => if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar))
tvars.foreach(tvar =>
if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar))
typeComparer.addToConstraint(tl, tvars)
}) &&
// Integrate the additional constraints on type variables from `other`
// and merge hardness markers
constraint.uninstVars.forall(tv =>
if other.isHard(tv) then constraint = constraint.withHard(tv)
val p = tv.origin
val otherLos = other.lower(p)
val otherHis = other.upper(p)
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4714,12 +4714,15 @@ object Types {
* is also a singleton type.
*/
def instantiate(fromBelow: Boolean)(using Context): Type =
val tp = TypeComparer.instanceType(origin, fromBelow, nestingLevel)
val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)
if myInst.exists then // The line above might have triggered instantiation of the current type variable
myInst
else
instantiateWith(tp)

/** Widen unions when instantiating this variable in the current context? */
def widenUnions(using Context): Boolean = !ctx.typerState.constraint.isHard(this)

/** For uninstantiated type variables: the entry in the constraint (either bounds or
* provisional instance value)
*/
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1888,7 +1888,7 @@ class Namer { typer: Typer =>
TypeOps.simplify(tp.widenTermRefExpr,
if defaultTp.exists then TypeOps.SimplifyKeepUnchecked() else null) match
case ctp: ConstantType if sym.isInlineVal => ctp
case tp => TypeComparer.widenInferred(tp, pt)
case tp => TypeComparer.widenInferred(tp, pt, widenUnions = true)

// Replace aliases to Unit by Unit itself. If we leave the alias in
// it would be erased to BoxedUnit.
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val tparams = poly.paramRefs
val variances = childClass.typeParams.map(_.paramVarianceSign)
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
TypeComparer.instanceType(tparam, fromBelow = variance < 0)
TypeComparer.instanceType(tparam, fromBelow = variance < 0, widenUnions = true)
)
val instanceType = resType.substParams(poly, instanceTypes)
// this is broken in tests/run/i13332intersection.scala,
Expand Down
Loading

0 comments on commit bf03086

Please sign in to comment.