Skip to content

Commit

Permalink
Find extension copy in companion
Browse files Browse the repository at this point in the history
  • Loading branch information
OndrejSpanel committed Dec 2, 2024
1 parent cf75c7f commit 0ba4cf8
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 51 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ excludeLintKeys in Global ++= Set(ideSkipProject)
val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq(
organization := "com.softwaremill.quicklens",
updateDocs := UpdateVersionInDocs(sLog.value, organization.value, version.value, List(file("README.md"))),
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all"
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all", "-Xcheck-macros"
ideSkipProject := (scalaVersion.value != scalaIdeaVersion)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ object QuicklensMacros {
/** Method call with one type parameter and using clause */
case a @ Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) =>
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
case Apply(Ident(ident), Seq(deep)) => // this is an extension method, which is called e.g. as x(_$1)
toPath(deep, focus) :+ PathSymbol.Field(ident)
/** Field access */
case Apply(deep, idents) =>
toPath(deep, focus) ++ idents.flatMap(toPath(_, focus))
Expand Down Expand Up @@ -179,21 +181,77 @@ object QuicklensMacros {
case lst => report.errorAndAbort(multipleMatchingMethods(sym.name, name, lst))
}

def methodSymbolByNameAndArgsOrError(sym: Symbol, name: String, argsMap: Map[String, Term]): Symbol = {
def filterMethodsByNameAndArgs(allMethods: List[Symbol], argsMap: Map[String, Term]): Option[Symbol] = {
val argNames = argsMap.keys
sym.methodMember(name).filter{ msym =>
allMethods.filter { msym =>
// for copy, we filter out the methods that don't have the desired parameter names
val paramNames = msym.paramSymss.flatten.filter(_.isTerm).map(_.name)
argNames.forall(paramNames.contains)
} match
case List(m) => m
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
case lst @ (m :: _) =>
case List(m) => Some(m)
case Nil => None
case lst@(m :: _) =>
// if we have multiple matching copy methods, pick the synthetic one, if it exists, otherwise, pick any method
val syntheticCopies = lst.filter(_.flags.is(Flags.Synthetic))
syntheticCopies match
case List(mSynth) => mSynth
case _ => m
case List(mSynth) => Some(mSynth)
case _ => Some(m)
}

def methodSymbolByNameAndArgs(sym: Symbol, name: String, argsMap: Map[String, Term]): Option[Symbol] = {
val memberMethods = sym.methodMember(name)
filterMethodsByNameAndArgs(memberMethods, argsMap)
}

/**
* @param argsMap normal methods receive one parameter list, extensions methods two, the first one contains the value
* on which the extension is called
* */
def callMethod(obj: Term, copy: Symbol, argsMap: List[Map[String, Term]]) = {
val objTpe = obj.tpe.widenAll
val objSymbol = objTpe.matchingTypeSymbol

val typeParams = objTpe match {
case AppliedType(_, typeParams) => Some(typeParams)
case _ => None
}
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
val copyParams: List[(String, Option[Term])] = copyTree.termParamss.zip(argsMap)
.map((params, args) => params.params.map(_.name).map(name => name -> args.get(name)))
.flatten.toList

val args = copyParams.zipWithIndex.map { case ((n, v), _i) =>
val i = _i + 1
def defaultMethod =
val methodSymbol = methodSymbolByNameOrError(objSymbol, copy.name + "$default$" + i.toString)
// default values in extensions are obtained by calling a method receiving the extension parameter
val defaultMethodArgs = argsMap.dropRight(1).headOption.toList.flatMap(_.values)
//println(s"defaultMethodArgs ${obj.show} ${methodSymbol.name} $defaultMethodArgs")
if defaultMethodArgs.nonEmpty then
Apply(Select(obj, methodSymbol), defaultMethodArgs)
else
// note: this is not always correct, -Xcheck-macros shows errors here
// sometimes we should call a method with empry parameter list instead
obj.select(methodSymbol)

// for extension methods, might need sth more like this: (or probably some weird implicit conversion)
// val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
n -> v.getOrElse(defaultMethod)
}.toMap

val argLists = copyTree.termParamss.take(argsMap.size).map(list => list.params.map(p => args(p.name)))

if copyTree.termParamss.drop(argLists.size).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit. ${copyTree.termParamss.drop(1)}"
)

val applyOn = typeParams match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case Some(typeParams) => TypeApply(Select(obj, copy), typeParams.map(Inferred(_)))
case _ => Select(obj, copy)
}
argLists.foldLeft(applyOn)((applied, list) => Apply(applied, list))
}

def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
Expand All @@ -210,8 +268,19 @@ object QuicklensMacros {
(sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract)))
}

def findExtensionMethod(using Quotes)(sym: Symbol, methodName: String): List[Symbol] = {
// TODO: can we check parameter types somehow?
def isExtensionMethod(sym: Symbol): Boolean = sym.isDefDef && sym.paramSymss.headOption.exists(_.sizeIs == 1)

// TODO: try to search in symbol parent object as well
val symbols = Seq(sym.companionModule).filter(_ != Symbol.noSymbol)

symbols.flatMap(_.declaredMethods).filter(sym => sym.name == methodName).filter(isExtensionMethod).toList
}

