Skip to content

Commit

Permalink
Add support for traits with generic parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Dec 27, 2023
1 parent 6a5e7ab commit 9e159d1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 11 deletions.
8 changes: 8 additions & 0 deletions trait-variant/examples/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,12 @@ fn spawn_task(factory: impl IntFactory + 'static) {
});
}

#[trait_variant::make(GenericTrait: Send)]
pub trait LocalGenericTrait<'x, S: Sync, Y, const X: usize> {
const CONST: usize = 3;
type F;

async fn take(&self, s: S);
}

fn main() {}
64 changes: 53 additions & 11 deletions trait-variant/src/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
use std::iter;

use proc_macro2::TokenStream;
use quote::quote;
use quote::{quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
token::Plus,
Error, FnArg, Generics, Ident, ItemTrait, Pat, PatType, Result, ReturnType, Signature, Token,
TraitBound, TraitItem, TraitItemConst, TraitItemFn, TraitItemType, Type, TypeImplTrait,
TypeParamBound,
token::{Comma, Plus},
Error, FnArg, GenericParam, Generics, Ident, ItemTrait, Lifetime, Pat, PatType, Result,
ReturnType, Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn,
TraitItemType, Type, TypeImplTrait, TypeParamBound,
};

struct Attrs {
Expand Down Expand Up @@ -162,16 +162,58 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {

fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
let orig = &tr.ident;
let generics = &tr.generics.params;
let mut generic_names = tr
.generics
.params
.iter()
.map(|generic| match generic {
GenericParam::Lifetime(lt) => GenericParamName::Lifetime(&lt.lifetime),
GenericParam::Type(ty) => GenericParamName::Type(&ty.ident),
GenericParam::Const(co) => GenericParamName::Const(&co.ident),
})
.collect::<Punctuated<_, Comma>>();
let trailing_comma = if !generic_names.is_empty() {
generic_names.push_punct(Comma::default());
quote! { , }
} else {
quote! {}
};
let variant = &attrs.variant.name;
let items = tr.items.iter().map(|item| blanket_impl_item(item, variant));
let items = tr
.items
.iter()
.map(|item| blanket_impl_item(item, variant, &generic_names));
quote! {
impl<T> #orig for T where T: #variant {
impl<#generics #trailing_comma T> #orig<#generic_names> for T
where T: #variant<#generic_names>
{
#(#items)*
}
}
}

fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream {
enum GenericParamName<'s> {
Lifetime(&'s Lifetime),
Type(&'s Ident),
Const(&'s Ident),
}

impl ToTokens for GenericParamName<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
GenericParamName::Lifetime(lt) => lt.to_tokens(tokens),
GenericParamName::Type(ty) => ty.to_tokens(tokens),
GenericParamName::Const(co) => co.to_tokens(tokens),
}
}
}

fn blanket_impl_item(
item: &TraitItem,
variant: &Ident,
generic_names: &Punctuated<GenericParamName<'_>, Comma>,
) -> TokenStream {
// impl<T> IntFactory for T where T: SendIntFactory {
// const NAME: &'static str = <Self as SendIntFactory>::NAME;
// type MyFut<'a> = <Self as SendIntFactory>::MyFut<'a> where Self: 'a;
Expand All @@ -187,7 +229,7 @@ fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream {
..
}) => {
quote! {
const #ident #generics: #ty = <Self as #variant>::#ident;
const #ident #generics: #ty = <Self as #variant<#generic_names>>::#ident;
}
}
TraitItem::Fn(TraitItemFn { sig, .. }) => {
Expand All @@ -207,7 +249,7 @@ fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream {
};
quote! {
#sig {
<Self as #variant>::#ident(#(#args),*)#maybe_await
<Self as #variant<#generic_names>>::#ident(#(#args),*)#maybe_await
}
}
}
Expand All @@ -222,7 +264,7 @@ fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream {
..
}) => {
quote! {
type #ident<#params> = <Self as #variant>::#ident<#params> #where_clause;
type #ident<#params> = <Self as #variant<#generic_names>>::#ident<#params> #where_clause;
}
}
_ => Error::new_spanned(item, "unsupported item type").into_compile_error(),
Expand Down

0 comments on commit 9e159d1

Please sign in to comment.