diff --git a/build.sbt b/build.sbt index 1ce7d97..48c86be 100644 --- a/build.sbt +++ b/build.sbt @@ -68,6 +68,16 @@ lazy val zioQuill = (project in file("modules/neotype-zio-quill")) ) .dependsOn(core) +lazy val zio = (project in file("modules/neotype-zio")) + .settings( + name := "neotype-zio", + sharedSettings, + libraryDependencies ++= Seq( + "dev.zio" %% "zio" % zioVersion + ) + ) + .dependsOn(core) + lazy val examples = (project in file("examples")) .settings( name := "neotype-examples", diff --git a/examples/src/main/scala/neotype/examples/Main.scala b/examples/src/main/scala/neotype/examples/Main.scala index 18a5716..4df7e58 100644 --- a/examples/src/main/scala/neotype/examples/Main.scala +++ b/examples/src/main/scala/neotype/examples/Main.scala @@ -3,6 +3,7 @@ package neotype.examples import neotype.* object Main extends App: + Email("kit@gmail.com") // OK FourSeasons("Spring") // OK NonEmptyString("Good") // OK FiveElements("REALLY LONG ELEMENT") // OK diff --git a/examples/src/main/scala/neotype/examples/Newtypes.scala b/examples/src/main/scala/neotype/examples/Newtypes.scala index b35b0c5..0a74ff8 100644 --- a/examples/src/main/scala/neotype/examples/Newtypes.scala +++ b/examples/src/main/scala/neotype/examples/Newtypes.scala @@ -2,15 +2,23 @@ package neotype.examples import neotype.* +type NonEmptyString = NonEmptyString.Type given NonEmptyString: Newtype[String] with inline def validate(value: String): Boolean = - value.nonEmpty + value.reverse.nonEmpty +type Email = Email.Type +given Email: Newtype[String] with + inline def validate(value: String): Boolean = + value.contains("@") && value.contains(".") + +type FourSeasons = FourSeasons.Type given FourSeasons: Newtype[String] with inline def validate(value: String): Boolean = val seasons = Set("Spring", "Summer", "Autumn", "Winter") seasons.contains(value) +type FiveElements = FiveElements.Type given FiveElements: Newtype[String] with inline def validate(value: String): Boolean = value match diff --git a/modules/core/src/main/scala/neotype/Calc.scala b/modules/core/src/main/scala/neotype/Calc.scala index 3f7f3f8..b3c1f63 100644 --- a/modules/core/src/main/scala/neotype/Calc.scala +++ b/modules/core/src/main/scala/neotype/Calc.scala @@ -4,97 +4,44 @@ import scala.annotation.tailrec import StringFormatting.* import scala.quoted.* import scala.util.matching.Regex +import CustomFromExpr.given enum Calc[A]: case Constant(value: A) // Comparisons - case BinaryOp[A, B](lhs: Calc[A], rhs: Calc[B], op: (A, B) => Boolean, name: String) extends Calc[Boolean] - - // Boolean Algebra - case And(lhs: Calc[Boolean], rhs: Calc[Boolean]) extends Calc[Boolean] - case Or(lhs: Calc[Boolean], rhs: Calc[Boolean]) extends Calc[Boolean] - case Not(calc: Calc[Boolean]) extends Calc[Boolean] - - // Numeric Operations - case Add[Num](lhs: Calc[Num], rhs: Calc[Num])(using val numeric: Numeric[Num]) extends Calc[Num] - case ToDouble[Num](num: Calc[Num])(using val numeric: Numeric[Num]) extends Calc[Double] - - // String Operations - case Length(str: Calc[String]) extends Calc[Int] - case Substring(str: Calc[String], start: Calc[Int], end: Calc[Int]) extends Calc[String] - case ToUpper(str: Calc[String]) extends Calc[String] - case ToLower(str: Calc[String]) extends Calc[String] - case StartsWith(str: Calc[String], prefix: Calc[String]) extends Calc[Boolean] - case EndsWith(str: Calc[String], suffix: Calc[String]) extends Calc[Boolean] - case Contains(str: Calc[String], substr: Calc[String]) extends Calc[Boolean] - case MatchesRegex(str: Calc[String], regex: Calc[String]) extends Calc[Boolean] - case RegexMatches(regex: Calc[Regex], str: Calc[String]) extends Calc[Boolean] - case StringNonEmpty(str: Calc[String]) extends Calc[Boolean] - case StringIsEmpty(str: Calc[String]) extends Calc[Boolean] - case StringApply(str: Calc[String], index: Calc[Int]) extends Calc[Char] - case StringTrim(str: Calc[String]) extends Calc[String] - - // Set Operations - case SetContains[A](set: Calc[Set[A]], elem: Calc[A]) extends Calc[Boolean] + case BinaryOp[A, B, C](lhs: Calc[A], rhs: Calc[B], op: (A, B) => C, render: (String, String) => String) + extends Calc[C] + case UnaryOp[A, B](calc: Calc[A], op: A => B, render: String => String) extends Calc[B] + case TernaryOp[A, B, C, D]( + cond: Calc[A], + lhs: Calc[B], + rhs: Calc[C], + op: (A, B, C) => D, + render: (String, String, String) => String + ) extends Calc[D] // Custom - case WithMessage(calc: Calc[A], message: String) extends Calc[A] - case Block(defs: List[CalcDef[?]], calc: Calc[A]) extends Calc[A] - case Reference[A](name: String) extends Calc[A] + case Block(defs: List[CalcValDef[?]], calc: Calc[A]) extends Calc[A] + case Reference[A](name: String) extends Calc[A] case MatchExpr[A](expr: Calc[A], cases: List[CalcMatchCase[A]]) extends Calc[A] - def renderConstant(value: Any): String = - value match - case s: String => s""""$s"""".green - case c: Char => s"'$c'".green - case set: Set[?] => - val elems = set.map(renderConstant).mkString(", ".reset) - s"Set(".reset + elems + ")".reset - case regex: Regex => - s"""${renderConstant(regex.toString)}.r""" - case _ => value.toString.cyan - def render(using ctx: Map[String, String]): String = this match - case Constant(value) => renderConstant(value) + case Constant(value) => Calc.renderConstant(value) // Comparisons - case BinaryOp(lhs, rhs, _, name) => s"${lhs.render} $name ${rhs.render}" - - // Boolean Operations - case And(lhs, rhs) => s"${lhs.render} && ${rhs.render}" - case Or(lhs, rhs) => s"${lhs.render} || ${rhs.render}" - case Not(calc) => s"!${calc.render}" - - case Add(lhs, rhs) => s"${lhs.render} + ${rhs.render}" - case ToDouble(num) => s"toDouble(${num.render})" - - // String Operations - case Length(str) => s"${str.render}.length" - case Substring(str, start, end) => s"${str.render}.substring(${start.render}, ${end.render})" - case ToUpper(str) => s"${str.render}.toUpperCase" - case ToLower(str) => s"${str.render}.toLowerCase" - case MatchesRegex(str, regex) => s"${str.render}.matches(${regex.render})" - case StartsWith(str, prefix) => s"${str.render}.startsWith(${prefix.render})" - case EndsWith(str, suffix) => s"${str.render}.endsWith(${suffix.render})" - case Contains(str, substr) => s"${str.render}.contains(${substr.render})" - case RegexMatches(regex, str) => s"${regex.render}.matches(${str.render})" - case StringNonEmpty(str) => s"${str.render}.nonEmpty" - case StringIsEmpty(str) => s"${str.render}.isEmpty" - case StringApply(str, index) => s"${str.render}(${index.render})" - case StringTrim(str) => s"${str.render}.trim" - - // Set Operations - case SetContains(set, elem) => s"${set.render}.contains(${elem.render})" + case UnaryOp(calc, _, show) => show(calc.render) + case BinaryOp(lhs, rhs, _, show) => show(lhs.render, rhs.render) + case TernaryOp(cond, lhs, rhs, _, show) => show(cond.render, lhs.render, rhs.render) - case WithMessage(calc, message) => s"${calc.render} // $message" case Block(defs, calc) => val newCtx = defs.foldLeft(ctx) { (ctx, defn) => ctx + (defn.name -> defn.calc.render(using ctx)) } calc.render(using newCtx) + case Reference(name) => ctx(name) case MatchExpr(expr, cases) => @@ -107,42 +54,11 @@ enum Calc[A]: def result(using context: Map[String, Any], q: Quotes): A = this match - case Constant(value) => value - - // Comparisons - case BinaryOp(lhs, rhs, op, name) => - val cool = op(lhs.result, rhs.result).asInstanceOf[A] - cool - - // Boolean Operations - case And(lhs, rhs) => (lhs.result && rhs.result).asInstanceOf[A] - case Or(lhs, rhs) => (lhs.result || rhs.result).asInstanceOf[A] - case Not(calc) => (!calc.result).asInstanceOf[A] - - case add @ Add(lhs, rhs): Add[num] => - import add.numeric - summon[Numeric[num]].plus(lhs.result, rhs.result).asInstanceOf[A] - case calc @ ToDouble(num): ToDouble[num] => - import calc.numeric - summon[Numeric[num]].toDouble(num.result).asInstanceOf[A] - - // String Operations - case Length(str) => str.result.length - case Substring(str, start, end) => str.result.substring(start.result, end.result) - case ToUpper(str) => str.result.toUpperCase - case ToLower(str) => str.result.toLowerCase - case MatchesRegex(str, regex) => str.result.matches(regex.result) - case StartsWith(str, prefix) => str.result.startsWith(prefix.result) - case EndsWith(str, suffix) => str.result.endsWith(suffix.result) - case Contains(str, substr) => str.result.contains(substr.result) - case RegexMatches(regex, str) => regex.result.matches(str.result) - case StringNonEmpty(str) => str.result.nonEmpty - case StringIsEmpty(str) => str.result.isEmpty - case StringApply(str, index) => str.result(index.result) - case StringTrim(str) => str.result.trim - - // Set Operations - case SetContains(set, elem) => set.result.contains(elem.result) + case Constant(value) => value + case Reference(name) => context(name).asInstanceOf[A] + case UnaryOp(calc, op, _) => op(calc.result).asInstanceOf[A] + case BinaryOp(lhs, rhs, op, _) => op(lhs.result, rhs.result).asInstanceOf[A] + case TernaryOp(a, b, c, op, _) => op(a.result, b.result, c.result).asInstanceOf[A] case MatchExpr(expr, cases) => val exprResult = expr.result @@ -156,15 +72,13 @@ enum Calc[A]: q.reflect.report.errorAndAbort(s"CalcMatchCase not found for $exprResult") throw MatchError(exprResult) - case WithMessage(calc, _) => - calc.result case Block(defs, calc) => val newContext = defs.foldLeft(context) { (ctx, defn) => - ctx + (defn.name -> defn.calc.result) + ctx + (defn.name -> defn.calc.result(using ctx)) } calc.result(using newContext) - case Reference(name) => - context(name).asInstanceOf[A] + end result +end Calc case class BinaryOpMatch(name: String): def unapply(using Quotes)(expr: Expr[Any]): Option[(Calc[Any], Calc[Any])] = @@ -172,164 +86,160 @@ case class BinaryOpMatch(name: String): expr.asTerm match case Apply(Select(lhs, `name`), List(rhs)) => (lhs.asExpr, rhs.asExpr) match - case (Calc(lhs), Calc(rhs)) => + case (Calc[Any](lhs), Calc[Any](rhs)) => Some((lhs, rhs)) - case _ => None + case _ => + None -val MatchBinEq = BinaryOpMatch("==") -val MatchBinLt = BinaryOpMatch("<") -val MatchBinGt = BinaryOpMatch(">") -val MatchBinLte = BinaryOpMatch("<=") -val MatchBinGte = BinaryOpMatch(">=") +val MatchBinEq = BinaryOpMatch("==") +val MatchBinLt = BinaryOpMatch("<") +val MatchBinGt = BinaryOpMatch(">") +val MatchBinLte = BinaryOpMatch("<=") +val MatchBinGte = BinaryOpMatch(">=") +val MatchBinMinus = BinaryOpMatch("-") +val MatchBinPlus = BinaryOpMatch("+") +val MatchBinTimes = BinaryOpMatch("*") +val MatchBinDivide = BinaryOpMatch("/") +val MatchBinMod = BinaryOpMatch("%") object Calc: - def unapply[A: Type](expr: Expr[A])(using Quotes): Option[Calc[A]] = + def renderConstant(value: Any): String = + value match + case s: String => s""""$s"""".green + case c: Char => s"'$c'".green + case set: Set[?] => + val elems = set.map(renderConstant).mkString(", ".reset) + s"Set(".reset + elems + ")".reset + case list: List[?] => + val elems = list.map(renderConstant).mkString(", ".reset) + s"List(".reset + elems + ")".reset + case regex: Regex => + s"""${renderConstant(regex.toString)}.r""" + case long: Long => s"${long}L".cyan + case _ => value.toString.cyan + + def infix(op: String) = (a: String, b: String) => s"$a $op $b" + def call(op: String) = (a: String, b: String) => s"$a.$op($b)" + def call2(op: String) = (a: String, b: String, c: String) => s"$a.$op($b, $c)" + def nullary(op: String) = (a: String) => s"$a.$op" + def prefix(op: String) = (a: String) => s"$op$a" + + def unapply[A](expr: Expr[Any])(using Quotes): Option[Calc[A]] = import quotes.reflect.* - expr match - case '{ ${ Expr(int) }: Int } => Some(Calc.Constant(int).asInstanceOf[Calc[A]]) - case '{ ${ Expr(string) }: String } => Some(Calc.Constant(string).asInstanceOf[Calc[A]]) - case '{ ${ Expr(bool) }: Boolean } => Some(Calc.Constant(bool).asInstanceOf[Calc[A]]) - case '{ ${ Expr(long) }: Long } => Some(Calc.Constant(long).asInstanceOf[Calc[A]]) - case '{ ${ Expr(double) }: Double } => Some(Calc.Constant(double).asInstanceOf[Calc[A]]) - case '{ ${ Expr(float) }: Float } => Some(Calc.Constant(float).asInstanceOf[Calc[A]]) - case '{ ${ Expr(char) }: Char } => Some(Calc.Constant(char).asInstanceOf[Calc[A]]) - case '{ ${ Expr(byte) }: Byte } => Some(Calc.Constant(byte).asInstanceOf[Calc[A]]) - case '{ ${ Expr(short) }: Short } => Some(Calc.Constant(short).asInstanceOf[Calc[A]]) - case '{ () } => Some(Calc.Constant(()).asInstanceOf[Calc[A]]) - case '{ BigInt(${ Expr(string) }: String) } => Some(Calc.Constant(BigInt(string)).asInstanceOf[Calc[A]]) - case '{ BigDecimal(${ Expr(string) }: String) } => Some(Calc.Constant(BigDecimal(string)).asInstanceOf[Calc[A]]) - case '{ (${ Expr(string) }: String).r } => Some(Calc.Constant(string.r).asInstanceOf[Calc[A]]) - // Set - case '{ ${ Expr[Set[String]](set) }: Set[String] } => Some(Calc.Constant(set).asInstanceOf[Calc[A]]) - case Unseal(Ident(name)) => Some(Calc.Reference(name).asInstanceOf[Calc[A]]) + import quotes.reflect as r + val result: Option[Calc[?]] = expr match + // BASIC TYPES + case Unseal(r.Literal(constant)) => Some(Calc.Constant(constant.value)) + case '{ BigInt(${ Expr(string) }: String) } => Some(Calc.Constant(BigInt(string))) + case '{ BigDecimal(${ Expr(string) }: String) } => Some(Calc.Constant(BigDecimal(string))) + case '{ (${ Expr(string) }: String).r } => Some(Calc.Constant(string.r)) + + // CONTAINER TYPES + case '{ type a; ${ Expr[Set[`a`]](set) }: Set[`a`] } => Some(Calc.Constant(set)) + case '{ type a; ${ Expr[List[`a`]](list) }: List[`a`] } => Some(Calc.Constant(list)) + + case Unseal(Ident(name)) => Some(Calc.Reference(name)) // Boolean Operations - case '{ (${ Calc(lhs) }: Boolean) && (${ Calc(rhs) }: Boolean) } => - Some(Calc.And(lhs, rhs).asInstanceOf[Calc[A]]) - case '{ (${ Calc(lhs) }: Boolean) || (${ Calc(rhs) }: Boolean) } => - Some(Calc.Or(lhs, rhs).asInstanceOf[Calc[A]]) - case '{ !(${ Calc(calc) }: Boolean) } => - Some(Calc.Not(calc).asInstanceOf[Calc[A]]) + case '{ (${ Calc[Boolean](lhs) }: Boolean) && (${ Calc[Boolean](rhs) }: Boolean) } => + Some(Calc.BinaryOp(lhs, rhs, _ && _, infix("&&"))) + case '{ (${ Calc[Boolean](lhs) }: Boolean) || (${ Calc[Boolean](rhs) }: Boolean) } => + Some(Calc.BinaryOp(lhs, rhs, _ || _, infix("||"))) + case '{ !(${ Calc[Boolean](calc) }: Boolean) } => + Some(Calc.UnaryOp(calc, !_, prefix("!"))) // Numeric Operations - case '{ (${ Calc(lhs) }: Int) + (${ Calc(rhs) }: Int) } => - Some(Calc.Add(lhs, rhs).asInstanceOf[Calc[A]]) - case '{ (${ Calc(num) }: Int).toDouble } => - Some(Calc.ToDouble(num).asInstanceOf[Calc[A]]) + case '{ (${ Calc[Int](num) }: Int).toDouble } => + Some(Calc.UnaryOp(num, _.toDouble, nullary("toDouble"))) // String Operations - case '{ (${ Calc(string) }: String).length } => - Some(Calc.Length(string).asInstanceOf[Calc[A]]) - case '{ (${ Calc(string) }: String).length() } => - Some(Calc.Length(string).asInstanceOf[Calc[A]]) - case '{ (${ Calc(string) }: String).substring(${ Calc(start) }: Int, ${ Calc(end) }: Int) } => - Some(Calc.Substring(string, start, end).asInstanceOf[Calc[A]]) - case '{ (${ Calc(string) }: String).toUpperCase } => - Some(Calc.ToUpper(string).asInstanceOf[Calc[A]]) - case '{ (${ Calc(string) }: String).toLowerCase } => - Some(Calc.ToLower(string).asInstanceOf[Calc[A]]) - case '{ (${ Calc(str) }: String).startsWith(${ Calc(prefix) }: String) } => - Some(Calc.StartsWith(str, prefix).asInstanceOf[Calc[A]]) - case '{ (${ Calc(str) }: String).endsWith(${ Calc(suffix) }: String) } => - Some(Calc.EndsWith(str, suffix).asInstanceOf[Calc[A]]) - case '{ (${ Calc(str) }: String).contains(${ Calc(substr) }: String) } => - Some(Calc.Contains(str, substr).asInstanceOf[Calc[A]]) - case '{ (${ Calc(str) }: String).matches(${ Calc(regex) }: String) } => - Some(Calc.MatchesRegex(str, regex).asInstanceOf[Calc[A]]) - case '{ (${ Calc(regex) }: Regex).matches(${ Calc(str) }: String) } => - Some(Calc.RegexMatches(regex, str).asInstanceOf[Calc[A]]) - case '{ (${ Calc(str) }: String).nonEmpty } => - Some(Calc.StringNonEmpty(str).asInstanceOf[Calc[A]]) - case '{ (${ Calc(str) }: String).isEmpty } => - Some(Calc.StringIsEmpty(str).asInstanceOf[Calc[A]]) - case '{ (${ Calc(str) }: String).apply(${ Calc(index) }: Int) } => - Some(Calc.StringApply(str, index).asInstanceOf[Calc[A]]) - case '{ (${ Calc(str) }: String)(${ Calc(index) }: Int) } => - Some(Calc.StringApply(str, index).asInstanceOf[Calc[A]]) - case '{ (${ Calc(str) }: String).trim } => - Some(Calc.StringTrim(str).asInstanceOf[Calc[A]]) - - // Set Operations - case '{ (${ Calc(set) }: Set[String]).contains(${ Calc(elem) }: String) } => - Some(Calc.SetContains(set, elem).asInstanceOf[Calc[A]]) - - case Unseal(quotes.reflect.Block(stats, Seal(Calc(expr)))) => - val defs = stats.collect { case ValDef(name, _, Some(Seal(Calc(calc)))) => - CalcDef(name, calc) - } -// report.errorAndAbort(s"Defs: ${defs} stats: ${stats} expr: ${expr}") - Some(Calc.Block(defs, expr).asInstanceOf[Calc[A]]) - + case '{ (${ Calc[String](string) }: String).length } => + Some(Calc.UnaryOp(string, _.length, nullary("length"))) + case '{ (${ Calc[String](string) }: String).substring(${ Calc[Int](start) }: Int, ${ Calc[Int](end) }: Int) } => + Some(Calc.TernaryOp(string, start, end, _.substring(_, _), call2("substring"))) + case '{ (${ Calc[String](string) }: String).toUpperCase } => + Some(Calc.UnaryOp(string, _.toUpperCase, nullary("toUpperCase"))) + case '{ (${ Calc[String](string) }: String).toLowerCase } => + Some(Calc.UnaryOp(string, _.toLowerCase, nullary("toLowerCase"))) + case '{ (${ Calc[String](str) }: String).startsWith(${ Calc[String](prefix) }: String) } => + Some(Calc.BinaryOp(str, prefix, _.startsWith(_), call("startsWith"))) + case '{ (${ Calc[String](str) }: String).endsWith(${ Calc[String](suffix) }: String) } => + Some(Calc.BinaryOp(str, suffix, _.endsWith(_), call("endsWith"))) + case '{ (${ Calc[String](str) }: String).contains(${ Calc[String](substr) }: String) } => + Some(Calc.BinaryOp(str, substr, _.contains(_), call("contains"))) + case '{ (${ Calc[String](str) }: String).matches(${ Calc[String](regex) }: String) } => + Some(Calc.BinaryOp(str, regex, _.matches(_), call("matches"))) + case '{ (${ Calc[Regex](regex) }: Regex).matches(${ Calc[String](str) }: String) } => + Some(Calc.BinaryOp(regex, str, _.matches(_), call("matches"))) + case '{ (${ Calc[String](str) }: String).nonEmpty } => + Some(Calc.UnaryOp(str, _.nonEmpty, nullary("nonEmpty"))) + case '{ (${ Calc[String](str) }: String).isEmpty } => + Some(Calc.UnaryOp(str, _.isEmpty, nullary("isEmpty"))) + case '{ (${ Calc[String](str) }: String)(${ Calc[Int](index) }: Int) } => + Some(Calc.BinaryOp(str, index, _.apply(_), (l, r) => s"$l($r)")) + case '{ (${ Calc[String](str) }: String).trim } => + Some(Calc.UnaryOp(str, _.trim, nullary("trim"))) + case '{ scala.Predef.identity[a](${ Calc[Any](calc) }) } => + Some(calc) case MatchBinEq(lhs, rhs) => - Some(Calc.BinaryOp(lhs, rhs, _ == _, "==").asInstanceOf[Calc[A]]) + Some(Calc.BinaryOp(lhs, rhs, _ == _, infix("=="))) case MatchBinLt(lhs, rhs) => - Some(Calc.BinaryOp(lhs, rhs, Operations.lessThan, "<").asInstanceOf[Calc[A]]) + Some(Calc.BinaryOp(lhs, rhs, Operations.lessThan, infix("<"))) case MatchBinGt(lhs, rhs) => - Some(Calc.BinaryOp(lhs, rhs, Operations.greaterThan, ">").asInstanceOf[Calc[A]]) + Some(Calc.BinaryOp(lhs, rhs, Operations.greaterThan, infix(">"))) case MatchBinLte(lhs, rhs) => - Some(Calc.BinaryOp(lhs, rhs, Operations.lessThanOrEqual, "<=").asInstanceOf[Calc[A]]) + Some(Calc.BinaryOp(lhs, rhs, Operations.lessThanOrEqual, infix("<="))) case MatchBinGte(lhs, rhs) => - Some(Calc.BinaryOp(lhs, rhs, Operations.greaterThanOrEqual, ">=").asInstanceOf[Calc[A]]) + Some(Calc.BinaryOp(lhs, rhs, Operations.greaterThanOrEqual, infix(">="))) + case MatchBinMinus(lhs, rhs) => + Some(Calc.BinaryOp(lhs, rhs, Operations.minus, infix("-"))) + case MatchBinPlus(lhs, rhs) => + Some(Calc.BinaryOp(lhs, rhs, Operations.plus, infix("+"))) + case MatchBinDivide(lhs, rhs) => + Some(Calc.BinaryOp(lhs, rhs, Operations.divide, infix("/"))) + case MatchBinMod(lhs, rhs) => + Some(Calc.BinaryOp(lhs, rhs, Operations.mod, infix("%"))) + case MatchBinTimes(lhs, rhs) => + Some(Calc.BinaryOp(lhs, rhs, Operations.times, infix("*"))) + + case Unseal(quotes.reflect.Block(stats, Seal(Calc[Any](expr)))) => + val defs = stats.collect { case ValDef(name, _, Some(Seal(Calc[Any](calc)))) => + CalcValDef(name, calc) + } + Some(Calc.Block(defs, expr)) // parse match expression - case Unseal(Match(Seal(Calc(expr)), caseDefs)) => + case Unseal(Match(Seal(Calc[Any](expr)), caseDefs)) => val calcCaseDefs = caseDefs.map(CalcMatchCase.parse) -// report.errorAndAbort(s"Match: ${calcCaseDefs}") - Some(MatchExpr(expr, calcCaseDefs).asInstanceOf[Calc[A]]) + Some(MatchExpr(expr, calcCaseDefs)) case Unseal(Typed(t, _)) => - unapply(t.asExprOf[A]) - case Unseal( - Apply(Apply(Apply(Ident("??"), List(_)), List(Seal(Calc(body)))), List(Literal(StringConstant(msg)))) - ) => - Some(Calc.WithMessage(body, msg).asInstanceOf[Calc[A]]) + unapply(t.asExpr) + + case Unseal(Uninlined(t)) => + unapply(t.asExpr) + + // Set Operations + case '{ type a; (${ Calc[Set[`a`]](set) }: Set[`a`]).contains(${ Calc[`a`](elem) }: `a`) } => + Some(Calc.BinaryOp(set, elem, _.contains(_), call("contains"))) + + // List Operations + case '{ type a; (${ Calc[List[`a`]](list) }: List[`a`]).:+(${ Calc[`a`](elem) }: `a`) } => + Some(Calc.BinaryOp(list, elem, _ :+ _, infix(":+"))) + case '{ type a; (${ Calc[List[`a`]](list) }: List[`a`]).::[`a`](${ Calc[`a`](elem) }: `a`) } => + Some(Calc.BinaryOp(elem, list, _ :: _, infix("::"))) // Fixing case other => -// report.errorAndAbort(s"Calc unapply failed to parse: ${other.show}\n${other.asTerm.underlyingArgument}") +// report.errorAndAbort( +// s"CALC PARSE FAIL: ${other.show}\n${other.asTerm.tpe.show}\n${other.asTerm.underlyingArgument}" +// ) None -object Operations: - // define an any that works for any combination of - // Int < Int - // Long < Long - // Int < Long - // Long < Int - // etc - def lessThan(lhs: Any, rhs: Any): Boolean = - compare(lhs, rhs) < 0 + result.asInstanceOf[Option[Calc[A]]] - def greaterThan(lhs: Any, rhs: Any): Boolean = - compare(lhs, rhs) > 0 +case class CalcValDef[A](name: String, calc: Calc[A]) - def greaterThanOrEqual(lhs: Any, rhs: Any): Boolean = - compare(lhs, rhs) >= 0 - - def lessThanOrEqual(lhs: Any, rhs: Any): Boolean = - compare(lhs, rhs) <= 0 - - def compare(lhs: Any, rhs: Any): Int = - (lhs, rhs) match - case (lhs: String, rhs: String) => lhs.compare(rhs) - case _ => - val ln = numericFor(lhs).asInstanceOf[Numeric[Any]] - val rn = numericFor(rhs).asInstanceOf[Numeric[Any]] - ln.toDouble(lhs).compare(rn.toDouble(rhs)) - - def numericFor(any: Any): Numeric[?] = - any match - case _: Int => summon[Numeric[Int]] - case _: Long => summon[Numeric[Long]] - case _: Short => summon[Numeric[Short]] - case _: Char => summon[Numeric[Char]] - case _: Byte => summon[Numeric[Byte]] - case _: Double => summon[Numeric[Double]] - case _: Float => summon[Numeric[Float]] - case _: BigInt => summon[Numeric[BigInt]] - case _: BigDecimal => summon[Numeric[BigDecimal]] - case _ => throw new IllegalArgumentException(s"Cannot find numeric for ${any}") - -case class CalcDef[A](name: String, calc: Calc[A]) case class CalcMatchCase[A](pattern: CalcPattern[A], guard: Option[Calc[Boolean]], calc: Calc[A]): def render(using Map[String, String]): String = s"${pattern.render} => ${calc.render}" @@ -348,7 +258,7 @@ object CalcMatchCase: def parse(using Quotes)(caseDef: quotes.reflect.CaseDef): CalcMatchCase[Any] = import quotes.reflect.* caseDef match - case CaseDef(pattern, guard, Seal(Calc(calc))) => + case CaseDef(pattern, guard, Seal(Calc[Any](calc))) => val guardCalc = guard.map { case Seal(Calc[Boolean](guardCalc)) => guardCalc } CalcMatchCase(CalcPattern.parse(pattern), guardCalc, calc) case other => @@ -391,20 +301,20 @@ object CalcPattern: term match case r.Wildcard() => CalcPattern.Wildcard() case r.Bind(name, r.Wildcard()) => CalcPattern.Variable(name) - case Seal(Calc(constant)) => CalcPattern.Constant(constant) + case Seal(Calc[Any](constant)) => CalcPattern.Constant(constant) case r.Alternatives(patterns) => CalcPattern.Alternative(patterns.map(parse)) object Unseal: def unapply(expr: Expr[?])(using Quotes): Option[quotes.reflect.Term] = import quotes.reflect.* - Uninlined.unapply(expr.asTerm.underlyingArgument) + Some(expr.asTerm) object Uninlined: def unapply(using Quotes)(term: quotes.reflect.Term): Option[quotes.reflect.Term] = import quotes.reflect.* term match case Inlined(_, bindings, t) => Some(quotes.reflect.Block(bindings, t)) - case t => Some(t) + case t => None object Seal: // turn term into expr diff --git a/modules/core/src/main/scala/neotype/CustomFromExpr.scala b/modules/core/src/main/scala/neotype/CustomFromExpr.scala new file mode 100644 index 0000000..4d119d7 --- /dev/null +++ b/modules/core/src/main/scala/neotype/CustomFromExpr.scala @@ -0,0 +1,46 @@ +package neotype + +import scala.quoted.* + +object CustomFromExpr: + given [A]: FromExpr[Set[A]] with + def unapply(x: Expr[Set[A]])(using Quotes) = + import quotes.reflect.* + val aType = x.asTerm.tpe.widen.typeArgs.head.asType + given FromExpr[A] = fromExprForType(aType).asInstanceOf[FromExpr[A]] + given Type[A] = aType.asInstanceOf[Type[A]] + x match + case '{ Set[A](${ Varargs(Exprs(elems)) }*) } => Some(elems.toSet) + case '{ Set.empty[A] } => Some(Set.empty[A]) + case '{ scala.collection.immutable.Set[A](${ Varargs(Exprs(elems)) }*) } => Some(elems.toSet) + case '{ scala.collection.immutable.Set.empty[A] } => Some(Set.empty[A]) + case _ => +// report.warning(s"Cannot unapply Set from ${x}\n${x.asTerm}") + None + + given [A]: FromExpr[List[A]] with + def unapply(x: Expr[List[A]])(using Quotes) = + import quotes.reflect.* + val aType = x.asTerm.tpe.widen.typeArgs.head.asType + given FromExpr[A] = fromExprForType(aType).asInstanceOf[FromExpr[A]] + given Type[A] = aType.asInstanceOf[Type[A]] + x match + case '{ List[A](${ Varargs(Exprs(elems)) }*) } => Some(elems.toList) + case '{ List.empty[A] } => Some(List.empty[A]) + case '{ scala.collection.immutable.List[A](${ Varargs(Exprs(elems)) }*) } => Some(elems.toList) + case '{ scala.collection.immutable.List.empty[A] } => Some(List.empty[A]) + case _ => +// report.warning(s"Cannot unapply List from ${x.show}") + None + + def fromExprForType(using Quotes)(tpe: Type[?]) = + tpe match + case '[String] => summon[FromExpr[String]] + case '[Int] => summon[FromExpr[Int]] + case '[Long] => summon[FromExpr[Long]] + case '[Short] => summon[FromExpr[Short]] + case '[Char] => summon[FromExpr[Char]] + case '[Byte] => summon[FromExpr[Byte]] + case '[Double] => summon[FromExpr[Double]] + case '[Float] => summon[FromExpr[Float]] + case '[Boolean] => summon[FromExpr[Boolean]] diff --git a/modules/core/src/main/scala/neotype/ErrorMessages.scala b/modules/core/src/main/scala/neotype/ErrorMessages.scala index 6161447..33e21b5 100644 --- a/modules/core/src/main/scala/neotype/ErrorMessages.scala +++ b/modules/core/src/main/scala/neotype/ErrorMessages.scala @@ -4,8 +4,6 @@ import scala.quoted.* import StringFormatting.* private[neotype] object ErrorMessages: - // contiguous ASCII dash symbol: U+2015 looks like: — - val header = "—— Newtype Error ——————————————————————————————————————————————————————————".red val footer = @@ -13,11 +11,11 @@ private[neotype] object ErrorMessages: /** An error message for when the input to a Newtype's apply method is not known at compile time. */ - def inputNotKnownAtCompileTime(using Quotes)(input: Expr[Any], nt: quotes.reflect.TypeRepr) = + def inputParseFailureMessage(using Quotes)(input: Expr[Any], nt: quotes.reflect.TypeRepr): String = import quotes.reflect.* val inputTpe = input.asTerm.tpe.widenTermRefByName - val example = examples(inputTpe) + val example = Calc.renderConstant(examples(input)) val newTypeNameString = nt.typeSymbol.name.replaceAll("\\$$", "").green.bold val valueExprString = input.asTerm.pos.sourceCode.getOrElse(input.show).blue @@ -28,7 +26,7 @@ private[neotype] object ErrorMessages: | | 🤠 ${"Possible Solutions".bold} | ${"1.".dim} Try passing a literal $inputTypeString: - | $newTypeNameString(${example.show.green}) + | $newTypeNameString(${example}) | ${"2.".dim} Call the ${"make".green} method, which returns a runtime-validated ${"Either".yellow}: | $newTypeNameString.${"make".green}(${valueExprString}) | ${"3.".dim} If you are sure the input is valid, use the ${"unsafe".green} method: @@ -40,9 +38,9 @@ private[neotype] object ErrorMessages: /** An error message for when the compile-time validation of a Newtype's apply method fails. */ - def validationFailed(using + def compileTimeValidationFailureMessage(using Quotes - )(input: Expr[Any], nt: quotes.reflect.TypeRepr, source: Option[String], failureMessage: String) = + )(input: Expr[Any], nt: quotes.reflect.TypeRepr, source: Option[String], failureMessage: String): String = import quotes.reflect.* val isDefaultFailureMessage = failureMessage == "Validation Failed" @@ -61,7 +59,7 @@ private[neotype] object ErrorMessages: | $footer |""".stripMargin - def validateIsNotInline(using Quotes)(input: Expr[Any], nt: quotes.reflect.TypeRepr): String = + def validateIsNotInlineMessage(using Quotes)(input: Expr[Any], nt: quotes.reflect.TypeRepr): String = import quotes.reflect.* val inputTpe = input.asTerm.tpe.widenTermRefByName val newTypeNameString = nt.typeSymbol.name.replaceAll("\\$$", "").green.bold @@ -77,7 +75,7 @@ private[neotype] object ErrorMessages: | $footer |""".stripMargin - def failedToParseCustomErrorMessage(using Quotes)(nt: quotes.reflect.TypeRepr) = + def failedToParseCustomErrorMessage(using Quotes)(nt: quotes.reflect.TypeRepr): String = val newTypeNameString = nt.typeSymbol.name.replaceAll("\\$$", "").green.bold s""" $header | 😭 I've ${"FAILED".red} to parse $newTypeNameString's ${"failureMessage".green}! @@ -90,18 +88,11 @@ private[neotype] object ErrorMessages: | $footer |""".stripMargin - def indent(str: String) = - str.linesIterator - .map { line => - s" $line".blue - } - .mkString("\n") - def failedToParseValidateMethod(using Quotes )(input: Expr[Any], nt: quotes.reflect.TypeRepr, source: Option[String], isBodyInline: Option[Boolean]): String = import quotes.reflect.* - if isBodyInline.contains(false) then return validateIsNotInline(input, nt) + if isBodyInline.contains(false) then return validateIsNotInlineMessage(input, nt) val newTypeNameString = nt.typeSymbol.name.replaceAll("\\$$", "").green.bold val sourceExpr = source.fold("") { s => @@ -129,16 +120,27 @@ private[neotype] object ErrorMessages: | $footer |""".stripMargin + private def indent(str: String) = + str.linesIterator + .map { line => + s" $line".blue + } + .mkString("\n") + // Create a map from various input types to examples of the given type of statically known inputs - def examples(using Quotes)(tpe: quotes.reflect.TypeRepr) = + def examples(using Quotes)(input: Expr[Any]): Any = import quotes.reflect.* - val examples = Map( - TypeRepr.of[String] -> '{ "foo" }, - TypeRepr.of[Int] -> '{ 1 }, - TypeRepr.of[Long] -> '{ 1L }, - TypeRepr.of[Float] -> '{ 1.0f }, - TypeRepr.of[Double] -> '{ 1.0 } - ) - - examples.find { case (k, _) => k <:< tpe }.get._2 + input match + case '{ ($_): String } => "foo" + case '{ ($_): Int } => 1 + case '{ ($_): Long } => 1L + case '{ ($_): Float } => 1.0f + case '{ ($_): Double } => 1.0 + case '{ ($_): Boolean } => true + case '{ ($_): Char } => 'a' + case '{ ($_): Byte } => 1.toByte + case '{ ($_): Short } => 1.toShort + case '{ ($_): Unit } => () + case '{ ($_): Null } => null + case _ => "input" diff --git a/modules/core/src/main/scala/neotype/Macros.scala b/modules/core/src/main/scala/neotype/Macros.scala index 36413a0..e17423e 100644 --- a/modules/core/src/main/scala/neotype/Macros.scala +++ b/modules/core/src/main/scala/neotype/Macros.scala @@ -23,21 +23,15 @@ private[neotype] object Macros: case _ => None - lazy val treeSource = scala.util - .Try { + lazy val treeSource = + try nt.typeSymbol .methodMember("validate") .headOption .flatMap { - _.tree match - case body => - body.pos.sourceCode - case _ => - None + _.tree.pos.sourceCode } - } - .toOption - .flatten + catch case _: Throwable => None val isBodyInline = nt.typeSymbol .methodMember("validate") @@ -48,25 +42,28 @@ private[neotype] object Macros: case Calc[A](calc) => scala.util.Try(calc.result(using Map.empty)) match case Failure(_) => - report.errorAndAbort(ErrorMessages.inputNotKnownAtCompileTime(a, nt)) + report.errorAndAbort(ErrorMessages.inputParseFailureMessage(a, nt)) case Success(_) => () case _ => - report.errorAndAbort(ErrorMessages.inputNotKnownAtCompileTime(a, nt)) + report.errorAndAbort(ErrorMessages.inputParseFailureMessage(a, nt)) val validateApplied = Expr.betaReduce('{ $validate($a) }) validateApplied match - case Calc(calc) => + case Calc[A](calc) => scala.util.Try(calc.result(using Map.empty)) match case Failure(exception) => - report.errorAndAbort(s"Failed to execute parsed validation: $exception") +// report.errorAndAbort(s"Failed to execute parsed validation: $exception") + report.errorAndAbort(ErrorMessages.failedToParseValidateMethod(a, nt, treeSource, isBodyInline)) case Success(true) => a.asExprOf[T] case Success(false) => val failureMessageValue = failureMessage match case Expr(str: String) => str case _ => "Validation Failed" - report.errorAndAbort(ErrorMessages.validationFailed(a, nt, expressionSource, failureMessageValue)) + report.errorAndAbort( + ErrorMessages.compileTimeValidationFailureMessage(a, nt, expressionSource, failureMessageValue) + ) case _ => report.errorAndAbort(ErrorMessages.failedToParseValidateMethod(a, nt, treeSource, isBodyInline)) @@ -89,3 +86,38 @@ private[neotype] object Macros: processArgs(args.asInstanceOf[Seq[Expr[A]]]) case other => report.errorAndAbort(s"Could not parse input at compile time: ${other.show}") + +private[neotype] object TestMacros: + inline def eval[A](inline expr: A): A = ${ evalImpl[A]('expr) } + inline def evalDebug[A](inline expr: A): A = ${ evalDebugImpl[A]('expr) } + + def evalDebugImpl[A: Type](using Quotes)(expr: Expr[A]): Expr[A] = + import quotes.reflect.* + report.info(s"expr: ${expr.show}\nterm: ${expr.asTerm.underlyingArgument}") + evalImpl(expr) + + def evalImpl[A: Type](using Quotes)(expr: Expr[A]): Expr[A] = + import quotes.reflect.* + expr match + case Calc[A](calc) => + val result = calc.result(using Map.empty) + given ToExpr[A] = toExprInstance(result).asInstanceOf[ToExpr[A]] + Expr(result) + case _ => + report.errorAndAbort(s"Could not parse input at compile time: ${expr.show}\n\n${expr.asTerm.toString.blue}") + ??? + + def toExprInstance(using Quotes)(any: Any): ToExpr[?] = + import quotes.reflect.* + any match + case _: Int => summon[ToExpr[Int]] + case _: String => summon[ToExpr[String]] + case _: Boolean => summon[ToExpr[Boolean]] + case _: Long => summon[ToExpr[Long]] + case _: Double => summon[ToExpr[Double]] + case _: Float => summon[ToExpr[Float]] + case _: Char => summon[ToExpr[Char]] + case _: Byte => summon[ToExpr[Byte]] + case _: Short => summon[ToExpr[Short]] + case _: Set[Int] => summon[ToExpr[Set[Int]]] + case _: List[Int] => summon[ToExpr[List[Int]]] diff --git a/modules/core/src/main/scala/neotype/Operations.scala b/modules/core/src/main/scala/neotype/Operations.scala new file mode 100644 index 0000000..069515d --- /dev/null +++ b/modules/core/src/main/scala/neotype/Operations.scala @@ -0,0 +1,248 @@ +package neotype + +private[neotype] object Operations: + + def minus(lhs: Any, rhs: Any): Any = + (lhs, rhs) match + case (lhs: Set[Any], rhs: Any) => lhs - rhs + case _ => performNumericBinOp(NumericBinOp.Minus, lhs, rhs) + def plus(lhs: Any, rhs: Any): Any = + (lhs, rhs) match + case (lhs: String, rhs: String) => lhs + rhs + case (lhs: Set[Any], rhs: Any) => lhs + rhs + case _ => performNumericBinOp(NumericBinOp.Plus, lhs, rhs) + def times(lhs: Any, rhs: Any): Any = performNumericBinOp(NumericBinOp.Times, lhs, rhs) + def divide(lhs: Any, rhs: Any): Any = performNumericBinOp(NumericBinOp.Divide, lhs, rhs) + def mod(lhs: Any, rhs: Any): Any = performNumericBinOp(NumericBinOp.Mod, lhs, rhs) + def pow(lhs: Any, rhs: Any): Any = performNumericBinOp(NumericBinOp.Pow, lhs, rhs) + def min(lhs: Any, rhs: Any): Any = performNumericBinOp(NumericBinOp.Min, lhs, rhs) + def max(lhs: Any, rhs: Any): Any = performNumericBinOp(NumericBinOp.Max, lhs, rhs) + def lessThan(lhs: Any, rhs: Any): Any = + (lhs, rhs) match + case (lhs: String, rhs: String) => lhs < rhs + case _ => performNumericBinOp(NumericBinOp.LessThan, lhs, rhs) + + def lessThanOrEqual(lhs: Any, rhs: Any): Any = + (lhs, rhs) match + case (lhs: String, rhs: String) => lhs <= rhs + case _ => performNumericBinOp(NumericBinOp.LessThanOrEqual, lhs, rhs) + + def greaterThan(lhs: Any, rhs: Any): Any = + (lhs, rhs) match + case (lhs: String, rhs: String) => lhs > rhs + case _ => performNumericBinOp(NumericBinOp.GreaterThan, lhs, rhs) + + def greaterThanOrEqual(lhs: Any, rhs: Any): Any = + (lhs, rhs) match + case (lhs: String, rhs: String) => lhs >= rhs + case _ => performNumericBinOp(NumericBinOp.GreaterThanOrEqual, lhs, rhs) + + // Adding (1: Byte) to (1: Int) will fail with a ClassCastException, + // so we need to convert the Byte to an Int before adding. + // We need to always convert to the larger type, so we need to know + // the type of the larger type. + private def performNumericBinOp(op: NumericBinOp, lhs: Any, rhs: Any): Any = + val lhsType = NumericType.forNumber(lhs) + val rhsType = NumericType.forNumber(rhs) + val widestType = lhsType.widest(rhsType) + val f = NumericBinOp.ops((widestType, op)) + f(widestType.widen(lhs), widestType.widen(rhs)) + +private enum NumericBinOp: + case Plus, Minus, Times, Divide, Mod, Min, Max, Pow, LessThan, LessThanOrEqual, GreaterThan, GreaterThanOrEqual + +private object NumericBinOp: + val ops: Map[(NumericType, NumericBinOp), (Any, Any) => Any] = + Map( + // byte ops + (NumericType.ByteT, NumericBinOp.Plus) -> ((a: Byte, b: Byte) => a + b), + (NumericType.ByteT, NumericBinOp.Minus) -> ((a: Byte, b: Byte) => a - b), + (NumericType.ByteT, NumericBinOp.Times) -> ((a: Byte, b: Byte) => a * b), + (NumericType.ByteT, NumericBinOp.Divide) -> ((a: Byte, b: Byte) => a / b), + (NumericType.ByteT, NumericBinOp.Mod) -> ((a: Byte, b: Byte) => a % b), + (NumericType.ByteT, NumericBinOp.Min) -> ((a: Byte, b: Byte) => a.min(b)), + (NumericType.ByteT, NumericBinOp.Max) -> ((a: Byte, b: Byte) => a.max(b)), + (NumericType.ByteT, NumericBinOp.Pow) -> ((a: Byte, b: Byte) => Math.pow(a, b)), + (NumericType.ByteT, NumericBinOp.LessThan) -> ((a: Byte, b: Byte) => a < b), + (NumericType.ByteT, NumericBinOp.LessThanOrEqual) -> ((a: Byte, b: Byte) => a <= b), + (NumericType.ByteT, NumericBinOp.GreaterThan) -> ((a: Byte, b: Byte) => a > b), + (NumericType.ByteT, NumericBinOp.GreaterThanOrEqual) -> ((a: Byte, b: Byte) => a >= b), + + // short ops + (NumericType.ShortT, NumericBinOp.Plus) -> ((a: Short, b: Short) => a + b), + (NumericType.ShortT, NumericBinOp.Minus) -> ((a: Short, b: Short) => a - b), + (NumericType.ShortT, NumericBinOp.Times) -> ((a: Short, b: Short) => a * b), + (NumericType.ShortT, NumericBinOp.Divide) -> ((a: Short, b: Short) => a / b), + (NumericType.ShortT, NumericBinOp.Mod) -> ((a: Short, b: Short) => a % b), + (NumericType.ShortT, NumericBinOp.Min) -> ((a: Short, b: Short) => a.min(b)), + (NumericType.ShortT, NumericBinOp.Max) -> ((a: Short, b: Short) => a.max(b)), + (NumericType.ShortT, NumericBinOp.Pow) -> ((a: Short, b: Short) => Math.pow(a, b)), + (NumericType.ShortT, NumericBinOp.LessThan) -> ((a: Short, b: Short) => a < b), + (NumericType.ShortT, NumericBinOp.LessThanOrEqual) -> ((a: Short, b: Short) => a <= b), + (NumericType.ShortT, NumericBinOp.GreaterThan) -> ((a: Short, b: Short) => a > b), + (NumericType.ShortT, NumericBinOp.GreaterThanOrEqual) -> ((a: Short, b: Short) => a >= b), + // char ops + (NumericType.CharT, NumericBinOp.Plus) -> ((a: Char, b: Char) => a + b), + (NumericType.CharT, NumericBinOp.Minus) -> ((a: Char, b: Char) => a - b), + (NumericType.CharT, NumericBinOp.Times) -> ((a: Char, b: Char) => a * b), + (NumericType.CharT, NumericBinOp.Divide) -> ((a: Char, b: Char) => a / b), + (NumericType.CharT, NumericBinOp.Mod) -> ((a: Char, b: Char) => a % b), + (NumericType.CharT, NumericBinOp.Min) -> ((a: Char, b: Char) => a.min(b)), + (NumericType.CharT, NumericBinOp.Max) -> ((a: Char, b: Char) => a.max(b)), + (NumericType.CharT, NumericBinOp.Pow) -> ((a: Char, b: Char) => Math.pow(a, b)), + (NumericType.CharT, NumericBinOp.LessThan) -> ((a: Char, b: Char) => a < b), + (NumericType.CharT, NumericBinOp.LessThanOrEqual) -> ((a: Char, b: Char) => a <= b), + (NumericType.CharT, NumericBinOp.GreaterThan) -> ((a: Char, b: Char) => a > b), + (NumericType.CharT, NumericBinOp.GreaterThanOrEqual) -> ((a: Char, b: Char) => a >= b), + // int ops + (NumericType.IntT, NumericBinOp.Plus) -> ((a: Int, b: Int) => a + b), + (NumericType.IntT, NumericBinOp.Minus) -> ((a: Int, b: Int) => a - b), + (NumericType.IntT, NumericBinOp.Times) -> ((a: Int, b: Int) => a * b), + (NumericType.IntT, NumericBinOp.Divide) -> ((a: Int, b: Int) => a / b), + (NumericType.IntT, NumericBinOp.Mod) -> ((a: Int, b: Int) => a % b), + (NumericType.IntT, NumericBinOp.Min) -> ((a: Int, b: Int) => a.min(b)), + (NumericType.IntT, NumericBinOp.Max) -> ((a: Int, b: Int) => a.max(b)), + (NumericType.IntT, NumericBinOp.Pow) -> ((a: Int, b: Int) => Math.pow(a, b)), + (NumericType.IntT, NumericBinOp.LessThan) -> ((a: Int, b: Int) => a < b), + (NumericType.IntT, NumericBinOp.LessThanOrEqual) -> ((a: Int, b: Int) => a <= b), + (NumericType.IntT, NumericBinOp.GreaterThan) -> ((a: Int, b: Int) => a > b), + (NumericType.IntT, NumericBinOp.GreaterThanOrEqual) -> ((a: Int, b: Int) => a >= b), + // long ops + (NumericType.LongT, NumericBinOp.Plus) -> ((a: Long, b: Long) => a + b), + (NumericType.LongT, NumericBinOp.Minus) -> ((a: Long, b: Long) => a - b), + (NumericType.LongT, NumericBinOp.Times) -> ((a: Long, b: Long) => a * b), + (NumericType.LongT, NumericBinOp.Divide) -> ((a: Long, b: Long) => a / b), + (NumericType.LongT, NumericBinOp.Mod) -> ((a: Long, b: Long) => a % b), + (NumericType.LongT, NumericBinOp.Min) -> ((a: Long, b: Long) => a.min(b)), + (NumericType.LongT, NumericBinOp.Max) -> ((a: Long, b: Long) => a.max(b)), + (NumericType.LongT, NumericBinOp.Pow) -> ((a: Long, b: Long) => Math.pow(a, b)), + (NumericType.LongT, NumericBinOp.LessThan) -> ((a: Long, b: Long) => a < b), + (NumericType.LongT, NumericBinOp.LessThanOrEqual) -> ((a: Long, b: Long) => a <= b), + (NumericType.LongT, NumericBinOp.GreaterThan) -> ((a: Long, b: Long) => a > b), + (NumericType.LongT, NumericBinOp.GreaterThanOrEqual) -> ((a: Long, b: Long) => a >= b), + // float ops + (NumericType.FloatT, NumericBinOp.Plus) -> ((a: Float, b: Float) => a + b), + (NumericType.FloatT, NumericBinOp.Minus) -> ((a: Float, b: Float) => a - b), + (NumericType.FloatT, NumericBinOp.Times) -> ((a: Float, b: Float) => a * b), + (NumericType.FloatT, NumericBinOp.Divide) -> ((a: Float, b: Float) => a / b), + (NumericType.FloatT, NumericBinOp.Mod) -> ((a: Float, b: Float) => a % b), + (NumericType.FloatT, NumericBinOp.Min) -> ((a: Float, b: Float) => a.min(b)), + (NumericType.FloatT, NumericBinOp.Max) -> ((a: Float, b: Float) => a.max(b)), + (NumericType.FloatT, NumericBinOp.Pow) -> ((a: Float, b: Float) => Math.pow(a, b)), + (NumericType.FloatT, NumericBinOp.LessThan) -> ((a: Float, b: Float) => a < b), + (NumericType.FloatT, NumericBinOp.LessThanOrEqual) -> ((a: Float, b: Float) => a <= b), + (NumericType.FloatT, NumericBinOp.GreaterThan) -> ((a: Float, b: Float) => a > b), + (NumericType.FloatT, NumericBinOp.GreaterThanOrEqual) -> ((a: Float, b: Float) => a >= b), + // double ops + (NumericType.DoubleT, NumericBinOp.Plus) -> ((a: Double, b: Double) => a + b), + (NumericType.DoubleT, NumericBinOp.Minus) -> ((a: Double, b: Double) => a - b), + (NumericType.DoubleT, NumericBinOp.Times) -> ((a: Double, b: Double) => a * b), + (NumericType.DoubleT, NumericBinOp.Divide) -> ((a: Double, b: Double) => a / b), + (NumericType.DoubleT, NumericBinOp.Mod) -> ((a: Double, b: Double) => a % b), + (NumericType.DoubleT, NumericBinOp.Min) -> ((a: Double, b: Double) => a.min(b)), + (NumericType.DoubleT, NumericBinOp.Max) -> ((a: Double, b: Double) => a.max(b)), + (NumericType.DoubleT, NumericBinOp.Pow) -> ((a: Double, b: Double) => Math.pow(a, b)), + (NumericType.DoubleT, NumericBinOp.LessThan) -> ((a: Double, b: Double) => a < b), + (NumericType.DoubleT, NumericBinOp.LessThanOrEqual) -> ((a: Double, b: Double) => a <= b), + (NumericType.DoubleT, NumericBinOp.GreaterThan) -> ((a: Double, b: Double) => a > b), + (NumericType.DoubleT, NumericBinOp.GreaterThanOrEqual) -> ((a: Double, b: Double) => a >= b), + // big int ops + (NumericType.BigIntT, NumericBinOp.Plus) -> ((a: BigInt, b: BigInt) => a + b), + (NumericType.BigIntT, NumericBinOp.Minus) -> ((a: BigInt, b: BigInt) => a - b), + (NumericType.BigIntT, NumericBinOp.Times) -> ((a: BigInt, b: BigInt) => a * b), + (NumericType.BigIntT, NumericBinOp.Divide) -> ((a: BigInt, b: BigInt) => a / b), + (NumericType.BigIntT, NumericBinOp.Mod) -> ((a: BigInt, b: BigInt) => a % b), + (NumericType.BigIntT, NumericBinOp.Min) -> ((a: BigInt, b: BigInt) => a.min(b)), + (NumericType.BigIntT, NumericBinOp.Max) -> ((a: BigInt, b: BigInt) => a.max(b)), + (NumericType.BigIntT, NumericBinOp.Pow) -> ((a: BigInt, b: BigInt) => a.pow(b.toInt)), + (NumericType.BigIntT, NumericBinOp.LessThan) -> ((a: BigInt, b: BigInt) => a < b), + (NumericType.BigIntT, NumericBinOp.LessThanOrEqual) -> ((a: BigInt, b: BigInt) => a <= b), + (NumericType.BigIntT, NumericBinOp.GreaterThan) -> ((a: BigInt, b: BigInt) => a > b), + (NumericType.BigIntT, NumericBinOp.GreaterThanOrEqual) -> ((a: BigInt, b: BigInt) => a >= b), + // big decimal ops + (NumericType.BigDecimalT, NumericBinOp.Plus) -> ((a: BigDecimal, b: BigDecimal) => a + b), + (NumericType.BigDecimalT, NumericBinOp.Minus) -> ((a: BigDecimal, b: BigDecimal) => a - b), + (NumericType.BigDecimalT, NumericBinOp.Times) -> ((a: BigDecimal, b: BigDecimal) => a * b), + (NumericType.BigDecimalT, NumericBinOp.Divide) -> ((a: BigDecimal, b: BigDecimal) => a / b), + (NumericType.BigDecimalT, NumericBinOp.Mod) -> ((a: BigDecimal, b: BigDecimal) => a % b), + (NumericType.BigDecimalT, NumericBinOp.Min) -> ((a: BigDecimal, b: BigDecimal) => a.min(b)), + (NumericType.BigDecimalT, NumericBinOp.Max) -> ((a: BigDecimal, b: BigDecimal) => a.max(b)), + (NumericType.BigDecimalT, NumericBinOp.Pow) -> ((a: BigDecimal, b: BigDecimal) => a.pow(b.toInt)), + (NumericType.BigDecimalT, NumericBinOp.LessThan) -> ((a: BigDecimal, b: BigDecimal) => a < b), + (NumericType.BigDecimalT, NumericBinOp.LessThanOrEqual) -> ((a: BigDecimal, b: BigDecimal) => a <= b), + (NumericType.BigDecimalT, NumericBinOp.GreaterThan) -> ((a: BigDecimal, b: BigDecimal) => a > b), + (NumericType.BigDecimalT, NumericBinOp.GreaterThanOrEqual) -> ((a: BigDecimal, b: BigDecimal) => a >= b) + ).asInstanceOf[Map[(NumericType, NumericBinOp), (Any, Any) => Any]] + +private enum NumericType: + // from narrowest to widest + case ByteT, ShortT, CharT, IntT, LongT, FloatT, DoubleT, BigIntT, BigDecimalT + + def widest(that: NumericType): NumericType = + if this.ordinal > that.ordinal then this else that + + def widen(any: Any): Any = + (any, this) match + case (any: Byte, ByteT) => any + case (any: Byte, ShortT) => any.toShort + case (any: Byte, CharT) => any.toChar + case (any: Byte, IntT) => any.toInt + case (any: Byte, LongT) => any.toLong + case (any: Byte, FloatT) => any.toFloat + case (any: Byte, DoubleT) => any.toDouble + case (any: Byte, BigIntT) => BigInt(any) + case (any: Byte, BigDecimalT) => BigDecimal(any) + + case (any: Short, ShortT) => any + case (any: Short, CharT) => any.toChar + case (any: Short, IntT) => any.toInt + case (any: Short, LongT) => any.toLong + case (any: Short, FloatT) => any.toFloat + case (any: Short, DoubleT) => any.toDouble + case (any: Short, BigIntT) => BigInt(any) + case (any: Short, BigDecimalT) => BigDecimal(any) + + case (any: Char, CharT) => any + case (any: Char, IntT) => any.toInt + case (any: Char, LongT) => any.toLong + case (any: Char, FloatT) => any.toFloat + case (any: Char, DoubleT) => any.toDouble + case (any: Char, BigIntT) => BigInt(any) + case (any: Char, BigDecimalT) => BigDecimal(any) + + case (any: Int, IntT) => any + case (any: Int, LongT) => any.toLong + case (any: Int, FloatT) => any.toFloat + case (any: Int, DoubleT) => any.toDouble + case (any: Int, BigIntT) => BigInt(any) + case (any: Int, BigDecimalT) => BigDecimal(any) + + case (any: Long, LongT) => any + case (any: Long, FloatT) => any.toFloat + case (any: Long, DoubleT) => any.toDouble + case (any: Long, BigIntT) => BigInt(any) + case (any: Long, BigDecimalT) => BigDecimal(any) + + case (any: Float, FloatT) => any + case (any: Float, DoubleT) => any.toDouble + case (any: Float, BigDecimalT) => BigDecimal(any) + + case (any: Double, DoubleT) => any + case (any: Double, BigDecimalT) => BigDecimal(any) + + case (any: BigInt, BigIntT) => any + case (any: BigInt, BigDecimalT) => BigDecimal(any) + +private object NumericType: + def forNumber(any: Any): NumericType = + any match + case _: Int => NumericType.IntT + case _: Long => NumericType.LongT + case _: Short => NumericType.ShortT + case _: Char => NumericType.CharT + case _: Byte => NumericType.ByteT + case _: Double => NumericType.DoubleT + case _: Float => NumericType.FloatT + case _: BigInt => NumericType.BigIntT + case _: BigDecimal => NumericType.BigDecimalT + case _ => throw new IllegalArgumentException(s"Cannot find numeric for ${any}") diff --git a/modules/core/src/main/scala/neotype/package.scala b/modules/core/src/main/scala/neotype/package.scala index d825c5d..6521332 100644 --- a/modules/core/src/main/scala/neotype/package.scala +++ b/modules/core/src/main/scala/neotype/package.scala @@ -49,6 +49,8 @@ abstract class Newtype[A](using fromExpr: FromExpr[A]) extends ValidatedWrapper[ extension (inline input: Type) // inline def unwrap: A = input + inline def unsafeWrapF[F[_]](inline input: F[A]): F[Type] = input + object Newtype: type WithType[A, B] = Newtype[A] { type Type = B } @@ -61,6 +63,8 @@ object Newtype: inline def applyF[F[_]](inline input: F[A]): F[Type] = input + inline def unsafeWrapF[F[_]](inline input: F[A]): F[Type] = input + object Simple: type WithType[A, B] = Newtype.Simple[A] { type Type = B } @@ -78,6 +82,9 @@ abstract class Subtype[A](using fromExpr: FromExpr[A]) extends ValidatedWrapper[ inline def cast(inline input: Type): A = input inline def castF[F[_]](inline input: F[Type]): F[A] = input + inline def unsafeWrap(inline input: A): Type = input + inline def unsafeWrapF[F[_]](inline input: F[A]): F[Type] = input + object Subtype: type WithType[A, B <: A] = Subtype[A] { type Type = B } @@ -93,5 +100,8 @@ object Subtype: inline def cast(inline input: A): Type = input inline def castF[F[_]](inline input: F[A]): F[Type] = input + inline def unsafeWrap(inline input: A): Type = input + inline def unsafeWrapF[F[_]](inline input: F[A]): F[Type] = input + object Simple: type WithType[A, B <: A] = Subtype.Simple[A] { type Type = B } diff --git a/modules/core/src/test/scala/neotype/EvalSpec.scala b/modules/core/src/test/scala/neotype/EvalSpec.scala new file mode 100644 index 0000000..6e5d89b --- /dev/null +++ b/modules/core/src/test/scala/neotype/EvalSpec.scala @@ -0,0 +1,70 @@ +package neotype + +import neotype.TestMacros.* +import zio.test.* + +object EvalSpec extends ZIOSpecDefault: + val spec = + suite("EvalSpec")( + evalTests.map { case (actual, expected) => + test(s"eval($actual) == $expected") { + assertTrue(actual == expected) + } + } + ) + +/** Tests various ways of evaluating various expressions + */ +val evalTests = + List( + // numeric expressions + // int + eval(identity(3) * 2) -> 6, + eval(identity(1) + 1) -> 2, + eval(identity(10) - 5) -> 5, + eval(identity(10) / 2) -> 5, + eval(identity(20) % 3) -> 2, + + // long + eval(identity(3L) * 2L) -> 6L, + eval(identity(1L) + 1L) -> 2L, + eval(identity(10L) - 5L) -> 5L, + eval(identity(10L) / 2L) -> 5L, + eval(identity(20L) % 3L) -> 2L, + + // double + eval(identity(3.0) * 2.0) -> 6.0, + eval(identity(1.5) + 1.5) -> 3.0, + eval(identity(10.0) - 5.5) -> 4.5, + eval(identity(10.0) / 2.0) -> 5.0, +// eval(identity(20.0) % 3.0) -> 2.0, + + // string expressions + eval(identity("Hello, ") + "world!") -> "Hello, world!", + eval("Scala is good".toUpperCase) -> "SCALA IS GOOD", + eval("SCALA IS GOOD".toLowerCase) -> "scala is good", + eval("Scala is good".toUpperCase()) -> "SCALA IS GOOD", + eval("SCALA IS GOOD".toLowerCase()) -> "scala is good", + eval("myemail@gmail.com".matches(".*@gmail.com")) -> true, + + // boolean expressions + eval(identity(true) && false) -> false, + eval(identity(true) || false) -> true, + eval(!identity(true)) -> false, + eval(!identity(false)) -> true, + eval(identity(true) == false) -> false, + + // set expressions + eval(Set(1, 2, 3)) -> Set(1, 2, 3), + eval(Set(1, 2, 3) + 4) -> Set(1, 2, 3, 4), + eval(Set(1, 2, 3) - 2) -> Set(1, 3), + eval(Set(1, 2, 3).contains(2)) -> true, + eval(Set(1, 2, 3).contains(5)) -> false, +// eval(Set(1, 2, 3)(5)) -> false, + + // list expressions + eval(List(1, 2, 3)) -> List(1, 2, 3), + eval(List(1, 2, 3) :+ 4) -> List(1, 2, 3, 4), + eval(5 :: List(1, 2, 3)) -> List(5, 1, 2, 3) +// eval(List(1, 2, 3).head) -> 1 + ) diff --git a/modules/core/src/test/scala/neotype/Use.scala b/modules/core/src/test/scala/neotype/Use.scala index 72ce550..6c961cf 100644 --- a/modules/core/src/test/scala/neotype/Use.scala +++ b/modules/core/src/test/scala/neotype/Use.scala @@ -3,3 +3,12 @@ package neotype object Use extends App: EqualityParsingNewtype("secret string") LessThanParsingNewtype("sectret string") + +//given ArithmeticNewtype: Newtype[Double] with +// inline def validate(input: Double) = +// val y = input +// y - 10 == 0 +// +//object Testing extends App: +// val int = 10.5 +// ArithmeticNewtype(int) // ok diff --git a/modules/neotype-zio-json/src/test/scala/neotype/ziojson/ZioJsonNewtypeSpec.scala b/modules/neotype-zio-json/src/test/scala/neotype/ziojson/ZioJsonNewtypeSpec.scala index 8a50783..6f9d8ff 100644 --- a/modules/neotype-zio-json/src/test/scala/neotype/ziojson/ZioJsonNewtypeSpec.scala +++ b/modules/neotype-zio-json/src/test/scala/neotype/ziojson/ZioJsonNewtypeSpec.scala @@ -24,6 +24,13 @@ given SimpleNewtype: Newtype.Simple[Int] with {} type SimpleSubtype = SimpleSubtype.Type given SimpleSubtype: Subtype.Simple[String] with {} +object LayerTest: + import zio.* + given [A, B](using newType: Newtype.WithType[A, B], tag: Tag[A]): Tag[B] = + newType.unsafeWrapF(tag) + + val layer = ZLayer.succeed(NonEmptyString("Hello")) + object ZioJsonSpec extends ZIOSpecDefault: def spec = suite("ZioJsonSpec")( suite("NonEmptyString")( diff --git a/modules/neotype-zio/src/main/scala/neotype/zio/Main.scala b/modules/neotype-zio/src/main/scala/neotype/zio/Main.scala new file mode 100644 index 0000000..b9dc3d6 --- /dev/null +++ b/modules/neotype-zio/src/main/scala/neotype/zio/Main.scala @@ -0,0 +1,20 @@ +package neotype.zio + +import neotype.* +import _root_.zio.* + +// Newtype +given [A, B](using newType: Newtype.WithType[A, B], tag: Tag[A]): Tag[B] = + newType.unsafeWrapF(tag) + +// Newtype.Simple +given [A, B](using newType: Newtype.Simple.WithType[A, B], tag: Tag[A]): Tag[B] = + newType.unsafeWrapF(tag) + +// Subtype +given [A, B <: A](using subType: Subtype.WithType[A, B], tag: Tag[A]): Tag[B] = + subType.unsafeWrapF(tag) + +// Subtype.Simple +given [A, B <: A](using subType: Subtype.Simple.WithType[A, B], tag: Tag[A]): Tag[B] = + subType.unsafeWrapF(tag) diff --git a/modules/neotype-zio/src/test/scala/neotype/zio/ZioJsonNewtypeSpec.scala b/modules/neotype-zio/src/test/scala/neotype/zio/ZioJsonNewtypeSpec.scala new file mode 100644 index 0000000..bd59463 --- /dev/null +++ b/modules/neotype-zio/src/test/scala/neotype/zio/ZioJsonNewtypeSpec.scala @@ -0,0 +1,68 @@ +package neotype.zio + +import neotype.{Newtype, Subtype} +import zio.test.* +import zio.* + +type MyNewtype = MyNewtype.Type +given MyNewtype: Newtype[String] with + inline def validate(value: String): Boolean = + value.nonEmpty + + override inline def failureMessage = "String must not be empty" + +type MySubtype = MySubtype.Type +given MySubtype: Subtype[String] with + inline def validate(value: String): Boolean = + value.length > 10 + + override inline def failureMessage = "String must be longer than 10 characters" + +type SimpleNewtype = SimpleNewtype.Type +given SimpleNewtype: Newtype.Simple[Int] with {} + +type SimpleSubtype = SimpleSubtype.Type +given SimpleSubtype: Subtype.Simple[String] with {} + +final case class SomeService( + myNewtype: MyNewtype, + simpleNewtype: SimpleNewtype +): + def showAll = ZIO.succeed(s""" + |myNewtype: $myNewtype + |simpleNewtype: $simpleNewtype + |""".stripMargin.trim) + +object SomeService: + val showAll = ZIO.serviceWithZIO[SomeService](_.showAll) + val layer = ZLayer.fromFunction(SomeService.apply _) + +object ZioSpec extends ZIOSpecDefault: + + def spec = + val newtypeLayer: ULayer[MyNewtype] = ZLayer.succeed(MyNewtype("Hello")) + val simpleNewtypeLayer: ULayer[SimpleNewtype] = ZLayer.succeed(SimpleNewtype(1)) +// val subtypeLayer: ULayer[MySubtype] = ZLayer.succeed(MySubtype("Hello World")) +// val combined: ZLayer[Any, Nothing, MyNewtype & SimpleNewtype & MySubtype] = +// newtypeLayer ++ simpleNewtypeLayer ++ subtypeLayer + suite("ZioSpec")( + suite("MyNewtype")( + test("tag materialization") { + SomeService.showAll.map { str => + assertTrue( + str == + """ + |myNewtype: Hello + |simpleNewtype: 1 + |""".stripMargin.trim + ) + } + } + ).provide( + SomeService.layer, + newtypeLayer ++ simpleNewtypeLayer + // TODO: Does not work! +// newtypeLayer, +// simpleNewtypeLayer + ) + )