Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix use of class terms in match analysis #21848

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 69 additions & 62 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,67 @@ object TypeOps:
* Otherwise, return NoType.
*/
private def instantiateToSubType(tp1: NamedType, tp2: Type, mixins: List[Type])(using Context): Type = trace(i"instantiateToSubType($tp1, $tp2, $mixins)", typr) {
/** Gather GADT symbols and singletons found in `tp2`, ie. the scrutinee. */
object TraverseTp2 extends TypeTraverser:
val singletons = util.HashMap[Symbol, SingletonType]()
val gadtSyms = new mutable.ListBuffer[Symbol]

def traverse(tp: Type) = try
val tpd = tp.dealias
if tpd ne tp then traverse(tpd)
else tp match
case tp: ThisType if !singletons.contains(tp.tref.symbol) && !tp.tref.symbol.isStaticOwner =>
singletons(tp.tref.symbol) = tp
traverseChildren(tp.tref)
case tp: TermRef =>
singletons(tp.typeSymbol) = tp
traverseChildren(tp)
case tp: TypeRef if !gadtSyms.contains(tp.symbol) && tp.symbol.isAbstractOrParamType =>
gadtSyms += tp.symbol
traverseChildren(tp)
// traverse abstract type infos, to add any singletons
// for example, i16451.CanForward.scala, add `Namer.this`, from the info of the type parameter `A1`
// also, i19031.ci-reg2.scala, add `out`, from the info of the type parameter `A1` (from synthetic applyOrElse)
traverseChildren(tp.info)
case _ =>
traverseChildren(tp)
catch case ex: Throwable => handleRecursive("traverseTp2", tp.show, ex)
TraverseTp2.traverse(tp2)
val singletons = TraverseTp2.singletons
val gadtSyms = TraverseTp2.gadtSyms.toList

// Prefix inference, given `p.C.this.Child`:
// 1. return it as is, if `C.this` is found in `tp`, i.e. the scrutinee; or
// 2. replace it with `X.Child` where `X <: p.C`, stripping ThisType in `p` recursively.
//
// See tests/patmat/i3938.scala, tests/pos/i15029.more.scala, tests/pos/i16785.scala
class InferPrefixMap extends TypeMap {
var prefixTVar: Type | Null = null
def apply(tp: Type): Type = tp match {
case tp: TermRef if singletons.contains(tp.symbol) =>
prefixTVar = singletons(tp.symbol) // e.g. tests/pos/i19031.ci-reg2.scala, keep out
prefixTVar.uncheckedNN
case ThisType(tref) if !tref.symbol.isStaticOwner =>
val symbol = tref.symbol
if singletons.contains(symbol) then
prefixTVar = singletons(symbol) // e.g. tests/pos/i16785.scala, keep Outer.this
prefixTVar.uncheckedNN
else if symbol.is(Module) then
TermRef(this(tref.prefix), symbol.sourceModule)
else if (prefixTVar != null)
this(tref.applyIfParameterized(tref.typeParams.map(_ => WildcardType)))
else {
prefixTVar = WildcardType // prevent recursive call from assigning it
// e.g. tests/pos/i15029.more.scala, create a TypeVar for `Instances`' B, so we can disregard `Ints`
val tvars = tref.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds, DepParamName.fresh(tparam.paramName)) }
val tref2 = this(tref.applyIfParameterized(tvars))
prefixTVar = newTypeVar(TypeBounds.upper(tref2), DepParamName.fresh(tref.name))
prefixTVar.uncheckedNN
}
case tp => mapOver(tp)
}
}

// In order for a child type S to qualify as a valid subtype of the parent
// T, we need to test whether it is possible S <: T.
//
Expand All @@ -788,8 +849,15 @@ object TypeOps:
// then to avoid it failing the <:<
// we'll approximate by widening to its bounds

case tp: TermRef if singletons.contains(tp.symbol) =>
singletons(tp.symbol)

case ThisType(tref: TypeRef) if !tref.symbol.isStaticOwner =>
tref
val symbol = tref.symbol
if singletons.contains(symbol) then
singletons(symbol)
else
tref

case tp: TypeRef if !tp.symbol.isClass =>
val lookup = boundTypeParams.lookup(tp)
Expand Down Expand Up @@ -840,67 +908,6 @@ object TypeOps:
}
}

/** Gather GADT symbols and singletons found in `tp2`, ie. the scrutinee. */
object TraverseTp2 extends TypeTraverser:
val singletons = util.HashMap[Symbol, SingletonType]()
val gadtSyms = new mutable.ListBuffer[Symbol]

def traverse(tp: Type) = try
val tpd = tp.dealias
if tpd ne tp then traverse(tpd)
else tp match
case tp: ThisType if !singletons.contains(tp.tref.symbol) && !tp.tref.symbol.isStaticOwner =>
singletons(tp.tref.symbol) = tp
traverseChildren(tp.tref)
case tp: TermRef if tp.symbol.is(Param) =>
singletons(tp.typeSymbol) = tp
traverseChildren(tp)
case tp: TypeRef if !gadtSyms.contains(tp.symbol) && tp.symbol.isAbstractOrParamType =>
gadtSyms += tp.symbol
traverseChildren(tp)
// traverse abstract type infos, to add any singletons
// for example, i16451.CanForward.scala, add `Namer.this`, from the info of the type parameter `A1`
// also, i19031.ci-reg2.scala, add `out`, from the info of the type parameter `A1` (from synthetic applyOrElse)
traverseChildren(tp.info)
case _ =>
traverseChildren(tp)
catch case ex: Throwable => handleRecursive("traverseTp2", tp.show, ex)
TraverseTp2.traverse(tp2)
val singletons = TraverseTp2.singletons
val gadtSyms = TraverseTp2.gadtSyms.toList

