Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync: implement Weak version of mpsc::UnboundedSender #5189

Merged
merged 3 commits into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tokio/src/sync/mpsc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ mod chan;
pub(super) mod list;

mod unbounded;
pub use self::unbounded::{unbounded_channel, UnboundedReceiver, UnboundedSender};
pub use self::unbounded::{
unbounded_channel, UnboundedReceiver, UnboundedSender, WeakUnboundedSender,
};

pub mod error;

Expand Down
69 changes: 68 additions & 1 deletion tokio/src/sync/mpsc/unbounded.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::{atomic::AtomicUsize, Arc};
use crate::sync::mpsc::chan;
use crate::sync::mpsc::error::{SendError, TryRecvError};

Expand All @@ -13,6 +13,40 @@ pub struct UnboundedSender<T> {
chan: chan::Tx<T, Semaphore>,
}

/// An unbounded sender that does not prevent the channel from being closed.
///
/// If all [`UnboundedSender`] instances of a channel were dropped and only
/// `WeakUnboundedSender` instances remain, the channel is closed.
///
/// In order to send messages, the `WeakUnboundedSender` needs to be upgraded using
/// [`WeakUnboundedSender::upgrade`], which returns `Option<UnboundedSender>`. It returns `None`
/// if all `UnboundedSender`s have been dropped, and otherwise it returns an `UnboundedSender`.
///
/// [`UnboundedSender`]: UnboundedSender
/// [`WeakUnboundedSender::upgrade`]: WeakUnboundedSender::upgrade
///
/// #Examples
///
/// ```
/// use tokio::sync::mpsc::unbounded_channel;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, _rx) = unbounded_channel::<i32>();
/// let tx_weak = tx.downgrade();
///
/// // Upgrading will succeed because `tx` still exists.
/// assert!(tx_weak.upgrade().is_some());
///
/// // If we drop `tx`, then it will fail.
/// drop(tx);
/// assert!(tx_weak.clone().upgrade().is_none());
/// }
/// ```
pub struct WeakUnboundedSender<T> {
chan: Arc<chan::Chan<T, Semaphore>>,
}

