Skip to content

Commit

Permalink
Merge pull request #15453 from dotty-staging/fix-superType
Browse files Browse the repository at this point in the history
Don't normalize in AppliedType#superType
  • Loading branch information
mbovel authored Jun 17, 2022
2 parents 140693d + a000229 commit 13d9359
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 261 deletions.
20 changes: 11 additions & 9 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
if tycon1sym == tycon2sym && tycon1sym.isAliasType then
val preConstraint = constraint
isSubArgs(args1, args2, tp1, tparams)
&& tryAlso(preConstraint, recur(tp1.superType, tp2.superType))
&& tryAlso(preConstraint, recur(tp1.superTypeNormalized, tp2.superTypeNormalized))
else
isSubArgs(args1, args2, tp1, tparams)
}
Expand Down Expand Up @@ -1177,7 +1177,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
*/
def compareLower(tycon2bounds: TypeBounds, tyconIsTypeRef: Boolean): Boolean =
if ((tycon2bounds.lo `eq` tycon2bounds.hi) && !tycon2bounds.isInstanceOf[MatchAlias])
if (tyconIsTypeRef) recur(tp1, tp2.superType)
if (tyconIsTypeRef) recur(tp1, tp2.superTypeNormalized)
else isSubApproxHi(tp1, tycon2bounds.lo.applyIfParameterized(args2))
else
fallback(tycon2bounds.lo)
Expand Down Expand Up @@ -1249,11 +1249,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling

!sym.isClass && {
defn.isCompiletimeAppliedType(sym) && compareCompiletimeAppliedType(tp1, tp2, fromBelow = false) ||
recur(tp1.superType, tp2) ||
recur(tp1.superTypeNormalized, tp2) ||
tryLiftedToThis1
}|| byGadtBounds
case tycon1: TypeProxy =>
recur(tp1.superType, tp2)
recur(tp1.superTypeNormalized, tp2)
case _ =>
false
}
Expand Down Expand Up @@ -2645,9 +2645,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
!(tp2 <:< tp1)
&& (provablyDisjoint(tp1, tp2.tp2) || provablyDisjoint(tp1, tp2.tp1))
case (tp1: NamedType, _) if gadtBounds(tp1.symbol) != null =>
provablyDisjoint(gadtBounds(tp1.symbol).uncheckedNN.hi, tp2) || provablyDisjoint(tp1.superType, tp2)
provablyDisjoint(gadtBounds(tp1.symbol).uncheckedNN.hi, tp2)
|| provablyDisjoint(tp1.superTypeNormalized, tp2)
case (_, tp2: NamedType) if gadtBounds(tp2.symbol) != null =>
provablyDisjoint(tp1, gadtBounds(tp2.symbol).uncheckedNN.hi) || provablyDisjoint(tp1, tp2.superType)
provablyDisjoint(tp1, gadtBounds(tp2.symbol).uncheckedNN.hi)
|| provablyDisjoint(tp1, tp2.superTypeNormalized)
case (tp1: TermRef, tp2: TermRef) if isEnumValueOrModule(tp1) && isEnumValueOrModule(tp2) =>
tp1.termSymbol != tp2.termSymbol
case (tp1: TermRef, tp2: TypeRef) if isEnumValue(tp1) =>
Expand All @@ -2663,11 +2665,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case (tp1: Type, tp2: Type) if defn.isTupleNType(tp2) =>
provablyDisjoint(tp1, tp2.toNestedPairs)
case (tp1: TypeProxy, tp2: TypeProxy) =>
provablyDisjoint(tp1.superType, tp2) || provablyDisjoint(tp1, tp2.superType)
provablyDisjoint(tp1.superTypeNormalized, tp2) || provablyDisjoint(tp1, tp2.superTypeNormalized)
case (tp1: TypeProxy, _) =>
provablyDisjoint(tp1.superType, tp2)
provablyDisjoint(tp1.superTypeNormalized, tp2)
case (_, tp2: TypeProxy) =>
provablyDisjoint(tp1, tp2.superType)
provablyDisjoint(tp1, tp2.superTypeNormalized)
case _ =>
false
}
Expand Down
240 changes: 240 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeEval.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
package dotty.tools
package dotc
package core

