Skip to content

Commit

Permalink
Replace macro with function for array_repeat (apache#8071)
Browse files Browse the repository at this point in the history
* General array repeat

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* add test

Signed-off-by: jayzhan211 <[email protected]>

* add test

Signed-off-by: jayzhan211 <[email protected]>

* done

Signed-off-by: jayzhan211 <[email protected]>

* remove test

Signed-off-by: jayzhan211 <[email protected]>

* add comment

Signed-off-by: jayzhan211 <[email protected]>

* fm

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Nov 8, 2023
1 parent 3446382 commit aefee03
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 241 deletions.
312 changes: 126 additions & 186 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -841,125 +841,6 @@ pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
concat_internal(new_args.as_slice())
}

macro_rules! general_repeat {
($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{
let mut offsets: Vec<i32> = vec![0];
let mut values =
downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone();

let element_array = downcast_arg!($ELEMENT, $ARRAY_TYPE);
for (el, c) in element_array.iter().zip($COUNT.iter()) {
let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
DataFusionError::Internal(format!("offsets should not be empty"))
})?;
match el {
Some(el) => {
let c = if c < Some(0) { 0 } else { c.unwrap() } as usize;
let repeated_array =
[Some(el.clone())].repeat(c).iter().collect::<$ARRAY_TYPE>();

values = downcast_arg!(
compute::concat(&[&values, &repeated_array])?.clone(),
$ARRAY_TYPE
)
.clone();
offsets.push(last_offset + repeated_array.len() as i32);
}
None => {
offsets.push(last_offset);
}
}
}

let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true));

Arc::new(ListArray::try_new(
field,
OffsetBuffer::new(offsets.into()),
Arc::new(values),
None,
)?)
}};
}

macro_rules! general_repeat_list {
($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{
let mut offsets: Vec<i32> = vec![0];
let mut values =
downcast_arg!(new_empty_array($ELEMENT.data_type()), ListArray).clone();

let element_array = downcast_arg!($ELEMENT, ListArray);
for (el, c) in element_array.iter().zip($COUNT.iter()) {
let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
DataFusionError::Internal(format!("offsets should not be empty"))
})?;
match el {
Some(el) => {
let c = if c < Some(0) { 0 } else { c.unwrap() } as usize;
let repeated_vec = vec![el; c];

let mut i: i32 = 0;
let mut repeated_offsets = vec![i];
repeated_offsets.extend(
repeated_vec
.clone()
.into_iter()
.map(|a| {
i += a.len() as i32;
i
})
.collect::<Vec<_>>(),
);

let mut repeated_values = downcast_arg!(
new_empty_array(&element_array.value_type()),
$ARRAY_TYPE
)
.clone();
for repeated_list in repeated_vec {
repeated_values = downcast_arg!(
compute::concat(&[&repeated_values, &repeated_list])?,
$ARRAY_TYPE
)
.clone();
}

let field = Arc::new(Field::new(
"item",
element_array.value_type().clone(),
true,
));
let repeated_array = ListArray::try_new(
field,
OffsetBuffer::new(repeated_offsets.clone().into()),
Arc::new(repeated_values),
None,
)?;

values = downcast_arg!(
compute::concat(&[&values, &repeated_array,])?.clone(),
ListArray
)
.clone();
offsets.push(last_offset + repeated_array.len() as i32);
}
None => {
offsets.push(last_offset);
}
}
}

let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true));

Arc::new(ListArray::try_new(
field,
OffsetBuffer::new(offsets.into()),
Arc::new(values),
None,
)?)
}};
}

