Skip to content

Commit

Permalink
add correlation function (#1561)
Browse files Browse the repository at this point in the history
* add correlation function

* add readme

* add sql test

* fix divide by 0
  • Loading branch information
realno authored Jan 16, 2022
1 parent 1c39f5c commit b743610
Show file tree
Hide file tree
Showing 12 changed files with 768 additions and 21 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ This library currently supports many SQL constructs, including
- `CAST` to change types, including e.g. `Timestamp(Nanosecond, None)`
- Many mathematical unary and binary expressions such as `+`, `/`, `sqrt`, `tan`, `>=`.
- `WHERE` to filter
- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `VAR`, `COVAR`, `STDDEV` (sample and population)
- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population)
- `ORDER BY` together with an expression and optional `ASC` or `DESC` and also optional `NULLS FIRST` or `NULLS LAST`

## Supported Functions
Expand Down
1 change: 1 addition & 0 deletions ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ enum AggregateFunction {
COVARIANCE_POP=10;
STDDEV=11;
STDDEV_POP=12;
CORRELATION=13;
}

message AggregateExprNode {
Expand Down
4 changes: 4 additions & 0 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,9 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
AggregateFunction::StddevPop => {
protobuf::AggregateFunction::StddevPop
}
AggregateFunction::Correlation => {
protobuf::AggregateFunction::Correlation
}
};

let arg = &args[0];
Expand Down Expand Up @@ -1285,6 +1288,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
AggregateFunction::CovariancePop => Self::CovariancePop,
AggregateFunction::Stddev => Self::Stddev,
AggregateFunction::StddevPop => Self::StddevPop,
AggregateFunction::Correlation => Self::Correlation,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions ballista/rust/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
}
protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev,
protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop,
protobuf::AggregateFunction::Correlation => AggregateFunction::Correlation,
}
}
}
Expand Down
168 changes: 165 additions & 3 deletions datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ use crate::physical_plan::distinct_expressions;
use crate::physical_plan::expressions;
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use expressions::{
avg_return_type, covariance_return_type, stddev_return_type, sum_return_type,
variance_return_type,
avg_return_type, correlation_return_type, covariance_return_type, stddev_return_type,
sum_return_type, variance_return_type,
};
use std::{fmt, str::FromStr, sync::Arc};

Expand Down Expand Up @@ -79,6 +79,8 @@ pub enum AggregateFunction {
Covariance,
/// Covariance (Population)
CovariancePop,
/// Correlation
Correlation,
}

impl fmt::Display for AggregateFunction {
Expand Down Expand Up @@ -108,6 +110,7 @@ impl FromStr for AggregateFunction {
"covar" => AggregateFunction::Covariance,
"covar_samp" => AggregateFunction::Covariance,
"covar_pop" => AggregateFunction::CovariancePop,
"corr" => AggregateFunction::Correlation,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand Down Expand Up @@ -146,6 +149,7 @@ pub fn return_type(
AggregateFunction::CovariancePop => {
covariance_return_type(&coerced_data_types[0])
}
AggregateFunction::Correlation => correlation_return_type(&coerced_data_types[0]),
AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]),
AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]),
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
Expand Down Expand Up @@ -315,6 +319,19 @@ pub fn create_aggregate_expr(
"STDDEV_POP(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::Correlation, false) => {
Arc::new(expressions::Correlation::new(
coerced_phy_exprs[0].clone(),
coerced_phy_exprs[1].clone(),
name,
return_type,
))
}
(AggregateFunction::Correlation, true) => {
return Err(DataFusionError::NotImplemented(
"CORR(DISTINCT) aggregations are not available".to_string(),
));
}
})
}

Expand Down Expand Up @@ -370,6 +387,9 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
}
}

