Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array agg groups accumulator, second attempt #11096

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,12 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
opt_filter,
total_num_groups,
|group_index, new_value| {
let prod = &mut self.prods[group_index];
*prod = prod.mul_wrapping(new_value);
if let Some(new_value) = new_value {
let prod = &mut self.prods[group_index];
*prod = prod.mul_wrapping(new_value);

self.counts[group_index] += 1;
self.counts[group_index] += 1;
}
},
);

Expand All @@ -319,7 +321,9 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
opt_filter,
total_num_groups,
|group_index, partial_count| {
self.counts[group_index] += partial_count;
if let Some(partial_count) = partial_count {
self.counts[group_index] += partial_count;
}
},
);

Expand All @@ -330,9 +334,12 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
partial_prods,
opt_filter,
total_num_groups,
|group_index, new_value: <Float64Type as ArrowPrimitiveType>::Native| {
let prod = &mut self.prods[group_index];
*prod = prod.mul_wrapping(new_value);
|group_index,
new_value: Option<<Float64Type as ArrowPrimitiveType>::Native>| {
if let Some(new_value) = new_value {
let prod = &mut self.prods[group_index];
*prod = prod.mul_wrapping(new_value);
}
},
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
//!
//! [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator

use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
use arrow::array::{
Array, ArrayRef, BooleanArray, BooleanBufferBuilder, Int64BufferBuilder, ListArray,
PrimitiveArray, StringArray,
};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::datatypes::ArrowPrimitiveType;

Expand Down Expand Up @@ -59,6 +62,8 @@ pub struct NullState {
/// If `seen_values[i]` is false, have not seen any values that
/// pass the filter yet for group `i`
seen_values: BooleanBufferBuilder,

seen_nulls: Int64BufferBuilder,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maintaining NullState is typically on the inner loop of aggregate performance so I am worried about the impact to performance here

I will run some benchmark numbers to gather some data

}

impl Default for NullState {
Expand All @@ -71,13 +76,14 @@ impl NullState {
pub fn new() -> Self {
Self {
seen_values: BooleanBufferBuilder::new(0),
seen_nulls: Int64BufferBuilder::new(0),
}
}

/// return the size of all buffers allocated by this null state, not including self
pub fn size(&self) -> usize {
// capacity is in bits, so convert to bytes
self.seen_values.capacity() / 8
self.seen_values.capacity() / 8 + self.seen_nulls.capacity() / 8
}

/// Invokes `value_fn(group_index, value)` for each non null, non
Expand Down Expand Up @@ -132,7 +138,7 @@ impl NullState {
mut value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
F: FnMut(usize, Option<T::Native>) + Send,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the docs should be updated as well to reflect this change

{
let data: &[T::Native] = values.values();
assert_eq!(data.len(), group_indices.len());
Expand All @@ -141,14 +147,13 @@ impl NullState {
// "not seen" valid)
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);

match (values.null_count() > 0, opt_filter) {
// no nulls, no filter,
(false, None) => {
let iter = group_indices.iter().zip(data.iter());
for (&group_index, &new_value) in iter {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, Some(new_value));
}
}
// nulls, no filter
Expand All @@ -175,7 +180,9 @@ impl NullState {
let is_valid = (mask & index_mask) != 0;
if is_valid {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, Some(new_value));
} else {
value_fn(group_index, None);
}
index_mask <<= 1;
},
Expand All @@ -192,7 +199,9 @@ impl NullState {
let is_valid = remainder_bits & (1 << i) != 0;
if is_valid {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, Some(new_value));
} else {
value_fn(group_index, None);
}
});
}
Expand All @@ -209,7 +218,7 @@ impl NullState {
.for_each(|((&group_index, &new_value), filter_value)| {
if let Some(true) = filter_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, Some(new_value));
}
})
}
Expand All @@ -227,7 +236,9 @@ impl NullState {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
value_fn(group_index, Some(new_value))
} else {
value_fn(group_index, None);
}
}
})
Expand Down Expand Up @@ -324,6 +335,176 @@ impl NullState {
}
}

