Skip to content

Commit

Permalink
feat: Add List compile-time support
Browse files Browse the repository at this point in the history
  • Loading branch information
Iltotore committed Sep 13, 2024
1 parent 209c864 commit 6b2f8b5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 20 deletions.
22 changes: 19 additions & 3 deletions main/src/io/github/iltotore/iron/constraint/collection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ object collection:

class LengthIterable[I <: Iterable[?], C, Impl <: Constraint[Int, C]](using Impl) extends Constraint[I, Length[C]]:

override inline def test(inline value: I): Boolean = summonInline[Impl].test(value.size)
override inline def test(inline value: I): Boolean = ${ checkIterable('value, '{ summonInline[Impl] }) }

override inline def message: String = "Length: (" + summonInline[Impl].message + ")"

Expand All @@ -111,6 +111,14 @@ object collection:

inline given lengthString[C, Impl <: Constraint[Int, C]](using inline impl: Impl): LengthString[C, Impl] = new LengthString

private def checkIterable[I <: Iterable[?] : Type, C, Impl <: Constraint[Int, C]](expr: Expr[I], constraintExpr: Expr[Impl])(using Quotes): Expr[Boolean] =
val rflUtil = reflectUtil
import rflUtil.*

expr.decode match
case Right(value) => applyConstraint(Expr(value.size), constraintExpr)
case _ => applyConstraint('{ $expr.size }, constraintExpr)

private def checkString[C, Impl <: Constraint[Int, C]](expr: Expr[String], constraintExpr: Expr[Impl])(using Quotes): Expr[Boolean] =
val rflUtil = reflectUtil
import rflUtil.*
Expand All @@ -124,16 +132,24 @@ object collection:
object Contain:
inline given [A, V <: A, I <: Iterable[A]]: Constraint[I, Contain[V]] with

override inline def test(inline value: I): Boolean = value.iterator.contains(constValue[V])
override inline def test(inline value: I): Boolean = ${ checkIterable('value, '{ constValue[V] }) }

override inline def message: String = "Should contain at most " + stringValue[V] + " elements"
override inline def message: String = "Should contain the value " + stringValue[V]

inline given [V <: String]: Constraint[String, Contain[V]] with

override inline def test(inline value: String): Boolean = ${ checkString('value, '{ constValue[V] }) }

override inline def message: String = "Should contain the string " + constValue[V]

private def checkIterable[I <: Iterable[?] : Type, V : Type](expr: Expr[I], partExpr: Expr[V])(using Quotes): Expr[Boolean] =
val rflUtil = reflectUtil
import rflUtil.*

(expr.decode, partExpr.decode) match
case (Right(value), Right(part)) => Expr(value.iterator.contains(part))
case _ => '{ ${ expr }.iterator.contains($partExpr) }

private def checkString(expr: Expr[String], partExpr: Expr[String])(using Quotes): Expr[Boolean] =
val rflUtil = reflectUtil
import rflUtil.*
Expand Down
48 changes: 31 additions & 17 deletions main/src/io/github/iltotore/iron/macros/ReflectUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):

type DecodingResult[+T] = Either[DecodingFailure, T]
extension [T](result: DecodingResult[T])
private def as[U]: DecodingResult[U] = result.asInstanceOf[Either[DecodingFailure, U]]
private def as[U]: DecodingResult[U] = result.asInstanceOf[DecodingResult[U]]

extension [T: Type](expr: Expr[T])
/**
* Decode this expression.
*
* @return the value of this expression found at compile time or a [[DecodingFailure]]
*/
def decode: Either[DecodingFailure, T] = ExprDecoder.decodeTerm(expr.asTerm, Map.empty).as[T]
def decode: DecodingResult[T] = ExprDecoder.decodeTerm(expr.asTerm, Map.empty).as[T]

/**
* A decoding failure.
Expand Down Expand Up @@ -71,32 +71,32 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
*
* @param parameters the list of decoded parameters, whether an failure or a value of unknown type
*/
case ApplyNotInlined(name: String, parameters: List[Either[DecodingFailure, ?]])
case ApplyNotInlined(name: String, parameters: List[DecodingResult[?]])

case VarArgsNotInlined(args: List[Either[DecodingFailure, ?]])
case VarArgsNotInlined(args: List[DecodingResult[?]])

/**
* A boolean OR is not inlined.
*
* @param left the left operand
* @param right the right operand
*/
case OrNotInlined(left: Either[DecodingFailure, Boolean], right: Either[DecodingFailure, Boolean])
case OrNotInlined(left: DecodingResult[Boolean], right: Either[DecodingFailure, Boolean])

/**
* A boolean AND is not inlined.
*
* @param left the left operand
* @param right the right operand
*/
case AndNotInlined(left: Either[DecodingFailure, Boolean], right: Either[DecodingFailure, Boolean])
case AndNotInlined(left: DecodingResult[Boolean], right: Either[DecodingFailure, Boolean])

/**
* Some part of the decoded String are not inlined. A more specialized version of [[ApplyNotInlined]].
*
* @param parts the parts of the String
*/
case StringPartsNotInlined(parts: List[Either[DecodingFailure, String]])
case StringPartsNotInlined(parts: List[DecodingResult[String]])

