Skip to content

Commit

Permalink
Add stddev operator (#1525)
Browse files Browse the repository at this point in the history
* Initial implementation of variance

* get simple f64 type tests working

* add math functions to ScalarValue, some tests

* add to expressions and tests

* add more tests

* add test for ScalarValue add

* add tests for scalar arithmetic

* add test, finish variance

* fix warnings

* add more sql tests

* add stddev and tests

* add the hooks and expression

* add more tests

* fix lint and clipy

* address comments and fix test errors

* address comments

* add population and sample for variance and stddev

* address more comments

* fmt

* add test for less than 2 values

* fix inconsistency in the merge logic

* fix lint and clipy
  • Loading branch information
realno authored Jan 10, 2022
1 parent d6d90e9 commit 90de12a
Show file tree
Hide file tree
Showing 12 changed files with 1,987 additions and 5 deletions.
4 changes: 4 additions & 0 deletions ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ enum AggregateFunction {
COUNT = 4;
APPROX_DISTINCT = 5;
ARRAY_AGG = 6;
VARIANCE=7;
VARIANCE_POP=8;
STDDEV=9;
STDDEV_POP=10;
}

message AggregateExprNode {
Expand Down
12 changes: 12 additions & 0 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,14 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
AggregateFunction::Sum => protobuf::AggregateFunction::Sum,
AggregateFunction::Avg => protobuf::AggregateFunction::Avg,
AggregateFunction::Count => protobuf::AggregateFunction::Count,
AggregateFunction::Variance => protobuf::AggregateFunction::Variance,
AggregateFunction::VariancePop => {
protobuf::AggregateFunction::VariancePop
}
AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev,
AggregateFunction::StddevPop => {
protobuf::AggregateFunction::StddevPop
}
};

let arg = &args[0];
Expand Down Expand Up @@ -1256,6 +1264,10 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
AggregateFunction::Count => Self::Count,
AggregateFunction::ApproxDistinct => Self::ApproxDistinct,
AggregateFunction::ArrayAgg => Self::ArrayAgg,
AggregateFunction::Variance => Self::Variance,
AggregateFunction::VariancePop => Self::VariancePop,
AggregateFunction::Stddev => Self::Stddev,
AggregateFunction::StddevPop => Self::StddevPop,
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions ballista/rust/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
AggregateFunction::ApproxDistinct
}
protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg,
protobuf::AggregateFunction::Variance => AggregateFunction::Variance,
protobuf::AggregateFunction::VariancePop => AggregateFunction::VariancePop,
protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev,
protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/optimizer/simplify_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ impl ConstEvaluator {
}

/// Internal helper to evaluates an Expr
fn evaluate_to_scalar(&self, expr: Expr) -> Result<ScalarValue> {
pub(crate) fn evaluate_to_scalar(&self, expr: Expr) -> Result<ScalarValue> {
if let Expr::Literal(s) = expr {
return Ok(s);
}
Expand Down
277 changes: 274 additions & 3 deletions datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_t
use crate::physical_plan::distinct_expressions;
use crate::physical_plan::expressions;
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use expressions::{avg_return_type, sum_return_type};
use expressions::{
avg_return_type, stddev_return_type, sum_return_type, variance_return_type,
};
use std::{fmt, str::FromStr, sync::Arc};

/// the implementation of an aggregate function
Expand Down Expand Up @@ -64,6 +66,14 @@ pub enum AggregateFunction {
ApproxDistinct,
/// array_agg
ArrayAgg,
/// Variance (Sample)
Variance,
/// Variance (Population)
VariancePop,
/// Standard Deviation (Sample)
Stddev,
/// Standard Deviation (Population)
StddevPop,
}

impl fmt::Display for AggregateFunction {
Expand All @@ -84,6 +94,12 @@ impl FromStr for AggregateFunction {
"sum" => AggregateFunction::Sum,
"approx_distinct" => AggregateFunction::ApproxDistinct,
"array_agg" => AggregateFunction::ArrayAgg,
"var" => AggregateFunction::Variance,
"var_samp" => AggregateFunction::Variance,
"var_pop" => AggregateFunction::VariancePop,
"stddev" => AggregateFunction::Stddev,
"stddev_samp" => AggregateFunction::Stddev,
"stddev_pop" => AggregateFunction::StddevPop,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand Down Expand Up @@ -116,6 +132,10 @@ pub fn return_type(
Ok(coerced_data_types[0].clone())
}
AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]),
AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]),
AggregateFunction::VariancePop => variance_return_type(&coerced_data_types[0]),
AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]),
AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]),
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new(
"item",
Expand Down Expand Up @@ -212,6 +232,48 @@ pub fn create_aggregate_expr(
"AVG(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::Variance, false) => Arc::new(expressions::Variance::new(
coerced_phy_exprs[0].clone(),
name,
return_type,
)),
(AggregateFunction::Variance, true) => {
return Err(DataFusionError::NotImplemented(
"VAR(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::VariancePop, false) => {
Arc::new(expressions::VariancePop::new(
coerced_phy_exprs[0].clone(),
name,
return_type,
))
}
(AggregateFunction::VariancePop, true) => {
return Err(DataFusionError::NotImplemented(
"VAR_POP(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new(
coerced_phy_exprs[0].clone(),
name,
return_type,
)),
(AggregateFunction::Stddev, true) => {
return Err(DataFusionError::NotImplemented(
"STDDEV(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::StddevPop, false) => Arc::new(expressions::StddevPop::new(
coerced_phy_exprs[0].clone(),
name,
return_type,
)),
(AggregateFunction::StddevPop, true) => {
return Err(DataFusionError::NotImplemented(
"STDDEV_POP(DISTINCT) aggregations are not available".to_string(),
));
}
})
}

Expand Down Expand Up @@ -256,7 +318,12 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
.collect::<Vec<_>>();
Signature::uniform(1, valid, Volatility::Immutable)
}
AggregateFunction::Avg | AggregateFunction::Sum => {
AggregateFunction::Avg
| AggregateFunction::Sum
| AggregateFunction::Variance
| AggregateFunction::VariancePop
| AggregateFunction::Stddev
| AggregateFunction::StddevPop => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
}
Expand All @@ -267,7 +334,7 @@ mod tests {
use super::*;
use crate::error::Result;
use crate::physical_plan::expressions::{
ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum,
ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Stddev, Sum, Variance,
};

#[test]
Expand Down Expand Up @@ -450,6 +517,158 @@ mod tests {
Ok(())
}

#[test]
fn test_variance_expr() -> Result<()> {
let funcs = vec![AggregateFunction::Variance];
let data_types = vec![
DataType::UInt32,
DataType::UInt64,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
for fun in funcs {
for data_type in &data_types {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
)];
let result_agg_phy_exprs = create_aggregate_expr(
&fun,
false,
&input_phy_exprs[0..1],
&input_schema,
"c1",
)?;
if fun == AggregateFunction::Variance {
assert!(result_agg_phy_exprs.as_any().is::<Variance>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, true),
result_agg_phy_exprs.field().unwrap()
)
}
}
}
Ok(())
}

