Skip to content

Commit

Permalink
First attempt to implement non trivial case
Browse files Browse the repository at this point in the history
  • Loading branch information
edmondop committed Nov 7, 2023
1 parent 48d44e3 commit fafe3db
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 23 deletions.
7 changes: 1 addition & 6 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -704,12 +704,7 @@ scalar_expr!(
array delimiter,
"converts each element to its text representation."
);
scalar_expr!(
ArrayUnion,
array_union,
array1 array2,
"returns an array of the elements in the union of array1 and array2 without duplicates."
);
nary_scalar_expr!(ArrayUnion, array_union, "returns an array of the elements in the union of array1 and array2 without duplicates.");

scalar_expr!(
Cardinality,
Expand Down
86 changes: 71 additions & 15 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ use arrow::array::*;
use arrow::buffer::OffsetBuffer;
use arrow::compute;
use arrow::datatypes::{DataType, Field, UInt64Type};
use arrow::row::{RowConverter, SortField};
use arrow_buffer::NullBuffer;
use core::any::type_name;
use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_list_array};
use datafusion_common::{exec_err, internal_err, not_impl_err, plan_err, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use itertools::Itertools;
use std::collections::HashSet;
use std::sync::Arc;

macro_rules! downcast_arg {
Expand Down Expand Up @@ -1478,28 +1480,83 @@ macro_rules! to_string {
}};
}

fn union_generic_lists<OffsetSize: OffsetSizeTrait>(
l: &GenericListArray<OffsetSize>,
r: &GenericListArray<OffsetSize>,
) -> Result<GenericListArray<OffsetSize>, DataFusionError> {
let converter =
RowConverter::new(vec![SortField::new(l.value_type().clone())]).unwrap();
let mut dedup = HashSet::new();
let nulls = NullBuffer::union(l.nulls(), r.nulls());
let field = Arc::new(Field::new(
"item",
l.value_type().to_owned(),
l.is_nullable(),
));
let l_values = l.values().clone();
let r_values = r.values().clone();
let l_values = converter.convert_columns(&[l_values]).unwrap();
let r_values = converter.convert_columns(&[r_values]).unwrap();

// Might be worth adding an upstream OffsetBufferBuilder
let mut offsets = Vec::<OffsetSize>::with_capacity(l.len() + 1);
offsets.push(OffsetSize::usize_as(0));
let mut rows = Vec::with_capacity(l_values.num_rows() + r_values.num_rows());

for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) {
let l_slice = l_w[0].as_usize()..l_w[1].as_usize();
let r_slice = r_w[0].as_usize()..r_w[1].as_usize();
for i in l_slice {
dedup.insert(l_values.row(i));
}
for i in r_slice {
dedup.insert(r_values.row(i));
}

rows.extend(dedup.iter());
offsets.push(OffsetSize::usize_as(rows.len()));
dedup.clear();
}

let values = converter.convert_rows(rows).unwrap();
let offsets = OffsetBuffer::new(offsets.into());
let result = values[0].clone();
Ok(GenericListArray::<OffsetSize>::new(
field, offsets, result, nulls,
))
}

/// Array_union SQL function
pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_union needs two arguments");
}
let array1 = &args[0];
let array2= &args[1];

check_datatypes("array_union", &[array1, array2])?;
let list1 = as_list_array(array1)?;
let list2 = as_list_array(array2)?;
match (list1.value_type(), list2.value_type()){
(DataType::Null, _) => {
Ok(array2.clone())
},
(_, DataType::Null) => {
Ok(array1.clone())
let array2 = &args[1];
check_datatypes("array_union", &[&array1, &array2])?;
match (array1.data_type(), array2.data_type()) {
(DataType::Null, _) => Ok(array2.clone()),
(_, DataType::Null) => Ok(array1.clone()),
(DataType::List(_), DataType::List(_)) => {
let list1 = array1.as_list::<i32>();
let list2 = array2.as_list::<i32>();
let result = union_generic_lists::<i32>(list1, list2)?;
Ok(result.values().clone())
}
(DataType::LargeList(_), DataType::LargeList(_)) => {
let list1 = array1.as_list::<i64>();
let list2 = array2.as_list::<i64>();
let result = union_generic_lists::<i64>(list1, list2)?;
Ok(result.values().clone())
}
_ => {
return internal_err!(
"array_union only support list with offsets of type int32 and int64"
);
}
(DataType::List(_), DataType::List(_)) => concat_internal(args),
_ => return exec_err!("array_union can only concatenate lists")
}
}


/// Array_to_string SQL function
pub fn array_to_string(args: &[ArrayRef]) -> Result<ArrayRef> {
let arr = &args[0];
Expand Down Expand Up @@ -1611,7 +1668,6 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(StringArray::from(res)))
}


/// Cardinality SQL function
pub fn cardinality(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_list_array(&args[0])?.clone();
Expand Down
4 changes: 3 additions & 1 deletion datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::MakeArray => {
Arc::new(|args| make_scalar_function(array_expressions::make_array)(args))
}

BuiltinScalarFunction::ArrayUnion => {
Arc::new(|args| make_scalar_function(array_expressions::array_union)(args))
}
// struct functions
BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr),

Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1409,7 +1409,7 @@ pub fn parse_expr(
}
ScalarFunction::ArrayNdims => {
Ok(array_ndims(parse_expr(&args[0], registry)?))
},
}
ScalarFunction::ArrayUnion => Ok(array(
args.to_owned()
.iter()
Expand Down

0 comments on commit fafe3db

Please sign in to comment.