From 30846e04b780091edf484a6ceccd7346ba5d3178 Mon Sep 17 00:00:00 2001 From: jychen7 Date: Thu, 3 Jun 2021 03:14:15 +0000 Subject: [PATCH] #215 resolve aliases for group by exprs --- datafusion/src/sql/planner.rs | 39 ++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 63499aa1abe2..231f193f083e 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -582,13 +582,26 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // All of the aggregate expressions (deduplicated). let aggr_exprs = find_aggregate_exprs(&aggr_expr_haystack); + let group_by_exprs = select.group_by + .iter() + .map(|e| { + let group_by_expr = self.sql_expr_to_logical_expr(e)?; + let group_by_expr = resolve_aliases_to_exprs( + &group_by_expr, + &extract_aliases(&select_exprs), + )?; + self.validate_schema_satisfies_exprs(plan.schema(), &[group_by_expr.clone()])?; + Ok(group_by_expr) + }) + .collect::>>()?; + let (plan, select_exprs_post_aggr, having_expr_post_aggr_opt) = - if !select.group_by.is_empty() || !aggr_exprs.is_empty() { + if !group_by_exprs.is_empty() || !aggr_exprs.is_empty() { self.aggregate( &plan, &select_exprs, &having_expr_opt, - &select.group_by, + group_by_exprs, aggr_exprs, )? } else { @@ -691,14 +704,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: &LogicalPlan, select_exprs: &[Expr], having_expr_opt: &Option, - group_by: &[SQLExpr], + group_by_exprs: Vec, aggr_exprs: Vec, ) -> Result<(LogicalPlan, Vec, Option)> { - let group_by_exprs = group_by - .iter() - .map(|e| self.sql_to_rex(e, &input.schema())) - .collect::>>()?; - let aggr_projection_exprs = group_by_exprs .iter() .chain(aggr_exprs.iter()) @@ -2285,15 +2293,12 @@ mod tests { } #[test] - fn select_simple_aggregate_with_groupby_cannot_use_alias() { - let sql = "SELECT state AS x, MAX(age) FROM person GROUP BY x"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - format!( - "Plan(\"Invalid identifier \\\'x\\\' for schema {}\")", - PERSON_COLUMN_NAMES - ), - format!("{:?}", err) + fn select_simple_aggregate_with_groupby_can_use_alias() { + quick_test( + "SELECT state AS a, MIN(age) AS b FROM person GROUP BY a", + "Projection: #state AS a, #MIN(age) AS b\ + \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\ + \n TableScan: person projection=None", ); }