Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework modifyAll for Scala 3 #82

Merged
merged 6 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ object QuicklensMacros {
def modifyLensApplyImpl[T, U](path: Expr[T => U])(using Quotes, Type[T], Type[U]): Expr[PathLazyModify[T, U]] = '{
PathLazyModify { (t, mod) =>
${
toPathModify('t, modifyImpl('t, path))
toPathModify('t, modifyImpl('t, Seq(path)))
}.using(mod)
}
}
Expand All @@ -37,24 +37,18 @@ object QuicklensMacros {
import quotes.reflect.*

val focuses = focusesExpr match {
case Varargs(args) => args
case Varargs(args) => focus +: args
}

val modF1 = modifyImpl(obj, focus)
val modF = { (mod: Expr[A => A]) =>
focuses.foldLeft(from[(A => A), S](modF1).apply(mod)) { case (objAcc, focus) =>
val modCur = modifyImpl(objAcc, focus)
from[(A => A), S](modCur).apply(mod)
}
}
val modF = modifyImpl(obj, focuses)

toPathModify(obj, to(modF))
toPathModify(obj, modF)
}

def toPathModifyFromFocus[S: Type, A: Type](obj: Expr[S], focus: Expr[S => A])(using Quotes): Expr[PathModify[S, A]] =
toPathModify(obj, modifyImpl(obj, focus))
toPathModify(obj, modifyImpl(obj, Seq(focus)))

private def modifyImpl[S: Type, A: Type](obj: Expr[S], focus: Expr[S => A])(using Quotes): Expr[(A => A) => S] = {
private def modifyImpl[S: Type, A: Type](obj: Expr[S], focuses: Seq[Expr[S => A]])(using Quotes): Expr[(A => A) => S] = {
import quotes.reflect.*

def unsupportedShapeInfo(tree: Tree) =
Expand All @@ -63,27 +57,71 @@ object QuicklensMacros {
def methodSupported(method: String) =
Seq("at", "each", "eachWhere", "eachRight", "eachLeft", "atOrElse", "index", "when").contains(method)

enum PathTree:
case Empty
case Node(children: Seq[(PathSymbol, Seq[PathTree])])

object PathTree:
def empty: PathTree = Empty

extension (symbols: Seq[PathSymbol])
def toPathTree: PathTree = symbols match
case Nil => PathTree.Empty
case (symbol :: tail) => PathTree.Node(Seq(symbol -> Seq(tail.toPathTree)))

extension (t1: PathTree)
def <>(symbols: Seq[PathSymbol]): PathTree = (t1, symbols) match
case (PathTree.Empty, _) =>
symbols.toPathTree
case (PathTree.Node(children), (symbol :: Nil)) =>
PathTree.Node {
if children.find(_._1 == symbol).isEmpty then
children :+ (symbol -> Seq(PathTree.Empty))
else
children.map {
case (sym, trees) if sym == symbol =>
sym -> (trees :+ PathTree.Empty)
case c => c
}
}
case (PathTree.Node(children), Nil) =>
t1
case (PathTree.Node(children), (symbol :: tail)) =>
PathTree.Node {
if children.find(_._1 == symbol).isEmpty then
children :+ (symbol -> Seq(tail.toPathTree))
else
children.map {
case (sym, trees) if sym == symbol =>
sym -> { trees.init ++ { trees.last match
case PathTree.Empty => Seq(PathTree.Empty, tail.toPathTree)
case node => Seq(node <> tail)
}}
case c => c
}
}

enum PathSymbol:
case Field(name: String)
case FunctionDelegate(name: String, givn: Term, typeTree: TypeTree, args: List[Term])
case FunctionDelegate(name: String, givn: Term, typeTree: TypeTree, args: List[Term]) //TODO probably have to override equals (and hashCode)

def toPath(tree: Tree): Seq[PathSymbol] = {
def toPath(tree: Tree, focus: Expr[S => A]): Seq[PathSymbol] = {
tree match {
/** Field access */
case Select(deep, ident) =>
toPath(deep) :+ PathSymbol.Field(ident)
toPath(deep, focus) :+ PathSymbol.Field(ident)
/** Method call with arguments and using clause */
case Apply(Apply(Apply(TypeApply(Ident(s), typeTrees), idents), args), List(givn)) if methodSupported(s) =>
idents.flatMap(toPath) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, args)
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, args)
/** Method call with no arguments and using clause */
case Apply(Apply(TypeApply(Ident(s), typeTrees), idents), List(givn)) if methodSupported(s) =>
idents.flatMap(toPath) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
/** Method call with one type parameter and using clause */
case a @ Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) =>
idents.flatMap(toPath) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
/** Field access */
case Apply(deep, idents) =>
toPath(deep) ++ idents.flatMap(toPath)
toPath(deep, focus) ++ idents.flatMap(toPath(_, focus))
/** Wild card from path */
case i: Ident if i.name.startsWith("_") =>
Seq.empty
Expand All @@ -106,33 +144,41 @@ object QuicklensMacros {
owner: Symbol,
mod: Expr[A => A],
obj: Term,
field: PathSymbol.Field,
tail: Seq[PathSymbol]
): Term =
fields: Seq[(PathSymbol.Field, Seq[PathTree])]
): Term = {
val objSymbol = obj.tpe.typeSymbol
if objSymbol.flags.is(Flags.Case) then
if objSymbol.flags.is(Flags.Case) then {
val copy = termMethodByNameUnsafe(obj, "copy")
val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name)
val namedArg = NamedArg(field.name, mapToCopy(owner, mod, Select(obj, fieldMethod), tail))
val argsMap: Map[Int, Term] = fields.map { (field, trees) =>
val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name)
val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) =>
mapToCopy(owner, mod, term, tree)
}
val namedArg = NamedArg(field.name, resTerm)
idx -> namedArg
}.toMap

