From 96a6a405d657c49ced8709ce8fff1fd0cec4de1b Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Wed, 22 May 2019 18:09:15 +0200 Subject: [PATCH] Trust case class unapplies to produce checkable type tests Note that in the case where the unapply cannot match, exhaustivity checker will issue a warning. --- .../tools/dotc/transform/PatternMatcher.scala | 27 ++++++++++++------- .../tools/dotc/transform/TypeTestsCasts.scala | 3 ++- .../suppressed-type-test-warnings.scala | 27 +++++++++++++++++++ 3 files changed, 47 insertions(+), 10 deletions(-) create mode 100644 tests/neg-custom-args/fatal-warnings/suppressed-type-test-warnings.scala diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 2236c29462f2..204550f73657 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -16,6 +16,7 @@ import NameKinds.{PatMatStdBinderName, PatMatAltsName, PatMatResultName} import config.Printers.patmatch import reporting.diagnostic.messages._ import dotty.tools.dotc.ast._ +import util.Property.Key /** The pattern matching transform. * After this phase, the only Match nodes remaining in the code are simple switches @@ -52,6 +53,8 @@ object PatternMatcher { /** Minimal number of cases to emit a switch */ final val MinSwitchCases = 4 + object TrustedTypeTestKey extends Key[Unit] + /** Was symbol generated by pattern matcher? */ def isPatmatGenerated(sym: Symbol)(implicit ctx: Context): Boolean = sym.is(Synthetic) && sym.name.is(PatMatStdBinderName) @@ -153,24 +156,24 @@ object PatternMatcher { /** The different kinds of tests */ sealed abstract class Test - case class TypeTest(tpt: Tree) extends Test { // scrutinee.isInstanceOf[tpt] + case class TypeTest(tpt: Tree, trusted: Boolean) extends Test { // scrutinee.isInstanceOf[tpt] override def equals(that: Any): Boolean = that match { case that: TypeTest => this.tpt.tpe =:= that.tpt.tpe case _ => false } override def hashCode: Int = tpt.tpe.hash } - case class EqualTest(tree: Tree) extends Test { // scrutinee == tree + case class EqualTest(tree: Tree) extends Test { // scrutinee == tree override def equals(that: Any): Boolean = that match { case that: EqualTest => this.tree === that.tree case _ => false } override def hashCode: Int = tree.hash } - case class LengthTest(len: Int, exact: Boolean) extends Test // scrutinee (== | >=) len - case object NonEmptyTest extends Test // !scrutinee.isEmpty - case object NonNullTest extends Test // scrutinee ne null - case object GuardTest extends Test // scrutinee + case class LengthTest(len: Int, exact: Boolean) extends Test // scrutinee (== | >=) len + case object NonEmptyTest extends Test // !scrutinee.isEmpty + case object NonNullTest extends Test // scrutinee ne null + case object GuardTest extends Test // scrutinee // ------- Generating plans from trees ------------------------ @@ -352,7 +355,12 @@ object PatternMatcher { // begin patternPlan swapBind(tree) match { case Typed(pat, tpt) => - TestPlan(TypeTest(tpt), scrutinee, tree.span, + val isTrusted = pat match { + case UnApply(extractor, _, _) => + extractor.symbol.is(Synthetic) && extractor.symbol.owner.linkedClass.is(Case) + case _ => false + } + TestPlan(TypeTest(tpt, isTrusted), scrutinee, tree.span, letAbstract(ref(scrutinee).cast(tpt.tpe)) { casted => nonNull += casted patternPlan(casted, pat, onSuccess) @@ -685,7 +693,7 @@ object PatternMatcher { .select(defn.Seq_length.matchingMember(scrutinee.tpe)) .select(if (exact) defn.Int_== else defn.Int_>=) .appliedTo(Literal(Constant(len))) - case TypeTest(tpt) => + case TypeTest(tpt, trusted) => val expectedTp = tpt.tpe // An outer test is needed in a situation like `case x: y.Inner => ...` @@ -716,6 +724,7 @@ object PatternMatcher { scrutinee.isInstance(expectedTp) // will be translated to an equality test case _ => val typeTest = scrutinee.select(defn.Any_typeTest).appliedToType(expectedTp) + if (trusted) typeTest.pushAttachment(TrustedTypeTestKey, ()) if (outerTestNeeded) typeTest.and(outerTest) else typeTest } } @@ -899,7 +908,7 @@ object PatternMatcher { val seen = mutable.Set[Int]() def showTest(test: Test) = test match { case EqualTest(tree) => i"EqualTest($tree)" - case TypeTest(tpt) => i"TypeTest($tpt)" + case TypeTest(tpt, trusted) => i"TypeTest($tpt, trusted=$trusted)" case _ => test.toString } def showPlan(plan: Plan): Unit = diff --git a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala index b9be5097ac98..362623f8b94c 100644 --- a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala +++ b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala @@ -298,7 +298,8 @@ object TypeTestsCasts { if (sym.isTypeTest) { val argType = tree.args.head.tpe - if (!checkable(expr.tpe, argType, tree.span)) + val isTrusted = tree.getAttachment(PatternMatcher.TrustedTypeTestKey).nonEmpty + if (!isTrusted && !checkable(expr.tpe, argType, tree.span)) ctx.warning(i"the type test for $argType cannot be checked at runtime", tree.sourcePos) transformTypeTest(expr, tree.args.head.tpe, flagUnrelated = true) } diff --git a/tests/neg-custom-args/fatal-warnings/suppressed-type-test-warnings.scala b/tests/neg-custom-args/fatal-warnings/suppressed-type-test-warnings.scala new file mode 100644 index 000000000000..175096fc6b21 --- /dev/null +++ b/tests/neg-custom-args/fatal-warnings/suppressed-type-test-warnings.scala @@ -0,0 +1,27 @@ +object Test { + sealed trait Foo[A, B] + final case class Bar[X](x: X) extends Foo[X, X] + + def foo[A, B](value: Foo[A, B], a: A => Int): B = value match { + case Bar(x) => a(x); x + } + + def bar[A, B](value: Foo[A, B], a: A => Int): B = value match { + case b: Bar[a] => b.x + } + + def err1[A, B](value: Foo[A, B], a: A => Int): B = value match { + case b: Bar[A] => // spurious // error + b.x + } + + def err2[A, B](value: Foo[A, B], a: A => Int): B = value match { + case b: Bar[B] => // spurious // error + b.x + } + + def fail[A, B](value: Foo[A, B], a: A => Int): B = value match { + case b: Bar[Int] => // error + b.x + } +}