Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Covariance (Population) covar_pop to be a User Defined Aggregate Function #10418

Merged
merged 2 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ pub enum AggregateFunction {
Stddev,
/// Standard Deviation (Population)
StddevPop,
/// Covariance (Population)
CovariancePop,
/// Correlation
Correlation,
/// Slope from linear regression
Expand Down Expand Up @@ -126,7 +124,6 @@ impl AggregateFunction {
VariancePop => "VAR_POP",
Stddev => "STDDEV",
StddevPop => "STDDEV_POP",
CovariancePop => "COVAR_POP",
Correlation => "CORR",
RegrSlope => "REGR_SLOPE",
RegrIntercept => "REGR_INTERCEPT",
Expand Down Expand Up @@ -181,7 +178,6 @@ impl FromStr for AggregateFunction {
"string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
"covar_pop" => AggregateFunction::CovariancePop,
"stddev" => AggregateFunction::Stddev,
"stddev_pop" => AggregateFunction::StddevPop,
"stddev_samp" => AggregateFunction::Stddev,
Expand Down Expand Up @@ -255,9 +251,6 @@ impl AggregateFunction {
AggregateFunction::VariancePop => {
variance_return_type(&coerced_data_types[0])
}
AggregateFunction::CovariancePop => {
covariance_return_type(&coerced_data_types[0])
}
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
Expand Down Expand Up @@ -349,8 +342,7 @@ impl AggregateFunction {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::CovariancePop
| AggregateFunction::Correlation
AggregateFunction::Correlation
| AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
| AggregateFunction::RegrCount
Expand Down
10 changes: 0 additions & 10 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,6 @@ pub fn coerce_types(
}
Ok(vec![Float64, Float64])
}
AggregateFunction::CovariancePop => {
if !is_covariance_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(vec![Float64, Float64])
}
AggregateFunction::Stddev | AggregateFunction::StddevPop => {
if !is_stddev_support_arg_type(&input_types[0]) {
return plan_err!(
Expand Down
81 changes: 81 additions & 0 deletions datafusion/functions-aggregate/src/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ make_udaf_expr_and_func!(
covar_samp_udaf
);

make_udaf_expr_and_func!(
CovariancePopulation,
covar_pop,
y x,
"Computes the population covariance.",
covar_pop_udaf
);

pub struct CovarianceSample {
signature: Signature,
aliases: Vec<String>,
Expand Down Expand Up @@ -120,6 +128,79 @@ impl AggregateUDFImpl for CovarianceSample {
}
}

pub struct CovariancePopulation {
signature: Signature,
}

impl Debug for CovariancePopulation {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("CovariancePopulation")
.field("name", &self.name())
.field("signature", &self.signature)
.finish()
}
}

impl Default for CovariancePopulation {
fn default() -> Self {
Self::new()
}
}

impl CovariancePopulation {
pub fn new() -> Self {
Self {
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for CovariancePopulation {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"covar_pop"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("Covariance requires numeric input types");
}

Ok(DataType::Float64)
}

fn state_fields(
&self,
name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
Field::new(
format_state_name(name, "algo_const"),
DataType::Float64,
true,
),
])
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(CovarianceAccumulator::try_new(
StatsType::Population,
)?))
}
}

/// An accumulator to compute covariance
/// The algorithm used is an online implementation and numerically stable. It is derived from the following paper
/// for calculating variance:
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
let functions: Vec<Arc<AggregateUDF>> = vec![
first_last::first_value_udaf(),
covariance::covar_samp_udaf(),
covariance::covar_pop_udaf(),
];

functions.into_iter().try_for_each(|udf| {
Expand Down
11 changes: 0 additions & 11 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,6 @@ pub fn create_aggregate_expr(
(AggregateFunction::VariancePop, true) => {
return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
}
(AggregateFunction::CovariancePop, false) => {
Arc::new(expressions::CovariancePop::new(
input_phy_exprs[0].clone(),
input_phy_exprs[1].clone(),
name,
data_type,
))
}
(AggregateFunction::CovariancePop, true) => {
return not_impl_err!("COVAR_POP(DISTINCT) aggregations are not available");
}
(AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new(
input_phy_exprs[0].clone(),
name,
Expand Down
Loading