Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ORDER BY an aliased column #5067

Merged
merged 5 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/order.slt
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,98 @@ SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC
query RC
SELECT c1, c2 FROM test WHERE c1 > 100000 ORDER BY c1 DESC, c2 ASC
----

#####
# Sorting and Grouping
#####
statement ok
create table foo as values (1, 2), (3, 4), (5, 6);

query ok rowsort
select * from foo
----
1 2
3 4
5 6

query ok
select column1 from foo order by column2;
----
1
3
5

query ok
select column1 from foo order by column1 + column2;
----
1
3
5

query ok
select column1 from foo order by column1 + column2;
----
1
3
5

query ok rowsort
select column1 + column2 from foo group by column1, column2;
----
11
3
7

query ok
select column1 + column2 from foo group by column1, column2 ORDER BY column2 desc;
----
11
7
3


# Cleanup
statement ok
drop table foo;


#####
# Tests for https://github.com/apache/arrow-datafusion/issues/4854
# Ordering / grouping by the same column
#####
statement ok
create or replace table t as select column1 as value, column2 as time from (select * from (values
(1, timestamp '2022-01-01 00:00:30'),
(2, timestamp '2022-01-01 01:00:10'),
(3, timestamp '2022-01-02 00:00:20')
) as sq) as sq


query I,I rowsort
select
sum(value) AS "value",
date_trunc('hour',time) AS "time"
FROM t
GROUP BY time;
----
1 2022-01-01T00:00:00
2 2022-01-01T01:00:00
3 2022-01-02T00:00:00

# should work fine
query I,I
select
sum(value) AS "value",
date_trunc('minute',time) AS "time"
FROM t
GROUP BY time
ORDER BY time;
----
1 2022-01-01T00:00:00
2 2022-01-01T01:00:00
3 2022-01-02T00:00:00


## Cleanup
statement ok
drop table t;
4 changes: 4 additions & 0 deletions datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ arrow = { version = "32.0.0", default-features = false }
datafusion-common = { path = "../common", version = "17.0.0" }
log = "^0.4"
sqlparser = "0.30"

[dev-dependencies]
ctor = "0.1.22"
env_logger = "0.10"
5 changes: 5 additions & 0 deletions datafusion/expr/src/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,11 @@ mod test {
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, ScalarValue};

#[ctor::ctor]
fn init() {
let _ = env_logger::try_init();
}

#[derive(Default)]
struct RecordingRewriter {
v: Vec<String>,
Expand Down
165 changes: 98 additions & 67 deletions datafusion/expr/src/expr_rewriter/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
//! Rewrite for order by expressions

use crate::expr::Sort;
use crate::expr_rewriter::{normalize_col, ExprRewritable, ExprRewriter};
use crate::logical_plan::Aggregate;
use crate::utils::grouping_set_to_exprlist;
use crate::expr_rewriter::{normalize_col, rewrite_expr};
use crate::{Expr, ExprSchemable, LogicalPlan};
use datafusion_common::Result;
use datafusion_common::{Column, Result};

/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
/// For example, `max(x)` is written to `col("MAX(x)")`
Expand Down Expand Up @@ -54,56 +52,84 @@ pub fn rewrite_sort_cols_by_aggs(
}

fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
match plan {
LogicalPlan::Aggregate(Aggregate {
input,
aggr_expr,
group_expr,
..
}) => {
struct Rewriter<'a> {
plan: &'a LogicalPlan,
input: &'a LogicalPlan,
aggr_expr: &'a Vec<Expr>,
distinct_group_exprs: &'a Vec<Expr>,
}
let plan_inputs = plan.inputs();

impl<'a> ExprRewriter for Rewriter<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
let normalized_expr = normalize_col(expr.clone(), self.plan);
if normalized_expr.is_err() {
// The expr is not based on Aggregate plan output. Skip it.
return Ok(expr);
}
let normalized_expr = normalized_expr?;
if let Some(found_agg) = self
.aggr_expr
.iter()
.chain(self.distinct_group_exprs)
.find(|a| (**a) == normalized_expr)
{
let agg = normalize_col(found_agg.clone(), self.plan)?;
let col = Expr::Column(
agg.to_field(self.input.schema())
.map(|f| f.qualified_column())?,
);
Ok(col)
} else {
Ok(expr)
}
}
}
// Joins, and Unions are not yet handled (should have a projection
// on top of them)
if plan_inputs.len() == 1 {
let proj_exprs = plan.expressions();
rewrite_in_terms_of_projection(expr, proj_exprs, plan_inputs[0])
} else {
Ok(expr)
}
}