/// Invokes `value_fn(group_index, value)` for each non null, non
/// filtered value in `values`, while tracking which groups have
/// seen null inputs and which groups have seen any inputs, for
/// [`ListArray`]s.
///
/// See [`Self::accumulate`], which handles `PrimitiveArray`s, for
/// more details on other arguments.
pub fn accumulate_array<F, N>(
&mut self,
group_indices: &[usize],
values: &ListArray,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
mut value_fn: F,
mut null_fn: N,
) where
F: FnMut(usize, ArrayRef) + Send,
N: FnMut(usize) + Send,
{
assert_eq!(values.len(), group_indices.len());

// ensure the seen_values is big enough (start everything at
// "not seen" valid)
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);

match (values.null_count() > 0, opt_filter) {
// no nulls, no filter,
(false, None) => {
let iter = group_indices.iter().zip(values.iter());
for (&group_index, new_value) in iter {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value.unwrap());
}
}
// nulls, no filter
(true, None) => {
let nulls = values.nulls().unwrap();
group_indices
.iter()
.zip(values.iter())
.zip(nulls.iter())
.for_each(|((&group_index, new_value), is_valid)| {
if is_valid {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value.unwrap());
} else {
null_fn(group_index);
}
})
}
// no nulls, but a filter
(false, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
group_indices
.iter()
.zip(values.iter())
.zip(filter.iter())
.for_each(|((&group_index, new_value), filter_value)| {
if let Some(true) = filter_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value.unwrap());
}
});
}
// both null values and filters
(true, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
filter
.iter()
.zip(group_indices.iter())
.zip(values.iter())
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
} else {
null_fn(group_index);
}
}
});
}
}
}

/// Invokes `value_fn(group_index, value)` for each non-null,
/// non-filtered value in `values`, while tracking which groups have
/// seen null inputs and which groups have seen any inputs, for
/// [`ListArray`]s.
///
/// See [`Self::accumulate`], which handles `PrimitiveArray`s, for
/// more details on other arguments.
pub fn accumulate_string<F>(
&mut self,
group_indices: &[usize],
values: &StringArray,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
mut value_fn: F,
) where
F: FnMut(usize, Option<&str>) + Send,
{
assert_eq!(values.len(), group_indices.len());

// ensure the seen_values is big enough (start everything at
// "not seen" valid)
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);

match (values.null_count() > 0, opt_filter) {
// no nulls, no filter,
(false, None) => {
let iter = group_indices.iter().zip(values.iter());
for (&group_index, new_value) in iter {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
}
// nulls, no filter
(true, None) => {
let nulls = values.nulls().unwrap();
group_indices
.iter()
.zip(values.iter())
.zip(nulls.iter())
.for_each(|((&group_index, new_value), is_valid)| {
if is_valid {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
} else {
value_fn(group_index, None);
}
})
}
// no nulls, but a filter
(false, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
group_indices
.iter()
.zip(values.iter())
.zip(filter.iter())
.for_each(|((&group_index, new_value), filter_value)| {
if let Some(true) = filter_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
});
}
// both null values and filters
(true, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
filter
.iter()
.zip(group_indices.iter())
.zip(values.iter())
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, Some(new_value));
} else {
value_fn(group_index, None);
}
}
});
}
}
}

/// Creates the a [`NullBuffer`] representing which group_indices
/// should have null values (because they never saw any values)
/// for the `emit_to` rows.
Expand Down Expand Up @@ -670,7 +851,9 @@ mod test {
opt_filter,
total_num_groups,
|group_index, value| {
accumulated_values.push((group_index, value));
if let Some(value) = value {
accumulated_values.push((group_index, value));
}
},
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ where
opt_filter,
total_num_groups,
|group_index, new_value| {
let value = &mut self.values[group_index];
(self.prim_fn)(value, new_value);
if let Some(new_value) = new_value {
let value = &mut self.values[group_index];
(self.prim_fn)(value, new_value);
}
},
);

Expand Down
Loading