Skip to content

Commit

Permalink
Add DateFieldExtractStyle support
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrebnov committed Jul 19, 2024
1 parent 1d5fe15 commit d1af73f
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 1 deletion.
38 changes: 38 additions & 0 deletions datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ pub trait Dialect: Send + Sync {
fn use_char_for_utf8_cast(&self) -> bool {
false
}

fn date_field_extract_style(&self) -> DateFieldExtractStyle {
DateFieldExtractStyle::DatePart
}
}

/// `IntervalStyle` to use for unparsing
Expand All @@ -74,6 +78,19 @@ pub enum IntervalStyle {
MySQL,
}

/// Datetime subfield extraction style for unparsing
///
/// https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
/// Different DBMSs follow different standards; popular ones are:
/// date_part('YEAR', date '2001-02-16')
/// EXTRACT(YEAR from date '2001-02-16')
/// Some DBMSs, like Postgres, support both, whereas others like MySQL require EXTRACT.
#[derive(Clone, Copy, PartialEq)]
pub enum DateFieldExtractStyle {
DatePart,
Extract
}

pub struct DefaultDialect {}

impl Dialect for DefaultDialect {
Expand Down Expand Up @@ -123,6 +140,10 @@ impl Dialect for MySqlDialect {
fn use_char_for_utf8_cast(&self) -> bool {
true
}

fn date_field_extract_style(&self) -> DateFieldExtractStyle {
DateFieldExtractStyle::Extract
}
}

pub struct SqliteDialect {}
Expand All @@ -140,6 +161,7 @@ pub struct CustomDialect {
interval_style: IntervalStyle,
use_double_precision_for_float64: bool,
use_char_for_utf8_cast: bool,
date_subfield_extract_style: DateFieldExtractStyle,
}

impl Default for CustomDialect {
Expand All @@ -151,6 +173,7 @@ impl Default for CustomDialect {
interval_style: IntervalStyle::SQLStandard,
use_double_precision_for_float64: false,
use_char_for_utf8_cast: false,
date_subfield_extract_style: DateFieldExtractStyle::DatePart,
}
}
}
Expand Down Expand Up @@ -189,6 +212,10 @@ impl Dialect for CustomDialect {
fn use_char_for_utf8_cast(&self) -> bool {
self.use_char_for_utf8_cast
}

fn date_field_extract_style(&self) -> DateFieldExtractStyle {
self.date_subfield_extract_style
}
}

// create a CustomDialectBuilder
Expand All @@ -199,6 +226,7 @@ pub struct CustomDialectBuilder {
interval_style: IntervalStyle,
use_double_precision_for_float64: bool,
use_char_for_utf8_cast: bool,
date_subfield_extract_style: DateFieldExtractStyle,
}

impl CustomDialectBuilder {
Expand All @@ -210,6 +238,7 @@ impl CustomDialectBuilder {
interval_style: IntervalStyle::PostgresVerbose,
use_double_precision_for_float64: false,
use_char_for_utf8_cast: false,
date_subfield_extract_style: DateFieldExtractStyle::DatePart,
}
}

Expand All @@ -221,6 +250,7 @@ impl CustomDialectBuilder {
interval_style: self.interval_style,
use_double_precision_for_float64: self.use_double_precision_for_float64,
use_char_for_utf8_cast: self.use_char_for_utf8_cast,
date_subfield_extract_style: self.date_subfield_extract_style,
}
}

Expand Down Expand Up @@ -262,4 +292,12 @@ impl CustomDialectBuilder {
self.use_char_for_utf8_cast = use_char_for_utf8_cast;
self
}

