Skip to content

Commit

Permalink
Implement #[derive(Query)] for enum types
Browse files Browse the repository at this point in the history
  • Loading branch information
Moulberry authored Nov 20, 2024
1 parent d7676a3 commit 96f8289
Show file tree
Hide file tree
Showing 14 changed files with 419 additions and 35 deletions.
7 changes: 5 additions & 2 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ pub fn derive_dynamic_bundle_clone(input: TokenStream) -> TokenStream {
.into()
}

/// Implement `Query` for a struct
/// Implement `Query` for a struct or enum.
///
/// Queries structs can be passed to the type parameter of `World::query`. They must have exactly
/// Queries can be passed to the type parameter of `World::query`. They must have exactly
/// one lifetime parameter, and all of their fields must be queries (e.g. references) using that
/// lifetime.
///
/// For enum queries, the result will always be the first variant that matches the entity.
/// Unit variants and variants without any fields will always match an entity.
///
/// # Example
/// ```
/// # use hecs::*;
Expand Down
294 changes: 283 additions & 11 deletions macros/src/query.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
use proc_macro2::Span;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{DeriveInput, Error, Ident, Lifetime, Result, Type};
use syn::{DataEnum, DataStruct, DeriveInput, Error, Ident, Lifetime, Result, Type, Visibility};

pub fn derive(input: DeriveInput) -> Result<TokenStream2> {
let ident = input.ident;
let vis = input.vis;
let data = match input.data {
syn::Data::Struct(s) => s,

match input.data {
syn::Data::Struct(_) | syn::Data::Enum(_) => {}
_ => {
return Err(Error::new_spanned(
ident,
"derive(Query) may only be applied to structs",
"derive(Query) may only be applied to structs and enums",
))
}
};
}

let vis = input.vis;
let lifetime = input
.generics
.lifetimes()
Expand All @@ -36,6 +38,19 @@ pub fn derive(input: DeriveInput) -> Result<TokenStream2> {
));
}

match input.data {
syn::Data::Struct(data_struct) => derive_struct(ident, vis, data_struct, lifetime),
syn::Data::Enum(data_enum) => derive_enum(ident, vis, data_enum, lifetime),
_ => unreachable!(),
}
}

