diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 5e0f41bc81eb..378d2548effc 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -23,6 +23,7 @@ use datafusion::arrow::{ }; use datafusion::from_slice::FromSlice; +use datafusion::logical_expr::AggregateState; use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumulator}; use datafusion::{logical_expr::Volatility, prelude::*, scalar::ScalarValue}; use std::sync::Arc; @@ -107,10 +108,10 @@ impl Accumulator for GeometricMean { // This function serializes our state to `ScalarValue`, which DataFusion uses // to pass this state between execution stages. // Note that this can be arbitrary data. - fn state(&self) -> Result> { + fn state(&self) -> Result> { Ok(vec![ - ScalarValue::from(self.prod), - ScalarValue::from(self.n), + AggregateState::Scalar(ScalarValue::from(self.prod)), + AggregateState::Scalar(ScalarValue::from(self.n)), ]) } diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index e7873de4d03a..33d6af087d4c 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -45,4 +45,5 @@ object_store = { version = "0.3", optional = true } ordered-float = "3.0" parquet = { version = "19.0.0", features = ["arrow"], optional = true } pyo3 = { version = "0.16", optional = true } +serde_json = "1.0" sqlparser = "0.19" diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs b/datafusion/core/src/physical_plan/aggregates/hash.rs index c21109495e20..54806d37fcac 100644 --- a/datafusion/core/src/physical_plan/aggregates/hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/hash.rs @@ -428,8 +428,10 @@ fn create_batch_from_map( AggregateMode::Partial => { let res = ScalarValue::iter_to_array( accumulators.group_states.iter().map(|group_state| { - let x = group_state.accumulator_set[x].state().unwrap(); - x[y].clone() + group_state.accumulator_set[x] + .state() + .and_then(|x| x[y].as_scalar().map(|v| v.clone())) + .expect("unexpected accumulator state in hash aggregate") }), )?; diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 02d4b3a4d142..eb0e07f84291 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -221,7 +221,7 @@ async fn csv_query_stddev_6() -> Result<()> { } #[tokio::test] -async fn csv_query_median_1() -> Result<()> { +async fn csv_query_approx_median_1() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; let sql = "SELECT approx_median(c2) FROM aggregate_test_100"; @@ -232,7 +232,7 @@ async fn csv_query_median_1() -> Result<()> { } #[tokio::test] -async fn csv_query_median_2() -> Result<()> { +async fn csv_query_approx_median_2() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; let sql = "SELECT approx_median(c6) FROM aggregate_test_100"; @@ -243,7 +243,7 @@ async fn csv_query_median_2() -> Result<()> { } #[tokio::test] -async fn csv_query_median_3() -> Result<()> { +async fn csv_query_approx_median_3() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; let sql = "SELECT approx_median(c12) FROM aggregate_test_100"; @@ -253,6 +253,189 @@ async fn csv_query_median_3() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_median_1() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT median(c2) FROM aggregate_test_100"; + let actual = execute(&ctx, sql).await; + let expected = vec![vec!["3"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_median_2() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT median(c6) FROM aggregate_test_100"; + let actual = execute(&ctx, sql).await; + let expected = vec![vec!["1125553990140691277"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_median_3() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT median(c12) FROM aggregate_test_100"; + let actual = execute(&ctx, sql).await; + let expected = vec![vec!["0.5513900544385053"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn median_i8() -> Result<()> { + median_test( + "median", + DataType::Int8, + Arc::new(Int8Array::from(vec![i8::MIN, i8::MIN, 100, i8::MAX])), + "-14", + ) + .await +} + +#[tokio::test] +async fn median_i16() -> Result<()> { + median_test( + "median", + DataType::Int16, + Arc::new(Int16Array::from(vec![i16::MIN, i16::MIN, 100, i16::MAX])), + "-16334", + ) + .await +} + +#[tokio::test] +async fn median_i32() -> Result<()> { + median_test( + "median", + DataType::Int32, + Arc::new(Int32Array::from(vec![i32::MIN, i32::MIN, 100, i32::MAX])), + "-1073741774", + ) + .await +} + +#[tokio::test] +async fn median_i64() -> Result<()> { + median_test( + "median", + DataType::Int64, + Arc::new(Int64Array::from(vec![i64::MIN, i64::MIN, 100, i64::MAX])), + "-4611686018427388000", + ) + .await +} + +#[tokio::test] +async fn median_u8() -> Result<()> { + median_test( + "median", + DataType::UInt8, + Arc::new(UInt8Array::from(vec![u8::MIN, u8::MIN, 100, u8::MAX])), + "50", + ) + .await +} + +#[tokio::test] +async fn median_u16() -> Result<()> { + median_test( + "median", + DataType::UInt16, + Arc::new(UInt16Array::from(vec![u16::MIN, u16::MIN, 100, u16::MAX])), + "50", + ) + .await +} + +#[tokio::test] +async fn median_u32() -> Result<()> { + median_test( + "median", + DataType::UInt32, + Arc::new(UInt32Array::from(vec![u32::MIN, u32::MIN, 100, u32::MAX])), + "50", + ) + .await +} + +#[tokio::test] +async fn median_u64() -> Result<()> { + median_test( + "median", + DataType::UInt64, + Arc::new(UInt64Array::from(vec![u64::MIN, u64::MIN, 100, u64::MAX])), + "50", + ) + .await +} + +#[tokio::test] +async fn median_f32() -> Result<()> { + median_test( + "median", + DataType::Float32, + Arc::new(Float32Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])), + "3.3", + ) + .await +} + +#[tokio::test] +async fn median_f64() -> Result<()> { + median_test( + "median", + DataType::Float64, + Arc::new(Float64Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])), + "3.3", + ) + .await +} + +#[tokio::test] +async fn median_f64_nan() -> Result<()> { + median_test( + "median", + DataType::Float64, + Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])), + "NaN", // probably not the desired behavior? - see https://github.com/apache/arrow-datafusion/issues/3039 + ) + .await +} + +#[tokio::test] +async fn approx_median_f64_nan() -> Result<()> { + median_test( + "approx_median", + DataType::Float64, + Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])), + "NaN", // probably not the desired behavior? - see https://github.com/apache/arrow-datafusion/issues/3039 + ) + .await +} + +async fn median_test( + func: &str, + data_type: DataType, + values: ArrayRef, + expected: &str, +) -> Result<()> { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![Field::new("a", data_type, false)])); + let batch = RecordBatch::try_new(schema.clone(), vec![values])?; + let table = Arc::new(MemTable::try_new(schema, vec![vec![batch]])?); + ctx.register_table("t", table)?; + let sql = format!("SELECT {}(a) FROM t", func); + let actual = execute(&ctx, &sql).await; + let expected = vec![vec![expected.to_owned()]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_external_table_count() { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 847f19cf8dd0..5481161d0c27 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -128,7 +128,11 @@ where l.as_ref().parse::().unwrap(), r.as_str().parse::().unwrap(), ); - assert!((l - r).abs() <= 2.0 * f64::EPSILON); + if l.is_nan() || r.is_nan() { + assert!(l.is_nan() && r.is_nan()); + } else if (l - r).abs() > 2.0 * f64::EPSILON { + panic!("{} != {}", l, r) + } }); } diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs index d59764957ef5..6c146bc30693 100644 --- a/datafusion/expr/src/accumulator.rs +++ b/datafusion/expr/src/accumulator.rs @@ -18,7 +18,7 @@ //! Accumulator module contains the trait definition for aggregation function's accumulators. use arrow::array::ArrayRef; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use std::fmt::Debug; /// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and @@ -26,14 +26,14 @@ use std::fmt::Debug; /// /// An accumulator knows how to: /// * update its state from inputs via `update_batch` -/// * convert its internal state to a vector of scalar values +/// * convert its internal state to a vector of aggregate values /// * update its state from multiple accumulators' states via `merge_batch` /// * compute the final value from its internal state via `evaluate` pub trait Accumulator: Send + Sync + Debug { /// Returns the state of the accumulator at the end of the accumulation. - // in the case of an average on which we track `sum` and `n`, this function should return a vector - // of two values, sum and n. - fn state(&self) -> Result>; + /// in the case of an average on which we track `sum` and `n`, this function should return a vector + /// of two values, sum and n. + fn state(&self) -> Result>; /// updates the accumulator's state from a vector of arrays. fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; @@ -44,3 +44,38 @@ pub trait Accumulator: Send + Sync + Debug { /// returns its value based on its current state. fn evaluate(&self) -> Result; } + +/// Representation of internal accumulator state. Accumulators can potentially have a mix of +/// scalar and array values. It may be desirable to add custom aggregator states here as well +/// in the future (perhaps `Custom(Box)`?). +#[derive(Debug)] +pub enum AggregateState { + /// Simple scalar value. Note that `ScalarValue::List` can be used to pass multiple + /// values around + Scalar(ScalarValue), + /// Arrays can be used instead of `ScalarValue::List` and could potentially have better + /// performance with large data sets, although this has not been verified. It also allows + /// for use of arrow kernels with less overhead. + Array(ArrayRef), +} + +impl AggregateState { + /// Access the aggregate state as a scalar value. An error will occur if the + /// state is not a scalar value. + pub fn as_scalar(&self) -> Result<&ScalarValue> { + match &self { + Self::Scalar(v) => Ok(v), + _ => Err(DataFusionError::Internal( + "AggregateState is not a scalar aggregate".to_string(), + )), + } + } + + /// Access the aggregate state as an array value. + pub fn to_array(&self) -> ArrayRef { + match &self { + Self::Scalar(v) => v.to_array(), + Self::Array(array) => array.clone(), + } + } +} diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 30bf0521d7f6..09d759e56466 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -62,6 +62,8 @@ pub enum AggregateFunction { Max, /// avg Avg, + /// median + Median, /// Approximate aggregate function ApproxDistinct, /// array_agg @@ -107,6 +109,7 @@ impl FromStr for AggregateFunction { "avg" => AggregateFunction::Avg, "mean" => AggregateFunction::Avg, "sum" => AggregateFunction::Sum, + "median" => AggregateFunction::Median, "approx_distinct" => AggregateFunction::ApproxDistinct, "array_agg" => AggregateFunction::ArrayAgg, "var" => AggregateFunction::Variance, @@ -175,7 +178,9 @@ pub fn return_type( AggregateFunction::ApproxPercentileContWithWeight => { Ok(coerced_data_types[0].clone()) } - AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()), + AggregateFunction::ApproxMedian | AggregateFunction::Median => { + Ok(coerced_data_types[0].clone()) + } AggregateFunction::Grouping => Ok(DataType::Int32), } } @@ -330,6 +335,7 @@ pub fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::Median => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), } } @@ -358,6 +364,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature { | AggregateFunction::VariancePop | AggregateFunction::Stddev | AggregateFunction::StddevPop + | AggregateFunction::Median | AggregateFunction::ApproxMedian => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index f71243610ddf..90007a8bddfc 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -53,7 +53,7 @@ pub mod utils; pub mod window_frame; pub mod window_function; -pub use accumulator::Accumulator; +pub use accumulator::{Accumulator, AggregateState}; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; pub use columnar_value::{ColumnarValue, NullColumnarValue}; diff --git a/datafusion/physical-expr/src/aggregate/approx_distinct.rs b/datafusion/physical-expr/src/aggregate/approx_distinct.rs index c67d1c9d35aa..5b391ed84c90 100644 --- a/datafusion/physical-expr/src/aggregate/approx_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/approx_distinct.rs @@ -30,7 +30,7 @@ use arrow::datatypes::{ }; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; use std::any::type_name; use std::any::Any; use std::convert::TryFrom; @@ -232,8 +232,8 @@ macro_rules! default_accumulator_impl { Ok(()) } - fn state(&self) -> Result> { - let value = ScalarValue::from(&self.hll); + fn state(&self) -> Result> { + let value = AggregateState::Scalar(ScalarValue::from(&self.hll)); Ok(vec![value]) } diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index 2315ad1d540f..41c6c72db18c 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -29,7 +29,7 @@ use arrow::{ use datafusion_common::DataFusionError; use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; use ordered_float::OrderedFloat; use std::{any::Any, iter, sync::Arc}; @@ -287,8 +287,13 @@ impl ApproxPercentileAccumulator { } impl Accumulator for ApproxPercentileAccumulator { - fn state(&self) -> Result> { - Ok(self.digest.to_scalar_state()) + fn state(&self) -> Result> { + Ok(self + .digest + .to_scalar_state() + .into_iter() + .map(AggregateState::Scalar) + .collect()) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs index f9874b0a5a45..40a44c3a55e1 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs @@ -26,7 +26,7 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; use std::{any::Any, sync::Arc}; @@ -114,7 +114,7 @@ impl ApproxPercentileWithWeightAccumulator { } impl Accumulator for ApproxPercentileWithWeightAccumulator { - fn state(&self) -> Result> { + fn state(&self) -> Result> { self.approx_percentile_cont_accumulator.state() } diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index eaed89390cc5..e7fd0937cc87 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -23,7 +23,7 @@ use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; use std::any::Any; use std::sync::Arc; @@ -143,8 +143,8 @@ impl Accumulator for ArrayAggAccumulator { }) } - fn state(&self) -> Result> { - Ok(vec![self.evaluate()?]) + fn state(&self) -> Result> { + Ok(vec![AggregateState::Scalar(self.evaluate()?)]) } fn evaluate(&self) -> Result { diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 44e24e93c91c..f9899379d2c9 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -29,7 +29,7 @@ use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; /// Expression for a ARRAY_AGG(DISTINCT) aggregation. #[derive(Debug)] @@ -119,11 +119,11 @@ impl DistinctArrayAggAccumulator { } impl Accumulator for DistinctArrayAggAccumulator { - fn state(&self) -> Result> { - Ok(vec![ScalarValue::List( + fn state(&self) -> Result> { + Ok(vec![AggregateState::Scalar(ScalarValue::List( Some(self.values.clone().into_iter().collect()), Box::new(Field::new("item", self.datatype.clone(), true)), - )]) + ))]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 1b1d995257b2..a55e0e35278f 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -33,7 +33,7 @@ use arrow::{ }; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; use datafusion_row::accessor::RowAccessor; /// AVG aggregate expression @@ -150,8 +150,11 @@ impl AvgAccumulator { } impl Accumulator for AvgAccumulator { - fn state(&self) -> Result> { - Ok(vec![ScalarValue::from(self.count), self.sum.clone()]) + fn state(&self) -> Result> { + Ok(vec![ + AggregateState::Scalar(ScalarValue::from(self.count)), + AggregateState::Scalar(self.sum.clone()), + ]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 23d2a84d132e..8d76e35e4945 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -251,6 +251,16 @@ pub fn create_aggregate_expr( )) } (AggregateFunction::ApproxMedian, true) => { + return Err(DataFusionError::NotImplemented( + "APPROX_MEDIAN(DISTINCT) aggregations are not available".to_string(), + )); + } + (AggregateFunction::Median, false) => Arc::new(expressions::Median::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), + (AggregateFunction::Median, true) => { return Err(DataFusionError::NotImplemented( "MEDIAN(DISTINCT) aggregations are not available".to_string(), )); diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs b/datafusion/physical-expr/src/aggregate/correlation.rs index 94a820849e51..3bbea5d9b3a3 100644 --- a/datafusion/physical-expr/src/aggregate/correlation.rs +++ b/datafusion/physical-expr/src/aggregate/correlation.rs @@ -25,7 +25,7 @@ use crate::{AggregateExpr, PhysicalExpr}; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; use std::any::Any; use std::sync::Arc; @@ -133,14 +133,14 @@ impl CorrelationAccumulator { } impl Accumulator for CorrelationAccumulator { - fn state(&self) -> Result> { + fn state(&self) -> Result> { Ok(vec![ - ScalarValue::from(self.covar.get_count()), - ScalarValue::from(self.covar.get_mean1()), - ScalarValue::from(self.stddev1.get_m2()), - ScalarValue::from(self.covar.get_mean2()), - ScalarValue::from(self.stddev2.get_m2()), - ScalarValue::from(self.covar.get_algo_const()), + AggregateState::Scalar(ScalarValue::from(self.covar.get_count())), + AggregateState::Scalar(ScalarValue::from(self.covar.get_mean1())), + AggregateState::Scalar(ScalarValue::from(self.stddev1.get_m2())), + AggregateState::Scalar(ScalarValue::from(self.covar.get_mean2())), + AggregateState::Scalar(ScalarValue::from(self.stddev2.get_m2())), + AggregateState::Scalar(ScalarValue::from(self.covar.get_algo_const())), ]) } @@ -191,6 +191,7 @@ impl Accumulator for CorrelationAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::aggregate::utils::get_accum_scalar_values_as_arrays; use crate::expressions::col; use crate::expressions::tests::aggregate; use crate::generic_test_op2; @@ -469,12 +470,7 @@ mod tests { .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; - let state2 = accum2 - .state()? - .iter() - .map(|v| vec![v.clone()]) - .map(|x| ScalarValue::iter_to_array(x).unwrap()) - .collect::>(); + let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?; accum1.merge_batch(&state2)?; accum1.evaluate() } diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 2b02d03b51f6..982c1dc09ed4 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -29,7 +29,7 @@ use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; use datafusion_row::accessor::RowAccessor; use crate::expressions::format_state_name; @@ -134,8 +134,10 @@ impl Accumulator for CountAccumulator { Ok(()) } - fn state(&self) -> Result> { - Ok(vec![ScalarValue::Int64(Some(self.count))]) + fn state(&self) -> Result> { + Ok(vec![AggregateState::Scalar(ScalarValue::Int64(Some( + self.count, + )))]) } fn evaluate(&self) -> Result { diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 744d9b90d9bb..6060ddb4dc99 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -28,7 +28,7 @@ use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; #[derive(Debug, PartialEq, Eq, Hash, Clone)] struct DistinctScalarValues(Vec); @@ -177,7 +177,7 @@ impl Accumulator for DistinctCountAccumulator { self.merge(&v) }) } - fn state(&self) -> Result> { + fn state(&self) -> Result> { let mut cols_out = self .state_data_types .iter() @@ -206,7 +206,7 @@ impl Accumulator for DistinctCountAccumulator { ) }); - Ok(cols_out) + Ok(cols_out.into_iter().map(AggregateState::Scalar).collect()) } fn evaluate(&self) -> Result { @@ -223,6 +223,7 @@ impl Accumulator for DistinctCountAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::aggregate::utils::get_accum_scalar_values; use arrow::array::{ ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array, UInt64Array, @@ -341,7 +342,7 @@ mod tests { let mut accum = agg.create_accumulator()?; accum.update_batch(arrays)?; - Ok((accum.state()?, accum.evaluate()?)) + Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?)) } fn run_update( @@ -372,7 +373,7 @@ mod tests { accum.update_batch(&arrays)?; - Ok((accum.state()?, accum.evaluate()?)) + Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?)) } fn run_merge_batch(arrays: &[ArrayRef]) -> Result<(Vec, ScalarValue)> { @@ -390,7 +391,7 @@ mod tests { let mut accum = agg.create_accumulator()?; accum.merge_batch(arrays)?; - Ok((accum.state()?, accum.evaluate()?)) + Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?)) } // Used trait to create associated constant for f32 and f64 diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs index 1df002b489e8..9cd3191277bc 100644 --- a/datafusion/physical-expr/src/aggregate/covariance.rs +++ b/datafusion/physical-expr/src/aggregate/covariance.rs @@ -30,7 +30,7 @@ use arrow::{ }; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; use crate::aggregate::stats::StatsType; use crate::expressions::format_state_name; @@ -237,12 +237,12 @@ impl CovarianceAccumulator { } impl Accumulator for CovarianceAccumulator { - fn state(&self) -> Result> { + fn state(&self) -> Result> { Ok(vec![ - ScalarValue::from(self.count), - ScalarValue::from(self.mean1), - ScalarValue::from(self.mean2), - ScalarValue::from(self.algo_const), + AggregateState::Scalar(ScalarValue::from(self.count)), + AggregateState::Scalar(ScalarValue::from(self.mean1)), + AggregateState::Scalar(ScalarValue::from(self.mean2)), + AggregateState::Scalar(ScalarValue::from(self.algo_const)), ]) } @@ -352,6 +352,7 @@ impl Accumulator for CovarianceAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::aggregate::utils::get_accum_scalar_values_as_arrays; use crate::expressions::col; use crate::expressions::tests::aggregate; use crate::generic_test_op2; @@ -644,12 +645,7 @@ mod tests { .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; - let state2 = accum2 - .state()? - .iter() - .map(|v| vec![v.clone()]) - .map(|x| ScalarValue::iter_to_array(x).unwrap()) - .collect::>(); + let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?; accum1.merge_batch(&state2)?; accum1.evaluate() } diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs new file mode 100644 index 000000000000..6b68f2ec33c2 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Median + +use crate::expressions::format_state_name; +use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::{Array, ArrayRef, PrimitiveArray, PrimitiveBuilder}; +use arrow::compute::sort; +use arrow::datatypes::{ + ArrowPrimitiveType, DataType, Field, Float32Type, Float64Type, Int16Type, Int32Type, + Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::{Accumulator, AggregateState}; +use std::any::Any; +use std::sync::Arc; + +/// MEDIAN aggregate expression. This uses a lot of memory because all values need to be +/// stored in memory before a result can be computed. If an approximation is sufficient +/// then APPROX_MEDIAN provides a much more efficient solution. +#[derive(Debug)] +pub struct Median { + name: String, + expr: Arc, + data_type: DataType, +} + +impl Median { + /// Create a new MEDIAN aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + expr, + data_type, + } + } +} + +impl AggregateExpr for Median { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(MedianAccumulator { + data_type: self.data_type.clone(), + all_values: vec![], + })) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + &format_state_name(&self.name, "median"), + self.data_type.clone(), + true, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +#[derive(Debug)] +struct MedianAccumulator { + data_type: DataType, + all_values: Vec, +} + +macro_rules! median { + ($SELF:ident, $TY:ty, $SCALAR_TY:ident, $TWO:expr) => {{ + let combined = combine_arrays::<$TY>($SELF.all_values.as_slice())?; + if combined.is_empty() { + return Ok(ScalarValue::Null); + } + let sorted = sort(&combined, None)?; + let array = sorted + .as_any() + .downcast_ref::>() + .ok_or(DataFusionError::Internal( + "median! macro failed to cast array to expected type".to_string(), + ))?; + let len = sorted.len(); + let mid = len / 2; + if len % 2 == 0 { + Ok(ScalarValue::$SCALAR_TY(Some( + (array.value(mid - 1) + array.value(mid)) / $TWO, + ))) + } else { + Ok(ScalarValue::$SCALAR_TY(Some(array.value(mid)))) + } + }}; +} + +impl Accumulator for MedianAccumulator { + fn state(&self) -> Result> { + let mut vec: Vec = self + .all_values + .iter() + .map(|v| AggregateState::Array(v.clone())) + .collect(); + if vec.is_empty() { + match self.data_type { + DataType::UInt8 => vec.push(empty_array::()), + DataType::UInt16 => vec.push(empty_array::()), + DataType::UInt32 => vec.push(empty_array::()), + DataType::UInt64 => vec.push(empty_array::()), + DataType::Int8 => vec.push(empty_array::()), + DataType::Int16 => vec.push(empty_array::()), + DataType::Int32 => vec.push(empty_array::()), + DataType::Int64 => vec.push(empty_array::()), + DataType::Float32 => vec.push(empty_array::()), + DataType::Float64 => vec.push(empty_array::()), + _ => { + return Err(DataFusionError::Execution( + "unsupported data type for median".to_string(), + )) + } + } + } + Ok(vec) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let x = values[0].clone(); + self.all_values.extend_from_slice(&[x]); + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + for array in states { + self.all_values.extend_from_slice(&[array.clone()]); + } + Ok(()) + } + + fn evaluate(&self) -> Result { + match self.all_values[0].data_type() { + DataType::Int8 => median!(self, arrow::datatypes::Int8Type, Int8, 2), + DataType::Int16 => median!(self, arrow::datatypes::Int16Type, Int16, 2), + DataType::Int32 => median!(self, arrow::datatypes::Int32Type, Int32, 2), + DataType::Int64 => median!(self, arrow::datatypes::Int64Type, Int64, 2), + DataType::UInt8 => median!(self, arrow::datatypes::UInt8Type, UInt8, 2), + DataType::UInt16 => median!(self, arrow::datatypes::UInt16Type, UInt16, 2), + DataType::UInt32 => median!(self, arrow::datatypes::UInt32Type, UInt32, 2), + DataType::UInt64 => median!(self, arrow::datatypes::UInt64Type, UInt64, 2), + DataType::Float32 => { + median!(self, arrow::datatypes::Float32Type, Float32, 2_f32) + } + DataType::Float64 => { + median!(self, arrow::datatypes::Float64Type, Float64, 2_f64) + } + _ => Err(DataFusionError::Execution( + "unsupported data type for median".to_string(), + )), + } + } +} + +/// Create an empty array +fn empty_array() -> AggregateState { + AggregateState::Array(Arc::new(PrimitiveBuilder::::new(0).finish())) +} + +/// Combine all non-null values from provided arrays into a single array +fn combine_arrays(arrays: &[ArrayRef]) -> Result { + let len = arrays.iter().map(|a| a.len() - a.null_count()).sum(); + let mut builder: PrimitiveBuilder = PrimitiveBuilder::new(len); + for array in arrays { + let array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "combine_arrays failed to cast array to expected type".to_string(), + ) + })?; + for i in 0..array.len() { + if !array.is_null(i) { + builder.append_value(array.value(i)); + } + } + } + Ok(Arc::new(builder.finish())) +} + +#[cfg(test)] +mod test { + use crate::aggregate::median::combine_arrays; + use arrow::array::{Int32Array, UInt32Array}; + use arrow::datatypes::{Int32Type, UInt32Type}; + use datafusion_common::Result; + use std::sync::Arc; + + #[test] + fn combine_i32_array() -> Result<()> { + let a = Arc::new(Int32Array::from(vec![1, 2, 3])); + let b = combine_arrays::(&[a.clone(), a])?; + assert_eq!( + "PrimitiveArray\n[\n 1,\n 2,\n 3,\n 1,\n 2,\n 3,\n]", + format!("{:?}", b) + ); + Ok(()) + } + + #[test] + fn combine_u32_array() -> Result<()> { + let a = Arc::new(UInt32Array::from(vec![1, 2, 3])); + let b = combine_arrays::(&[a.clone(), a])?; + assert_eq!( + "PrimitiveArray\n[\n 1,\n 2,\n 3,\n 1,\n 2,\n 3,\n]", + format!("{:?}", b) + ); + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index bd56973b167b..077f4d725de3 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -36,7 +36,7 @@ use arrow::{ }; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; use crate::aggregate::row_accumulator::RowAccumulator; use crate::expressions::format_state_name; @@ -538,8 +538,8 @@ impl Accumulator for MaxAccumulator { self.update_batch(states) } - fn state(&self) -> Result> { - Ok(vec![self.max.clone()]) + fn state(&self) -> Result> { + Ok(vec![AggregateState::Scalar(self.max.clone())]) } fn evaluate(&self) -> Result { @@ -691,8 +691,8 @@ impl MinAccumulator { } impl Accumulator for MinAccumulator { - fn state(&self) -> Result> { - Ok(vec![self.min.clone()]) + fn state(&self) -> Result> { + Ok(vec![AggregateState::Scalar(self.min.clone())]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 1cbd4aeea008..a8d59d71490f 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -37,6 +37,7 @@ pub(crate) mod count; pub(crate) mod count_distinct; pub(crate) mod covariance; pub(crate) mod grouping; +pub(crate) mod median; #[macro_use] pub(crate) mod min_max; pub mod build_in; @@ -47,6 +48,7 @@ pub(crate) mod stddev; pub(crate) mod sum; pub(crate) mod sum_distinct; mod tdigest; +pub mod utils; pub(crate) mod variance; /// An aggregate expression that: diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index 13085fee2285..77f080293e27 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -27,7 +27,7 @@ use crate::{AggregateExpr, PhysicalExpr}; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression #[derive(Debug)] @@ -180,11 +180,11 @@ impl StddevAccumulator { } impl Accumulator for StddevAccumulator { - fn state(&self) -> Result> { + fn state(&self) -> Result> { Ok(vec![ - ScalarValue::from(self.variance.get_count()), - ScalarValue::from(self.variance.get_mean()), - ScalarValue::from(self.variance.get_m2()), + AggregateState::Scalar(ScalarValue::from(self.variance.get_count())), + AggregateState::Scalar(ScalarValue::from(self.variance.get_mean())), + AggregateState::Scalar(ScalarValue::from(self.variance.get_m2())), ]) } @@ -216,6 +216,7 @@ impl Accumulator for StddevAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::aggregate::utils::get_accum_scalar_values_as_arrays; use crate::expressions::col; use crate::expressions::tests::aggregate; use crate::generic_test_op; @@ -441,12 +442,7 @@ mod tests { .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; - let state2 = accum2 - .state()? - .iter() - .map(|v| vec![v.clone()]) - .map(|x| ScalarValue::iter_to_array(x).unwrap()) - .collect::>(); + let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?; accum1.merge_batch(&state2)?; accum1.evaluate() } diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 866e90f1eacb..b0a7de6c633c 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -32,7 +32,7 @@ use arrow::{ datatypes::Field, }; use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; use crate::aggregate::row_accumulator::RowAccumulator; use crate::expressions::format_state_name; @@ -435,8 +435,8 @@ pub(crate) fn add_to_row( } impl Accumulator for SumAccumulator { - fn state(&self) -> Result> { - Ok(vec![self.sum.clone()]) + fn state(&self) -> Result> { + Ok(vec![AggregateState::Scalar(self.sum.clone())]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index a64b4b497c19..d939a033e368 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -29,7 +29,7 @@ use std::collections::HashSet; use crate::{AggregateExpr, PhysicalExpr}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; /// Expression for a SUM(DISTINCT) aggregation. #[derive(Debug)] @@ -128,7 +128,7 @@ impl DistinctSumAccumulator { } impl Accumulator for DistinctSumAccumulator { - fn state(&self) -> Result> { + fn state(&self) -> Result> { // 1. Stores aggregate state in `ScalarValue::List` // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set let state_out = { @@ -136,10 +136,10 @@ impl Accumulator for DistinctSumAccumulator { self.hash_values .iter() .for_each(|distinct_value| distinct_values.push(distinct_value.clone())); - vec![ScalarValue::List( + vec![AggregateState::Scalar(ScalarValue::List( Some(distinct_values), Box::new(Field::new("item", self.data_type.clone(), true)), - )] + ))] }; Ok(state_out) } @@ -181,6 +181,7 @@ impl Accumulator for DistinctSumAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::aggregate::utils::get_accum_scalar_values; use crate::expressions::col; use crate::expressions::tests::aggregate; use arrow::record_batch::RecordBatch; @@ -196,7 +197,7 @@ mod tests { let mut accum = agg.create_accumulator()?; accum.update_batch(arrays)?; - Ok((accum.state()?, accum.evaluate()?)) + Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?)) } macro_rules! generic_test_sum_distinct { diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs new file mode 100644 index 000000000000..1cac5b98a21b --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities used in aggregates + +use arrow::array::ArrayRef; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Accumulator; + +/// Extract scalar values from an accumulator. This can return an error if the accumulator +/// has any non-scalar values. +pub fn get_accum_scalar_values(accum: &dyn Accumulator) -> Result> { + accum + .state()? + .iter() + .map(|agg| agg.as_scalar().map(|v| v.clone())) + .collect::>>() +} + +/// Convert scalar values from an accumulator into arrays. This can return an error if the +/// accumulator has any non-scalar values. +pub fn get_accum_scalar_values_as_arrays( + accum: &dyn Accumulator, +) -> Result> { + accum + .state()? + .iter() + .map(|v| { + v.as_scalar() + .map(|s| vec![s.clone()]) + .and_then(ScalarValue::iter_to_array) + }) + .collect::>>() +} diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index 364936213fca..4ff4318e359c 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -32,7 +32,7 @@ use arrow::{ }; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, AggregateState}; /// VAR and VAR_SAMP aggregate expression #[derive(Debug)] @@ -210,11 +210,11 @@ impl VarianceAccumulator { } impl Accumulator for VarianceAccumulator { - fn state(&self) -> Result> { + fn state(&self) -> Result> { Ok(vec![ - ScalarValue::from(self.count), - ScalarValue::from(self.mean), - ScalarValue::from(self.m2), + AggregateState::Scalar(ScalarValue::from(self.count)), + AggregateState::Scalar(ScalarValue::from(self.mean)), + AggregateState::Scalar(ScalarValue::from(self.m2)), ]) } @@ -296,6 +296,7 @@ impl Accumulator for VarianceAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::aggregate::utils::get_accum_scalar_values_as_arrays; use crate::expressions::col; use crate::expressions::tests::aggregate; use crate::generic_test_op; @@ -522,12 +523,7 @@ mod tests { .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; - let state2 = accum2 - .state()? - .iter() - .map(|v| vec![v.clone()]) - .map(|x| ScalarValue::iter_to_array(x).unwrap()) - .collect::>(); + let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?; accum1.merge_batch(&state2)?; accum1.evaluate() } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 7a78f4603e87..6d8852e77f75 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -52,6 +52,7 @@ pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; pub use crate::aggregate::grouping::Grouping; +pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; pub use crate::aggregate::stats::StatsType; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index ec816a419432..c9c1237a7f29 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -466,6 +466,7 @@ enum AggregateFunction { APPROX_MEDIAN=15; APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; GROUPING = 17; + MEDIAN=18; } message AggregateExprNode { diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 40ea1bd02500..1f3c3955a0f4 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -504,6 +504,7 @@ impl From for AggregateFunction { } protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, protobuf::AggregateFunction::Grouping => Self::Grouping, + protobuf::AggregateFunction::Median => Self::Median, } } } diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 1d7847df2c29..88230766d907 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -68,8 +68,8 @@ mod roundtrip_tests { use datafusion_expr::expr::GroupingSet; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode}; use datafusion_expr::{ - col, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::Sqrt, Expr, - LogicalPlan, Volatility, + col, lit, Accumulator, AggregateFunction, AggregateState, + BuiltinScalarFunction::Sqrt, Expr, LogicalPlan, Volatility, }; use prost::Message; use std::any::Any; @@ -986,7 +986,7 @@ mod roundtrip_tests { struct Dummy {} impl Accumulator for Dummy { - fn state(&self) -> datafusion::error::Result> { + fn state(&self) -> datafusion::error::Result> { Ok(vec![]) } diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 323e2186d4f6..b8ca81008453 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -354,6 +354,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { } AggregateFunction::ApproxMedian => Self::ApproxMedian, AggregateFunction::Grouping => Self::Grouping, + AggregateFunction::Median => Self::Median, } } } @@ -540,6 +541,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { protobuf::AggregateFunction::ApproxMedian } AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, + AggregateFunction::Median => protobuf::AggregateFunction::Median, }; let aggregate_expr = protobuf::AggregateExprNode {