diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 9d42f4fb1e0d4..0d0d7606e4257 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -172,7 +172,7 @@ pub trait TreeNode: Sized { /// TreeNodeRewriter::f_up(ChildNode2) /// TreeNodeRewriter::f_up(ParentNode) /// ``` - fn rewrite>( + fn rewrite + ?Sized>( self, rewriter: &mut R, ) -> Result> { @@ -503,7 +503,7 @@ pub trait TreeNodeVisitor: Sized { /// /// # See Also: /// * [`TreeNode::visit`] to inspect borrowed `TreeNode`s -pub trait TreeNodeRewriter: Sized { +pub trait TreeNodeRewriter { /// The node type which is rewritable. type Node: TreeNode; diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/simplification.rs index bb41929834267..6f5a10c3c5dc5 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/simplification.rs @@ -508,27 +508,34 @@ fn test_simplify(input_expr: Expr, expected_expr: Expr) { "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" ); } -fn test_simplify_with_cycle_count( +fn test_simplify_with_cycle_info( input_expr: Expr, expected_expr: Expr, - expected_count: u32, + expected_cycle_count: usize, + expected_iteration_count: usize, ) { let info: MyInfo = MyInfo { schema: expr_test_schema(), execution_props: ExecutionProps::new(), }; let simplifier = ExprSimplifier::new(info); - let (simplified_expr, count) = simplifier - .simplify_with_cycle_count(input_expr.clone()) + let (simplified_expr, info) = simplifier + .simplify_with_cycle_info(input_expr.clone()) .expect("successfully evaluated"); + let total_iterations = info.total_iterations(); + let completed_cycles = info.completed_cycles(); assert_eq!( simplified_expr, expected_expr, "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" ); assert_eq!( - count, expected_count, - "Mismatch simplifier cycle count\n Expected: {expected_count}\n Got:{count}" + completed_cycles, expected_cycle_count, + "Mismatch simplifier cycle count\n Expected: {expected_cycle_count}\n Got:{completed_cycles}" + ); + assert_eq!( + total_iterations, expected_iteration_count, + "Mismatch simplifier cycle count\n Expected: {expected_iteration_count}\n Got:{total_iterations}" ); } @@ -687,5 +694,5 @@ fn test_simplify_cycles() { let expr = cast(now(), DataType::Int64) .lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX)); let expected = lit(true); - test_simplify_with_cycle_count(expr, expected, 3); + test_simplify_with_cycle_info(expr, expected, 2, 7); } diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 793c87f8bc0c7..2caf1bf2c2ecc 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -48,6 +48,7 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; +pub mod rewrite_cycle; pub mod rewrite_disjunctive_predicate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; diff --git a/datafusion/optimizer/src/rewrite_cycle.rs b/datafusion/optimizer/src/rewrite_cycle.rs new file mode 100644 index 0000000000000..45afe074bde2c --- /dev/null +++ b/datafusion/optimizer/src/rewrite_cycle.rs @@ -0,0 +1,303 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// An API for executing a sequence of [TreeNodeRewriter]s in multiple passes. +/// +/// See [RewriteCycle] for more information. +/// +use std::ops::ControlFlow; + +use datafusion_common::{ + tree_node::{Transformed, TreeNode, TreeNodeRewriter}, + Result, +}; + +/// A builder with methods for executing a "rewrite cycle". +/// +/// Often the results of one optimization rule can uncover more optimizations in other optimization +/// rules. A sequence of optimization rules can be ran in multiple "passes" until there are no +/// more optmizations to make. +/// +/// The [RewriteCycle] handles logic for running these multi-pass loops. +/// It applies a sequence of [TreeNodeRewriter]s to a [TreeNode] by calling +/// [TreeNode::rewrite] in a loop - passing the output of one rewrite as the input to the next +/// rewrite - until [RewriteCycle::max_cycles] is reached or until every [TreeNode::rewrite] +/// returns a [Transformed::no] result in a consecutive sequence. +/// +/// There are two methods to execute a rewrite cycle: +/// * Static dispatch API: [Self::each_cycle] +/// * Dynamic dispatch API: [Self::fold_rewrites] +#[derive(Debug)] +pub struct RewriteCycle { + max_cycles: usize, +} + +impl Default for RewriteCycle { + fn default() -> Self { + Self::new() + } +} + +impl RewriteCycle { + /// The default maximum number of completed cycles to run before terminating the rewrite loop. + /// You can override this default with [Self::with_max_cycles] + pub const DEFAULT_MAX_CYCLES: usize = 3; + + /// Creates a new [RewriteCycle] with default options. + pub fn new() -> Self { + Self { + max_cycles: Self::DEFAULT_MAX_CYCLES, + } + } + /// Sets the [Self::max_cycles] to run before terminating the rewrite loop. + pub fn with_max_cycles(mut self, max_cycles: usize) -> Self { + self.max_cycles = max_cycles; + self + } + + /// The maximum number of completed cycles to run before terminating the rewrite loop. + /// Defaults to [Self::DEFAULT_MAX_CYCLES]. + pub fn max_cycles(&self) -> usize { + self.max_cycles + } + + /// Runs a rewrite cycle on the given [TreeNode] using the given callback function to + /// explicitly handle the cycle iterations. + /// + /// This API allows for static dispatch at the cost of some extra code complexity. For a + /// simpler API that uses dynamic dispatch, use [Self::fold_rewrites] + /// + /// The callback function is given a [RewriteCycleState], which manages the short-circuiting + /// logic of the loop. The function is expected to call [RewriteCycleState::rewrite] for each + /// individual [TreeNodeRewriter] in the cycle. [RewriteCycleState::rewrite] returns a [RewriteCycleControlFlow] + /// result, indicating whether the loop should break or continue. + /// + /// ```rust + /// + /// use arrow::datatypes::{Schema, Field, DataType}; + /// use datafusion_expr::{col, lit}; + /// use datafusion_common::{DataFusionError, ToDFSchema}; + /// use datafusion_expr::execution_props::ExecutionProps; + /// use datafusion_expr::simplify::SimplifyContext; + /// use datafusion_optimizer::rewrite_cycle::RewriteCycle; + /// use datafusion_optimizer::simplify_expressions::{Simplifier, ConstEvaluator}; + /// + /// // Create the schema + /// let schema = Schema::new(vec![ + /// Field::new("c1", DataType::Int64, true), + /// Field::new("c2", DataType::UInt32, true), + /// ]).to_dfschema_ref().unwrap(); + /// + /// // Create the rewriters + /// let props = ExecutionProps::new(); + /// let context = SimplifyContext::new(&props) + /// .with_schema(schema); + /// let mut simplifier = Simplifier::new(&context); + /// let mut const_evaluator = ConstEvaluator::try_new(&props).unwrap(); + /// + /// // ((c4 < 1 or c3 < 2) and c3 < 3) and false + /// let expr = col("c2") + /// .lt(lit(1)) + /// .or(col("c1").lt(lit(2))) + /// .and(col("c1").lt(lit(3))) + /// .and(lit(false)); + /// + /// // run the rewrite cycle loop + /// let (expr, info) = RewriteCycle::new() + /// .with_max_cycles(4) + /// .each_cycle(expr, |cycle_state| { + /// cycle_state + /// .rewrite(&mut const_evaluator)? + /// .rewrite(&mut simplifier) + /// }).unwrap(); + /// assert_eq!(expr, lit(false)); + /// assert_eq!(info.completed_cycles(), 2); + /// assert_eq!(info.total_iterations(), 4); + /// ``` + /// + pub fn each_cycle< + Node: TreeNode, + F: FnMut( + RewriteCycleState, + ) -> RewriteCycleControlFlow>, + >( + &self, + node: Node, + mut f: F, + ) -> Result<(Node, RewriteCycleInfo)> { + let mut state = RewriteCycleState::new(node); + if self.max_cycles == 0 { + return state.finish(); + } + // run first cycle then record number of rewriters + state = match f(state) { + ControlFlow::Break(result) => return result?.finish(), + ControlFlow::Continue(node) => node, + }; + state.record_cycle_length(); + if state.is_done() { + return state.finish(); + } + // run remaining cycles + match (1..self.max_cycles).try_fold(state, |state, _| f(state)) { + ControlFlow::Break(result) => result?.finish(), + ControlFlow::Continue(state) => state.finish(), + } + } + + /// Runs a rewrite cycle on the given tree node by applying the given sequence of [TreeNodeRewriter] trait objects. + /// + /// [TreeNode::rewrite] will be called for each [TreeNodeRewriter], with the + /// output [TreeNode] of each call being passed as the input to the next. + /// + /// This process repeats until either [Self::max_cycles] is reached or until a consecutive + /// sequence of [Transformed::no] results is returned for each [TreeNodeRewriter], + /// indicating that there is no more work to be done. + /// + /// For an API that avoids dynamic dispatch, see [Self::each_cycle] + /// + pub fn fold_rewrites( + &self, + node: Node, + rewriters: &mut [Box>], + ) -> Result<(Node, RewriteCycleInfo)> { + let mut state = RewriteCycleState::new(node); + state.with_cycle_length(rewriters.len()); + match (0..self.max_cycles).try_fold(state, |state, _| { + rewriters + .iter_mut() + .try_fold(state, |state, rewriter| state.rewrite(rewriter.as_mut())) + }) { + ControlFlow::Break(result) => result?.finish(), + ControlFlow::Continue(state) => state.finish(), + } + } +} + +/// Iteration state of a rewrite cycle. See [RewriteCycle::each_cycle] for usage examples and information. +#[derive(Debug)] +pub struct RewriteCycleState { + node: Node, + consecutive_unchanged_count: usize, + rewrite_count: usize, + cycle_length: Option, +} + +impl RewriteCycleState { + fn new(node: Node) -> Self { + Self { + node, + cycle_length: None, + consecutive_unchanged_count: 0, + rewrite_count: 0, + } + } + + /// Explicitly specify the cycle length. Can be used to set the cycle length when the sequence of rewriters is given as + /// a slice, such as with [RewriteCycle::fold_rewrites] + fn with_cycle_length(&mut self, cycle_length: usize) { + self.cycle_length = Some(cycle_length); + } + + /// Records the rewrite cycle length based on the current iteration count + /// + /// When the total number of writers is not known upfront - such as when using + /// [RewriteCycle::each_cycle] we need to keep count of the number of [Self::rewrite] + /// calls and then record the number at the end of the first cycle. + fn record_cycle_length(&mut self) { + self.cycle_length = Some(self.rewrite_count); + } + + /// Returns true when the loop has reached the maximum cycle length or when we've received + /// consecutive unchanged tree nodes equal to the total number of rewriters. + fn is_done(&self) -> bool { + // default value indicates we have not completed a cycle + let Some(cycle_length) = self.cycle_length else { + return false; + }; + self.consecutive_unchanged_count >= cycle_length + } + + /// Finishes the iteration by consuming the state and returning a [TreeNode] and + /// [RewriteCycleInfo] + fn finish(self) -> Result<(Node, RewriteCycleInfo)> { + Ok(( + self.node, + RewriteCycleInfo { + cycle_length: self.cycle_length.unwrap_or(self.rewrite_count), + total_iterations: self.rewrite_count, + }, + )) + } + + /// Calls [TreeNode::rewrite] and determines if the rewrite cycle should break or continue + /// based on the current [RewriteCycleState]. + pub fn rewrite + ?Sized>( + mut self, + rewriter: &mut R, + ) -> RewriteCycleControlFlow { + match self.node.rewrite(rewriter) { + Err(e) => ControlFlow::Break(Err(e)), + Ok(Transformed { + data: node, + transformed, + .. + }) => { + self.node = node; + self.rewrite_count += 1; + if transformed { + self.consecutive_unchanged_count = 0; + } else { + self.consecutive_unchanged_count += 1; + } + if self.is_done() { + ControlFlow::Break(Ok(self)) + } else { + ControlFlow::Continue(self) + } + } + } + } +} + +/// Information about a rewrite cycle, such as total number of iterations and number of fully +/// completed cycles. This is useful for testing purposes to ensure that optimzation passes are +/// working as expected. +#[derive(Debug, Clone, Copy)] +pub struct RewriteCycleInfo { + total_iterations: usize, + cycle_length: usize, +} + +impl RewriteCycleInfo { + /// The total number of **fully completed** cycles. + pub fn completed_cycles(&self) -> usize { + self.total_iterations / self.cycle_length + } + + /// The total number of [TreeNode::rewrite] calls. + pub fn total_iterations(&self) -> usize { + self.total_iterations + } + + /// The number of [TreeNode::rewrite] calls within a single cycle. + pub fn cycle_length(&self) -> usize { + self.cycle_length + } +} + +pub type RewriteCycleControlFlow = ControlFlow, T>; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 55052542a8bf9..3e51171420eb2 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -40,10 +40,13 @@ use datafusion_expr::{ use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; -use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::SimplifyInfo; +use crate::{ + analyzer::type_coercion::TypeCoercionRewriter, + rewrite_cycle::{RewriteCycle, RewriteCycleInfo}, +}; use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; @@ -92,11 +95,10 @@ pub struct ExprSimplifier { /// true canonicalize: bool, /// Maximum number of simplifier cycles - max_simplifier_cycles: u32, + max_simplifier_cycles: Option, } pub const THRESHOLD_INLINE_INLIST: usize = 3; -pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3; impl ExprSimplifier { /// Create a new `ExprSimplifier` with the given `info` such as an @@ -109,7 +111,7 @@ impl ExprSimplifier { info, guarantees: vec![], canonicalize: true, - max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES, + max_simplifier_cycles: None, } } @@ -175,7 +177,7 @@ impl ExprSimplifier { /// assert_eq!(expr, b_lt_2); /// ``` pub fn simplify(&self, expr: Expr) -> Result { - Ok(self.simplify_with_cycle_count(expr)?.0) + Ok(self.simplify_with_cycle_info(expr)?.0) } /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating @@ -185,36 +187,30 @@ impl ExprSimplifier { /// /// See [Self::simplify] for details and usage examples. /// - pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> { + pub fn simplify_with_cycle_info( + &self, + mut expr: Expr, + ) -> Result<(Expr, RewriteCycleInfo)> { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; - let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); + // let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); if self.canonicalize { expr = expr.rewrite(&mut Canonicalizer::new()).data()? } - - // Evaluating constants can enable new simplifications and - // simplifications can enable new constant evaluation - // see `Self::with_max_cycles` - let mut num_cycles = 0; - loop { - let Transformed { - data, transformed, .. - } = expr - .rewrite(&mut const_evaluator)? - .transform_data(|expr| expr.rewrite(&mut simplifier))? - .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?; - expr = data; - num_cycles += 1; - if !transformed || num_cycles >= self.max_simplifier_cycles { - break; - } + let mut rewrite_cycle = RewriteCycle::new(); + if let Some(max_cycles) = self.max_simplifier_cycles { + rewrite_cycle = rewrite_cycle.with_max_cycles(max_cycles); } - // shorten inlist should be started after other inlist rules are applied - expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?; - Ok((expr, num_cycles)) + let (mut expr, info) = rewrite_cycle.each_cycle(expr, |cycle_state| { + cycle_state + .rewrite(&mut const_evaluator)? + .rewrite(&mut simplifier)? + .rewrite(&mut guarantee_rewriter) + })?; + expr = expr.rewrite(&mut ShortenInListSimplifier::new()).data()?; + Ok((expr, info)) } /// Apply type coercion to an [`Expr`] so that it can be @@ -378,22 +374,16 @@ impl ExprSimplifier { /// // Expression: a IS NOT NULL /// let expr = col("a").is_not_null(); /// - /// // When using default maximum cycles, 2 cycles will be performed. - /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap(); - /// assert_eq!(simplified_expr, lit(true)); - /// // 2 cycles were executed, but only 1 was needed - /// assert_eq!(count, 2); - /// /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1. - /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap(); + /// let (simplified_expr, info) = simplifier.with_max_cycles(1).simplify_with_cycle_info(expr.clone()).unwrap(); /// // Expression has been rewritten to: (c = a AND b = 1) /// assert_eq!(simplified_expr, lit(true)); /// // Only 1 cycle was executed - /// assert_eq!(count, 1); + /// assert_eq!(info.completed_cycles(), 1); /// /// ``` - pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self { - self.max_simplifier_cycles = max_simplifier_cycles; + pub fn with_max_cycles(mut self, max_simplifier_cycles: usize) -> Self { + self.max_simplifier_cycles = Some(max_simplifier_cycles); self } } @@ -447,12 +437,11 @@ impl TreeNodeRewriter for Canonicalizer { } } -#[allow(rustdoc::private_intra_doc_links)] /// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time. /// /// Note it does not handle algebraic rewrites such as `(a or false)` /// --> `a`, which is handled by [`Simplifier`] -struct ConstEvaluator<'a> { +pub struct ConstEvaluator<'a> { /// `can_evaluate` is used during the depth-first-search of the /// `Expr` tree to track if any siblings (or their descendants) were /// non evaluatable (e.g. had a column reference or volatile @@ -472,9 +461,8 @@ struct ConstEvaluator<'a> { input_batch: RecordBatch, } -#[allow(dead_code)] /// The simplify result of ConstEvaluator -enum ConstSimplifyResult { +pub enum ConstSimplifyResult { // Expr was simplifed and contains the new expression Simplified(ScalarValue), // Expr was not simplified and original value is returned @@ -673,7 +661,7 @@ impl<'a> ConstEvaluator<'a> { /// * `false = true` and `true = false` to `false` /// * `!!expr` to `expr` /// * `expr = null` and `expr != null` to `null` -struct Simplifier<'a, S> { +pub struct Simplifier<'a, S> { info: &'a S, } @@ -2954,17 +2942,17 @@ mod tests { try_simplify(expr).unwrap() } - fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> { + fn try_simplify_with_cycle_info(expr: Expr) -> Result<(Expr, RewriteCycleInfo)> { let schema = expr_test_schema(); let execution_props = ExecutionProps::new(); let simplifier = ExprSimplifier::new( SimplifyContext::new(&execution_props).with_schema(schema), ); - simplifier.simplify_with_cycle_count(expr) + simplifier.simplify_with_cycle_info(expr) } - fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) { - try_simplify_with_cycle_count(expr).unwrap() + fn simplify_with_cycle_info(expr: Expr) -> (Expr, RewriteCycleInfo) { + try_simplify_with_cycle_info(expr).unwrap() } fn simplify_with_guarantee( @@ -3680,24 +3668,27 @@ mod tests { // TRUE let expr = lit(true); let expected = lit(true); - let (expr, num_iter) = simplify_with_cycle_count(expr); + let (expr, info) = simplify_with_cycle_info(expr); assert_eq!(expr, expected); - assert_eq!(num_iter, 1); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 3); // (true != NULL) OR (5 > 10) let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10))); let expected = lit_bool_null(); - let (expr, num_iter) = simplify_with_cycle_count(expr); + let (expr, info) = simplify_with_cycle_info(expr); assert_eq!(expr, expected); - assert_eq!(num_iter, 2); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 4); // NOTE: this currently does not simplify // (((c4 - 10) + 10) *100) / 100 let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100); let expected = expr.clone(); - let (expr, num_iter) = simplify_with_cycle_count(expr); + let (expr, info) = simplify_with_cycle_info(expr); assert_eq!(expr, expected); - assert_eq!(num_iter, 1); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 3); // ((c4<1 or c3<2) and c3_non_null<3) and false let expr = col("c4") @@ -3706,10 +3697,12 @@ mod tests { .and(col("c3_non_null").lt(lit(3))) .and(lit(false)); let expected = lit(false); - let (expr, num_iter) = simplify_with_cycle_count(expr); + let (expr, info) = simplify_with_cycle_info(expr); assert_eq!(expr, expected); - assert_eq!(num_iter, 2); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 5); } + #[test] fn test_simplify_udaf() { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());