def isProductLike(sym: Symbol): Boolean = {
sym.methodMember("copy").size >= 1
// just assume true - we can always fail if there is no copy
sym.methodMember("copy").nonEmpty || findExtensionMethod(sym, "copy").nonEmpty
}

def caseClassCopy(
Expand Down Expand Up @@ -248,6 +317,7 @@ object QuicklensMacros {
}

val elseThrow = '{ throw new IllegalStateException() }.asTerm

ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) =>
If(ifCond, ifThen, ifElse)
}
Expand All @@ -260,36 +330,18 @@ object QuicklensMacros {
val namedArg = NamedArg(field.name, resTerm)
field.name -> namedArg
}.toMap
val copy = methodSymbolByNameAndArgsOrError(objSymbol, "copy", argsMap)

val typeParams = objTpe match {
case AppliedType(_, typeParams) => Some(typeParams)
case _ => None
}
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
val copyParamNames: List[String] = copyTree.termParamss.headOption.map(_.params).toList.flatten.map(_.name)

val args = copyParamNames.zipWithIndex.map { (n, _i) =>
val i = _i + 1
val defaultMethod = obj.select(methodSymbolByNameOrError(objSymbol, "copy$default$" + i.toString))
// for extension methods, might need sth more like this: (or probably some weird implicit conversion)
// val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
argsMap.getOrElse(
n,
defaultMethod
)
}.toList

if copyTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit."
)

typeParams match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
case _ => Apply(Select(obj, copy), args)
}
methodSymbolByNameAndArgs(objSymbol, "copy", argsMap) match
case Some(copy) =>
callMethod(obj, copy, List(argsMap))
case None =>
val objCompanion = objSymbol.companionModule
methodSymbolByNameAndArgs(objCompanion, "copy", argsMap) match
case Some(copy) =>
// now try to call the extension as a method, assume the object is its first parameter
val extensionParameter = copy.paramSymss.headOption.map(_.headOption).flatten
val argsWithObj = List(extensionParameter.map(name => name.name -> obj).toMap, argsMap)
callMethod(Ref(objCompanion), copy, argsWithObj)
case None => report.errorAndAbort(noSuchMember(objSymbol.name, "copy"))
} else
report.errorAndAbort(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol of type ${objTpe.show} (${obj.show})")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,79 @@ object ExtensionCopyTest {

object Vec {
def apply(x: Double, y: Double): Vec = V(x, y)
}

extension (v: Vec) {
def x: Double = v.x
def y: Double = v.y
def copy(x: Double = v.x, y: Double = v.y): Vec = V(x, y)
extension (v: Vec) {
def x: Double = v.x
def y: Double = v.y
def copy(x: Double = v.x, y: Double = v.y): Vec = V(x, y)
}
}
}

class ExtensionCopyTest extends AnyFlatSpec with Matchers {
/*
it should "modify a simple class with an extension copy method" in {
class VecSimple(xp: Double, yp: Double) {
val xMember = xp
val yMember = yp
}
object VecSimple {
def apply(x: Double, y: Double): VecSimple = new VecSimple(x, y)
}
extension (v: VecSimple) {
def copy(x: Double = v.xMember, y: Double = v.yMember): VecSimple = new VecSimple(x, y)
}
val a = VecSimple(1, 2)
val b = a.modify(_.xMember).using(_ + 1)
println(b)
}
*/

it should "modify a simple class with an extension copy method in companion" in {
class VecCompanion(xp: Double, yp: Double) {
val x = xp
val y = yp
}

object VecCompanion {
def apply(x: Double, y: Double): VecCompanion = new VecCompanion(x, y)
extension (v: VecCompanion) {
def copy(x: Double = v.x, y: Double = v.y): VecCompanion = new VecCompanion(x, y)
}
}

val a = VecCompanion(1, 2)
val b = a.modify(_.x).using(_ + 1)
println(b)
}
/*
it should "modify a class with an extension copy method" in {
case class V(x: Double, y: Double)
class Vec(val v: V)
class VecClass(val v: V)
object Vec {
def apply(x: Double, y: Double): Vec = new Vec(V(x, y))
object VecClass {
def apply(x: Double, y: Double): VecClass = new VecClass(V(x, y))
}
extension (v: Vec) {
extension (v: VecClass) {
def x: Double = v.v.x
def y: Double = v.v.y
def copy(x: Double = v.x, y: Double = v.y): Vec = new Vec(V(x, y))
def copy(x: Double = v.x, y: Double = v.y): VecClass = new VecClass(V(x, y))
}
val a = Vec(1, 2)
val a = VecClass(1, 2)
val b = a.modify(_.x).using(_ + 1)
println(b)
}

it should "modify an opaque type with an extension copy method" in {
import ExtensionCopyTest.*
val a = Vec(1, 2)
val b = a.modify(_.x).using(_ + 1)
println(b)
}
*/
}

0 comments on commit 0ba4cf8

Please sign in to comment.