Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Added method to restart the current state #3328

Closed
wants to merge 7 commits into from
Closed
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions crates/bevy_ecs/src/schedule/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ impl<T> StateData for T where T: Send + Sync + Clone + Eq + Debug + Hash + 'stat
#[derive(Debug)]
pub struct State<T: StateData> {
transition: Option<StateTransition<T>>,
/// The current states in the stack.
///
/// There is always guaranteed to be at least one.
stack: Vec<T>,
scheduled: Option<ScheduledOperation<T>>,
end_next_loop: bool,
Expand Down Expand Up @@ -369,6 +372,26 @@ where
Ok(())
}

/// Schedule a state change that restarts the active state.
/// This will fail if there is a scheduled operation
pub fn restart(&mut self) -> Result<(), StateError> {
if self.scheduled.is_some() {
return Err(StateError::StateAlreadyQueued);
}

let state = self.stack.last().unwrap();
self.scheduled = Some(ScheduledOperation::Set(state.clone()));
Ok(())
}

/// Same as [Self::restart], but if there is already a scheduled state operation,
/// it will be overwritten instead of failing
pub fn overwrite_restart(&mut self) -> Result<(), StateError> {
MrGVSV marked this conversation as resolved.
Show resolved Hide resolved
let state = self.stack.last().unwrap();
self.scheduled = Some(ScheduledOperation::Set(state.clone()));
Ok(())
}

pub fn current(&self) -> &T {
self.stack.last().unwrap()
}
Expand Down Expand Up @@ -655,4 +678,90 @@ mod test {
stage.run(&mut world);
assert!(*world.get_resource::<bool>().unwrap(), "after test");
}

#[test]
fn restart_state_tests() {
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
enum LoadState {
Load,
Finish,
}

#[derive(PartialEq, Eq, Debug)]
enum LoadStatus {
EnterLoad,
ExitLoad,
EnterFinish,
}

let mut world = World::new();
world.insert_resource(Vec::<LoadStatus>::new());
world.insert_resource(State::new(LoadState::Load));

let mut stage = SystemStage::parallel();
stage.add_system_set(State::<LoadState>::get_driver());

// Systems to track loading status
stage
.add_system_set(
State::on_enter_set(LoadState::Load)
.with_system(|mut r: ResMut<Vec<LoadStatus>>| r.push(LoadStatus::EnterLoad)),
)
.add_system_set(
State::on_exit_set(LoadState::Load)
.with_system(|mut r: ResMut<Vec<LoadStatus>>| r.push(LoadStatus::ExitLoad)),
)
.add_system_set(
State::on_enter_set(LoadState::Finish)
.with_system(|mut r: ResMut<Vec<LoadStatus>>| r.push(LoadStatus::EnterFinish)),
);

stage.run(&mut world);

// A. Restart state
let mut state = world.get_resource_mut::<State<LoadState>>().unwrap();
let result = state.restart();
assert!(matches!(result, Ok(())));
stage.run(&mut world);

// B. Restart state (overwrite schedule)
let mut state = world.get_resource_mut::<State<LoadState>>().unwrap();
state.set(LoadState::Finish).unwrap();
let result = state.overwrite_restart();
assert!(matches!(result, Ok(())));
stage.run(&mut world);

// C. Fail restart state (transition already scheduled)
let mut state = world.get_resource_mut::<State<LoadState>>().unwrap();
state.set(LoadState::Finish).unwrap();
let result = state.restart();
assert!(matches!(result, Err(StateError::StateAlreadyQueued)));
stage.run(&mut world);

const EXPECTED: &[LoadStatus] = &[
LoadStatus::EnterLoad,
// A
LoadStatus::ExitLoad,
LoadStatus::EnterLoad,
// B
LoadStatus::ExitLoad,
LoadStatus::EnterLoad,
// C
LoadStatus::ExitLoad,
LoadStatus::EnterFinish,
];

let mut collected = world.get_resource_mut::<Vec<LoadStatus>>().unwrap();
let mut count = 0;
for (found, expected) in collected.drain(..).zip(EXPECTED) {
assert_eq!(found, *expected);
count += 1;
}
// If not equal, some elements weren't executed
assert_eq!(EXPECTED.len(), count);
assert_eq!(
world.get_resource::<State<LoadState>>().unwrap().current(),
&LoadState::Finish
);
}
}