Skip to content

Commit

Permalink
Merge pull request #6398 from dotty-staging/intersection-based-gadts
Browse files Browse the repository at this point in the history
Intersection based gadts
  • Loading branch information
abgruszecki authored Jun 6, 2019
2 parents fe749ea + 01096ff commit 49dd34d
Show file tree
Hide file tree
Showing 27 changed files with 851 additions and 137 deletions.
11 changes: 8 additions & 3 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -336,16 +336,21 @@ trait ConstraintHandling[AbstractContext] {
* L2 <: L1, and
* U1 <: U2
*
* Both `c1` and `c2` are required to derive from constraint `pre`, possibly
* narrowing it with further bounds.
* Both `c1` and `c2` are required to derive from constraint `pre`, without adding
* any new type variables but possibly narrowing already registered ones with further bounds.
*/
protected final def subsumes(c1: Constraint, c2: Constraint, pre: Constraint)(implicit actx: AbstractContext): Boolean =
if (c2 eq pre) true
else if (c1 eq pre) false
else {
val saved = constraint
try
c2.forallParams(p =>
// We iterate over params of `pre`, instead of `c2` as the documentation may suggest.
// As neither `c1` nor `c2` can have more params than `pre`, this only matters in one edge case.
// Constraint#forallParams only iterates over params that can be directly constrained.
// If `c2` has, compared to `pre`, instantiated a param and we iterated over params of `c2`,
// we could miss that param being instantiated to an incompatible type in `c1`.
pre.forallParams(p =>
c1.contains(p) &&
c2.upper(p).forall(c1.isLess(p, _)) &&
isSubTypeWhenFrozen(c1.nonParamBounds(p), c2.nonParamBounds(p)))
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Mode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ object Mode {
/** We are in a pattern alternative */
val InPatternAlternative: Mode = newMode(7, "InPatternAlternative")

/** Infer GADT constraints during type comparisons `A <:< B` */
val GADTflexible: Mode = newMode(8, "GADTflexible")
/** Make subtyping checks instead infer constraints necessarily following from given subtyping relation.
*
* This enables changing [[GadtConstraint]] and alters how [[TypeComparer]] approximates constraints.
*/
val GadtConstraintInference: Mode = newMode(8, "GadtConstraintInference")

/** Assume -language:strictEquality */
val StrictEquality: Mode = newMode(9, "StrictEquality")
Expand Down
208 changes: 208 additions & 0 deletions compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package dotty.tools
package dotc
package core

import Decorators._
import Symbols._
import Types._
import Flags._
import dotty.tools.dotc.reporting.trace
import config.Printers._

trait PatternTypeConstrainer { self: TypeComparer =>

/** Derive type and GADT constraints that necessarily follow from a pattern with the given type matching
* a scrutinee of the given type.
*
* This function breaks down scrutinee and pattern types into subcomponents between which there must be
* a subtyping relationship, and derives constraints from those relationships. We have the following situation
* in case of a (dynamic) pattern match:
*
* StaticScrutineeType PatternType
* \ /
* DynamicScrutineeType
*
* In simple cases, it must hold that `PatternType <: StaticScrutineeType`:
*
* StaticScrutineeType
* | \
* | PatternType
* | /
* DynamicScrutineeType
*
* A good example of a situation where the above must hold is when static scrutinee type is the root of an enum,
* and the pattern is an unapply of a case class, or a case object literal (of that enum).
*
* In slightly more complex cases, we may need to upcast `StaticScrutineeType`:
*
* SharedPatternScrutineeSuperType
* / \
* StaticScrutineeType PatternType
* \ /
* DynamicScrutineeType
*
* This may be the case if the scrutinee is a singleton type or a path-dependent type. It is also the case
* for the following definitions:
*
* trait Expr[T]
* trait IntExpr extends Expr[T]
* trait Const[T] extends Expr[T]
*
* StaticScrutineeType = Const[T]
* PatternType = IntExpr
*
* Union and intersection types are an additional complication - if either scrutinee or pattern are a union type,
* then the above relationships only need to hold for the "leaves" of the types.
*
* Finally, if pattern type contains hk-types applied to concrete types (as opposed to type variables),
* or either scrutinee or pattern type contain type member refinements, the above relationships do not need
* to hold at all. Consider (where `T1`, `T2` are unrelated traits):
*
* StaticScrutineeType = { type T <: T1 }
* PatternType = { type T <: T2 }
*
* In the above situation, DynamicScrutineeType can equal { type T = T1 & T2 }, but there is no useful relationship
* between StaticScrutineeType and PatternType (nor any of their subcomponents). Similarly:
*
* StaticScrutineeType = Option[T1]
* PatternType = Some[T2]
*
* Again, DynamicScrutineeType may equal Some[T1 & T2], and there's no useful relationship between the static
* scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
* in which case the subtyping relationship "heals" the type.
*/
def constrainPatternType(pat: Type, scrut: Type): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {

def classesMayBeCompatible: Boolean = {
import Flags._
val patClassSym = pat.widenSingleton.classSymbol
val scrutClassSym = scrut.widenSingleton.classSymbol
!patClassSym.exists || !scrutClassSym.exists || {
if (patClassSym.is(Final)) patClassSym.derivesFrom(scrutClassSym)
else if (scrutClassSym.is(Final)) scrutClassSym.derivesFrom(patClassSym)
else if (!patClassSym.is(Flags.Trait) && !scrutClassSym.is(Flags.Trait))
patClassSym.derivesFrom(scrutClassSym) || scrutClassSym.derivesFrom(patClassSym)
else true
}
}

def stripRefinement(tp: Type): Type = tp match {
case tp: RefinedOrRecType => stripRefinement(tp.parent)
case tp => tp
}

def constrainUpcasted(scrut: Type): Boolean = trace(i"constrainUpcasted($scrut)", gadts) {
val upcasted: Type = scrut match {
case scrut: TypeRef if scrut.symbol.isClass =>
// we do not infer constraints following from all parents for performance reasons
// in principle however, if `A extends B, C`, then `A` can be treated as `B & C`
scrut.firstParent
case scrut @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass =>
val patClassSym = pat.classSymbol
// as above, we do not consider all parents for performance reasons
def firstParentSharedWithPat(tp: Type, tpClassSym: ClassSymbol): Symbol = {
var parents = tpClassSym.info.parents
parents match {
case first :: rest =>
if (first.classSymbol == defn.ObjectClass) parents = rest
case _ => ;
}
parents match {
case first :: _ =>
val firstClassSym = first.classSymbol.asClass
val res = if (patClassSym.derivesFrom(firstClassSym)) firstClassSym
else firstParentSharedWithPat(first, firstClassSym)
res
case _ => NoSymbol
}
}
val sym = firstParentSharedWithPat(tycon, tycon.symbol.asClass)
if (sym.exists) scrut.baseType(sym) else NoType
case scrut: TypeProxy => scrut.superType
case _ => NoType
}
if (upcasted.exists)
constrainSimplePatternType(pat, upcasted) || constrainUpcasted(upcasted)
else true
}

scrut.dealias match {
case OrType(scrut1, scrut2) =>
either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2))
case AndType(scrut1, scrut2) =>
constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2)
case scrut: RefinedOrRecType =>
constrainPatternType(pat, stripRefinement(scrut))
case scrut => pat.dealias match {
case OrType(pat1, pat2) =>
either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut))
case AndType(pat1, pat2) =>
constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut)
case scrut: RefinedOrRecType =>
constrainPatternType(stripRefinement(scrut), pat)
case pat =>
constrainSimplePatternType(pat, scrut) || classesMayBeCompatible && constrainUpcasted(scrut)
}
}
}

