From a325825c8e83f18263efd899498eaadc70551de4 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 24 Apr 2024 21:14:23 +0800 Subject: [PATCH 1/2] add reverse enum Signed-off-by: jayzhan211 --- datafusion-cli/Cargo.lock | 1 + .../user_defined/user_defined_aggregates.rs | 97 ++++++++++++++++++- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udaf.rs | 24 +++++ datafusion/physical-expr-common/Cargo.toml | 1 + .../physical-expr-common/src/aggregate/mod.rs | 30 +++++- 6 files changed, 149 insertions(+), 6 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index ba3e68e4011f..09e80a3da98c 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1364,6 +1364,7 @@ dependencies = [ "arrow", "datafusion-common", "datafusion-expr", + "sqlparser", ] [[package]] diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 8f02fb30b013..863bb71abdec 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -18,16 +18,18 @@ //! This module contains end to end demonstrations of creating //! user defined aggregate functions +use std::fmt::Debug; + use arrow::{array::AsArray, datatypes::Fields}; use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; use arrow_schema::Schema; +use datafusion_physical_plan::udaf::create_aggregate_expr; +use sqlparser::ast::NullTreatment; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use datafusion::datasource::MemTable; -use datafusion::test_util::plan_and_collect; use datafusion::{ arrow::{ array::{ArrayRef, Float64Array, TimestampNanosecondArray}, @@ -43,10 +45,11 @@ use datafusion::{ prelude::SessionContext, scalar::ScalarValue, }; +use datafusion::{datasource::MemTable, test_util::plan_and_collect}; use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, - SimpleAggregateUDF, + create_udaf, expr::AggregateFunction, function::AccumulatorArgs, AggregateUDFImpl, + GroupsAccumulator, ReversedExpr, SimpleAggregateUDF, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -795,3 +798,89 @@ impl GroupsAccumulator for TestGroupsAccumulator { std::mem::size_of::() } } + +#[derive(Clone)] +struct TestReverseUDAF { + signature: Signature, + // accumulator: AccumulatorFactoryFunction, + // state_fields: Vec, +} + +impl Debug for TestReverseUDAF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("TestReverseUDAF") + .field("name", &self.name()) + .field("signature", self.signature()) + .finish() + } +} + +impl AggregateUDFImpl for TestReverseUDAF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "test_reverse" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + todo!("no need") + } + + fn state_fields( + &self, + _name: &str, + _value_type: DataType, + _ordering_fields: Vec, + ) -> Result> { + Ok(vec![]) + } + + fn reverse_expr(&self) -> ReversedExpr { + ReversedExpr::Reversed(AggregateFunction::new_udf( + Arc::new(self.clone().into()), + vec![], + false, + None, + None, + Some(NullTreatment::RespectNulls), + )) + } +} + +/// tests the creation, registration and usage of a UDAF +#[tokio::test] +async fn test_reverse_udaf() -> Result<()> { + let my_reverse = AggregateUDF::from(TestReverseUDAF { + signature: Signature::exact(vec![], Volatility::Immutable), + }); + + let empty_schema = Schema::empty(); + let e = create_aggregate_expr( + &my_reverse, + &[], + &[], + &[], + &empty_schema, + "test_reverse_udaf", + true, + )?; + + // TODO: We don't have a nice way to test the change without introducing many other things + // We check with the output string. `ignore nulls` is expeceted to be false. + let res = e.reverse_expr(); + let res_str = format!("{:?}", res.unwrap()); + + assert_eq!(&res_str, "AggregateFunctionExpr { fun: AggregateUDF { inner: TestReverseUDAF { name: \"test_reverse\", signature: Signature { type_signature: Exact([]), volatility: Immutable } } }, args: [], data_type: Float64, name: \"test_reverse_udaf\", schema: Schema { fields: [], metadata: {} }, sort_exprs: [], ordering_req: [], ignore_nulls: false, ordering_fields: [] }"); + + Ok(()) +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index de4f31029293..d31ba8f7efd1 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -81,7 +81,7 @@ pub use signature::{ TIMEZONE_WILDCARD, }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateUDF, AggregateUDFImpl}; +pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedExpr}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 67c3b51ca373..fbdacf4d9510 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,6 +17,7 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions +use crate::expr::AggregateFunction; use crate::function::AccumulatorArgs; use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; @@ -195,6 +196,11 @@ impl AggregateUDF { pub fn create_groups_accumulator(&self) -> Result> { self.inner.create_groups_accumulator() } + + /// See [`AggregateUDFImpl::reverse_expr`] for more details. + pub fn reverse_expr(&self) -> ReversedExpr { + self.inner.reverse_expr() + } } impl From for AggregateUDF @@ -354,6 +360,24 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn aliases(&self) -> &[String] { &[] } + + /// Construct an expression that calculates the aggregate in reverse. + /// Typically the "reverse" expression is itself (e.g. SUM, COUNT). + /// For aggregates that do not support calculation in reverse, + /// returns None (which is the default value). + fn reverse_expr(&self) -> ReversedExpr { + ReversedExpr::NotSupported + } +} + +#[derive(Debug)] +pub enum ReversedExpr { + /// The expression is the same as the original expression, like SUM, COUNT + Identical, + /// The expression does not support reverse calculation, like ArrayAgg + NotSupported, + /// The expression is different from the original expression + Reversed(AggregateFunction), } /// AggregateUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/physical-expr-common/Cargo.toml b/datafusion/physical-expr-common/Cargo.toml index d1202c83d526..1b50081ee375 100644 --- a/datafusion/physical-expr-common/Cargo.toml +++ b/datafusion/physical-expr-common/Cargo.toml @@ -39,3 +39,4 @@ path = "src/lib.rs" arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +sqlparser = { workspace = true } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 448af634176a..7be39d91b432 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -19,10 +19,13 @@ pub mod utils; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::expr::AggregateFunction; use datafusion_expr::type_coercion::aggregates::check_arg_count; +use datafusion_expr::ReversedExpr; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, }; +use sqlparser::ast::NullTreatment; use std::fmt::Debug; use std::{any::Any, sync::Arc}; @@ -147,7 +150,7 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { } /// Physical aggregate expression of a UDAF. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct AggregateFunctionExpr { fun: AggregateUDF, args: Vec>, @@ -273,6 +276,31 @@ impl AggregateExpr for AggregateFunctionExpr { fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } + + fn reverse_expr(&self) -> Option> { + match self.fun.reverse_expr() { + ReversedExpr::NotSupported => None, + ReversedExpr::Identical => Some(Arc::new(self.clone())), + ReversedExpr::Reversed(AggregateFunction { + func_def: _, + args: _, + distinct: _, + filter: _, + order_by: _, + null_treatment, + }) => { + let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) + == NullTreatment::IgnoreNulls; + + // TODO: Do the actual conversion from logical expr + // for other fields + let mut expr = self.clone(); + expr.ignore_nulls = ignore_nulls; + + Some(Arc::new(expr)) + } + } + } } impl PartialEq for AggregateFunctionExpr { From 047e7a8ad47ca4b729c2bbd3091a64ec20dcf072 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 24 Apr 2024 21:21:33 +0800 Subject: [PATCH 2/2] cleanup Signed-off-by: jayzhan211 --- datafusion/core/tests/user_defined/user_defined_aggregates.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 863bb71abdec..4b96905cd37a 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -802,8 +802,6 @@ impl GroupsAccumulator for TestGroupsAccumulator { #[derive(Clone)] struct TestReverseUDAF { signature: Signature, - // accumulator: AccumulatorFactoryFunction, - // state_fields: Vec, } impl Debug for TestReverseUDAF {