impl<T> Clone for UnboundedSender<T> {
fn clone(&self) -> Self {
UnboundedSender {
Expand Down Expand Up @@ -384,4 +418,37 @@ impl<T> UnboundedSender<T> {
pub fn same_channel(&self, other: &Self) -> bool {
self.chan.same_channel(&other.chan)
}

/// Converts the `UnboundedSender` to a [`WeakUnboundedSender`] that does not count
/// towards RAII semantics, i.e. if all `UnboundedSender` instances of the
/// channel were dropped and only `WeakUnboundedSender` instances remain,
/// the channel is closed.
pub fn downgrade(&self) -> WeakUnboundedSender<T> {
WeakUnboundedSender {
chan: self.chan.downgrade(),
}
}
}

impl<T> Clone for WeakUnboundedSender<T> {
fn clone(&self) -> Self {
WeakUnboundedSender {
chan: self.chan.clone(),
}
}
}

impl<T> WeakUnboundedSender<T> {
/// Tries to convert a WeakUnboundedSender into an [`UnboundedSender`].
/// This will return `Some` if there are other `Sender` instances alive and
/// the channel wasn't previously dropped, otherwise `None` is returned.
pub fn upgrade(&self) -> Option<UnboundedSender<T>> {
chan::Tx::upgrade(self.chan.clone()).map(UnboundedSender::new)
}
}

impl<T> fmt::Debug for WeakUnboundedSender<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("WeakUnboundedSender").finish()
}
}
232 changes: 231 additions & 1 deletion tokio/tests/sync_mpsc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use wasm_bindgen_test::wasm_bindgen_test as maybe_tokio_test;
use tokio::test as maybe_tokio_test;

use tokio::sync::mpsc::error::{TryRecvError, TrySendError};
use tokio::sync::mpsc::{self, channel};
use tokio::sync::mpsc::{self, channel, unbounded_channel};
use tokio::sync::oneshot;
use tokio_test::*;

Expand Down Expand Up @@ -943,3 +943,233 @@ async fn test_tx_capacity() {
}

fn is_debug<T: fmt::Debug>(_: &T) {}

#[tokio::test]
async fn weak_unbounded_sender() {
let (tx, mut rx) = unbounded_channel();

let tx_weak = tokio::spawn(async move {
let tx_weak = tx.clone().downgrade();

for i in 0..10 {
if tx.send(i).is_err() {
return None;
}
}

let tx2 = tx_weak
.upgrade()
.expect("expected to be able to upgrade tx_weak");
let _ = tx2.send(20);
let tx_weak = tx2.downgrade();

Some(tx_weak)
})
.await
.unwrap();

for i in 0..12 {
let recvd = rx.recv().await;

match recvd {
Some(msg) => {
if i == 10 {
assert_eq!(msg, 20);
}
}
None => {
assert_eq!(i, 11);
break;
}
}
}

let tx_weak = tx_weak.unwrap();
let upgraded = tx_weak.upgrade();
assert!(upgraded.is_none());
}

#[tokio::test]
async fn actor_weak_unbounded_sender() {
pub struct MyActor {
receiver: mpsc::UnboundedReceiver<ActorMessage>,
sender: mpsc::WeakUnboundedSender<ActorMessage>,
next_id: u32,
pub received_self_msg: bool,
}

enum ActorMessage {
GetUniqueId { respond_to: oneshot::Sender<u32> },
SelfMessage {},
}

impl MyActor {
fn new(
receiver: mpsc::UnboundedReceiver<ActorMessage>,
sender: mpsc::WeakUnboundedSender<ActorMessage>,
) -> Self {
MyActor {
receiver,
sender,
next_id: 0,
received_self_msg: false,
}
}

fn handle_message(&mut self, msg: ActorMessage) {
match msg {
ActorMessage::GetUniqueId { respond_to } => {
self.next_id += 1;

// The `let _ =` ignores any errors when sending.
//
// This can happen if the `select!` macro is used
// to cancel waiting for the response.
let _ = respond_to.send(self.next_id);
}
ActorMessage::SelfMessage { .. } => {
self.received_self_msg = true;
}
}
}

async fn send_message_to_self(&mut self) {
let msg = ActorMessage::SelfMessage {};

let sender = self.sender.clone();

// cannot move self.sender here
if let Some(sender) = sender.upgrade() {
let _ = sender.send(msg);
self.sender = sender.downgrade();
}
}

async fn run(&mut self) {
let mut i = 0;
while let Some(msg) = self.receiver.recv().await {
self.handle_message(msg);

if i == 0 {
self.send_message_to_self().await;
}

i += 1
}

assert!(self.received_self_msg);
}
}

#[derive(Clone)]
pub struct MyActorHandle {
sender: mpsc::UnboundedSender<ActorMessage>,
}

impl MyActorHandle {
pub fn new() -> (Self, MyActor) {
let (sender, receiver) = mpsc::unbounded_channel();
let actor = MyActor::new(receiver, sender.clone().downgrade());

(Self { sender }, actor)
}

pub async fn get_unique_id(&self) -> u32 {
let (send, recv) = oneshot::channel();
let msg = ActorMessage::GetUniqueId { respond_to: send };

// Ignore send errors. If this send fails, so does the
// recv.await below. There's no reason to check the
// failure twice.
let _ = self.sender.send(msg);
recv.await.expect("Actor task has been killed")
}
}

let (handle, mut actor) = MyActorHandle::new();

let actor_handle = tokio::spawn(async move { actor.run().await });

let _ = tokio::spawn(async move {
let _ = handle.get_unique_id().await;
drop(handle);
})
.await;

let _ = actor_handle.await;
}

// Tests that no pending messages are put onto the channel after `Rx` was
// dropped.
//
// Note: After the introduction of `UnboundedWeakSender`, which internally
// used `Arc` and doesn't call a drop of the channel after the last strong
// `UnboundedSender` was dropped while more than one `UnboundedWeakSender`
// remains, we want to ensure that no messages are kept in the channel, which
// were sent after the receiver was dropped.
#[tokio::test]
async fn test_msgs_dropped_on_unbounded_rx_drop() {
let (tx, mut rx) = mpsc::unbounded_channel();

tx.send(Msg {}).unwrap();
tx.send(Msg {}).unwrap();

// This msg will be pending and should be dropped when `rx` is dropped
let sent = tx.send(Msg {});

let _ = rx.recv().await.unwrap();
let _ = rx.recv().await.unwrap();

sent.unwrap();

drop(rx);

assert_eq!(NUM_DROPPED.load(Acquire), 3);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this shared by tests? That's not so good because they run in parallel and access the same globals.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just made separate version for unbounded tests


// This msg will not be put onto `Tx` list anymore, since `Rx` is closed.
assert!(tx.send(Msg {}).is_err());

assert_eq!(NUM_DROPPED.load(Acquire), 4);
}

// Tests that an `WeakUnboundedSender` is upgradeable when other
// `UnboundedSender`s exist.
#[tokio::test]
async fn downgrade_upgrade_unbounded_sender_success() {
let (tx, _rx) = mpsc::unbounded_channel::<i32>();
let weak_tx = tx.downgrade();
assert!(weak_tx.upgrade().is_some());
}

// Tests that a `WeakUnboundedSender` fails to upgrade when no other
// `UnboundedSender` exists.
#[tokio::test]
async fn downgrade_upgrade_unbounded_sender_failure() {
let (tx, _rx) = mpsc::unbounded_channel::<i32>();
let weak_tx = tx.downgrade();
drop(tx);
assert!(weak_tx.upgrade().is_none());
}

// Tests that an `WeakUnboundedSender` cannot be upgraded after an
// `UnboundedSender` was dropped, which existed at the time of the `downgrade` call.
#[tokio::test]
async fn downgrade_drop_upgrade_unbounded() {
let (tx, _rx) = mpsc::unbounded_channel::<i32>();

// the cloned `Tx` is dropped right away
let weak_tx = tx.clone().downgrade();
drop(tx);
assert!(weak_tx.upgrade().is_none());
}

// Tests that `downgrade` does not change the `tx_count` of the channel.
#[tokio::test]
async fn test_tx_count_weak_unbounded_sender() {
let (tx, _rx) = mpsc::unbounded_channel::<i32>();
let tx_weak = tx.downgrade();
let tx_weak2 = tx.downgrade();
drop(tx);

assert!(tx_weak.upgrade().is_none() && tx_weak2.upgrade().is_none());
}