diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 96d1d164225..836a68a5c76 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -165,6 +165,9 @@ mod state { /// Snapshot of the state. The first bit is used as the CLOSED bit. /// The remaining bits are used as the version. + /// + /// The CLOSED bit tracks whether the Sender has been dropped. Dropping all + /// receivers does not set it. #[derive(Copy, Clone, Debug)] pub(super) struct StateSnapshot(usize); @@ -427,7 +430,7 @@ impl Sender { /// every receiver has been dropped. pub fn send(&self, value: T) -> Result<(), error::SendError> { // This is pretty much only useful as a hint anyway, so synchronization isn't critical. - if 0 == self.shared.ref_count_rx.load(Relaxed) { + if 0 == self.receiver_count() { return Err(error::SendError { inner: value }); } @@ -484,7 +487,7 @@ impl Sender { /// assert!(tx.is_closed()); /// ``` pub fn is_closed(&self) -> bool { - self.shared.ref_count_rx.load(Relaxed) == 0 + self.receiver_count() == 0 } /// Completes when all receivers have dropped. @@ -517,23 +520,81 @@ impl Sender { /// } /// ``` pub async fn closed(&self) { - let notified = self.shared.notify_tx.notified(); + while self.receiver_count() > 0 { + let notified = self.shared.notify_tx.notified(); - if self.shared.ref_count_rx.load(Relaxed) == 0 { - return; - } + if self.receiver_count() == 0 { + return; + } - notified.await; - debug_assert_eq!(0, self.shared.ref_count_rx.load(Relaxed)); + notified.await; + // The channel could have been reopened in the meantime by calling + // `subscribe`, so we loop again. + } } - cfg_signal_internal! { - pub(crate) fn subscribe(&self) -> Receiver { - let shared = self.shared.clone(); - let version = shared.state.load().version(); + /// Creates a new [`Receiver`] connected to this `Sender`. + /// + /// All messages sent before this call to `subscribe` are initially marked + /// as seen by the new `Receiver`. + /// + /// This method can be called even if there are no other receivers. In this + /// case, the channel is reopened. + /// + /// # Examples + /// + /// The new channel will receive messages sent on this `Sender`. + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = watch::channel(0u64); + /// + /// tx.send(5).unwrap(); + /// + /// let rx = tx.subscribe(); + /// assert_eq!(5, *rx.borrow()); + /// + /// tx.send(10).unwrap(); + /// assert_eq!(10, *rx.borrow()); + /// } + /// ``` + /// + /// The most recent message is considered seen by the channel, so this test + /// is guaranteed to pass. + /// + /// ``` + /// use tokio::sync::watch; + /// use tokio::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = watch::channel(0u64); + /// tx.send(5).unwrap(); + /// let mut rx = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// // by spawning and sleeping, the message is sent after `main` + /// // hits the call to `changed`. + /// # if false { + /// tokio::time::sleep(Duration::from_millis(10)).await; + /// # } + /// tx.send(100).unwrap(); + /// }); + /// + /// rx.changed().await.unwrap(); + /// assert_eq!(100, *rx.borrow()); + /// } + /// ``` + pub fn subscribe(&self) -> Receiver { + let shared = self.shared.clone(); + let version = shared.state.load().version(); - Receiver::from_shared(version, shared) - } + // The CLOSED bit in the state tracks only whether the sender is + // dropped, so we do not need to unset it if this reopens the channel. + Receiver::from_shared(version, shared) } /// Returns the number of receivers that currently exist diff --git a/tokio/tests/sync_watch.rs b/tokio/tests/sync_watch.rs index a2a276d8beb..b7bbaf721c1 100644 --- a/tokio/tests/sync_watch.rs +++ b/tokio/tests/sync_watch.rs @@ -186,3 +186,18 @@ fn borrow_and_update() { assert_eq!(*rx.borrow_and_update(), "three"); assert_ready!(spawn(rx.changed()).poll()).unwrap_err(); } + +#[test] +fn reopened_after_subscribe() { + let (tx, rx) = watch::channel("one"); + assert!(!tx.is_closed()); + + drop(rx); + assert!(tx.is_closed()); + + let rx = tx.subscribe(); + assert!(!tx.is_closed()); + + drop(rx); + assert!(tx.is_closed()); +}