let distinct_group_exprs = grouping_set_to_exprlist(group_expr.as_slice())?;
expr.rewrite(&mut Rewriter {
plan,
input,
aggr_expr,
distinct_group_exprs: &distinct_group_exprs,
})
/// Rewrites a sort expression in terms of the output of the previous [`LogicalPlan`]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new more general algorithm that works for LogicalPlan::Projection and LogicalPlan::Aggregation

///
/// Example:
///
/// Given an input expression such as `col(a) + col(b) + col(c)`
///
/// into `col(a) + col("b + c")`
///
/// Remember that:
/// 1. given a projection with exprs: [a, b + c]
/// 2. t produces an output schema with two columns "a", "b + c"
fn rewrite_in_terms_of_projection(
expr: Expr,
proj_exprs: Vec<Expr>,
input: &LogicalPlan,
) -> Result<Expr> {
// assumption is that each item in exprs, such as "b + c" is
// available as an output column named "b + c"
rewrite_expr(expr, |expr| {
// search for unnormalized names first such as "c1" (such as aliases)
if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) {
let col = Expr::Column(
found
.to_field(input.schema())
.map(|f| f.qualified_column())?,
);
return Ok(col);
}
LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]),
_ => Ok(expr),

// if that doesn't work, try to match the expression as an
// output column -- however first it must be "normalized"
// (e.g. "c1" --> "t.c1") because that normalization is done
// at the input of the aggregate.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand why this is needed, but several tests fail if this logic is removed and the previous logic did it as well

Copy link
Member

@jackwener jackwener Feb 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe(Just to offer a possibility) it's related with qualifier.

In the comment:

/// 1. given a projection with exprs: [a, b + c]
/// 2. t produces an output schema with two columns "a", "b + c"

But in fact, in the plan, [a, b + c] can be [table1.a, "table2.b+c"].
output schema also may be with qualifier like above.
Due to qualifier, some equations are not true, like 'a' == 't1.a' will be false.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in fact, in the plan, [a, b + c] can be [table1.a, "table2.b+c"].

Yes I think you are correct. I think what I found was confusing is that the qualifier is added "on demand" in several places and there isn't a clear cut line between "as written" and "qualified"


let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) {
e
} else {
// The expr is not based on Aggregate plan output. Skip it.
return Ok(expr);
};

// expr is an actual expr like min(t.c2), but we are looking
// for a column with the same "MIN(C2)", so translate there
let name = normalized_expr.display_name()?;

let search_col = Expr::Column(Column {
relation: None,
name,
});

// look for the column named the same as this expr
if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) {
return Ok((*found).clone());
}
Ok(expr)
})
}

/// Does the underlying expr match e?
/// so avg(c) as average will match avgc
fn expr_match(needle: &Expr, haystack: &Expr) -> bool {
// check inside aliases
if let Expr::Alias(haystack, _) = &haystack {
haystack.as_ref() == needle
} else {
haystack == needle
}
}

