Skip to content

Commit

Permalink
sync: add mpsc::WeakUnboundedSender (#5189)
Browse files Browse the repository at this point in the history
Signed-off-by: Artyom Kozhemiakin <[email protected]>
  • Loading branch information
akozhemiakin authored Nov 12, 2022
1 parent b7812c8 commit 582d512
Show file tree
Hide file tree
Showing 4 changed files with 588 additions and 272 deletions.
4 changes: 3 additions & 1 deletion tokio/src/sync/mpsc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ mod chan;
pub(super) mod list;

mod unbounded;
pub use self::unbounded::{unbounded_channel, UnboundedReceiver, UnboundedSender};
pub use self::unbounded::{
unbounded_channel, UnboundedReceiver, UnboundedSender, WeakUnboundedSender,
};

pub mod error;

Expand Down
69 changes: 68 additions & 1 deletion tokio/src/sync/mpsc/unbounded.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::{atomic::AtomicUsize, Arc};
use crate::sync::mpsc::chan;
use crate::sync::mpsc::error::{SendError, TryRecvError};

Expand All @@ -13,6 +13,40 @@ pub struct UnboundedSender<T> {
chan: chan::Tx<T, Semaphore>,
}

/// An unbounded sender that does not prevent the channel from being closed.
///
/// If all [`UnboundedSender`] instances of a channel were dropped and only
/// `WeakUnboundedSender` instances remain, the channel is closed.
///
/// In order to send messages, the `WeakUnboundedSender` needs to be upgraded using
/// [`WeakUnboundedSender::upgrade`], which returns `Option<UnboundedSender>`. It returns `None`
/// if all `UnboundedSender`s have been dropped, and otherwise it returns an `UnboundedSender`.
///
/// [`UnboundedSender`]: UnboundedSender
/// [`WeakUnboundedSender::upgrade`]: WeakUnboundedSender::upgrade
///
/// #Examples
///
/// ```
/// use tokio::sync::mpsc::unbounded_channel;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, _rx) = unbounded_channel::<i32>();
/// let tx_weak = tx.downgrade();
///
/// // Upgrading will succeed because `tx` still exists.
/// assert!(tx_weak.upgrade().is_some());
///
/// // If we drop `tx`, then it will fail.
/// drop(tx);
/// assert!(tx_weak.clone().upgrade().is_none());
/// }
/// ```
pub struct WeakUnboundedSender<T> {
chan: Arc<chan::Chan<T, Semaphore>>,
}

impl<T> Clone for UnboundedSender<T> {
fn clone(&self) -> Self {
UnboundedSender {
Expand Down Expand Up @@ -384,4 +418,37 @@ impl<T> UnboundedSender<T> {
pub fn same_channel(&self, other: &Self) -> bool {
self.chan.same_channel(&other.chan)
}

/// Converts the `UnboundedSender` to a [`WeakUnboundedSender`] that does not count
/// towards RAII semantics, i.e. if all `UnboundedSender` instances of the
/// channel were dropped and only `WeakUnboundedSender` instances remain,
/// the channel is closed.
pub fn downgrade(&self) -> WeakUnboundedSender<T> {
WeakUnboundedSender {
chan: self.chan.downgrade(),
}
}
}

impl<T> Clone for WeakUnboundedSender<T> {
fn clone(&self) -> Self {
WeakUnboundedSender {
chan: self.chan.clone(),
}
}
}

impl<T> WeakUnboundedSender<T> {
/// Tries to convert a WeakUnboundedSender into an [`UnboundedSender`].
/// This will return `Some` if there are other `Sender` instances alive and
/// the channel wasn't previously dropped, otherwise `None` is returned.
pub fn upgrade(&self) -> Option<UnboundedSender<T>> {
chan::Tx::upgrade(self.chan.clone()).map(UnboundedSender::new)
}
}

impl<T> fmt::Debug for WeakUnboundedSender<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("WeakUnboundedSender").finish()
}
}
274 changes: 4 additions & 270 deletions tokio/tests/sync_mpsc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,14 @@ use wasm_bindgen_test::wasm_bindgen_test as test;
#[cfg(tokio_wasm_not_wasi)]
use wasm_bindgen_test::wasm_bindgen_test as maybe_tokio_test;

use std::fmt;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::{TryRecvError, TrySendError};
#[cfg(not(tokio_wasm_not_wasi))]
use tokio::test as maybe_tokio_test;

use tokio::sync::mpsc::error::{TryRecvError, TrySendError};
use tokio::sync::mpsc::{self, channel};
use tokio::sync::oneshot;
use tokio_test::*;

use std::fmt;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{Acquire, Release};
use std::sync::Arc;

#[cfg(not(tokio_wasm))]
mod support {
pub(crate) mod mpsc_stream;
Expand Down Expand Up @@ -662,267 +657,6 @@ fn recv_timeout_panic() {
tx.send_timeout(10, Duration::from_secs(1)).now_or_never();
}

#[tokio::test]
async fn weak_sender() {
let (tx, mut rx) = channel(11);

let tx_weak = tokio::spawn(async move {
let tx_weak = tx.clone().downgrade();

for i in 0..10 {
if tx.send(i).await.is_err() {
return None;
}
}

let tx2 = tx_weak
.upgrade()
.expect("expected to be able to upgrade tx_weak");
let _ = tx2.send(20).await;
let tx_weak = tx2.downgrade();

Some(tx_weak)
})
.await
.unwrap();

for i in 0..12 {
let recvd = rx.recv().await;

match recvd {
Some(msg) => {
if i == 10 {
assert_eq!(msg, 20);
}
}
None => {
assert_eq!(i, 11);
break;
}
}
}

let tx_weak = tx_weak.unwrap();
let upgraded = tx_weak.upgrade();
assert!(upgraded.is_none());
}

#[tokio::test]
async fn actor_weak_sender() {
pub struct MyActor {
receiver: mpsc::Receiver<ActorMessage>,
sender: mpsc::WeakSender<ActorMessage>,
next_id: u32,
pub received_self_msg: bool,
}

enum ActorMessage {
GetUniqueId { respond_to: oneshot::Sender<u32> },
SelfMessage {},
}

impl MyActor {
fn new(
receiver: mpsc::Receiver<ActorMessage>,
sender: mpsc::WeakSender<ActorMessage>,
) -> Self {
MyActor {
receiver,
sender,
next_id: 0,
received_self_msg: false,
}
}

fn handle_message(&mut self, msg: ActorMessage) {
match msg {
ActorMessage::GetUniqueId { respond_to } => {
self.next_id += 1;

// The `let _ =` ignores any errors when sending.
//
// This can happen if the `select!` macro is used
// to cancel waiting for the response.
let _ = respond_to.send(self.next_id);
}
ActorMessage::SelfMessage { .. } => {
self.received_self_msg = true;
}
}
}

async fn send_message_to_self(&mut self) {
let msg = ActorMessage::SelfMessage {};

let sender = self.sender.clone();

// cannot move self.sender here
if let Some(sender) = sender.upgrade() {
let _ = sender.send(msg).await;
self.sender = sender.downgrade();
}
}

async fn run(&mut self) {
let mut i = 0;
while let Some(msg) = self.receiver.recv().await {
self.handle_message(msg);

if i == 0 {
self.send_message_to_self().await;
}

i += 1
}

assert!(self.received_self_msg);
}
}

#[derive(Clone)]
pub struct MyActorHandle {
sender: mpsc::Sender<ActorMessage>,
}

impl MyActorHandle {
pub fn new() -> (Self, MyActor) {
let (sender, receiver) = mpsc::channel(8);
let actor = MyActor::new(receiver, sender.clone().downgrade());

(Self { sender }, actor)
}

pub async fn get_unique_id(&self) -> u32 {
let (send, recv) = oneshot::channel();
let msg = ActorMessage::GetUniqueId { respond_to: send };

// Ignore send errors. If this send fails, so does the
// recv.await below. There's no reason to check the
// failure twice.
let _ = self.sender.send(msg).await;
recv.await.expect("Actor task has been killed")
}
}

let (handle, mut actor) = MyActorHandle::new();

let actor_handle = tokio::spawn(async move { actor.run().await });

let _ = tokio::spawn(async move {
let _ = handle.get_unique_id().await;
drop(handle);
})
.await;

let _ = actor_handle.await;
}

static NUM_DROPPED: AtomicUsize = AtomicUsize::new(0);

#[derive(Debug)]
struct Msg;

impl Drop for Msg {
fn drop(&mut self) {
NUM_DROPPED.fetch_add(1, Release);
}
}

// Tests that no pending messages are put onto the channel after `Rx` was
// dropped.
//
// Note: After the introduction of `WeakSender`, which internally
// used `Arc` and doesn't call a drop of the channel after the last strong
// `Sender` was dropped while more than one `WeakSender` remains, we want to
// ensure that no messages are kept in the channel, which were sent after
// the receiver was dropped.
#[tokio::test]
async fn test_msgs_dropped_on_rx_drop() {
let (tx, mut rx) = mpsc::channel(3);

tx.send(Msg {}).await.unwrap();
tx.send(Msg {}).await.unwrap();

// This msg will be pending and should be dropped when `rx` is dropped
let sent_fut = tx.send(Msg {});

let _ = rx.recv().await.unwrap();
let _ = rx.recv().await.unwrap();

sent_fut.await.unwrap();

drop(rx);

assert_eq!(NUM_DROPPED.load(Acquire), 3);

// This msg will not be put onto `Tx` list anymore, since `Rx` is closed.
assert!(tx.send(Msg {}).await.is_err());

assert_eq!(NUM_DROPPED.load(Acquire), 4);
}

// Tests that a `WeakSender` is upgradeable when other `Sender`s exist.
#[tokio::test]
async fn downgrade_upgrade_sender_success() {
let (tx, _rx) = mpsc::channel::<i32>(1);
let weak_tx = tx.downgrade();
assert!(weak_tx.upgrade().is_some());
}

// Tests that a `WeakSender` fails to upgrade when no other `Sender` exists.
#[tokio::test]
async fn downgrade_upgrade_sender_failure() {
let (tx, _rx) = mpsc::channel::<i32>(1);
let weak_tx = tx.downgrade();
drop(tx);
assert!(weak_tx.upgrade().is_none());
}

// Tests that a `WeakSender` cannot be upgraded after a `Sender` was dropped,
// which existed at the time of the `downgrade` call.
#[tokio::test]
async fn downgrade_drop_upgrade() {
let (tx, _rx) = mpsc::channel::<i32>(1);

// the cloned `Tx` is dropped right away
let weak_tx = tx.clone().downgrade();
drop(tx);
assert!(weak_tx.upgrade().is_none());
}

// Tests that we can upgrade a weak sender with an outstanding permit
// but no other strong senders.
#[tokio::test]
async fn downgrade_get_permit_upgrade_no_senders() {
let (tx, _rx) = mpsc::channel::<i32>(1);
let weak_tx = tx.downgrade();
let _permit = tx.reserve_owned().await.unwrap();
assert!(weak_tx.upgrade().is_some());
}

// Tests that you can downgrade and upgrade a sender with an outstanding permit
// but no other senders left.
#[tokio::test]
async fn downgrade_upgrade_get_permit_no_senders() {
let (tx, _rx) = mpsc::channel::<i32>(1);
let tx2 = tx.clone();
let _permit = tx.reserve_owned().await.unwrap();
let weak_tx = tx2.downgrade();
drop(tx2);
assert!(weak_tx.upgrade().is_some());
}

// Tests that `downgrade` does not change the `tx_count` of the channel.
#[tokio::test]
async fn test_tx_count_weak_sender() {
let (tx, _rx) = mpsc::channel::<i32>(1);
let tx_weak = tx.downgrade();
let tx_weak2 = tx.downgrade();
drop(tx);

assert!(tx_weak.upgrade().is_none() && tx_weak2.upgrade().is_none());
}

// Tests that channel `capacity` changes and `max_capacity` stays the same
#[tokio::test]
async fn test_tx_capacity() {
Expand Down
Loading

0 comments on commit 582d512

Please sign in to comment.