fn derive_struct(
ident: Ident,
vis: Visibility,
data: DataStruct,
lifetime: Lifetime,
) -> Result<TokenStream2> {
let (fields, queries) = match data.fields {
syn::Fields::Named(ref fields) => fields
.named
Expand All @@ -55,7 +70,7 @@ pub fn derive(input: DeriveInput) -> Result<TokenStream2> {
(
syn::Member::Unnamed(syn::Index {
index: i as u32,
span: Span::call_site(),
span: Span::mixed_site(),
}),
query_ty(&lifetime, &f.ty),
)
Expand All @@ -67,7 +82,7 @@ pub fn derive(input: DeriveInput) -> Result<TokenStream2> {
.iter()
.map(|ty| quote! { <#ty as ::hecs::Query>::Fetch })
.collect::<Vec<_>>();
let fetch_ident = Ident::new(&format!("{}Fetch", ident), Span::call_site());
let fetch_ident = Ident::new(&format!("{}Fetch", ident), Span::mixed_site());
let fetch = match data.fields {
syn::Fields::Named(_) => quote! {
#vis struct #fetch_ident {
Expand All @@ -83,7 +98,7 @@ pub fn derive(input: DeriveInput) -> Result<TokenStream2> {
#vis struct #fetch_ident;
},
};
let state_ident = Ident::new(&format!("{}State", ident), Span::call_site());
let state_ident = Ident::new(&format!("{}State", ident), Span::mixed_site());
let state = match data.fields {
syn::Fields::Named(_) => quote! {
#[derive(Clone, Copy)]
Expand All @@ -108,7 +123,7 @@ pub fn derive(input: DeriveInput) -> Result<TokenStream2> {
.map(|x| match x {
syn::Member::Named(ref ident) => ident.clone(),
syn::Member::Unnamed(ref index) => {
Ident::new(&format!("field_{}", index.index), Span::call_site())
Ident::new(&format!("field_{}", index.index), Span::mixed_site())
}
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -193,14 +208,271 @@ pub fn derive(input: DeriveInput) -> Result<TokenStream2> {
})
}

fn derive_enum(
enum_ident: Ident,
vis: Visibility,
data: DataEnum,
lifetime: Lifetime,
) -> Result<TokenStream2> {
let mut dangling_constructor = None;
let mut fetch_variants = TokenStream2::new();
let mut state_variants = TokenStream2::new();
let mut query_get_variants = TokenStream2::new();
let mut fetch_access_variants = TokenStream2::new();
let mut fetch_borrow_variants = TokenStream2::new();
let mut fetch_prepare_variants = TokenStream2::new();
let mut fetch_execute_variants = TokenStream2::new();
let mut fetch_release_variants = TokenStream2::new();
let mut fetch_for_each_borrow = TokenStream2::new();

for variant in &data.variants {
let (fields, queries) = match variant.fields {
syn::Fields::Named(ref fields) => fields
.named
.iter()
.map(|f| {
(
syn::Member::Named(f.ident.clone().unwrap()),
query_ty(&lifetime, &f.ty),
)
})
.unzip(),
syn::Fields::Unnamed(ref fields) => fields
.unnamed
.iter()
.enumerate()
.map(|(i, f)| {
(
syn::Member::Unnamed(syn::Index {
index: i as u32,
span: Span::mixed_site(),
}),
query_ty(&lifetime, &f.ty),
)
})
.unzip(),
syn::Fields::Unit => (Vec::new(), Vec::new()),
};

let ident = variant.ident.clone();

if ident == "__HecsDanglingFetch__" {
return Err(Error::new_spanned(
ident,
"derive(Query) reserves this identifier for internal use",
));
}

let named_fields = fields
.iter()
.map(|x| match x {
syn::Member::Named(ref ident) => ident.clone(),
syn::Member::Unnamed(ref index) => {
Ident::new(&format!("field_{}", index.index), Span::mixed_site())
}
})
.collect::<Vec<_>>();

let fetches = queries
.iter()
.map(|ty| quote! { <#ty as ::hecs::Query>::Fetch })
.collect::<Vec<_>>();

if dangling_constructor.is_none() && fields.is_empty() {
dangling_constructor = Some(quote! {
Self::#ident {}
});
}

fetch_variants.extend(quote! {
#ident {
#(
#named_fields: #fetches,
)*
},
});

state_variants.extend(quote! {
#ident {
#(
#named_fields: <#fetches as ::hecs::Fetch>::State,
)*
},
});

query_get_variants.extend(quote! {
Self::Fetch::#ident { #(#named_fields),* } => {
#(
let #named_fields: <#queries as ::hecs::Query>::Item<'q> = <#queries as ::hecs::Query>::get(#named_fields, n);
)*
Self::Item::#ident { #( #fields: #named_fields,)* }
},
});

fetch_access_variants.extend(quote! {
'block: {
let mut access = ::hecs::Access::Iterate;
#(
if let ::core::option::Option::Some(new_access) = #fetches::access(archetype) {
access = ::core::cmp::max(access, new_access);
} else {
break 'block;
}
)*
return ::core::option::Option::Some(access)
}
});

fetch_borrow_variants.extend(quote! {
Self::State::#ident { #(#named_fields),* } => {
#(
#fetches::borrow(archetype, #named_fields);
)*
},
});

fetch_prepare_variants.extend(quote! {
'block: {
#(
let ::core::option::Option::Some(#named_fields) = #fetches::prepare(archetype) else {
break 'block;
};
)*
return ::core::option::Option::Some(Self::State::#ident { #(#named_fields,)* });
}
});

fetch_execute_variants.extend(quote! {
Self::State::#ident { #(#named_fields),* } => {
return Self::#ident {
#(
#named_fields: #fetches::execute(archetype, #named_fields),
)*
};
},
});