/** Constrain "simple" patterns (see `constrainPatternType`).
*
* This function attempts to modify pattern and scrutinee type s.t. the pattern must be a subtype of the scrutinee,
* or otherwise it cannot possibly match. In order to do that, we:
*
* 1. Rely on `constrainPatternType` to break the actual scrutinee/pattern types into subcomponents
* 2. Widen type parameters of scrutinee type that are not invariantly refined (see below) by the pattern type.
* 3. Wrap the pattern type in a skolem to avoid overconstraining top-level abstract types in scrutinee type
* 4. Check that `WidenedScrutineeType <: NarrowedPatternType`
*
* Importantly, note that the pattern type may contain type variables.
*
* ## Invariant refinement
* Essentially, we say that `D[B] extends C[B]` s.t. refines parameter `A` of `trait C[A]` invariantly if
* when `c: C[T]` and `c` is instance of `D`, then necessarily `c: D[T]`. This is violated if `A` is variant:
*
* trait C[+A]
* trait D[+B](val b: B) extends C[B]
* trait E extends D[Any](0) with C[String]
*
* `E` is a counter-example to the above - if `e: E`, then `e: C[String]` and `e` is instance of `D`, but
* it is false that `e: D[String]`! This is a problem if we're constraining a pattern like the below:
*
* def foo[T](c: C[T]): T = c match {
* case d: D[t] => d.b
* }
*
* It'd be unsound for us to say that `t <: T`, even though that follows from `D[t] <: C[T]`.
* Note, however, that if `D` was a final class, we *could* rely on that relationship.
* To support typical case classes, we also assume that this relationship holds for them and their parent traits.
* This is enforced by checking that classes inheriting from case classes do not extend the parent traits of those
* case classes without also appropriately extending the relevant case class
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
*/
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type): Boolean = {
def refinementIsInvariant(tp: Type): Boolean = tp match {
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
case tp: TypeProxy => refinementIsInvariant(tp.underlying)
case _ => false
}

def widenVariantParams = new TypeMap {
def apply(tp: Type) = mapOver(tp) match {
case tp @ AppliedType(tycon, args) =>
val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
if (tparam.paramVariance != 0) TypeBounds.empty else arg
)
tp.derivedAppliedType(tycon, args1)
case tp =>
tp
}
}

