diff --git a/crates/wasmtime/src/store.rs b/crates/wasmtime/src/store.rs index 9a322fd6a525..8b7f5f3e6175 100644 --- a/crates/wasmtime/src/store.rs +++ b/crates/wasmtime/src/store.rs @@ -425,11 +425,13 @@ enum OutOfGas { /// What to do when the engine epoch reaches the deadline for a Store /// during execution of a function using that store. +#[derive(Default)] enum EpochDeadline { /// Return early with a trap. + #[default] Trap, /// Call a custom deadline handler. - Callback(Box Result + Send + Sync>), + Callback(Box) -> Result + Send + Sync>), /// Extend the deadline by the specified number of ticks after /// yielding to the async executor loop. #[cfg(feature = "async")] @@ -932,7 +934,7 @@ impl Store { /// for an introduction to epoch-based interruption. pub fn epoch_deadline_callback( &mut self, - callback: impl FnMut(&mut T) -> Result + Send + Sync + 'static, + callback: impl FnMut(StoreContextMut) -> Result + Send + Sync + 'static, ) { self.inner.epoch_deadline_callback(Box::new(callback)); } @@ -1975,10 +1977,13 @@ unsafe impl wasmtime_runtime::Store for StoreInner { } fn new_epoch(&mut self) -> Result { - return match &mut self.epoch_deadline_behavior { + // Temporarily take the configured behavior to avoid mutably borrowing + // multiple times. + let mut behavior = std::mem::take(&mut self.epoch_deadline_behavior); + let delta_result = match &mut behavior { EpochDeadline::Trap => Err(Trap::Interrupt.into()), EpochDeadline::Callback(callback) => { - let delta = callback(&mut self.data)?; + let delta = callback((&mut *self).as_context_mut())?; // Set a new deadline and return the new epoch deadline so // the Wasm code doesn't have to reload it. self.set_epoch_deadline(delta); @@ -1998,6 +2003,10 @@ unsafe impl wasmtime_runtime::Store for StoreInner { Ok(self.get_epoch_deadline()) } }; + + // Put back the original behavior which was replaced by `take`. + self.epoch_deadline_behavior = behavior; + delta_result } } @@ -2022,7 +2031,7 @@ impl StoreInner { fn epoch_deadline_callback( &mut self, - callback: Box Result + Send + Sync>, + callback: Box) -> Result + Send + Sync>, ) { self.epoch_deadline_behavior = EpochDeadline::Callback(callback); } diff --git a/tests/all/epoch_interruption.rs b/tests/all/epoch_interruption.rs index c34a8b7f755d..a340ef658341 100644 --- a/tests/all/epoch_interruption.rs +++ b/tests/all/epoch_interruption.rs @@ -32,7 +32,7 @@ fn make_env(engine: &Engine) -> Linker { enum InterruptMode { Trap, - Callback(fn(&mut usize) -> Result), + Callback(fn(StoreContextMut) -> Result), Yield(u64), } @@ -334,7 +334,8 @@ async fn epoch_callback_continue() { (func $subfunc)) ", 1, - InterruptMode::Callback(|s| { + InterruptMode::Callback(|mut cx| { + let s = cx.data_mut(); *s += 1; Ok(1) }),