pub fn with_date_field_extract_style(
mut self,
date_subfield_extract_style: DateFieldExtractStyle,
) -> Self {
self.date_subfield_extract_style = date_subfield_extract_style;
self
}
}
99 changes: 98 additions & 1 deletion datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use arrow::util::display::array_value_to_string;
use arrow_array::{Date32Array, Date64Array, PrimitiveArray};
use arrow_schema::DataType;
use core::fmt;
use datafusion_expr::ScalarUDF;
use sqlparser::ast::TimezoneInfo;
use sqlparser::ast::Value::SingleQuotedString;
use sqlparser::ast::{
Expand All @@ -42,7 +43,7 @@ use datafusion_expr::{
Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast,
};

use super::dialect::IntervalStyle;
use super::dialect::{DateFieldExtractStyle, IntervalStyle};
use super::Unparser;

/// DataFusion's Exprs can represent either an `Expr` or an `OrderByExpr`
Expand Down Expand Up @@ -122,6 +123,12 @@ impl Unparser<'_> {
Expr::ScalarFunction(ScalarFunction { func, args }) => {
let func_name = func.name();

if let Some(expr) =
self.scalar_function_to_sql_overrides(func_name, func, args)
{
return Ok(expr);
}

let args = args
.iter()
.map(|e| {
Expand Down Expand Up @@ -515,6 +522,48 @@ impl Unparser<'_> {
}
}

fn scalar_function_to_sql_overrides(
&self,
func_name: &str,
_func: &Arc<ScalarUDF>,
args: &Vec<Expr>,
) -> Option<ast::Expr> {
match func_name.to_lowercase().as_str() {
"date_part" => {
if self.dialect.date_field_extract_style()
== DateFieldExtractStyle::Extract
&& args.len() == 2
{
let Ok(date_expr) = self.expr_to_sql(&args[1]) else {
return None;
};

match &args[0] {
Expr::Literal(ScalarValue::Utf8(Some(field))) => {
let field = match field.to_lowercase().as_str() {
"year" => ast::DateTimeField::Year,
"month" => ast::DateTimeField::Month,
"day" => ast::DateTimeField::Day,
"hour" => ast::DateTimeField::Hour,
"minute" => ast::DateTimeField::Minute,
"second" => ast::DateTimeField::Second,
_ => return None,
};

return Some(ast::Expr::Extract {
field,
expr: Box::new(date_expr),
});
}
_ => return None,
}
}
None
}
_ => None,
}
}

fn ast_type_for_date64_in_cast(&self) -> ast::DataType {
if self.dialect.use_timestamp_for_date64() {
ast::DataType::Timestamp(None, ast::TimezoneInfo::None)
Expand Down Expand Up @@ -1841,4 +1890,52 @@ mod tests {
}
Ok(())
}

#[test]
fn custom_dialect_with_date_field_extract_style() -> Result<()> {
for (extract_style, unit, expected) in [
(
DateFieldExtractStyle::DatePart,
"YEAR",
"date_part('YEAR', x)",
),
(
DateFieldExtractStyle::Extract,
"YEAR",
"EXTRACT(YEAR FROM x)",
),
(
DateFieldExtractStyle::DatePart,
"MONTH",
"date_part('MONTH', x)",
),
(
DateFieldExtractStyle::Extract,
"MONTH",
"EXTRACT(MONTH FROM x)",
),
(
DateFieldExtractStyle::DatePart,
"DAY",
"date_part('DAY', x)",
),
(DateFieldExtractStyle::Extract, "DAY", "EXTRACT(DAY FROM x)"),
] {
let dialect = CustomDialectBuilder::new()
.with_date_field_extract_style(extract_style)
.build();

let unparser = Unparser::new(&dialect);
let expr = ScalarUDF::new_from_impl(
datafusion_functions::datetime::date_part::DatePartFunc::new(),
)
.call(vec![Expr::Literal(ScalarValue::new_utf8(unit)), col("x")]);

let ast = unparser.expr_to_sql(&expr)?;
let actual = format!("{}", ast);

assert_eq!(actual, expected);
}
Ok(())
}
}

0 comments on commit d1af73f

Please sign in to comment.