Skip to content

Commit

Permalink
Extract ReceiverStreamBuilder (#7817)
Browse files Browse the repository at this point in the history
* Extract ReceiverStreamBuilder

* Docs and format

* Update datafusion/physical-plan/src/stream.rs

* fmt

* Undo changes to testing pin

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
tustvold and alamb authored Oct 16, 2023
1 parent 26e43ac commit fa2bb6c
Showing 1 changed file with 132 additions and 100 deletions.
232 changes: 132 additions & 100 deletions datafusion/physical-plan/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,124 @@ use tokio::task::JoinSet;
use super::metrics::BaselineMetrics;
use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};

/// Creates a stream from a collection of producing tasks, routing panics to the stream.
///
/// Note that this is similar to [`ReceiverStream` from tokio-stream], with the differences being:
///
/// 1. Methods to bound and "detach" tasks (`spawn()` and `spawn_blocking()`).
///
/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver.
///
/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped.
///
/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html

pub(crate) struct ReceiverStreamBuilder<O> {
tx: Sender<Result<O>>,
rx: Receiver<Result<O>>,
join_set: JoinSet<Result<()>>,
}

impl<O: Send + 'static> ReceiverStreamBuilder<O> {
/// create new channels with the specified buffer size
pub fn new(capacity: usize) -> Self {
let (tx, rx) = tokio::sync::mpsc::channel(capacity);

Self {
tx,
rx,
join_set: JoinSet::new(),
}
}

/// Get a handle for sending data to the output
pub fn tx(&self) -> Sender<Result<O>> {
self.tx.clone()
}

/// Spawn task that will be aborted if this builder (or the stream
/// built from it) are dropped
pub fn spawn<F>(&mut self, task: F)
where
F: Future<Output = Result<()>>,
F: Send + 'static,
{
self.join_set.spawn(task);
}

/// Spawn a blocking task that will be aborted if this builder (or the stream
/// built from it) are dropped
///
/// this is often used to spawn tasks that write to the sender
/// retrieved from `Self::tx`
pub fn spawn_blocking<F>(&mut self, f: F)
where
F: FnOnce() -> Result<()>,
F: Send + 'static,
{
self.join_set.spawn_blocking(f);
}

/// Create a stream of all data written to `tx`
pub fn build(self) -> BoxStream<'static, Result<O>> {
let Self {
tx,
rx,
mut join_set,
} = self;

// don't need tx
drop(tx);

// future that checks the result of the join set, and propagates panic if seen
let check = async move {
while let Some(result) = join_set.join_next().await {
match result {
Ok(task_result) => {
match task_result {
// nothing to report
Ok(_) => continue,
// This means a blocking task error
Err(e) => {
return Some(exec_err!("Spawned Task error: {e}"));
}
}
}
// This means a tokio task error, likely a panic
Err(e) => {
if e.is_panic() {
// resume on the main thread
std::panic::resume_unwind(e.into_panic());
} else {
// This should only occur if the task is
// cancelled, which would only occur if
// the JoinSet were aborted, which in turn
// would imply that the receiver has been
// dropped and this code is not running
return Some(internal_err!("Non Panic Task error: {e}"));
}
}
}
}
None
};

let check_stream = futures::stream::once(check)
// unwrap Option / only return the error
.filter_map(|item| async move { item });

// Convert the receiver into a stream
let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
let next_item = rx.recv().await;
next_item.map(|next_item| (next_item, rx))
});

// Merge the streams together so whichever is ready first
// produces the batch
futures::stream::select(rx_stream, check_stream).boxed()
}
}

/// Builder for [`RecordBatchReceiverStream`] that propagates errors
/// and panic's correctly.
///
Expand All @@ -47,28 +165,22 @@ use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
///
/// This also handles propagating panic`s and canceling the tasks.
pub struct RecordBatchReceiverStreamBuilder {
tx: Sender<Result<RecordBatch>>,
rx: Receiver<Result<RecordBatch>>,
schema: SchemaRef,
join_set: JoinSet<Result<()>>,
inner: ReceiverStreamBuilder<RecordBatch>,
}

