Skip to content

Commit

Permalink
Support grouping aliases in the order by clause
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jan 27, 2023
1 parent 4fb5365 commit 2cd0a0c
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 57 deletions.
95 changes: 95 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/order.slt
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,98 @@ SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC
query RC
SELECT c1, c2 FROM test WHERE c1 > 100000 ORDER BY c1 DESC, c2 ASC
----

#####
# Sorting and Grouping
#####
statement ok
create table foo as values (1, 2), (3, 4), (5, 6);

query ok rowsort
select * from foo
----
1 2
3 4
5 6

query ok
select column1 from foo order by column2;
----
1
3
5

query ok
select column1 from foo order by column1 + column2;
----
1
3
5

query ok
select column1 from foo order by column1 + column2;
----
1
3
5

query ok rowsort
select column1 + column2 from foo group by column1, column2;
----
11
3
7

query ok
select column1 + column2 from foo group by column1, column2 ORDER BY column2 desc;
----
11
7
3


# Cleanup
statement ok
drop table foo;


#####
# Tests for https://github.com/apache/arrow-datafusion/issues/4854
# Ordering / grouping by the same column
#####
statement ok
create or replace table t as select column1 as value, column2 as time from (select * from (values
(1, timestamp '2022-01-01 00:00:30'),
(2, timestamp '2022-01-01 01:00:10'),
(3, timestamp '2022-01-02 00:00:20')
) as sq) as sq


query I,I rowsort
select
sum(value) AS "value",
date_trunc('hour',time) AS "time"
FROM t
GROUP BY time;
----
1 2022-01-01T00:00:00
2 2022-01-01T01:00:00
3 2022-01-02T00:00:00

# should work fine
query I,I
select
sum(value) AS "value",
date_trunc('minute',time) AS "time"
FROM t
GROUP BY time
ORDER BY time;
----
1 2022-01-01T00:00:00
2 2022-01-01T01:00:00
3 2022-01-02T00:00:00


## Cleanup
statement ok
drop table t;
4 changes: 4 additions & 0 deletions datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ arrow = { version = "31.0.0", default-features = false }
datafusion-common = { path = "../common", version = "17.0.0" }
log = "^0.4"
sqlparser = "0.30"

[dev-dependencies]
ctor = "0.1.22"
env_logger = "0.10"
5 changes: 5 additions & 0 deletions datafusion/expr/src/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,11 @@ mod test {
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, ScalarValue};

#[ctor::ctor]
fn init() {
let _ = env_logger::try_init();
}

