From 146f16a0c14f0fe65e2bd8b7226508f27ced3f13 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Sat, 26 Oct 2024 09:16:01 -0700 Subject: [PATCH] Move filtered SMJ Left Anti filtered join out of `join_partial` phase (#13111) * Move filtered SMJ Left Anti filtered join out of `join_partial` phase --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 6 +- .../src/joins/sort_merge_join.rs | 245 ++++++++++- .../test_files/sort_merge_join.slt | 383 +++++++++--------- 3 files changed, 414 insertions(+), 220 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index ca2c2bf4e438..44d34b674bbb 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -41,6 +41,7 @@ use datafusion::physical_plan::joins::{ }; use datafusion::physical_plan::memory::MemoryExec; +use crate::fuzz_cases::join_fuzz::JoinTestType::NljHj; use datafusion::prelude::{SessionConfig, SessionContext}; use test_utils::stagger_batch_with_seed; @@ -223,9 +224,6 @@ async fn test_anti_join_1k() { } #[tokio::test] -// flaky for HjSmj case, giving 1 rows difference sometimes -// https://github.com/apache/datafusion/issues/11555 -#[ignore] async fn test_anti_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -233,7 +231,7 @@ async fn test_anti_join_1k_filtered() { JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj, NljHj], false) .await } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index d5134855440a..7b7b7462f7e4 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -802,6 +802,32 @@ fn get_corrected_filter_mask( Some(corrected_mask.finish()) } + JoinType::LeftAnti => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + + if filter_mask.value(i) { + seen_true = true; + } + + if last_index { + if !seen_true { + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); + } + + seen_true = false; + } else { + corrected_mask.append_null(); + } + } + + let null_matched = expected_size - corrected_mask.len(); + corrected_mask.extend(vec![Some(true); null_matched]); + Some(corrected_mask.finish()) + } // Only outer joins needs to keep track of processed rows and apply corrected filter mask _ => None, } @@ -835,15 +861,18 @@ impl Stream for SMJStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::LeftAnti ) { self.freeze_all()?; if !self.output_record_batches.batches.is_empty() - && self.buffered_data.scanning_finished() { - let out_batch = self.filter_joined_batch()?; - return Poll::Ready(Some(Ok(out_batch))); + let out_filtered_batch = + self.filter_joined_batch()?; + return Poll::Ready(Some(Ok( + out_filtered_batch, + ))); } } @@ -907,15 +936,17 @@ impl Stream for SMJStream { // because target output batch size can be hit in the middle of // filtering causing the filtering to be incomplete and causing // correctness issues - let record_batch = if !(self.filter.is_some() + if self.filter.is_some() && matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi | JoinType::Right - )) { - record_batch - } else { + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + ) + { continue; - }; + } return Poll::Ready(Some(Ok(record_batch))); } @@ -929,7 +960,10 @@ impl Stream for SMJStream { if self.filter.is_some() && matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi | JoinType::Right + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti ) { let out = self.filter_joined_batch()?; @@ -1273,11 +1307,7 @@ impl SMJStream { }; if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() { - join_streamed = !self - .streamed_batch - .join_filter_matched_idxs - .contains(&(self.streamed_batch.idx as u64)) - && !self.streamed_joined; + join_streamed = !self.streamed_joined; join_buffered = join_streamed; } } @@ -1519,7 +1549,10 @@ impl SMJStream { // Push the filtered batch which contains rows passing join filter to the output if matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi | JoinType::Right + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti ) { self.output_record_batches .batches @@ -1654,7 +1687,10 @@ impl SMJStream { if !(self.filter.is_some() && matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi | JoinType::Right + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti )) { self.output_record_batches.batches.clear(); @@ -1727,7 +1763,7 @@ impl SMJStream { &self.schema, &[filtered_record_batch, null_joined_streamed_batch], )?; - } else if matches!(self.join_type, JoinType::LeftSemi) { + } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { let output_column_indices = (0..streamed_columns_length).collect::>(); filtered_record_batch = filtered_record_batch.project(&output_column_indices)?; @@ -3349,6 +3385,7 @@ mod tests { batch_ids: vec![], }; + // Insert already prejoined non-filtered rows batches.batches.push(RecordBatch::try_new( Arc::clone(&schema), vec![ @@ -3835,6 +3872,178 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_left_anti_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![Some(true)]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + 2 + ) + .unwrap(), + BooleanArray::from(vec![None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true)]) + ); + + let corrected_mask = get_corrected_filter_mask( + LeftAnti, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(true), + None, + Some(true) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 13 | 1 | 12 |", + "| 1 | 14 | 1 | 11 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); + + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(false), + None, + Some(false), + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 051cc6dce3d4..f4cc888d6b8e 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -407,214 +407,201 @@ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != statement ok set datafusion.execution.batch_size = 10; -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 13 c union all -# select 11 a, 14 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- -#11 12 - -# Uncomment when filtered LEFTANTI moved -#query III -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b, 1 c union all -# select 11 a, 13 b, 2 c), -#t2 as ( -# select 11 a, 12 b, 3 c union all -# select 11 a, 14 b, 4 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -#) order by 1, 2; -#---- -#11 12 1 -#11 13 2 - -# Uncomment when filtered LEFTANTI moved -#query III -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b, 1 c union all -# select 11 a, 13 b, 2 c), -#t2 as ( -# select 11 a, 12 b, 3 c where false -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -#) order by 1, 2; -#---- -#11 12 1 -#11 13 2 - -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 13 c union all -# select 11 a, 14 c union all -# select 11 a, 15 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- -#11 12 - -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 11 c union all -# select 11 a, 14 c union all -# select 11 a, 15 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 12 c union all -# select 11 a, 11 c union all -# select 11 a, 15 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c union all + select 11 a, 14 b, 4 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c where false + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 12 c union all -# select 11 a, 14 c union all -# select 11 a, 11 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 11 c union all + select 11 a, 14 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 12 c union all + select 11 a, 11 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- + + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 12 c union all + select 11 a, 14 c union all + select 11 a, 11 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- # Test LEFT ANTI with cross batch data distribution statement ok set datafusion.execution.batch_size = 1; -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 13 c union all -# select 11 a, 14 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- -#11 12 - -# Uncomment when filtered LEFTANTI moved -#query III -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b, 1 c union all -# select 11 a, 13 b, 2 c), -#t2 as ( -# select 11 a, 12 b, 3 c union all -# select 11 a, 14 b, 4 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -#) order by 1, 2; -#---- -#11 12 1 -#11 13 2 - -# Uncomment when filtered LEFTANTI moved -#query III -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b, 1 c union all -# select 11 a, 13 b, 2 c), -#t2 as ( -# select 11 a, 12 b, 3 c where false -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -#) order by 1, 2; -#---- -#11 12 1 -#11 13 2 - -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 13 c union all -# select 11 a, 14 c union all -# select 11 a, 15 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- -#11 12 - -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 12 c union all -# select 11 a, 11 c union all -# select 11 a, 15 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 12 c union all -# select 11 a, 14 c union all -# select 11 a, 11 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c union all + select 11 a, 14 b, 4 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c where false + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 12 c union all + select 11 a, 11 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 12 c union all + select 11 a, 14 c union all + select 11 a, 11 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- query IIII select * from (