Skip to content

Commit

Permalink
Address review comments.
Browse files Browse the repository at this point in the history
Main changes:
* Apply common protection against wrong number of arguments.
* Exception in an operation is converted into a type error.
  • Loading branch information
soronpo committed Oct 26, 2021
1 parent 32043aa commit d433ce8
Showing 1 changed file with 126 additions and 114 deletions.
240 changes: 126 additions & 114 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4196,6 +4196,11 @@ object Types {
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
extension (tp : Type) def fixForEvaluation : Type =
tp.normalized.dealias match {
//enable operations for constant singleton terms. E.g.:
//```
//final val one = 1
//type Two = one.type + one.type
//```
case tp : TermRef => tp.underlying
case tp => tp
}
Expand Down Expand Up @@ -4234,163 +4239,170 @@ object Types {
case ConstantType(Constant(n: String)) => Some(n)
case _ => None
}
def isConst : Option[Type] = args.head.fixForEvaluation match {

def isConst(tp : Type) : Option[Type] = tp.fixForEvaluation match {
case ConstantType(_) => Some(ConstantType(Constant(true)))
case _ => Some(ConstantType(Constant(false)))
}

def expectArgsNum(expectedNum : Int) : Unit =
//We can use assert instead of a compiler type error because this error should not
//occur since the type signature of the operation enforces the proper number of args.
assert(args.length == expectedNum, s"Type operation expects $expectedNum arguments but found ${args.length}")

def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)

//Runs the op and returns the result as a constant type.
//If the op throws an exception, then this exception is converted into a type error.
def runConstantOp(op : => Any): Type =
val result = try {
op
} catch {
case e : Throwable =>
throw new TypeError(e.getMessage)
}
ConstantType(Constant(result))

def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
extractor(args.head).map(a => ConstantType(Constant(op(a))))
expectArgsNum(1)
extractor(args.head).map(a => runConstantOp(op(a)))

def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
constantFold2AB(extractor, extractor, op)

def constantFold2AB[TA, TB](extractorA: Type => Option[TA], extractorB: Type => Option[TB], op: (TA, TB) => Any): Option[Type] =
expectArgsNum(2)
for {
a <- extractorA(args.head)
b <- extractorB(args.last)
} yield ConstantType(Constant(op(a, b)))
a <- extractorA(args(0))
b <- extractorB(args(1))
} yield runConstantOp(op(a, b))

def constantFold3[TA, TB, TC](
extractorA: Type => Option[TA],
extractorB: Type => Option[TB],
extractorC: Type => Option[TC],
op: (TA, TB, TC) => Any
): Option[Type] =
expectArgsNum(3)
for {
a <- extractorA(args.head)
a <- extractorA(args(0))
b <- extractorB(args(1))
c <- extractorC(args.last)
} yield ConstantType(Constant(op(a, b, c)))
c <- extractorC(args(2))
} yield runConstantOp(op(a, b, c))