/**
* The given String interpolator cannot be inlined.
Expand Down Expand Up @@ -185,10 +185,11 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):

object ExprDecoder:

private val enhancedDecoders: Map[TypeRepr, (Term, Map[String, ?]) => Either[DecodingFailure, ?]] = Map(
private val enhancedDecoders: Map[TypeRepr, (Term, Map[String, ?]) => DecodingResult[?]] = Map(
TypeRepr.of[Boolean] -> decodeBoolean,
TypeRepr.of[BigDecimal] -> decodeBigDecimal,
TypeRepr.of[BigInt] -> decodeBigInt,
TypeRepr.of[List[?]] -> decodeList,
TypeRepr.of[String] -> decodeString
)

Expand All @@ -199,10 +200,10 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
* @param definitions the decoded definitions in scope
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
def decodeTerm(tree: Term, definitions: Map[String, ?]): Either[DecodingFailure, ?] =
def decodeTerm(tree: Term, definitions: Map[String, ?]): DecodingResult[?] =
val specializedResult = enhancedDecoders
.collectFirst:
case (k, v) if k =:= tree.tpe => v
case (k, v) if tree.tpe <:< k => v
.toRight(DecodingFailure.Unknown)
.flatMap(_.apply(tree, definitions))

Expand All @@ -218,13 +219,13 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
* @tparam T the expected type of this term used as implicit cast for convenience
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
def decodeUnspecializedTerm(tree: Term, definitions: Map[String, ?]): Either[DecodingFailure, ?] =
def decodeUnspecializedTerm(tree: Term, definitions: Map[String, ?]): DecodingResult[?] =
tree match
case block @ Block(stats, e) => if stats.isEmpty then decodeTerm(e, definitions) else Left(DecodingFailure.HasStatements(block))

case Inlined(_, bindings, e) =>
val (failures, values) = bindings
.map[(String, Either[DecodingFailure, ?])](b => (b.name, decodeBinding(b, definitions)))
.map[(String, DecodingResult[?])](b => (b.name, decodeBinding(b, definitions)))
.partitionMap:
case (name, Right(value)) => Right((name, value))
case (name, Left(failure)) => Left((name, failure))
Expand Down Expand Up @@ -281,7 +282,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
* @tparam T the expected type of this term used as implicit cast for convenience
* @return the value of the given definition found at compile time or a [[DecodingFailure]]
*/
def decodeBinding(definition: Definition, definitions: Map[String, ?]): Either[DecodingFailure, ?] = definition match
def decodeBinding(definition: Definition, definitions: Map[String, ?]): DecodingResult[?] = definition match
case ValDef(name, tpeTree, Some(term)) => decodeTerm(term, definitions)
case DefDef(name, Nil, tpeTree, Some(term)) => decodeTerm(term, definitions)
case _ => Left(DecodingFailure.DefinitionNotInlined(definition.name))
Expand All @@ -293,7 +294,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
* @param definitions the decoded definitions in scope
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
def decodeBoolean(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, ?] = term match
def decodeBoolean(term: Term, definitions: Map[String, ?]): DecodingResult[?] = term match
case Apply(Select(left, "||"), List(right)) if left.tpe <:< TypeRepr.of[Boolean] && right.tpe <:< TypeRepr.of[Boolean] => // OR
(decodeTerm(left, definitions).as[Boolean], decodeTerm(right, definitions).as[Boolean]) match
case (Right(true), _) => Right(true)
Expand All @@ -317,7 +318,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
* @param definitions the decoded definitions in scope
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
def decodeString(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, String] = term match
def decodeString(term: Term, definitions: Map[String, ?]): DecodingResult[String] = term match
case Apply(Select(left, "+"), List(right)) if left.tpe <:< TypeRepr.of[String] && right.tpe <:< TypeRepr.of[String] =>
(decodeTerm(left, definitions).as[String], decodeTerm(right, definitions).as[String]) match
case (Right(leftValue), Right(rightValue)) => Right(leftValue + rightValue)
Expand All @@ -338,7 +339,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
* @param definitions the decoded definitions in scope
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
def decodeBigInt(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, BigInt] =
def decodeBigInt(term: Term, definitions: Map[String, ?]): DecodingResult[BigInt] =
term match
case Apply(Select(Ident("BigInt"), "apply"), List(value)) =>
decodeTerm(value, definitions).as[Int | Long].map:
Expand All @@ -353,7 +354,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
* @param definitions the decoded definitions in scope
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
def decodeBigDecimal(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, BigDecimal] =
def decodeBigDecimal(term: Term, definitions: Map[String, ?]): DecodingResult[BigDecimal] =
term match
case Apply(Select(Ident("BigDecimal"), "apply"), List(value)) =>
decodeTerm(value, definitions).as[NumConstant].map:
Expand All @@ -363,3 +364,16 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
case x: Double => BigDecimal(x)

case _ => Left(DecodingFailure.Unknown)

/**
* Decode a [[List]] term using only [[List]]-specific cases.
*
* @param term the term to decode
* @param definitions the decoded definitions in scope
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
def decodeList(term: Term, definitions: Map[String, ?]): DecodingResult[List[?]] =
term match
case Apply(TypeApply(Select(Ident("List"), "apply"), _), List(values)) =>
decodeTerm(values, definitions).as[List[?]]
case _ => Left(DecodingFailure.Unknown)

0 comments on commit 6b2f8b5

Please sign in to comment.