From 27f5c73abb10ef6a86e7d4a75118c256a36490c5 Mon Sep 17 00:00:00 2001 From: Michael-J-Ward Date: Wed, 24 Jul 2024 14:04:13 -0500 Subject: [PATCH] migrate regr_* functions to UDAF Ref: https://github.com/apache/datafusion/pull/10898 --- python/datafusion/functions.py | 18 +++---- src/functions.rs | 99 ++++++++++++++++++++++++++++++---- 2 files changed, 99 insertions(+), 18 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 2f9e6d61..b79a01b4 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -1398,7 +1398,7 @@ def regr_avgx(y: Expr, x: Expr, distinct: bool = False) -> Expr: Only non-null pairs of the inputs are evaluated. """ - return Expr(f.regr_avgx[y.expr, x.expr], distinct) + return Expr(f.regr_avgx(y.expr, x.expr, distinct)) def regr_avgy(y: Expr, x: Expr, distinct: bool = False) -> Expr: @@ -1406,42 +1406,42 @@ def regr_avgy(y: Expr, x: Expr, distinct: bool = False) -> Expr: Only non-null pairs of the inputs are evaluated. """ - return Expr(f.regr_avgy[y.expr, x.expr], distinct) + return Expr(f.regr_avgy(y.expr, x.expr, distinct)) def regr_count(y: Expr, x: Expr, distinct: bool = False) -> Expr: """Counts the number of rows in which both expressions are not null.""" - return Expr(f.regr_count[y.expr, x.expr], distinct) + return Expr(f.regr_count(y.expr, x.expr, distinct)) def regr_intercept(y: Expr, x: Expr, distinct: bool = False) -> Expr: """Computes the intercept from the linear regression.""" - return Expr(f.regr_intercept[y.expr, x.expr], distinct) + return Expr(f.regr_intercept(y.expr, x.expr, distinct)) def regr_r2(y: Expr, x: Expr, distinct: bool = False) -> Expr: """Computes the R-squared value from linear regression.""" - return Expr(f.regr_r2[y.expr, x.expr], distinct) + return Expr(f.regr_r2(y.expr, x.expr, distinct)) def regr_slope(y: Expr, x: Expr, distinct: bool = False) -> Expr: """Computes the slope from linear regression.""" - return Expr(f.regr_slope[y.expr, x.expr], distinct) + return Expr(f.regr_slope(y.expr, x.expr, distinct)) def regr_sxx(y: Expr, x: Expr, distinct: bool = False) -> Expr: """Computes the sum of squares of the independent variable `x`.""" - return Expr(f.regr_sxx[y.expr, x.expr], distinct) + return Expr(f.regr_sxx(y.expr, x.expr, distinct)) def regr_sxy(y: Expr, x: Expr, distinct: bool = False) -> Expr: """Computes the sum of products of pairs of numbers.""" - return Expr(f.regr_sxy[y.expr, x.expr], distinct) + return Expr(f.regr_sxy(y.expr, x.expr, distinct)) def regr_syy(y: Expr, x: Expr, distinct: bool = False) -> Expr: """Computes the sum of squares of the dependent variable `y`.""" - return Expr(f.regr_syy[y.expr, x.expr], distinct) + return Expr(f.regr_syy(y.expr, x.expr, distinct)) def first_value( diff --git a/src/functions.rs b/src/functions.rs index 74513011..ae32a485 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -191,6 +191,96 @@ pub fn var_pop(expression: PyExpr, distinct: bool) -> PyResult { } } +#[pyfunction] +pub fn regr_avgx(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { + let expr = functions_aggregate::expr_fn::regr_avgx(expr_y.expr, expr_x.expr); + if distinct { + Ok(expr.distinct().build()?.into()) + } else { + Ok(expr.into()) + } +} + +#[pyfunction] +pub fn regr_avgy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { + let expr = functions_aggregate::expr_fn::regr_avgy(expr_y.expr, expr_x.expr); + if distinct { + Ok(expr.distinct().build()?.into()) + } else { + Ok(expr.into()) + } +} + +#[pyfunction] +pub fn regr_count(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { + let expr = functions_aggregate::expr_fn::regr_count(expr_y.expr, expr_x.expr); + if distinct { + Ok(expr.distinct().build()?.into()) + } else { + Ok(expr.into()) + } +} + +#[pyfunction] +pub fn regr_intercept(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { + let expr = functions_aggregate::expr_fn::regr_intercept(expr_y.expr, expr_x.expr); + if distinct { + Ok(expr.distinct().build()?.into()) + } else { + Ok(expr.into()) + } +} + +#[pyfunction] +pub fn regr_r2(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { + let expr = functions_aggregate::expr_fn::regr_r2(expr_y.expr, expr_x.expr); + if distinct { + Ok(expr.distinct().build()?.into()) + } else { + Ok(expr.into()) + } +} + +#[pyfunction] +pub fn regr_slope(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { + let expr = functions_aggregate::expr_fn::regr_slope(expr_y.expr, expr_x.expr); + if distinct { + Ok(expr.distinct().build()?.into()) + } else { + Ok(expr.into()) + } +} + +#[pyfunction] +pub fn regr_sxx(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { + let expr = functions_aggregate::expr_fn::regr_sxx(expr_y.expr, expr_x.expr); + if distinct { + Ok(expr.distinct().build()?.into()) + } else { + Ok(expr.into()) + } +} + +#[pyfunction] +pub fn regr_sxy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { + let expr = functions_aggregate::expr_fn::regr_sxy(expr_y.expr, expr_x.expr); + if distinct { + Ok(expr.distinct().build()?.into()) + } else { + Ok(expr.into()) + } +} + +#[pyfunction] +pub fn regr_syy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { + let expr = functions_aggregate::expr_fn::regr_syy(expr_y.expr, expr_x.expr); + if distinct { + Ok(expr.distinct().build()?.into()) + } else { + Ok(expr.into()) + } +} + #[pyfunction] #[pyo3(signature = (expr, distinct = false, filter = None, order_by = None, null_treatment = None))] pub fn first_value( @@ -847,15 +937,6 @@ array_fn!(range, start stop step); aggregate_function!(array_agg, ArrayAgg); aggregate_function!(max, Max); aggregate_function!(min, Min); -aggregate_function!(regr_avgx, RegrAvgx); -aggregate_function!(regr_avgy, RegrAvgy); -aggregate_function!(regr_count, RegrCount); -aggregate_function!(regr_intercept, RegrIntercept); -aggregate_function!(regr_r2, RegrR2); -aggregate_function!(regr_slope, RegrSlope); -aggregate_function!(regr_sxx, RegrSXX); -aggregate_function!(regr_sxy, RegrSXY); -aggregate_function!(regr_syy, RegrSYY); aggregate_function!(bit_and, BitAnd); aggregate_function!(bit_or, BitOr); aggregate_function!(bit_xor, BitXor);