From fcd945f27c0c5b971b0cd8d7baf18ad194b8907d Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Fri, 1 Dec 2023 09:39:21 +0100 Subject: [PATCH] simplify the code --- .../physical-expr/src/array_expressions.rs | 104 ++++++++++-------- 1 file changed, 58 insertions(+), 46 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index e053bd7cd2dfa..3f7845b2210a3 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1759,39 +1759,67 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// general internal function for array_has_all and array_has_all -/// if is_all is true, then it is array_has_all, otherwise it is array_has_any +/// Represents the type of comparison for array_has. +#[derive(Debug, PartialEq)] +enum ComparisonType { + // array_has_all + All, + // array_has_any + Any, + // array_has + Single, +} + fn general_array_has_dispatch( array: &ArrayRef, sub_array: &ArrayRef, - is_all: bool, + comparison_type: ComparisonType, ) -> Result { - check_datatypes("array_has", &[array, sub_array])?; - - let array = as_generic_list_array::(array)?; - let sub_array = as_generic_list_array::(&sub_array)?; + let array = if comparison_type == ComparisonType::Single { + let arr = as_generic_list_array::(array)?; + check_datatypes("array_has", &[arr.values(), sub_array])?; + arr + } else { + check_datatypes("array_has", &[array, sub_array])?; + as_generic_list_array::(array)? + }; let mut boolean_builder = BooleanArray::builder(array.len()); let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { + + let element = sub_array.clone(); + let sub_array = if comparison_type != ComparisonType::Single { + as_generic_list_array::(sub_array)? + } else { + array + }; + + for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() { if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; + let sub_arr_values = if comparison_type != ComparisonType::Single { + converter.convert_columns(&[sub_arr])? + } else { + converter.convert_columns(&[element.clone()])? + }; - let mut res = if is_all { - sub_arr_values + let mut res = match comparison_type { + ComparisonType::All => sub_arr_values .iter() .dedup() - .all(|elem| arr_values.iter().dedup().any(|x| x == elem)) - } else { - sub_arr_values + .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Any => sub_arr_values .iter() .dedup() - .any(|elem| arr_values.iter().dedup().any(|x| x == elem)) + .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Single => arr_values + .iter() + .dedup() + .any(|x| x == sub_arr_values.row(row_idx)), }; - if is_all { + if comparison_type == ComparisonType::Any { res |= res; } @@ -1808,34 +1836,14 @@ pub fn array_has(args: &[ArrayRef]) -> Result { let element = &args[1]; match array_type { - DataType::List(_) => array_has_dispatch::(array, element), - DataType::LargeList(_) => array_has_dispatch::(array, element), - _ => internal_err!("array_has does not support type '{array_type:?}'."), - } -} - -fn array_has_dispatch( - array: &ArrayRef, - element: &ArrayRef, -) -> Result { - let array = as_generic_list_array::(array)?; - - check_datatypes("array_has", &[array.values(), element])?; - let mut boolean_builder = BooleanArray::builder(array.len()); - - let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - let r_values = converter.convert_columns(&[element.clone()])?; - for (row_idx, arr) in array.iter().enumerate() { - if let Some(arr) = arr { - let arr_values = converter.convert_columns(&[arr])?; - let res = arr_values - .iter() - .dedup() - .any(|x| x == r_values.row(row_idx)); - boolean_builder.append_value(res); + DataType::List(_) => { + general_array_has_dispatch::(array, element, ComparisonType::Single) + } + DataType::LargeList(_) => { + general_array_has_dispatch::(array, element, ComparisonType::Single) } + _ => internal_err!("array_has does not support type '{array_type:?}'."), } - Ok(Arc::new(boolean_builder.finish())) } /// Array_has_any SQL function @@ -1845,9 +1853,11 @@ pub fn array_has_any(args: &[ArrayRef]) -> Result { let sub_array = &args[1]; match array_type { - DataType::List(_) => general_array_has_dispatch::(array, sub_array, false), + DataType::List(_) => { + general_array_has_dispatch::(array, sub_array, ComparisonType::Any) + } DataType::LargeList(_) => { - general_array_has_dispatch::(array, sub_array, false) + general_array_has_dispatch::(array, sub_array, ComparisonType::Any) } _ => internal_err!("array_has_any does not support type '{array_type:?}'."), } @@ -1860,9 +1870,11 @@ pub fn array_has_all(args: &[ArrayRef]) -> Result { let sub_array = &args[1]; match array_type { - DataType::List(_) => general_array_has_dispatch::(array, sub_array, true), + DataType::List(_) => { + general_array_has_dispatch::(array, sub_array, ComparisonType::All) + } DataType::LargeList(_) => { - general_array_has_dispatch::(array, sub_array, true) + general_array_has_dispatch::(array, sub_array, ComparisonType::All) } _ => internal_err!("array_has_all does not support type '{array_type:?}'."), }