From 9b050db59bc58b1e747bdafe62908a6b9ca10a44 Mon Sep 17 00:00:00 2001 From: veeupup Date: Wed, 8 Nov 2023 00:11:53 +0800 Subject: [PATCH] Initial Implementation of array_intersect Signed-off-by: veeupup --- .../tests/sqllogictests/test_files/array.slt | 5 + datafusion/expr/src/built_in_function.rs | 33 +++ datafusion/expr/src/expr_fn.rs | 6 + .../physical-expr/src/array_expressions.rs | 206 ++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 3 + datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/prost.rs | 2 + .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/logical_plan/to_proto.rs | 1 + docs/source/user-guide/expressions.md | 1 + 10 files changed, 259 insertions(+) diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index 2a1add0b13ba0..304033f12f093 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -1695,6 +1695,11 @@ select array_has_all(make_array(1,2,3), make_array(1,3)), ---- true false true false false false true true false false true false true +query ?? +SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), array_intersect(make_array(1,3,5), make_array(2,4,6)); +---- +[2, 3] [] + query BBBB select list_has_all(make_array(1,2,3), make_array(4,5,6)), list_has_all(make_array(1,2,3), make_array(1,2)), diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 33db0f9eb1a4e..00287ec1f9c20 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -153,6 +153,8 @@ pub enum BuiltinScalarFunction { ArrayReplaceAll, /// array_to_string ArrayToString, + /// array_intersect + ArrayIntersect, /// cardinality Cardinality, /// construct an array from columns @@ -359,6 +361,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable, BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable, BuiltinScalarFunction::ArrayToString => Volatility::Immutable, + BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, BuiltinScalarFunction::TrimArray => Volatility::Immutable, @@ -543,6 +546,34 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), + BuiltinScalarFunction::ArrayIntersect => { + if input_expr_types.len() < 2 || input_expr_types.len() > 2 { + Err(DataFusionError::Internal(format!( + "The {self} function must have two arrays as parameters" + ))) + } else { + match (&input_expr_types[0], &input_expr_types[1]) { + (List(l_field), List(r_field)) => { + if !l_field.data_type().equals_datatype(r_field.data_type()) { + Err(DataFusionError::Internal(format!( + "The {self} function array data type not equal, [0]: {:?}, [1]: {:?}", + l_field.data_type(), r_field.data_type() + ))) + } else { + Ok(List(Arc::new(Field::new( + "item", + l_field.data_type().clone(), + true, + )))) + } + } + _ => Err(DataFusionError::Internal(format!( + "The {} parameters should be array, [0]: {:?}, [1]: {:?}", + self, input_expr_types[0], input_expr_types[1] + ))), + } + } + } BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), @@ -834,6 +865,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayToString => { Signature::variadic_any(self.volatility()) } + BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), BuiltinScalarFunction::MakeArray => { Signature::variadic_any(self.volatility()) @@ -1324,6 +1356,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Cardinality => &["cardinality"], BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], BuiltinScalarFunction::TrimArray => &["trim_array"], + BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_interact"], } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index cb5317da4408b..a97e6773ee3e7 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -654,6 +654,12 @@ scalar_expr!( array n, "removes the last n elements from the array." ); +scalar_expr!( + ArrayIntersect, + array_intersect, + first_array second_array, + "Returns an array of the elements in the intersection of array1 and array2." +); // string functions scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 01b9ac95b4a0a..fe6ad4d18a57a 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -27,6 +27,7 @@ use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_list_a use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; +use hashbrown::{HashMap, HashSet}; use itertools::Itertools; use std::sync::Arc; @@ -1820,6 +1821,211 @@ pub fn array_has_all(args: &[ArrayRef]) -> Result { Ok(Arc::new(boolean_builder.finish())) } +macro_rules! array_intersect_normal { + ($FIRST_ARRAY:expr, $SECOND_ARRAY:expr, $DATA_TYPE:expr, $ARRAY_TYPE:ident) => {{ + let mut offsets: Vec = vec![0]; + let mut values = + downcast_arg!(new_empty_array(&$DATA_TYPE), $ARRAY_TYPE).clone(); + + for (first_arr, second_arr) in $FIRST_ARRAY.iter().zip($SECOND_ARRAY.iter()) { + let last_offset: i32 = offsets.last().copied().ok_or_else(|| { + DataFusionError::Internal(format!("offsets should not be empty")) + })?; + match (first_arr, second_arr) { + (Some(first_arr), Some(second_arr)) => { + let first_arr = downcast_arg!(first_arr, $ARRAY_TYPE); + // TODO(veeupup): maybe use stack-implemented map to avoid heap memory allocation + let first_set = first_arr.iter().dedup().flatten().collect::>(); + let second_arr = downcast_arg!(second_arr, $ARRAY_TYPE); + + let mut builder = $ARRAY_TYPE::builder(first_arr.len().min(second_arr.len())); + for elem in second_arr.iter().dedup().flatten() { + if first_set.contains(&elem) { + builder.append_value(elem); + } + } + + let arr = builder.finish(); + values = downcast_arg!( + compute::concat(&[ + &values, + &arr + ])? + .clone(), + $ARRAY_TYPE + ) + .clone(); + offsets.push(last_offset + arr.len() as i32); + }, + _ => { + todo!() + } + } + } + let field = Arc::new(Field::new("item", $DATA_TYPE, true)); + + Ok(Arc::new(ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + )?)) + + }}; +} + +/// array_intersect SQL function +pub fn array_intersect(args: &[ArrayRef]) -> Result { + assert_eq!(args.len(), 2); + + let first_array = as_list_array(&args[0])?; + let second_array = as_list_array(&args[1])?; + + // write array interact method + + match (first_array.value_type(), second_array.value_type()) { + // (DataType::List(_), DataType::List(_)) => concat_internal(args)?, + // (DataType::Utf8, DataType::Utf8) => array_intersect_normal!(arr, element, StringArray), + // (DataType::LargeUtf8, DataType::LargeUtf8) => array_intersect_normal!(arr, element, LargeStringArray), + // (DataType::Boolean, DataType::Boolean) => array_intersect_normal!(arr, element, BooleanArray), + // (DataType::Float32, DataType::Float32) => array_intersect_normal!(arr, element, Float32Array), + // (DataType::Float64, DataType::Float64) => array_intersect_normal!(arr, element, Float64Array), + (DataType::Int8, DataType::Int8) => array_intersect_normal!(first_array, second_array, DataType::Int8, Int8Array), + (DataType::Int16, DataType::Int16) => array_intersect_normal!(first_array, second_array, DataType::Int16, Int16Array), + (DataType::Int32, DataType::Int32) => array_intersect_normal!(first_array, second_array, DataType::Int32, Int32Array), + (DataType::Int64, DataType::Int64) => array_intersect_normal!(first_array, second_array, DataType::Int64, Int64Array), + (DataType::UInt8, DataType::UInt8) => array_intersect_normal!(first_array, second_array, DataType::UInt8, UInt8Array), + (DataType::UInt16, DataType::UInt16) => array_intersect_normal!(first_array, second_array, DataType::UInt16, UInt16Array), + (DataType::UInt32, DataType::UInt32) => array_intersect_normal!(first_array, second_array, DataType::UInt32, UInt32Array), + (DataType::UInt64, DataType::UInt64) => array_intersect_normal!(first_array, second_array, DataType::UInt64, UInt64Array), + // (DataType::Null, _) => return Ok(array(&[ColumnarValue::Array(args[1].clone())])?.into_array(1)), + (DataType::Int64, DataType::Int64) => { + let mut offsets: Vec = vec![0]; + let mut values = + downcast_arg!(new_empty_array(&DataType::Int64), Int64Array).clone(); + + for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { + let last_offset: i32 = offsets.last().copied().ok_or_else(|| { + DataFusionError::Internal(format!("offsets should not be empty")) + })?; + match (first_arr, second_arr) { + (Some(first_arr), Some(second_arr)) => { + let first_arr = downcast_arg!(first_arr, Int64Array); + let first_set = first_arr.iter().dedup().flatten().collect::>(); + println!("{:?}", first_set); + let second_arr = downcast_arg!(second_arr, Int64Array); + print!("{:?}", second_arr); + + let mut builder = Int64Array::builder(first_arr.len().min(second_arr.len())); + for elem in second_arr.iter().dedup().flatten() { + println!("second_arr: {:?}", elem); + if first_set.contains(&elem) { + builder.append_value(elem); + } + } + + let arr = builder.finish(); + values = downcast_arg!( + compute::concat(&[ + &values, + &arr + ])? + .clone(), + Int64Array + ) + .clone(); + offsets.push(last_offset + arr.len() as i32); + }, + _ => { + todo!() + } + } + } + let field = Arc::new(Field::new("item", DataType::Int64, true)); + + Ok(Arc::new(ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + )?)) + }, + (first_value_dt, second_value_dt) => { + Err(DataFusionError::NotImplemented(format!( + "array_intersect is not implemented for '{first_value_dt:?}' and '{second_value_dt:?}'", + ))) + } + } + // for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { + // if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { + // let ret: ArrayRef = match (first_arr.data_type(), second_arr.data_type()) { + // (DataType::List(_), DataType::List(_)) => { + // todo!() + // }, + // // Int64, Int32, Int16, Int8 + // // UInt64, UInt32, UInt16, UInt8 + // (DataType::Int64, DataType::Int64) => { + // // array_intersect_non_list_check!(first_arr, second_arr, Int64Array) + // let first_arr = downcast_arg!(first_array, Int64Array); + // let second_arr = downcast_arg!(second_array, Int64Array); + + // let mut offsets: Vec = vec![0]; + // let mut values = + // Int64Array::builder(first_arr.len().min(second_arr.len())); + // let first_set = first_arr.iter().dedup().flatten().collect::>(); + // for elem in second_arr.iter().dedup().flatten() { + // if first_set.contains(&elem) { + // let last_offset: i32 = offsets.last().copied().ok_or_else(|| { + // DataFusionError::Internal(format!("offsets should not be empty")) + // })?; + // values.append_value(elem); + // offsets.push(last_offset + 1); + // } + // } + + // let field = Arc::new(Field::new("item", DataType::Int64, true)); + + // Arc::new(ListArray::try_new( + // field, + // OffsetBuffer::new(offsets.into()), + // Arc::new(values.finish()), + // None, + // )?) + // }, + // (DataType::Int32, DataType::Int32) => { + // array_intersect_non_list_check!(first_arr, second_arr, Int32Array) + // } + // (DataType::Int16, DataType::Int16) => { + // array_intersect_non_list_check!(first_arr, second_arr, Int16Array) + // } + // (DataType::Int8, DataType::Int8) => { + // array_intersect_non_list_check!(first_arr, second_arr, Int8Array) + // } + // (DataType::UInt64, DataType::UInt64) => { + // array_intersect_non_list_check!(first_arr, second_arr, UInt64Array) + // } + // (DataType::UInt32, DataType::UInt32) => { + // array_intersect_non_list_check!(first_arr, second_arr, UInt32Array) + // } + // (DataType::UInt16, DataType::UInt16) => { + // array_intersect_non_list_check!(first_arr, second_arr, UInt16Array) + // } + // (DataType::UInt8, DataType::UInt8) => { + // array_intersect_non_list_check!(first_arr, second_arr, UInt8Array) + // } + + // (first_arr_type, second_arr_type) => Err(DataFusionError::NotImplemented(format!( + // "array_intersect is not implemented for '{first_arr_type:?}' and '{second_arr_type:?}'", + // )))?, + // }; + // return Ok(ret); + // } + // } + // Err(DataFusionError::Internal(format!( + // "array_intersect does not support Null type for element in sub_array" + // ))) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 948cb4ec47b0b..a6bca73b6e7d3 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -465,6 +465,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayToString => Arc::new(|args| { make_scalar_function(array_expressions::array_to_string)(args) }), + BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| { + make_scalar_function(array_expressions::array_intersect)(args) + }), BuiltinScalarFunction::Cardinality => { Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e9ae76b25d185..a4c1d0e86e7be 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -577,6 +577,7 @@ enum ScalarFunction { ArrayReplaceN = 108; ArrayRemoveAll = 109; ArrayReplaceAll = 110; + ArrayIntersect = 111; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c6f3a23ed65f2..b177a516b2f52 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2310,6 +2310,7 @@ pub enum ScalarFunction { ArrayReplaceN = 108, ArrayRemoveAll = 109, ArrayReplaceAll = 110, + ArrayIntersect = 111, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2428,6 +2429,7 @@ impl ScalarFunction { ScalarFunction::ArrayReplaceN => "ArrayReplaceN", ScalarFunction::ArrayRemoveAll => "ArrayRemoveAll", ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll", + ScalarFunction::ArrayIntersect => "ArrayIntersect", } } /// Creates an enum from field names used in the ProtoBuf definition. diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2591f179b99a1..84c1fde44b8f3 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -469,6 +469,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, ScalarFunction::ArrayToString => Self::ArrayToString, + ScalarFunction::ArrayIntersect => Self::ArrayIntersect, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, ScalarFunction::TrimArray => Self::TrimArray, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index df5701a282c32..226e16448ffd2 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1417,6 +1417,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, BuiltinScalarFunction::ArrayToString => Self::ArrayToString, + BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, BuiltinScalarFunction::TrimArray => Self::TrimArray, diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 139e968eccfba..fc0bc1815b1f2 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -200,6 +200,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_replace_n(array, from, to, max) | Replaces the first `max` occurrences of the specified element with another specified element. `array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2) -> [1, 5, 5, 3, 2, 1, 4]` | | array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` | | array_to_string(array, delimeter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | +| array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | | cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | | trim_array(array, n) | Removes the last n elements from the array. |