#[test]
fn test_var_pop_expr() -> Result<()> {
let funcs = vec![AggregateFunction::VariancePop];
let data_types = vec![
DataType::UInt32,
DataType::UInt64,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
for fun in funcs {
for data_type in &data_types {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
)];
let result_agg_phy_exprs = create_aggregate_expr(
&fun,
false,
&input_phy_exprs[0..1],
&input_schema,
"c1",
)?;
if fun == AggregateFunction::Variance {
assert!(result_agg_phy_exprs.as_any().is::<Variance>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, true),
result_agg_phy_exprs.field().unwrap()
)
}
}
}
Ok(())
}

#[test]
fn test_stddev_expr() -> Result<()> {
let funcs = vec![AggregateFunction::Stddev];
let data_types = vec![
DataType::UInt32,
DataType::UInt64,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
for fun in funcs {
for data_type in &data_types {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
)];
let result_agg_phy_exprs = create_aggregate_expr(
&fun,
false,
&input_phy_exprs[0..1],
&input_schema,
"c1",
)?;
if fun == AggregateFunction::Variance {
assert!(result_agg_phy_exprs.as_any().is::<Stddev>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, true),
result_agg_phy_exprs.field().unwrap()
)
}
}
}
Ok(())
}

#[test]
fn test_stddev_pop_expr() -> Result<()> {
let funcs = vec![AggregateFunction::StddevPop];
let data_types = vec![
DataType::UInt32,
DataType::UInt64,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
for fun in funcs {
for data_type in &data_types {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
)];
let result_agg_phy_exprs = create_aggregate_expr(
&fun,
false,
&input_phy_exprs[0..1],
&input_schema,
"c1",
)?;
if fun == AggregateFunction::Variance {
assert!(result_agg_phy_exprs.as_any().is::<Stddev>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, true),
result_agg_phy_exprs.field().unwrap()
)
}
}
}
Ok(())
}

#[test]
fn test_min_max() -> Result<()> {
let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?;
Expand Down Expand Up @@ -544,4 +763,56 @@ mod tests {
let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]);
assert!(observed.is_err());
}

#[test]
fn test_variance_return_type() -> Result<()> {
let observed = return_type(&AggregateFunction::Variance, &[DataType::Float32])?;
assert_eq!(DataType::Float64, observed);

let observed = return_type(&AggregateFunction::Variance, &[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);

let observed = return_type(&AggregateFunction::Variance, &[DataType::Int32])?;
assert_eq!(DataType::Float64, observed);

let observed = return_type(&AggregateFunction::Variance, &[DataType::UInt32])?;
assert_eq!(DataType::Float64, observed);

let observed = return_type(&AggregateFunction::Variance, &[DataType::Int64])?;
assert_eq!(DataType::Float64, observed);

Ok(())
}

#[test]
fn test_variance_no_utf8() {
let observed = return_type(&AggregateFunction::Variance, &[DataType::Utf8]);
assert!(observed.is_err());
}

#[test]
fn test_stddev_return_type() -> Result<()> {
let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float32])?;
assert_eq!(DataType::Float64, observed);

let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);

let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int32])?;
assert_eq!(DataType::Float64, observed);

let observed = return_type(&AggregateFunction::Stddev, &[DataType::UInt32])?;
assert_eq!(DataType::Float64, observed);

let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int64])?;
assert_eq!(DataType::Float64, observed);

Ok(())
}

#[test]
fn test_stddev_no_utf8() {
let observed = return_type(&AggregateFunction::Stddev, &[DataType::Utf8]);
assert!(observed.is_err());
}
}
Loading

0 comments on commit 90de12a

Please sign in to comment.