Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TreeNode recursions #7942

Closed
wants to merge 11 commits into from
6 changes: 3 additions & 3 deletions datafusion-examples/examples/rewrite_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule {

impl MyAnalyzerRule {
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
plan.transform(&|plan| {
plan.transform_up_old(&|plan| {
Ok(match plan {
LogicalPlan::Filter(filter) => {
let predicate = Self::analyze_expr(filter.predicate.clone())?;
Expand All @@ -106,7 +106,7 @@ impl MyAnalyzerRule {
}

fn analyze_expr(expr: Expr) -> Result<Expr> {
expr.transform(&|expr| {
expr.transform_up_old(&|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Literal(ScalarValue::Int64(i)) => {
Expand Down Expand Up @@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule {

/// use rewrite_expr to modify the expression tree.
fn my_rewrite(expr: Expr) -> Result<Expr> {
expr.transform(&|expr| {
expr.transform_up_old(&|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Between(Between {
Expand Down
488 changes: 381 additions & 107 deletions datafusion/common/src/tree_node.rs

Large diffs are not rendered by default.

23 changes: 12 additions & 11 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use crate::{error::Result, scalar::ScalarValue};
use super::PartitionedFile;
use crate::datasource::listing::ListingTableUrl;
use crate::execution::context::SessionState;
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError};
use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
use datafusion_physical_expr::create_physical_expr;
Expand All @@ -52,17 +52,18 @@ use object_store::{ObjectMeta, ObjectStore};
/// was performed
pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
let mut is_applicable = true;
expr.apply(&mut |expr| {
expr.visit_down(&mut |expr| {
match expr {
Expr::Column(Column { ref name, .. }) => {
is_applicable &= col_names.contains(name);
if is_applicable {
Ok(VisitRecursion::Skip)
Ok(TreeNodeRecursion::Prune)
} else {
Ok(VisitRecursion::Stop)
Ok(TreeNodeRecursion::Stop)
}
}
Expr::Literal(_)
Expr::Nop
| Expr::Literal(_)
| Expr::Alias(_)
| Expr::OuterReferenceColumn(_, _)
| Expr::ScalarVariable(_, _)
Expand All @@ -88,27 +89,27 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
| Expr::ScalarSubquery(_)
| Expr::GetIndexedField { .. }
| Expr::GroupingSet(_)
| Expr::Case { .. } => Ok(VisitRecursion::Continue),
| Expr::Case { .. } => Ok(TreeNodeRecursion::Continue),

Expr::ScalarFunction(scalar_function) => {
match &scalar_function.func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
match fun.volatility() {
Volatility::Immutable => Ok(VisitRecursion::Continue),
Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
Ok(VisitRecursion::Stop)
Ok(TreeNodeRecursion::Stop)
}
}
}
ScalarFunctionDefinition::UDF(fun) => {
match fun.signature().volatility {
Volatility::Immutable => Ok(VisitRecursion::Continue),
Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
Ok(VisitRecursion::Stop)
Ok(TreeNodeRecursion::Stop)
}
}
}
Expand All @@ -128,7 +129,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
| Expr::Wildcard { .. }
| Expr::Placeholder(_) => {
is_applicable = false;
Ok(VisitRecursion::Stop)
Ok(TreeNodeRecursion::Stop)
}
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use arrow::{array::ArrayRef, datatypes::Schema};
use arrow_schema::FieldRef;
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{Column, DataFusionError, Result, ScalarValue};
use parquet::file::metadata::ColumnChunkMetaData;
use parquet::schema::types::SchemaDescriptor;
Expand Down Expand Up @@ -259,7 +259,7 @@ impl BloomFilterPruningPredicate {

fn get_predicate_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<String> {
let mut columns = HashSet::new();
expr.apply(&mut |expr| {
expr.visit_down(&mut |expr| {
if let Some(binary_expr) =
expr.as_any().downcast_ref::<phys_expr::BinaryExpr>()
{
Expand All @@ -269,7 +269,7 @@ impl BloomFilterPruningPredicate {
columns.insert(column.name().to_string());
}
}
Ok(VisitRecursion::Continue)
Ok(TreeNodeRecursion::Continue)
})
// no way to fail as only Ok(VisitRecursion::Continue) is returned
.unwrap();
Expand Down
12 changes: 8 additions & 4 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::{
use datafusion_common::{
alias::AliasGenerator,
exec_err, not_impl_err, plan_datafusion_err, plan_err,
tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion},
tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor},
};
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
Expand Down Expand Up @@ -2093,9 +2093,9 @@ impl<'a> BadPlanVisitor<'a> {
}

impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> {
type N = LogicalPlan;
type Node = LogicalPlan;

fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion> {
fn pre_visit(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
match node {
LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
plan_err!("DDL not supported: {}", ddl.name())
Expand All @@ -2109,9 +2109,13 @@ impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> {
LogicalPlan::Statement(stmt) if !self.options.allow_statements => {
plan_err!("Statement not supported: {}", stmt.name())
}
_ => Ok(VisitRecursion::Continue),
_ => Ok(TreeNodeRecursion::Continue),
}
}

fn post_visit(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/coalesce_batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl PhysicalOptimizerRule for CoalesceBatches {
}

let target_batch_size = config.execution.batch_size;
plan.transform_up(&|plan| {
plan.transform_up_old(&|plan| {
let plan_any = plan.as_any();
// The goal here is to detect operators that could produce small batches and only
// wrap those ones with a CoalesceBatchesExec operator. An alternate approach here
Expand Down
103 changes: 48 additions & 55 deletions datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGro
use crate::physical_plan::ExecutionPlan;

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};

Expand All @@ -48,27 +48,27 @@ impl CombinePartialFinalAggregate {
impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
mut plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_down(&|plan| {
let transformed =
plan.as_any()
.downcast_ref::<AggregateExec>()
.and_then(|agg_exec| {
if matches!(
agg_exec.mode(),
AggregateMode::Final | AggregateMode::FinalPartitioned
) {
agg_exec
.input()
.as_any()
.downcast_ref::<AggregateExec>()
.and_then(|input_agg_exec| {
if matches!(
input_agg_exec.mode(),
AggregateMode::Partial
) && can_combine(
plan.transform_down(&mut |plan| {
plan.clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is certainly nice to avoid the clone

.as_any()
.downcast_ref::<AggregateExec>()
.into_iter()
.for_each(|agg_exec| {
if matches!(
agg_exec.mode(),
AggregateMode::Final | AggregateMode::FinalPartitioned
) {
agg_exec
.input()
.as_any()
.downcast_ref::<AggregateExec>()
.into_iter()
.for_each(|input_agg_exec| {
if matches!(input_agg_exec.mode(), AggregateMode::Partial)
&& can_combine(
(
agg_exec.group_by(),
agg_exec.aggr_expr(),
Expand All @@ -79,41 +79,34 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
input_agg_exec.aggr_expr(),
input_agg_exec.filter_expr(),
),
) {
let mode =
if agg_exec.mode() == &AggregateMode::Final {
AggregateMode::Single
} else {
AggregateMode::SinglePartitioned
};
AggregateExec::try_new(
mode,
input_agg_exec.group_by().clone(),
input_agg_exec.aggr_expr().to_vec(),
input_agg_exec.filter_expr().to_vec(),
input_agg_exec.input().clone(),
input_agg_exec.input_schema(),
)
.map(|combined_agg| {
combined_agg.with_limit(agg_exec.limit())
})
.ok()
.map(Arc::new)
)
{
let mode = if agg_exec.mode() == &AggregateMode::Final
{
AggregateMode::Single
} else {
None
}
})
} else {
None
}
});

Ok(if let Some(transformed) = transformed {
Transformed::Yes(transformed)
} else {
Transformed::No(plan)
})
})
AggregateMode::SinglePartitioned
};
AggregateExec::try_new(
mode,
input_agg_exec.group_by().clone(),
input_agg_exec.aggr_expr().to_vec(),
input_agg_exec.filter_expr().to_vec(),
input_agg_exec.input().clone(),
input_agg_exec.input_schema(),
)
.map(|combined_agg| {
combined_agg.with_limit(agg_exec.limit())
})
.into_iter()
.for_each(|p| *plan = Arc::new(p))
}
})
}
});
Ok(TreeNodeRecursion::Continue)
})?;
Ok(plan)
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -178,7 +171,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs {
fn discard_column_index(group_expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
group_expr
.clone()
.transform(&|expr| {
.transform_up_old(&|expr| {
let normalized_form: Option<Arc<dyn PhysicalExpr>> =
match expr.as_any().downcast_ref::<Column>() {
Some(column) => Some(Arc::new(Column::new(column.name(), 0))),
Expand Down
Loading