impl RecordBatchReceiverStreamBuilder {
/// create new channels with the specified buffer size
pub fn new(schema: SchemaRef, capacity: usize) -> Self {
let (tx, rx) = tokio::sync::mpsc::channel(capacity);

Self {
tx,
rx,
schema,
join_set: JoinSet::new(),
inner: ReceiverStreamBuilder::new(capacity),
}
}

/// Get a handle for sending [`RecordBatch`]es to the output
/// Get a handle for sending [`RecordBatch`] to the output
pub fn tx(&self) -> Sender<Result<RecordBatch>> {
self.tx.clone()
self.inner.tx()
}

/// Spawn task that will be aborted if this builder (or the stream
Expand All @@ -81,7 +193,7 @@ impl RecordBatchReceiverStreamBuilder {
F: Future<Output = Result<()>>,
F: Send + 'static,
{
self.join_set.spawn(task);
self.inner.spawn(task)
}

/// Spawn a blocking task that will be aborted if this builder (or the stream
Expand All @@ -94,7 +206,7 @@ impl RecordBatchReceiverStreamBuilder {
F: FnOnce() -> Result<()>,
F: Send + 'static,
{
self.join_set.spawn_blocking(f);
self.inner.spawn_blocking(f)
}

/// runs the input_partition of the `input` ExecutionPlan on the
Expand All @@ -110,7 +222,7 @@ impl RecordBatchReceiverStreamBuilder {
) {
let output = self.tx();

self.spawn(async move {
self.inner.spawn(async move {
let mut stream = match input.execute(partition, context) {
Err(e) => {
// If send fails, the plan being torn down, there
Expand Down Expand Up @@ -155,80 +267,17 @@ impl RecordBatchReceiverStreamBuilder {
});
}

/// Create a stream of all `RecordBatch`es written to `tx`
/// Create a stream of all [`RecordBatch`] written to `tx`
pub fn build(self) -> SendableRecordBatchStream {
let Self {
tx,
rx,
schema,
mut join_set,
} = self;

// don't need tx
drop(tx);

// future that checks the result of the join set, and propagates panic if seen
let check = async move {
while let Some(result) = join_set.join_next().await {
match result {
Ok(task_result) => {
match task_result {
// nothing to report
Ok(_) => continue,
// This means a blocking task error
Err(e) => {
return Some(exec_err!("Spawned Task error: {e}"));
}
}
}
// This means a tokio task error, likely a panic
Err(e) => {
if e.is_panic() {
// resume on the main thread
std::panic::resume_unwind(e.into_panic());
} else {
// This should only occur if the task is
// cancelled, which would only occur if
// the JoinSet were aborted, which in turn
// would imply that the receiver has been
// dropped and this code is not running
return Some(internal_err!("Non Panic Task error: {e}"));
}
}
}
}
None
};

let check_stream = futures::stream::once(check)
// unwrap Option / only return the error
.filter_map(|item| async move { item });

// Convert the receiver into a stream
let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
let next_item = rx.recv().await;
next_item.map(|next_item| (next_item, rx))
});

// Merge the streams together so whichever is ready first
// produces the batch
let inner = futures::stream::select(rx_stream, check_stream).boxed();

Box::pin(RecordBatchReceiverStream { schema, inner })
Box::pin(RecordBatchStreamAdapter::new(
self.schema,
self.inner.build(),
))
}
}

/// A [`SendableRecordBatchStream`] that combines [`RecordBatch`]es from multiple inputs,
/// on new tokio Tasks, increasing the potential parallelism.
///
/// This structure also handles propagating panics and cancelling the
/// underlying tasks correctly.
///
/// Use [`Self::builder`] to construct one.
pub struct RecordBatchReceiverStream {
schema: SchemaRef,
inner: BoxStream<'static, Result<RecordBatch>>,
}
#[doc(hidden)]
pub struct RecordBatchReceiverStream {}

impl RecordBatchReceiverStream {
/// Create a builder with an internal buffer of capacity batches.
Expand All @@ -240,23 +289,6 @@ impl RecordBatchReceiverStream {
}
}

impl Stream for RecordBatchReceiverStream {
type Item = Result<RecordBatch>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx)
}
}

impl RecordBatchStream for RecordBatchReceiverStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}

pin_project! {
/// Combines a [`Stream`] with a [`SchemaRef`] implementing
/// [`RecordBatchStream`] for the combination
Expand Down

0 comments on commit fa2bb6c

Please sign in to comment.