diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c97f770ab301..586fff30ac43 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -2452,12 +2452,7 @@ mod tests { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| { - Ok(Box::new(AvgAccumulator::try_new( - &DataType::Float64, - &DataType::Float64, - )?)) - }), + Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); diff --git a/datafusion/core/tests/sql/udf.rs b/datafusion/core/tests/sql/udf.rs index 86ff6ebac228..5aa3ab3bc0a8 100644 --- a/datafusion/core/tests/sql/udf.rs +++ b/datafusion/core/tests/sql/udf.rs @@ -237,12 +237,7 @@ async fn simple_udaf() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| { - Ok(Box::new(AvgAccumulator::try_new( - &DataType::Float64, - &DataType::Float64, - )?)) - }), + Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 2a6e08cd52c4..1ebe234840f4 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -906,12 +906,7 @@ mod test { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| { - Ok(Box::new(AvgAccumulator::try_new( - &DataType::Float64, - &DataType::Float64, - )?)) - }), + Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( @@ -932,12 +927,8 @@ mod test { Arc::new(move |_| Ok(Arc::new(DataType::Float64))); let state_type: StateTypeFunction = Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64]))); - let accumulator: AccumulatorFactoryFunction = Arc::new(|_| { - Ok(Box::new(AvgAccumulator::try_new( - &DataType::Float64, - &DataType::Float64, - )?)) - }); + let accumulator: AccumulatorFactoryFunction = + Arc::new(|_| Ok(Box::::default())); let my_avg = AggregateUDF::new( "MY_AVG", &Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 59f72b6e3e25..9f98c6fce57f 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -21,17 +21,13 @@ use arrow::array::{AsArray, PrimitiveBuilder}; use log::debug; use std::any::Any; -use std::convert::TryFrom; use std::sync::Arc; use crate::aggregate::groups_accumulator::accumulate::NullState; -use crate::aggregate::sum; -use crate::aggregate::sum::sum_batch; -use crate::aggregate::utils::calculate_result_decimal_for_avg; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; -use arrow::compute; +use arrow::compute::sum; use arrow::datatypes::{DataType, Decimal128Type, Float64Type, UInt64Type}; use arrow::{ array::{ArrayRef, UInt64Array}, @@ -40,9 +36,7 @@ use arrow::{ use arrow_array::{ Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, PrimitiveArray, }; -use datafusion_common::{ - downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::avg_return_type; use datafusion_expr::Accumulator; @@ -93,11 +87,27 @@ impl AggregateExpr for Avg { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(AvgAccumulator::try_new( - // avg is f64 or decimal - &self.input_data_type, - &self.result_data_type, - )?)) + use DataType::*; + // instantiate specialized accumulator based for the type + match (&self.input_data_type, &self.result_data_type) { + (Float64, Float64) => Ok(Box::::default()), + ( + Decimal128(sum_precision, sum_scale), + Decimal128(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), + _ => not_impl_err!( + "AvgAccumulator for ({} --> {})", + self.input_data_type, + self.result_data_type + ), + } } fn state_fields(&self) -> Result> { @@ -128,10 +138,7 @@ impl AggregateExpr for Avg { } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(AvgAccumulator::try_new( - &self.input_data_type, - &self.result_data_type, - )?)) + self.create_accumulator() } fn groups_accumulator_supported(&self) -> bool { @@ -195,91 +202,141 @@ impl PartialEq for Avg { } /// An accumulator to compute the average -#[derive(Debug)] +#[derive(Debug, Default)] pub struct AvgAccumulator { - // sum is used for null - sum: ScalarValue, - return_data_type: DataType, + sum: Option, count: u64, } -impl AvgAccumulator { - /// Creates a new `AvgAccumulator` - pub fn try_new(datatype: &DataType, return_data_type: &DataType) -> Result { - Ok(Self { - sum: ScalarValue::try_from(datatype)?, - return_data_type: return_data_type.clone(), - count: 0, - }) +impl Accumulator for AvgAccumulator { + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::Float64(self.sum), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count += (values.len() - values.null_count()) as u64; + if let Some(x) = sum(values) { + let v = self.sum.get_or_insert(0.); + *v += x; + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count -= (values.len() - values.null_count()) as u64; + if let Some(x) = sum(values) { + self.sum = Some(self.sum.unwrap() - x); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // counts are summed + self.count += sum(states[0].as_primitive::()).unwrap_or_default(); + + // sums are summed + if let Some(x) = sum(states[1].as_primitive::()) { + let v = self.sum.get_or_insert(0.); + *v += x; + } + Ok(()) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Float64( + self.sum.map(|f| f / self.count as f64), + )) + } + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) } } -impl Accumulator for AvgAccumulator { +/// An accumulator to compute the average for decimals +#[derive(Debug)] +struct DecimalAvgAccumulator { + sum: Option, + count: u64, + sum_scale: i8, + sum_precision: u8, + target_precision: u8, + target_scale: i8, +} + +impl Accumulator for DecimalAvgAccumulator { fn state(&self) -> Result> { - Ok(vec![ScalarValue::from(self.count), self.sum.clone()]) + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale), + ]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; + let values = values[0].as_primitive::(); self.count += (values.len() - values.null_count()) as u64; - self.sum = self.sum.add(&sum::sum_batch(values)?)?; + if let Some(x) = sum(values) { + let v = self.sum.get_or_insert(0); + *v += x; + } Ok(()) } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; + let values = values[0].as_primitive::(); self.count -= (values.len() - values.null_count()) as u64; - let delta = sum_batch(values)?; - self.sum = self.sum.sub(&delta)?; + if let Some(x) = sum(values) { + self.sum = Some(self.sum.unwrap() - x); + } Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); // counts are summed - self.count += compute::sum(counts).unwrap_or(0); + self.count += sum(states[0].as_primitive::()).unwrap_or_default(); // sums are summed - self.sum = self.sum.add(&sum::sum_batch(&states[1])?)?; + if let Some(x) = sum(states[1].as_primitive::()) { + let v = self.sum.get_or_insert(0); + *v += x; + } Ok(()) } fn evaluate(&self) -> Result { - match self.sum { - ScalarValue::Float64(e) => { - Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) - } - ScalarValue::Decimal128(value, _, scale) => { - match value { - None => match &self.return_data_type { - DataType::Decimal128(p, s) => { - Ok(ScalarValue::Decimal128(None, *p, *s)) - } - other => internal_err!( - "Error returned data type in AvgAccumulator {other:?}" - ), - }, - Some(value) => { - // now the sum_type and return type is not the same, need to convert the sum type to return type - calculate_result_decimal_for_avg( - value, - self.count as i128, - scale, - &self.return_data_type, - ) - } - } - } - _ => internal_err!("Sum should be f64 or decimal128 on average"), - } + let v = self + .sum + .map(|v| { + Decimal128Averager::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )? + .avg(v, self.count as _) + }) + .transpose()?; + + Ok(ScalarValue::Decimal128( + v, + self.target_precision, + self.target_scale, + )) } fn supports_retract_batch(&self) -> bool { true } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size() + std::mem::size_of_val(self) } } @@ -490,6 +547,7 @@ mod tests { assert_aggregate( array, AggregateFunction::Avg, + false, ScalarValue::Decimal128(Some(35000), 14, 4), ); } @@ -506,6 +564,7 @@ mod tests { assert_aggregate( array, AggregateFunction::Avg, + false, ScalarValue::Decimal128(Some(32500), 14, 4), ); } @@ -523,6 +582,7 @@ mod tests { assert_aggregate( array, AggregateFunction::Avg, + false, ScalarValue::Decimal128(None, 14, 4), ); } @@ -530,7 +590,7 @@ mod tests { #[test] fn avg_i32() { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - assert_aggregate(a, AggregateFunction::Avg, ScalarValue::from(3_f64)); + assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64)); } #[test] @@ -542,33 +602,33 @@ mod tests { Some(4), Some(5), ])); - assert_aggregate(a, AggregateFunction::Avg, ScalarValue::from(3.25f64)); + assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3.25f64)); } #[test] fn avg_i32_all_nulls() { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - assert_aggregate(a, AggregateFunction::Avg, ScalarValue::Float64(None)); + assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::Float64(None)); } #[test] fn avg_u32() { let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - assert_aggregate(a, AggregateFunction::Avg, ScalarValue::from(3.0f64)); + assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3.0f64)); } #[test] fn avg_f32() { let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - assert_aggregate(a, AggregateFunction::Avg, ScalarValue::from(3_f64)); + assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64)); } #[test] fn avg_f64() { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - assert_aggregate(a, AggregateFunction::Avg, ScalarValue::from(3_f64)); + assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64)); } } diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index baaebada377b..5cc8e933324e 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -18,30 +18,22 @@ //! Defines `SUM` and `SUM DISTINCT` aggregate accumulators use std::any::Any; -use std::convert::TryFrom; -use std::ops::AddAssign; use std::sync::Arc; use super::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; -use arrow::array::Array; -use arrow::array::Decimal128Array; -use arrow::array::Decimal256Array; -use arrow::compute; +use arrow::compute::sum; use arrow::datatypes::DataType; -use arrow::{ - array::{ArrayRef, Float64Array, Int64Array, UInt64Array}, - datatypes::Field, -}; +use arrow::{array::ArrayRef, datatypes::Field}; +use arrow_array::cast::AsArray; use arrow_array::types::{ - Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, Int64Type, - UInt32Type, UInt64Type, -}; -use datafusion_common::{ - downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, + Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, }; +use arrow_array::{Array, ArrowNativeTypeOp, ArrowNumericType}; +use arrow_buffer::ArrowNativeType; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::sum_return_type; use datafusion_expr::Accumulator; @@ -71,18 +63,24 @@ impl Sum { } } -/// Creates a [`PrimitiveGroupsAccumulator`] with the specified -/// [`ArrowPrimitiveType`] which applies `$FN` to each element +/// Sum only supports a subset of numeric types, instead relying on type coercion +/// +/// This macro is similar to [downcast_primitive](arrow_array::downcast_primitive) /// -/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_primitive_accumulator { - ($SELF:expr, $PRIMTYPE:ident, $FN:expr) => {{ - Ok(Box::new(PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( - &$SELF.data_type, - $FN, - ))) - }}; +/// `s` is a `Sum`, `helper` is a macro accepting (ArrowPrimitiveType, DataType) +macro_rules! downcast_sum { + ($s:ident, $helper:ident) => { + match $s.data_type { + DataType::UInt64 => $helper!(UInt64Type, $s.data_type), + DataType::Int64 => $helper!(Int64Type, $s.data_type), + DataType::Float64 => $helper!(Float64Type, $s.data_type), + DataType::Decimal128(_, _) => $helper!(Decimal128Type, $s.data_type), + DataType::Decimal256(_, _) => $helper!(Decimal256Type, $s.data_type), + _ => not_impl_err!("Sum not supported for {}: {}", $s.name, $s.data_type), + } + }; } +pub(crate) use downcast_sum; impl AggregateExpr for Sum { /// Return a reference to Any that can be used for downcasting @@ -99,7 +97,12 @@ impl AggregateExpr for Sum { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(SumAccumulator::<$t>::new($dt.clone()))) + }; + } + downcast_sum!(self, helper) } fn state_fields(&self) -> Result> { @@ -123,46 +126,15 @@ impl AggregateExpr for Sum { } fn create_groups_accumulator(&self) -> Result> { - // instantiate specialized accumulator - match self.data_type { - DataType::UInt64 => { - instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x - .add_assign(y)) - } - DataType::Int64 => { - instantiate_primitive_accumulator!(self, Int64Type, |x, y| x - .add_assign(y)) - } - DataType::UInt32 => { - instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x - .add_assign(y)) - } - DataType::Int32 => { - instantiate_primitive_accumulator!(self, Int32Type, |x, y| x - .add_assign(y)) - } - DataType::Float32 => { - instantiate_primitive_accumulator!(self, Float32Type, |x, y| x - .add_assign(y)) - } - DataType::Float64 => { - instantiate_primitive_accumulator!(self, Float64Type, |x, y| x - .add_assign(y)) - } - DataType::Decimal128(_, _) => { - instantiate_primitive_accumulator!(self, Decimal128Type, |x, y| x - .add_assign(y)) - } - DataType::Decimal256(_, _) => { - instantiate_primitive_accumulator!(self, Decimal256Type, |x, y| *x = - *x + y) - } - _ => not_impl_err!( - "GroupsAccumulator not supported for {}: {}", - self.name, - self.data_type - ), + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new( + &$dt, + |x, y| *x = x.add_wrapping(y), + ))) + }; } + downcast_sum!(self, helper) } fn reverse_expr(&self) -> Option> { @@ -170,7 +142,12 @@ impl AggregateExpr for Sum { } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(SlidingSumAccumulator::try_new(&self.data_type)?)) + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) + }; + } + downcast_sum!(self, helper) } } @@ -189,177 +166,135 @@ impl PartialEq for Sum { } /// This accumulator computes SUM incrementally -#[derive(Debug)] -struct SumAccumulator { - sum: ScalarValue, -} - -impl SumAccumulator { - /// new sum accumulator - pub fn try_new(data_type: &DataType) -> Result { - Ok(Self { - sum: ScalarValue::try_from(data_type)?, - }) - } -} - -/// This accumulator incrementally computes sums over a sliding window -#[derive(Debug)] -struct SlidingSumAccumulator { - sum: ScalarValue, - count: u64, +struct SumAccumulator { + sum: Option, + data_type: DataType, } -impl SlidingSumAccumulator { - /// new sum accumulator - pub fn try_new(data_type: &DataType) -> Result { - Ok(Self { - // start at zero - sum: ScalarValue::try_from(data_type)?, - count: 0, - }) +impl std::fmt::Debug for SumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SumAccumulator({})", self.data_type) } } -/// Sums the contents of the `$VALUES` array using the arrow compute -/// kernel, and return a `ScalarValue::$SCALAR`. -/// -/// Handles nullability -macro_rules! typed_sum_delta_batch { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let delta = compute::sum(array); - ScalarValue::$SCALAR(delta) - }}; -} - -fn sum_decimal_batch(values: &ArrayRef, precision: u8, scale: i8) -> Result { - let array = downcast_value!(values, Decimal128Array); - let result = compute::sum(array); - Ok(ScalarValue::Decimal128(result, precision, scale)) -} - -fn sum_decimal256_batch( - values: &ArrayRef, - precision: u8, - scale: i8, -) -> Result { - let array = downcast_value!(values, Decimal256Array); - let result = compute::sum(array); - Ok(ScalarValue::Decimal256(result, precision, scale)) -} - -// sums the array and returns a ScalarValue of its corresponding type. -pub(crate) fn sum_batch(values: &ArrayRef) -> Result { - Ok(match values.data_type() { - DataType::Decimal128(precision, scale) => { - sum_decimal_batch(values, *precision, *scale)? - } - DataType::Decimal256(precision, scale) => { - sum_decimal256_batch(values, *precision, *scale)? - } - DataType::Float64 => typed_sum_delta_batch!(values, Float64Array, Float64), - DataType::Int64 => typed_sum_delta_batch!(values, Int64Array, Int64), - DataType::UInt64 => typed_sum_delta_batch!(values, UInt64Array, UInt64), - e => { - return internal_err!("Sum is not expected to receive the type {e:?}"); +impl SumAccumulator { + fn new(data_type: DataType) -> Self { + Self { + sum: None, + data_type, } - }) + } } -impl Accumulator for SumAccumulator { +impl Accumulator for SumAccumulator { fn state(&self) -> Result> { - Ok(vec![self.sum.clone()]) + Ok(vec![self.evaluate()?]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - let delta = sum_batch(values)?; - self.sum = self.sum.add(&delta)?; + let values = values[0].as_primitive::(); + if let Some(x) = sum(values) { + let v = self.sum.get_or_insert(T::Native::usize_as(0)); + *v = v.add_wrapping(x); + } Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - // sum(sum1, sum2, sum3, ...) = sum1 + sum2 + sum3 + ... self.update_batch(states) } fn evaluate(&self) -> Result { - // TODO: add the checker for overflow - // For the decimal(precision,_) data type, the absolute of value must be less than 10^precision. - Ok(self.sum.clone()) + Ok(ScalarValue::new_primitive::(self.sum, &self.data_type)) } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size() + std::mem::size_of_val(self) + } +} + +/// This accumulator incrementally computes sums over a sliding window +/// +/// This is separate from [`SumAccumulator`] as requires additional state +struct SlidingSumAccumulator { + sum: T::Native, + count: u64, + data_type: DataType, +} + +impl std::fmt::Debug for SlidingSumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SlidingSumAccumulator({})", self.data_type) + } +} + +impl SlidingSumAccumulator { + fn new(data_type: DataType) -> Self { + Self { + sum: T::Native::usize_as(0), + count: 0, + data_type, + } } } -impl Accumulator for SlidingSumAccumulator { +impl Accumulator for SlidingSumAccumulator { fn state(&self) -> Result> { - Ok(vec![self.sum.clone(), ScalarValue::from(self.count)]) + Ok(vec![self.evaluate()?, self.count.into()]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; + let values = values[0].as_primitive::(); self.count += (values.len() - values.null_count()) as u64; - let delta = sum_batch(values)?; - self.sum = self.sum.add(&delta)?; + if let Some(x) = sum(values) { + self.sum = self.sum.add_wrapping(x) + } Ok(()) } - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - self.count -= (values.len() - values.null_count()) as u64; - let delta = sum_batch(values)?; - self.sum = self.sum.sub(&delta)?; + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let values = states[0].as_primitive::(); + if let Some(x) = sum(values) { + self.sum = self.sum.add_wrapping(x) + } + if let Some(x) = sum(states[1].as_primitive::()) { + self.count += x; + } Ok(()) } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - // sum(sum1, sum2, sum3, ...) = sum1 + sum2 + sum3 + ... - self.update_batch(states) + fn evaluate(&self) -> Result { + let v = (self.count != 0).then_some(self.sum); + Ok(ScalarValue::new_primitive::(v, &self.data_type)) } - fn evaluate(&self) -> Result { - // TODO: add the checker for overflow - // For the decimal(precision,_) data type, the absolute of value must be less than 10^precision. - if self.count == 0 { - ScalarValue::try_from(&self.sum.get_datatype()) - } else { - Ok(self.sum.clone()) + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + if let Some(x) = sum(values) { + self.sum = self.sum.sub_wrapping(x) } + self.count -= (values.len() - values.null_count()) as u64; + Ok(()) } fn supports_retract_batch(&self) -> bool { true } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size() - } } #[cfg(test)] mod tests { use super::*; use crate::expressions::tests::assert_aggregate; - use arrow_array::{Float32Array, Int32Array, UInt32Array}; + use arrow_array::*; use datafusion_expr::AggregateFunction; #[test] fn sum_decimal() { - // test sum batch - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0) - .unwrap(), - ); - let result = sum_batch(&array).unwrap(); - assert_eq!(ScalarValue::Decimal128(Some(15), 10, 0), result); - // test agg let array: ArrayRef = Arc::new( (1..6) @@ -372,23 +307,13 @@ mod tests { assert_aggregate( array, AggregateFunction::Sum, + false, ScalarValue::Decimal128(Some(15), 20, 0), ); } #[test] fn sum_decimal_with_nulls() { - // test with batch - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i) }) - .collect::() - .with_precision_and_scale(10, 0) - .unwrap(), - ); - let result = sum_batch(&array).unwrap(); - assert_eq!(ScalarValue::Decimal128(Some(13), 10, 0), result); - // test agg let array: ArrayRef = Arc::new( (1..6) @@ -401,6 +326,7 @@ mod tests { assert_aggregate( array, AggregateFunction::Sum, + false, ScalarValue::Decimal128(Some(13), 38, 0), ); } @@ -415,13 +341,12 @@ mod tests { .with_precision_and_scale(10, 0) .unwrap(), ); - let result = sum_batch(&array).unwrap(); - assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); // test agg assert_aggregate( array, AggregateFunction::Sum, + false, ScalarValue::Decimal128(None, 20, 0), ); } @@ -429,7 +354,7 @@ mod tests { #[test] fn sum_i32() { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - assert_aggregate(a, AggregateFunction::Sum, ScalarValue::from(15i64)); + assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(15i64)); } #[test] @@ -441,33 +366,33 @@ mod tests { Some(4), Some(5), ])); - assert_aggregate(a, AggregateFunction::Sum, ScalarValue::from(13i64)); + assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(13i64)); } #[test] fn sum_i32_all_nulls() { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - assert_aggregate(a, AggregateFunction::Sum, ScalarValue::Int64(None)); + assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::Int64(None)); } #[test] fn sum_u32() { let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - assert_aggregate(a, AggregateFunction::Sum, ScalarValue::from(15u64)); + assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(15u64)); } #[test] fn sum_f32() { let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - assert_aggregate(a, AggregateFunction::Sum, ScalarValue::from(15_f64)); + assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(15_f64)); } #[test] fn sum_f64() { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - assert_aggregate(a, AggregateFunction::Sum, ScalarValue::from(15_f64)); + assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(15_f64)); } } diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index 366b875c2393..c3d8d5e87068 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -18,17 +18,21 @@ use crate::expressions::format_state_name; use arrow::datatypes::{DataType, Field}; use std::any::Any; -use std::fmt::Debug; use std::sync::Arc; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; +use arrow_array::cast::AsArray; +use arrow_array::types::*; +use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType}; +use arrow_buffer::{ArrowNativeType, ToByteSlice}; use std::collections::HashSet; +use crate::aggregate::sum::downcast_sum; use crate::aggregate::utils::down_cast_any_ref; use crate::{AggregateExpr, PhysicalExpr}; -use datafusion_common::ScalarValue; -use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::type_coercion::aggregates::sum_return_type; use datafusion_expr::Accumulator; /// Expression for a SUM(DISTINCT) aggregation. @@ -49,6 +53,7 @@ impl DistinctSum { name: String, data_type: DataType, ) -> Self { + let data_type = sum_return_type(&data_type).unwrap(); Self { name, data_type, @@ -84,7 +89,12 @@ impl AggregateExpr for DistinctSum { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctSumAccumulator::try_new(&self.data_type)?)) + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?)) + }; + } + downcast_sum!(self, helper) } } @@ -106,29 +116,56 @@ impl PartialEq for DistinctSum { } } -#[derive(Debug)] -struct DistinctSumAccumulator { - hash_values: HashSet, +/// A wrapper around a type to provide hash for floats +#[derive(Copy, Clone)] +struct Hashable(T); + +impl std::hash::Hash for Hashable { + fn hash(&self, state: &mut H) { + self.0.to_byte_slice().hash(state) + } +} + +impl PartialEq for Hashable { + fn eq(&self, other: &Self) -> bool { + self.0.is_eq(other.0) + } +} + +impl Eq for Hashable {} + +struct DistinctSumAccumulator { + values: HashSet, RandomState>, data_type: DataType, } -impl DistinctSumAccumulator { + +impl std::fmt::Debug for DistinctSumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DistinctSumAccumulator({})", self.data_type) + } +} + +impl DistinctSumAccumulator { pub fn try_new(data_type: &DataType) -> Result { Ok(Self { - hash_values: HashSet::default(), + values: HashSet::default(), data_type: data_type.clone(), }) } } -impl Accumulator for DistinctSumAccumulator { +impl Accumulator for DistinctSumAccumulator { fn state(&self) -> Result> { // 1. Stores aggregate state in `ScalarValue::List` // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set let state_out = { let mut distinct_values = Vec::new(); - self.hash_values - .iter() - .for_each(|distinct_value| distinct_values.push(distinct_value.clone())); + self.values.iter().for_each(|distinct_value| { + distinct_values.push(ScalarValue::new_primitive::( + Some(distinct_value.0), + &self.data_type, + )) + }); vec![ScalarValue::new_list( Some(distinct_values), self.data_type.clone(), @@ -142,62 +179,49 @@ impl Accumulator for DistinctSumAccumulator { return Ok(()); } - let arr = &values[0]; - (0..values[0].len()).try_for_each(|index| { - if !arr.is_null(index) { - let v = ScalarValue::try_from_array(arr, index)?; - self.hash_values.insert(v); + let array = values[0].as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.values.insert(Hashable(array.value(idx))); + } } - Ok(()) - }) + None => array.values().iter().for_each(|x| { + self.values.insert(Hashable(*x)); + }), + } + Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); + for x in states[0].as_list::().iter().flatten() { + self.update_batch(&[x])? } - - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - - if let ScalarValue::List(Some(scalar), _) = scalar { - scalar.iter().for_each(|scalar| { - if !ScalarValue::is_null(scalar) { - self.hash_values.insert(scalar.clone()); - } - }); - } else { - return internal_err!("Unexpected accumulator state"); - } - Ok(()) - }) + Ok(()) } fn evaluate(&self) -> Result { - let mut sum_value = ScalarValue::try_from(&self.data_type)?; - for distinct_value in self.hash_values.iter() { - sum_value = sum_value.add(distinct_value)?; + let mut acc = T::Native::usize_as(0); + for distinct_value in self.values.iter() { + acc = acc.add_wrapping(distinct_value.0) } - Ok(sum_value) + let v = (!self.values.is_empty()).then_some(acc); + Ok(ScalarValue::new_primitive::(v, &self.data_type)) } fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.hash_values) - - std::mem::size_of_val(&self.hash_values) - + self.data_type.size() - - std::mem::size_of_val(&self.data_type) + std::mem::size_of_val(self) + + self.values.capacity() * std::mem::size_of::() } } #[cfg(test)] mod tests { use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use arrow::record_batch::RecordBatch; - use arrow::{array::*, datatypes::*}; + use crate::expressions::tests::assert_aggregate; + use arrow::array::*; use datafusion_common::Result; + use datafusion_expr::AggregateFunction; fn run_update_batch( return_type: DataType, @@ -211,26 +235,6 @@ mod tests { Ok((accum.state()?, accum.evaluate()?)) } - macro_rules! generic_test_sum_distinct { - ($ARRAY:expr, $DATATYPE:expr, $EXPECTED:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(DistinctSum::new( - vec![col("a", &schema)?], - "count_distinct_a".to_string(), - $EXPECTED.get_datatype(), - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) - }}; - } - #[test] fn sum_distinct_update_batch() -> Result<()> { let array_int64: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 3])); @@ -244,7 +248,7 @@ mod tests { } #[test] - fn sum_distinct_i32_with_nulls() -> Result<()> { + fn sum_distinct_i32_with_nulls() { let array = Arc::new(Int32Array::from(vec![ Some(1), Some(1), @@ -253,11 +257,11 @@ mod tests { Some(2), Some(3), ])); - generic_test_sum_distinct!(array, DataType::Int32, ScalarValue::from(6_i32)) + assert_aggregate(array, AggregateFunction::Sum, true, 6_i64.into()); } #[test] - fn sum_distinct_u32_with_nulls() -> Result<()> { + fn sum_distinct_u32_with_nulls() { let array: ArrayRef = Arc::new(UInt32Array::from(vec![ Some(1_u32), Some(1_u32), @@ -265,28 +269,30 @@ mod tests { Some(3_u32), None, ])); - generic_test_sum_distinct!(array, DataType::UInt32, ScalarValue::from(4_u32)) + assert_aggregate(array, AggregateFunction::Sum, true, 4_u64.into()); } #[test] - fn sum_distinct_f64() -> Result<()> { + fn sum_distinct_f64() { let array: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 1_f64, 3_f64, 3_f64, 3_f64])); - generic_test_sum_distinct!(array, DataType::Float64, ScalarValue::from(4_f64)) + assert_aggregate(array, AggregateFunction::Sum, true, 4_f64.into()); } #[test] - fn sum_distinct_decimal_with_nulls() -> Result<()> { + fn sum_distinct_decimal_with_nulls() { let array: ArrayRef = Arc::new( (1..6) .map(|i| if i == 2 { None } else { Some(i % 2) }) .collect::() - .with_precision_and_scale(35, 0)?, + .with_precision_and_scale(35, 0) + .unwrap(), ); - generic_test_sum_distinct!( + assert_aggregate( array, - DataType::Decimal128(35, 0), - ScalarValue::Decimal128(Some(1), 38, 0) - ) + AggregateFunction::Sum, + true, + ScalarValue::Decimal128(Some(1), 38, 0), + ); } } diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index e86eb1dc1fc5..463d8fec189c 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -23,8 +23,7 @@ use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PREC use arrow_array::cast::AsArray; use arrow_array::types::Decimal128Type; use arrow_schema::{DataType, Field}; -use datafusion_common::internal_err; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use std::any::Any; use std::sync::Arc; @@ -118,34 +117,6 @@ impl Decimal128Averager { } } -/// Returns `sum`/`count` for decimal values, detecting and reporting overflow. -/// -/// * sum: stored as Decimal128 with `sum_scale` scale -/// * count: stored as a i128 (*NOT* a Decimal128 value) -/// * sum_scale: the scale of `sum` -/// * target_type: the output decimal type -pub fn calculate_result_decimal_for_avg( - sum: i128, - count: i128, - sum_scale: i8, - target_type: &DataType, -) -> Result { - match target_type { - DataType::Decimal128(target_precision, target_scale) => { - let new_value = - Decimal128Averager::try_new(sum_scale, *target_precision, *target_scale)? - .avg(sum, count)?; - - Ok(ScalarValue::Decimal128( - Some(new_value), - *target_precision, - *target_scale, - )) - } - other => internal_err!("Invalid target type in AvgAccumulator {other:?}"), - } -} - /// Adjust array type metadata if needed /// /// Since `Decimal128Arrays` created from `Vec` have diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 4a6d52834dda..bce1240e50dd 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -140,6 +140,7 @@ pub(crate) mod tests { pub fn assert_aggregate( array: ArrayRef, function: AggregateFunction, + distinct: bool, expected: ScalarValue, ) { let data_type = array.data_type(); @@ -159,7 +160,7 @@ pub(crate) mod tests { let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); let agg = - create_aggregate_expr(&function, false, &[input], &[], &schema, "aggregate") + create_aggregate_expr(&function, distinct, &[input], &[], &schema, "agg") .unwrap(); let result = aggregate(&batch, agg).unwrap(); diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index a2a1df55e54e..76743e444e2c 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -618,7 +618,7 @@ select a / b from foo; statement ok create table t as values (arrow_cast(123, 'Decimal256(5,2)')); -query error DataFusion error: Internal error: Operator \+ is not implemented for types Decimal256\(None,5,2\) and Decimal256\(Some\(12300\),5,2\)\. This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker +query error DataFusion error: This feature is not implemented: AvgAccumulator for \(Decimal256\(5, 2\) \-\-> Decimal256\(9, 6\)\) select AVG(column1) from t; statement ok