Skip to content

Commit

Permalink
Implement exact median, add AggregateState (#3009)
Browse files Browse the repository at this point in the history
* Implement exact median

* revert some changes

* toml format

* add median to protobuf

* remove some unwraps

* remove some unwraps

* remove some unwraps

* fix

* clippy

* reduce code duplication

* reduce code duplication

* more tests

* move tests to simplify github diff

* Update datafusion/expr/src/accumulator.rs

Co-authored-by: Andrew Lamb <[email protected]>

* refactor to make it more obvious that empty arrays are being created

* partially address feedback

* Update datafusion/physical-expr/src/aggregate/count_distinct.rs

Co-authored-by: Andrew Lamb <[email protected]>

* add more tests

* more docs

* clippy

* avoid a clone

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
andygrove and alamb authored Aug 5, 2022
1 parent 581934d commit 245def0
Show file tree
Hide file tree
Showing 32 changed files with 645 additions and 107 deletions.
7 changes: 4 additions & 3 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use datafusion::arrow::{
};

use datafusion::from_slice::FromSlice;
use datafusion::logical_expr::AggregateState;
use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumulator};
use datafusion::{logical_expr::Volatility, prelude::*, scalar::ScalarValue};
use std::sync::Arc;
Expand Down Expand Up @@ -107,10 +108,10 @@ impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
AggregateState::Scalar(ScalarValue::from(self.prod)),
AggregateState::Scalar(ScalarValue::from(self.n)),
])
}

Expand Down
1 change: 1 addition & 0 deletions datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ object_store = { version = "0.3", optional = true }
ordered-float = "3.0"
parquet = { version = "19.0.0", features = ["arrow"], optional = true }
pyo3 = { version = "0.16", optional = true }
serde_json = "1.0"
sqlparser = "0.19"
6 changes: 4 additions & 2 deletions datafusion/core/src/physical_plan/aggregates/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,10 @@ fn create_batch_from_map(
AggregateMode::Partial => {
let res = ScalarValue::iter_to_array(
accumulators.group_states.iter().map(|group_state| {
let x = group_state.accumulator_set[x].state().unwrap();
x[y].clone()
group_state.accumulator_set[x]
.state()
.and_then(|x| x[y].as_scalar().map(|v| v.clone()))
.expect("unexpected accumulator state in hash aggregate")
}),
)?;

Expand Down
189 changes: 186 additions & 3 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ async fn csv_query_stddev_6() -> Result<()> {
}

#[tokio::test]
async fn csv_query_median_1() -> Result<()> {
async fn csv_query_approx_median_1() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT approx_median(c2) FROM aggregate_test_100";
Expand All @@ -232,7 +232,7 @@ async fn csv_query_median_1() -> Result<()> {
}

#[tokio::test]
async fn csv_query_median_2() -> Result<()> {
async fn csv_query_approx_median_2() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT approx_median(c6) FROM aggregate_test_100";
Expand All @@ -243,7 +243,7 @@ async fn csv_query_median_2() -> Result<()> {
}

#[tokio::test]
async fn csv_query_median_3() -> Result<()> {
async fn csv_query_approx_median_3() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT approx_median(c12) FROM aggregate_test_100";
Expand All @@ -253,6 +253,189 @@ async fn csv_query_median_3() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_median_1() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT median(c2) FROM aggregate_test_100";
let actual = execute(&ctx, sql).await;
let expected = vec![vec!["3"]];
assert_float_eq(&expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_query_median_2() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT median(c6) FROM aggregate_test_100";
let actual = execute(&ctx, sql).await;
let expected = vec![vec!["1125553990140691277"]];
assert_float_eq(&expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_query_median_3() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT median(c12) FROM aggregate_test_100";
let actual = execute(&ctx, sql).await;
let expected = vec![vec!["0.5513900544385053"]];
assert_float_eq(&expected, &actual);
Ok(())
}

#[tokio::test]
async fn median_i8() -> Result<()> {
median_test(
"median",
DataType::Int8,
Arc::new(Int8Array::from(vec![i8::MIN, i8::MIN, 100, i8::MAX])),
"-14",
)
.await
}

#[tokio::test]
async fn median_i16() -> Result<()> {
median_test(
"median",
DataType::Int16,
Arc::new(Int16Array::from(vec![i16::MIN, i16::MIN, 100, i16::MAX])),
"-16334",
)
.await
}

#[tokio::test]
async fn median_i32() -> Result<()> {
median_test(
"median",
DataType::Int32,
Arc::new(Int32Array::from(vec![i32::MIN, i32::MIN, 100, i32::MAX])),
"-1073741774",
)
.await
}

#[tokio::test]
async fn median_i64() -> Result<()> {
median_test(
"median",
DataType::Int64,
Arc::new(Int64Array::from(vec![i64::MIN, i64::MIN, 100, i64::MAX])),
"-4611686018427388000",
)
.await
}

#[tokio::test]
async fn median_u8() -> Result<()> {
median_test(
"median",
DataType::UInt8,
Arc::new(UInt8Array::from(vec![u8::MIN, u8::MIN, 100, u8::MAX])),
"50",
)
.await
}

#[tokio::test]
async fn median_u16() -> Result<()> {
median_test(
"median",
DataType::UInt16,
Arc::new(UInt16Array::from(vec![u16::MIN, u16::MIN, 100, u16::MAX])),
"50",
)
.await
}

#[tokio::test]
async fn median_u32() -> Result<()> {
median_test(
"median",
DataType::UInt32,
Arc::new(UInt32Array::from(vec![u32::MIN, u32::MIN, 100, u32::MAX])),
"50",
)
.await
}

#[tokio::test]
async fn median_u64() -> Result<()> {
median_test(
"median",
DataType::UInt64,
Arc::new(UInt64Array::from(vec![u64::MIN, u64::MIN, 100, u64::MAX])),
"50",
)
.await
}

#[tokio::test]
async fn median_f32() -> Result<()> {
median_test(
"median",
DataType::Float32,
Arc::new(Float32Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])),
"3.3",
)
.await
}

#[tokio::test]
async fn median_f64() -> Result<()> {
median_test(
"median",
DataType::Float64,
Arc::new(Float64Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])),
"3.3",
)
.await
}

#[tokio::test]
async fn median_f64_nan() -> Result<()> {
median_test(
"median",
DataType::Float64,
Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])),
"NaN", // probably not the desired behavior? - see https://github.com/apache/arrow-datafusion/issues/3039
)
.await
}

