diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 7f2dc61c07bf..a10428a7224e 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -338,6 +338,8 @@ pub fn longest_consecutive_prefix>( count } +/// Array Utils + /// Wrap an array into a single element `ListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` pub fn array_into_list_array(arr: ArrayRef) -> ListArray { @@ -425,6 +427,42 @@ pub fn base_type(data_type: &DataType) -> DataType { } } +/// A helper function to coerce base type in List. +/// +/// Example +/// ``` +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::utils::coerced_type_with_base_type_only; +/// use std::sync::Arc; +/// +/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let base_type = DataType::Float64; +/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); +/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true)))); +pub fn coerced_type_with_base_type_only( + data_type: &DataType, + base_type: &DataType, +) -> DataType { + match data_type { + DataType::List(field) => { + let data_type = match field.data_type() { + DataType::List(_) => { + coerced_type_with_base_type_only(field.data_type(), base_type) + } + _ => base_type.to_owned(), + }; + + DataType::List(Arc::new(Field::new( + field.name(), + data_type, + field.is_nullable(), + ))) + } + + _ => base_type.clone(), + } +} + /// Compute the number of dimensions in a list data type. pub fn list_ndims(data_type: &DataType) -> u64 { if let DataType::List(field) = data_type { diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 44fbf45525d4..135074eecfd3 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -912,10 +912,17 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { - BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArraySort => { Signature::variadic_any(self.volatility()) } + BuiltinScalarFunction::ArrayAppend => Signature { + type_signature: ArrayAppendLikeSignature, + volatility: self.volatility(), + }, + BuiltinScalarFunction::MakeArray => { + // 0 or more arguments of arbitrary type + Signature::one_of(vec![VariadicCoerced, Any(0)], self.volatility()) + } BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayConcat => { @@ -954,10 +961,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()), BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), - BuiltinScalarFunction::MakeArray => { - // 0 or more arguments of arbitrary type - Signature::one_of(vec![VariadicAny, Any(0)], self.volatility()) - } BuiltinScalarFunction::Range => Signature::one_of( vec![ Exact(vec![Int64]), diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 685601523f9b..2e0e4ca731c4 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -95,6 +95,8 @@ pub enum TypeSignature { VariadicEqual, /// One or more arguments with arbitrary types VariadicAny, + /// A function such as `make_array` should be coerced to the same type + VariadicCoerced, /// fixed number of arguments of an arbitrary but equal type out of a list of valid types. /// /// # Examples @@ -113,6 +115,8 @@ pub enum TypeSignature { /// Function `make_array` takes 0 or more arguments with arbitrary types, its `TypeSignature` /// is `OneOf(vec![Any(0), VariadicAny])`. OneOf(Vec), + /// Specialized Signature for ArrayAppend and similar functions + ArrayAppendLikeSignature, } impl TypeSignature { @@ -136,11 +140,17 @@ impl TypeSignature { .collect::>() .join(", ")] } + TypeSignature::VariadicCoerced => { + vec!["CoercibleT, .., CoercibleT".to_string()] + } TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], TypeSignature::OneOf(sigs) => { sigs.iter().flat_map(|s| s.to_string_repr()).collect() } + TypeSignature::ArrayAppendLikeSignature => { + vec!["ArrayAppendLikeSignature(List, T)".to_string()] + } } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 79b574238495..29a53181d584 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -15,13 +15,19 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::signature::TIMEZONE_WILDCARD; use crate::{Signature, TypeSignature}; +use arrow::datatypes::Field; use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::utils::list_ndims; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; + +use super::binary::comparison_coercion; /// Performs type coercion for function arguments. /// @@ -85,6 +91,24 @@ fn get_valid_types( .iter() .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) .collect(), + TypeSignature::VariadicCoerced => { + let new_type = current_types.iter().skip(1).try_fold( + current_types.first().unwrap().clone(), + |acc, x| { + let coerced_type = comparison_coercion(&acc, x); + if let Some(coerced_type) = coerced_type { + Ok(coerced_type) + } else { + internal_err!("Coercion from {acc:?} to {x:?} failed.") + } + }, + ); + + match new_type { + Ok(new_type) => vec![vec![new_type; current_types.len()]], + Err(e) => return Err(e), + } + } TypeSignature::VariadicEqual => { // one entry with the same len as current_types, whose type is `current_types[0]`. vec![current_types @@ -95,7 +119,48 @@ fn get_valid_types( TypeSignature::VariadicAny => { vec![current_types.to_vec()] } + TypeSignature::Exact(valid_types) => vec![valid_types.clone()], + TypeSignature::ArrayAppendLikeSignature => { + if current_types.len() != 2 { + return Ok(vec![vec![]]); + } + + let array_type = ¤t_types[0]; + let elem_type = ¤t_types[1]; + + // Special case for `array_append(Null, T)`, just return and process in physical expression step. + if array_type.eq(&DataType::Null) { + let array_type = + DataType::List(Arc::new(Field::new("item", elem_type.clone(), true))); + return Ok(vec![vec![array_type.to_owned(), elem_type.to_owned()]]); + } + + // We need to find the coerced base type, mainly for cases like: + // `array_append(List(null), i64)` -> `List(i64)` + let array_base_type = datafusion_common::utils::base_type(array_type); + let elem_base_type = datafusion_common::utils::base_type(elem_type); + let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); + + if new_base_type.is_none() { + return internal_err!( + "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." + ); + } + let new_base_type = new_base_type.unwrap(); + + let array_type = datafusion_common::utils::coerced_type_with_base_type_only( + array_type, + &new_base_type, + ); + + if let DataType::List(ref field) = array_type { + let elem_type = field.data_type(); + return Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]); + } else { + return Ok(vec![vec![]]); + } + } TypeSignature::Any(number) => { if current_types.len() != *number { return plan_err!( @@ -241,6 +306,15 @@ fn coerced_from<'a>( Utf8 | LargeUtf8 => Some(type_into.clone()), Null if can_cast_types(type_from, type_into) => Some(type_into.clone()), + // Only accept list with the same number of dimensions unless the type is Null. + // List with different dimensions should be handled in TypeSignature or other places before this. + List(_) + if datafusion_common::utils::base_type(type_from).eq(&Null) + || list_ndims(type_from) == list_ndims(type_into) => + { + Some(type_into.clone()) + } + Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => { match type_from { Timestamp(_, Some(from_tz)) => { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 91611251d9dd..c5e1180b9f97 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -590,26 +590,6 @@ fn coerce_arguments_for_fun( .collect::>>()?; } - if *fun == BuiltinScalarFunction::MakeArray { - // Find the final data type for the function arguments - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - let new_type = current_types - .iter() - .skip(1) - .fold(current_types.first().unwrap().clone(), |acc, x| { - comparison_coercion(&acc, x).unwrap_or(acc) - }); - - return expressions - .iter() - .zip(current_types) - .map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema)) - .collect(); - } Ok(expressions) } @@ -618,20 +598,6 @@ fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result expr.clone().cast_to(to_type, schema) } -/// Cast array `expr` to the specified type, if possible -fn cast_array_expr( - expr: &Expr, - from_type: &DataType, - to_type: &DataType, - schema: &DFSchema, -) -> Result { - if from_type.equals_datatype(&DataType::Null) { - Ok(expr.clone()) - } else { - cast_expr(expr, to_type, schema) - } -} - /// Returns the coerced exprs for each `input_exprs`. /// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the /// data type of `input_exprs` need to be coerced. diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 269bbf7dcf10..e16d60e3531d 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -362,7 +362,8 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { match data_type { // Either an empty array or all nulls: DataType::Null => { - let array = new_null_array(&DataType::Null, arrays.len()); + let array = + new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum()); Ok(Arc::new(array_into_list_array(array))) } DataType::LargeList(..) => array_array::(arrays, data_type), @@ -763,7 +764,12 @@ pub fn array_append(args: &[ArrayRef]) -> Result { check_datatypes("array_append", &[list_array.values(), element_array])?; let res = match list_array.value_type() { DataType::List(_) => concat_internal(args)?, - DataType::Null => return make_array(&[element_array.to_owned()]), + DataType::Null => { + return make_array(&[ + list_array.values().to_owned(), + element_array.to_owned(), + ]); + } data_type => { return general_append_and_prepend( list_array, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 3c23dd369ae5..afd6042c776e 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -265,10 +265,8 @@ AS VALUES (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10) ; -query ? +query error select [1, true, null] ----- -[1, 1, ] query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now() SELECT [now()] @@ -1092,18 +1090,27 @@ select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3 ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) -# TODO: array_append with NULLs -# array_append scalar function #1 -# query ? -# select array_append(make_array(), 4); -# ---- -# [4] +# array_append with NULLs -# array_append scalar function #2 -# query ?? -# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4)); -# ---- -# [[]] [[4]] +query ??????? +select + array_append(null, 1), + array_append(null, [2, 3]), + array_append(null, [[4]]), + array_append(make_array(), 4), + array_append(make_array(), null), + array_append(make_array(1, null, 3), 4), + array_append(make_array(null, null), 1) +; +---- +[1] [[2, 3]] [[[4]]] [4] [] [1, , 3, 4] [, , 1] + +query ?? +select + array_append(make_array(make_array(1, null, 3)), make_array(null)), + array_append(make_array(make_array(1, null, 3)), null); +---- +[[1, , 3], []] [[1, , 3], ] # array_append scalar function #3 query ???