Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix exhaustivity due to separate TypeVar lambdas #18616

Merged
merged 2 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ object Trees {
/** A type tree that represents an existing or inferred type */
case class TypeTree[+T <: Untyped]()(implicit @constructorOnly src: SourceFile)
extends DenotingTree[T] with TypTree[T] {
type ThisTree[+T <: Untyped] = TypeTree[T]
type ThisTree[+T <: Untyped] <: TypeTree[T]
override def isEmpty: Boolean = !hasType
override def toString: String =
s"TypeTree${if (hasType) s"[$typeOpt]" else ""}"
Expand All @@ -794,7 +794,8 @@ object Trees {
* - as a (result-)type of an inferred ValDef or DefDef.
* Every TypeVar is created as the type of one InferredTypeTree.
*/
class InferredTypeTree[+T <: Untyped](implicit @constructorOnly src: SourceFile) extends TypeTree[T]
class InferredTypeTree[+T <: Untyped](implicit @constructorOnly src: SourceFile) extends TypeTree[T]:
type ThisTree[+T <: Untyped] <: InferredTypeTree[T]

/** ref.type */
case class SingletonTypeTree[+T <: Untyped] private[ast] (ref: Tree[T])(implicit @constructorOnly src: SourceFile)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3253,7 +3253,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
def matchCase(cas: Type): MatchResult = trace(i"$scrut match ${MatchTypeTrace.caseText(cas)}", matchTypes, show = true) {
val cas1 = cas match {
case cas: HKTypeLambda =>
caseLambda = constrained(cas)
caseLambda = constrained(cas, ast.tpd.EmptyTree)._1
caseLambda.resultType
case _ =>
cas
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4941,6 +4941,9 @@ object Types {
if (inst.exists) inst else origin
}

def wrapInTypeTree(owningTree: Tree)(using Context): InferredTypeTree =
new InferredTypeTree().withSpan(owningTree.span).withType(this)

override def computeHash(bs: Binders): Int = identityHash(bs)
override def equals(that: Any): Boolean = this.eq(that.asInstanceOf[AnyRef])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,12 +530,9 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
(rawRef, rawInfo)
baseInfo match
case tl: PolyType =>
val (tl1, tpts) = constrained(tl, untpd.EmptyTree, alwaysAddTypeVars = true)
val targs =
for (tpt <- tpts) yield
tpt.tpe match {
case tvar: TypeVar => tvar.instantiate(fromBelow = false)
}
val tvars = constrained(tl)
val targs = for tvar <- tvars yield
tvar.instantiate(fromBelow = false)
(baseRef.appliedTo(targs), extractParams(tl.instantiate(targs)))
case methTpe =>
(baseRef, extractParams(methTpe))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ object TypeTestsCasts {
case tp: TypeProxy => underlyingLambda(tp.superType)
}
val typeLambda = underlyingLambda(tycon)
val tvars = constrained(typeLambda, untpd.EmptyTree, alwaysAddTypeVars = true)._2.map(_.tpe)
val tvars = constrained(typeLambda)
val P1 = tycon.appliedTo(tvars)

debug.println("before " + ctx.typerState.constraint.show)
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/patmat/Space.scala
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,8 @@ object SpaceEngine {
erase(parent, inArray, isValue, isTyped)

case tref: TypeRef if tref.symbol.isPatternBound =>
if inArray then tref.underlying
else if isValue then tref.superType
if inArray then erase(tref.underlying, inArray, isValue, isTyped)
else if isValue then erase(tref.superType, inArray, isValue, isTyped)
else WildcardType

case _ => tp
Expand Down Expand Up @@ -531,7 +531,7 @@ object SpaceEngine {
val mt: MethodType = unapp.widen match {
case mt: MethodType => mt
case pt: PolyType =>
val tvars = pt.paramInfos.map(newTypeVar(_))
val tvars = constrained(pt)
val mt = pt.instantiate(tvars).asInstanceOf[MethodType]
scrutineeTp <:< mt.paramInfos(0)
// force type inference to infer a narrower type: could be singleton
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ object Inferencing {
def inferTypeParams(tree: Tree, pt: Type)(using Context): Tree = tree.tpe match
case tl: TypeLambda =>
val (tl1, tvars) = constrained(tl, tree)
var tree1 = AppliedTypeTree(tree.withType(tl1), tvars)
val tree1 = AppliedTypeTree(tree.withType(tl1), tvars.map(_.wrapInTypeTree(tree)))
tree1.tpe <:< pt
if isFullyDefined(tree1.tpe, force = ForceDegree.failBottom) then
tree1
Expand Down
30 changes: 14 additions & 16 deletions compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -726,41 +726,39 @@ object ProtoTypes {
tl: TypeLambda, owningTree: untpd.Tree,
alwaysAddTypeVars: Boolean,
nestingLevel: Int = ctx.nestingLevel
): (TypeLambda, List[TypeTree]) = {
): (TypeLambda, List[TypeVar]) = {
val state = ctx.typerState
val addTypeVars = alwaysAddTypeVars || !owningTree.isEmpty
if (tl.isInstanceOf[PolyType])
assert(!ctx.typerState.isCommittable || addTypeVars,
s"inconsistent: no typevars were added to committable constraint ${state.constraint}")
// hk type lambdas can be added to constraints without typevars during match reduction

def newTypeVars(tl: TypeLambda): List[TypeTree] =
for (paramRef <- tl.paramRefs)
yield {
val tt = InferredTypeTree().withSpan(owningTree.span)
def newTypeVars(tl: TypeLambda): List[TypeVar] =
for paramRef <- tl.paramRefs
yield
val tvar = TypeVar(paramRef, state, nestingLevel)
state.ownedVars += tvar
tt.withType(tvar)
}
tvar

val added = state.constraint.ensureFresh(tl)
val tvars = if (addTypeVars) newTypeVars(added) else Nil
TypeComparer.addToConstraint(added, tvars.tpes.asInstanceOf[List[TypeVar]])
val tvars = if addTypeVars then newTypeVars(added) else Nil
TypeComparer.addToConstraint(added, tvars)
(added, tvars)
}

def constrained(tl: TypeLambda, owningTree: untpd.Tree)(using Context): (TypeLambda, List[TypeTree]) =
def constrained(tl: TypeLambda, owningTree: untpd.Tree)(using Context): (TypeLambda, List[TypeVar]) =
constrained(tl, owningTree,
alwaysAddTypeVars = tl.isInstanceOf[PolyType] && ctx.typerState.isCommittable)

/** Same as `constrained(tl, EmptyTree)`, but returns just the created type lambda */
def constrained(tl: TypeLambda)(using Context): TypeLambda =
constrained(tl, EmptyTree)._1
/** Same as `constrained(tl, EmptyTree, alwaysAddTypeVars = true)`, but returns just the created type vars. */
def constrained(tl: TypeLambda)(using Context): List[TypeVar] =
constrained(tl, EmptyTree, alwaysAddTypeVars = true)._2

/** Instantiate `tl` with fresh type variables added to the constraint. */
def instantiateWithTypeVars(tl: TypeLambda)(using Context): Type =
val targs = constrained(tl, ast.tpd.EmptyTree, alwaysAddTypeVars = true)._2
tl.instantiate(targs.tpes)
val tvars = constrained(tl)
tl.instantiate(tvars)

/** A fresh type variable added to the current constraint.
* @param bounds The initial bounds of the variable
Expand All @@ -779,7 +777,7 @@ object ProtoTypes {
pt => bounds :: Nil,
pt => represents.orElse(defn.AnyType))
constrained(poly, untpd.EmptyTree, alwaysAddTypeVars = true, nestingLevel)
._2.head.tpe.asInstanceOf[TypeVar]
._2.head

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can see, the changes are only about APIs, which simplifies usage in several places, the logic is the same. /cc: @odersky

/** If `param` was created using `newTypeVar(..., represents = X)`, returns X.
* This is used in:
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4336,7 +4336,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
var typeArgs = tree match
case Select(qual, nme.CONSTRUCTOR) => qual.tpe.widenDealias.argTypesLo.map(TypeTree(_))
case _ => Nil
if typeArgs.isEmpty then typeArgs = constrained(poly, tree)._2
if typeArgs.isEmpty then typeArgs = constrained(poly, tree)._2.map(_.wrapInTypeTree(tree))
convertNewGenericArray(readapt(tree.appliedToTypeTrees(typeArgs)))
case wtp =>
val isStructuralCall = wtp.isValueType && isStructuralTermSelectOrApply(tree)
Expand Down
4 changes: 2 additions & 2 deletions compiler/test/dotty/tools/SignatureTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class SignatureTest:
| def tuple2(x: Foo *: (T | Tuple) & Foo): Unit = {}
|""".stripMargin):
val cls = requiredClass("A")
val tvar = constrained(cls.requiredMethod(nme.CONSTRUCTOR).info.asInstanceOf[TypeLambda], untpd.EmptyTree, alwaysAddTypeVars = true)._2.head.tpe
val tvar = constrained(cls.requiredMethod(nme.CONSTRUCTOR).info.asInstanceOf[TypeLambda]).head
tvar <:< defn.TupleTypeRef
val prefix = cls.typeRef.appliedTo(tvar)

Expand All @@ -89,7 +89,7 @@ class SignatureTest:
| def and(x: T & Foo): Unit = {}
|""".stripMargin):
val cls = requiredClass("A")
val tvar = constrained(cls.requiredMethod(nme.CONSTRUCTOR).info.asInstanceOf[TypeLambda], untpd.EmptyTree, alwaysAddTypeVars = true)._2.head.tpe
val tvar = constrained(cls.requiredMethod(nme.CONSTRUCTOR).info.asInstanceOf[TypeLambda]).head
val prefix = cls.typeRef.appliedTo(tvar)
val ref = prefix.select(cls.requiredMethod("and")).asInstanceOf[TermRef]

Expand Down
19 changes: 7 additions & 12 deletions compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ class ConstraintsTest:
@Test def mergeParamsTransitivity: Unit =
inCompilerContext(TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S, T, R]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t, r) = tvars.tpes
val List(s, t, r) = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])

val innerCtx = ctx.fresh.setExploreTyperState()
inContext(innerCtx) {
Expand All @@ -38,8 +37,7 @@ class ConstraintsTest:
@Test def mergeBoundsTransitivity: Unit =
inCompilerContext(TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S, T]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes
val List(s, t) = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])

val innerCtx = ctx.fresh.setExploreTyperState()
inContext(innerCtx) {
Expand All @@ -57,32 +55,29 @@ class ConstraintsTest:
@Test def validBoundsInit: Unit = inCompilerContext(
TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes
val List(s, t) = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])

val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked
val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.origin): @unchecked
assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}")
assert(hi =:= defn.StringType, i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}") // used to be Any
}

@Test def validBoundsUnify: Unit = inCompilerContext(
TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String | Int]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes
val List(s, t) = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])

s <:< t

val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked
val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.origin): @unchecked
assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}")
assert(hi =:= (defn.StringType | defn.IntType), i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}")
}

@Test def validBoundsReplace: Unit = inCompilerContext(
TestConfiguration.basicClasspath,
scalaSources = "trait X; trait A { def foo[S <: U | X, T, U]: Any }") {
val tvarTrees = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val tvars @ List(s, t, u) = tvarTrees.tpes.asInstanceOf[List[TypeVar]]
val tvars @ List(s, t, u) = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])
s =:= t
t =:= u

Expand Down
11 changes: 11 additions & 0 deletions tests/pos/i14224.1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//> using options -Werror

// Derived from the extensive test in the gist in i14224
// Minimising to the false positive in SealedTrait1.either

sealed trait Foo[A, A1 <: A]
final case class Bar[A, A1 <: A](value: A1) extends Foo[A, A1]

class Main:
def test[A, A1 <: A](foo: Foo[A, A1]): A1 = foo match
case Bar(v) => v