val fieldsIdxs = 1.to(obj.tpe.typeSymbol.caseFields.length)
val args = fieldsIdxs.map { i =>
if i == idx then namedArg
else Select(obj, termMethodByNameUnsafe(obj, "copy$default$" + i.toString))
argsMap.getOrElse(
i,
Select(obj, termMethodByNameUnsafe(obj, "copy$default$" + i.toString))
)
}.toList

obj.tpe.widen match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case AppliedType(_, typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
case _ => Apply(Select(obj, copy), args)
}
else if objSymbol.flags.is(Flags.Enum) ||
} else if objSymbol.flags.is(Flags.Enum) ||
(objSymbol.flags.is(Flags.Sealed) && (objSymbol.flags.is(Flags.Trait) || objSymbol.flags.is(Flags.Abstract)))
then
then {
// if the source is a sealed trait / sealed abstract class / enum, generating a if-then-else with a .copy for each child (implementing case class)
val cases = obj.tpe.typeSymbol.children.map { child =>
val subtype = TypeIdent(child)
val bind = Symbol.newBind(owner, "c", Flags.EmptyFlags, subtype.tpe)
CaseDef(Bind(bind, Typed(Ref(bind), subtype)), None, caseClassCopy(owner, mod, Ref(bind), field, tail))
CaseDef(Bind(bind, Typed(Ref(bind), subtype)), None, caseClassCopy(owner, mod, Ref(bind), fields))
}

/*
Expand All @@ -146,7 +192,7 @@ object QuicklensMacros {

val ifThen = ValDef.let(owner, TypeApply(Select.unique(obj, "asInstanceOf"), List(TypeIdent(child)))) {
castToChildVal =>
caseClassCopy(owner, mod, castToChildVal, field, tail)
caseClassCopy(owner, mod, castToChildVal, fields)
}

ifCond -> ifThen
Expand All @@ -156,54 +202,87 @@ object QuicklensMacros {
ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) =>
If(ifCond, ifThen, ifElse)
}
else report.throwError(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol")
} else
report.throwError(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol")
}

def applyFunctionDelegate(
owner: Symbol,
mod: Expr[A => A],
objTerm: Term,
f: PathSymbol.FunctionDelegate,
tree: PathTree
): Term =
val defdefSymbol = Symbol.newMethod(
owner,
"$anonfun",
MethodType(List("x"))(_ => List(f.typeTree.tpe), _ => f.typeTree.tpe)
)
val fMethod = termMethodByNameUnsafe(f.givn, f.name)
val fun = TypeApply(
Select(f.givn, fMethod),
List(f.typeTree)
)
val defdefStatements = DefDef(
defdefSymbol,
{ case List(List(x)) => Some(mapToCopy(defdefSymbol, mod, x.asExpr.asTerm, tree)) }
)
val closure = Closure(Ref(defdefSymbol), None)
val block = Block(List(defdefStatements), closure)
Apply(fun, List(objTerm, block) ++ f.args)

def accumulateToCopy(
owner: Symbol,
mod: Expr[A => A],
objTerm: Term,
pathSymbols: Seq[(PathSymbol, Seq[PathTree])]
): Term = pathSymbols match {

def mapToCopy(owner: Symbol, mod: Expr[A => A], objTerm: Term, path: Seq[PathSymbol]): Term = path match
case Nil =>
val apply = termMethodByNameUnsafe(mod.asTerm, "apply")
Apply(Select(mod.asTerm, apply), List(objTerm))
case (field: PathSymbol.Field) :: tail =>
caseClassCopy(owner, mod, objTerm, field, tail)
objTerm

case (_: PathSymbol.Field, _) :: _ =>
val (fs, funs) = pathSymbols.span(_._1.isInstanceOf[PathSymbol.Field])
val fields = fs.collect { case (p: PathSymbol.Field, trees) => p -> trees }
val withCopiedFields: Term = caseClassCopy(owner, mod, objTerm, fields)
accumulateToCopy(owner, mod, withCopiedFields, funs)

/** For FunctionDelegate(method, givn, T, args)
*
* Generates: `givn.method[T](obj, x => mapToCopy(...), ...args)`
*/
case (f: PathSymbol.FunctionDelegate) :: tail =>
val defdefSymbol = Symbol.newMethod(
owner,
"$anonfun",
MethodType(List("x"))(_ => List(f.typeTree.tpe), _ => f.typeTree.tpe)
)
val fMethod = termMethodByNameUnsafe(f.givn, f.name)
val fun = TypeApply(
Select(f.givn, fMethod),
List(f.typeTree)
)
val defdefStatements = DefDef(
defdefSymbol,
{ case List(List(x)) =>
Some(mapToCopy(defdefSymbol, mod, x.asExpr.asTerm, tail))
}
)
val closure = Closure(Ref(defdefSymbol), None)
val block = Block(List(defdefStatements), closure)
Apply(fun, List(objTerm, block) ++ f.args)
case (f: PathSymbol.FunctionDelegate, actions: Seq[PathTree]) :: tail =>
val term = actions.foldLeft(objTerm) { (term, tree) =>
applyFunctionDelegate(owner, mod, term, f, tree)
}
accumulateToCopy(owner, mod, term, tail)
}

