Skip to content

Commit

Permalink
Support SortMerge spilling
Browse files Browse the repository at this point in the history
  • Loading branch information
comphead committed Jul 8, 2024
1 parent 8aa0bf6 commit 9c16696
Show file tree
Hide file tree
Showing 4 changed files with 325 additions and 60 deletions.
15 changes: 14 additions & 1 deletion datafusion/core/tests/memory_limit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,20 @@ async fn sort_merge_join_no_spill() {

#[tokio::test]
async fn sort_merge_join_spill() {
todo!()
// Planner chooses MergeJoin only if number of partitions > 1
let config = SessionConfig::new()
.with_target_partitions(2)
.set_bool("datafusion.optimizer.prefer_hash_join", false);

TestCase::new()
.with_query(
"select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time",
)
.with_memory_limit(1_000)
.with_config(config)
.with_disk_manager_config(DiskManagerConfig::NewOs)
.run()
.await
}

#[tokio::test]
Expand Down
17 changes: 16 additions & 1 deletion datafusion/execution/src/memory_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! [`MemoryPool`] for memory management during query execution, [`proxy]` for
//! help with allocation accounting.
use datafusion_common::Result;
use datafusion_common::{internal_err, Result};
use std::{cmp::Ordering, sync::Arc};

mod pool;
Expand Down Expand Up @@ -220,6 +220,21 @@ impl MemoryReservation {
self.size = new_size
}

/// Tries to free `capacity` bytes from this reservation
/// if `capacity` does not exceed [`Self::size`]
pub fn try_shrink(&mut self, capacity: usize) -> Result<()> {
if let Some(new_size) = self.size.checked_sub(capacity) {
self.registration.pool.shrink(self, capacity);
self.size = new_size;
Ok(())
} else {
internal_err!(
"Cannot free the capacity {capacity} out of allocated size {}",
self.size
)
}
}

/// Sets the size of this reservation to `capacity`
pub fn resize(&mut self, capacity: usize) {
match capacity.cmp(&self.size) {
Expand Down
250 changes: 196 additions & 54 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,44 +24,45 @@ use std::any::Any;
use std::cmp::Ordering;
use std::collections::VecDeque;
use std::fmt::Formatter;
use std::fs::File;
use std::io::BufReader;
use std::mem;
use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::expressions::PhysicalSortExpr;
use crate::joins::utils::{
build_join_schema, check_join_is_valid, estimate_join_statistics,
symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
};
use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use crate::{
execution_mode_from_children, metrics, read_spill_as_stream, spill_record_batches,
DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream,
Statistics,
};

use arrow::array::*;
use arrow::compute::{self, concat_batches, take, SortOptions};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use arrow::error::ArrowError;
use arrow::ipc::reader::FileReader;
use arrow_array::types::UInt64Type;
use futures::{Stream, StreamExt};
use hashbrown::HashSet;

use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType,
Result,
internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result,
};
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};

use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::runtime_env::RuntimeEnv;
use futures::{Stream, StreamExt};
use hashbrown::HashSet;
use crate::expressions::PhysicalSortExpr;
use crate::joins::utils::{
build_join_schema, check_join_is_valid, estimate_join_statistics,
symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
};
use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use crate::{
execution_mode_from_children, metrics, spill_record_batch_by_size, DisplayAs,
DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream,
Statistics,
};

/// join execution plan executes partitions in parallel and combines them into a set of
/// partitions.
Expand Down Expand Up @@ -638,27 +639,22 @@ impl BufferedBatch {
&mut self,
path: RefCountedTempFile,
buffered_schema: SchemaRef,
batch_size: usize,
) -> Result<()> {
let batch = std::mem::replace(
&mut self.batch,
RecordBatch::new_empty(buffered_schema.clone()),
);
let _ = spill_record_batches(vec![batch], path.path().into(), buffered_schema)?;
let _ = spill_record_batch_by_size(
batch,
path.path().into(),
buffered_schema,
batch_size,
);
self.spill_file = Some(path);

dbg!(&self.spill_file);
Ok(())
}

fn read_spilled_from_disk(
&mut self,
schema: SchemaRef,
) -> Result<SendableRecordBatchStream> {
if let Some(f) = mem::take(&mut self.spill_file) {
read_spill_as_stream(f, schema, 2)
} else {
exec_err!("Cannot read data batch from disk. Use `spill_to_disk` to spill")
}
}
}

/// Sort-merge join stream that consumes streamed and buffered data stream
Expand Down Expand Up @@ -925,8 +921,11 @@ impl SMJStream {
if let Some(buffered_batch) =
self.buffered_data.batches.pop_front()
{
println!("shrink\n");
self.reservation.shrink(buffered_batch.size_estimation);
// Noop on shrink complaints, this might happen
// on spilled batches
self.reservation
.try_shrink(buffered_batch.size_estimation)
.unwrap_or(());
}
} else {
// If the head batch is not fully processed, break the loop.
Expand Down Expand Up @@ -988,6 +987,7 @@ impl SMJStream {
buffered_batch.spill_to_disk(
spill_file,
self.buffered_schema.clone(),
self.batch_size,
)?;

// update metrics to display spill
Expand Down Expand Up @@ -1614,30 +1614,46 @@ fn get_buffered_columns_from_batch(
buffered_batch: &mut BufferedBatch,
buffered_indices: &UInt64Array,
) -> Result<Vec<ArrayRef>, ArrowError> {
if buffered_batch.spill_file.is_none() {
if let Some(spill_file) = mem::take(&mut buffered_batch.spill_file) {
// if spilled read as a stream
let mut buffered_cols: Vec<ArrayRef> = Vec::with_capacity(buffered_indices.len());
// let mut stream =
// read_spill_as_stream(spill_file, buffered_batch.batch.schema(), 2)?;

let file = BufReader::new(File::open(spill_file.path())?);
let reader = FileReader::try_new(file, None)?;

for batch in reader {
let batch = batch?;
batch.columns().iter().for_each(|column| {
buffered_cols.extend(take(column, &buffered_indices, None))
});
}

// let _ = futures::stream::once(async {
// dbg!("in");
// while let Some(batch) = stream.next().await {
// dbg!("stream spilled batch");
//
// let batch = batch?;
// batch.columns().iter().for_each(|column| {
// buffered_cols.extend(take(column, &buffered_indices, None))
// });
// }
//
// Ok::<(), ArrowError>(())
// });

dbg!(&buffered_cols);

Ok(buffered_cols)
} else {
buffered_batch
.batch
.columns()
.iter()
.map(|column| take(column, &buffered_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()
} else {
// if spilled read as a stream
let mut buffered_cols: Vec<ArrayRef> = Vec::with_capacity(buffered_indices.len());
let mut stream =
buffered_batch.read_spilled_from_disk(buffered_batch.batch.schema())?;
let _ = futures::stream::once(async {
while let Some(batch) = stream.next().await {
let batch = batch?;
batch.columns().iter().for_each(|column| {
buffered_cols.extend(take(column, &buffered_indices, None))
});
}

Ok::<(), DataFusionError>(())
});

Ok(buffered_cols)
}
}

Expand Down Expand Up @@ -1674,6 +1690,7 @@ fn get_filtered_join_mask(
// we don't need to check any others for the same index
JoinType::LeftSemi => {
// have we seen a filter match for a streaming index before
// have we seen a filter match for are streaming index before
for i in 0..streamed_indices_length {
// LeftSemi respects only first true values for specific streaming index,
// others true values for the same index must be false
Expand Down Expand Up @@ -3025,11 +3042,136 @@ mod tests {

#[tokio::test]
async fn overallocation_single_batch_spill() -> Result<()> {
todo!()
let left = build_table(
("a1", &vec![0, 1, 2, 3, 4, 5]),
("b1", &vec![1, 2, 3, 4, 5, 6]),
("c1", &vec![4, 5, 6, 7, 8, 9]),
);
let right = build_table(
("a2", &vec![0, 10, 20, 30, 40]),
("b2", &vec![1, 3, 4, 6, 8]),
("c2", &vec![50, 60, 70, 80, 90]),
);
let on = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let sort_options = vec![SortOptions::default(); on.len()];

let join_types = vec![
JoinType::Inner,
JoinType::Left,
JoinType::Right,
//JoinType::Full,
JoinType::LeftSemi,
JoinType::LeftAnti,
];

// Enable DiskManager to allow spilling
let runtime_config = RuntimeConfig::new()
.with_memory_limit(100, 1.0)
.with_disk_manager(DiskManagerConfig::NewOs);
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
let session_config = SessionConfig::default().with_batch_size(50);

for join_type in join_types {
let task_ctx = TaskContext::default()
.with_session_config(session_config.clone())
.with_runtime(runtime.clone());
let task_ctx = Arc::new(task_ctx);

println!("{join_type}");

let join = join_with_options(
left.clone(),
right.clone(),
on.clone(),
join_type,
sort_options.clone(),
false,
)?;

let stream = join.execute(0, task_ctx)?;
let _ = common::collect(stream).await.unwrap();
}

Ok(())
}

#[tokio::test]
async fn overallocation_multi_batch_spill() -> Result<()> {
todo!()
let left_batch_1 = build_table_i32(
("a1", &vec![0, 1]),
("b1", &vec![1, 1]),
("c1", &vec![4, 5]),
);
let left_batch_2 = build_table_i32(
("a1", &vec![2, 3]),
("b1", &vec![1, 1]),
("c1", &vec![6, 7]),
);
let left_batch_3 = build_table_i32(
("a1", &vec![4, 5]),
("b1", &vec![1, 1]),
("c1", &vec![8, 9]),
);
let right_batch_1 = build_table_i32(
("a2", &vec![0, 10]),
("b2", &vec![1, 1]),
("c2", &vec![50, 60]),
);
let right_batch_2 = build_table_i32(
("a2", &vec![20, 30]),
("b2", &vec![1, 1]),
("c2", &vec![70, 80]),
);
let right_batch_3 =
build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90]));
let left =
build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]);
let right =
build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]);
let on = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let sort_options = vec![SortOptions::default(); on.len()];

let join_types = vec![
JoinType::Inner,
JoinType::Left,
JoinType::Right,
//JoinType::Full,
JoinType::LeftSemi,
JoinType::LeftAnti,
];

// Enable DiskManager to allow spilling
let runtime_config = RuntimeConfig::new()
.with_memory_limit(100, 1.0)
.with_disk_manager(DiskManagerConfig::NewOs);
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
let session_config = SessionConfig::default().with_batch_size(50);

for join_type in join_types {
let task_ctx = TaskContext::default()
.with_session_config(session_config.clone())
.with_runtime(runtime.clone());
let task_ctx = Arc::new(task_ctx);
let join = join_with_options(
left.clone(),
right.clone(),
on.clone(),
join_type,
sort_options.clone(),
false,
)?;

let stream = join.execute(0, task_ctx)?;
let _ = common::collect(stream).await.unwrap();
}

Ok(())
}

#[tokio::test]
Expand Down
Loading

0 comments on commit 9c16696

Please sign in to comment.