From d1af73f8a5d1c178ad4c39be9dba8cc42ff91009 Mon Sep 17 00:00:00 2001 From: sgrebnov Date: Thu, 18 Jul 2024 20:11:51 -0700 Subject: [PATCH] Add DateFieldExtractStyle support --- datafusion/sql/src/unparser/dialect.rs | 38 ++++++++++ datafusion/sql/src/unparser/expr.rs | 99 +++++++++++++++++++++++++- 2 files changed, 136 insertions(+), 1 deletion(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 933d336e07945..bc43234206862 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -57,6 +57,10 @@ pub trait Dialect: Send + Sync { fn use_char_for_utf8_cast(&self) -> bool { false } + + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + DateFieldExtractStyle::DatePart + } } /// `IntervalStyle` to use for unparsing @@ -74,6 +78,19 @@ pub enum IntervalStyle { MySQL, } +/// Datetime subfield extraction style for unparsing +/// +/// https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT +/// Different DBMSs follow different standards; popular ones are: +/// date_part('YEAR', date '2001-02-16') +/// EXTRACT(YEAR from date '2001-02-16') +/// Some DBMSs, like Postgres, support both, whereas others like MySQL require EXTRACT. +#[derive(Clone, Copy, PartialEq)] +pub enum DateFieldExtractStyle { + DatePart, + Extract +} + pub struct DefaultDialect {} impl Dialect for DefaultDialect { @@ -123,6 +140,10 @@ impl Dialect for MySqlDialect { fn use_char_for_utf8_cast(&self) -> bool { true } + + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + DateFieldExtractStyle::Extract + } } pub struct SqliteDialect {} @@ -140,6 +161,7 @@ pub struct CustomDialect { interval_style: IntervalStyle, use_double_precision_for_float64: bool, use_char_for_utf8_cast: bool, + date_subfield_extract_style: DateFieldExtractStyle, } impl Default for CustomDialect { @@ -151,6 +173,7 @@ impl Default for CustomDialect { interval_style: IntervalStyle::SQLStandard, use_double_precision_for_float64: false, use_char_for_utf8_cast: false, + date_subfield_extract_style: DateFieldExtractStyle::DatePart, } } } @@ -189,6 +212,10 @@ impl Dialect for CustomDialect { fn use_char_for_utf8_cast(&self) -> bool { self.use_char_for_utf8_cast } + + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + self.date_subfield_extract_style + } } // create a CustomDialectBuilder @@ -199,6 +226,7 @@ pub struct CustomDialectBuilder { interval_style: IntervalStyle, use_double_precision_for_float64: bool, use_char_for_utf8_cast: bool, + date_subfield_extract_style: DateFieldExtractStyle, } impl CustomDialectBuilder { @@ -210,6 +238,7 @@ impl CustomDialectBuilder { interval_style: IntervalStyle::PostgresVerbose, use_double_precision_for_float64: false, use_char_for_utf8_cast: false, + date_subfield_extract_style: DateFieldExtractStyle::DatePart, } } @@ -221,6 +250,7 @@ impl CustomDialectBuilder { interval_style: self.interval_style, use_double_precision_for_float64: self.use_double_precision_for_float64, use_char_for_utf8_cast: self.use_char_for_utf8_cast, + date_subfield_extract_style: self.date_subfield_extract_style, } } @@ -262,4 +292,12 @@ impl CustomDialectBuilder { self.use_char_for_utf8_cast = use_char_for_utf8_cast; self } + + pub fn with_date_field_extract_style( + mut self, + date_subfield_extract_style: DateFieldExtractStyle, + ) -> Self { + self.date_subfield_extract_style = date_subfield_extract_style; + self + } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index c19a57f2257de..efaef94ffd6f5 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -25,6 +25,7 @@ use arrow::util::display::array_value_to_string; use arrow_array::{Date32Array, Date64Array, PrimitiveArray}; use arrow_schema::DataType; use core::fmt; +use datafusion_expr::ScalarUDF; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ @@ -42,7 +43,7 @@ use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, }; -use super::dialect::IntervalStyle; +use super::dialect::{DateFieldExtractStyle, IntervalStyle}; use super::Unparser; /// DataFusion's Exprs can represent either an `Expr` or an `OrderByExpr` @@ -122,6 +123,12 @@ impl Unparser<'_> { Expr::ScalarFunction(ScalarFunction { func, args }) => { let func_name = func.name(); + if let Some(expr) = + self.scalar_function_to_sql_overrides(func_name, func, args) + { + return Ok(expr); + } + let args = args .iter() .map(|e| { @@ -515,6 +522,48 @@ impl Unparser<'_> { } } + fn scalar_function_to_sql_overrides( + &self, + func_name: &str, + _func: &Arc, + args: &Vec, + ) -> Option { + match func_name.to_lowercase().as_str() { + "date_part" => { + if self.dialect.date_field_extract_style() + == DateFieldExtractStyle::Extract + && args.len() == 2 + { + let Ok(date_expr) = self.expr_to_sql(&args[1]) else { + return None; + }; + + match &args[0] { + Expr::Literal(ScalarValue::Utf8(Some(field))) => { + let field = match field.to_lowercase().as_str() { + "year" => ast::DateTimeField::Year, + "month" => ast::DateTimeField::Month, + "day" => ast::DateTimeField::Day, + "hour" => ast::DateTimeField::Hour, + "minute" => ast::DateTimeField::Minute, + "second" => ast::DateTimeField::Second, + _ => return None, + }; + + return Some(ast::Expr::Extract { + field, + expr: Box::new(date_expr), + }); + } + _ => return None, + } + } + None + } + _ => None, + } + } + fn ast_type_for_date64_in_cast(&self) -> ast::DataType { if self.dialect.use_timestamp_for_date64() { ast::DataType::Timestamp(None, ast::TimezoneInfo::None) @@ -1841,4 +1890,52 @@ mod tests { } Ok(()) } + + #[test] + fn custom_dialect_with_date_field_extract_style() -> Result<()> { + for (extract_style, unit, expected) in [ + ( + DateFieldExtractStyle::DatePart, + "YEAR", + "date_part('YEAR', x)", + ), + ( + DateFieldExtractStyle::Extract, + "YEAR", + "EXTRACT(YEAR FROM x)", + ), + ( + DateFieldExtractStyle::DatePart, + "MONTH", + "date_part('MONTH', x)", + ), + ( + DateFieldExtractStyle::Extract, + "MONTH", + "EXTRACT(MONTH FROM x)", + ), + ( + DateFieldExtractStyle::DatePart, + "DAY", + "date_part('DAY', x)", + ), + (DateFieldExtractStyle::Extract, "DAY", "EXTRACT(DAY FROM x)"), + ] { + let dialect = CustomDialectBuilder::new() + .with_date_field_extract_style(extract_style) + .build(); + + let unparser = Unparser::new(&dialect); + let expr = ScalarUDF::new_from_impl( + datafusion_functions::datetime::date_part::DatePartFunc::new(), + ) + .call(vec![Expr::Literal(ScalarValue::new_utf8(unit)), col("x")]); + + let ast = unparser.expr_to_sql(&expr)?; + let actual = format!("{}", ast); + + assert_eq!(actual, expected); + } + Ok(()) + } }