From 9941c2ad9d86705e26404517b29566ca799d62cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Wa=C5=9Bko?= Date: Wed, 8 Apr 2020 18:07:08 +0200 Subject: [PATCH] Fix lampepfl#7044: Added GADT recovery with fallback to default error. --- .../dotty/tools/dotc/typer/Inferencing.scala | 29 +++----- .../dotty/tools/dotc/typer/ProtoTypes.scala | 10 ++- .../src/dotty/tools/dotc/typer/Typer.scala | 71 ++++++++++--------- tests/neg/boundspropagation.scala | 13 ---- tests/pos/boundspropagation.scala | 13 ++++ tests/pos/gadt-infer-ascription.scala | 15 ++++ tests/pos/i7044.scala | 14 ++++ 7 files changed, 95 insertions(+), 70 deletions(-) create mode 100644 tests/pos/gadt-infer-ascription.scala create mode 100644 tests/pos/i7044.scala diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index f731eb1d4892..c6b241458ee1 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -19,7 +19,7 @@ import collection.mutable import scala.annotation.internal.sharable import scala.annotation.threadUnsafe -import config.Printers.debug +import config.Printers.gadts object Inferencing { @@ -165,14 +165,16 @@ object Inferencing { } def approximateGADT(tp: Type)(implicit ctx: Context): Type = { - val map = new IsFullyDefinedAccumulator2 + val map = new ApproximateGadtAccumulator val res = map(tp) assert(!map.failed) - debug.println(i"approximateGADT( $tp ) = $res // {${tp.toString}}") res } - private class IsFullyDefinedAccumulator2(implicit ctx: Context) extends TypeMap { + /** This class is mostly based on IsFullyDefinedAccumulator. + * It tries to approximate the given type based on the available GADT constraints. + */ + private class ApproximateGadtAccumulator(implicit ctx: Context) extends TypeMap { var failed = false @@ -200,9 +202,7 @@ object Inferencing { val sym = tp.symbol val res = ctx.gadt.approximation(sym, fromBelow = variance < 0) - - debug.println(i"approximated $tp ~~ $res") - + gadts.println(i"approximated $tp ~~ $res") res case _: WildcardType | _: ProtoType => @@ -213,21 +213,8 @@ object Inferencing { mapOver(tp) } - // private class UpperInstantiator(implicit ctx: Context) extends TypeAccumulator[Unit] { - // def apply(x: Unit, tp: Type): Unit = { - // tp match { - // case tvar: TypeVar if !tvar.isInstantiated => - // instantiate(tvar, fromBelow = false) - // case _ => - // } - // foldOver(x, tp) - // } - // } - def process(tp: Type): Type = { - val res = apply(tp) - // if (res && toMaximize) new UpperInstantiator().apply((), tp) - res + apply(tp) } } diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index b62e1bc7f0e7..19f15caa7223 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -13,6 +13,7 @@ import Decorators._ import Uniques._ import config.Printers.typr import util.SourceFile +import util.Property import scala.annotation.internal.sharable @@ -684,7 +685,14 @@ object ProtoTypes { /** Dummy tree to be used as an argument of a FunProto or ViewProto type */ object dummyTreeOfType { - def apply(tp: Type)(implicit src: SourceFile): Tree = untpd.Literal(Constant(null)) withTypeUnchecked tp + /* + * A property indicating that the given tree was created with dummyTreeOfType. + * It is sometimes necessary to detect the dummy trees to avoid unwanted readaptations on them. + */ + val IsDummyTree = new Property.Key[Unit] + + def apply(tp: Type)(implicit src: SourceFile): Tree = + (untpd.Literal(Constant(null)) withTypeUnchecked tp).withAttachment(IsDummyTree, ()) def unapply(tree: untpd.Tree): Option[Type] = tree match { case Literal(Constant(null)) => Some(tree.typeOpt) case _ => None diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 5d665ac4d904..b3a9ec84a1f8 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -44,7 +44,6 @@ import transform.TypeUtils._ import reporting.trace import Nullables.{NotNullInfo, given _} import NullOpsDecorator._ -import config.Printers.debug object Typer { @@ -2774,7 +2773,7 @@ class Typer extends Namer */ def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean = true)(using Context): Tree = { val last = Thread.currentThread.getStackTrace()(2).toString; - trace/*.force*/(i"adapting (tryGadtHealing=$tryGadtHealing) $tree to $pt\n{callsite: $last}", typr, show = true) { + trace(i"adapting (tryGadtHealing=$tryGadtHealing) $tree to $pt\n{callsite: $last}", typr, show = true) { record("adapt") adapt1(tree, pt, locked, tryGadtHealing) } @@ -2784,7 +2783,6 @@ class Typer extends Namer adapt(tree, pt, ctx.typerState.ownedVars) private def adapt1(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = { - // assert(pt.exists && !pt.isInstanceOf[ExprType]) assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported) def methodStr = err.refStr(methPart(tree).tpe) @@ -3243,19 +3241,15 @@ class Typer extends Namer } def adaptToSubType(wtp: Type): Tree = { - debug.println("adaptToSubType") - debug.println("// try converting a constant to the target type") // try converting a constant to the target type val folded = ConstFold(tree, pt) if (folded ne tree) return adaptConstant(folded, folded.tpe.asInstanceOf[ConstantType]) - debug.println("// Try to capture wildcards in type") val captured = captureWildcards(wtp) if (captured `ne` wtp) return readapt(tree.cast(captured)) - debug.println("// drop type if prototype is Unit") // drop type if prototype is Unit if (pt isRef defn.UnitClass) { // local adaptation makes sure every adapted tree conforms to its pt @@ -3265,7 +3259,6 @@ class Typer extends Namer return tpd.Block(tree1 :: Nil, Literal(Constant(()))) } - debug.println("// convert function literal to SAM closure") // convert function literal to SAM closure tree match { case closure(Nil, id @ Ident(nme.ANON_FUN), _) @@ -3283,28 +3276,28 @@ class Typer extends Namer case _ => } - debug.println("// try GADT approximation") - val foo = Inferencing.approximateGADT(wtp) - debug.println( - i""" - foo = $foo + val approximation = Inferencing.approximateGADT(wtp) + gadts.println( + i"""GADT approximation { + approximation = $approximation pt.isInstanceOf[SelectionProto] = ${pt.isInstanceOf[SelectionProto]} ctx.gadt.nonEmpty = ${ctx.gadt.nonEmpty} + ctx.gadt = ${ctx.gadt.debugBoundsDescription} pt.isMatchedBy = ${ if (pt.isInstanceOf[SelectionProto]) - pt.asInstanceOf[SelectionProto].isMatchedBy(foo).toString + pt.asInstanceOf[SelectionProto].isMatchedBy(approximation).toString else "" } + } """ ) pt match { - case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(foo) => - return tpd.Typed(tree, TypeTree(foo)) + case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(approximation) => + return tpd.Typed(tree, TypeTree(approximation)) case _ => ; } - debug.println("// try an extension method in scope") // try an extension method in scope pt match { case SelectionProto(name, mbrType, _, _) => @@ -3322,33 +3315,41 @@ class Typer extends Namer val app = tryExtension(using nestedCtx) if (!app.isEmpty && !nestedCtx.reporter.hasErrors) { nestedCtx.typerState.commit() - debug.println("returning ext meth in scope") return ExtMethodApply(app) } case _ => } - debug.println("// try an implicit conversion") // try an implicit conversion val prevConstraint = ctx.typerState.constraint - def recover(failure: SearchFailureType) = - { - debug.println("recover") + def recover(failure: SearchFailureType) = { + def canTryGADTHealing: Boolean = { + val isDummy = tree.hasAttachment(dummyTreeOfType.IsDummyTree) + tryGadtHealing // allow GADT healing only once to avoid a loop + && ctx.gadt.nonEmpty // GADT healing only makes sense if there are GADT constraints present + && !isDummy // avoid healing a dummy tree as it can lead to an error in a very specific case + } + if (isFullyDefined(wtp, force = ForceDegree.all) && ctx.typerState.constraint.ne(prevConstraint)) readapt(tree) - // else if ({ - // debug.println(i"tryGadtHealing=$tryGadtHealing && \n\tctx.gadt.nonEmpty=${ctx.gadt.nonEmpty}") - // tryGadtHealing && ctx.gadt.nonEmpty - // }) - // { - // debug.println("here") - // readapt( - // tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))), - // shouldTryGadtHealing = false, - // ) - // } - else err.typeMismatch(tree, pt, failure) - } + else if (canTryGADTHealing) { + // try recovering with a GADT approximation + val nestedCtx = ctx.fresh.setNewTyperState() + val res = + readapt( + tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))), + shouldTryGadtHealing = false, + )(using nestedCtx) + if (!nestedCtx.reporter.hasErrors) { + // GADT recovery successful + nestedCtx.typerState.commit() + res + } else { + // otherwise fail with the error that would have been reported without the GADT recovery + err.typeMismatch(tree, pt, failure) + } + } else err.typeMismatch(tree, pt, failure) + } if ctx.mode.is(Mode.ImplicitsEnabled) && tree.typeOpt.isValueType then if pt.isRef(defn.AnyValClass) || pt.isRef(defn.ObjectClass) then ctx.error(em"the result of an implicit conversion must be more specific than $pt", tree.sourcePos) diff --git a/tests/neg/boundspropagation.scala b/tests/neg/boundspropagation.scala index 63a0d1c359ba..8adf73d33453 100644 --- a/tests/neg/boundspropagation.scala +++ b/tests/neg/boundspropagation.scala @@ -25,19 +25,6 @@ object test3 { } } -// Example contributed by Jason. -object test4 { - class Base { - type N - - class Tree[-S, -T >: Option[S]] - - def g(x: Any): Tree[_, _ <: Option[N]] = x match { - case y: Tree[_, _] => y // error -- used to work (because of capture conversion?) - } - } -} - class Test5 { "": ({ type U = this.type })#U // error } diff --git a/tests/pos/boundspropagation.scala b/tests/pos/boundspropagation.scala index 78366c3a1196..e3a42711128c 100644 --- a/tests/pos/boundspropagation.scala +++ b/tests/pos/boundspropagation.scala @@ -29,3 +29,16 @@ object test2 { } } */ + +// Example contributed by Jason. +object test2 { + class Base { + type N + + class Tree[-S, -T >: Option[S]] + + def g(x: Any): Tree[_, _ <: Option[N]] = x match { + case y: Tree[_, _] => y + } + } +} diff --git a/tests/pos/gadt-infer-ascription.scala b/tests/pos/gadt-infer-ascription.scala new file mode 100644 index 000000000000..773b1ee17eca --- /dev/null +++ b/tests/pos/gadt-infer-ascription.scala @@ -0,0 +1,15 @@ +// test based on an example code by @Blaisorblade +object GadtAscription { + enum Var[G, A] { + case Z[G, A]() extends Var[(A, G), A] + case S[G, A, B](x: Var[G, A]) extends Var[(B, G), A] + } + + import Var._ + def evalVar[G, A](x: Var[G, A])(rho: G): A = x match { + case _: Z[g, a] => + rho(0) + case s: S[g, a, b] => + evalVar(s.x)(rho(1)) + } +} diff --git a/tests/pos/i7044.scala b/tests/pos/i7044.scala new file mode 100644 index 000000000000..a18d87244643 --- /dev/null +++ b/tests/pos/i7044.scala @@ -0,0 +1,14 @@ +object i7044 { + case class Seg[T](pat:Pat[T], body:T) + + trait Pat[T] + object Pat { + case class Expr() extends Pat[Int] + case class Opt[S](el:Pat[S]) extends Pat[Option[S]] + } + + def test[T](s:Seg[T]):Int = s match { + case Seg(Pat.Expr(),body) => body + 1 + case Seg(Pat.Opt(Pat.Expr()),body) => body.get + } +}