From 269563afd00df40653439afcf6b372438a4d0ae6 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Mon, 18 Mar 2024 09:47:53 -0400 Subject: [PATCH] Improve Robustness of Unparser Testing and Implementation (#9623) * support having, simplify logic * fmt * expr rewriting and more... * lint * add license * retry windows ci * retry windows ci 2 * subquery expr support * make test even harder * retry windows * cargo fmt * retry windows * retry windows * retry windows (last try) * add comment explaining test --- datafusion/sql/src/unparser/expr.rs | 46 +++++++-- datafusion/sql/src/unparser/mod.rs | 1 + datafusion/sql/src/unparser/plan.rs | 87 +++++++--------- datafusion/sql/src/unparser/utils.rs | 84 +++++++++++++++ datafusion/sql/tests/sql_integration.rs | 130 ++++++++++++++---------- 5 files changed, 232 insertions(+), 116 deletions(-) create mode 100644 datafusion/sql/src/unparser/utils.rs diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 403a7c6193d0..9680177d736f 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -18,7 +18,7 @@ use arrow_array::{Date32Array, Date64Array}; use arrow_schema::DataType; use datafusion_common::{ - internal_datafusion_err, not_impl_err, Column, Result, ScalarValue, + internal_datafusion_err, not_impl_err, plan_err, Column, Result, ScalarValue, }; use datafusion_expr::{ expr::{AggregateFunctionDefinition, Alias, InList, ScalarFunction, WindowFunction}, @@ -40,7 +40,7 @@ use super::Unparser; /// let expr = col("a").gt(lit(4)); /// let sql = expr_to_sql(&expr).unwrap(); /// -/// assert_eq!(format!("{}", sql), "(a > 4)") +/// assert_eq!(format!("{}", sql), "(\"a\" > 4)") /// ``` pub fn expr_to_sql(expr: &Expr) -> Result { let unparser = Unparser::default(); @@ -151,6 +151,36 @@ impl Unparser<'_> { order_by: vec![], })) } + Expr::ScalarSubquery(subq) => { + let sub_statement = self.plan_to_sql(subq.subquery.as_ref())?; + let sub_query = if let ast::Statement::Query(inner_query) = sub_statement + { + inner_query + } else { + return plan_err!( + "Subquery must be a Query, but found {sub_statement:?}" + ); + }; + Ok(ast::Expr::Subquery(sub_query)) + } + Expr::InSubquery(insubq) => { + let inexpr = Box::new(self.expr_to_sql(insubq.expr.as_ref())?); + let sub_statement = + self.plan_to_sql(insubq.subquery.subquery.as_ref())?; + let sub_query = if let ast::Statement::Query(inner_query) = sub_statement + { + inner_query + } else { + return plan_err!( + "Subquery must be a Query, but found {sub_statement:?}" + ); + }; + Ok(ast::Expr::InSubquery { + expr: inexpr, + subquery: sub_query, + negated: insubq.negated, + }) + } _ => not_impl_err!("Unsupported expression: {expr:?}"), } } @@ -169,7 +199,7 @@ impl Unparser<'_> { pub(super) fn new_ident(&self, str: String) -> ast::Ident { ast::Ident { value: str, - quote_style: self.dialect.identifier_quote_style(), + quote_style: Some(self.dialect.identifier_quote_style().unwrap_or('"')), } } @@ -491,28 +521,28 @@ mod tests { #[test] fn expr_to_sql_ok() -> Result<()> { let tests: Vec<(Expr, &str)> = vec![ - ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#), + ((col("a") + col("b")).gt(lit(4)), r#"(("a" + "b") > 4)"#), ( Expr::Column(Column { relation: Some(TableReference::partial("a", "b")), name: "c".to_string(), }) .gt(lit(4)), - r#"(a.b.c > 4)"#, + r#"("a"."b"."c" > 4)"#, ), ( Expr::Cast(Cast { expr: Box::new(col("a")), data_type: DataType::Date64, }), - r#"CAST(a AS DATETIME)"#, + r#"CAST("a" AS DATETIME)"#, ), ( Expr::Cast(Cast { expr: Box::new(col("a")), data_type: DataType::UInt32, }), - r#"CAST(a AS INTEGER UNSIGNED)"#, + r#"CAST("a" AS INTEGER UNSIGNED)"#, ), ( Expr::Literal(ScalarValue::Date64(Some(0))), @@ -549,7 +579,7 @@ mod tests { order_by: None, null_treatment: None, }), - "SUM(a)", + r#"SUM("a")"#, ), ( Expr::AggregateFunction(AggregateFunction { diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index e67ebc198018..fb0285901c3f 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -18,6 +18,7 @@ mod ast; mod expr; mod plan; +mod utils; pub use expr::expr_to_sql; pub use plan::plan_to_sql; diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 8d9e0b1a6ebb..e1f5135efda9 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -17,13 +17,16 @@ use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::{expr::Alias, Expr, JoinConstraint, JoinType, LogicalPlan}; -use sqlparser::ast::{self, Ident, SelectItem}; +use sqlparser::ast::{self}; + +use crate::unparser::utils::unproject_agg_exprs; use super::{ ast::{ BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder, }, + utils::find_agg_node_within_select, Unparser, }; @@ -49,7 +52,7 @@ use super::{ /// .unwrap(); /// let sql = plan_to_sql(&plan).unwrap(); /// -/// assert_eq!(format!("{}", sql), "SELECT table.id, table.value FROM table") +/// assert_eq!(format!("{}", sql), "SELECT \"table\".\"id\", \"table\".\"value\" FROM \"table\"") /// ``` pub fn plan_to_sql(plan: &LogicalPlan) -> Result { let unparser = Unparser::default(); @@ -132,43 +135,16 @@ impl Unparser<'_> { // A second projection implies a derived tablefactor if !select.already_projected() { // Special handling when projecting an agregation plan - if let LogicalPlan::Aggregate(agg) = p.input.as_ref() { - let mut items = p - .expr - .iter() - .filter(|e| !matches!(e, Expr::AggregateFunction(_))) - .map(|e| self.select_item_to_sql(e)) - .collect::>>()?; - - let proj_aggs = p + if let Some(agg) = find_agg_node_within_select(plan, true) { + let items = p .expr .iter() - .filter(|e| matches!(e, Expr::AggregateFunction(_))) - .zip(agg.aggr_expr.iter()) - .map(|(proj, agg_exp)| { - let sql_agg_expr = self.select_item_to_sql(agg_exp)?; - let maybe_aliased = - if let Expr::Alias(Alias { name, .. }) = proj { - if let SelectItem::UnnamedExpr(aggregation_fun) = - sql_agg_expr - { - SelectItem::ExprWithAlias { - expr: aggregation_fun, - alias: Ident { - value: name.to_string(), - quote_style: None, - }, - } - } else { - sql_agg_expr - } - } else { - sql_agg_expr - }; - Ok(maybe_aliased) + .map(|proj_expr| { + let unproj = unproject_agg_exprs(proj_expr, agg)?; + self.select_item_to_sql(&unproj) }) .collect::>>()?; - items.extend(proj_aggs); + select.projection(items); select.group_by(ast::GroupByExpr::Expressions( agg.group_expr @@ -176,12 +152,6 @@ impl Unparser<'_> { .map(|expr| self.expr_to_sql(expr)) .collect::>>()?, )); - self.select_to_sql_recursively( - agg.input.as_ref(), - query, - select, - relation, - ) } else { let items = p .expr @@ -189,13 +159,13 @@ impl Unparser<'_> { .map(|e| self.select_item_to_sql(e)) .collect::>>()?; select.projection(items); - self.select_to_sql_recursively( - p.input.as_ref(), - query, - select, - relation, - ) } + self.select_to_sql_recursively( + p.input.as_ref(), + query, + select, + relation, + ) } else { let mut derived_builder = DerivedRelationBuilder::default(); derived_builder.lateral(false).alias(None).subquery({ @@ -213,9 +183,16 @@ impl Unparser<'_> { } } LogicalPlan::Filter(filter) => { - let filter_expr = self.expr_to_sql(&filter.predicate)?; - - select.selection(Some(filter_expr)); + if let Some(agg) = + find_agg_node_within_select(plan, select.already_projected()) + { + let unprojected = unproject_agg_exprs(&filter.predicate, agg)?; + let filter_expr = self.expr_to_sql(&unprojected)?; + select.having(Some(filter_expr)); + } else { + let filter_expr = self.expr_to_sql(&filter.predicate)?; + select.selection(Some(filter_expr)); + } self.select_to_sql_recursively( filter.input.as_ref(), @@ -249,9 +226,13 @@ impl Unparser<'_> { relation, ) } - LogicalPlan::Aggregate(_agg) => { - not_impl_err!( - "Unsupported aggregation plan not following a projection: {plan:?}" + LogicalPlan::Aggregate(agg) => { + // Aggregate nodes are handled simulatenously with Projection nodes + self.select_to_sql_recursively( + agg.input.as_ref(), + query, + select, + relation, ) } LogicalPlan::Distinct(_distinct) => { diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs new file mode 100644 index 000000000000..9d098c494599 --- /dev/null +++ b/datafusion/sql/src/unparser/utils.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{ + internal_err, + tree_node::{Transformed, TreeNode}, + Result, +}; +use datafusion_expr::{Aggregate, Expr, LogicalPlan}; + +/// Recursively searches children of [LogicalPlan] to find an Aggregate node if one exists +/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). +/// If an Aggregate node is not found prior to this or at all before reaching the end +/// of the tree, None is returned. +pub(crate) fn find_agg_node_within_select( + plan: &LogicalPlan, + already_projected: bool, +) -> Option<&Aggregate> { + // Note that none of the nodes that have a corresponding agg node can have more + // than 1 input node. E.g. Projection / Filter always have 1 input node. + let input = plan.inputs(); + let input = if input.len() > 1 { + return None; + } else { + input.first()? + }; + if let LogicalPlan::Aggregate(agg) = input { + Some(agg) + } else if let LogicalPlan::TableScan(_) = input { + None + } else if let LogicalPlan::Projection(_) = input { + if already_projected { + None + } else { + find_agg_node_within_select(input, true) + } + } else { + find_agg_node_within_select(input, already_projected) + } +} + +/// Recursively identify all Column expressions and transform them into the appropriate +/// aggregate expression contained in agg. +/// +/// For example, if expr contains the column expr "COUNT(*)" it will be transformed +/// into an actual aggregate expression COUNT(*) as identified in the aggregate node. +pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result { + expr.clone() + .transform(&|sub_expr| { + if let Expr::Column(c) = sub_expr { + // find the column in the agg schmea + if let Ok(n) = agg.schema.index_of_column(&c) { + let unprojected_expr = agg + .group_expr + .iter() + .chain(agg.aggr_expr.iter()) + .nth(n) + .unwrap(); + Ok(Transformed::yes(unprojected_expr.clone())) + } else { + internal_err!( + "Tried to unproject agg expr not found in provided Aggregate!" + ) + } + } else { + Ok(Transformed::no(sub_expr)) + } + }) + .map(|e| e.data) +} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 1983f1fa72e3..c9c2bdd694b5 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2822,6 +2822,20 @@ impl ContextProvider for MockContextProvider { ), Field::new("😀", DataType::Int32, false), ])), + "person_quoted_cols" => Ok(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("First Name", DataType::Utf8, false), + Field::new("Last Name", DataType::Utf8, false), + Field::new("Age", DataType::Int32, false), + Field::new("State", DataType::Utf8, false), + Field::new("Salary", DataType::Float64, false), + Field::new( + "Birth Date", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("😀", DataType::Int32, false), + ])), "orders" => Ok(Schema::new(vec![ Field::new("order_id", DataType::UInt32, false), Field::new("customer_id", DataType::UInt32, false), @@ -4491,17 +4505,25 @@ impl TableSource for EmptyTable { #[test] fn roundtrip_expr() { let tests: Vec<(TableReference, &str, &str)> = vec![ - (TableReference::bare("person"), "age > 35", "(age > 35)"), - (TableReference::bare("person"), "id = '10'", "(id = '10')"), + ( + TableReference::bare("person"), + "age > 35", + r#"("age" > 35)"#, + ), + ( + TableReference::bare("person"), + "id = '10'", + r#"("id" = '10')"#, + ), ( TableReference::bare("person"), "CAST(id AS VARCHAR)", - "CAST(id AS VARCHAR)", + r#"CAST("id" AS VARCHAR)"#, ), ( TableReference::bare("person"), "SUM((age * 2))", - "SUM((age * 2))", + r#"SUM(("age" * 2))"#, ), ]; @@ -4528,79 +4550,77 @@ fn roundtrip_expr() { } #[test] -fn roundtrip_statement() { - let tests: Vec<(&str, &str)> = vec![ - ( +fn roundtrip_statement() -> Result<()> { + let tests: Vec<&str> = vec![ "select ta.j1_id from j1 ta;", - r#"SELECT ta.j1_id FROM j1 AS ta"#, - ), - ( "select ta.j1_id from j1 ta order by ta.j1_id;", - r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS LAST"#, - ), - ( "select * from j1 ta order by ta.j1_id, ta.j1_string desc;", - r#"SELECT ta.j1_id, ta.j1_string FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS LAST, ta.j1_string DESC NULLS FIRST"#, - ), - ( "select * from j1 limit 10;", - r#"SELECT j1.j1_id, j1.j1_string FROM j1 LIMIT 10"#, - ), - ( "select ta.j1_id from j1 ta where ta.j1_id > 1;", - r#"SELECT ta.j1_id FROM j1 AS ta WHERE (ta.j1_id > 1)"#, - ), - ( "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id);", - r#"SELECT ta.j1_id, tb.j2_string FROM j1 AS ta JOIN j2 AS tb ON (ta.j1_id = tb.j2_id)"#, - ), - ( "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);", - r#"SELECT ta.j1_id, tb.j2_string, tc.j3_string FROM j1 AS ta JOIN j2 AS tb ON (ta.j1_id = tb.j2_id) JOIN j3 AS tc ON (ta.j1_id = tc.j3_id)"#, - ), - ( "select * from (select id, first_name from person)", - "SELECT person.id, person.first_name FROM (SELECT person.id, person.first_name FROM person)" - ), - ( "select * from (select id, first_name from (select * from person))", - "SELECT person.id, person.first_name FROM (SELECT person.id, person.first_name FROM (SELECT person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀 FROM person))" - ), - ( "select id, count(*) as cnt from (select id from person) group by id", - "SELECT person.id, COUNT(*) AS cnt FROM (SELECT person.id FROM person) GROUP BY person.id" - ), - ( + "select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from (select (id-1) as id from person) group by id", + r#"select "First Name" from person_quoted_cols"#, + r#"select id, count("First Name") as cnt from (select id, "First Name" from person_quoted_cols) group by id"#, "select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id", - "SELECT p1.id, COUNT(*) AS cnt FROM (SELECT p1.id FROM person AS p1 JOIN person AS p2 ON (p1.id = p2.id)) GROUP BY p1.id" - ), - ( "select id, count(*), first_name from person group by first_name, id", - "SELECT person.id, COUNT(*), person.first_name FROM person GROUP BY person.first_name, person.id" - ), - ]; + "select id, sum(age), first_name from person group by first_name, id", + "select id, count(*), first_name + from person + where id!=3 and first_name=='test' + group by first_name, id + having count(*)>5 and count(*)<10 + order by count(*)", + r#"select id, count("First Name") as count_first_name, "Last Name" + from person_quoted_cols + where id!=3 and "First Name"=='test' + group by "Last Name", id + having count_first_name>5 and count_first_name<10 + order by count_first_name, "Last Name""#, + r#"select p.id, count("First Name") as count_first_name, + "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) + from (select id, "First Name", "Last Name" from person_quoted_cols) qp + inner join (select * from person) p + on p.id = qp.id + where p.id!=3 and "First Name"=='test' and qp.id in + (select id from (select id, count(*) from person group by id having count(*) > 0)) + group by "Last Name", p.id + having count_first_name>5 and count_first_name<10 + order by count_first_name, "Last Name""#, + ]; - let roundtrip = |sql: &str| -> Result { + // For each test sql string, we transform as follows: + // sql -> ast::Statement (s1) -> LogicalPlan (p1) -> ast::Statement (s2) -> LogicalPlan (p2) + // We test not that s1==s2, but rather p1==p2. This ensures that unparser preserves the logical + // query information of the original sql string and disreguards other differences in syntax or + // quoting. + for query in tests { let dialect = GenericDialect {}; - let statement = Parser::new(&dialect).try_with_sql(sql)?.parse_statement()?; + let statement = Parser::new(&dialect) + .try_with_sql(query)? + .parse_statement()?; let context = MockContextProvider::default(); let sql_to_rel = SqlToRel::new(&context); - let plan = sql_to_rel.sql_statement_to_plan(statement)?; + let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); - println!("{}", plan.display_indent()); + let roundtrip_statement = plan_to_sql(&plan)?; - let ast = plan_to_sql(&plan)?; + let actual = format!("{}", &roundtrip_statement); + println!("roundtrip sql: {actual}"); + println!("plan {}", plan.display_indent()); - println!("{ast}"); + let plan_roundtrip = sql_to_rel + .sql_statement_to_plan(roundtrip_statement.clone()) + .unwrap(); - Ok(format!("{}", ast)) - }; - - for (query, expected) in tests { - let actual = roundtrip(query).unwrap(); - assert_eq!(actual, expected); + assert_eq!(plan, plan_roundtrip); } + + Ok(()) } #[cfg(test)]