diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 0a3b7ff130a8..b8f489c3eab4 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -4196,6 +4196,11 @@ object Types { case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) => extension (tp : Type) def fixForEvaluation : Type = tp.normalized.dealias match { + //enable operations for constant singleton terms. E.g.: + //``` + //final val one = 1 + //type Two = one.type + one.type + //``` case tp : TermRef => tp.underlying case tp => tp } @@ -4234,23 +4239,43 @@ object Types { case ConstantType(Constant(n: String)) => Some(n) case _ => None } - def isConst : Option[Type] = args.head.fixForEvaluation match { + + def isConst(tp : Type) : Option[Type] = tp.fixForEvaluation match { case ConstantType(_) => Some(ConstantType(Constant(true))) case _ => Some(ConstantType(Constant(false))) } + + def expectArgsNum(expectedNum : Int) : Unit = + //We can use assert instead of a compiler type error because this error should not + //occur since the type signature of the operation enforces the proper number of args. + assert(args.length == expectedNum, s"Type operation expects $expectedNum arguments but found ${args.length}") + def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue) + //Runs the op and returns the result as a constant type. + //If the op throws an exception, then this exception is converted into a type error. + def runConstantOp(op : => Any): Type = + val result = try { + op + } catch { + case e : Throwable => + throw new TypeError(e.getMessage) + } + ConstantType(Constant(result)) + def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] = - extractor(args.head).map(a => ConstantType(Constant(op(a)))) + expectArgsNum(1) + extractor(args.head).map(a => runConstantOp(op(a))) def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] = constantFold2AB(extractor, extractor, op) def constantFold2AB[TA, TB](extractorA: Type => Option[TA], extractorB: Type => Option[TB], op: (TA, TB) => Any): Option[Type] = + expectArgsNum(2) for { - a <- extractorA(args.head) - b <- extractorB(args.last) - } yield ConstantType(Constant(op(a, b))) + a <- extractorA(args(0)) + b <- extractorB(args(1)) + } yield runConstantOp(op(a, b)) def constantFold3[TA, TB, TC]( extractorA: Type => Option[TA], @@ -4258,139 +4283,126 @@ object Types { extractorC: Type => Option[TC], op: (TA, TB, TC) => Any ): Option[Type] = + expectArgsNum(3) for { - a <- extractorA(args.head) + a <- extractorA(args(0)) b <- extractorB(args(1)) - c <- extractorC(args.last) - } yield ConstantType(Constant(op(a, b, c))) + c <- extractorC(args(2)) + } yield runConstantOp(op(a, b, c)) trace(i"compiletime constant fold $this", typr, show = true) { val name = tycon.symbol.name val owner = tycon.symbol.owner - val nArgs = args.length val constantType = if (defn.isCompiletime_S(tycon.symbol)) { - if (nArgs == 1) constantFold1(natValue, _ + 1) - else None + constantFold1(natValue, _ + 1) } else if (owner == defn.CompiletimeOpsAnyModuleClass) name match { - case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _) - case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _) - case tpnme.ToString if nArgs == 1 => constantFold1(constValue, _.toString) - case tpnme.IsConst if nArgs == 1 => isConst + case tpnme.Equals => constantFold2(constValue, _ == _) + case tpnme.NotEquals => constantFold2(constValue, _ != _) + case tpnme.ToString => constantFold1(constValue, _.toString) + case tpnme.IsConst => isConst(args.head) case _ => None } else if (owner == defn.CompiletimeOpsIntModuleClass) name match { - case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs) - case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => -x) + case tpnme.Abs => constantFold1(intValue, _.abs) + case tpnme.Negate => constantFold1(intValue, x => -x) //ToString is deprecated for ops.int, and moved to ops.any - case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString) - case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _) - case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _) - case tpnme.Times if nArgs == 2 => constantFold2(intValue, _ * _) - case tpnme.Div if nArgs == 2 => constantFold2(intValue, { - case (_, 0) => throw new TypeError("Division by 0") - case (a, b) => a / b - }) - case tpnme.Mod if nArgs == 2 => constantFold2(intValue, { - case (_, 0) => throw new TypeError("Modulo by 0") - case (a, b) => a % b - }) - case tpnme.Lt if nArgs == 2 => constantFold2(intValue, _ < _) - case tpnme.Gt if nArgs == 2 => constantFold2(intValue, _ > _) - case tpnme.Ge if nArgs == 2 => constantFold2(intValue, _ >= _) - case tpnme.Le if nArgs == 2 => constantFold2(intValue, _ <= _) - case tpnme.Xor if nArgs == 2 => constantFold2(intValue, _ ^ _) - case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(intValue, _ & _) - case tpnme.BitwiseOr if nArgs == 2 => constantFold2(intValue, _ | _) - case tpnme.ASR if nArgs == 2 => constantFold2(intValue, _ >> _) - case tpnme.LSL if nArgs == 2 => constantFold2(intValue, _ << _) - case tpnme.LSR if nArgs == 2 => constantFold2(intValue, _ >>> _) - case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _) - case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _) - case tpnme.NumberOfLeadingZeros if nArgs == 1 => constantFold1(intValue, Integer.numberOfLeadingZeros(_)) - case tpnme.ToLong if nArgs == 1 => constantFold1(intValue, _.toLong) - case tpnme.ToFloat if nArgs == 1 => constantFold1(intValue, _.toFloat) - case tpnme.ToDouble if nArgs == 1 => constantFold1(intValue, _.toDouble) + case tpnme.ToString => constantFold1(intValue, _.toString) + case tpnme.Plus => constantFold2(intValue, _ + _) + case tpnme.Minus => constantFold2(intValue, _ - _) + case tpnme.Times => constantFold2(intValue, _ * _) + case tpnme.Div => constantFold2(intValue, _ / _) + case tpnme.Mod => constantFold2(intValue, _ % _) + case tpnme.Lt => constantFold2(intValue, _ < _) + case tpnme.Gt => constantFold2(intValue, _ > _) + case tpnme.Ge => constantFold2(intValue, _ >= _) + case tpnme.Le => constantFold2(intValue, _ <= _) + case tpnme.Xor => constantFold2(intValue, _ ^ _) + case tpnme.BitwiseAnd => constantFold2(intValue, _ & _) + case tpnme.BitwiseOr => constantFold2(intValue, _ | _) + case tpnme.ASR => constantFold2(intValue, _ >> _) + case tpnme.LSL => constantFold2(intValue, _ << _) + case tpnme.LSR => constantFold2(intValue, _ >>> _) + case tpnme.Min => constantFold2(intValue, _ min _) + case tpnme.Max => constantFold2(intValue, _ max _) + case tpnme.NumberOfLeadingZeros => constantFold1(intValue, Integer.numberOfLeadingZeros(_)) + case tpnme.ToLong => constantFold1(intValue, _.toLong) + case tpnme.ToFloat => constantFold1(intValue, _.toFloat) + case tpnme.ToDouble => constantFold1(intValue, _.toDouble) case _ => None } else if (owner == defn.CompiletimeOpsLongModuleClass) name match { - case tpnme.Abs if nArgs == 1 => constantFold1(longValue, _.abs) - case tpnme.Negate if nArgs == 1 => constantFold1(longValue, x => -x) - case tpnme.Plus if nArgs == 2 => constantFold2(longValue, _ + _) - case tpnme.Minus if nArgs == 2 => constantFold2(longValue, _ - _) - case tpnme.Times if nArgs == 2 => constantFold2(longValue, _ * _) - case tpnme.Div if nArgs == 2 => constantFold2(longValue, { - case (_, 0L) => throw new TypeError("Division by 0") - case (a, b) => a / b - }) - case tpnme.Mod if nArgs == 2 => constantFold2(longValue, { - case (_, 0L) => throw new TypeError("Modulo by 0") - case (a, b) => a % b - }) - case tpnme.Lt if nArgs == 2 => constantFold2(longValue, _ < _) - case tpnme.Gt if nArgs == 2 => constantFold2(longValue, _ > _) - case tpnme.Ge if nArgs == 2 => constantFold2(longValue, _ >= _) - case tpnme.Le if nArgs == 2 => constantFold2(longValue, _ <= _) - case tpnme.Xor if nArgs == 2 => constantFold2(longValue, _ ^ _) - case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(longValue, _ & _) - case tpnme.BitwiseOr if nArgs == 2 => constantFold2(longValue, _ | _) - case tpnme.ASR if nArgs == 2 => constantFold2(longValue, _ >> _) - case tpnme.LSL if nArgs == 2 => constantFold2(longValue, _ << _) - case tpnme.LSR if nArgs == 2 => constantFold2(longValue, _ >>> _) - case tpnme.Min if nArgs == 2 => constantFold2(longValue, _ min _) - case tpnme.Max if nArgs == 2 => constantFold2(longValue, _ max _) - case tpnme.NumberOfLeadingZeros if nArgs == 1 => + case tpnme.Abs => constantFold1(longValue, _.abs) + case tpnme.Negate => constantFold1(longValue, x => -x) + case tpnme.Plus => constantFold2(longValue, _ + _) + case tpnme.Minus => constantFold2(longValue, _ - _) + case tpnme.Times => constantFold2(longValue, _ * _) + case tpnme.Div => constantFold2(longValue, _ / _) + case tpnme.Mod => constantFold2(longValue, _ % _) + case tpnme.Lt => constantFold2(longValue, _ < _) + case tpnme.Gt => constantFold2(longValue, _ > _) + case tpnme.Ge => constantFold2(longValue, _ >= _) + case tpnme.Le => constantFold2(longValue, _ <= _) + case tpnme.Xor => constantFold2(longValue, _ ^ _) + case tpnme.BitwiseAnd => constantFold2(longValue, _ & _) + case tpnme.BitwiseOr => constantFold2(longValue, _ | _) + case tpnme.ASR => constantFold2(longValue, _ >> _) + case tpnme.LSL => constantFold2(longValue, _ << _) + case tpnme.LSR => constantFold2(longValue, _ >>> _) + case tpnme.Min => constantFold2(longValue, _ min _) + case tpnme.Max => constantFold2(longValue, _ max _) + case tpnme.NumberOfLeadingZeros => constantFold1(longValue, java.lang.Long.numberOfLeadingZeros(_)) - case tpnme.ToInt if nArgs == 1 => constantFold1(longValue, _.toInt) - case tpnme.ToFloat if nArgs == 1 => constantFold1(longValue, _.toFloat) - case tpnme.ToDouble if nArgs == 1 => constantFold1(longValue, _.toDouble) + case tpnme.ToInt => constantFold1(longValue, _.toInt) + case tpnme.ToFloat => constantFold1(longValue, _.toFloat) + case tpnme.ToDouble => constantFold1(longValue, _.toDouble) case _ => None } else if (owner == defn.CompiletimeOpsFloatModuleClass) name match { - case tpnme.Abs if nArgs == 1 => constantFold1(floatValue, _.abs) - case tpnme.Negate if nArgs == 1 => constantFold1(floatValue, x => -x) - case tpnme.Plus if nArgs == 2 => constantFold2(floatValue, _ + _) - case tpnme.Minus if nArgs == 2 => constantFold2(floatValue, _ - _) - case tpnme.Times if nArgs == 2 => constantFold2(floatValue, _ * _) - case tpnme.Div if nArgs == 2 => constantFold2(floatValue, _ / _) - case tpnme.Mod if nArgs == 2 => constantFold2(floatValue, _ % _) - case tpnme.Lt if nArgs == 2 => constantFold2(floatValue, _ < _) - case tpnme.Gt if nArgs == 2 => constantFold2(floatValue, _ > _) - case tpnme.Ge if nArgs == 2 => constantFold2(floatValue, _ >= _) - case tpnme.Le if nArgs == 2 => constantFold2(floatValue, _ <= _) - case tpnme.Min if nArgs == 2 => constantFold2(floatValue, _ min _) - case tpnme.Max if nArgs == 2 => constantFold2(floatValue, _ max _) - case tpnme.ToInt if nArgs == 1 => constantFold1(floatValue, _.toInt) - case tpnme.ToLong if nArgs == 1 => constantFold1(floatValue, _.toLong) - case tpnme.ToDouble if nArgs == 1 => constantFold1(floatValue, _.toDouble) + case tpnme.Abs => constantFold1(floatValue, _.abs) + case tpnme.Negate => constantFold1(floatValue, x => -x) + case tpnme.Plus => constantFold2(floatValue, _ + _) + case tpnme.Minus => constantFold2(floatValue, _ - _) + case tpnme.Times => constantFold2(floatValue, _ * _) + case tpnme.Div => constantFold2(floatValue, _ / _) + case tpnme.Mod => constantFold2(floatValue, _ % _) + case tpnme.Lt => constantFold2(floatValue, _ < _) + case tpnme.Gt => constantFold2(floatValue, _ > _) + case tpnme.Ge => constantFold2(floatValue, _ >= _) + case tpnme.Le => constantFold2(floatValue, _ <= _) + case tpnme.Min => constantFold2(floatValue, _ min _) + case tpnme.Max => constantFold2(floatValue, _ max _) + case tpnme.ToInt => constantFold1(floatValue, _.toInt) + case tpnme.ToLong => constantFold1(floatValue, _.toLong) + case tpnme.ToDouble => constantFold1(floatValue, _.toDouble) case _ => None } else if (owner == defn.CompiletimeOpsDoubleModuleClass) name match { - case tpnme.Abs if nArgs == 1 => constantFold1(doubleValue, _.abs) - case tpnme.Negate if nArgs == 1 => constantFold1(doubleValue, x => -x) - case tpnme.Plus if nArgs == 2 => constantFold2(doubleValue, _ + _) - case tpnme.Minus if nArgs == 2 => constantFold2(doubleValue, _ - _) - case tpnme.Times if nArgs == 2 => constantFold2(doubleValue, _ * _) - case tpnme.Div if nArgs == 2 => constantFold2(doubleValue, _ / _) - case tpnme.Mod if nArgs == 2 => constantFold2(doubleValue, _ % _) - case tpnme.Lt if nArgs == 2 => constantFold2(doubleValue, _ < _) - case tpnme.Gt if nArgs == 2 => constantFold2(doubleValue, _ > _) - case tpnme.Ge if nArgs == 2 => constantFold2(doubleValue, _ >= _) - case tpnme.Le if nArgs == 2 => constantFold2(doubleValue, _ <= _) - case tpnme.Min if nArgs == 2 => constantFold2(doubleValue, _ min _) - case tpnme.Max if nArgs == 2 => constantFold2(doubleValue, _ max _) - case tpnme.ToInt if nArgs == 1 => constantFold1(doubleValue, _.toInt) - case tpnme.ToLong if nArgs == 1 => constantFold1(doubleValue, _.toLong) - case tpnme.ToFloat if nArgs == 1 => constantFold1(doubleValue, _.toFloat) + case tpnme.Abs => constantFold1(doubleValue, _.abs) + case tpnme.Negate => constantFold1(doubleValue, x => -x) + case tpnme.Plus => constantFold2(doubleValue, _ + _) + case tpnme.Minus => constantFold2(doubleValue, _ - _) + case tpnme.Times => constantFold2(doubleValue, _ * _) + case tpnme.Div => constantFold2(doubleValue, _ / _) + case tpnme.Mod => constantFold2(doubleValue, _ % _) + case tpnme.Lt => constantFold2(doubleValue, _ < _) + case tpnme.Gt => constantFold2(doubleValue, _ > _) + case tpnme.Ge => constantFold2(doubleValue, _ >= _) + case tpnme.Le => constantFold2(doubleValue, _ <= _) + case tpnme.Min => constantFold2(doubleValue, _ min _) + case tpnme.Max => constantFold2(doubleValue, _ max _) + case tpnme.ToInt => constantFold1(doubleValue, _.toInt) + case tpnme.ToLong => constantFold1(doubleValue, _.toLong) + case tpnme.ToFloat => constantFold1(doubleValue, _.toFloat) case _ => None } else if (owner == defn.CompiletimeOpsStringModuleClass) name match { - case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _) - case tpnme.Length if nArgs == 1 => constantFold1(stringValue, _.length) - case tpnme.Matches if nArgs == 2 => constantFold2(stringValue, _ matches _) - case tpnme.Substring if nArgs == 3 => + case tpnme.Plus => constantFold2(stringValue, _ + _) + case tpnme.Length => constantFold1(stringValue, _.length) + case tpnme.Matches => constantFold2(stringValue, _ matches _) + case tpnme.Substring => constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e)) case _ => None } else if (owner == defn.CompiletimeOpsBooleanModuleClass) name match { - case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => !x) - case tpnme.And if nArgs == 2 => constantFold2(boolValue, _ && _) - case tpnme.Or if nArgs == 2 => constantFold2(boolValue, _ || _) - case tpnme.Xor if nArgs == 2 => constantFold2(boolValue, _ ^ _) + case tpnme.Not => constantFold1(boolValue, x => !x) + case tpnme.And => constantFold2(boolValue, _ && _) + case tpnme.Or => constantFold2(boolValue, _ || _) + case tpnme.Xor => constantFold2(boolValue, _ ^ _) case _ => None } else None