From 85d53634c4d26a3b9e545878e377860c14d01d7d Mon Sep 17 00:00:00 2001 From: kamille <34352236+Rachelint@users.noreply.github.com> Date: Mon, 8 Aug 2022 21:57:49 +0800 Subject: [PATCH] feat: support double quoted literal strings for dialects(such as mysql,bigquery,spark) (#3056) * support double quoted string. * add test. * add case sensitie test case. * fix naming error. --- datafusion/core/tests/sql/select.rs | 36 ++++++++++++++++++++++++ datafusion/sql/src/planner.rs | 43 +++++++++++++++++++++++++---- 2 files changed, 74 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 5056fa02f77d..b48f8d7cc5f8 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -1210,3 +1210,39 @@ async fn unprojected_filter() { ]; assert_batches_sorted_eq!(expected, &results); } + +#[tokio::test] +async fn case_sensitive_in_default_dialect() { + let int32_array = Int32Array::from(vec![1, 2, 3, 4, 5]); + let schema = Schema::new(vec![Field::new("INT32", DataType::Int32, false)]); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(int32_array)]).unwrap(); + + let ctx = SessionContext::new(); + let table = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + { + let sql = "select \"int32\" from t"; + let plan = ctx.create_logical_plan(sql); + assert!(plan.is_err()); + } + + { + let sql = "select \"INT32\" from t"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+-------+", + "| INT32 |", + "+-------+", + "| 1 |", + "| 2 |", + "| 3 |", + "| 4 |", + "| 5 |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + } +} diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 8c615a24a3f3..dc71dc4b2f1c 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -98,7 +98,9 @@ fn plan_key(key: SQLExpr) -> Result { SQLExpr::Value(Value::Number(s, _)) => { ScalarValue::Int64(Some(s.parse().unwrap())) } - SQLExpr::Value(Value::SingleQuotedString(s)) => ScalarValue::Utf8(Some(s)), + SQLExpr::Value(Value::SingleQuotedString(s) | Value::DoubleQuotedString(s)) => { + ScalarValue::Utf8(Some(s)) + } _ => { return Err(DataFusionError::SQL(ParserError(format!( "Unsuported index key expression: {:?}", @@ -1596,7 +1598,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { row.into_iter() .map(|v| match v { SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n), - SQLExpr::Value(Value::SingleQuotedString(s)) => Ok(lit(s)), + SQLExpr::Value( + Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), + ) => Ok(lit(s)), SQLExpr::Value(Value::Null) => { Ok(Expr::Literal(ScalarValue::Null)) } @@ -1638,7 +1642,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match sql { SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n), - SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())), + SQLExpr::Value(Value::SingleQuotedString(ref s) | Value::DoubleQuotedString(ref s)) => Ok(lit(s.clone())), SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)), SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction { @@ -2219,7 +2223,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Only handle string exprs for now let value = match value { - SQLExpr::Value(Value::SingleQuotedString(s)) => s, + SQLExpr::Value( + Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), + ) => s, _ => { return Err(DataFusionError::NotImplemented(format!( "Unsupported interval argument. Expected string literal, got: {:?}", @@ -2595,6 +2601,7 @@ fn parse_sql_number(n: &str) -> Result { mod tests { use super::*; use crate::assert_contains; + use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use std::any::Any; #[test] @@ -4371,8 +4378,16 @@ mod tests { } fn logical_plan(sql: &str) -> Result { + let dialect = &GenericDialect {}; + logical_plan_with_dialect(sql, dialect) + } + + fn logical_plan_with_dialect( + sql: &str, + dialect: &dyn Dialect, + ) -> Result { let planner = SqlToRel::new(&MockContextProvider {}); - let result = DFParser::parse_sql(sql); + let result = DFParser::parse_sql_with_dialect(sql, dialect); let mut ast = result?; planner.statement_to_plan(ast.pop_front().unwrap()) } @@ -4840,6 +4855,24 @@ mod tests { quick_test(sql, expected); } + #[test] + fn test_double_quoted_literal_string() { + // Assert double quoted literal string is parsed correctly like single quoted one in specific dialect. + let dialect = &MySqlDialect {}; + let single_quoted_res = format!( + "{:?}", + logical_plan_with_dialect("SELECT '1'", dialect).unwrap() + ); + let double_quoted_res = format!( + "{:?}", + logical_plan_with_dialect("SELECT \"1\"", dialect).unwrap() + ); + assert_eq!(single_quoted_res, double_quoted_res); + + // It should return error in other dialect. + assert!(logical_plan("SELECT \"1\"").is_err()); + } + fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => {