Skip to content

Commit

Permalink
Implementation of array_intersect (#8081)
Browse files Browse the repository at this point in the history
* Initial Implementation of array_intersect

Signed-off-by: veeupup <[email protected]>

* fix comments

Signed-off-by: veeupup <[email protected]>

x

---------

Signed-off-by: veeupup <[email protected]>
  • Loading branch information
Veeupup authored Nov 11, 2023
1 parent 4e8777d commit 8966dc0
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 6 deletions.
6 changes: 6 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ pub enum BuiltinScalarFunction {
ArraySlice,
/// array_to_string
ArrayToString,
/// array_intersect
ArrayIntersect,
/// cardinality
Cardinality,
/// construct an array from columns
Expand Down Expand Up @@ -398,6 +400,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Flatten => Volatility::Immutable,
BuiltinScalarFunction::ArraySlice => Volatility::Immutable,
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
BuiltinScalarFunction::Ascii => Volatility::Immutable,
Expand Down Expand Up @@ -577,6 +580,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::Cardinality => Ok(UInt64),
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
Expand Down Expand Up @@ -880,6 +884,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayToString => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
BuiltinScalarFunction::MakeArray => {
// 0 or more arguments of arbitrary type
Expand Down Expand Up @@ -1505,6 +1510,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
],
BuiltinScalarFunction::Cardinality => &["cardinality"],
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"],

// struct functions
BuiltinScalarFunction::Struct => &["struct"],
Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,12 @@ nary_scalar_expr!(
array,
"returns an Arrow array using the specified input expressions."
);
scalar_expr!(
ArrayIntersect,
array_intersect,
first_array second_array,
"Returns an array of the elements in the intersection of array1 and array2."
);

// string functions
scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character");
Expand Down
69 changes: 63 additions & 6 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use arrow::array::*;
use arrow::buffer::OffsetBuffer;
use arrow::compute;
use arrow::datatypes::{DataType, Field, UInt64Type};
use arrow::row::{RowConverter, SortField};
use arrow_buffer::NullBuffer;

use datafusion_common::cast::{
Expand All @@ -35,6 +36,7 @@ use datafusion_common::{
DataFusionError, Result,
};

use hashbrown::HashSet;
use itertools::Itertools;

macro_rules! downcast_arg {
Expand Down Expand Up @@ -347,7 +349,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
let data_type = arrays[0].data_type();
let field = Arc::new(Field::new("item", data_type.to_owned(), true));
let elements = arrays.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
let values = arrow::compute::concat(elements.as_slice())?;
let values = compute::concat(elements.as_slice())?;
let list_arr = ListArray::new(
field,
OffsetBuffer::from_lengths(array_lengths),
Expand All @@ -368,7 +370,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
.iter()
.map(|x| x as &dyn Array)
.collect::<Vec<_>>();
let values = arrow::compute::concat(elements.as_slice())?;
let values = compute::concat(elements.as_slice())?;
let list_arr = ListArray::new(
field,
OffsetBuffer::from_lengths(list_array_lengths),
Expand Down Expand Up @@ -767,7 +769,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
.collect::<Vec<&dyn Array>>();

// Concatenated array on i-th row
let concated_array = arrow::compute::concat(elements.as_slice())?;
let concated_array = compute::concat(elements.as_slice())?;
array_lengths.push(concated_array.len());
arrays.push(concated_array);
valid.append(true);
Expand All @@ -785,7 +787,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_arr = ListArray::new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::from_lengths(array_lengths),
Arc::new(arrow::compute::concat(elements.as_slice())?),
Arc::new(compute::concat(elements.as_slice())?),
Some(NullBuffer::new(buffer)),
);
Ok(Arc::new(list_arr))
Expand Down Expand Up @@ -879,7 +881,7 @@ fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result<ArrayRef
}

let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
let values = arrow::compute::concat(&new_values)?;
let values = compute::concat(&new_values)?;

Ok(Arc::new(ListArray::try_new(
Arc::new(Field::new("item", data_type.to_owned(), true)),
Expand Down Expand Up @@ -947,7 +949,7 @@ fn general_list_repeat(

let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>();
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
let values = arrow::compute::concat(&new_values)?;
let values = compute::concat(&new_values)?;

Ok(Arc::new(ListArray::try_new(
Arc::new(Field::new("item", data_type.to_owned(), true)),
Expand Down Expand Up @@ -1798,6 +1800,61 @@ pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef
Ok(Arc::new(list_array) as ArrayRef)
}

/// array_intersect SQL function
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
assert_eq!(args.len(), 2);

let first_array = as_list_array(&args[0])?;
let second_array = as_list_array(&args[1])?;

if first_array.value_type() != second_array.value_type() {
return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'");
}
let dt = first_array.value_type().clone();

let mut offsets = vec![0];
let mut new_arrays = vec![];

let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
let l_values = converter.convert_columns(&[first_arr])?;
let r_values = converter.convert_columns(&[second_arr])?;

let values_set: HashSet<_> = l_values.iter().collect();
let mut rows = Vec::with_capacity(r_values.num_rows());
for r_val in r_values.iter().sorted().dedup() {
if values_set.contains(&r_val) {
rows.push(r_val);
}
}

let last_offset: i32 = match offsets.last().copied() {
Some(offset) => offset,
None => return internal_err!("offsets should not be empty"),
};
offsets.push(last_offset + rows.len() as i32);
let arrays = converter.convert_rows(rows)?;
let array = match arrays.get(0) {
Some(array) => array.clone(),
None => {
return internal_err!(
"array_intersect: failed to get array from rows"
)
}
};
new_arrays.push(array);
}
}

let field = Arc::new(Field::new("item", dt, true));
let offsets = OffsetBuffer::new(offsets.into());
let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
let values = compute::concat(&new_arrays_ref)?;
let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?);
Ok(arr)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayToString => Arc::new(|args| {
make_scalar_function(array_expressions::array_to_string)(args)
}),
BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| {
make_scalar_function(array_expressions::array_intersect)(args)
}),
BuiltinScalarFunction::Cardinality => {
Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ enum ScalarFunction {
ArrayPopBack = 116;
StringToArray = 117;
ToTimestampNanos = 118;
ArrayIntersect = 119;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
ScalarFunction::ArraySlice => Self::ArraySlice,
ScalarFunction::ArrayToString => Self::ArrayToString,
ScalarFunction::ArrayIntersect => Self::ArrayIntersect,
ScalarFunction::Cardinality => Self::Cardinality,
ScalarFunction::Array => Self::MakeArray,
ScalarFunction::NullIf => Self::NullIf,
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
BuiltinScalarFunction::ArraySlice => Self::ArraySlice,
BuiltinScalarFunction::ArrayToString => Self::ArrayToString,
BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect,
BuiltinScalarFunction::Cardinality => Self::Cardinality,
BuiltinScalarFunction::MakeArray => Self::Array,
BuiltinScalarFunction::NullIf => Self::NullIf,
Expand Down
Loading

0 comments on commit 8966dc0

Please sign in to comment.