#[derive(Default)]
struct RecordingRewriter {
v: Vec<String>,
Expand Down
185 changes: 128 additions & 57 deletions datafusion/expr/src/expr_rewriter/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
//! Rewrite for order by expressions

use crate::expr::Sort;
use crate::expr_rewriter::{normalize_col, ExprRewritable, ExprRewriter};
use crate::expr_rewriter::{normalize_col, rewrite_expr};
use crate::logical_plan::Aggregate;
use crate::utils::grouping_set_to_exprlist;
use crate::{Expr, ExprSchemable, LogicalPlan};
use datafusion_common::Result;
use crate::{Expr, ExprSchemable, LogicalPlan, Projection};
use datafusion_common::{Column, Result};

/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
/// For example, `max(x)` is written to `col("MAX(x)")`
Expand Down Expand Up @@ -55,55 +55,121 @@ pub fn rewrite_sort_cols_by_aggs(

fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
match plan {
LogicalPlan::Aggregate(Aggregate {
input,
aggr_expr,
group_expr,
..
}) => {
struct Rewriter<'a> {
plan: &'a LogicalPlan,
input: &'a LogicalPlan,
aggr_expr: &'a Vec<Expr>,
distinct_group_exprs: &'a Vec<Expr>,
}
LogicalPlan::Aggregate(aggregate) => {
rewrite_in_terms_of_aggregate(expr, plan, aggregate)
}
LogicalPlan::Projection(projection) => {
rewrite_in_terms_of_projection(expr, projection)
}
_ => Ok(expr),
}
}

impl<'a> ExprRewriter for Rewriter<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
let normalized_expr = normalize_col(expr.clone(), self.plan);
if normalized_expr.is_err() {
// The expr is not based on Aggregate plan output. Skip it.
return Ok(expr);
}
let normalized_expr = normalized_expr?;
if let Some(found_agg) = self
.aggr_expr
.iter()
.chain(self.distinct_group_exprs)
.find(|a| (**a) == normalized_expr)
{
let agg = normalize_col(found_agg.clone(), self.plan)?;
let col = Expr::Column(
agg.to_field(self.input.schema())
.map(|f| f.qualified_column())?,
);
Ok(col)
} else {
Ok(expr)
}
}
}
/// rewrites a sort expression in terms of the output of an [`Aggregate`].
///
///
fn rewrite_in_terms_of_aggregate(
expr: Expr,
// the LogicalPlan::Aggregate
plan: &LogicalPlan,
aggregate: &Aggregate,
) -> Result<Expr> {
let Aggregate {
input,
aggr_expr,
group_expr,
..
} = aggregate;

let distinct_group_exprs = grouping_set_to_exprlist(group_expr.as_slice())?;
expr.rewrite(&mut Rewriter {
plan,
input,
aggr_expr,
distinct_group_exprs: &distinct_group_exprs,
})
let distinct_group_exprs = grouping_set_to_exprlist(group_expr.as_slice())?;

rewrite_expr(expr, |expr| {
// normalize in terms of the input plan
let normalized_expr = normalize_col(expr.clone(), plan);
if normalized_expr.is_err() {
// The expr is not based on Aggregate plan output. Skip it.
return Ok(expr);
}
LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]),
_ => Ok(expr),
let normalized_expr = normalized_expr?;
if let Some(found_agg) = aggr_expr
.iter()
.chain(distinct_group_exprs.iter())
.find(|a| (**a) == normalized_expr)
{
let agg = normalize_col(found_agg.clone(), plan)?;
let col =
Expr::Column(agg.to_field(input.schema()).map(|f| f.qualified_column())?);
Ok(col)
} else {
Ok(expr)
}
})
}

/// Rewrites a sort expression in terms of the output of a [`Projection`].
/// For exmaple, will rewrite an input expression such as
/// `a + b + c` into `col(a) + col("b + c")`
///
/// Remember that:
/// 1. given a projection with exprs: [a, b + c]
/// 2. t produces an output schema with two columns "a", "b + c"
fn rewrite_in_terms_of_projection(expr: Expr, projection: &Projection) -> Result<Expr> {
let Projection {
expr: proj_exprs,
input,
..
} = projection;

// assumption is that each item in exprs, such as "b + c" is
// available as an output column named "b + c"
rewrite_expr(expr, |expr| {
// search for unnormalized names first such as "c1" (such as aliases)
if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) {
let col = Expr::Column(
found
.to_field(input.schema())
.map(|f| f.qualified_column())?,
);
return Ok(col);
}

// if that doesn't work, try to match the expression as an
// output column -- however first it must be "normalized"
// (e.g. "c1" --> "t.c1") because that normalization is done
// at the input of the aggregate.

let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) {
e
} else {
// The expr is not based on Aggregate plan output. Skip it.
return Ok(expr);
};

// expr is an actual expr like min(t.c2), but we are looking
// for a column with the same "MIN(C2)", so translate there
let name = normalized_expr.display_name()?;

let search_col = Expr::Column(Column {
relation: None,
name,
});

// look for the column named the same as this expr
if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) {
return Ok((*found).clone());
}
Ok(expr)
})
}

/// Does the underlying expr match e?
/// so avg(c) as average will match avgc
fn expr_match(needle: &Expr, haystack: &Expr) -> bool {
// check inside aliases
if let Expr::Alias(haystack, _) = &haystack {
haystack.as_ref() == needle
} else {
haystack == needle
}
}

