diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index 1300a67e2a1..46849e136f6 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -854,8 +854,8 @@ impl BackgroundProcessor { peer_manager.onion_message_handler().process_pending_events(&event_handler), gossip_sync, logger, scorer, stop_thread.load(Ordering::Acquire), { Sleeper::from_two_futures( - channel_manager.get_event_or_persistence_needed_future(), - chain_monitor.get_update_future() + &channel_manager.get_event_or_persistence_needed_future(), + &chain_monitor.get_update_future() ).wait_timeout(Duration::from_millis(100)); }, |_| Instant::now(), |time: &Instant, dur| time.elapsed().as_secs() > dur, false, || { diff --git a/lightning/src/util/wakers.rs b/lightning/src/util/wakers.rs index d220aa02dfd..b2c9d21b998 100644 --- a/lightning/src/util/wakers.rs +++ b/lightning/src/util/wakers.rs @@ -180,16 +180,16 @@ impl Future { /// Waits until this [`Future`] completes. #[cfg(feature = "std")] - pub fn wait(self) { - Sleeper::from_single_future(self).wait(); + pub fn wait(&self) { + Sleeper::from_single_future(&self).wait(); } /// Waits until this [`Future`] completes or the given amount of time has elapsed. /// /// Returns true if the [`Future`] completed, false if the time elapsed. #[cfg(feature = "std")] - pub fn wait_timeout(self, max_wait: Duration) -> bool { - Sleeper::from_single_future(self).wait_timeout(max_wait) + pub fn wait_timeout(&self, max_wait: Duration) -> bool { + Sleeper::from_single_future(&self).wait_timeout(max_wait) } #[cfg(test)] @@ -202,6 +202,12 @@ impl Future { } } +impl Drop for Future { + fn drop(&mut self) { + self.state.lock().unwrap().std_future_callbacks.retain(|(idx, _)| *idx != self.self_idx); + } +} + use core::task::Waker; struct StdWaker(pub Waker); @@ -216,6 +222,7 @@ impl<'a> StdFuture for Future { Poll::Ready(()) } else { let waker = cx.waker().clone(); + state.std_future_callbacks.retain(|(idx, _)| *idx != self.self_idx); state.std_future_callbacks.push((self.self_idx, StdWaker(waker))); Poll::Pending } @@ -232,17 +239,17 @@ pub struct Sleeper { #[cfg(feature = "std")] impl Sleeper { /// Constructs a new sleeper from one future, allowing blocking on it. - pub fn from_single_future(future: Future) -> Self { - Self { notifiers: vec![future.state] } + pub fn from_single_future(future: &Future) -> Self { + Self { notifiers: vec![Arc::clone(&future.state)] } } /// Constructs a new sleeper from two futures, allowing blocking on both at once. // Note that this is the common case - a ChannelManager and ChainMonitor. - pub fn from_two_futures(fut_a: Future, fut_b: Future) -> Self { - Self { notifiers: vec![fut_a.state, fut_b.state] } + pub fn from_two_futures(fut_a: &Future, fut_b: &Future) -> Self { + Self { notifiers: vec![Arc::clone(&fut_a.state), Arc::clone(&fut_b.state)] } } /// Constructs a new sleeper on many futures, allowing blocking on all at once. pub fn new(futures: Vec) -> Self { - Self { notifiers: futures.into_iter().map(|f| f.state).collect() } + Self { notifiers: futures.into_iter().map(|f| Arc::clone(&f.state)).collect() } } /// Prepares to go into a wait loop body, creating a condition variable which we can block on /// and an `Arc>>` which gets set to the waking `Future`'s state prior to the @@ -447,13 +454,15 @@ mod tests { // Wait on the other thread to finish its sleep, note that the leak only happened if we // actually have to sleep here, not if we immediately return. - Sleeper::from_two_futures(future_a, future_b).wait(); + Sleeper::from_two_futures(&future_a, &future_b).wait(); join_handle.join().unwrap(); // then drop the notifiers and make sure the future states are gone. mem::drop(notifier_a); mem::drop(notifier_b); + mem::drop(future_a); + mem::drop(future_b); assert!(future_state_a.upgrade().is_none() && future_state_b.upgrade().is_none()); } @@ -655,18 +664,18 @@ mod tests { // Set both notifiers as woken without sleeping yet. notifier_a.notify(); notifier_b.notify(); - Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait(); + Sleeper::from_two_futures(¬ifier_a.get_future(), ¬ifier_b.get_future()).wait(); // One future has woken us up, but the other should still have a pending notification. - Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait(); + Sleeper::from_two_futures(¬ifier_a.get_future(), ¬ifier_b.get_future()).wait(); // However once we've slept twice, we should no longer have any pending notifications - assert!(!Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()) + assert!(!Sleeper::from_two_futures(¬ifier_a.get_future(), ¬ifier_b.get_future()) .wait_timeout(Duration::from_millis(10))); // Test ordering somewhat more. notifier_a.notify(); - Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait(); + Sleeper::from_two_futures(¬ifier_a.get_future(), ¬ifier_b.get_future()).wait(); } #[test] @@ -684,7 +693,7 @@ mod tests { // After sleeping one future (not guaranteed which one, however) will have its notification // bit cleared. - Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait(); + Sleeper::from_two_futures(¬ifier_a.get_future(), ¬ifier_b.get_future()).wait(); // By registering a callback on the futures for both notifiers, one will complete // immediately, but one will remain tied to the notifier, and will complete once the @@ -703,8 +712,48 @@ mod tests { notifier_b.notify(); assert!(callback_a.load(Ordering::SeqCst) && callback_b.load(Ordering::SeqCst)); - Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait(); - assert!(!Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()) + Sleeper::from_two_futures(¬ifier_a.get_future(), ¬ifier_b.get_future()).wait(); + assert!(!Sleeper::from_two_futures(¬ifier_a.get_future(), ¬ifier_b.get_future()) .wait_timeout(Duration::from_millis(10))); } + + #[test] + #[cfg(feature = "std")] + fn multi_poll_stores_single_waker() { + // When a `Future` is `poll()`ed multiple times, only the last `Waker` should be called, + // but previously we'd store all `Waker`s until they're all woken at once. This tests a few + // cases to ensure `Future`s avoid storing an endless set of `Waker`s. + let notifier = Notifier::new(); + let future_state = Arc::clone(¬ifier.get_future().state); + assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 0); + + // Test that simply polling a future twice doesn't result in two pending `Waker`s. + let mut future_a = notifier.get_future(); + assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending); + assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1); + assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending); + assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1); + + // If we poll a second future, however, that will store a second `Waker`. + let mut future_b = notifier.get_future(); + assert_eq!(Pin::new(&mut future_b).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending); + assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 2); + + // but when we drop the `Future`s, the pending Wakers will also be dropped. + mem::drop(future_a); + assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1); + mem::drop(future_b); + assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 0); + + // Further, after polling a future twice, if the notifier is woken all Wakers are dropped. + let mut future_a = notifier.get_future(); + assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending); + assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1); + assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending); + assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1); + notifier.notify(); + assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 0); + assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Ready(())); + assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 0); + } }