From 328e93e449146f0ccfae80734778c84b06e717b0 Mon Sep 17 00:00:00 2001 From: ihor rudynskyi Date: Tue, 17 Dec 2024 17:37:30 +0100 Subject: [PATCH] allow custom *Authorization types be used in SecurityScheme macro --- poem-openapi-derive/src/security_scheme.rs | 43 +++++++++++++--------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/poem-openapi-derive/src/security_scheme.rs b/poem-openapi-derive/src/security_scheme.rs index e7b4298c4e..8cd81377f7 100644 --- a/poem-openapi-derive/src/security_scheme.rs +++ b/poem-openapi-derive/src/security_scheme.rs @@ -6,7 +6,7 @@ use darling::{ use http::header::HeaderName; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; -use syn::{Attribute, DeriveInput, Error, Path}; +use syn::{Attribute, DeriveInput, Error, Path, Type}; use crate::{ error::GeneratorResult, @@ -401,7 +401,11 @@ impl SecuritySchemeArgs { Ok(ts) } - fn generate_from_request(&self, crate_name: &TokenStream) -> GeneratorResult { + fn generate_from_request( + &self, + crate_name: &TokenStream, + field: &Type, + ) -> GeneratorResult { match self.auth_type()? { AuthType::ApiKey => { let key_name = self.key_name.as_ref().unwrap().as_str(); @@ -411,21 +415,21 @@ impl SecuritySchemeArgs { ApiKeyInType::Cookie => quote!(#crate_name::registry::MetaParamIn::Cookie), }; Ok( - quote!(<#crate_name::auth::ApiKey as #crate_name::auth::ApiKeyAuthorization>::from_request(req, query, #key_name, #param_in)), + quote!(<#field as #crate_name::auth::ApiKeyAuthorization>::from_request(req, query, #key_name, #param_in)), ) } - AuthType::Basic => Ok( - quote!(<#crate_name::auth::Basic as #crate_name::auth::BasicAuthorization>::from_request(req)), - ), - AuthType::Bearer => Ok( - quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req)), - ), - AuthType::OAuth2 => Ok( - quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req)), - ), - AuthType::OpenIdConnect => Ok( - quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req)), - ), + AuthType::Basic => { + Ok(quote!(<#field as #crate_name::auth::BasicAuthorization>::from_request(req))) + } + AuthType::Bearer => { + Ok(quote!(<#field as #crate_name::auth::BearerAuthorization>::from_request(req))) + } + AuthType::OAuth2 => { + Ok(quote!(<#field as #crate_name::auth::BearerAuthorization>::from_request(req))) + } + AuthType::OpenIdConnect => { + Ok(quote!(<#field as #crate_name::auth::BearerAuthorization>::from_request(req))) + } } } } @@ -439,15 +443,20 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { Data::Struct(fields) => { let oai_typename = args.rename.clone().unwrap_or_else(|| ident.to_string()); - if fields.style != Style::Tuple || fields.fields.len() != 1 { + if fields.style != Style::Tuple { return Err(Error::new_spanned(ident, "Must be a tuple of length 1.").into()); } + let field = match fields.fields.as_slice() { + [field] => field, + _ => return Err(Error::new_spanned(ident, "Must be a tuple of length 1.").into()), + }; + args.validate()?; let register_security_scheme = args.generate_register_security_scheme(&crate_name, &oai_typename)?; - let from_request = args.generate_from_request(&crate_name)?; + let from_request = args.generate_from_request(&crate_name, field)?; let path = args.checker.as_ref(); let output = match path {