#[tokio::test]
async fn approx_median_f64_nan() -> Result<()> {
median_test(
"approx_median",
DataType::Float64,
Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])),
"NaN", // probably not the desired behavior? - see https://github.com/apache/arrow-datafusion/issues/3039
)
.await
}

async fn median_test(
func: &str,
data_type: DataType,
values: ArrayRef,
expected: &str,
) -> Result<()> {
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![Field::new("a", data_type, false)]));
let batch = RecordBatch::try_new(schema.clone(), vec![values])?;
let table = Arc::new(MemTable::try_new(schema, vec![vec![batch]])?);
ctx.register_table("t", table)?;
let sql = format!("SELECT {}(a) FROM t", func);
let actual = execute(&ctx, &sql).await;
let expected = vec![vec![expected.to_owned()]];
assert_float_eq(&expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_query_external_table_count() {
let ctx = SessionContext::new();
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/tests/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ where
l.as_ref().parse::<f64>().unwrap(),
r.as_str().parse::<f64>().unwrap(),
);
assert!((l - r).abs() <= 2.0 * f64::EPSILON);
if l.is_nan() || r.is_nan() {
assert!(l.is_nan() && r.is_nan());
} else if (l - r).abs() > 2.0 * f64::EPSILON {
panic!("{} != {}", l, r)
}
});
}

Expand Down
45 changes: 40 additions & 5 deletions datafusion/expr/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@
//! Accumulator module contains the trait definition for aggregation function's accumulators.
use arrow::array::ArrayRef;
use datafusion_common::{Result, ScalarValue};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use std::fmt::Debug;

