From f2d3d4e1546709d73ec19222e06f3630bc2eb55b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20=C5=A0pan=C4=9Bl?= Date: Thu, 1 Sep 2022 15:52:22 +0200 Subject: [PATCH] Invert case class detection logic --- .../softwaremill/quicklens/QuicklensMacros.scala | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala index 19ed879..e5fafe1 100644 --- a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala +++ b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala @@ -161,7 +161,9 @@ object QuicklensMacros { fields: Seq[(PathSymbol.Field, Seq[PathTree])] ): Term = { val objSymbol = obj.tpe.typeSymbol - if objSymbol.flags.is(Flags.Case) then { + if !(objSymbol.flags.is(Flags.Enum) || + (objSymbol.flags.is(Flags.Sealed) && (objSymbol.flags.is(Flags.Trait) || objSymbol.flags.is(Flags.Abstract)))) + then { val copy = termMethodByNameUnsafe(obj, "copy") val argsMap: Map[Int, Term] = fields.map { (field, trees) => val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name) @@ -185,11 +187,9 @@ object QuicklensMacros { case AppliedType(_, typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args) case _ => Apply(Select(obj, copy), args) } - } else if objSymbol.flags.is(Flags.Enum) || - (objSymbol.flags.is(Flags.Sealed) && (objSymbol.flags.is(Flags.Trait) || objSymbol.flags.is(Flags.Abstract))) - then { + } else { // if the source is a sealed trait / sealed abstract class / enum, generating a if-then-else with a .copy for each child (implementing case class) - val cases = obj.tpe.typeSymbol.children.map { child => + val cases = objSymbol.children.map { child => val subtype = TypeIdent(child) val bind = Symbol.newBind(owner, "c", Flags.EmptyFlags, subtype.tpe) CaseDef(Bind(bind, Typed(Ref(bind), subtype)), None, caseClassCopy(owner, mod, Ref(bind), fields)) @@ -201,7 +201,7 @@ object QuicklensMacros { ... else throw new IllegalStateException() */ - val ifThens = obj.tpe.typeSymbol.children.map { child => + val ifThens = objSymbol.children.map { child => val ifCond = TypeApply(Select.unique(obj, "isInstanceOf"), List(TypeIdent(child))) val ifThen = ValDef.let(owner, TypeApply(Select.unique(obj, "asInstanceOf"), List(TypeIdent(child)))) { @@ -216,8 +216,7 @@ object QuicklensMacros { ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) => If(ifCond, ifThen, ifElse) } - } else - report.throwError(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol") + } } def applyFunctionDelegate(