Skip to content

Commit

Permalink
Signature for array_append and make_array
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Dec 7, 2023
1 parent 9be9073 commit 3ba90b4
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 56 deletions.
38 changes: 38 additions & 0 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ pub fn longest_consecutive_prefix<T: Borrow<usize>>(
count
}

/// Array Utils

/// Wrap an array into a single element `ListArray`.
/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]`
pub fn array_into_list_array(arr: ArrayRef) -> ListArray {
Expand Down Expand Up @@ -425,6 +427,42 @@ pub fn base_type(data_type: &DataType) -> DataType {
}
}

/// A helper function to coerce base type in List.
///
/// Example
/// ```
/// use arrow::datatypes::{DataType, Field};
/// use datafusion_common::utils::coerced_type_with_base_type_only;
/// use std::sync::Arc;
///
/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
/// let base_type = DataType::Float64;
/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type);
/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true))));
pub fn coerced_type_with_base_type_only(
data_type: &DataType,
base_type: &DataType,
) -> DataType {
match data_type {
DataType::List(field) => {
let data_type = match field.data_type() {
DataType::List(_) => {
coerced_type_with_base_type_only(field.data_type(), base_type)
}
_ => base_type.to_owned(),
};

DataType::List(Arc::new(Field::new(
field.name(),
data_type,
field.is_nullable(),
)))
}

_ => base_type.clone(),
}
}

