Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support omitting the variant name #30

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ trait IntFactory: Send {

Implementers can choose to implement either `LocalIntFactory` or `IntFactory` as appropriate.

If a non-`Send` variant of the trait is not needed, the name of the new variant can simply be omitted. E.g., this generates a *single* (rather than an additional) trait whose definition matches that in the expansion above:

```rust
#[trait_variant::make(Send)]
trait IntFactory {
async fn make(&self) -> i32;
fn stream(&self) -> impl Iterator<Item = i32>;
fn call(&self) -> u32;
}
```

For more details, see the docs for [`trait_variant::make`].

[`trait_variant::make`]: https://docs.rs/trait-variant/latest/trait_variant/attr.make.html
Expand Down
5 changes: 5 additions & 0 deletions trait-variant/examples/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ fn spawn_task(factory: impl IntFactory + 'static) {
});
}

#[trait_variant::make(Send)]
pub trait TupleFactory {
async fn new() -> Self;
}

#[trait_variant::make(GenericTrait: Send)]
pub trait LocalGenericTrait<'x, S: Sync, Y, const X: usize>
where
Expand Down
13 changes: 13 additions & 0 deletions trait-variant/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ mod variant;
/// Implementers of the trait can choose to implement the variant instead of the
/// original trait. The macro creates a blanket impl which ensures that any type
/// which implements the variant also implements the original trait.
///
/// If a non-`Send` variant of the trait is not needed, the name of
/// new variant can simply be omitted. E.g., this generates a
/// *single* (rather than an additional) trait whose definition
/// matches that in the expansion above:
///
/// #[trait_variant::make(Send)]
/// trait IntFactory {
/// async fn make(&self) -> i32;
/// fn stream(&self) -> impl Iterator<Item = i32>;
/// fn call(&self) -> u32;
/// }
/// ```
#[proc_macro_attribute]
pub fn make(
attr: proc_macro::TokenStream,
Expand Down
81 changes: 46 additions & 35 deletions trait-variant/src/variant.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) 2023 Google LLC
// Copyright (c) 2023 Various contributors (see git history)
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand All @@ -11,7 +12,7 @@ use std::iter;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{
parse::{Parse, ParseStream},
parse::{discouraged::Speculative as _, Parse, ParseStream},
parse_macro_input, parse_quote,
punctuated::Punctuated,
token::Plus,
Expand All @@ -20,44 +21,57 @@ use syn::{
TypeImplTrait, TypeParam, TypeParamBound,
};

struct Attrs {
variant: MakeVariant,
#[derive(Clone)]
struct Variant {
name: Option<Ident>,
_colon: Option<Token![:]>,
bounds: Punctuated<TraitBound, Plus>,
}

impl Parse for Attrs {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Self {
variant: MakeVariant::parse(input)?,
})
}
fn parse_bounds_only(input: ParseStream) -> Result<Option<Variant>> {
let fork = input.fork();
let colon: Option<Token![:]> = fork.parse()?;
let bounds = match fork.parse_terminated(TraitBound::parse, Token![+]) {
Ok(x) => Ok(x),
Err(e) if colon.is_some() => Err(e),
Err(_) => return Ok(None),
};
input.advance_to(&fork);
Ok(Some(Variant {
name: None,
_colon: colon,
bounds: bounds?,
}))
}

struct MakeVariant {
name: Ident,
#[allow(unused)]
colon: Token![:],
bounds: Punctuated<TraitBound, Plus>,
fn parse_fallback(input: ParseStream) -> Result<Variant> {
let name: Ident = input.parse()?;
let colon: Token![:] = input.parse()?;
let bounds = input.parse_terminated(TraitBound::parse, Token![+])?;
Ok(Variant {
name: Some(name),
_colon: Some(colon),
bounds,
})
}

impl Parse for MakeVariant {
impl Parse for Variant {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Self {
name: input.parse()?,
colon: input.parse()?,
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
})
match parse_bounds_only(input)? {
Some(x) => Ok(x),
None => parse_fallback(input),
}
}
}

pub fn make(
attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let attrs = parse_macro_input!(attr as Attrs);
let variant = parse_macro_input!(attr as Variant);
let item = parse_macro_input!(item as ItemTrait);

let maybe_allow_async_lint = if attrs
.variant
let maybe_allow_async_lint = if variant
.bounds
.iter()
.any(|b| b.path.segments.last().unwrap().ident == "Send")
Expand All @@ -67,26 +81,24 @@ pub fn make(
quote! {}
};

let variant = mk_variant(&attrs, &item);
let blanket_impl = mk_blanket_impl(&attrs, &item);

let variant_name = variant.clone().name.unwrap_or(item.clone().ident);
let variant_def = mk_variant(&variant_name, &variant.bounds, &item);
if variant_name == item.ident {
return variant_def.into();
}
let blanket_impl = Some(mk_blanket_impl(&variant_name, &item));
quote! {
#maybe_allow_async_lint
#item

#variant
#variant_def

#blanket_impl
}
.into()
}

fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
let MakeVariant {
ref name,
colon: _,
ref bounds,
} = attrs.variant;
fn mk_variant(name: &Ident, bounds: &Punctuated<TraitBound, Plus>, tr: &ItemTrait) -> TokenStream {
let bounds: Vec<_> = bounds
.into_iter()
.map(|b| TypeParamBound::Trait(b.clone()))
Expand Down Expand Up @@ -160,9 +172,8 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
})
}

fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
fn mk_blanket_impl(variant: &Ident, tr: &ItemTrait) -> TokenStream {
let orig = &tr.ident;
let variant = &attrs.variant.name;
let (_impl, orig_ty_generics, _where) = &tr.generics.split_for_impl();
let items = tr
.items
Expand Down
65 changes: 65 additions & 0 deletions trait-variant/tests/bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2023 Google LLC
// Copyright (c) 2023 Various contributors (see git history)
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

#[trait_variant::make(Send + Sync)]
pub trait Trait {
const CONST: &'static ();
type Gat<'a>
where
Self: 'a;
async fn assoc_async_fn_no_ret(a: (), b: ());
async fn assoc_async_method_no_ret(&self, a: (), b: ());
async fn assoc_async_fn(a: (), b: ()) -> ();
async fn assoc_async_method(&self, a: (), b: ()) -> ();
fn assoc_sync_fn_no_ret(a: (), b: ());
fn assoc_sync_method_no_ret(&self, a: (), b: ());
fn assoc_sync_fn(a: (), b: ()) -> ();
fn assoc_sync_method(&self, a: (), b: ()) -> ();
// FIXME: See #17.
//async fn dft_assoc_async_fn_no_ret(_a: (), _b: ()) {}
//async fn dft_assoc_async_method_no_ret(&self, _a: (), _b: ()) {}
//async fn dft_assoc_async_fn(_a: (), _b: ()) -> () {}
//async fn dft_assoc_async_method(&self, _a: (), _b: ()) -> () {}
fn dft_assoc_sync_fn_no_ret(_a: (), _b: ()) {}
fn dft_assoc_sync_method_no_ret(&self, _a: (), _b: ()) {}
fn dft_assoc_sync_fn(_a: (), _b: ()) -> () {}
fn dft_assoc_sync_method(&self, _a: (), _b: ()) -> () {}
}

impl Trait for () {
const CONST: &'static () = &();
type Gat<'a> = ();
async fn assoc_async_fn_no_ret(_a: (), _b: ()) {}
async fn assoc_async_method_no_ret(&self, _a: (), _b: ()) {}
async fn assoc_async_fn(_a: (), _b: ()) -> () {}
async fn assoc_async_method(&self, _a: (), _b: ()) -> () {}
fn assoc_sync_fn_no_ret(_a: (), _b: ()) {}
fn assoc_sync_method_no_ret(&self, _a: (), _b: ()) {}
fn assoc_sync_fn(_a: (), _b: ()) -> () {}
fn assoc_sync_method(&self, _a: (), _b: ()) -> () {}
}

fn is_bounded<T: Send + Sync>(_: T) {}

#[test]
fn test() {
fn inner<T: Trait>(x: T) {
let (a, b) = ((), ());
is_bounded(<T as Trait>::assoc_async_fn_no_ret(a, b));
is_bounded(<T as Trait>::assoc_async_method_no_ret(&x, a, b));
is_bounded(<T as Trait>::assoc_async_fn(a, b));
is_bounded(<T as Trait>::assoc_async_method(&x, a, b));
// FIXME: See #17.
//is_bounded(<T as Trait>::dft_assoc_async_fn_no_ret(a, b));
//is_bounded(<T as Trait>::dft_assoc_async_method_no_ret(&x, a, b));
//is_bounded(<T as Trait>::dft_assoc_async_fn(a, b));
//is_bounded(<T as Trait>::dft_assoc_async_method(&x, a, b));
}
inner(());
}
65 changes: 65 additions & 0 deletions trait-variant/tests/colon-bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2023 Google LLC
// Copyright (c) 2023 Various contributors (see git history)
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

#[trait_variant::make(: Send + Sync)]
pub trait Trait {
const CONST: &'static ();
type Gat<'a>
where
Self: 'a;
async fn assoc_async_fn_no_ret(a: (), b: ());
async fn assoc_async_method_no_ret(&self, a: (), b: ());
async fn assoc_async_fn(a: (), b: ()) -> ();
async fn assoc_async_method(&self, a: (), b: ()) -> ();
fn assoc_sync_fn_no_ret(a: (), b: ());
fn assoc_sync_method_no_ret(&self, a: (), b: ());
fn assoc_sync_fn(a: (), b: ()) -> ();
fn assoc_sync_method(&self, a: (), b: ()) -> ();
// FIXME: See #17.
//async fn dft_assoc_async_fn_no_ret(_a: (), _b: ()) {}
//async fn dft_assoc_async_method_no_ret(&self, _a: (), _b: ()) {}
//async fn dft_assoc_async_fn(_a: (), _b: ()) -> () {}
//async fn dft_assoc_async_method(&self, _a: (), _b: ()) -> () {}
fn dft_assoc_sync_fn_no_ret(_a: (), _b: ()) {}
fn dft_assoc_sync_method_no_ret(&self, _a: (), _b: ()) {}
fn dft_assoc_sync_fn(_a: (), _b: ()) -> () {}
fn dft_assoc_sync_method(&self, _a: (), _b: ()) -> () {}
}

impl Trait for () {
const CONST: &'static () = &();
type Gat<'a> = ();
async fn assoc_async_fn_no_ret(_a: (), _b: ()) {}
async fn assoc_async_method_no_ret(&self, _a: (), _b: ()) {}
async fn assoc_async_fn(_a: (), _b: ()) -> () {}
async fn assoc_async_method(&self, _a: (), _b: ()) -> () {}
fn assoc_sync_fn_no_ret(_a: (), _b: ()) {}
fn assoc_sync_method_no_ret(&self, _a: (), _b: ()) {}
fn assoc_sync_fn(_a: (), _b: ()) -> () {}
fn assoc_sync_method(&self, _a: (), _b: ()) -> () {}
}

fn is_bounded<T: Send + Sync>(_: T) {}

#[test]
fn test() {
fn inner<T: Trait>(x: T) {
let (a, b) = ((), ());
is_bounded(<T as Trait>::assoc_async_fn_no_ret(a, b));
is_bounded(<T as Trait>::assoc_async_method_no_ret(&x, a, b));
is_bounded(<T as Trait>::assoc_async_fn(a, b));
is_bounded(<T as Trait>::assoc_async_method(&x, a, b));
// FIXME: See #17.
//is_bounded(<T as Trait>::dft_assoc_async_fn_no_ret(a, b));
//is_bounded(<T as Trait>::dft_assoc_async_method_no_ret(&x, a, b));
//is_bounded(<T as Trait>::dft_assoc_async_fn(a, b));
//is_bounded(<T as Trait>::dft_assoc_async_method(&x, a, b));
}
inner(());
}
65 changes: 65 additions & 0 deletions trait-variant/tests/name-colon-bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2023 Google LLC
// Copyright (c) 2023 Various contributors (see git history)
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

#[trait_variant::make(Trait: Send + Sync)]
pub trait LocalTrait {
const CONST: &'static ();
type Gat<'a>
where
Self: 'a;
async fn assoc_async_fn_no_ret(a: (), b: ());
async fn assoc_async_method_no_ret(&self, a: (), b: ());
async fn assoc_async_fn(a: (), b: ()) -> ();
async fn assoc_async_method(&self, a: (), b: ()) -> ();
fn assoc_sync_fn_no_ret(a: (), b: ());
fn assoc_sync_method_no_ret(&self, a: (), b: ());
fn assoc_sync_fn(a: (), b: ()) -> ();
fn assoc_sync_method(&self, a: (), b: ()) -> ();
// FIXME: See #17.
//async fn dft_assoc_async_fn_no_ret(_a: (), _b: ()) {}
//async fn dft_assoc_async_method_no_ret(&self, _a: (), _b: ()) {}
//async fn dft_assoc_async_fn(_a: (), _b: ()) -> () {}
//async fn dft_assoc_async_method(&self, _a: (), _b: ()) -> () {}
fn dft_assoc_sync_fn_no_ret(_a: (), _b: ()) {}
fn dft_assoc_sync_method_no_ret(&self, _a: (), _b: ()) {}
fn dft_assoc_sync_fn(_a: (), _b: ()) -> () {}
fn dft_assoc_sync_method(&self, _a: (), _b: ()) -> () {}
}

impl Trait for () {
const CONST: &'static () = &();
type Gat<'a> = ();
async fn assoc_async_fn_no_ret(_a: (), _b: ()) {}
async fn assoc_async_method_no_ret(&self, _a: (), _b: ()) {}
async fn assoc_async_fn(_a: (), _b: ()) -> () {}
async fn assoc_async_method(&self, _a: (), _b: ()) -> () {}
fn assoc_sync_fn_no_ret(_a: (), _b: ()) {}
fn assoc_sync_method_no_ret(&self, _a: (), _b: ()) {}
fn assoc_sync_fn(_a: (), _b: ()) -> () {}
fn assoc_sync_method(&self, _a: (), _b: ()) -> () {}
}

fn is_bounded<T: Send + Sync>(_: T) {}

#[test]
fn test() {
fn inner<T: Trait>(x: T) {
let (a, b) = ((), ());
is_bounded(<T as Trait>::assoc_async_fn_no_ret(a, b));
is_bounded(<T as Trait>::assoc_async_method_no_ret(&x, a, b));
is_bounded(<T as Trait>::assoc_async_fn(a, b));
is_bounded(<T as Trait>::assoc_async_method(&x, a, b));
// FIXME: See #17.
//is_bounded(<T as Trait>::dft_assoc_async_fn_no_ret(a, b));
//is_bounded(<T as Trait>::dft_assoc_async_method_no_ret(&x, a, b));
//is_bounded(<T as Trait>::dft_assoc_async_fn(a, b));
//is_bounded(<T as Trait>::dft_assoc_async_method(&x, a, b));
}
inner(());
}
Loading
Loading