Skip to content

Commit

Permalink
fix: join swap for projected semi/anti joins
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Oct 20, 2024
1 parent 972e3ab commit bcea434
Showing 1 changed file with 65 additions and 20 deletions.
85 changes: 65 additions & 20 deletions datafusion/core/src/physical_optimizer/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,32 @@ fn swap_join_projection(
left_schema_len: usize,
right_schema_len: usize,
projection: Option<&Vec<usize>>,
join_type: &JoinType,
) -> Option<Vec<usize>> {
projection.map(|p| {
p.iter()
.map(|i| {
// If the index is less than the left schema length, it is from the left schema, so we add the right schema length to it.
// Otherwise, it is from the right schema, so we subtract the left schema length from it.
if *i < left_schema_len {
*i + right_schema_len
} else {
*i - left_schema_len
}
})
.collect()
})
match join_type {
// For Anti/Semi join types, projection should remain unmodified,
// since these joins output schema remains the same after swap
JoinType::LeftAnti
| JoinType::LeftSemi
| JoinType::RightAnti
| JoinType::RightSemi => projection.cloned(),

_ => projection.map(|p| {
p.iter()
.map(|i| {
// If the index is less than the left schema length, it is from
// the left schema, so we add the right schema length to it.
// Otherwise, it is from the right schema, so we subtract the left
// schema length from it.
if *i < left_schema_len {
*i + right_schema_len
} else {
*i - left_schema_len
}
})
.collect()
}),
}
}

/// This function swaps the inputs of the given join operator.
Expand All @@ -179,6 +191,7 @@ pub fn swap_hash_join(
left.schema().fields().len(),
right.schema().fields().len(),
hash_join.projection.as_ref(),
hash_join.join_type(),
),
partition_mode,
hash_join.null_equals_null(),
Expand Down Expand Up @@ -1289,27 +1302,59 @@ mod tests_statistical {
);
}

#[rstest(
join_type, projection, small_on_right,
case::inner(JoinType::Inner, vec![1], true),
case::left(JoinType::Left, vec![1], true),
case::right(JoinType::Right, vec![1], true),
case::full(JoinType::Full, vec![1], true),
case::left_anti(JoinType::LeftAnti, vec![0], false),
case::left_semi(JoinType::LeftSemi, vec![0], false),
case::right_anti(JoinType::RightAnti, vec![0], true),
case::right_semi(JoinType::RightSemi, vec![0], true),
)]
#[tokio::test]
async fn test_hash_join_swap_on_joins_with_projections() -> Result<()> {
async fn test_hash_join_swap_on_joins_with_projections(
join_type: JoinType,
projection: Vec<usize>,
small_on_right: bool,
) -> Result<()> {
let (big, small) = create_big_and_small();

let left = if small_on_right { &big } else { &small };
let right = if small_on_right { &small } else { &big };

let left_on = if small_on_right {
"big_col"
} else {
"small_col"
};
let right_on = if small_on_right {
"small_col"
} else {
"big_col"
};

let join = Arc::new(HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
Arc::clone(left),
Arc::clone(right),
vec![(
Arc::new(Column::new_with_schema("big_col", &big.schema())?),
Arc::new(Column::new_with_schema("small_col", &small.schema())?),
Arc::new(Column::new_with_schema(left_on, &left.schema())?),
Arc::new(Column::new_with_schema(right_on, &right.schema())?),
)],
None,
&JoinType::Inner,
Some(vec![1]),
&join_type,
Some(projection),
PartitionMode::Partitioned,
false,
)?);

let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned)
.expect("swap_hash_join must support joins with projections");
let swapped_join = swapped.as_any().downcast_ref::<HashJoinExec>().expect(
"ProjectionExec won't be added above if HashJoinExec contains embedded projection",
);

assert_eq!(swapped_join.projection, Some(vec![0_usize]));
assert_eq!(swapped.schema().fields.len(), 1);
assert_eq!(swapped.schema().fields[0].name(), "small_col");
Expand Down

0 comments on commit bcea434

Please sign in to comment.