Skip to content

Commit

Permalink
Specialize SUM and AVG (#6842)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Aug 22, 2023
1 parent 870857a commit 1fd74c8
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 318 deletions.
7 changes: 1 addition & 6 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2452,12 +2452,7 @@ mod tests {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
}),
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down
7 changes: 1 addition & 6 deletions datafusion/core/tests/sql/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,7 @@ async fn simple_udaf() -> Result<()> {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
}),
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down
15 changes: 3 additions & 12 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -906,12 +906,7 @@ mod test {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
}),
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);
let udaf = Expr::AggregateUDF(expr::AggregateUDF::new(
Expand All @@ -932,12 +927,8 @@ mod test {
Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
let state_type: StateTypeFunction =
Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64])));
let accumulator: AccumulatorFactoryFunction = Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
});
let accumulator: AccumulatorFactoryFunction =
Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
let my_avg = AggregateUDF::new(
"MY_AVG",
&Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
Expand Down
191 changes: 124 additions & 67 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,13 @@ use arrow::array::{AsArray, PrimitiveBuilder};
use log::debug;

use std::any::Any;
use std::convert::TryFrom;
use std::sync::Arc;

use crate::aggregate::groups_accumulator::accumulate::NullState;
use crate::aggregate::sum;
use crate::aggregate::sum::sum_batch;
use crate::aggregate::utils::calculate_result_decimal_for_avg;
use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr};
use arrow::compute;
use arrow::compute::sum;
use arrow::datatypes::{DataType, Decimal128Type, Float64Type, UInt64Type};
use arrow::{
array::{ArrayRef, UInt64Array},
Expand All @@ -40,9 +36,7 @@ use arrow::{
use arrow_array::{
Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, PrimitiveArray,
};
use datafusion_common::{
downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue,
};
use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::type_coercion::aggregates::avg_return_type;
use datafusion_expr::Accumulator;

Expand Down Expand Up @@ -93,11 +87,27 @@ impl AggregateExpr for Avg {
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(AvgAccumulator::try_new(
// avg is f64 or decimal
&self.input_data_type,
&self.result_data_type,
)?))
use DataType::*;
// instantiate specialized accumulator based for the type
match (&self.input_data_type, &self.result_data_type) {
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
(
Decimal128(sum_precision, sum_scale),
Decimal128(target_precision, target_scale),
) => Ok(Box::new(DecimalAvgAccumulator {
sum: None,
count: 0,
sum_scale: *sum_scale,
sum_precision: *sum_precision,
target_precision: *target_precision,
target_scale: *target_scale,
})),
_ => not_impl_err!(
"AvgGroupsAccumulator for ({} --> {})",
self.input_data_type,
self.result_data_type
),
}
}

fn state_fields(&self) -> Result<Vec<Field>> {
Expand Down Expand Up @@ -128,10 +138,7 @@ impl AggregateExpr for Avg {
}

fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(AvgAccumulator::try_new(
&self.input_data_type,
&self.result_data_type,
)?))
self.create_accumulator()
}

fn groups_accumulator_supported(&self) -> bool {
Expand Down Expand Up @@ -195,91 +202,141 @@ impl PartialEq<dyn Any> for Avg {
}

/// An accumulator to compute the average
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct AvgAccumulator {
// sum is used for null
sum: ScalarValue,
return_data_type: DataType,
sum: Option<f64>,
count: u64,
}

impl AvgAccumulator {
/// Creates a new `AvgAccumulator`
pub fn try_new(datatype: &DataType, return_data_type: &DataType) -> Result<Self> {
Ok(Self {
sum: ScalarValue::try_from(datatype)?,
return_data_type: return_data_type.clone(),
count: 0,
})
impl Accumulator for AvgAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::Float64(self.sum),
])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<Float64Type>();
self.count += (values.len() - values.null_count()) as u64;
if let Some(x) = sum(values) {
let v = self.sum.get_or_insert(0.);
*v += x;
}
Ok(())
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<Float64Type>();
self.count -= (values.len() - values.null_count()) as u64;
if let Some(x) = sum(values) {
self.sum = Some(self.sum.unwrap() - x);
}
Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
// counts are summed
self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();

// sums are summed
if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
let v = self.sum.get_or_insert(0.);
*v += x;
}
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::Float64(
self.sum.map(|f| f / self.count as f64),
))
}
fn supports_retract_batch(&self) -> bool {
true
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}
}

impl Accumulator for AvgAccumulator {
/// An accumulator to compute the average for decimals
#[derive(Debug)]
struct DecimalAvgAccumulator {
sum: Option<i128>,
count: u64,
sum_scale: i8,
sum_precision: u8,
target_precision: u8,
target_scale: i8,
}

impl Accumulator for DecimalAvgAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::from(self.count), self.sum.clone()])
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale),
])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
let values = values[0].as_primitive::<Decimal128Type>();

self.count += (values.len() - values.null_count()) as u64;
self.sum = self.sum.add(&sum::sum_batch(values)?)?;
if let Some(x) = sum(values) {
let v = self.sum.get_or_insert(0);
*v += x;
}
Ok(())
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
let values = values[0].as_primitive::<Decimal128Type>();
self.count -= (values.len() - values.null_count()) as u64;
let delta = sum_batch(values)?;
self.sum = self.sum.sub(&delta)?;
if let Some(x) = sum(values) {
self.sum = Some(self.sum.unwrap() - x);
}
Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let counts = downcast_value!(states[0], UInt64Array);
// counts are summed
self.count += compute::sum(counts).unwrap_or(0);
self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();

// sums are summed
self.sum = self.sum.add(&sum::sum_batch(&states[1])?)?;
if let Some(x) = sum(states[1].as_primitive::<Decimal128Type>()) {
let v = self.sum.get_or_insert(0);
*v += x;
}
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
match self.sum {
ScalarValue::Float64(e) => {
Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64)))
}
ScalarValue::Decimal128(value, _, scale) => {
match value {
None => match &self.return_data_type {
DataType::Decimal128(p, s) => {
Ok(ScalarValue::Decimal128(None, *p, *s))
}
other => internal_err!(
"Error returned data type in AvgAccumulator {other:?}"
),
},
Some(value) => {
// now the sum_type and return type is not the same, need to convert the sum type to return type
calculate_result_decimal_for_avg(
value,
self.count as i128,
scale,
&self.return_data_type,
)
}
}
}
_ => internal_err!("Sum should be f64 or decimal128 on average"),
}
let v = self
.sum
.map(|v| {
Decimal128Averager::try_new(
self.sum_scale,
self.target_precision,
self.target_scale,
)?
.avg(v, self.count as _)
})
.transpose()?;

Ok(ScalarValue::Decimal128(
v,
self.target_precision,
self.target_scale,
))
}
fn supports_retract_batch(&self) -> bool {
true
}

fn size(&self) -> usize {
std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size()
std::mem::size_of_val(self)
}
}

Expand Down
Loading

0 comments on commit 1fd74c8

Please sign in to comment.