Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support LargeList in array_element #8570

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] {
List(field) => Ok(field.data_type().clone()),
LargeList(field) => Ok(field.data_type().clone()),
_ => plan_err!(
"The {self} function can only accept list as the first argument"
"The {self} function can only accept list or largelist as the first argument"
),
},
BuiltinScalarFunction::ArrayLength => Ok(UInt64),
Expand Down
82 changes: 57 additions & 25 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,56 +369,62 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result<ArrayRef> {
}
}

/// array_element SQL function
///
/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index.
/// `array_element(array, index)`
///
/// For example:
/// > array_element(\[1, 2, 3], 2) -> 2
pub fn array_element(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_list_array(&args[0])?;
let indexes = as_int64_array(&args[1])?;

let values = list_array.values();
fn general_array_element<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
indexes: &Int64Array,
) -> Result<ArrayRef>
where
i64: TryInto<O>,
{
let values = array.values();
let original_data = values.to_data();
let capacity = Capacities::Array(original_data.len());

// use_nulls: true, we don't construct List for array_element, so we need explicit nulls.
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true, capacity);

fn adjusted_array_index(index: i64, len: usize) -> Option<i64> {
fn adjusted_array_index<O: OffsetSizeTrait>(index: i64, len: O) -> Result<Option<O>>
where
i64: TryInto<O>,
{
let index: O = index.try_into().map_err(|_| {
DataFusionError::Execution(format!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: macro

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot use marco here because of map_err

"array_element got invalid index: {}",
index
))
})?;
// 0 ~ len - 1
let adjusted_zero_index = if index < 0 {
index + len as i64
let adjusted_zero_index = if index < O::usize_as(0) {
index + len
} else {
index - 1
index - O::usize_as(1)
};

if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 {
Some(adjusted_zero_index)
if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len {
Ok(Some(adjusted_zero_index))
} else {
// Out of bounds
None
Ok(None)
}
}

for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
let start = offset_window[0] as usize;
let end = offset_window[1] as usize;
for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
let start = offset_window[0];
let end = offset_window[1];
let len = end - start;

// array is null
if len == 0 {
if len == O::usize_as(0) {
mutable.extend_nulls(1);
continue;
}

let index = adjusted_array_index(indexes.value(row_index), len);
let index = adjusted_array_index::<O>(indexes.value(row_index), len)?;

if let Some(index) = index {
mutable.extend(0, start + index as usize, start + index as usize + 1);
let start = start.as_usize() + index.as_usize();
mutable.extend(0, start, start + 1_usize);
} else {
// Index out of bounds
mutable.extend_nulls(1);
Expand All @@ -429,6 +435,32 @@ pub fn array_element(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(arrow_array::make_array(data))
}

/// array_element SQL function
///
/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index.
/// `array_element(array, index)`
///
/// For example:
/// > array_element(\[1, 2, 3], 2) -> 2
pub fn array_element(args: &[ArrayRef]) -> Result<ArrayRef> {
match &args[0].data_type() {
DataType::List(_) => {
let array = as_list_array(&args[0])?;
let indexes = as_int64_array(&args[1])?;
general_array_element::<i32>(array, indexes)
}
DataType::LargeList(_) => {
let array = as_large_list_array(&args[0])?;
let indexes = as_int64_array(&args[1])?;
general_array_element::<i64>(array, indexes)
}
_ => not_impl_err!(
"array_element does not support type: {:?}",
args[0].data_type()
),
}
}

fn general_except<OffsetSize: OffsetSizeTrait>(
l: &GenericListArray<OffsetSize>,
r: &GenericListArray<OffsetSize>,
Expand Down
72 changes: 71 additions & 1 deletion datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ from arrays_values_without_nulls;
## array_element (aliases: array_extract, list_extract, list_element)

# array_element error
query error DataFusion error: Error during planning: The array_element function can only accept list as the first argument
query error DataFusion error: Error during planning: The array_element function can only accept list or largelist as the first argument
select array_element(1, 2);


Expand All @@ -729,58 +729,106 @@ select array_element(make_array(1, 2, 3, 4, 5), 2), array_element(make_array('h'
----
2 l

query IT
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3);
----
2 l

# array_element scalar function #2 (with positive index; out of bounds)
query IT
select array_element(make_array(1, 2, 3, 4, 5), 7), array_element(make_array('h', 'e', 'l', 'l', 'o'), 11);
----
NULL NULL

query IT
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 7), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 11);
----
NULL NULL

# array_element scalar function #3 (with zero)
query IT
select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h', 'e', 'l', 'l', 'o'), 0);
----
NULL NULL

query IT
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0);
----
NULL NULL

# array_element scalar function #4 (with NULL)
query error
select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL);

query error
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL);

# array_element scalar function #5 (with negative index)
query IT
select array_element(make_array(1, 2, 3, 4, 5), -2), array_element(make_array('h', 'e', 'l', 'l', 'o'), -3);
----
4 l

query IT
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3);
----
4 l

# array_element scalar function #6 (with negative index; out of bounds)
query IT
select array_element(make_array(1, 2, 3, 4, 5), -11), array_element(make_array('h', 'e', 'l', 'l', 'o'), -7);
----
NULL NULL

query IT
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -11), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7);
----
NULL NULL

# array_element scalar function #7 (nested array)
query ?
select array_element(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1);
----
[1, 2, 3, 4, 5]

query ?
select array_element(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1);
----
[1, 2, 3, 4, 5]

# array_extract scalar function #8 (function alias `array_slice`)
query IT
select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3);
----
2 l

query IT
select array_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3);
----
2 l

# list_element scalar function #9 (function alias `array_slice`)
query IT
select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3);
----
2 l

query IT
select list_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3);
----
2 l

# list_extract scalar function #10 (function alias `array_slice`)
query IT
select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3);
----
2 l

query IT
select list_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3);
----
2 l

# array_element with columns
query I
select array_element(column1, column2) from slices;
Expand All @@ -793,6 +841,17 @@ NULL
NULL
55

query I
select array_element(arrow_cast(column1, 'LargeList(Int64)'), column2) from slices;
----
NULL
12
NULL
37
NULL
NULL
55

# array_element with columns and scalars
query II
select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from slices;
Expand All @@ -805,6 +864,17 @@ NULL 23
NULL 43
5 NULL

query II
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_element(arrow_cast(column1, 'LargeList(Int64)'), 3) from slices;
----
1 3
2 13
NULL 23
2 33
4 NULL
NULL 43
5 NULL

## array_pop_back (aliases: `list_pop_back`)

# array_pop_back scalar function #1
Expand Down