From d72ca10cbd6658057622066ae640be3626990fc7 Mon Sep 17 00:00:00 2001 From: Dom Date: Tue, 11 Jan 2022 19:49:15 +0000 Subject: [PATCH] refactor: use input type as return type Casts the calculated quantile value to the same type as the input data. --- datafusion/src/physical_plan/aggregates.rs | 4 +-- .../expressions/approx_quantile.rs | 34 ++++++++++++++----- datafusion/tests/sql/aggregates.rs | 2 +- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 70362728fee21..69030310804f7 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -145,7 +145,7 @@ pub fn return_type( coerced_data_types[0].clone(), true, )))), - AggregateFunction::ApproxQuantile => Ok(DataType::Float64), + AggregateFunction::ApproxQuantile => Ok(coerced_data_types[0].clone()), } } @@ -509,7 +509,7 @@ mod tests { assert!(result_agg_phy_exprs.as_any().is::()); assert_eq!("c1", result_agg_phy_exprs.name()); assert_eq!( - Field::new("c1", DataType::Float64, false), + Field::new("c1", data_type.clone(), false), result_agg_phy_exprs.field().unwrap() ); } diff --git a/datafusion/src/physical_plan/expressions/approx_quantile.rs b/datafusion/src/physical_plan/expressions/approx_quantile.rs index 50e6adf55af40..f13ed903cb486 100644 --- a/datafusion/src/physical_plan/expressions/approx_quantile.rs +++ b/datafusion/src/physical_plan/expressions/approx_quantile.rs @@ -105,7 +105,7 @@ impl AggregateExpr for ApproxQuantile { } fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, false)) + Ok(Field::new(&self.name, self.input_data_type.clone(), false)) } /// See [`TDigest::to_scalar_state()`] for a description of the serialised @@ -151,7 +151,9 @@ impl AggregateExpr for ApproxQuantile { fn create_accumulator(&self) -> Result> { let accumulator: Box = match &self.input_data_type { - DataType::UInt8 + t + @ + (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 @@ -160,8 +162,8 @@ impl AggregateExpr for ApproxQuantile { | DataType::Int32 | DataType::Int64 | DataType::Float32 - | DataType::Float64 => { - Box::new(ApproxQuantileAccumulator::new(self.quantile)) + | DataType::Float64) => { + Box::new(ApproxQuantileAccumulator::new(self.quantile, t.clone())) } other => { return Err(DataFusionError::NotImplemented(format!( @@ -182,13 +184,15 @@ impl AggregateExpr for ApproxQuantile { pub struct ApproxQuantileAccumulator { digest: TDigest, quantile: f64, + return_type: DataType, } impl ApproxQuantileAccumulator { - pub fn new(quantile: f64) -> Self { + pub fn new(quantile: f64, return_type: DataType) -> Self { Self { digest: TDigest::new(100), quantile, + return_type, } } } @@ -283,8 +287,22 @@ impl Accumulator for ApproxQuantileAccumulator { } fn evaluate(&self) -> Result { - Ok(ScalarValue::Float64(Some( - self.digest.estimate_quantile(self.quantile), - ))) + let q = self.digest.estimate_quantile(self.quantile); + + // These acceptable return types MUST match the validation in + // ApproxQuantile::create_accumulator. + Ok(match &self.return_type { + DataType::Int8 => ScalarValue::Int8(Some(q as i8)), + DataType::Int16 => ScalarValue::Int16(Some(q as i16)), + DataType::Int32 => ScalarValue::Int32(Some(q as i32)), + DataType::Int64 => ScalarValue::Int64(Some(q as i64)), + DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)), + DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)), + DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)), + DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)), + DataType::Float32 => ScalarValue::Float32(Some(q as f32)), + DataType::Float64 => ScalarValue::Float64(Some(q as f64)), + v => unreachable!("unexpected return type {:?}", v), + }) } } diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 7ccbb57b03920..9510dd68b8b44 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -348,7 +348,7 @@ async fn csv_query_approx_quantile() -> Result<()> { // within 5% of the $actual quantile value. macro_rules! quantile_test { ($ctx:ident, column=$column:literal, quantile=$quantile:literal, actual=$actual:literal) => { - let sql = format!("SELECT (ABS(1 - approx_quantile({}, {}) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $quantile, $actual); + let sql = format!("SELECT (ABS(1 - CAST(approx_quantile({}, {}) AS DOUBLE) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $quantile, $actual); let actual = execute_to_batches(&mut ctx, &sql).await; // // "+------+",