Skip to content

Commit

Permalink
simplify the code
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H committed Dec 1, 2023
1 parent 243e2ed commit fcd945f
Showing 1 changed file with 58 additions and 46 deletions.
104 changes: 58 additions & 46 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1759,39 +1759,67 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}

/// general internal function for array_has_all and array_has_all
/// if is_all is true, then it is array_has_all, otherwise it is array_has_any
/// Represents the type of comparison for array_has.
#[derive(Debug, PartialEq)]
enum ComparisonType {
// array_has_all
All,
// array_has_any
Any,
// array_has
Single,
}

fn general_array_has_dispatch<O: OffsetSizeTrait>(
array: &ArrayRef,
sub_array: &ArrayRef,
is_all: bool,
comparison_type: ComparisonType,
) -> Result<ArrayRef> {
check_datatypes("array_has", &[array, sub_array])?;

let array = as_generic_list_array::<O>(array)?;
let sub_array = as_generic_list_array::<O>(&sub_array)?;
let array = if comparison_type == ComparisonType::Single {
let arr = as_generic_list_array::<O>(array)?;
check_datatypes("array_has", &[arr.values(), sub_array])?;
arr
} else {
check_datatypes("array_has", &[array, sub_array])?;
as_generic_list_array::<O>(array)?
};

let mut boolean_builder = BooleanArray::builder(array.len());

let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
for (arr, sub_arr) in array.iter().zip(sub_array.iter()) {

let element = sub_array.clone();
let sub_array = if comparison_type != ComparisonType::Single {
as_generic_list_array::<O>(sub_array)?
} else {
array
};

for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = converter.convert_columns(&[sub_arr])?;
let sub_arr_values = if comparison_type != ComparisonType::Single {
converter.convert_columns(&[sub_arr])?
} else {
converter.convert_columns(&[element.clone()])?
};

let mut res = if is_all {
sub_arr_values
let mut res = match comparison_type {
ComparisonType::All => sub_arr_values
.iter()
.dedup()
.all(|elem| arr_values.iter().dedup().any(|x| x == elem))
} else {
sub_arr_values
.all(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Any => sub_arr_values
.iter()
.dedup()
.any(|elem| arr_values.iter().dedup().any(|x| x == elem))
.any(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Single => arr_values
.iter()
.dedup()
.any(|x| x == sub_arr_values.row(row_idx)),
};

if is_all {
if comparison_type == ComparisonType::Any {
res |= res;
}

Expand All @@ -1808,34 +1836,14 @@ pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {
let element = &args[1];

match array_type {
DataType::List(_) => array_has_dispatch::<i32>(array, element),
DataType::LargeList(_) => array_has_dispatch::<i64>(array, element),
_ => internal_err!("array_has does not support type '{array_type:?}'."),
}
}

fn array_has_dispatch<O: OffsetSizeTrait>(
array: &ArrayRef,
element: &ArrayRef,
) -> Result<ArrayRef> {
let array = as_generic_list_array::<O>(array)?;

check_datatypes("array_has", &[array.values(), element])?;
let mut boolean_builder = BooleanArray::builder(array.len());

let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
let r_values = converter.convert_columns(&[element.clone()])?;
for (row_idx, arr) in array.iter().enumerate() {
if let Some(arr) = arr {
let arr_values = converter.convert_columns(&[arr])?;
let res = arr_values
.iter()
.dedup()
.any(|x| x == r_values.row(row_idx));
boolean_builder.append_value(res);
DataType::List(_) => {
general_array_has_dispatch::<i32>(array, element, ComparisonType::Single)
}
DataType::LargeList(_) => {
general_array_has_dispatch::<i64>(array, element, ComparisonType::Single)
}
_ => internal_err!("array_has does not support type '{array_type:?}'."),
}
Ok(Arc::new(boolean_builder.finish()))
}

/// Array_has_any SQL function
Expand All @@ -1845,9 +1853,11 @@ pub fn array_has_any(args: &[ArrayRef]) -> Result<ArrayRef> {
let sub_array = &args[1];

match array_type {
DataType::List(_) => general_array_has_dispatch::<i32>(array, sub_array, false),
DataType::List(_) => {
general_array_has_dispatch::<i32>(array, sub_array, ComparisonType::Any)
}
DataType::LargeList(_) => {
general_array_has_dispatch::<i64>(array, sub_array, false)
general_array_has_dispatch::<i64>(array, sub_array, ComparisonType::Any)
}
_ => internal_err!("array_has_any does not support type '{array_type:?}'."),
}
Expand All @@ -1860,9 +1870,11 @@ pub fn array_has_all(args: &[ArrayRef]) -> Result<ArrayRef> {
let sub_array = &args[1];

match array_type {
DataType::List(_) => general_array_has_dispatch::<i32>(array, sub_array, true),
DataType::List(_) => {
general_array_has_dispatch::<i32>(array, sub_array, ComparisonType::All)
}
DataType::LargeList(_) => {
general_array_has_dispatch::<i64>(array, sub_array, true)
general_array_has_dispatch::<i64>(array, sub_array, ComparisonType::All)
}
_ => internal_err!("array_has_all does not support type '{array_type:?}'."),
}
Expand Down

0 comments on commit fcd945f

Please sign in to comment.