val widePt = if (ctx.scala2Mode || refinementIsInvariant(patternTp)) scrutineeTp else widenVariantParams(scrutineeTp)
val narrowTp = SkolemType(patternTp)
trace(i"constraining simple pattern type $narrowTp <:< $widePt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
isSubType(narrowTp, widePt)
}
}
}
29 changes: 23 additions & 6 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object AbsentContext {

/** Provides methods to compare types.
*/
class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] with PatternTypeConstrainer {
import TypeComparer._
implicit def ctx(implicit nc: AbsentContext): Context = initctx

Expand Down Expand Up @@ -141,6 +141,13 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
*/
private [this] var leftRoot: Type = _

/** Are we forbidden from recording GADT constraints?
*
* This flag is set when we're already in [[Mode.GadtConstraintInference]],
* to signify that we temporarily cannot record any GADT constraints.
*/
private[this] var frozenGadt = false

protected def isSubType(tp1: Type, tp2: Type, a: ApproxState): Boolean = {
val savedApprox = approx
val savedLeftRoot = leftRoot
Expand Down Expand Up @@ -840,8 +847,18 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
gadtBoundsContain(tycon1sym, tycon2) ||
gadtBoundsContain(tycon2sym, tycon1)
) &&
isSubType(tycon1.prefix, tycon2.prefix) &&
isSubArgs(args1, args2, tp1, tparams)
isSubType(tycon1.prefix, tycon2.prefix) && {
// check both tycons to deal with the case when they are equal b/c of GADT constraint
val tyconIsInjective = tycon1sym.isClass || tycon2sym.isClass
def checkSubArgs() = isSubArgs(args1, args2, tp1, tparams)
// we only record GADT constraints if tycon is guaranteed to be injective
if (tyconIsInjective) checkSubArgs()
else {
val savedFrozenGadt = frozenGadt
frozenGadt = true
try checkSubArgs() finally frozenGadt = savedFrozenGadt
}
}
if (res && touchedGADTs) GADTused = true
res
case _ =>
Expand Down Expand Up @@ -1227,8 +1244,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
* @see [[sufficientEither]] for the normal case
* @see [[necessaryEither]] for the GADTFlexible case
*/
private def either(op1: => Boolean, op2: => Boolean): Boolean =
if (ctx.mode.is(Mode.GADTflexible)) necessaryEither(op1, op2) else sufficientEither(op1, op2)
protected def either(op1: => Boolean, op2: => Boolean): Boolean =
if (ctx.mode.is(Mode.GadtConstraintInference)) necessaryEither(op1, op2) else sufficientEither(op1, op2)

/** Returns true iff the result of evaluating either `op1` or `op2` is true,
* trying at the same time to keep the constraint as wide as possible.
Expand Down Expand Up @@ -1476,7 +1493,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
*/
private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = {
val boundImprecise = approx.high || approx.low
ctx.mode.is(Mode.GADTflexible) && !frozenConstraint && !boundImprecise && {
ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint && !boundImprecise && {
val tparam = tr.symbol
gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}")
if (bound.isRef(tparam)) false
Expand Down
27 changes: 5 additions & 22 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1085,21 +1085,6 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>

def fromScala2x = unapplyFn.symbol.exists && (unapplyFn.symbol.owner is Scala2x)

/** Is `subtp` a subtype of `tp` or of some generalization of `tp`?
* The generalizations of a type T are the smallest set G such that
*
* - T is in G
* - If a typeref R in G represents a class or trait, R's superclass is in G.
* - If a type proxy P is not a reference to a class, P's supertype is in G
*/
def isSubTypeOfParent(subtp: Type, tp: Type)(implicit ctx: Context): Boolean =
if (constrainPatternType(subtp, tp)) true
else tp match {
case tp: TypeRef if tp.symbol.isClass => isSubTypeOfParent(subtp, tp.firstParent)
case tp: TypeProxy => isSubTypeOfParent(subtp, tp.superType)
case _ => false
}

unapplyFn.tpe.widen match {
case mt: MethodType if mt.paramInfos.length == 1 =>
val unapplyArgType = mt.paramInfos.head
Expand All @@ -1109,17 +1094,15 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
unapp.println(i"case 1 $unapplyArgType ${ctx.typerState.constraint}")
fullyDefinedType(unapplyArgType, "pattern selector", tree.span)
selType.dropAnnot(defn.UncheckedAnnot) // need to drop @unchecked. Just because the selector is @unchecked, the pattern isn't.
} else if (isSubTypeOfParent(unapplyArgType, selType)(ctx.addMode(Mode.GADTflexible))) {
} else {
// We ignore whether constraining the pattern succeeded.
// Constraining only fails if the pattern cannot possibly match,
// but useless pattern checks detect more such cases, so we simply rely on them instead.
ctx.addMode(Mode.GadtConstraintInference).typeComparer.constrainPatternType(unapplyArgType, selType)
val patternBound = maximizeType(unapplyArgType, tree.span, fromScala2x)
if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound)
unapp.println(i"case 2 $unapplyArgType ${ctx.typerState.constraint}")
unapplyArgType
} else {
unapp.println("Neither sub nor super")
unapp.println(TypeComparer.explained(implicit ctx => unapplyArgType <:< selType))
errorType(
ex"Pattern type $unapplyArgType is neither a subtype nor a supertype of selector type $selType",
tree.sourcePos)
}
val dummyArg = dummyTreeOfType(ownType)
val unapplyApp = typedExpr(untpd.TypedSplice(Apply(unapplyFn, dummyArg :: Nil)))
Expand Down
Loading

0 comments on commit 49dd34d

Please sign in to comment.