Skip to content

Commit

Permalink
Add a heuristic for picking a copy method for Scala 3, when multiple …
Browse files Browse the repository at this point in the history
…are found (#257)

Co-authored-by: Adam Warski <[email protected]>
  • Loading branch information
KacperFKorban and adamw authored Nov 29, 2024
1 parent 69047f2 commit d0a4712
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ out
.bloop/
project/**/metals.sbt
.vscode/
.bsp
.bsp
.scala-build
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ object QuicklensMacros {
def noSuchMember(tpeStr: String, name: String) =
s"$tpeStr has no member named $name"

def multipleMatchingMethods(tpeStr: String, name: String) =
s"Multiple methods named $name found in $tpeStr"
def multipleMatchingMethods(tpeStr: String, name: String, syms: Seq[Symbol]) =
val symsStr = syms.map(s => s" - $s: ${s.termRef.dealias.widen.show}").mkString("\n", "\n", "")
s"Multiple methods named $name found in $tpeStr: $symsStr"

def methodSupported(method: String) =
Seq("at", "each", "eachWhere", "eachRight", "eachLeft", "atOrElse", "index", "when").contains(method)
Expand Down Expand Up @@ -168,19 +169,36 @@ object QuicklensMacros {
def symbolAccessorByNameOrError(sym: Symbol, name: String): Symbol = {
val mem = sym.fieldMember(name)
if mem != Symbol.noSymbol then mem
else symbolMethodByNameOrError(sym, name)
else methodSymbolByNameOrError(sym, name)
}

def symbolMethodByNameOrError(sym: Symbol, name: String): Symbol = {
def methodSymbolByNameOrError(sym: Symbol, name: String): Symbol = {
sym.methodMember(name) match
case List(m) => m
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
case _ => report.errorAndAbort(multipleMatchingMethods(sym.name, name))
case lst => report.errorAndAbort(multipleMatchingMethods(sym.name, name, lst))
}

def methodSymbolByNameAndArgsOrError(sym: Symbol, name: String, argsMap: Map[String, Term]): Symbol = {
val argNames = argsMap.keys
sym.methodMember(name).filter{ msym =>
// for copy, we filter out the methods that don't have the desired parameter names
val paramNames = msym.paramSymss.flatten.filter(_.isTerm).map(_.name)
argNames.forall(paramNames.contains)
} match
case List(m) => m
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
case lst @ (m :: _) =>
// if we have multiple matching copy methods, pick the synthetic one, if it exists, otherwise, pick any method
val syntheticCopies = lst.filter(_.flags.is(Flags.Synthetic))
syntheticCopies match
case List(mSynth) => mSynth
case _ => m
}

def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
val typeSymbol = term.tpe.widenAll.typeSymbol
symbolMethodByNameOrError(typeSymbol, name)
methodSymbolByNameOrError(typeSymbol, name)
}

def isProduct(sym: Symbol): Boolean = {
Expand All @@ -193,7 +211,7 @@ object QuicklensMacros {
}

def isProductLike(sym: Symbol): Boolean = {
sym.methodMember("copy").size == 1
sym.methodMember("copy").size >= 1
}

def caseClassCopy(
Expand Down Expand Up @@ -234,7 +252,6 @@ object QuicklensMacros {
If(ifCond, ifThen, ifElse)
}
} else if isProduct(objSymbol) || isProductLike(objSymbol) then {
val copy = symbolMethodByNameOrError(objSymbol, "copy")
val argsMap: Map[String, Term] = fields.map { (field, trees) =>
val fieldMethod = symbolAccessorByNameOrError(objSymbol, field.name)
val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) =>
Expand All @@ -243,6 +260,7 @@ object QuicklensMacros {
val namedArg = NamedArg(field.name, resTerm)
field.name -> namedArg
}.toMap
val copy = methodSymbolByNameAndArgsOrError(objSymbol, "copy", argsMap)

val typeParams = objTpe match {
case AppliedType(_, typeParams) => Some(typeParams)
Expand All @@ -253,7 +271,7 @@ object QuicklensMacros {

val args = copyParamNames.zipWithIndex.map { (n, _i) =>
val i = _i + 1
val defaultMethod = obj.select(symbolMethodByNameOrError(objSymbol, "copy$default$" + i.toString))
val defaultMethod = obj.select(methodSymbolByNameOrError(objSymbol, "copy$default$" + i.toString))
// for extension methods, might need sth more like this: (or probably some weird implicit conversion)
// val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
argsMap.getOrElse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,59 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
docs.modify(_.paths.pathItems).using(m => m + ("a" -> PathItem()))
}

it should "modify a case class with an additional explicit copy" in {
case class Frozen(state: String, ext: Int) {
def copy(stateC: Char): Frozen = Frozen(stateC.toString, ext)
}

val f = Frozen("A", 0)
f.modify(_.state).setTo("B")
}

it should "modify a case class with an ambiguous additional explicit copy" in {
case class Frozen(state: String, ext: Int) {
def copy(state: String): Frozen = Frozen(state, ext)
}

val f = Frozen("A", 0)
f.modify(_.state).setTo("B")
}

it should "modify a class with two explicit copy methods" in {
class Frozen(val state: String, val ext: Int) {
def copy(state: String = state, ext: Int = ext): Frozen = new Frozen(state, ext)
def copy(state: String): Frozen = new Frozen(state, ext)
}

val f = new Frozen("A", 0)
f.modify(_.state).setTo("B")
}

it should "modify a case class with an ambiguous additional explicit copy and pick the synthetic one first" in {
var accessed = 0
case class Frozen(state: String, ext: Int) {
def copy(state: String): Frozen =
accessed += 1
Frozen(state, ext)
}

val f = Frozen("A", 0)
f.modify(_.state).setTo("B")
accessed shouldEqual 0
}

// TODO: Would be nice to be able to handle this case. Based on the types, it
// is obvious, that the explicit copy should be picked, but I'm not sure if we
// can get that information

// it should "pick the correct copy method, based on the type" in {
// case class Frozen(state: String, ext: Int) {
// def copy(state: Char): Frozen =
// Frozen(state.toString, ext)
// }

// val f = Frozen("A", 0)
// f.modify(_.state).setTo('B')
// }

}

0 comments on commit d0a4712

Please sign in to comment.