/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
/// generically accumulates values.
///
/// An accumulator knows how to:
/// * update its state from inputs via `update_batch`
/// * convert its internal state to a vector of scalar values
/// * convert its internal state to a vector of aggregate values
/// * update its state from multiple accumulators' states via `merge_batch`
/// * compute the final value from its internal state via `evaluate`
pub trait Accumulator: Send + Sync + Debug {
/// Returns the state of the accumulator at the end of the accumulation.
// in the case of an average on which we track `sum` and `n`, this function should return a vector
// of two values, sum and n.
fn state(&self) -> Result<Vec<ScalarValue>>;
/// in the case of an average on which we track `sum` and `n`, this function should return a vector
/// of two values, sum and n.
fn state(&self) -> Result<Vec<AggregateState>>;

/// updates the accumulator's state from a vector of arrays.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
Expand All @@ -44,3 +44,38 @@ pub trait Accumulator: Send + Sync + Debug {
/// returns its value based on its current state.
fn evaluate(&self) -> Result<ScalarValue>;
}

/// Representation of internal accumulator state. Accumulators can potentially have a mix of
/// scalar and array values. It may be desirable to add custom aggregator states here as well
/// in the future (perhaps `Custom(Box<dyn Any>)`?).
#[derive(Debug)]
pub enum AggregateState {
/// Simple scalar value. Note that `ScalarValue::List` can be used to pass multiple
/// values around
Scalar(ScalarValue),
/// Arrays can be used instead of `ScalarValue::List` and could potentially have better
/// performance with large data sets, although this has not been verified. It also allows
/// for use of arrow kernels with less overhead.
Array(ArrayRef),
}

impl AggregateState {
/// Access the aggregate state as a scalar value. An error will occur if the
/// state is not a scalar value.
pub fn as_scalar(&self) -> Result<&ScalarValue> {
match &self {
Self::Scalar(v) => Ok(v),
_ => Err(DataFusionError::Internal(
"AggregateState is not a scalar aggregate".to_string(),
)),
}
}

/// Access the aggregate state as an array value.
pub fn to_array(&self) -> ArrayRef {
match &self {
Self::Scalar(v) => v.to_array(),
Self::Array(array) => array.clone(),
}
}
}
9 changes: 8 additions & 1 deletion datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ pub enum AggregateFunction {
Max,
/// avg
Avg,
/// median
Median,
/// Approximate aggregate function
ApproxDistinct,
/// array_agg
Expand Down Expand Up @@ -107,6 +109,7 @@ impl FromStr for AggregateFunction {
"avg" => AggregateFunction::Avg,
"mean" => AggregateFunction::Avg,
"sum" => AggregateFunction::Sum,
"median" => AggregateFunction::Median,
"approx_distinct" => AggregateFunction::ApproxDistinct,
"array_agg" => AggregateFunction::ArrayAgg,
"var" => AggregateFunction::Variance,
Expand Down Expand Up @@ -175,7 +178,9 @@ pub fn return_type(
AggregateFunction::ApproxPercentileContWithWeight => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
AggregateFunction::ApproxMedian | AggregateFunction::Median => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::Grouping => Ok(DataType::Int32),
}
}
Expand Down Expand Up @@ -330,6 +335,7 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::Median => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
}
}
Expand Down Expand Up @@ -358,6 +364,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
| AggregateFunction::VariancePop
| AggregateFunction::Stddev
| AggregateFunction::StddevPop
| AggregateFunction::Median
| AggregateFunction::ApproxMedian => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub mod utils;
pub mod window_frame;
pub mod window_function;

pub use accumulator::Accumulator;
pub use accumulator::{Accumulator, AggregateState};
pub use aggregate_function::AggregateFunction;
pub use built_in_function::BuiltinScalarFunction;
pub use columnar_value::{ColumnarValue, NullColumnarValue};
Expand Down
Loading

0 comments on commit 245def0

Please sign in to comment.