Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory reservation & metrics for cross join #5339

Merged
merged 3 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions datafusion/core/src/physical_plan/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::error::{DataFusionError, Result};
use crate::execution::context::TaskContext;
use crate::execution::memory_pool::MemoryReservation;
use crate::physical_plan::metrics::MemTrackingMetrics;
use crate::physical_plan::{displayable, ColumnStatistics, ExecutionPlan, Statistics};
use arrow::datatypes::{Schema, SchemaRef};
Expand All @@ -28,6 +29,7 @@ use arrow::record_batch::RecordBatch;
use datafusion_physical_expr::PhysicalSortExpr;
use futures::{Future, Stream, StreamExt, TryStreamExt};
use log::debug;
use parking_lot::Mutex;
use pin_project_lite::pin_project;
use std::fs;
use std::fs::{metadata, File};
Expand All @@ -37,6 +39,8 @@ use std::task::{Context, Poll};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;

pub(crate) type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;

/// Stream of record batches
pub struct SizedRecordBatchStream {
schema: SchemaRef,
Expand Down
210 changes: 154 additions & 56 deletions datafusion/core/src/physical_plan/joins/cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;

use crate::execution::context::TaskContext;
use crate::execution::memory_pool::MemoryConsumer;
use crate::physical_plan::common::SharedMemoryReservation;
use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::physical_plan::{
coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec,
ColumnStatistics, DisplayFormatType, Distribution, EquivalenceProperties,
Expand All @@ -35,12 +38,11 @@ use crate::physical_plan::{
use crate::{error::Result, scalar::ScalarValue};
use async_trait::async_trait;
use datafusion_common::DataFusionError;
use log::debug;
use std::time::Instant;
use parking_lot::Mutex;

use super::utils::{
adjust_right_output_partitioning, cross_join_equivalence_properties, OnceAsync,
OnceFut,
adjust_right_output_partitioning, cross_join_equivalence_properties,
BuildProbeJoinMetrics, OnceAsync, OnceFut,
};

/// Data of the left side
Expand All @@ -58,6 +60,8 @@ pub struct CrossJoinExec {
schema: SchemaRef,
/// Build-side data
left_fut: OnceAsync<JoinLeftData>,
/// Execution plan metrics
metrics: ExecutionPlanMetricsSet,
}

impl CrossJoinExec {
Expand All @@ -79,6 +83,7 @@ impl CrossJoinExec {
right,
schema,
left_fut: Default::default(),
metrics: ExecutionPlanMetricsSet::default(),
}
}

Expand All @@ -97,9 +102,9 @@ impl CrossJoinExec {
async fn load_left_input(
left: Arc<dyn ExecutionPlan>,
context: Arc<TaskContext>,
metrics: BuildProbeJoinMetrics,
reservation: SharedMemoryReservation,
) -> Result<JoinLeftData> {
let start = Instant::now();

// merge all left parts into a single stream
let merge = {
if left.output_partitioning().partition_count() != 1 {
Expand All @@ -111,22 +116,28 @@ async fn load_left_input(
let stream = merge.execute(0, context)?;

// Load all batches and count the rows
let (batches, num_rows) = stream
.try_fold((Vec::new(), 0usize), |mut acc, batch| async {
acc.1 += batch.num_rows();
acc.0.push(batch);
Ok(acc)
})
let (batches, num_rows, _, _) = stream
.try_fold(
(Vec::new(), 0usize, metrics, reservation),
|mut acc, batch| async {
let batch_size = batch.get_array_memory_size();
// Reserve memory for incoming batch
acc.3.lock().try_grow(batch_size)?;
// Update metrics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

acc.2.build_mem_used.add(batch_size);
acc.2.build_input_batches.add(1);
acc.2.build_input_rows.add(batch.num_rows());
// Update rowcount
acc.1 += batch.num_rows();
// Push batch to output
acc.0.push(batch);
Ok(acc)
},
)
.await?;

let merged_batch = concat_batches(&left.schema(), &batches, num_rows)?;

debug!(
"Built build-side of cross join containing {} rows in {} ms",
num_rows,
start.elapsed().as_millis()
);

Ok(merged_batch)
}

Expand All @@ -143,6 +154,10 @@ impl ExecutionPlan for CrossJoinExec {
vec![self.left.clone(), self.right.clone()]
}

fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}

/// Specifies whether this plan generates an infinite stream of records.
/// If the plan does not support pipelining, but it its input(s) are
/// infinite, returns an error to indicate this.
Expand Down Expand Up @@ -205,21 +220,29 @@ impl ExecutionPlan for CrossJoinExec {
) -> Result<SendableRecordBatchStream> {
let stream = self.right.execute(partition, context.clone())?;

let left_fut = self
.left_fut
.once(|| load_left_input(self.left.clone(), context));
let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
let reservation = Arc::new(Mutex::new(
MemoryConsumer::new(format!("CrossJoinStream[{partition}]"))
.register(context.memory_pool()),
));

let left_fut = self.left_fut.once(|| {
load_left_input(
self.left.clone(),
context,
join_metrics.clone(),
reservation.clone(),
)
});

Ok(Box::pin(CrossJoinStream {
schema: self.schema.clone(),
left_fut,
right: stream,
right_batch: Arc::new(parking_lot::Mutex::new(None)),
left_index: 0,
num_input_batches: 0,
num_input_rows: 0,
num_output_batches: 0,
num_output_rows: 0,
join_time: 0,
join_metrics,
reservation,
}))
}

Expand Down Expand Up @@ -321,16 +344,10 @@ struct CrossJoinStream {
left_index: usize,
/// Current batch being processed from the right side
right_batch: Arc<parking_lot::Mutex<Option<RecordBatch>>>,
/// number of input batches
num_input_batches: usize,
/// number of input rows
num_input_rows: usize,
/// number of batches produced
num_output_batches: usize,
/// number of rows produced
num_output_rows: usize,
/// total time for joining probe-side batches to the build-side batches
join_time: usize,
/// join execution metrics
join_metrics: BuildProbeJoinMetrics,
/// memory reservation
reservation: SharedMemoryReservation,
}

impl RecordBatchStream for CrossJoinStream {
Expand Down Expand Up @@ -385,28 +402,30 @@ impl CrossJoinStream {
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<RecordBatch>>> {
let build_timer = self.join_metrics.build_time.timer();
let left_data = match ready!(self.left_fut.get(cx)) {
Ok(left_data) => left_data,
Err(e) => return Poll::Ready(Some(Err(e))),
};
build_timer.done();

if left_data.num_rows() == 0 {
return Poll::Ready(None);
}

if self.left_index > 0 && self.left_index < left_data.num_rows() {
let start = Instant::now();
let join_timer = self.join_metrics.join_time.timer();
let right_batch = {
let right_batch = self.right_batch.lock();
right_batch.clone().unwrap()
};
let result =
build_batch(self.left_index, &right_batch, left_data, &self.schema);
self.num_input_rows += right_batch.num_rows();
self.join_metrics.input_rows.add(right_batch.num_rows());
if let Ok(ref batch) = result {
self.join_time += start.elapsed().as_millis() as usize;
self.num_output_batches += 1;
self.num_output_rows += batch.num_rows();
join_timer.done();
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
}
self.left_index += 1;
return Poll::Ready(Some(result));
Expand All @@ -416,15 +435,15 @@ impl CrossJoinStream {
.poll_next_unpin(cx)
.map(|maybe_batch| match maybe_batch {
Some(Ok(batch)) => {
let start = Instant::now();
let join_timer = self.join_metrics.join_time.timer();
let result =
build_batch(self.left_index, &batch, left_data, &self.schema);
self.num_input_batches += 1;
self.num_input_rows += batch.num_rows();
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
if let Ok(ref batch) = result {
self.join_time += start.elapsed().as_millis() as usize;
self.num_output_batches += 1;
self.num_output_rows += batch.num_rows();
join_timer.done();
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
}
self.left_index = 1;

Expand All @@ -434,15 +453,7 @@ impl CrossJoinStream {
Some(result)
}
other => {
debug!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a nice improvement that the metrics are now included in things like EXPLAIN ANALYZE

"Processed {} probe-side input batches containing {} rows and \
produced {} output batches containing {} rows in {} ms",
self.num_input_batches,
self.num_input_rows,
self.num_output_batches,
self.num_output_rows,
self.join_time
);
self.reservation.lock().free();
other
}
})
Expand All @@ -452,6 +463,25 @@ impl CrossJoinStream {
#[cfg(test)]
mod tests {
use super::*;
use crate::assert_batches_sorted_eq;
use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use crate::physical_plan::common;
use crate::prelude::{SessionConfig, SessionContext};
use crate::test::{build_table_scan_i32, columns};

async fn join_collect(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
context: Arc<TaskContext>,
) -> Result<(Vec<String>, Vec<RecordBatch>)> {
let join = CrossJoinExec::new(left, right);
let columns_header = columns(&join.schema());

let stream = join.execute(0, context)?;
let batches = common::collect(stream).await?;

Ok((columns_header, batches))
}

#[tokio::test]
async fn test_stats_cartesian_product() {
Expand Down Expand Up @@ -589,4 +619,72 @@ mod tests {

assert_eq!(result, expected);
}

#[tokio::test]
async fn test_join() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();

let left = build_table_scan_i32(
("a1", &vec![1, 2, 3]),
("b1", &vec![4, 5, 6]),
("c1", &vec![7, 8, 9]),
);
let right = build_table_scan_i32(
("a2", &vec![10, 11]),
("b2", &vec![12, 13]),
("c2", &vec![14, 15]),
);

let (columns, batches) = join_collect(left, right, task_ctx).await?;

assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
let expected = vec![
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b2 | c2 |",
"+----+----+----+----+----+----+",
"| 1 | 4 | 7 | 10 | 12 | 14 |",
"| 1 | 4 | 7 | 11 | 13 | 15 |",
"| 2 | 5 | 8 | 10 | 12 | 14 |",
"| 2 | 5 | 8 | 11 | 13 | 15 |",
"| 3 | 6 | 9 | 10 | 12 | 14 |",
"| 3 | 6 | 9 | 11 | 13 | 15 |",
"+----+----+----+----+----+----+",
];

assert_batches_sorted_eq!(expected, &batches);

Ok(())
}

#[tokio::test]
async fn test_overallocation() -> Result<()> {
let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0);
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
let session_ctx =
SessionContext::with_config_rt(SessionConfig::default(), runtime);
let task_ctx = session_ctx.task_ctx();

let left = build_table_scan_i32(
("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
);
let right = build_table_scan_i32(
("a2", &vec![10, 11]),
("b2", &vec![12, 13]),
("c2", &vec![14, 15]),
);

let err = join_collect(left, right, task_ctx).await.unwrap_err();

assert_eq!(
err.to_string(),
"External error: Resources exhausted: Failed to allocate \
additional 744 bytes for CrossJoinStream[0] with 0 bytes \
already allocated - maximum available is 100"
);

Ok(())
}
}
Loading