diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index c454a9781eda..e642dae06e4f 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -960,7 +960,10 @@ impl BuiltinScalarFunction { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayPrepend => Signature { + type_signature: ElementAndArray, + volatility: self.volatility(), + }, BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 3f07c300e196..729131bd95e1 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -122,6 +122,10 @@ pub enum TypeSignature { /// List dimension of the List/LargeList is equivalent to the number of List. /// List dimension of the non-list is 0. ArrayAndElement, + /// Specialized Signature for ArrayPrepend and similar functions + /// The first argument should be non-list or list, and the second argument should be List/LargeList. + /// The first argument's list dimension should be one dimension less than the second argument's list dimension. + ElementAndArray, } impl TypeSignature { @@ -155,6 +159,9 @@ impl TypeSignature { TypeSignature::ArrayAndElement => { vec!["ArrayAndElement(List, T)".to_string()] } + TypeSignature::ElementAndArray => { + vec!["ElementAndArray(T, List)".to_string()] + } } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index f95a30e025b4..fa47c92762bf 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -79,6 +79,55 @@ fn get_valid_types( signature: &TypeSignature, current_types: &[DataType], ) -> Result>> { + fn array_append_or_prepend_valid_types( + current_types: &[DataType], + is_append: bool, + ) -> Result>> { + if current_types.len() != 2 { + return Ok(vec![vec![]]); + } + + let (array_type, elem_type) = if is_append { + (¤t_types[0], ¤t_types[1]) + } else { + (¤t_types[1], ¤t_types[0]) + }; + + // We follow Postgres on `array_append(Null, T)`, which is not valid. + if array_type.eq(&DataType::Null) { + return Ok(vec![vec![]]); + } + + // We need to find the coerced base type, mainly for cases like: + // `array_append(List(null), i64)` -> `List(i64)` + let array_base_type = datafusion_common::utils::base_type(array_type); + let elem_base_type = datafusion_common::utils::base_type(elem_type); + let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); + + if new_base_type.is_none() { + return internal_err!( + "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." + ); + } + let new_base_type = new_base_type.unwrap(); + + let array_type = datafusion_common::utils::coerced_type_with_base_type_only( + array_type, + &new_base_type, + ); + + if let DataType::List(ref field) = array_type { + let elem_type = field.data_type(); + if is_append { + Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]) + } else { + Ok(vec![vec![elem_type.to_owned(), array_type.clone()]]) + } + } else { + Ok(vec![vec![]]) + } + } + let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() @@ -112,42 +161,10 @@ fn get_valid_types( TypeSignature::Exact(valid_types) => vec![valid_types.clone()], TypeSignature::ArrayAndElement => { - if current_types.len() != 2 { - return Ok(vec![vec![]]); - } - - let array_type = ¤t_types[0]; - let elem_type = ¤t_types[1]; - - // We follow Postgres on `array_append(Null, T)`, which is not valid. - if array_type.eq(&DataType::Null) { - return Ok(vec![vec![]]); - } - - // We need to find the coerced base type, mainly for cases like: - // `array_append(List(null), i64)` -> `List(i64)` - let array_base_type = datafusion_common::utils::base_type(array_type); - let elem_base_type = datafusion_common::utils::base_type(elem_type); - let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); - - if new_base_type.is_none() { - return internal_err!( - "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." - ); - } - let new_base_type = new_base_type.unwrap(); - - let array_type = datafusion_common::utils::coerced_type_with_base_type_only( - array_type, - &new_base_type, - ); - - if let DataType::List(ref field) = array_type { - let elem_type = field.data_type(); - return Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]); - } else { - return Ok(vec![vec![]]); - } + return array_append_or_prepend_valid_types(current_types, true) + } + TypeSignature::ElementAndArray => { + return array_append_or_prepend_valid_types(current_types, false) } TypeSignature::Any(number) => { if current_types.len() != *number { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b8d89edb49b1..6dab3b3084a9 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1618,18 +1618,58 @@ select array_append(column1, make_array(1, 11, 111)), array_append(make_array(ma ## array_prepend (aliases: `list_prepend`, `array_push_front`, `list_push_front`) -# TODO: array_prepend with NULLs -# array_prepend scalar function #1 -# query ? -# select array_prepend(4, make_array()); -# ---- -# [4] +# array_prepend with NULLs + +# DuckDB: [4] +# ClickHouse: Null +# Since they dont have the same result, we just follow Postgres, return error +query error +select array_prepend(4, NULL); + +query ? +select array_prepend(4, []); +---- +[4] + +query ? +select array_prepend(4, [null]); +---- +[4, ] + +# DuckDB: [null] +# ClickHouse: [null] +query ? +select array_prepend(null, []); +---- +[] + +query ? +select array_prepend(null, [1]); +---- +[, 1] + +query ? +select array_prepend(null, [[1,2,3]]); +---- +[, [1, 2, 3]] + +# DuckDB: [[]] +# ClickHouse: [[]] +# TODO: We may also return [[]] +query error +select array_prepend([], []); + +# DuckDB: [null] +# ClickHouse: [null] +# TODO: We may also return [null] +query error +select array_prepend(null, null); + +query ? +select array_append([], null); +---- +[] -# array_prepend scalar function #2 -# query ?? -# select array_prepend(make_array(), make_array()), array_prepend(make_array(4), make_array()); -# ---- -# [[]] [[4]] # array_prepend scalar function #3 query ???