diff --git a/compiler/src/dotty/tools/dotc/config/Printers.scala b/compiler/src/dotty/tools/dotc/config/Printers.scala index 0c888b849857..b1f2afe6e463 100644 --- a/compiler/src/dotty/tools/dotc/config/Printers.scala +++ b/compiler/src/dotty/tools/dotc/config/Printers.scala @@ -20,6 +20,7 @@ object Printers { val dottydoc: Printer = noPrinter val exhaustivity: Printer = noPrinter val gadts: Printer = noPrinter + val gadtsConstr: Printer = noPrinter val hk: Printer = noPrinter val implicits: Printer = noPrinter val implicitsDetailed: Printer = noPrinter diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index aff9883cf294..e8e7b534dc3e 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -18,15 +18,18 @@ import config.Printers.{constr, typr} * By comparison: Constraint handlers are parts of type comparers and can use their functionality. * Constraint handlers update the current constraint as a side effect. */ -trait ConstraintHandling { +trait ConstraintHandling[AbstractContext] { - implicit val ctx: Context + def constr_println(msg: => String): Unit = constr.println(msg) + def typr_println(msg: => String): Unit = typr.println(msg) - protected def isSubType(tp1: Type, tp2: Type): Boolean - protected def isSameType(tp1: Type, tp2: Type): Boolean + implicit def ctx(implicit ac: AbstractContext): Context - val state: TyperState - import state.constraint + protected def isSubType(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean + protected def isSameType(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean + + protected def constraint: Constraint + protected def constraint_=(c: Constraint): Unit private[this] var addConstraintInvocations = 0 @@ -50,7 +53,20 @@ trait ConstraintHandling { */ protected var comparedTypeLambdas: Set[TypeLambda] = Set.empty - protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean): Boolean = + /** Gives for each instantiated type var that does not yet have its `inst` field + * set, the instance value stored in the constraint. Storing instances in constraints + * is done only in a temporary way for contexts that may be retracted + * without also retracting the type var as a whole. + */ + def instType(tvar: TypeVar): Type = constraint.entry(tvar.origin) match { + case _: TypeBounds => NoType + case tp: TypeParamRef => + var tvar1 = constraint.typeVarOfParam(tp) + if (tvar1.exists) tvar1 else tp + case tp => tp + } + + protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(implicit actx: AbstractContext): Boolean = !constraint.contains(param) || { def occursIn(bound: Type): Boolean = { val b = bound.dealias @@ -100,34 +116,34 @@ trait ConstraintHandling { private def location(implicit ctx: Context) = "" // i"in ${ctx.typerState.stateChainStr}" // use for debugging - protected def addUpperBound(param: TypeParamRef, bound: Type): Boolean = { + protected def addUpperBound(param: TypeParamRef, bound: Type)(implicit actx: AbstractContext): Boolean = { def description = i"constraint $param <: $bound to\n$constraint" if (bound.isRef(defn.NothingClass) && ctx.typerState.isGlobalCommittable) { def msg = s"!!! instantiated to Nothing: $param, constraint = ${constraint.show}" if (Config.failOnInstantiationToNothing) assert(false, msg) else ctx.log(msg) } - constr.println(i"adding $description$location") + constr_println(i"adding $description$location") val lower = constraint.lower(param) val res = addOneBound(param, bound, isUpper = true) && lower.forall(addOneBound(_, bound, isUpper = true)) - constr.println(i"added $description = $res$location") + constr_println(i"added $description = $res$location") res } - protected def addLowerBound(param: TypeParamRef, bound: Type): Boolean = { + protected def addLowerBound(param: TypeParamRef, bound: Type)(implicit actx: AbstractContext): Boolean = { def description = i"constraint $param >: $bound to\n$constraint" - constr.println(i"adding $description") + constr_println(i"adding $description") val upper = constraint.upper(param) val res = addOneBound(param, bound, isUpper = false) && upper.forall(addOneBound(_, bound, isUpper = false)) - constr.println(i"added $description = $res$location") + constr_println(i"added $description = $res$location") res } - protected def addLess(p1: TypeParamRef, p2: TypeParamRef): Boolean = { + protected def addLess(p1: TypeParamRef, p2: TypeParamRef)(implicit actx: AbstractContext): Boolean = { def description = i"ordering $p1 <: $p2 to\n$constraint" val res = if (constraint.isLess(p2, p1)) unify(p2, p1) @@ -136,20 +152,20 @@ trait ConstraintHandling { val up2 = p2 :: constraint.exclusiveUpper(p2, p1) val lo1 = constraint.nonParamBounds(p1).lo val hi2 = constraint.nonParamBounds(p2).hi - constr.println(i"adding $description down1 = $down1, up2 = $up2$location") + constr_println(i"adding $description down1 = $down1, up2 = $up2$location") constraint = constraint.addLess(p1, p2) down1.forall(addOneBound(_, hi2, isUpper = true)) && up2.forall(addOneBound(_, lo1, isUpper = false)) } - constr.println(i"added $description = $res$location") + constr_println(i"added $description = $res$location") res } /** Make p2 = p1, transfer all bounds of p2 to p1 * @pre less(p1)(p2) */ - private def unify(p1: TypeParamRef, p2: TypeParamRef): Boolean = { - constr.println(s"unifying $p1 $p2") + private def unify(p1: TypeParamRef, p2: TypeParamRef)(implicit actx: AbstractContext): Boolean = { + constr_println(s"unifying $p1 $p2") assert(constraint.isLess(p1, p2)) val down = constraint.exclusiveLower(p2, p1) val up = constraint.exclusiveUpper(p1, p2) @@ -163,7 +179,7 @@ trait ConstraintHandling { } - protected def isSubType(tp1: Type, tp2: Type, whenFrozen: Boolean): Boolean = { + protected def isSubType(tp1: Type, tp2: Type, whenFrozen: Boolean)(implicit actx: AbstractContext): Boolean = { if (whenFrozen) isSubTypeWhenFrozen(tp1, tp2) else @@ -182,13 +198,13 @@ trait ConstraintHandling { } } - final def isSubTypeWhenFrozen(tp1: Type, tp2: Type): Boolean = inFrozenConstraint(isSubType(tp1, tp2)) - final def isSameTypeWhenFrozen(tp1: Type, tp2: Type): Boolean = inFrozenConstraint(isSameType(tp1, tp2)) + final def isSubTypeWhenFrozen(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean = inFrozenConstraint(isSubType(tp1, tp2)) + final def isSameTypeWhenFrozen(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean = inFrozenConstraint(isSameType(tp1, tp2)) /** Test whether the lower bounds of all parameters in this * constraint are a solution to the constraint. */ - protected final def isSatisfiable: Boolean = + protected final def isSatisfiable(implicit actx: AbstractContext): Boolean = constraint.forallParams { param => val TypeBounds(lo, hi) = constraint.entry(param) isSubType(lo, hi) || { @@ -207,7 +223,7 @@ trait ConstraintHandling { * @return the instantiating type * @pre `param` is in the constraint's domain. */ - final def approximation(param: TypeParamRef, fromBelow: Boolean): Type = { + final def approximation(param: TypeParamRef, fromBelow: Boolean)(implicit actx: AbstractContext): Type = { val avoidParam = new TypeMap { override def stopAtStatic = true def avoidInArg(arg: Type): Type = @@ -247,7 +263,7 @@ trait ConstraintHandling { case _: TypeBounds => val bound = if (fromBelow) constraint.fullLowerBound(param) else constraint.fullUpperBound(param) val inst = avoidParam(bound) - typr.println(s"approx ${param.show}, from below = $fromBelow, bound = ${bound.show}, inst = ${inst.show}") + typr_println(s"approx ${param.show}, from below = $fromBelow, bound = ${bound.show}, inst = ${inst.show}") inst case inst => assert(inst.exists, i"param = $param\nconstraint = $constraint") @@ -261,7 +277,7 @@ trait ConstraintHandling { * 2. If `tp` is a union type, yet upper bound is not a union type, * approximate the union type from above by an intersection of all common base types. */ - def widenInferred(tp: Type, bound: Type): Type = { + def widenInferred(tp: Type, bound: Type)(implicit actx: AbstractContext): Type = { def isMultiSingleton(tp: Type): Boolean = tp.stripAnnots match { case tp: SingletonType => true case AndType(tp1, tp2) => isMultiSingleton(tp1) | isMultiSingleton(tp2) @@ -294,7 +310,7 @@ trait ConstraintHandling { * a lower bound instantiation can be a singleton type only if the upper bound * is also a singleton type. */ - def instanceType(param: TypeParamRef, fromBelow: Boolean): Type = { + def instanceType(param: TypeParamRef, fromBelow: Boolean)(implicit actx: AbstractContext): Type = { val inst = approximation(param, fromBelow).simplified if (fromBelow) widenInferred(inst, constraint.fullUpperBound(param)) else inst } @@ -309,7 +325,7 @@ trait ConstraintHandling { * Both `c1` and `c2` are required to derive from constraint `pre`, possibly * narrowing it with further bounds. */ - protected final def subsumes(c1: Constraint, c2: Constraint, pre: Constraint): Boolean = + protected final def subsumes(c1: Constraint, c2: Constraint, pre: Constraint)(implicit actx: AbstractContext): Boolean = if (c2 eq pre) true else if (c1 eq pre) false else { @@ -323,7 +339,7 @@ trait ConstraintHandling { } /** The current bounds of type parameter `param` */ - def bounds(param: TypeParamRef): TypeBounds = { + def bounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds = { val e = constraint.entry(param) if (e.exists) e.bounds else { @@ -337,7 +353,7 @@ trait ConstraintHandling { * and propagate all bounds. * @param tvars See Constraint#add */ - def addToConstraint(tl: TypeLambda, tvars: List[TypeVar]): Boolean = + def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(implicit actx: AbstractContext): Boolean = checkPropagated(i"initialized $tl") { constraint = constraint.add(tl, tvars) tl.paramRefs.forall { param => @@ -346,7 +362,7 @@ trait ConstraintHandling { val lower = constraint.lower(param) val upper = constraint.upper(param) if (lower.nonEmpty && !bounds.lo.isRef(defn.NothingClass) || - upper.nonEmpty && !bounds.hi.isRef(defn.AnyClass)) constr.println(i"INIT*** $tl") + upper.nonEmpty && !bounds.hi.isRef(defn.AnyClass)) constr_println(i"INIT*** $tl") lower.forall(addOneBound(_, bounds.hi, isUpper = true)) && upper.forall(addOneBound(_, bounds.lo, isUpper = false)) case _ => @@ -365,7 +381,7 @@ trait ConstraintHandling { * This holds if `TypeVarsMissContext` is set unless `param` is a part * of a MatchType that is currently normalized. */ - final def assumedTrue(param: TypeParamRef): Boolean = + final def assumedTrue(param: TypeParamRef)(implicit actx: AbstractContext): Boolean = ctx.mode.is(Mode.TypevarsMissContext) && (caseLambda `ne` param.binder) /** Add constraint `param <: bound` if `fromBelow` is false, `param >: bound` otherwise. @@ -375,7 +391,7 @@ trait ConstraintHandling { * not be AndTypes and lower bounds may not be OrTypes. This is assured by the * way isSubType is organized. */ - protected def addConstraint(param: TypeParamRef, bound: Type, fromBelow: Boolean): Boolean = { + protected def addConstraint(param: TypeParamRef, bound: Type, fromBelow: Boolean)(implicit actx: AbstractContext): Boolean = { def description = i"constr $param ${if (fromBelow) ">:" else "<:"} $bound:\n$constraint" //checkPropagated(s"adding $description")(true) // DEBUG in case following fails checkPropagated(s"added $description") { @@ -491,7 +507,7 @@ trait ConstraintHandling { } /** Instantiate `param` to `tp` if the constraint stays satisfiable */ - protected def tryInstantiate(param: TypeParamRef, tp: Type): Boolean = { + protected def tryInstantiate(param: TypeParamRef, tp: Type)(implicit actx: AbstractContext): Boolean = { val saved = constraint constraint = if (addConstraint(param, tp, fromBelow = true) && @@ -501,7 +517,7 @@ trait ConstraintHandling { } /** Check that constraint is fully propagated. See comment in Config.checkConstraintsPropagated */ - def checkPropagated(msg: => String)(result: Boolean): Boolean = { + def checkPropagated(msg: => String)(result: Boolean)(implicit actx: AbstractContext): Boolean = { if (Config.checkConstraintsPropagated && result && addConstraintInvocations == 0) { inFrozenConstraint { for (p <- constraint.domainParams) { diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 7ea68b0ae816..7957471d3c13 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -480,7 +480,7 @@ object Contexts { def setTyper(typer: Typer): this.type = { this.scope = typer.scope; setTypeAssigner(typer) } def setImportInfo(importInfo: ImportInfo): this.type = { this.importInfo = importInfo; this } def setGadt(gadt: GADTMap): this.type = { this.gadt = gadt; this } - def setFreshGADTBounds: this.type = setGadt(new GADTMap(gadt.bounds)) + def setFreshGADTBounds: this.type = setGadt(gadt.fresh) def setSearchHistory(searchHistory: SearchHistory): this.type = { this.searchHistory = searchHistory; this } def setTypeComparerFn(tcfn: Context => TypeComparer): this.type = { this.typeComparer = tcfn(this); this } private def setMoreProperties(moreProperties: Map[Key[Any], Any]): this.type = { this.moreProperties = moreProperties; this } @@ -708,14 +708,204 @@ object Contexts { else assert(thread == Thread.currentThread(), "illegal multithreaded access to ContextBase") } - class GADTMap(initBounds: SimpleIdentityMap[Symbol, TypeBounds]) { - private[this] var myBounds = initBounds - def setBounds(sym: Symbol, b: TypeBounds): Unit = - myBounds = myBounds.updated(sym, b) - def bounds: SimpleIdentityMap[Symbol, TypeBounds] = myBounds + sealed abstract class GADTMap { + def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit + def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean + def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds + def contains(sym: Symbol)(implicit ctx: Context): Boolean + def debugBoundsDescription(implicit ctx: Context): String + def fresh: GADTMap } - @sharable object EmptyGADTMap extends GADTMap(SimpleIdentityMap.Empty) { - override def setBounds(sym: Symbol, b: TypeBounds): Unit = unsupported("EmptyGADTMap.setBounds") + final class SmartGADTMap private ( + private[this] var myConstraint: Constraint, + private[this] var mapping: SimpleIdentityMap[Symbol, TypeVar], + private[this] var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], + private[this] var boundCache: SimpleIdentityMap[Symbol, TypeBounds] + ) extends GADTMap with ConstraintHandling[Context] { + import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} + + def this() = this( + myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty), + mapping = SimpleIdentityMap.Empty, + reverseMapping = SimpleIdentityMap.Empty, + boundCache = SimpleIdentityMap.Empty + ) + + implicit override def ctx(implicit ctx: Context): Context = ctx + + override protected def constraint = myConstraint + override protected def constraint_=(c: Constraint) = myConstraint = c + + override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2) + override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2) + + override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym) + + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = try { + boundCache = SimpleIdentityMap.Empty + boundAdditionInProgress = true + @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { + case tv: TypeVar => + val inst = instType(tv) + if (inst.exists) stripInternalTypeVar(inst) else tv + case _ => tp + } + + def externalizedSubtype(tp1: Type, tp2: Type, isSubtype: Boolean): Boolean = { + val externalizedTp1 = removeTypeVars(tp1) + val externalizedTp2 = removeTypeVars(tp2) + + ( + if (isSubtype) externalizedTp1 frozen_<:< externalizedTp2 + else externalizedTp2 frozen_<:< externalizedTp1 + ).reporting({ res => + val descr = i"$externalizedTp1 frozen_${if (isSubtype) "<:<" else ">:>"} $externalizedTp2" + i"$descr = $res" + }, gadts) + } + + val symTvar: TypeVar = stripInternalTypeVar(tvar(sym)) match { + case tv: TypeVar => tv + case inst => + val externalizedInst = removeTypeVars(inst) + gadts.println(i"instantiated: $sym -> $externalizedInst") + return if (isUpper) isSubType(externalizedInst , bound) else isSubType(bound, externalizedInst) + } + + val internalizedBound = insertTypeVars(bound) + ( + stripInternalTypeVar(internalizedBound) match { + case boundTvar: TypeVar => + if (boundTvar eq symTvar) true + else if (isUpper) addLess(symTvar.origin, boundTvar.origin) + else addLess(boundTvar.origin, symTvar.origin) + case bound => + if (externalizedSubtype(symTvar, bound, isSubtype = !isUpper)) { + gadts.println(i"manually unifying $symTvar with $bound") + constraint = constraint.updateEntry(symTvar.origin, bound) + true + } + else if (isUpper) addUpperBound(symTvar.origin, bound) + else addLowerBound(symTvar.origin, bound) + } + ).reporting({ res => + val descr = if (isUpper) "upper" else "lower" + val op = if (isUpper) "<:" else ">:" + i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $internalizedBound )" + }, gadts) + } finally boundAdditionInProgress = false + + override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = { + mapping(sym) match { + case null => null + case tv => + def retrieveBounds: TypeBounds = { + val tb = constraint.fullBounds(tv.origin) + removeTypeVars(tb).asInstanceOf[TypeBounds] + } + ( + if (boundAdditionInProgress || ctx.mode.is(Mode.GADTflexible)) retrieveBounds + else boundCache(sym) match { + case tb: TypeBounds => tb + case null => + val bounds = retrieveBounds + boundCache = boundCache.updated(sym, bounds) + bounds + } + ).reporting({ res => + // i"gadt bounds $sym: $res" + "" + }, gadts) + } + } + + override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null + + override def fresh: GADTMap = new SmartGADTMap( + myConstraint, + mapping, + reverseMapping, + boundCache + ) + + // ---- Private ---------------------------------------------------------- + + private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = { + mapping(sym) match { + case tv: TypeVar => + tv + case null => + val res = { + import NameKinds.DepParamName + // avoid registering the TypeVar with TyperState / TyperState#constraint + // - we don't want TyperState instantiating these TypeVars + // - we don't want TypeComparer constraining these TypeVars + val poly = PolyType(DepParamName.fresh(sym.name.toTypeName) :: Nil)( + pt => TypeBounds.empty :: Nil, + pt => defn.AnyType) + new TypeVar(poly.paramRefs.head, creatorState = null) + } + gadts.println(i"GADTMap: created tvar $sym -> $res") + constraint = constraint.add(res.origin.binder, res :: Nil) + mapping = mapping.updated(sym, res) + reverseMapping = reverseMapping.updated(res.origin, sym) + res + } + } + + private def insertTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match { + case tp: TypeRef => + val sym = tp.typeSymbol + if (contains(sym)) tvar(sym) else tp + case _ => + (if (map != null) map else new TypeVarInsertingMap()).mapOver(tp) + } + private final class TypeVarInsertingMap(implicit ctx: Context) extends TypeMap { + override def apply(tp: Type): Type = insertTypeVars(tp, this) + } + + private def removeTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match { + case tpr: TypeParamRef => + reverseMapping(tpr) match { + case null => tpr + case sym => sym.typeRef + } + case tv: TypeVar => + reverseMapping(tv.origin) match { + case null => tv + case sym => sym.typeRef + } + case _ => + (if (map != null) map else new TypeVarRemovingMap()).mapOver(tp) + } + private final class TypeVarRemovingMap(implicit ctx: Context) extends TypeMap { + override def apply(tp: Type): Type = removeTypeVars(tp, this) + } + + private[this] var boundAdditionInProgress = false + + // ---- Debug ------------------------------------------------------------ + + override def constr_println(msg: => String): Unit = gadtsConstr.println(msg) + + override def debugBoundsDescription(implicit ctx: Context): String = { + val sb = new mutable.StringBuilder + sb ++= constraint.show + sb += '\n' + mapping.foreachBinding { case (sym, _) => + sb ++= i"$sym: ${bounds(sym)}\n" + } + sb.result + } + } + + @sharable object EmptyGADTMap extends GADTMap { + override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGADTMap.addEmptyBounds") + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.addBound") + override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null + override def contains(sym: Symbol)(implicit ctx: Context) = false + override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGADTMap" + override def fresh = new SmartGADTMap } } diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index f128f4e2022b..e310c197c132 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -325,8 +325,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds, private def order(current: This, param1: TypeParamRef, param2: TypeParamRef)(implicit ctx: Context): This = if (param1 == param2 || current.isLess(param1, param2)) this else { - assert(contains(param1)) - assert(contains(param2)) + assert(contains(param1), i"$param1") + assert(contains(param2), i"$param2") val newUpper = param2 :: exclusiveUpper(param2, param1) val newLower = param1 :: exclusiveLower(param1, param2) val current1 = (current /: newLower)(upperLens.map(this, _, _, newUpper ::: _)) diff --git a/compiler/src/dotty/tools/dotc/core/Symbols.scala b/compiler/src/dotty/tools/dotc/core/Symbols.scala index d6b6ed1efd1c..864407e51897 100644 --- a/compiler/src/dotty/tools/dotc/core/Symbols.scala +++ b/compiler/src/dotty/tools/dotc/core/Symbols.scala @@ -214,7 +214,11 @@ trait Symbols { this: Context => */ def newPatternBoundSymbol(name: Name, info: Type, pos: Position): Symbol = { val sym = newSymbol(owner, name, Case, info, coord = pos) - if (name.isTypeName) gadt.setBounds(sym, info.bounds) + if (name.isTypeName) { + val bounds = info.bounds + gadt.addBound(sym, bounds.lo, isUpper = false) + gadt.addBound(sym, bounds.hi, isUpper = true) + } sym } diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 60ad1963ee29..a4471c5ae2cf 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -17,14 +17,20 @@ import scala.util.control.NonFatal import typer.ProtoTypes.constrained import reporting.trace +final class AbsentContext +object AbsentContext { + implicit val absentContext: AbsentContext = new AbsentContext +} + /** Provides methods to compare types. */ -class TypeComparer(initctx: Context) extends ConstraintHandling { +class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { import TypeComparer._ - implicit val ctx: Context = initctx + implicit def ctx(implicit nc: AbsentContext): Context = initctx - val state: TyperState = ctx.typerState - import state.constraint + val state = ctx.typerState + def constraint: Constraint = state.constraint + def constraint_=(c: Constraint): Unit = state.constraint = c private[this] var pendingSubTypes: mutable.Set[(Type, Type)] = null private[this] var recCount = 0 @@ -105,8 +111,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling { true } - protected def gadtBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = ctx.gadt.bounds(sym) - protected def gadtSetBounds(sym: Symbol, b: TypeBounds): Unit = ctx.gadt.setBounds(sym, b) + protected def gadtBounds(sym: Symbol)(implicit ctx: Context) = ctx.gadt.bounds(sym) + protected def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = false) + protected def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = true) protected def typeVarInstance(tvar: TypeVar)(implicit ctx: Context): Type = tvar.underlying @@ -136,7 +143,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling { finally this.approx = saved } - protected def isSubType(tp1: Type, tp2: Type): Boolean = isSubType(tp1, tp2, NoApprox) + def isSubType(tp1: Type, tp2: Type)(implicit nc: AbsentContext): Boolean = isSubType(tp1, tp2, NoApprox) protected def recur(tp1: Type, tp2: Type): Boolean = trace(s"isSubType ${traceInfo(tp1, tp2)} $approx", subtyping) { @@ -183,7 +190,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling { def firstTry: Boolean = tp2 match { case tp2: NamedType => def compareNamed(tp1: Type, tp2: NamedType): Boolean = { - implicit val ctx = this.ctx + implicit val ctx: Context = this.ctx tp2.info match { case info2: TypeAlias => recur(tp1, info2.alias) case _ => tp1 match { @@ -738,9 +745,28 @@ class TypeComparer(initctx: Context) extends ConstraintHandling { isSubArgs(args1, args2, tp1, tparams) case tycon1: TypeRef => tycon2.dealiasKeepRefiningAnnots match { - case tycon2: TypeRef if tycon1.symbol == tycon2.symbol => + case tycon2: TypeRef => + val tycon1sym = tycon1.symbol + val tycon2sym = tycon2.symbol + + var touchedGADTs = false + def gadtBoundsContain(sym: Symbol, tp: Type): Boolean = { + touchedGADTs = true + val b = gadtBounds(sym) + b != null && inFrozenConstraint { + (b.lo =:= tp) && (b.hi =:= tp) + } + } + + val res = ( + tycon1sym == tycon2sym || + gadtBoundsContain(tycon1sym, tycon2) || + gadtBoundsContain(tycon2sym, tycon1) + ) && isSubType(tycon1.prefix, tycon2.prefix) && isSubArgs(args1, args2, tp1, tparams) + if (res && touchedGADTs) GADTused = true + res case _ => false } @@ -830,9 +856,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling { * tp1 <:< tp2 using fourthTry (this might instantiate params in tp1) * tp1 <:< app2 using isSubType (this might instantiate params in tp2) */ - def compareLower(tycon2bounds: TypeBounds, followSuperType: Boolean): Boolean = + def compareLower(tycon2bounds: TypeBounds, tyconIsTypeRef: Boolean): Boolean = if ((tycon2bounds.lo `eq` tycon2bounds.hi) && !tycon2bounds.isInstanceOf[MatchAlias]) - if (followSuperType) recur(tp1, tp2.superType) + if (tyconIsTypeRef) recur(tp1, tp2.superType) else isSubApproxHi(tp1, tycon2bounds.lo.applyIfParameterized(args2)) else fallback(tycon2bounds.lo) @@ -841,15 +867,13 @@ class TypeComparer(initctx: Context) extends ConstraintHandling { case param2: TypeParamRef => isMatchingApply(tp1) || canConstrain(param2) && canInstantiate(param2) || - compareLower(bounds(param2), followSuperType = false) + compareLower(bounds(param2), tyconIsTypeRef = false) case tycon2: TypeRef => isMatchingApply(tp1) || defn.isTypelevel_S(tycon2.symbol) && compareS(tp2, tp1, fromBelow = true) || { tycon2.info match { case info2: TypeBounds => - val gbounds2 = ctx.gadt.bounds(tycon2.symbol) - if (gbounds2 == null) compareLower(info2, followSuperType = true) - else compareLower(gbounds2 & info2, followSuperType = false) + compareLower(info2, tyconIsTypeRef = true) case info2: ClassInfo => tycon2.name.toString.startsWith("Tuple") && defn.isTupleType(tp2) && isSubType(tp1, tp2.toNestedPairs) || @@ -885,11 +909,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling { case tycon1: TypeRef => val sym = tycon1.symbol !sym.isClass && ( - defn.isTypelevel_S(sym) && compareS(tp1, tp2, fromBelow = false) || { - val gbounds1 = ctx.gadt.bounds(tycon1.symbol) - if (gbounds1 == null) recur(tp1.superType, tp2) - else recur((gbounds1.hi & tycon1.info.bounds.hi).applyIfParameterized(args1), tp2) - }) + defn.isTypelevel_S(sym) && compareS(tp1, tp2, fromBelow = false) || + recur(tp1.superType, tp2)) case tycon1: TypeProxy => recur(tp1.superType, tp2) case _ => @@ -1216,14 +1237,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling { val tparam = tr.symbol gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") if (bound.isRef(tparam)) false - else { - val oldBounds = gadtBounds(tparam) - val newBounds = - if (isUpper) TypeBounds(oldBounds.lo, oldBounds.hi & bound) - else TypeBounds(oldBounds.lo | bound, oldBounds.hi) - isSubType(newBounds.lo, newBounds.hi) && - { gadtSetBounds(tparam, newBounds); true } - } + else if (isUpper) gadtAddUpperBound(tparam, bound) + else gadtAddLowerBound(tparam, bound) } } @@ -1307,7 +1322,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling { // Type equality =:= /** Two types are the same if are mutual subtypes of each other */ - def isSameType(tp1: Type, tp2: Type): Boolean = + def isSameType(tp1: Type, tp2: Type)(implicit nc: AbsentContext): Boolean = if (tp1 eq NoType) false else if (tp1 eq tp2) true else isSubType(tp1, tp2) && isSubType(tp2, tp1) @@ -1771,6 +1786,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling { totalCount = 0 } } + + /** Returns last check's debug mode, if explicitly enabled. */ + def lastTrace(): String = "" } object TypeComparer { @@ -1803,10 +1821,18 @@ object TypeComparer { val NoApprox: ApproxState = new ApproxState(0) /** Show trace of comparison operations when performing `op` as result string */ - def explained[T](op: Context => T)(implicit ctx: Context): String = { + def explaining[T](say: String => Unit)(op: Context => T)(implicit ctx: Context): T = { val nestedCtx = ctx.fresh.setTypeComparerFn(new ExplainingTypeComparer(_)) - op(nestedCtx) - nestedCtx.typeComparer.toString + val res = op(nestedCtx) + say(nestedCtx.typeComparer.lastTrace()) + res + } + + /** Like [[explaining]], but returns the trace instead */ + def explained[T](op: Context => T)(implicit ctx: Context): String = { + var trace: String = null + explaining(trace = _)(op) + trace } } @@ -1815,12 +1841,12 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) { val footprint: mutable.Set[Type] = mutable.Set[Type]() - override def bounds(param: TypeParamRef): TypeBounds = { + override def bounds(param: TypeParamRef)(implicit nc: AbsentContext): TypeBounds = { if (param.binder `ne` caseLambda) footprint += param super.bounds(param) } - override def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean): Boolean = { + override def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(implicit nc: AbsentContext): Boolean = { if (param.binder `ne` caseLambda) footprint += param super.addOneBound(param, bound, isUpper) } @@ -1830,9 +1856,14 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) { super.gadtBounds(sym) } - override def gadtSetBounds(sym: Symbol, b: TypeBounds): Unit = { + override def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = { + footprint += sym.typeRef + super.gadtAddLowerBound(sym, b) + } + + override def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = { footprint += sym.typeRef - super.gadtSetBounds(sym, b) + super.gadtAddUpperBound(sym, b) } override def typeVarInstance(tvar: TypeVar)(implicit ctx: Context): Type = { @@ -1926,12 +1957,12 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) { super.glb(tp1, tp2) } - override def addConstraint(param: TypeParamRef, bound: Type, fromBelow: Boolean): Boolean = + override def addConstraint(param: TypeParamRef, bound: Type, fromBelow: Boolean)(implicit nc: AbsentContext): Boolean = traceIndented(i"add constraint $param ${if (fromBelow) ">:" else "<:"} $bound $frozenConstraint, constraint = ${ctx.typerState.constraint}") { super.addConstraint(param, bound, fromBelow) } override def copyIn(ctx: Context): ExplainingTypeComparer = new ExplainingTypeComparer(ctx) - override def toString: String = "Subtype trace:" + { try b.toString finally b.clear() } + override def lastTrace(): String = "Subtype trace:" + { try b.toString finally b.clear() } } diff --git a/compiler/src/dotty/tools/dotc/core/TyperState.scala b/compiler/src/dotty/tools/dotc/core/TyperState.scala index 5a2fb0052e0d..83efc1329350 100644 --- a/compiler/src/dotty/tools/dotc/core/TyperState.scala +++ b/compiler/src/dotty/tools/dotc/core/TyperState.scala @@ -79,19 +79,6 @@ class TyperState(previous: TyperState /* | Null */) { def ownedVars: TypeVars = myOwnedVars def ownedVars_=(vs: TypeVars): Unit = myOwnedVars = vs - /** Gives for each instantiated type var that does not yet have its `inst` field - * set, the instance value stored in the constraint. Storing instances in constraints - * is done only in a temporary way for contexts that may be retracted - * without also retracting the type var as a whole. - */ - def instType(tvar: TypeVar)(implicit ctx: Context): Type = constraint.entry(tvar.origin) match { - case _: TypeBounds => NoType - case tp: TypeParamRef => - var tvar1 = constraint.typeVarOfParam(tp) - if (tvar1.exists) tvar1 else tp - case tp => tp - } - /** The closest ancestor of this typer state (including possibly this typer state itself) * which is not yet committed, or which does not have a parent. */ @@ -173,7 +160,7 @@ class TyperState(previous: TyperState /* | Null */) { val toCollect = new mutable.ListBuffer[TypeLambda] constraint foreachTypeVar { tvar => if (!tvar.inst.exists) { - val inst = instType(tvar) + val inst = ctx.typeComparer.instType(tvar) if (inst.exists && (tvar.owningState.get eq this)) { tvar.inst = inst val lam = tvar.origin.binder diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index b101ada545cb..405ffe86a8d4 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -563,7 +563,7 @@ object Types { case _ => go(tp.superType) } - case tp: ThisType => + case tp: ThisType => // ??? inline goThis(tp) case tp: RefinedType => if (name eq tp.refinedName) goRefined(tp) else go(tp.parent) @@ -3597,7 +3597,7 @@ object Types { private[core] def inst: Type = myInst private[core] def inst_=(tp: Type): Unit = { myInst = tp - if (tp.exists) { + if (tp.exists && (owningState ne null)) { owningState.get.ownedVars -= this owningState = null // no longer needed; null out to avoid a memory leak } @@ -3606,13 +3606,14 @@ object Types { /** The state owning the variable. This is at first `creatorState`, but it can * be changed to an enclosing state on a commit. */ - private[core] var owningState: WeakReference[TyperState] = new WeakReference(creatorState) + private[core] var owningState: WeakReference[TyperState] = + if (creatorState == null) null else new WeakReference(creatorState) /** The instance type of this variable, or NoType if the variable is currently * uninstantiated */ def instanceOpt(implicit ctx: Context): Type = - if (inst.exists) inst else ctx.typerState.instType(this) + if (inst.exists) inst else ctx.typeComparer.instType(this) /** Is the variable already instantiated? */ def isInstantiated(implicit ctx: Context): Boolean = instanceOpt.exists @@ -3729,7 +3730,7 @@ object Types { def isBounded(tp: Type) = tp match { case tp: TypeParamRef => - case tp: TypeRef => ctx.gadt.bounds.contains(tp.symbol) + case tp: TypeRef => ctx.gadt.contains(tp.symbol) } def contextInfo(tp: Type): Type = tp match { diff --git a/compiler/src/dotty/tools/dotc/printing/Formatting.scala b/compiler/src/dotty/tools/dotc/printing/Formatting.scala index 1c4d9c3e1283..24c13791524b 100644 --- a/compiler/src/dotty/tools/dotc/printing/Formatting.scala +++ b/compiler/src/dotty/tools/dotc/printing/Formatting.scala @@ -168,7 +168,7 @@ object Formatting { s"is a reference to a value parameter" case sym: Symbol => val info = - if (ctx.gadt.bounds.contains(sym)) + if (ctx.gadt.contains(sym)) sym.info & ctx.gadt.bounds(sym) else sym.info @@ -189,7 +189,7 @@ object Formatting { case param: TermParamRef => false case skolem: SkolemType => true case sym: Symbol => - ctx.gadt.bounds.contains(sym) && ctx.gadt.bounds(sym) != TypeBounds.empty + ctx.gadt.contains(sym) && ctx.gadt.bounds(sym) != TypeBounds.empty case _ => assert(false, "unreachable") false diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index befc06112068..c49bf8ef0ad6 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -11,7 +11,7 @@ import NameKinds.UniqueName import util.Positions._ import util.{Stats, SimpleIdentityMap} import Decorators._ -import config.Printers.typr +import config.Printers.{gadts, typr} import annotation.tailrec import reporting._ import collection.mutable @@ -206,7 +206,9 @@ object Inferencing { } val widePt = if (ctx.scala2Mode || refinementIsInvariant(tp)) pt else widenVariantParams(pt) - tp <:< widePt + trace(i"constraining pattern type $tp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") { + tp <:< widePt + } } /** The list of uninstantiated type variables bound by some prefix of type `T` which diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 5b1c9fe47ec4..0b3b29b48fca 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -730,7 +730,11 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { for (tpt <- tpts) boundVars = getBoundVars(boundVars, tpt) case _ => } - for (bv <- boundVars) ctx.gadt.setBounds(bv, bv.info.bounds) + for (bv <- boundVars) { + val TypeBounds(lo, hi) = bv.info.bounds + ctx.gadt.addBound(bv, lo, isUpper = false) + ctx.gadt.addBound(bv, hi, isUpper = true) + } scrut <:< tpt.tpe && { for (bv <- boundVars) { bv.info = TypeAlias(ctx.gadt.bounds(bv).lo) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index dca8c7137111..262f1ac03917 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1023,8 +1023,7 @@ class Typer extends Namer def gadtContext(gadtSyms: Set[Symbol])(implicit ctx: Context): Context = { val gadtCtx = ctx.fresh.setFreshGADTBounds for (sym <- gadtSyms) - if (!gadtCtx.gadt.bounds.contains(sym)) - gadtCtx.gadt.setBounds(sym, TypeBounds.empty) + if (!gadtCtx.gadt.contains(sym)) gadtCtx.gadt.addEmptyBounds(sym) gadtCtx } @@ -1502,8 +1501,11 @@ class Typer extends Namer // that their type parameters are aliases of the class type parameters. // See pos/i941.scala rhsCtx = ctx.fresh.setFreshGADTBounds - (tparams1, sym.owner.typeParams).zipped.foreach ((tdef, tparam) => - rhsCtx.gadt.setBounds(tdef.symbol, TypeAlias(tparam.typeRef))) + (tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) => + val tr = tparam.typeRef + rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = false) + rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = true) + } } if (sym.isInlineMethod) rhsCtx = rhsCtx.addMode(Mode.InlineableBody) val rhs1 = typedExpr(ddef.rhs, tpt1.tpe)(rhsCtx) diff --git a/tests/neg/gadt-banal-nested.scala b/tests/neg/gadt-banal-nested.scala new file mode 100644 index 000000000000..3a098f69739a --- /dev/null +++ b/tests/neg/gadt-banal-nested.scala @@ -0,0 +1,12 @@ +object banal { + sealed trait T[A] + final case class StrLit(v: String) extends T[String] + final case class IntLit(v: Int) extends T[Int] + + def eval[A](t: T[A]): A = t match { + case _: T[a] => t match { + case StrLit(v) => v + case IntLit(_) => "" // error + } + } +} diff --git a/tests/neg/gadt-banal.scala b/tests/neg/gadt-banal.scala new file mode 100644 index 000000000000..a822c827d917 --- /dev/null +++ b/tests/neg/gadt-banal.scala @@ -0,0 +1,17 @@ +object banal { + final case class Box[A](a: A) + + sealed trait T[A] + final case class StrLit(v: String) extends T[String] + final case class IntLit(v: Int) extends T[Int] + + def evul[A](t: T[A]): A = t match { + case StrLit(v) => (v: Any) // error + case IntLit(v) => (??? : Nothing) + } + + def noeval[A](t: T[A]): Box[A] = t match { + case StrLit(v) => Box[Any](v) // error + case IntLit(v) => Box[Nothing](???) // error + } +} diff --git a/tests/neg/gadt-i4075.scala b/tests/neg/gadt-i4075.scala new file mode 100644 index 000000000000..7037fbc7d5d3 --- /dev/null +++ b/tests/neg/gadt-i4075.scala @@ -0,0 +1,8 @@ +object i4075 { + case class One[T](fst: T) + def bad[T](e: One[T]) = e match { + case foo: One[a] => + val nok: Nothing = foo.fst // error + ??? + } +} diff --git a/tests/neg/gadt-injectivity.scala b/tests/neg/gadt-injectivity.scala new file mode 100644 index 000000000000..d4f8373e102a --- /dev/null +++ b/tests/neg/gadt-injectivity.scala @@ -0,0 +1,17 @@ +object injectivity { + sealed trait EQ[A, B] + final case class Refl[A](u: Unit) extends EQ[A, A] + + def conform[A, B, C, D](a: A, b: B, eq: EQ[(A, B), (C, D)]): C = + eq match { + case _: Refl[a] => + val ab: (A, B) = (a, b) + val cd: (C, D) = ab + val bb: (B, B) = ab // error + val cc: (C, C) = cd // error + val dd: (D, D) = cd // error + val rab: a = ab + val rcd: a = cd + a + } +} diff --git a/tests/neg/gadt-uninjectivity.scala b/tests/neg/gadt-uninjectivity.scala new file mode 100644 index 000000000000..f1d0fc59000a --- /dev/null +++ b/tests/neg/gadt-uninjectivity.scala @@ -0,0 +1,24 @@ +object uninjectivity { + sealed trait EQ[A, B] + final case class Refl[T]() extends EQ[T, T] + + def absurd1[F[_], X, Y](eq: EQ[F[X], F[Y]], x: X): Y = eq match { + case Refl() => + x // should be an error + } + + def absurd2[F[_], G[_]](eq: EQ[F[Int], G[Int]], fi: F[Int], fs: F[String]): G[Int] = eq match { + case Refl() => + val gs: G[String] = fs // error + // fi + ??? + } + + def absurd3[F[_], G[_], X, Y](eq: EQ[F[X], G[Y]], fx: F[X]): G[Y] = eq match { + case Refl() => + val gx: G[X] = fx // error + val fy: F[Y] = fx // error + // fx + ??? + } +} diff --git a/tests/pos/gadt-EQK.scala b/tests/pos/gadt-EQK.scala new file mode 100644 index 000000000000..b713de2d833b --- /dev/null +++ b/tests/pos/gadt-EQK.scala @@ -0,0 +1,33 @@ +object EQK { + sealed trait EQ[A, B] + final case class Refl[A]() extends EQ[A, A] + + sealed trait EQK[F[_], G[_]] + final case class ReflK[F[_]]() extends EQK[F, F] + + def m0[F[_], G[_], A](fa: F[A], eqk: EQK[F, G]): G[A] = + eqk match { case ReflK() => fa } + + def m1[F[_], G[_], A](fa: F[A], eq: EQ[A, Int], eqk: EQK[F, G]): G[Int] = + eqk match { + case ReflK() => eq match { + case Refl() => + val r1: F[Int] = fa + val r2: G[A] = fa + val r3: F[Int] = r2 + fa : G[Int] + } + } + + def m2[F[_], G[_], A](fa: F[A], a: A, eq: EQ[F[A], G[Int]], eqk: EQK[F, G]): Int = + eqk match { + case ReflK() => eq match { + case Refl() => + val r1: F[Int] = fa + val r2: G[A] = fa + val r3: F[Int] = r2 + a + } + } + +} diff --git a/tests/pos/gadt-GadtStlc.scala b/tests/pos/gadt-GadtStlc.scala new file mode 100644 index 000000000000..32e6f5630d6b --- /dev/null +++ b/tests/pos/gadt-GadtStlc.scala @@ -0,0 +1,127 @@ +object GadtStlc { + // creates type-level "strings" like M[M[M[W]]] + object W + type W = W.type + class M[A] + + + // variable with name A + // Var[W] + // Var[M[W]] + sealed trait Var[A] + object VarW extends Var[W] + case class VarM[A] extends Var[M[A]] + + // \s.e + sealed trait Abs[S, E] + case class AbsC[S, E](v: Var[S], b: E) extends Abs[Var[S], E] + + // e1 e2 + case class App[E1, E2](e1: E1, e2: E2) + + // T1 -> T2 + case class TyFun[T1, T2](t1: T1, t2: T2) + + // arbitrary base literal + case object Lit + type Lit = Lit.type + + // arbitrary base type + case object TyBase + type TyBase = TyBase.type + + // IN[G, (X, TY)] === evidence that binding (X, TY) is in environment G + // x: ty \in G + sealed trait IN[G, P] + case class INOne[G, X, TY]() extends IN[(G, (X,TY)), (X, TY)] + // this is wrong - we need evidence that A does not contain a binding for X + case class INShift[G0, A, X, TY](in: IN[G0, (X, TY)]) extends IN[(G0, A), (X, TY)] + + // DER[G, E, TY] === evidence that G |- E : TY + sealed trait DER[G, E, TY] + case class DVar[G, G0, X, TY]( + in: IN[(G, G0), (Var[X], TY)] + ) extends DER[(G, G0), Var[X], TY] + + case class DApp[G, E1, E2, TY1, TY2]( + d1: DER[G, E1, TyFun[TY1, TY2]], + d2: DER[G, E2, TY1] + ) extends DER[G, App[E1, E2], TY2] + + case class DAbs[G, X, E, TY1, TY2]( + d1: DER[(G, (Var[X], TY1)), E, TY2] + ) extends DER[G, Abs[Var[X], E], TyFun[TY1, TY2]] + + case class DLit[G]() extends DER[G, Lit, TyBase] + + // forall G, a. G |- \x.x : a -> a + def test1[G, TY]: DER[G, Abs[Var[W], Var[W]], TyFun[TY, TY]] = + DAbs(DVar(INOne())) + + // forall G. G |- \x.x : Lit -> Lit + def test2[G]: DER[G, App[Abs[Var[W], Var[W]], Lit], TyBase] = + DApp(DAbs(DVar(INOne())), DLit()) + + // forall G, c. G |- \x.\y. x y : (c -> c) -> c -> c + def test3[G, TY]: DER[G, + Abs[Var[W], + Abs[Var[M[W]], + App[Var[W], Var[M[W]]] + ] + ], + TyFun[TyFun[TY, TY], TyFun[TY, TY]] + ] = DAbs(DAbs(DApp(DVar(INShift(INOne())), DVar(INOne())))) + + + // evidence that E is a value + sealed trait ISVAL[E] + case class ISVAL_Abs[X, E]() extends ISVAL[Abs[Var[X], E]] + case object ISVAL_Lit extends ISVAL[Lit] + + // evidence that E1 reduces to E2 + sealed trait REDUDER[E1, E2] + case class EApp1[E1a, E1b, E2]( + ed: REDUDER[E1a, E1b] + ) extends REDUDER[App[E1a, E2], App[E1b, E2]] + + case class EApp2[V1, E2a, E2b]( + isval: ISVAL[V1], + ed: REDUDER[E2a, E2b] + ) extends REDUDER[App[V1, E2a], App[V1, E2b]] + + case class EAppAbs[X, E, V2, R]( + isval: ISVAL[V2] + // cheating - subst is hard + // , subst: SUBST[E, X, V2, R] + ) extends REDUDER[App[Abs[Var[X], E], V2], R] + + // evidence that V is a lambda + sealed trait ISLAMBDA[V] + case class ISLAMBDAC[X, E]() extends ISLAMBDA[Abs[Var[X], E]] + + // evidence that E reduces + type REDUCES[E] = REDUDER[E, _] + + def followsIsLambda[G, V, TY1, TY2]( + isval: ISVAL[V], + der: DER[G, V, TyFun[TY1, TY2]] + ): ISLAMBDA[V] = (isval, der) match { + case (_: ISVAL_Abs[x, e], _) => ISLAMBDAC[x, e]() + } + + // \empty |- E : TY ==> E is a value /\ E reduces to some E1 + def progress[E, TY](der: DER[Unit, E, TY]): Either[ISVAL[E], REDUCES[E]] = + der match { + case _: DAbs[g, a, e, ty1, ty2] => Left(ISVAL_Abs[a, e]()) + case DLit() => Left(ISVAL_Lit) + case dapp: DApp[Unit, a, b, ty1, ty2] => progress(dapp.d1) match { + case Right(r1) => Right(EApp1[E2 = b](r1)) + case Left(isv1) => progress(dapp.d2) match { + case Right(r2) => Right(EApp2(isv1, r2)) + case Left(isv2) => followsIsLambda(isv1, dapp.d1) match { + case _: ISLAMBDAC[x, e] => Right(EAppAbs[x, e, b, _](isv2)) + } + } + } + } +} diff --git a/tests/pos/gadt-TailCalls.scala b/tests/pos/gadt-TailCalls.scala new file mode 100644 index 000000000000..8899216068ae --- /dev/null +++ b/tests/pos/gadt-TailCalls.scala @@ -0,0 +1,39 @@ +// package scala.util.control +object TailCalls { + + abstract class TailRec[+A] { + + final def flatMap[B](f: A => TailRec[B]): TailRec[B] = + this match { + case Done(a) => Call(() => f(a)) + case c@Call(_) => Cont(c, f) + case c: Cont[a1, b1] => Cont(c.a, (x: a1) => c.f(x) flatMap f) + } + + @annotation.tailrec final def resume: Either[() => TailRec[A], A] = this match { + case Done(a) => Right(a) + case Call(k) => Left(k) + case Cont(a, f) => a match { + case Done(v) => f(v).resume + case Call(k) => Left(() => k().flatMap(f)) + case Cont(b, g) => b.flatMap(x => g(x) flatMap f).resume + } + } + + @annotation.tailrec final def result: A = this match { + case Done(a) => a + case Call(t) => t().result + case Cont(a, f) => a match { + case Done(v) => f(v).result + case Call(t) => t().flatMap(f).result + case Cont(b, g) => b.flatMap(x => g(x) flatMap f).result + } + } + } + + protected case class Call[A](rest: () => TailRec[A]) extends TailRec[A] + + protected case class Done[A](value: A) extends TailRec[A] + + protected case class Cont[A, B](a: TailRec[A], f: A => TailRec[B]) extends TailRec[B] +} diff --git a/tests/pos/gadt-TypeSafeLambda.scala b/tests/pos/gadt-TypeSafeLambda.scala new file mode 100644 index 000000000000..b9cbbc50757e --- /dev/null +++ b/tests/pos/gadt-TypeSafeLambda.scala @@ -0,0 +1,109 @@ +object TypeSafeLambda { + + trait Category[Arr[_, _]] { + def id[A]: Arr[A, A] + def comp[A, B, C](ab: Arr[A, B], bc: Arr[B, C]): Arr[A, C] + } + + trait Terminal[Term, Arr[_, _]] extends Category[Arr] { + def terminal[A]: Arr[A, Term] + } + + trait ProductCategory[Prod[_, _], Arr[_, _]] extends Category[Arr] { + def first[A, B]: Arr[Prod[A, B], A] + def second[A, B]: Arr[Prod[A, B], B] + def pair[A, B, C](ab: Arr[A, B], ac: Arr[A, C]): Arr[A, Prod[B, C]] + } + + trait Exponential[Exp[_, _], Prod[_, _], Arr[_, _]] + extends ProductCategory[Prod, Arr] { + def eval[A, B]: Arr[Prod[Exp[A, B], A], B] + def curry[A, B, C](a: Arr[Prod[C, A], B]): Arr[C, Exp[A, B]] + } + + trait CartesianClosed[Term, Exp[_, _], Prod[_, _], Arr[_, _]] + extends Exponential[Exp, Prod, Arr] with Terminal[Term, Arr] + + sealed trait V[Prod[_, _], Env, R] + case class Zero[Prod[_, _], Env, R]() extends V[Prod, Prod[Env, R], R] + case class Succ[Prod[_, _], Env, R, X]( + v: V[Prod, Env, R] + ) extends V[Prod, Prod[Env, X], R] + + sealed trait Lambda[Terminal, Exp[_, _], Prod[_, _], Env, R] + + case class LUnit[Terminal, Exp[_, _], Prod[_, _], Env]() + extends Lambda[Terminal, Exp, Prod, Env, Terminal] + + case class Var[Terminal, Exp[_, _], Prod[_, _], Env, R]( + v: V[Prod, Env, R] + ) extends Lambda[Terminal, Exp, Prod, Env, R] + + case class Lam[Terminal, Exp[_, _], Prod[_, _], Env, R, A]( + lam: Lambda[Terminal, Exp, Prod, Prod[Env, A], R] + ) extends Lambda[Terminal, Exp, Prod, Env, Exp[A, R]] + + case class App[Terminal, Exp[_, _], Prod[_, _], Env, R, R_]( + f: Lambda[Terminal, Exp, Prod, Env, Exp[R, R_]], + a: Lambda[Terminal, Exp, Prod, Env, R] + ) extends Lambda[Terminal, Exp, Prod, Env, R_] + + def interp[Term, Exp[_, _], Prod[_, _], Arr[_, _], S, T]( + c: CartesianClosed[Term, Exp, Prod, Arr], + exp: Lambda[Term, Exp, Prod, S, T] + ): Arr[S, T] = exp match { + case LUnit() => c.terminal + case v: Var[t, e, var_p, env, r] => v.v match { + case _: Zero[z_p, z_env, z_r] => c.second[z_env, z_r] + case s: Succ[s_prod, s_env, s_r, x] => + c.comp[A = s_prod[s_env, x], C = s_r]( + c.first, + interp(c, Var[Term, Exp, Prod, s_env, s_r](s.v)) + ) + } + case Lam(lam) => c.curry(interp(c, lam)) + case app: App[t, e, p, env, r, r_] => + c.comp( + c.pair( + interp(c, app.f), + interp(c, app.a)), + c.eval[r, r_] + ) + } + + object example { + type Term = Unit + type Prod[A, B] = (A, B) + type Exp[A, B] = A => B + type Arr[A, B] = A => B + + val c = new CartesianClosed[Term, Exp, Prod, Arr] { + def id[A]: A => A = a => a + def comp[A, B, C](f: A => B, g: B => C): A => C = f andThen g + + def terminal[A]: A => Unit = a => () + + def first[A, B]: ((A, B)) => A = { case (a, _) => a } + def second[A, B]: ((A, B)) => B = { case (_, b) => b } + def pair[A, B, C](f: A => B, g: A => C): A => (B, C) = + a => (f(a), g(a)) + + def eval[A, B]: ((A => B, A)) => B = { case (f, a) => f(a) } + def curry[A, B, C](f: ((C, A)) => B): C => A => B = + c => a => f((c, a)) + } + + type Env = Unit Prod Int Prod (Int => String) + val exp = App[Term, Exp, Prod, Env, Int, String]( + // args to Var are RHS "indices" into Env + Var(Zero()), + Var(Succ(Zero())) + ) + + val interped: (Env) => String = + interp[Term, Exp, Prod, Arr, Env, String] (c, exp) + + interped((((), 1), { i: Int => i.toString })) : String // "1" + } + +} diff --git a/tests/pos/gadt-banal.scala b/tests/pos/gadt-banal.scala new file mode 100644 index 000000000000..a0efb73412fa --- /dev/null +++ b/tests/pos/gadt-banal.scala @@ -0,0 +1,15 @@ +object banal { + sealed trait T[A] + final case class StrLit(v: String) extends T[String] + final case class IntLit(v: Int) extends T[Int] + + def eval[A](t: T[A]): A = t match { + case StrLit(v) => v + case IntLit(v) => v + } + + def evul[A](t: T[A]): A = t match { + case StrLit(_) => "" + case IntLit(_) => 0 + } +} diff --git a/tests/pos/gadt-complexEQ.scala b/tests/pos/gadt-complexEQ.scala new file mode 100644 index 000000000000..991c838ad610 --- /dev/null +++ b/tests/pos/gadt-complexEQ.scala @@ -0,0 +1,24 @@ +object complexEQ { + sealed trait EQ[A, B] + final case class Refl[A]() extends EQ[A, A] + + def m[A, B, C, D](e1: EQ[A, (B, C)], e2: EQ[A, (C, D)], d: D): A = + e1 match { + case Refl() => e2 match { + case Refl() => + val r1: (B, B) = (d, d) + val r2: (C, C) = r1 + val r3: (D, D) = r1 + r1 + } + } + + def m2[Z, A, B, C, D](e0: EQ[Z, A], e1: EQ[A, (B, C)], e2: EQ[Z, (C, D)], d: D): Z = + (e0, e1, e2) match { + case (Refl(), Refl(), Refl()) => + val r1: (B, B) = (d, d) + val r2: (C, C) = r1 + val r3: (D, D) = r1 + r1 + } +} diff --git a/tests/pos/gadt-foo.scala b/tests/pos/gadt-foo.scala new file mode 100644 index 000000000000..4966ce2cf3a5 --- /dev/null +++ b/tests/pos/gadt-foo.scala @@ -0,0 +1,11 @@ +object foo { + sealed trait Exp[T] + case class Var[T](name: String) extends Exp[T] + + def env[T](x: Var[T]): T = ??? + + def eval[S](e: Exp[S]) = e match { + case v: Var[foo] => + env(v) + } +} diff --git a/tests/pos/gadt-simpleEQ.scala b/tests/pos/gadt-simpleEQ.scala new file mode 100644 index 000000000000..03b5c0b27e6b --- /dev/null +++ b/tests/pos/gadt-simpleEQ.scala @@ -0,0 +1,13 @@ +object simpleEQ { + sealed trait EQ[A, B] + final case class Refl[A](u: Unit) extends EQ[A, A] + + def conform[A, B](a: A, eq: EQ[A, B]): B = eq match { + case Refl(()) => a + } + + def conform2[A, B, C, D](a: A, b: B, eq: EQ[(A, B), (C, D)]): (C, D) = + eq match { + case Refl(()) => (a, b) + } +} diff --git a/tests/pos/i3666-gadt.scala b/tests/pos/i3666-gadt.scala new file mode 100644 index 000000000000..16614e9e949e --- /dev/null +++ b/tests/pos/i3666-gadt.scala @@ -0,0 +1,65 @@ +object i3666 { + sealed trait Exp[T] + case class Num(n: Int) extends Exp[Int] + case class Plus(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int] + case class Var[T](name: String) extends Exp[T] + case class Lambda[T, U](x: Var[T], e: Exp[U]) extends Exp[T => U] + case class App[T, U](f: Exp[T => U], e: Exp[T]) extends Exp[U] + + abstract class Env { outer => + def apply[T](x: Var[T]): T + + def + [T](xe: (Var[T], T)) = new Env { + def apply[T](x: Var[T]): T = + if (x == xe._1) xe._2.asInstanceOf[T] + else outer(x) + } + } + + object Env { + val empty = new Env { + def apply[T](x: Var[T]): T = ??? + } + } + + object Test { + + val exp = App(Lambda(Var[Int]("x"), Plus(Var[Int]("x"), Num(1))), Var[Int]("2")) + + def eval[T](e: Exp[T])(env: Env): T = e match { + case Num(n) => n + case Plus(e1, e2) => eval(e1)(env) + eval(e2)(env) + case v: Var[_] => env(v) + case Lambda(x: Var[s], e) => ((y: s) => eval(e)(env + (x -> y))) + case App(f, e) => eval(f)(env)(eval(e)(env)) + } + + eval(exp)(Env.empty) + } +} +// A HOAS well-typed interpreter +object i3666Hoas { + sealed trait Exp[T] + case class IntLit(n: Int) extends Exp[Int] + case class BooleanLit(b: Boolean) extends Exp[Boolean] + + case class GenLit[T](t: T) extends Exp[T] + case class Plus(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int] + case class Fun[S, T](f: Exp[S] => Exp[T]) extends Exp[S => T] + case class App[T, U](f: Exp[T => U], e: Exp[T]) extends Exp[U] + + + def eval[T](e: Exp[T]): T = e match { + case IntLit(n) => n + case BooleanLit(b) => b + case GenLit(t) => t + case Plus(e1, e2) => eval(e1) + eval(e2) + case f: Fun[s, t] => + (v: s) => eval(f.f(GenLit(v))) + case App(f, e) => eval(f)(eval(e)) + } + + val exp = App(Fun[S = Int](x => Plus(x, IntLit(1))), IntLit(2)) + + eval(exp) +} diff --git a/tests/pos/i4075-gadt.scala b/tests/pos/i4075-gadt.scala new file mode 100644 index 000000000000..4c49cee2d95b --- /dev/null +++ b/tests/pos/i4075-gadt.scala @@ -0,0 +1,12 @@ +object i4075 { + case class One[T](fst: T) + def bad[T](e: One[T]) = e match { + case foo: One[a] => + val t: T = e.fst + // val nok: Nothing = t // should not compile + val ok: a = t // does compile + One(ok) + } + + val one: One[Int] = bad(One(0)) +} diff --git a/tests/pos/i4176-gadt.scala b/tests/pos/i4176-gadt.scala new file mode 100644 index 000000000000..5c7efa6993c9 --- /dev/null +++ b/tests/pos/i4176-gadt.scala @@ -0,0 +1,36 @@ +object i4176 { + sealed trait TNat + case class TZero() extends TNat + case class TSucc[N <: TNat] extends TNat + + object TNatSum { + sealed trait TSum[M, N, R] + case class TSumZero[N]() extends TSum[TZero, N, N] + case class TSumM[M <: TNat, N, R <: TNat](sum: TSum[M, N, R]) extends TSum[TSucc[M], N, TSucc[R]] + } + import TNatSum._ + + implicit def tSumZero[N]: TSum[TZero, N, N] = + TSumZero() + implicit def tSumM[M <: TNat, N, R <: TNat](implicit sum: TSum[M, N, R]): TSum[TSucc[M], N, TSucc[R]] = + TSumM(sum) + + sealed trait Vec[T, N <: TNat] + case object VNil extends Vec[Nothing, TZero] // fails but in refchecks + case class VCons[T, N <: TNat](x: T, xs: Vec[T, N]) extends Vec[T, TSucc[N]] + + def append0[T, M <: TNat, N <: TNat, R <: TNat]($this: Vec[T, M], that: Vec[T, N])(implicit tsum: TSum[M, N, R]): Vec[T, R] = + ($this, tsum) match { + case (VNil, TSumZero()) => that + case (VCons(x, xs), TSumM(sum)) => VCons(x, append0(xs, that)(sum)) + } + + def append[T, M <: TNat, N <: TNat, R <: TNat]($this: Vec[T, M], that: Vec[T, N])(implicit tsum: TSum[M, N, R]): Vec[T, R] = + tsum match { + case TSumZero() => + $this match { case VNil => that } + case TSumM(sum) => + $this match { case VCons(x, xs) => VCons(x, append(xs, that)(sum)) } + } + +} diff --git a/tests/pos/i4471-gadt.scala b/tests/pos/i4471-gadt.scala new file mode 100644 index 000000000000..ff57a6a30891 --- /dev/null +++ b/tests/pos/i4471-gadt.scala @@ -0,0 +1,31 @@ +object i4471 { + sealed trait Shuffle[A1, A2] { + def andThen[A3](that: Shuffle[A2, A3]): Shuffle[A1, A3] = AndThen(this, that) + } + + case class Id[A]() extends Shuffle[A, A] + case class Swap[A, B]() extends Shuffle[(A, B), (B, A)] + case class AssocLR[A, B, C]() extends Shuffle[((A, B), C), (A, (B, C))] + case class AssocRL[A, B, C]() extends Shuffle[(A, (B, C)), ((A, B), C)] + case class Par[A1, B1, A2, B2](_1: Shuffle[A1, B1], _2: Shuffle[A2, B2]) extends Shuffle[(A1, A2), (B1, B2)] + case class AndThen[A1, A2, A3](_1: Shuffle[A1, A2], _2: Shuffle[A2, A3]) extends Shuffle[A1, A3] + + def rewrite3[A1, A2, A3, A4]( + op1: Shuffle[A1, A2], + op2: Shuffle[A2, A3], + op3: Shuffle[A3, A4] + ): Option[Shuffle[A1, A4]] = (op1, op2, op3) match { + case ( + _: Swap[x, y], + _: AssocRL[u, v, w], + op3_ : Par[p1, q1, p2, q2] + ) => op3_ match { + case Par(_: Swap[r, s], _: Id[p2_]) => + Some( + AssocLR[v, w, u]() andThen Par(Id[v](), Swap[w, u]()) andThen AssocRL[v, u, w]() + ) + case _ => None + } + case _ => None + } +} diff --git a/tests/pos/i5068-gadt.scala b/tests/pos/i5068-gadt.scala new file mode 100644 index 000000000000..b4fba06f3c53 --- /dev/null +++ b/tests/pos/i5068-gadt.scala @@ -0,0 +1,24 @@ +object i5068 { + case class Box[F[_]](value: F[Int]) + sealed trait IsK[F[_], G[_]] + final case class ReflK[F[_]]() extends IsK[F, F] + + def foo[F[_], G[_]](r: F IsK G, a: Box[F]): Box[G] = r match { case ReflK() => a } +} + +object i5068b { + type WeirdShape[A[_], B] = A[B] + // type WeirderShape[S[_[_], _], I, M] = Any + case class Box[ F[_[_[_], _], _, _[_]] ](value: F[WeirdShape, Int, Option]) + sealed trait IsK[F[_[_[_], _], _, _[_]], G[_[_[_], _], _, _[_]]] + final case class ReflK[ F[_[_[_], _], _, _[_]] ]() extends IsK[F, F] + + def foo[F[_[_[_], _], _, _[_]], G[_[_[_], _], _, _[_]]]( + r: F IsK G, + a: Box[F] + ): Box[G] = r match { case ReflK() => a } + + // def main(args: Array[String]): Unit = { + // println(foo(ReflK(), Box[WeirderShape](???))) + // } +} diff --git a/tests/pos/injectivity-gadt.scala b/tests/pos/injectivity-gadt.scala new file mode 100644 index 000000000000..525db2ce798c --- /dev/null +++ b/tests/pos/injectivity-gadt.scala @@ -0,0 +1,14 @@ +object injectivity { + sealed trait EQ[A, B] + final case class Refl[A]() extends EQ[A, A] + + def conform[A, B, C, D](a: A, b: B, eq: EQ[(A, B), (C, D)]): C = + eq match { + case _: Refl[a] => + val ab: (A, B) = (a, b) + val cd: (C, D) = ab + val rab: a = ab + val rcd: a = cd + a + } +} diff --git a/tests/run/gadt-injectivity-unsoundness.scala b/tests/run/gadt-injectivity-unsoundness.scala new file mode 100644 index 000000000000..192a82afb539 --- /dev/null +++ b/tests/run/gadt-injectivity-unsoundness.scala @@ -0,0 +1,19 @@ +object Test { + sealed trait EQ[A, B] + final case class Refl[T]() extends EQ[T, T] + + def absurd[F[_], X, Y](eq: EQ[F[X], F[Y]], x: X): Y = eq match { + case Refl() => x + } + + var ex: Exception = _ + try { + type Unsoundness[X] = Int + val s: String = absurd[Unsoundness, Int, String](Refl(), 0) + } catch { + case e: ClassCastException => ex = e + } + + def main(args: Array[String]) = + assert(ex != null) +}