diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 96aa1be181f5..45b8f0826f9a 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -125,8 +125,6 @@ async fn test_left_join_1k() { } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 async fn test_left_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -134,7 +132,7 @@ async fn test_left_join_1k_filtered() { JoinType::Left, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -151,8 +149,6 @@ async fn test_right_join_1k() { } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 async fn test_right_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -160,7 +156,7 @@ async fn test_right_join_1k_filtered() { JoinType::Right, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -177,8 +173,6 @@ async fn test_full_join_1k() { } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 async fn test_full_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -186,7 +180,7 @@ async fn test_full_join_1k_filtered() { JoinType::Full, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -227,8 +221,6 @@ async fn test_anti_join_1k() { } #[tokio::test] -// flaky for HjSmj case, giving 1 rows difference sometimes -// https://github.com/apache/datafusion/issues/11555 async fn test_anti_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -236,7 +228,7 @@ async fn test_anti_join_1k_filtered() { JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 2118c1a5266f..007e1c149ace 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -22,7 +22,7 @@ use std::any::Any; use std::cmp::Ordering; -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::fmt::Formatter; use std::fs::File; use std::io::BufReader; @@ -38,8 +38,7 @@ use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use arrow::ipc::reader::FileReader; use arrow_array::types::UInt64Type; -use futures::{Stream, StreamExt}; -use hashbrown::HashSet; +use futures::{ready, Stream, StreamExt}; use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, @@ -52,6 +51,7 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; use datafusion_physical_expr_common::sort_expr::LexRequirement; +use types::UInt32Type; use crate::expressions::PhysicalSortExpr; use crate::joins::utils::{ @@ -66,6 +66,13 @@ use crate::{ RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use super::utils::{ + adjust_indices_by_join_type, + apply_join_filter_to_indices, + get_final_indices_from_bit_map, //apply_join_filter_to_indices, + //build_batch_from_indices +}; + /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. #[derive(Debug)] @@ -470,34 +477,6 @@ enum SMJState { Exhausted, } -/// State of streamed data stream -#[derive(Debug, PartialEq, Eq)] -enum StreamedState { - /// Init polling - Init, - /// Polling one streamed row - Polling, - /// Ready to produce one streamed row - Ready, - /// No more streamed row - Exhausted, -} - -/// State of buffered data stream -#[derive(Debug, PartialEq, Eq)] -enum BufferedState { - /// Init polling - Init, - /// Polling first row in the next batch - PollingFirst, - /// Polling rest rows in the next batch - PollingRest, - /// Ready to produce one batch - Ready, - /// No more buffered batches - Exhausted, -} - /// Represents a chunk of joined data from streamed and buffered side struct StreamedJoinedChunk { /// Index of batch in buffered_data @@ -521,10 +500,10 @@ struct StreamedBatch { pub output_indices: Vec, /// Index of currently scanned batch from buffered data pub buffered_batch_idx: Option, - /// Indices that found a match for the given join filter - /// Used for semi joins to keep track the streaming index which got a join filter match - /// and already emitted to the output. - pub join_filter_matched_idxs: HashSet, + /// Alignment range start + pub alignment_range_start: usize, + /// Alignment range end + pub alignment_range_end: usize, } impl StreamedBatch { @@ -536,7 +515,8 @@ impl StreamedBatch { join_arrays, output_indices: vec![], buffered_batch_idx: None, - join_filter_matched_idxs: HashSet::new(), + alignment_range_start: 0, + alignment_range_end: 0, } } @@ -547,7 +527,8 @@ impl StreamedBatch { join_arrays: vec![], output_indices: vec![], buffered_batch_idx: None, - join_filter_matched_idxs: HashSet::new(), + alignment_range_start: 0, + alignment_range_end: 0, } } @@ -568,7 +549,7 @@ impl StreamedBatch { buffered_indices: UInt64Builder::with_capacity(1), }); self.buffered_batch_idx = buffered_batch_idx; - }; + } let current_chunk = self.output_indices.last_mut().unwrap(); // Append index of streamed batch and index of buffered batch into current chunk @@ -591,15 +572,11 @@ struct BufferedBatch { pub range: Range, /// Array refs of the join key pub join_arrays: Vec, - /// Buffered joined index (null joining buffered) - pub null_joined: Vec, /// Size estimation used for reserving / releasing memory pub size_estimation: usize, - /// The indices of buffered batch that failed the join filter. - /// This is a map between buffered row index and a boolean value indicating whether all joined row - /// of the buffered row failed the join filter. - /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. - pub join_filter_failed_map: HashMap, + /// Contains `true` if buffered batch row has any matches with + /// streamed side + pub visited_indices_bitmap: BooleanBufferBuilder, /// Current buffered batch number of rows. Equal to batch.num_rows() /// but if batch is spilled to disk this property is preferable /// and less expensive @@ -634,13 +611,14 @@ impl BufferedBatch { + mem::size_of::(); let num_rows = batch.num_rows(); + let mut visited_indices_bitmap = BooleanBufferBuilder::new(num_rows); + visited_indices_bitmap.append_n(num_rows, false); BufferedBatch { batch: Some(batch), range, join_arrays, - null_joined: vec![], size_estimation, - join_filter_failed_map: HashMap::new(), + visited_indices_bitmap, num_rows, spill_file: None, } @@ -670,14 +648,10 @@ struct SMJStream { pub streamed_batch: StreamedBatch, /// Current buffered data pub buffered_data: BufferedData, - /// (used in outer join) Is current streamed row joined at least once? - pub streamed_joined: bool, - /// (used in outer join) Is current buffered batches joined at least once? - pub buffered_joined: bool, /// State of streamed - pub streamed_state: StreamedState, + pub streamed_exhausted: bool, /// State of buffered - pub buffered_state: BufferedState, + pub buffered_exhausted: bool, /// The comparison result of current streamed row and buffered batches pub current_ordering: Ordering, /// Join key columns of streamed @@ -723,77 +697,64 @@ impl Stream for SMJStream { loop { match &self.state { SMJState::Init => { - let streamed_exhausted = - self.streamed_state == StreamedState::Exhausted; - let buffered_exhausted = - self.buffered_state == BufferedState::Exhausted; - self.state = if streamed_exhausted && buffered_exhausted { - SMJState::Exhausted - } else { - match self.current_ordering { - Ordering::Less | Ordering::Equal => { - if !streamed_exhausted { - self.streamed_joined = false; - self.streamed_state = StreamedState::Init; - } - } - Ordering::Greater => { - if !buffered_exhausted { - self.buffered_joined = false; - self.buffered_state = BufferedState::Init; - } - } - } - SMJState::Polling - }; + ready!(self.poll_streamed_row(cx)?); + ready!(self.poll_buffered_batches(cx)?); + + if self.streamed_exhausted && self.buffered_exhausted { + self.state = SMJState::Exhausted; + continue; + } + + self.current_ordering = self.compare_streamed_buffered()?; + self.state = SMJState::JoinOutput; } SMJState::Polling => { - if ![StreamedState::Exhausted, StreamedState::Ready] - .contains(&self.streamed_state) - { - match self.poll_streamed_row(cx)? { - Poll::Ready(_) => {} - Poll::Pending => return Poll::Pending, + match self.current_ordering { + Ordering::Less | Ordering::Equal => { + if !self.streamed_exhausted { + ready!(self.poll_streamed_row(cx)?); + } } - } - - if ![BufferedState::Exhausted, BufferedState::Ready] - .contains(&self.buffered_state) - { - match self.poll_buffered_batches(cx)? { - Poll::Ready(_) => {} - Poll::Pending => return Poll::Pending, + Ordering::Greater => { + if !self.buffered_exhausted { + ready!(self.poll_buffered_batches(cx)?); + } } } - let streamed_exhausted = - self.streamed_state == StreamedState::Exhausted; - let buffered_exhausted = - self.buffered_state == BufferedState::Exhausted; - if streamed_exhausted && buffered_exhausted { + + if self.streamed_exhausted && self.buffered_exhausted { self.state = SMJState::Exhausted; continue; } + self.current_ordering = self.compare_streamed_buffered()?; self.state = SMJState::JoinOutput; } SMJState::JoinOutput => { - self.join_partial()?; + // For equal ordering, creating matched pairs of indices, + // if scanning finished -- can proceed with the next streamed row + // otherwise -- output + if self.current_ordering == Ordering::Equal { + self.match_against_buffered_data(); - if self.output_size < self.batch_size { + // Needs to complete buffered data scanning if self.buffered_data.scanning_finished() { self.buffered_data.scanning_reset(); - self.state = SMJState::Init; + self.state = SMJState::Polling; } - } else { - self.freeze_all()?; - if !self.output_record_batches.is_empty() { - let record_batch = self.output_record_batch_and_reset()?; - return Poll::Ready(Some(Ok(record_batch))); + + if self.output_size >= self.batch_size { + self.freeze_all()?; + let batch = self.output_record_batch_and_reset()?; + return Poll::Ready(Some(Ok(batch))); } - return Poll::Pending; } + + self.state = SMJState::Polling; } SMJState::Exhausted => { + self.streamed_batch.alignment_range_end = + self.streamed_batch.batch.num_rows(); self.freeze_all()?; if !self.output_record_batches.is_empty() { let record_batch = self.output_record_batch_and_reset()?; @@ -836,10 +797,8 @@ impl SMJStream { buffered, streamed_batch: StreamedBatch::new_empty(streamed_schema), buffered_data: BufferedData::default(), - streamed_joined: false, - buffered_joined: false, - streamed_state: StreamedState::Init, - buffered_state: BufferedState::Init, + streamed_exhausted: false, + buffered_exhausted: false, current_ordering: Ordering::Equal, on_streamed, on_buffered, @@ -856,40 +815,30 @@ impl SMJStream { /// Poll next streamed row fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll>> { + if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows() { + self.streamed_batch.idx += 1; + return Poll::Ready(Some(Ok(()))); + } + loop { - match &self.streamed_state { - StreamedState::Init => { - if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows() - { - self.streamed_batch.idx += 1; - self.streamed_state = StreamedState::Ready; - return Poll::Ready(Some(Ok(()))); - } else { - self.streamed_state = StreamedState::Polling; - } - } - StreamedState::Polling => match self.streamed.poll_next_unpin(cx)? { - Poll::Pending => { - return Poll::Pending; - } - Poll::Ready(None) => { - self.streamed_state = StreamedState::Exhausted; - } - Poll::Ready(Some(batch)) => { - if batch.num_rows() > 0 { + match ready!(self.streamed.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if batch.num_rows() > 0 { + if self.streamed_batch.batch.num_rows() != 0 { + self.streamed_batch.alignment_range_end = + self.streamed_batch.batch.num_rows(); self.freeze_streamed()?; - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - self.streamed_batch = - StreamedBatch::new(batch, &self.on_streamed); - self.streamed_state = StreamedState::Ready; } + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + self.streamed_batch = + StreamedBatch::new(batch, &self.on_streamed); + return Poll::Ready(Some(Ok(()))); } - }, - StreamedState::Ready => { - return Poll::Ready(Some(Ok(()))); } - StreamedState::Exhausted => { + Some(Err(err)) => return Poll::Ready(Some(Err(err))), + None => { + self.streamed_exhausted = true; return Poll::Ready(None); } } @@ -950,115 +899,90 @@ impl SMJStream { /// Poll next buffered batches fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll>> { - loop { - match &self.buffered_state { - BufferedState::Init => { - // pop previous buffered batches - while !self.buffered_data.batches.is_empty() { - let head_batch = self.buffered_data.head_batch(); - // If the head batch is fully processed, dequeue it and produce output of it. - if head_batch.range.end == head_batch.num_rows { - self.freeze_dequeuing_buffered()?; - if let Some(buffered_batch) = - self.buffered_data.batches.pop_front() - { - self.free_reservation(buffered_batch)?; - } - } else { - // If the head batch is not fully processed, break the loop. - // Streamed batch will be joined with the head batch in the next step. - break; - } + // pop previous buffered batches + while !self.buffered_data.batches.is_empty() { + let head_batch = self.buffered_data.head_batch(); + // If the head batch is fully processed, dequeue it and produce output of it. + if head_batch.range.end == head_batch.num_rows { + self.freeze_dequeuing_buffered()?; + if let Some(buffered_batch) = self.buffered_data.batches.pop_front() { + self.free_reservation(buffered_batch)?; + } + } else { + // If the head batch is not fully processed, break the loop. + // Streamed batch will be joined with the head batch in the next step. + break; + } + } + + if self.buffered_data.batches.is_empty() { + match ready!(self.buffered.poll_next_unpin(cx)?) { + Some(batch) => { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + if batch.num_rows() > 0 { + let buffered_batch = + BufferedBatch::new(batch, 0..1, &self.on_buffered); + + self.allocate_reservation(buffered_batch)?; } - if self.buffered_data.batches.is_empty() { - self.buffered_state = BufferedState::PollingFirst; + } + None => { + self.buffered_exhausted = true; + return Poll::Ready(None); + } + } + } else { + let tail_batch = self.buffered_data.tail_batch_mut(); + tail_batch.range.start = tail_batch.range.end; + tail_batch.range.end += 1; + } + + loop { + if self.buffered_data.tail_batch().range.end + < self.buffered_data.tail_batch().num_rows + { + while self.buffered_data.tail_batch().range.end + < self.buffered_data.tail_batch().num_rows + { + if is_join_arrays_equal( + &self.buffered_data.head_batch().join_arrays, + self.buffered_data.head_batch().range.start, + &self.buffered_data.tail_batch().join_arrays, + self.buffered_data.tail_batch().range.end, + )? { + self.buffered_data.tail_batch_mut().range.end += 1; } else { - let tail_batch = self.buffered_data.tail_batch_mut(); - tail_batch.range.start = tail_batch.range.end; - tail_batch.range.end += 1; - self.buffered_state = BufferedState::PollingRest; + return Poll::Ready(Some(Ok(()))); } } - BufferedState::PollingFirst => match self.buffered.poll_next_unpin(cx)? { - Poll::Pending => { - return Poll::Pending; - } - Poll::Ready(None) => { - self.buffered_state = BufferedState::Exhausted; - return Poll::Ready(None); + } else { + match ready!(self.buffered.poll_next_unpin(cx)?) { + None => { + return Poll::Ready(Some(Ok(()))); } - Poll::Ready(Some(batch)) => { + Some(batch) => { + // Polling batches coming concurrently as multiple partitions self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); - if batch.num_rows() > 0 { let buffered_batch = - BufferedBatch::new(batch, 0..1, &self.on_buffered); - + BufferedBatch::new(batch, 0..0, &self.on_buffered); self.allocate_reservation(buffered_batch)?; - self.buffered_state = BufferedState::PollingRest; } } - }, - BufferedState::PollingRest => { - if self.buffered_data.tail_batch().range.end - < self.buffered_data.tail_batch().num_rows - { - while self.buffered_data.tail_batch().range.end - < self.buffered_data.tail_batch().num_rows - { - if is_join_arrays_equal( - &self.buffered_data.head_batch().join_arrays, - self.buffered_data.head_batch().range.start, - &self.buffered_data.tail_batch().join_arrays, - self.buffered_data.tail_batch().range.end, - )? { - self.buffered_data.tail_batch_mut().range.end += 1; - } else { - self.buffered_state = BufferedState::Ready; - return Poll::Ready(Some(Ok(()))); - } - } - } else { - match self.buffered.poll_next_unpin(cx)? { - Poll::Pending => { - return Poll::Pending; - } - Poll::Ready(None) => { - self.buffered_state = BufferedState::Ready; - } - Poll::Ready(Some(batch)) => { - // Polling batches coming concurrently as multiple partitions - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - if batch.num_rows() > 0 { - let buffered_batch = BufferedBatch::new( - batch, - 0..0, - &self.on_buffered, - ); - self.allocate_reservation(buffered_batch)?; - } - } - } - } - } - BufferedState::Ready => { - return Poll::Ready(Some(Ok(()))); } - BufferedState::Exhausted => { - return Poll::Ready(None); - } - } + }; } } /// Get comparison result of streamed row and buffered batches fn compare_streamed_buffered(&self) -> Result { - if self.streamed_state == StreamedState::Exhausted { + if self.streamed_exhausted { return Ok(Ordering::Greater); } - if !self.buffered_data.has_buffered_rows() { + if self.buffered_exhausted { return Ok(Ordering::Less); } @@ -1072,123 +996,30 @@ impl SMJStream { ); } - /// Produce join and fill output buffer until reaching target batch size - /// or the join is finished - fn join_partial(&mut self) -> Result<()> { - // Whether to join streamed rows - let mut join_streamed = false; - // Whether to join buffered rows - let mut join_buffered = false; - - // determine whether we need to join streamed/buffered rows - match self.current_ordering { - Ordering::Less => { - if matches!( - self.join_type, - JoinType::Left - | JoinType::Right - | JoinType::RightSemi - | JoinType::Full - | JoinType::LeftAnti - ) { - join_streamed = !self.streamed_joined; - } - } - Ordering::Equal => { - if matches!(self.join_type, JoinType::LeftSemi) { - // if the join filter is specified then its needed to output the streamed index - // only if it has not been emitted before - // the `join_filter_matched_idxs` keeps track on if streamed index has a successful - // filter match and prevents the same index to go into output more than once - if self.filter.is_some() { - join_streamed = !self - .streamed_batch - .join_filter_matched_idxs - .contains(&(self.streamed_batch.idx as u64)) - && !self.streamed_joined; - // if the join filter specified there can be references to buffered columns - // so buffered columns are needed to access them - join_buffered = join_streamed; - } else { - join_streamed = !self.streamed_joined; - } - } - if matches!( - self.join_type, - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full - ) { - join_streamed = true; - join_buffered = true; - }; + fn match_against_buffered_data(&mut self) { + while !self.buffered_data.scanning_finished() + && self.output_size < self.batch_size + { + let scanning_idx = self.buffered_data.scanning_idx(); + self.streamed_batch.append_output_pair( + Some(self.buffered_data.scanning_batch_idx), + Some(scanning_idx), + ); - if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() { - join_streamed = !self - .streamed_batch - .join_filter_matched_idxs - .contains(&(self.streamed_batch.idx as u64)) - && !self.streamed_joined; - join_buffered = join_streamed; - } - } - Ordering::Greater => { - if matches!(self.join_type, JoinType::Full) { - join_buffered = !self.buffered_joined; - }; - } - } - if !join_streamed && !join_buffered { - // no joined data - self.buffered_data.scanning_finish(); - return Ok(()); + self.output_size += 1; + self.buffered_data.scanning_advance(); } - if join_buffered { - // joining streamed/nulls and buffered - while !self.buffered_data.scanning_finished() - && self.output_size < self.batch_size - { - let scanning_idx = self.buffered_data.scanning_idx(); - if join_streamed { - // Join streamed row and buffered row - self.streamed_batch.append_output_pair( - Some(self.buffered_data.scanning_batch_idx), - Some(scanning_idx), - ); - } else { - // Join nulls and buffered row for FULL join - self.buffered_data - .scanning_batch_mut() - .null_joined - .push(scanning_idx); - } - self.output_size += 1; - self.buffered_data.scanning_advance(); - - if self.buffered_data.scanning_finished() { - self.streamed_joined = join_streamed; - self.buffered_joined = true; - } - } - } else { - // joining streamed and nulls - let scanning_batch_idx = if self.buffered_data.scanning_finished() { - None + self.streamed_batch.alignment_range_end = + if self.buffered_data.scanning_finished() { + self.streamed_batch.idx + 1 } else { - Some(self.buffered_data.scanning_batch_idx) + self.streamed_batch.idx }; - - self.streamed_batch - .append_output_pair(scanning_batch_idx, None); - self.output_size += 1; - self.buffered_data.scanning_finish(); - self.streamed_joined = true; - } - Ok(()) } fn freeze_all(&mut self) -> Result<()> { self.freeze_streamed()?; - self.freeze_buffered(self.buffered_data.batches.len(), false)?; Ok(()) } @@ -1199,7 +1030,7 @@ impl SMJStream { fn freeze_dequeuing_buffered(&mut self) -> Result<()> { self.freeze_streamed()?; // Only freeze and produce the first batch in buffered_data as the batch is fully processed - self.freeze_buffered(1, true)?; + self.freeze_buffered(1)?; Ok(()) } @@ -1210,18 +1041,16 @@ impl SMJStream { // // If `output_not_matched_filter` is true, this will also produce record batches // for buffered rows which are joined with streamed side but don't match join filter. - fn freeze_buffered( - &mut self, - batch_count: usize, - output_not_matched_filter: bool, - ) -> Result<()> { + fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> { if !matches!(self.join_type, JoinType::Full) { return Ok(()); } for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) { - let buffered_indices = UInt64Array::from_iter_values( - buffered_batch.null_joined.iter().map(|&index| index as u64), + let (buffered_indices, _) = get_final_indices_from_bit_map( + &buffered_batch.visited_indices_bitmap, + self.join_type, ); + if let Some(record_batch) = produce_buffered_null_batch( &self.schema, &self.streamed_schema, @@ -1230,31 +1059,6 @@ impl SMJStream { )? { self.output_record_batches.push(record_batch); } - buffered_batch.null_joined.clear(); - - // For buffered row which is joined with streamed side rows but all joined rows - // don't satisfy the join filter - if output_not_matched_filter { - let not_matched_buffered_indices = buffered_batch - .join_filter_failed_map - .iter() - .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None }) - .collect::>(); - - let buffered_indices = UInt64Array::from_iter_values( - not_matched_buffered_indices.iter().copied(), - ); - - if let Some(record_batch) = produce_buffered_null_batch( - &self.schema, - &self.streamed_schema, - &buffered_indices, - buffered_batch, - )? { - self.output_record_batches.push(record_batch); - } - buffered_batch.join_filter_failed_map.clear(); - } } Ok(()) } @@ -1262,13 +1066,63 @@ impl SMJStream { // Produces and stages record batch for all output indices found // for current streamed batch and clears staged output indices. fn freeze_streamed(&mut self) -> Result<()> { - for chunk in self.streamed_batch.output_indices.iter_mut() { - // The row indices of joined streamed batch - let streamed_indices = chunk.streamed_indices.finish(); + let alignment_join_type = match self.join_type { + JoinType::Left => JoinType::Right, + JoinType::LeftSemi => JoinType::RightSemi, + JoinType::LeftAnti => JoinType::RightAnti, + _ => self.join_type, + }; + let buffered_side = match SortMergeJoinExec::probe_side(&self.join_type) { + JoinSide::Left => JoinSide::Right, + JoinSide::Right => JoinSide::Left, + }; - if streamed_indices.is_empty() { - continue; - } + for chunk_idx in 0..self.streamed_batch.output_indices.len() { + let chunk = &mut self.streamed_batch.output_indices[chunk_idx]; + let streamed_indices = + compute::cast(&chunk.streamed_indices.finish(), &DataType::UInt32)?; + let streamed_indices = + streamed_indices.as_primitive::().to_owned(); + let buffered_indices = chunk.buffered_indices.finish(); + + let (buffered_indices, streamed_indices) = if let Some(filter) = &self.filter + { + apply_join_filter_to_indices( + self.buffered_data.batches[chunk.buffered_batch_idx.unwrap()] + .batch + .as_ref() + .unwrap(), + &self.streamed_batch.batch, + buffered_indices, + streamed_indices, + filter, + buffered_side, + )? + } else { + (buffered_indices, streamed_indices) + }; + + let alignment_range_end = match streamed_indices.len() { + 0 => self.streamed_batch.alignment_range_start, + n => 1 + streamed_indices.value(n - 1) as usize, + }; + + let (buffered_indices, streamed_indices) = adjust_indices_by_join_type( + buffered_indices, + streamed_indices, + self.streamed_batch.alignment_range_start..alignment_range_end, + alignment_join_type, + true, + ); + + let visited_bitmap = &mut self.buffered_data.batches + [chunk.buffered_batch_idx.unwrap()] + .visited_indices_bitmap; + buffered_indices.iter().flatten().for_each(|x| { + visited_bitmap.set_bit(x as usize, true); + }); + + self.streamed_batch.alignment_range_start = alignment_range_end; let mut streamed_columns = self .streamed_batch @@ -1278,8 +1132,6 @@ impl SMJStream { .map(|column| take(column, &streamed_indices, None)) .collect::, ArrowError>>()?; - // The row indices of joined buffered batch - let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); let mut buffered_columns = if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { vec![] @@ -1299,35 +1151,6 @@ impl SMJStream { .collect::>() }; - let streamed_columns_length = streamed_columns.len(); - let buffered_columns_length = buffered_columns.len(); - - // Prepare the columns we apply join filter on later. - // Only for joined rows between streamed and buffered. - let filter_columns = if chunk.buffered_batch_idx.is_some() { - if matches!(self.join_type, JoinType::Right) { - get_filter_column(&self.filter, &buffered_columns, &streamed_columns) - } else if matches!( - self.join_type, - JoinType::LeftSemi | JoinType::LeftAnti - ) { - // unwrap is safe here as we check is_some on top of if statement - let buffered_columns = get_buffered_columns( - &self.buffered_data, - chunk.buffered_batch_idx.unwrap(), - &buffered_indices, - )?; - - get_filter_column(&self.filter, &streamed_columns, &buffered_columns) - } else { - get_filter_column(&self.filter, &streamed_columns, &buffered_columns) - } - } else { - // This chunk is totally for null joined rows (outer join), we don't need to apply join filter. - // Any join filter applied only on either streamed or buffered side will be pushed already. - vec![] - }; - let columns = if matches!(self.join_type, JoinType::Right) { buffered_columns.extend(streamed_columns.clone()); buffered_columns @@ -1339,166 +1162,62 @@ impl SMJStream { let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns.clone())?; - // Apply join filter if any - if !filter_columns.is_empty() { - if let Some(f) = &self.filter { - // Construct batch with only filter columns - let filter_batch = RecordBatch::try_new( - Arc::new(f.schema().clone()), - filter_columns, - )?; + self.output_record_batches.push(output_batch); + } - let filter_result = f - .expression() - .evaluate(&filter_batch)? - .into_array(filter_batch.num_rows())?; - - // The boolean selection mask of the join filter result - let pre_mask = - datafusion_common::cast::as_boolean_array(&filter_result)?; - - // If there are nulls in join filter result, exclude them from selecting - // the rows to output. - let mask = if pre_mask.null_count() > 0 { - compute::prep_null_mask_filter( - datafusion_common::cast::as_boolean_array(&filter_result)?, - ) - } else { - pre_mask.clone() - }; - - // For certain join types, we need to adjust the initial mask to handle the join filter. - let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = - get_filtered_join_mask( - self.join_type, - &streamed_indices, - &mask, - &self.streamed_batch.join_filter_matched_idxs, - &self.buffered_data.scanning_offset, - ); - - let mask = - if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { - self.streamed_batch - .join_filter_matched_idxs - .extend(&filtered_join_mask.1); - &filtered_join_mask.0 - } else { - &mask - }; + if self.streamed_batch.alignment_range_start + < self.streamed_batch.alignment_range_end + { + let streamed_indices = UInt32Builder::new().finish(); + let buffered_indices = UInt64Builder::new().finish(); + + let (buffered_indices, streamed_indices) = adjust_indices_by_join_type( + buffered_indices, + streamed_indices, + self.streamed_batch.alignment_range_start + ..self.streamed_batch.alignment_range_end, + alignment_join_type, + true, + ); - // Push the filtered batch which contains rows passing join filter to the output - let filtered_batch = - compute::filter_record_batch(&output_batch, mask)?; - self.output_record_batches.push(filtered_batch); - - // For outer joins, we need to push the null joined rows to the output if - // all joined rows are failed on the join filter. - // I.e., if all rows joined from a streamed row are failed with the join filter, - // we need to join it with nulls as buffered side. - if matches!( - self.join_type, - JoinType::Left | JoinType::Right | JoinType::Full - ) { - // We need to get the mask for row indices that the joined rows are failed - // on the join filter. I.e., for a row in streamed side, if all joined rows - // between it and all buffered rows are failed on the join filter, we need to - // output it with null columns from buffered side. For the mask here, it - // behaves like LeftAnti join. - let null_mask: BooleanArray = get_filtered_join_mask( - // Set a mask slot as true only if all joined rows of same streamed index - // are failed on the join filter. - // The masking behavior is like LeftAnti join. - JoinType::LeftAnti, - &streamed_indices, - mask, - &self.streamed_batch.join_filter_matched_idxs, - &self.buffered_data.scanning_offset, - ) - .unwrap() - .0; - - let null_joined_batch = - compute::filter_record_batch(&output_batch, &null_mask)?; - - let mut buffered_columns = self - .buffered_schema - .fields() - .iter() - .map(|f| { - new_null_array( - f.data_type(), - null_joined_batch.num_rows(), - ) - }) - .collect::>(); - - let columns = if matches!(self.join_type, JoinType::Right) { - let streamed_columns = null_joined_batch - .columns() - .iter() - .skip(buffered_columns_length) - .cloned() - .collect::>(); - - buffered_columns.extend(streamed_columns); - buffered_columns - } else { - // Left join or full outer join - let mut streamed_columns = null_joined_batch - .columns() - .iter() - .take(streamed_columns_length) - .cloned() - .collect::>(); - - streamed_columns.extend(buffered_columns); - streamed_columns - }; + self.streamed_batch.alignment_range_start = + self.streamed_batch.alignment_range_end; + if streamed_indices.is_empty() { + return Ok(()); + } + let mut streamed_columns = self + .streamed_batch + .batch + .columns() + .iter() + .map(|column| take(column, &streamed_indices, None)) + .collect::, ArrowError>>()?; - // Push the streamed/buffered batch joined nulls to the output - let null_joined_streamed_batch = RecordBatch::try_new( - Arc::clone(&self.schema), - columns.clone(), - )?; - self.output_record_batches.push(null_joined_streamed_batch); - - // For full join, we also need to output the null joined rows from the buffered side. - // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with - // streamed side, it won't be outputted by `freeze_buffered`. - // We need to check if a buffered row is joined with streamed side and output. - // If it is joined with streamed side, but doesn't match the join filter, - // we need to output it with nulls as streamed side. - if matches!(self.join_type, JoinType::Full) { - let buffered_batch = &mut self.buffered_data.batches - [chunk.buffered_batch_idx.unwrap()]; - - for i in 0..pre_mask.len() { - // If the buffered row is not joined with streamed side, - // skip it. - if buffered_indices.is_null(i) { - continue; - } - - let buffered_index = buffered_indices.value(i); - - buffered_batch.join_filter_failed_map.insert( - buffered_index, - *buffered_batch - .join_filter_failed_map - .get(&buffered_index) - .unwrap_or(&true) - && !pre_mask.value(i), - ); - } - } - } + let mut buffered_columns = + if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + vec![] } else { - self.output_record_batches.push(output_batch); - } + // If buffered batch none, meaning it is null joined batch. + // We need to create null arrays for buffered columns to join with streamed rows. + self.buffered_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), buffered_indices.len())) + .collect::>() + }; + + let columns = if matches!(self.join_type, JoinType::Right) { + buffered_columns.extend(streamed_columns.clone()); + buffered_columns } else { - self.output_record_batches.push(output_batch); - } + streamed_columns.extend(buffered_columns); + streamed_columns + }; + + let output_batch = + RecordBatch::try_new(Arc::clone(&self.schema), columns.clone())?; + + self.output_record_batches.push(output_batch); } self.streamed_batch.output_indices.clear(); @@ -1525,36 +1244,6 @@ impl SMJStream { } } -/// Gets the arrays which join filters are applied on. -fn get_filter_column( - join_filter: &Option, - streamed_columns: &[ArrayRef], - buffered_columns: &[ArrayRef], -) -> Vec { - let mut filter_columns = vec![]; - - if let Some(f) = join_filter { - let left_columns = f - .column_indices() - .iter() - .filter(|col_index| col_index.side == JoinSide::Left) - .map(|i| Arc::clone(&streamed_columns[i.index])) - .collect::>(); - - let right_columns = f - .column_indices() - .iter() - .filter(|col_index| col_index.side == JoinSide::Right) - .map(|i| Arc::clone(&buffered_columns[i.index])) - .collect::>(); - - filter_columns.extend(left_columns); - filter_columns.extend(right_columns); - } - - filter_columns -} - fn produce_buffered_null_batch( schema: &SchemaRef, streamed_schema: &SchemaRef, @@ -1631,101 +1320,6 @@ fn get_buffered_columns_from_batch( } } -/// Calculate join filter bit mask considering join type specifics -/// `streamed_indices` - array of streamed datasource JOINED row indices -/// `mask` - array booleans representing computed join filter expression eval result: -/// true = the row index matches the join filter -/// false = the row index doesn't match the join filter -/// `streamed_indices` have the same length as `mask` -/// `matched_indices` array of streaming indices that already has a join filter match -/// `scanning_buffered_offset` current buffered offset across batches -/// -/// This return a tuple of: -/// - corrected mask with respect to the join type -/// - indices of rows in streamed batch that have a join filter match -fn get_filtered_join_mask( - join_type: JoinType, - streamed_indices: &UInt64Array, - mask: &BooleanArray, - matched_indices: &HashSet, - scanning_buffered_offset: &usize, -) -> Option<(BooleanArray, Vec)> { - let mut seen_as_true: bool = false; - let streamed_indices_length = streamed_indices.len(); - let mut corrected_mask: BooleanBuilder = - BooleanBuilder::with_capacity(streamed_indices_length); - - let mut filter_matched_indices: Vec = vec![]; - - #[allow(clippy::needless_range_loop)] - match join_type { - // for LeftSemi Join the filter mask should be calculated in its own way: - // if we find at least one matching row for specific streaming index - // we don't need to check any others for the same index - JoinType::LeftSemi => { - // have we seen a filter match for a streaming index before - for i in 0..streamed_indices_length { - // LeftSemi respects only first true values for specific streaming index, - // others true values for the same index must be false - let streamed_idx = streamed_indices.value(i); - if mask.value(i) - && !seen_as_true - && !matched_indices.contains(&streamed_idx) - { - seen_as_true = true; - corrected_mask.append_value(true); - filter_matched_indices.push(streamed_idx); - } else { - corrected_mask.append_value(false); - } - - // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag - if i < streamed_indices_length - 1 - && streamed_idx != streamed_indices.value(i + 1) - { - seen_as_true = false; - } - } - Some((corrected_mask.finish(), filter_matched_indices)) - } - // LeftAnti semantics: return true if for every x in the collection the join matching filter is false. - // `filter_matched_indices` needs to be set once per streaming index - // to prevent duplicates in the output - JoinType::LeftAnti => { - // have we seen a filter match for a streaming index before - for i in 0..streamed_indices_length { - let streamed_idx = streamed_indices.value(i); - if mask.value(i) - && !seen_as_true - && !matched_indices.contains(&streamed_idx) - { - seen_as_true = true; - filter_matched_indices.push(streamed_idx); - } - - // Reset `seen_as_true` flag and calculate mask for the current streaming index - // - if within the batch it switched to next streaming index(e.g. from 0 to 1, or from 1 to 2) - // - if it is at the end of the all buffered batches for the given streaming index, 0 index comes last - if (i < streamed_indices_length - 1 - && streamed_idx != streamed_indices.value(i + 1)) - || (i == streamed_indices_length - 1 - && *scanning_buffered_offset == 0) - { - corrected_mask.append_value( - !matched_indices.contains(&streamed_idx) && !seen_as_true, - ); - seen_as_true = false; - } else { - corrected_mask.append_value(false); - } - } - - Some((corrected_mask.finish(), filter_matched_indices)) - } - _ => None, - } -} - /// Buffered data contains all buffered batches with one unique join key #[derive(Debug, Default)] struct BufferedData { @@ -1750,10 +1344,6 @@ impl BufferedData { self.batches.back_mut().unwrap() } - pub fn has_buffered_rows(&self) -> bool { - self.batches.iter().any(|batch| !batch.range.is_empty()) - } - pub fn scanning_reset(&mut self) { self.scanning_batch_idx = 0; self.scanning_offset = 0; @@ -1771,10 +1361,6 @@ impl BufferedData { &self.batches[self.scanning_batch_idx] } - pub fn scanning_batch_mut(&mut self) -> &mut BufferedBatch { - &mut self.batches[self.scanning_batch_idx] - } - pub fn scanning_idx(&self) -> usize { self.scanning_batch().range.start + self.scanning_offset } @@ -1786,11 +1372,6 @@ impl BufferedData { pub fn scanning_finished(&self) -> bool { self.scanning_batch_idx == self.batches.len() } - - pub fn scanning_finish(&mut self) { - self.scanning_batch_idx = self.batches.len(); - self.scanning_offset = 0; - } } /// Get join array refs of given batch and join columns @@ -1969,21 +1550,21 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; - use arrow_array::{BooleanArray, UInt64Array}; - use hashbrown::HashSet; - use datafusion_common::JoinType::{LeftAnti, LeftSemi}; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, + assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinSide, JoinType, + Result, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::disk_manager::DiskManagerConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::BinaryExpr; + use datafusion_physical_expr::PhysicalExpr; use crate::expressions::Column; - use crate::joins::sort_merge_join::get_filtered_join_mask; - use crate::joins::utils::JoinOn; + use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn}; use crate::joins::SortMergeJoinExec; use crate::memory::MemoryExec; use crate::test::build_table_i32; @@ -2087,25 +1668,6 @@ mod tests { SortMergeJoinExec::try_new(left, right, on, None, join_type, sort_options, false) } - fn join_with_options( - left: Arc, - right: Arc, - on: JoinOn, - join_type: JoinType, - sort_options: Vec, - null_equals_null: bool, - ) -> Result { - SortMergeJoinExec::try_new( - left, - right, - on, - None, - join_type, - sort_options, - null_equals_null, - ) - } - async fn join_collect( left: Arc, right: Arc, @@ -2113,22 +1675,26 @@ mod tests { join_type: JoinType, ) -> Result<(Vec, Vec)> { let sort_options = vec![SortOptions::default(); on.len()]; - join_collect_with_options(left, right, on, join_type, sort_options, false).await + join_collect_with_options(left, right, on, None, join_type, sort_options, false) + .await } async fn join_collect_with_options( left: Arc, right: Arc, on: JoinOn, + filter: Option, join_type: JoinType, sort_options: Vec, null_equals_null: bool, ) -> Result<(Vec, Vec)> { - let task_ctx = Arc::new(TaskContext::default()); - let join = join_with_options( + let config = SessionConfig::default().with_batch_size(3); + let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); + let join = SortMergeJoinExec::try_new( left, right, on, + filter, join_type, sort_options, null_equals_null, @@ -2157,6 +1723,30 @@ mod tests { Ok((columns, batches)) } + fn prepare_join_filter() -> JoinFilter { + let column_indices = vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ]); + let filter_expression = Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Gt, + Arc::new(Column::new("c", 1)), + )) as Arc; + + JoinFilter::new(filter_expression, column_indices, intermediate_schema) + } + #[tokio::test] async fn join_inner_one() -> Result<()> { let left = build_table( @@ -2332,6 +1922,7 @@ mod tests { left, right, on, + None, JoinType::Inner, vec![ SortOptions { @@ -2821,6 +2412,96 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_right_buffered_batch_no_match() -> Result<()> { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 3]), + ("c1", &vec![4, 5]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 5, 6]), + ("b1", &vec![6, 7, 8]), + ("c1", &vec![7, 9, 9]), + ); + let left_batch_3 = build_table_i32( + ("a1", &vec![3, 5]), + ("b1", &vec![9, 11]), + ("c1", &vec![7, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![2, 2, 7, 10, 10]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let left = + build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); + let right = build_table_from_batches(vec![right_batch_1]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 0 | 2 | 50 |", + "| | | | 10 | 2 | 60 |", + "| 5 | 7 | 9 | 20 | 7 | 70 |", + "| | | | 30 | 10 | 80 |", + "| | | | 40 | 10 | 90 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_last_row_match() -> Result<()> { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 3]), + ("c1", &vec![4, 5]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 6]), + ("b1", &vec![6, 8]), + ("c1", &vec![7, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10]), + ("b2", &vec![2, 4]), + ("c2", &vec![50, 60]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![20, 40]), + ("b2", &vec![5, 8]), + ("c2", &vec![70, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 0 | 2 | 50 |", + "| | | | 10 | 4 | 60 |", + "| | | | 20 | 5 | 70 |", + "| 6 | 8 | 9 | 40 | 8 | 90 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn join_full_multiple_batches() -> Result<()> { let left_batch_1 = build_table_i32( @@ -2872,6 +2553,53 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_right_with_filter() -> Result<()> { + let left_batch = build_table_i32( + ("a1", &vec![0, 0, 0, 0, 0, 0]), + ("b1", &vec![0, 1, 2, 2, 3, 4]), + ("c1", &vec![0, 0, 5, 0, 0, 0]), + ); + let right_batch = build_table_i32( + ("a2", &vec![0, 0, 0, 0, 0, 0]), + ("b2", &vec![0, 1, 2, 3, 4, 5]), + ("c2", &vec![1, 1, 1, 1, 1, 1]), + ); + let left = build_table_from_batches(vec![left_batch]); + let right = build_table_from_batches(vec![right_batch]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + + let filter = prepare_join_filter(); + + let (_, batches) = join_collect_with_options( + left, + right, + on, + Some(filter), + JoinType::Right, + vec![SortOptions::default()], + false, + ) + .await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 0 | 0 | 1 |", + "| | | | 0 | 1 | 1 |", + "| 0 | 2 | 5 | 0 | 2 | 1 |", + "| | | | 0 | 3 | 1 |", + "| | | | 0 | 4 | 1 |", + "| | | | 0 | 5 | 1 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn overallocation_single_batch_no_spill() -> Result<()> { let left = build_table( @@ -2912,10 +2640,11 @@ mod tests { .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); - let join = join_with_options( + let join = SortMergeJoinExec::try_new( Arc::clone(&left), Arc::clone(&right), on.clone(), + None, join_type, sort_options.clone(), false, @@ -2996,10 +2725,11 @@ mod tests { .with_session_config(session_config.clone()) .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); - let join = join_with_options( + let join = SortMergeJoinExec::try_new( Arc::clone(&left), Arc::clone(&right), on.clone(), + None, join_type, sort_options.clone(), false, @@ -3062,10 +2792,11 @@ mod tests { .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); - let join = join_with_options( + let join = SortMergeJoinExec::try_new( Arc::clone(&left), Arc::clone(&right), on.clone(), + None, *join_type, sort_options.clone(), false, @@ -3084,10 +2815,11 @@ mod tests { TaskContext::default().with_session_config(session_config.clone()); let task_ctx_no_spill = Arc::new(task_ctx_no_spill); - let join = join_with_options( + let join = SortMergeJoinExec::try_new( Arc::clone(&left), Arc::clone(&right), on.clone(), + None, *join_type, sort_options.clone(), false, @@ -3169,10 +2901,11 @@ mod tests { .with_session_config(session_config.clone()) .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); - let join = join_with_options( + let join = SortMergeJoinExec::try_new( Arc::clone(&left), Arc::clone(&right), on.clone(), + None, *join_type, sort_options.clone(), false, @@ -3190,10 +2923,11 @@ mod tests { TaskContext::default().with_session_config(session_config.clone()); let task_ctx_no_spill = Arc::new(task_ctx_no_spill); - let join = join_with_options( + let join = SortMergeJoinExec::try_new( Arc::clone(&left), Arc::clone(&right), on.clone(), + None, *join_type, sort_options.clone(), false, @@ -3213,174 +2947,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn left_semi_join_filtered_mask() -> Result<()> { - assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 1, 1]), - &BooleanArray::from(vec![true, true, false, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) - ); - - assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, true]), vec![0, 1])) - ); - - assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![false, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, true]), vec![1])) - ); - - assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false]), vec![0])) - ); - - assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, true, true, true, true, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, true, false, true, false, false]), - vec![0, 1] - )) - ); - - assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, false, false, false, false, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, false, false, false, true]), - vec![1] - )) - ); - - assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![true, false, false, false, false, true]), - &HashSet::from_iter(vec![1]), - &0, - ), - Some(( - BooleanArray::from(vec![true, false, false, false, false, false]), - vec![0] - )) - ); - - Ok(()) - } - - #[tokio::test] - async fn left_anti_join_filtered_mask() -> Result<()> { - assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 1, 1]), - &BooleanArray::from(vec![true, true, false, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, false, false, true]), vec![0])) - ); - - assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, false]), vec![0, 1])) - ); - - assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![false, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false]), vec![1])) - ); - - assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, true]), vec![0])) - ); - - assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, true, true, true, true, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, false, false, false, false]), - vec![0, 1] - )) - ); - - assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, false, false, false, false, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, true, false, false, false]), - vec![1] - )) - ); - - Ok(()) - } - /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect()