Skip to content

Commit

Permalink
Support customizing $type tag with key annotation.
Browse files Browse the repository at this point in the history
  • Loading branch information
mrdziuban committed May 14, 2024
1 parent 3845e9e commit b472fc9
Show file tree
Hide file tree
Showing 12 changed files with 124 additions and 82 deletions.
53 changes: 29 additions & 24 deletions upickle/core/src/upickle/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ trait Types{ types =>
abstract class Delegate[T](other: Visitor[Any, T])
extends Visitor.Delegate[Any, T](other) with ReadWriter[T]

def merge[T](rws: ReadWriter[_ <: T]*): TaggedReadWriter[T] = {
new TaggedReadWriter.Node(rws.asInstanceOf[Seq[TaggedReadWriter[T]]]:_*)
def merge[T](tagKey: String, rws: ReadWriter[_ <: T]*): TaggedReadWriter[T] = {
new TaggedReadWriter.Node(tagKey, rws.asInstanceOf[Seq[TaggedReadWriter[T]]]:_*)
}

implicit def join[T](implicit r0: Reader[T], w0: Writer[T]): ReadWriter[T] = (r0, w0) match{
Expand All @@ -45,6 +45,7 @@ trait Types{ types =>

case (r1: TaggedReader[T], w1: TaggedWriter[T]) =>
new TaggedReadWriter[T] {
private[upickle] def tagKey = r1.tagKey
override def isJsonDictKey = w0.isJsonDictKey
def findReader(s: String) = r1.findReader(s)
def findWriter(v: Any) = w1.findWriter(v)
Expand Down Expand Up @@ -104,8 +105,8 @@ trait Types{ types =>

override def visitArray(length: Int, index: Int) = super.visitArray(length, index).asInstanceOf[ArrVisitor[Any, Z]]
}
def merge[T](readers0: Reader[_ <: T]*) = {
new TaggedReader.Node(readers0.asInstanceOf[Seq[TaggedReader[T]]]:_*)
def merge[T](tagKey: String, readers0: Reader[_ <: T]*) = {
new TaggedReader.Node(tagKey, readers0.asInstanceOf[Seq[TaggedReader[T]]]:_*)
}
}

Expand Down Expand Up @@ -147,6 +148,8 @@ trait Types{ types =>
}

trait TaggedReader[T] extends SimpleReader[T]{
private[upickle] def tagKey: String

def findReader(s: String): Reader[T]

override def expectedMsg = taggedExpectedMsg
Expand All @@ -160,34 +163,34 @@ trait Types{ types =>
}
}
object TaggedReader{
class Leaf[T](tag: String, r: Reader[T]) extends TaggedReader[T]{
def findReader(s: String) = if (s == tag) r else null
class Leaf[T](private[upickle] val tagKey: String, tagValue: String, r: Reader[T]) extends TaggedReader[T]{
def findReader(s: String) = if (s == tagValue) r else null
}
class Node[T](rs: TaggedReader[_ <: T]*) extends TaggedReader[T]{
class Node[T](private[upickle] val tagKey: String, rs: TaggedReader[_ <: T]*) extends TaggedReader[T]{
def findReader(s: String) = scanChildren(rs)(_.findReader(s)).asInstanceOf[Reader[T]]
}
}

trait TaggedWriter[T] extends Writer[T]{
def findWriter(v: Any): (String, ObjectWriter[T])
def findWriter(v: Any): (String, String, ObjectWriter[T])
def write0[R](out: Visitor[_, R], v: T): R = {
val (tag, w) = findWriter(v)
taggedWrite(w, tag, out, v)
val (tagKey, tagValue, w) = findWriter(v)
taggedWrite(w, tagKey, tagValue, out, v)

}
}
object TaggedWriter{
class Leaf[T](checker: Annotator.Checker, tag: String, r: ObjectWriter[T]) extends TaggedWriter[T]{
class Leaf[T](checker: Annotator.Checker, tagKey: String, tagValue: String, r: ObjectWriter[T]) extends TaggedWriter[T]{
def findWriter(v: Any) = {
checker match{
case Annotator.Checker.Cls(c) if c.isInstance(v) => tag -> r
case Annotator.Checker.Val(v0) if v0 == v => tag -> r
case Annotator.Checker.Cls(c) if c.isInstance(v) => (tagKey, tagValue, r)
case Annotator.Checker.Val(v0) if v0 == v => (tagKey, tagValue, r)
case _ => null
}
}
}
class Node[T](rs: TaggedWriter[_ <: T]*) extends TaggedWriter[T]{
def findWriter(v: Any) = scanChildren(rs)(_.findWriter(v)).asInstanceOf[(String, ObjectWriter[T])]
def findWriter(v: Any) = scanChildren(rs)(_.findWriter(v)).asInstanceOf[(String, String, ObjectWriter[T])]
}
}

Expand All @@ -197,16 +200,16 @@ trait Types{ types =>

}
object TaggedReadWriter{
class Leaf[T](c: ClassTag[_], tag: String, r: ObjectWriter[T] with Reader[T]) extends TaggedReadWriter[T]{
def findReader(s: String) = if (s == tag) r else null
class Leaf[T](c: ClassTag[_], private[upickle] val tagKey: String, tagValue: String, r: ObjectWriter[T] with Reader[T]) extends TaggedReadWriter[T]{
def findReader(s: String) = if (s == tagValue) r else null
def findWriter(v: Any) = {
if (c.runtimeClass.isInstance(v)) (tag -> r)
if (c.runtimeClass.isInstance(v)) (tagKey, tagValue, r)
else null
}
}
class Node[T](rs: TaggedReadWriter[_ <: T]*) extends TaggedReadWriter[T]{
class Node[T](private[upickle] val tagKey: String, rs: TaggedReadWriter[_ <: T]*) extends TaggedReadWriter[T]{
def findReader(s: String) = scanChildren(rs)(_.findReader(s)).asInstanceOf[Reader[T]]
def findWriter(v: Any) = scanChildren(rs)(_.findWriter(v)).asInstanceOf[(String, ObjectWriter[T])]
def findWriter(v: Any) = scanChildren(rs)(_.findWriter(v)).asInstanceOf[(String, String, ObjectWriter[T])]
}
}

Expand All @@ -216,7 +219,7 @@ trait Types{ types =>

def taggedObjectContext[T](taggedReader: TaggedReader[T], index: Int): ObjVisitor[Any, T] = throw new Abort(taggedExpectedMsg)

def taggedWrite[T, R](w: ObjectWriter[T], tag: String, out: Visitor[_, R], v: T): R
def taggedWrite[T, R](w: ObjectWriter[T], tagKey: String, tagValue: String, out: Visitor[_, R], v: T): R

private[this] def scanChildren[T, V](xs: Seq[T])(f: T => V) = {
var x: V = null.asInstanceOf[V]
Expand Down Expand Up @@ -249,12 +252,14 @@ class CurrentlyDeriving[T]
* for `.equals` equality during writes to determine which tag to use.
*/
trait Annotator { this: Types =>
def annotate[V](rw: Reader[V], n: String): TaggedReader[V]
def annotate[V](rw: ObjectWriter[V], n: String, checker: Annotator.Checker): TaggedWriter[V]
def annotate[V](rw: ObjectWriter[V], n: String)(implicit ct: ClassTag[V]): TaggedWriter[V] =
annotate(rw, n, Annotator.Checker.Cls(ct.runtimeClass))
def annotate[V](rw: Reader[V], key: String, value: String): TaggedReader[V]
def annotate[V](rw: ObjectWriter[V], key: String, value: String, checker: Annotator.Checker): TaggedWriter[V]
def annotate[V](rw: ObjectWriter[V], key: String, value: String)(implicit ct: ClassTag[V]): TaggedWriter[V] =
annotate(rw, key, value, Annotator.Checker.Cls(ct.runtimeClass))
}
object Annotator{
def defaultTagKey = "$type"

sealed trait Checker
object Checker{
case class Cls(c: Class[_]) extends Checker
Expand Down
18 changes: 10 additions & 8 deletions upickle/implicits/src-2/upickle/implicits/internal/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import scala.annotation.{nowarn, StaticAnnotation}
import scala.language.experimental.macros
import compat._
import acyclic.file
import upickle.core.Annotator
import upickle.implicits.key

import language.higherKinds
Expand Down Expand Up @@ -102,7 +103,7 @@ object Macros {
annotate(tpe)(wrapObject(mod2))

}
def mergeTrait(subtrees: Seq[Tree], subtypes: Seq[Type], targetType: c.Type): Tree
def mergeTrait(tagKey: String, subtrees: Seq[Tree], subtypes: Seq[Type], targetType: c.Type): Tree

def derive(tpe: c.Type) = {
if (tpe.typeSymbol.asClass.isTrait || (tpe.typeSymbol.asClass.isAbstractClass && !tpe.typeSymbol.isJava)) {
Expand All @@ -125,11 +126,12 @@ object Macros {
"https://com-lihaoyi.github.io/upickle/#ManualSealedTraitPicklers"
fail(tpe, msg)
}else{
val tagKey = customKey(clsSymbol).getOrElse(Annotator.defaultTagKey)
val subTypes = fleshedOutSubtypes(tpe).toSeq.sortBy(_.typeSymbol.fullName)
// println("deriveTrait")
val subDerives = subTypes.map(subCls => q"implicitly[${typeclassFor(subCls)}]")
// println(Console.GREEN + "subDerives " + Console.RESET + subDrivess)
val merged = mergeTrait(subDerives, subTypes, tpe)
val merged = mergeTrait(tagKey, subDerives, subTypes, tpe)
merged
}
}
Expand Down Expand Up @@ -205,10 +207,10 @@ object Macros {
def annotate(tpe: c.Type)(derived: c.universe.Tree) = {
val sealedParent = tpe.baseClasses.find(_.asClass.isSealed)
sealedParent.fold(derived) { parent =>
val tagKey = customKey(parent).getOrElse(Annotator.defaultTagKey)
val tagValue = customKey(tpe.typeSymbol).getOrElse(TypeName(tpe.typeSymbol.fullName).decodedName.toString)

val index = customKey(tpe.typeSymbol).getOrElse(TypeName(tpe.typeSymbol.fullName).decodedName.toString)

q"${c.prefix}.annotate($derived, $index)"
q"${c.prefix}.annotate($derived, $tagKey, $tagValue)"
}
}

Expand Down Expand Up @@ -329,8 +331,8 @@ object Macros {
}
"""
}
def mergeTrait(subtrees: Seq[Tree], subtypes: Seq[Type], targetType: c.Type): Tree = {
q"${c.prefix}.Reader.merge[$targetType](..$subtrees)"
def mergeTrait(tagKey: String, subtrees: Seq[Tree], subtypes: Seq[Type], targetType: c.Type): Tree = {
q"${c.prefix}.Reader.merge[$targetType]($tagKey, ..$subtrees)"
}
}

Expand Down Expand Up @@ -401,7 +403,7 @@ object Macros {
}
"""
}
def mergeTrait(subtree: Seq[Tree], subtypes: Seq[Type], targetType: c.Type): Tree = {
def mergeTrait(tagKey: String, subtree: Seq[Tree], subtypes: Seq[Type], targetType: c.Type): Tree = {
q"${c.prefix}.Writer.merge[$targetType](..$subtree)"
}
}
Expand Down
9 changes: 5 additions & 4 deletions upickle/implicits/src-3/upickle/implicits/Readers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ trait ReadersVersionSpecific
}

inline if macros.isSingleton[T] then
annotate[T](SingletonReader[T](macros.getSingleton[T]), macros.tagName[T])
annotate[T](SingletonReader[T](macros.getSingleton[T]), macros.tagKey[T], macros.tagName[T])
else if macros.isMemberOfSealedHierarchy[T] then
annotate[T](reader, macros.tagName[T])
annotate[T](reader, macros.tagKey[T], macros.tagName[T])
else reader

case m: Mirror.SumOf[T] =>
Expand All @@ -87,7 +87,7 @@ trait ReadersVersionSpecific
.toList
.asInstanceOf[List[Reader[_ <: T]]]

Reader.merge[T](readers: _*)
Reader.merge[T](macros.tagKey[T], readers: _*)
}

inline def macroRAll[T](using m: Mirror.Of[T]): Reader[T] = inline m match {
Expand All @@ -99,8 +99,9 @@ trait ReadersVersionSpecific
inline given superTypeReader[T: Mirror.ProductOf, V >: T : Reader : Mirror.SumOf]
(using NotGiven[CurrentlyDeriving[V]]): Reader[T] = {
val actual = implicitly[Reader[V]].asInstanceOf[TaggedReader[T]]
val tagKey = macros.tagKey[T]
val tagName = macros.tagName[T]
new TaggedReader.Leaf(tagName, actual.findReader(tagName))
new TaggedReader.Leaf(tagKey, tagName, actual.findReader(tagName))
}

// see comment in MacroImplicits as to why Dotty's extension methods aren't used here
Expand Down
14 changes: 12 additions & 2 deletions upickle/implicits/src-3/upickle/implicits/Writers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,19 @@ trait WritersVersionSpecific
}

inline if macros.isSingleton[T] then
annotate[T](SingletonWriter[T](null.asInstanceOf[T]), macros.tagName[T], Annotator.Checker.Val(macros.getSingleton[T]))
annotate[T](
SingletonWriter[T](null.asInstanceOf[T]),
macros.tagKey[T],
macros.tagName[T],
Annotator.Checker.Val(macros.getSingleton[T]),
)
else if macros.isMemberOfSealedHierarchy[T] then
annotate[T](writer, macros.tagName[T], Annotator.Checker.Cls(implicitly[ClassTag[T]].runtimeClass))
annotate[T](
writer,
macros.tagKey[T],
macros.tagName[T],
Annotator.Checker.Cls(implicitly[ClassTag[T]].runtimeClass),
)
else writer

case _: Mirror.SumOf[T] =>
Expand Down
16 changes: 13 additions & 3 deletions upickle/implicits/src-3/upickle/implicits/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,23 @@ def writeSnippetsImpl[R, T, WS <: Tuple](thisOuter: Expr[upickle.core.Types with
'{()}
)

inline def isMemberOfSealedHierarchy[T]: Boolean = ${ isMemberOfSealedHierarchyImpl[T] }
def isMemberOfSealedHierarchyImpl[T](using Quotes, Type[T]): Expr[Boolean] =
private def sealedHierarchyParent[T](using Quotes, Type[T]): Option[quotes.reflect.Symbol] =
import quotes.reflect._

val parents = TypeRepr.of[T].baseClasses

Expr(parents.exists { p => p.flags.is(Flags.Sealed) })
// TODO - what if there are multiple?
parents.find(_.flags.is(Flags.Sealed))

inline def isMemberOfSealedHierarchy[T]: Boolean = ${ isMemberOfSealedHierarchyImpl[T] }
def isMemberOfSealedHierarchyImpl[T](using Quotes, Type[T]): Expr[Boolean] =
Expr(sealedHierarchyParent[T].isDefined)

inline def tagKey[T]: String = ${ tagKeyImpl[T] }
def tagKeyImpl[T](using Quotes, Type[T]): Expr[String] =
import quotes.reflect._

Expr(sealedHierarchyParent[T].flatMap(extractKey).getOrElse(upickle.core.Annotator.defaultTagKey))

inline def tagName[T]: String = ${ tagNameImpl[T] }
def tagNameImpl[T](using Quotes, Type[T]): Expr[String] =
Expand Down
41 changes: 20 additions & 21 deletions upickle/src/upickle/Api.scala
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,10 @@ object default extends AttributeTagged{
*/
object legacy extends LegacyApi
trait LegacyApi extends Api with Annotator{
def annotate[V](rw: Reader[V], n: String) = new TaggedReader.Leaf[V](n, rw)
def annotate[V](rw: Reader[V], key: String, value: String) = new TaggedReader.Leaf[V](key, value, rw)

def annotate[V](rw: ObjectWriter[V], n: String, checker: Annotator.Checker): TaggedWriter[V] = {
new TaggedWriter.Leaf[V](checker, n, rw)
def annotate[V](rw: ObjectWriter[V], key: String, value: String, checker: Annotator.Checker): TaggedWriter[V] = {
new TaggedWriter.Leaf[V](checker, key, value, rw)
}

def taggedExpectedMsg = "expected sequence"
Expand Down Expand Up @@ -287,9 +287,9 @@ trait LegacyApi extends Api with Annotator{
}

}
def taggedWrite[T, R](w: ObjectWriter[T], tag: String, out: Visitor[_, R], v: T): R = {
def taggedWrite[T, R](w: ObjectWriter[T], tagKey: String, tagValue: String, out: Visitor[_, R], v: T): R = {
val ctx = out.asInstanceOf[Visitor[Any, R]].visitArray(2, -1)
ctx.visitValue(ctx.subVisitor.visitString(objectTypeKeyWriteMap(tag), -1), -1)
ctx.visitValue(ctx.subVisitor.visitString(objectTypeKeyWriteMap(tagValue), -1), -1)

ctx.visitValue(w.write(ctx.subVisitor, v), -1)

Expand All @@ -303,19 +303,18 @@ trait LegacyApi extends Api with Annotator{
* of the attribute is.
*/
trait AttributeTagged extends Api with Annotator{
def tagName = "$type"
def annotate[V](rw: Reader[V], n: String) = {
new TaggedReader.Leaf[V](n, rw)
def annotate[V](rw: Reader[V], key: String, value: String) = {
new TaggedReader.Leaf[V](key, value, rw)
}

def annotate[V](rw: ObjectWriter[V], n: String, checker: Annotator.Checker): TaggedWriter[V] = {
new TaggedWriter.Leaf[V](checker, n, rw)
def annotate[V](rw: ObjectWriter[V], key: String, value: String, checker: Annotator.Checker): TaggedWriter[V] = {
new TaggedWriter.Leaf[V](checker, key, value, rw)
}

def taggedExpectedMsg = "expected dictionary"
private def isTagName(i: Any) = i match{
case s: BufferedValue.Str => s.value0.toString == tagName
case s: CharSequence => s.toString == tagName
private def isTagName(tagKey: String, i: Any) = i match{
case s: BufferedValue.Str => s.value0.toString == tagKey
case s: CharSequence => s.toString == tagKey
case _ => false
}
override def taggedObjectContext[T](taggedReader: TaggedReader[T], index: Int) = {
Expand All @@ -333,7 +332,7 @@ trait AttributeTagged extends Api with Annotator{
def visitKeyValue(s: Any): Unit = {
if (context != null) context.visitKeyValue(s)
else {
if (isTagName(s)) () //do nothing
if (isTagName(taggedReader.tagKey, s)) () //do nothing
else {
// otherwise, go slow path
val slowCtx = BufferedValue.Builder.visitObject(-1, true, index).narrow
Expand All @@ -359,12 +358,12 @@ trait AttributeTagged extends Api with Annotator{
}
}
def visitEnd(index: Int) = {
def missingKeyMsg = s"""Missing key "$tagName" for tagged dictionary"""
def missingKeyMsg = s"""Missing key "${taggedReader.tagKey}" for tagged dictionary"""
if (context == null) throw new Abort(missingKeyMsg)
else if (fastPath) context.visitEnd(index).asInstanceOf[T]
else{
val x = context.visitEnd(index).asInstanceOf[BufferedValue.Obj]
val keyAttr = x.value0.find(t => isTagName(t._1))
val keyAttr = x.value0.find(t => isTagName(taggedReader.tagKey, t._1))
.getOrElse(throw new Abort(missingKeyMsg))
._2
val key = keyAttr.asInstanceOf[BufferedValue.Str].value0.toString
Expand All @@ -376,7 +375,7 @@ trait AttributeTagged extends Api with Annotator{
for (p <- x.value0) {
val (k0, v) = p
val k = k0
if (!isTagName(k)){
if (!isTagName(taggedReader.tagKey, k)){
val keyVisitor = ctx2.visitKey(-1)

ctx2.visitKeyValue(BufferedValue.transform(k, keyVisitor))
Expand All @@ -389,15 +388,15 @@ trait AttributeTagged extends Api with Annotator{

}
}
def taggedWrite[T, R](w: ObjectWriter[T], tag: String, out: Visitor[_, R], v: T): R = {
def taggedWrite[T, R](w: ObjectWriter[T], tagKey: String, tagValue: String, out: Visitor[_, R], v: T): R = {

if (w.isInstanceOf[SingletonWriter[_]]) out.visitString(tag, -1)
if (w.isInstanceOf[SingletonWriter[_]]) out.visitString(tagValue, -1)
else {
val ctx = out.asInstanceOf[Visitor[Any, R]].visitObject(w.length(v) + 1, true, -1)
val keyVisitor = ctx.visitKey(-1)

ctx.visitKeyValue(keyVisitor.visitString(tagName, -1))
ctx.visitValue(ctx.subVisitor.visitString(objectTypeKeyWriteMap(tag), -1), -1)
ctx.visitKeyValue(keyVisitor.visitString(tagKey, -1))
ctx.visitValue(ctx.subVisitor.visitString(objectTypeKeyWriteMap(tagValue), -1), -1)
w.writeToObject(ctx, v)
val res = ctx.visitEnd(-1)
res
Expand Down
Loading

0 comments on commit b472fc9

Please sign in to comment.