Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gadt unification #5611

Merged
merged 8 commits into from
Dec 17, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ object Printers {
val dottydoc: Printer = noPrinter
val exhaustivity: Printer = noPrinter
val gadts: Printer = noPrinter
val gadtsConstr: Printer = noPrinter
val hk: Printer = noPrinter
val implicits: Printer = noPrinter
val implicitsDetailed: Printer = noPrinter
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ trait ConstraintHandling {
def constr_println(msg: => String): Unit = constr.println(msg)
odersky marked this conversation as resolved.
Show resolved Hide resolved
def typr_println(msg: => String): Unit = typr.println(msg)

implicit val ctx: Context
implicit def ctx: Context

protected def isSubType(tp1: Type, tp2: Type): Boolean
protected def isSameType(tp1: Type, tp2: Type): Boolean

val state: TyperState
import state.constraint
protected def constraint: Constraint
protected def constraint_=(c: Constraint): Unit

private[this] var addConstraintInvocations = 0

Expand Down
203 changes: 195 additions & 8 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ object Contexts {
def setTyper(typer: Typer): this.type = { this.scope = typer.scope; setTypeAssigner(typer) }
def setImportInfo(importInfo: ImportInfo): this.type = { this.importInfo = importInfo; this }
def setGadt(gadt: GADTMap): this.type = { this.gadt = gadt; this }
def setFreshGADTBounds: this.type = setGadt(new GADTMap(gadt.bounds))
def setFreshGADTBounds: this.type = setGadt(gadt.fresh)
def setSearchHistory(searchHistory: SearchHistory): this.type = { this.searchHistory = searchHistory; this }
def setTypeComparerFn(tcfn: Context => TypeComparer): this.type = { this.typeComparer = tcfn(this); this }
private def setMoreProperties(moreProperties: Map[Key[Any], Any]): this.type = { this.moreProperties = moreProperties; this }
Expand Down Expand Up @@ -708,14 +708,201 @@ object Contexts {
else assert(thread == Thread.currentThread(), "illegal multithreaded access to ContextBase")
}

class GADTMap(initBounds: SimpleIdentityMap[Symbol, TypeBounds]) {
private[this] var myBounds = initBounds
def setBounds(sym: Symbol, b: TypeBounds): Unit =
myBounds = myBounds.updated(sym, b)
def bounds: SimpleIdentityMap[Symbol, TypeBounds] = myBounds
sealed abstract class GADTMap {
def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean
def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds
def contains(sym: Symbol)(implicit ctx: Context): Boolean
def debugBoundsDescription(implicit ctx: Context): String
def fresh: GADTMap
}

@sharable object EmptyGADTMap extends GADTMap(SimpleIdentityMap.Empty) {
override def setBounds(sym: Symbol, b: TypeBounds): Unit = unsupported("EmptyGADTMap.setBounds")
final class SmartGADTMap private (
private[this] var myConstraint: Constraint,
private[this] var mapping: SimpleIdentityMap[Symbol, TypeVar],
private[this] var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol]
) extends GADTMap with ConstraintHandling {
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}

def this() = this(
myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty),
mapping = SimpleIdentityMap.Empty,
reverseMapping = SimpleIdentityMap.Empty
)

// TODO: clean up this dirty kludge
private[this] var myCtx: Context = null
abgruszecki marked this conversation as resolved.
Show resolved Hide resolved
implicit override def ctx = myCtx
@forceInline private[this] final def inCtx[T](_ctx: Context)(op: => T) = {
val savedCtx = myCtx
myCtx = _ctx
try op finally myCtx = savedCtx
}

override protected def constraint = myConstraint
override protected def constraint_=(c: Constraint) = myConstraint = c

override def isSubType(tp1: Type, tp2: Type): Boolean = ctx.typeComparer.isSubType(tp1, tp2)
override def isSameType(tp1: Type, tp2: Type): Boolean = ctx.typeComparer.isSameType(tp1, tp2)


override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym)

override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = inCtx(ctx) {
@annotation.tailrec def stripInst(tp: Type): Type = tp match {
abgruszecki marked this conversation as resolved.
Show resolved Hide resolved
case tv: TypeVar =>
val inst = instType(tv)
if (inst.exists) stripInst(inst) else tv
case _ => tp
}

abgruszecki marked this conversation as resolved.
Show resolved Hide resolved
def cautiousSubtype(tp1: Type, tp2: Type, isSubtype: Boolean): Boolean = {
val externalizedTp1 = removeTypeVars(tp1)
val externalizedTp2 = removeTypeVars(tp2)

def descr = {
def op = s"frozen_${if (isSubtype) "<:<" else ">:>"}"
i"$tp1 $op $tp2\n\t$externalizedTp1 $op $externalizedTp2"
}
// gadts.println(descr)

val res =
// TypeComparer.explain[Boolean](gadts.println) { implicit ctx =>
if (isSubtype) externalizedTp1 frozen_<:< externalizedTp2
else externalizedTp2 frozen_<:< externalizedTp1
abgruszecki marked this conversation as resolved.
Show resolved Hide resolved
// }

gadts.println(i"$descr = $res")
res
}

def unify(tv: TypeVar, tp: Type): Unit = {
gadts.println(i"manually unifying $tv with $tp")
constraint = constraint.updateEntry(tv.origin, tp)
}

val symTvar: TypeVar = stripInst(tvar(sym)) match {
case tv: TypeVar => tv
case inst =>
gadts.println(i"instantiated: $sym -> $inst")
// this is wrong in general, but "correct" due to a subtype check in TypeComparer#narrowGadtBounds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More explanations needed.

Copy link
Contributor Author

@abgruszecki abgruszecki Dec 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by inlining the code from TypeComparer#narrowGADTBounds. Inlining it did not work previously, but whatever code was causing that must've changed along the way.

return true
}

val internalizedBound = insertTypeVars(bound)
val res = stripInst(internalizedBound) match {
case boundTvar: TypeVar =>
if (boundTvar eq symTvar) true
else if (isUpper) addLess(symTvar.origin, boundTvar.origin)
else addLess(boundTvar.origin, symTvar.origin)
case bound =>
if (cautiousSubtype(symTvar, bound, isSubtype = !isUpper)) { unify(symTvar, bound); true }
else if (isUpper) addUpperBound(symTvar.origin, bound)
else addLowerBound(symTvar.origin, bound)
}

gadts.println {
val descr = if (isUpper) "upper" else "lower"
val op = if (isUpper) "<:" else ">:"
i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $internalizedBound )"
}
res
}

override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = inCtx(ctx) {
mapping(sym) match {
case null => null
case tv =>
val tb = constraint.fullBounds(tv.origin)
val res = removeTypeVars(tb).asInstanceOf[TypeBounds]
abgruszecki marked this conversation as resolved.
Show resolved Hide resolved
// gadts.println(i"gadt bounds $sym: $res")
res
}
}

override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null

override def fresh: GADTMap = new SmartGADTMap(
myConstraint,
mapping,
reverseMapping
)

// ---- Private ----------------------------------------------------------

private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = {
mapping(sym) match {
case tv: TypeVar =>
tv
case null =>
val res = {
import NameKinds.DepParamName
// avoid registering the TypeVar with TyperState / TyperState#constraint
// - we don't want TyperState instantiating these TypeVars
// - we don't want TypeComparer constraining these TypeVars
val poly = PolyType(DepParamName.fresh(sym.name.toTypeName) :: Nil)(
pt => TypeBounds.empty :: Nil,
pt => defn.AnyType)
new TypeVar(poly.paramRefs.head, creatorState = null)
}
gadts.println(i"GADTMap: created tvar $sym -> $res")
constraint = constraint.add(res.origin.binder, res :: Nil)
mapping = mapping.updated(sym, res)
reverseMapping = reverseMapping.updated(res.origin, sym)
res
}
}

private def insertTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match {
case tp: TypeRef =>
val sym = tp.typeSymbol
if (contains(sym)) tvar(sym) else tp
case _ =>
(if (map != null) map else new TypeVarInsertingMap()).mapOver(tp)
}
private final class TypeVarInsertingMap(implicit ctx: Context) extends TypeMap {
override def apply(tp: Type): Type = insertTypeVars(tp, this)
}

private def removeTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match {
case tpr: TypeParamRef =>
reverseMapping(tpr) match {
case null => tpr
case sym => sym.typeRef
}
case tv: TypeVar =>
reverseMapping(tv.origin) match {
case null => tv
case sym => sym.typeRef
}
case _ =>
(if (map != null) map else new TypeVarRemovingMap()).mapOver(tp)
}
private final class TypeVarRemovingMap(implicit ctx: Context) extends TypeMap {
override def apply(tp: Type): Type = removeTypeVars(tp, this)
}

// ---- Debug ------------------------------------------------------------

override def constr_println(msg: => String): Unit = gadtsConstr.println(msg)

override def debugBoundsDescription(implicit ctx: Context): String = {
val sb = new mutable.StringBuilder
sb ++= constraint.show
sb += '\n'
mapping.foreachBinding { case (sym, _) =>
sb ++= i"$sym: ${bounds(sym)}\n"
}
sb.result
}
}

