Skip to content

Commit

Permalink
Add a compile-time option to inline non value classes which have the …
Browse files Browse the repository at this point in the history
…primary constructor with just one argument
  • Loading branch information
plokhotnyuk committed Jan 2, 2024
1 parent ad2b7c0 commit f7b9ccd
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 89 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ There are configurable options that can be set in compile-time:
- Ability to disable generation of implementation for decoding or encoding
- Ability to require fields that have defined default values
- Ability to generate smaller and more efficient codecs for classes when checking of field duplication is not needed
- Ability to inline non value classes which have the primary constructor with just one argument

List of options that change parsing and serialization in runtime:
- Serialization of strings with escaped Unicode characters to be ASCII compatible
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ final class stringified extends StaticAnnotation
* @param checkFieldDuplication a flag that turns on checking of duplicated fields during parsing of classes (turned
* on by default)
* @param scalaTransientSupport a flag that turns on support of `scala.transient` (turned off by default)
* @param inlineOneValueClasses a flag that turns on derivation of inlined codecs for non-values classes that have
* the primary constructor with just one argument (turned off by default)
*/
class CodecMakerConfig private (
val fieldNameMapper: PartialFunction[String, String],
Expand Down Expand Up @@ -120,7 +122,8 @@ class CodecMakerConfig private (
val encodingOnly: Boolean,
val requireDefaultFields: Boolean,
val checkFieldDuplication: Boolean,
val scalaTransientSupport: Boolean) {
val scalaTransientSupport: Boolean,
val inlineOneValueClasses: Boolean) {
def withFieldNameMapper(fieldNameMapper: PartialFunction[String, String]): CodecMakerConfig =
copy(fieldNameMapper = fieldNameMapper)

Expand Down Expand Up @@ -196,6 +199,9 @@ class CodecMakerConfig private (
def withScalaTransientSupport(scalaTransientSupport: Boolean): CodecMakerConfig =
copy(scalaTransientSupport = scalaTransientSupport)

def withInlineOneValueClasses(inlineOneValueClasses: Boolean): CodecMakerConfig =
copy(inlineOneValueClasses = inlineOneValueClasses)

private[this] def copy(fieldNameMapper: PartialFunction[String, String] = fieldNameMapper,
javaEnumValueNameMapper: PartialFunction[String, String] = javaEnumValueNameMapper,
adtLeafClassNameMapper: String => String = adtLeafClassNameMapper,
Expand Down Expand Up @@ -223,7 +229,8 @@ class CodecMakerConfig private (
encodingOnly: Boolean = encodingOnly,
requireDefaultFields: Boolean = requireDefaultFields,
checkFieldDuplication: Boolean = checkFieldDuplication,
scalaTransientSupport: Boolean = scalaTransientSupport): CodecMakerConfig =
scalaTransientSupport: Boolean = scalaTransientSupport,
inlineOneValueClasses: Boolean = inlineOneValueClasses): CodecMakerConfig =
new CodecMakerConfig(
fieldNameMapper = fieldNameMapper,
javaEnumValueNameMapper = javaEnumValueNameMapper,
Expand Down Expand Up @@ -252,7 +259,8 @@ class CodecMakerConfig private (
encodingOnly = encodingOnly,
requireDefaultFields = requireDefaultFields,
checkFieldDuplication = checkFieldDuplication,
scalaTransientSupport = scalaTransientSupport)
scalaTransientSupport = scalaTransientSupport,
inlineOneValueClasses = inlineOneValueClasses)
}

object CodecMakerConfig extends CodecMakerConfig(
Expand Down Expand Up @@ -283,7 +291,8 @@ object CodecMakerConfig extends CodecMakerConfig(
decodingOnly = false,
requireDefaultFields = false,
checkFieldDuplication = true,
scalaTransientSupport = false) {
scalaTransientSupport = false,
inlineOneValueClasses = false) {

/**
* Use to enable printing of codec during compilation:
Expand Down Expand Up @@ -583,8 +592,6 @@ object JsonCodecMaker {

def isTuple(tpe: Type): Boolean = tupleSymbols(tpe.typeSymbol)

def isValueClass(tpe: Type): Boolean = tpe.typeSymbol.isClass && tpe.typeSymbol.asClass.isDerivedValueClass

def valueClassValueMethod(tpe: Type): MethodSymbol = tpe.decls.head.asMethod

def decodeName(s: Symbol): String = NameTransformer.decode(s.name.toString)
Expand Down Expand Up @@ -797,6 +804,80 @@ object JsonCodecMaker {
}
}

case class FieldInfo(symbol: TermSymbol, mappedName: String, tmpName: TermName, getter: MethodSymbol,
defaultValue: Option[Tree], resolvedTpe: Type, isStringified: Boolean)

case class ClassInfo(tpe: Type, paramLists: List[List[FieldInfo]]) {
val fields: List[FieldInfo] = paramLists.flatten
}

val classInfos = new mutable.LinkedHashMap[Type, ClassInfo]

def getClassInfo(tpe: Type): ClassInfo = classInfos.getOrElseUpdate(tpe, {
case class FieldAnnotations(partiallyMappedName: String, transient: Boolean, stringified: Boolean)

def getPrimaryConstructor(tpe: Type): MethodSymbol = tpe.decls.collectFirst {
case m: MethodSymbol if m.isPrimaryConstructor => m // FIXME: sometime it cannot be accessed from the place of the `make` call
}.getOrElse(fail(s"Cannot find a primary constructor for '$tpe'"))

def hasSupportedAnnotation(m: TermSymbol): Boolean = {
m.info: Unit // to enforce the type information completeness and availability of annotations
m.annotations.exists(a => a.tree.tpe =:= typeOf[named] || a.tree.tpe =:= typeOf[transient] ||
a.tree.tpe =:= typeOf[stringified] || (cfg.scalaTransientSupport && a.tree.tpe =:= typeOf[scala.transient]))
}

def supportedTransientTypeNames: String =
if (cfg.scalaTransientSupport) s"'${typeOf[transient]}' (or '${typeOf[scala.transient]}')"
else s"'${typeOf[transient]}')"

lazy val module = companion(tpe).asModule // don't lookup for the companion when there are no default values for constructor params
val getters = tpe.members.collect { case m: MethodSymbol if m.isParamAccessor && m.isGetter => m }
val annotations = tpe.members.collect {
case m: TermSymbol if hasSupportedAnnotation(m) =>
val name = decodeName(m).trim // FIXME: Why is there a space at the end of field name?!
val named = m.annotations.filter(_.tree.tpe =:= typeOf[named])
if (named.size > 1) fail(s"Duplicated '${typeOf[named]}' defined for '$name' of '$tpe'.")
val trans = m.annotations.filter(a => a.tree.tpe =:= typeOf[transient] ||
(cfg.scalaTransientSupport && a.tree.tpe =:= typeOf[scala.transient]))
if (trans.size > 1) warn(s"Duplicated $supportedTransientTypeNames defined for '$name' of '$tpe'.")
val strings = m.annotations.filter(_.tree.tpe =:= typeOf[stringified])
if (strings.size > 1) warn(s"Duplicated '${typeOf[stringified]}' defined for '$name' of '$tpe'.")
if ((named.nonEmpty || strings.nonEmpty) && trans.nonEmpty) {
warn(s"Both $supportedTransientTypeNames and '${typeOf[named]}' or " +
s"$supportedTransientTypeNames and '${typeOf[stringified]}' defined for '$name' of '$tpe'.")
}
val partiallyMappedName = namedValueOpt(named.headOption, tpe).getOrElse(name)
(name, FieldAnnotations(partiallyMappedName, trans.nonEmpty, strings.nonEmpty))
}.toMap
ClassInfo(tpe, {
var i = 0
getPrimaryConstructor(tpe).paramLists.map(_.flatMap { p =>
i += 1
val symbol = p.asTerm
val name = decodeName(symbol)
val annotationOption = annotations.get(name)
if (annotationOption.exists(_.transient)) None
else {
val fieldNameMapper: String => String = n => cfg.fieldNameMapper.lift(n).getOrElse(n)
val mappedName = annotationOption.fold(fieldNameMapper(name))(_.partiallyMappedName)
val tmpName = TermName("_" + symbol.name)
val getter = getters.find(_.name == symbol.name).getOrElse {
fail(s"'$name' parameter of '$tpe' should be defined as 'val' or 'var' in the primary constructor.")
}
val defaultValue =
if (!cfg.requireDefaultFields && symbol.isParamWithDefault) Some(q"$module.${TermName("$lessinit$greater$default$" + i)}")
else None
val isStringified = annotationOption.exists(_.stringified)
Some(FieldInfo(symbol, mappedName, tmpName, getter, defaultValue, paramType(tpe, symbol), isStringified))
}
})
})
})

def isValueClass(tpe: Type): Boolean = !isConstType(tpe) &&
(cfg.inlineOneValueClasses && isNonAbstractScalaClass(tpe) && !isCollection(tpe) && getClassInfo(tpe).fields.size == 1 ||
tpe.typeSymbol.isClass && tpe.typeSymbol.asClass.isDerivedValueClass)

def genReadKey(types: List[Type]): Tree = {
val tpe = types.head
val implKeyCodec = findImplicitKeyCodec(types)
Expand Down Expand Up @@ -1071,76 +1152,6 @@ object JsonCodecMaker {
}
}

case class FieldInfo(symbol: TermSymbol, mappedName: String, tmpName: TermName, getter: MethodSymbol,
defaultValue: Option[Tree], resolvedTpe: Type, isStringified: Boolean)

case class ClassInfo(tpe: Type, paramLists: List[List[FieldInfo]]) {
val fields: List[FieldInfo] = paramLists.flatten
}

val classInfos = new mutable.LinkedHashMap[Type, ClassInfo]

def getClassInfo(tpe: Type): ClassInfo = classInfos.getOrElseUpdate(tpe, {
case class FieldAnnotations(partiallyMappedName: String, transient: Boolean, stringified: Boolean)

def getPrimaryConstructor(tpe: Type): MethodSymbol = tpe.decls.collectFirst {
case m: MethodSymbol if m.isPrimaryConstructor => m // FIXME: sometime it cannot be accessed from the place of the `make` call
}.getOrElse(fail(s"Cannot find a primary constructor for '$tpe'"))

def hasSupportedAnnotation(m: TermSymbol): Boolean = {
m.info: Unit // to enforce the type information completeness and availability of annotations
m.annotations.exists(a => a.tree.tpe =:= typeOf[named] || a.tree.tpe =:= typeOf[transient] ||
a.tree.tpe =:= typeOf[stringified] || (cfg.scalaTransientSupport && a.tree.tpe =:= typeOf[scala.transient]))
}

def supportedTransientTypeNames: String =
if (cfg.scalaTransientSupport) s"'${typeOf[transient]}' (or '${typeOf[scala.transient]}')"
else s"'${typeOf[transient]}')"

lazy val module = companion(tpe).asModule // don't lookup for the companion when there are no default values for constructor params
val getters = tpe.members.collect { case m: MethodSymbol if m.isParamAccessor && m.isGetter => m }
val annotations = tpe.members.collect {
case m: TermSymbol if hasSupportedAnnotation(m) =>
val name = decodeName(m).trim // FIXME: Why is there a space at the end of field name?!
val named = m.annotations.filter(_.tree.tpe =:= typeOf[named])
if (named.size > 1) fail(s"Duplicated '${typeOf[named]}' defined for '$name' of '$tpe'.")
val trans = m.annotations.filter(a => a.tree.tpe =:= typeOf[transient] ||
(cfg.scalaTransientSupport && a.tree.tpe =:= typeOf[scala.transient]))
if (trans.size > 1) warn(s"Duplicated $supportedTransientTypeNames defined for '$name' of '$tpe'.")
val strings = m.annotations.filter(_.tree.tpe =:= typeOf[stringified])
if (strings.size > 1) warn(s"Duplicated '${typeOf[stringified]}' defined for '$name' of '$tpe'.")
if ((named.nonEmpty || strings.nonEmpty) && trans.nonEmpty) {
warn(s"Both $supportedTransientTypeNames and '${typeOf[named]}' or " +
s"$supportedTransientTypeNames and '${typeOf[stringified]}' defined for '$name' of '$tpe'.")
}
val partiallyMappedName = namedValueOpt(named.headOption, tpe).getOrElse(name)
(name, FieldAnnotations(partiallyMappedName, trans.nonEmpty, strings.nonEmpty))
}.toMap
ClassInfo(tpe, {
var i = 0
getPrimaryConstructor(tpe).paramLists.map(_.flatMap { p =>
i += 1
val symbol = p.asTerm
val name = decodeName(symbol)
val annotationOption = annotations.get(name)
if (annotationOption.exists(_.transient)) None
else {
val fieldNameMapper: String => String = n => cfg.fieldNameMapper.lift(n).getOrElse(n)
val mappedName = annotationOption.fold(fieldNameMapper(name))(_.partiallyMappedName)
val tmpName = TermName("_" + symbol.name)
val getter = getters.find(_.name == symbol.name).getOrElse {
fail(s"'$name' parameter of '$tpe' should be defined as 'val' or 'var' in the primary constructor.")
}
val defaultValue =
if (!cfg.requireDefaultFields && symbol.isParamWithDefault) Some(q"$module.${TermName("$lessinit$greater$default$" + i)}")
else None
val isStringified = annotationOption.exists(_.stringified)
Some(FieldInfo(symbol, mappedName, tmpName, getter, defaultValue, paramType(tpe, symbol), isStringified))
}
})
})
})

val unexpectedFieldHandler =
if (cfg.skipUnexpectedFields) q"in.skip()"
else q"in.unexpectedKeyError(l)"
Expand Down
Loading

0 comments on commit f7b9ccd

Please sign in to comment.