/// Array_empty SQL function
pub fn array_empty(args: &[ArrayRef]) -> Result<ArrayRef> {
if args[0].as_any().downcast_ref::<NullArray>().is_some() {
Expand All @@ -978,28 +859,136 @@ pub fn array_empty(args: &[ArrayRef]) -> Result<ArrayRef> {
/// Array_repeat SQL function
pub fn array_repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
let element = &args[0];
let count = as_int64_array(&args[1])?;
let count_array = as_int64_array(&args[1])?;

let res = match element.data_type() {
DataType::List(field) => {
macro_rules! array_function {
($ARRAY_TYPE:ident) => {
general_repeat_list!(element, count, $ARRAY_TYPE)
};
}
call_array_function!(field.data_type(), true)
match element.data_type() {
DataType::List(_) => {
let list_array = as_list_array(element)?;
general_list_repeat(list_array, count_array)
}
data_type => {
macro_rules! array_function {
($ARRAY_TYPE:ident) => {
general_repeat!(element, count, $ARRAY_TYPE)
};
_ => general_repeat(element, count_array),
}
}

/// For each element of `array[i]` repeat `count_array[i]` times.
///
/// Assumption for the input:
/// 1. `count[i] >= 0`
/// 2. `array.len() == count_array.len()`
///
/// For example,
/// ```text
/// array_repeat(
/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]]
/// )
/// ```
fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result<ArrayRef> {
let data_type = array.data_type();
let mut new_values = vec![];

let count_vec = count_array
.values()
.to_vec()
.iter()
.map(|x| *x as usize)
.collect::<Vec<_>>();

for (row_index, &count) in count_vec.iter().enumerate() {
let repeated_array = if array.is_null(row_index) {
new_null_array(data_type, count)
} else {
let original_data = array.to_data();
let capacity = Capacities::Array(count);
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], false, capacity);

for _ in 0..count {
mutable.extend(0, row_index, row_index + 1);
}
call_array_function!(data_type, false)
}
};

Ok(res)
let data = mutable.freeze();
arrow_array::make_array(data)
};
new_values.push(repeated_array);
}

let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
let values = arrow::compute::concat(&new_values)?;

Ok(Arc::new(ListArray::try_new(
Arc::new(Field::new("item", data_type.to_owned(), true)),
OffsetBuffer::from_lengths(count_vec),
values,
None,
)?))
}

/// Handle List version of `general_repeat`
///
/// For each element of `list_array[i]` repeat `count_array[i]` times.
///
/// For example,
/// ```text
/// array_repeat(
/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]]
/// )
/// ```
fn general_list_repeat(
list_array: &ListArray,
count_array: &Int64Array,
) -> Result<ArrayRef> {
let data_type = list_array.data_type();
let value_type = list_array.value_type();
let mut new_values = vec![];

let count_vec = count_array
.values()
.to_vec()
.iter()
.map(|x| *x as usize)
.collect::<Vec<_>>();

for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) {
let list_arr = match list_array_row {
Some(list_array_row) => {
let original_data = list_array_row.to_data();
let capacity = Capacities::Array(original_data.len() * count);
let mut mutable = MutableArrayData::with_capacities(
vec![&original_data],
false,
capacity,
);

for _ in 0..count {
mutable.extend(0, 0, original_data.len());
}

let data = mutable.freeze();
let repeated_array = arrow_array::make_array(data);

let list_arr = ListArray::try_new(
Arc::new(Field::new("item", value_type.clone(), true)),
OffsetBuffer::from_lengths(vec![original_data.len(); count]),
repeated_array,
None,
)?;
Arc::new(list_arr) as ArrayRef
}
None => new_null_array(data_type, count),
};
new_values.push(list_arr);
}

let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>();
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
let values = arrow::compute::concat(&new_values)?;

Ok(Arc::new(ListArray::try_new(
Arc::new(Field::new("item", data_type.to_owned(), true)),
OffsetBuffer::from_lengths(lengths),
values,
None,
)?))
}

macro_rules! position {
Expand Down Expand Up @@ -2925,55 +2914,6 @@ mod tests {
);
}

#[test]
fn test_array_repeat() {
// array_repeat(3, 5) = [3, 3, 3, 3, 3]
let array = array_repeat(&[
Arc::new(Int64Array::from_value(3, 1)),
Arc::new(Int64Array::from_value(5, 1)),
])
.expect("failed to initialize function array_repeat");
let result =
as_list_array(&array).expect("failed to initialize function array_repeat");

assert_eq!(result.len(), 1);
assert_eq!(
&[3, 3, 3, 3, 3],
result
.value(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
.values()
);
}

#[test]
fn test_nested_array_repeat() {
// array_repeat([1, 2, 3, 4], 3) = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]
let element = return_array();
let array = array_repeat(&[element, Arc::new(Int64Array::from_value(3, 1))])
.expect("failed to initialize function array_repeat");
let result =
as_list_array(&array).expect("failed to initialize function array_repeat");

assert_eq!(result.len(), 1);
let data = vec![
Some(vec![Some(1), Some(2), Some(3), Some(4)]),
Some(vec![Some(1), Some(2), Some(3), Some(4)]),
Some(vec![Some(1), Some(2), Some(3), Some(4)]),
];
let expected = ListArray::from_iter_primitive::<Int64Type, _, _>(data);
assert_eq!(
expected,
result
.value(0)
.as_any()
.downcast_ref::<ListArray>()
.unwrap()
.clone()
);
}
#[test]
fn test_array_to_string() {
// array_to_string([1, 2, 3, 4], ',') = 1,2,3,4
Expand Down
Loading

0 comments on commit aefee03

Please sign in to comment.