Skip to content

Commit

Permalink
feat: support FixedSizeList Type Coercion (#9108)
Browse files Browse the repository at this point in the history
* support FixedSizeList Type Coercion

* add allow null type coercion parameter

* support null column in FixedSizeList

* Add test

* Add tests for cardinality with fixed size lists

* chore

* fix ci

* add comment

* Fix array_element function signature

* Remove unused imports and simplify code

* Fix array function signatures and behavior

* fix conflict

* fix conflict

* add tests for FixedSizeList

* remove unreacheable null check

* simplify the code

* remove null checking

* reformat output

* simplify code

* add tests for array_dims

* Refactor type coercion functions in datafusion/expr module
  • Loading branch information
Weijun-H authored Feb 26, 2024
1 parent 3050699 commit b728232
Show file tree
Hide file tree
Showing 4 changed files with 628 additions and 83 deletions.
23 changes: 12 additions & 11 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{
};

use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit};
use datafusion_common::{internal_err, plan_err, DataFusionError, Result};
use datafusion_common::{exec_err, plan_err, DataFusionError, Result};

use strum::IntoEnumIterator;
use strum_macros::EnumIter;
Expand Down Expand Up @@ -543,10 +543,11 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Flatten => {
fn get_base_type(data_type: &DataType) -> Result<DataType> {
match data_type {
DataType::List(field) if matches!(field.data_type(), DataType::List(_)) => get_base_type(field.data_type()),
DataType::List(field) | DataType::FixedSizeList(field, _) if matches!(field.data_type(), DataType::List(_)|DataType::FixedSizeList(_,_ )) => get_base_type(field.data_type()),
DataType::LargeList(field) if matches!(field.data_type(), DataType::LargeList(_)) => get_base_type(field.data_type()),
DataType::Null | DataType::List(_) | DataType::LargeList(_) => Ok(data_type.to_owned()),
_ => internal_err!("Not reachable, data_type should be List or LargeList"),
DataType::FixedSizeList(field,_ ) => Ok(DataType::List(field.clone())),
_ => exec_err!("Not reachable, data_type should be List, LargeList or FixedSizeList"),
}
}

Expand Down Expand Up @@ -929,18 +930,18 @@ impl BuiltinScalarFunction {
// 0 or more arguments of arbitrary type
Signature::one_of(vec![VariadicEqual, Any(0)], self.volatility())
}
BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayPopFront => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayPopBack => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayConcat => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayEmpty => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayDims => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayEmpty => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayElement => {
Signature::array_and_index(self.volatility())
}
BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()),
BuiltinScalarFunction::Flatten => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny => {
Signature::any(2, self.volatility())
}
Expand All @@ -950,8 +951,8 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayLength => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayDistinct => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayNdims => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayDistinct => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayPosition => {
Signature::array_and_element_and_optional_index(self.volatility())
}
Expand Down Expand Up @@ -981,7 +982,7 @@ impl BuiltinScalarFunction {

BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayResize => {
Signature::variadic_any(self.volatility())
}
Expand Down
17 changes: 16 additions & 1 deletion datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ pub enum TypeSignature {
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ArrayFunctionSignature {
/// Specialized Signature for ArrayAppend and similar functions
/// The first argument should be List/LargeList, and the second argument should be non-list or list.
/// The first argument should be List/LargeList/FixedSizedList, and the second argument should be non-list or list.
/// The second argument's list dimension should be one dimension less than the first argument's list dimension.
/// List dimension of the List/LargeList is equivalent to the number of List.
/// List dimension of the non-list is 0.
Expand All @@ -133,9 +133,14 @@ pub enum ArrayFunctionSignature {
/// The first argument's list dimension should be one dimension less than the second argument's list dimension.
ElementAndArray,
/// Specialized Signature for Array functions of the form (List/LargeList, Index)
/// The first argument should be List/LargeList/FixedSizedList, and the second argument should be Int64.
ArrayAndIndex,
/// Specialized Signature for Array functions of the form (List/LargeList, Element, Optional Index)
ArrayAndElementAndOptionalIndex,
/// Specialized Signature for ArrayEmpty and similar functions
/// The function takes a single argument that must be a List/LargeList/FixedSizeList
/// or something that can be coerced to one of those types.
Array,
}

impl std::fmt::Display for ArrayFunctionSignature {
Expand All @@ -153,6 +158,9 @@ impl std::fmt::Display for ArrayFunctionSignature {
ArrayFunctionSignature::ArrayAndIndex => {
write!(f, "array, index")
}
ArrayFunctionSignature::Array => {
write!(f, "array")
}
}
}
}
Expand Down Expand Up @@ -325,6 +333,13 @@ impl Signature {
volatility,
}
}
/// Specialized Signature for ArrayEmpty and similar functions
pub fn array(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
volatility,
}
}
}