Expand All @@ -115,7 +181,7 @@ mod test {
use arrow::datatypes::{DataType, Field, Schema};

use crate::{
col, lit, logical_plan::builder::LogicalTableSource, min, LogicalPlanBuilder,
avg, col, lit, logical_plan::builder::LogicalTableSource, min, LogicalPlanBuilder,
};

use super::*;
Expand Down Expand Up @@ -168,8 +234,8 @@ mod test {
.aggregate(
// gby c1
vec![col("c1")],
// agg: min(2)
vec![min(col("c2"))],
// agg: min(c2), avg(c3)
vec![min(col("c2")), avg(col("c3"))],
)
.unwrap()
// projects out an expression "c1" that is different than the column "c1"
Expand All @@ -178,6 +244,8 @@ mod test {
col("c1").add(lit(1)).alias("c1"),
// min(c2)
min(col("c2")),
// avg("c3") as average
avg(col("c3")).alias("average"),
])
.unwrap()
.build()
Expand All @@ -187,9 +255,8 @@ mod test {
TestCase {
desc: "c1 --> c1 -- column *named* c1 that came out of the projection, (not t.c1)",
input: sort(col("c1")),
// Incorrect due to https://github.com/apache/arrow-datafusion/issues/4854
// should be "c1" not t.c1
expected: sort(col("t.c1")),
expected: sort(col("c1")),
},
TestCase {
desc: r#"min(c2) --> "MIN(c2)" -- (column *named* "min(t.c2)"!)"#,
Expand All @@ -199,10 +266,14 @@ mod test {
TestCase {
desc: r#"c1 + min(c2) --> "c1 + MIN(c2)" -- (column *named* "min(t.c2)"!)"#,
input: sort(col("c1") + min(col("c2"))),
// Incorrect due to https://github.com/apache/arrow-datafusion/issues/4854
// should be "c1" not t.c1
expected: sort(col("t.c1") + col("MIN(t.c2)")),
}
expected: sort(col("c1") + col("MIN(t.c2)")),
},
TestCase {
desc: r#"avg(c3) --> "AVG(t.c3)" as average (column *named* "AVG(t.c3)", aliased)"#,
input: sort(avg(col("c3"))),
expected: sort(col("AVG(t.c3)").alias("average")),
},
];

for case in cases {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ impl ExprSchemable for Expr {
}

/// Returns a [arrow::datatypes::Field] compatible with this expression.
///
/// So for example, a projected expression `col(c1) + col(c2)` is
/// placed in an output field **named** col("c1 + c2")
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
match self {
Expr::Column(c) => Ok(DFField::new(
Expand Down
2 changes: 2 additions & 0 deletions datafusion/sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,6 @@ log = "^0.4"
sqlparser = "0.30"

[dev-dependencies]
ctor = "0.1.22"
env_logger = "0.10"
rstest = "*"
7 changes: 7 additions & 0 deletions datafusion/sql/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use datafusion_expr::Expr::Alias;
use datafusion_expr::{
Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning,
};
use log::debug;
use sqlparser::ast::{Expr as SQLExpr, WildcardAdditionalOptions};
use sqlparser::ast::{Select, SelectItem, TableWithJoins};
use std::collections::HashSet;
Expand Down Expand Up @@ -87,6 +88,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// having and group by clause may reference aliases defined in select projection
let projected_plan = self.project(plan.clone(), select_exprs.clone())?;
let mut combined_schema = (**projected_plan.schema()).clone();

debug!("AAL plan: {plan:?}");
debug!("projected_plan: {projected_plan:?}");
debug!("plan_schema: {:#?}", plan.schema());
debug!("combined_schema: {combined_schema:#?}");

combined_schema.merge(plan.schema());

// this alias map is resolved and looked up in both having exprs and group by exprs
Expand Down
Loading

0 comments on commit 2cd0a0c

Please sign in to comment.