-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support GroupsAccumulator accumulator for udaf #8892
Conversation
cc @alamb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much @guojidan -- this is a very nice PR as it follows the pattern of the existing APIs and is really nicely written/ Bravo 👏
Prior to merge, I think this PR needs:
- Hook up the new trait to
AggregateExpr
(I left comments about what I think is needed inline) as I don't think it quite works yet - A test that shows the GroupsAccumulator being called (see below)
I had a few comments about how we could potentially improve the example with comments and simplification, but we could do that as a follow on PR
Testing
For testing, I suggest a test that all the plumbing is hooked up correctly (which I don't think it is yet in this PR)
For example,
- Add a new
AggregateUDF
here: https://github.com/apache/arrow-datafusion/blob/main/datafusion/core/tests/user_defined/user_defined_aggregates.rs - The new
AggregateUDF
returns a basicGroupsAccumulator
(can simply return a constant or something -- it doesn't have to have any logic) - The new
AggregateUDF
should error / panic if the normalAccumulator
path is invoked - Run a query that shows the
Accumulator
is invoked not theAccumulator
/// If the aggregate expression has a specialized | ||
/// [`GroupsAccumulator`] implementation. If this returns true, | ||
/// `[Self::create_groups_accumulator`] will be called. | ||
fn groups_accumulator_supported(&self) -> bool { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may have missed it, but I don't see this function being called anywhere
I think we need to implement AggregateExpr::groups_accumulator_supported
and groups_accumulator_supported
for the implementation of AggregateExpr
(different trait) for AggregateUDF
, here :
@@ -82,6 +88,16 @@ impl AggregateUDFImpl for GeoMeanUdf { | |||
fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> { | |||
Ok(vec![DataType::Float64, DataType::UInt32]) | |||
} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend we add a note to accumulator()
above about when this is used. Now that I write this maybe we should also put some of this information on the docstrings for AggregateUDF::groups_accumulator
- /// This is the accumulator factory; DataFusion uses it to create new accumulators.
+ /// This is the accumulator factory for row wise accumulation; Even when `GroupsAccumulator`
+ /// is supported, DataFusion will use this row oriented
+ /// accumulator when the aggregate function is used as a window function
+ /// or when there are only aggregates (no GROUP BY columns) in the plan.
@@ -82,6 +88,16 @@ impl AggregateUDFImpl for GeoMeanUdf { | |||
fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> { | |||
Ok(vec![DataType::Float64, DataType::UInt32]) | |||
} | |||
|
|||
fn groups_accumulator_supported(&self) -> bool { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to add some context annotating this function for the example:
fn groups_accumulator_supported(&self) -> bool { | |
/// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator` | |
/// which is used for cases when there are grouping columns in the query | |
fn groups_accumulator_supported(&self) -> bool { |
@@ -194,12 +210,196 @@ fn create_context() -> Result<SessionContext> { | |||
Ok(ctx) | |||
} | |||
|
|||
struct GeometricMeanGroupsAccumulator<F> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct GeometricMeanGroupsAccumulator<F> | |
/// Define a `GroupsAccumulator` for GeometricMean | |
/// which handles accumulator state for multiple groups at once. | |
/// This API is significantly more complicated than `Accumulator`, which manages | |
/// the state for a single group, but for queries with a large number of groups | |
/// can be significantly faster. See the `GroupsAccumulator` documentation for | |
/// more information. | |
struct GeometricMeanGroupsAccumulator<F> |
/// Count per group (use u64 to make UInt64Array) | ||
counts: Vec<u64>, | ||
|
||
/// product per group, stored as the native type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// product per group, stored as the native type | |
/// product per group, stored as the native type (not `ScalarValue`) |
) -> Result<<Float64Type as ArrowPrimitiveType>::Native> | ||
+ Send, | ||
{ | ||
fn update_batch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fn update_batch( | |
/// Updates the accumulator state given input. DataFusion provides `group_indices`, the groups that each | |
/// row in `values` belongs to as well as an optional filter of which rows passed. | |
fn update_batch( |
self.counts.resize(total_num_groups, 0); | ||
self.prods | ||
.resize(total_num_groups, Float64Type::default_value()); | ||
self.null_state.accumulate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.null_state.accumulate( | |
/// Use the `NullState` structure to generate specialized code for null / non null input elements | |
self.null_state.accumulate( |
Ok(()) | ||
} | ||
|
||
fn merge_batch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fn merge_batch( | |
/// Merge the results from previous invocations of `evaluate` into this accumulator's state | |
fn merge_batch( |
Ok(()) | ||
} | ||
|
||
fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> { | |
/// Generate output, as specififed by `emit_to` and update the intermediate state | |
fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> { |
null_state: NullState, | ||
|
||
/// Function that computes the final geometric mean (value / count) | ||
geo_mean_fn: F, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the example would be simpler if you removed the generics and simply inlined the definition of geo_mean_fn
into the callsite in evaluate
. The generics are needed for GroupsAccumulators that are specialized on type (e.g. a special one for Float32, Float64, etc).
@guojidan -- thanks again for your work on this PR. I am quite interested in getting this PR merged, so if you don't think you'll have a chance to work on it this week I would be happy to try and help as well |
sorry, I have been very busy the other day, but I am come back 😄 |
No worries -- thank you for working on this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much @guojidan -- this looks great
|
||
fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> { | ||
// should use groups accumulator | ||
panic!("accumulator shouldn't invoke"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
Which issue does this PR close?
Closes #8793 .
What changes are included in this PR?
Support GroupsAccumulator accumulator for udaf