Skip to content

Commit

Permalink
Adds compiletime.ops.{long, float, double}, adds other ops, and fixes…
Browse files Browse the repository at this point in the history
… termref type not being considered.

fix check file

more ops wip

Added ops.float and ops.double
  • Loading branch information
soronpo committed Aug 27, 2021
1 parent 7a6cabe commit 9570a88
Show file tree
Hide file tree
Showing 16 changed files with 1,099 additions and 39 deletions.
35 changes: 31 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ class Definitions {
@tu lazy val CompiletimeOpsPackage: Symbol = requiredPackage("scala.compiletime.ops")
@tu lazy val CompiletimeOpsAnyModuleClass: Symbol = requiredModule("scala.compiletime.ops.any").moduleClass
@tu lazy val CompiletimeOpsIntModuleClass: Symbol = requiredModule("scala.compiletime.ops.int").moduleClass
@tu lazy val CompiletimeOpsLongModuleClass: Symbol = requiredModule("scala.compiletime.ops.long").moduleClass
@tu lazy val CompiletimeOpsFloatModuleClass: Symbol = requiredModule("scala.compiletime.ops.float").moduleClass
@tu lazy val CompiletimeOpsDoubleModuleClass: Symbol = requiredModule("scala.compiletime.ops.double").moduleClass
@tu lazy val CompiletimeOpsStringModuleClass: Symbol = requiredModule("scala.compiletime.ops.string").moduleClass
@tu lazy val CompiletimeOpsBooleanModuleClass: Symbol = requiredModule("scala.compiletime.ops.boolean").moduleClass

Expand Down Expand Up @@ -1071,19 +1074,40 @@ class Definitions {
final def isCompiletime_S(sym: Symbol)(using Context): Boolean =
sym.name == tpnme.S && sym.owner == CompiletimeOpsIntModuleClass

private val compiletimePackageAnyTypes: Set[Name] = Set(tpnme.Equals, tpnme.NotEquals)
private val compiletimePackageIntTypes: Set[Name] = Set(
private val compiletimePackageAnyTypes: Set[Name] = Set(
tpnme.Equals, tpnme.NotEquals, tpnme.IsConst, tpnme.ToString
)
private val compiletimePackageNumericTypes: Set[Name] = Set(
tpnme.Plus, tpnme.Minus, tpnme.Times, tpnme.Div, tpnme.Mod,
tpnme.Lt, tpnme.Gt, tpnme.Ge, tpnme.Le,
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max, tpnme.ToString,
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max
)
private val compiletimePackageIntTypes: Set[Name] = compiletimePackageNumericTypes ++ Set[Name](
tpnme.ToString, //ToString is moved to ops.any and deprecated for ops.int
tpnme.NumberOfLeadingZeros, tpnme.ToLong, tpnme.ToFloat, tpnme.ToDouble,
tpnme.Xor, tpnme.BitwiseAnd, tpnme.BitwiseOr, tpnme.ASR, tpnme.LSL, tpnme.LSR
)
private val compiletimePackageLongTypes: Set[Name] = compiletimePackageNumericTypes ++ Set[Name](
tpnme.NumberOfLeadingZeros, tpnme.ToInt, tpnme.ToFloat, tpnme.ToDouble,
tpnme.Xor, tpnme.BitwiseAnd, tpnme.BitwiseOr, tpnme.ASR, tpnme.LSL, tpnme.LSR
)
private val compiletimePackageFloatTypes: Set[Name] = compiletimePackageNumericTypes ++ Set[Name](
tpnme.ToInt, tpnme.ToLong, tpnme.ToDouble
)
private val compiletimePackageDoubleTypes: Set[Name] = compiletimePackageNumericTypes ++ Set[Name](
tpnme.ToInt, tpnme.ToLong, tpnme.ToFloat
)
private val compiletimePackageBooleanTypes: Set[Name] = Set(tpnme.Not, tpnme.Xor, tpnme.And, tpnme.Or)
private val compiletimePackageStringTypes: Set[Name] = Set(tpnme.Plus)
private val compiletimePackageStringTypes: Set[Name] = Set(
tpnme.Plus, tpnme.Length, tpnme.Substring, tpnme.Matches
)
private val compiletimePackageOpTypes: Set[Name] =
Set(tpnme.S)
++ compiletimePackageAnyTypes
++ compiletimePackageIntTypes
++ compiletimePackageLongTypes
++ compiletimePackageFloatTypes
++ compiletimePackageDoubleTypes
++ compiletimePackageBooleanTypes
++ compiletimePackageStringTypes

Expand All @@ -1093,6 +1117,9 @@ class Definitions {
isCompiletime_S(sym)
|| sym.owner == CompiletimeOpsAnyModuleClass && compiletimePackageAnyTypes.contains(sym.name)
|| sym.owner == CompiletimeOpsIntModuleClass && compiletimePackageIntTypes.contains(sym.name)
|| sym.owner == CompiletimeOpsLongModuleClass && compiletimePackageLongTypes.contains(sym.name)
|| sym.owner == CompiletimeOpsFloatModuleClass && compiletimePackageFloatTypes.contains(sym.name)
|| sym.owner == CompiletimeOpsDoubleModuleClass && compiletimePackageDoubleTypes.contains(sym.name)
|| sym.owner == CompiletimeOpsBooleanModuleClass && compiletimePackageBooleanTypes.contains(sym.name)
|| sym.owner == CompiletimeOpsStringModuleClass && compiletimePackageStringTypes.contains(sym.name)
)
Expand Down
55 changes: 32 additions & 23 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,29 +211,38 @@ object StdNames {
final val IOOBException: N = "IndexOutOfBoundsException"
final val FunctionXXL: N = "FunctionXXL"

final val Abs: N = "Abs"
final val And: N = "&&"
final val BitwiseAnd: N = "BitwiseAnd"
final val BitwiseOr: N = "BitwiseOr"
final val Div: N = "/"
final val Equals: N = "=="
final val Ge: N = ">="
final val Gt: N = ">"
final val Le: N = "<="
final val Lt: N = "<"
final val Max: N = "Max"
final val Min: N = "Min"
final val Minus: N = "-"
final val Mod: N = "%"
final val Negate: N = "Negate"
final val Not: N = "!"
final val NotEquals: N = "!="
final val Or: N = "||"
final val Plus: N = "+"
final val S: N = "S"
final val Times: N = "*"
final val ToString: N = "ToString"
final val Xor: N = "^"
final val Abs: N = "Abs"
final val And: N = "&&"
final val BitwiseAnd: N = "BitwiseAnd"
final val BitwiseOr: N = "BitwiseOr"
final val Div: N = "/"
final val Equals: N = "=="
final val Ge: N = ">="
final val Gt: N = ">"
final val IsConst: N = "IsConst"
final val Le: N = "<="
final val Length: N = "Length"
final val Lt: N = "<"
final val Matches: N = "Matches"
final val Max: N = "Max"
final val Min: N = "Min"
final val Minus: N = "-"
final val Mod: N = "%"
final val Negate: N = "Negate"
final val Not: N = "!"
final val NotEquals: N = "!="
final val NumberOfLeadingZeros: N = "NumberOfLeadingZeros"
final val Or: N = "||"
final val Plus: N = "+"
final val S: N = "S"
final val Substring: N = "Substring"
final val Times: N = "*"
final val ToInt: N = "ToInt"
final val ToLong: N = "ToLong"
final val ToFloat: N = "ToFloat"
final val ToDouble: N = "ToDouble"
final val ToString: N = "ToString"
final val Xor: N = "^"

final val ClassfileAnnotation: N = "ClassfileAnnotation"
final val ClassManifest: N = "ClassManifest"
Expand Down
134 changes: 126 additions & 8 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4174,37 +4174,76 @@ object Types {

def tryCompiletimeConstantFold(using Context): Type = tycon match {
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
def constValue(tp: Type): Option[Any] = tp.dealias match {
extension (tp : Type) def fixForEvaluation : Type =
tp.normalized.dealias match {
case tp : TermRef => tp.underlying
case tp => tp
}

def constValue(tp: Type): Option[Any] = tp.fixForEvaluation match {
case ConstantType(Constant(n)) => Some(n)
case _ => None
}

def boolValue(tp: Type): Option[Boolean] = tp.dealias match {
def boolValue(tp: Type): Option[Boolean] = tp.fixForEvaluation match {
case ConstantType(Constant(n: Boolean)) => Some(n)
case _ => None
}

def intValue(tp: Type): Option[Int] = tp.dealias match {
def intValue(tp: Type): Option[Int] = tp.fixForEvaluation match {
case ConstantType(Constant(n: Int)) => Some(n)
case _ => None
}

def stringValue(tp: Type): Option[String] = tp.dealias match {
case ConstantType(Constant(n: String)) => Some(n)
def longValue(tp: Type): Option[Long] = tp.fixForEvaluation match {
case ConstantType(Constant(n: Long)) => Some(n)
case _ => None
}

def floatValue(tp: Type): Option[Float] = tp.fixForEvaluation match {
case ConstantType(Constant(n: Float)) => Some(n)
case _ => None
}

def doubleValue(tp: Type): Option[Double] = tp.fixForEvaluation match {
case ConstantType(Constant(n: Double)) => Some(n)
case _ => None
}

def stringValue(tp: Type): Option[String] = tp.fixForEvaluation match {
case ConstantType(Constant(n: String)) => Some(n)
case _ => None
}
def isConst : Option[Type] = args.head.fixForEvaluation match {
case ConstantType(_) => Some(ConstantType(Constant(true)))
case _ => Some(ConstantType(Constant(false)))
}
def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)

def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
extractor(args.head.normalized).map(a => ConstantType(Constant(op(a))))
extractor(args.head).map(a => ConstantType(Constant(op(a))))

def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
constantFold2AB(extractor, extractor, op)

def constantFold2AB[TA, TB](extractorA: Type => Option[TA], extractorB: Type => Option[TB], op: (TA, TB) => Any): Option[Type] =
for {
a <- extractor(args.head.normalized)
b <- extractor(args.tail.head.normalized)
a <- extractorA(args.head)
b <- extractorB(args.last)
} yield ConstantType(Constant(op(a, b)))

def constantFold3[TA, TB, TC](
extractorA: Type => Option[TA],
extractorB: Type => Option[TB],
extractorC: Type => Option[TC],
op: (TA, TB, TC) => Any
): Option[Type] =
for {
a <- extractorA(args.head)
b <- extractorB(args(1))
c <- extractorC(args.last)
} yield ConstantType(Constant(op(a, b, c)))

trace(i"compiletime constant fold $this", typr, show = true) {
val name = tycon.symbol.name
val owner = tycon.symbol.owner
Expand All @@ -4216,10 +4255,13 @@ object Types {
} else if (owner == defn.CompiletimeOpsAnyModuleClass) name match {
case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _)
case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _)
case tpnme.ToString if nArgs == 1 => constantFold1(constValue, _.toString)
case tpnme.IsConst if nArgs == 1 => isConst
case _ => None
} else if (owner == defn.CompiletimeOpsIntModuleClass) name match {
case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs)
case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => -x)
//ToString is deprecated for ops.int, and moved to ops.any
case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString)
case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _)
case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _)
Expand All @@ -4244,9 +4286,85 @@ object Types {
case tpnme.LSR if nArgs == 2 => constantFold2(intValue, _ >>> _)
case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
case tpnme.NumberOfLeadingZeros if nArgs == 1 => constantFold1(intValue, Integer.numberOfLeadingZeros(_))
case tpnme.ToLong if nArgs == 1 => constantFold1(intValue, _.toLong)
case tpnme.ToFloat if nArgs == 1 => constantFold1(intValue, _.toFloat)
case tpnme.ToDouble if nArgs == 1 => constantFold1(intValue, _.toDouble)
case _ => None
} else if (owner == defn.CompiletimeOpsLongModuleClass) name match {
case tpnme.Abs if nArgs == 1 => constantFold1(longValue, _.abs)
case tpnme.Negate if nArgs == 1 => constantFold1(longValue, x => -x)
case tpnme.Plus if nArgs == 2 => constantFold2(longValue, _ + _)
case tpnme.Minus if nArgs == 2 => constantFold2(longValue, _ - _)
case tpnme.Times if nArgs == 2 => constantFold2(longValue, _ * _)
case tpnme.Div if nArgs == 2 => constantFold2(longValue, {
case (_, 0L) => throw new TypeError("Division by 0")
case (a, b) => a / b
})
case tpnme.Mod if nArgs == 2 => constantFold2(longValue, {
case (_, 0L) => throw new TypeError("Modulo by 0")
case (a, b) => a % b
})
case tpnme.Lt if nArgs == 2 => constantFold2(longValue, _ < _)
case tpnme.Gt if nArgs == 2 => constantFold2(longValue, _ > _)
case tpnme.Ge if nArgs == 2 => constantFold2(longValue, _ >= _)
case tpnme.Le if nArgs == 2 => constantFold2(longValue, _ <= _)
case tpnme.Xor if nArgs == 2 => constantFold2(longValue, _ ^ _)
case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(longValue, _ & _)
case tpnme.BitwiseOr if nArgs == 2 => constantFold2(longValue, _ | _)
case tpnme.ASR if nArgs == 2 => constantFold2(longValue, _ >> _)
case tpnme.LSL if nArgs == 2 => constantFold2(longValue, _ << _)
case tpnme.LSR if nArgs == 2 => constantFold2(longValue, _ >>> _)
case tpnme.Min if nArgs == 2 => constantFold2(longValue, _ min _)
case tpnme.Max if nArgs == 2 => constantFold2(longValue, _ max _)
case tpnme.NumberOfLeadingZeros if nArgs == 1 =>
constantFold1(longValue, java.lang.Long.numberOfLeadingZeros(_))
case tpnme.ToInt if nArgs == 1 => constantFold1(longValue, _.toInt)
case tpnme.ToFloat if nArgs == 1 => constantFold1(longValue, _.toFloat)
case tpnme.ToDouble if nArgs == 1 => constantFold1(longValue, _.toDouble)
case _ => None
} else if (owner == defn.CompiletimeOpsFloatModuleClass) name match {
case tpnme.Abs if nArgs == 1 => constantFold1(floatValue, _.abs)
case tpnme.Negate if nArgs == 1 => constantFold1(floatValue, x => -x)
case tpnme.Plus if nArgs == 2 => constantFold2(floatValue, _ + _)
case tpnme.Minus if nArgs == 2 => constantFold2(floatValue, _ - _)
case tpnme.Times if nArgs == 2 => constantFold2(floatValue, _ * _)
case tpnme.Div if nArgs == 2 => constantFold2(floatValue, _ / _)
case tpnme.Mod if nArgs == 2 => constantFold2(floatValue, _ % _)
case tpnme.Lt if nArgs == 2 => constantFold2(floatValue, _ < _)
case tpnme.Gt if nArgs == 2 => constantFold2(floatValue, _ > _)
case tpnme.Ge if nArgs == 2 => constantFold2(floatValue, _ >= _)
case tpnme.Le if nArgs == 2 => constantFold2(floatValue, _ <= _)
case tpnme.Min if nArgs == 2 => constantFold2(floatValue, _ min _)
case tpnme.Max if nArgs == 2 => constantFold2(floatValue, _ max _)
case tpnme.ToInt if nArgs == 1 => constantFold1(floatValue, _.toInt)
case tpnme.ToLong if nArgs == 1 => constantFold1(floatValue, _.toLong)
case tpnme.ToDouble if nArgs == 1 => constantFold1(floatValue, _.toDouble)
case _ => None
} else if (owner == defn.CompiletimeOpsDoubleModuleClass) name match {
case tpnme.Abs if nArgs == 1 => constantFold1(doubleValue, _.abs)
case tpnme.Negate if nArgs == 1 => constantFold1(doubleValue, x => -x)
case tpnme.Plus if nArgs == 2 => constantFold2(doubleValue, _ + _)
case tpnme.Minus if nArgs == 2 => constantFold2(doubleValue, _ - _)
case tpnme.Times if nArgs == 2 => constantFold2(doubleValue, _ * _)
case tpnme.Div if nArgs == 2 => constantFold2(doubleValue, _ / _)
case tpnme.Mod if nArgs == 2 => constantFold2(doubleValue, _ % _)
case tpnme.Lt if nArgs == 2 => constantFold2(doubleValue, _ < _)
case tpnme.Gt if nArgs == 2 => constantFold2(doubleValue, _ > _)
case tpnme.Ge if nArgs == 2 => constantFold2(doubleValue, _ >= _)
case tpnme.Le if nArgs == 2 => constantFold2(doubleValue, _ <= _)
case tpnme.Min if nArgs == 2 => constantFold2(doubleValue, _ min _)
case tpnme.Max if nArgs == 2 => constantFold2(doubleValue, _ max _)
case tpnme.ToInt if nArgs == 1 => constantFold1(doubleValue, _.toInt)
case tpnme.ToLong if nArgs == 1 => constantFold1(doubleValue, _.toLong)
case tpnme.ToFloat if nArgs == 1 => constantFold1(doubleValue, _.toFloat)
case _ => None
} else if (owner == defn.CompiletimeOpsStringModuleClass) name match {
case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _)
case tpnme.Length if nArgs == 1 => constantFold1(stringValue, _.length)
case tpnme.Matches if nArgs == 2 => constantFold2(stringValue, _ matches _)
case tpnme.Substring if nArgs == 3 =>
constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
case _ => None
} else if (owner == defn.CompiletimeOpsBooleanModuleClass) name match {
case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => !x)
Expand Down
19 changes: 19 additions & 0 deletions library/src/scala/compiletime/ops/any.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,22 @@ object any:
* @syntax markdown
*/
type !=[X, Y] <: Boolean

/** Tests if a type is a constant.
* ```scala
* val c1: IsConst[1] = true
* val c2: IsConst["hi"] = true
* val c3: IsConst[false] = true
* ```
* @syntax markdown
*/
type IsConst[X] <: Boolean

/** String conversion of a constant singleton type.
* ```scala
* val s1: ToString[1] = "1"
* val sTrue: ToString[true] = "true"
* ```
* @syntax markdown
*/
type ToString[X] <: String
Loading

0 comments on commit 9570a88

Please sign in to comment.