Skip to content

Commit

Permalink
split up expr for rewriting, visiting, and simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist committed Feb 7, 2022
1 parent a39a223 commit 0d90e5b
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 129 deletions.
2 changes: 1 addition & 1 deletion datafusion/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use log::debug;
use crate::{
error::Result,
execution::context::ExecutionContext,
logical_plan::{self, Expr, ExpressionVisitor, Recursion},
logical_plan::{self, Expr, ExprVisitable, ExpressionVisitor, Recursion},
physical_plan::functions::Volatility,
scalar::ScalarValue,
};
Expand Down
262 changes: 140 additions & 122 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,128 +557,13 @@ impl Expr {
nulls_first,
}
}
}

/// Performs a depth first walk of an expression and
/// its children, calling [`ExpressionVisitor::pre_visit`] and
/// `visitor.post_visit`.
///
/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to
/// separate expression algorithms from the structure of the
/// `Expr` tree and make it easier to add new types of expressions
/// and algorithms that walk the tree.
///
/// For an expression tree such as
/// ```text
/// BinaryExpr (GT)
/// left: Column("foo")
/// right: Column("bar")
/// ```
///
/// The nodes are visited using the following order
/// ```text
/// pre_visit(BinaryExpr(GT))
/// pre_visit(Column("foo"))
/// pre_visit(Column("bar"))
/// post_visit(Column("bar"))
/// post_visit(Column("bar"))
/// post_visit(BinaryExpr(GT))
/// ```
///
/// If an Err result is returned, recursion is stopped immediately
///
/// If `Recursion::Stop` is returned on a call to pre_visit, no
/// children of that expression are visited, nor is post_visit
/// called on that expression
///
pub fn accept<V: ExpressionVisitor>(&self, visitor: V) -> Result<V> {
let visitor = match visitor.pre_visit(self)? {
Recursion::Continue(visitor) => visitor,
// If the recursion should stop, do not visit children
Recursion::Stop(visitor) => return Ok(visitor),
};

// recurse (and cover all expression types)
let visitor = match self {
Expr::Alias(expr, _)
| Expr::Not(expr)
| Expr::IsNotNull(expr)
| Expr::IsNull(expr)
| Expr::Negative(expr)
| Expr::Cast { expr, .. }
| Expr::TryCast { expr, .. }
| Expr::Sort { expr, .. }
| Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
Expr::Column(_)
| Expr::ScalarVariable(_)
| Expr::Literal(_)
| Expr::Wildcard => Ok(visitor),
Expr::BinaryExpr { left, right, .. } => {
let visitor = left.accept(visitor)?;
right.accept(visitor)
}
Expr::Between {
expr, low, high, ..
} => {
let visitor = expr.accept(visitor)?;
let visitor = low.accept(visitor)?;
high.accept(visitor)
}
Expr::Case {
expr,
when_then_expr,
else_expr,
} => {
let visitor = if let Some(expr) = expr.as_ref() {
expr.accept(visitor)
} else {
Ok(visitor)
}?;
let visitor = when_then_expr.iter().try_fold(
visitor,
|visitor, (when, then)| {
let visitor = when.accept(visitor)?;
then.accept(visitor)
},
)?;
if let Some(else_expr) = else_expr.as_ref() {
else_expr.accept(visitor)
} else {
Ok(visitor)
}
}
Expr::ScalarFunction { args, .. }
| Expr::ScalarUDF { args, .. }
| Expr::AggregateFunction { args, .. }
| Expr::AggregateUDF { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
Expr::WindowFunction {
args,
partition_by,
order_by,
..
} => {
let visitor = args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
let visitor = partition_by
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
let visitor = order_by
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
Ok(visitor)
}
Expr::InList { expr, list, .. } => {
let visitor = expr.accept(visitor)?;
list.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))
}
}?;

visitor.post_visit(self)
}
pub(crate) trait ExprRewritable: Sized {
fn rewrite<R: ExprRewriter>(self, rewriter: &mut R) -> Result<Self>;
}

