diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e9f4f1f80972..f963685bdfbe 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -3496,7 +3496,9 @@ mod tests { use crate::logical_plan::table_scan; use crate::{col, exists, in_subquery, lit, placeholder, GroupingSet}; - use datafusion_common::tree_node::{TransformedResult, TreeNodeVisitor}; + use datafusion_common::tree_node::{ + TransformedResult, TreeNodeRewriter, TreeNodeVisitor, + }; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; use crate::test::function_stub::count; @@ -4157,4 +4159,120 @@ digraph { .unwrap(); assert_eq!(limit, new_limit); } + + #[test] + fn test_with_subqueries_jump() { + // The plan contains a `Project` node above a `Filter` node so returning + // `TreeNodeRecursion::Jump` on `Project` should cause not visiting `Filter`. + let plan = test_plan(); + + let mut filter_found = false; + plan.apply_with_subqueries(|plan| { + match plan { + LogicalPlan::Projection(..) => return Ok(TreeNodeRecursion::Jump), + LogicalPlan::Filter(..) => filter_found = true, + _ => {} + } + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + assert!(!filter_found); + + struct ProjectJumpVisitor { + filter_found: bool, + } + + impl ProjectJumpVisitor { + fn new() -> Self { + Self { + filter_found: false, + } + } + } + + impl<'n> TreeNodeVisitor<'n> for ProjectJumpVisitor { + type Node = LogicalPlan; + + fn f_down(&mut self, node: &'n Self::Node) -> Result { + match node { + LogicalPlan::Projection(..) => return Ok(TreeNodeRecursion::Jump), + LogicalPlan::Filter(..) => self.filter_found = true, + _ => {} + } + Ok(TreeNodeRecursion::Continue) + } + } + + let mut visitor = ProjectJumpVisitor::new(); + plan.visit_with_subqueries(&mut visitor).unwrap(); + assert!(!visitor.filter_found); + + let mut filter_found = false; + plan.clone() + .transform_down_with_subqueries(|plan| { + match plan { + LogicalPlan::Projection(..) => { + return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) + } + LogicalPlan::Filter(..) => filter_found = true, + _ => {} + } + Ok(Transformed::no(plan)) + }) + .unwrap(); + assert!(!filter_found); + + let mut filter_found = false; + plan.clone() + .transform_down_up_with_subqueries( + |plan| { + match plan { + LogicalPlan::Projection(..) => { + return Ok(Transformed::new( + plan, + false, + TreeNodeRecursion::Jump, + )) + } + LogicalPlan::Filter(..) => filter_found = true, + _ => {} + } + Ok(Transformed::no(plan)) + }, + |plan| Ok(Transformed::no(plan)), + ) + .unwrap(); + assert!(!filter_found); + + struct ProjectJumpRewriter { + filter_found: bool, + } + + impl ProjectJumpRewriter { + fn new() -> Self { + Self { + filter_found: false, + } + } + } + + impl TreeNodeRewriter for ProjectJumpRewriter { + type Node = LogicalPlan; + + fn f_down(&mut self, node: Self::Node) -> Result> { + match node { + LogicalPlan::Projection(..) => { + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)) + } + LogicalPlan::Filter(..) => self.filter_found = true, + _ => {} + } + Ok(Transformed::no(node)) + } + } + + let mut rewriter = ProjectJumpRewriter::new(); + plan.rewrite_with_subqueries(&mut rewriter).unwrap(); + assert!(!rewriter.filter_found); + } } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 6850c30f4f81..1539b69b4007 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -385,8 +385,10 @@ fn rewrite_extension_inputs Result {{ $F_DOWN? - .transform_children(|n| n.map_subqueries($F_CHILD))? - .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_children(|n| { + n.map_subqueries($F_CHILD)? + .transform_sibling(|n| n.map_children($F_CHILD)) + })? .transform_parent($F_UP) }}; } @@ -675,9 +677,11 @@ impl LogicalPlan { visitor .f_down(self)? .visit_children(|| { - self.apply_subqueries(|c| c.visit_with_subqueries(visitor)) + self.apply_subqueries(|c| c.visit_with_subqueries(visitor))? + .visit_sibling(|| { + self.apply_children(|c| c.visit_with_subqueries(visitor)) + }) })? - .visit_sibling(|| self.apply_children(|c| c.visit_with_subqueries(visitor)))? .visit_parent(|| visitor.f_up(self)) } @@ -710,13 +714,12 @@ impl LogicalPlan { node: &LogicalPlan, f: &mut F, ) -> Result { - f(node)? - .visit_children(|| { - node.apply_subqueries(|c| apply_with_subqueries_impl(c, f)) - })? - .visit_sibling(|| { - node.apply_children(|c| apply_with_subqueries_impl(c, f)) - }) + f(node)?.visit_children(|| { + node.apply_subqueries(|c| apply_with_subqueries_impl(c, f))? + .visit_sibling(|| { + node.apply_children(|c| apply_with_subqueries_impl(c, f)) + }) + }) } apply_with_subqueries_impl(self, &mut f) @@ -746,13 +749,12 @@ impl LogicalPlan { node: LogicalPlan, f: &mut F, ) -> Result> { - f(node)? - .transform_children(|n| { - n.map_subqueries(|c| transform_down_with_subqueries_impl(c, f)) - })? - .transform_sibling(|n| { - n.map_children(|c| transform_down_with_subqueries_impl(c, f)) - }) + f(node)?.transform_children(|n| { + n.map_subqueries(|c| transform_down_with_subqueries_impl(c, f))? + .transform_sibling(|n| { + n.map_children(|c| transform_down_with_subqueries_impl(c, f)) + }) + }) } transform_down_with_subqueries_impl(self, &mut f)