diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index 3cdba3dc237..b7b1ce7f623 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -1409,6 +1409,16 @@ impl Sender { pub fn max_capacity(&self) -> usize { self.chan.semaphore().bound } + + /// Returns the number of [`Sender`] handles. + pub fn strong_count(&self) -> usize { + self.chan.strong_count() + } + + /// Returns the number of [`WeakSender`] handles. + pub fn weak_count(&self) -> usize { + self.chan.weak_count() + } } impl Clone for Sender { @@ -1429,12 +1439,20 @@ impl fmt::Debug for Sender { impl Clone for WeakSender { fn clone(&self) -> Self { + self.chan.increment_weak_count(); + WeakSender { chan: self.chan.clone(), } } } +impl Drop for WeakSender { + fn drop(&mut self) { + self.chan.decrement_weak_count(); + } +} + impl WeakSender { /// Tries to convert a `WeakSender` into a [`Sender`]. This will return `Some` /// if there are other `Sender` instances alive and the channel wasn't @@ -1442,6 +1460,16 @@ impl WeakSender { pub fn upgrade(&self) -> Option> { chan::Tx::upgrade(self.chan.clone()).map(Sender::new) } + + /// Returns the number of [`Sender`] handles. + pub fn strong_count(&self) -> usize { + self.chan.strong_count() + } + + /// Returns the number of [`WeakSender`] handles. + pub fn weak_count(&self) -> usize { + self.chan.weak_count() + } } impl fmt::Debug for WeakSender { diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index c05a4abb7c0..179a69f5700 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -66,6 +66,9 @@ pub(super) struct Chan { /// When this drops to zero, the send half of the channel is closed. tx_count: AtomicUsize, + /// Tracks the number of outstanding weak sender handles. + tx_weak_count: AtomicUsize, + /// Only accessed by `Rx` handle. rx_fields: UnsafeCell>, } @@ -115,6 +118,7 @@ pub(crate) fn channel(semaphore: S) -> (Tx, Rx) { semaphore, rx_waker: CachePadded::new(AtomicWaker::new()), tx_count: AtomicUsize::new(1), + tx_weak_count: AtomicUsize::new(0), rx_fields: UnsafeCell::new(RxFields { list: rx, rx_closed: false, @@ -131,7 +135,17 @@ impl Tx { Tx { inner: chan } } + pub(super) fn strong_count(&self) -> usize { + self.inner.tx_count.load(Acquire) + } + + pub(super) fn weak_count(&self) -> usize { + self.inner.tx_weak_count.load(Relaxed) + } + pub(super) fn downgrade(&self) -> Arc> { + self.inner.increment_weak_count(); + self.inner.clone() } @@ -452,6 +466,22 @@ impl Chan { // Notify the rx task self.rx_waker.wake(); } + + pub(super) fn decrement_weak_count(&self) { + self.tx_weak_count.fetch_sub(1, Relaxed); + } + + pub(super) fn increment_weak_count(&self) { + self.tx_weak_count.fetch_add(1, Relaxed); + } + + pub(super) fn strong_count(&self) -> usize { + self.tx_count.load(Acquire) + } + + pub(super) fn weak_count(&self) -> usize { + self.tx_weak_count.load(Relaxed) + } } impl Drop for Chan { diff --git a/tokio/src/sync/mpsc/unbounded.rs b/tokio/src/sync/mpsc/unbounded.rs index b87b07ba653..e5ef0adef38 100644 --- a/tokio/src/sync/mpsc/unbounded.rs +++ b/tokio/src/sync/mpsc/unbounded.rs @@ -578,16 +578,34 @@ impl UnboundedSender { chan: self.chan.downgrade(), } } + + /// Returns the number of [`UnboundedSender`] handles. + pub fn strong_count(&self) -> usize { + self.chan.strong_count() + } + + /// Returns the number of [`WeakUnboundedSender`] handles. + pub fn weak_count(&self) -> usize { + self.chan.weak_count() + } } impl Clone for WeakUnboundedSender { fn clone(&self) -> Self { + self.chan.increment_weak_count(); + WeakUnboundedSender { chan: self.chan.clone(), } } } +impl Drop for WeakUnboundedSender { + fn drop(&mut self) { + self.chan.decrement_weak_count(); + } +} + impl WeakUnboundedSender { /// Tries to convert a `WeakUnboundedSender` into an [`UnboundedSender`]. /// This will return `Some` if there are other `Sender` instances alive and @@ -595,6 +613,16 @@ impl WeakUnboundedSender { pub fn upgrade(&self) -> Option> { chan::Tx::upgrade(self.chan.clone()).map(UnboundedSender::new) } + + /// Returns the number of [`UnboundedSender`] handles. + pub fn strong_count(&self) -> usize { + self.chan.strong_count() + } + + /// Returns the number of [`WeakUnboundedSender`] handles. + pub fn weak_count(&self) -> usize { + self.chan.weak_count() + } } impl fmt::Debug for WeakUnboundedSender { diff --git a/tokio/tests/sync_mpsc_weak.rs b/tokio/tests/sync_mpsc_weak.rs index fad4c72f799..7716902f959 100644 --- a/tokio/tests/sync_mpsc_weak.rs +++ b/tokio/tests/sync_mpsc_weak.rs @@ -511,3 +511,145 @@ fn test_tx_count_weak_unbounded_sender() { assert!(tx_weak.upgrade().is_none() && tx_weak2.upgrade().is_none()); } + +#[tokio::test] +async fn sender_strong_count_when_cloned() { + let (tx, _rx) = mpsc::channel::<()>(1); + + let tx2 = tx.clone(); + + assert_eq!(tx.strong_count(), 2); + assert_eq!(tx2.strong_count(), 2); +} + +#[tokio::test] +async fn sender_weak_count_when_downgraded() { + let (tx, _rx) = mpsc::channel::<()>(1); + + let weak = tx.downgrade(); + + assert_eq!(tx.weak_count(), 1); + assert_eq!(weak.weak_count(), 1); +} + +#[tokio::test] +async fn sender_strong_count_when_dropped() { + let (tx, _rx) = mpsc::channel::<()>(1); + + let tx2 = tx.clone(); + + drop(tx2); + + assert_eq!(tx.strong_count(), 1); +} + +#[tokio::test] +async fn sender_weak_count_when_dropped() { + let (tx, _rx) = mpsc::channel::<()>(1); + + let weak = tx.downgrade(); + + drop(weak); + + assert_eq!(tx.weak_count(), 0); +} + +#[tokio::test] +async fn sender_strong_and_weak_conut() { + let (tx, _rx) = mpsc::channel::<()>(1); + + let tx2 = tx.clone(); + + let weak = tx.downgrade(); + let weak2 = tx2.downgrade(); + + assert_eq!(tx.strong_count(), 2); + assert_eq!(tx2.strong_count(), 2); + assert_eq!(weak.strong_count(), 2); + assert_eq!(weak2.strong_count(), 2); + + assert_eq!(tx.weak_count(), 2); + assert_eq!(tx2.weak_count(), 2); + assert_eq!(weak.weak_count(), 2); + assert_eq!(weak2.weak_count(), 2); + + drop(tx2); + drop(weak2); + + assert_eq!(tx.strong_count(), 1); + assert_eq!(weak.strong_count(), 1); + + assert_eq!(tx.weak_count(), 1); + assert_eq!(weak.weak_count(), 1); +} + +#[tokio::test] +async fn unbounded_sender_strong_count_when_cloned() { + let (tx, _rx) = mpsc::unbounded_channel::<()>(); + + let tx2 = tx.clone(); + + assert_eq!(tx.strong_count(), 2); + assert_eq!(tx2.strong_count(), 2); +} + +#[tokio::test] +async fn unbounded_sender_weak_count_when_downgraded() { + let (tx, _rx) = mpsc::unbounded_channel::<()>(); + + let weak = tx.downgrade(); + + assert_eq!(tx.weak_count(), 1); + assert_eq!(weak.weak_count(), 1); +} + +#[tokio::test] +async fn unbounded_sender_strong_count_when_dropped() { + let (tx, _rx) = mpsc::unbounded_channel::<()>(); + + let tx2 = tx.clone(); + + drop(tx2); + + assert_eq!(tx.strong_count(), 1); +} + +#[tokio::test] +async fn unbounded_sender_weak_count_when_dropped() { + let (tx, _rx) = mpsc::unbounded_channel::<()>(); + + let weak = tx.downgrade(); + + drop(weak); + + assert_eq!(tx.weak_count(), 0); +} + +#[tokio::test] +async fn unbounded_sender_strong_and_weak_conut() { + let (tx, _rx) = mpsc::unbounded_channel::<()>(); + + let tx2 = tx.clone(); + + let weak = tx.downgrade(); + let weak2 = tx2.downgrade(); + + assert_eq!(tx.strong_count(), 2); + assert_eq!(tx2.strong_count(), 2); + assert_eq!(weak.strong_count(), 2); + assert_eq!(weak2.strong_count(), 2); + + assert_eq!(tx.weak_count(), 2); + assert_eq!(tx2.weak_count(), 2); + assert_eq!(weak.weak_count(), 2); + assert_eq!(weak2.weak_count(), 2); + + drop(tx2); + drop(weak2); + + assert_eq!(tx.strong_count(), 1); + assert_eq!(weak.strong_count(), 1); + + assert_eq!(tx.weak_count(), 1); + assert_eq!(weak.weak_count(), 1); +}