Skip to content

Commit

Permalink
Some refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
KacperFKorban committed Dec 9, 2024
1 parent 4a0c653 commit 643bab0
Showing 1 changed file with 63 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ object QuicklensMacros {
case (symbol :: tail) => PathTree.Node(Seq(symbol -> Seq(tail.toPathTree)))

enum PathSymbol:
case Field(name: String)
case Extension(term: Term, name: String)
case FunctionDelegate(name: String, givn: Term, typeTree: TypeTree, args: List[Term])
case Field(override val name: String)
case Extension(term: Term, override val name: String)
case FunctionDelegate(override val name: String, givn: Term, typeTree: TypeTree, args: List[Term])
def name: String

def equiv(other: Any): Boolean = (this, other) match
case (Field(name1), Field(name2)) => name1 == name2
Expand All @@ -138,14 +139,9 @@ object QuicklensMacros {
/** Method call with one type parameter and using clause */
case a @ Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) =>
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
case Apply(obj, Seq(deep)) => // this is an extension method, which is called e.g. as x(_$1)
obj match
case Ident(ident) =>
toPath(deep, focus) :+ PathSymbol.Field(ident)
case Select(term, member) =>
toPath(deep, focus) :+ PathSymbol.Extension(term, member)
case other =>
report.errorAndAbort(unsupportedShapeInfo(focus.asTerm))
/** Extension method, which is called e.g. as x(_$1) */
case Apply(obj@Select(term, member), Seq(deep)) if obj.symbol.flags.is(Flags.ExtensionMethod) =>
toPath(deep, focus) :+ PathSymbol.Extension(term, member)
/** Field access */
case Apply(deep, idents) =>
toPath(deep, focus) ++ idents.flatMap(toPath(_, focus))
Expand All @@ -167,49 +163,46 @@ object QuicklensMacros {
def widenAll: TypeRepr =
tpe.widen.dealias.poorMansLUB

def matchingTypeSymbol: Symbol = {
def recurse(tpe: TypeRepr): Symbol = {
tpe.widenAll match {
case AndType(l, r) =>
val lSym = recurse(l)
if lSym != Symbol.noSymbol then lSym else recurse(r)
case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) =>
tpe.typeSymbol
case tpe if isProductLike(tpe.typeSymbol) =>
tpe.typeSymbol
case _ =>
Symbol.noSymbol
}
}
val rSym = recurse(tpe)
if rSym != Symbol.noSymbol then rSym else tpe.typeSymbol // if everything else fails, try the original type, maybe it will have all we need
def matchingTypeSymbol: Symbol = tpe.widenAll match {
case AndType(l, r) =>
val lSym = l.matchingTypeSymbol
if lSym != Symbol.noSymbol then lSym else r.matchingTypeSymbol
case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) || isProductLike(tpe.typeSymbol) =>
tpe.typeSymbol
case _ =>
Symbol.noSymbol
}

extension (term: Term)
def appliedToIfNeeded(args: List[Term]): Term =
if args.isEmpty then term else term.appliedToArgs(args)

def symbolAccessorByNameOrError(obj: Term, name: String): Term = {
val objTpe = obj.tpe.widenAll
val objSymbol = objTpe.matchingTypeSymbol
// opaque types could find members of underlying types - do not ask them (see https://github.com/scala/scala3/issues/22143)
val mem = if !objSymbol.flags.is(Flags.Deferred) then objSymbol.fieldMember(name) else Symbol.noSymbol
if (mem != Symbol.noSymbol)
Select(obj, mem)
// opaque types can find members of underlying types - ignore them (see https://github.com/scala/scala3/issues/22143)
val fieldMemberSym = objSymbol.fieldMember(name)
if !objSymbol.flags.is(Flags.Deferred) && fieldMemberSym.exists then
Select(obj, fieldMemberSym)
else
objSymbol.methodMember(name) match
case List(m) =>
Select(obj, m)
case lst =>
reportMethodError(objSymbol, name, lst)
report.errorAndAbort(reportMethodError(objSymbol, name, lst))
}

def reportMethodError(sym: Symbol, name: String, lst: List[Symbol]): Nothing = {
lst match
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
case lst => report.errorAndAbort(multipleMatchingMethods(sym.name, name, lst))
def reportMethodError(sym: Symbol, name: String, lst: List[Symbol], maybeArgNames: Option[Iterable[String]] = None): String = {
(lst, maybeArgNames) match
case (Nil, _) => noSuchMember(sym.name, name)
case (lst, None) => multipleMatchingMethods(sym.name, name, lst)
case (lst, Some(argNames)) => noSuitableMember(sym.name, name, argNames)
}

def methodSymbolByNameOrError(sym: Symbol, name: String): Symbol = {
sym.methodMember(name) match
case List(m) => m
case lst => reportMethodError(sym, name, lst)
case lst => report.errorAndAbort(reportMethodError(sym, name, lst))
}

def filterMethodsByNameAndArgs(allMethods: List[Symbol], argsMap: Map[String, Term]): Option[Symbol] = {
Expand All @@ -233,7 +226,7 @@ object QuicklensMacros {
if !sym.flags.is(Flags.Deferred) then
val memberMethods = sym.methodMember(name)
filterMethodsByNameAndArgs(memberMethods, argsMap)
.toRight(if memberMethods.isEmpty then noSuchMember(sym.name, name) else noSuitableMember(sym.name, name, argsMap.keys))
.toRight(reportMethodError(sym, name, memberMethods, Some(argsMap.keys)))
else Left(s"Deferred type ${sym.name}")
}

Expand All @@ -242,49 +235,35 @@ object QuicklensMacros {
* on which the extension is called
* */
def callMethod(obj: Term, copy: Symbol, argsMap: List[Map[String, Term]]) = {
require(argsMap.size == 1 || argsMap.size == 2, s"argsMap.size should be either 1 or 2, got: ${argsMap.size} ($argsMap)")
val objTpe = obj.tpe.widenAll
val objSymbol = objTpe.matchingTypeSymbol

val typeParams = objTpe match {
case AppliedType(_, typeParams) => Some(typeParams)
case _ => None
}
val typeParams = objTpe.typeArgs
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
val copyParams: List[(String, Option[Term])] = copyTree.termParamss.zip(argsMap)
.map((params, args) => params.params.map(_.name).map(name => name -> args.get(name)))
.flatten.toList

val args = copyParams.zipWithIndex.map { case ((n, v), _i) =>
val i = _i + 1
def defaultMethod =
def defaultMethod: Term =
val methodSymbol = methodSymbolByNameOrError(objSymbol, copy.name + "$default$" + i.toString)
// default values in extensions are obtained by calling a method receiving the extension parameter
val defaultMethodArgs = argsMap.dropRight(1).headOption.toList.flatMap(_.values)
if defaultMethodArgs.nonEmpty then
Apply(Select(obj, methodSymbol), defaultMethodArgs)
else
// note: this is not always correct, -Xcheck-macros shows errors here
// sometimes we should call a method with empry parameter list instead
obj.select(methodSymbol)

// for extension methods, might need sth more like this: (or probably some weird implicit conversion)
// val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
// default values in extension methods take the extension receiver as the first parameter
val defaultMethodArgs = argsMap.dropRight(1).flatMap(_.values)
obj.select(methodSymbol).appliedToIfNeeded(defaultMethodArgs)
n -> v.getOrElse(defaultMethod)
}.toMap

val argLists = copyTree.termParamss.take(argsMap.size).map(list => list.params.map(p => args(p.name)))
val argLists: List[List[Term]] = copyTree.termParamss.take(argsMap.size).map(list => list.params.map(p => args(p.name)))

if copyTree.termParamss.drop(argLists.size).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit. ${copyTree.termParamss.drop(1)}"
)

val applyOn = typeParams match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case Some(typeParams) => TypeApply(Select(obj, copy), typeParams.map(Inferred(_)))
case _ => Select(obj, copy)
}
argLists.foldLeft(applyOn)((applied, list) => Apply(applied, list))
val withTypeParamsApplied = obj.select(copy).appliedToTypes(typeParams)
argLists.foldLeft(withTypeParamsApplied)(Apply(_, _))
}

def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
Expand All @@ -301,26 +280,25 @@ object QuicklensMacros {
(sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract)))
}

