Skip to content

Commit

Permalink
Add projection to HashJoinExec.
Browse files Browse the repository at this point in the history
  • Loading branch information
my-vegetable-has-exploded committed Feb 15, 2024
1 parent e909443 commit 54f4712
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 15 deletions.
164 changes: 154 additions & 10 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
//! projection reaches a source, it can even dissappear from the plan entirely.

use std::collections::HashMap;
use std::ops::Index;
use std::sync::Arc;

use super::output_requirements::OutputRequirementExec;
Expand All @@ -44,7 +43,9 @@ use crate::physical_plan::{Distribution, ExecutionPlan};

use arrow_schema::SchemaRef;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeVisitor, VisitRecursion,
};
use datafusion_common::{DataFusionError, JoinSide};
use datafusion_physical_expr::expressions::{Column, Literal};
use datafusion_physical_expr::{
Expand Down Expand Up @@ -130,7 +131,10 @@ pub fn remove_unnecessary_projections(
} else if let Some(union) = input.downcast_ref::<UnionExec>() {
try_pushdown_through_union(projection, union)?
} else if let Some(hash_join) = input.downcast_ref::<HashJoinExec>() {
try_pushdown_through_hash_join(projection, hash_join)?
try_pushdown_through_hash_join(projection, hash_join)?.map_or_else(
|| try_embed_to_hash_join(projection, hash_join),
|e| Ok(Some(e)),
)?
} else if let Some(cross_join) = input.downcast_ref::<CrossJoinExec>() {
try_swapping_with_cross_join(projection, cross_join)?
} else if let Some(nl_join) = input.downcast_ref::<NestedLoopJoinExec>() {
Expand Down Expand Up @@ -525,26 +529,85 @@ fn try_pushdown_through_union(
Ok(Some(Arc::new(UnionExec::new(new_children))))
}

/// Some projection can't be pushed down left input or right input of hash join because filter or on need may need some columns that won't be used in later.
/// By embed those projection to hash join, we can reduce the cost of build_batch_from_indices in hash join (build_batch_from_indices need to can compute::take() for each column).
fn try_embed_to_hash_join(
projection: &ProjectionExec,
hash_join: &HashJoinExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else {
let projection_index = collect_column_indices(projection.expr());

if projection_index.is_empty() {
return Ok(None);
};

if projection_as_columns.len() >= hash_join.schema().fields().len() {
if projection_index.len() >= hash_join.schema().fields().len() {
return Ok(None);
}

let projection_index = projection_as_columns
let new_hash_join = Arc::new(hash_join.with_projection(&projection_index)?);

// build projection expressions for update_expr
let embed_project_exprs = projection_index
.iter()
.map(|(c, _)| c.index())
.zip(new_hash_join.schema().fields())
.map(|(index, field)| {
(
Arc::new(Column::new(field.name(), *index)) as Arc<dyn PhysicalExpr>,
field.name().to_owned(),
)
})
.collect::<Vec<_>>();

let new_hash_join = hash_join.with_projection(projection_index)?;
let mut new_projection_exprs = Vec::with_capacity(projection.expr().len());

for (expr, alias) in projection.expr() {
// update column index for projection expression since the input schema has been changed.
let Some(expr) = update_expr(expr, embed_project_exprs.as_slice(), false)? else {
return Ok(None);
};
new_projection_exprs.push((expr, alias.clone()));
}
let new_projection = Arc::new(ProjectionExec::try_new(
new_projection_exprs,
new_hash_join.clone(),
)?);
if is_projection_removable(&new_projection) {
Ok(Some(new_hash_join))
} else {
Ok(Some(new_projection))
}
}

/// Collect all column indices from the given projection expressions.
fn collect_column_indices(exprs: &[(Arc<dyn PhysicalExpr>, String)]) -> Vec<usize> {
// Since there are some expressions like `a + 1`, so we need to traverse the expr tree.
struct ColumnVisitor {
// Todo: should we use structure that preserves insertion order here, like indexmap?
pub column_indices: std::collections::BTreeSet<usize>,
}
impl TreeNodeVisitor for ColumnVisitor {
type N = Arc<dyn PhysicalExpr>;

fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion> {
if let Some(column) = node.as_any().downcast_ref::<Column>() {
self.column_indices.insert(column.index());
}
Ok(VisitRecursion::Continue)
}

fn post_visit(&mut self, _node: &Self::N) -> Result<VisitRecursion> {
Ok(VisitRecursion::Continue)
}
}

Ok(Some(Arc::new(new_hash_join)))
let mut visitor = ColumnVisitor {
column_indices: Default::default(),
};
exprs.iter().for_each(|(expr, _)| {
let _ = expr.visit(&mut visitor);
});
visitor.column_indices.into_iter().collect::<Vec<_>>()
}

/// Tries to push `projection` down through `hash_join`. If possible, performs the
Expand Down Expand Up @@ -1255,7 +1318,9 @@ mod tests {
Distribution, Partitioning, PhysicalExpr, PhysicalSortExpr,
PhysicalSortRequirement, ScalarFunctionExpr,
};
use datafusion_physical_plan::joins::SymmetricHashJoinExec;
use datafusion_physical_plan::joins::{
HashJoinExec, PartitionMode, SymmetricHashJoinExec,
};
use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec};
use datafusion_physical_plan::union::UnionExec;

Expand Down Expand Up @@ -2266,6 +2331,85 @@ mod tests {
Ok(())
}

#[test]
fn test_hash_join_after_projection() -> Result<()> {
// sql like
// SELECT t1.c as c_from_left, t1.b as b_from_left, t1.a as a_from_left, t2.c as c_from_right FROM t1 JOIN t2 ON t1.b = t2.c WHERE t1.b - (1 + t2.a) <= t2.a + t1.c
let left_csv = create_simple_csv_exec();
let right_csv = create_simple_csv_exec();

let join: Arc<dyn ExecutionPlan> = Arc::new(HashJoinExec::try_new(
left_csv,
right_csv,
vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))],
// b_left-(1+a_right)<=a_right+c_left
Some(JoinFilter::new(
Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b_left_inter", 0)),
Operator::Minus,
Arc::new(BinaryExpr::new(
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
Operator::Plus,
Arc::new(Column::new("a_right_inter", 1)),
)),
)),
Operator::LtEq,
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a_right_inter", 1)),
Operator::Plus,
Arc::new(Column::new("c_left_inter", 2)),
)),
)),
vec![
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
ColumnIndex {
index: 2,
side: JoinSide::Left,
},
],
Schema::new(vec![
Field::new("b_left_inter", DataType::Int32, true),
Field::new("a_right_inter", DataType::Int32, true),
Field::new("c_left_inter", DataType::Int32, true),
]),
)),
&JoinType::Inner,
PartitionMode::Auto,
true,
)?);
let projection: Arc<dyn ExecutionPlan> = Arc::new(ProjectionExec::try_new(
vec![
(Arc::new(Column::new("c", 2)), "c_from_left".to_string()),
(Arc::new(Column::new("b", 1)), "b_from_left".to_string()),
(Arc::new(Column::new("a", 0)), "a_from_left".to_string()),
(Arc::new(Column::new("c", 7)), "c_from_right".to_string()),
],
join,
)?);
let initial = get_plan_string(&projection);
let expected_initial = [
"ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@7 as c_from_right]", " HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false"
];
assert_eq!(initial, expected_initial);