import Types.*, Contexts.*, Symbols.*, Constants.*, Decorators.*
import config.Printers.typr
import reporting.trace
import StdNames.tpnme

object TypeEval:

def tryCompiletimeConstantFold(tp: AppliedType)(using Context): Type = tp.tycon match
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
extension (tp: Type) def fixForEvaluation: Type =
tp.normalized.dealias match
// enable operations for constant singleton terms. E.g.:
// ```
// final val one = 1
// type Two = one.type + one.type
// ```
case tp: TypeProxy if tp.underlying.isStable => tp.underlying.fixForEvaluation
case tp => tp

def constValue(tp: Type): Option[Any] = tp.fixForEvaluation match
case ConstantType(Constant(n)) => Some(n)
case _ => None

def boolValue(tp: Type): Option[Boolean] = tp.fixForEvaluation match
case ConstantType(Constant(n: Boolean)) => Some(n)
case _ => None

def intValue(tp: Type): Option[Int] = tp.fixForEvaluation match
case ConstantType(Constant(n: Int)) => Some(n)
case _ => None

def longValue(tp: Type): Option[Long] = tp.fixForEvaluation match
case ConstantType(Constant(n: Long)) => Some(n)
case _ => None

def floatValue(tp: Type): Option[Float] = tp.fixForEvaluation match
case ConstantType(Constant(n: Float)) => Some(n)
case _ => None

def doubleValue(tp: Type): Option[Double] = tp.fixForEvaluation match
case ConstantType(Constant(n: Double)) => Some(n)
case _ => None

def stringValue(tp: Type): Option[String] = tp.fixForEvaluation match
case ConstantType(Constant(n: String)) => Some(n)
case _ => None

// Returns Some(true) if the type is a constant.
// Returns Some(false) if the type is not a constant.
// Returns None if there is not enough information to determine if the type is a constant.
// The type is a constant if it is a constant type or a type operation composition of constant types.
// If we get a type reference for an argument, then the result is not yet known.
def isConst(tp: Type): Option[Boolean] = tp.dealias match
// known to be constant
case ConstantType(_) => Some(true)
// currently not a concrete known type
case TypeRef(NoPrefix,_) => None
// currently not a concrete known type
case _: TypeParamRef => None
// constant if the term is constant
case t: TermRef => isConst(t.underlying)
// an operation type => recursively check all argument compositions
case applied: AppliedType if defn.isCompiletimeAppliedType(applied.typeSymbol) =>
val argsConst = applied.args.map(isConst)
if (argsConst.exists(_.isEmpty)) None
else Some(argsConst.forall(_.get))
// all other types are considered not to be constant
case _ => Some(false)

def expectArgsNum(expectedNum: Int): Unit =
// We can use assert instead of a compiler type error because this error should not
// occur since the type signature of the operation enforces the proper number of args.
assert(tp.args.length == expectedNum, s"Type operation expects $expectedNum arguments but found ${tp.args.length}")

def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)

// Runs the op and returns the result as a constant type.
// If the op throws an exception, then this exception is converted into a type error.
def runConstantOp(op: => Any): Type =
val result =
try op
catch case e: Throwable =>
throw new TypeError(e.getMessage.nn)
ConstantType(Constant(result))

def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
expectArgsNum(1)
extractor(tp.args.head).map(a => runConstantOp(op(a)))

def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
constantFold2AB(extractor, extractor, op)

def constantFold2AB[TA, TB](extractorA: Type => Option[TA], extractorB: Type => Option[TB], op: (TA, TB) => Any): Option[Type] =
expectArgsNum(2)
for
a <- extractorA(tp.args(0))
b <- extractorB(tp.args(1))
yield runConstantOp(op(a, b))

