Skip to content

Commit

Permalink
sync: add broadcast::Sender::new (#5824)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcospb19 authored Jul 15, 2023
1 parent 304d140 commit e52d56e
Showing 1 changed file with 101 additions and 33 deletions.
134 changes: 101 additions & 33 deletions tokio/src/sync/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,42 +444,13 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2;
/// This will panic if `capacity` is equal to `0` or larger
/// than `usize::MAX / 2`.
#[track_caller]
pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
assert!(capacity > 0, "capacity is empty");
assert!(capacity <= usize::MAX >> 1, "requested capacity too large");

// Round to a power of two
capacity = capacity.next_power_of_two();

let mut buffer = Vec::with_capacity(capacity);

for i in 0..capacity {
buffer.push(RwLock::new(Slot {
rem: AtomicUsize::new(0),
pos: (i as u64).wrapping_sub(capacity as u64),
val: UnsafeCell::new(None),
}));
}

let shared = Arc::new(Shared {
buffer: buffer.into_boxed_slice(),
mask: capacity - 1,
tail: Mutex::new(Tail {
pos: 0,
rx_cnt: 1,
closed: false,
waiters: LinkedList::new(),
}),
num_tx: AtomicUsize::new(1),
});

pub fn channel<T: Clone>(capacity: usize) -> (Sender<T>, Receiver<T>) {
// SAFETY: In the line below we are creating one extra receiver, so there will be 1 in total.
let tx = unsafe { Sender::new_with_receiver_count(1, capacity) };
let rx = Receiver {
shared: shared.clone(),
shared: tx.shared.clone(),
next: 0,
};

let tx = Sender { shared };

(tx, rx)
}

Expand All @@ -490,6 +461,65 @@ unsafe impl<T: Send> Send for Receiver<T> {}
unsafe impl<T: Send> Sync for Receiver<T> {}

impl<T> Sender<T> {
/// Creates the sending-half of the [`broadcast`] channel.
///
/// See documentation of [`broadcast::channel`] for errors when calling this function.
///
/// [`broadcast`]: crate::sync::broadcast
/// [`broadcast::channel`]: crate::sync::broadcast
#[track_caller]
pub fn new(capacity: usize) -> Self {
// SAFETY: We don't create extra receivers, so there are 0.
unsafe { Self::new_with_receiver_count(0, capacity) }
}

/// Creates the sending-half of the [`broadcast`](self) channel, and provide the receiver
/// count.
///
/// See the documentation of [`broadcast::channel`](self::channel) for more errors when
/// calling this function.
///
/// # Safety:
///
/// The caller must ensure that the amount of receivers for this Sender is correct before
/// the channel functionalities are used, the count is zero by default, as this function
/// does not create any receivers by itself.
#[track_caller]
unsafe fn new_with_receiver_count(receiver_count: usize, mut capacity: usize) -> Self {
assert!(capacity > 0, "broadcast channel capacity cannot be zero");
assert!(
capacity <= usize::MAX >> 1,
"broadcast channel capacity exceeded `usize::MAX / 2`"
);

// Round to a power of two
capacity = capacity.next_power_of_two();

let mut buffer = Vec::with_capacity(capacity);

for i in 0..capacity {
buffer.push(RwLock::new(Slot {
rem: AtomicUsize::new(0),
pos: (i as u64).wrapping_sub(capacity as u64),
val: UnsafeCell::new(None),
}));
}

let shared = Arc::new(Shared {
buffer: buffer.into_boxed_slice(),
mask: capacity - 1,
tail: Mutex::new(Tail {
pos: 0,
rx_cnt: receiver_count,
closed: false,
waiters: LinkedList::new(),
}),
num_tx: AtomicUsize::new(1),
});

Sender { shared }
}

/// Attempts to send a value to all active [`Receiver`] handles, returning
/// it back if it could not be sent.
///
Expand Down Expand Up @@ -1370,3 +1400,41 @@ impl<'a, T> Drop for RecvGuard<'a, T> {
}

fn is_unpin<T: Unpin>() {}

#[cfg(not(loom))]
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn receiver_count_on_sender_constructor() {
let sender = Sender::<i32>::new(16);
assert_eq!(sender.receiver_count(), 0);

let rx_1 = sender.subscribe();
assert_eq!(sender.receiver_count(), 1);

let rx_2 = rx_1.resubscribe();
assert_eq!(sender.receiver_count(), 2);

let rx_3 = sender.subscribe();
assert_eq!(sender.receiver_count(), 3);

drop(rx_3);
drop(rx_1);
assert_eq!(sender.receiver_count(), 1);

drop(rx_2);
assert_eq!(sender.receiver_count(), 0);
}

#[cfg(not(loom))]
#[test]
fn receiver_count_on_channel_constructor() {
let (sender, rx) = channel::<i32>(16);
assert_eq!(sender.receiver_count(), 1);

let _rx_2 = rx.resubscribe();
assert_eq!(sender.receiver_count(), 2);
}
}

0 comments on commit e52d56e

Please sign in to comment.