fetch_release_variants.extend(quote! {
Self::State::#ident { #(#named_fields),* } => {
#(
#fetches::release(archetype, #named_fields);
)*
},
});

fetch_for_each_borrow.extend(quote! {
#(
<#fetches as ::hecs::Fetch>::for_each_borrow(&mut f);
)*
});
}

let dangling_constructor = if let Some(dangling_constructor) = dangling_constructor {
dangling_constructor
} else {
fetch_variants.extend(quote! {
__HecsDanglingFetch__,
});
query_get_variants.extend(quote! {
Self::Fetch::__HecsDanglingFetch__ => panic!("Called get() with dangling fetch"),
});
quote! {
Self::__HecsDanglingFetch__
}
};

let fetch_ident = Ident::new(&format!("{}Fetch", enum_ident), Span::mixed_site());
let fetch = quote! {
#vis enum #fetch_ident {
#fetch_variants
}
};

let state_ident = Ident::new(&format!("{}State", enum_ident), Span::mixed_site());
let state = quote! {
#vis enum #state_ident {
#state_variants
}
};

Ok(quote! {
const _: () = {
#[derive(Clone)]
#fetch

impl<'a> ::hecs::Query for #enum_ident<'a> {
type Item<'q> = #enum_ident<'q>;

type Fetch = #fetch_ident;

#[allow(unused_variables)]
unsafe fn get<'q>(fetch: &Self::Fetch, n: usize) -> Self::Item<'q> {
match fetch {
#query_get_variants
}
}
}

#[derive(Clone, Copy)]
#state

unsafe impl ::hecs::Fetch for #fetch_ident {
type State = #state_ident;

fn dangling() -> Self {
#dangling_constructor
}

#[allow(unused_variables, unused_mut, unreachable_code)]
fn access(archetype: &::hecs::Archetype) -> ::core::option::Option<::hecs::Access> {
#fetch_access_variants
::core::option::Option::None
}

#[allow(unused_variables)]
fn borrow(archetype: &::hecs::Archetype, state: Self::State) {
match state {
#fetch_borrow_variants
}
}

#[allow(unused_variables, unreachable_code)]
fn prepare(archetype: &::hecs::Archetype) -> ::core::option::Option<Self::State> {
#fetch_prepare_variants
::core::option::Option::None
}

#[allow(unused_variables)]
fn execute(archetype: &::hecs::Archetype, state: Self::State) -> Self {
match state {
#fetch_execute_variants
}
}

#[allow(unused_variables)]
fn release(archetype: &::hecs::Archetype, state: Self::State) {
match state {
#fetch_release_variants
}
}

#[allow(unused_variables, unused_mut)]
fn for_each_borrow(mut f: impl ::core::ops::FnMut(::core::any::TypeId, bool)) {
#fetch_for_each_borrow
}
}
};
})
}

fn query_ty(lifetime: &Lifetime, ty: &Type) -> TokenStream2 {
struct Visitor<'a> {
replace: &'a Lifetime,
}
impl syn::visit_mut::VisitMut for Visitor<'_> {
fn visit_lifetime_mut(&mut self, l: &mut Lifetime) {
if l == self.replace {
*l = Lifetime::new("'static", Span::call_site());
*l = Lifetime::new("'static", Span::mixed_site());
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
fn derive() {
const TEST_DIR: &str = "tests/derive";
let t = trybuild::TestCases::new();
let failures = &["enum.rs", "union.rs", "wrong_lifetime.rs"];
let failures = &["enum_unsupported.rs", "union.rs", "wrong_lifetime.rs"];
let successes = &[
"enum_query.rs",
"unit_structs.rs",
"tuple_structs.rs",
"named_structs.rs",
Expand Down
9 changes: 0 additions & 9 deletions tests/derive/enum.rs

This file was deleted.

Loading

0 comments on commit 96f8289

Please sign in to comment.