Expand All @@ -380,7 +400,8 @@ mod tests {
use crate::error::Result;
use crate::physical_plan::distinct_expressions::DistinctCount;
use crate::physical_plan::expressions::{
ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Stddev, Sum, Variance,
ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, Max, Min, Stddev,
Sum, Variance,
};

#[test]
Expand Down Expand Up @@ -760,6 +781,147 @@ mod tests {
Ok(())
}

#[test]
fn test_covar_expr() -> Result<()> {
let funcs = vec![AggregateFunction::Covariance];
let data_types = vec![
DataType::UInt32,
DataType::UInt64,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
for fun in funcs {
for data_type in &data_types {
let input_schema = Schema::new(vec![
Field::new("c1", data_type.clone(), true),
Field::new("c2", data_type.clone(), true),
]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(
expressions::Column::new_with_schema("c1", &input_schema)
.unwrap(),
),
Arc::new(
expressions::Column::new_with_schema("c2", &input_schema)
.unwrap(),
),
];
let result_agg_phy_exprs = create_aggregate_expr(
&fun,
false,
&input_phy_exprs[0..2],
&input_schema,
"c1",
)?;
if fun == AggregateFunction::Covariance {
assert!(result_agg_phy_exprs.as_any().is::<Covariance>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, true),
result_agg_phy_exprs.field().unwrap()
)
}
}
}
Ok(())
}

#[test]
fn test_covar_pop_expr() -> Result<()> {
let funcs = vec![AggregateFunction::CovariancePop];
let data_types = vec![
DataType::UInt32,
DataType::UInt64,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
for fun in funcs {
for data_type in &data_types {
let input_schema = Schema::new(vec![
Field::new("c1", data_type.clone(), true),
Field::new("c2", data_type.clone(), true),
]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(
expressions::Column::new_with_schema("c1", &input_schema)
.unwrap(),
),
Arc::new(
expressions::Column::new_with_schema("c2", &input_schema)
.unwrap(),
),
];
let result_agg_phy_exprs = create_aggregate_expr(
&fun,
false,
&input_phy_exprs[0..2],
&input_schema,
"c1",
)?;
if fun == AggregateFunction::Covariance {
assert!(result_agg_phy_exprs.as_any().is::<Covariance>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, true),
result_agg_phy_exprs.field().unwrap()
)
}
}
}
Ok(())
}

#[test]
fn test_corr_expr() -> Result<()> {
let funcs = vec![AggregateFunction::Correlation];
let data_types = vec![
DataType::UInt32,
DataType::UInt64,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
for fun in funcs {
for data_type in &data_types {
let input_schema = Schema::new(vec![
Field::new("c1", data_type.clone(), true),
Field::new("c2", data_type.clone(), true),
]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(
expressions::Column::new_with_schema("c1", &input_schema)
.unwrap(),
),
Arc::new(
expressions::Column::new_with_schema("c2", &input_schema)
.unwrap(),
),
];
let result_agg_phy_exprs = create_aggregate_expr(
&fun,
false,
&input_phy_exprs[0..2],
&input_schema,
"c1",
)?;
if fun == AggregateFunction::Covariance {
assert!(result_agg_phy_exprs.as_any().is::<Correlation>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, true),
result_agg_phy_exprs.field().unwrap()
)
}
}
}
Ok(())
}

#[test]
fn test_min_max() -> Result<()> {
let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?;
Expand Down
14 changes: 12 additions & 2 deletions datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ use crate::arrow::datatypes::Schema;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::aggregates::AggregateFunction;
use crate::physical_plan::expressions::{
is_avg_support_arg_type, is_covariance_support_arg_type, is_stddev_support_arg_type,
is_sum_support_arg_type, is_variance_support_arg_type, try_cast,
is_avg_support_arg_type, is_correlation_support_arg_type,
is_covariance_support_arg_type, is_stddev_support_arg_type, is_sum_support_arg_type,
is_variance_support_arg_type, try_cast,
};
use crate::physical_plan::functions::{Signature, TypeSignature};
use crate::physical_plan::PhysicalExpr;
Expand Down Expand Up @@ -141,6 +142,15 @@ pub(crate) fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::Correlation => {
if !is_correlation_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
}
}

Expand Down
Loading

0 comments on commit b743610

Please sign in to comment.