Skip to content

Commit

Permalink
Make BallistaContext::collect streaming (#535)
Browse files Browse the repository at this point in the history
  • Loading branch information
edrevo authored Jun 11, 2021
1 parent 3ef7f34 commit 63e3045
Showing 1 changed file with 72 additions and 41 deletions.
113 changes: 72 additions & 41 deletions ballista/rust/client/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,27 @@ use std::{collections::HashMap, convert::TryInto};
use std::{fs, time::Duration};

use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient;
use ballista_core::serde::protobuf::PartitionLocation;
use ballista_core::serde::protobuf::{
execute_query_params::Query, job_status, ExecuteQueryParams, GetJobStatusParams,
GetJobStatusResult,
};
use ballista_core::{
client::BallistaClient, datasource::DfTableAdapter, memory_stream::MemoryStream,
utils::create_datafusion_context,
client::BallistaClient, datasource::DfTableAdapter, utils::create_datafusion_context,
};

use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::error::Result as ArrowResult;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::catalog::TableReference;
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_plan::LogicalPlan;
use datafusion::physical_plan::csv::CsvReadOptions;
use datafusion::{dataframe::DataFrame, physical_plan::RecordBatchStream};
use futures::future;
use futures::Stream;
use futures::StreamExt;
use log::{error, info};

#[allow(dead_code)]
Expand Down Expand Up @@ -68,6 +74,32 @@ impl BallistaContextState {
}
}

struct WrappedStream {
stream: Pin<Box<dyn Stream<Item = ArrowResult<RecordBatch>> + Send + Sync>>,
schema: SchemaRef,
}

impl RecordBatchStream for WrappedStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}

impl Stream for WrappedStream {
type Item = ArrowResult<RecordBatch>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.stream.poll_next_unpin(cx)
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.stream.size_hint()
}
}

#[allow(dead_code)]

pub struct BallistaContext {
Expand Down Expand Up @@ -155,6 +187,29 @@ impl BallistaContext {
ctx.sql(sql)
}

async fn fetch_partition(
location: PartitionLocation,
) -> Result<Pin<Box<dyn RecordBatchStream + Send + Sync>>> {
let metadata = location.executor_meta.ok_or_else(|| {
DataFusionError::Internal("Received empty executor metadata".to_owned())
})?;
let partition_id = location.partition_id.ok_or_else(|| {
DataFusionError::Internal("Received empty partition id".to_owned())
})?;
let mut ballista_client =
BallistaClient::try_new(metadata.host.as_str(), metadata.port as u16)
.await
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
Ok(ballista_client
.fetch_partition(
&partition_id.job_id,
partition_id.stage_id as usize,
partition_id.partition_id as usize,
)
.await
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?)
}

pub async fn collect(
&self,
plan: &LogicalPlan,
Expand Down Expand Up @@ -222,45 +277,21 @@ impl BallistaContext {
break Err(DataFusionError::Execution(msg));
}
job_status::Status::Completed(completed) => {
// TODO: use streaming. Probably need to change the signature of fetch_partition to achieve that
let mut result = vec![];
for location in completed.partition_location {
let metadata = location.executor_meta.ok_or_else(|| {
DataFusionError::Internal(
"Received empty executor metadata".to_owned(),
)
})?;
let partition_id = location.partition_id.ok_or_else(|| {
DataFusionError::Internal(
"Received empty partition id".to_owned(),
)
})?;
let mut ballista_client = BallistaClient::try_new(
metadata.host.as_str(),
metadata.port as u16,
)
.await
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
let stream = ballista_client
.fetch_partition(
&partition_id.job_id,
partition_id.stage_id as usize,
partition_id.partition_id as usize,
)
.await
.map_err(|e| {
DataFusionError::Execution(format!("{:?}", e))
})?;
result.append(
&mut datafusion::physical_plan::common::collect(stream)
.await?,
);
}
break Ok(Box::pin(MemoryStream::try_new(
result,
Arc::new(schema),
None,
)?));
let result = future::join_all(
completed
.partition_location
.into_iter()
.map(BallistaContext::fetch_partition),
)
.await
.into_iter()
.collect::<Result<Vec<_>>>()?;

let result = WrappedStream {
stream: Box::pin(futures::stream::iter(result).flatten()),
schema: Arc::new(schema),
};
break Ok(Box::pin(result));
}
};
}
Expand Down

0 comments on commit 63e3045

Please sign in to comment.