/// Compute the number of dimensions in a list data type.
pub fn list_ndims(data_type: &DataType) -> u64 {
if let DataType::List(field) = data_type {
Expand Down
13 changes: 8 additions & 5 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -912,10 +912,17 @@ impl BuiltinScalarFunction {

// for now, the list is small, as we do not have many built-in functions.
match self {
BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArraySort => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayAppend => Signature {
type_signature: ArrayAppendLikeSignature,
volatility: self.volatility(),
},
BuiltinScalarFunction::MakeArray => {
// 0 or more arguments of arbitrary type
Signature::one_of(vec![VariadicCoerced, Any(0)], self.volatility())
}
BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayConcat => {
Expand Down Expand Up @@ -954,10 +961,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
BuiltinScalarFunction::MakeArray => {
// 0 or more arguments of arbitrary type
Signature::one_of(vec![VariadicAny, Any(0)], self.volatility())
}
BuiltinScalarFunction::Range => Signature::one_of(
vec![
Exact(vec![Int64]),
Expand Down
10 changes: 10 additions & 0 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ pub enum TypeSignature {
VariadicEqual,
/// One or more arguments with arbitrary types
VariadicAny,
/// A function such as `make_array` should be coerced to the same type
VariadicCoerced,
/// fixed number of arguments of an arbitrary but equal type out of a list of valid types.
///
/// # Examples
Expand All @@ -113,6 +115,8 @@ pub enum TypeSignature {
/// Function `make_array` takes 0 or more arguments with arbitrary types, its `TypeSignature`
/// is `OneOf(vec![Any(0), VariadicAny])`.
OneOf(Vec<TypeSignature>),
/// Specialized Signature for ArrayAppend and similar functions
ArrayAppendLikeSignature,
}

impl TypeSignature {
Expand All @@ -136,11 +140,17 @@ impl TypeSignature {
.collect::<Vec<&str>>()
.join(", ")]
}
TypeSignature::VariadicCoerced => {
vec!["CoercibleT, .., CoercibleT".to_string()]
}
TypeSignature::VariadicEqual => vec!["T, .., T".to_string()],
TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()],
TypeSignature::OneOf(sigs) => {
sigs.iter().flat_map(|s| s.to_string_repr()).collect()
}
TypeSignature::ArrayAppendLikeSignature => {
vec!["ArrayAppendLikeSignature(List<T>, T)".to_string()]
}
}
}

Expand Down
76 changes: 75 additions & 1 deletion datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::signature::TIMEZONE_WILDCARD;
use crate::{Signature, TypeSignature};
use arrow::datatypes::Field;
use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::{plan_err, DataFusionError, Result};
use datafusion_common::utils::list_ndims;
use datafusion_common::{internal_err, plan_err, DataFusionError, Result};

use super::binary::comparison_coercion;

/// Performs type coercion for function arguments.
///
Expand Down Expand Up @@ -85,6 +91,24 @@ fn get_valid_types(
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::VariadicCoerced => {
let new_type = current_types.iter().skip(1).try_fold(
current_types.first().unwrap().clone(),
|acc, x| {
let coerced_type = comparison_coercion(&acc, x);
if let Some(coerced_type) = coerced_type {
Ok(coerced_type)
} else {
internal_err!("Coercion from {acc:?} to {x:?} failed.")
}
},
);

match new_type {
Ok(new_type) => vec![vec![new_type; current_types.len()]],
Err(e) => return Err(e),
}
}
TypeSignature::VariadicEqual => {
// one entry with the same len as current_types, whose type is `current_types[0]`.
vec![current_types
Expand All @@ -95,7 +119,48 @@ fn get_valid_types(
TypeSignature::VariadicAny => {
vec![current_types.to_vec()]
}

TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::ArrayAppendLikeSignature => {
if current_types.len() != 2 {
return Ok(vec![vec![]]);
}

let array_type = &current_types[0];
let elem_type = &current_types[1];

// Special case for `array_append(Null, T)`, just return and process in physical expression step.
if array_type.eq(&DataType::Null) {
let array_type =
DataType::List(Arc::new(Field::new("item", elem_type.clone(), true)));
return Ok(vec![vec![array_type.to_owned(), elem_type.to_owned()]]);
}

// We need to find the coerced base type, mainly for cases like:
// `array_append(List(null), i64)` -> `List(i64)`
let array_base_type = datafusion_common::utils::base_type(array_type);
let elem_base_type = datafusion_common::utils::base_type(elem_type);
let new_base_type = comparison_coercion(&array_base_type, &elem_base_type);

if new_base_type.is_none() {
return internal_err!(
"Coercion from {array_base_type:?} to {elem_base_type:?} not supported."
);
}
let new_base_type = new_base_type.unwrap();

let array_type = datafusion_common::utils::coerced_type_with_base_type_only(
array_type,
&new_base_type,
);

if let DataType::List(ref field) = array_type {
let elem_type = field.data_type();
return Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]);
} else {
return Ok(vec![vec![]]);
}
}
TypeSignature::Any(number) => {
if current_types.len() != *number {
return plan_err!(
Expand Down Expand Up @@ -241,6 +306,15 @@ fn coerced_from<'a>(
Utf8 | LargeUtf8 => Some(type_into.clone()),
Null if can_cast_types(type_from, type_into) => Some(type_into.clone()),

// Only accept list with the same number of dimensions unless the type is Null.
// List with different dimensions should be handled in TypeSignature or other places before this.
List(_)
if datafusion_common::utils::base_type(type_from).eq(&Null)
|| list_ndims(type_from) == list_ndims(type_into) =>
{
Some(type_into.clone())
}

Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => {
match type_from {
Timestamp(_, Some(from_tz)) => {
Expand Down
34 changes: 0 additions & 34 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,26 +590,6 @@ fn coerce_arguments_for_fun(
.collect::<Result<Vec<_>>>()?;
}

if *fun == BuiltinScalarFunction::MakeArray {
// Find the final data type for the function arguments
let current_types = expressions
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

let new_type = current_types
.iter()
.skip(1)
.fold(current_types.first().unwrap().clone(), |acc, x| {
comparison_coercion(&acc, x).unwrap_or(acc)
});

return expressions
.iter()
.zip(current_types)
.map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema))
.collect();
}
Ok(expressions)
}

Expand All @@ -618,20 +598,6 @@ fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result<Expr>
expr.clone().cast_to(to_type, schema)
}

/// Cast array `expr` to the specified type, if possible
fn cast_array_expr(
expr: &Expr,
from_type: &DataType,
to_type: &DataType,
schema: &DFSchema,
) -> Result<Expr> {
if from_type.equals_datatype(&DataType::Null) {
Ok(expr.clone())
} else {
cast_expr(expr, to_type, schema)
}
}

/// Returns the coerced exprs for each `input_exprs`.
/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the
/// data type of `input_exprs` need to be coerced.
Expand Down
10 changes: 8 additions & 2 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result<ArrayRef> {
match data_type {
// Either an empty array or all nulls:
DataType::Null => {
let array = new_null_array(&DataType::Null, arrays.len());
let array =
new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum());
Ok(Arc::new(array_into_list_array(array)))
}
DataType::LargeList(..) => array_array::<i64>(arrays, data_type),
Expand Down Expand Up @@ -763,7 +764,12 @@ pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
check_datatypes("array_append", &[list_array.values(), element_array])?;
let res = match list_array.value_type() {
DataType::List(_) => concat_internal(args)?,
DataType::Null => return make_array(&[element_array.to_owned()]),
DataType::Null => {
return make_array(&[
list_array.values().to_owned(),
element_array.to_owned(),
]);
}
data_type => {
return general_append_and_prepend(
list_array,
Expand Down
35 changes: 21 additions & 14 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,8 @@ AS VALUES
(make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10)
;

query ?
query error
select [1, true, null]
----
[1, 1, ]

query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now()
SELECT [now()]
Expand Down Expand Up @@ -1092,18 +1090,27 @@ select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3

## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`)

# TODO: array_append with NULLs
# array_append scalar function #1
# query ?
# select array_append(make_array(), 4);
# ----
# [4]
# array_append with NULLs

# array_append scalar function #2
# query ??
# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4));
# ----
# [[]] [[4]]
query ???????
select
array_append(null, 1),
array_append(null, [2, 3]),
array_append(null, [[4]]),
array_append(make_array(), 4),
array_append(make_array(), null),
array_append(make_array(1, null, 3), 4),
array_append(make_array(null, null), 1)
;
----
[1] [[2, 3]] [[[4]]] [4] [] [1, , 3, 4] [, , 1]

query ??
select
array_append(make_array(make_array(1, null, 3)), make_array(null)),
array_append(make_array(make_array(1, null, 3)), null);
----
[[1, , 3], []] [[1, , 3], ]

# array_append scalar function #3
query ???
Expand Down

0 comments on commit 3ba90b4

Please sign in to comment.