Skip to content

Commit

Permalink
POC
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Dec 15, 2023
1 parent b276d47 commit 36e36fd
Show file tree
Hide file tree
Showing 40 changed files with 1,109 additions and 993 deletions.
6 changes: 3 additions & 3 deletions datafusion-examples/examples/rewrite_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule {

impl MyAnalyzerRule {
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
plan.transform(&|plan| {
plan.transform_up(&|plan| {
Ok(match plan {
LogicalPlan::Filter(filter) => {
let predicate = Self::analyze_expr(filter.predicate.clone())?;
Expand All @@ -106,7 +106,7 @@ impl MyAnalyzerRule {
}

fn analyze_expr(expr: Expr) -> Result<Expr> {
expr.transform(&|expr| {
expr.transform_up(&|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Literal(ScalarValue::Int64(i)) => {
Expand Down Expand Up @@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule {

/// use rewrite_expr to modify the expression tree.
fn my_rewrite(expr: Expr) -> Result<Expr> {
expr.transform(&|expr| {
expr.transform_up(&|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Between(Between {
Expand Down
403 changes: 296 additions & 107 deletions datafusion/common/src/tree_node.rs

Large diffs are not rendered by default.

23 changes: 12 additions & 11 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use crate::{error::Result, scalar::ScalarValue};
use super::PartitionedFile;
use crate::datasource::listing::ListingTableUrl;
use crate::execution::context::SessionState;
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError};
use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
use datafusion_physical_expr::create_physical_expr;
Expand All @@ -52,17 +52,18 @@ use object_store::{ObjectMeta, ObjectStore};
/// was performed
pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
let mut is_applicable = true;
expr.apply(&mut |expr| {
expr.visit_down(&mut |expr| {
match expr {
Expr::Column(Column { ref name, .. }) => {
is_applicable &= col_names.contains(name);
if is_applicable {
Ok(VisitRecursion::Skip)
Ok(TreeNodeRecursion::Prune)
} else {
Ok(VisitRecursion::Stop)
Ok(TreeNodeRecursion::Stop)
}
}
Expr::Literal(_)
Expr::Nop
| Expr::Literal(_)
| Expr::Alias(_)
| Expr::OuterReferenceColumn(_, _)
| Expr::ScalarVariable(_, _)
Expand All @@ -88,27 +89,27 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
| Expr::ScalarSubquery(_)
| Expr::GetIndexedField { .. }
| Expr::GroupingSet(_)
| Expr::Case { .. } => Ok(VisitRecursion::Continue),
| Expr::Case { .. } => Ok(TreeNodeRecursion::Continue),

Expr::ScalarFunction(scalar_function) => {
match &scalar_function.func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
match fun.volatility() {
Volatility::Immutable => Ok(VisitRecursion::Continue),
Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
Ok(VisitRecursion::Stop)
Ok(TreeNodeRecursion::Stop)
}
}
}
ScalarFunctionDefinition::UDF(fun) => {
match fun.signature().volatility {
Volatility::Immutable => Ok(VisitRecursion::Continue),
Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
Ok(VisitRecursion::Stop)
Ok(TreeNodeRecursion::Stop)
}
}
}
Expand All @@ -128,7 +129,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
| Expr::Wildcard { .. }
| Expr::Placeholder(_) => {
is_applicable = false;
Ok(VisitRecursion::Stop)
Ok(TreeNodeRecursion::Stop)
}
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use arrow::{array::ArrayRef, datatypes::Schema};
use arrow_schema::FieldRef;
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{Column, DataFusionError, Result, ScalarValue};
use parquet::file::metadata::ColumnChunkMetaData;
use parquet::schema::types::SchemaDescriptor;
Expand Down Expand Up @@ -259,7 +259,7 @@ impl BloomFilterPruningPredicate {

fn get_predicate_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<String> {
let mut columns = HashSet::new();
expr.apply(&mut |expr| {
expr.visit_down(&mut |expr| {
if let Some(binary_expr) =
expr.as_any().downcast_ref::<phys_expr::BinaryExpr>()
{
Expand All @@ -269,7 +269,7 @@ impl BloomFilterPruningPredicate {
columns.insert(column.name().to_string());
}
}
Ok(VisitRecursion::Continue)
Ok(TreeNodeRecursion::Continue)
})
// no way to fail as only Ok(VisitRecursion::Continue) is returned
.unwrap();
Expand Down
12 changes: 8 additions & 4 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::{
use datafusion_common::{
alias::AliasGenerator,
exec_err, not_impl_err, plan_datafusion_err, plan_err,
tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion},
tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor},
};
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
Expand Down Expand Up @@ -2098,9 +2098,9 @@ impl<'a> BadPlanVisitor<'a> {
}

impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> {
type N = LogicalPlan;
type Node = LogicalPlan;

fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion> {
fn pre_visit(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
match node {
LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
plan_err!("DDL not supported: {}", ddl.name())
Expand All @@ -2114,9 +2114,13 @@ impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> {
LogicalPlan::Statement(stmt) if !self.options.allow_statements => {
plan_err!("Statement not supported: {}", stmt.name())
}
_ => Ok(VisitRecursion::Continue),
_ => Ok(TreeNodeRecursion::Continue),
}
}

fn post_visit(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}
}

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs {
fn discard_column_index(group_expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
group_expr
.clone()
.transform(&|expr| {
.transform_up(&|expr| {
let normalized_form: Option<Arc<dyn PhysicalExpr>> =
match expr.as_any().downcast_ref::<Column>() {
Some(column) => Some(Arc::new(Column::new(column.name(), 0))),
Expand Down
65 changes: 42 additions & 23 deletions datafusion/core/src/physical_optimizer/enforce_distribution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ use crate::physical_plan::{
};

use arrow::compute::SortOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator,
};
use datafusion_expr::logical_plan::JoinType;
use datafusion_physical_expr::expressions::{Column, NoOp};
use datafusion_physical_expr::utils::map_columns_before_projection;
Expand Down Expand Up @@ -1476,18 +1478,11 @@ impl DistributionContext {
}

impl TreeNode for DistributionContext {
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&Self) -> Result<VisitRecursion>,
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
{
for child in self.children() {
match op(&child)? {
VisitRecursion::Continue => {}
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
}
}
Ok(VisitRecursion::Continue)
self.children().iter().for_each_till_continue(f)
}

fn map_children<F>(self, transform: F) -> Result<Self>
Expand All @@ -1505,6 +1500,23 @@ impl TreeNode for DistributionContext {
DistributionContext::new_from_children_nodes(children_nodes, self.plan)
}
}

fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
{
let mut children = self.children();
if children.is_empty() {
Ok(TreeNodeRecursion::Continue)
} else {
let tnr = children.iter_mut().for_each_till_continue(f)?;
*self = DistributionContext::new_from_children_nodes(
children,
self.plan.clone(),
)?;
Ok(tnr)
}
}
}

/// implement Display method for `DistributionContext` struct.
Expand Down Expand Up @@ -1566,20 +1578,11 @@ impl PlanWithKeyRequirements {
}

impl TreeNode for PlanWithKeyRequirements {
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&Self) -> Result<VisitRecursion>,
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
{
let children = self.children();
for child in children {
match op(&child)? {
VisitRecursion::Continue => {}
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
}
}

Ok(VisitRecursion::Continue)
self.children().iter().for_each_till_continue(f)
}

fn map_children<F>(self, transform: F) -> Result<Self>
Expand All @@ -1605,6 +1608,22 @@ impl TreeNode for PlanWithKeyRequirements {
Ok(self)
}
}

fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
{
let mut children = self.children();
if !children.is_empty() {
let tnr = children.iter_mut().for_each_till_continue(f)?;
let children_plans = children.into_iter().map(|c| c.plan).collect();
self.plan =
with_new_children_if_necessary(self.plan.clone(), children_plans)?.into();
Ok(tnr)
} else {
Ok(TreeNodeRecursion::Continue)
}
}
}

/// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on
Expand Down
68 changes: 43 additions & 25 deletions datafusion/core/src/physical_optimizer/enforce_sorting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ use crate::physical_plan::{
with_new_children_if_necessary, Distribution, ExecutionPlan, InputOrderMode,
};

use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator,
};
use datafusion_common::{plan_err, DataFusionError};
use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement};

Expand Down Expand Up @@ -157,20 +159,11 @@ impl PlanWithCorrespondingSort {
}

impl TreeNode for PlanWithCorrespondingSort {
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&Self) -> Result<VisitRecursion>,
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
{
let children = self.children();
for child in children {
match op(&child)? {
VisitRecursion::Continue => {}
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
}
}

Ok(VisitRecursion::Continue)
self.children().iter().for_each_till_continue(f)
}

fn map_children<F>(self, transform: F) -> Result<Self>
Expand All @@ -188,6 +181,23 @@ impl TreeNode for PlanWithCorrespondingSort {
PlanWithCorrespondingSort::new_from_children_nodes(children_nodes, self.plan)
}
}

fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
{
let mut children = self.children();
if children.is_empty() {
Ok(TreeNodeRecursion::Continue)
} else {
let tnr = children.iter_mut().for_each_till_continue(f)?;
*self = PlanWithCorrespondingSort::new_from_children_nodes(
children,
self.plan.clone(),
)?;
Ok(tnr)
}
}
}

/// This object is used within the [EnforceSorting] rule to track the closest
Expand Down Expand Up @@ -273,20 +283,11 @@ impl PlanWithCorrespondingCoalescePartitions {
}

impl TreeNode for PlanWithCorrespondingCoalescePartitions {
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&Self) -> Result<VisitRecursion>,
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
{
let children = self.children();
for child in children {
match op(&child)? {
VisitRecursion::Continue => {}
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
}
}

Ok(VisitRecursion::Continue)
self.children().iter().for_each_till_continue(f)
}

fn map_children<F>(self, transform: F) -> Result<Self>
Expand All @@ -307,6 +308,23 @@ impl TreeNode for PlanWithCorrespondingCoalescePartitions {
)
}
}

fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
{
let mut children = self.children();
if children.is_empty() {
Ok(TreeNodeRecursion::Continue)
} else {
let tnr = children.iter_mut().for_each_till_continue(f)?;
*self = PlanWithCorrespondingCoalescePartitions::new_from_children_nodes(
children,
self.plan.clone(),
)?;
Ok(tnr)
}
}
}

/// The boolean flag `repartition_sorts` defined in the config indicates
Expand Down
Loading

0 comments on commit 36e36fd

Please sign in to comment.