From 90c80bdd65933f28c7d56e64d0eb1bd4fff89f2c Mon Sep 17 00:00:00 2001 From: mendess Date: Wed, 27 Dec 2023 14:41:29 +0000 Subject: [PATCH] Add support for traits with generic parameters --- trait-variant/examples/variant.rs | 11 ++++++ trait-variant/src/variant.rs | 66 +++++++++++++++++++++++++------ 2 files changed, 66 insertions(+), 11 deletions(-) diff --git a/trait-variant/examples/variant.rs b/trait-variant/examples/variant.rs index 9c15409..428e737 100644 --- a/trait-variant/examples/variant.rs +++ b/trait-variant/examples/variant.rs @@ -29,4 +29,15 @@ fn spawn_task(factory: impl IntFactory + 'static) { }); } +#[trait_variant::make(GenericTrait: Send)] +pub trait LocalGenericTrait<'x, S: Sync, Y, const X: usize> +where + Y: Sync, +{ + const CONST: usize = 3; + type F; + + async fn take(&self, s: S); +} + fn main() {} diff --git a/trait-variant/src/variant.rs b/trait-variant/src/variant.rs index 5ff3381..c2cb9c6 100644 --- a/trait-variant/src/variant.rs +++ b/trait-variant/src/variant.rs @@ -9,15 +9,15 @@ use std::iter; use proc_macro2::TokenStream; -use quote::quote; +use quote::{quote, ToTokens}; use syn::{ parse::{Parse, ParseStream}, parse_macro_input, punctuated::Punctuated, - token::Plus, - Error, FnArg, Generics, Ident, ItemTrait, Pat, PatType, Result, ReturnType, Signature, Token, - TraitBound, TraitItem, TraitItemConst, TraitItemFn, TraitItemType, Type, TypeImplTrait, - TypeParamBound, + token::{Comma, Plus}, + Error, FnArg, GenericParam, Generics, Ident, ItemTrait, Lifetime, Pat, PatType, Result, + ReturnType, Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn, + TraitItemType, Type, TypeImplTrait, TypeParamBound, }; struct Attrs { @@ -162,16 +162,60 @@ fn transform_item(item: &TraitItem, bounds: &Vec) -> TraitItem { fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream { let orig = &tr.ident; + let generics = &tr.generics.params; + let mut generic_names = tr + .generics + .params + .iter() + .map(|generic| match generic { + GenericParam::Lifetime(lt) => GenericParamName::Lifetime(<.lifetime), + GenericParam::Type(ty) => GenericParamName::Type(&ty.ident), + GenericParam::Const(co) => GenericParamName::Const(&co.ident), + }) + .collect::>(); + let trailing_comma = if !generic_names.is_empty() { + generic_names.push_punct(Comma::default()); + quote! { , } + } else { + quote! {} + }; let variant = &attrs.variant.name; - let items = tr.items.iter().map(|item| blanket_impl_item(item, variant)); + let items = tr + .items + .iter() + .map(|item| blanket_impl_item(item, variant, &generic_names)); + let where_clauses = tr.generics.where_clause.as_ref().map(|wh| &wh.predicates); quote! { - impl #orig for T where T: #variant { + impl<#generics #trailing_comma TraitVariantBlanketType> #orig<#generic_names> + for TraitVariantBlanketType + where TraitVariantBlanketType: #variant<#generic_names>, #where_clauses + { #(#items)* } } } -fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream { +enum GenericParamName<'s> { + Lifetime(&'s Lifetime), + Type(&'s Ident), + Const(&'s Ident), +} + +impl ToTokens for GenericParamName<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + GenericParamName::Lifetime(lt) => lt.to_tokens(tokens), + GenericParamName::Type(ty) => ty.to_tokens(tokens), + GenericParamName::Const(co) => co.to_tokens(tokens), + } + } +} + +fn blanket_impl_item( + item: &TraitItem, + variant: &Ident, + generic_names: &Punctuated, Comma>, +) -> TokenStream { // impl IntFactory for T where T: SendIntFactory { // const NAME: &'static str = ::NAME; // type MyFut<'a> = ::MyFut<'a> where Self: 'a; @@ -187,7 +231,7 @@ fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream { .. }) => { quote! { - const #ident #generics: #ty = ::#ident; + const #ident #generics: #ty = >::#ident; } } TraitItem::Fn(TraitItemFn { sig, .. }) => { @@ -207,7 +251,7 @@ fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream { }; quote! { #sig { - ::#ident(#(#args),*)#maybe_await + >::#ident(#(#args),*)#maybe_await } } } @@ -222,7 +266,7 @@ fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream { .. }) => { quote! { - type #ident<#params> = ::#ident<#params> #where_clause; + type #ident<#params> = >::#ident<#params> #where_clause; } } _ => Error::new_spanned(item, "unsupported item type").into_compile_error(),