Skip to content

Commit

Permalink
Support ORDER BY an aliased column (#5067)
Browse files Browse the repository at this point in the history
* Support grouping aliases in the order by clause

* simplify the code

* Add more test coverage
  • Loading branch information
alamb authored Feb 6, 2023
1 parent b09c1ee commit 6eb0e36
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 67 deletions.
95 changes: 95 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/order.slt
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,98 @@ SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC
query RC
SELECT c1, c2 FROM test WHERE c1 > 100000 ORDER BY c1 DESC, c2 ASC
----

#####
# Sorting and Grouping
#####
statement ok
create table foo as values (1, 2), (3, 4), (5, 6);

query ok rowsort
select * from foo
----
1 2
3 4
5 6

query ok
select column1 from foo order by column2;
----
1
3
5

query ok
select column1 from foo order by column1 + column2;
----
1
3
5

query ok
select column1 from foo order by column1 + column2;
----
1
3
5

query ok rowsort
select column1 + column2 from foo group by column1, column2;
----
11
3
7

query ok
select column1 + column2 from foo group by column1, column2 ORDER BY column2 desc;
----
11
7
3


# Cleanup
statement ok
drop table foo;


#####
# Tests for https://github.com/apache/arrow-datafusion/issues/4854
# Ordering / grouping by the same column
#####
statement ok
create or replace table t as select column1 as value, column2 as time from (select * from (values
(1, timestamp '2022-01-01 00:00:30'),
(2, timestamp '2022-01-01 01:00:10'),
(3, timestamp '2022-01-02 00:00:20')
) as sq) as sq


query I,I rowsort
select
sum(value) AS "value",
date_trunc('hour',time) AS "time"
FROM t
GROUP BY time;
----
1 2022-01-01T00:00:00
2 2022-01-01T01:00:00
3 2022-01-02T00:00:00

# should work fine
query I,I
select
sum(value) AS "value",
date_trunc('minute',time) AS "time"
FROM t
GROUP BY time
ORDER BY time;
----
1 2022-01-01T00:00:00
2 2022-01-01T01:00:00
3 2022-01-02T00:00:00


## Cleanup
statement ok
drop table t;
4 changes: 4 additions & 0 deletions datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ arrow = { version = "32.0.0", default-features = false }
datafusion-common = { path = "../common", version = "17.0.0" }
log = "^0.4"
sqlparser = "0.30"

[dev-dependencies]
ctor = "0.1.22"
env_logger = "0.10"
5 changes: 5 additions & 0 deletions datafusion/expr/src/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,11 @@ mod test {
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, ScalarValue};

#[ctor::ctor]
fn init() {
let _ = env_logger::try_init();
}

#[derive(Default)]
struct RecordingRewriter {
v: Vec<String>,
Expand Down
165 changes: 98 additions & 67 deletions datafusion/expr/src/expr_rewriter/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
//! Rewrite for order by expressions

use crate::expr::Sort;
use crate::expr_rewriter::{normalize_col, ExprRewritable, ExprRewriter};
use crate::logical_plan::Aggregate;
use crate::utils::grouping_set_to_exprlist;
use crate::expr_rewriter::{normalize_col, rewrite_expr};
use crate::{Expr, ExprSchemable, LogicalPlan};
use datafusion_common::Result;
use datafusion_common::{Column, Result};

/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
/// For example, `max(x)` is written to `col("MAX(x)")`
Expand Down Expand Up @@ -54,56 +52,84 @@ pub fn rewrite_sort_cols_by_aggs(
}

fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
match plan {
LogicalPlan::Aggregate(Aggregate {
input,
aggr_expr,
group_expr,
..
}) => {
struct Rewriter<'a> {
plan: &'a LogicalPlan,
input: &'a LogicalPlan,
aggr_expr: &'a Vec<Expr>,
distinct_group_exprs: &'a Vec<Expr>,
}
let plan_inputs = plan.inputs();

impl<'a> ExprRewriter for Rewriter<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
let normalized_expr = normalize_col(expr.clone(), self.plan);
if normalized_expr.is_err() {
// The expr is not based on Aggregate plan output. Skip it.
return Ok(expr);
}
let normalized_expr = normalized_expr?;
if let Some(found_agg) = self
.aggr_expr
.iter()
.chain(self.distinct_group_exprs)
.find(|a| (**a) == normalized_expr)
{
let agg = normalize_col(found_agg.clone(), self.plan)?;
let col = Expr::Column(
agg.to_field(self.input.schema())
.map(|f| f.qualified_column())?,
);
Ok(col)
} else {
Ok(expr)
}
}
}
// Joins, and Unions are not yet handled (should have a projection
// on top of them)
if plan_inputs.len() == 1 {
let proj_exprs = plan.expressions();
rewrite_in_terms_of_projection(expr, proj_exprs, plan_inputs[0])
} else {
Ok(expr)
}
}

let distinct_group_exprs = grouping_set_to_exprlist(group_expr.as_slice())?;
expr.rewrite(&mut Rewriter {
plan,
input,
aggr_expr,
distinct_group_exprs: &distinct_group_exprs,
})
/// Rewrites a sort expression in terms of the output of the previous [`LogicalPlan`]
///
/// Example:
///
/// Given an input expression such as `col(a) + col(b) + col(c)`
///
/// into `col(a) + col("b + c")`
///
/// Remember that:
/// 1. given a projection with exprs: [a, b + c]
/// 2. t produces an output schema with two columns "a", "b + c"
fn rewrite_in_terms_of_projection(
expr: Expr,
proj_exprs: Vec<Expr>,
input: &LogicalPlan,
) -> Result<Expr> {
// assumption is that each item in exprs, such as "b + c" is
// available as an output column named "b + c"
rewrite_expr(expr, |expr| {
// search for unnormalized names first such as "c1" (such as aliases)
if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) {
let col = Expr::Column(
found
.to_field(input.schema())
.map(|f| f.qualified_column())?,
);
return Ok(col);
}
LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]),
_ => Ok(expr),

