diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 5e95562033e60..9dfc238ab9e83 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule { impl MyAnalyzerRule { fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform(&|plan| { + plan.transform_up(&|plan| { Ok(match plan { LogicalPlan::Filter(filter) => { let predicate = Self::analyze_expr(filter.predicate.clone())?; @@ -106,7 +106,7 @@ impl MyAnalyzerRule { } fn analyze_expr(expr: Expr) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Literal(ScalarValue::Int64(i)) => { @@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule { /// use rewrite_expr to modify the expression tree. fn my_rewrite(expr: Expr) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Between(Between { diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 5da9636ffe185..39d691a9dcead 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -22,10 +22,24 @@ use std::sync::Arc; use crate::Result; -/// Defines a visitable and rewriteable a tree node. This trait is -/// implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as -/// well as expression trees ([`PhysicalExpr`], [`Expr`]) in -/// DataFusion +/// Defines a tree node that can have children of the same type as the parent node. The +/// implementations must provide [`TreeNode::apply_children()`] and +/// [`TreeNode::map_children()`] for visiting and changing the structure of the tree. +/// +/// [`TreeNode`] is implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as well +/// as expression trees ([`PhysicalExpr`], [`Expr`]) in DataFusion. +/// +/// Besides the children, each tree node can define links to embedded trees of the same +/// type. The root node of these trees are called inner children of a node. +/// +/// A logical plan of a query is a tree of [`LogicalPlan`] nodes, where each node can +/// contain multiple expression ([`Expr`]) trees. But expression tree nodes can contain +/// logical plans of subqueries, which are again trees of [`LogicalPlan`] nodes. The root +/// nodes of these subquery plans are the inner children of the containing query plan +/// node. +/// +/// Tree node implementations can provide [`TreeNode::apply_inner_children()`] for +/// visiting the structure of the inner tree. /// /// /// [`ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html @@ -33,28 +47,40 @@ use crate::Result; /// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html pub trait TreeNode: Sized { - /// Use preorder to iterate the node on the tree so that we can - /// stop fast for some cases. - /// - /// The `op` closure can be used to collect some info from the - /// tree node or do some checking for the tree node. - fn apply(&self, op: &mut F) -> Result + /// Applies `f` to the tree node, then to its inner children and then to its children + /// depending on the result of `f` in a preorder traversal. + /// See [`TreeNodeRecursion`] for more details on how the preorder traversal can be + /// controlled. + /// If an [`Err`] result is returned, recursion is stopped immediately. + fn visit_down(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - match op(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - self.apply_children(&mut |node| node.apply(op)) + // Apply `f` on self. + f(self) + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + .and_then_on_continue(|| { + // Run the recursive `apply` on each inner children, but as they are + // unrelated root nodes of inner trees if any returns stop then continue + // with the next one. + self.apply_inner_children(&mut |c| c.visit_down(f).continue_on_stop()) + // Run the recursive `apply` on each children. + .and_then_on_continue(|| { + self.apply_children(&mut |c| c.visit_down(f)) + }) + }) + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() } - /// Visit the tree node using the given [TreeNodeVisitor] - /// It performs a depth first walk of an node and its children. + /// Uses a [`TreeNodeVisitor`] to visit the tree node, then its inner children and + /// then its children depending on the result of [`TreeNodeVisitor::pre_visit()`] and + /// [`TreeNodeVisitor::post_visit()`] in a traversal. + /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. + /// + /// If an [`Err`] result is returned, recursion is stopped immediately. /// /// For an node tree such as /// ```text @@ -73,45 +99,54 @@ pub trait TreeNode: Sized { /// post_visit(ParentNode) /// ``` /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is post_visit - /// called on that node. Details see [`TreeNodeVisitor`] - /// - /// If using the default [`TreeNodeVisitor::post_visit`] that does - /// nothing, [`Self::apply`] should be preferred. - fn visit>( + /// If using the default [`TreeNodeVisitor::post_visit()`] that does nothing, + /// [`Self::visit_down()`] should be preferred. + fn visit>( &self, visitor: &mut V, - ) -> Result { - match visitor.pre_visit(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - match self.apply_children(&mut |node| node.visit(visitor))? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - - visitor.post_visit(self) + ) -> Result { + // Apply `pre_visit` on self. + visitor + .pre_visit(self) + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + .and_then_on_continue(|| { + // Run the recursive `visit` on each inner children, but as they are + // unrelated subquery plans if any returns stop then continue with the + // next one. + self.apply_inner_children(&mut |c| c.visit(visitor).continue_on_stop()) + // Run the recursive `visit` on each children. + .and_then_on_continue(|| { + self.apply_children(&mut |c| c.visit(visitor)) + }) + // Apply `post_visit` on self. + .and_then_on_continue(|| visitor.post_visit(self)) + }) + // Applying `pre_visit` or `post_visit` on self might have returned prune, + // but we need to propagate continue. + .continue_on_prune() } - /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. - /// When `op` does not apply to a given node, it is left unchanged. - /// The default tree traversal direction is transform_up(Postorder Traversal). - fn transform(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - self.transform_up(op) + fn transform>( + &mut self, + transformer: &mut T, + ) -> Result { + // Apply `pre_transform` on self. + transformer + .pre_transform(self) + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + .and_then_on_continue(|| + // Run the recursive `transform` on each children. + self + .transform_children(&mut |c| c.transform(transformer)) + // Apply `post_transform` on new self. + .and_then_on_continue(|| { + transformer.post_transform(self) + })) + // Applying `pre_transform` or `post_transform` on self might have returned + // prune, but we need to propagate continue. + .continue_on_prune() } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its @@ -208,54 +243,109 @@ pub trait TreeNode: Sized { } } - /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result + /// Apply `f` to the node's children. + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result; + F: FnMut(&Self) -> Result; + + /// Apply `f` to the node's inner children. + fn apply_inner_children(&self, _f: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + Ok(TreeNodeRecursion::Continue) + } /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result; + + /// Apply `f` to the node's children. + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result; + + /// Convenience function to do a preorder traversal of the tree nodes with `f` that + /// can't fail. + fn for_each(&self, f: &mut F) + where + F: FnMut(&Self), + { + self.visit_down(&mut |n| { + f(n); + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + } + + /// Convenience function to collect the first non-empty value that `f` returns in a + /// preorder traversal. + fn collect_first(&self, f: &mut F) -> Option + where + F: FnMut(&Self) -> Option, + { + let mut res = None; + self.visit_down(&mut |n| { + res = f(n); + if res.is_some() { + Ok(TreeNodeRecursion::StopAll) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + .unwrap(); + res + } + + /// Convenience function to collect all values that `f` returns in a preorder + /// traversal. + fn collect(&self, f: &mut F) -> Vec + where + F: FnMut(&Self) -> Vec, + { + let mut res = vec![]; + self.visit_down(&mut |n| { + res.extend(f(n)); + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + res + } } -/// Implements the [visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s. -/// -/// [`TreeNodeVisitor`] allows keeping the algorithms -/// separate from the code to traverse the structure of the `TreeNode` -/// tree and makes it easier to add new types of tree node and -/// algorithms. +/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for +/// recursively walking [`TreeNode`]s. /// -/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::pre_visit`] -/// and [`TreeNodeVisitor::post_visit`] are invoked recursively -/// on an node tree. +/// [`TreeNodeVisitor`] allows keeping the algorithms separate from the code to traverse +/// the structure of the [`TreeNode`] tree and makes it easier to add new types of tree +/// node and algorithms. /// -/// If an [`Err`] result is returned, recursion is stopped -/// immediately. +/// When passed to [`TreeNode::visit()`], [`TreeNodeVisitor::pre_visit()`] and +/// [`TreeNodeVisitor::post_visit()`] are invoked recursively on an node tree. +/// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. /// -/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no -/// children of that tree node are visited, nor is post_visit -/// called on that tree node -/// -/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no -/// siblings of that tree node are visited, nor is post_visit -/// called on its parent tree node -/// -/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no -/// children of that tree node are visited. +/// If an [`Err`] result is returned, recursion is stopped immediately. pub trait TreeNodeVisitor: Sized { /// The node type which is visitable. - type N: TreeNode; + type Node: TreeNode; - /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; + /// Invoked before any inner children or children of a node are visited. + fn pre_visit(&mut self, node: &Self::Node) -> Result; - /// Invoked after all children of `node` are visited. Default - /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { - Ok(VisitRecursion::Continue) - } + /// Invoked after all inner children and children of a node are visited. + fn post_visit(&mut self, _node: &Self::Node) -> Result; +} + +pub trait TreeNodeTransformer: Sized { + /// The node type which is visitable. + type Node: TreeNode; + + /// Invoked before any inner children or children of a node are modified. + fn pre_transform(&mut self, node: &mut Self::Node) -> Result; + + /// Invoked after all inner children and children of a node are modified. + fn post_transform(&mut self, node: &mut Self::Node) -> Result; } /// Trait for potentially recursively transform an [`TreeNode`] node @@ -289,15 +379,108 @@ pub enum RewriteRecursion { Skip, } -/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit`]. +/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit_down()`] and +/// [`TreeNode::visit()`]. #[derive(Debug)] -pub enum VisitRecursion { - /// Continue the visit to this node tree. +pub enum TreeNodeRecursion { + /// Continue the visit to the next node. Continue, - /// Keep recursive but skip applying op on the children - Skip, - /// Stop the visit to this node tree. + + /// Prune the current subtree. + /// If a preorder visit of a tree node returns [`TreeNodeRecursion::Prune`] then inner + /// children and children will not be visited and postorder visit of the node will not + /// be invoked. + Prune, + + /// Stop recursion on current tree. + /// If recursion runs on an inner tree then returning [`TreeNodeRecursion::Stop`] doesn't + /// stop recursion on the outer tree. Stop, + + /// Stop recursion on all (including outer) trees. + StopAll, +} + +impl TreeNodeRecursion { + fn continue_on_prune(self) -> TreeNodeRecursion { + match self { + TreeNodeRecursion::Prune => TreeNodeRecursion::Continue, + o => o, + } + } + + fn fail_on_prune(self) -> TreeNodeRecursion { + match self { + TreeNodeRecursion::Prune => panic!("Recursion can't prune."), + o => o, + } + } + + fn continue_on_stop(self) -> TreeNodeRecursion { + match self { + TreeNodeRecursion::Stop => TreeNodeRecursion::Continue, + o => o, + } + } +} + +/// This helper trait provide functions to control recursion on +/// [`Result`]. +pub trait TreeNodeRecursionResult: Sized { + fn and_then_on_continue(self, f: F) -> Result + where + F: FnOnce() -> Result; + + fn continue_on_prune(self) -> Result; + + fn fail_on_prune(self) -> Result; + + fn continue_on_stop(self) -> Result; +} + +impl TreeNodeRecursionResult for Result { + fn and_then_on_continue(self, f: F) -> Result + where + F: FnOnce() -> Result, + { + match self? { + TreeNodeRecursion::Continue => f(), + o => Ok(o), + } + } + + fn continue_on_prune(self) -> Result { + self.map(|tnr| tnr.continue_on_prune()) + } + + fn fail_on_prune(self) -> Result { + self.map(|tnr| tnr.fail_on_prune()) + } + + fn continue_on_stop(self) -> Result { + self.map(|tnr| tnr.continue_on_stop()) + } +} + +pub trait VisitRecursionIterator: Iterator { + fn for_each_till_continue(self, f: &mut F) -> Result + where + F: FnMut(Self::Item) -> Result; +} + +impl VisitRecursionIterator for I { + fn for_each_till_continue(self, f: &mut F) -> Result + where + F: FnMut(Self::Item) -> Result, + { + for i in self { + match f(i)? { + TreeNodeRecursion::Continue => {} + o => return Ok(o), + } + } + Ok(TreeNodeRecursion::Continue) + } } pub enum Transformed { @@ -342,19 +525,11 @@ pub trait DynTreeNode { /// Blanket implementation for Arc for any tye that implements /// [`DynTreeNode`] (such as [`Arc`]) impl TreeNode for Arc { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in self.arc_children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.arc_children().iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -371,4 +546,18 @@ impl TreeNode for Arc { Ok(self) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut new_children = self.arc_children(); + if !new_children.is_empty() { + let tnr = new_children.iter_mut().for_each_till_continue(f)?; + *self = self.with_new_arc_children(self.clone(), new_children)?; + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index be74afa1f4d66..870bddbaaaa57 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -37,7 +37,7 @@ use crate::{error::Result, scalar::ScalarValue}; use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; use crate::execution::context::SessionState; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; @@ -52,17 +52,18 @@ use object_store::{ObjectMeta, ObjectStore}; /// was performed pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { let mut is_applicable = true; - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { match expr { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - Ok(VisitRecursion::Skip) + Ok(TreeNodeRecursion::Prune) } else { - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } - Expr::Literal(_) + Expr::Nop + | Expr::Literal(_) | Expr::Alias(_) | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) @@ -88,27 +89,27 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => Ok(VisitRecursion::Continue), + | Expr::Case { .. } => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match &scalar_function.func_def { ScalarFunctionDefinition::BuiltIn(fun) => { match fun.volatility() { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } ScalarFunctionDefinition::UDF(fun) => { match fun.signature().volatility { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } @@ -128,7 +129,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::Wildcard { .. } | Expr::Placeholder(_) => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } }) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 7c3f7d9384abc..f573fd11a8c22 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -17,7 +17,7 @@ use arrow::{array::ArrayRef, datatypes::Schema}; use arrow_schema::FieldRef; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{Column, DataFusionError, Result, ScalarValue}; use parquet::file::metadata::ColumnChunkMetaData; use parquet::schema::types::SchemaDescriptor; @@ -259,7 +259,7 @@ impl BloomFilterPruningPredicate { fn get_predicate_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::new(); - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { if let Some(binary_expr) = expr.as_any().downcast_ref::() { @@ -269,7 +269,7 @@ impl BloomFilterPruningPredicate { columns.insert(column.name().to_string()); } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // no way to fail as only Ok(VisitRecursion::Continue) is returned .unwrap(); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 58a4f08341d64..3b16efb7d089f 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -38,7 +38,7 @@ use crate::{ use datafusion_common::{ alias::AliasGenerator, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}, + tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ @@ -2098,9 +2098,9 @@ impl<'a> BadPlanVisitor<'a> { } impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, node: &Self::N) -> Result { + fn pre_visit(&mut self, node: &Self::Node) -> Result { match node { LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => { plan_err!("DDL not supported: {}", ddl.name()) @@ -2114,9 +2114,13 @@ impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { LogicalPlan::Statement(stmt) if !self.options.allow_statements => { plan_err!("Statement not supported: {}", stmt.name()) } - _ => Ok(VisitRecursion::Continue), + _ => Ok(TreeNodeRecursion::Continue), } } + + fn post_visit(&mut self, _node: &Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } } #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index c50ea36b68ecb..7e159985f57cb 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -178,7 +178,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs { fn discard_column_index(group_expr: Arc) -> Arc { group_expr .clone() - .transform(&|expr| { + .transform_up(&|expr| { let normalized_form: Option> = match expr.as_any().downcast_ref::() { Some(column) => Some(Arc::new(Column::new(column.name(), 0))), diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index f2e04989ef661..3f2416f53f45f 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -47,7 +47,9 @@ use crate::physical_plan::{ }; use arrow::compute::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; @@ -1476,18 +1478,11 @@ impl DistributionContext { } impl TreeNode for DistributionContext { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -1505,6 +1500,23 @@ impl TreeNode for DistributionContext { DistributionContext::new_from_children_nodes(children_nodes, self.plan) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if children.is_empty() { + Ok(TreeNodeRecursion::Continue) + } else { + let tnr = children.iter_mut().for_each_till_continue(f)?; + *self = DistributionContext::new_from_children_nodes( + children, + self.plan.clone(), + )?; + Ok(tnr) + } + } } /// implement Display method for `DistributionContext` struct. @@ -1566,20 +1578,11 @@ impl PlanWithKeyRequirements { } impl TreeNode for PlanWithKeyRequirements { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -1605,6 +1608,22 @@ impl TreeNode for PlanWithKeyRequirements { Ok(self) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if !children.is_empty() { + let tnr = children.iter_mut().for_each_till_continue(f)?; + let children_plans = children.into_iter().map(|c| c.plan).collect(); + self.plan = + with_new_children_if_necessary(self.plan.clone(), children_plans)?.into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } /// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 277404b301c46..4f6d159c2a1d3 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -57,7 +57,9 @@ use crate::physical_plan::{ with_new_children_if_necessary, Distribution, ExecutionPlan, InputOrderMode, }; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; @@ -157,20 +159,11 @@ impl PlanWithCorrespondingSort { } impl TreeNode for PlanWithCorrespondingSort { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -188,6 +181,23 @@ impl TreeNode for PlanWithCorrespondingSort { PlanWithCorrespondingSort::new_from_children_nodes(children_nodes, self.plan) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if children.is_empty() { + Ok(TreeNodeRecursion::Continue) + } else { + let tnr = children.iter_mut().for_each_till_continue(f)?; + *self = PlanWithCorrespondingSort::new_from_children_nodes( + children, + self.plan.clone(), + )?; + Ok(tnr) + } + } } /// This object is used within the [EnforceSorting] rule to track the closest @@ -273,20 +283,11 @@ impl PlanWithCorrespondingCoalescePartitions { } impl TreeNode for PlanWithCorrespondingCoalescePartitions { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -307,6 +308,23 @@ impl TreeNode for PlanWithCorrespondingCoalescePartitions { ) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if children.is_empty() { + Ok(TreeNodeRecursion::Continue) + } else { + let tnr = children.iter_mut().for_each_till_continue(f)?; + *self = PlanWithCorrespondingCoalescePartitions::new_from_children_nodes( + children, + self.plan.clone(), + )?; + Ok(tnr) + } + } } /// The boolean flag `repartition_sorts` defined in the config indicates diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index d59248aadf056..122ce7171bd38 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -28,7 +28,9 @@ use crate::physical_plan::joins::SymmetricHashJoinExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::OptimizerOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; @@ -94,19 +96,11 @@ impl PipelineStatePropagator { } impl TreeNode for PipelineStatePropagator { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.children.iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -130,6 +124,21 @@ impl TreeNode for PipelineStatePropagator { Ok(self) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + if !self.children.is_empty() { + let tnr = self.children.iter_mut().for_each_till_continue(f)?; + let children_plans = self.children.iter().map(|c| c.plan.clone()).collect(); + self.plan = + with_new_children_if_necessary(self.plan.clone(), children_plans)?.into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } /// This function propagates finiteness information and rejects any plan with diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 664afbe822ff5..bf59b6c2f80c6 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -43,7 +43,7 @@ use crate::physical_plan::{Distribution, ExecutionPlan}; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::JoinSide; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ @@ -255,12 +255,12 @@ fn try_unifying_projections( // Collect the column references usage in the outer projection. projection.expr().iter().for_each(|(expr, _)| { - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { Ok({ if let Some(column) = expr.as_any().downcast_ref::() { *column_ref_map.entry(column.clone()).or_default() += 1; } - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index b2ba7596db8d2..2423ccc4c32e3 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -678,7 +678,7 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform(&|expr| { + e.transform_up(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { return Ok(Transformed::Yes(Arc::new(column_new.clone()))); diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index af45df7d8474b..1ebece70c1284 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -30,7 +30,9 @@ use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use super::utils::is_repartition; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_physical_plan::unbounded_output; /// For a given `plan`, this object carries the information one needs from its @@ -118,18 +120,11 @@ impl OrderPreservationContext { } impl TreeNode for OrderPreservationContext { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -147,6 +142,23 @@ impl TreeNode for OrderPreservationContext { OrderPreservationContext::new_from_children_nodes(children_nodes, self.plan) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if children.is_empty() { + Ok(TreeNodeRecursion::Continue) + } else { + let tnr = children.iter_mut().for_each_till_continue(f)?; + *self = OrderPreservationContext::new_from_children_nodes( + children, + self.plan.clone(), + )?; + Ok(tnr) + } + } } /// Calculates the updated plan by replacing executors that lose ordering diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index b9502d92ac12f..4b06218df9e98 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -28,7 +28,9 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_common::{plan_err, DataFusionError, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; @@ -82,20 +84,11 @@ impl SortPushDown { } impl TreeNode for SortPushDown { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(mut self, transform: F) -> Result @@ -118,6 +111,22 @@ impl TreeNode for SortPushDown { }; Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if !children.is_empty() { + let tnr = children.iter_mut().for_each_till_continue(f)?; + let children_plans = children.into_iter().map(|c| c.plan).collect(); + self.plan = + with_new_children_if_necessary(self.plan.clone(), children_plans)?.into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } pub(crate) fn pushdown_sorts( diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index e5816eb49ebb1..fd9e81c1b7520 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -381,6 +381,9 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::OuterReferenceColumn(_, _) => { internal_err!("Create physical name does not support OuterReferenceColumn") } + Expr::Nop => { + internal_err!("Create physical name does not support Nop expression") + } } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index f0aab95b8f0df..5369d502113b3 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -31,10 +31,10 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use std::collections::HashSet; -use std::fmt; use std::fmt::{Display, Formatter, Write}; use std::hash::{BuildHasher, Hash, Hasher}; use std::sync::Arc; +use std::{fmt, mem}; /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS @@ -81,8 +81,10 @@ use std::sync::Arc; /// assert_eq!(binary_expr.op, Operator::Eq); /// } /// ``` -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)] pub enum Expr { + #[default] + Nop, /// An expression with a specific name. Alias(Alias), /// A named reference to a qualified filed in a schema. @@ -784,6 +786,7 @@ impl Expr { /// Useful for non-rust based bindings pub fn variant_name(&self) -> &str { match self { + Expr::Nop { .. } => "Nop", Expr::AggregateFunction { .. } => "AggregateFunction", Expr::Alias(..) => "Alias", Expr::Between { .. } => "Between", @@ -954,11 +957,11 @@ impl Expr { } /// Remove an alias from an expression if one exists. - pub fn unalias(self) -> Expr { - match self { - Expr::Alias(alias) => alias.expr.as_ref().clone(), - _ => self, + pub fn unalias(&mut self) -> &mut Self { + if let Expr::Alias(alias) = self { + *self = mem::take(alias.expr.as_mut()); } + self } /// Return `self IN ` if `negated` is false, otherwise @@ -1147,7 +1150,7 @@ impl Expr { /// For example, gicen an expression like ` = $0` will infer `$0` to /// have type `int32`. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result { - self.transform(&|mut expr| { + self.transform_up(&|mut expr| { // Default to assuming the arguments are the same type if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; @@ -1204,6 +1207,7 @@ macro_rules! expr_vec_fmt { impl fmt::Display for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + Expr::Nop => write!(f, "NOP"), Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"), Expr::Column(c) => write!(f, "{c}"), Expr::OuterReferenceColumn(_, c) => write!(f, "outer_ref({c})"), @@ -1446,6 +1450,7 @@ fn create_function_name(fun: &str, distinct: bool, args: &[Expr]) -> Result 2)". fn create_name(e: &Expr) -> Result { match e { + Expr::Nop => Ok("NOP".to_string()), Expr::Alias(Alias { name, .. }) => Ok(name.clone()), Expr::Column(c) => Ok(c.flat_name()), Expr::OuterReferenceColumn(_, c) => Ok(format!("outer_ref({})", c.flat_name())), diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 1f04c80833f09..cbdeb16f99b23 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -20,7 +20,7 @@ use crate::expr::Alias; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeTransformer}; use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; use std::collections::HashMap; @@ -33,7 +33,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; @@ -57,7 +57,7 @@ pub fn normalize_col_with_schemas( schemas: &[&Arc], using_columns: &[HashSet], ) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = c.normalize_with_schemas(schemas, using_columns)?; @@ -75,7 +75,7 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( schemas: &[&[&DFSchema]], using_columns: &[HashSet], ) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = @@ -102,7 +102,7 @@ pub fn normalize_cols( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { @@ -122,7 +122,7 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul /// For example, if there were expressions like `foo.bar` this would /// rewrite it to just `bar`. pub fn unnormalize_col(expr: Expr) -> Expr { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = Column { @@ -164,7 +164,7 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { /// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column /// in the expression tree. pub fn strip_outer_reference(expr: Expr) -> Expr { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { Transformed::Yes(Expr::Column(col)) @@ -248,12 +248,12 @@ pub fn unalias(expr: Expr) -> Expr { /// /// This is important when optimizing plans to ensure the output /// schema of plan nodes don't change after optimization -pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result +pub fn rewrite_preserving_name(mut expr: Expr, transformer: &mut R) -> Result where - R: TreeNodeRewriter, + R: TreeNodeTransformer, { let original_name = expr.name_for_alias()?; - let expr = expr.rewrite(rewriter)?; + expr.transform(transformer)?; expr.alias_if_changed(original_name) } @@ -263,7 +263,7 @@ mod test { use crate::expr::Sort; use crate::{col, lit, Cast}; use arrow::datatypes::DataType; - use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; + use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{DFField, DFSchema, ScalarValue}; use std::ops::Add; @@ -272,17 +272,17 @@ mod test { v: Vec, } - impl TreeNodeRewriter for RecordingRewriter { - type N = Expr; + impl TreeNodeTransformer for RecordingRewriter { + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn pre_transform(&mut self, expr: &mut Expr) -> Result { self.v.push(format!("Previsited {expr}")); - Ok(RewriteRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn mutate(&mut self, expr: Expr) -> Result { + fn post_transform(&mut self, expr: &mut Expr) -> Result { self.v.push(format!("Mutated {expr}")); - Ok(expr) + Ok(TreeNodeRecursion::Continue) } } @@ -305,11 +305,17 @@ mod test { }; // rewrites "foo" --> "bar" - let rewritten = col("state").eq(lit("foo")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("foo")) + .transform_up(&transformer) + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); // doesn't rewrite - let rewritten = col("state").eq(lit("baz")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("baz")) + .transform_up(&transformer) + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); } @@ -399,7 +405,8 @@ mod test { #[test] fn rewriter_visit() { let mut rewriter = RecordingRewriter::default(); - col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); + let mut expr = col("state").eq(lit("CO")); + expr.transform(&mut rewriter).unwrap(); assert_eq!( rewriter.v, @@ -439,22 +446,28 @@ mod test { /// rewrites `expr_from` to `rewrite_to` using /// `rewrite_preserving_name` verifying the result is `expected_expr` fn test_rewrite(expr_from: Expr, rewrite_to: Expr) { - struct TestRewriter { - rewrite_to: Expr, - } + struct TestTransformer {} + + impl TreeNodeTransformer for TestTransformer { + type Node = Expr; - impl TreeNodeRewriter for TestRewriter { - type N = Expr; + fn pre_transform( + &mut self, + _node: &mut Self::Node, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } - fn mutate(&mut self, _: Expr) -> Result { - Ok(self.rewrite_to.clone()) + fn post_transform( + &mut self, + _node: &mut Self::Node, + ) -> Result { + Ok(TreeNodeRecursion::Continue) } } - let mut rewriter = TestRewriter { - rewrite_to: rewrite_to.clone(), - }; - let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap(); + let mut transformer = TestTransformer {}; + let expr = rewrite_preserving_name(expr_from.clone(), &mut transformer).unwrap(); let original_name = match &expr_from { Expr::Sort(Sort { expr, .. }) => expr.display_name(), diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index c87a724d5646b..1e7efcafd04df 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -83,7 +83,7 @@ fn rewrite_in_terms_of_projection( ) -> Result { // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" - expr.transform(&|expr| { + expr.transform_up(&|expr| { // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let col = Expr::Column( diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index e5b0185d90e0b..71987667be4a3 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -67,6 +67,7 @@ impl ExprSchemable for Expr { /// (e.g. `[utf8] + [bool]`). fn get_type(&self, schema: &S) -> Result { match self { + Expr::Nop => Ok(DataType::Null), Expr::Alias(Alias { expr, name, .. }) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { None => schema.data_type(&Column::from_name(name)).cloned(), @@ -251,7 +252,8 @@ impl ExprSchemable for Expr { | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::Placeholder(_) => Ok(true), - Expr::IsNull(_) + Expr::Nop + | Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) | Expr::IsFalse(_) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 88310dab82a27..eb2085ec9e155 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1220,9 +1220,10 @@ pub fn project_with_column_index( let alias_expr = expr .into_iter() .enumerate() - .map(|(i, e)| match e { + .map(|(i, mut e)| match e { Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => { - e.unalias().alias(schema.field(i).name()) + e.unalias(); + e.alias(schema.field(i).name()) } Expr::Column(Column { relation: _, diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 112dbf74dba18..2a8c4ce5912d3 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -19,7 +19,7 @@ use crate::LogicalPlan; use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; -use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::DataFusionError; use std::fmt; @@ -49,12 +49,12 @@ impl<'a, 'b> IndentVisitor<'a, 'b> { } impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { - type N = LogicalPlan; + type Node = LogicalPlan; fn pre_visit( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { if self.indent > 0 { writeln!(self.f)?; } @@ -69,15 +69,15 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { } self.indent += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn post_visit( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { self.indent -= 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -171,12 +171,12 @@ impl<'a, 'b> GraphvizVisitor<'a, 'b> { } impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { - type N = LogicalPlan; + type Node = LogicalPlan; fn pre_visit( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { let id = self.graphviz_builder.next_id(); // Create a new graph node for `plan` such as @@ -204,18 +204,18 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { } self.parent_ids.push(id); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn post_visit( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { // always be non-empty as pre_visit always pushes // So it should always be Ok(true) let res = self.parent_ids.pop(); res.ok_or(DataFusionError::Internal("Fail to format".to_string())) - .map(|_| VisitRecursion::Continue) + .map(|_| TreeNodeRecursion::Continue) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 1f3711407a149..3d8a8356f397f 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -32,8 +32,7 @@ use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, - grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, - split_conjunction, + grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, }; use crate::{ build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, @@ -43,8 +42,8 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, TreeNodeVisitor, - VisitRecursion, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRecursionResult, + TreeNodeTransformer, VisitRecursionIterator, }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, @@ -277,9 +276,9 @@ impl LogicalPlan { /// children pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.apply_expressions(&mut |e| { exprs.push(e.clone()); - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -290,13 +289,13 @@ impl LogicalPlan { /// logical plan nodes and all its descendant nodes. pub fn all_out_ref_exprs(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.apply_expressions(&mut |e| { find_out_reference_exprs(e).into_iter().for_each(|e| { if !exprs.contains(&e) { exprs.push(e) } }); - Ok(()) as Result<(), DataFusionError> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -311,37 +310,41 @@ impl LogicalPlan { exprs } - /// Calls `f` on all expressions (non-recursively) in the current - /// logical plan node. This does not include expressions in any - /// children. - pub fn inspect_expressions(self: &LogicalPlan, mut f: F) -> Result<(), E> + /// Apply `f` on expressions of the plan node. + /// `f` is not allowed to return [`TreeNodeRecursion::Prune`]. + pub fn apply_expressions(&self, f: &mut F) -> Result where - F: FnMut(&Expr) -> Result<(), E>, + F: FnMut(&Expr) -> Result, { + let f = &mut |e: &Expr| f(e).fail_on_prune(); + match self { LogicalPlan::Projection(Projection { expr, .. }) => { - expr.iter().try_for_each(f) + expr.iter().for_each_till_continue(f) } LogicalPlan::Values(Values { values, .. }) => { - values.iter().flatten().try_for_each(f) + values.iter().flatten().for_each_till_continue(f) } LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. }) => match partitioning_scheme { - Partitioning::Hash(expr, _) => expr.iter().try_for_each(f), - Partitioning::DistributeBy(expr) => expr.iter().try_for_each(f), - Partitioning::RoundRobinBatch(_) => Ok(()), + Partitioning::Hash(expr, _) => expr.iter().for_each_till_continue(f), + Partitioning::DistributeBy(expr) => expr.iter().for_each_till_continue(f), + Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, LogicalPlan::Window(Window { window_expr, .. }) => { - window_expr.iter().try_for_each(f) + window_expr.iter().for_each_till_continue(f) } LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. - }) => group_expr.iter().chain(aggr_expr.iter()).try_for_each(f), + }) => group_expr + .iter() + .chain(aggr_expr.iter()) + .for_each_till_continue(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). @@ -349,22 +352,21 @@ impl LogicalPlan { on.iter() // it not ideal to create an expr here to analyze them, but could cache it on the Join itself .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .try_for_each(|e| f(&e))?; - - if let Some(filter) = filter.as_ref() { - f(filter) - } else { - Ok(()) - } + .for_each_till_continue(&mut |e| f(&e)) + .and_then_on_continue(|| filter.iter().for_each_till_continue(f)) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().try_for_each(f), + LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().for_each_till_continue(f), LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - extension.node.expressions().iter().try_for_each(f) + extension + .node + .expressions() + .iter() + .for_each_till_continue(f) } LogicalPlan::TableScan(TableScan { filters, .. }) => { - filters.iter().try_for_each(f) + filters.iter().for_each_till_continue(f) } LogicalPlan::Unnest(Unnest { column, .. }) => { f(&Expr::Column(column.clone())) @@ -378,7 +380,7 @@ impl LogicalPlan { .iter() .chain(select_expr.iter()) .chain(sort_expr.clone().unwrap_or(vec![]).iter()) - .try_for_each(f), + .for_each_till_continue(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Subquery(_) @@ -394,7 +396,7 @@ impl LogicalPlan { | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Prepare(_) => Ok(()), + | LogicalPlan::Prepare(_) => Ok(TreeNodeRecursion::Continue), } } @@ -440,7 +442,7 @@ impl LogicalPlan { pub fn using_columns(&self) -> Result>, DataFusionError> { let mut using_columns: Vec> = vec![]; - self.apply(&mut |plan| { + self.visit_down(&mut |plan| { if let LogicalPlan::Join(Join { join_constraint: JoinConstraint::Using, on, @@ -456,7 +458,7 @@ impl LogicalPlan { })?; using_columns.push(columns); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(using_columns) @@ -642,7 +644,7 @@ impl LogicalPlan { } LogicalPlan::Filter { .. } => { assert_eq!(1, expr.len()); - let predicate = expr.pop().unwrap(); + let mut predicate = expr.pop().unwrap(); // filter predicates should not contain aliased expressions so we remove any aliases // before this logic was added we would have aliases within filters such as for @@ -658,29 +660,39 @@ impl LogicalPlan { struct RemoveAliases {} - impl TreeNodeRewriter for RemoveAliases { - type N = Expr; + impl TreeNodeTransformer for RemoveAliases { + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn pre_transform( + &mut self, + expr: &mut Expr, + ) -> Result { match expr { Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) => { // subqueries could contain aliases so we don't recurse into those - Ok(RewriteRecursion::Stop) + Ok(TreeNodeRecursion::Prune) } - Expr::Alias(_) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), + Expr::Alias(_) => { + expr.unalias(); + Ok(TreeNodeRecursion::Prune) + } + _ => Ok(TreeNodeRecursion::Continue), } } - fn mutate(&mut self, expr: Expr) -> Result { - Ok(expr.unalias()) + fn post_transform( + &mut self, + expr: &mut Expr, + ) -> Result { + expr.unalias(); + Ok(TreeNodeRecursion::Continue) } } let mut remove_aliases = RemoveAliases {}; - let predicate = predicate.rewrite(&mut remove_aliases)?; + predicate.transform(&mut remove_aliases)?; Filter::try_new(predicate, Arc::new(inputs[0].clone())) .map(LogicalPlan::Filter) @@ -754,10 +766,10 @@ impl LogicalPlan { // The first part of expr is equi-exprs, // and the struct of each equi-expr is like `left-expr = right-expr`. assert_eq!(expr.len(), equi_expr_count); - let new_on:Vec<(Expr,Expr)> = expr.into_iter().map(|equi_expr| { + let new_on:Vec<(Expr,Expr)> = expr.into_iter().map(|mut equi_expr| { // SimplifyExpression rule may add alias to the equi_expr. - let unalias_expr = equi_expr.clone().unalias(); - if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = unalias_expr { + equi_expr.unalias(); + if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = equi_expr { Ok((*left, *right)) } else { internal_err!( @@ -1126,59 +1138,27 @@ impl LogicalPlan { | LogicalPlan::Extension(_) => None, } } -} -impl LogicalPlan { - /// applies `op` to any subqueries in the plan - pub(crate) fn apply_subqueries(&self, op: &mut F) -> datafusion_common::Result<()> + /// Apply `f` on the root nodes of subquery plans of the plan node. + /// `f` is not allowed to return [`TreeNodeRecursion::Prune`]. + pub fn apply_subqueries(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> datafusion_common::Result, + F: FnMut(&Self) -> Result, { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.apply(op)?; - } - _ => {} + self.apply_expressions(&mut |e| { + e.visit_down(&mut |e| match e { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + // use a synthetic plan so the collector sees a + // LogicalPlan::Subquery (even though it is + // actually a Subquery alias) + let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); + f(&synthetic_plan).fail_on_prune() } - Ok::<(), DataFusionError>(()) + _ => Ok(TreeNodeRecursion::Continue), }) - })?; - Ok(()) - } - - /// applies visitor to any subqueries in the plan - pub(crate) fn visit_subqueries(&self, v: &mut V) -> datafusion_common::Result<()> - where - V: TreeNodeVisitor, - { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the visitor sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.visit(v)?; - } - _ => {} - } - Ok::<(), DataFusionError>(()) - }) - })?; - Ok(()) + }) } /// Return a `LogicalPlan` with all placeholders (e.g $1 $2, @@ -1214,9 +1194,9 @@ impl LogicalPlan { ) -> Result>, DataFusionError> { let mut param_types: HashMap> = HashMap::new(); - self.apply(&mut |plan| { - plan.inspect_expressions(|expr| { - expr.apply(&mut |expr| { + self.visit_down(&mut |plan| { + plan.apply_expressions(&mut |expr| { + expr.visit_down(&mut |expr| { if let Expr::Placeholder(Placeholder { id, data_type }) = expr { let prev = param_types.get(id); match (prev, data_type) { @@ -1231,11 +1211,9 @@ impl LogicalPlan { _ => {} } } - Ok(VisitRecursion::Continue) - })?; - Ok::<(), DataFusionError>(()) - })?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) + }) + }) })?; Ok(param_types) @@ -1247,7 +1225,7 @@ impl LogicalPlan { expr: Expr, param_values: &ParamValues, ) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { let value = @@ -2762,9 +2740,9 @@ digraph { } impl TreeNodeVisitor for OkVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "pre_visit Projection", LogicalPlan::Filter { .. } => "pre_visit Filter", @@ -2775,10 +2753,10 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "post_visit Projection", LogicalPlan::Filter { .. } => "post_visit Filter", @@ -2789,7 +2767,7 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -2845,20 +2823,20 @@ digraph { } impl TreeNodeVisitor for StoppingVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_pre_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } self.inner.pre_visit(plan)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_post_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } self.inner.post_visit(plan) @@ -2914,9 +2892,9 @@ digraph { } impl TreeNodeVisitor for ErrorVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_pre_in.dec() { return not_impl_err!("Error in pre_visit"); } @@ -2924,7 +2902,7 @@ digraph { self.inner.pre_visit(plan) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_post_in.dec() { return not_impl_err!("Error in post_visit"); } @@ -3217,7 +3195,7 @@ digraph { // after transformation, because plan is not the same anymore, // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs let plan = plan - .transform(&|plan| match plan { + .transform_up(&|plan| match plan { LogicalPlan::TableScan(table) => { let filter = Filter::try_new( external_filter.clone(), diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 1098842716b9e..8ec4a94204b0c 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -24,15 +24,17 @@ use crate::expr::{ }; use crate::{Expr, GetFieldAccess}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + TreeNode, TreeNodeRecursion, TreeNodeRecursionResult, VisitRecursionIterator, +}; use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - let children = match self { + match self { Expr::Alias(Alias{expr,..}) | Expr::Not(expr) | Expr::IsNotNull(expr) @@ -47,30 +49,26 @@ impl TreeNode for Expr { | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref().clone()], + | Expr::InSubquery(InSubquery{ expr, .. }) => f(expr), Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let expr = expr.as_ref().clone(); - match field { + f(expr).and_then_on_continue(|| match field { GetFieldAccess::ListIndex {key} => { - vec![key.as_ref().clone(), expr] + f(key) }, - GetFieldAccess::ListRange {start, stop} => { - vec![start.as_ref().clone(), stop.as_ref().clone(), expr] - } - GetFieldAccess::NamedStructField {name: _name} => { - vec![expr] + GetFieldAccess::ListRange { start, stop} => { + f(start).and_then_on_continue(|| f(stop)) } - } + GetFieldAccess::NamedStructField { name: _name } => Ok(TreeNodeRecursion::Continue) + }) } Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), - Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { - args.clone() - } + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().for_each_till_continue(f), + Expr::ScalarFunction (ScalarFunction{ args, .. } ) => args.iter().for_each_till_continue(f), Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { - lists_of_exprs.clone().into_iter().flatten().collect() + lists_of_exprs.iter().flatten().for_each_till_continue(f) } - Expr::Column(_) + Expr::Nop + | Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) @@ -78,76 +76,43 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::Wildcard {..} - | Expr::Placeholder (_) => vec![], + | Expr::Placeholder (_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - vec![left.as_ref().clone(), right.as_ref().clone()] + f(left) + .and_then_on_continue(|| f(right)) } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - vec![expr.as_ref().clone(), pattern.as_ref().clone()] + f(expr) + .and_then_on_continue(|| f(pattern)) } - Expr::Between(Between { - expr, low, high, .. - }) => vec![ - expr.as_ref().clone(), - low.as_ref().clone(), - high.as_ref().clone(), - ], - Expr::Case(case) => { - let mut expr_vec = vec![]; - if let Some(expr) = case.expr.as_ref() { - expr_vec.push(expr.as_ref().clone()); - }; - for (when, then) in case.when_then_expr.iter() { - expr_vec.push(when.as_ref().clone()); - expr_vec.push(then.as_ref().clone()); - } - if let Some(else_expr) = case.else_expr.as_ref() { - expr_vec.push(else_expr.as_ref().clone()); - } - expr_vec + Expr::Between(Between { expr, low, high, .. }) => { + f(expr) + .and_then_on_continue(|| f(low)) + .and_then_on_continue(|| f(high)) + }, + Expr::Case( Case { expr, when_then_expr, else_expr }) => { + expr.as_deref().into_iter().for_each_till_continue(f) + .and_then_on_continue(|| + when_then_expr.iter().for_each_till_continue(&mut |(w, t)| f(w).and_then_on_continue(|| f(t)))) + .and_then_on_continue(|| else_expr.as_deref().into_iter().for_each_till_continue(f)) } Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => { - let mut expr_vec = args.clone(); - - if let Some(f) = filter { - expr_vec.push(f.as_ref().clone()); - } - if let Some(o) = order_by { - expr_vec.extend(o.clone()); - } - - expr_vec + args.iter().for_each_till_continue(f) + .and_then_on_continue(|| filter.as_deref().into_iter().for_each_till_continue(f)) + .and_then_on_continue(|| order_by.iter().flatten().for_each_till_continue(f)) } - Expr::WindowFunction(WindowFunction { - args, - partition_by, - order_by, - .. - }) => { - let mut expr_vec = args.clone(); - expr_vec.extend(partition_by.clone()); - expr_vec.extend(order_by.clone()); - expr_vec + Expr::WindowFunction(WindowFunction { args, partition_by, order_by, .. }) => { + args.iter().for_each_till_continue(f) + .and_then_on_continue(|| partition_by.iter().for_each_till_continue(f)) + .and_then_on_continue(|| order_by.iter().for_each_till_continue(f)) } Expr::InList(InList { expr, list, .. }) => { - let mut expr_vec = vec![]; - expr_vec.push(expr.as_ref().clone()); - expr_vec.extend(list.clone()); - expr_vec - } - }; - - for child in children.iter() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + f(expr) + .and_then_on_continue(|| list.iter().for_each_till_continue(f)) } } - - Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result @@ -157,6 +122,7 @@ impl TreeNode for Expr { let mut transform = transform; Ok(match self { + Expr::Nop => self, Expr::Alias(Alias { expr, relation, @@ -376,6 +342,90 @@ impl TreeNode for Expr { } }) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + match self { + Expr::Alias(Alias { expr,.. }) + | Expr::Not(expr) + | Expr::IsNotNull(expr) + | Expr::IsTrue(expr) + | Expr::IsFalse(expr) + | Expr::IsUnknown(expr) + | Expr::IsNotTrue(expr) + | Expr::IsNotFalse(expr) + | Expr::IsNotUnknown(expr) + | Expr::IsNull(expr) + | Expr::Negative(expr) + | Expr::Cast(Cast { expr, .. }) + | Expr::TryCast(TryCast { expr, .. }) + | Expr::Sort(Sort { expr, .. }) + | Expr::InSubquery(InSubquery{ expr, .. }) => f(expr), + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + f(expr).and_then_on_continue(|| match field { + GetFieldAccess::ListIndex {key} => { + f(key) + }, + GetFieldAccess::ListRange { start, stop} => { + f(start).and_then_on_continue(|| f(stop)) + } + GetFieldAccess::NamedStructField { name: _name } => Ok(TreeNodeRecursion::Continue) + }) + } + Expr::GroupingSet(GroupingSet::Rollup(exprs)) + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter_mut().for_each_till_continue(f), + | Expr::ScalarFunction(ScalarFunction{ args, .. }) => args.iter_mut().for_each_till_continue(f), + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + lists_of_exprs.iter_mut().flatten().for_each_till_continue(f) + } + Expr::Nop + | Expr::Column(_) + // Treat OuterReferenceColumn as a leaf expression + | Expr::OuterReferenceColumn(_, _) + | Expr::ScalarVariable(_, _) + | Expr::Literal(_) + | Expr::Exists { .. } + | Expr::ScalarSubquery(_) + | Expr::Wildcard {..} + | Expr::Placeholder (_) => Ok(TreeNodeRecursion::Continue), + Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { + f(left) + .and_then_on_continue(|| f(right)) + } + Expr::Like(Like { expr, pattern, .. }) + | Expr::SimilarTo(Like { expr, pattern, .. }) => { + f(expr) + .and_then_on_continue(|| f(pattern)) + } + Expr::Between(Between { expr, low, high, .. }) => { + f(expr) + .and_then_on_continue(|| f(low)) + .and_then_on_continue(|| f(high)) + }, + Expr::Case( Case { expr, when_then_expr, else_expr }) => { + expr.as_deref_mut().into_iter().for_each_till_continue(f) + .and_then_on_continue(|| + when_then_expr.iter_mut().for_each_till_continue(&mut |(w, t)| f(w).and_then_on_continue(|| f(t)))) + .and_then_on_continue(|| else_expr.as_deref_mut().into_iter().for_each_till_continue(f)) + } + Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => { + args.iter_mut().for_each_till_continue(f) + .and_then_on_continue(|| filter.as_deref_mut().into_iter().for_each_till_continue(f)) + .and_then_on_continue(|| order_by.iter_mut().flatten().for_each_till_continue(f)) + } + Expr::WindowFunction(WindowFunction { args, partition_by, order_by, .. }) => { + args.iter_mut().for_each_till_continue(f) + .and_then_on_continue(|| partition_by.iter_mut().for_each_till_continue(f)) + .and_then_on_continue(|| order_by.iter_mut().for_each_till_continue(f)) + } + Expr::InList(InList { expr, list, .. }) => { + f(expr) + .and_then_on_continue(|| list.iter_mut().for_each_till_continue(f)) + } + } + } } fn transform_boxed(boxed_expr: Box, transform: &mut F) -> Result> diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index c7621bc178332..e85294ea5f736 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -18,92 +18,22 @@ //! Tree node implementation for logical plan use crate::LogicalPlan; -use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; -use datafusion_common::{tree_node::TreeNode, Result}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, VisitRecursionIterator}; +use datafusion_common::Result; impl TreeNode for LogicalPlan { - fn apply(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - // Note, - // - // Compared to the default implementation, we need to invoke - // [`Self::apply_subqueries`] before visiting its children - match op(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - self.apply_subqueries(op)?; - - self.apply_children(&mut |node| node.apply(op)) - } - - /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke - /// [`LogicalPlan::visit`]. - /// - /// For example, for a logical plan like: - /// - /// ```text - /// Projection: id - /// Filter: state Eq Utf8(\"CO\")\ - /// CsvScan: employee.csv projection=Some([0, 3])"; - /// ``` - /// - /// The sequence of visit operations would be: - /// ```text - /// visitor.pre_visit(Projection) - /// visitor.pre_visit(Filter) - /// visitor.pre_visit(CsvScan) - /// visitor.post_visit(CsvScan) - /// visitor.post_visit(Filter) - /// visitor.post_visit(Projection) - /// ``` - fn visit>( - &self, - visitor: &mut V, - ) -> Result { - // Compared to the default implementation, we need to invoke - // [`Self::visit_subqueries`] before visiting its children - - match visitor.pre_visit(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - self.visit_subqueries(visitor)?; - - match self.apply_children(&mut |node| node.visit(visitor))? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - - visitor.post_visit(self) + self.inputs().into_iter().for_each_till_continue(f) } - fn apply_children(&self, op: &mut F) -> Result + fn apply_inner_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in self.inputs() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.apply_subqueries(f) } fn map_children(self, transform: F) -> Result @@ -128,4 +58,24 @@ impl TreeNode for LogicalPlan { Ok(self) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let old_children = self.inputs(); + let mut new_children = + old_children.iter().map(|&c| c.clone()).collect::>(); + let tnr = new_children.iter_mut().for_each_till_continue(f)?; + + // if any changes made, make a new child + if old_children + .iter() + .zip(new_children.iter()) + .any(|(c1, c2)| c1 != &c2) + { + *self = self.with_new_inputs(new_children.as_slice())?; + } + Ok(tnr) + } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index abdd7f5f57f61..9d0daa5f4ca26 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -31,7 +31,7 @@ use crate::{ }; use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, @@ -261,15 +261,16 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { /// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { - inspect_expr_pre(expr, |expr| { - match expr { + expr.visit_down(&mut |e| { + match e { Expr::Column(qc) => { accum.insert(qc.clone()); } // Use explicit pattern match instead of a default // implementation, so that in the future if someone adds // new Expr types, they will check here as well - Expr::ScalarVariable(_, _) + Expr::Nop + | Expr::ScalarVariable(_, _) | Expr::Alias(_) | Expr::Literal(_) | Expr::BinaryExpr { .. } @@ -303,8 +304,9 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } => {} } - Ok(()) + Ok(TreeNodeRecursion::Continue) }) + .map(|_| ()) } /// Find excluded columns in the schema, if any @@ -655,44 +657,22 @@ where F: Fn(&Expr) -> bool, { let mut exprs = vec![]; - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { if test_fn(expr) { if !(exprs.contains(expr)) { exprs.push(expr.clone()) } // stop recursing down this expr once we find a match - return Ok(VisitRecursion::Skip); + return Ok(TreeNodeRecursion::Prune); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); exprs } -/// Recursively inspect an [`Expr`] and all its children. -pub fn inspect_expr_pre(expr: &Expr, mut f: F) -> Result<(), E> -where - F: FnMut(&Expr) -> Result<(), E>, -{ - let mut err = Ok(()); - expr.apply(&mut |expr| { - if let Err(e) = f(expr) { - // save the error for later (it may not be a DataFusionError - err = Err(e); - Ok(VisitRecursion::Stop) - } else { - // keep going - Ok(VisitRecursion::Continue) - } - }) - // The closure always returns OK, so this will always too - .expect("no way to return error during recursion"); - - err -} - /// Returns a new logical plan based on the original one with inputs /// and expressions replaced. /// @@ -825,17 +805,14 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec { .collect() } -pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { - let mut exprs = vec![]; - inspect_expr_pre(e, |expr| { - if let Expr::Column(c) = expr { - exprs.push(c.clone()) +pub(crate) fn find_columns_referenced_by_expr(expr: &Expr) -> Vec { + expr.collect(&mut |e| { + if let Expr::Column(c) = e { + vec![c.clone()] + } else { + vec![] } - Ok(()) as Result<()> }) - // As the closure always returns Ok, this "can't" error - .expect("Unexpected error"); - exprs } /// Convert any `Expr` to an `Expr::Column`. @@ -852,26 +829,16 @@ pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result { /// Recursively walk an expression tree, collecting the column indexes /// referenced in the expression pub(crate) fn find_column_indexes_referenced_by_expr( - e: &Expr, + expr: &Expr, schema: &DFSchemaRef, ) -> Vec { - let mut indexes = vec![]; - inspect_expr_pre(e, |expr| { - match expr { - Expr::Column(qc) => { - if let Ok(idx) = schema.index_of_column(qc) { - indexes.push(idx); - } - } - Expr::Literal(_) => { - indexes.push(std::usize::MAX); - } - _ => {} + expr.collect(&mut |e| match e { + Expr::Column(qc) => schema.index_of_column(qc).into_iter().collect(), + Expr::Literal(_) => { + vec![std::usize::MAX] } - Ok(()) as Result<()> + _ => vec![], }) - .unwrap(); - indexes } /// can this data type be used in hash join equal conditions?? diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index fd84bb80160b4..17b1ad8cc73f5 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -17,14 +17,18 @@ use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeTransformer, +}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; +use datafusion_expr::expr::{ + AggregateFunction, AggregateFunctionDefinition, Exists, InSubquery, WindowFunction, +}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::ScalarSubquery; use datafusion_expr::{ - aggregate_function, expr, lit, window_function, Aggregate, Expr, Filter, LogicalPlan, + aggregate_function, lit, window_function, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, Sort, Subquery, }; use std::sync::Arc; @@ -114,108 +118,69 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { struct CountWildcardRewriter {} -impl TreeNodeRewriter for CountWildcardRewriter { - type N = Expr; +impl TreeNodeTransformer for CountWildcardRewriter { + type Node = Expr; + + fn pre_transform(&mut self, _node: &mut Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } - fn mutate(&mut self, old_expr: Expr) -> Result { - let new_expr = match old_expr.clone() { - Expr::WindowFunction(expr::WindowFunction { + fn post_transform(&mut self, expr: &mut Expr) -> Result { + match expr { + Expr::WindowFunction(WindowFunction { fun: window_function::WindowFunction::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args, - partition_by, - order_by, - window_frame, - }) if args.len() == 1 => match args[0] { - Expr::Wildcard { qualifier: None } => { - Expr::WindowFunction(expr::WindowFunction { - fun: window_function::WindowFunction::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ), - args: vec![lit(COUNT_STAR_EXPANSION)], - partition_by, - order_by, - window_frame, - }) + .. + }) if args.len() == 1 => { + if let Expr::Wildcard { qualifier: None } = args[0] { + args[0] = lit(COUNT_STAR_EXPANSION) } - - _ => old_expr, - }, + } Expr::AggregateFunction(AggregateFunction { func_def: AggregateFunctionDefinition::BuiltIn( aggregate_function::AggregateFunction::Count, ), args, - distinct, - filter, - order_by, - }) if args.len() == 1 => match args[0] { - Expr::Wildcard { qualifier: None } => { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![lit(COUNT_STAR_EXPANSION)], - distinct, - filter, - order_by, - )) + .. + }) if args.len() == 1 => { + if let Expr::Wildcard { qualifier: None } = args[0] { + args[0] = lit(COUNT_STAR_EXPANSION) } - _ => old_expr, - }, - - ScalarSubquery(Subquery { - subquery, - outer_ref_columns, - }) => { + } + ScalarSubquery(Subquery { subquery, .. }) => { let new_plan = subquery .as_ref() .clone() .transform_down(&analyze_internal)?; - ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - }) + *subquery = Arc::new(new_plan); } Expr::InSubquery(InSubquery { - expr, - subquery, - negated, + subquery: Subquery { subquery, .. }, + .. }) => { let new_plan = subquery - .subquery .as_ref() .clone() .transform_down(&analyze_internal)?; - - Expr::InSubquery(InSubquery::new( - expr, - Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - )) + *subquery = Arc::new(new_plan); } - Expr::Exists(expr::Exists { subquery, negated }) => { + Expr::Exists(Exists { + subquery: Subquery { subquery, .. }, + .. + }) => { let new_plan = subquery - .subquery .as_ref() .clone() .transform_down(&analyze_internal)?; - - Expr::Exists(expr::Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - }) + *subquery = Arc::new(new_plan); } - _ => old_expr, + _ => {} }; - Ok(new_expr) + Ok(TreeNodeRecursion::Continue) } } diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 90af7aec82935..a418fbf5537be 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -74,7 +74,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Transformed::Yes(plan) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform(&rewrite_subquery)?; + let new_expr = filter.predicate.transform_up(&rewrite_subquery)?; Transformed::Yes(LogicalPlan::Filter(Filter::try_new( new_expr, filter.input, diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 14d5ddf473786..0b2c20db39572 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -27,11 +27,10 @@ use crate::analyzer::subquery::check_subquery_expr; use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::expr::Exists; use datafusion_expr::expr::InSubquery; -use datafusion_expr::utils::inspect_expr_pre; use datafusion_expr::{Expr, LogicalPlan}; use log::debug; use std::sync::Arc; @@ -117,21 +116,21 @@ impl Analyzer { /// Do necessary check and fail the invalid plan fn check_plan(plan: &LogicalPlan) -> Result<()> { - plan.apply(&mut |plan: &LogicalPlan| { - for expr in plan.expressions().iter() { + plan.visit_down(&mut |plan: &LogicalPlan| { + plan.apply_expressions(&mut |e| { // recursively look for subqueries - inspect_expr_pre(expr, |expr| match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - check_subquery_expr(plan, &subquery.subquery, expr) + e.visit_down(&mut |e| { + match e { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + check_subquery_expr(plan, &subquery.subquery, e)? + } + _ => {} } - _ => Ok(()), - })?; - } - - Ok(VisitRecursion::Continue) - })?; - - Ok(()) + Ok(TreeNodeRecursion::Continue) + }) + }) + }) + .map(|_| ()) } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 7c5b70b19af0a..78c630982d9b3 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -17,7 +17,7 @@ use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::utils::split_conjunction; @@ -146,7 +146,7 @@ fn check_inner_plan( LogicalPlan::Aggregate(_) => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -171,7 +171,7 @@ fn check_inner_plan( check_mixed_out_refer_in_window(window)?; inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -188,7 +188,7 @@ fn check_inner_plan( | LogicalPlan::SubqueryAlias(_) => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -206,7 +206,7 @@ fn check_inner_plan( is_aggregate, can_contain_outer_ref, )?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -221,7 +221,7 @@ fn check_inner_plan( JoinType::Full => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, false)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -281,7 +281,7 @@ fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { let mut exprs = vec![]; - inner_plan.apply(&mut |plan| { + inner_plan.visit_down(&mut |plan| { if let LogicalPlan::Filter(Filter { predicate, .. }) = plan { let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) .into_iter() @@ -290,9 +290,9 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { correlated .into_iter() .for_each(|expr| exprs.push(strip_outer_reference(expr.clone()))); - return Ok(VisitRecursion::Continue); + return Ok(TreeNodeRecursion::Continue); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(exprs) } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 91611251d9dd9..854cd46cd89cc 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -17,12 +17,13 @@ //! Optimizer rule for type validation and coercion +use std::mem; use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeTransformer}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -44,7 +45,6 @@ use datafusion_expr::type_coercion::other::{ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ - is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarFunctionDefinition, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -125,40 +125,24 @@ pub(crate) struct TypeCoercionRewriter { pub(crate) schema: DFSchemaRef, } -impl TreeNodeRewriter for TypeCoercionRewriter { - type N = Expr; +impl TreeNodeTransformer for TypeCoercionRewriter { + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) + fn pre_transform(&mut self, _expr: &mut Expr) -> Result { + Ok(TreeNodeRecursion::Continue) } - fn mutate(&mut self, expr: Expr) -> Result { + fn post_transform(&mut self, expr: &mut Expr) -> Result { match expr { - Expr::ScalarSubquery(Subquery { - subquery, - outer_ref_columns, + Expr::ScalarSubquery(Subquery { subquery, .. }) + | Expr::Exists(Exists { + subquery: Subquery { subquery, .. }, + .. }) => { - let new_plan = analyze_internal(&self.schema, &subquery)?; - Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - })) - } - Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; - Ok(Expr::Exists(Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - })) + let new_plan = analyze_internal(&self.schema, subquery)?; + *subquery = Arc::new(new_plan); } - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { + Expr::InSubquery(InSubquery { expr, subquery, .. }) => { let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; let expr_type = expr.get_type(&self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); @@ -166,53 +150,31 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" ), )?; + **expr = mem::take(expr.as_mut()).cast_to(&common_type, &self.schema)?; let new_subquery = Subquery { subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, + outer_ref_columns: mem::take(&mut subquery.outer_ref_columns), }; - Ok(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, &self.schema)?), - cast_subquery(new_subquery, &common_type)?, - negated, - ))) - } - Expr::IsTrue(expr) => { - let expr = is_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotTrue(expr) => { - let expr = is_not_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsFalse(expr) => { - let expr = is_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) + *subquery = cast_subquery(new_subquery, &common_type)?; } - Expr::IsNotFalse(expr) => { - let expr = - is_not_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsUnknown(expr) => { - let expr = is_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotUnknown(expr) => { - let expr = - is_not_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) + Expr::IsTrue(expr) + | Expr::IsNotTrue(expr) + | Expr::IsFalse(expr) + | Expr::IsNotFalse(expr) + | Expr::IsUnknown(expr) + | Expr::IsNotUnknown(expr) => { + **expr = get_casted_expr_for_bool_op(expr, &self.schema)? } Expr::Like(Like { - negated, expr, pattern, - escape_char, case_insensitive, + .. }) => { let left_type = expr.get_type(&self.schema)?; let right_type = pattern.get_type(&self.schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { - let op_name = if case_insensitive { + let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" @@ -221,35 +183,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression" ) })?; - let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); - let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); - let expr = Expr::Like(Like::new( - negated, - expr, - pattern, - escape_char, - case_insensitive, - )); - Ok(expr) + **expr = mem::take(expr.as_mut()).cast_to(&coerced_type, &self.schema)?; + **pattern = + mem::take(pattern.as_mut()).cast_to(&coerced_type, &self.schema)?; } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left_type, right_type) = get_input_types( &left.get_type(&self.schema)?, - &op, + op, &right.get_type(&self.schema)?, )?; - - Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, &self.schema)?), - op, - Box::new(right.cast_to(&right_type, &self.schema)?), - ))) + **left = mem::take(left.as_mut()).cast_to(&left_type, &self.schema)?; + **right = mem::take(right.as_mut()).cast_to(&right_type, &self.schema)?; } Expr::Between(Between { - expr, - negated, - low, - high, + expr, low, high, .. }) => { let expr_type = expr.get_type(&self.schema)?; let low_type = low.get_type(&self.schema)?; @@ -273,19 +221,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" )) })?; - let expr = Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, &self.schema)?), - negated, - Box::new(low.cast_to(&coercion_type, &self.schema)?), - Box::new(high.cast_to(&coercion_type, &self.schema)?), - )); - Ok(expr) + **expr = + mem::take(expr.as_mut()).cast_to(&coercion_type, &self.schema)?; + **low = mem::take(low.as_mut()).cast_to(&coercion_type, &self.schema)?; + **high = + mem::take(high.as_mut()).cast_to(&coercion_type, &self.schema)?; } - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList(InList { expr, list, .. }) => { let expr_data_type = expr.get_type(&self.schema)?; let list_data_types = list .iter() @@ -296,28 +238,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter { match result_type { None => plan_err!( "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}" - ), + )?, Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, &self.schema)?; - let cast_list_expr = list - .into_iter() - .map(|list_expr| { - list_expr.cast_to(&coerced_type, &self.schema) - }) - .collect::>>()?; - let expr = Expr::InList(InList ::new( - Box::new(cast_expr), - cast_list_expr, - negated, - )); - Ok(expr) + **expr = mem::take(expr.as_mut()).cast_to(&coerced_type, &self.schema)?; + list.iter_mut() + .try_for_each(|list_expr| { + mem::take(list_expr).cast_to(&coerced_type, &self.schema).map(|r| *list_expr = r) + })?; } } } - Expr::Case(case) => { - let case = coerce_case_expression(case, &self.schema)?; - Ok(Expr::Case(case)) + Expr::Case(_) => { + if let Expr::Case(case) = mem::take(expr) { + *expr = Expr::Case(coerce_case_expression(case, &self.schema)?); + } } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { @@ -326,12 +261,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun.signature(), )?; - let new_args = coerce_arguments_for_fun( - new_args.as_slice(), - &self.schema, - &fun, - )?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + let new_args = + coerce_arguments_for_fun(new_args.as_slice(), &self.schema, fun)?; + *args = new_args } ScalarFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -339,30 +271,23 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) + *args = new_expr } ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") + internal_err!("Function `Expr` with name should be resolved.")? } }, Expr::AggregateFunction(expr::AggregateFunction { - func_def, - args, - distinct, - filter, - order_by, + func_def, args, .. }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let new_expr = coerce_agg_exprs_for_signature( - &fun, - &args, + fun, + args, &self.schema, &fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) + *args = new_expr } AggregateFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -370,48 +295,47 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fun, new_expr, false, filter, order_by, - )); - Ok(expr) + *args = new_expr } AggregateFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") + internal_err!("Function `Expr` with name should be resolved.")? } }, - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - }) => { - let window_frame = - coerce_window_frame(window_frame, &self.schema, &order_by)?; - - let args = match &fun { - window_function::WindowFunction::AggregateFunction(fun) => { - coerce_agg_exprs_for_signature( - fun, - &args, - &self.schema, - &fun.signature(), - )? - } - _ => args, - }; - - let expr = Expr::WindowFunction(WindowFunction::new( + Expr::WindowFunction(_) => { + if let Expr::WindowFunction(WindowFunction { fun, args, partition_by, order_by, window_frame, - )); - Ok(expr) + .. + }) = mem::take(expr) + { + let window_frame = + coerce_window_frame(window_frame, &self.schema, &order_by)?; + let args = match &fun { + window_function::WindowFunction::AggregateFunction(fun) => { + coerce_agg_exprs_for_signature( + fun, + &args, + &self.schema, + &fun.signature(), + )? + } + _ => args, + }; + *expr = Expr::WindowFunction(WindowFunction::new( + fun, + args, + partition_by, + order_by, + window_frame, + )); + } } - expr => Ok(expr), + _ => {} } + Ok(TreeNodeRecursion::Continue) } } @@ -1259,7 +1183,7 @@ mod test { None, ), ))); - let expr = Expr::ScalarFunction(ScalarFunction::new( + let mut expr = Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::MakeArray, vec![val.clone()], )); @@ -1274,8 +1198,8 @@ mod test { )], std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; - let result = expr.rewrite(&mut rewriter)?; + let mut transformer = TypeCoercionRewriter { schema }; + expr.transform(&mut transformer)?; let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( @@ -1296,7 +1220,7 @@ mod test { vec![expected_casted_expr], )); - assert_eq!(result, expected); + assert_eq!(expr, expected); Ok(()) } @@ -1307,33 +1231,33 @@ mod test { vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; - let expr = is_true(lit(12i32).gt(lit(13i64))); + let mut transformer = TypeCoercionRewriter { schema }; + let mut expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; - assert_eq!(expected, result); + expr.transform(&mut transformer)?; + assert_eq!(expected, expr); // eq let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; - let expr = is_true(lit(12i32).eq(lit(13i64))); + let mut transformer = TypeCoercionRewriter { schema }; + let mut expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; - assert_eq!(expected, result); + expr.transform(&mut transformer)?; + assert_eq!(expected, expr); // lt let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; - let expr = is_true(lit(12i32).lt(lit(13i64))); + let mut transfomer = TypeCoercionRewriter { schema }; + let mut expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; - assert_eq!(expected, result); + expr.transform(&mut transfomer)?; + assert_eq!(expected, expr); Ok(()) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 1e089257c61ad..d6cad22eb7e22 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -24,7 +24,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{ - RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, + RewriteRecursion, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, @@ -612,18 +612,18 @@ impl ExprIdentifierVisitor<'_> { } impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { + fn pre_visit(&mut self, _expr: &Expr) -> Result { self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; // put placeholder self.id_array.push((0, "".to_string())); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, expr: &Expr) -> Result { + fn post_visit(&mut self, expr: &Expr) -> Result { self.series_number += 1; let (idx, sub_expr_desc) = self.pop_enter_mark(); @@ -632,7 +632,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); - return Ok(VisitRecursion::Continue); + return Ok(TreeNodeRecursion::Continue); } let mut desc = Self::desc_expr(expr); desc.push_str(&sub_expr_desc); @@ -646,7 +646,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) .1 += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b1000f042c987..a68f374d9fe68 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -370,7 +370,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for e in agg_expr.iter() { - let result_expr = e.clone().transform_up(&|expr| { + let mut result_expr = e.clone().transform_up(&|expr| { let new_expr = match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { match func_def { @@ -396,7 +396,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( Ok(new_expr) })?; - let result_expr = result_expr.unalias(); + result_expr.unalias(); let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema.clone()); let simplifier = ExprSimplifier::new(info); diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 07f495a7262df..97a56f85ef967 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, @@ -73,9 +73,9 @@ impl LogicalPlanSignature { /// Get total number of [`LogicalPlan`]s in the plan. fn get_node_number(plan: &LogicalPlan) -> NonZeroUsize { let mut node_number = 0; - plan.apply(&mut |_plan| { + plan.visit_down(&mut |_plan| { node_number += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // Closure always return Ok .unwrap(); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index c090fb849a823..af779b30d4634 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -17,7 +17,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, }; @@ -213,11 +213,11 @@ fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result Result { let mut is_evaluate = true; - predicate.apply(&mut |expr| match expr { + predicate.visit_down(&mut |expr| match expr { Expr::Column(_) | Expr::Literal(_) | Expr::Placeholder(_) - | Expr::ScalarVariable(_, _) => Ok(VisitRecursion::Skip), + | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Prune), Expr::Exists { .. } | Expr::InSubquery(_) | Expr::ScalarSubquery(_) @@ -227,7 +227,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { .. }) => { is_evaluate = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } Expr::Alias(_) | Expr::BinaryExpr(_) @@ -249,8 +249,9 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Cast(_) | Expr::TryCast(_) | Expr::ScalarFunction(..) - | Expr::InList { .. } => Ok(VisitRecursion::Continue), + | Expr::InList { .. } => Ok(TreeNodeRecursion::Continue), Expr::Sort(_) + | Expr::Nop | Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::Wildcard { .. } @@ -976,29 +977,29 @@ pub fn replace_cols_by_name( /// check whether the expression is volatile predicates fn is_volatile_expression(e: &Expr) -> bool { let mut is_volatile = false; - e.apply(&mut |expr| { + e.visit_down(&mut |expr| { Ok(match expr { Expr::ScalarFunction(f) => match &f.func_def { ScalarFunctionDefinition::BuiltIn(fun) if fun.volatility() == Volatility::Volatile => { is_volatile = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } ScalarFunctionDefinition::UDF(fun) if fun.signature().volatility == Volatility::Volatile => { is_volatile = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } ScalarFunctionDefinition::Name(_) => { return internal_err!( "Function `Expr` with name should be resolved." ); } - _ => VisitRecursion::Continue, + _ => TreeNodeRecursion::Continue, }, - _ => VisitRecursion::Continue, + _ => TreeNodeRecursion::Continue, }) }) .unwrap(); @@ -1008,17 +1009,17 @@ fn is_volatile_expression(e: &Expr) -> bool { /// check whether the expression uses the columns in `check_map`. fn contain(e: &Expr, check_map: &HashMap) -> bool { let mut is_contain = false; - e.apply(&mut |expr| { + e.visit_down(&mut |expr| { Ok(if let Expr::Column(c) = &expr { match check_map.get(&c.flat_name()) { Some(_) => { is_contain = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } - None => VisitRecursion::Continue, + None => TreeNodeRecursion::Continue, } } else { - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index e2fbd5e927a16..fe2e1345290b3 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -158,10 +158,11 @@ impl ExprSimplifier { // rather than creating an DFSchemaRef coerces rather than doing // it manually. // https://github.com/apache/arrow-datafusion/issues/3793 - pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { + pub fn coerce(&self, mut expr: Expr, schema: DFSchemaRef) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite) + expr.transform(&mut expr_rewrite)?; + Ok(expr) } /// Input guarantees about the values of columns. @@ -330,7 +331,8 @@ impl<'a> ConstEvaluator<'a> { // at plan time match expr { // Has no runtime cost, but needed during planning - Expr::Alias(..) + Expr::Nop + | Expr::Alias(..) | Expr::AggregateFunction { .. } | Expr::ScalarVariable(_, _) | Expr::Column(_) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 91603e82a54fc..dfe4d1fa9ab8b 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -24,17 +24,16 @@ use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeTransformer}; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::merge_schema; -use datafusion_expr::{ - binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, -}; +use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator}; use std::cmp::Ordering; +use std::mem; use std::sync::Arc; /// [`UnwrapCastInComparison`] attempts to remove casts from @@ -126,21 +125,19 @@ struct UnwrapCastExprRewriter { schema: DFSchemaRef, } -impl TreeNodeRewriter for UnwrapCastExprRewriter { - type N = Expr; +impl TreeNodeTransformer for UnwrapCastExprRewriter { + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) + fn pre_transform(&mut self, _expr: &mut Expr) -> Result { + Ok(TreeNodeRecursion::Continue) } - fn mutate(&mut self, expr: Expr) -> Result { - match &expr { + fn post_transform(&mut self, expr: &mut Expr) -> Result { + match expr { // For case: // try_cast/cast(expr as data_type) op literal // literal op try_cast/cast(expr as data_type) Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let left = left.as_ref().clone(); - let right = right.as_ref().clone(); let left_type = left.get_type(&self.schema)?; let right_type = right.get_type(&self.schema)?; // Because the plan has been done the type coercion, the left and right must be equal @@ -148,7 +145,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { && is_support_data_type(&right_type) && is_comparison_op(op) { - match (&left, &right) { + match (left.as_mut(), right.as_mut()) { ( Expr::Literal(left_lit_value), Expr::TryCast(TryCast { expr, .. }) @@ -161,11 +158,8 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(left_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the right expr - return Ok(binary_expr( - lit(value), - *op, - expr.as_ref().clone(), - )); + **left = lit(value); + **right = mem::take(expr.as_mut()); } } ( @@ -180,49 +174,42 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(right_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the left expr - return Ok(binary_expr( - expr.as_ref().clone(), - *op, - lit(value), - )); + **left = mem::take(expr.as_mut()); + **right = lit(value); } } (_, _) => { // do nothing } - }; + } } - // return the new binary op - Ok(binary_expr(left, *op, right)) } // For case: // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) Expr::InList(InList { expr: left_expr, list, - negated, + .. }) => { - if let Some( - Expr::TryCast(TryCast { - expr: internal_left_expr, - .. - }) - | Expr::Cast(Cast { - expr: internal_left_expr, - .. - }), - ) = Some(left_expr.as_ref()) + if let Expr::TryCast(TryCast { + expr: internal_left_expr, + .. + }) + | Expr::Cast(Cast { + expr: internal_left_expr, + .. + }) = left_expr.as_ref() { let internal_left = internal_left_expr.as_ref().clone(); let internal_left_type = internal_left.get_type(&self.schema); if internal_left_type.is_err() { // error data type - return Ok(expr); + return Ok(TreeNodeRecursion::Continue); } let internal_left_type = internal_left_type?; if !is_support_data_type(&internal_left_type) { // not supported data type - return Ok(expr); + return Ok(TreeNodeRecursion::Continue); } let right_exprs = list .iter() @@ -256,19 +243,16 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { } }) .collect::>>(); - match right_exprs { - Ok(right_exprs) => { - Ok(in_list(internal_left, right_exprs, *negated)) - } - Err(_) => Ok(expr), + if let Ok(right_exprs) = right_exprs { + **left_expr = internal_left; + *list = right_exprs; } - } else { - Ok(expr) } } // TODO: handle other expr type and dfs visit them - _ => Ok(expr), + _ => {} } + Ok(TreeNodeRecursion::Continue) } } @@ -730,11 +714,12 @@ mod tests { assert_eq!(optimize_test(expr_lt, &schema), expected); } - fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { + fn optimize_test(mut expr: Expr, schema: &DFSchemaRef) -> Expr { let mut expr_rewriter = UnwrapCastExprRewriter { schema: schema.clone(), }; - expr.rewrite(&mut expr_rewriter).unwrap() + expr.transform(&mut expr_rewriter).unwrap(); + expr } fn expr_test_schema() -> DFSchemaRef { diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs index defd7b5786a3e..4899d69bad589 100644 --- a/datafusion/physical-expr/src/equivalence.rs +++ b/datafusion/physical-expr/src/equivalence.rs @@ -352,7 +352,7 @@ impl EquivalenceGroup { /// class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { expr.clone() - .transform(&|expr| { + .transform_up(&|expr| { for cls in self.iter() { if cls.contains(&expr) { return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 52fb85657f4e4..d637cf1e54e63 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -944,7 +944,7 @@ mod tests { let expr2 = expr .clone() - .transform(&|e| { + .transform_up(&|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index f513744617769..7c61e14e345a1 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -20,7 +20,8 @@ use std::{ops::Neg, sync::Arc}; use arrow_schema::SortOptions; use crate::PhysicalExpr; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; + +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, VisitRecursionIterator}; use datafusion_common::Result; /// To propagate [`SortOptions`] across the [`PhysicalExpr`], it is insufficient @@ -173,18 +174,11 @@ impl ExprOrdering { } impl TreeNode for ExprOrdering { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + self.children.iter().for_each_till_continue(f) } fn map_children(mut self, transform: F) -> Result @@ -202,4 +196,11 @@ impl TreeNode for ExprOrdering { Ok(self) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + self.children.iter_mut().for_each_till_continue(f) + } } diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 71a7ff5fb7785..7a3271932780b 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -26,7 +26,7 @@ use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRewriter, VisitRecursion, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeTransformer, VisitRecursionIterator, }; use datafusion_common::Result; use datafusion_expr::Operator; @@ -147,19 +147,11 @@ impl ExprTreeNode { } impl TreeNode for ExprTreeNode { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in self.children() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(mut self, transform: F) -> Result @@ -173,9 +165,16 @@ impl TreeNode for ExprTreeNode { .collect::>>()?; Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + self.child_nodes.iter_mut().for_each_till_continue(f) + } } -/// This struct facilitates the [TreeNodeRewriter] mechanism to convert a +/// This struct facilitates the [TreeNodeTransformer] mechanism to convert a /// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting /// identical expressions in one node. Caller specifies the node type in the /// DAEG via the `constructor` argument, which constructs nodes in the DAEG @@ -189,16 +188,21 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result< constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter +impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeTransformer for PhysicalExprDAEGBuilder<'a, T, F> { - type N = ExprTreeNode; + type Node = ExprTreeNode; + + fn pre_transform(&mut self, _node: &mut Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } + // This method mutates an expression node by transforming it to a physical expression // and adding it to the graph. The method returns the mutated expression node. - fn mutate( + fn post_transform( &mut self, - mut node: ExprTreeNode, - ) -> Result> { + node: &mut ExprTreeNode, + ) -> Result { // Get the expression associated with the input expression node. let expr = &node.expr; @@ -210,7 +214,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter // add edges to its child nodes. Add the visited expression to the vector // of visited expressions and return the newly created node index. None => { - let node_idx = self.graph.add_node((self.constructor)(&node)?); + let node_idx = self.graph.add_node((self.constructor)(node)?); for expr_node in node.child_nodes.iter() { self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); } @@ -221,7 +225,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter // Set the data field of the input expression node to the corresponding node index. node.data = Some(node_idx); // Return the mutated expression node. - Ok(node) + Ok(TreeNodeRecursion::Continue) } } @@ -242,7 +246,8 @@ where constructor, }; // Use the builder to transform the expression tree node into a DAG. - let root = init.rewrite(&mut builder)?; + let mut root = init; + root.transform(&mut builder)?; // Return a tuple containing the root node index and the DAG. Ok((root.data.unwrap(), builder.graph)) } @@ -250,13 +255,13 @@ where /// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. pub fn collect_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::::new(); - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { if let Some(column) = expr.as_any().downcast_ref::() { if !columns.iter().any(|c| c.eq(column)) { columns.insert(column.clone()); } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2997d147424d8..f2d2779f99a80 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -484,6 +484,10 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { use protobuf::logical_expr_node::ExprType; let expr_node = match expr { + Expr::Nop => Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a Nop expression" + .to_string(), + ))?, Expr::Column(c) => Self { expr_type: Some(ExprType::Column(c.into())), },