From 1fc52790026aacb168ee39f031f4e516b7b98eec Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Mon, 7 Feb 2022 23:36:58 +0800 Subject: [PATCH] split up expr for rewriting, visiting, and simplification --- datafusion/src/datasource/listing/helpers.rs | 2 +- datafusion/src/logical_plan/expr.rs | 764 +----------------- datafusion/src/logical_plan/expr_rewriter.rs | 589 ++++++++++++++ datafusion/src/logical_plan/expr_simplier.rs | 95 +++ datafusion/src/logical_plan/expr_visitor.rs | 174 ++++ datafusion/src/logical_plan/mod.rs | 21 +- .../src/optimizer/common_subexpr_eliminate.rs | 4 +- .../src/optimizer/simplify_expressions.rs | 5 +- datafusion/src/optimizer/utils.rs | 6 +- datafusion/src/sql/utils.rs | 1 + datafusion/tests/simplification.rs | 1 + 11 files changed, 886 insertions(+), 776 deletions(-) create mode 100644 datafusion/src/logical_plan/expr_rewriter.rs create mode 100644 datafusion/src/logical_plan/expr_simplier.rs create mode 100644 datafusion/src/logical_plan/expr_visitor.rs diff --git a/datafusion/src/datasource/listing/helpers.rs b/datafusion/src/datasource/listing/helpers.rs index 912179c36f06..8ff821082906 100644 --- a/datafusion/src/datasource/listing/helpers.rs +++ b/datafusion/src/datasource/listing/helpers.rs @@ -37,7 +37,7 @@ use log::debug; use crate::{ error::Result, execution::context::ExecutionContext, - logical_plan::{self, Expr, ExpressionVisitor, Recursion}, + logical_plan::{self, Expr, ExprVisitable, ExpressionVisitor, Recursion}, physical_plan::functions::Volatility, scalar::ScalarValue, }; diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 4b539a814551..69da346aee8d 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -20,12 +20,8 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; -use crate::execution::context::ExecutionProps; use crate::field_util::get_indexed_field; -use crate::logical_plan::{ - plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan, -}; -use crate::optimizer::simplify_expressions::{ConstEvaluator, Simplifier}; +use crate::logical_plan::{window_frames, DFField, DFSchema}; use crate::physical_plan::functions::Volatility; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, @@ -36,7 +32,7 @@ use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::{compute::can_cast_types, datatypes::DataType}; pub use datafusion_common::{Column, ExprSchema}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::fmt; use std::hash::{BuildHasher, Hash, Hasher}; use std::ops::Not; @@ -557,348 +553,6 @@ impl Expr { nulls_first, } } - - /// Performs a depth first walk of an expression and - /// its children, calling [`ExpressionVisitor::pre_visit`] and - /// `visitor.post_visit`. - /// - /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to - /// separate expression algorithms from the structure of the - /// `Expr` tree and make it easier to add new types of expressions - /// and algorithms that walk the tree. - /// - /// For an expression tree such as - /// ```text - /// BinaryExpr (GT) - /// left: Column("foo") - /// right: Column("bar") - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(BinaryExpr(GT)) - /// pre_visit(Column("foo")) - /// pre_visit(Column("bar")) - /// post_visit(Column("bar")) - /// post_visit(Column("bar")) - /// post_visit(BinaryExpr(GT)) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If `Recursion::Stop` is returned on a call to pre_visit, no - /// children of that expression are visited, nor is post_visit - /// called on that expression - /// - pub fn accept(&self, visitor: V) -> Result { - let visitor = match visitor.pre_visit(self)? { - Recursion::Continue(visitor) => visitor, - // If the recursion should stop, do not visit children - Recursion::Stop(visitor) => return Ok(visitor), - }; - - // recurse (and cover all expression types) - let visitor = match self { - Expr::Alias(expr, _) - | Expr::Not(expr) - | Expr::IsNotNull(expr) - | Expr::IsNull(expr) - | Expr::Negative(expr) - | Expr::Cast { expr, .. } - | Expr::TryCast { expr, .. } - | Expr::Sort { expr, .. } - | Expr::GetIndexedField { expr, .. } => expr.accept(visitor), - Expr::Column(_) - | Expr::ScalarVariable(_) - | Expr::Literal(_) - | Expr::Wildcard => Ok(visitor), - Expr::BinaryExpr { left, right, .. } => { - let visitor = left.accept(visitor)?; - right.accept(visitor) - } - Expr::Between { - expr, low, high, .. - } => { - let visitor = expr.accept(visitor)?; - let visitor = low.accept(visitor)?; - high.accept(visitor) - } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let visitor = if let Some(expr) = expr.as_ref() { - expr.accept(visitor) - } else { - Ok(visitor) - }?; - let visitor = when_then_expr.iter().try_fold( - visitor, - |visitor, (when, then)| { - let visitor = when.accept(visitor)?; - then.accept(visitor) - }, - )?; - if let Some(else_expr) = else_expr.as_ref() { - else_expr.accept(visitor) - } else { - Ok(visitor) - } - } - Expr::ScalarFunction { args, .. } - | Expr::ScalarUDF { args, .. } - | Expr::AggregateFunction { args, .. } - | Expr::AggregateUDF { args, .. } => args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)), - Expr::WindowFunction { - args, - partition_by, - order_by, - .. - } => { - let visitor = args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; - let visitor = partition_by - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; - let visitor = order_by - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; - Ok(visitor) - } - Expr::InList { expr, list, .. } => { - let visitor = expr.accept(visitor)?; - list.iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)) - } - }?; - - visitor.post_visit(self) - } - - /// Performs a depth first walk of an expression and its children - /// to rewrite an expression, consuming `self` producing a new - /// [`Expr`]. - /// - /// Implements a modified version of the [visitor - /// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to - /// separate algorithms from the structure of the `Expr` tree and - /// make it easier to write new, efficient expression - /// transformation algorithms. - /// - /// For an expression tree such as - /// ```text - /// BinaryExpr (GT) - /// left: Column("foo") - /// right: Column("bar") - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(BinaryExpr(GT)) - /// pre_visit(Column("foo")) - /// mutatate(Column("foo")) - /// pre_visit(Column("bar")) - /// mutate(Column("bar")) - /// mutate(BinaryExpr(GT)) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`false`] is returned on a call to pre_visit, no - /// children of that expression are visited, nor is mutate - /// called on that expression - /// - pub fn rewrite(self, rewriter: &mut R) -> Result - where - R: ExprRewriter, - { - let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, - }; - - // recurse into all sub expressions(and cover all expression types) - let expr = match self { - Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), - Expr::Column(_) => self.clone(), - Expr::ScalarVariable(names) => Expr::ScalarVariable(names), - Expr::Literal(value) => Expr::Literal(value), - Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { - left: rewrite_boxed(left, rewriter)?, - op, - right: rewrite_boxed(right, rewriter)?, - }, - Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?), - Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?), - Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?), - Expr::Negative(expr) => Expr::Negative(rewrite_boxed(expr, rewriter)?), - Expr::Between { - expr, - low, - high, - negated, - } => Expr::Between { - expr: rewrite_boxed(expr, rewriter)?, - low: rewrite_boxed(low, rewriter)?, - high: rewrite_boxed(high, rewriter)?, - negated, - }, - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let expr = rewrite_option_box(expr, rewriter)?; - let when_then_expr = when_then_expr - .into_iter() - .map(|(when, then)| { - Ok(( - rewrite_boxed(when, rewriter)?, - rewrite_boxed(then, rewriter)?, - )) - }) - .collect::>>()?; - - let else_expr = rewrite_option_box(else_expr, rewriter)?; - - Expr::Case { - expr, - when_then_expr, - else_expr, - } - } - Expr::Cast { expr, data_type } => Expr::Cast { - expr: rewrite_boxed(expr, rewriter)?, - data_type, - }, - Expr::TryCast { expr, data_type } => Expr::TryCast { - expr: rewrite_boxed(expr, rewriter)?, - data_type, - }, - Expr::Sort { - expr, - asc, - nulls_first, - } => Expr::Sort { - expr: rewrite_boxed(expr, rewriter)?, - asc, - nulls_first, - }, - Expr::ScalarFunction { args, fun } => Expr::ScalarFunction { - args: rewrite_vec(args, rewriter)?, - fun, - }, - Expr::ScalarUDF { args, fun } => Expr::ScalarUDF { - args: rewrite_vec(args, rewriter)?, - fun, - }, - Expr::WindowFunction { - args, - fun, - partition_by, - order_by, - window_frame, - } => Expr::WindowFunction { - args: rewrite_vec(args, rewriter)?, - fun, - partition_by: rewrite_vec(partition_by, rewriter)?, - order_by: rewrite_vec(order_by, rewriter)?, - window_frame, - }, - Expr::AggregateFunction { - args, - fun, - distinct, - } => Expr::AggregateFunction { - args: rewrite_vec(args, rewriter)?, - fun, - distinct, - }, - Expr::AggregateUDF { args, fun } => Expr::AggregateUDF { - args: rewrite_vec(args, rewriter)?, - fun, - }, - Expr::InList { - expr, - list, - negated, - } => Expr::InList { - expr: rewrite_boxed(expr, rewriter)?, - list: rewrite_vec(list, rewriter)?, - negated, - }, - Expr::Wildcard => Expr::Wildcard, - Expr::GetIndexedField { expr, key } => Expr::GetIndexedField { - expr: rewrite_boxed(expr, rewriter)?, - key, - }, - }; - - // now rewrite this expression itself - if need_mutate { - rewriter.mutate(expr) - } else { - Ok(expr) - } - } - - /// Simplifies this [`Expr`]`s as much as possible, evaluating - /// constants and applying algebraic simplifications - /// - /// # Example: - /// `b > 2 AND b > 2` - /// can be written to - /// `b > 2` - /// - /// ``` - /// use datafusion::logical_plan::*; - /// use datafusion::error::Result; - /// use datafusion::execution::context::ExecutionProps; - /// - /// /// Simple implementation that provides `Simplifier` the information it needs - /// #[derive(Default)] - /// struct Info { - /// execution_props: ExecutionProps, - /// }; - /// - /// impl SimplifyInfo for Info { - /// fn is_boolean_type(&self, expr: &Expr) -> Result { - /// Ok(false) - /// } - /// fn nullable(&self, expr: &Expr) -> Result { - /// Ok(true) - /// } - /// fn execution_props(&self) -> &ExecutionProps { - /// &self.execution_props - /// } - /// } - /// - /// // b < 2 - /// let b_lt_2 = col("b").gt(lit(2)); - /// - /// // (b < 2) OR (b < 2) - /// let expr = b_lt_2.clone().or(b_lt_2.clone()); - /// - /// // (b < 2) OR (b < 2) --> (b < 2) - /// let expr = expr.simplify(&Info::default()).unwrap(); - /// assert_eq!(expr, b_lt_2); - /// ``` - pub fn simplify(self, info: &S) -> Result { - let mut rewriter = Simplifier::new(info); - let mut const_evaluator = ConstEvaluator::new(info.execution_props()); - - // TODO iterate until no changes are made during rewrite - // (evaluating constants can enable new simplifications and - // simplifications can enable new constant evaluation) - // https://github.com/apache/arrow-datafusion/issues/1160 - self.rewrite(&mut const_evaluator)?.rewrite(&mut rewriter) - } } impl Not for Expr { @@ -936,103 +590,6 @@ impl std::fmt::Display for Expr { } } -#[allow(clippy::boxed_local)] -fn rewrite_boxed(boxed_expr: Box, rewriter: &mut R) -> Result> -where - R: ExprRewriter, -{ - // TODO: It might be possible to avoid an allocation (the - // Box::new) below by reusing the box. - let expr: Expr = *boxed_expr; - let rewritten_expr = expr.rewrite(rewriter)?; - Ok(Box::new(rewritten_expr)) -} - -fn rewrite_option_box( - option_box: Option>, - rewriter: &mut R, -) -> Result>> -where - R: ExprRewriter, -{ - option_box - .map(|expr| rewrite_boxed(expr, rewriter)) - .transpose() -} - -/// rewrite a `Vec` of `Expr`s with the rewriter -fn rewrite_vec(v: Vec, rewriter: &mut R) -> Result> -where - R: ExprRewriter, -{ - v.into_iter().map(|expr| expr.rewrite(rewriter)).collect() -} - -/// Controls how the visitor recursion should proceed. -pub enum Recursion { - /// Attempt to visit all the children, recursively, of this expression. - Continue(V), - /// Do not visit the children of this expression, though the walk - /// of parents of this expression will not be affected - Stop(V), -} - -/// Encode the traversal of an expression tree. When passed to -/// `Expr::accept`, `ExpressionVisitor::visit` is invoked -/// recursively on all nodes of an expression tree. See the comments -/// on `Expr::accept` for details on its use -pub trait ExpressionVisitor: Sized { - /// Invoked before any children of `expr` are visisted. - fn pre_visit(self, expr: &Expr) -> Result>; - - /// Invoked after all children of `expr` are visited. Default - /// implementation does nothing. - fn post_visit(self, _expr: &Expr) -> Result { - Ok(self) - } -} - -/// Controls how the [ExprRewriter] recursion should proceed. -pub enum RewriteRecursion { - /// Continue rewrite / visit this expression. - Continue, - /// Call [mutate()] immediately and return. - Mutate, - /// Do not rewrite / visit the children of this expression. - Stop, - /// Keep recursive but skip mutate on this expression - Skip, -} - -/// Trait for potentially recursively rewriting an [`Expr`] expression -/// tree. When passed to `Expr::rewrite`, `ExpressionVisitor::mutate` is -/// invoked recursively on all nodes of an expression tree. See the -/// comments on `Expr::rewrite` for details on its use -pub trait ExprRewriter: Sized { - /// Invoked before any children of `expr` are rewritten / - /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) - } - - /// Invoked after all children of `expr` have been mutated and - /// returns a potentially modified expr. - fn mutate(&mut self, expr: Expr) -> Result; -} - -/// The information necessary to apply algebraic simplification to an -/// [Expr]. See [SimplifyContext] for one implementation -pub trait SimplifyInfo { - /// returns true if this Expr has boolean type - fn is_boolean_type(&self, expr: &Expr) -> Result; - - /// returns true of this expr is nullable (could possibly be NULL) - fn nullable(&self, expr: &Expr) -> Result; - - /// Returns details needed for partial expression evaluation - fn execution_props(&self) -> &ExecutionProps; -} - /// Helper struct for building [Expr::Case] pub struct CaseBuilder { expr: Option>, @@ -1201,183 +758,6 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { } } -/// Recursively replace all Column expressions in a given expression tree with Column expressions -/// provided by the hash map argument. -pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - struct ColumnReplacer<'a> { - replace_map: &'a HashMap<&'a Column, &'a Column>, - } - - impl<'a> ExprRewriter for ColumnReplacer<'a> { - fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(c) = &expr { - match self.replace_map.get(c) { - Some(new_c) => Ok(Expr::Column((*new_c).to_owned())), - None => Ok(expr), - } - } else { - Ok(expr) - } - } - } - - e.rewrite(&mut ColumnReplacer { replace_map }) -} - -/// Recursively call [`Column::normalize`] on all Column expressions -/// in the `expr` expression tree. -pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - normalize_col_with_schemas(expr, &plan.all_schemas(), &plan.using_columns()?) -} - -/// Recursively call [`Column::normalize`] on all Column expressions -/// in the `expr` expression tree. -fn normalize_col_with_schemas( - expr: Expr, - schemas: &[&Arc], - using_columns: &[HashSet], -) -> Result { - struct ColumnNormalizer<'a> { - schemas: &'a [&'a Arc], - using_columns: &'a [HashSet], - } - - impl<'a> ExprRewriter for ColumnNormalizer<'a> { - fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(c) = expr { - Ok(Expr::Column(c.normalize_with_schemas( - self.schemas, - self.using_columns, - )?)) - } else { - Ok(expr) - } - } - } - - expr.rewrite(&mut ColumnNormalizer { - schemas, - using_columns, - }) -} - -/// Recursively normalize all Column expressions in a list of expression trees -pub fn normalize_cols( - exprs: impl IntoIterator>, - plan: &LogicalPlan, -) -> Result> { - exprs - .into_iter() - .map(|e| normalize_col(e.into(), plan)) - .collect() -} - -/// Rewrite sort on aggregate expressions to sort on the column of aggregate output -/// For example, `max(x)` is written to `col("MAX(x)")` -pub fn rewrite_sort_cols_by_aggs( - exprs: impl IntoIterator>, - plan: &LogicalPlan, -) -> Result> { - exprs - .into_iter() - .map(|e| { - let expr = e.into(); - match expr { - Expr::Sort { - expr, - asc, - nulls_first, - } => { - let sort = Expr::Sort { - expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?), - asc, - nulls_first, - }; - Ok(sort) - } - expr => Ok(expr), - } - }) - .collect() -} - -fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { - match plan { - LogicalPlan::Aggregate(Aggregate { - input, aggr_expr, .. - }) => { - struct Rewriter<'a> { - plan: &'a LogicalPlan, - input: &'a LogicalPlan, - aggr_expr: &'a Vec, - } - - impl<'a> ExprRewriter for Rewriter<'a> { - fn mutate(&mut self, expr: Expr) -> Result { - let normalized_expr = normalize_col(expr.clone(), self.plan); - if normalized_expr.is_err() { - // The expr is not based on Aggregate plan output. Skip it. - return Ok(expr); - } - let normalized_expr = normalized_expr.unwrap(); - if let Some(found_agg) = - self.aggr_expr.iter().find(|a| (**a) == normalized_expr) - { - let agg = normalize_col(found_agg.clone(), self.plan)?; - let col = Expr::Column( - agg.to_field(self.input.schema()) - .map(|f| f.qualified_column())?, - ); - Ok(col) - } else { - Ok(expr) - } - } - } - - expr.rewrite(&mut Rewriter { - plan, - input, - aggr_expr, - }) - } - LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]), - _ => Ok(expr), - } -} - -/// Recursively 'unnormalize' (remove all qualifiers) from an -/// expression tree. -/// -/// For example, if there were expressions like `foo.bar` this would -/// rewrite it to just `bar`. -pub fn unnormalize_col(expr: Expr) -> Expr { - struct RemoveQualifier {} - - impl ExprRewriter for RemoveQualifier { - fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(col) = expr { - //let Column { relation: _, name } = col; - Ok(Expr::Column(Column { - relation: None, - name: col.name, - })) - } else { - Ok(expr) - } - } - } - - expr.rewrite(&mut RemoveQualifier {}) - .expect("Unnormalize is infallable") -} - -/// Recursively un-normalize all Column expressions in a list of expression trees -#[inline] -pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { - exprs.into_iter().map(unnormalize_col).collect() -} - /// Recursively un-alias an expressions #[inline] pub fn unalias(expr: Expr) -> Expr { @@ -2114,24 +1494,6 @@ mod tests { assert_eq!(expr, expected); } - #[test] - fn rewriter_visit() { - let mut rewriter = RecordingRewriter::default(); - col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); - - assert_eq!( - rewriter.v, - vec![ - "Previsited #state = Utf8(\"CO\")", - "Previsited #state", - "Mutated #state", - "Previsited Utf8(\"CO\")", - "Mutated Utf8(\"CO\")", - "Mutated #state = Utf8(\"CO\")" - ] - ) - } - #[test] fn filter_is_null_and_is_not_null() { let col_null = col("col1"); @@ -2143,128 +1505,6 @@ mod tests { ); } - #[derive(Default)] - struct RecordingRewriter { - v: Vec, - } - impl ExprRewriter for RecordingRewriter { - fn mutate(&mut self, expr: Expr) -> Result { - self.v.push(format!("Mutated {:?}", expr)); - Ok(expr) - } - - fn pre_visit(&mut self, expr: &Expr) -> Result { - self.v.push(format!("Previsited {:?}", expr)); - Ok(RewriteRecursion::Continue) - } - } - - #[test] - fn rewriter_rewrite() { - let mut rewriter = FooBarRewriter {}; - - // rewrites "foo" --> "bar" - let rewritten = col("state").eq(lit("foo")).rewrite(&mut rewriter).unwrap(); - assert_eq!(rewritten, col("state").eq(lit("bar"))); - - // doesn't wrewrite - let rewritten = col("state").eq(lit("baz")).rewrite(&mut rewriter).unwrap(); - assert_eq!(rewritten, col("state").eq(lit("baz"))); - } - - /// rewrites all "foo" string literals to "bar" - struct FooBarRewriter {} - impl ExprRewriter for FooBarRewriter { - fn mutate(&mut self, expr: Expr) -> Result { - match expr { - Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { - let utf8_val = if utf8_val == "foo" { - "bar".to_string() - } else { - utf8_val - }; - Ok(lit(utf8_val)) - } - // otherwise, return the expression unchanged - expr => Ok(expr), - } - } - } - - #[test] - fn normalize_cols() { - let expr = col("a") + col("b") + col("c"); - - // Schemas with some matching and some non matching cols - let schema_a = - DFSchema::new(vec![make_field("tableA", "a"), make_field("tableA", "aa")]) - .unwrap(); - let schema_c = - DFSchema::new(vec![make_field("tableC", "cc"), make_field("tableC", "c")]) - .unwrap(); - let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); - // non matching - let schema_f = - DFSchema::new(vec![make_field("tableC", "f"), make_field("tableC", "ff")]) - .unwrap(); - let schemas = vec![schema_c, schema_f, schema_b, schema_a] - .into_iter() - .map(Arc::new) - .collect::>(); - let schemas = schemas.iter().collect::>(); - - let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); - assert_eq!( - normalized_expr, - col("tableA.a") + col("tableB.b") + col("tableC.c") - ); - } - - #[test] - fn normalize_cols_priority() { - let expr = col("a") + col("b"); - // Schemas with multiple matches for column a, first takes priority - let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); - let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); - let schema_a2 = DFSchema::new(vec![make_field("tableA2", "a")]).unwrap(); - let schemas = vec![schema_a2, schema_b, schema_a] - .into_iter() - .map(Arc::new) - .collect::>(); - let schemas = schemas.iter().collect::>(); - - let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); - assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b")); - } - - #[test] - fn normalize_cols_non_exist() { - // test normalizing columns when the name doesn't exist - let expr = col("a") + col("b"); - let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); - let schemas = vec![schema_a].into_iter().map(Arc::new).collect::>(); - let schemas = schemas.iter().collect::>(); - - let error = normalize_col_with_schemas(expr, &schemas, &[]) - .unwrap_err() - .to_string(); - assert_eq!( - error, - "Error during planning: Column #b not found in provided schemas" - ); - } - - #[test] - fn unnormalize_cols() { - let expr = col("tableA.a") + col("tableB.b"); - let unnormalized_expr = unnormalize_col(expr); - assert_eq!(unnormalized_expr, col("a") + col("b")); - } - - fn make_field(relation: &str, column: &str) -> DFField { - DFField::new(Some(relation), column, DataType::Int8, false) - } - #[test] fn test_not() { assert_eq!(lit(1).not(), !lit(1)); diff --git a/datafusion/src/logical_plan/expr_rewriter.rs b/datafusion/src/logical_plan/expr_rewriter.rs new file mode 100644 index 000000000000..67be9b7bf523 --- /dev/null +++ b/datafusion/src/logical_plan/expr_rewriter.rs @@ -0,0 +1,589 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression rewriter + +use super::Expr; +use crate::logical_plan::plan::Aggregate; +use crate::logical_plan::DFSchema; +use crate::logical_plan::LogicalPlan; +use datafusion_common::Column; +use datafusion_common::Result; +use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Arc; + +/// Controls how the [ExprRewriter] recursion should proceed. +pub enum RewriteRecursion { + /// Continue rewrite / visit this expression. + Continue, + /// Call [mutate()] immediately and return. + Mutate, + /// Do not rewrite / visit the children of this expression. + Stop, + /// Keep recursive but skip mutate on this expression + Skip, +} + +/// Trait for potentially recursively rewriting an [`Expr`] expression +/// tree. When passed to `Expr::rewrite`, `ExpressionVisitor::mutate` is +/// invoked recursively on all nodes of an expression tree. See the +/// comments on `Expr::rewrite` for details on its use +pub trait ExprRewriter: Sized { + /// Invoked before any children of `expr` are rewritten / + /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` + fn pre_visit(&mut self, _expr: &E) -> Result { + Ok(RewriteRecursion::Continue) + } + + /// Invoked after all children of `expr` have been mutated and + /// returns a potentially modified expr. + fn mutate(&mut self, expr: E) -> Result; +} + +pub trait ExprRewritable: Sized { + fn rewrite>(self, rewriter: &mut R) -> Result; +} + +impl ExprRewritable for Expr { + /// Performs a depth first walk of an expression and its children + /// to rewrite an expression, consuming `self` producing a new + /// [`Expr`]. + /// + /// Implements a modified version of the [visitor + /// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to + /// separate algorithms from the structure of the `Expr` tree and + /// make it easier to write new, efficient expression + /// transformation algorithms. + /// + /// For an expression tree such as + /// ```text + /// BinaryExpr (GT) + /// left: Column("foo") + /// right: Column("bar") + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(BinaryExpr(GT)) + /// pre_visit(Column("foo")) + /// mutatate(Column("foo")) + /// pre_visit(Column("bar")) + /// mutate(Column("bar")) + /// mutate(BinaryExpr(GT)) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`false`] is returned on a call to pre_visit, no + /// children of that expression are visited, nor is mutate + /// called on that expression + /// + fn rewrite(self, rewriter: &mut R) -> Result + where + R: ExprRewriter, + { + let need_mutate = match rewriter.pre_visit(&self)? { + RewriteRecursion::Mutate => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::Continue => true, + RewriteRecursion::Skip => false, + }; + + // recurse into all sub expressions(and cover all expression types) + let expr = match self { + Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), + Expr::Column(_) => self.clone(), + Expr::ScalarVariable(names) => Expr::ScalarVariable(names), + Expr::Literal(value) => Expr::Literal(value), + Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { + left: rewrite_boxed(left, rewriter)?, + op, + right: rewrite_boxed(right, rewriter)?, + }, + Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?), + Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?), + Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?), + Expr::Negative(expr) => Expr::Negative(rewrite_boxed(expr, rewriter)?), + Expr::Between { + expr, + low, + high, + negated, + } => Expr::Between { + expr: rewrite_boxed(expr, rewriter)?, + low: rewrite_boxed(low, rewriter)?, + high: rewrite_boxed(high, rewriter)?, + negated, + }, + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let expr = rewrite_option_box(expr, rewriter)?; + let when_then_expr = when_then_expr + .into_iter() + .map(|(when, then)| { + Ok(( + rewrite_boxed(when, rewriter)?, + rewrite_boxed(then, rewriter)?, + )) + }) + .collect::>>()?; + + let else_expr = rewrite_option_box(else_expr, rewriter)?; + + Expr::Case { + expr, + when_then_expr, + else_expr, + } + } + Expr::Cast { expr, data_type } => Expr::Cast { + expr: rewrite_boxed(expr, rewriter)?, + data_type, + }, + Expr::TryCast { expr, data_type } => Expr::TryCast { + expr: rewrite_boxed(expr, rewriter)?, + data_type, + }, + Expr::Sort { + expr, + asc, + nulls_first, + } => Expr::Sort { + expr: rewrite_boxed(expr, rewriter)?, + asc, + nulls_first, + }, + Expr::ScalarFunction { args, fun } => Expr::ScalarFunction { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::ScalarUDF { args, fun } => Expr::ScalarUDF { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::WindowFunction { + args, + fun, + partition_by, + order_by, + window_frame, + } => Expr::WindowFunction { + args: rewrite_vec(args, rewriter)?, + fun, + partition_by: rewrite_vec(partition_by, rewriter)?, + order_by: rewrite_vec(order_by, rewriter)?, + window_frame, + }, + Expr::AggregateFunction { + args, + fun, + distinct, + } => Expr::AggregateFunction { + args: rewrite_vec(args, rewriter)?, + fun, + distinct, + }, + Expr::AggregateUDF { args, fun } => Expr::AggregateUDF { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::InList { + expr, + list, + negated, + } => Expr::InList { + expr: rewrite_boxed(expr, rewriter)?, + list: rewrite_vec(list, rewriter)?, + negated, + }, + Expr::Wildcard => Expr::Wildcard, + Expr::GetIndexedField { expr, key } => Expr::GetIndexedField { + expr: rewrite_boxed(expr, rewriter)?, + key, + }, + }; + + // now rewrite this expression itself + if need_mutate { + rewriter.mutate(expr) + } else { + Ok(expr) + } + } +} + +#[allow(clippy::boxed_local)] +fn rewrite_boxed(boxed_expr: Box, rewriter: &mut R) -> Result> +where + R: ExprRewriter, +{ + // TODO: It might be possible to avoid an allocation (the + // Box::new) below by reusing the box. + let expr: Expr = *boxed_expr; + let rewritten_expr = expr.rewrite(rewriter)?; + Ok(Box::new(rewritten_expr)) +} + +fn rewrite_option_box( + option_box: Option>, + rewriter: &mut R, +) -> Result>> +where + R: ExprRewriter, +{ + option_box + .map(|expr| rewrite_boxed(expr, rewriter)) + .transpose() +} + +/// rewrite a `Vec` of `Expr`s with the rewriter +fn rewrite_vec(v: Vec, rewriter: &mut R) -> Result> +where + R: ExprRewriter, +{ + v.into_iter().map(|expr| expr.rewrite(rewriter)).collect() +} + +/// Rewrite sort on aggregate expressions to sort on the column of aggregate output +/// For example, `max(x)` is written to `col("MAX(x)")` +pub fn rewrite_sort_cols_by_aggs( + exprs: impl IntoIterator>, + plan: &LogicalPlan, +) -> Result> { + exprs + .into_iter() + .map(|e| { + let expr = e.into(); + match expr { + Expr::Sort { + expr, + asc, + nulls_first, + } => { + let sort = Expr::Sort { + expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?), + asc, + nulls_first, + }; + Ok(sort) + } + expr => Ok(expr), + } + }) + .collect() +} + +fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Aggregate(Aggregate { + input, aggr_expr, .. + }) => { + struct Rewriter<'a> { + plan: &'a LogicalPlan, + input: &'a LogicalPlan, + aggr_expr: &'a Vec, + } + + impl<'a> ExprRewriter for Rewriter<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + let normalized_expr = normalize_col(expr.clone(), self.plan); + if normalized_expr.is_err() { + // The expr is not based on Aggregate plan output. Skip it. + return Ok(expr); + } + let normalized_expr = normalized_expr.unwrap(); + if let Some(found_agg) = + self.aggr_expr.iter().find(|a| (**a) == normalized_expr) + { + let agg = normalize_col(found_agg.clone(), self.plan)?; + let col = Expr::Column( + agg.to_field(self.input.schema()) + .map(|f| f.qualified_column())?, + ); + Ok(col) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut Rewriter { + plan, + input, + aggr_expr, + }) + } + LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]), + _ => Ok(expr), + } +} + +/// Recursively call [`Column::normalize`] on all Column expressions +/// in the `expr` expression tree. +pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { + normalize_col_with_schemas(expr, &plan.all_schemas(), &plan.using_columns()?) +} + +/// Recursively call [`Column::normalize`] on all Column expressions +/// in the `expr` expression tree. +fn normalize_col_with_schemas( + expr: Expr, + schemas: &[&Arc], + using_columns: &[HashSet], +) -> Result { + struct ColumnNormalizer<'a> { + schemas: &'a [&'a Arc], + using_columns: &'a [HashSet], + } + + impl<'a> ExprRewriter for ColumnNormalizer<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(c) = expr { + Ok(Expr::Column(c.normalize_with_schemas( + self.schemas, + self.using_columns, + )?)) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut ColumnNormalizer { + schemas, + using_columns, + }) +} + +/// Recursively normalize all Column expressions in a list of expression trees +pub fn normalize_cols( + exprs: impl IntoIterator>, + plan: &LogicalPlan, +) -> Result> { + exprs + .into_iter() + .map(|e| normalize_col(e.into(), plan)) + .collect() +} + +/// Recursively replace all Column expressions in a given expression tree with Column expressions +/// provided by the hash map argument. +pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { + struct ColumnReplacer<'a> { + replace_map: &'a HashMap<&'a Column, &'a Column>, + } + + impl<'a> ExprRewriter for ColumnReplacer<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(c) = &expr { + match self.replace_map.get(c) { + Some(new_c) => Ok(Expr::Column((*new_c).to_owned())), + None => Ok(expr), + } + } else { + Ok(expr) + } + } + } + + e.rewrite(&mut ColumnReplacer { replace_map }) +} + +/// Recursively 'unnormalize' (remove all qualifiers) from an +/// expression tree. +/// +/// For example, if there were expressions like `foo.bar` this would +/// rewrite it to just `bar`. +pub fn unnormalize_col(expr: Expr) -> Expr { + struct RemoveQualifier {} + + impl ExprRewriter for RemoveQualifier { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(col) = expr { + //let Column { relation: _, name } = col; + Ok(Expr::Column(Column { + relation: None, + name: col.name, + })) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut RemoveQualifier {}) + .expect("Unnormalize is infallable") +} + +/// Recursively un-normalize all Column expressions in a list of expression trees +#[inline] +pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { + exprs.into_iter().map(unnormalize_col).collect() +} + +#[cfg(test)] +mod test { + use super::*; + use crate::logical_plan::DFField; + use crate::prelude::{col, lit}; + use arrow::datatypes::DataType; + use datafusion_common::ScalarValue; + + #[derive(Default)] + struct RecordingRewriter { + v: Vec, + } + impl ExprRewriter for RecordingRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + self.v.push(format!("Mutated {:?}", expr)); + Ok(expr) + } + + fn pre_visit(&mut self, expr: &Expr) -> Result { + self.v.push(format!("Previsited {:?}", expr)); + Ok(RewriteRecursion::Continue) + } + } + + #[test] + fn rewriter_rewrite() { + let mut rewriter = FooBarRewriter {}; + + // rewrites "foo" --> "bar" + let rewritten = col("state").eq(lit("foo")).rewrite(&mut rewriter).unwrap(); + assert_eq!(rewritten, col("state").eq(lit("bar"))); + + // doesn't wrewrite + let rewritten = col("state").eq(lit("baz")).rewrite(&mut rewriter).unwrap(); + assert_eq!(rewritten, col("state").eq(lit("baz"))); + } + + /// rewrites all "foo" string literals to "bar" + struct FooBarRewriter {} + impl ExprRewriter for FooBarRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { + let utf8_val = if utf8_val == "foo" { + "bar".to_string() + } else { + utf8_val + }; + Ok(lit(utf8_val)) + } + // otherwise, return the expression unchanged + expr => Ok(expr), + } + } + } + + #[test] + fn normalize_cols() { + let expr = col("a") + col("b") + col("c"); + + // Schemas with some matching and some non matching cols + let schema_a = + DFSchema::new(vec![make_field("tableA", "a"), make_field("tableA", "aa")]) + .unwrap(); + let schema_c = + DFSchema::new(vec![make_field("tableC", "cc"), make_field("tableC", "c")]) + .unwrap(); + let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); + // non matching + let schema_f = + DFSchema::new(vec![make_field("tableC", "f"), make_field("tableC", "ff")]) + .unwrap(); + let schemas = vec![schema_c, schema_f, schema_b, schema_a] + .into_iter() + .map(Arc::new) + .collect::>(); + let schemas = schemas.iter().collect::>(); + + let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); + assert_eq!( + normalized_expr, + col("tableA.a") + col("tableB.b") + col("tableC.c") + ); + } + + #[test] + fn normalize_cols_priority() { + let expr = col("a") + col("b"); + // Schemas with multiple matches for column a, first takes priority + let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); + let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); + let schema_a2 = DFSchema::new(vec![make_field("tableA2", "a")]).unwrap(); + let schemas = vec![schema_a2, schema_b, schema_a] + .into_iter() + .map(Arc::new) + .collect::>(); + let schemas = schemas.iter().collect::>(); + + let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); + assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b")); + } + + #[test] + fn normalize_cols_non_exist() { + // test normalizing columns when the name doesn't exist + let expr = col("a") + col("b"); + let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); + let schemas = vec![schema_a].into_iter().map(Arc::new).collect::>(); + let schemas = schemas.iter().collect::>(); + + let error = normalize_col_with_schemas(expr, &schemas, &[]) + .unwrap_err() + .to_string(); + assert_eq!( + error, + "Error during planning: Column #b not found in provided schemas" + ); + } + + #[test] + fn unnormalize_cols() { + let expr = col("tableA.a") + col("tableB.b"); + let unnormalized_expr = unnormalize_col(expr); + assert_eq!(unnormalized_expr, col("a") + col("b")); + } + + fn make_field(relation: &str, column: &str) -> DFField { + DFField::new(Some(relation), column, DataType::Int8, false) + } + + #[test] + fn rewriter_visit() { + let mut rewriter = RecordingRewriter::default(); + col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); + + assert_eq!( + rewriter.v, + vec![ + "Previsited #state = Utf8(\"CO\")", + "Previsited #state", + "Mutated #state", + "Previsited Utf8(\"CO\")", + "Mutated Utf8(\"CO\")", + "Mutated #state = Utf8(\"CO\")" + ] + ) + } +} diff --git a/datafusion/src/logical_plan/expr_simplier.rs b/datafusion/src/logical_plan/expr_simplier.rs new file mode 100644 index 000000000000..bc1b02e92b23 --- /dev/null +++ b/datafusion/src/logical_plan/expr_simplier.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression simplifier + +use super::Expr; +use super::ExprRewritable; +use crate::execution::context::ExecutionProps; +use crate::optimizer::simplify_expressions::{ConstEvaluator, Simplifier}; +use datafusion_common::Result; + +/// The information necessary to apply algebraic simplification to an +/// [Expr]. See [SimplifyContext] for one implementation +pub trait SimplifyInfo { + /// returns true if this Expr has boolean type + fn is_boolean_type(&self, expr: &Expr) -> Result; + + /// returns true of this expr is nullable (could possibly be NULL) + fn nullable(&self, expr: &Expr) -> Result; + + /// Returns details needed for partial expression evaluation + fn execution_props(&self) -> &ExecutionProps; +} + +pub trait ExprSimplifiable: Sized { + fn simplify(self, info: &S) -> Result; +} + +impl ExprSimplifiable for Expr { + /// Simplifies this [`Expr`]`s as much as possible, evaluating + /// constants and applying algebraic simplifications + /// + /// # Example: + /// `b > 2 AND b > 2` + /// can be written to + /// `b > 2` + /// + /// ``` + /// use datafusion::logical_plan::*; + /// use datafusion::error::Result; + /// use datafusion::execution::context::ExecutionProps; + /// + /// /// Simple implementation that provides `Simplifier` the information it needs + /// #[derive(Default)] + /// struct Info { + /// execution_props: ExecutionProps, + /// }; + /// + /// impl SimplifyInfo for Info { + /// fn is_boolean_type(&self, expr: &Expr) -> Result { + /// Ok(false) + /// } + /// fn nullable(&self, expr: &Expr) -> Result { + /// Ok(true) + /// } + /// fn execution_props(&self) -> &ExecutionProps { + /// &self.execution_props + /// } + /// } + /// + /// // b < 2 + /// let b_lt_2 = col("b").gt(lit(2)); + /// + /// // (b < 2) OR (b < 2) + /// let expr = b_lt_2.clone().or(b_lt_2.clone()); + /// + /// // (b < 2) OR (b < 2) --> (b < 2) + /// let expr = expr.simplify(&Info::default()).unwrap(); + /// assert_eq!(expr, b_lt_2); + /// ``` + fn simplify(self, info: &S) -> Result { + let mut rewriter = Simplifier::new(info); + let mut const_evaluator = ConstEvaluator::new(info.execution_props()); + + // TODO iterate until no changes are made during rewrite + // (evaluating constants can enable new simplifications and + // simplifications can enable new constant evaluation) + // https://github.com/apache/arrow-datafusion/issues/1160 + self.rewrite(&mut const_evaluator)?.rewrite(&mut rewriter) + } +} diff --git a/datafusion/src/logical_plan/expr_visitor.rs b/datafusion/src/logical_plan/expr_visitor.rs new file mode 100644 index 000000000000..f983be51c4e6 --- /dev/null +++ b/datafusion/src/logical_plan/expr_visitor.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression visitor + +use super::Expr; +use datafusion_common::Result; + +/// Controls how the visitor recursion should proceed. +pub enum Recursion { + /// Attempt to visit all the children, recursively, of this expression. + Continue(V), + /// Do not visit the children of this expression, though the walk + /// of parents of this expression will not be affected + Stop(V), +} + +/// Encode the traversal of an expression tree. When passed to +/// `Expr::accept`, `ExpressionVisitor::visit` is invoked +/// recursively on all nodes of an expression tree. See the comments +/// on `Expr::accept` for details on its use +pub trait ExpressionVisitor: Sized { + /// Invoked before any children of `expr` are visisted. + fn pre_visit(self, expr: &E) -> Result> + where + Self: ExpressionVisitor; + + /// Invoked after all children of `expr` are visited. Default + /// implementation does nothing. + fn post_visit(self, _expr: &E) -> Result { + Ok(self) + } +} + +pub trait ExprVisitable: Sized { + fn accept>(&self, visitor: V) -> Result; +} + +impl ExprVisitable for Expr { + /// Performs a depth first walk of an expression and + /// its children, calling [`ExpressionVisitor::pre_visit`] and + /// `visitor.post_visit`. + /// + /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to + /// separate expression algorithms from the structure of the + /// `Expr` tree and make it easier to add new types of expressions + /// and algorithms that walk the tree. + /// + /// For an expression tree such as + /// ```text + /// BinaryExpr (GT) + /// left: Column("foo") + /// right: Column("bar") + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(BinaryExpr(GT)) + /// pre_visit(Column("foo")) + /// pre_visit(Column("bar")) + /// post_visit(Column("bar")) + /// post_visit(Column("bar")) + /// post_visit(BinaryExpr(GT)) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If `Recursion::Stop` is returned on a call to pre_visit, no + /// children of that expression are visited, nor is post_visit + /// called on that expression + /// + fn accept(&self, visitor: V) -> Result { + let visitor = match visitor.pre_visit(self)? { + Recursion::Continue(visitor) => visitor, + // If the recursion should stop, do not visit children + Recursion::Stop(visitor) => return Ok(visitor), + }; + + // recurse (and cover all expression types) + let visitor = match self { + Expr::Alias(expr, _) + | Expr::Not(expr) + | Expr::IsNotNull(expr) + | Expr::IsNull(expr) + | Expr::Negative(expr) + | Expr::Cast { expr, .. } + | Expr::TryCast { expr, .. } + | Expr::Sort { expr, .. } + | Expr::GetIndexedField { expr, .. } => expr.accept(visitor), + Expr::Column(_) + | Expr::ScalarVariable(_) + | Expr::Literal(_) + | Expr::Wildcard => Ok(visitor), + Expr::BinaryExpr { left, right, .. } => { + let visitor = left.accept(visitor)?; + right.accept(visitor) + } + Expr::Between { + expr, low, high, .. + } => { + let visitor = expr.accept(visitor)?; + let visitor = low.accept(visitor)?; + high.accept(visitor) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let visitor = if let Some(expr) = expr.as_ref() { + expr.accept(visitor) + } else { + Ok(visitor) + }?; + let visitor = when_then_expr.iter().try_fold( + visitor, + |visitor, (when, then)| { + let visitor = when.accept(visitor)?; + then.accept(visitor) + }, + )?; + if let Some(else_expr) = else_expr.as_ref() { + else_expr.accept(visitor) + } else { + Ok(visitor) + } + } + Expr::ScalarFunction { args, .. } + | Expr::ScalarUDF { args, .. } + | Expr::AggregateFunction { args, .. } + | Expr::AggregateUDF { args, .. } => args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor)), + Expr::WindowFunction { + args, + partition_by, + order_by, + .. + } => { + let visitor = args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + let visitor = partition_by + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + let visitor = order_by + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + Ok(visitor) + } + Expr::InList { expr, list, .. } => { + let visitor = expr.accept(visitor)?; + list.iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor)) + } + }?; + + visitor.post_visit(self) + } +} diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index ec1aea6a72a1..085775a2eb8c 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -25,6 +25,9 @@ pub(crate) mod builder; mod dfschema; mod display; mod expr; +mod expr_rewriter; +mod expr_simplier; +mod expr_visitor; mod extension; mod operators; pub mod plan; @@ -41,14 +44,18 @@ pub use expr::{ columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, - lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, - or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse, - rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, - signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, - translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, - Column, Expr, ExprRewriter, ExprSchema, ExpressionVisitor, Literal, Recursion, - RewriteRecursion, SimplifyInfo, + lower, lpad, ltrim, max, md5, min, now, octet_length, or, random, regexp_match, + regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, + sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, + to_hex, translate, trim, trunc, unalias, upper, when, Column, Expr, ExprSchema, + Literal, }; +pub use expr_rewriter::{ + normalize_col, normalize_cols, replace_col, rewrite_sort_cols_by_aggs, + unnormalize_col, unnormalize_cols, ExprRewritable, ExprRewriter, RewriteRecursion, +}; +pub use expr_simplier::{ExprSimplifiable, SimplifyInfo}; +pub use expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; pub use plan::{ diff --git a/datafusion/src/optimizer/common_subexpr_eliminate.rs b/datafusion/src/optimizer/common_subexpr_eliminate.rs index 947073409d05..5c2219b3d99a 100644 --- a/datafusion/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/src/optimizer/common_subexpr_eliminate.rs @@ -23,8 +23,8 @@ use crate::logical_plan::plan::{Filter, Projection, Window}; use crate::logical_plan::{ col, plan::{Aggregate, Sort}, - DFField, DFSchema, Expr, ExprRewriter, ExpressionVisitor, LogicalPlan, Recursion, - RewriteRecursion, + DFField, DFSchema, Expr, ExprRewritable, ExprRewriter, ExprVisitable, + ExpressionVisitor, LogicalPlan, Recursion, RewriteRecursion, }; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index 5f87542491d7..f8f3df44b673 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -24,8 +24,8 @@ use arrow::record_batch::RecordBatch; use crate::error::DataFusionError; use crate::execution::context::ExecutionProps; use crate::logical_plan::{ - lit, DFSchema, DFSchemaRef, Expr, ExprRewriter, LogicalPlan, RewriteRecursion, - SimplifyInfo, + lit, DFSchema, DFSchemaRef, Expr, ExprRewritable, ExprRewriter, ExprSimplifiable, + LogicalPlan, RewriteRecursion, SimplifyInfo, }; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; @@ -252,6 +252,7 @@ impl SimplifyExpressions { /// /// ``` /// # use datafusion::prelude::*; +/// # use datafusion::logical_plan::ExprRewritable; /// # use datafusion::optimizer::simplify_expressions::ConstEvaluator; /// # use datafusion::execution::context::ExecutionProps; /// diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index f7ab836b398c..41d1e4bca03b 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -22,9 +22,11 @@ use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{ Aggregate, Analyze, Extension, Filter, Join, Projection, Sort, Window, }; + use crate::logical_plan::{ - build_join_schema, Column, CreateMemoryTable, DFSchemaRef, Expr, Limit, LogicalPlan, - LogicalPlanBuilder, Operator, Partitioning, Recursion, Repartition, Union, Values, + build_join_schema, Column, CreateMemoryTable, DFSchemaRef, Expr, ExprVisitable, + Limit, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, Recursion, + Repartition, Union, Values, }; use crate::prelude::lit; use crate::scalar::ScalarValue; diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index d0cef0f3d376..cbe40d6dc51d 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -20,6 +20,7 @@ use arrow::datatypes::DataType; use sqlparser::ast::Ident; +use crate::logical_plan::ExprVisitable; use crate::logical_plan::{Expr, LogicalPlan}; use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128}; use crate::{ diff --git a/datafusion/tests/simplification.rs b/datafusion/tests/simplification.rs index 5edf43f5ccb2..0ce8e7685b83 100644 --- a/datafusion/tests/simplification.rs +++ b/datafusion/tests/simplification.rs @@ -18,6 +18,7 @@ //! This program demonstrates the DataFusion expression simplification API. use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::logical_plan::ExprSimplifiable; use datafusion::{ error::Result, execution::context::ExecutionProps,