diff --git a/crates/bevy_ecs/src/schedule/mod.rs b/crates/bevy_ecs/src/schedule/mod.rs index 5532c0f4d5c14..76ecfd4e091c0 100644 --- a/crates/bevy_ecs/src/schedule/mod.rs +++ b/crates/bevy_ecs/src/schedule/mod.rs @@ -2,6 +2,7 @@ mod executor; mod executor_parallel; mod stage; mod state; +mod state_set; mod system_container; mod system_descriptor; mod system_set; @@ -10,6 +11,7 @@ pub use executor::*; pub use executor_parallel::*; pub use stage::*; pub use state::*; +pub use state_set::*; pub use system_container::*; pub use system_descriptor::*; pub use system_set::*; diff --git a/crates/bevy_ecs/src/schedule/state_set.rs b/crates/bevy_ecs/src/schedule/state_set.rs new file mode 100644 index 0000000000000..2c72a1aa4ccf2 --- /dev/null +++ b/crates/bevy_ecs/src/schedule/state_set.rs @@ -0,0 +1,372 @@ +use std::{ + any::TypeId, + marker::PhantomData, + mem::{discriminant, Discriminant}, +}; + +use bevy_utils::HashMap; + +use crate::{ + ArchetypeComponent, IntoSystem, ResMut, Resource, ShouldRun, System, SystemDescriptor, + SystemId, SystemSet, SystemStage, TypeAccess, +}; +#[derive(Debug)] +pub struct SetState { + transition: Option>, + stack: Vec, + scheduled: Option>, +} + +#[derive(Debug)] +enum StateTransition { + ExitingToResume(T, T), + ExitingFull(T, T), + Entering(T, T), + Resuming(T, T), + Pausing(T, T), +} + +#[derive(Debug)] +pub enum ScheduledOperation { + Next(T), + Pop, + Push(T), +} + +impl SetState { + fn on_update(d: Discriminant) -> impl System { + Wrapper::::new(d) + } + + fn on_enter(d: Discriminant) -> impl System { + Wrapper::::new(d) + } + + fn on_exit(d: Discriminant) -> impl System { + Wrapper::::new(d) + } + + fn on_pause(d: Discriminant) -> impl System { + Wrapper::::new(d) + } + + fn on_resume(d: Discriminant) -> impl System { + Wrapper::::new(d) + } + + pub fn schedule_operation( + &mut self, + val: ScheduledOperation, + ) -> Option> { + self.scheduled.replace(val) + } + + pub fn new(val: T) -> Self { + Self { + stack: vec![val], + transition: None, + scheduled: None, + } + } + + pub fn current(&self) -> &T { + self.stack.last().unwrap() + } +} + +trait Comparer { + fn compare(d: Discriminant, s: &SetState) -> bool; +} + +struct OnUpdate; +impl Comparer for OnUpdate { + fn compare(d: Discriminant, s: &SetState) -> bool { + discriminant(s.stack.last().unwrap()) == d && s.transition.is_none() + } +} +struct OnEnter; +impl Comparer for OnEnter { + fn compare(d: Discriminant, s: &SetState) -> bool { + s.transition + .as_ref() + .map_or(false, |transition| match transition { + StateTransition::Entering(_, entering) => discriminant(entering) == d, + _ => false, + }) + } +} +struct OnExit; +impl Comparer for OnExit { + fn compare(d: Discriminant, s: &SetState) -> bool { + s.transition + .as_ref() + .map_or(false, |transition| match transition { + StateTransition::ExitingToResume(exiting, _) + | StateTransition::ExitingFull(exiting, _) => discriminant(exiting) == d, + _ => false, + }) + } +} +struct OnPause; +impl Comparer for OnPause { + fn compare(d: Discriminant, s: &SetState) -> bool { + s.transition + .as_ref() + .map_or(false, |transition| match transition { + StateTransition::Pausing(pausing, _) => discriminant(pausing) == d, + _ => false, + }) + } +} +struct OnResume; +impl Comparer for OnResume { + fn compare(d: Discriminant, s: &SetState) -> bool { + s.transition + .as_ref() + .map_or(false, |transition| match transition { + StateTransition::Resuming(_, resuming) => discriminant(resuming) == d, + _ => false, + }) + } +} + +impl> Wrapper { + fn new(discriminant: Discriminant) -> Self { + let mut resource_access = TypeAccess::default(); + resource_access.add_read(std::any::TypeId::of::>()); + Self { + discriminant, + exit_flag: false, + resource_access, + id: SystemId::new(), + archetype_access: Default::default(), + component_access: Default::default(), + marker: Default::default(), + } + } +} + +struct Wrapper> { + discriminant: Discriminant, + exit_flag: bool, + resource_access: TypeAccess, + id: SystemId, + archetype_access: TypeAccess, + component_access: TypeAccess, + marker: PhantomData, +} + +impl + Resource> System for Wrapper { + type In = (); + type Out = ShouldRun; + + fn name(&self) -> std::borrow::Cow<'static, str> { + std::borrow::Cow::Owned(format!( + "State checker for state {}", + std::any::type_name::() + )) + } + + fn id(&self) -> crate::SystemId { + self.id + } + + fn archetype_component_access(&self) -> &TypeAccess { + &self.archetype_access + } + + fn resource_access(&self) -> &TypeAccess { + &self.resource_access + } + + fn component_access(&self) -> &TypeAccess { + &self.component_access + } + + fn is_non_send(&self) -> bool { + false + } + + unsafe fn run_unsafe( + &mut self, + _input: Self::In, + _world: &crate::World, + resources: &crate::Resources, + ) -> Option { + let state = &*resources.get::>().unwrap(); + if state.transition.is_some() { + self.exit_flag = false; + } + if self.exit_flag { + self.exit_flag = false; + Some(ShouldRun::No) + } else { + self.exit_flag = true; + Some(if C::compare(self.discriminant, state) { + ShouldRun::YesAndCheckAgain + } else { + ShouldRun::NoAndCheckAgain + }) + } + } + + fn update_access(&mut self, _world: &crate::World) {} + + fn apply_buffers(&mut self, _world: &mut crate::World, _resources: &mut crate::Resources) {} + + fn initialize(&mut self, _world: &mut crate::World, _resources: &mut crate::Resources) {} +} + +pub struct StateSetBuilder { + on_update: HashMap, SystemSet>, + on_enter: HashMap, SystemSet>, + on_exit: HashMap, SystemSet>, + on_pause: HashMap, SystemSet>, + on_resume: HashMap, SystemSet>, +} + +impl Default for StateSetBuilder { + fn default() -> Self { + Self { + on_update: Default::default(), + on_enter: Default::default(), + on_exit: Default::default(), + on_pause: Default::default(), + on_resume: Default::default(), + } + } +} + +impl StateSetBuilder { + pub fn add_on_update(&mut self, v: T, system: impl Into) -> &mut Self { + self.on_update + .entry(discriminant(&v)) + .or_default() + .add_system(system); + self + } + + pub fn add_on_enter(&mut self, v: T, system: impl Into) -> &mut Self { + self.on_enter + .entry(discriminant(&v)) + .or_default() + .add_system(system); + self + } + + pub fn add_on_exit(&mut self, v: T, system: impl Into) -> &mut Self { + self.on_exit + .entry(discriminant(&v)) + .or_default() + .add_system(system); + self + } + + pub fn add_on_pause(&mut self, v: T, system: impl Into) -> &mut Self { + self.on_pause + .entry(discriminant(&v)) + .or_default() + .add_system(system); + self + } + + pub fn add_on_resume(&mut self, v: T, system: impl Into) -> &mut Self { + self.on_resume + .entry(discriminant(&v)) + .or_default() + .add_system(system); + self + } + + pub fn with_on_update(mut self, v: T, system: impl Into) -> Self { + self.add_on_update(v, system); + self + } + + pub fn with_on_enter(mut self, v: T, system: impl Into) -> Self { + self.add_on_enter(v, system); + self + } + + pub fn with_on_exit(mut self, v: T, system: impl Into) -> Self { + self.add_on_exit(v, system); + self + } + + pub fn with_on_pause(mut self, v: T, system: impl Into) -> Self { + self.add_on_pause(v, system); + self + } + + pub fn with_on_resume(mut self, v: T, system: impl Into) -> Self { + self.add_on_resume(v, system); + self + } + + pub fn finalize(self, stage: &mut SystemStage) { + fn state_cleaner(mut state: ResMut>) -> ShouldRun { + match state.scheduled.take() { + Some(ScheduledOperation::Next(next)) => { + if state.stack.len() == 1 { + let previous = + std::mem::replace(state.stack.last_mut().unwrap(), next.clone()); + state.transition = Some(StateTransition::ExitingFull(previous, next)); + } else { + state.scheduled = Some(ScheduledOperation::Next(next)); + match state.transition.take() { + Some(StateTransition::ExitingToResume(p, n)) => { + state.transition = Some(StateTransition::Resuming(p, n)); + } + _ => { + state.transition = Some(StateTransition::ExitingToResume( + state.stack.pop().unwrap(), + state.stack.last().unwrap().clone(), + )); + } + } + } + } + Some(ScheduledOperation::Push(next)) => { + let last = state.stack.last().unwrap().clone(); + state.stack.push(next.clone()); + state.transition = Some(StateTransition::Pausing(last, next)); + } + Some(ScheduledOperation::Pop) => { + state.transition = Some(StateTransition::ExitingToResume( + state.stack.pop().unwrap(), + state.stack.last().unwrap().clone(), + )); + } + None => match state.transition.take() { + Some(StateTransition::ExitingFull(p, n)) + | Some(StateTransition::Pausing(p, n)) => { + state.transition = Some(StateTransition::Entering(p, n)); + } + Some(StateTransition::ExitingToResume(p, n)) => { + state.transition = Some(StateTransition::Resuming(p, n)); + } + _ => return ShouldRun::Yes, + }, + }; + ShouldRun::YesAndCheckAgain + } + + for (val, set) in self.on_enter.into_iter() { + stage.add_system_set(set.with_run_criteria(SetState::::on_enter(val))); + } + for (val, set) in self.on_update.into_iter() { + stage.add_system_set(set.with_run_criteria(SetState::::on_update(val))); + } + for (val, set) in self.on_exit.into_iter() { + stage.add_system_set(set.with_run_criteria(SetState::::on_exit(val))); + } + for (val, set) in self.on_pause.into_iter() { + stage.add_system_set(set.with_run_criteria(SetState::::on_pause(val))); + } + for (val, set) in self.on_resume.into_iter() { + stage.add_system_set(set.with_run_criteria(SetState::::on_resume(val))); + } + + stage.add_system_set(SystemSet::default().with_run_criteria(state_cleaner::.system())); + } +}