From 663e7507a0bc0e14209ea59a11ce16f228619463 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20=C5=A0pan=C4=9Bl?= Date: Tue, 17 Dec 2024 16:41:14 +0100 Subject: [PATCH] Extension copy (#262) Co-authored-by: Kacper Korban --- build.sbt | 2 +- .../quicklens/QuicklensMacros.scala | 190 ++++++++++++------ .../com/softwaremill/quicklens/package.scala | 2 +- .../quicklens/test/ExplicitCopyTest.scala | 27 ++- .../quicklens/test/ExtensionCopyTest.scala | 90 +++++++++ 5 files changed, 247 insertions(+), 64 deletions(-) create mode 100644 quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExtensionCopyTest.scala diff --git a/build.sbt b/build.sbt index b2cd19b..61042f7 100644 --- a/build.sbt +++ b/build.sbt @@ -14,7 +14,7 @@ excludeLintKeys in Global ++= Set(ideSkipProject) val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq( organization := "com.softwaremill.quicklens", updateDocs := UpdateVersionInDocs(sLog.value, organization.value, version.value, List(file("README.md"))), - scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all" + scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all", "-Xcheck-macros" ideSkipProject := (scalaVersion.value != scalaIdeaVersion) ) diff --git a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala index 20e640e..6215ed3 100644 --- a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala +++ b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala @@ -58,6 +58,9 @@ object QuicklensMacros { def noSuchMember(tpeStr: String, name: String) = s"$tpeStr has no member named $name" + def noSuitableMember(tpeStr: String, name: String, argNames: Iterable[String]) = + s"$tpeStr has no member $name with parameters ${argNames.mkString("(", ", ", ")")}" + def multipleMatchingMethods(tpeStr: String, name: String, syms: Seq[Symbol]) = val symsStr = syms.map(s => s" - $s: ${s.termRef.dealias.widen.show}").mkString("\n", "\n", "") s"Multiple methods named $name found in $tpeStr: $symsStr" @@ -109,11 +112,14 @@ object QuicklensMacros { case (symbol :: tail) => PathTree.Node(Seq(symbol -> Seq(tail.toPathTree))) enum PathSymbol: - case Field(name: String) - case FunctionDelegate(name: String, givn: Term, typeTree: TypeTree, args: List[Term]) + case Field(override val name: String) + case Extension(term: Term, override val name: String) + case FunctionDelegate(override val name: String, givn: Term, typeTree: TypeTree, args: List[Term]) + def name: String def equiv(other: Any): Boolean = (this, other) match case (Field(name1), Field(name2)) => name1 == name2 + case (Extension(term1, name1), Extension(term2, name2)) => term1 == term2 && name1 == name2 case (FunctionDelegate(name1, _, typeTree1, args1), FunctionDelegate(name2, _, typeTree2, args2)) => name1 == name2 && typeTree1.tpe == typeTree2.tpe && args1 == args2 case _ => false @@ -133,6 +139,9 @@ object QuicklensMacros { /** Method call with one type parameter and using clause */ case a @ Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) => idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty) + /** Extension method, which is called e.g. as x(_$1) */ + case Apply(obj@Select(term, member), Seq(deep)) if obj.symbol.flags.is(Flags.ExtensionMethod) => + toPath(deep, focus) :+ PathSymbol.Extension(term, member) /** Field access */ case Apply(deep, idents) => toPath(deep, focus) ++ idents.flatMap(toPath(_, focus)) @@ -157,43 +166,104 @@ object QuicklensMacros { def matchingTypeSymbol: Symbol = tpe.widenAll match { case AndType(l, r) => val lSym = l.matchingTypeSymbol - if l.matchingTypeSymbol != Symbol.noSymbol then lSym else r.matchingTypeSymbol - case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) => - tpe.typeSymbol - case tpe if isProductLike(tpe.typeSymbol) => + if lSym != Symbol.noSymbol then lSym else r.matchingTypeSymbol + case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) || isProductLike(tpe.typeSymbol) => tpe.typeSymbol case _ => Symbol.noSymbol } - def symbolAccessorByNameOrError(sym: Symbol, name: String): Symbol = { - val mem = sym.fieldMember(name) - if mem != Symbol.noSymbol then mem - else methodSymbolByNameOrError(sym, name) + extension (term: Term) + def appliedToIfNeeded(args: List[Term]): Term = + if args.isEmpty then term else term.appliedToArgs(args) + + def symbolAccessorByNameOrError(obj: Term, name: String): Term = { + val objTpe = obj.tpe.widenAll + val objSymbol = objTpe.matchingTypeSymbol + // opaque types can find members of underlying types - ignore them (see https://github.com/scala/scala3/issues/22143) + val fieldMemberSym = objSymbol.fieldMember(name) + if !objSymbol.flags.is(Flags.Deferred) && fieldMemberSym.exists then + Select(obj, fieldMemberSym) + else + objSymbol.methodMember(name) match + case List(m) => + Select(obj, m) + case lst => + report.errorAndAbort(reportMethodError(objSymbol, name, lst)) + } + + def reportMethodError(sym: Symbol, name: String, lst: List[Symbol], maybeArgNames: Option[Iterable[String]] = None): String = { + (lst, maybeArgNames) match + case (Nil, _) => noSuchMember(sym.name, name) + case (lst, None) => multipleMatchingMethods(sym.name, name, lst) + case (lst, Some(argNames)) => noSuitableMember(sym.name, name, argNames) } def methodSymbolByNameOrError(sym: Symbol, name: String): Symbol = { sym.methodMember(name) match case List(m) => m - case Nil => report.errorAndAbort(noSuchMember(sym.name, name)) - case lst => report.errorAndAbort(multipleMatchingMethods(sym.name, name, lst)) + case lst => report.errorAndAbort(reportMethodError(sym, name, lst)) } - def methodSymbolByNameAndArgsOrError(sym: Symbol, name: String, argsMap: Map[String, Term]): Symbol = { + def filterMethodsByNameAndArgs(allMethods: List[Symbol], argsMap: Map[String, Term]): Option[Symbol] = { val argNames = argsMap.keys - sym.methodMember(name).filter{ msym => + allMethods.filter { msym => // for copy, we filter out the methods that don't have the desired parameter names val paramNames = msym.paramSymss.flatten.filter(_.isTerm).map(_.name) argNames.forall(paramNames.contains) } match - case List(m) => m - case Nil => report.errorAndAbort(noSuchMember(sym.name, name)) - case lst @ (m :: _) => + case List(m) => Some(m) + case Nil => None + case lst@(m :: _) => // if we have multiple matching copy methods, pick the synthetic one, if it exists, otherwise, pick any method val syntheticCopies = lst.filter(_.flags.is(Flags.Synthetic)) syntheticCopies match - case List(mSynth) => mSynth - case _ => m + case List(mSynth) => Some(mSynth) + case _ => Some(m) + } + + def methodSymbolByNameAndArgs(sym: Symbol, name: String, argsMap: Map[String, Term]): Either[String, Symbol] = { + if !sym.flags.is(Flags.Deferred) then + val memberMethods = sym.methodMember(name) + filterMethodsByNameAndArgs(memberMethods, argsMap) + .toRight(reportMethodError(sym, name, memberMethods, Some(argsMap.keys))) + else Left(s"Deferred type ${sym.name}") + } + + /** + * @param argsMap normal methods receive one parameter list, extensions methods two, the first one contains the value + * on which the extension is called + * */ + def callMethod(obj: Term, copy: Symbol, argsMap: List[Map[String, Term]]) = { + require(argsMap.size == 1 || argsMap.size == 2, s"argsMap.size should be either 1 or 2, got: ${argsMap.size} ($argsMap)") + val objTpe = obj.tpe.widenAll + val objSymbol = objTpe.matchingTypeSymbol + + val typeParams = objTpe.typeArgs + val copyTree: DefDef = copy.tree.asInstanceOf[DefDef] + val copyParams: List[(String, Option[Term])] = copyTree.termParamss.zip(argsMap) + .map((params, args) => params.params.map(_.name).map(name => name -> args.get(name))) + .flatten.toList + + val args = copyParams.zipWithIndex.map { case ((n, v), _i) => + val i = _i + 1 + def defaultMethod: Term = + val methodSymbol = methodSymbolByNameOrError(objSymbol, copy.name + "$default$" + i.toString) + // default values in extension methods take the extension receiver as the first parameter + val defaultMethodArgs = argsMap.dropRight(1).flatMap(_.values) + obj.select(methodSymbol).appliedToIfNeeded(defaultMethodArgs) + n -> v.getOrElse(defaultMethod) + }.toMap + + val argLists: List[List[Term]] = copyTree.termParamss.take(argsMap.size).map(list => list.params.map(p => args(p.name))) + + if copyTree.termParamss.drop(argLists.size).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then + report.errorAndAbort( + s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit. ${copyTree.termParamss.drop(1)}" + ) + + val withTypeParamsApplied = obj.select(copy).appliedToTypes(typeParams) + argLists.foldLeft(withTypeParamsApplied)(Apply(_, _)) } def termMethodByNameUnsafe(term: Term, name: String): Symbol = { @@ -210,15 +280,32 @@ object QuicklensMacros { (sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract))) } + def findCompanionLikeObject(objSymbol: Symbol): Symbol = { + if objSymbol.companionModule.exists then + objSymbol.companionModule + else + val namedFromOwnerScope = objSymbol.owner.fieldMember(objSymbol.name) + if namedFromOwnerScope.flags.is(Flags.Module) then namedFromOwnerScope + else Symbol.noSymbol + } + + def hasExtensionNamed(sym: Symbol, methodName: String): List[Symbol] = { + val companionSymbol = findCompanionLikeObject(sym) + if companionSymbol.exists then + companionSymbol.methodMember(methodName).filter(s => s.name == methodName && s.flags.is(Flags.ExtensionMethod)) + else + Nil + } + def isProductLike(sym: Symbol): Boolean = { - sym.methodMember("copy").size >= 1 + sym.methodMember("copy").nonEmpty || hasExtensionNamed(sym, "copy").nonEmpty } def caseClassCopy( owner: Symbol, mod: Expr[A => A], obj: Term, - fields: Seq[(PathSymbol.Field, Seq[PathTree])] + fields: Seq[(PathSymbol.Field | PathSymbol.Extension, Seq[PathTree])] ): Term = { val objTpe = obj.tpe.widenAll val objSymbol = objTpe.matchingTypeSymbol @@ -248,50 +335,39 @@ object QuicklensMacros { } val elseThrow = '{ throw new IllegalStateException() }.asTerm + ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) => If(ifCond, ifThen, ifElse) } } else if isProduct(objSymbol) || isProductLike(objSymbol) then { val argsMap: Map[String, Term] = fields.map { (field, trees) => - val fieldMethod = symbolAccessorByNameOrError(objSymbol, field.name) - val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) => + val fieldMethod = field match { + case PathSymbol.Field(name) => + symbolAccessorByNameOrError(obj, name) + case PathSymbol.Extension(term, name) => + val extensionMethod = symbolAccessorByNameOrError(term, name) + Apply(extensionMethod, List(obj)) + } + val resTerm: Term = trees.foldLeft[Term](fieldMethod) { (term, tree) => mapToCopy(owner, mod, term, tree) } val namedArg = NamedArg(field.name, resTerm) field.name -> namedArg }.toMap - val copy = methodSymbolByNameAndArgsOrError(objSymbol, "copy", argsMap) - - val typeParams = objTpe match { - case AppliedType(_, typeParams) => Some(typeParams) - case _ => None - } - val copyTree: DefDef = copy.tree.asInstanceOf[DefDef] - val copyParamNames: List[String] = copyTree.termParamss.headOption.map(_.params).toList.flatten.map(_.name) - - val args = copyParamNames.zipWithIndex.map { (n, _i) => - val i = _i + 1 - val defaultMethod = obj.select(methodSymbolByNameOrError(objSymbol, "copy$default$" + i.toString)) - // for extension methods, might need sth more like this: (or probably some weird implicit conversion) - // val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n)) - argsMap.getOrElse( - n, - defaultMethod - ) - }.toList - - if copyTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then - report.errorAndAbort( - s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit." - ) - - typeParams match { - // if the object's type is parametrised, we need to call .copy with the same type parameters - case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args) - case _ => Apply(Select(obj, copy), args) - } + methodSymbolByNameAndArgs(objSymbol, "copy", argsMap) match + case Right(copy) => + callMethod(obj, copy, List(argsMap)) + case Left(error) => + val objCompanion = findCompanionLikeObject(objSymbol) + methodSymbolByNameAndArgs(objCompanion, "copy", argsMap).toOption match + case Some(copy) => + // now try to call the extension as a method, assume the object is its first parameter + val extensionParameter = copy.paramSymss.headOption.map(_.headOption).flatten + val argsWithObj = List(extensionParameter.map(name => name.name -> obj).toMap, argsMap) + callMethod(Ref(objCompanion), copy, argsWithObj) + case None => report.errorAndAbort(error) } else - report.errorAndAbort(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol of type ${objTpe.show} (${obj.show})") + report.errorAndAbort(s"Unsupported source object: must be a case class, sealed trait or class with copy method, but got: $objSymbol of type ${objTpe.show} (${obj.show})") } def applyFunctionDelegate( @@ -331,9 +407,9 @@ object QuicklensMacros { case Nil => objTerm - case (_: PathSymbol.Field, _) :: _ => - val (fs, funs) = pathSymbols.span(_._1.isInstanceOf[PathSymbol.Field]) - val fields = fs.collect { case (p: PathSymbol.Field, trees) => p -> trees } + case (_: (PathSymbol.Field | PathSymbol.Extension), _) :: _ => + val (fs, funs) = pathSymbols.span((ps, _) => ps.isInstanceOf[PathSymbol.Field] || ps.isInstanceOf[PathSymbol.Extension]) + val fields = fs.collect { case (p: (PathSymbol.Field | PathSymbol.Extension), trees) => p -> trees } val withCopiedFields: Term = caseClassCopy(owner, mod, objTerm, fields) accumulateToCopy(owner, mod, withCopiedFields, funs) diff --git a/quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala b/quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala index 179801b..467b481 100644 --- a/quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala +++ b/quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala @@ -154,7 +154,7 @@ package object quicklens { def map[A](fa: M[A], f: A => A): M[A] = { val mapped = fa.view.mapValues(f) (fa match { - case sfa: SortedMap[K, A] => sfa.sortedMapFactory.from(mapped)(using sfa.ordering) + case sfa: SortedMap[K, A]@unchecked => sfa.sortedMapFactory.from(mapped)(using sfa.ordering) case _ => mapped.to(fa.mapFactory) }).asInstanceOf[M[A]] } diff --git a/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala index f1ecce1..d5dee42 100644 --- a/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala +++ b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala @@ -1,4 +1,5 @@ package com.softwaremill.quicklens +package test import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -33,7 +34,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers { def paths(paths: Paths): Docs = copy(paths = paths) } val docs = Docs() - docs.modify(_.paths.pathItems).using(m => m + ("a" -> PathItem())) + val r = docs.modify(_.paths.pathItems).using(m => m + ("a" -> PathItem())) + r.paths.pathItems should contain ("a" -> PathItem()) } it should "modify a case class with an additional explicit copy" in { @@ -42,7 +44,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers { } val f = Frozen("A", 0) - f.modify(_.state).setTo("B") + val r = f.modify(_.state).setTo("B") + r.state shouldEqual "B" } it should "modify a case class with an ambiguous additional explicit copy" in { @@ -51,7 +54,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers { } val f = Frozen("A", 0) - f.modify(_.state).setTo("B") + val r = f.modify(_.state).setTo("B") + r.state shouldEqual "B" } it should "modify a class with two explicit copy methods" in { @@ -61,7 +65,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers { } val f = new Frozen("A", 0) - f.modify(_.state).setTo("B") + val r = f.modify(_.state).setTo("B") + r.state shouldEqual "B" } it should "modify a case class with an ambiguous additional explicit copy and pick the synthetic one first" in { @@ -77,6 +82,19 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers { accessed shouldEqual 0 } + it should "not compile when modifying a field which is not present as a copy parameter" in { + """ + case class Content(x: String) + + class A(val c: Content) { + def copy(x: String = c.x): A = new A(Content(x)) + } + + val a = new A(Content("A")) + val am = a.modify(_.c).setTo(Content("B")) + """ shouldNot compile + } + // TODO: Would be nice to be able to handle this case. Based on the types, it // is obvious, that the explicit copy should be picked, but I'm not sure if we // can get that information @@ -90,5 +108,4 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers { // val f = Frozen("A", 0) // f.modify(_.state).setTo('B') // } - } diff --git a/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExtensionCopyTest.scala b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExtensionCopyTest.scala new file mode 100644 index 0000000..5c85919 --- /dev/null +++ b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExtensionCopyTest.scala @@ -0,0 +1,90 @@ +package com.softwaremill.quicklens +package test + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +object ExtensionCopyTest { + case class V(x: Double, y: Double, z: Double) + + opaque type Vec = V + + object Vec { + def apply(x: Double, y: Double): Vec = V(x, y, 0) + + extension (v: Vec) { + def x: Double = v.x + def y: Double = v.y + def copy(x: Double = v.x, y: Double = v.y): Vec = V(x, y, 0) + } + } +} + +class ExtensionCopyTest extends AnyFlatSpec with Matchers { + it should "modify a simple class with an extension copy method" in { + // this test does compile at the moment, because we search extensions in companions only + /* + class VecSimple(xp: Double, yp: Double) { + val xMember = xp + val yMember = yp + } + + object VecSimple { + def apply(x: Double, y: Double): VecSimple = new VecSimple(x, y) + } + + extension (v: VecSimple) { + def copy(x: Double = v.xMember, y: Double = v.yMember): VecSimple = new VecSimple(x, y) + } + val a = VecSimple(1, 2) + val b = a.modify(_.xMember).using(_ + 10) + b.xMember shouldEqual 11 + */ + } + + it should "modify a simple class with an extension copy method in companion" in { + class VecCompanion(xp: Double, yp: Double) { + val x = xp + val y = yp + } + + object VecCompanion { + def apply(x: Double, y: Double): VecCompanion = new VecCompanion(x, y) + extension (v: VecCompanion) { + def copy(x: Double = v.x, y: Double = v.y): VecCompanion = new VecCompanion(x, y) + } + } + + val a = VecCompanion(1, 2) + val b = a.modify(_.x).using(_ + 10) + b.x shouldEqual 11 + } + + it should "modify a class with extension methods in companion" in { + case class V(xm: Double, ym: Double) + + class VecClass(val v: V) + + object VecClass { + def apply(x: Double, y: Double): VecClass = new VecClass(V(x, y)) + + extension (v: VecClass) { + def x: Double = v.v.xm + def y: Double = v.v.ym + def copy(x: Double = v.x, y: Double = v.y): VecClass = new VecClass(V(x, y)) + } + } + + val a = VecClass(1, 2) + val b = a.modify(_.x).using(_ + 10) + b.x shouldEqual 11 + } + + it should "modify an opaque type with extension methods" in { + import ExtensionCopyTest.* + + val a = Vec(1, 2) + val b = a.modify(_.x).using(_ + 10) + b.x shouldEqual 11 + } +}