From d0a47120c2120a97ebf72bf78167ed19b9f7e2ff Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Fri, 29 Nov 2024 16:33:10 +0100 Subject: [PATCH] Add a heuristic for picking a copy method for Scala 3, when multiple are found (#257) Co-authored-by: Adam Warski --- .gitignore | 3 +- .../quicklens/QuicklensMacros.scala | 36 +++++++++--- .../quicklens/test/ExplicitCopyTest.scala | 55 +++++++++++++++++++ 3 files changed, 84 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index c50eea7..971a1b0 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ out .bloop/ project/**/metals.sbt .vscode/ -.bsp \ No newline at end of file +.bsp +.scala-build \ No newline at end of file diff --git a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala index d352f21..20e640e 100644 --- a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala +++ b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala @@ -58,8 +58,9 @@ object QuicklensMacros { def noSuchMember(tpeStr: String, name: String) = s"$tpeStr has no member named $name" - def multipleMatchingMethods(tpeStr: String, name: String) = - s"Multiple methods named $name found in $tpeStr" + def multipleMatchingMethods(tpeStr: String, name: String, syms: Seq[Symbol]) = + val symsStr = syms.map(s => s" - $s: ${s.termRef.dealias.widen.show}").mkString("\n", "\n", "") + s"Multiple methods named $name found in $tpeStr: $symsStr" def methodSupported(method: String) = Seq("at", "each", "eachWhere", "eachRight", "eachLeft", "atOrElse", "index", "when").contains(method) @@ -168,19 +169,36 @@ object QuicklensMacros { def symbolAccessorByNameOrError(sym: Symbol, name: String): Symbol = { val mem = sym.fieldMember(name) if mem != Symbol.noSymbol then mem - else symbolMethodByNameOrError(sym, name) + else methodSymbolByNameOrError(sym, name) } - def symbolMethodByNameOrError(sym: Symbol, name: String): Symbol = { + def methodSymbolByNameOrError(sym: Symbol, name: String): Symbol = { sym.methodMember(name) match case List(m) => m case Nil => report.errorAndAbort(noSuchMember(sym.name, name)) - case _ => report.errorAndAbort(multipleMatchingMethods(sym.name, name)) + case lst => report.errorAndAbort(multipleMatchingMethods(sym.name, name, lst)) + } + + def methodSymbolByNameAndArgsOrError(sym: Symbol, name: String, argsMap: Map[String, Term]): Symbol = { + val argNames = argsMap.keys + sym.methodMember(name).filter{ msym => + // for copy, we filter out the methods that don't have the desired parameter names + val paramNames = msym.paramSymss.flatten.filter(_.isTerm).map(_.name) + argNames.forall(paramNames.contains) + } match + case List(m) => m + case Nil => report.errorAndAbort(noSuchMember(sym.name, name)) + case lst @ (m :: _) => + // if we have multiple matching copy methods, pick the synthetic one, if it exists, otherwise, pick any method + val syntheticCopies = lst.filter(_.flags.is(Flags.Synthetic)) + syntheticCopies match + case List(mSynth) => mSynth + case _ => m } def termMethodByNameUnsafe(term: Term, name: String): Symbol = { val typeSymbol = term.tpe.widenAll.typeSymbol - symbolMethodByNameOrError(typeSymbol, name) + methodSymbolByNameOrError(typeSymbol, name) } def isProduct(sym: Symbol): Boolean = { @@ -193,7 +211,7 @@ object QuicklensMacros { } def isProductLike(sym: Symbol): Boolean = { - sym.methodMember("copy").size == 1 + sym.methodMember("copy").size >= 1 } def caseClassCopy( @@ -234,7 +252,6 @@ object QuicklensMacros { If(ifCond, ifThen, ifElse) } } else if isProduct(objSymbol) || isProductLike(objSymbol) then { - val copy = symbolMethodByNameOrError(objSymbol, "copy") val argsMap: Map[String, Term] = fields.map { (field, trees) => val fieldMethod = symbolAccessorByNameOrError(objSymbol, field.name) val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) => @@ -243,6 +260,7 @@ object QuicklensMacros { val namedArg = NamedArg(field.name, resTerm) field.name -> namedArg }.toMap + val copy = methodSymbolByNameAndArgsOrError(objSymbol, "copy", argsMap) val typeParams = objTpe match { case AppliedType(_, typeParams) => Some(typeParams) @@ -253,7 +271,7 @@ object QuicklensMacros { val args = copyParamNames.zipWithIndex.map { (n, _i) => val i = _i + 1 - val defaultMethod = obj.select(symbolMethodByNameOrError(objSymbol, "copy$default$" + i.toString)) + val defaultMethod = obj.select(methodSymbolByNameOrError(objSymbol, "copy$default$" + i.toString)) // for extension methods, might need sth more like this: (or probably some weird implicit conversion) // val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n)) argsMap.getOrElse( diff --git a/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala index 444fbf6..f1ecce1 100644 --- a/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala +++ b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala @@ -36,4 +36,59 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers { docs.modify(_.paths.pathItems).using(m => m + ("a" -> PathItem())) } + it should "modify a case class with an additional explicit copy" in { + case class Frozen(state: String, ext: Int) { + def copy(stateC: Char): Frozen = Frozen(stateC.toString, ext) + } + + val f = Frozen("A", 0) + f.modify(_.state).setTo("B") + } + + it should "modify a case class with an ambiguous additional explicit copy" in { + case class Frozen(state: String, ext: Int) { + def copy(state: String): Frozen = Frozen(state, ext) + } + + val f = Frozen("A", 0) + f.modify(_.state).setTo("B") + } + + it should "modify a class with two explicit copy methods" in { + class Frozen(val state: String, val ext: Int) { + def copy(state: String = state, ext: Int = ext): Frozen = new Frozen(state, ext) + def copy(state: String): Frozen = new Frozen(state, ext) + } + + val f = new Frozen("A", 0) + f.modify(_.state).setTo("B") + } + + it should "modify a case class with an ambiguous additional explicit copy and pick the synthetic one first" in { + var accessed = 0 + case class Frozen(state: String, ext: Int) { + def copy(state: String): Frozen = + accessed += 1 + Frozen(state, ext) + } + + val f = Frozen("A", 0) + f.modify(_.state).setTo("B") + accessed shouldEqual 0 + } + + // TODO: Would be nice to be able to handle this case. Based on the types, it + // is obvious, that the explicit copy should be picked, but I'm not sure if we + // can get that information + + // it should "pick the correct copy method, based on the type" in { + // case class Frozen(state: String, ext: Int) { + // def copy(state: Char): Frozen = + // Frozen(state.toString, ext) + // } + + // val f = Frozen("A", 0) + // f.modify(_.state).setTo('B') + // } + }