trace(i"compiletime constant fold $this", typr, show = true) {
val name = tycon.symbol.name
val owner = tycon.symbol.owner
val nArgs = args.length
val constantType =
if (defn.isCompiletime_S(tycon.symbol)) {
if (nArgs == 1) constantFold1(natValue, _ + 1)
else None
constantFold1(natValue, _ + 1)
} else if (owner == defn.CompiletimeOpsAnyModuleClass) name match {
case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _)
case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _)
case tpnme.ToString if nArgs == 1 => constantFold1(constValue, _.toString)
case tpnme.IsConst if nArgs == 1 => isConst
case tpnme.Equals => constantFold2(constValue, _ == _)
case tpnme.NotEquals => constantFold2(constValue, _ != _)
case tpnme.ToString => constantFold1(constValue, _.toString)
case tpnme.IsConst => isConst(args.head)
case _ => None
} else if (owner == defn.CompiletimeOpsIntModuleClass) name match {
case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs)
case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => -x)
case tpnme.Abs => constantFold1(intValue, _.abs)
case tpnme.Negate => constantFold1(intValue, x => -x)
//ToString is deprecated for ops.int, and moved to ops.any
case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString)
case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _)
case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _)
case tpnme.Times if nArgs == 2 => constantFold2(intValue, _ * _)
case tpnme.Div if nArgs == 2 => constantFold2(intValue, {
case (_, 0) => throw new TypeError("Division by 0")
case (a, b) => a / b
})
case tpnme.Mod if nArgs == 2 => constantFold2(intValue, {
case (_, 0) => throw new TypeError("Modulo by 0")
case (a, b) => a % b
})
case tpnme.Lt if nArgs == 2 => constantFold2(intValue, _ < _)
case tpnme.Gt if nArgs == 2 => constantFold2(intValue, _ > _)
case tpnme.Ge if nArgs == 2 => constantFold2(intValue, _ >= _)
case tpnme.Le if nArgs == 2 => constantFold2(intValue, _ <= _)
case tpnme.Xor if nArgs == 2 => constantFold2(intValue, _ ^ _)
case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(intValue, _ & _)
case tpnme.BitwiseOr if nArgs == 2 => constantFold2(intValue, _ | _)
case tpnme.ASR if nArgs == 2 => constantFold2(intValue, _ >> _)
case tpnme.LSL if nArgs == 2 => constantFold2(intValue, _ << _)
case tpnme.LSR if nArgs == 2 => constantFold2(intValue, _ >>> _)
case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
case tpnme.NumberOfLeadingZeros if nArgs == 1 => constantFold1(intValue, Integer.numberOfLeadingZeros(_))
case tpnme.ToLong if nArgs == 1 => constantFold1(intValue, _.toLong)
case tpnme.ToFloat if nArgs == 1 => constantFold1(intValue, _.toFloat)
case tpnme.ToDouble if nArgs == 1 => constantFold1(intValue, _.toDouble)
case tpnme.ToString => constantFold1(intValue, _.toString)
case tpnme.Plus => constantFold2(intValue, _ + _)
case tpnme.Minus => constantFold2(intValue, _ - _)
case tpnme.Times => constantFold2(intValue, _ * _)
case tpnme.Div => constantFold2(intValue, _ / _)
case tpnme.Mod => constantFold2(intValue, _ % _)
case tpnme.Lt => constantFold2(intValue, _ < _)
case tpnme.Gt => constantFold2(intValue, _ > _)
case tpnme.Ge => constantFold2(intValue, _ >= _)
case tpnme.Le => constantFold2(intValue, _ <= _)
case tpnme.Xor => constantFold2(intValue, _ ^ _)
case tpnme.BitwiseAnd => constantFold2(intValue, _ & _)
case tpnme.BitwiseOr => constantFold2(intValue, _ | _)
case tpnme.ASR => constantFold2(intValue, _ >> _)
case tpnme.LSL => constantFold2(intValue, _ << _)
case tpnme.LSR => constantFold2(intValue, _ >>> _)
case tpnme.Min => constantFold2(intValue, _ min _)
case tpnme.Max => constantFold2(intValue, _ max _)
case tpnme.NumberOfLeadingZeros => constantFold1(intValue, Integer.numberOfLeadingZeros(_))
case tpnme.ToLong => constantFold1(intValue, _.toLong)
case tpnme.ToFloat => constantFold1(intValue, _.toFloat)
case tpnme.ToDouble => constantFold1(intValue, _.toDouble)
case _ => None
} else if (owner == defn.CompiletimeOpsLongModuleClass) name match {
case tpnme.Abs if nArgs == 1 => constantFold1(longValue, _.abs)
case tpnme.Negate if nArgs == 1 => constantFold1(longValue, x => -x)
case tpnme.Plus if nArgs == 2 => constantFold2(longValue, _ + _)
case tpnme.Minus if nArgs == 2 => constantFold2(longValue, _ - _)
case tpnme.Times if nArgs == 2 => constantFold2(longValue, _ * _)
case tpnme.Div if nArgs == 2 => constantFold2(longValue, {
case (_, 0L) => throw new TypeError("Division by 0")
case (a, b) => a / b
})
case tpnme.Mod if nArgs == 2 => constantFold2(longValue, {
case (_, 0L) => throw new TypeError("Modulo by 0")
case (a, b) => a % b
})
case tpnme.Lt if nArgs == 2 => constantFold2(longValue, _ < _)
case tpnme.Gt if nArgs == 2 => constantFold2(longValue, _ > _)
case tpnme.Ge if nArgs == 2 => constantFold2(longValue, _ >= _)
case tpnme.Le if nArgs == 2 => constantFold2(longValue, _ <= _)
case tpnme.Xor if nArgs == 2 => constantFold2(longValue, _ ^ _)
case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(longValue, _ & _)
case tpnme.BitwiseOr if nArgs == 2 => constantFold2(longValue, _ | _)
case tpnme.ASR if nArgs == 2 => constantFold2(longValue, _ >> _)
case tpnme.LSL if nArgs == 2 => constantFold2(longValue, _ << _)
case tpnme.LSR if nArgs == 2 => constantFold2(longValue, _ >>> _)
case tpnme.Min if nArgs == 2 => constantFold2(longValue, _ min _)
case tpnme.Max if nArgs == 2 => constantFold2(longValue, _ max _)
case tpnme.NumberOfLeadingZeros if nArgs == 1 =>
case tpnme.Abs => constantFold1(longValue, _.abs)
case tpnme.Negate => constantFold1(longValue, x => -x)
case tpnme.Plus => constantFold2(longValue, _ + _)
case tpnme.Minus => constantFold2(longValue, _ - _)
case tpnme.Times => constantFold2(longValue, _ * _)
case tpnme.Div => constantFold2(longValue, _ / _)
case tpnme.Mod => constantFold2(longValue, _ % _)
case tpnme.Lt => constantFold2(longValue, _ < _)
case tpnme.Gt => constantFold2(longValue, _ > _)
case tpnme.Ge => constantFold2(longValue, _ >= _)
case tpnme.Le => constantFold2(longValue, _ <= _)
case tpnme.Xor => constantFold2(longValue, _ ^ _)
case tpnme.BitwiseAnd => constantFold2(longValue, _ & _)
case tpnme.BitwiseOr => constantFold2(longValue, _ | _)
case tpnme.ASR => constantFold2(longValue, _ >> _)
case tpnme.LSL => constantFold2(longValue, _ << _)
case tpnme.LSR => constantFold2(longValue, _ >>> _)
case tpnme.Min => constantFold2(longValue, _ min _)
case tpnme.Max => constantFold2(longValue, _ max _)
case tpnme.NumberOfLeadingZeros =>
constantFold1(longValue, java.lang.Long.numberOfLeadingZeros(_))
case tpnme.ToInt if nArgs == 1 => constantFold1(longValue, _.toInt)
case tpnme.ToFloat if nArgs == 1 => constantFold1(longValue, _.toFloat)
case tpnme.ToDouble if nArgs == 1 => constantFold1(longValue, _.toDouble)
case tpnme.ToInt => constantFold1(longValue, _.toInt)
case tpnme.ToFloat => constantFold1(longValue, _.toFloat)
case tpnme.ToDouble => constantFold1(longValue, _.toDouble)
case _ => None
} else if (owner == defn.CompiletimeOpsFloatModuleClass) name match {
case tpnme.Abs if nArgs == 1 => constantFold1(floatValue, _.abs)
case tpnme.Negate if nArgs == 1 => constantFold1(floatValue, x => -x)
case tpnme.Plus if nArgs == 2 => constantFold2(floatValue, _ + _)
case tpnme.Minus if nArgs == 2 => constantFold2(floatValue, _ - _)
case tpnme.Times if nArgs == 2 => constantFold2(floatValue, _ * _)
case tpnme.Div if nArgs == 2 => constantFold2(floatValue, _ / _)
case tpnme.Mod if nArgs == 2 => constantFold2(floatValue, _ % _)
case tpnme.Lt if nArgs == 2 => constantFold2(floatValue, _ < _)
case tpnme.Gt if nArgs == 2 => constantFold2(floatValue, _ > _)
case tpnme.Ge if nArgs == 2 => constantFold2(floatValue, _ >= _)
case tpnme.Le if nArgs == 2 => constantFold2(floatValue, _ <= _)
case tpnme.Min if nArgs == 2 => constantFold2(floatValue, _ min _)
case tpnme.Max if nArgs == 2 => constantFold2(floatValue, _ max _)
case tpnme.ToInt if nArgs == 1 => constantFold1(floatValue, _.toInt)
case tpnme.ToLong if nArgs == 1 => constantFold1(floatValue, _.toLong)
case tpnme.ToDouble if nArgs == 1 => constantFold1(floatValue, _.toDouble)
case tpnme.Abs => constantFold1(floatValue, _.abs)
case tpnme.Negate => constantFold1(floatValue, x => -x)
case tpnme.Plus => constantFold2(floatValue, _ + _)
case tpnme.Minus => constantFold2(floatValue, _ - _)
case tpnme.Times => constantFold2(floatValue, _ * _)
case tpnme.Div => constantFold2(floatValue, _ / _)
case tpnme.Mod => constantFold2(floatValue, _ % _)
case tpnme.Lt => constantFold2(floatValue, _ < _)
case tpnme.Gt => constantFold2(floatValue, _ > _)
case tpnme.Ge => constantFold2(floatValue, _ >= _)
case tpnme.Le => constantFold2(floatValue, _ <= _)
case tpnme.Min => constantFold2(floatValue, _ min _)
case tpnme.Max => constantFold2(floatValue, _ max _)
case tpnme.ToInt => constantFold1(floatValue, _.toInt)
case tpnme.ToLong => constantFold1(floatValue, _.toLong)
case tpnme.ToDouble => constantFold1(floatValue, _.toDouble)
case _ => None
} else if (owner == defn.CompiletimeOpsDoubleModuleClass) name match {
case tpnme.Abs if nArgs == 1 => constantFold1(doubleValue, _.abs)
case tpnme.Negate if nArgs == 1 => constantFold1(doubleValue, x => -x)
case tpnme.Plus if nArgs == 2 => constantFold2(doubleValue, _ + _)
case tpnme.Minus if nArgs == 2 => constantFold2(doubleValue, _ - _)
case tpnme.Times if nArgs == 2 => constantFold2(doubleValue, _ * _)
case tpnme.Div if nArgs == 2 => constantFold2(doubleValue, _ / _)
case tpnme.Mod if nArgs == 2 => constantFold2(doubleValue, _ % _)
case tpnme.Lt if nArgs == 2 => constantFold2(doubleValue, _ < _)
case tpnme.Gt if nArgs == 2 => constantFold2(doubleValue, _ > _)
case tpnme.Ge if nArgs == 2 => constantFold2(doubleValue, _ >= _)
case tpnme.Le if nArgs == 2 => constantFold2(doubleValue, _ <= _)
case tpnme.Min if nArgs == 2 => constantFold2(doubleValue, _ min _)
case tpnme.Max if nArgs == 2 => constantFold2(doubleValue, _ max _)
case tpnme.ToInt if nArgs == 1 => constantFold1(doubleValue, _.toInt)
case tpnme.ToLong if nArgs == 1 => constantFold1(doubleValue, _.toLong)
case tpnme.ToFloat if nArgs == 1 => constantFold1(doubleValue, _.toFloat)
case tpnme.Abs => constantFold1(doubleValue, _.abs)
case tpnme.Negate => constantFold1(doubleValue, x => -x)
case tpnme.Plus => constantFold2(doubleValue, _ + _)
case tpnme.Minus => constantFold2(doubleValue, _ - _)
case tpnme.Times => constantFold2(doubleValue, _ * _)
case tpnme.Div => constantFold2(doubleValue, _ / _)
case tpnme.Mod => constantFold2(doubleValue, _ % _)
case tpnme.Lt => constantFold2(doubleValue, _ < _)
case tpnme.Gt => constantFold2(doubleValue, _ > _)
case tpnme.Ge => constantFold2(doubleValue, _ >= _)
case tpnme.Le => constantFold2(doubleValue, _ <= _)
case tpnme.Min => constantFold2(doubleValue, _ min _)
case tpnme.Max => constantFold2(doubleValue, _ max _)
case tpnme.ToInt => constantFold1(doubleValue, _.toInt)
case tpnme.ToLong => constantFold1(doubleValue, _.toLong)
case tpnme.ToFloat => constantFold1(doubleValue, _.toFloat)
case _ => None
} else if (owner == defn.CompiletimeOpsStringModuleClass) name match {
case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _)
case tpnme.Length if nArgs == 1 => constantFold1(stringValue, _.length)
case tpnme.Matches if nArgs == 2 => constantFold2(stringValue, _ matches _)
case tpnme.Substring if nArgs == 3 =>
case tpnme.Plus => constantFold2(stringValue, _ + _)
case tpnme.Length => constantFold1(stringValue, _.length)
case tpnme.Matches => constantFold2(stringValue, _ matches _)
case tpnme.Substring =>
constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
case _ => None
} else if (owner == defn.CompiletimeOpsBooleanModuleClass) name match {
case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => !x)
case tpnme.And if nArgs == 2 => constantFold2(boolValue, _ && _)
case tpnme.Or if nArgs == 2 => constantFold2(boolValue, _ || _)
case tpnme.Xor if nArgs == 2 => constantFold2(boolValue, _ ^ _)
case tpnme.Not => constantFold1(boolValue, x => !x)
case tpnme.And => constantFold2(boolValue, _ && _)
case tpnme.Or => constantFold2(boolValue, _ || _)
case tpnme.Xor => constantFold2(boolValue, _ ^ _)
case _ => None
} else None

Expand Down

0 comments on commit d433ce8

Please sign in to comment.