diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 1ca5274820f84..e47aa3ae8423f 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -21,7 +21,6 @@ //! projection reaches a source, it can even dissappear from the plan entirely. use std::collections::HashMap; -use std::ops::Index; use std::sync::Arc; use super::output_requirements::OutputRequirementExec; @@ -44,7 +43,9 @@ use crate::physical_plan::{Distribution, ExecutionPlan}; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeVisitor, VisitRecursion, +}; use datafusion_common::{DataFusionError, JoinSide}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ @@ -130,7 +131,10 @@ pub fn remove_unnecessary_projections( } else if let Some(union) = input.downcast_ref::() { try_pushdown_through_union(projection, union)? } else if let Some(hash_join) = input.downcast_ref::() { - try_pushdown_through_hash_join(projection, hash_join)? + try_pushdown_through_hash_join(projection, hash_join)?.map_or_else( + || try_embed_to_hash_join(projection, hash_join), + |e| Ok(Some(e)), + )? } else if let Some(cross_join) = input.downcast_ref::() { try_swapping_with_cross_join(projection, cross_join)? } else if let Some(nl_join) = input.downcast_ref::() { @@ -525,26 +529,85 @@ fn try_pushdown_through_union( Ok(Some(Arc::new(UnionExec::new(new_children)))) } +/// Some projection can't be pushed down left input or right input of hash join because filter or on need may need some columns that won't be used in later. +/// By embed those projection to hash join, we can reduce the cost of build_batch_from_indices in hash join (build_batch_from_indices need to can compute::take() for each column). fn try_embed_to_hash_join( projection: &ProjectionExec, hash_join: &HashJoinExec, ) -> Result>> { - let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + let projection_index = collect_column_indices(projection.expr()); + + if projection_index.is_empty() { return Ok(None); }; - if projection_as_columns.len() >= hash_join.schema().fields().len() { + if projection_index.len() >= hash_join.schema().fields().len() { return Ok(None); } - let projection_index = projection_as_columns + let new_hash_join = Arc::new(hash_join.with_projection(&projection_index)?); + + // build projection expressions for update_expr + let embed_project_exprs = projection_index .iter() - .map(|(c, _)| c.index()) + .zip(new_hash_join.schema().fields()) + .map(|(index, field)| { + ( + Arc::new(Column::new(field.name(), *index)) as Arc, + field.name().to_owned(), + ) + }) .collect::>(); - let new_hash_join = hash_join.with_projection(projection_index)?; + let mut new_projection_exprs = Vec::with_capacity(projection.expr().len()); + + for (expr, alias) in projection.expr() { + // update column index for projection expression since the input schema has been changed. + let Some(expr) = update_expr(expr, embed_project_exprs.as_slice(), false)? else { + return Ok(None); + }; + new_projection_exprs.push((expr, alias.clone())); + } + let new_projection = Arc::new(ProjectionExec::try_new( + new_projection_exprs, + new_hash_join.clone(), + )?); + if is_projection_removable(&new_projection) { + Ok(Some(new_hash_join)) + } else { + Ok(Some(new_projection)) + } +} + +/// Collect all column indices from the given projection expressions. +fn collect_column_indices(exprs: &[(Arc, String)]) -> Vec { + // Since there are some expressions like `a + 1`, so we need to traverse the expr tree. + struct ColumnVisitor { + // Todo: should we use structure that preserves insertion order here, like indexmap? + pub column_indices: std::collections::BTreeSet, + } + impl TreeNodeVisitor for ColumnVisitor { + type N = Arc; + + fn pre_visit(&mut self, node: &Self::N) -> Result { + if let Some(column) = node.as_any().downcast_ref::() { + self.column_indices.insert(column.index()); + } + Ok(VisitRecursion::Continue) + } + + fn post_visit(&mut self, _node: &Self::N) -> Result { + Ok(VisitRecursion::Continue) + } + } - Ok(Some(Arc::new(new_hash_join))) + let mut visitor = ColumnVisitor { + column_indices: Default::default(), + }; + exprs.iter().for_each(|(expr, _)| { + let _ = expr.visit(&mut visitor); + }); + visitor.column_indices.into_iter().collect::>() } /// Tries to push `projection` down through `hash_join`. If possible, performs the @@ -1255,7 +1318,9 @@ mod tests { Distribution, Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, ScalarFunctionExpr, }; - use datafusion_physical_plan::joins::SymmetricHashJoinExec; + use datafusion_physical_plan::joins::{ + HashJoinExec, PartitionMode, SymmetricHashJoinExec, + }; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::union::UnionExec; @@ -2266,6 +2331,85 @@ mod tests { Ok(()) } + #[test] + fn test_hash_join_after_projection() -> Result<()> { + // sql like + // SELECT t1.c as c_from_left, t1.b as b_from_left, t1.a as a_from_left, t2.c as c_from_right FROM t1 JOIN t2 ON t1.b = t2.c WHERE t1.b - (1 + t2.a) <= t2.a + t1.c + let left_csv = create_simple_csv_exec(); + let right_csv = create_simple_csv_exec(); + + let join: Arc = Arc::new(HashJoinExec::try_new( + left_csv, + right_csv, + vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))], + // b_left-(1+a_right)<=a_right+c_left + Some(JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b_left_inter", 0)), + Operator::Minus, + Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Operator::Plus, + Arc::new(Column::new("a_right_inter", 1)), + )), + )), + Operator::LtEq, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a_right_inter", 1)), + Operator::Plus, + Arc::new(Column::new("c_left_inter", 2)), + )), + )), + vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ], + Schema::new(vec![ + Field::new("b_left_inter", DataType::Int32, true), + Field::new("a_right_inter", DataType::Int32, true), + Field::new("c_left_inter", DataType::Int32, true), + ]), + )), + &JoinType::Inner, + PartitionMode::Auto, + true, + )?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), + (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), + (Arc::new(Column::new("a", 0)), "a_from_left".to_string()), + (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), + ], + join, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@7 as c_from_right]", " HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + // HashJoinExec only returns result after projection. Because there are some alias columns in the projection, the ProjectionExec is not removed. + let expected = ["ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@3 as c_from_right]", " HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@2]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false"]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + #[test] fn test_repartition_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 3236d3e7ae179..72b8719592e7e 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -41,7 +41,7 @@ use crate::{ DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use crate::{handle_state, projection, DisplayAs}; +use crate::{handle_state, DisplayAs}; use super::{ utils::{OnceAsync, OnceFut}, @@ -409,8 +409,9 @@ impl HashJoinExec { JoinSide::Right } - pub fn with_projection(&self, projection: Vec) -> Result { - let new_schema = project_schema(&self.schema, Some(&projection))?; + /// project the output of the join + pub fn with_projection(&self, projection: &Vec) -> Result { + let new_schema = project_schema(&self.schema, Some(projection))?; let new_column_indices = projection .iter() .map(|i| self.column_indices[*i].clone()) @@ -428,7 +429,6 @@ impl HashJoinExec { metrics: ExecutionPlanMetricsSet::new(), column_indices: new_column_indices, null_equals_null: self.null_equals_null, - // Todo@wy to check output_order modification output_order: self.output_order.clone(), }) } @@ -442,6 +442,7 @@ impl DisplayAs for HashJoinExec { || "".to_string(), |f| format!(", filter={}", f.expression()), ); + // If output schema is less than the schema of the join, then it means that projection is applied. let display_projections = if self.schema.fields.len() != build_join_schema( &self.left.schema(), @@ -452,7 +453,16 @@ impl DisplayAs for HashJoinExec { .fields .len() { - format!(", projection={:?}", self.schema) + format!( + ", projection=[{}]", + self.schema + .fields + .iter() + .zip(self.column_indices.iter()) + .map(|(f, index)| format!("{}@{}", f.name(), index.index)) + .collect::>() + .join(", ") + ) } else { "".to_string() };