Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce reverse_expr for UDAF #10214

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

95 changes: 91 additions & 4 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@
//! This module contains end to end demonstrations of creating
//! user defined aggregate functions

use std::fmt::Debug;

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
use arrow_schema::Schema;
use datafusion_physical_plan::udaf::create_aggregate_expr;
use sqlparser::ast::NullTreatment;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use datafusion::datasource::MemTable;
use datafusion::test_util::plan_and_collect;
use datafusion::{
arrow::{
array::{ArrayRef, Float64Array, TimestampNanosecondArray},
Expand All @@ -43,10 +45,11 @@ use datafusion::{
prelude::SessionContext,
scalar::ScalarValue,
};
use datafusion::{datasource::MemTable, test_util::plan_and_collect};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
create_udaf, expr::AggregateFunction, function::AccumulatorArgs, AggregateUDFImpl,
GroupsAccumulator, ReversedExpr, SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;

Expand Down Expand Up @@ -795,3 +798,87 @@ impl GroupsAccumulator for TestGroupsAccumulator {
std::mem::size_of::<u64>()
}
}

#[derive(Clone)]
struct TestReverseUDAF {
signature: Signature,
}

impl Debug for TestReverseUDAF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("TestReverseUDAF")
.field("name", &self.name())
.field("signature", self.signature())
.finish()
}
}

impl AggregateUDFImpl for TestReverseUDAF {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"test_reverse"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
todo!("no need")
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
Ok(vec![])
}

fn reverse_expr(&self) -> ReversedExpr {
ReversedExpr::Reversed(AggregateFunction::new_udf(
Arc::new(self.clone().into()),
vec![],
false,
None,
None,
Some(NullTreatment::RespectNulls),
))
}
}

/// tests the creation, registration and usage of a UDAF
#[tokio::test]
async fn test_reverse_udaf() -> Result<()> {
let my_reverse = AggregateUDF::from(TestReverseUDAF {
signature: Signature::exact(vec![], Volatility::Immutable),
});

let empty_schema = Schema::empty();
let e = create_aggregate_expr(
&my_reverse,
&[],
&[],
&[],
&empty_schema,
"test_reverse_udaf",
true,
)?;

// TODO: We don't have a nice way to test the change without introducing many other things
// We check with the output string. `ignore nulls` is expeceted to be false.
Copy link
Contributor Author

@jayzhan211 jayzhan211 Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can even remove fn test_reverse_udaf if first/last are moved to UDAF based.

let res = e.reverse_expr();
let res_str = format!("{:?}", res.unwrap());

assert_eq!(&res_str, "AggregateFunctionExpr { fun: AggregateUDF { inner: TestReverseUDAF { name: \"test_reverse\", signature: Signature { type_signature: Exact([]), volatility: Immutable } } }, args: [], data_type: Float64, name: \"test_reverse_udaf\", schema: Schema { fields: [], metadata: {} }, sort_exprs: [], ordering_req: [], ignore_nulls: false, ordering_fields: [] }");

Ok(())
}
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub use signature::{
TIMEZONE_WILDCARD,
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl};
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedExpr};
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::{WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
24 changes: 24 additions & 0 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions

use crate::expr::AggregateFunction;
use crate::function::AccumulatorArgs;
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
Expand Down Expand Up @@ -195,6 +196,11 @@ impl AggregateUDF {
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator()
}

/// See [`AggregateUDFImpl::reverse_expr`] for more details.
pub fn reverse_expr(&self) -> ReversedExpr {
self.inner.reverse_expr()
}
}

impl<F> From<F> for AggregateUDF
Expand Down Expand Up @@ -354,6 +360,24 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn aliases(&self) -> &[String] {
&[]
}

/// Construct an expression that calculates the aggregate in reverse.
/// Typically the "reverse" expression is itself (e.g. SUM, COUNT).
/// For aggregates that do not support calculation in reverse,
/// returns None (which is the default value).
fn reverse_expr(&self) -> ReversedExpr {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to take the owned type, but

function arguments must have a statically known size, borrowed types always have a known size: `&`

ReversedExpr::NotSupported
}
}