impl ExprRewritable for Expr {
/// Performs a depth first walk of an expression and its children
/// to rewrite an expression, consuming `self` producing a new
/// [`Expr`].
Expand Down Expand Up @@ -712,7 +597,7 @@ impl Expr {
/// children of that expression are visited, nor is mutate
/// called on that expression
///
pub fn rewrite<R>(self, rewriter: &mut R) -> Result<Self>
fn rewrite<R>(self, rewriter: &mut R) -> Result<Self>
where
R: ExprRewriter,
{
Expand Down Expand Up @@ -847,7 +732,140 @@ impl Expr {
Ok(expr)
}
}
}

pub(crate) trait ExprVisitable {
fn accept<V: ExpressionVisitor>(&self, visitor: V) -> Result<V>;
}

impl ExprVisitable for Expr {
/// Performs a depth first walk of an expression and
/// its children, calling [`ExpressionVisitor::pre_visit`] and
/// `visitor.post_visit`.
///
/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to
/// separate expression algorithms from the structure of the
/// `Expr` tree and make it easier to add new types of expressions
/// and algorithms that walk the tree.
///
/// For an expression tree such as
/// ```text
/// BinaryExpr (GT)
/// left: Column("foo")
/// right: Column("bar")
/// ```
///
/// The nodes are visited using the following order
/// ```text
/// pre_visit(BinaryExpr(GT))
/// pre_visit(Column("foo"))
/// pre_visit(Column("bar"))
/// post_visit(Column("bar"))
/// post_visit(Column("bar"))
/// post_visit(BinaryExpr(GT))
/// ```
///
/// If an Err result is returned, recursion is stopped immediately
///
/// If `Recursion::Stop` is returned on a call to pre_visit, no
/// children of that expression are visited, nor is post_visit
/// called on that expression
///
fn accept<V: ExpressionVisitor>(&self, visitor: V) -> Result<V> {
let visitor = match visitor.pre_visit(self)? {
Recursion::Continue(visitor) => visitor,
// If the recursion should stop, do not visit children
Recursion::Stop(visitor) => return Ok(visitor),
};

// recurse (and cover all expression types)
let visitor = match self {
Expr::Alias(expr, _)
| Expr::Not(expr)
| Expr::IsNotNull(expr)
| Expr::IsNull(expr)
| Expr::Negative(expr)
| Expr::Cast { expr, .. }
| Expr::TryCast { expr, .. }
| Expr::Sort { expr, .. }
| Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
Expr::Column(_)
| Expr::ScalarVariable(_)
| Expr::Literal(_)
| Expr::Wildcard => Ok(visitor),
Expr::BinaryExpr { left, right, .. } => {
let visitor = left.accept(visitor)?;
right.accept(visitor)
}
Expr::Between {
expr, low, high, ..
} => {
let visitor = expr.accept(visitor)?;
let visitor = low.accept(visitor)?;
high.accept(visitor)
}
Expr::Case {
expr,
when_then_expr,
else_expr,
} => {
let visitor = if let Some(expr) = expr.as_ref() {
expr.accept(visitor)
} else {
Ok(visitor)
}?;
let visitor = when_then_expr.iter().try_fold(
visitor,
|visitor, (when, then)| {
let visitor = when.accept(visitor)?;
then.accept(visitor)
},
)?;
if let Some(else_expr) = else_expr.as_ref() {
else_expr.accept(visitor)
} else {
Ok(visitor)
}
}
Expr::ScalarFunction { args, .. }
| Expr::ScalarUDF { args, .. }
| Expr::AggregateFunction { args, .. }
| Expr::AggregateUDF { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
Expr::WindowFunction {
args,
partition_by,
order_by,
..
} => {
let visitor = args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
let visitor = partition_by
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
let visitor = order_by
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
Ok(visitor)
}
Expr::InList { expr, list, .. } => {
let visitor = expr.accept(visitor)?;
list.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))
}
}?;

visitor.post_visit(self)
}
}

pub(crate) trait ExprSimplifiable: Sized {
fn simplify<S: SimplifyInfo>(self, info: &S) -> Result<Self>;
}

impl ExprSimplifiable for Expr {
/// Simplifies this [`Expr`]`s as much as possible, evaluating
/// constants and applying algebraic simplifications
///
Expand Down Expand Up @@ -889,7 +907,7 @@ impl Expr {
/// let expr = expr.simplify(&Info::default()).unwrap();
/// assert_eq!(expr, b_lt_2);
/// ```
pub fn simplify<S: SimplifyInfo>(self, info: &S) -> Result<Self> {
fn simplify<S: SimplifyInfo>(self, info: &S) -> Result<Self> {
let mut rewriter = Simplifier::new(info);
let mut const_evaluator = ConstEvaluator::new(info.execution_props());

Expand Down
1 change: 1 addition & 0 deletions datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub use expr::{
Column, Expr, ExprRewriter, ExprSchema, ExpressionVisitor, Literal, Recursion,
RewriteRecursion, SimplifyInfo,
};
pub(crate) use expr::{ExprRewritable, ExprSimplifiable, ExprVisitable};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
pub use plan::{
Expand Down
4 changes: 2 additions & 2 deletions datafusion/src/optimizer/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use crate::logical_plan::plan::{Filter, Projection, Window};
use crate::logical_plan::{
col,
plan::{Aggregate, Sort},
DFField, DFSchema, Expr, ExprRewriter, ExpressionVisitor, LogicalPlan, Recursion,
RewriteRecursion,
DFField, DFSchema, Expr, ExprRewritable, ExprRewriter, ExprVisitable,
ExpressionVisitor, LogicalPlan, Recursion, RewriteRecursion,
};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
Expand Down
4 changes: 2 additions & 2 deletions datafusion/src/optimizer/simplify_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ use arrow::record_batch::RecordBatch;
use crate::error::DataFusionError;
use crate::execution::context::ExecutionProps;
use crate::logical_plan::{
lit, DFSchema, DFSchemaRef, Expr, ExprRewriter, LogicalPlan, RewriteRecursion,
SimplifyInfo,
lit, DFSchema, DFSchemaRef, Expr, ExprRewritable, ExprRewriter, ExprSimplifiable,
LogicalPlan, RewriteRecursion, SimplifyInfo,
};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
Expand Down
6 changes: 4 additions & 2 deletions datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ use crate::execution::context::ExecutionProps;
use crate::logical_plan::plan::{
Aggregate, Analyze, Extension, Filter, Join, Projection, Sort, Window,
};

use crate::logical_plan::{
build_join_schema, Column, CreateMemoryTable, DFSchemaRef, Expr, Limit, LogicalPlan,
LogicalPlanBuilder, Operator, Partitioning, Recursion, Repartition, Union, Values,
build_join_schema, Column, CreateMemoryTable, DFSchemaRef, Expr, ExprVisitable,
Limit, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, Recursion,
Repartition, Union, Values,
};
use crate::prelude::lit;
use crate::scalar::ScalarValue;
Expand Down
1 change: 1 addition & 0 deletions datafusion/src/sql/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use arrow::datatypes::DataType;
use sqlparser::ast::Ident;

use crate::logical_plan::ExprVisitable;
use crate::logical_plan::{Expr, LogicalPlan};
use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128};
use crate::{
Expand Down

0 comments on commit 0d90e5b

Please sign in to comment.