From 4d79459b6cf4b503e81869786ae6024afda5eb3e Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Mon, 28 Oct 2024 10:45:08 +0000 Subject: [PATCH] Fix use of class terms in match analysis When instantiating a subclass Outer.this.Bar, such that it's a subtype of Test.this.outer.Foo[X], make sure to infer the term `outer` (even if it's not a parameter). Also make sure to use those singletons when approximating the parent (to fix Outer.this.Qux instantiating). --- .../src/dotty/tools/dotc/core/TypeOps.scala | 131 +++++++++--------- .../tools/dotc/printing/RefinedPrinter.scala | 1 + tests/warn/i21845.orig.scala | 33 +++++ tests/warn/i21845.scala | 15 ++ 4 files changed, 118 insertions(+), 62 deletions(-) create mode 100644 tests/warn/i21845.orig.scala create mode 100644 tests/warn/i21845.scala diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index bfda613d0586..fac5d262d426 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -767,6 +767,67 @@ object TypeOps: * Otherwise, return NoType. */ private def instantiateToSubType(tp1: NamedType, tp2: Type, mixins: List[Type])(using Context): Type = trace(i"instantiateToSubType($tp1, $tp2, $mixins)", typr) { + /** Gather GADT symbols and singletons found in `tp2`, ie. the scrutinee. */ + object TraverseTp2 extends TypeTraverser: + val singletons = util.HashMap[Symbol, SingletonType]() + val gadtSyms = new mutable.ListBuffer[Symbol] + + def traverse(tp: Type) = try + val tpd = tp.dealias + if tpd ne tp then traverse(tpd) + else tp match + case tp: ThisType if !singletons.contains(tp.tref.symbol) && !tp.tref.symbol.isStaticOwner => + singletons(tp.tref.symbol) = tp + traverseChildren(tp.tref) + case tp: TermRef => + singletons(tp.typeSymbol) = tp + traverseChildren(tp) + case tp: TypeRef if !gadtSyms.contains(tp.symbol) && tp.symbol.isAbstractOrParamType => + gadtSyms += tp.symbol + traverseChildren(tp) + // traverse abstract type infos, to add any singletons + // for example, i16451.CanForward.scala, add `Namer.this`, from the info of the type parameter `A1` + // also, i19031.ci-reg2.scala, add `out`, from the info of the type parameter `A1` (from synthetic applyOrElse) + traverseChildren(tp.info) + case _ => + traverseChildren(tp) + catch case ex: Throwable => handleRecursive("traverseTp2", tp.show, ex) + TraverseTp2.traverse(tp2) + val singletons = TraverseTp2.singletons + val gadtSyms = TraverseTp2.gadtSyms.toList + + // Prefix inference, given `p.C.this.Child`: + // 1. return it as is, if `C.this` is found in `tp`, i.e. the scrutinee; or + // 2. replace it with `X.Child` where `X <: p.C`, stripping ThisType in `p` recursively. + // + // See tests/patmat/i3938.scala, tests/pos/i15029.more.scala, tests/pos/i16785.scala + class InferPrefixMap extends TypeMap { + var prefixTVar: Type | Null = null + def apply(tp: Type): Type = tp match { + case tp: TermRef if singletons.contains(tp.symbol) => + prefixTVar = singletons(tp.symbol) // e.g. tests/pos/i19031.ci-reg2.scala, keep out + prefixTVar.uncheckedNN + case ThisType(tref) if !tref.symbol.isStaticOwner => + val symbol = tref.symbol + if singletons.contains(symbol) then + prefixTVar = singletons(symbol) // e.g. tests/pos/i16785.scala, keep Outer.this + prefixTVar.uncheckedNN + else if symbol.is(Module) then + TermRef(this(tref.prefix), symbol.sourceModule) + else if (prefixTVar != null) + this(tref.applyIfParameterized(tref.typeParams.map(_ => WildcardType))) + else { + prefixTVar = WildcardType // prevent recursive call from assigning it + // e.g. tests/pos/i15029.more.scala, create a TypeVar for `Instances`' B, so we can disregard `Ints` + val tvars = tref.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds, DepParamName.fresh(tparam.paramName)) } + val tref2 = this(tref.applyIfParameterized(tvars)) + prefixTVar = newTypeVar(TypeBounds.upper(tref2), DepParamName.fresh(tref.name)) + prefixTVar.uncheckedNN + } + case tp => mapOver(tp) + } + } + // In order for a child type S to qualify as a valid subtype of the parent // T, we need to test whether it is possible S <: T. // @@ -788,8 +849,15 @@ object TypeOps: // then to avoid it failing the <:< // we'll approximate by widening to its bounds + case tp: TermRef if singletons.contains(tp.symbol) => + singletons(tp.symbol) + case ThisType(tref: TypeRef) if !tref.symbol.isStaticOwner => - tref + val symbol = tref.symbol + if singletons.contains(symbol) then + singletons(symbol) + else + tref case tp: TypeRef if !tp.symbol.isClass => val lookup = boundTypeParams.lookup(tp) @@ -840,67 +908,6 @@ object TypeOps: } } - /** Gather GADT symbols and singletons found in `tp2`, ie. the scrutinee. */ - object TraverseTp2 extends TypeTraverser: - val singletons = util.HashMap[Symbol, SingletonType]() - val gadtSyms = new mutable.ListBuffer[Symbol] - - def traverse(tp: Type) = try - val tpd = tp.dealias - if tpd ne tp then traverse(tpd) - else tp match - case tp: ThisType if !singletons.contains(tp.tref.symbol) && !tp.tref.symbol.isStaticOwner => - singletons(tp.tref.symbol) = tp - traverseChildren(tp.tref) - case tp: TermRef if tp.symbol.is(Param) => - singletons(tp.typeSymbol) = tp - traverseChildren(tp) - case tp: TypeRef if !gadtSyms.contains(tp.symbol) && tp.symbol.isAbstractOrParamType => - gadtSyms += tp.symbol - traverseChildren(tp) - // traverse abstract type infos, to add any singletons - // for example, i16451.CanForward.scala, add `Namer.this`, from the info of the type parameter `A1` - // also, i19031.ci-reg2.scala, add `out`, from the info of the type parameter `A1` (from synthetic applyOrElse) - traverseChildren(tp.info) - case _ => - traverseChildren(tp) - catch case ex: Throwable => handleRecursive("traverseTp2", tp.show, ex) - TraverseTp2.traverse(tp2) - val singletons = TraverseTp2.singletons - val gadtSyms = TraverseTp2.gadtSyms.toList - - // Prefix inference, given `p.C.this.Child`: - // 1. return it as is, if `C.this` is found in `tp`, i.e. the scrutinee; or - // 2. replace it with `X.Child` where `X <: p.C`, stripping ThisType in `p` recursively. - // - // See tests/patmat/i3938.scala, tests/pos/i15029.more.scala, tests/pos/i16785.scala - class InferPrefixMap extends TypeMap { - var prefixTVar: Type | Null = null - def apply(tp: Type): Type = tp match { - case tp: TermRef if singletons.contains(tp.symbol) => - prefixTVar = singletons(tp.symbol) // e.g. tests/pos/i19031.ci-reg2.scala, keep out - prefixTVar.uncheckedNN - case ThisType(tref) if !tref.symbol.isStaticOwner => - val symbol = tref.symbol - if singletons.contains(symbol) then - prefixTVar = singletons(symbol) // e.g. tests/pos/i16785.scala, keep Outer.this - prefixTVar.uncheckedNN - else if symbol.is(Module) then - TermRef(this(tref.prefix), symbol.sourceModule) - else if (prefixTVar != null) - this(tref.applyIfParameterized(tref.typeParams.map(_ => WildcardType))) - else { - prefixTVar = WildcardType // prevent recursive call from assigning it - // e.g. tests/pos/i15029.more.scala, create a TypeVar for `Instances`' B, so we can disregard `Ints` - val tvars = tref.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds, DepParamName.fresh(tparam.paramName)) } - val tref2 = this(tref.applyIfParameterized(tvars)) - prefixTVar = newTypeVar(TypeBounds.upper(tref2), DepParamName.fresh(tref.name)) - prefixTVar.uncheckedNN - } - case tp => mapOver(tp) - } - } - val inferThisMap = new InferPrefixMap val tvars = tp1.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds, DepParamName.fresh(tparam.paramName)) } val protoTp1 = inferThisMap.apply(tp1).appliedTo(tvars) diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index b229c7ec29d9..f22abbd2efcb 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -1121,6 +1121,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { recur(fn) ~ "(" ~ toTextGlobal(explicitArgs, ", ") ~ ")" case TypeApply(fn, args) => recur(fn) ~ "[" ~ toTextGlobal(args, ", ") ~ "]" case Select(qual, nme.CONSTRUCTOR) => recur(qual) + case id @ Ident(tpnme.BOUNDTYPE_ANNOT) => "@" ~ toText(id.symbol.name) case New(tpt) => recur(tpt) case _ => val annotSym = sym.orElse(tree.symbol.enclosingClass) diff --git a/tests/warn/i21845.orig.scala b/tests/warn/i21845.orig.scala new file mode 100644 index 000000000000..a8e0893ea5ce --- /dev/null +++ b/tests/warn/i21845.orig.scala @@ -0,0 +1,33 @@ +trait Init[ScopeType]: + sealed trait Initialize[A1] + final class Bind[S, A1](val f: S => Initialize[A1], val in: Initialize[S]) + extends Initialize[A1] + final class Value[A1](val value: () => A1) extends Initialize[A1] + final class ValidationCapture[A1](val key: ScopedKey[A1], val selfRefOk: Boolean) + extends Initialize[ScopedKey[A1]] + final class TransformCapture(val f: [x] => Initialize[x] => Initialize[x]) + extends Initialize[[x] => Initialize[x] => Initialize[x]] + final class Optional[S, A1](val a: Option[Initialize[S]], val f: Option[S] => A1) + extends Initialize[A1] + object StaticScopes extends Initialize[Set[ScopeType]] + + sealed trait Keyed[S, A1] extends Initialize[A1] + trait KeyedInitialize[A1] extends Keyed[A1, A1] + sealed case class ScopedKey[A](scope: ScopeType, key: AttributeKey[A]) extends KeyedInitialize[A] + sealed trait AttributeKey[A] + +abstract class EvaluateSettings[ScopeType]: + protected val init: Init[ScopeType] + import init._ + + val transform: [A] => Initialize[A] => Unit = [A] => + (fa: Initialize[A]) => + fa match + case k: Keyed[s, A] => ??? + case b: Bind[s, A] => ??? + case v: Value[A] => ??? + case v: ValidationCapture[a] => ??? // unrearchable warning + case t: TransformCapture => ??? // unrearchable warning + case o: Optional[s, A] => ??? // unrearchable warning + case StaticScopes => ??? // unrearchable warning + diff --git a/tests/warn/i21845.scala b/tests/warn/i21845.scala new file mode 100644 index 000000000000..58590c74e1d4 --- /dev/null +++ b/tests/warn/i21845.scala @@ -0,0 +1,15 @@ +trait Outer[O1]: + sealed trait Foo[A1] + final class Bar[A2] extends Foo[A2] + final class Baz[A4] extends Foo[Bar[A4]] + final class Qux extends Foo[[a5] => Foo[a5] => Foo[a5]] + +trait Test[O2]: + val outer: Outer[O2] + import outer.* + + def test[X](fa: Foo[X]): Unit = + fa match // was: inexhaustive: fail on _: (Outer[] & (Test#outer : Outer[Test#O2]))#Qux + case _: Bar[X] => ??? + case _: Baz[x] => ??? // was: unrearchable warning + case _: Qux => ??? // was: unrearchable warning