Skip to content

Commit

Permalink
simplify the code
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jan 30, 2023
1 parent 0c16ef1 commit 6d85c1f
Showing 1 changed file with 29 additions and 70 deletions.
99 changes: 29 additions & 70 deletions datafusion/expr/src/expr_rewriter/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@

use crate::expr::Sort;
use crate::expr_rewriter::{normalize_col, rewrite_expr};
use crate::logical_plan::Aggregate;
use crate::utils::grouping_set_to_exprlist;
use crate::{Expr, ExprSchemable, LogicalPlan, Projection};
use crate::{Expr, ExprSchemable, LogicalPlan};
use datafusion_common::{Column, Result};

/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
Expand Down Expand Up @@ -54,73 +52,34 @@ pub fn rewrite_sort_cols_by_aggs(
}

fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
match plan {
LogicalPlan::Aggregate(aggregate) => {
rewrite_in_terms_of_aggregate(expr, plan, aggregate)
}
LogicalPlan::Projection(projection) => {
rewrite_in_terms_of_projection(expr, projection)
}
_ => Ok(expr),
let plan_inputs = plan.inputs();

// Joins, and Unions are not yet handled (should have a projection
// on top of them)
if plan_inputs.len() == 1 {
let proj_exprs = plan.expressions();
rewrite_in_terms_of_projection(expr, proj_exprs, plan_inputs[0])
} else {
Ok(expr)
}
}

/// rewrites a sort expression in terms of the output of an [`Aggregate`].
/// Rewrites a sort expression in terms of the output of the previous [`LogicalPlan`]
///
/// Note The SQL planner always puts a `Projection` at the output of
/// an aggregate, the other paths such as LogicalPlanBuilder can
/// create a Sort directly above an Aggregate
fn rewrite_in_terms_of_aggregate(
expr: Expr,
// the LogicalPlan::Aggregate
plan: &LogicalPlan,
aggregate: &Aggregate,
) -> Result<Expr> {
let Aggregate {
input,
aggr_expr,
group_expr,
..
} = aggregate;
let distinct_group_exprs = grouping_set_to_exprlist(group_expr.as_slice())?;

rewrite_expr(expr, |expr| {
// normalize in terms of the input plan
let normalized_expr = normalize_col(expr.clone(), plan);
if normalized_expr.is_err() {
// The expr is not based on Aggregate plan output. Skip it.
return Ok(expr);
}
let normalized_expr = normalized_expr?;
if let Some(found_agg) = aggr_expr
.iter()
.chain(distinct_group_exprs.iter())
.find(|a| (**a) == normalized_expr)
{
let agg = normalize_col(found_agg.clone(), plan)?;
let col =
Expr::Column(agg.to_field(input.schema()).map(|f| f.qualified_column())?);
Ok(col)
} else {
Ok(expr)
}
})
}

/// Rewrites a sort expression in terms of the output of a [`Projection`].
/// For exmaple, will rewrite an input expression such as
/// `a + b + c` into `col(a) + col("b + c")`
/// Example:
///
/// Given an input expression such as `col(a) + col(b) + col(c)`
///
/// into `col(a) + col("b + c")`
///
/// Remember that:
/// 1. given a projection with exprs: [a, b + c]
/// 2. t produces an output schema with two columns "a", "b + c"
fn rewrite_in_terms_of_projection(expr: Expr, projection: &Projection) -> Result<Expr> {
let Projection {
expr: proj_exprs,
input,
..
} = projection;

fn rewrite_in_terms_of_projection(
expr: Expr,
proj_exprs: Vec<Expr>,
input: &LogicalPlan,
) -> Result<Expr> {
// assumption is that each item in exprs, such as "b + c" is
// available as an output column named "b + c"
rewrite_expr(expr, |expr| {
Expand Down Expand Up @@ -203,24 +162,24 @@ mod test {

let cases = vec![
TestCase {
desc: "c1 --> t.c1",
desc: "c1 --> c1",
input: sort(col("c1")),
expected: sort(col("t.c1")),
expected: sort(col("c1")),
},
TestCase {
desc: "c1 + c2 --> t.c1 + t.c2a",
desc: "c1 + c2 --> c1 + c2",
input: sort(col("c1") + col("c1")),
expected: sort(col("t.c1") + col("t.c1")),
expected: sort(col("c1") + col("c1")),
},
TestCase {
desc: r#"min(c2) --> "MIN(t.c2)" (column *named* "min(t.c2)"!)"#,
desc: r#"min(c2) --> "min(c2)"#,
input: sort(min(col("c2"))),
expected: sort(col("MIN(t.c2)")),
expected: sort(min(col("c2"))),
},
TestCase {
desc: r#"c1 + min(c2) --> "t.c1 + MIN(t.c2)" (column *named* "min(t.c2)"!)"#,
desc: r#"c1 + min(c2) --> "c1 + min(c2)"#,
input: sort(col("c1") + min(col("c2"))),
expected: sort(col("t.c1") + col("MIN(t.c2)")),
expected: sort(col("c1") + min(col("c2"))),
},
];

Expand Down

0 comments on commit 6d85c1f

Please sign in to comment.