diff --git a/crates/bevy_ecs/src/query/iter.rs b/crates/bevy_ecs/src/query/iter.rs index b88b3030e58d14..026599062558a3 100644 --- a/crates/bevy_ecs/src/query/iter.rs +++ b/crates/bevy_ecs/src/query/iter.rs @@ -211,34 +211,144 @@ where { } - if !self - .query_state - .matched_archetypes - .contains(location.archetype_id.index()) - { - continue; - } +/// +/// +/// This struct is created by the [`Query::iter_join_map`](crate::system::Query::iter_join_map) and [`Query::iter_join_map_mut`](crate::system::Query::iter_join_map_mut) methods. +pub struct QueryJoinMapIter<'w, 's, Q: WorldQuery, QQ, F: WorldQuery, I: Iterator, MapFn> +where + QQ: WorldQuery, + MapFn: FnMut(&I::Item) -> Entity, +{ + list: I, + map_f: MapFn, + entities: &'w Entities, + tables: &'w Tables, + archetypes: &'w Archetypes, + fetch: QueryFetch<'w, QQ>, + filter: QueryFetch<'w, F>, + query_state: &'s QueryState, +} - let archetype = &self.archetypes[location.archetype_id]; +impl<'w, 's, Q: WorldQuery, QQ, F: WorldQuery, I: Iterator, MapFn> + QueryJoinMapIter<'w, 's, Q, QQ, F, I, MapFn> +where + QQ: WorldQuery, + MapFn: FnMut(&I::Item) -> Entity, +{ + /// # Safety + /// This does not check for mutable query correctness. To be safe, make sure mutable queries + /// have unique access to the components they query. + /// This does not validate that `world.id()` matches `query_state.world_id`. Calling this on a `world` + /// with a mismatched [`WorldId`](crate::world::WorldId) is unsound. + pub(crate) unsafe fn new>( + world: &'w World, + query_state: &'s QueryState, + last_change_tick: u32, + change_tick: u32, + entity_map: II, + map_f: MapFn, + ) -> QueryJoinMapIter<'w, 's, Q, QQ, F, I, MapFn> { + let fetch = QueryFetch::::init( + world, + &query_state.fetch_state, + last_change_tick, + change_tick, + ); + let filter = QueryFetch::::init( + world, + &query_state.filter_state, + last_change_tick, + change_tick, + ); + QueryJoinMapIter { + query_state, + entities: &world.entities, + archetypes: &world.archetypes, + tables: &world.storages.tables, + fetch, + filter, + list: entity_map.into_iter(), + map_f, + } + } - self.fetch - .set_archetype(&self.query_state.fetch_state, archetype, self.tables); - self.filter - .set_archetype(&self.query_state.filter_state, archetype, self.tables); - if self.filter.archetype_filter_fetch(location.index) { - return Some(self.fetch.archetype_fetch(location.index)); - } + /// SAFETY: + /// The lifetime here is not restrictive enough for Fetch with &mut access, + /// as calling `fetch_next_aliased_unchecked` multiple times can produce multiple + /// references to the same component, leading to unique reference aliasing. + /// + /// It is always safe for immutable borrows. + #[inline(always)] + unsafe fn fetch_next_aliased_unchecked(&mut self) -> Option<(QueryItem<'w, QQ>, I::Item)> { + for item in self.list.by_ref() { + let location = match self.entities.get((self.map_f)(&item)) { + Some(location) => location, + None => continue, + }; + + if !self + .query_state + .matched_archetypes + .contains(location.archetype_id.index()) + { + continue; } - None + + let archetype = &self.archetypes[location.archetype_id]; + + self.fetch + .set_archetype(&self.query_state.fetch_state, archetype, self.tables); + self.filter + .set_archetype(&self.query_state.filter_state, archetype, self.tables); + if self.filter.archetype_filter_fetch(location.index) { + return Some((self.fetch.archetype_fetch(location.index), item)); + } + } + None + } + + /// Get next item from the inner join + #[inline(always)] + pub fn fetch_next(&mut self) -> Option<(QueryItem<'_, QQ>, I::Item)> { + // safety: we are limiting the returned reference to self, + // making sure this method cannot be called multiple times without getting rid + // of any previously returned unique references first, thus preventing aliasing. + unsafe { + self.fetch_next_aliased_unchecked() + .map(|(q_item, item)| (QQ::shrink(q_item), item)) } } +} + +impl<'w, 's, Q: WorldQuery, QQ, F: WorldQuery, I: Iterator, MapFn> Iterator + for QueryJoinMapIter<'w, 'w, Q, QQ, F, I, MapFn> +where + QQ: ReadOnlyWorldQuery, + MapFn: FnMut(&I::Item) -> Entity, +{ + type Item = (QueryItem<'w, QQ>, I::Item); + + #[inline(always)] + fn next(&mut self) -> Option { + // SAFETY: it is safe to alias for ReadOnlyWorldQuery + unsafe { self.fetch_next_aliased_unchecked() } + } fn size_hint(&self) -> (usize, Option) { - let (_, max_size) = self.entity_iter.size_hint(); + let (_, max_size) = self.list.size_hint(); (0, max_size) } } +// This is correct as [`QueryJoinMapIter`] always returns `None` once exhausted. +impl<'w, 's, Q: WorldQuery, QQ, F: WorldQuery, I: Iterator, MapFn: FnMut(&I::Item) -> Entity> + FusedIterator for QueryJoinMapIter<'w, 'w, Q, QQ, F, I, MapFn> +where + QQ: ReadOnlyWorldQuery, + MapFn: FnMut(&I::Item) -> Entity, +{ +} + pub struct QueryCombinationIter<'w, 's, Q: WorldQuery, QQ, F: WorldQuery, const K: usize> where QQ: WorldQuery, diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index f2d99153de4013..1072dd7ae01128 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -558,6 +558,48 @@ impl QueryState { } } + /// Returns an [`Iterator`] over the inner join of the results of a query and list of items mapped to [`Entity`]'s. + /// + /// This can only be called for read-only queries, see [`Self::iter_list_mut`] for write-queries. + #[inline] + pub fn iter_join_map<'w, 's, I: IntoIterator, MapFn: FnMut(&I::Item) -> Entity>( + &'s mut self, + world: &'w World, + list: I, + map_f: MapFn, + ) -> QueryJoinMapIter<'w, 's, Q, Q::ReadOnly, F, I::IntoIter, MapFn> { + self.update_archetypes(world); + unsafe { + self.iter_join_map_unchecked_manual( + world, + world.last_change_tick(), + world.read_change_tick(), + list, + map_f, + ) + } + } + + /// Returns an iterator over the inner join of the results of a query and list of items mapped to [`Entity`]'s. + #[inline] + pub fn iter_join_map_mut<'w, 's, I: IntoIterator, MapFn: FnMut(&I::Item) -> Entity>( + &'s mut self, + world: &'w mut World, + list: I, + map_f: MapFn, + ) -> QueryJoinMapIter<'w, 's, Q, Q, F, I::IntoIter, MapFn> { + self.update_archetypes(world); + unsafe { + self.iter_join_map_unchecked_manual( + world, + world.last_change_tick(), + world.read_change_tick(), + list, + map_f, + ) + } + } + /// Returns an [`Iterator`] over all possible combinations of `K` query results without repetition. /// This can only be called for read-only queries. /// @@ -694,6 +736,30 @@ impl QueryState { QueryEntityListIter::new(world, self, last_change_tick, change_tick, entities) } + /// Returns an [`Iterator`] over each item in an inner join on [`Entity`] between the query and a list of items which are mapped to [`Entity`]'s. + /// + /// # Safety + /// + /// This does not check for mutable query correctness. To be safe, make sure mutable queries + /// have unique access to the components they query. + /// This does not validate that `world.id()` matches `self.world_id`. Calling this on a `world` + /// with a mismatched [`WorldId`] is unsound. + #[inline] + pub(crate) unsafe fn iter_join_map_unchecked_manual< + 'w, + 's, + I: IntoIterator, + MapFn: FnMut(&I::Item) -> Entity, + QQ: WorldQuery, + >( + &'s self, + world: &'w World, + last_change_tick: u32, + change_tick: u32, + list: I, + map_f: MapFn, + ) -> QueryJoinMapIter<'w, 's, Q, QQ, F, I::IntoIter, MapFn> { + QueryJoinMapIter::new(world, self, last_change_tick, change_tick, list, map_f) } /// Returns an [`Iterator`] over all possible combinations of `K` query results for the diff --git a/crates/bevy_ecs/src/system/mod.rs b/crates/bevy_ecs/src/system/mod.rs index 276f53a47dac89..0698ee910ab83f 100644 --- a/crates/bevy_ecs/src/system/mod.rs +++ b/crates/bevy_ecs/src/system/mod.rs @@ -106,6 +106,7 @@ mod tests { bundle::Bundles, component::{Component, Components}, entity::{Entities, Entity}, + event::{EventReader, Events}, prelude::AnyOf, query::{Added, Changed, Or, With, Without}, schedule::{Schedule, Stage, SystemStage}, @@ -1045,4 +1046,55 @@ mod tests { } run_system(&mut world, check_2); } + + #[test] + fn iter_join_map() { + struct DamageEvent { + target: Entity, + damage: f32, + } + + #[derive(Component, PartialEq, Debug)] + struct Health(f32); + + let mut world = World::new(); + + let e1 = world.spawn().insert(Health(200.)).id(); + let e2 = world.spawn().insert(Health(100.)).id(); + + let mut events = Events::::default(); + + events.extend([ + DamageEvent { + target: e1, + damage: 50., + }, + DamageEvent { + target: e2, + damage: 80., + }, + DamageEvent { + target: e1, + damage: 150., + }, + ]); + + world.insert_resource(events); + + fn process_damage_events( + mut health_query: Query<&mut Health>, + mut event_reader: EventReader, + ) { + let mut join = + health_query.iter_join_map_mut(event_reader.iter(), |event| event.target); + while let Some((mut health, event)) = join.fetch_next() { + health.0 -= event.damage; + } + } + run_system(&mut world, process_damage_events); + + run_system(&mut world, move |health_query: Query<&Health>| { + assert_eq!([&Health(0.), &Health(20.)], health_query.many([e1, e2])); + }); + } } diff --git a/crates/bevy_ecs/src/system/query.rs b/crates/bevy_ecs/src/system/query.rs index a1f8740d3c36b3..e809f749d96d8c 100644 --- a/crates/bevy_ecs/src/system/query.rs +++ b/crates/bevy_ecs/src/system/query.rs @@ -3,8 +3,8 @@ use crate::{ entity::Entity, query::{ NopFetch, QueryCombinationIter, QueryEntityError, QueryEntityListIter, QueryFetch, - QueryItem, QueryIter, QuerySingleError, QueryState, ROQueryFetch, ROQueryItem, - ReadOnlyWorldQuery, WorldQuery, + QueryItem, QueryIter, QueryJoinMapIter, QuerySingleError, QueryState, ROQueryFetch, + ROQueryItem, ReadOnlyWorldQuery, WorldQuery, }, world::{Mut, World}, }; @@ -487,9 +487,93 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> { } } + /// Returns an [`Iterator`] over the inner join of the results of a [`Query`] and list of items mapped to [`Entity`]'s. + /// + /// # Example + /// ``` + /// # use bevy_ecs::prelude::*; + /// # #[derive(Component)] + /// # struct Health { + /// # value: f32 + /// # } + /// #[derive(Component)] + /// struct DamageEvent { + /// target: Entity, + /// damage: f32, + /// } + /// + /// fn system( + /// mut damage_events: EventReader, + /// health_query: Query<&Health>, + /// ) { + /// for (health, event) in + /// health_query.iter_join_map(damage_events.iter(), |event| event.target) + /// { + /// println!("Entity has {} health and will take {} damage!", health.value, event.damage); + /// } + /// } + /// # bevy_ecs::system::assert_is_system(system); + /// ``` + #[inline] + pub fn iter_join_map Entity>( + &self, + list: I, + map_f: MapFn, + ) -> QueryJoinMapIter<'w, 's, Q, Q::ReadOnly, F, I::IntoIter, MapFn> { + // SAFETY: system runs without conflicts with other systems. + // same-system queries have runtime borrow checks when they conflict + unsafe { + self.state.iter_join_map_unchecked_manual( + self.world, + self.last_change_tick, + self.change_tick, + list, + map_f, + ) + } + } + + /// Returns an iterator over the inner join of the results of a [`Query`] and list of items mapped to [`Entity`]'s. + /// + /// # Example + /// ``` + /// # use bevy_ecs::prelude::*; + /// # #[derive(Component)] + /// # struct Health { + /// # value: f32 + /// # } + /// #[derive(Component)] + /// struct DamageEvent { + /// target: Entity, + /// damage: f32, + /// } + /// + /// fn system( + /// mut damage_events: EventReader, + /// mut health_query: Query<&mut Health>, + /// ) { + /// let mut join = health_query.iter_join_map_mut(damage_events.iter(), |event| event.target); + /// while let Some((mut health, event)) = join.fetch_next() { + /// health.value -= event.damage; + /// } + /// } + /// # bevy_ecs::system::assert_is_system(system); + /// ``` + #[inline] + pub fn iter_join_map_mut Entity>( + &mut self, + list: I, + map_f: MapFn, + ) -> QueryJoinMapIter<'_, '_, Q, Q, F, I::IntoIter, MapFn> { + // SAFETY: system runs without conflicts with other systems. + // same-system queries have runtime borrow checks when they conflict + unsafe { + self.state.iter_join_map_unchecked_manual( self.world, self.last_change_tick, self.change_tick, + list, + map_f, ) } } diff --git a/crates/bevy_ecs_compile_fail_tests/tests/ui/system_query_iter_join_map_mut_lifetime_safety.rs b/crates/bevy_ecs_compile_fail_tests/tests/ui/system_query_iter_join_map_mut_lifetime_safety.rs new file mode 100644 index 00000000000000..5b5e665fa8daab --- /dev/null +++ b/crates/bevy_ecs_compile_fail_tests/tests/ui/system_query_iter_join_map_mut_lifetime_safety.rs @@ -0,0 +1,15 @@ +use bevy_ecs::prelude::*; + +#[derive(Component)] +struct A(usize); + +fn system(mut query: Query<&mut A>, e: Res) { + let mut results = Vec::new(); + let mut iter = query.iter_join_map_mut([*e, *e], |e| *e); + while let Some(a) = iter.fetch_next() { + // this should fail to compile + results.push(a); + } +} + +fn main() {} diff --git a/crates/bevy_ecs_compile_fail_tests/tests/ui/system_query_iter_join_map_mut_lifetime_safety.stderr b/crates/bevy_ecs_compile_fail_tests/tests/ui/system_query_iter_join_map_mut_lifetime_safety.stderr new file mode 100644 index 00000000000000..0238cf418a2ee1 --- /dev/null +++ b/crates/bevy_ecs_compile_fail_tests/tests/ui/system_query_iter_join_map_mut_lifetime_safety.stderr @@ -0,0 +1,5 @@ +error[E0499]: cannot borrow `iter` as mutable more than once at a time + --> tests/ui/system_query_iter_join_map_mut_lifetime_safety.rs:9:25 + | +9 | while let Some(a) = iter.fetch_next() { + | ^^^^^^^^^^^^^^^^^ `iter` was mutably borrowed here in the previous iteration of the loop