Skip to content

Commit

Permalink
support input reordering for NestedLoopJoinExec
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Mar 18, 2024
1 parent 269563a commit 03c778c
Show file tree
Hide file tree
Showing 8 changed files with 542 additions and 371 deletions.
209 changes: 206 additions & 3 deletions datafusion/core/src/physical_optimizer/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::error::Result;
use crate::physical_optimizer::PhysicalOptimizerRule;
use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
use crate::physical_plan::joins::{
CrossJoinExec, HashJoinExec, PartitionMode, StreamJoinPartitionMode,
SymmetricHashJoinExec,
CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode,
StreamJoinPartitionMode, SymmetricHashJoinExec,
};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
Expand Down Expand Up @@ -199,6 +199,38 @@ fn swap_hash_join(
}
}

/// Swaps inputs of `NestedLoopJoinExec` and wraps it into `ProjectionExec` is required
fn swap_nl_join(join: &NestedLoopJoinExec) -> Result<Arc<dyn ExecutionPlan>> {
let new_filter = swap_join_filter(join.filter());
let new_join_type = &swap_join_type(*join.join_type());

let new_join = NestedLoopJoinExec::try_new(
Arc::clone(join.right()),
Arc::clone(join.left()),
new_filter,
new_join_type,
)?;

// For Semi/Anti joins, swap result will produce same output schema,
// no need to wrap them into additional projection
let plan: Arc<dyn ExecutionPlan> = if matches!(
join.join_type(),
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
) {
Arc::new(new_join)
} else {
let projection =
swap_reverting_projection(&join.left().schema(), &join.right().schema());

Arc::new(ProjectionExec::try_new(projection, Arc::new(new_join))?)
};

Ok(plan)
}

/// When the order of the join is changed by the optimizer, the columns in
/// the output should not be impacted. This function creates the expressions
/// that will allow to swap back the values from the original left as the first
Expand Down Expand Up @@ -461,6 +493,14 @@ fn statistical_join_selection_subrule(
} else {
None
}
} else if let Some(nl_join) = plan.as_any().downcast_ref::<NestedLoopJoinExec>() {
let left = nl_join.left();
let right = nl_join.right();
if should_swap_join_order(&**left, &**right)? {
swap_nl_join(nl_join).map(Some)?
} else {
None
}
} else {
None
};
Expand Down Expand Up @@ -697,9 +737,12 @@ mod tests_statistical {

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{stats::Precision, JoinType, ScalarValue};
use datafusion_physical_expr::expressions::Column;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Column};
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};

use rstest::rstest;

/// Return statistcs for empty table
fn empty_statistics() -> Statistics {
Statistics {
Expand Down Expand Up @@ -785,6 +828,35 @@ mod tests_statistical {
}]
}

/// Create join filter for NLJoinExec with expression `big_col > small_col`
/// where both columns are 0-indexed and come from left and right inputs respectively
fn nl_join_filter() -> Option<JoinFilter> {
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
Field::new("big_col", DataType::Int32, false),
Field::new("small_col", DataType::Int32, false),
]);
let expression = Arc::new(BinaryExpr::new(
Arc::new(Column::new_with_schema("big_col", &intermediate_schema).unwrap()),
Operator::Gt,
Arc::new(Column::new_with_schema("big_col", &intermediate_schema).unwrap()),
)) as _;
Some(JoinFilter::new(
expression,
column_indices,
intermediate_schema,
))
}

/// Returns three plans with statistics of (min, max, distinct_count)
/// * big 100K rows @ (0, 50k, 50k)
/// * medium 10K rows @ (1k, 5k, 1k)
Expand Down Expand Up @@ -1151,6 +1223,137 @@ mod tests_statistical {
crosscheck_plans(join).unwrap();
}

#[rstest(
join_type,
case::inner(JoinType::Inner),
case::left(JoinType::Left),
case::right(JoinType::Right),
case::full(JoinType::Full)
)]
#[tokio::test]
async fn test_nl_join_with_swap(join_type: JoinType) {
let (big, small) = create_big_and_small();

let join = Arc::new(
NestedLoopJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
nl_join_filter(),
&join_type,
)
.unwrap(),
);

let optimized_join = JoinSelection::new()
.optimize(join.clone(), &ConfigOptions::new())
.unwrap();

let swapping_projection = optimized_join
.as_any()
.downcast_ref::<ProjectionExec>()
.expect("A proj is required to swap columns back to their original order");

assert_eq!(swapping_projection.expr().len(), 2);
let (col, name) = &swapping_projection.expr()[0];
assert_eq!(name, "big_col");
assert_col_expr(col, "big_col", 1);
let (col, name) = &swapping_projection.expr()[1];
assert_eq!(name, "small_col");
assert_col_expr(col, "small_col", 0);

let swapped_join = swapping_projection
.input()
.as_any()
.downcast_ref::<NestedLoopJoinExec>()
.expect("The type of the plan should not be changed");

// Assert join side of big_col swapped in filter expression
let swapped_filter = swapped_join.filter().unwrap();
let swapped_big_col_idx = swapped_filter.schema().index_of("big_col").unwrap();
let swapped_big_col_side = swapped_filter
.column_indices()
.get(swapped_big_col_idx)
.unwrap()
.side;
assert_eq!(
swapped_big_col_side,
JoinSide::Right,
"Filter column side should be swapped"
);

assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(8192)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(2097152)
);
crosscheck_plans(join.clone()).unwrap();
}

#[rstest(
join_type,
case::left_semi(JoinType::LeftSemi),
case::left_anti(JoinType::LeftAnti),
case::right_semi(JoinType::RightSemi),
case::right_anti(JoinType::RightAnti)
)]
#[tokio::test]
async fn test_nl_join_with_swap_no_proj(join_type: JoinType) {
let (big, small) = create_big_and_small();

let join = Arc::new(
NestedLoopJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
nl_join_filter(),
&join_type,
)
.unwrap(),
);

let optimized_join = JoinSelection::new()
.optimize(join.clone(), &ConfigOptions::new())
.unwrap();

let swapped_join = optimized_join
.as_any()
.downcast_ref::<NestedLoopJoinExec>()
.expect("The type of the plan should not be changed");

// Assert before/after schemas are equal
assert_eq!(
join.schema(),
swapped_join.schema(),
"Join schema should not be modified while optimization"
);

// Assert join side of big_col swapped in filter expression
let swapped_filter = swapped_join.filter().unwrap();
let swapped_big_col_idx = swapped_filter.schema().index_of("big_col").unwrap();
let swapped_big_col_side = swapped_filter
.column_indices()
.get(swapped_big_col_idx)
.unwrap()
.side;
assert_eq!(
swapped_big_col_side,
JoinSide::Right,
"Filter column side should be swapped"
);

assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(8192)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(2097152)
);
crosscheck_plans(join.clone()).unwrap();
}

#[tokio::test]
async fn test_swap_reverting_projection() {
let left_schema = Schema::new(vec![
Expand Down
89 changes: 84 additions & 5 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,19 @@ use arrow::array::{ArrayRef, Int32Array};
use arrow::compute::SortOptions;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use arrow_schema::Schema;
use rand::Rng;

use datafusion::common::JoinSide;
use datafusion::logical_expr::{JoinType, Operator};
use datafusion::physical_expr::expressions::BinaryExpr;
use datafusion::physical_plan::collect;
use datafusion::physical_plan::expressions::Column;
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec};
use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
use datafusion::physical_plan::joins::{
HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec,
};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion_expr::JoinType;

use datafusion::prelude::{SessionConfig, SessionContext};
use test_utils::stagger_batch_with_seed;
Expand Down Expand Up @@ -73,7 +79,7 @@ async fn test_full_join_1k() {
}

#[tokio::test]
async fn test_semi_join_1k() {
async fn test_semi_join_10k() {
run_join_test(
make_staggered_batches(10000),
make_staggered_batches(10000),
Expand All @@ -83,7 +89,7 @@ async fn test_semi_join_1k() {
}

#[tokio::test]
async fn test_anti_join_1k() {
async fn test_anti_join_10k() {
run_join_test(
make_staggered_batches(10000),
make_staggered_batches(10000),
Expand Down Expand Up @@ -118,6 +124,46 @@ async fn run_join_test(
),
];

// Nested loop join uses filter for joining records
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
schema1.field_with_name("a").unwrap().to_owned(),
schema1.field_with_name("b").unwrap().to_owned(),
schema2.field_with_name("a").unwrap().to_owned(),
schema2.field_with_name("b").unwrap().to_owned(),
]);

let equal_a = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Eq,
Arc::new(Column::new("a", 2)),
)) as _;
let equal_b = Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Operator::Eq,
Arc::new(Column::new("b", 3)),
)) as _;
let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _;

let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema);

// sort-merge join
let left = Arc::new(
MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(),
Expand Down Expand Up @@ -161,22 +207,55 @@ async fn run_join_test(
);
let hj_collected = collect(hj, task_ctx.clone()).await.unwrap();

// nested loop join
let left = Arc::new(
MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(),
);
let right = Arc::new(
MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(),
);
let nlj = Arc::new(
NestedLoopJoinExec::try_new(left, right, Some(on_filter), &join_type)
.unwrap(),
);
let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();

// compare
let smj_formatted = pretty_format_batches(&smj_collected).unwrap().to_string();
let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string();
let nlj_formatted = pretty_format_batches(&nlj_collected).unwrap().to_string();

let mut smj_formatted_sorted: Vec<&str> = smj_formatted.trim().lines().collect();
smj_formatted_sorted.sort_unstable();

let mut hj_formatted_sorted: Vec<&str> = hj_formatted.trim().lines().collect();
hj_formatted_sorted.sort_unstable();

let mut nlj_formatted_sorted: Vec<&str> = nlj_formatted.trim().lines().collect();
nlj_formatted_sorted.sort_unstable();

for (i, (smj_line, hj_line)) in smj_formatted_sorted
.iter()
.zip(&hj_formatted_sorted)
.enumerate()
{
assert_eq!((i, smj_line), (i, hj_line));
assert_eq!(
(i, smj_line),
(i, hj_line),
"SortMergeJoinExec and HashJoinExec produced different results"
);
}

for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
.iter()
.zip(&hj_formatted_sorted)
.enumerate()
{
assert_eq!(
(i, nlj_line),
(i, hj_line),
"NestedLoopJoinExec and HashJoinExec produced different results"
);
}
}
}
Expand Down
Loading

0 comments on commit 03c778c

Please sign in to comment.