Skip to content

Commit

Permalink
tests: enable filtered semi-join nlj-hj fuzz case
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Sep 5, 2024
1 parent d6f3f73 commit 20f69cd
Showing 1 changed file with 92 additions and 41 deletions.
133 changes: 92 additions & 41 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use datafusion_common::ScalarValue;
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr::PhysicalExprRef;

use itertools::Itertools;
use rand::Rng;

use datafusion::common::JoinSide;
Expand Down Expand Up @@ -225,15 +226,13 @@ async fn test_semi_join_1k() {

#[tokio::test]
async fn test_semi_join_1k_filtered() {
// NLJ vs HJ gives wrong result
// Tracked in https://github.com/apache/datafusion/issues/11537
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::LeftSemi,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::HjSmj], false)
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand Down Expand Up @@ -292,27 +291,6 @@ impl JoinFuzzTestCase {
}
}

fn column_indices(&self) -> Vec<ColumnIndex> {
vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
]
}

fn on_columns(&self) -> Vec<(PhysicalExprRef, PhysicalExprRef)> {
let schema1 = self.input1[0].schema();
let schema2 = self.input2[0].schema();
Expand All @@ -328,10 +306,20 @@ impl JoinFuzzTestCase {
]
}

/// Helper function for building NLJoin filter, returning intermediate
/// schema as a union of origin filter intermediate schema and
/// on-condition schema
fn intermediate_schema(&self) -> Schema {
let filter_schema = if let Some(filter) = self.join_filter() {
filter.schema().to_owned()
} else {
Schema::empty()
};

let schema1 = self.input1[0].schema();
let schema2 = self.input2[0].schema();
Schema::new(vec![

let on_schema = Schema::new(vec![
schema1
.field_with_name("a")
.unwrap()
Expand All @@ -344,7 +332,81 @@ impl JoinFuzzTestCase {
.with_nullable(true),
schema2.field_with_name("a").unwrap().to_owned(),
schema2.field_with_name("b").unwrap().to_owned(),
])
]);

Schema::new(
filter_schema
.fields
.into_iter()
.cloned()
.chain(on_schema.fields.into_iter().cloned())
.collect_vec(),
)
}

/// Helper function for building NLJoin filter, returns the union
/// of original filter expression and on-condition expression
fn composite_filter_expression(&self) -> PhysicalExprRef {
let (filter_expression, column_idx_offset) =
if let Some(filter) = self.join_filter() {
(
filter.expression().to_owned(),
filter.schema().fields().len(),
)
} else {
(Arc::new(Literal::new(ScalarValue::from(true))) as _, 0)
};

let equal_a = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", column_idx_offset)),
Operator::Eq,
Arc::new(Column::new("a", column_idx_offset + 2)),
));
let equal_b = Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", column_idx_offset + 1)),
Operator::Eq,
Arc::new(Column::new("b", column_idx_offset + 3)),
));
let on_expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b));

Arc::new(BinaryExpr::new(
filter_expression,
Operator::And,
on_expression,
))
}

/// Helper function for building NLJoin filter, returning the union
/// of original filter column indices and on-condition column indices.
/// Result must match intermediate schema.
fn column_indices(&self) -> Vec<ColumnIndex> {
let mut column_indices = if let Some(filter) = self.join_filter() {
filter.column_indices().to_vec()
} else {
vec![]
};

let on_column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
];

column_indices.extend(on_column_indices);
column_indices
}

fn left_right(&self) -> (Arc<MemoryExec>, Arc<MemoryExec>) {
Expand Down Expand Up @@ -400,26 +462,15 @@ impl JoinFuzzTestCase {

fn nested_loop_join(&self) -> Arc<NestedLoopJoinExec> {
let (left, right) = self.left_right();
// Nested loop join uses filter for joining records

let column_indices = self.column_indices();
let intermediate_schema = self.intermediate_schema();
let expression = self.composite_filter_expression();

let equal_a = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Eq,
Arc::new(Column::new("a", 2)),
)) as _;
let equal_b = Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Operator::Eq,
Arc::new(Column::new("b", 3)),
)) as _;
let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _;

let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema);
let filter = JoinFilter::new(expression, column_indices, intermediate_schema);

Arc::new(
NestedLoopJoinExec::try_new(left, right, Some(on_filter), &self.join_type)
NestedLoopJoinExec::try_new(left, right, Some(filter), &self.join_type)
.unwrap(),
)
}
Expand Down

0 comments on commit 20f69cd

Please sign in to comment.