Skip to content

Commit

Permalink
feat: trait_variant::make supports rewriting of the original trait.
Browse files Browse the repository at this point in the history
  • Loading branch information
sargarass committed Feb 10, 2024
1 parent f1e171e commit b89d532
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 55 deletions.
14 changes: 14 additions & 0 deletions trait-variant/examples/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,18 @@ where
fn build<T: Display>(&self, items: impl Iterator<Item = T>) -> Self::B<T>;
}

#[trait_variant::make(Send + Sync)]
pub trait GenericTraitWithBounds<'x, S: Sync, Y, const X: usize>
where
Y: Sync,
{
const CONST: usize = 3;
type F;
type A<const ANOTHER_CONST: u8>;
type B<T: Display>: FromIterator<T>;

async fn take(&self, s: S);
fn build<T: Display>(&self, items: impl Iterator<Item = T>) -> Self::B<T>;
}

fn main() {}
19 changes: 16 additions & 3 deletions trait-variant/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ mod variant;
/// fn` and/or `-> impl Trait` return types.
///
/// ```
/// #[trait_variant::make(IntFactory: Send)]
/// trait LocalIntFactory {
/// #[trait_variant::make(Send)]
/// trait IntFactory {
/// async fn make(&self) -> i32;
/// fn stream(&self) -> impl Iterator<Item = i32>;
/// fn call(&self) -> u32;
/// }
/// ```
///
/// The above example causes a second trait called `IntFactory` to be created:
/// The above example causes the trait to be rewritten as:
///
/// ```
/// # use core::future::Future;
Expand All @@ -35,6 +35,19 @@ mod variant;
///
/// Note that ordinary methods such as `call` are not affected.
///
/// If you want to preserve an original trait untouched, `make` can be used to create a new trait with bounds on `async
/// fn` and/or `-> impl Trait` return types.
///
/// ```
/// #[trait_variant::make(IntFactory: Send)]
/// trait LocalIntFactory {
/// async fn make(&self) -> i32;
/// fn stream(&self) -> impl Iterator<Item = i32>;
/// fn call(&self) -> u32;
/// }
/// ```
///
/// The example causes a second trait called `IntFactory` to be created.
/// Implementers of the trait can choose to implement the variant instead of the
/// original trait. The macro creates a blanket impl which ensures that any type
/// which implements the variant also implements the original trait.
Expand Down
111 changes: 59 additions & 52 deletions trait-variant/src/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,33 @@ impl Parse for Attrs {
}
}

struct MakeVariant {
name: Ident,
#[allow(unused)]
colon: Token![:],
bounds: Punctuated<TraitBound, Plus>,
enum MakeVariant {
// Creates a variant of a trait under a new name with additional bounds while preserving the original trait.
Create {
name: Ident,
_colon: Token![:],
bounds: Punctuated<TraitBound, Plus>,
},
// Rewrites the original trait into a new trait with additional bounds.
Rewrite {
bounds: Punctuated<TraitBound, Plus>,
},
}

impl Parse for MakeVariant {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Self {
name: input.parse()?,
colon: input.parse()?,
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
})
let variant = if input.peek(Ident) && input.peek2(Token![:]) {
MakeVariant::Create {
name: input.parse()?,
_colon: input.parse()?,
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
}
} else {
MakeVariant::Rewrite {
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
}
};
Ok(variant)
}
}

Expand All @@ -56,43 +69,51 @@ pub fn make(
let attrs = parse_macro_input!(attr as Attrs);
let item = parse_macro_input!(item as ItemTrait);

let maybe_allow_async_lint = if attrs
.variant
.bounds
.iter()
.any(|b| b.path.segments.last().unwrap().ident == "Send")
{
quote! { #[allow(async_fn_in_trait)] }
} else {
quote! {}
};
match attrs.variant {
MakeVariant::Create { name, bounds, .. } => {
let maybe_allow_async_lint = if bounds
.iter()
.any(|b| b.path.segments.last().unwrap().ident == "Send")
{
quote! { #[allow(async_fn_in_trait)] }
} else {
quote! {}
};

let variant = mk_variant(&attrs, &item);
let blanket_impl = mk_blanket_impl(&attrs, &item);
let variant = mk_variant(&name, bounds, &item);
let blanket_impl = mk_blanket_impl(&name, &item);

quote! {
#maybe_allow_async_lint
#item
quote! {
#maybe_allow_async_lint
#item

#variant
#variant

#blanket_impl
#blanket_impl
}
.into()
}
MakeVariant::Rewrite { bounds, .. } => {
let variant = mk_variant(&item.ident, bounds, &item);
quote! {
#variant
}
.into()
}
}
.into()
}

fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
let MakeVariant {
ref name,
colon: _,
ref bounds,
} = attrs.variant;
let bounds: Vec<_> = bounds
fn mk_variant(
variant: &Ident,
with_bounds: Punctuated<TraitBound, Plus>,
tr: &ItemTrait,
) -> TokenStream {
let bounds: Vec<_> = with_bounds
.into_iter()
.map(|b| TypeParamBound::Trait(b.clone()))
.collect();
let variant = ItemTrait {
ident: name.clone(),
ident: variant.clone(),
supertraits: tr.supertraits.iter().chain(&bounds).cloned().collect(),
items: tr
.items
Expand All @@ -104,21 +125,8 @@ fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
quote! { #variant }
}

// Transforms a one item declaration within the definition if it has `async fn` and/or `-> impl Trait` return types by adding new bounds.
fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
// #[make_variant(SendIntFactory: Send)]
// trait IntFactory {
// async fn make(&self, x: u32, y: &str) -> i32;
// fn stream(&self) -> impl Iterator<Item = i32>;
// fn call(&self) -> u32;
// }
//
// becomes:
//
// trait SendIntFactory: Send {
// fn make(&self, x: u32, y: &str) -> impl ::core::future::Future<Output = i32> + Send;
// fn stream(&self) -> impl Iterator<Item = i32> + Send;
// fn call(&self) -> u32;
// }
let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else {
return item.clone();
};
Expand Down Expand Up @@ -160,9 +168,8 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
})
}

fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
fn mk_blanket_impl(variant: &Ident, tr: &ItemTrait) -> TokenStream {
let orig = &tr.ident;
let variant = &attrs.variant.name;
let (_impl, orig_ty_generics, _where) = &tr.generics.split_for_impl();
let items = tr
.items
Expand Down

0 comments on commit b89d532

Please sign in to comment.