def findCompanionLikeObject(objSymbol: Symbol): Option[Symbol] = {
def optSymbol(objSymbol: Symbol) = Option.when(!objSymbol.isNoSymbol)(objSymbol)
optSymbol(objSymbol.companionModule).orElse {
// for opaque types, the companion type is not found by objSymbol.companionModule
// try to find an object by name in the owner scope
optSymbol(objSymbol.owner.fieldMember(objSymbol.name)).filter(_.flags.is(Flags.Module))
}
def findCompanionLikeObject(objSymbol: Symbol): Symbol = {
if objSymbol.companionModule.exists then
objSymbol.companionModule
else
val namedFromOwnerScope = objSymbol.owner.fieldMember(objSymbol.name)
if namedFromOwnerScope.flags.is(Flags.Module) then namedFromOwnerScope
else Symbol.noSymbol
}
def findExtensionMethod(using Quotes)(sym: Symbol, methodName: String): List[(Term, Symbol)] = {
// TODO: can we check parameter types somehow?
def isExtensionMethod(sym: Symbol): Boolean = sym.isDefDef && sym.paramSymss.headOption.exists(_.sizeIs == 1)

// TODO: try to search in symbol parent scope as well, as extension methods could be located there as well
val symbols = findCompanionLikeObject(sym).filter(_ != Symbol.noSymbol).toList

symbols.flatMap(s => s.declaredMethods.map(Ref(s) -> _)).filter((_, m) => m.name == methodName && isExtensionMethod(m)).toList
def hasExtensionNamed(sym: Symbol, methodName: String): List[Symbol] = {
val companionSymbol = findCompanionLikeObject(sym)
if companionSymbol.exists then
companionSymbol.methodMember(methodName).filter(s => s.name == methodName && s.flags.is(Flags.ExtensionMethod))
else
Nil
}

def isProductLike(sym: Symbol): Boolean = {
sym.methodMember("copy").nonEmpty || findExtensionMethod(sym, "copy").nonEmpty
sym.methodMember("copy").nonEmpty || hasExtensionNamed(sym, "copy").nonEmpty
}

def caseClassCopy(
Expand Down Expand Up @@ -363,30 +341,30 @@ object QuicklensMacros {
}
} else if isProduct(objSymbol) || isProductLike(objSymbol) then {
val argsMap: Map[String, Term] = fields.map { (field, trees) =>
val (fieldMethod, name) = field match {
val fieldMethod = field match {
case PathSymbol.Field(name) =>
symbolAccessorByNameOrError (obj, name) -> name
symbolAccessorByNameOrError(obj, name)
case PathSymbol.Extension(term, name) =>
val extensionMethod = symbolAccessorByNameOrError (term, name)
Apply(extensionMethod, List(obj)) -> name
val extensionMethod = symbolAccessorByNameOrError(term, name)
Apply(extensionMethod, List(obj))
}
val resTerm: Term = trees.foldLeft[Term](fieldMethod) { (term, tree) =>
mapToCopy(owner, mod, term, tree)
}
val namedArg = NamedArg(name, resTerm)
name -> namedArg
val namedArg = NamedArg(field.name, resTerm)
field.name -> namedArg
}.toMap
methodSymbolByNameAndArgs(objSymbol, "copy", argsMap) match
case Right(copy) =>
callMethod(obj, copy, List(argsMap))
case Left(error) =>
val objCompanion = findCompanionLikeObject(objSymbol)
objCompanion.flatMap(methodSymbolByNameAndArgs(_, "copy", argsMap).toOption) match
methodSymbolByNameAndArgs(objCompanion, "copy", argsMap).toOption match
case Some(copy) =>
// now try to call the extension as a method, assume the object is its first parameter
val extensionParameter = copy.paramSymss.headOption.map(_.headOption).flatten
val argsWithObj = List(extensionParameter.map(name => name.name -> obj).toMap, argsMap)
callMethod(Ref(objCompanion.get), copy, argsWithObj)
callMethod(Ref(objCompanion), copy, argsWithObj)
case None => report.errorAndAbort(error)
} else
report.errorAndAbort(s"Unsupported source object: must be a case class, sealed trait or class with copy method, but got: $objSymbol of type ${objTpe.show} (${obj.show})")
Expand Down Expand Up @@ -430,7 +408,7 @@ object QuicklensMacros {
objTerm

case (_: (PathSymbol.Field | PathSymbol.Extension), _) :: _ =>
val (fs, funs) = pathSymbols.span(s => s._1.isInstanceOf[PathSymbol.Field] | s._1.isInstanceOf[PathSymbol.Extension])
val (fs, funs) = pathSymbols.partition((ps, _) => ps.isInstanceOf[PathSymbol.Field] || ps.isInstanceOf[PathSymbol.Extension])
val fields = fs.collect { case (p: (PathSymbol.Field | PathSymbol.Extension), trees) => p -> trees }
val withCopiedFields: Term = caseClassCopy(owner, mod, objTerm, fields)
accumulateToCopy(owner, mod, withCopiedFields, funs)
Expand Down

0 comments on commit 643bab0

Please sign in to comment.