val focusTree: Tree = focus.asTerm
val path = focusTree match {
def mapToCopy(owner: Symbol, mod: Expr[A => A], objTerm: Term, pathTree: PathTree): Term = pathTree match {
case PathTree.Empty =>
val apply = termMethodByNameUnsafe(mod.asTerm, "apply")
Apply(Select(mod.asTerm, apply), List(objTerm))
case PathTree.Node(children) =>
accumulateToCopy(owner, mod, objTerm, children)
}

val focusesTrees: Seq[Tree] = focuses.map(_.asTerm)
val paths: Seq[Seq[PathSymbol]] = focusesTrees.zip(focuses).map { (tree, focus) => tree match
/** Single inlined path */
case Inlined(_, _, Block(List(DefDef(_, _, _, Some(p))), _)) =>
toPath(p)
toPath(p, focus)
/** One of paths from modifyAll */
case Block(List(DefDef(_, _, _, Some(p))), _) =>
toPath(p)
toPath(p, focus)
case _ =>
report.throwError(unsupportedShapeInfo(focusTree))
report.throwError(unsupportedShapeInfo(tree))
}

val pathTree: PathTree =
paths.foldLeft(PathTree.empty) { (tree, path) => tree <> path }

val res: (Expr[A => A] => Expr[S]) = (mod: Expr[A => A]) =>
mapToCopy(Symbol.spliceOwner, mod, obj.asTerm, path).asExpr.asInstanceOf[Expr[S]]
mapToCopy(Symbol.spliceOwner, mod, obj.asTerm, pathTree).asExpr.asInstanceOf[Expr[S]]
to(res)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package com.softwaremill.quicklens

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class EnormousModifyAllTest extends AnyFlatSpec with Matchers {
import EnormousModifyAllTest._

it should "expand an enormous function" in {
val c6 = C6(1)
val c5 = C5(c6, c6, c6, c6)
val c4 = C4(c5, c5, c5, c5)
val c3 = C3(c4, c4, c4, c4)
val c2 = C2(c3, c3, c3, c3)
val c1 = C1(c2, c2, c2, c2)

val c6e = C6(2)
val c5e = C5(c6e, c6, c6, c6)
val c4e = C4(c5e, c5, c5, c5)
val c3e = C3(c4e, c4, c4, c4)
val c2e = C2(c3e, c3, c3, c3)
val c1e = C1(c2e, c2e, c2e, c2e)

val res = c1
.modifyAll(
_.a.a.a.a.a.a,
_.b.a.a.a.a.a,
_.c.a.a.a.a.a,
_.d.a.a.a.a.a
).using(_ + 1)
res should be(c1e)
}
}

object EnormousModifyAllTest {
case class C1(
a: C2,
b: C2,
c: C2,
d: C2
)

case class C2(
a: C3,
b: C3,
c: C3,
d: C3
)

case class C3(
a: C4,
b: C4,
c: C4,
d: C4
)

case class C4(
a: C5,
b: C5,
c: C5,
d: C5
)

case class C5(
a: C6,
b: C6,
c: C6,
d: C6
)

case class C6(
a: Int
)
}
Loading