diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index a605eb09a924d..9d42f4fb1e0d4 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -172,8 +172,7 @@ pub trait TreeNode: Sized { /// TreeNodeRewriter::f_up(ChildNode2) /// TreeNodeRewriter::f_up(ParentNode) /// ``` - #[inline] - fn rewrite + ?Sized>( + fn rewrite>( self, rewriter: &mut R, ) -> Result> { @@ -504,7 +503,7 @@ pub trait TreeNodeVisitor: Sized { /// /// # See Also: /// * [`TreeNode::visit`] to inspect borrowed `TreeNode`s -pub trait TreeNodeRewriter { +pub trait TreeNodeRewriter: Sized { /// The node type which is rewritable. type Node: TreeNode; diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/simplification.rs index dc128a041e32f..6f5a10c3c5dc5 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/simplification.rs @@ -33,9 +33,7 @@ use datafusion_expr::{ }; use datafusion_functions::{math, string}; use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::simplify_expressions::{ - ExprSimplifier, RewriteCycleInfo, SimplifyExpressions, -}; +use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; use datafusion_optimizer::{OptimizerContext, OptimizerRule}; use std::sync::Arc; @@ -521,16 +519,12 @@ fn test_simplify_with_cycle_info( execution_props: ExecutionProps::new(), }; let simplifier = ExprSimplifier::new(info); - let ( - simplified_expr, - RewriteCycleInfo { - completed_cycles, - total_iterations, - }, - ) = simplifier + let (simplified_expr, info) = simplifier .simplify_with_cycle_info(input_expr.clone()) .expect("successfully evaluated"); + let total_iterations = info.total_iterations(); + let completed_cycles = info.completed_cycles(); assert_eq!( simplified_expr, expected_expr, "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 6b71a9c9291ba..3e7255fde9b08 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -189,7 +189,7 @@ impl ExprSimplifier { pub fn simplify_with_cycle_info( &self, mut expr: Expr, - ) -> Result<(Expr, RewriteCycleInfo)> { + ) -> Result<(Expr, RewriteCycle)> { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; // let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); @@ -198,71 +198,13 @@ impl ExprSimplifier { if self.canonicalize { expr = expr.rewrite(&mut Canonicalizer::new()).data()? } - - // let mut last_changed_count = 0; - // let mut total_iterations = 0; - // // we avoid using built-in iterators here for several reasons: - // // * iter_mut() cannot be cloned because of mutable ref restrictions, so we cannot use - // // cycle/repeat functions without cloning the underlying data or using an Arc/Rc - // // * cannot use fold() or try_fold() easily because we need to simultaneously handle errors - // // and short-circuiting control flow - // // * cannot use scan() because we need ownership of the state value in order to avoid - // // cloning in TreeNodeRewriter and OptimizerRule transformations - // for _ in 0..self.max_simplifier_cycles { - // let Transformed { - // data, transformed, .. - // } = expr.rewrite(&mut const_evaluator)?; - // if transformed { - // last_changed_count = 1; - // } else { - // last_changed_count += 1; - // } - // expr = data; - // total_iterations += 1; - // if last_changed_count >= 3 { - // break; - // } - // let Transformed { - // data, transformed, .. - // } = expr.rewrite(&mut simplifier)?; - // if transformed { - // last_changed_count = 1; - // } else { - // last_changed_count += 1; - // } - // expr = data; - // total_iterations += 1; - // if last_changed_count >= 3 { - // break; - // } - // let Transformed { - // data, transformed, .. - // } = expr.rewrite(&mut guarantee_rewriter)?; - // if transformed { - // last_changed_count = 1; - // } else { - // last_changed_count += 1; - // } - // expr = data; - // total_iterations += 1; - // if last_changed_count >= 3 { - // break; - // } - // } - // let info = RewriteCycleInfo { - // total_iterations, - // completed_cycles: total_iterations / 3, - // }; - let (out, info) = rewrite_cycle( - &mut [ - &mut const_evaluator, - &mut simplifier, - &mut guarantee_rewriter, - ], - expr, - self.max_simplifier_cycles, - )?; - expr = out; + let (mut expr, info) = + rewrite_cycle(expr, self.max_simplifier_cycles, |cycle, mut expr| { + expr = cycle.rewrite(expr, &mut const_evaluator)?; + expr = cycle.rewrite(expr, &mut simplifier)?; + expr = cycle.rewrite(expr, &mut guarantee_rewriter)?; + ControlFlow::Continue(expr) + })?; expr = expr.rewrite(&mut ShortenInListSimplifier::new()).data()?; Ok((expr, info)) } @@ -433,7 +375,7 @@ impl ExprSimplifier { /// // Expression has been rewritten to: (c = a AND b = 1) /// assert_eq!(simplified_expr, lit(true)); /// // Only 1 cycle was executed - /// assert_eq!(info.completed_cycles, 1); + /// assert_eq!(info.completed_cycles(), 1); /// /// ``` pub fn with_max_cycles(mut self, max_simplifier_cycles: usize) -> Self { @@ -1792,70 +1734,94 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result { Ok(Expr::InList(l1)) } -fn rewrite_cycle( - rewriters: &mut [&mut dyn TreeNodeRewriter], - init: Expr, - max_cycles: usize, -) -> Result<(Expr, RewriteCycleInfo)> { - let num_rewriters = rewriters.len(); - struct FoldState { - expr: Expr, - consecutive_unchanged_count: usize, - total_iterations: usize, +pub struct RewriteCycle { + consecutive_unchanged_count: usize, + total_iterations: usize, + num_rewriters: usize, +} +pub type RewriteCycleResult = ControlFlow, Expr>; + +impl RewriteCycle { + fn new() -> Self { + RewriteCycle { + // use usize::MAX as default to avoid checking null in is_done() comparison + // real value is set later by record_num_rewriters + num_rewriters: usize::MAX, + consecutive_unchanged_count: 0, + total_iterations: 0, + } } - let init = FoldState { - expr: init, - consecutive_unchanged_count: 0, - total_iterations: 0, - }; - let result = (0..max_cycles).try_fold(init, move |state, _| { - rewriters - .iter_mut() - .try_fold(state, move |mut state, rewriter| { - match state.expr.rewrite(*rewriter) { - Err(e) => ControlFlow::Break(Err(e)), - Ok(Transformed { - data: expr, - transformed, - .. - }) => { - state.expr = expr; - state.total_iterations += 1; - if transformed { - state.consecutive_unchanged_count = 0; - } else { - state.consecutive_unchanged_count += 1; - } - if state.consecutive_unchanged_count >= num_rewriters { - ControlFlow::Break(Ok(state)) - } else { - ControlFlow::Continue(state) - } - } + + pub fn completed_cycles(&self) -> usize { + // default value indicates we have not completed a cycle + if self.num_rewriters == usize::MAX { + 0 + } else { + self.total_iterations / self.num_rewriters + } + } + + pub fn total_iterations(&self) -> usize { + self.total_iterations + } + + pub(crate) fn record_num_rewriters(&mut self) { + self.num_rewriters = self.total_iterations; + } + + pub(crate) fn is_done(&self) -> bool { + self.consecutive_unchanged_count >= self.num_rewriters + } + + pub fn rewrite>( + &mut self, + node: Expr, + rewriter: &mut R, + ) -> RewriteCycleResult { + match node.rewrite(rewriter) { + Err(e) => ControlFlow::Break(Err(e)), + Ok(Transformed { + data: node, + transformed, + .. + }) => { + self.total_iterations += 1; + if transformed { + self.consecutive_unchanged_count = 0; + } else { + self.consecutive_unchanged_count += 1; } - }) - }); - let FoldState { - expr, - total_iterations, - .. - } = match result { - ControlFlow::Break(result) => result?, - ControlFlow::Continue(state) => state, - }; - Ok(( - expr, - RewriteCycleInfo { - total_iterations, - completed_cycles: total_iterations / num_rewriters, - }, - )) + if self.is_done() { + ControlFlow::Break(Ok(node)) + } else { + ControlFlow::Continue(node) + } + } + } + } } -#[derive(Debug)] -pub struct RewriteCycleInfo { - pub total_iterations: usize, - pub completed_cycles: usize, +pub fn rewrite_cycle RewriteCycleResult>( + node: Expr, + max_cycles: usize, + mut f: F, +) -> Result<(Expr, RewriteCycle)> { + let mut cycle = RewriteCycle::new(); + // run first cycle then record number of rewriters + let node = match f(&mut cycle, node) { + ControlFlow::Break(result) => return result.map(|n| (n, cycle)), + ControlFlow::Continue(node) => node, + }; + cycle.record_num_rewriters(); + if cycle.is_done() { + return Ok((node, cycle)); + } + // run remaining cycles + let node = match (1..max_cycles).try_fold(node, |node, _| f(&mut cycle, node)) { + ControlFlow::Break(result) => result?, + ControlFlow::Continue(node) => node, + }; + Ok((node, cycle)) } #[cfg(test)] @@ -3055,7 +3021,7 @@ mod tests { try_simplify(expr).unwrap() } - fn try_simplify_with_cycle_info(expr: Expr) -> Result<(Expr, RewriteCycleInfo)> { + fn try_simplify_with_cycle_info(expr: Expr) -> Result<(Expr, RewriteCycle)> { let schema = expr_test_schema(); let execution_props = ExecutionProps::new(); let simplifier = ExprSimplifier::new( @@ -3064,7 +3030,7 @@ mod tests { simplifier.simplify_with_cycle_info(expr) } - fn simplify_with_cycle_info(expr: Expr) -> (Expr, RewriteCycleInfo) { + fn simplify_with_cycle_info(expr: Expr) -> (Expr, RewriteCycle) { try_simplify_with_cycle_info(expr).unwrap() } @@ -3783,16 +3749,16 @@ mod tests { let expected = lit(true); let (expr, info) = simplify_with_cycle_info(expr); assert_eq!(expr, expected); - assert_eq!(info.completed_cycles, 1); - assert_eq!(info.total_iterations, 3); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 3); // (true != NULL) OR (5 > 10) let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10))); let expected = lit_bool_null(); let (expr, info) = simplify_with_cycle_info(expr); assert_eq!(expr, expected); - assert_eq!(info.completed_cycles, 1); - assert_eq!(info.total_iterations, 4); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 4); // NOTE: this currently does not simplify // (((c4 - 10) + 10) *100) / 100 @@ -3800,8 +3766,8 @@ mod tests { let expected = expr.clone(); let (expr, info) = simplify_with_cycle_info(expr); assert_eq!(expr, expected); - assert_eq!(info.completed_cycles, 1); - assert_eq!(info.total_iterations, 3); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 3); // ((c4<1 or c3<2) and c3_non_null<3) and false let expr = col("c4") @@ -3812,7 +3778,7 @@ mod tests { let expected = lit(false); let (expr, info) = simplify_with_cycle_info(expr); assert_eq!(expr, expected); - assert_eq!(info.completed_cycles, 1); - assert_eq!(info.total_iterations, 5); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 5); } }