#[derive(Debug)]
pub enum ReversedExpr {
/// The expression is the same as the original expression, like SUM, COUNT
Identical,
/// The expression does not support reverse calculation, like ArrayAgg
NotSupported,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent some time trying to understand this -- is the reason that ArrayAgg doesn't support reversed calculations that the array elements would be in a different order?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because ArrayAgg has no ordering info.

You can also take a look at this comment #9972 (comment)

Although I found that the counter-example might not be correct after a few days, ARRAY_AGG(b ORDER BY c DESC) is OrderSensitiveArrayAgg, so it is possible to produce reverse expr, but it still makes sense to me that order insensitive ArrayAgg (which is ArrayAgg) does not support reverse_expr. cc @mustafasrepo

/// The expression is different from the original expression
Reversed(AggregateFunction),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about returning an AggregateFunction here -- this method is part of AggregateUDFImpl trait, but AggregateFunction has arguments and other fields that I don't think will be accessable to an instance of AggregateUDFImpl.

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct AggregateFunction {
/// Name of the function
pub func_def: AggregateFunctionDefinition,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
/// Whether this is a DISTINCT aggregation or not
pub distinct: bool,
/// Optional filter
pub filter: Option<Box<Expr>>,
/// Optional ordering
pub order_by: Option<Vec<Expr>>,
pub null_treatment: Option<NullTreatment>,
}

Would it make more sense to return an AggregateUDF here:

/// Logical representation of a user-defined [aggregate function] (UDAF).
///
/// An aggregate function combines the values from multiple input rows
/// into a single output "aggregate" (summary) row. It is different
/// from a scalar function because it is stateful across batches. User
/// defined aggregate functions can be used as normal SQL aggregate
/// functions (`GROUP BY` clause) as well as window functions (`OVER`
/// clause).
///
/// `AggregateUDF` provides DataFusion the information needed to plan and call
/// aggregate functions, including name, type information, and a factory
/// function to create an [`Accumulator`] instance, to perform the actual
/// aggregation.
///
/// For more information, please see [the examples]:
///
/// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]).
///
/// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API
/// access (examples in [`advanced_udaf.rs`]).
///
/// # API Note
/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
/// compatibility with the older API.
///
/// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process
/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
/// [`Accumulator`]: crate::Accumulator
/// [`create_udaf`]: crate::expr_fn::create_udaf
/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs
/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
#[derive(Debug, Clone)]
pub struct AggregateUDF {
inner: Arc<dyn AggregateUDFImpl>,
}
here instead?

Maybe it would help guide the API design to implement this API for first_value/last_value udafs 🤔

Copy link
Contributor Author

@jayzhan211 jayzhan211 Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can see the reverse expr in OrderSensitiveArrayAgg

fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
        Some(Arc::new(Self {
            name: self.name.to_string(),
            input_data_type: self.input_data_type.clone(),
            expr: self.expr.clone(),
            nullable: self.nullable,
            order_by_data_types: self.order_by_data_types.clone(),
            // Reverse requirement:
            ordering_req: reverse_order_bys(&self.ordering_req),
            reverse: !self.reverse,
        }))
    }

We need the ordering info, it is not contained in either AggregateUDF or AggregateUDFImpl function. Therefore, I return AggregateFunction instead. Then, we can create AggregateExec with create_physical_expr to get the physical-expr from these logical exprs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I wonder why does the AggregateUDFImpl need to actually reverse the ordering? Couldn't the code that calls reverse_expr() do the actual call to reverse_order_by?

As in maybe AggregateFunctionExpr::reverse_expr could call reverse_order_by and AggregateUdfImpl::reverse_expr would only return an AggregateUDF

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me take a look at the code that uses reverse_expr, I guess some kind of rewrite may help.

}

/// AggregateUDF that adds an alias to the underlying function. It is better to
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ path = "src/lib.rs"
arrow = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
sqlparser = { workspace = true }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ideally the physical expr shouldn't be depending on sql parser (for people who are not using SQL)

Though now I see that perhaps it is because AggregateFunction has a null treatment flag on it 🤔 That might be a nice dependency to avoid

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is easy to avoid this dependency, we can check this token in parser, then pass boolean all the way down.

Copy link
Contributor Author

@jayzhan211 jayzhan211 Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me file an issue for this

30 changes: 29 additions & 1 deletion datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ pub mod utils;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::type_coercion::aggregates::check_arg_count;
use datafusion_expr::ReversedExpr;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator,
};
use sqlparser::ast::NullTreatment;
use std::fmt::Debug;
use std::{any::Any, sync::Arc};

Expand Down Expand Up @@ -147,7 +150,7 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq<dyn Any> {
}

/// Physical aggregate expression of a UDAF.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct AggregateFunctionExpr {
fun: AggregateUDF,
args: Vec<Arc<dyn PhysicalExpr>>,
Expand Down Expand Up @@ -273,6 +276,31 @@ impl AggregateExpr for AggregateFunctionExpr {
fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
(!self.ordering_req.is_empty()).then_some(&self.ordering_req)
}

fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
match self.fun.reverse_expr() {
ReversedExpr::NotSupported => None,
ReversedExpr::Identical => Some(Arc::new(self.clone())),
ReversedExpr::Reversed(AggregateFunction {
func_def: _,
args: _,
distinct: _,
filter: _,
order_by: _,
null_treatment,
}) => {
let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls)
== NullTreatment::IgnoreNulls;

// TODO: Do the actual conversion from logical expr
// for other fields
let mut expr = self.clone();
expr.ignore_nulls = ignore_nulls;

Some(Arc::new(expr))
}
}
}
}

impl PartialEq<dyn Any> for AggregateFunctionExpr {
Expand Down