@sharable object EmptyGADTMap extends GADTMap {
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGADTMap.addEmptyBounds")
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.addBound")
override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null
override def contains(sym: Symbol)(implicit ctx: Context) = false
override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGADTMap"
override def fresh = new SmartGADTMap
}
}
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
private def order(current: This, param1: TypeParamRef, param2: TypeParamRef)(implicit ctx: Context): This =
if (param1 == param2 || current.isLess(param1, param2)) this
else {
assert(contains(param1))
assert(contains(param2))
assert(contains(param1), i"$param1")
assert(contains(param2), i"$param2")
val newUpper = param2 :: exclusiveUpper(param2, param1)
val newLower = param1 :: exclusiveLower(param1, param2)
val current1 = (current /: newLower)(upperLens.map(this, _, _, newUpper ::: _))
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,11 @@ trait Symbols { this: Context =>
*/
def newPatternBoundSymbol(name: Name, info: Type, pos: Position): Symbol = {
val sym = newSymbol(owner, name, Case, info, coord = pos)
if (name.isTypeName) gadt.setBounds(sym, info.bounds)
if (name.isTypeName) {
val bounds = info.bounds
gadt.addBound(sym, bounds.lo, isUpper = false)
gadt.addBound(sym, bounds.hi, isUpper = true)
}
sym
}

Expand Down
63 changes: 51 additions & 12 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
import TypeComparer._
implicit val ctx: Context = initctx

val state: TyperState = ctx.typerState
import state.constraint
val state = ctx.typerState
def constraint: Constraint = state.constraint
def constraint_=(c: Constraint): Unit = state.constraint = c

private[this] var pendingSubTypes: mutable.Set[(Type, Type)] = null
private[this] var recCount = 0
Expand Down Expand Up @@ -105,8 +106,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
true
}

protected def gadtBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = ctx.gadt.bounds(sym)
protected def gadtSetBounds(sym: Symbol, b: TypeBounds): Unit = ctx.gadt.setBounds(sym, b)
protected def gadtBounds(sym: Symbol)(implicit ctx: Context) = ctx.gadt.bounds(sym)
protected def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = false)
protected def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = true)