// if that doesn't work, try to match the expression as an
// output column -- however first it must be "normalized"
// (e.g. "c1" --> "t.c1") because that normalization is done
// at the input of the aggregate.

let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) {
e
} else {
// The expr is not based on Aggregate plan output. Skip it.
return Ok(expr);
};

// expr is an actual expr like min(t.c2), but we are looking
// for a column with the same "MIN(C2)", so translate there
let name = normalized_expr.display_name()?;

let search_col = Expr::Column(Column {
relation: None,
name,
});

// look for the column named the same as this expr
if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) {
return Ok((*found).clone());
}
Ok(expr)
})
}

/// Does the underlying expr match e?
/// so avg(c) as average will match avgc
fn expr_match(needle: &Expr, haystack: &Expr) -> bool {
// check inside aliases
if let Expr::Alias(haystack, _) = &haystack {
haystack.as_ref() == needle
} else {
haystack == needle
}
}

Expand All @@ -115,7 +141,7 @@ mod test {
use arrow::datatypes::{DataType, Field, Schema};

use crate::{
col, lit, logical_plan::builder::LogicalTableSource, min, LogicalPlanBuilder,
avg, col, lit, logical_plan::builder::LogicalTableSource, min, LogicalPlanBuilder,
};

use super::*;
Expand All @@ -136,24 +162,24 @@ mod test {

let cases = vec![
TestCase {
desc: "c1 --> t.c1",
desc: "c1 --> c1",
input: sort(col("c1")),
expected: sort(col("t.c1")),
expected: sort(col("c1")),
},
TestCase {
desc: "c1 + c2 --> t.c1 + t.c2a",
desc: "c1 + c2 --> c1 + c2",
input: sort(col("c1") + col("c1")),
expected: sort(col("t.c1") + col("t.c1")),
expected: sort(col("c1") + col("c1")),
},
TestCase {
desc: r#"min(c2) --> "MIN(t.c2)" (column *named* "min(t.c2)"!)"#,
desc: r#"min(c2) --> "min(c2)"#,
input: sort(min(col("c2"))),
expected: sort(col("MIN(t.c2)")),
expected: sort(min(col("c2"))),
},
TestCase {
desc: r#"c1 + min(c2) --> "t.c1 + MIN(t.c2)" (column *named* "min(t.c2)"!)"#,
desc: r#"c1 + min(c2) --> "c1 + min(c2)"#,
input: sort(col("c1") + min(col("c2"))),
expected: sort(col("t.c1") + col("MIN(t.c2)")),
expected: sort(col("c1") + min(col("c2"))),
},
];

Expand All @@ -168,8 +194,8 @@ mod test {
.aggregate(
// gby c1
vec![col("c1")],
// agg: min(2)
vec![min(col("c2"))],
// agg: min(c2), avg(c3)
vec![min(col("c2")), avg(col("c3"))],
)
.unwrap()
// projects out an expression "c1" that is different than the column "c1"
Expand All @@ -178,6 +204,8 @@ mod test {
col("c1").add(lit(1)).alias("c1"),
// min(c2)
min(col("c2")),
// avg("c3") as average
avg(col("c3")).alias("average"),
])
.unwrap()
.build()
Expand All @@ -187,9 +215,8 @@ mod test {
TestCase {
desc: "c1 --> c1 -- column *named* c1 that came out of the projection, (not t.c1)",
input: sort(col("c1")),
// Incorrect due to https://github.com/apache/arrow-datafusion/issues/4854
// should be "c1" not t.c1
expected: sort(col("t.c1")),
expected: sort(col("c1")),
},
TestCase {
desc: r#"min(c2) --> "MIN(c2)" -- (column *named* "min(t.c2)"!)"#,
Expand All @@ -199,10 +226,14 @@ mod test {
TestCase {
desc: r#"c1 + min(c2) --> "c1 + MIN(c2)" -- (column *named* "min(t.c2)"!)"#,
input: sort(col("c1") + min(col("c2"))),
// Incorrect due to https://github.com/apache/arrow-datafusion/issues/4854
// should be "c1" not t.c1
expected: sort(col("t.c1") + col("MIN(t.c2)")),
}
expected: sort(col("c1") + col("MIN(t.c2)")),
},
TestCase {
desc: r#"avg(c3) --> "AVG(t.c3)" as average (column *named* "AVG(t.c3)", aliased)"#,
input: sort(avg(col("c3"))),
expected: sort(col("AVG(t.c3)").alias("average")),
},
];

for case in cases {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ impl ExprSchemable for Expr {
}

/// Returns a [arrow::datatypes::Field] compatible with this expression.
///
/// So for example, a projected expression `col(c1) + col(c2)` is
/// placed in an output field **named** col("c1 + c2")
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
match self {
Expr::Column(c) => Ok(DFField::new(
Expand Down
2 changes: 2 additions & 0 deletions datafusion/sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,6 @@ log = "^0.4"
sqlparser = "0.30"

[dev-dependencies]
ctor = "0.1.22"
env_logger = "0.10"
rstest = "0.16"
Loading

0 comments on commit 6eb0e36

Please sign in to comment.