Skip to content

Commit

Permalink
feat: RewriteCycle API for short-circuiting optimizer loops
Browse files Browse the repository at this point in the history
  • Loading branch information
erratic-pattern committed May 15, 2024
1 parent e859426 commit 83de911
Show file tree
Hide file tree
Showing 5 changed files with 367 additions and 63 deletions.
4 changes: 2 additions & 2 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ pub trait TreeNode: Sized {
/// TreeNodeRewriter::f_up(ChildNode2)
/// TreeNodeRewriter::f_up(ParentNode)
/// ```
fn rewrite<R: TreeNodeRewriter<Node = Self>>(
fn rewrite<R: TreeNodeRewriter<Node = Self> + ?Sized>(
self,
rewriter: &mut R,
) -> Result<Transformed<Self>> {
Expand Down Expand Up @@ -503,7 +503,7 @@ pub trait TreeNodeVisitor: Sized {
///
/// # See Also:
/// * [`TreeNode::visit`] to inspect borrowed `TreeNode`s
pub trait TreeNodeRewriter: Sized {
pub trait TreeNodeRewriter {
/// The node type which is rewritable.
type Node: TreeNode;

Expand Down
21 changes: 14 additions & 7 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,27 +508,34 @@ fn test_simplify(input_expr: Expr, expected_expr: Expr) {
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
);
}
fn test_simplify_with_cycle_count(
fn test_simplify_with_cycle_info(
input_expr: Expr,
expected_expr: Expr,
expected_count: u32,
expected_cycle_count: usize,
expected_iteration_count: usize,
) {
let info: MyInfo = MyInfo {
schema: expr_test_schema(),
execution_props: ExecutionProps::new(),
};
let simplifier = ExprSimplifier::new(info);
let (simplified_expr, count) = simplifier
.simplify_with_cycle_count(input_expr.clone())
let (simplified_expr, info) = simplifier
.simplify_with_cycle_info(input_expr.clone())
.expect("successfully evaluated");

let total_iterations = info.total_iterations();
let completed_cycles = info.completed_cycles();
assert_eq!(
simplified_expr, expected_expr,
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
);
assert_eq!(
count, expected_count,
"Mismatch simplifier cycle count\n Expected: {expected_count}\n Got:{count}"
completed_cycles, expected_cycle_count,
"Mismatch simplifier cycle count\n Expected: {expected_cycle_count}\n Got:{completed_cycles}"
);
assert_eq!(
total_iterations, expected_iteration_count,
"Mismatch simplifier cycle count\n Expected: {expected_iteration_count}\n Got:{total_iterations}"
);
}

Expand Down Expand Up @@ -687,5 +694,5 @@ fn test_simplify_cycles() {
let expr = cast(now(), DataType::Int64)
.lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX));
let expected = lit(true);
test_simplify_with_cycle_count(expr, expected, 3);
test_simplify_with_cycle_info(expr, expected, 2, 7);
}
1 change: 1 addition & 0 deletions datafusion/optimizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub mod propagate_empty_relation;
pub mod push_down_filter;
pub mod push_down_limit;
pub mod replace_distinct_aggregate;
pub mod rewrite_cycle;
pub mod rewrite_disjunctive_predicate;
pub mod scalar_subquery_to_join;
pub mod simplify_expressions;
Expand Down
303 changes: 303 additions & 0 deletions datafusion/optimizer/src/rewrite_cycle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

/// An API for executing a sequence of [TreeNodeRewriter]s in multiple passes.
///
/// See [RewriteCycle] for more information.
///
use std::ops::ControlFlow;

use datafusion_common::{
tree_node::{Transformed, TreeNode, TreeNodeRewriter},
Result,
};

/// A builder with methods for executing a "rewrite cycle".
///
/// Often the results of one optimization rule can uncover more optimizations in other optimization
/// rules. A sequence of optimization rules can be ran in multiple "passes" until there are no
/// more optmizations to make.
///
/// The [RewriteCycle] handles logic for running these multi-pass loops.
/// It applies a sequence of [TreeNodeRewriter]s to a [TreeNode] by calling
/// [TreeNode::rewrite] in a loop - passing the output of one rewrite as the input to the next
/// rewrite - until [RewriteCycle::max_cycles] is reached or until every [TreeNode::rewrite]
/// returns a [Transformed::no] result in a consecutive sequence.
///
/// There are two methods to execute a rewrite cycle:
/// * Static dispatch API: [Self::each_cycle]
/// * Dynamic dispatch API: [Self::fold_rewrites]
#[derive(Debug)]
pub struct RewriteCycle {
max_cycles: usize,
}

impl Default for RewriteCycle {
fn default() -> Self {
Self::new()
}
}

impl RewriteCycle {
/// The default maximum number of completed cycles to run before terminating the rewrite loop.
/// You can override this default with [Self::with_max_cycles]
pub const DEFAULT_MAX_CYCLES: usize = 3;

/// Creates a new [RewriteCycle] with default options.
pub fn new() -> Self {
Self {
max_cycles: Self::DEFAULT_MAX_CYCLES,
}
}
/// Sets the [Self::max_cycles] to run before terminating the rewrite loop.
pub fn with_max_cycles(mut self, max_cycles: usize) -> Self {
self.max_cycles = max_cycles;
self
}

/// The maximum number of completed cycles to run before terminating the rewrite loop.
/// Defaults to [Self::DEFAULT_MAX_CYCLES].
pub fn max_cycles(&self) -> usize {
self.max_cycles
}

/// Runs a rewrite cycle on the given [TreeNode] using the given callback function to
/// explicitly handle the cycle iterations.
///
/// This API allows for static dispatch at the cost of some extra code complexity. For a
/// simpler API that uses dynamic dispatch, use [Self::fold_rewrites]
///
/// The callback function is given a [RewriteCycleState], which manages the short-circuiting
/// logic of the loop. The function is expected to call [RewriteCycleState::rewrite] for each
/// individual [TreeNodeRewriter] in the cycle. [RewriteCycleState::rewrite] returns a [RewriteCycleControlFlow]
/// result, indicating whether the loop should break or continue.
///
/// ```rust
///
/// use arrow::datatypes::{Schema, Field, DataType};
/// use datafusion_expr::{col, lit};
/// use datafusion_common::{DataFusionError, ToDFSchema};
/// use datafusion_expr::execution_props::ExecutionProps;
/// use datafusion_expr::simplify::SimplifyContext;
/// use datafusion_optimizer::rewrite_cycle::RewriteCycle;
/// use datafusion_optimizer::simplify_expressions::{Simplifier, ConstEvaluator};
///
/// // Create the schema
/// let schema = Schema::new(vec![
/// Field::new("c1", DataType::Int64, true),
/// Field::new("c2", DataType::UInt32, true),
/// ]).to_dfschema_ref().unwrap();
///
/// // Create the rewriters
/// let props = ExecutionProps::new();
/// let context = SimplifyContext::new(&props)
/// .with_schema(schema);
/// let mut simplifier = Simplifier::new(&context);
/// let mut const_evaluator = ConstEvaluator::try_new(&props).unwrap();
///
/// // ((c4 < 1 or c3 < 2) and c3 < 3) and false
/// let expr = col("c2")
/// .lt(lit(1))
/// .or(col("c1").lt(lit(2)))
/// .and(col("c1").lt(lit(3)))
/// .and(lit(false));
///
/// // run the rewrite cycle loop
/// let (expr, info) = RewriteCycle::new()
/// .with_max_cycles(4)
/// .each_cycle(expr, |cycle_state| {
/// cycle_state
/// .rewrite(&mut const_evaluator)?
/// .rewrite(&mut simplifier)
/// }).unwrap();
/// assert_eq!(expr, lit(false));
/// assert_eq!(info.completed_cycles(), 2);
/// assert_eq!(info.total_iterations(), 4);
/// ```
///
pub fn each_cycle<
Node: TreeNode,
F: FnMut(
RewriteCycleState<Node>,
) -> RewriteCycleControlFlow<RewriteCycleState<Node>>,
>(
&self,
node: Node,
mut f: F,
) -> Result<(Node, RewriteCycleInfo)> {
let mut state = RewriteCycleState::new(node);
if self.max_cycles == 0 {
return state.finish();
}
// run first cycle then record number of rewriters
state = match f(state) {
ControlFlow::Break(result) => return result?.finish(),
ControlFlow::Continue(node) => node,
};
state.record_cycle_length();
if state.is_done() {
return state.finish();
}
// run remaining cycles
match (1..self.max_cycles).try_fold(state, |state, _| f(state)) {
ControlFlow::Break(result) => result?.finish(),
ControlFlow::Continue(state) => state.finish(),
}
}

/// Runs a rewrite cycle on the given tree node by applying the given sequence of [TreeNodeRewriter] trait objects.
///
/// [TreeNode::rewrite] will be called for each [TreeNodeRewriter], with the
/// output [TreeNode] of each call being passed as the input to the next.
///
/// This process repeats until either [Self::max_cycles] is reached or until a consecutive
/// sequence of [Transformed::no] results is returned for each [TreeNodeRewriter],
/// indicating that there is no more work to be done.
///
/// For an API that avoids dynamic dispatch, see [Self::each_cycle]
///
pub fn fold_rewrites<Node: TreeNode>(
&self,
node: Node,
rewriters: &mut [Box<dyn TreeNodeRewriter<Node = Node>>],
) -> Result<(Node, RewriteCycleInfo)> {
let mut state = RewriteCycleState::new(node);
state.with_cycle_length(rewriters.len());
match (0..self.max_cycles).try_fold(state, |state, _| {
rewriters
.iter_mut()
.try_fold(state, |state, rewriter| state.rewrite(rewriter.as_mut()))
}) {
ControlFlow::Break(result) => result?.finish(),
ControlFlow::Continue(state) => state.finish(),
}
}
}

/// Iteration state of a rewrite cycle. See [RewriteCycle::each_cycle] for usage examples and information.
#[derive(Debug)]
pub struct RewriteCycleState<Node: TreeNode> {
node: Node,
consecutive_unchanged_count: usize,
rewrite_count: usize,
cycle_length: Option<usize>,
}

impl<Node: TreeNode> RewriteCycleState<Node> {
fn new(node: Node) -> Self {
Self {
node,
cycle_length: None,
consecutive_unchanged_count: 0,
rewrite_count: 0,
}
}

/// Explicitly specify the cycle length. Can be used to set the cycle length when the sequence of rewriters is given as
/// a slice, such as with [RewriteCycle::fold_rewrites]
fn with_cycle_length(&mut self, cycle_length: usize) {
self.cycle_length = Some(cycle_length);
}

/// Records the rewrite cycle length based on the current iteration count
///
/// When the total number of writers is not known upfront - such as when using
/// [RewriteCycle::each_cycle] we need to keep count of the number of [Self::rewrite]
/// calls and then record the number at the end of the first cycle.
fn record_cycle_length(&mut self) {
self.cycle_length = Some(self.rewrite_count);
}

/// Returns true when the loop has reached the maximum cycle length or when we've received
/// consecutive unchanged tree nodes equal to the total number of rewriters.
fn is_done(&self) -> bool {
// default value indicates we have not completed a cycle
let Some(cycle_length) = self.cycle_length else {
return false;
};
self.consecutive_unchanged_count >= cycle_length
}

/// Finishes the iteration by consuming the state and returning a [TreeNode] and
/// [RewriteCycleInfo]
fn finish(self) -> Result<(Node, RewriteCycleInfo)> {
Ok((
self.node,
RewriteCycleInfo {
cycle_length: self.cycle_length.unwrap_or(self.rewrite_count),
total_iterations: self.rewrite_count,
},
))
}

/// Calls [TreeNode::rewrite] and determines if the rewrite cycle should break or continue
/// based on the current [RewriteCycleState].
pub fn rewrite<R: TreeNodeRewriter<Node = Node> + ?Sized>(
mut self,
rewriter: &mut R,
) -> RewriteCycleControlFlow<Self> {
match self.node.rewrite(rewriter) {
Err(e) => ControlFlow::Break(Err(e)),
Ok(Transformed {
data: node,
transformed,
..
}) => {
self.node = node;
self.rewrite_count += 1;
if transformed {
self.consecutive_unchanged_count = 0;
} else {
self.consecutive_unchanged_count += 1;
}
if self.is_done() {
ControlFlow::Break(Ok(self))
} else {
ControlFlow::Continue(self)
}
}
}
}
}

/// Information about a rewrite cycle, such as total number of iterations and number of fully
/// completed cycles. This is useful for testing purposes to ensure that optimzation passes are
/// working as expected.
#[derive(Debug, Clone, Copy)]
pub struct RewriteCycleInfo {
total_iterations: usize,
cycle_length: usize,
}

impl RewriteCycleInfo {
/// The total number of **fully completed** cycles.
pub fn completed_cycles(&self) -> usize {
self.total_iterations / self.cycle_length
}

/// The total number of [TreeNode::rewrite] calls.
pub fn total_iterations(&self) -> usize {
self.total_iterations
}

/// The number of [TreeNode::rewrite] calls within a single cycle.
pub fn cycle_length(&self) -> usize {
self.cycle_length
}
}

pub type RewriteCycleControlFlow<T> = ControlFlow<Result<T>, T>;
Loading

0 comments on commit 83de911

Please sign in to comment.