Skip to content

Commit

Permalink
Use make_array to handle SQLExpr::Array.
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangzhx committed Aug 27, 2023
1 parent e0961d5 commit 38bf081
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 64 deletions.
27 changes: 25 additions & 2 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ use datafusion_expr::{
col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast,
Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast,
};
use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value};
use sqlparser::ast::{
ArrayAgg, Expr as SQLExpr, FunctionArg, FunctionArgExpr, JsonOperator,
TrimWhereField, Value,
};
use sqlparser::parser::ParserError::ParserError;
use std::str::FromStr;

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(crate) fn sql_expr_to_logical_expr(
Expand Down Expand Up @@ -176,7 +180,26 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
)))
}

SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema),
SQLExpr::Array(arr) => {
if arr.elem.is_empty() {
Ok(lit(ScalarValue::new_list(None, DataType::Utf8)))
} else {
let args = arr
.elem
.iter()
.map(|expr| {
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr.clone()))
})
.collect::<Vec<FunctionArg>>();

let args =
self.function_args_to_expr(args, schema, planner_context)?;
Ok(Expr::ScalarFunction(ScalarFunction::new(
BuiltinScalarFunction::from_str("make_array")?,
args,
)))
}
}
SQLExpr::Interval(interval) => {
self.sql_interval_to_expr(false, interval, schema, planner_context)
}
Expand Down
40 changes: 0 additions & 40 deletions datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use datafusion_expr::{lit, Expr, Operator};
use log::debug;
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value};
use sqlparser::parser::ParserError::ParserError;
use std::collections::HashSet;

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(crate) fn parse_value(
Expand Down Expand Up @@ -125,45 +124,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
)))
}

pub(super) fn sql_array_literal(
&self,
elements: Vec<SQLExpr>,
schema: &DFSchema,
) -> Result<Expr> {
let mut values = Vec::with_capacity(elements.len());

for element in elements {
let value = self.sql_expr_to_logical_expr(
element,
schema,
&mut PlannerContext::new(),
)?;
match value {
Expr::Literal(scalar) => {
values.push(scalar);
}
_ => {
return not_impl_err!(
"Arrays with elements other than literal are not supported: {value}"
);
}
}
}

let data_types: HashSet<DataType> =
values.iter().map(|e| e.get_datatype()).collect();

if data_types.is_empty() {
Ok(lit(ScalarValue::new_list(None, DataType::Utf8)))
} else if data_types.len() > 1 {
not_impl_err!("Arrays with different types are not supported: {data_types:?}")
} else {
let data_type = values[0].get_datatype();

Ok(lit(ScalarValue::new_list(Some(values), data_type)))
}
}

/// Convert a SQL interval expression to a DataFusion logical plan
/// expression
pub(super) fn sql_interval_to_expr(
Expand Down
31 changes: 9 additions & 22 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1364,18 +1364,6 @@ fn select_interval_out_of_range() {
);
}

#[test]
fn select_array_no_common_type() {
let sql = "SELECT [1, true, null]";
let err = logical_plan(sql).expect_err("query should have failed");

// HashSet doesn't guarantee order
assert_contains!(
err.to_string(),
r#"Arrays with different types are not supported: "#
);
}

#[test]
fn recursive_ctes() {
let sql = "
Expand All @@ -1392,16 +1380,6 @@ fn recursive_ctes() {
);
}

#[test]
fn select_array_non_literal_type() {
let sql = "SELECT [now()]";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
r#"NotImplemented("Arrays with elements other than literal are not supported: now()")"#,
format!("{err:?}")
);
}

#[test]
fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() {
quick_test(
Expand Down Expand Up @@ -4230,6 +4208,15 @@ fn test_multi_grouping_sets() {
quick_test(sql, expected);
}

#[test]
fn test_array_with_binaryExpr() {
let sql = "select [1>0,2>1]";

let expected = "Projection: make_array(Int64(1) > Int64(0), Int64(2) > Int64(1))\
\n EmptyRelation";
quick_test(sql, expected);
}

fn assert_field_not_found(err: DataFusionError, name: &str) {
match err {
DataFusionError::SchemaError { .. } => {
Expand Down

0 comments on commit 38bf081

Please sign in to comment.