Skip to content

Commit

Permalink
Use AccumulatorArgs::is_reversed in NthValueAgg (apache#11669)
Browse files Browse the repository at this point in the history
* Refactor: use `AccumulatorArgs::is_reversed`

* Minor: fixes comment
  • Loading branch information
jcsherin authored Jul 27, 2024
1 parent 01dc3f9 commit 204e1bc
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
18 changes: 3 additions & 15 deletions datafusion/functions-aggregate/src/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValu
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDF, AggregateUDFImpl, Expr, ReversedUDAF, Signature,
Volatility,
Accumulator, AggregateUDFImpl, Expr, ReversedUDAF, Signature, Volatility,
};
use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays;
use datafusion_physical_expr_common::aggregate::utils::ordering_fields;
Expand All @@ -53,24 +52,15 @@ make_udaf_expr_and_func!(
#[derive(Debug)]
pub struct NthValueAgg {
signature: Signature,
/// Determines whether `N` is relative to the beginning or the end
/// of the aggregation. When set to `true`, then `N` is from the end.
reversed: bool,
}

impl NthValueAgg {
/// Create a new `NthValueAgg` aggregate function
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
reversed: false,
}
}

pub fn with_reversed(mut self, reversed: bool) -> Self {
self.reversed = reversed;
self
}
}

impl Default for NthValueAgg {
Expand Down Expand Up @@ -99,7 +89,7 @@ impl AggregateUDFImpl for NthValueAgg {
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let n = match acc_args.input_exprs[1] {
Expr::Literal(ScalarValue::Int64(Some(value))) => {
if self.reversed {
if acc_args.is_reversed {
Ok(-value)
} else {
Ok(value)
Expand Down Expand Up @@ -154,9 +144,7 @@ impl AggregateUDFImpl for NthValueAgg {
}

fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Reversed(Arc::from(AggregateUDF::from(
Self::new().with_reversed(!self.reversed),
)))
ReversedUDAF::Reversed(nth_value_udaf())
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ impl AggregateFunctionExpr {
self.ignore_nulls
}

/// Return if the aggregation is distinct
/// Return if the aggregation is reversed
pub fn is_reversed(&self) -> bool {
self.is_reversed
}
Expand Down

0 comments on commit 204e1bc

Please sign in to comment.