Skip to content

Commit

Permalink
mid-cycle short-circuiting with static dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
erratic-pattern committed May 6, 2024
1 parent 3ddad40 commit 7330dfe
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 150 deletions.
5 changes: 2 additions & 3 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ pub trait TreeNode: Sized {
/// TreeNodeRewriter::f_up(ChildNode2)
/// TreeNodeRewriter::f_up(ParentNode)
/// ```
#[inline]
fn rewrite<R: TreeNodeRewriter<Node = Self> + ?Sized>(
fn rewrite<R: TreeNodeRewriter<Node = Self>>(
self,
rewriter: &mut R,
) -> Result<Transformed<Self>> {
Expand Down Expand Up @@ -504,7 +503,7 @@ pub trait TreeNodeVisitor: Sized {
///
/// # See Also:
/// * [`TreeNode::visit`] to inspect borrowed `TreeNode`s
pub trait TreeNodeRewriter {
pub trait TreeNodeRewriter: Sized {
/// The node type which is rewritable.
type Node: TreeNode;

Expand Down
14 changes: 4 additions & 10 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ use datafusion_expr::{
};
use datafusion_functions::{math, string};
use datafusion_optimizer::optimizer::Optimizer;
use datafusion_optimizer::simplify_expressions::{
ExprSimplifier, RewriteCycleInfo, SimplifyExpressions,
};
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions};
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
use std::sync::Arc;

Expand Down Expand Up @@ -521,16 +519,12 @@ fn test_simplify_with_cycle_info(
execution_props: ExecutionProps::new(),
};
let simplifier = ExprSimplifier::new(info);
let (
simplified_expr,
RewriteCycleInfo {
completed_cycles,
total_iterations,
},
) = simplifier
let (simplified_expr, info) = simplifier
.simplify_with_cycle_info(input_expr.clone())
.expect("successfully evaluated");

let total_iterations = info.total_iterations();
let completed_cycles = info.completed_cycles();
assert_eq!(
simplified_expr, expected_expr,
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
Expand Down
240 changes: 103 additions & 137 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
pub fn simplify_with_cycle_info(
&self,
mut expr: Expr,
) -> Result<(Expr, RewriteCycleInfo)> {
) -> Result<(Expr, RewriteCycle)> {
let mut simplifier = Simplifier::new(&self.info);
let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
// let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
Expand All @@ -198,71 +198,13 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
if self.canonicalize {
expr = expr.rewrite(&mut Canonicalizer::new()).data()?
}

// let mut last_changed_count = 0;
// let mut total_iterations = 0;
// // we avoid using built-in iterators here for several reasons:
// // * iter_mut() cannot be cloned because of mutable ref restrictions, so we cannot use
// // cycle/repeat functions without cloning the underlying data or using an Arc/Rc
// // * cannot use fold() or try_fold() easily because we need to simultaneously handle errors
// // and short-circuiting control flow
// // * cannot use scan() because we need ownership of the state value in order to avoid
// // cloning in TreeNodeRewriter and OptimizerRule transformations
// for _ in 0..self.max_simplifier_cycles {
// let Transformed {
// data, transformed, ..
// } = expr.rewrite(&mut const_evaluator)?;
// if transformed {
// last_changed_count = 1;
// } else {
// last_changed_count += 1;
// }
// expr = data;
// total_iterations += 1;
// if last_changed_count >= 3 {
// break;
// }
// let Transformed {
// data, transformed, ..
// } = expr.rewrite(&mut simplifier)?;
// if transformed {
// last_changed_count = 1;
// } else {
// last_changed_count += 1;
// }
// expr = data;
// total_iterations += 1;
// if last_changed_count >= 3 {
// break;
// }
// let Transformed {
// data, transformed, ..
// } = expr.rewrite(&mut guarantee_rewriter)?;
// if transformed {
// last_changed_count = 1;
// } else {
// last_changed_count += 1;
// }
// expr = data;
// total_iterations += 1;
// if last_changed_count >= 3 {
// break;
// }
// }
// let info = RewriteCycleInfo {
// total_iterations,
// completed_cycles: total_iterations / 3,
// };
let (out, info) = rewrite_cycle(
&mut [
&mut const_evaluator,
&mut simplifier,
&mut guarantee_rewriter,
],
expr,
self.max_simplifier_cycles,
)?;
expr = out;
let (mut expr, info) =
rewrite_cycle(expr, self.max_simplifier_cycles, |cycle, mut expr| {
expr = cycle.rewrite(expr, &mut const_evaluator)?;
expr = cycle.rewrite(expr, &mut simplifier)?;
expr = cycle.rewrite(expr, &mut guarantee_rewriter)?;
ControlFlow::Continue(expr)
})?;
expr = expr.rewrite(&mut ShortenInListSimplifier::new()).data()?;
Ok((expr, info))
}
Expand Down Expand Up @@ -433,7 +375,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// // Expression has been rewritten to: (c = a AND b = 1)
/// assert_eq!(simplified_expr, lit(true));
/// // Only 1 cycle was executed
/// assert_eq!(info.completed_cycles, 1);
/// assert_eq!(info.completed_cycles(), 1);
///
/// ```
pub fn with_max_cycles(mut self, max_simplifier_cycles: usize) -> Self {
Expand Down Expand Up @@ -1792,70 +1734,94 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
Ok(Expr::InList(l1))
}

fn rewrite_cycle(
rewriters: &mut [&mut dyn TreeNodeRewriter<Node = Expr>],
init: Expr,
max_cycles: usize,
) -> Result<(Expr, RewriteCycleInfo)> {
let num_rewriters = rewriters.len();
struct FoldState {
expr: Expr,
consecutive_unchanged_count: usize,
total_iterations: usize,
pub struct RewriteCycle {
consecutive_unchanged_count: usize,
total_iterations: usize,
num_rewriters: usize,
}
pub type RewriteCycleResult = ControlFlow<Result<Expr>, Expr>;

impl RewriteCycle {
fn new() -> Self {
RewriteCycle {
// use usize::MAX as default to avoid checking null in is_done() comparison
// real value is set later by record_num_rewriters
num_rewriters: usize::MAX,
consecutive_unchanged_count: 0,
total_iterations: 0,
}
}
let init = FoldState {
expr: init,
consecutive_unchanged_count: 0,
total_iterations: 0,
};
let result = (0..max_cycles).try_fold(init, move |state, _| {
rewriters
.iter_mut()
.try_fold(state, move |mut state, rewriter| {
match state.expr.rewrite(*rewriter) {
Err(e) => ControlFlow::Break(Err(e)),
Ok(Transformed {
data: expr,
transformed,
..
}) => {
state.expr = expr;
state.total_iterations += 1;
if transformed {
state.consecutive_unchanged_count = 0;
} else {
state.consecutive_unchanged_count += 1;
}
if state.consecutive_unchanged_count >= num_rewriters {
ControlFlow::Break(Ok(state))
} else {
ControlFlow::Continue(state)
}
}

pub fn completed_cycles(&self) -> usize {
// default value indicates we have not completed a cycle
if self.num_rewriters == usize::MAX {
0
} else {
self.total_iterations / self.num_rewriters
}
}

pub fn total_iterations(&self) -> usize {
self.total_iterations
}

pub(crate) fn record_num_rewriters(&mut self) {
self.num_rewriters = self.total_iterations;
}

pub(crate) fn is_done(&self) -> bool {
self.consecutive_unchanged_count >= self.num_rewriters
}

pub fn rewrite<R: TreeNodeRewriter<Node = Expr>>(
&mut self,
node: Expr,
rewriter: &mut R,
) -> RewriteCycleResult {
match node.rewrite(rewriter) {
Err(e) => ControlFlow::Break(Err(e)),
Ok(Transformed {
data: node,
transformed,
..
}) => {
self.total_iterations += 1;
if transformed {
self.consecutive_unchanged_count = 0;
} else {
self.consecutive_unchanged_count += 1;
}
})
});
let FoldState {
expr,
total_iterations,
..
} = match result {
ControlFlow::Break(result) => result?,
ControlFlow::Continue(state) => state,
};
Ok((
expr,
RewriteCycleInfo {
total_iterations,
completed_cycles: total_iterations / num_rewriters,
},
))
if self.is_done() {
ControlFlow::Break(Ok(node))
} else {
ControlFlow::Continue(node)
}
}
}
}
}

#[derive(Debug)]
pub struct RewriteCycleInfo {
pub total_iterations: usize,
pub completed_cycles: usize,
pub fn rewrite_cycle<F: FnMut(&mut RewriteCycle, Expr) -> RewriteCycleResult>(
node: Expr,
max_cycles: usize,
mut f: F,
) -> Result<(Expr, RewriteCycle)> {
let mut cycle = RewriteCycle::new();
// run first cycle then record number of rewriters
let node = match f(&mut cycle, node) {
ControlFlow::Break(result) => return result.map(|n| (n, cycle)),
ControlFlow::Continue(node) => node,
};
cycle.record_num_rewriters();
if cycle.is_done() {
return Ok((node, cycle));
}
// run remaining cycles
let node = match (1..max_cycles).try_fold(node, |node, _| f(&mut cycle, node)) {
ControlFlow::Break(result) => result?,
ControlFlow::Continue(node) => node,
};
Ok((node, cycle))
}

#[cfg(test)]
Expand Down Expand Up @@ -3055,7 +3021,7 @@ mod tests {
try_simplify(expr).unwrap()
}

fn try_simplify_with_cycle_info(expr: Expr) -> Result<(Expr, RewriteCycleInfo)> {
fn try_simplify_with_cycle_info(expr: Expr) -> Result<(Expr, RewriteCycle)> {
let schema = expr_test_schema();
let execution_props = ExecutionProps::new();
let simplifier = ExprSimplifier::new(
Expand All @@ -3064,7 +3030,7 @@ mod tests {
simplifier.simplify_with_cycle_info(expr)
}

fn simplify_with_cycle_info(expr: Expr) -> (Expr, RewriteCycleInfo) {
fn simplify_with_cycle_info(expr: Expr) -> (Expr, RewriteCycle) {
try_simplify_with_cycle_info(expr).unwrap()
}

Expand Down Expand Up @@ -3783,25 +3749,25 @@ mod tests {
let expected = lit(true);
let (expr, info) = simplify_with_cycle_info(expr);
assert_eq!(expr, expected);
assert_eq!(info.completed_cycles, 1);
assert_eq!(info.total_iterations, 3);
assert_eq!(info.completed_cycles(), 1);
assert_eq!(info.total_iterations(), 3);

// (true != NULL) OR (5 > 10)
let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
let expected = lit_bool_null();
let (expr, info) = simplify_with_cycle_info(expr);
assert_eq!(expr, expected);
assert_eq!(info.completed_cycles, 1);
assert_eq!(info.total_iterations, 4);
assert_eq!(info.completed_cycles(), 1);
assert_eq!(info.total_iterations(), 4);

// NOTE: this currently does not simplify
// (((c4 - 10) + 10) *100) / 100
let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
let expected = expr.clone();
let (expr, info) = simplify_with_cycle_info(expr);
assert_eq!(expr, expected);
assert_eq!(info.completed_cycles, 1);
assert_eq!(info.total_iterations, 3);
assert_eq!(info.completed_cycles(), 1);
assert_eq!(info.total_iterations(), 3);

// ((c4<1 or c3<2) and c3_non_null<3) and false
let expr = col("c4")
Expand All @@ -3812,7 +3778,7 @@ mod tests {
let expected = lit(false);
let (expr, info) = simplify_with_cycle_info(expr);
assert_eq!(expr, expected);
assert_eq!(info.completed_cycles, 1);
assert_eq!(info.total_iterations, 5);
assert_eq!(info.completed_cycles(), 1);
assert_eq!(info.total_iterations(), 5);
}
}

0 comments on commit 7330dfe

Please sign in to comment.