protected def typeVarInstance(tvar: TypeVar)(implicit ctx: Context): Type = tvar.underlying

Expand Down Expand Up @@ -136,7 +138,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
finally this.approx = saved
}

protected def isSubType(tp1: Type, tp2: Type): Boolean = isSubType(tp1, tp2, NoApprox)
def isSubType(tp1: Type, tp2: Type): Boolean = isSubType(tp1, tp2, NoApprox)

protected def recur(tp1: Type, tp2: Type): Boolean = trace(s"isSubType ${traceInfo(tp1, tp2)} $approx", subtyping) {

Expand Down Expand Up @@ -738,9 +740,28 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
isSubArgs(args1, args2, tp1, tparams)
case tycon1: TypeRef =>
tycon2.dealiasKeepRefiningAnnots match {
case tycon2: TypeRef if tycon1.symbol == tycon2.symbol =>
case tycon2: TypeRef =>
val tycon1sym = tycon1.symbol
val tycon2sym = tycon2.symbol

var touchedGADTs = false
def gadtBoundsContain(sym: Symbol, tp: Type): Boolean = {
touchedGADTs = true
val b = gadtBounds(sym)
b != null && inFrozenConstraint {
(b.lo =:= tp) && (b.hi =:= tp)
}
}

val res = (
tycon1sym == tycon2sym ||
gadtBoundsContain(tycon1sym, tycon2) ||
gadtBoundsContain(tycon2sym, tycon1)
) &&
isSubType(tycon1.prefix, tycon2.prefix) &&
isSubArgs(args1, args2, tp1, tparams)
if (res && touchedGADTs) GADTused = true
res
case _ =>
false
}
Expand Down Expand Up @@ -1217,7 +1238,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
if (isUpper) TypeBounds(oldBounds.lo, oldBounds.hi & bound)
else TypeBounds(oldBounds.lo | bound, oldBounds.hi)
isSubType(newBounds.lo, newBounds.hi) &&
{ gadtSetBounds(tparam, newBounds); true }
(if (isUpper) gadtAddUpperBound(tparam, bound) else gadtAddLowerBound(tparam, bound))
}
}
}
Expand Down Expand Up @@ -1766,6 +1787,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
totalCount = 0
}
}

