diff --git a/build.sbt b/build.sbt index b2cd19b..61042f7 100644 --- a/build.sbt +++ b/build.sbt @@ -14,7 +14,7 @@ excludeLintKeys in Global ++= Set(ideSkipProject) val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq( organization := "com.softwaremill.quicklens", updateDocs := UpdateVersionInDocs(sLog.value, organization.value, version.value, List(file("README.md"))), - scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all" + scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all", "-Xcheck-macros" ideSkipProject := (scalaVersion.value != scalaIdeaVersion) ) diff --git a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala index 20e640e..0a652b9 100644 --- a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala +++ b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala @@ -133,6 +133,8 @@ object QuicklensMacros { /** Method call with one type parameter and using clause */ case a @ Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) => idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty) + case Apply(Ident(ident), Seq(deep)) => // this is an extension method, which is called e.g. as x(_$1) + toPath(deep, focus) :+ PathSymbol.Field(ident) /** Field access */ case Apply(deep, idents) => toPath(deep, focus) ++ idents.flatMap(toPath(_, focus)) @@ -179,21 +181,77 @@ object QuicklensMacros { case lst => report.errorAndAbort(multipleMatchingMethods(sym.name, name, lst)) } - def methodSymbolByNameAndArgsOrError(sym: Symbol, name: String, argsMap: Map[String, Term]): Symbol = { + def filterMethodsByNameAndArgs(allMethods: List[Symbol], argsMap: Map[String, Term]): Option[Symbol] = { val argNames = argsMap.keys - sym.methodMember(name).filter{ msym => + allMethods.filter { msym => // for copy, we filter out the methods that don't have the desired parameter names val paramNames = msym.paramSymss.flatten.filter(_.isTerm).map(_.name) argNames.forall(paramNames.contains) } match - case List(m) => m - case Nil => report.errorAndAbort(noSuchMember(sym.name, name)) - case lst @ (m :: _) => + case List(m) => Some(m) + case Nil => None + case lst@(m :: _) => // if we have multiple matching copy methods, pick the synthetic one, if it exists, otherwise, pick any method val syntheticCopies = lst.filter(_.flags.is(Flags.Synthetic)) syntheticCopies match - case List(mSynth) => mSynth - case _ => m + case List(mSynth) => Some(mSynth) + case _ => Some(m) + } + + def methodSymbolByNameAndArgs(sym: Symbol, name: String, argsMap: Map[String, Term]): Option[Symbol] = { + val memberMethods = sym.methodMember(name) + filterMethodsByNameAndArgs(memberMethods, argsMap) + } + + /** + * @param argsMap normal methods receive one parameter list, extensions methods two, the first one contains the value + * on which the extension is called + * */ + def callMethod(obj: Term, copy: Symbol, argsMap: List[Map[String, Term]]) = { + val objTpe = obj.tpe.widenAll + val objSymbol = objTpe.matchingTypeSymbol + + val typeParams = objTpe match { + case AppliedType(_, typeParams) => Some(typeParams) + case _ => None + } + val copyTree: DefDef = copy.tree.asInstanceOf[DefDef] + val copyParams: List[(String, Option[Term])] = copyTree.termParamss.zip(argsMap) + .map((params, args) => params.params.map(_.name).map(name => name -> args.get(name))) + .flatten.toList + + val args = copyParams.zipWithIndex.map { case ((n, v), _i) => + val i = _i + 1 + def defaultMethod = + val methodSymbol = methodSymbolByNameOrError(objSymbol, copy.name + "$default$" + i.toString) + // default values in extensions are obtained by calling a method receiving the extension parameter + val defaultMethodArgs = argsMap.dropRight(1).headOption.toList.flatMap(_.values) + //println(s"defaultMethodArgs ${obj.show} ${methodSymbol.name} $defaultMethodArgs") + if defaultMethodArgs.nonEmpty then + Apply(Select(obj, methodSymbol), defaultMethodArgs) + else + // note: this is not always correct, -Xcheck-macros shows errors here + // sometimes we should call a method with empry parameter list instead + obj.select(methodSymbol) + + // for extension methods, might need sth more like this: (or probably some weird implicit conversion) + // val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n)) + n -> v.getOrElse(defaultMethod) + }.toMap + + val argLists = copyTree.termParamss.take(argsMap.size).map(list => list.params.map(p => args(p.name))) + + if copyTree.termParamss.drop(argLists.size).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then + report.errorAndAbort( + s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit. ${copyTree.termParamss.drop(1)}" + ) + + val applyOn = typeParams match { + // if the object's type is parametrised, we need to call .copy with the same type parameters + case Some(typeParams) => TypeApply(Select(obj, copy), typeParams.map(Inferred(_))) + case _ => Select(obj, copy) + } + argLists.foldLeft(applyOn)((applied, list) => Apply(applied, list)) } def termMethodByNameUnsafe(term: Term, name: String): Symbol = { @@ -210,8 +268,19 @@ object QuicklensMacros { (sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract))) } + def findExtensionMethod(using Quotes)(sym: Symbol, methodName: String): List[Symbol] = { + // TODO: can we check parameter types somehow? + def isExtensionMethod(sym: Symbol): Boolean = sym.isDefDef && sym.paramSymss.headOption.exists(_.sizeIs == 1) + + // TODO: try to search in symbol parent object as well + val symbols = Seq(sym.companionModule).filter(_ != Symbol.noSymbol) + + symbols.flatMap(_.declaredMethods).filter(sym => sym.name == methodName).filter(isExtensionMethod).toList + } + def isProductLike(sym: Symbol): Boolean = { - sym.methodMember("copy").size >= 1 + // just assume true - we can always fail if there is no copy + sym.methodMember("copy").nonEmpty || findExtensionMethod(sym, "copy").nonEmpty } def caseClassCopy( @@ -248,6 +317,7 @@ object QuicklensMacros { } val elseThrow = '{ throw new IllegalStateException() }.asTerm + ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) => If(ifCond, ifThen, ifElse) } @@ -260,36 +330,18 @@ object QuicklensMacros { val namedArg = NamedArg(field.name, resTerm) field.name -> namedArg }.toMap - val copy = methodSymbolByNameAndArgsOrError(objSymbol, "copy", argsMap) - - val typeParams = objTpe match { - case AppliedType(_, typeParams) => Some(typeParams) - case _ => None - } - val copyTree: DefDef = copy.tree.asInstanceOf[DefDef] - val copyParamNames: List[String] = copyTree.termParamss.headOption.map(_.params).toList.flatten.map(_.name) - - val args = copyParamNames.zipWithIndex.map { (n, _i) => - val i = _i + 1 - val defaultMethod = obj.select(methodSymbolByNameOrError(objSymbol, "copy$default$" + i.toString)) - // for extension methods, might need sth more like this: (or probably some weird implicit conversion) - // val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n)) - argsMap.getOrElse( - n, - defaultMethod - ) - }.toList - - if copyTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then - report.errorAndAbort( - s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit." - ) - - typeParams match { - // if the object's type is parametrised, we need to call .copy with the same type parameters - case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args) - case _ => Apply(Select(obj, copy), args) - } + methodSymbolByNameAndArgs(objSymbol, "copy", argsMap) match + case Some(copy) => + callMethod(obj, copy, List(argsMap)) + case None => + val objCompanion = objSymbol.companionModule + methodSymbolByNameAndArgs(objCompanion, "copy", argsMap) match + case Some(copy) => + // now try to call the extension as a method, assume the object is its first parameter + val extensionParameter = copy.paramSymss.headOption.map(_.headOption).flatten + val argsWithObj = List(extensionParameter.map(name => name.name -> obj).toMap, argsMap) + callMethod(Ref(objCompanion), copy, argsWithObj) + case None => report.errorAndAbort(noSuchMember(objSymbol.name, "copy")) } else report.errorAndAbort(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol of type ${objTpe.show} (${obj.show})") } diff --git a/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExtensionCopyTest.scala b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExtensionCopyTest.scala index d85f588..a24f08e 100644 --- a/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExtensionCopyTest.scala +++ b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExtensionCopyTest.scala @@ -11,35 +11,73 @@ object ExtensionCopyTest { object Vec { def apply(x: Double, y: Double): Vec = V(x, y) - } - extension (v: Vec) { - def x: Double = v.x - def y: Double = v.y - def copy(x: Double = v.x, y: Double = v.y): Vec = V(x, y) + extension (v: Vec) { + def x: Double = v.x + def y: Double = v.y + def copy(x: Double = v.x, y: Double = v.y): Vec = V(x, y) + } } } class ExtensionCopyTest extends AnyFlatSpec with Matchers { + /* + it should "modify a simple class with an extension copy method" in { + class VecSimple(xp: Double, yp: Double) { + val xMember = xp + val yMember = yp + } + + object VecSimple { + def apply(x: Double, y: Double): VecSimple = new VecSimple(x, y) + } + + extension (v: VecSimple) { + def copy(x: Double = v.xMember, y: Double = v.yMember): VecSimple = new VecSimple(x, y) + } + val a = VecSimple(1, 2) + val b = a.modify(_.xMember).using(_ + 1) + println(b) + } + */ + + it should "modify a simple class with an extension copy method in companion" in { + class VecCompanion(xp: Double, yp: Double) { + val x = xp + val y = yp + } + + object VecCompanion { + def apply(x: Double, y: Double): VecCompanion = new VecCompanion(x, y) + extension (v: VecCompanion) { + def copy(x: Double = v.x, y: Double = v.y): VecCompanion = new VecCompanion(x, y) + } + } + + val a = VecCompanion(1, 2) + val b = a.modify(_.x).using(_ + 1) + println(b) + } + /* + it should "modify a class with an extension copy method" in { case class V(x: Double, y: Double) - class Vec(val v: V) + class VecClass(val v: V) - object Vec { - def apply(x: Double, y: Double): Vec = new Vec(V(x, y)) + object VecClass { + def apply(x: Double, y: Double): VecClass = new VecClass(V(x, y)) } - extension (v: Vec) { + extension (v: VecClass) { def x: Double = v.v.x def y: Double = v.v.y - def copy(x: Double = v.x, y: Double = v.y): Vec = new Vec(V(x, y)) + def copy(x: Double = v.x, y: Double = v.y): VecClass = new VecClass(V(x, y)) } - val a = Vec(1, 2) + val a = VecClass(1, 2) val b = a.modify(_.x).using(_ + 1) println(b) } - it should "modify an opaque type with an extension copy method" in { import ExtensionCopyTest.* @@ -47,4 +85,5 @@ class ExtensionCopyTest extends AnyFlatSpec with Matchers { val b = a.modify(_.x).using(_ + 1) println(b) } + */ }