diff --git a/src/core/config/config.rs b/src/core/config/config.rs index 8c25dc08be..702715493a 100644 --- a/src/core/config/config.rs +++ b/src/core/config/config.rs @@ -8,7 +8,7 @@ use async_graphql::Positioned; use derive_setters::Setters; use serde::{Deserialize, Serialize}; use serde_json::Value; -use tailcall_macros::{DirectiveDefinition, InputDefinition}; +use tailcall_macros::{CustomResolver, DirectiveDefinition, InputDefinition}; use tailcall_typedefs_common::directive_definition::DirectiveDefinition; use tailcall_typedefs_common::input_definition::InputDefinition; use tailcall_typedefs_common::ServiceDocumentBuilder; @@ -210,53 +210,11 @@ pub struct RootSchema { /// Used to omit a field from public consumption. pub struct Omit {} -// generate Resolver with macro in order to autogenerate conversion code -// from the underlying directives. -// TODO: replace with derive macro -macro_rules! create_resolver { - ($($var:ident($ty:ty)),+$(,)?) => { - #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] - #[serde(rename_all = "camelCase")] - pub enum Resolver { - // just specify the same variants - $($var($ty),)+ - } - - impl Resolver { - pub fn from_directives( - directives: &[Positioned], - ) -> Valid, String> { - let mut result = None; - let mut resolvable_directives = Vec::new(); - let mut valid = Valid::succeed(()); - - $( - // try to parse directive from the Resolver variant - valid = valid.and(<$ty>::from_directives(directives.iter()).map(|resolver| { - if let Some(resolver) = resolver { - // on success store it as a result and remember parsed directives - result = Some(Self::$var(resolver)); - resolvable_directives.push(<$ty>::trace_name()); - } - })); - )+ - - valid.and_then(|_| { - if resolvable_directives.len() > 1 { - Valid::fail(format!( - "Multiple resolvers detected [{}]", - resolvable_directives.join(", ") - )) - } else { - Valid::succeed(result) - } - }) - } - } - }; -} - -create_resolver! { +#[derive( + Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema, CustomResolver, +)] +#[serde(rename_all = "camelCase")] +pub enum Resolver { Http(Http), Grpc(Grpc), Graphql(GraphQL), @@ -265,30 +223,6 @@ create_resolver! { Expr(Expr), } -impl Resolver { - pub fn to_directive(&self) -> ConstDirective { - match self { - Resolver::Http(d) => d.to_directive(), - Resolver::Grpc(d) => d.to_directive(), - Resolver::Graphql(d) => d.to_directive(), - Resolver::Call(d) => d.to_directive(), - Resolver::Js(d) => d.to_directive(), - Resolver::Expr(d) => d.to_directive(), - } - } - - pub fn directive_name(&self) -> String { - match self { - Resolver::Http(_) => Http::directive_name(), - Resolver::Grpc(_) => Grpc::directive_name(), - Resolver::Graphql(_) => GraphQL::directive_name(), - Resolver::Call(_) => Call::directive_name(), - Resolver::Js(_) => JS::directive_name(), - Resolver::Expr(_) => Expr::directive_name(), - } - } -} - /// /// A field definition containing all the metadata information about resolving a /// field. diff --git a/tailcall-macros/src/lib.rs b/tailcall-macros/src/lib.rs index 76348310d2..377a4ef20e 100644 --- a/tailcall-macros/src/lib.rs +++ b/tailcall-macros/src/lib.rs @@ -5,8 +5,12 @@ use proc_macro::TokenStream; mod document_definition; mod gen; mod merge_right; +mod resolver; + use crate::document_definition::{expand_directive_definition, expand_input_definition}; use crate::merge_right::expand_merge_right_derive; +use crate::resolver::expand_resolver_derive; + #[proc_macro_derive(MergeRight, attributes(merge_right))] pub fn merge_right_derive(input: TokenStream) -> TokenStream { expand_merge_right_derive(input) @@ -43,3 +47,8 @@ pub fn gen_doc(item: TokenStream) -> TokenStream { pub fn input_definition_derive(input: TokenStream) -> TokenStream { expand_input_definition(input) } + +#[proc_macro_derive(CustomResolver)] +pub fn resolver_derive(input: TokenStream) -> TokenStream { + expand_resolver_derive(input) +} diff --git a/tailcall-macros/src/resolver.rs b/tailcall-macros/src/resolver.rs new file mode 100644 index 0000000000..cb6526af1e --- /dev/null +++ b/tailcall-macros/src/resolver.rs @@ -0,0 +1,91 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, Data, DeriveInput, Fields}; + +pub fn expand_resolver_derive(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let name = &input.ident; + + let variants = if let Data::Enum(data_enum) = &input.data { + data_enum + .variants + .iter() + .map(|variant| { + let variant_name = &variant.ident; + let ty = match &variant.fields { + Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed[0].ty, + _ => panic!("Resolver variants must have exactly one unnamed field"), + }; + + (variant_name, ty) + }) + .collect::>() + } else { + panic!("Resolver can only be derived for enums"); + }; + + let variant_parsers = variants.iter().map(|(variant_name, ty)| { + quote! { + valid = valid.and(<#ty>::from_directives(directives.iter()).map(|resolver| { + if let Some(resolver) = resolver { + let directive_name = <#ty>::trace_name(); + if !resolvable_directives.contains(&directive_name) { + resolvable_directives.push(directive_name); + } + result = Some(Self::#variant_name(resolver)); + } + })); + } + }); + + let match_arms_to_directive = variants.iter().map(|(variant_name, _ty)| { + quote! { + Self::#variant_name(d) => d.to_directive(), + } + }); + + let match_arms_directive_name = variants.iter().map(|(variant_name, ty)| { + quote! { + Self::#variant_name(_) => <#ty>::directive_name(), + } + }); + + let expanded = quote! { + impl #name { + pub fn from_directives( + directives: &[Positioned], + ) -> Valid, String> { + let mut result = None; + let mut resolvable_directives = Vec::new(); + let mut valid = Valid::succeed(()); + + #(#variant_parsers)* + + valid.and_then(|_| { + if resolvable_directives.len() > 1 { + Valid::fail(format!( + "Multiple resolvers detected [{}]", + resolvable_directives.join(", ") + )) + } else { + Valid::succeed(result) + } + }) + } + + pub fn to_directive(&self) -> ConstDirective { + match self { + #(#match_arms_to_directive)* + } + } + + pub fn directive_name(&self) -> String { + match self { + #(#match_arms_directive_name)* + } + } + } + }; + + TokenStream::from(expanded) +}