Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Groups accumulator primitive #177

Merged
merged 2 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions datafusion/src/cube_ext/joinagg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ impl ExecutionPlan for CrossJoinAggExec {
&AggregateMode::Full,
self.group_expr.len(),
)?;
let mut accumulators = create_accumulation_state(&self.agg_expr)?;
let mut accumulators: hash_aggregate::AccumulationState =
create_accumulation_state(&self.agg_expr)?;
for partition in 0..self.join.right.output_partitioning().partition_count() {
let mut batches = self.join.right.execute(partition).await?;
while let Some(right) = batches.next().await {
Expand Down Expand Up @@ -273,7 +274,7 @@ impl ExecutionPlan for CrossJoinAggExec {
let out_schema = self.schema.clone();
let r = hash_aggregate::create_batch_from_map(
&AggregateMode::Full,
&accumulators,
accumulators,
self.group_expr.len(),
&out_schema,
)?;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ pub fn create_aggregate_expr(
))
}
(AggregateFunction::Sum, false) => {
Arc::new(expressions::Sum::new(arg, name, return_type))
Arc::new(expressions::Sum::new(arg, name, return_type, &arg_types[0]))
}
(AggregateFunction::Sum, true) => {
return Err(DataFusionError::NotImplemented(
Expand Down
180 changes: 168 additions & 12 deletions datafusion/src/physical_plan/expressions/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::sync::Arc;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::groups_accumulator::GroupsAccumulator;
use crate::physical_plan::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter;
use crate::physical_plan::groups_accumulator_prim_op::PrimitiveGroupsAccumulator;
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
use arrow::compute;
Expand All @@ -49,6 +50,7 @@ use smallvec::SmallVec;
pub struct Sum {
name: String,
data_type: DataType,
input_data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
nullable: bool,
}
Expand Down Expand Up @@ -80,11 +82,16 @@ impl Sum {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
input_data_type: &DataType,
) -> Self {
// Note: data_type = sum_return_type(input_data_type) in the actual caller, so we don't
// really need two params. But, we keep the four params to break symmetry with other
// accumulators and any code that might use 3 params, such as the generic_test_op macro.
Self {
name: name.into(),
expr,
data_type,
input_data_type: input_data_type.clone(),
nullable: true,
}
}
Expand Down Expand Up @@ -127,12 +134,147 @@ impl AggregateExpr for Sum {
fn create_groups_accumulator(
&self,
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
let data_type = self.data_type.clone();
Ok(Some(Box::new(
GroupsAccumulatorFlatAdapter::<SumAccumulator>::new(move || {
SumAccumulator::try_new(&data_type)
}),
)))
use arrow::datatypes::ArrowPrimitiveType;

macro_rules! make_accumulator {
($T:ty, $U:ty) => {
Box::new(PrimitiveGroupsAccumulator::<$T, $U, _, _>::new(
&<$T as ArrowPrimitiveType>::DATA_TYPE,
|x: &mut <$T as ArrowPrimitiveType>::Native,
y: <$U as ArrowPrimitiveType>::Native| {
*x = *x + (y as <$T as ArrowPrimitiveType>::Native);
},
|x: &mut <$T as ArrowPrimitiveType>::Native,
y: <$T as ArrowPrimitiveType>::Native| {
*x = *x + y;
},
))
};
}

// Note that upstream uses x.add_wrapping(y) for the sum functions -- but here we just mimic
// the current datafusion Sum accumulator implementation using native +. (That native +
// specifically is the one in the expressions *x = *x + ... above.)
Ok(Some(match (&self.data_type, &self.input_data_type) {
(DataType::Int64, DataType::Int64) => make_accumulator!(
arrow::datatypes::Int64Type,
arrow::datatypes::Int64Type
),
(DataType::Int64, DataType::Int32) => make_accumulator!(
arrow::datatypes::Int64Type,
arrow::datatypes::Int32Type
),
(DataType::Int64, DataType::Int16) => make_accumulator!(
arrow::datatypes::Int64Type,
arrow::datatypes::Int16Type
),
(DataType::Int64, DataType::Int8) => {
make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int8Type)
}

(DataType::Int96, DataType::Int96) => make_accumulator!(
arrow::datatypes::Int96Type,
arrow::datatypes::Int96Type
),

(DataType::Int64Decimal(0), DataType::Int64Decimal(0)) => make_accumulator!(
arrow::datatypes::Int64Decimal0Type,
arrow::datatypes::Int64Decimal0Type
),
(DataType::Int64Decimal(1), DataType::Int64Decimal(1)) => make_accumulator!(
arrow::datatypes::Int64Decimal1Type,
arrow::datatypes::Int64Decimal1Type
),
(DataType::Int64Decimal(2), DataType::Int64Decimal(2)) => make_accumulator!(
arrow::datatypes::Int64Decimal2Type,
arrow::datatypes::Int64Decimal2Type
),
(DataType::Int64Decimal(3), DataType::Int64Decimal(3)) => make_accumulator!(
arrow::datatypes::Int64Decimal3Type,
arrow::datatypes::Int64Decimal3Type
),
(DataType::Int64Decimal(4), DataType::Int64Decimal(4)) => make_accumulator!(
arrow::datatypes::Int64Decimal4Type,
arrow::datatypes::Int64Decimal4Type
),
(DataType::Int64Decimal(5), DataType::Int64Decimal(5)) => make_accumulator!(
arrow::datatypes::Int64Decimal5Type,
arrow::datatypes::Int64Decimal5Type
),
(DataType::Int64Decimal(10), DataType::Int64Decimal(10)) => {
make_accumulator!(
arrow::datatypes::Int64Decimal10Type,
arrow::datatypes::Int64Decimal10Type
)
}

(DataType::Int96Decimal(0), DataType::Int96Decimal(0)) => make_accumulator!(
arrow::datatypes::Int96Decimal0Type,
arrow::datatypes::Int96Decimal0Type
),
(DataType::Int96Decimal(1), DataType::Int96Decimal(1)) => make_accumulator!(
arrow::datatypes::Int96Decimal1Type,
arrow::datatypes::Int96Decimal1Type
),
(DataType::Int96Decimal(2), DataType::Int96Decimal(2)) => make_accumulator!(
arrow::datatypes::Int96Decimal2Type,
arrow::datatypes::Int96Decimal2Type
),
(DataType::Int96Decimal(3), DataType::Int96Decimal(3)) => make_accumulator!(
arrow::datatypes::Int96Decimal3Type,
arrow::datatypes::Int96Decimal3Type
),
(DataType::Int96Decimal(4), DataType::Int96Decimal(4)) => make_accumulator!(
arrow::datatypes::Int96Decimal4Type,
arrow::datatypes::Int96Decimal4Type
),
(DataType::Int96Decimal(5), DataType::Int96Decimal(5)) => make_accumulator!(
arrow::datatypes::Int96Decimal5Type,
arrow::datatypes::Int96Decimal5Type
),
(DataType::Int96Decimal(10), DataType::Int96Decimal(10)) => {
make_accumulator!(
arrow::datatypes::Int96Decimal10Type,
arrow::datatypes::Int96Decimal10Type
)
}

(DataType::UInt64, DataType::UInt64) => make_accumulator!(
arrow::datatypes::UInt64Type,
arrow::datatypes::UInt64Type
),
(DataType::UInt64, DataType::UInt32) => make_accumulator!(
arrow::datatypes::UInt64Type,
arrow::datatypes::UInt32Type
),
(DataType::UInt64, DataType::UInt16) => make_accumulator!(
arrow::datatypes::UInt64Type,
arrow::datatypes::UInt16Type
),
(DataType::UInt64, DataType::UInt8) => make_accumulator!(
arrow::datatypes::UInt64Type,
arrow::datatypes::UInt8Type
),

(DataType::Float32, DataType::Float32) => make_accumulator!(
arrow::datatypes::Float32Type,
arrow::datatypes::Float32Type
),
(DataType::Float64, DataType::Float64) => make_accumulator!(
arrow::datatypes::Float64Type,
arrow::datatypes::Float64Type
),

_ => {
// This case should never be reached because we've handled all sum_return_type
// arg_type values. Nonetheless:
let data_type = self.data_type.clone();

Box::new(GroupsAccumulatorFlatAdapter::<SumAccumulator>::new(
move || SumAccumulator::try_new(&data_type),
))
}
}))
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -416,13 +558,27 @@ mod tests {
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;

// A wrapper to make Sum::new, which now has an input_type argument, work with
// generic_test_op!.
struct SumTestStandin;
impl SumTestStandin {
fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
) -> Sum {
Sum::new(expr, name, data_type.clone(), &data_type)
}
}

#[test]
fn sum_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));

generic_test_op!(
a,
DataType::Int32,
Sum,
SumTestStandin,
ScalarValue::from(15i64),
DataType::Int64
)
Expand All @@ -440,7 +596,7 @@ mod tests {
generic_test_op!(
a,
DataType::Int32,
Sum,
SumTestStandin,
ScalarValue::from(13i64),
DataType::Int64
)
Expand All @@ -452,7 +608,7 @@ mod tests {
generic_test_op!(
a,
DataType::Int32,
Sum,
SumTestStandin,
ScalarValue::Int64(None),
DataType::Int64
)
Expand All @@ -465,7 +621,7 @@ mod tests {
generic_test_op!(
a,
DataType::UInt32,
Sum,
SumTestStandin,
ScalarValue::from(15u64),
DataType::UInt64
)
Expand All @@ -478,7 +634,7 @@ mod tests {
generic_test_op!(
a,
DataType::Float32,
Sum,
SumTestStandin,
ScalarValue::from(15_f32),
DataType::Float32
)
Expand All @@ -491,7 +647,7 @@ mod tests {
generic_test_op!(
a,
DataType::Float64,
Sum,
SumTestStandin,
ScalarValue::from(15_f64),
DataType::Float64
)
Expand Down
10 changes: 0 additions & 10 deletions datafusion/src/physical_plan/groups_accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
//! Vectorized [`GroupsAccumulator`]

use crate::error::{DataFusionError, Result};
use crate::scalar::ScalarValue;
use arrow::array::{ArrayRef, BooleanArray};
use smallvec::SmallVec;

/// From upstream: This replaces a datafusion_common::{not_impl_err} import.
macro_rules! not_impl_err {
Expand Down Expand Up @@ -194,10 +192,6 @@ pub trait GroupsAccumulator: Send {
/// `n`. See [`EmitTo::First`] for more details.
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef>;

// TODO: Remove this?
/// evaluate for a particular group index.
fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue>;

/// Returns the intermediate aggregate state for this accumulator,
/// used for multi-phase grouping, resetting its internal state.
///
Expand All @@ -216,10 +210,6 @@ pub trait GroupsAccumulator: Send {
/// [`Accumulator::state`]: crate::accumulator::Accumulator::state
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;

// TODO: Remove this?
/// Looks at the state for a particular group index.
fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>>;

/// Merges intermediate state (the output from [`Self::state`])
/// into this accumulator's current state.
///
Expand Down
9 changes: 0 additions & 9 deletions datafusion/src/physical_plan/groups_accumulator_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ use arrow::{
compute,
datatypes::UInt32Type,
};
use smallvec::SmallVec;

/// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`]
///
Expand Down Expand Up @@ -345,10 +344,6 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
result
}

fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue> {
self.states[group_index].accumulator.evaluate()
}

// filtered_null_mask(opt_filter, &values);
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let vec_size_pre = self.states.allocated_size();
Expand Down Expand Up @@ -385,10 +380,6 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
Ok(arrays)
}

fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>> {
self.states[group_index].accumulator.state()
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,6 @@ impl<AccumulatorType: Accumulator> GroupsAccumulator
result
}

fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue> {
self.accumulators[group_index].evaluate()
}

// filtered_null_mask(opt_filter, &values);
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let vec_size_pre = self.accumulators.allocated_size();
Expand Down Expand Up @@ -428,10 +424,6 @@ impl<AccumulatorType: Accumulator> GroupsAccumulator
Ok(arrays)
}

fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>> {
self.accumulators[group_index].state()
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
Expand Down
Loading
Loading