Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extension copy #262

Merged
merged 14 commits into from
Dec 17, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ object QuicklensMacros {
case (symbol :: tail) => PathTree.Node(Seq(symbol -> Seq(tail.toPathTree)))

enum PathSymbol:
case Field(name: String)
case Extension(term: Term, name: String)
case FunctionDelegate(name: String, givn: Term, typeTree: TypeTree, args: List[Term])
case Field(override val name: String)
case Extension(term: Term, override val name: String)
case FunctionDelegate(override val name: String, givn: Term, typeTree: TypeTree, args: List[Term])
def name: String

def equiv(other: Any): Boolean = (this, other) match
case (Field(name1), Field(name2)) => name1 == name2
Expand All @@ -138,14 +139,9 @@ object QuicklensMacros {
/** Method call with one type parameter and using clause */
case a @ Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) =>
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
case Apply(obj, Seq(deep)) => // this is an extension method, which is called e.g. as x(_$1)
obj match
case Ident(ident) =>
toPath(deep, focus) :+ PathSymbol.Field(ident)
case Select(term, member) =>
toPath(deep, focus) :+ PathSymbol.Extension(term, member)
case other =>
report.errorAndAbort(unsupportedShapeInfo(focus.asTerm))
/** Extension method, which is called e.g. as x(_$1) */
case Apply(obj@Select(term, member), Seq(deep)) if obj.symbol.flags.is(Flags.ExtensionMethod) =>
toPath(deep, focus) :+ PathSymbol.Extension(term, member)
/** Field access */
case Apply(deep, idents) =>
toPath(deep, focus) ++ idents.flatMap(toPath(_, focus))
Expand All @@ -167,49 +163,46 @@ object QuicklensMacros {
def widenAll: TypeRepr =
tpe.widen.dealias.poorMansLUB

def matchingTypeSymbol: Symbol = {
def recurse(tpe: TypeRepr): Symbol = {
tpe.widenAll match {
case AndType(l, r) =>
val lSym = recurse(l)
if lSym != Symbol.noSymbol then lSym else recurse(r)
case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) =>
tpe.typeSymbol
case tpe if isProductLike(tpe.typeSymbol) =>
tpe.typeSymbol
case _ =>
Symbol.noSymbol
}
}
val rSym = recurse(tpe)
if rSym != Symbol.noSymbol then rSym else tpe.typeSymbol // if everything else fails, try the original type, maybe it will have all we need
def matchingTypeSymbol: Symbol = tpe.widenAll match {
case AndType(l, r) =>
val lSym = l.matchingTypeSymbol
if lSym != Symbol.noSymbol then lSym else r.matchingTypeSymbol
case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) || isProductLike(tpe.typeSymbol) =>
tpe.typeSymbol
case _ =>
OndrejSpanel marked this conversation as resolved.
Show resolved Hide resolved
Symbol.noSymbol
}

extension (term: Term)
def appliedToIfNeeded(args: List[Term]): Term =
if args.isEmpty then term else term.appliedToArgs(args)

def symbolAccessorByNameOrError(obj: Term, name: String): Term = {
val objTpe = obj.tpe.widenAll
val objSymbol = objTpe.matchingTypeSymbol
// opaque types could find members of underlying types - do not ask them (see https://github.com/scala/scala3/issues/22143)
val mem = if !objSymbol.flags.is(Flags.Deferred) then objSymbol.fieldMember(name) else Symbol.noSymbol
if (mem != Symbol.noSymbol)
Select(obj, mem)
// opaque types can find members of underlying types - ignore them (see https://github.com/scala/scala3/issues/22143)
val fieldMemberSym = objSymbol.fieldMember(name)
if !objSymbol.flags.is(Flags.Deferred) && fieldMemberSym.exists then
Select(obj, fieldMemberSym)
else
objSymbol.methodMember(name) match
case List(m) =>
Select(obj, m)
case lst =>
reportMethodError(objSymbol, name, lst)
report.errorAndAbort(reportMethodError(objSymbol, name, lst))
}

def reportMethodError(sym: Symbol, name: String, lst: List[Symbol]): Nothing = {
lst match
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
case lst => report.errorAndAbort(multipleMatchingMethods(sym.name, name, lst))
def reportMethodError(sym: Symbol, name: String, lst: List[Symbol], maybeArgNames: Option[Iterable[String]] = None): String = {
(lst, maybeArgNames) match
case (Nil, _) => noSuchMember(sym.name, name)
case (lst, None) => multipleMatchingMethods(sym.name, name, lst)
case (lst, Some(argNames)) => noSuitableMember(sym.name, name, argNames)
}

def methodSymbolByNameOrError(sym: Symbol, name: String): Symbol = {
sym.methodMember(name) match
case List(m) => m
case lst => reportMethodError(sym, name, lst)
case lst => report.errorAndAbort(reportMethodError(sym, name, lst))
}

def filterMethodsByNameAndArgs(allMethods: List[Symbol], argsMap: Map[String, Term]): Option[Symbol] = {
Expand All @@ -233,7 +226,7 @@ object QuicklensMacros {
if !sym.flags.is(Flags.Deferred) then
val memberMethods = sym.methodMember(name)
filterMethodsByNameAndArgs(memberMethods, argsMap)
.toRight(if memberMethods.isEmpty then noSuchMember(sym.name, name) else noSuitableMember(sym.name, name, argsMap.keys))
.toRight(reportMethodError(sym, name, memberMethods, Some(argsMap.keys)))
else Left(s"Deferred type ${sym.name}")
}

Expand All @@ -242,49 +235,35 @@ object QuicklensMacros {
* on which the extension is called
* */
def callMethod(obj: Term, copy: Symbol, argsMap: List[Map[String, Term]]) = {
require(argsMap.size == 1 || argsMap.size == 2, s"argsMap.size should be either 1 or 2, got: ${argsMap.size} ($argsMap)")
val objTpe = obj.tpe.widenAll
val objSymbol = objTpe.matchingTypeSymbol

val typeParams = objTpe match {
case AppliedType(_, typeParams) => Some(typeParams)
case _ => None
}
val typeParams = objTpe.typeArgs
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
val copyParams: List[(String, Option[Term])] = copyTree.termParamss.zip(argsMap)
.map((params, args) => params.params.map(_.name).map(name => name -> args.get(name)))
.flatten.toList

val args = copyParams.zipWithIndex.map { case ((n, v), _i) =>
val i = _i + 1
def defaultMethod =
def defaultMethod: Term =
val methodSymbol = methodSymbolByNameOrError(objSymbol, copy.name + "$default$" + i.toString)
// default values in extensions are obtained by calling a method receiving the extension parameter
val defaultMethodArgs = argsMap.dropRight(1).headOption.toList.flatMap(_.values)
if defaultMethodArgs.nonEmpty then
Apply(Select(obj, methodSymbol), defaultMethodArgs)
else
// note: this is not always correct, -Xcheck-macros shows errors here
// sometimes we should call a method with empry parameter list instead
obj.select(methodSymbol)

// for extension methods, might need sth more like this: (or probably some weird implicit conversion)
// val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
// default values in extension methods take the extension receiver as the first parameter
val defaultMethodArgs = argsMap.dropRight(1).flatMap(_.values)
obj.select(methodSymbol).appliedToIfNeeded(defaultMethodArgs)
n -> v.getOrElse(defaultMethod)
}.toMap

val argLists = copyTree.termParamss.take(argsMap.size).map(list => list.params.map(p => args(p.name)))
val argLists: List[List[Term]] = copyTree.termParamss.take(argsMap.size).map(list => list.params.map(p => args(p.name)))

if copyTree.termParamss.drop(argLists.size).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit. ${copyTree.termParamss.drop(1)}"
)

val applyOn = typeParams match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case Some(typeParams) => TypeApply(Select(obj, copy), typeParams.map(Inferred(_)))
case _ => Select(obj, copy)
}
argLists.foldLeft(applyOn)((applied, list) => Apply(applied, list))
val withTypeParamsApplied = obj.select(copy).appliedToTypes(typeParams)
argLists.foldLeft(withTypeParamsApplied)(Apply(_, _))
}

def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
Expand All @@ -301,26 +280,25 @@ object QuicklensMacros {
(sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract)))
}