/** Returns last check's debug mode, if explicitly enabled. */
def lastTrace(): String = ""
}

object TypeComparer {
Expand Down Expand Up @@ -1797,11 +1821,21 @@ object TypeComparer {

val NoApprox: ApproxState = new ApproxState(0)

def explain[T](say: String => Unit)(op: Context => T)(implicit ctx: Context): T = {
val (res, explanation) = underlyingExplained(op)
odersky marked this conversation as resolved.
Show resolved Hide resolved
say(explanation)
res
}

/** Show trace of comparison operations when performing `op` as result string */
def explained[T](op: Context => T)(implicit ctx: Context): String = {
underlyingExplained(op)._2
}

private def underlyingExplained[T](op: Context => T)(implicit ctx: Context): (T, String) = {
val nestedCtx = ctx.fresh.setTypeComparerFn(new ExplainingTypeComparer(_))
op(nestedCtx)
nestedCtx.typeComparer.toString
val res = op(nestedCtx)
abgruszecki marked this conversation as resolved.
Show resolved Hide resolved
(res, nestedCtx.typeComparer.lastTrace())
}
}

Expand All @@ -1825,9 +1859,14 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
super.gadtBounds(sym)
}

override def gadtSetBounds(sym: Symbol, b: TypeBounds): Unit = {
override def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = {
footprint += sym.typeRef
super.gadtAddLowerBound(sym, b)
}

override def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = {
footprint += sym.typeRef
super.gadtSetBounds(sym, b)
super.gadtAddUpperBound(sym, b)
}

override def typeVarInstance(tvar: TypeVar)(implicit ctx: Context): Type = {
Expand Down Expand Up @@ -1928,5 +1967,5 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {

override def copyIn(ctx: Context): ExplainingTypeComparer = new ExplainingTypeComparer(ctx)

override def toString: String = "Subtype trace:" + { try b.toString finally b.clear() }
override def lastTrace(): String = "Subtype trace:" + { try b.toString finally b.clear() }
}
Loading