/// Monotonicity of the `ScalarFunctionExpr` with respect to its arguments.
Expand Down
109 changes: 59 additions & 50 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,36 @@ fn get_valid_types(
signature: &TypeSignature,
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
fn array_element_and_optional_index(
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
// make sure there's 2 or 3 arguments
if !(current_types.len() == 2 || current_types.len() == 3) {
return Ok(vec![vec![]]);
}

let first_two_types = &current_types[0..2];
let mut valid_types = array_append_or_prepend_valid_types(first_two_types, true)?;

// Early return if there are only 2 arguments
if current_types.len() == 2 {
return Ok(valid_types);
}

let valid_types_with_index = valid_types
.iter()
.map(|t| {
let mut t = t.clone();
t.push(DataType::Int64);
t
})
.collect::<Vec<_>>();

valid_types.extend(valid_types_with_index);

Ok(valid_types)
}

fn array_append_or_prepend_valid_types(
current_types: &[DataType],
is_append: bool,
Expand Down Expand Up @@ -111,71 +141,37 @@ fn get_valid_types(
)
})?;

let array_type = datafusion_common::utils::coerced_type_with_base_type_only(
let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only(
array_type,
&new_base_type,
);

match array_type {
match new_array_type {
DataType::List(ref field)
| DataType::LargeList(ref field)
| DataType::FixedSizeList(ref field, _) => {
let elem_type = field.data_type();
let new_elem_type = field.data_type();
if is_append {
Ok(vec![vec![array_type.clone(), elem_type.clone()]])
Ok(vec![vec![new_array_type.clone(), new_elem_type.clone()]])
} else {
Ok(vec![vec![elem_type.to_owned(), array_type.clone()]])
Ok(vec![vec![new_elem_type.to_owned(), new_array_type.clone()]])
}
}
_ => Ok(vec![vec![]]),
}
}
fn array_element_and_optional_index(
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
// make sure there's 2 or 3 arguments
if !(current_types.len() == 2 || current_types.len() == 3) {
return Ok(vec![vec![]]);
}

let first_two_types = &current_types[0..2];
let mut valid_types = array_append_or_prepend_valid_types(first_two_types, true)?;

// Early return if there are only 2 arguments
if current_types.len() == 2 {
return Ok(valid_types);
}

let valid_types_with_index = valid_types
.iter()
.map(|t| {
let mut t = t.clone();
t.push(DataType::Int64);
t
})
.collect::<Vec<_>>();

valid_types.extend(valid_types_with_index);

Ok(valid_types)
}
fn array_and_index(current_types: &[DataType]) -> Result<Vec<Vec<DataType>>> {
if current_types.len() != 2 {
return Ok(vec![vec![]]);
}

let array_type = &current_types[0];

fn array(array_type: &DataType) -> Option<DataType> {
match array_type {
DataType::List(_)
| DataType::LargeList(_)
| DataType::FixedSizeList(_, _) => {
let array_type = coerced_fixed_size_list_to_list(array_type);
Ok(vec![vec![array_type, DataType::Int64]])
Some(array_type)
}
_ => Ok(vec![vec![]]),
_ => None,
}
}

let valid_types = match signature {
TypeSignature::Variadic(valid_types) => valid_types
.iter()
Expand Down Expand Up @@ -211,19 +207,32 @@ fn get_valid_types(
TypeSignature::ArraySignature(ref function_signature) => match function_signature
{
ArrayFunctionSignature::ArrayAndElement => {
return array_append_or_prepend_valid_types(current_types, true)
array_append_or_prepend_valid_types(current_types, true)?
}
ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => {
return array_element_and_optional_index(current_types)
ArrayFunctionSignature::ElementAndArray => {
array_append_or_prepend_valid_types(current_types, false)?
}
ArrayFunctionSignature::ArrayAndIndex => {
return array_and_index(current_types)
if current_types.len() != 2 {
return Ok(vec![vec![]]);
}
array(&current_types[0]).map_or_else(
|| vec![vec![]],
|array_type| vec![vec![array_type, DataType::Int64]],
)
}
ArrayFunctionSignature::ElementAndArray => {
return array_append_or_prepend_valid_types(current_types, false)
ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => {
array_element_and_optional_index(current_types)?
}
},
ArrayFunctionSignature::Array => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
}

array(&current_types[0])
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
}
},
TypeSignature::Any(number) => {
if current_types.len() != *number {
return plan_err!(
Expand Down
Loading

0 comments on commit b728232

Please sign in to comment.