Skip to content

Commit

Permalink
Derive description from scaladoc for Scala 3 (zio#646)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Jan 30, 2024
1 parent df4bbc0 commit 244546f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ private case class DeriveSchema()(using val ctx: Quotes) {
val reflectionUtils = ReflectionUtils(ctx)
import reflectionUtils.{MirrorType, Mirror, summonOptional}
import ctx.reflect._

case class Frame(ref: Term, tpe: TypeRepr)
case class Stack(frames: List[Frame]) {
def find(tpe: TypeRepr): Option[Term] = frames.find(_.tpe =:= tpe).map(_.ref)
Expand All @@ -33,7 +33,7 @@ private case class DeriveSchema()(using val ctx: Quotes) {

def size = frames.size

override def toString =
override def toString =
frames.map(f => s"${f.ref.show} : ${f.tpe.show}").mkString("Stack(", ", ", ")")
}

Expand All @@ -52,7 +52,7 @@ private case class DeriveSchema()(using val ctx: Quotes) {
val result = stack.find(typeRepr) match {
case Some(ref) =>
'{ Schema.defer(${ref.asExprOf[Schema[T]]}) }
case None =>
case None =>
val summoned = if (!top) Expr.summon[Schema[T]] else None
if (!top && summoned.isDefined) {
'{
Expand Down Expand Up @@ -121,9 +121,13 @@ private case class DeriveSchema()(using val ctx: Quotes) {
val selfRefSymbol = Symbol.newVal(Symbol.spliceOwner, s"derivedSchema${stack.size}", TypeRepr.of[Schema[T]], Flags.Lazy, Symbol.noSymbol)
val selfRef = Ref(selfRefSymbol)

val docAnnotationExpr = TypeRepr.of[T].typeSymbol.docstring.map { docstring =>
val docstringExpr = Expr(docstring)
'{zio.schema.annotation.description(${docstringExpr})}
}
val typeInfo = '{TypeId.parse(${Expr(TypeRepr.of[T].show)})}
val annotationExprs = TypeRepr.of[T].typeSymbol.annotations.filter (filterAnnotation).map (_.asExpr)
val annotations = '{zio.Chunk.fromIterable (${Expr.ofSeq (annotationExprs)})}
val annotations = '{zio.Chunk.fromIterable (${Expr.ofSeq (annotationExprs)}) ++ zio.Chunk.fromIterable(${Expr.ofSeq(docAnnotationExpr.toList)}) }

val constructor = '{() => ${Ref(TypeRepr.of[T].typeSymbol.companionModule).asExprOf[T]}}
val ctor = typeRprOf[T](0).typeSymbol.companionModule
Expand Down Expand Up @@ -163,8 +167,12 @@ private case class DeriveSchema()(using val ctx: Quotes) {
val paramAnns = fromConstructor(TypeRepr.of[T].typeSymbol)
val constructor = caseClassConstructor[T](mirror).asExpr

val docAnnotationExpr = TypeRepr.of[T].typeSymbol.docstring.map { docstring =>
val docstringExpr = Expr(docstring)
'{zio.schema.annotation.description(${docstringExpr})}
}
val annotationExprs = TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr)
val annotations = '{ zio.Chunk.fromIterable(${Expr.ofSeq(annotationExprs)}) }
val annotations = '{ zio.Chunk.fromIterable(${Expr.ofSeq(annotationExprs)}) ++ zio.Chunk.fromIterable(${Expr.ofSeq(docAnnotationExpr.toList)}) }
val typeInfo = '{TypeId.parse(${Expr(TypeRepr.of[T].show)})}

val applied = if (labels.length <= 22) {
Expand All @@ -182,7 +190,7 @@ private case class DeriveSchema()(using val ctx: Quotes) {
)

val fieldsAndFieldTypes = typesAndLabels.map { case (tpe, label) => deriveField[T](tpe, label, paramAnns.getOrElse(label, List.empty), newStack) }
val (fields, fieldTypes) = fieldsAndFieldTypes.unzip
val (fields, fieldTypes) = fieldsAndFieldTypes.unzip
val args = List(typeInfo) ++ fields ++ Seq(constructor) ++ Seq(annotations)
val terms = Expr.ofTupleFromSeq(args)

Expand Down Expand Up @@ -271,7 +279,7 @@ private case class DeriveSchema()(using val ctx: Quotes) {
}


private def fromDeclarations(from: Symbol): List[(String, List[Expr[Any]])] =
private def fromDeclarations(from: Symbol): List[(String, List[Expr[Any]])] =
from.declaredFields.map {
field =>
field.name -> field.annotations.filter(filterAnnotation).map(_.asExpr)
Expand Down Expand Up @@ -332,12 +340,16 @@ private case class DeriveSchema()(using val ctx: Quotes) {
val isSimpleEnum: Boolean = !TypeRepr.of[T].typeSymbol.children.map(_.declaredFields.length).exists( _ > numParentFields )
val hasSimpleEnumAnn: Boolean = TypeRepr.of[T].typeSymbol.hasAnnotation(TypeRepr.of[_root_.zio.schema.annotation.simpleEnum].typeSymbol)

val docAnnotationExpr = TypeRepr.of[T].typeSymbol.docstring.map { docstring =>
val docstringExpr = Expr(docstring)
'{zio.schema.annotation.description(${docstringExpr})}
}
val annotationExprs = (isSimpleEnum, hasSimpleEnumAnn) match {
case (true, false) => TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr).+:('{zio.schema.annotation.simpleEnum(true)})
case (false, true) => throw new Exception(s"${TypeRepr.of[T].typeSymbol.name} must be a simple Enum")
case _ => TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr)
}
val annotations = '{ zio.Chunk.fromIterable(${Expr.ofSeq(annotationExprs)}) }
val annotations = '{ zio.Chunk.fromIterable(${Expr.ofSeq(annotationExprs)}) ++ zio.Chunk.fromIterable(${Expr.ofSeq(docAnnotationExpr.toList)}) }

val typeInfo = '{TypeId.parse(${Expr(TypeRepr.of[T].show)})}

Expand All @@ -346,12 +358,12 @@ private case class DeriveSchema()(using val ctx: Quotes) {
val terms = Expr.ofTupleFromSeq(args)
val ctor = TypeRepr.of[Enum2[_, _, _]].typeSymbol.primaryConstructor

val typeArgs =
val typeArgs =
(types.appended(TypeRepr.of[T])).map { tpe =>
tpe.asType match
case '[tt] => TypeTree.of[tt]
}

val typeTree = enumTypeTree[T](labels.length)

Apply(
Expand All @@ -371,7 +383,7 @@ private case class DeriveSchema()(using val ctx: Quotes) {
case '{ type tt <: Schema[T]; $ex : `tt` } =>
'{
${Block(
List(lazyValDef),
List(lazyValDef),
selfRef
).asExpr}.asInstanceOf[tt]
}
Expand Down Expand Up @@ -416,16 +428,16 @@ private case class DeriveSchema()(using val ctx: Quotes) {

if (anns.nonEmpty) {
val (newName, newNameValue) = anns.collectFirst {
case ann if ann.isExprOf[fieldName] =>
case ann if ann.isExprOf[fieldName] =>
val fieldNameAnn = ann.asExprOf[fieldName]
('{${fieldNameAnn}.name}, extractFieldNameValue(fieldNameAnn))
}.getOrElse((Expr(name), name))

val f = '{ Field($newName, $schema, $chunk, $validator, $get, $set)}
addFieldName(newNameValue)(f) // TODO: we need to pass the evaluated annotation value instead of name
} else {
val f = '{ Field(${Expr(name)}, $schema, $chunk, $validator, $get, $set) }
addFieldName(name)(f)
val f = '{ Field(${Expr(name)}, $schema, $chunk, $validator, $get, $set) }
addFieldName(name)(f)
}
}
}
Expand Down Expand Up @@ -484,21 +496,25 @@ private case class DeriveSchema()(using val ctx: Quotes) {
val r = TypeRepr.of[R]
val t = TypeRepr.of[T]
val nameT = ConstantType(StringConstant(name))
val fieldWithName = withFieldName.appliedTo(List(r, nameT, t))
val fieldWithName = withFieldName.appliedTo(List(r, nameT, t))
(Select.unique(f.asTerm, "asInstanceOf").appliedToType(fieldWithName).asExprOf[F], nameT)
}


// sealed case class Case[A, Z](id: String, codec: Schema[A], unsafeDeconstruct: Z => A, annotations: Chunk[Any] = Chunk.empty) {
def deriveCase[T: Type](repr: TypeRepr, label: String, stack: Stack)(using Quotes) = {
repr.asType match { case '[t] =>
repr.asType match { case '[t] =>
val schema = deriveSchema[t](stack)
val stringExpr = Expr(label)

val docAnnotationExpr = TypeRepr.of[t].typeSymbol.docstring.map { docstring =>
val docstringExpr = Expr(docstring)
'{zio.schema.annotation.description(${docstringExpr})}
}
val annotationExprs = TypeRepr.of[t].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr)
val annotations = '{ zio.Chunk.fromIterable(${Expr.ofSeq(annotationExprs)}) }
val annotations = '{ zio.Chunk.fromIterable(${Expr.ofSeq(annotationExprs)}) ++ zio.Chunk.fromIterable(${Expr.ofSeq(docAnnotationExpr.toList)}) }

val unsafeDeconstruct = '{
val unsafeDeconstruct = '{
(z: T) => z.asInstanceOf[t]
}
val construct = '{
Expand All @@ -525,14 +541,14 @@ private case class DeriveSchema()(using val ctx: Quotes) {
a.tpe.typeSymbol.maybeOwner.isNoSymbol ||
a.tpe.typeSymbol.owner.fullName != "scala.annotation.internal"

def extractFieldNameValue(attribute: Expr[fieldName]): String =
def extractFieldNameValue(attribute: Expr[fieldName]): String =
attribute.asTerm match {
// Apply(Select(New(Ident(fieldName)),<init>),List(Literal(Constant(renamed))))
case Apply(_, List(Literal(StringConstant(name)))) =>
name
}
}

def caseClassTypeTree[T: Type](arity: Int): TypeTree =
def caseClassTypeTree[T: Type](arity: Int): TypeTree =
arity match {
case 0 => TypeTree.of[CaseClass0[T]]
case 1 => TypeTree.of[CaseClass1[_, T]]
Expand All @@ -559,7 +575,7 @@ private case class DeriveSchema()(using val ctx: Quotes) {
case 22 => TypeTree.of[CaseClass22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]]
}

def typeRprOf[T: Type](arity: Int): TypeRepr =
def typeRprOf[T: Type](arity: Int): TypeRepr =
arity match {
case 0 => TypeRepr.of[CaseClass0[T]]
case 1 => TypeRepr.of[CaseClass1[_, T]]
Expand All @@ -586,7 +602,7 @@ private case class DeriveSchema()(using val ctx: Quotes) {
case 22 => TypeRepr.of[CaseClass22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]]
}

def caseClassWithFieldsType(arity: Int): TypeRepr =
def caseClassWithFieldsType(arity: Int): TypeRepr =
arity match {
case 1 => TypeRepr.of[CaseClass1.WithFields]
case 2 => TypeRepr.of[CaseClass2.WithFields]
Expand All @@ -612,7 +628,7 @@ private case class DeriveSchema()(using val ctx: Quotes) {
case 22 => TypeRepr.of[CaseClass22.WithFields]
}

def enumTypeTree[T: Type](arity: Int): TypeTree =
def enumTypeTree[T: Type](arity: Int): TypeTree =
arity match {
case 0 => TypeTree.of[CaseClass0[T]]
case 1 => TypeTree.of[Enum1[_, T]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package zio.schema

import zio.Chunk
import zio.test.*
import zio.schema.annotation.simpleEnum
import zio.schema.annotation.*

trait VersionSpecificDeriveSchemaSpec extends ZIOSpecDefault {
/** ObjectWithDoc doc */
object ObjectWithDoc

case class ContainerFields(field1: Option[String])

object ContainerFields {
Expand All @@ -22,14 +25,26 @@ trait VersionSpecificDeriveSchemaSpec extends ZIOSpecDefault {

final case class AutoDerives(i: Int) derives Schema

/** AutoDerives scaladoc */
final case class AutoDerivesWithDoc(i: Int) derives Schema

enum Colour(val rgb: Int) {
case Red extends Colour(0xff0000)
case Green extends Colour(0x00ff00)
case Blue extends Colour(0x0000ff)
}

/** Colour scaladoc */
@caseName("Red")
enum ColourWithDoc(val rgb: Int) {
/** Red scaladoc */
case Red extends ColourWithDoc(0xff0000)
case Green extends ColourWithDoc(0x00ff00)
case Blue extends ColourWithDoc(0x0000ff)
}

def versionSpecificSuite = Spec.labeled(
"Scala 3 specific tests",
"Scala 3 specific tests",
suite("Derivation")(
test("correctly derives case class with `derives` keyword") {
val expected: Schema[AutoDerives] = Schema.CaseClass1(
Expand All @@ -47,7 +62,19 @@ trait VersionSpecificDeriveSchemaSpec extends ZIOSpecDefault {
test("correctly assigns simpleEnum to enum") {
val derived: Schema[Colour] = DeriveSchema.gen[Colour]
assertTrue(derived.annotations == Chunk(simpleEnum(true)))
}
},
test("correctly adds scaladoc as description"){
val colourWithDoc: Schema[ColourWithDoc] = DeriveSchema.gen[ColourWithDoc]
val autoDerivesWithDoc: Schema[AutoDerivesWithDoc] = Schema[AutoDerivesWithDoc]
val objectWithDoc: Schema[ObjectWithDoc.type] = DeriveSchema.gen[ObjectWithDoc.type]
val redAnnotations = colourWithDoc.asInstanceOf[Schema.Enum[ColourWithDoc]].cases.find(_.id == "Red").get.schema.annotations.find(_.isInstanceOf[description])
assertTrue(
colourWithDoc.annotations.find(_.isInstanceOf[description]) == Some(description("/** Colour scaladoc */")),
//redAnnotations == Some(description("/** Red scaladoc */")), fix #651 to make this work
autoDerivesWithDoc.annotations.find(_.isInstanceOf[description]) == Some(description("/** AutoDerives scaladoc */")),
objectWithDoc.annotations.find(_.isInstanceOf[description]) == Some(description("/** ObjectWithDoc doc */")),
)
},
)
)
}

0 comments on commit 244546f

Please sign in to comment.