def findCompanionLikeObject(objSymbol: Symbol): Option[Symbol] = {
def optSymbol(objSymbol: Symbol) = Option.when(!objSymbol.isNoSymbol)(objSymbol)
optSymbol(objSymbol.companionModule).orElse {
// for opaque types, the companion type is not found by objSymbol.companionModule
// try to find an object by name in the owner scope
optSymbol(objSymbol.owner.fieldMember(objSymbol.name)).filter(_.flags.is(Flags.Module))
}
def findCompanionLikeObject(objSymbol: Symbol): Symbol = {
if objSymbol.companionModule.exists then
objSymbol.companionModule
else
val namedFromOwnerScope = objSymbol.owner.fieldMember(objSymbol.name)
if namedFromOwnerScope.flags.is(Flags.Module) then namedFromOwnerScope
else Symbol.noSymbol
}
def findExtensionMethod(using Quotes)(sym: Symbol, methodName: String): List[(Term, Symbol)] = {
// TODO: can we check parameter types somehow?
def isExtensionMethod(sym: Symbol): Boolean = sym.isDefDef && sym.paramSymss.headOption.exists(_.sizeIs == 1)

// TODO: try to search in symbol parent scope as well, as extension methods could be located there as well
val symbols = findCompanionLikeObject(sym).filter(_ != Symbol.noSymbol).toList

symbols.flatMap(s => s.declaredMethods.map(Ref(s) -> _)).filter((_, m) => m.name == methodName && isExtensionMethod(m)).toList
def hasExtensionNamed(sym: Symbol, methodName: String): List[Symbol] = {
val companionSymbol = findCompanionLikeObject(sym)
if companionSymbol.exists then
companionSymbol.methodMember(methodName).filter(s => s.name == methodName && s.flags.is(Flags.ExtensionMethod))
else
Nil
}

def isProductLike(sym: Symbol): Boolean = {
sym.methodMember("copy").nonEmpty || findExtensionMethod(sym, "copy").nonEmpty
sym.methodMember("copy").nonEmpty || hasExtensionNamed(sym, "copy").nonEmpty
}

def caseClassCopy(
Expand Down Expand Up @@ -363,30 +341,30 @@ object QuicklensMacros {
}
} else if isProduct(objSymbol) || isProductLike(objSymbol) then {
val argsMap: Map[String, Term] = fields.map { (field, trees) =>
val (fieldMethod, name) = field match {
val fieldMethod = field match {
case PathSymbol.Field(name) =>
symbolAccessorByNameOrError (obj, name) -> name
symbolAccessorByNameOrError(obj, name)
case PathSymbol.Extension(term, name) =>
val extensionMethod = symbolAccessorByNameOrError (term, name)
Apply(extensionMethod, List(obj)) -> name
val extensionMethod = symbolAccessorByNameOrError(term, name)
Apply(extensionMethod, List(obj))
}
val resTerm: Term = trees.foldLeft[Term](fieldMethod) { (term, tree) =>
mapToCopy(owner, mod, term, tree)
}
val namedArg = NamedArg(name, resTerm)
name -> namedArg
val namedArg = NamedArg(field.name, resTerm)
field.name -> namedArg
}.toMap
methodSymbolByNameAndArgs(objSymbol, "copy", argsMap) match
case Right(copy) =>
callMethod(obj, copy, List(argsMap))
case Left(error) =>
val objCompanion = findCompanionLikeObject(objSymbol)
objCompanion.flatMap(methodSymbolByNameAndArgs(_, "copy", argsMap).toOption) match
methodSymbolByNameAndArgs(objCompanion, "copy", argsMap).toOption match
OndrejSpanel marked this conversation as resolved.
Show resolved Hide resolved
case Some(copy) =>
// now try to call the extension as a method, assume the object is its first parameter
val extensionParameter = copy.paramSymss.headOption.map(_.headOption).flatten
val argsWithObj = List(extensionParameter.map(name => name.name -> obj).toMap, argsMap)
callMethod(Ref(objCompanion.get), copy, argsWithObj)
callMethod(Ref(objCompanion), copy, argsWithObj)
case None => report.errorAndAbort(error)
} else
report.errorAndAbort(s"Unsupported source object: must be a case class, sealed trait or class with copy method, but got: $objSymbol of type ${objTpe.show} (${obj.show})")
Expand Down Expand Up @@ -430,7 +408,7 @@ object QuicklensMacros {
objTerm

case (_: (PathSymbol.Field | PathSymbol.Extension), _) :: _ =>
val (fs, funs) = pathSymbols.span(s => s._1.isInstanceOf[PathSymbol.Field] | s._1.isInstanceOf[PathSymbol.Extension])
val (fs, funs) = pathSymbols.partition((ps, _) => ps.isInstanceOf[PathSymbol.Field] || ps.isInstanceOf[PathSymbol.Extension])
KacperFKorban marked this conversation as resolved.
Show resolved Hide resolved
val fields = fs.collect { case (p: (PathSymbol.Field | PathSymbol.Extension), trees) => p -> trees }
val withCopiedFields: Term = caseClassCopy(owner, mod, objTerm, fields)
accumulateToCopy(owner, mod, withCopiedFields, funs)
Expand Down