Skip to content

Commit

Permalink
Cut the Gordian Knot: Don't widen unions to transparent
Browse files Browse the repository at this point in the history
The idea is that some unions usually make more sense than others. For instance,
if `Apply` and `Ident` are case classes that extend `Tree`, it makes sense to
widen `Apply | Ident` to `Tree`. But it makes less sense to widen `String | Int`
to `Matchable`.

Making sense means: (1) Matches our intuitive understanding, and (2) choosing not to
widen would usually not cause errors.

To explain (2): In the `Tree` case it might well be that we define an implicits on `Inv[Tree]` for
invariant class `Inv`, and then we would not find the implicit for `Inv[Apply | Ident]`.
But it's much less likely that we are looking for an implicit of type `Inv[Any]`.

This commit does two things:

 - add logic not to widen a union if the result is a product of only transparent traits or classes.
 - treat `Any`, `AnyVal`, `Object`, and `Matchable` as transparent.
  • Loading branch information
odersky committed Nov 8, 2022
1 parent 275cfa8 commit f19de96
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 15 deletions.
17 changes: 15 additions & 2 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,10 @@ trait ConstraintHandling {
inst
end approximation

private def isTransparent(tp: Type)(using Context): Boolean = tp match
case AndType(tp1, tp2) => isTransparent(tp1) && isTransparent(tp2)
case _ => tp.typeSymbol.isTransparentTrait && !tp.isLambdaSub

/** If `tp` is an intersection such that some operands are transparent trait instances
* and others are not, replace as many transparent trait instances as possible with Any
* as long as the result is still a subtype of `bound`. But fall back to the
Expand All @@ -563,7 +567,7 @@ trait ConstraintHandling {

def dropOneTransparentTrait(tp: Type): Type =
val tpd = tp.dealias
if tpd.typeSymbol.isTransparentTrait && !tpd.isLambdaSub && !kept.contains(tpd) then
if isTransparent(tpd) && !kept.contains(tpd) then
dropped = tpd :: dropped
defn.AnyType
else tpd match
Expand Down Expand Up @@ -648,7 +652,16 @@ trait ConstraintHandling {

val wideInst =
if isSingleton(bound) then inst
else dropTransparentTraits(widenIrreducible(widenOr(widenSingle(inst))), bound)
else
val widenedFromSingle = widenSingle(inst)
val widenedFromUnion = widenOr(widenedFromSingle)
val widened =
if (widenedFromUnion ne widenedFromSingle) && isTransparent(widenedFromUnion) then
widenedFromSingle
else
dropTransparentTraits(widenedFromUnion, bound)
widenIrreducible(widened)

wideInst match
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1839,7 +1839,11 @@ class Definitions {
requiredClass("scala.collection.generic.IsMap"),
requiredClass("scala.collection.generic.IsSeq"),
requiredClass("scala.collection.generic.Subtractable"),
requiredClass("scala.collection.immutable.StrictOptimizedSeqOps")
requiredClass("scala.collection.immutable.StrictOptimizedSeqOps"),
AnyClass,
AnyValClass,
ObjectClass,
MatchableClass
)

// ----- primitive value class machinery ------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions tests/neg/harmonize.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ object Test {
val a4 = ArrayBuffer(1.0f, 1L)
val b4: ArrayBuffer[Double] = a4 // error: no widening
val a5 = ArrayBuffer(1.0f, 1L, f())
val b5: ArrayBuffer[AnyVal] = a5
val b5: ArrayBuffer[Float | Long | Int] = a5
val a6 = ArrayBuffer(1.0f, 1234567890)
val b6: ArrayBuffer[AnyVal] = a6
val b6: ArrayBuffer[Float | Int] = a6

def totalDuration(results: List[Long], cond: Boolean): Long =
results.map(r => if (cond) r else 0).sum
Expand Down
11 changes: 6 additions & 5 deletions tests/neg/supertraits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@ class C extends A, S
val x = if ??? then B() else C()
val x1: S = x // error

case object a
case object b
class Top
case object a extends Top
case object b extends Top
val y = if ??? then a else b
val y1: Product = y // error
val y2: Serializable = y // error

enum Color {
enum Color extends Top {
case Red, Green, Blue
}

enum Nucleobase {
enum Nucleobase extends Top {
case A, C, G, T
}

val z = if ??? then Color.Red else Nucleobase.G
val z1: reflect.Enum = z // error: Found: (z : Object) Required: reflect.Enum
val z1: reflect.Enum = z // error: Found: (z : Top) Required: reflect.Enum
5 changes: 3 additions & 2 deletions tests/neg/union.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ object Test {
}

object O {
class A
class B
class Top
class A extends Top
class B extends Top
def f[T](x: T, y: T): T = x

val x: A = f(new A { }, new A)
Expand Down
File renamed without changes.
7 changes: 7 additions & 0 deletions tests/pos/unions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
object Test:

def test =
val x = if ??? then "" else 1
val _: String | Int = x


6 changes: 3 additions & 3 deletions tests/run/weak-conformance.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ object Test extends App {
locally {
def f(): Int = b + 1
val x1 = ArrayBuffer(b, 33, 5.5) ; x1: ArrayBuffer[Double] // b is an inline val
val x2 = ArrayBuffer(f(), 33, 5.5) ; x2: ArrayBuffer[AnyVal] // f() is not a constant
val x2 = ArrayBuffer(f(), 33, 5.5) ; x2: ArrayBuffer[Int | Double] // f() is not a constant
val x3 = ArrayBuffer(5, 11L) ; x3: ArrayBuffer[Long]
val x4 = ArrayBuffer(5, 11L, 5.5) ; x4: ArrayBuffer[AnyVal] // Long and Double found
val x4 = ArrayBuffer(5, 11L, 5.5) ; x4: ArrayBuffer[Int | Long | Double] // Long and Double found
val x5 = ArrayBuffer(1.0f, 2) ; x5: ArrayBuffer[Float]
val x6 = ArrayBuffer(1.0f, 1234567890); x6: ArrayBuffer[AnyVal] // loss of precision
val x6 = ArrayBuffer(1.0f, 1234567890); x6: ArrayBuffer[Float | Int] // loss of precision
val x7 = ArrayBuffer(b, 33, 'a') ; x7: ArrayBuffer[Char]
val x8 = ArrayBuffer(5.toByte, 11) ; x8: ArrayBuffer[Byte]

Expand Down

0 comments on commit f19de96

Please sign in to comment.