Skip to content

Commit

Permalink
Fix lampepfl#7044: Added GADT recovery with fallback to default error.
Browse files Browse the repository at this point in the history
  • Loading branch information
radeusgd committed Apr 22, 2020
1 parent e06eb26 commit 8d8a362
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 66 deletions.
27 changes: 8 additions & 19 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import collection.mutable
import scala.annotation.internal.sharable
import scala.annotation.threadUnsafe

import config.Printers.debug
import config.Printers.gadts

object Inferencing {

Expand Down Expand Up @@ -165,14 +165,16 @@ object Inferencing {
}

def approximateGADT(tp: Type)(implicit ctx: Context): Type = {
val map = new IsFullyDefinedAccumulator2
val map = new ApproximateGadtAccumulator
val res = map(tp)
assert(!map.failed)
debug.println(i"approximateGADT( $tp ) = $res // {${tp.toString}}")
res
}

private class IsFullyDefinedAccumulator2(implicit ctx: Context) extends TypeMap {
/** This class is mostly based on IsFullyDefinedAccumulator.
* It tries to approximate the given type based on the available GADT constraints.
*/
private class ApproximateGadtAccumulator(implicit ctx: Context) extends TypeMap {

var failed = false

Expand Down Expand Up @@ -201,7 +203,7 @@ object Inferencing {
val res =
ctx.gadt.approximation(sym, fromBelow = variance < 0)

debug.println(i"approximated $tp ~~ $res")
gadts.println(i"approximated $tp ~~ $res")

res

Expand All @@ -213,21 +215,8 @@ object Inferencing {
mapOver(tp)
}

// private class UpperInstantiator(implicit ctx: Context) extends TypeAccumulator[Unit] {
// def apply(x: Unit, tp: Type): Unit = {
// tp match {
// case tvar: TypeVar if !tvar.isInstantiated =>
// instantiate(tvar, fromBelow = false)
// case _ =>
// }
// foldOver(x, tp)
// }
// }

def process(tp: Type): Type = {
val res = apply(tp)
// if (res && toMaximize) new UpperInstantiator().apply((), tp)
res
apply(tp)
}
}

Expand Down
66 changes: 32 additions & 34 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import transform.TypeUtils._
import reporting.trace
import Nullables.{NotNullInfo, given _}
import NullOpsDecorator._
import config.Printers.debug
import config.Printers.gadts

object Typer {

Expand Down Expand Up @@ -2764,7 +2764,7 @@ class Typer extends Namer
*/
def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean = true)(using Context): Tree = {
val last = Thread.currentThread.getStackTrace()(2).toString;
trace/*.force*/(i"adapting (tryGadtHealing=$tryGadtHealing) $tree to $pt\n{callsite: $last}", typr, show = true) {
trace(i"adapting (tryGadtHealing=$tryGadtHealing) $tree to $pt\n{callsite: $last}", typr, show = true) {
record("adapt")
adapt1(tree, pt, locked, tryGadtHealing)
}
Expand All @@ -2774,7 +2774,6 @@ class Typer extends Namer
adapt(tree, pt, ctx.typerState.ownedVars)

private def adapt1(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = {
// assert(pt.exists && !pt.isInstanceOf[ExprType])
assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported)
def methodStr = err.refStr(methPart(tree).tpe)

Expand Down Expand Up @@ -3224,19 +3223,15 @@ class Typer extends Namer
}

def adaptToSubType(wtp: Type): Tree = {
debug.println("adaptToSubType")
debug.println("// try converting a constant to the target type")
// try converting a constant to the target type
val folded = ConstFold(tree, pt)
if (folded ne tree)
return adaptConstant(folded, folded.tpe.asInstanceOf[ConstantType])

debug.println("// Try to capture wildcards in type")
val captured = captureWildcards(wtp)
if (captured `ne` wtp)
return readapt(tree.cast(captured))

debug.println("// drop type if prototype is Unit")
// drop type if prototype is Unit
if (pt isRef defn.UnitClass) {
// local adaptation makes sure every adapted tree conforms to its pt
Expand All @@ -3246,7 +3241,6 @@ class Typer extends Namer
return tpd.Block(tree1 :: Nil, Literal(Constant(())))
}

debug.println("// convert function literal to SAM closure")
// convert function literal to SAM closure
tree match {
case closure(Nil, id @ Ident(nme.ANON_FUN), _)
Expand All @@ -3264,28 +3258,28 @@ class Typer extends Namer
case _ =>
}

debug.println("// try GADT approximation")
val foo = Inferencing.approximateGADT(wtp)
debug.println(
i"""
foo = $foo
val approximation = Inferencing.approximateGADT(wtp)
gadts.println(
i"""GADT approximation {
approximation = $approximation
pt.isInstanceOf[SelectionProto] = ${pt.isInstanceOf[SelectionProto]}
ctx.gadt.nonEmpty = ${ctx.gadt.nonEmpty}
ctx.gadt = ${ctx.gadt.debugBoundsDescription}
pt.isMatchedBy = ${
if (pt.isInstanceOf[SelectionProto])
pt.asInstanceOf[SelectionProto].isMatchedBy(foo).toString
pt.asInstanceOf[SelectionProto].isMatchedBy(approximation).toString
else
"<not a SelectionProto>"
}
}
"""
)
pt match {
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(foo) =>
return tpd.Typed(tree, TypeTree(foo))
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(approximation) =>
return tpd.Typed(tree, TypeTree(approximation))
case _ => ;
}

debug.println("// try an extension method in scope")
// try an extension method in scope
pt match {
case SelectionProto(name, mbrType, _, _) =>
Expand All @@ -3303,32 +3297,36 @@ class Typer extends Namer
val app = tryExtension(using nestedCtx)
if (!app.isEmpty && !nestedCtx.reporter.hasErrors) {
nestedCtx.typerState.commit()
debug.println("returning ext meth in scope")
return ExtMethodApply(app)
}
case _ =>
}

debug.println("// try an implicit conversion")
// try an implicit conversion
val prevConstraint = ctx.typerState.constraint
def recover(failure: SearchFailureType) =
{
debug.println("recover")
if (isFullyDefined(wtp, force = ForceDegree.all) &&
ctx.typerState.constraint.ne(prevConstraint)) readapt(tree)
// else if ({
// debug.println(i"tryGadtHealing=$tryGadtHealing && \n\tctx.gadt.nonEmpty=${ctx.gadt.nonEmpty}")
// tryGadtHealing && ctx.gadt.nonEmpty
// })
// {
// debug.println("here")
// readapt(
// tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))),
// shouldTryGadtHealing = false,
// )
// }
else err.typeMismatch(tree, pt, failure)
if (isFullyDefined(wtp, force = ForceDegree.all) &&
ctx.typerState.constraint.ne(prevConstraint)) readapt(tree)
else if (tryGadtHealing && ctx.gadt.nonEmpty)
{
// try recovering with a GADT approximation
val nestedCtx = ctx.fresh.setNewTyperState()
val res =
readapt(
tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))),
shouldTryGadtHealing = false,
)(using nestedCtx)
if (!nestedCtx.reporter.hasErrors) {
// GADT recovery successful
nestedCtx.typerState.commit()
res
} else {
// otherwise fail with the error that would have been reported without the GADT recovery
err.typeMismatch(tree, pt, failure)
}
}
else err.typeMismatch(tree, pt, failure)
}
if ctx.mode.is(Mode.ImplicitsEnabled) && tree.typeOpt.isValueType then
if pt.isRef(defn.AnyValClass) || pt.isRef(defn.ObjectClass) then
Expand Down
13 changes: 0 additions & 13 deletions tests/neg/boundspropagation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,6 @@ object test3 {
}
}

// Example contributed by Jason.
object test4 {
class Base {
type N

class Tree[-S, -T >: Option[S]]

def g(x: Any): Tree[_, _ <: Option[N]] = x match {
case y: Tree[_, _] => y // error -- used to work (because of capture conversion?)
}
}
}

class Test5 {
"": ({ type U = this.type })#U // error
}
13 changes: 13 additions & 0 deletions tests/pos/boundspropagation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,16 @@ object test2 {
}
}
*/

// Example contributed by Jason.
object test2 {
class Base {
type N

class Tree[-S, -T >: Option[S]]

def g(x: Any): Tree[_, _ <: Option[N]] = x match {
case y: Tree[_, _] => y
}
}
}
15 changes: 15 additions & 0 deletions tests/pos/gadt-infer-ascription.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// test based on an example code by @Blaisorblade
object GadtAscription {
enum Var[G, A] {
case Z[G, A]() extends Var[(A, G), A]
case S[G, A, B](x: Var[G, A]) extends Var[(B, G), A]
}

import Var._
def evalVar[G, A](x: Var[G, A])(rho: G): A = x match {
case _: Z[g, a] =>
rho(0)
case s: S[g, a, b] =>
evalVar(s.x)(rho(1))
}
}
14 changes: 14 additions & 0 deletions tests/pos/i7044.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
object i7044 {
case class Seg[T](pat:Pat[T], body:T)

trait Pat[T]
object Pat {
case class Expr() extends Pat[Int]
case class Opt[S](el:Pat[S]) extends Pat[Option[S]]
}

def test[T](s:Seg[T]):Int = s match {
case Seg(Pat.Expr(),body) => body + 1
case Seg(Pat.Opt(Pat.Expr()),body) => body.get
}
}

0 comments on commit 8d8a362

Please sign in to comment.