-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support ORDER BY an aliased column #5067
Changes from all commits
0c16ef1
6d85c1f
3affe92
b7bfbbf
ba563ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,11 +18,9 @@ | |
//! Rewrite for order by expressions | ||
|
||
use crate::expr::Sort; | ||
use crate::expr_rewriter::{normalize_col, ExprRewritable, ExprRewriter}; | ||
use crate::logical_plan::Aggregate; | ||
use crate::utils::grouping_set_to_exprlist; | ||
use crate::expr_rewriter::{normalize_col, rewrite_expr}; | ||
use crate::{Expr, ExprSchemable, LogicalPlan}; | ||
use datafusion_common::Result; | ||
use datafusion_common::{Column, Result}; | ||
|
||
/// Rewrite sort on aggregate expressions to sort on the column of aggregate output | ||
/// For example, `max(x)` is written to `col("MAX(x)")` | ||
|
@@ -54,56 +52,84 @@ pub fn rewrite_sort_cols_by_aggs( | |
} | ||
|
||
fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> { | ||
match plan { | ||
LogicalPlan::Aggregate(Aggregate { | ||
input, | ||
aggr_expr, | ||
group_expr, | ||
.. | ||
}) => { | ||
struct Rewriter<'a> { | ||
plan: &'a LogicalPlan, | ||
input: &'a LogicalPlan, | ||
aggr_expr: &'a Vec<Expr>, | ||
distinct_group_exprs: &'a Vec<Expr>, | ||
} | ||
let plan_inputs = plan.inputs(); | ||
|
||
impl<'a> ExprRewriter for Rewriter<'a> { | ||
fn mutate(&mut self, expr: Expr) -> Result<Expr> { | ||
let normalized_expr = normalize_col(expr.clone(), self.plan); | ||
if normalized_expr.is_err() { | ||
// The expr is not based on Aggregate plan output. Skip it. | ||
return Ok(expr); | ||
} | ||
let normalized_expr = normalized_expr?; | ||
if let Some(found_agg) = self | ||
.aggr_expr | ||
.iter() | ||
.chain(self.distinct_group_exprs) | ||
.find(|a| (**a) == normalized_expr) | ||
{ | ||
let agg = normalize_col(found_agg.clone(), self.plan)?; | ||
let col = Expr::Column( | ||
agg.to_field(self.input.schema()) | ||
.map(|f| f.qualified_column())?, | ||
); | ||
Ok(col) | ||
} else { | ||
Ok(expr) | ||
} | ||
} | ||
} | ||
// Joins, and Unions are not yet handled (should have a projection | ||
// on top of them) | ||
if plan_inputs.len() == 1 { | ||
let proj_exprs = plan.expressions(); | ||
rewrite_in_terms_of_projection(expr, proj_exprs, plan_inputs[0]) | ||
} else { | ||
Ok(expr) | ||
} | ||
} | ||
|
||
let distinct_group_exprs = grouping_set_to_exprlist(group_expr.as_slice())?; | ||
expr.rewrite(&mut Rewriter { | ||
plan, | ||
input, | ||
aggr_expr, | ||
distinct_group_exprs: &distinct_group_exprs, | ||
}) | ||
/// Rewrites a sort expression in terms of the output of the previous [`LogicalPlan`] | ||
/// | ||
/// Example: | ||
/// | ||
/// Given an input expression such as `col(a) + col(b) + col(c)` | ||
/// | ||
/// into `col(a) + col("b + c")` | ||
/// | ||
/// Remember that: | ||
/// 1. given a projection with exprs: [a, b + c] | ||
/// 2. t produces an output schema with two columns "a", "b + c" | ||
fn rewrite_in_terms_of_projection( | ||
expr: Expr, | ||
proj_exprs: Vec<Expr>, | ||
input: &LogicalPlan, | ||
) -> Result<Expr> { | ||
// assumption is that each item in exprs, such as "b + c" is | ||
// available as an output column named "b + c" | ||
rewrite_expr(expr, |expr| { | ||
// search for unnormalized names first such as "c1" (such as aliases) | ||
if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { | ||
let col = Expr::Column( | ||
found | ||
.to_field(input.schema()) | ||
.map(|f| f.qualified_column())?, | ||
); | ||
return Ok(col); | ||
} | ||
LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]), | ||
_ => Ok(expr), | ||
|
||
// if that doesn't work, try to match the expression as an | ||
// output column -- however first it must be "normalized" | ||
// (e.g. "c1" --> "t.c1") because that normalization is done | ||
// at the input of the aggregate. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't fully understand why this is needed, but several tests fail if this logic is removed and the previous logic did it as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe(Just to offer a possibility) it's related with In the comment:
But in fact, in the plan, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes I think you are correct. I think what I found was confusing is that the qualifier is added "on demand" in several places and there isn't a clear cut line between "as written" and "qualified" |
||
|
||
let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) { | ||
e | ||
} else { | ||
// The expr is not based on Aggregate plan output. Skip it. | ||
return Ok(expr); | ||
}; | ||
|
||
// expr is an actual expr like min(t.c2), but we are looking | ||
// for a column with the same "MIN(C2)", so translate there | ||
let name = normalized_expr.display_name()?; | ||
|
||
let search_col = Expr::Column(Column { | ||
relation: None, | ||
name, | ||
}); | ||
|
||
// look for the column named the same as this expr | ||
if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) { | ||
return Ok((*found).clone()); | ||
} | ||
Ok(expr) | ||
}) | ||
} | ||
|
||
/// Does the underlying expr match e? | ||
/// so avg(c) as average will match avgc | ||
fn expr_match(needle: &Expr, haystack: &Expr) -> bool { | ||
// check inside aliases | ||
if let Expr::Alias(haystack, _) = &haystack { | ||
haystack.as_ref() == needle | ||
} else { | ||
haystack == needle | ||
} | ||
} | ||
|
||
|
@@ -115,7 +141,7 @@ mod test { | |
use arrow::datatypes::{DataType, Field, Schema}; | ||
|
||
use crate::{ | ||
col, lit, logical_plan::builder::LogicalTableSource, min, LogicalPlanBuilder, | ||
avg, col, lit, logical_plan::builder::LogicalTableSource, min, LogicalPlanBuilder, | ||
}; | ||
|
||
use super::*; | ||
|
@@ -136,24 +162,24 @@ mod test { | |
|
||
let cases = vec![ | ||
TestCase { | ||
desc: "c1 --> t.c1", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure why direct aggregate creation doesn't need this rewrite anymore. All the other tests in DataFusion CI pass and I wrote the unit tests last week to document the existing behavior in #5088 If reviewers prefer to keep the old behavior here, that is easy as the first commit in this PR, 0c16ef1 actually keeps all these tests passing with the existing "aggregate" code. |
||
desc: "c1 --> c1", | ||
input: sort(col("c1")), | ||
expected: sort(col("t.c1")), | ||
expected: sort(col("c1")), | ||
}, | ||
TestCase { | ||
desc: "c1 + c2 --> t.c1 + t.c2a", | ||
desc: "c1 + c2 --> c1 + c2", | ||
input: sort(col("c1") + col("c1")), | ||
expected: sort(col("t.c1") + col("t.c1")), | ||
expected: sort(col("c1") + col("c1")), | ||
}, | ||
TestCase { | ||
desc: r#"min(c2) --> "MIN(t.c2)" (column *named* "min(t.c2)"!)"#, | ||
desc: r#"min(c2) --> "min(c2)"#, | ||
input: sort(min(col("c2"))), | ||
expected: sort(col("MIN(t.c2)")), | ||
expected: sort(min(col("c2"))), | ||
}, | ||
TestCase { | ||
desc: r#"c1 + min(c2) --> "t.c1 + MIN(t.c2)" (column *named* "min(t.c2)"!)"#, | ||
desc: r#"c1 + min(c2) --> "c1 + min(c2)"#, | ||
input: sort(col("c1") + min(col("c2"))), | ||
expected: sort(col("t.c1") + col("MIN(t.c2)")), | ||
expected: sort(col("c1") + min(col("c2"))), | ||
}, | ||
]; | ||
|
||
|
@@ -168,8 +194,8 @@ mod test { | |
.aggregate( | ||
// gby c1 | ||
vec![col("c1")], | ||
// agg: min(2) | ||
vec![min(col("c2"))], | ||
// agg: min(c2), avg(c3) | ||
vec![min(col("c2")), avg(col("c3"))], | ||
) | ||
.unwrap() | ||
// projects out an expression "c1" that is different than the column "c1" | ||
|
@@ -178,6 +204,8 @@ mod test { | |
col("c1").add(lit(1)).alias("c1"), | ||
// min(c2) | ||
min(col("c2")), | ||
// avg("c3") as average | ||
avg(col("c3")).alias("average"), | ||
]) | ||
.unwrap() | ||
.build() | ||
|
@@ -187,9 +215,8 @@ mod test { | |
TestCase { | ||
desc: "c1 --> c1 -- column *named* c1 that came out of the projection, (not t.c1)", | ||
input: sort(col("c1")), | ||
// Incorrect due to https://github.com/apache/arrow-datafusion/issues/4854 | ||
// should be "c1" not t.c1 | ||
expected: sort(col("t.c1")), | ||
expected: sort(col("c1")), | ||
}, | ||
TestCase { | ||
desc: r#"min(c2) --> "MIN(c2)" -- (column *named* "min(t.c2)"!)"#, | ||
|
@@ -199,10 +226,14 @@ mod test { | |
TestCase { | ||
desc: r#"c1 + min(c2) --> "c1 + MIN(c2)" -- (column *named* "min(t.c2)"!)"#, | ||
input: sort(col("c1") + min(col("c2"))), | ||
// Incorrect due to https://github.com/apache/arrow-datafusion/issues/4854 | ||
// should be "c1" not t.c1 | ||
expected: sort(col("t.c1") + col("MIN(t.c2)")), | ||
} | ||
expected: sort(col("c1") + col("MIN(t.c2)")), | ||
}, | ||
TestCase { | ||
desc: r#"avg(c3) --> "AVG(t.c3)" as average (column *named* "AVG(t.c3)", aliased)"#, | ||
input: sort(avg(col("c3"))), | ||
expected: sort(col("AVG(t.c3)").alias("average")), | ||
}, | ||
]; | ||
|
||
for case in cases { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,4 +44,6 @@ log = "^0.4" | |
sqlparser = "0.30" | ||
|
||
[dev-dependencies] | ||
ctor = "0.1.22" | ||
env_logger = "0.10" | ||
rstest = "0.16" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the new more general algorithm that works for
LogicalPlan::Projection
andLogicalPlan::Aggregation