Skip to content

Commit

Permalink
Merge pull request #82 from KacperFKorban/rework-modifyAll
Browse files Browse the repository at this point in the history
Rework modifyAll for Scala 3
  • Loading branch information
adamw authored Mar 9, 2022
2 parents a2a9b88 + bb191ea commit fdae07d
Show file tree
Hide file tree
Showing 5 changed files with 674 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ object QuicklensMacros {
def modifyLensApplyImpl[T, U](path: Expr[T => U])(using Quotes, Type[T], Type[U]): Expr[PathLazyModify[T, U]] = '{
PathLazyModify { (t, mod) =>
${
toPathModify('t, modifyImpl('t, path))
toPathModify('t, modifyImpl('t, Seq(path)))
}.using(mod)
}
}
Expand All @@ -37,24 +37,18 @@ object QuicklensMacros {
import quotes.reflect.*

val focuses = focusesExpr match {
case Varargs(args) => args
case Varargs(args) => focus +: args
}

val modF1 = modifyImpl(obj, focus)
val modF = { (mod: Expr[A => A]) =>
focuses.foldLeft(from[(A => A), S](modF1).apply(mod)) { case (objAcc, focus) =>
val modCur = modifyImpl(objAcc, focus)
from[(A => A), S](modCur).apply(mod)
}
}
val modF = modifyImpl(obj, focuses)

toPathModify(obj, to(modF))
toPathModify(obj, modF)
}

def toPathModifyFromFocus[S: Type, A: Type](obj: Expr[S], focus: Expr[S => A])(using Quotes): Expr[PathModify[S, A]] =
toPathModify(obj, modifyImpl(obj, focus))
toPathModify(obj, modifyImpl(obj, Seq(focus)))

private def modifyImpl[S: Type, A: Type](obj: Expr[S], focus: Expr[S => A])(using Quotes): Expr[(A => A) => S] = {
private def modifyImpl[S: Type, A: Type](obj: Expr[S], focuses: Seq[Expr[S => A]])(using Quotes): Expr[(A => A) => S] = {
import quotes.reflect.*

def unsupportedShapeInfo(tree: Tree) =
Expand All @@ -63,27 +57,79 @@ object QuicklensMacros {
def methodSupported(method: String) =
Seq("at", "each", "eachWhere", "eachRight", "eachLeft", "atOrElse", "index", "when").contains(method)

enum PathTree:
case Empty
case Node(children: Seq[(PathSymbol, Seq[PathTree])])

def <>(symbols: Seq[PathSymbol]): PathTree = (this, symbols) match
case (PathTree.Empty, _) =>
symbols.toPathTree
case (PathTree.Node(children), (symbol :: Nil)) =>
PathTree.Node {
if children.find(_._1 equiv symbol).isEmpty then
children :+ (symbol -> Seq(PathTree.Empty))
else
children.map {
case (sym, trees) if sym equiv symbol =>
sym -> (trees :+ PathTree.Empty)
case c => c
}
}
case (PathTree.Node(children), Nil) =>
this
case (PathTree.Node(children), (symbol :: tail)) =>
PathTree.Node {
if children.find(_._1 equiv symbol).isEmpty then
children :+ (symbol -> Seq(tail.toPathTree))
else
children.map {
case (sym, trees) if sym equiv symbol =>
sym -> { trees.init ++ { trees.last match
case PathTree.Empty => Seq(PathTree.Empty, tail.toPathTree)
case node => Seq(node <> tail)
}}
case c => c
}
}
end PathTree

object PathTree:
def empty: PathTree = Empty

extension (symbols: Seq[PathSymbol])
def toPathTree: PathTree = symbols match
case Nil => PathTree.Empty
case (symbol :: tail) => PathTree.Node(Seq(symbol -> Seq(tail.toPathTree)))


enum PathSymbol:
case Field(name: String)
case FunctionDelegate(name: String, givn: Term, typeTree: TypeTree, args: List[Term])

def toPath(tree: Tree): Seq[PathSymbol] = {
def equiv(other: Any): Boolean = (this, other) match
case (Field(name1), Field(name2)) => name1 == name2
case (FunctionDelegate(name1, _, typeTree1, args1), FunctionDelegate(name2, _, typeTree2, args2)) =>
name1 == name2 && typeTree1.tpe == typeTree2.tpe && args1 == args2
case _ => false
end PathSymbol

def toPath(tree: Tree, focus: Expr[S => A]): Seq[PathSymbol] = {
tree match {
/** Field access */
case Select(deep, ident) =>
toPath(deep) :+ PathSymbol.Field(ident)
toPath(deep, focus) :+ PathSymbol.Field(ident)
/** Method call with arguments and using clause */
case Apply(Apply(Apply(TypeApply(Ident(s), typeTrees), idents), args), List(givn)) if methodSupported(s) =>
idents.flatMap(toPath) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, args)
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, args)
/** Method call with no arguments and using clause */
case Apply(Apply(TypeApply(Ident(s), typeTrees), idents), List(givn)) if methodSupported(s) =>
idents.flatMap(toPath) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
/** Method call with one type parameter and using clause */
case a @ Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) =>
idents.flatMap(toPath) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
/** Field access */
case Apply(deep, idents) =>
toPath(deep) ++ idents.flatMap(toPath)
toPath(deep, focus) ++ idents.flatMap(toPath(_, focus))
/** Wild card from path */
case i: Ident if i.name.startsWith("_") =>
Seq.empty
Expand All @@ -106,33 +152,41 @@ object QuicklensMacros {
owner: Symbol,
mod: Expr[A => A],
obj: Term,
field: PathSymbol.Field,
tail: Seq[PathSymbol]
): Term =
fields: Seq[(PathSymbol.Field, Seq[PathTree])]
): Term = {
val objSymbol = obj.tpe.typeSymbol
if objSymbol.flags.is(Flags.Case) then
if objSymbol.flags.is(Flags.Case) then {
val copy = termMethodByNameUnsafe(obj, "copy")
val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name)
val namedArg = NamedArg(field.name, mapToCopy(owner, mod, Select(obj, fieldMethod), tail))
val argsMap: Map[Int, Term] = fields.map { (field, trees) =>
val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name)
val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) =>
mapToCopy(owner, mod, term, tree)
}
val namedArg = NamedArg(field.name, resTerm)
idx -> namedArg
}.toMap

