Skip to content

Commit

Permalink
Support array_distinct function. (#8268)
Browse files Browse the repository at this point in the history
* implement distinct func

implement slt & proto

fix null & empty list

* add comment for slt

Co-authored-by: Alex Huang <[email protected]>

* fix largelist

* add largelist for slt

* Use collect for rows & init capcity for offsets.

* fixup: remove useless match

* fix fmt

* fix fmt

---------

Co-authored-by: Alex Huang <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
3 people authored Dec 8, 2023
1 parent 047fb33 commit cd02c40
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 11 deletions.
6 changes: 6 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ pub enum BuiltinScalarFunction {
ArrayPopBack,
/// array_dims
ArrayDims,
/// array_distinct
ArrayDistinct,
/// array_element
ArrayElement,
/// array_empty
Expand Down Expand Up @@ -407,6 +409,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable,
BuiltinScalarFunction::ArrayHas => Volatility::Immutable,
BuiltinScalarFunction::ArrayDims => Volatility::Immutable,
BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable,
BuiltinScalarFunction::ArrayElement => Volatility::Immutable,
BuiltinScalarFunction::ArrayExcept => Volatility::Immutable,
BuiltinScalarFunction::ArrayLength => Volatility::Immutable,
Expand Down Expand Up @@ -586,6 +589,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayDims => {
Ok(List(Arc::new(Field::new("item", UInt64, true))))
}
BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] {
List(field) => Ok(field.data_type().clone()),
_ => plan_err!(
Expand Down Expand Up @@ -933,6 +937,7 @@ impl BuiltinScalarFunction {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayDistinct => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayPosition => {
Signature::variadic_any(self.volatility())
}
Expand Down Expand Up @@ -1570,6 +1575,7 @@ impl BuiltinScalarFunction {
&["array_concat", "array_cat", "list_concat", "list_cat"]
}
BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"],
BuiltinScalarFunction::ArrayDistinct => &["array_distinct", "list_distinct"],
BuiltinScalarFunction::ArrayEmpty => &["empty"],
BuiltinScalarFunction::ArrayElement => &[
"array_element",
Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,12 @@ scalar_expr!(
array,
"returns the number of dimensions of the array."
);
scalar_expr!(
ArrayDistinct,
array_distinct,
array,
"return distinct values from the array after removing duplicates."
);
scalar_expr!(
ArrayPosition,
array_position,
Expand Down
64 changes: 62 additions & 2 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ use arrow_buffer::NullBuffer;

use arrow_schema::{FieldRef, SortOptions};
use datafusion_common::cast::{
as_generic_list_array, as_generic_string_array, as_int64_array, as_list_array,
as_null_array, as_string_array,
as_generic_list_array, as_generic_string_array, as_int64_array, as_large_list_array,
as_list_array, as_null_array, as_string_array,
};
use datafusion_common::utils::{array_into_list_array, list_ndims};
use datafusion_common::{
Expand Down Expand Up @@ -2111,6 +2111,66 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

pub fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
array: &GenericListArray<OffsetSize>,
field: &FieldRef,
) -> Result<ArrayRef> {
let dt = array.value_type();
let mut offsets = Vec::with_capacity(array.len());
offsets.push(OffsetSize::usize_as(0));
let mut new_arrays = Vec::with_capacity(array.len());
let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
// distinct for each list in ListArray
for arr in array.iter().flatten() {
let values = converter.convert_columns(&[arr])?;
// sort elements in list and remove duplicates
let rows = values.iter().sorted().dedup().collect::<Vec<_>>();
let last_offset: OffsetSize = offsets.last().copied().unwrap();
offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
let arrays = converter.convert_rows(rows)?;
let array = match arrays.get(0) {
Some(array) => array.clone(),
None => {
return internal_err!("array_distinct: failed to get array from rows")
}
};
new_arrays.push(array);
}
let offsets = OffsetBuffer::new(offsets.into());
let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
let values = compute::concat(&new_arrays_ref)?;
Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
field.clone(),
offsets,
values,
None,
)?))
}

/// array_distinct SQL function
/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4]
pub fn array_distinct(args: &[ArrayRef]) -> Result<ArrayRef> {
assert_eq!(args.len(), 1);

// handle null
if args[0].data_type() == &DataType::Null {
return Ok(args[0].clone());
}

// handle for list & largelist
match args[0].data_type() {
DataType::List(field) => {
let array = as_list_array(&args[0])?;
general_array_distinct(array, field)
}
DataType::LargeList(field) => {
let array = as_large_list_array(&args[0])?;
general_array_distinct(array, field)
}
_ => internal_err!("array_distinct only support list array"),
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayDims => {
Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args))
}
BuiltinScalarFunction::ArrayDistinct => {
Arc::new(|args| make_scalar_function(array_expressions::array_distinct)(args))
}
BuiltinScalarFunction::ArrayElement => {
Arc::new(|args| make_scalar_function(array_expressions::array_element)(args))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ enum ScalarFunction {
SubstrIndex = 126;
FindInSet = 127;
ArraySort = 128;
ArrayDistinct = 129;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 13 additions & 9 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ use datafusion_common::{
};
use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by};
use datafusion_expr::{
abs, acos, acosh, array, array_append, array_concat, array_dims, array_element,
array_except, array_has, array_has_all, array_has_any, array_intersect, array_length,
array_ndims, array_position, array_positions, array_prepend, array_remove,
array_remove_all, array_remove_n, array_repeat, array_replace, array_replace_all,
array_replace_n, array_slice, array_sort, array_to_string, arrow_typeof, ascii, asin,
asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil,
character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot,
current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest,
encode, exp,
abs, acos, acosh, array, array_append, array_concat, array_dims, array_distinct,
array_element, array_except, array_has, array_has_all, array_has_any,
array_intersect, array_length, array_ndims, array_position, array_positions,
array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat,
array_replace, array_replace_all, array_replace_n, array_slice, array_sort,
array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, bit_length,
btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr,
concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part,
date_trunc, decode, degrees, digest, encode, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero,
lcm, left, levenshtein, ln, log, log10, log2,
Expand Down Expand Up @@ -484,6 +484,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ArrayHasAny => Self::ArrayHasAny,
ScalarFunction::ArrayHas => Self::ArrayHas,
ScalarFunction::ArrayDims => Self::ArrayDims,
ScalarFunction::ArrayDistinct => Self::ArrayDistinct,
ScalarFunction::ArrayElement => Self::ArrayElement,
ScalarFunction::Flatten => Self::Flatten,
ScalarFunction::ArrayLength => Self::ArrayLength,
Expand Down Expand Up @@ -1467,6 +1468,9 @@ pub fn parse_expr(
ScalarFunction::ArrayDims => {
Ok(array_dims(parse_expr(&args[0], registry)?))
}
ScalarFunction::ArrayDistinct => {
Ok(array_distinct(parse_expr(&args[0], registry)?))
}
ScalarFunction::ArrayElement => Ok(array_element(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1512,6 +1512,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny,
BuiltinScalarFunction::ArrayHas => Self::ArrayHas,
BuiltinScalarFunction::ArrayDims => Self::ArrayDims,
BuiltinScalarFunction::ArrayDistinct => Self::ArrayDistinct,
BuiltinScalarFunction::ArrayElement => Self::ArrayElement,
BuiltinScalarFunction::Flatten => Self::Flatten,
BuiltinScalarFunction::ArrayLength => Self::ArrayLength,
Expand Down
99 changes: 99 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,38 @@ AS VALUES
(make_array([[1], [2]], [[2], [3]]), make_array([1], [2]))
;

statement ok
CREATE TABLE array_distinct_table_1D
AS VALUES
(make_array(1, 1, 2, 2, 3)),
(make_array(1, 2, 3, 4, 5)),
(make_array(3, 5, 3, 3, 3))
;

statement ok
CREATE TABLE array_distinct_table_1D_UTF8
AS VALUES
(make_array('a', 'a', 'bc', 'bc', 'def')),
(make_array('a', 'bc', 'def', 'defg', 'defg')),
(make_array('defg', 'defg', 'defg', 'defg', 'defg'))
;

statement ok
CREATE TABLE array_distinct_table_2D
AS VALUES
(make_array([1,2], [1,2], [3,4], [3,4], [5,6])),
(make_array([1,2], [3,4], [5,6], [7,8], [9,10])),
(make_array([5,6], [5,6], NULL))
;

statement ok
CREATE TABLE array_distinct_table_1D_large
AS VALUES
(arrow_cast(make_array(1, 1, 2, 2, 3), 'LargeList(Int64)')),
(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')),
(arrow_cast(make_array(3, 5, 3, 3, 3), 'LargeList(Int64)'))
;

statement ok
CREATE TABLE array_intersect_table_1D
AS VALUES
Expand Down Expand Up @@ -2864,6 +2896,73 @@ select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_ca
----
true false true false false false true true false false true false true

query BBBBBBBBBBBBB
select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')),
array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')),
array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')),
array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')),
array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')),
array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')),
array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')),
array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')),
array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')),
array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')),
array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')),
array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')),
array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))'))
;
----
true false true false false false true true false false true false true

## array_distinct

query ?
select array_distinct(null);
----
NULL

query ?
select array_distinct([]);
----
[]

query ?
select array_distinct([[], []]);
----
[[]]

query ?
select array_distinct(column1)
from array_distinct_table_1D;
----
[1, 2, 3]
[1, 2, 3, 4, 5]
[3, 5]

query ?
select array_distinct(column1)
from array_distinct_table_1D_UTF8;
----
[a, bc, def]
[a, bc, def, defg]
[defg]

query ?
select array_distinct(column1)
from array_distinct_table_2D;
----
[[1, 2], [3, 4], [5, 6]]
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
[, [5, 6]]

query ?
select array_distinct(column1)
from array_distinct_table_1D_large;
----
[1, 2, 3]
[1, 2, 3, 4, 5]
[3, 5]

query ???
select array_intersect(column1, column2),
array_intersect(column3, column4),
Expand Down
1 change: 1 addition & 0 deletions docs/source/user-guide/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ Unlike to some databases the math functions in Datafusion works the same way as
| array_has_all(array, sub-array) | Returns true if all elements of sub-array exist in array `array_has_all([1,2,3], [1,3]) -> true` |
| array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` |
| array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` |
| array_distinct(array) | Returns distinct values from the array after removing duplicates. `array_distinct([1, 3, 2, 3, 1, 2, 4]) -> [1, 2, 3, 4]` |
| array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` |
| flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` |
| array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` |
Expand Down

0 comments on commit cd02c40

Please sign in to comment.