Skip to content

Commit

Permalink
Move Regr_* functions to use UDAF (#10898)
Browse files Browse the repository at this point in the history
* Move Regr_* functions to use UDAF

Closes #10883 and is part of #8708

* Format and regen

* tweak error check

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
eejbyfeldt and alamb authored Jun 13, 2024
1 parent b627ca3 commit cc60278
Show file tree
Hide file tree
Showing 15 changed files with 135 additions and 316 deletions.
56 changes: 1 addition & 55 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,6 @@ pub enum AggregateFunction {
NthValue,
/// Correlation
Correlation,
/// Slope from linear regression
RegrSlope,
/// Intercept from linear regression
RegrIntercept,
/// Number of input rows in which both expressions are not null
RegrCount,
/// R-squared value from linear regression
RegrR2,
/// Average of the independent variable
RegrAvgx,
/// Average of the dependent variable
RegrAvgy,
/// Sum of squares of the independent variable
RegrSXX,
/// Sum of squares of the dependent variable
RegrSYY,
/// Sum of products of pairs of numbers
RegrSXY,
/// Approximate continuous percentile function
ApproxPercentileCont,
/// Approximate continuous percentile function with weight
Expand Down Expand Up @@ -93,15 +75,6 @@ impl AggregateFunction {
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
Correlation => "CORR",
RegrSlope => "REGR_SLOPE",
RegrIntercept => "REGR_INTERCEPT",
RegrCount => "REGR_COUNT",
RegrR2 => "REGR_R2",
RegrAvgx => "REGR_AVGX",
RegrAvgy => "REGR_AVGY",
RegrSXX => "REGR_SXX",
RegrSYY => "REGR_SYY",
RegrSXY => "REGR_SXY",
ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT",
Grouping => "GROUPING",
Expand Down Expand Up @@ -140,15 +113,6 @@ impl FromStr for AggregateFunction {
"string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
"regr_slope" => AggregateFunction::RegrSlope,
"regr_intercept" => AggregateFunction::RegrIntercept,
"regr_count" => AggregateFunction::RegrCount,
"regr_r2" => AggregateFunction::RegrR2,
"regr_avgx" => AggregateFunction::RegrAvgx,
"regr_avgy" => AggregateFunction::RegrAvgy,
"regr_sxx" => AggregateFunction::RegrSXX,
"regr_syy" => AggregateFunction::RegrSYY,
"regr_sxy" => AggregateFunction::RegrSXY,
// approximate
"approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
"approx_percentile_cont_with_weight" => {
Expand Down Expand Up @@ -200,15 +164,6 @@ impl AggregateFunction {
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
| AggregateFunction::RegrCount
| AggregateFunction::RegrR2
| AggregateFunction::RegrAvgx
| AggregateFunction::RegrAvgy
| AggregateFunction::RegrSXX
| AggregateFunction::RegrSYY
| AggregateFunction::RegrSXY => Ok(DataType::Float64),
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
Expand Down Expand Up @@ -272,16 +227,7 @@ impl AggregateFunction {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::Correlation
| AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
| AggregateFunction::RegrCount
| AggregateFunction::RegrR2
| AggregateFunction::RegrAvgx
| AggregateFunction::RegrAvgy
| AggregateFunction::RegrSXX
| AggregateFunction::RegrSYY
| AggregateFunction::RegrSXY => {
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ApproxPercentileCont => {
Expand Down
21 changes: 0 additions & 21 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,27 +158,6 @@ pub fn coerce_types(
}
Ok(vec![Float64, Float64])
}
AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
| AggregateFunction::RegrCount
| AggregateFunction::RegrR2
| AggregateFunction::RegrAvgx
| AggregateFunction::RegrAvgy
| AggregateFunction::RegrSXX
| AggregateFunction::RegrSYY
| AggregateFunction::RegrSXY => {
let valid_types = [NUMERICS.to_vec(), vec![Null]].concat();
let input_types_valid = // number of input already checked before
valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]);
if !input_types_valid {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(vec![Float64, Float64])
}
AggregateFunction::ApproxPercentileCont => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
return plan_err!(
Expand Down
19 changes: 19 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub mod covariance;
pub mod first_last;
pub mod hyperloglog;
pub mod median;
pub mod regr;
pub mod stddev;
pub mod sum;
pub mod variance;
Expand All @@ -85,6 +86,15 @@ pub mod expr_fn {
pub use super::first_last::first_value;
pub use super::first_last::last_value;
pub use super::median::median;
pub use super::regr::regr_avgx;
pub use super::regr::regr_avgy;
pub use super::regr::regr_count;
pub use super::regr::regr_intercept;
pub use super::regr::regr_r2;
pub use super::regr::regr_slope;
pub use super::regr::regr_sxx;
pub use super::regr::regr_sxy;
pub use super::regr::regr_syy;
pub use super::stddev::stddev;
pub use super::stddev::stddev_pop;
pub use super::sum::sum;
Expand All @@ -102,6 +112,15 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
covariance::covar_pop_udaf(),
median::median_udaf(),
count::count_udaf(),
regr::regr_slope_udaf(),
regr::regr_intercept_udaf(),
regr::regr_count_udaf(),
regr::regr_r2_udaf(),
regr::regr_avgx_udaf(),
regr::regr_avgy_udaf(),
regr::regr_sxx_udaf(),
regr::regr_syy_udaf(),
regr::regr_sxy_udaf(),
variance::var_samp_udaf(),
variance::var_pop_udaf(),
stddev::stddev_udaf(),
Expand Down
14 changes: 11 additions & 3 deletions datafusion/functions-aggregate/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
// specific language governing permissions and limitations
// under the License.

macro_rules! make_udaf_expr_and_func {
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
macro_rules! make_udaf_expr {
($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
// "fluent expr_fn" style function
#[doc = $DOC]
pub fn $EXPR_FN(
Expand All @@ -48,7 +48,12 @@ macro_rules! make_udaf_expr_and_func {
None,
))
}
};
}

macro_rules! make_udaf_expr_and_func {
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN);
create_func!($UDAF, $AGGREGATE_UDF_FN);
};
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
Expand All @@ -73,6 +78,9 @@ macro_rules! make_udaf_expr_and_func {

macro_rules! create_func {
($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default());
};
($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => {
paste::paste! {
/// Singleton instance of [$UDAF], ensures the UDAF is only created once
/// named STATIC_$(UDAF). For example `STATIC_FirstValue`
Expand All @@ -86,7 +94,7 @@ macro_rules! create_func {
pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<datafusion_expr::AggregateUDF> {
[< STATIC_ $UDAF >]
.get_or_init(|| {
std::sync::Arc::new(datafusion_expr::AggregateUDF::from(<$UDAF>::default()))
std::sync::Arc::new(datafusion_expr::AggregateUDF::from($CREATE))
})
.clone()
}
Expand Down
Loading

0 comments on commit cc60278

Please sign in to comment.