val fieldsIdxs = 1.to(obj.tpe.typeSymbol.caseFields.length)
val args = fieldsIdxs.map { i =>
if i == idx then namedArg
else Select(obj, termMethodByNameUnsafe(obj, "copy$default$" + i.toString))
argsMap.getOrElse(
i,
Select(obj, termMethodByNameUnsafe(obj, "copy$default$" + i.toString))
)
}.toList

obj.tpe.widen match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case AppliedType(_, typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
case _ => Apply(Select(obj, copy), args)
}
else if objSymbol.flags.is(Flags.Enum) ||
} else if objSymbol.flags.is(Flags.Enum) ||
(objSymbol.flags.is(Flags.Sealed) && (objSymbol.flags.is(Flags.Trait) || objSymbol.flags.is(Flags.Abstract)))
then
then {
// if the source is a sealed trait / sealed abstract class / enum, generating a if-then-else with a .copy for each child (implementing case class)
val cases = obj.tpe.typeSymbol.children.map { child =>
val subtype = TypeIdent(child)
val bind = Symbol.newBind(owner, "c", Flags.EmptyFlags, subtype.tpe)
CaseDef(Bind(bind, Typed(Ref(bind), subtype)), None, caseClassCopy(owner, mod, Ref(bind), field, tail))
CaseDef(Bind(bind, Typed(Ref(bind), subtype)), None, caseClassCopy(owner, mod, Ref(bind), fields))
}

/*
Expand All @@ -146,7 +200,7 @@ object QuicklensMacros {

val ifThen = ValDef.let(owner, TypeApply(Select.unique(obj, "asInstanceOf"), List(TypeIdent(child)))) {
castToChildVal =>
caseClassCopy(owner, mod, castToChildVal, field, tail)
caseClassCopy(owner, mod, castToChildVal, fields)
}

ifCond -> ifThen
Expand All @@ -156,54 +210,87 @@ object QuicklensMacros {
ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) =>
If(ifCond, ifThen, ifElse)
}
else report.throwError(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol")
} else
report.throwError(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol")
}

def applyFunctionDelegate(
owner: Symbol,
mod: Expr[A => A],
objTerm: Term,
f: PathSymbol.FunctionDelegate,
tree: PathTree
): Term =
val defdefSymbol = Symbol.newMethod(
owner,
"$anonfun",
MethodType(List("x"))(_ => List(f.typeTree.tpe), _ => f.typeTree.tpe)
)
val fMethod = termMethodByNameUnsafe(f.givn, f.name)
val fun = TypeApply(
Select(f.givn, fMethod),
List(f.typeTree)
)
val defdefStatements = DefDef(
defdefSymbol,
{ case List(List(x)) => Some(mapToCopy(defdefSymbol, mod, x.asExpr.asTerm, tree)) }
)
val closure = Closure(Ref(defdefSymbol), None)
val block = Block(List(defdefStatements), closure)
Apply(fun, List(objTerm, block) ++ f.args)

def accumulateToCopy(
owner: Symbol,
mod: Expr[A => A],
objTerm: Term,
pathSymbols: Seq[(PathSymbol, Seq[PathTree])]
): Term = pathSymbols match {

def mapToCopy(owner: Symbol, mod: Expr[A => A], objTerm: Term, path: Seq[PathSymbol]): Term = path match
case Nil =>
val apply = termMethodByNameUnsafe(mod.asTerm, "apply")
Apply(Select(mod.asTerm, apply), List(objTerm))
case (field: PathSymbol.Field) :: tail =>
caseClassCopy(owner, mod, objTerm, field, tail)
objTerm

case (_: PathSymbol.Field, _) :: _ =>
val (fs, funs) = pathSymbols.span(_._1.isInstanceOf[PathSymbol.Field])
val fields = fs.collect { case (p: PathSymbol.Field, trees) => p -> trees }
val withCopiedFields: Term = caseClassCopy(owner, mod, objTerm, fields)
accumulateToCopy(owner, mod, withCopiedFields, funs)

/** For FunctionDelegate(method, givn, T, args)
*
* Generates: `givn.method[T](obj, x => mapToCopy(...), ...args)`
*/
case (f: PathSymbol.FunctionDelegate) :: tail =>
val defdefSymbol = Symbol.newMethod(
owner,
"$anonfun",
MethodType(List("x"))(_ => List(f.typeTree.tpe), _ => f.typeTree.tpe)
)
val fMethod = termMethodByNameUnsafe(f.givn, f.name)
val fun = TypeApply(
Select(f.givn, fMethod),
List(f.typeTree)
)
val defdefStatements = DefDef(
defdefSymbol,
{ case List(List(x)) =>
Some(mapToCopy(defdefSymbol, mod, x.asExpr.asTerm, tail))
}
)
val closure = Closure(Ref(defdefSymbol), None)
val block = Block(List(defdefStatements), closure)
Apply(fun, List(objTerm, block) ++ f.args)
case (f: PathSymbol.FunctionDelegate, actions: Seq[PathTree]) :: tail =>
val term = actions.foldLeft(objTerm) { (term, tree) =>
applyFunctionDelegate(owner, mod, term, f, tree)
}
accumulateToCopy(owner, mod, term, tail)
}