def constantFold3[TA, TB, TC](
extractorA: Type => Option[TA],
extractorB: Type => Option[TB],
extractorC: Type => Option[TC],
op: (TA, TB, TC) => Any
): Option[Type] =
expectArgsNum(3)
for
a <- extractorA(tp.args(0))
b <- extractorB(tp.args(1))
c <- extractorC(tp.args(2))
yield runConstantOp(op(a, b, c))

trace(i"compiletime constant fold $tp", typr, show = true) {
val name = tycon.symbol.name
val owner = tycon.symbol.owner
val constantType =
if defn.isCompiletime_S(tycon.symbol) then
constantFold1(natValue, _ + 1)
else if owner == defn.CompiletimeOpsAnyModuleClass then name match
case tpnme.Equals => constantFold2(constValue, _ == _)
case tpnme.NotEquals => constantFold2(constValue, _ != _)
case tpnme.ToString => constantFold1(constValue, _.toString)
case tpnme.IsConst => isConst(tp.args.head).map(b => ConstantType(Constant(b)))
case _ => None
else if owner == defn.CompiletimeOpsIntModuleClass then name match
case tpnme.Abs => constantFold1(intValue, _.abs)
case tpnme.Negate => constantFold1(intValue, x => -x)
// ToString is deprecated for ops.int, and moved to ops.any
case tpnme.ToString => constantFold1(intValue, _.toString)
case tpnme.Plus => constantFold2(intValue, _ + _)
case tpnme.Minus => constantFold2(intValue, _ - _)
case tpnme.Times => constantFold2(intValue, _ * _)
case tpnme.Div => constantFold2(intValue, _ / _)
case tpnme.Mod => constantFold2(intValue, _ % _)
case tpnme.Lt => constantFold2(intValue, _ < _)
case tpnme.Gt => constantFold2(intValue, _ > _)
case tpnme.Ge => constantFold2(intValue, _ >= _)
case tpnme.Le => constantFold2(intValue, _ <= _)
case tpnme.Xor => constantFold2(intValue, _ ^ _)
case tpnme.BitwiseAnd => constantFold2(intValue, _ & _)
case tpnme.BitwiseOr => constantFold2(intValue, _ | _)
case tpnme.ASR => constantFold2(intValue, _ >> _)
case tpnme.LSL => constantFold2(intValue, _ << _)
case tpnme.LSR => constantFold2(intValue, _ >>> _)
case tpnme.Min => constantFold2(intValue, _ min _)
case tpnme.Max => constantFold2(intValue, _ max _)
case tpnme.NumberOfLeadingZeros => constantFold1(intValue, Integer.numberOfLeadingZeros(_))
case tpnme.ToLong => constantFold1(intValue, _.toLong)
case tpnme.ToFloat => constantFold1(intValue, _.toFloat)
case tpnme.ToDouble => constantFold1(intValue, _.toDouble)
case _ => None
else if owner == defn.CompiletimeOpsLongModuleClass then name match
case tpnme.Abs => constantFold1(longValue, _.abs)
case tpnme.Negate => constantFold1(longValue, x => -x)
case tpnme.Plus => constantFold2(longValue, _ + _)
case tpnme.Minus => constantFold2(longValue, _ - _)
case tpnme.Times => constantFold2(longValue, _ * _)
case tpnme.Div => constantFold2(longValue, _ / _)
case tpnme.Mod => constantFold2(longValue, _ % _)
case tpnme.Lt => constantFold2(longValue, _ < _)
case tpnme.Gt => constantFold2(longValue, _ > _)
case tpnme.Ge => constantFold2(longValue, _ >= _)
case tpnme.Le => constantFold2(longValue, _ <= _)
case tpnme.Xor => constantFold2(longValue, _ ^ _)
case tpnme.BitwiseAnd => constantFold2(longValue, _ & _)
case tpnme.BitwiseOr => constantFold2(longValue, _ | _)
case tpnme.ASR => constantFold2(longValue, _ >> _)
case tpnme.LSL => constantFold2(longValue, _ << _)
case tpnme.LSR => constantFold2(longValue, _ >>> _)
case tpnme.Min => constantFold2(longValue, _ min _)
case tpnme.Max => constantFold2(longValue, _ max _)
case tpnme.NumberOfLeadingZeros =>
constantFold1(longValue, java.lang.Long.numberOfLeadingZeros(_))
case tpnme.ToInt => constantFold1(longValue, _.toInt)
case tpnme.ToFloat => constantFold1(longValue, _.toFloat)
case tpnme.ToDouble => constantFold1(longValue, _.toDouble)
case _ => None
else if owner == defn.CompiletimeOpsFloatModuleClass then name match
case tpnme.Abs => constantFold1(floatValue, _.abs)
case tpnme.Negate => constantFold1(floatValue, x => -x)
case tpnme.Plus => constantFold2(floatValue, _ + _)
case tpnme.Minus => constantFold2(floatValue, _ - _)
case tpnme.Times => constantFold2(floatValue, _ * _)
case tpnme.Div => constantFold2(floatValue, _ / _)
case tpnme.Mod => constantFold2(floatValue, _ % _)
case tpnme.Lt => constantFold2(floatValue, _ < _)
case tpnme.Gt => constantFold2(floatValue, _ > _)
case tpnme.Ge => constantFold2(floatValue, _ >= _)
case tpnme.Le => constantFold2(floatValue, _ <= _)
case tpnme.Min => constantFold2(floatValue, _ min _)
case tpnme.Max => constantFold2(floatValue, _ max _)
case tpnme.ToInt => constantFold1(floatValue, _.toInt)
case tpnme.ToLong => constantFold1(floatValue, _.toLong)
case tpnme.ToDouble => constantFold1(floatValue, _.toDouble)
case _ => None
else if owner == defn.CompiletimeOpsDoubleModuleClass then name match
case tpnme.Abs => constantFold1(doubleValue, _.abs)
case tpnme.Negate => constantFold1(doubleValue, x => -x)
case tpnme.Plus => constantFold2(doubleValue, _ + _)
case tpnme.Minus => constantFold2(doubleValue, _ - _)
case tpnme.Times => constantFold2(doubleValue, _ * _)
case tpnme.Div => constantFold2(doubleValue, _ / _)
case tpnme.Mod => constantFold2(doubleValue, _ % _)
case tpnme.Lt => constantFold2(doubleValue, _ < _)
case tpnme.Gt => constantFold2(doubleValue, _ > _)
case tpnme.Ge => constantFold2(doubleValue, _ >= _)
case tpnme.Le => constantFold2(doubleValue, _ <= _)
case tpnme.Min => constantFold2(doubleValue, _ min _)
case tpnme.Max => constantFold2(doubleValue, _ max _)
case tpnme.ToInt => constantFold1(doubleValue, _.toInt)
case tpnme.ToLong => constantFold1(doubleValue, _.toLong)
case tpnme.ToFloat => constantFold1(doubleValue, _.toFloat)
case _ => None
else if owner == defn.CompiletimeOpsStringModuleClass then name match
case tpnme.Plus => constantFold2(stringValue, _ + _)
case tpnme.Length => constantFold1(stringValue, _.length)
case tpnme.Matches => constantFold2(stringValue, _ matches _)
case tpnme.Substring =>
constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
case tpnme.CharAt =>
constantFold2AB(stringValue, intValue, _ charAt _)
case _ => None
else if owner == defn.CompiletimeOpsBooleanModuleClass then name match
case tpnme.Not => constantFold1(boolValue, x => !x)
case tpnme.And => constantFold2(boolValue, _ && _)
case tpnme.Or => constantFold2(boolValue, _ || _)
case tpnme.Xor => constantFold2(boolValue, _ ^ _)
case _ => None
else None

constantType.getOrElse(NoType)
}

case _ => NoType
end tryCompiletimeConstantFold
end TypeEval
Loading

0 comments on commit 13d9359

Please sign in to comment.