Expand All @@ -115,7 +141,7 @@ mod test {
use arrow::datatypes::{DataType, Field, Schema};

use crate::{
col, lit, logical_plan::builder::LogicalTableSource, min, LogicalPlanBuilder,
avg, col, lit, logical_plan::builder::LogicalTableSource, min, LogicalPlanBuilder,
};

use super::*;
Expand All @@ -136,24 +162,24 @@ mod test {

let cases = vec![
TestCase {
desc: "c1 --> t.c1",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why direct aggregate creation doesn't need this rewrite anymore. All the other tests in DataFusion CI pass and I wrote the unit tests last week to document the existing behavior in #5088

If reviewers prefer to keep the old behavior here, that is easy as the first commit in this PR, 0c16ef1 actually keeps all these tests passing with the existing "aggregate" code.

desc: "c1 --> c1",
input: sort(col("c1")),
expected: sort(col("t.c1")),
expected: sort(col("c1")),
},
TestCase {
desc: "c1 + c2 --> t.c1 + t.c2a",
desc: "c1 + c2 --> c1 + c2",
input: sort(col("c1") + col("c1")),
expected: sort(col("t.c1") + col("t.c1")),
expected: sort(col("c1") + col("c1")),
},
TestCase {
desc: r#"min(c2) --> "MIN(t.c2)" (column *named* "min(t.c2)"!)"#,
desc: r#"min(c2) --> "min(c2)"#,
input: sort(min(col("c2"))),
expected: sort(col("MIN(t.c2)")),
expected: sort(min(col("c2"))),
},
TestCase {
desc: r#"c1 + min(c2) --> "t.c1 + MIN(t.c2)" (column *named* "min(t.c2)"!)"#,
desc: r#"c1 + min(c2) --> "c1 + min(c2)"#,
input: sort(col("c1") + min(col("c2"))),
expected: sort(col("t.c1") + col("MIN(t.c2)")),
expected: sort(col("c1") + min(col("c2"))),
},
];

Expand All @@ -168,8 +194,8 @@ mod test {
.aggregate(
// gby c1
vec![col("c1")],
// agg: min(2)
vec![min(col("c2"))],
// agg: min(c2), avg(c3)
vec![min(col("c2")), avg(col("c3"))],
)
.unwrap()
// projects out an expression "c1" that is different than the column "c1"
Expand All @@ -178,6 +204,8 @@ mod test {
col("c1").add(lit(1)).alias("c1"),
// min(c2)
min(col("c2")),
// avg("c3") as average
avg(col("c3")).alias("average"),
])
.unwrap()
.build()
Expand All @@ -187,9 +215,8 @@ mod test {
TestCase {
desc: "c1 --> c1 -- column *named* c1 that came out of the projection, (not t.c1)",
input: sort(col("c1")),
// Incorrect due to https://github.com/apache/arrow-datafusion/issues/4854
// should be "c1" not t.c1
expected: sort(col("t.c1")),
expected: sort(col("c1")),
},
TestCase {
desc: r#"min(c2) --> "MIN(c2)" -- (column *named* "min(t.c2)"!)"#,
Expand All @@ -199,10 +226,14 @@ mod test {
TestCase {
desc: r#"c1 + min(c2) --> "c1 + MIN(c2)" -- (column *named* "min(t.c2)"!)"#,
input: sort(col("c1") + min(col("c2"))),
// Incorrect due to https://github.com/apache/arrow-datafusion/issues/4854
// should be "c1" not t.c1
expected: sort(col("t.c1") + col("MIN(t.c2)")),
}
expected: sort(col("c1") + col("MIN(t.c2)")),
},
TestCase {
desc: r#"avg(c3) --> "AVG(t.c3)" as average (column *named* "AVG(t.c3)", aliased)"#,
input: sort(avg(col("c3"))),
expected: sort(col("AVG(t.c3)").alias("average")),
},
];

for case in cases {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ impl ExprSchemable for Expr {
}

/// Returns a [arrow::datatypes::Field] compatible with this expression.
///
/// So for example, a projected expression `col(c1) + col(c2)` is
/// placed in an output field **named** col("c1 + c2")
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
match self {
Expr::Column(c) => Ok(DFField::new(
Expand Down
2 changes: 2 additions & 0 deletions datafusion/sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,6 @@ log = "^0.4"
sqlparser = "0.30"

[dev-dependencies]
ctor = "0.1.22"
env_logger = "0.10"
rstest = "0.16"
Loading