Skip to content

Commit

Permalink
rewrite simplifier loop to use iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
erratic-pattern committed May 5, 2024
1 parent 932c0a6 commit 3ddad40
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 54 deletions.
5 changes: 3 additions & 2 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ pub trait TreeNode: Sized {
/// TreeNodeRewriter::f_up(ChildNode2)
/// TreeNodeRewriter::f_up(ParentNode)
/// ```
fn rewrite<R: TreeNodeRewriter<Node = Self>>(
#[inline]
fn rewrite<R: TreeNodeRewriter<Node = Self> + ?Sized>(
self,
rewriter: &mut R,
) -> Result<Transformed<Self>> {
Expand Down Expand Up @@ -503,7 +504,7 @@ pub trait TreeNodeVisitor: Sized {
///
/// # See Also:
/// * [`TreeNode::visit`] to inspect borrowed `TreeNode`s
pub trait TreeNodeRewriter: Sized {
pub trait TreeNodeRewriter {
/// The node type which is rewritable.
type Node: TreeNode;

Expand Down
29 changes: 21 additions & 8 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ use datafusion_expr::{
};
use datafusion_functions::{math, string};
use datafusion_optimizer::optimizer::Optimizer;
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions};
use datafusion_optimizer::simplify_expressions::{
ExprSimplifier, RewriteCycleInfo, SimplifyExpressions,
};
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
use std::sync::Arc;

Expand Down Expand Up @@ -508,27 +510,38 @@ fn test_simplify(input_expr: Expr, expected_expr: Expr) {
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
);
}
fn test_simplify_with_cycle_count(
fn test_simplify_with_cycle_info(
input_expr: Expr,
expected_expr: Expr,
expected_count: u32,
expected_cycle_count: usize,
expected_iteration_count: usize,
) {
let info: MyInfo = MyInfo {
schema: expr_test_schema(),
execution_props: ExecutionProps::new(),
};
let simplifier = ExprSimplifier::new(info);
let (simplified_expr, count) = simplifier
.simplify_with_cycle_count(input_expr.clone())
let (
simplified_expr,
RewriteCycleInfo {
completed_cycles,
total_iterations,
},
) = simplifier
.simplify_with_cycle_info(input_expr.clone())
.expect("successfully evaluated");

assert_eq!(
simplified_expr, expected_expr,
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
);
assert_eq!(
count, expected_count,
"Mismatch simplifier cycle count\n Expected: {expected_count}\n Got:{count}"
completed_cycles, expected_cycle_count,
"Mismatch simplifier cycle count\n Expected: {expected_cycle_count}\n Got:{completed_cycles}"
);
assert_eq!(
total_iterations, expected_iteration_count,
"Mismatch simplifier cycle count\n Expected: {expected_iteration_count}\n Got:{total_iterations}"
);
}

Expand Down Expand Up @@ -687,5 +700,5 @@ fn test_simplify_cycles() {
let expr = cast(now(), DataType::Int64)
.lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX));
let expected = lit(true);
test_simplify_with_cycle_count(expr, expected, 3);
test_simplify_with_cycle_info(expr, expected, 2, 7);
}
204 changes: 160 additions & 44 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use std::borrow::Cow;
use std::collections::HashSet;
use std::ops::Not;
use std::ops::{ControlFlow, Not};

use arrow::{
array::{new_null_array, AsArray},
Expand Down Expand Up @@ -93,11 +93,11 @@ pub struct ExprSimplifier<S> {
/// true
canonicalize: bool,
/// Maximum number of simplifier cycles
max_simplifier_cycles: u32,
max_simplifier_cycles: usize,
}

pub const THRESHOLD_INLINE_INLIST: usize = 3;
pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;
pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: usize = 3;

impl<S: SimplifyInfo> ExprSimplifier<S> {
/// Create a new `ExprSimplifier` with the given `info` such as an
Expand Down Expand Up @@ -176,7 +176,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// assert_eq!(expr, b_lt_2);
/// ```
pub fn simplify(&self, expr: Expr) -> Result<Expr> {
Ok(self.simplify_with_cycle_count(expr)?.0)
Ok(self.simplify_with_cycle_info(expr)?.0)
}

/// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
Expand All @@ -186,33 +186,85 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
///
/// See [Self::simplify] for details and usage examples.
///
pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> {
pub fn simplify_with_cycle_info(
&self,
mut expr: Expr,
) -> Result<(Expr, RewriteCycleInfo)> {
let mut simplifier = Simplifier::new(&self.info);
let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
// let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);

if self.canonicalize {
expr = expr.rewrite(&mut Canonicalizer::new()).data()?
}

let mut num_cycles = 0;
loop {
let Transformed {
data, transformed, ..
} = expr
.rewrite(&mut const_evaluator)?
.transform_data(|expr| expr.rewrite(&mut simplifier))?
.transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
expr = data;
num_cycles += 1;
if !transformed || num_cycles >= self.max_simplifier_cycles {
break;
}
}
// shorten inlist should be started after other inlist rules are applied
expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
Ok((expr, num_cycles))
// let mut last_changed_count = 0;
// let mut total_iterations = 0;
// // we avoid using built-in iterators here for several reasons:
// // * iter_mut() cannot be cloned because of mutable ref restrictions, so we cannot use
// // cycle/repeat functions without cloning the underlying data or using an Arc/Rc
// // * cannot use fold() or try_fold() easily because we need to simultaneously handle errors
// // and short-circuiting control flow
// // * cannot use scan() because we need ownership of the state value in order to avoid
// // cloning in TreeNodeRewriter and OptimizerRule transformations
// for _ in 0..self.max_simplifier_cycles {
// let Transformed {
// data, transformed, ..
// } = expr.rewrite(&mut const_evaluator)?;
// if transformed {
// last_changed_count = 1;
// } else {
// last_changed_count += 1;
// }
// expr = data;
// total_iterations += 1;
// if last_changed_count >= 3 {
// break;
// }
// let Transformed {
// data, transformed, ..
// } = expr.rewrite(&mut simplifier)?;
// if transformed {
// last_changed_count = 1;
// } else {
// last_changed_count += 1;
// }
// expr = data;
// total_iterations += 1;
// if last_changed_count >= 3 {
// break;
// }
// let Transformed {
// data, transformed, ..
// } = expr.rewrite(&mut guarantee_rewriter)?;
// if transformed {
// last_changed_count = 1;
// } else {
// last_changed_count += 1;
// }
// expr = data;
// total_iterations += 1;
// if last_changed_count >= 3 {
// break;
// }
// }
// let info = RewriteCycleInfo {
// total_iterations,
// completed_cycles: total_iterations / 3,
// };
let (out, info) = rewrite_cycle(
&mut [
&mut const_evaluator,
&mut simplifier,
&mut guarantee_rewriter,
],
expr,
self.max_simplifier_cycles,
)?;
expr = out;
expr = expr.rewrite(&mut ShortenInListSimplifier::new()).data()?;
Ok((expr, info))
}

/// Apply type coercion to an [`Expr`] so that it can be
Expand Down Expand Up @@ -376,21 +428,15 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// // Expression: a IS NOT NULL
/// let expr = col("a").is_not_null();
///
/// // When using default maximum cycles, 2 cycles will be performed.
/// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap();
/// assert_eq!(simplified_expr, lit(true));
/// // 2 cycles were executed, but only 1 was needed
/// assert_eq!(count, 2);
///
/// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1.
/// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap();
/// let (simplified_expr, info) = simplifier.with_max_cycles(1).simplify_with_cycle_info(expr.clone()).unwrap();
/// // Expression has been rewritten to: (c = a AND b = 1)
/// assert_eq!(simplified_expr, lit(true));
/// // Only 1 cycle was executed
/// assert_eq!(count, 1);
/// assert_eq!(info.completed_cycles, 1);
///
/// ```
pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
pub fn with_max_cycles(mut self, max_simplifier_cycles: usize) -> Self {
self.max_simplifier_cycles = max_simplifier_cycles;
self
}
Expand Down Expand Up @@ -1746,6 +1792,72 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
Ok(Expr::InList(l1))
}

fn rewrite_cycle(
rewriters: &mut [&mut dyn TreeNodeRewriter<Node = Expr>],
init: Expr,
max_cycles: usize,
) -> Result<(Expr, RewriteCycleInfo)> {
let num_rewriters = rewriters.len();
struct FoldState {
expr: Expr,
consecutive_unchanged_count: usize,
total_iterations: usize,
}
let init = FoldState {
expr: init,
consecutive_unchanged_count: 0,
total_iterations: 0,
};
let result = (0..max_cycles).try_fold(init, move |state, _| {
rewriters
.iter_mut()
.try_fold(state, move |mut state, rewriter| {
match state.expr.rewrite(*rewriter) {
Err(e) => ControlFlow::Break(Err(e)),
Ok(Transformed {
data: expr,
transformed,
..
}) => {
state.expr = expr;
state.total_iterations += 1;
if transformed {
state.consecutive_unchanged_count = 0;
} else {
state.consecutive_unchanged_count += 1;
}
if state.consecutive_unchanged_count >= num_rewriters {
ControlFlow::Break(Ok(state))
} else {
ControlFlow::Continue(state)
}
}
}
})
});
let FoldState {
expr,
total_iterations,
..
} = match result {
ControlFlow::Break(result) => result?,
ControlFlow::Continue(state) => state,
};
Ok((
expr,
RewriteCycleInfo {
total_iterations,
completed_cycles: total_iterations / num_rewriters,
},
))
}

#[derive(Debug)]
pub struct RewriteCycleInfo {
pub total_iterations: usize,
pub completed_cycles: usize,
}

#[cfg(test)]
mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
Expand Down Expand Up @@ -2943,17 +3055,17 @@ mod tests {
try_simplify(expr).unwrap()
}

fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
fn try_simplify_with_cycle_info(expr: Expr) -> Result<(Expr, RewriteCycleInfo)> {
let schema = expr_test_schema();
let execution_props = ExecutionProps::new();
let simplifier = ExprSimplifier::new(
SimplifyContext::new(&execution_props).with_schema(schema),
);
simplifier.simplify_with_cycle_count(expr)
simplifier.simplify_with_cycle_info(expr)
}

fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
try_simplify_with_cycle_count(expr).unwrap()
fn simplify_with_cycle_info(expr: Expr) -> (Expr, RewriteCycleInfo) {
try_simplify_with_cycle_info(expr).unwrap()
}

fn simplify_with_guarantee(
Expand Down Expand Up @@ -3669,24 +3781,27 @@ mod tests {
// TRUE
let expr = lit(true);
let expected = lit(true);
let (expr, num_iter) = simplify_with_cycle_count(expr);
let (expr, info) = simplify_with_cycle_info(expr);
assert_eq!(expr, expected);
assert_eq!(num_iter, 1);
assert_eq!(info.completed_cycles, 1);
assert_eq!(info.total_iterations, 3);

// (true != NULL) OR (5 > 10)
let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
let expected = lit_bool_null();
let (expr, num_iter) = simplify_with_cycle_count(expr);
let (expr, info) = simplify_with_cycle_info(expr);
assert_eq!(expr, expected);
assert_eq!(num_iter, 2);
assert_eq!(info.completed_cycles, 1);
assert_eq!(info.total_iterations, 4);

// NOTE: this currently does not simplify
// (((c4 - 10) + 10) *100) / 100
let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
let expected = expr.clone();
let (expr, num_iter) = simplify_with_cycle_count(expr);
let (expr, info) = simplify_with_cycle_info(expr);
assert_eq!(expr, expected);
assert_eq!(num_iter, 1);
assert_eq!(info.completed_cycles, 1);
assert_eq!(info.total_iterations, 3);

// ((c4<1 or c3<2) and c3_non_null<3) and false
let expr = col("c4")
Expand All @@ -3695,8 +3810,9 @@ mod tests {
.and(col("c3_non_null").lt(lit(3)))
.and(lit(false));
let expected = lit(false);
let (expr, num_iter) = simplify_with_cycle_count(expr);
let (expr, info) = simplify_with_cycle_info(expr);
assert_eq!(expr, expected);
assert_eq!(num_iter, 2);
assert_eq!(info.completed_cycles, 1);
assert_eq!(info.total_iterations, 5);
}
}

0 comments on commit 3ddad40

Please sign in to comment.