diff --git a/Cargo.toml b/Cargo.toml index 59281163..626563ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ serde_yaml = { version = "0.8", optional = true } utoipa-gen = { version = "1.0.2", path = "./utoipa-gen" } [dev-dependencies] +assert-json-diff = "2" actix-web = { version = "4" } paste = "1" chrono = { version = "0.4", features = ["serde"] } diff --git a/tests/component_derive_test.rs b/tests/component_derive_test.rs index da3c84f7..e7b0ee63 100644 --- a/tests/component_derive_test.rs +++ b/tests/component_derive_test.rs @@ -1,11 +1,12 @@ #![cfg(feature = "serde_json")] use std::{borrow::Cow, cell::RefCell, collections::HashMap, marker::PhantomData, vec}; +use assert_json_diff::assert_json_eq; #[cfg(any(feature = "chrono", feature = "chrono_with_format"))] use chrono::{Date, DateTime, Duration, Utc}; use serde::Serialize; -use serde_json::Value; +use serde_json::{json, Value}; use utoipa::{Component, OpenApi}; use crate::common::get_json_path; @@ -507,9 +508,100 @@ fn derive_with_box_and_refcell() { } #[test] -fn derive_complex_enum_with_named_and_unnamed_fields() { - struct Foo; - let complex_enum = api_doc! { +fn derive_simple_enum() { + let value: Value = api_doc! { + #[derive(Serialize)] + enum Bar { + A, + B, + C, + } + }; + + assert_json_eq!( + value, + json!({ + "enum": [ + "A", + "B", + "C", + ], + "type": "string", + }) + ); +} + +#[test] +fn derive_simple_enum_serde_tag() { + let value: Value = api_doc! { + #[derive(Serialize)] + #[serde(tag = "tag")] + enum Bar { + A, + B, + C, + } + }; + + assert_json_eq!( + value, + json!({ + "oneOf": [ + { + "type": "object", + "properties": { + "tag": { + "type": "string", + "enum": [ + "A", + ], + }, + }, + "required": [ + "tag", + ], + }, + { + "type": "object", + "properties": { + "tag": { + "type": "string", + "enum": [ + "B", + ], + }, + }, + "required": [ + "tag", + ], + }, + { + "type": "object", + "properties": { + "tag": { + "type": "string", + "enum": [ + "C", + ], + }, + }, + "required": [ + "tag", + ], + }, + ], + }) + ); +} + +/// Derive a complex enum with named and unnamed fields. +#[test] +fn derive_complex_enum() { + #[derive(Serialize)] + struct Foo(String); + + let value: Value = api_doc! { + #[derive(Serialize)] enum Bar { UnitValue, NamedFields { @@ -520,17 +612,245 @@ fn derive_complex_enum_with_named_and_unnamed_fields() { } }; - common::assert_json_array_len(complex_enum.get("oneOf").unwrap(), 3); - assert_value! {complex_enum=> - "oneOf.[0].type" = r###""string""###, "Complex enum unit value type" - "oneOf.[0].enum" = r###"["UnitValue"]"###, "Complex enum unit value enum" - "oneOf.[1].type" = r###""object""###, "Complex enum named fields type" - "oneOf.[1].properties.NamedFields.type" = r###""object""###, "Complex enum named fields object type" - "oneOf.[1].properties.NamedFields.properties.id.type" = r###""string""###, "Complex enum named fields id type" - "oneOf.[1].properties.NamedFields.properties.names.type" = r###""array""###, "Complex enum named fields names type" - "oneOf.[2].type" = r###""object""###, "Complex enum unnamed fields type" - "oneOf.[2].properties.UnnamedFields.$ref" = r###""#/components/schemas/Foo""###, "Complex enum unnamed fields type" - } + assert_json_eq!( + value, + json!({ + "oneOf": [ + { + "type": "string", + "enum": [ + "UnitValue", + ], + }, + { + "type": "object", + "properties": { + "NamedFields": { + "type": "object", + "properties": { + "id": { + "type": "string", + }, + "names": { + "type": "array", + "items": { + "type": "string", + }, + }, + }, + "required": [ + "id", + ], + }, + }, + }, + { + "type": "object", + "properties": { + "UnnamedFields": { + "$ref": "#/components/schemas/Foo", + }, + }, + }, + ], + }) + ); +} + +#[test] +fn derive_complex_enum_serde_rename_all() { + #[derive(Serialize)] + struct Foo(String); + + let value: Value = api_doc! { + #[derive(Serialize)] + #[serde(rename_all = "snake_case")] + enum Bar { + UnitValue, + NamedFields { + id: &'static str, + names: Option> + }, + UnnamedFields(Foo), + } + }; + + assert_json_eq!( + value, + json!({ + "oneOf": [ + { + "type": "string", + "enum": [ + "unit_value", + ], + }, + { + "type": "object", + "properties": { + "named_fields": { + "type": "object", + "properties": { + "id": { + "type": "string", + }, + "names": { + "type": "array", + "items": { + "type": "string", + }, + }, + }, + "required": [ + "id", + ], + }, + }, + }, + { + "type": "object", + "properties": { + "unnamed_fields": { + "$ref": "#/components/schemas/Foo", + }, + }, + }, + ], + }) + ); +} + +#[test] +fn derive_complex_enum_serde_rename_variant() { + #[derive(Serialize)] + struct Foo(String); + + let value: Value = api_doc! { + #[derive(Serialize)] + enum Bar { + #[serde(rename = "renamed_unit_value")] + UnitValue, + #[serde(rename = "renamed_named_fields")] + NamedFields { + #[serde(rename = "renamed_id")] + id: &'static str, + #[serde(rename = "renamed_names")] + names: Option> + }, + #[serde(rename = "renamed_unnamed_fields")] + UnnamedFields(Foo), + } + }; + + assert_json_eq!( + value, + json!({ + "oneOf": [ + { + "type": "string", + "enum": [ + "renamed_unit_value", + ], + }, + { + "type": "object", + "properties": { + "renamed_named_fields": { + "type": "object", + "properties": { + "renamed_id": { + "type": "string", + }, + "renamed_names": { + "type": "array", + "items": { + "type": "string", + }, + }, + }, + "required": [ + "renamed_id", + ], + }, + }, + }, + { + "type": "object", + "properties": { + "renamed_unnamed_fields": { + "$ref": "#/components/schemas/Foo", + }, + }, + }, + ], + }) + ); +} + +/// Derive a complex enum with the serde `tag` container attribute applied for internal tagging. +/// Note that tuple fields are not supported. +#[test] +fn derive_complex_enum_serde_tag() { + #[derive(Serialize)] + struct Foo(String); + + let value: Value = api_doc! { + #[derive(Serialize)] + #[serde(tag = "tag")] + enum Bar { + UnitValue, + NamedFields { + id: &'static str, + names: Option> + }, + } + }; + + assert_json_eq!( + value, + json!({ + "oneOf": [ + { + "type": "object", + "properties": { + "tag": { + "type": "string", + "enum": [ + "UnitValue", + ], + }, + }, + "required": [ + "tag", + ], + }, + { + "type": "object", + "properties": { + "id": { + "type": "string", + }, + "names": { + "type": "array", + "items": { + "type": "string", + }, + }, + "tag": { + "type": "string", + "enum": [ + "NamedFields", + ], + }, + }, + "required": [ + "id", + "tag", + ], + }, + ], + }) + ); } #[test] diff --git a/utoipa-gen/src/schema.rs b/utoipa-gen/src/schema.rs index 7d38d12e..31c1b12f 100644 --- a/utoipa-gen/src/schema.rs +++ b/utoipa-gen/src/schema.rs @@ -214,47 +214,43 @@ pub mod serde { use std::str::FromStr; - use proc_macro2::{Span, TokenTree}; + use proc_macro2::{Ident, Span, TokenTree}; use proc_macro_error::ResultExt; use syn::{buffer::Cursor, Attribute, Error}; - #[cfg_attr(feature = "debug", derive(Debug))] - pub enum Serde { - Container(SerdeContainer), - Value(SerdeValue), - } - - impl Serde { - #[inline] - fn parse_next_lit_str(next: Cursor) -> Option<(String, Span)> { - match next.token_tree() { - Some((tt, next)) => match tt { - TokenTree::Punct(punct) if punct.as_char() == '=' => { - Serde::parse_next_lit_str(next) - } - TokenTree::Literal(literal) => { - Some((literal.to_string().replace('\"', ""), literal.span())) - } - _ => None, - }, + #[inline] + fn parse_next_lit_str(next: Cursor) -> Option<(String, Span)> { + match next.token_tree() { + Some((tt, next)) => match tt { + TokenTree::Punct(punct) if punct.as_char() == '=' => parse_next_lit_str(next), + TokenTree::Literal(literal) => { + Some((literal.to_string().replace('\"', ""), literal.span())) + } _ => None, - } + }, + _ => None, } + } + + #[derive(Default)] + #[cfg_attr(feature = "debug", derive(Debug))] + pub struct SerdeValue { + pub skip: Option, + pub rename: Option, + } - fn parse_container(input: syn::parse::ParseStream) -> syn::Result { - let mut container = SerdeContainer::default(); + impl SerdeValue { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut value = Self::default(); input.step(|cursor| { let mut rest = *cursor; while let Some((tt, next)) = rest.token_tree() { match tt { - TokenTree::Ident(ident) if ident == "rename_all" => { - if let Some((literal, span)) = Serde::parse_next_lit_str(next) { - container.rename_all = Some( - literal - .parse::() - .map_err(|error| Error::new(span, error.to_string()))?, - ); + TokenTree::Ident(ident) if ident == "skip" => value.skip = Some(true), + TokenTree::Ident(ident) if ident == "rename" => { + if let Some((literal, _)) = parse_next_lit_str(next) { + value.rename = Some(literal) }; } _ => (), @@ -265,22 +261,51 @@ pub mod serde { Ok(((), rest)) })?; - Ok(Serde::Container(container)) + Ok(value) } + } - fn parse_value(input: syn::parse::ParseStream) -> syn::Result { - let mut value = SerdeValue::default(); + /// Attributes defined within a `#[serde(...)]` container attribute. + #[derive(Default)] + #[cfg_attr(feature = "debug", derive(Debug))] + pub struct SerdeContainer { + pub rename_all: Option, + pub tag: Option, + } + + impl SerdeContainer { + /// Parse a single serde attribute, currently `rename_all = ...` and `tag = ...` attributes + /// are supported. + fn parse_attribute(&mut self, ident: Ident, next: Cursor) -> syn::Result<()> { + match ident.to_string().as_str() { + "rename_all" => { + if let Some((literal, span)) = parse_next_lit_str(next) { + self.rename_all = Some( + literal + .parse::() + .map_err(|error| Error::new(span, error.to_string()))?, + ); + }; + } + "tag" => { + if let Some((literal, _span)) = parse_next_lit_str(next) { + self.tag = Some(literal) + } + } + _ => {} + } + Ok(()) + } + + /// Parse the attributes inside a `#[serde(...)]` container attribute. + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut container = Self::default(); input.step(|cursor| { let mut rest = *cursor; while let Some((tt, next)) = rest.token_tree() { match tt { - TokenTree::Ident(ident) if ident == "skip" => value.skip = Some(true), - TokenTree::Ident(ident) if ident == "rename" => { - if let Some((literal, _)) = Serde::parse_next_lit_str(next) { - value.rename = Some(literal) - }; - } + TokenTree::Ident(ident) => container.parse_attribute(ident, next)?, _ => (), } @@ -289,41 +314,28 @@ pub mod serde { Ok(((), rest)) })?; - Ok(Serde::Value(value)) + Ok(container) } } - #[derive(Default)] - #[cfg_attr(feature = "debug", derive(Debug))] - pub struct SerdeValue { - pub skip: Option, - pub rename: Option, - } - - #[derive(Default)] - #[cfg_attr(feature = "debug", derive(Debug))] - pub struct SerdeContainer { - pub rename_all: Option, - } - - pub fn parse_value(attributes: &[Attribute]) -> Option { + pub fn parse_value(attributes: &[Attribute]) -> Option { attributes .iter() .find(|attribute| attribute.path.is_ident("serde")) .map(|serde_attribute| { serde_attribute - .parse_args_with(Serde::parse_value) + .parse_args_with(SerdeValue::parse) .unwrap_or_abort() }) } - pub fn parse_container(attributes: &[Attribute]) -> Option { + pub fn parse_container(attributes: &[Attribute]) -> Option { attributes .iter() .find(|attribute| attribute.path.is_ident("serde")) .map(|serde_attribute| { serde_attribute - .parse_args_with(Serde::parse_container) + .parse_args_with(SerdeContainer::parse) .unwrap_or_abort() }) } diff --git a/utoipa-gen/src/schema/component.rs b/utoipa-gen/src/schema/component.rs index b08541f9..5a164ad6 100644 --- a/utoipa-gen/src/schema/component.rs +++ b/utoipa-gen/src/schema/component.rs @@ -1,5 +1,3 @@ -use std::mem; - use proc_macro2::{Ident, TokenStream as TokenStream2}; use proc_macro_error::{abort, ResultExt}; use quote::{quote, ToTokens}; @@ -20,7 +18,7 @@ use self::{ }; use super::{ - serde::{self, RenameRule, Serde}, + serde::{self, RenameRule, SerdeContainer, SerdeValue}, ComponentPart, GenericType, ValueType, }; @@ -391,6 +389,48 @@ struct SimpleEnum<'a> { attributes: &'a [Attribute], } +impl SimpleEnum<'_> { + /// Produce tokens that represent each variant for the situation where the serde enum tag = + /// "" attribute applies. + fn tagged_variants_tokens(tag: String, enum_values: Array) -> TokenStream2 { + let len = enum_values.len(); + let items: TokenStream2 = enum_values + .iter() + .map(|enum_value: &String| { + quote! { + utoipa::openapi::schema::ObjectBuilder::new() + .property( + #tag, + utoipa::openapi::schema::PropertyBuilder::new() + .component_type(utoipa::openapi::ComponentType::String) + .enum_values::<[&str; 1], &str>(Some([#enum_value])) + ) + .required(#tag) + } + }) + .map(|object: TokenStream2| { + quote! { + .item(#object) + } + }) + .collect(); + quote! { + Into::::into(utoipa::openapi::OneOf::with_capacity(#len)) + #items + } + } + + /// Produce tokens that represent each variant. + fn variants_tokens(enum_values: Array) -> TokenStream2 { + let len = enum_values.len(); + quote! { + utoipa::openapi::PropertyBuilder::new() + .component_type(utoipa::openapi::ComponentType::String) + .enum_values::<[&str; #len], &str>(Some(#enum_values)) + } + } +} + impl ToTokens for SimpleEnum<'_> { fn to_tokens(&self, tokens: &mut TokenStream2) { let mut container_rules = serde::parse_container(self.attributes); @@ -411,12 +451,13 @@ impl ToTokens for SimpleEnum<'_> { } }) .collect::>(); - let len = enum_values.len(); - tokens.extend(quote! { - utoipa::openapi::PropertyBuilder::new() - .component_type(utoipa::openapi::ComponentType::String) - .enum_values::<[&str; #len], &str>(Some(#enum_values)) + tokens.extend(match container_rules { + Some(serde_container) if serde_container.tag.is_some() => { + let tag = serde_container.tag.expect("Expected tag to be present"); + Self::tagged_variants_tokens(tag, enum_values) + } + _ => Self::variants_tokens(enum_values), }); let attrs = attr::parse_component_attr::>(self.attributes); @@ -441,6 +482,85 @@ struct ComplexEnum<'a> { attributes: &'a [Attribute], } +impl ComplexEnum<'_> { + fn unit_variant_tokens(variant_name: String) -> TokenStream2 { + quote! { + utoipa::openapi::PropertyBuilder::new() + .component_type(utoipa::openapi::ComponentType::String) + .enum_values::<[&str; 1], &str>(Some([#variant_name])) + } + } + /// Produce tokens that represent a variant of a [`ComplexEnum`]. + fn variant_tokens(variant_name: String, variant: &Variant) -> TokenStream2 { + match &variant.fields { + Fields::Named(named_fields) => { + let named_enum = NamedStructComponent { + attributes: &variant.attrs, + fields: &named_fields.named, + generics: None, + alias: None, + }; + + quote! { + utoipa::openapi::schema::ObjectBuilder::new() + .property(#variant_name, #named_enum) + } + } + Fields::Unnamed(unnamed_fields) => { + let unnamed_enum = UnnamedStructComponent { + attributes: &variant.attrs, + fields: &unnamed_fields.unnamed, + }; + + quote! { + utoipa::openapi::schema::ObjectBuilder::new() + .property(#variant_name, #unnamed_enum) + } + } + Fields::Unit => Self::unit_variant_tokens(variant_name), + } + } + + /// Produce tokens that represent a variant of a [`ComplexEnum`] where serde enum attribute + /// `tag = ` applies. + fn tagged_variant_tokens(tag: &str, variant_name: String, variant: &Variant) -> TokenStream2 { + match &variant.fields { + Fields::Named(named_fields) => { + let named_enum = NamedStructComponent { + attributes: &variant.attrs, + fields: &named_fields.named, + generics: None, + alias: None, + }; + + let variant_name_tokens = Self::unit_variant_tokens(variant_name); + + quote! { + #named_enum + .property(#tag, #variant_name_tokens) + .required(#tag) + } + } + Fields::Unnamed(_) => { + abort!( + variant, + "Unnamed (tuple) enum variants are unsupported for internally tagged enums using the `tag = ` serde attribute"; + + help = "Try using a different serde enum representation"; + ); + } + Fields::Unit => { + let variant_tokens = Self::unit_variant_tokens(variant_name); + quote! { + utoipa::openapi::schema::ObjectBuilder::new() + .property(#tag, #variant_tokens) + .required(#tag) + } + } + } + } +} + impl ToTokens for ComplexEnum<'_> { fn to_tokens(&self, tokens: &mut TokenStream2) { if self @@ -456,72 +576,52 @@ impl ToTokens for ComplexEnum<'_> { ); } - let capasity = self.variants.len(); - tokens.extend(quote! { - Into::::into(utoipa::openapi::OneOf::with_capacity(#capasity)) - }); + let capacity = self.variants.len(); - let mut container_rule = serde::parse_container(self.attributes); + let mut container_rules = serde::parse_container(self.attributes); + let tag: Option = if let Some(serde_container) = &mut container_rules { + serde_container.tag.take() + } else { + None + }; // serde, externally tagged format supported by now - self.variants + let items: TokenStream2 = self + .variants .iter() - .filter_map(|variant| { - let variant_rules = serde::parse_value(&variant.attrs); - if is_not_skipped(&variant_rules) { - Some((variant, variant_rules)) + .filter_map(|variant: &Variant| { + let variant_serde_rules = serde::parse_value(&variant.attrs); + if is_not_skipped(&variant_serde_rules) { + Some((variant, variant_serde_rules)) } else { None } }) - .map(|(variant, mut variant_rule)| match &variant.fields { - Fields::Named(named_fields) => { - let named_enum = NamedStructComponent { - attributes: &variant.attrs, - fields: &named_fields.named, - generics: None, - alias: None, - }; - let name = &*variant.ident.to_string(); - - let renamed = rename_variant(&mut container_rule, &mut variant_rule, name) - .unwrap_or_else(|| String::from(name)); - - quote! { - utoipa::openapi::schema::ObjectBuilder::new() - .property(#renamed, #named_enum) - } - } - Fields::Unnamed(unnamed_fields) => { - let unnamed_enum = UnnamedStructComponent { - attributes: &variant.attrs, - fields: &unnamed_fields.unnamed, - }; - let name = &*variant.ident.to_string(); - let renamed = rename_variant(&mut container_rule, &mut variant_rule, name) - .unwrap_or_else(|| String::from(name)); - - quote! { - utoipa::openapi::schema::ObjectBuilder::new() - .property(#renamed, #unnamed_enum) - } - } - Fields::Unit => { - let mut enum_values = Punctuated::::new(); - enum_values.push(variant.clone()); - - SimpleEnum { - attributes: self.attributes, - variants: &enum_values, - } - .to_token_stream() + .map(|(variant, mut variant_serde_rules)| { + let variant_name = &*variant.ident.to_string(); + let variant_name = + rename_variant(&mut container_rules, &mut variant_serde_rules, variant_name) + .unwrap_or_else(|| String::from(variant_name)); + + if let Some(tag) = &tag { + Self::tagged_variant_tokens(&tag, variant_name, variant) + } else { + Self::variant_tokens(variant_name, variant) } }) - .for_each(|inline_variant| { - tokens.extend(quote! { + .map(|inline_variant| { + quote! { .item(#inline_variant) - }) - }); + } + }) + .collect(); + + tokens.extend( + quote! { + Into::::into(utoipa::openapi::OneOf::with_capacity(#capacity)) + #items + } + ); if let Some(comment) = CommentAttributes::from_attributes(self.attributes).first() { tokens.extend(quote! { @@ -690,47 +790,57 @@ where } #[inline] -fn is_not_skipped(rule: &Option) -> bool { +fn is_not_skipped(rule: &Option) -> bool { rule.as_ref() - .map(|rule| matches!(rule, Serde::Value(value) if value.skip == None)) + .map(|value| value.skip.is_none()) .unwrap_or(true) } +/// Resolves the appropriate [`RenameRule`] to apply to the specified `struct` `field` name given a +/// `container_rule` (`struct` or `enum` level) and `field_rule` (`struct` field or `enum` variant +/// level). Returns `Some` of the result of the `rename_op` if a rename is required by the supplied +/// rules. #[inline] fn rename_field<'a>( - container_rule: &'a mut Option, - field_rule: &'a mut Option, + container_rule: &'a Option, + field_rule: &'a Option, field: &str, ) -> Option { rename(container_rule, field_rule, &|rule| rule.rename(field)) } +/// Resolves the appropriate [`RenameRule`] to apply to the specified `enum` `variant` name given a +/// `container_rule` (`struct` or `enum` level) and `field_rule` (`struct` field or `enum` variant +/// level). Returns `Some` of the result of the `rename_op` if a rename is required by the supplied +/// rules. #[inline] fn rename_variant<'a>( - container_rule: &'a mut Option, - field_rule: &'a mut Option, - field: &str, + container_rule: &'a Option, + field_rule: &'a Option, + variant: &str, ) -> Option { rename(container_rule, field_rule, &|rule| { - rule.rename_variant(field) + rule.rename_variant(variant) }) } +/// Resolves the appropriate [`RenameRule`] to apply during a `rename_op` given a `container_rule` +/// (`struct` or `enum` level) and `field_rule` (`struct` field or `enum` variant level). Returns +/// `Some` of the result of the `rename_op` if a rename is required by the supplied rules. #[inline] fn rename<'a>( - container_rule: &'a mut Option, - field_rule: &'a mut Option, + container_rule: &'a Option, + field_rule: &'a Option, rename_op: &impl Fn(&RenameRule) -> String, ) -> Option { - let rename = |rule: &mut Serde| match rule { - Serde::Container(container) => container.rename_all.as_ref().map(rename_op), - Serde::Value(ref mut value) => mem::take(&mut value.rename), - }; - field_rule - .as_mut() - .and_then(rename) - .or_else(|| container_rule.as_mut().and_then(rename)) + .as_ref() + .and_then(|value| value.rename.clone()) + .or_else(|| { + container_rule + .as_ref() + .and_then(|container| container.rename_all.as_ref().map(rename_op)) + }) } #[cfg_attr(feature = "debug", derive(Debug))]