val focusTree: Tree = focus.asTerm
val path = focusTree match {
def mapToCopy(owner: Symbol, mod: Expr[A => A], objTerm: Term, pathTree: PathTree): Term = pathTree match {
case PathTree.Empty =>
val apply = termMethodByNameUnsafe(mod.asTerm, "apply")
Apply(Select(mod.asTerm, apply), List(objTerm))
case PathTree.Node(children) =>
accumulateToCopy(owner, mod, objTerm, children)
}

val focusesTrees: Seq[Tree] = focuses.map(_.asTerm)
val paths: Seq[Seq[PathSymbol]] = focusesTrees.zip(focuses).map { (tree, focus) => tree match
/** Single inlined path */
case Inlined(_, _, Block(List(DefDef(_, _, _, Some(p))), _)) =>
toPath(p)
toPath(p, focus)
/** One of paths from modifyAll */
case Block(List(DefDef(_, _, _, Some(p))), _) =>
toPath(p)
toPath(p, focus)
case _ =>
report.throwError(unsupportedShapeInfo(focusTree))
report.throwError(unsupportedShapeInfo(tree))
}

val pathTree: PathTree =
paths.foldLeft(PathTree.empty) { (tree, path) => tree <> path }

val res: (Expr[A => A] => Expr[S]) = (mod: Expr[A => A]) =>
mapToCopy(Symbol.spliceOwner, mod, obj.asTerm, path).asExpr.asInstanceOf[Expr[S]]
mapToCopy(Symbol.spliceOwner, mod, obj.asTerm, pathTree).asExpr.asInstanceOf[Expr[S]]
to(res)
}
}
Loading

0 comments on commit fdae07d

Please sign in to comment.