Skip to content

Commit

Permalink
Move Covariance (Population) covar_pop to be a User Defined Aggregate…
Browse files Browse the repository at this point in the history
… Function (#10418)

* move covariance

* add sqllogictest
  • Loading branch information
yyy1000 authored May 10, 2024
1 parent fe89d0b commit 9f0e016
Show file tree
Hide file tree
Showing 15 changed files with 273 additions and 424 deletions.
10 changes: 1 addition & 9 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ pub enum AggregateFunction {
Stddev,
/// Standard Deviation (Population)
StddevPop,
/// Covariance (Population)
CovariancePop,
/// Correlation
Correlation,
/// Slope from linear regression
Expand Down Expand Up @@ -126,7 +124,6 @@ impl AggregateFunction {
VariancePop => "VAR_POP",
Stddev => "STDDEV",
StddevPop => "STDDEV_POP",
CovariancePop => "COVAR_POP",
Correlation => "CORR",
RegrSlope => "REGR_SLOPE",
RegrIntercept => "REGR_INTERCEPT",
Expand Down Expand Up @@ -181,7 +178,6 @@ impl FromStr for AggregateFunction {
"string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
"covar_pop" => AggregateFunction::CovariancePop,
"stddev" => AggregateFunction::Stddev,
"stddev_pop" => AggregateFunction::StddevPop,
"stddev_samp" => AggregateFunction::Stddev,
Expand Down Expand Up @@ -255,9 +251,6 @@ impl AggregateFunction {
AggregateFunction::VariancePop => {
variance_return_type(&coerced_data_types[0])
}
AggregateFunction::CovariancePop => {
covariance_return_type(&coerced_data_types[0])
}
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
Expand Down Expand Up @@ -349,8 +342,7 @@ impl AggregateFunction {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::CovariancePop
| AggregateFunction::Correlation
AggregateFunction::Correlation
| AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
| AggregateFunction::RegrCount
Expand Down
10 changes: 0 additions & 10 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,6 @@ pub fn coerce_types(
}
Ok(vec![Float64, Float64])
}
AggregateFunction::CovariancePop => {
if !is_covariance_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(vec![Float64, Float64])
}
AggregateFunction::Stddev | AggregateFunction::StddevPop => {
if !is_stddev_support_arg_type(&input_types[0]) {
return plan_err!(
Expand Down
81 changes: 81 additions & 0 deletions datafusion/functions-aggregate/src/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ make_udaf_expr_and_func!(
covar_samp_udaf
);

make_udaf_expr_and_func!(
CovariancePopulation,
covar_pop,
y x,
"Computes the population covariance.",
covar_pop_udaf
);

pub struct CovarianceSample {
signature: Signature,
aliases: Vec<String>,
Expand Down Expand Up @@ -120,6 +128,79 @@ impl AggregateUDFImpl for CovarianceSample {
}
}

pub struct CovariancePopulation {
signature: Signature,
}

impl Debug for CovariancePopulation {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("CovariancePopulation")
.field("name", &self.name())
.field("signature", &self.signature)
.finish()
}
}

impl Default for CovariancePopulation {
fn default() -> Self {
Self::new()
}
}

impl CovariancePopulation {
pub fn new() -> Self {
Self {
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for CovariancePopulation {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"covar_pop"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("Covariance requires numeric input types");
}

Ok(DataType::Float64)
}

fn state_fields(
&self,
name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
Field::new(
format_state_name(name, "algo_const"),
DataType::Float64,
true,
),
])
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(CovarianceAccumulator::try_new(
StatsType::Population,
)?))
}
}

/// An accumulator to compute covariance
/// The algorithm used is an online implementation and numerically stable. It is derived from the following paper
/// for calculating variance:
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
let functions: Vec<Arc<AggregateUDF>> = vec![
first_last::first_value_udaf(),
covariance::covar_samp_udaf(),
covariance::covar_pop_udaf(),
];

functions.into_iter().try_for_each(|udf| {
Expand Down
11 changes: 0 additions & 11 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,6 @@ pub fn create_aggregate_expr(
(AggregateFunction::VariancePop, true) => {
return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
}
(AggregateFunction::CovariancePop, false) => {
Arc::new(expressions::CovariancePop::new(
input_phy_exprs[0].clone(),
input_phy_exprs[1].clone(),
name,
data_type,
))
}
(AggregateFunction::CovariancePop, true) => {
return not_impl_err!("COVAR_POP(DISTINCT) aggregations are not available");
}
(AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new(
input_phy_exprs[0].clone(),
name,
Expand Down
Loading

0 comments on commit 9f0e016

Please sign in to comment.