Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support LargeList in array_remove #8595

Merged
merged 1 commit into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 96 additions & 18 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ fn compare_element_to_list(
row_index: usize,
eq: bool,
) -> Result<BooleanArray> {
if list_array_row.data_type() != element_array.data_type() {
return exec_err!(
"compare_element_to_list received incompatible types: '{:?}' and '{:?}'.",
list_array_row.data_type(),
element_array.data_type()
);
}

let indices = UInt32Array::from(vec![row_index as u32]);
let element_array_row = arrow::compute::take(element_array, &indices, None)?;

Expand All @@ -126,6 +134,26 @@ fn compare_element_to_list(
})
.collect::<BooleanArray>()
}
DataType::LargeList(_) => {
// compare each element of the from array
let element_array_row_inner =
as_large_list_array(&element_array_row)?.value(0);
let list_array_row_inner = as_large_list_array(list_array_row)?;

list_array_row_inner
.iter()
// compare element by element the current row of list_array
.map(|row| {
row.map(|row| {
if eq {
row.eq(&element_array_row_inner)
} else {
row.ne(&element_array_row_inner)
}
})
})
.collect::<BooleanArray>()
}
_ => {
let element_arr = Scalar::new(element_array_row);
// use not_distinct so we can compare NULL
Expand Down Expand Up @@ -1511,14 +1539,14 @@ pub fn array_remove_n(args: &[ArrayRef]) -> Result<ArrayRef> {
/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced)
/// )
/// ```
fn general_replace(
list_array: &ListArray,
fn general_replace<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
from_array: &ArrayRef,
to_array: &ArrayRef,
arr_n: Vec<i64>,
) -> Result<ArrayRef> {
// Build up the offsets for the final output array
let mut offsets: Vec<i32> = vec![0];
let mut offsets: Vec<O> = vec![O::usize_as(0)];
let values = list_array.values();
let original_data = values.to_data();
let to_data = to_array.to_data();
Expand All @@ -1540,8 +1568,8 @@ fn general_replace(
continue;
}

let start = offset_window[0] as usize;
let end = offset_window[1] as usize;
let start = offset_window[0];
let end = offset_window[1];

let list_array_row = list_array.value(row_index);

Expand All @@ -1550,43 +1578,56 @@ fn general_replace(
let eq_array =
compare_element_to_list(&list_array_row, &from_array, row_index, true)?;

let original_idx = 0;
let replace_idx = 1;
let original_idx = O::usize_as(0);
let replace_idx = O::usize_as(1);
let n = arr_n[row_index];
let mut counter = 0;

// All elements are false, no need to replace, just copy original data
if eq_array.false_count() == eq_array.len() {
mutable.extend(original_idx, start, end);
offsets.push(offsets[row_index] + (end - start) as i32);
mutable.extend(
original_idx.to_usize().unwrap(),
start.to_usize().unwrap(),
end.to_usize().unwrap(),
);
offsets.push(offsets[row_index] + (end - start));
valid.append(true);
continue;
}

for (i, to_replace) in eq_array.iter().enumerate() {
let i = O::usize_as(i);
if let Some(true) = to_replace {
mutable.extend(replace_idx, row_index, row_index + 1);
mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1);
counter += 1;
if counter == n {
// copy original data for any matches past n
mutable.extend(original_idx, start + i + 1, end);
mutable.extend(
original_idx.to_usize().unwrap(),
(start + i).to_usize().unwrap() + 1,
end.to_usize().unwrap(),
);
break;
}
} else {
// copy original data for false / null matches
mutable.extend(original_idx, start + i, start + i + 1);
mutable.extend(
original_idx.to_usize().unwrap(),
(start + i).to_usize().unwrap(),
(start + i).to_usize().unwrap() + 1,
);
}
}

offsets.push(offsets[row_index] + (end - start) as i32);
offsets.push(offsets[row_index] + (end - start));
valid.append(true);
}

let data = mutable.freeze();

Ok(Arc::new(ListArray::try_new(
Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new("item", list_array.value_type(), true)),
OffsetBuffer::new(offsets.into()),
OffsetBuffer::<O>::new(offsets.into()),
arrow_array::make_array(data),
Some(NullBuffer::new(valid.finish())),
)?))
Expand All @@ -1595,19 +1636,56 @@ fn general_replace(
pub fn array_replace(args: &[ArrayRef]) -> Result<ArrayRef> {
// replace at most one occurence for each element
let arr_n = vec![1; args[0].len()];
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
let array = &args[0];
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, &args[1], &args[2], arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, &args[1], &args[2], arr_n)
}
array_type => exec_err!("array_replace does not support type '{array_type:?}'."),
}
}

pub fn array_replace_n(args: &[ArrayRef]) -> Result<ArrayRef> {
// replace the specified number of occurences
let arr_n = as_int64_array(&args[3])?.values().to_vec();
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
let array = &args[0];
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, &args[1], &args[2], arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, &args[1], &args[2], arr_n)
}
array_type => {
exec_err!("array_replace_n does not support type '{array_type:?}'.")
}
}
}

pub fn array_replace_all(args: &[ArrayRef]) -> Result<ArrayRef> {
// replace all occurrences (up to "i64::MAX")
let arr_n = vec![i64::MAX; args[0].len()];
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
let array = &args[0];
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, &args[1], &args[2], arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, &args[1], &args[2], arr_n)
}
array_type => {
exec_err!("array_replace_all does not support type '{array_type:?}'.")
}
}
}

macro_rules! to_string {
Expand Down
Loading