let after_optimize =
ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?;

// HashJoinExec only returns result after projection. Because there are some alias columns in the projection, the ProjectionExec is not removed.
let expected = ["ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@3 as c_from_right]", " HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@2]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false"];
assert_eq!(get_plan_string(&after_optimize), expected);

Ok(())
}

#[test]
fn test_repartition_after_projection() -> Result<()> {
let csv = create_simple_csv_exec();
Expand Down
20 changes: 15 additions & 5 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use crate::{
DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream,
SendableRecordBatchStream, Statistics,
};
use crate::{handle_state, projection, DisplayAs};
use crate::{handle_state, DisplayAs};

use super::{
utils::{OnceAsync, OnceFut},
Expand Down Expand Up @@ -409,8 +409,9 @@ impl HashJoinExec {
JoinSide::Right
}

pub fn with_projection(&self, projection: Vec<usize>) -> Result<Self> {
let new_schema = project_schema(&self.schema, Some(&projection))?;
/// project the output of the join
pub fn with_projection(&self, projection: &Vec<usize>) -> Result<Self> {
let new_schema = project_schema(&self.schema, Some(projection))?;
let new_column_indices = projection
.iter()
.map(|i| self.column_indices[*i].clone())
Expand All @@ -428,7 +429,6 @@ impl HashJoinExec {
metrics: ExecutionPlanMetricsSet::new(),
column_indices: new_column_indices,
null_equals_null: self.null_equals_null,
// Todo@wy to check output_order modification
output_order: self.output_order.clone(),
})
}
Expand All @@ -442,6 +442,7 @@ impl DisplayAs for HashJoinExec {
|| "".to_string(),
|f| format!(", filter={}", f.expression()),
);
// If output schema is less than the schema of the join, then it means that projection is applied.
let display_projections = if self.schema.fields.len()
!= build_join_schema(
&self.left.schema(),
Expand All @@ -452,7 +453,16 @@ impl DisplayAs for HashJoinExec {
.fields
.len()
{
format!(", projection={:?}", self.schema)
format!(
", projection=[{}]",
self.schema
.fields
.iter()
.zip(self.column_indices.iter())
.map(|(f, index)| format!("{}@{}", f.name(), index.index))
.collect::<Vec<_>>()
.join(", ")
)
} else {
"".to_string()
};
Expand Down

0 comments on commit 54f4712

Please sign in to comment.