Skip to content

Commit

Permalink
Initial Implementation of array_intersect
Browse files Browse the repository at this point in the history
Signed-off-by: veeupup <[email protected]>
  • Loading branch information
Veeupup committed Nov 9, 2023
1 parent 4512805 commit 41663a9
Show file tree
Hide file tree
Showing 11 changed files with 258 additions and 14 deletions.
20 changes: 8 additions & 12 deletions datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1536,12 +1536,10 @@ mod test {
.unwrap()
.resolve(&schema)
.unwrap();
let r4 = apache_avro::to_value(serde_json::json!({
"col1": null
}))
.unwrap()
.resolve(&schema)
.unwrap();
let r4 = apache_avro::to_value(serde_json::json!({ "col1": null }))
.unwrap()
.resolve(&schema)
.unwrap();

let mut w = apache_avro::Writer::new(&schema, vec![]);
w.append(r1).unwrap();
Expand Down Expand Up @@ -1600,12 +1598,10 @@ mod test {
}"#,
)
.unwrap();
let r1 = apache_avro::to_value(serde_json::json!({
"col1": null
}))
.unwrap()
.resolve(&schema)
.unwrap();
let r1 = apache_avro::to_value(serde_json::json!({ "col1": null }))
.unwrap()
.resolve(&schema)
.unwrap();
let r2 = apache_avro::to_value(serde_json::json!({
"col1": {
"col2": "hello"
Expand Down
33 changes: 33 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ pub enum BuiltinScalarFunction {
ArraySlice,
/// array_to_string
ArrayToString,
/// array_intersect
ArrayIntersect,
/// cardinality
Cardinality,
/// construct an array from columns
Expand Down Expand Up @@ -398,6 +400,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Flatten => Volatility::Immutable,
BuiltinScalarFunction::ArraySlice => Volatility::Immutable,
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
BuiltinScalarFunction::Ascii => Volatility::Immutable,
Expand Down Expand Up @@ -577,6 +580,34 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
BuiltinScalarFunction::ArrayIntersect => {
if input_expr_types.len() < 2 || input_expr_types.len() > 2 {
Err(DataFusionError::Internal(format!(
"The {self} function must have two arrays as parameters"
)))
} else {
match (&input_expr_types[0], &input_expr_types[1]) {
(List(l_field), List(r_field)) => {
if !l_field.data_type().equals_datatype(r_field.data_type()) {
Err(DataFusionError::Internal(format!(
"The {self} function array data type not equal, [0]: {:?}, [1]: {:?}",
l_field.data_type(), r_field.data_type()
)))
} else {
Ok(List(Arc::new(Field::new(
"item",
l_field.data_type().clone(),
true,
))))
}
}
_ => Err(DataFusionError::Internal(format!(
"The {} parameters should be array, [0]: {:?}, [1]: {:?}",
self, input_expr_types[0], input_expr_types[1]
))),
}
}
}
BuiltinScalarFunction::Cardinality => Ok(UInt64),
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
Expand Down Expand Up @@ -880,6 +911,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayToString => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
BuiltinScalarFunction::MakeArray => {
// 0 or more arguments of arbitrary type
Expand Down Expand Up @@ -1505,6 +1537,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
],
BuiltinScalarFunction::Cardinality => &["cardinality"],
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_interact"],

// struct functions
BuiltinScalarFunction::Struct => &["struct"],
Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,12 @@ nary_scalar_expr!(
array,
"returns an Arrow array using the specified input expressions."
);
scalar_expr!(
ArrayIntersect,
array_intersect,
first_array second_array,
"Returns an array of the elements in the intersection of array1 and array2."
);

// string functions
scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character");
Expand Down
65 changes: 63 additions & 2 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ use std::any::type_name;
use std::sync::Arc;

use arrow::array::*;
use arrow::buffer::OffsetBuffer;
use arrow::compute;
use arrow::buffer::{Buffer, OffsetBuffer};
use arrow::compute::{self, concat};
use arrow::datatypes::{DataType, Field, UInt64Type};
use arrow::row::{RowConverter, SortField};
use arrow_buffer::NullBuffer;

use datafusion_common::cast::{
Expand All @@ -35,6 +36,9 @@ use datafusion_common::{
DataFusionError, Result,
};

use datafusion_common::ScalarValue;
use datafusion_expr::ColumnarValue;
use hashbrown::{HashMap, HashSet};
use itertools::Itertools;

macro_rules! downcast_arg {
Expand Down Expand Up @@ -1807,6 +1811,63 @@ pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef
Ok(Arc::new(list_array) as ArrayRef)
}

/// array_intersect SQL function
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
assert_eq!(args.len(), 2);

let first_array = as_list_array(&args[0])?;
let second_array = as_list_array(&args[1])?;

if first_array.value_type() != second_array.value_type() {
return Err(DataFusionError::NotImplemented(format!(
"array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'",
)));
}
let dt = first_array.value_type().clone();

let mut offsets = vec![0];
let mut tmp_values = vec![];

let mut converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
let l_values = converter.convert_columns(&[first_arr])?;
let r_values = converter.convert_columns(&[second_arr])?;

let mut values_set: HashSet<_> = l_values.iter().collect();
let mut rows = Vec::with_capacity(r_values.num_rows());
for r_val in r_values.iter().sorted().dedup() {
if values_set.contains(&r_val) {
rows.push(r_val);
}
}

let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
DataFusionError::Internal(format!("offsets should not be empty"))
})?;
offsets.push(last_offset + rows.len() as i32);
let tmp_value = converter.convert_rows(rows)?;
tmp_values.push(
tmp_value
.get(0)
.ok_or_else(|| {
DataFusionError::Internal(format!(
"array_intersect: failed to get value from rows"
))
})?
.clone(),
);
}
}

let field = Arc::new(Field::new("item", dt, true));
let offsets = OffsetBuffer::new(offsets.into());
let tmp_values_ref = tmp_values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
let values = concat(&tmp_values_ref)?;
let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?);
Ok(arr)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayToString => Arc::new(|args| {
make_scalar_function(array_expressions::array_to_string)(args)
}),
BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| {
make_scalar_function(array_expressions::array_intersect)(args)
}),
BuiltinScalarFunction::Cardinality => {
Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ enum ScalarFunction {
ArrayPopBack = 116;
StringToArray = 117;
ToTimestampNanos = 118;
ArrayIntersect = 119;
}

message ScalarFunctionNode {
Expand Down
2 changes: 2 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
ScalarFunction::ArraySlice => Self::ArraySlice,
ScalarFunction::ArrayToString => Self::ArrayToString,
ScalarFunction::ArrayIntersect => Self::ArrayIntersect,
ScalarFunction::Cardinality => Self::Cardinality,
ScalarFunction::Array => Self::MakeArray,
ScalarFunction::NullIf => Self::NullIf,
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
BuiltinScalarFunction::ArraySlice => Self::ArraySlice,
BuiltinScalarFunction::ArrayToString => Self::ArrayToString,
BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect,
BuiltinScalarFunction::Cardinality => Self::Cardinality,
BuiltinScalarFunction::MakeArray => Self::Array,
BuiltinScalarFunction::NullIf => Self::NullIf,
Expand Down
Loading

0 comments on commit 41663a9

Please sign in to comment.