// Prefix inference, given `p.C.this.Child`:
// 1. return it as is, if `C.this` is found in `tp`, i.e. the scrutinee; or
// 2. replace it with `X.Child` where `X <: p.C`, stripping ThisType in `p` recursively.
//
// See tests/patmat/i3938.scala, tests/pos/i15029.more.scala, tests/pos/i16785.scala
class InferPrefixMap extends TypeMap {
var prefixTVar: Type | Null = null
def apply(tp: Type): Type = tp match {
case tp: TermRef if singletons.contains(tp.symbol) =>
prefixTVar = singletons(tp.symbol) // e.g. tests/pos/i19031.ci-reg2.scala, keep out
prefixTVar.uncheckedNN
case ThisType(tref) if !tref.symbol.isStaticOwner =>
val symbol = tref.symbol
if singletons.contains(symbol) then
prefixTVar = singletons(symbol) // e.g. tests/pos/i16785.scala, keep Outer.this
prefixTVar.uncheckedNN
else if symbol.is(Module) then
TermRef(this(tref.prefix), symbol.sourceModule)
else if (prefixTVar != null)
this(tref.applyIfParameterized(tref.typeParams.map(_ => WildcardType)))
else {
prefixTVar = WildcardType // prevent recursive call from assigning it
// e.g. tests/pos/i15029.more.scala, create a TypeVar for `Instances`' B, so we can disregard `Ints`
val tvars = tref.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds, DepParamName.fresh(tparam.paramName)) }
val tref2 = this(tref.applyIfParameterized(tvars))
prefixTVar = newTypeVar(TypeBounds.upper(tref2), DepParamName.fresh(tref.name))
prefixTVar.uncheckedNN
}
case tp => mapOver(tp)
}
}

val inferThisMap = new InferPrefixMap
val tvars = tp1.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds, DepParamName.fresh(tparam.paramName)) }
val protoTp1 = inferThisMap.apply(tp1).appliedTo(tvars)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
recur(fn) ~ "(" ~ toTextGlobal(explicitArgs, ", ") ~ ")"
case TypeApply(fn, args) => recur(fn) ~ "[" ~ toTextGlobal(args, ", ") ~ "]"
case Select(qual, nme.CONSTRUCTOR) => recur(qual)
case id @ Ident(tpnme.BOUNDTYPE_ANNOT) => "@" ~ toText(id.symbol.name)
case New(tpt) => recur(tpt)
case _ =>
val annotSym = sym.orElse(tree.symbol.enclosingClass)
Expand Down
33 changes: 33 additions & 0 deletions tests/warn/i21845.orig.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
trait Init[ScopeType]:
sealed trait Initialize[A1]
final class Bind[S, A1](val f: S => Initialize[A1], val in: Initialize[S])
extends Initialize[A1]
final class Value[A1](val value: () => A1) extends Initialize[A1]
final class ValidationCapture[A1](val key: ScopedKey[A1], val selfRefOk: Boolean)
extends Initialize[ScopedKey[A1]]
final class TransformCapture(val f: [x] => Initialize[x] => Initialize[x])
extends Initialize[[x] => Initialize[x] => Initialize[x]]
final class Optional[S, A1](val a: Option[Initialize[S]], val f: Option[S] => A1)
extends Initialize[A1]
object StaticScopes extends Initialize[Set[ScopeType]]

sealed trait Keyed[S, A1] extends Initialize[A1]
trait KeyedInitialize[A1] extends Keyed[A1, A1]
sealed case class ScopedKey[A](scope: ScopeType, key: AttributeKey[A]) extends KeyedInitialize[A]
sealed trait AttributeKey[A]

abstract class EvaluateSettings[ScopeType]:
protected val init: Init[ScopeType]
import init._

val transform: [A] => Initialize[A] => Unit = [A] =>
(fa: Initialize[A]) =>
fa match
case k: Keyed[s, A] => ???
case b: Bind[s, A] => ???
case v: Value[A] => ???
case v: ValidationCapture[a] => ??? // unrearchable warning
case t: TransformCapture => ??? // unrearchable warning
case o: Optional[s, A] => ??? // unrearchable warning
case StaticScopes => ??? // unrearchable warning

15 changes: 15 additions & 0 deletions tests/warn/i21845.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
trait Outer[O1]:
sealed trait Foo[A1]
final class Bar[A2] extends Foo[A2]
final class Baz[A4] extends Foo[Bar[A4]]
final class Qux extends Foo[[a5] => Foo[a5] => Foo[a5]]

trait Test[O2]:
val outer: Outer[O2]
import outer.*

def test[X](fa: Foo[X]): Unit =
fa match // was: inexhaustive: fail on _: (Outer[<?>] & (Test#outer : Outer[Test#O2]))#Qux
case _: Bar[X] => ???
case _: Baz[x] => ??? // was: unrearchable warning
case _: Qux => ??? // was: unrearchable warning
Loading