diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index 6396b92e7a2f3..aaa3c859450e6 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -137,50 +137,133 @@ impl AggregateExpr for Sum { use arrow::datatypes::ArrowPrimitiveType; macro_rules! make_accumulator { - ($T:ty, $U:ty) => { Box::new(PrimitiveGroupsAccumulator::< - $T, - $U, - _, - _, - >::new(& <$T as ArrowPrimitiveType>::DATA_TYPE, |x: &mut <$T as ArrowPrimitiveType>::Native, y: <$U as ArrowPrimitiveType>::Native| { - *x = *x + (y as <$T as ArrowPrimitiveType>::Native); - }, |x: &mut <$T as ArrowPrimitiveType>::Native, y: <$T as ArrowPrimitiveType>::Native| { *x = *x + y; })) }; + ($T:ty, $U:ty) => { + Box::new(PrimitiveGroupsAccumulator::<$T, $U, _, _>::new( + &<$T as ArrowPrimitiveType>::DATA_TYPE, + |x: &mut <$T as ArrowPrimitiveType>::Native, + y: <$U as ArrowPrimitiveType>::Native| { + *x = *x + (y as <$T as ArrowPrimitiveType>::Native); + }, + |x: &mut <$T as ArrowPrimitiveType>::Native, + y: <$T as ArrowPrimitiveType>::Native| { + *x = *x + y; + }, + )) + }; } // Note that upstream uses x.add_wrapping(y) for the sum functions -- but here we just mimic // the current datafusion Sum accumulator implementation using native +. (That native + // specifically is the one in the expressions *x = *x + ... above.) Ok(Some(match (&self.data_type, &self.input_data_type) { - (DataType::Int64, DataType::Int64) => make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int64Type), - (DataType::Int64, DataType::Int32) => make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int32Type), - (DataType::Int64, DataType::Int16) => make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int16Type), - (DataType::Int64, DataType::Int8) => make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int8Type), - - (DataType::Int96, DataType::Int96) => make_accumulator!(arrow::datatypes::Int96Type, arrow::datatypes::Int96Type), - - (DataType::Int64Decimal(0), DataType::Int64Decimal(0)) => make_accumulator!(arrow::datatypes::Int64Decimal0Type, arrow::datatypes::Int64Decimal0Type), - (DataType::Int64Decimal(1), DataType::Int64Decimal(1)) => make_accumulator!(arrow::datatypes::Int64Decimal1Type, arrow::datatypes::Int64Decimal1Type), - (DataType::Int64Decimal(2), DataType::Int64Decimal(2)) => make_accumulator!(arrow::datatypes::Int64Decimal2Type, arrow::datatypes::Int64Decimal2Type), - (DataType::Int64Decimal(3), DataType::Int64Decimal(3)) => make_accumulator!(arrow::datatypes::Int64Decimal3Type, arrow::datatypes::Int64Decimal3Type), - (DataType::Int64Decimal(4), DataType::Int64Decimal(4)) => make_accumulator!(arrow::datatypes::Int64Decimal4Type, arrow::datatypes::Int64Decimal4Type), - (DataType::Int64Decimal(5), DataType::Int64Decimal(5)) => make_accumulator!(arrow::datatypes::Int64Decimal5Type, arrow::datatypes::Int64Decimal5Type), - (DataType::Int64Decimal(10), DataType::Int64Decimal(10)) => make_accumulator!(arrow::datatypes::Int64Decimal10Type, arrow::datatypes::Int64Decimal10Type), - - (DataType::Int96Decimal(0), DataType::Int96Decimal(0)) => make_accumulator!(arrow::datatypes::Int96Decimal0Type, arrow::datatypes::Int96Decimal0Type), - (DataType::Int96Decimal(1), DataType::Int96Decimal(1)) => make_accumulator!(arrow::datatypes::Int96Decimal1Type, arrow::datatypes::Int96Decimal1Type), - (DataType::Int96Decimal(2), DataType::Int96Decimal(2)) => make_accumulator!(arrow::datatypes::Int96Decimal2Type, arrow::datatypes::Int96Decimal2Type), - (DataType::Int96Decimal(3), DataType::Int96Decimal(3)) => make_accumulator!(arrow::datatypes::Int96Decimal3Type, arrow::datatypes::Int96Decimal3Type), - (DataType::Int96Decimal(4), DataType::Int96Decimal(4)) => make_accumulator!(arrow::datatypes::Int96Decimal4Type, arrow::datatypes::Int96Decimal4Type), - (DataType::Int96Decimal(5), DataType::Int96Decimal(5)) => make_accumulator!(arrow::datatypes::Int96Decimal5Type, arrow::datatypes::Int96Decimal5Type), - (DataType::Int96Decimal(10), DataType::Int96Decimal(10)) => make_accumulator!(arrow::datatypes::Int96Decimal10Type, arrow::datatypes::Int96Decimal10Type), - - (DataType::UInt64, DataType::UInt64) => make_accumulator!(arrow::datatypes::UInt64Type, arrow::datatypes::UInt64Type), - (DataType::UInt64, DataType::UInt32) => make_accumulator!(arrow::datatypes::UInt64Type, arrow::datatypes::UInt32Type), - (DataType::UInt64, DataType::UInt16) => make_accumulator!(arrow::datatypes::UInt64Type, arrow::datatypes::UInt16Type), - (DataType::UInt64, DataType::UInt8) => make_accumulator!(arrow::datatypes::UInt64Type, arrow::datatypes::UInt8Type), - - (DataType::Float32, DataType::Float32) => make_accumulator!(arrow::datatypes::Float32Type, arrow::datatypes::Float32Type), - (DataType::Float64, DataType::Float64) => make_accumulator!(arrow::datatypes::Float64Type, arrow::datatypes::Float64Type), + (DataType::Int64, DataType::Int64) => make_accumulator!( + arrow::datatypes::Int64Type, + arrow::datatypes::Int64Type + ), + (DataType::Int64, DataType::Int32) => make_accumulator!( + arrow::datatypes::Int64Type, + arrow::datatypes::Int32Type + ), + (DataType::Int64, DataType::Int16) => make_accumulator!( + arrow::datatypes::Int64Type, + arrow::datatypes::Int16Type + ), + (DataType::Int64, DataType::Int8) => { + make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int8Type) + } + + (DataType::Int96, DataType::Int96) => make_accumulator!( + arrow::datatypes::Int96Type, + arrow::datatypes::Int96Type + ), + + (DataType::Int64Decimal(0), DataType::Int64Decimal(0)) => make_accumulator!( + arrow::datatypes::Int64Decimal0Type, + arrow::datatypes::Int64Decimal0Type + ), + (DataType::Int64Decimal(1), DataType::Int64Decimal(1)) => make_accumulator!( + arrow::datatypes::Int64Decimal1Type, + arrow::datatypes::Int64Decimal1Type + ), + (DataType::Int64Decimal(2), DataType::Int64Decimal(2)) => make_accumulator!( + arrow::datatypes::Int64Decimal2Type, + arrow::datatypes::Int64Decimal2Type + ), + (DataType::Int64Decimal(3), DataType::Int64Decimal(3)) => make_accumulator!( + arrow::datatypes::Int64Decimal3Type, + arrow::datatypes::Int64Decimal3Type + ), + (DataType::Int64Decimal(4), DataType::Int64Decimal(4)) => make_accumulator!( + arrow::datatypes::Int64Decimal4Type, + arrow::datatypes::Int64Decimal4Type + ), + (DataType::Int64Decimal(5), DataType::Int64Decimal(5)) => make_accumulator!( + arrow::datatypes::Int64Decimal5Type, + arrow::datatypes::Int64Decimal5Type + ), + (DataType::Int64Decimal(10), DataType::Int64Decimal(10)) => { + make_accumulator!( + arrow::datatypes::Int64Decimal10Type, + arrow::datatypes::Int64Decimal10Type + ) + } + + (DataType::Int96Decimal(0), DataType::Int96Decimal(0)) => make_accumulator!( + arrow::datatypes::Int96Decimal0Type, + arrow::datatypes::Int96Decimal0Type + ), + (DataType::Int96Decimal(1), DataType::Int96Decimal(1)) => make_accumulator!( + arrow::datatypes::Int96Decimal1Type, + arrow::datatypes::Int96Decimal1Type + ), + (DataType::Int96Decimal(2), DataType::Int96Decimal(2)) => make_accumulator!( + arrow::datatypes::Int96Decimal2Type, + arrow::datatypes::Int96Decimal2Type + ), + (DataType::Int96Decimal(3), DataType::Int96Decimal(3)) => make_accumulator!( + arrow::datatypes::Int96Decimal3Type, + arrow::datatypes::Int96Decimal3Type + ), + (DataType::Int96Decimal(4), DataType::Int96Decimal(4)) => make_accumulator!( + arrow::datatypes::Int96Decimal4Type, + arrow::datatypes::Int96Decimal4Type + ), + (DataType::Int96Decimal(5), DataType::Int96Decimal(5)) => make_accumulator!( + arrow::datatypes::Int96Decimal5Type, + arrow::datatypes::Int96Decimal5Type + ), + (DataType::Int96Decimal(10), DataType::Int96Decimal(10)) => { + make_accumulator!( + arrow::datatypes::Int96Decimal10Type, + arrow::datatypes::Int96Decimal10Type + ) + } + + (DataType::UInt64, DataType::UInt64) => make_accumulator!( + arrow::datatypes::UInt64Type, + arrow::datatypes::UInt64Type + ), + (DataType::UInt64, DataType::UInt32) => make_accumulator!( + arrow::datatypes::UInt64Type, + arrow::datatypes::UInt32Type + ), + (DataType::UInt64, DataType::UInt16) => make_accumulator!( + arrow::datatypes::UInt64Type, + arrow::datatypes::UInt16Type + ), + (DataType::UInt64, DataType::UInt8) => make_accumulator!( + arrow::datatypes::UInt64Type, + arrow::datatypes::UInt8Type + ), + + (DataType::Float32, DataType::Float32) => make_accumulator!( + arrow::datatypes::Float32Type, + arrow::datatypes::Float32Type + ), + (DataType::Float64, DataType::Float64) => make_accumulator!( + arrow::datatypes::Float64Type, + arrow::datatypes::Float64Type + ), _ => { // This case should never be reached because we've handled all sum_return_type @@ -479,9 +562,11 @@ mod tests { // generic_test_op!. struct SumTestStandin; impl SumTestStandin { - fn new(expr: Arc, - name: impl Into, - data_type: DataType) -> Sum { + fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Sum { Sum::new(expr, name, data_type.clone(), &data_type) } } diff --git a/datafusion/src/physical_plan/groups_accumulator.rs b/datafusion/src/physical_plan/groups_accumulator.rs index 3ea1e0665e4c7..795ba3bcc5192 100644 --- a/datafusion/src/physical_plan/groups_accumulator.rs +++ b/datafusion/src/physical_plan/groups_accumulator.rs @@ -18,9 +18,7 @@ //! Vectorized [`GroupsAccumulator`] use crate::error::{DataFusionError, Result}; -use crate::scalar::ScalarValue; use arrow::array::{ArrayRef, BooleanArray}; -use smallvec::SmallVec; /// From upstream: This replaces a datafusion_common::{not_impl_err} import. macro_rules! not_impl_err { diff --git a/datafusion/src/physical_plan/groups_accumulator_adapter.rs b/datafusion/src/physical_plan/groups_accumulator_adapter.rs index 35b35003c9a16..ceccc46bbaf58 100644 --- a/datafusion/src/physical_plan/groups_accumulator_adapter.rs +++ b/datafusion/src/physical_plan/groups_accumulator_adapter.rs @@ -33,7 +33,6 @@ use arrow::{ compute, datatypes::UInt32Type, }; -use smallvec::SmallVec; /// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`] /// diff --git a/datafusion/src/physical_plan/groups_accumulator_prim_op.rs b/datafusion/src/physical_plan/groups_accumulator_prim_op.rs index 36e9e2a6ce964..358a91a285a63 100644 --- a/datafusion/src/physical_plan/groups_accumulator_prim_op.rs +++ b/datafusion/src/physical_plan/groups_accumulator_prim_op.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! PrimitiveGroupsAccumulator + use std::any::type_name; use std::marker::PhantomData; use std::mem::size_of; @@ -97,7 +99,10 @@ where } /// Helper for update_batch and merge_batch -- (V, H) is either (T, G) or (U, F) respectively. - fn update_or_merge_batch( + fn update_or_merge_batch< + V: ArrowPrimitiveType + Send, + H: Fn(&mut T::Native, V::Native) + Send + Sync, + >( &mut self, values: &[ArrayRef], group_indices: &[usize], @@ -153,14 +158,19 @@ where total_num_groups: usize, ) -> Result<()> { // update / merge are almost the same except we're adding T's vs. adding U's. - self.update_or_merge_batch::(values, group_indices, opt_filter, total_num_groups, self.update_fn) + self.update_or_merge_batch::( + values, + group_indices, + opt_filter, + total_num_groups, + self.update_fn, + ) } fn evaluate(&mut self, emit_to: EmitTo) -> Result { let values = emit_to.take_needed(&mut self.values); let nulls = self.null_state.build(emit_to); - let buffers = vec![Buffer::from_slice_ref(&values)]; // TODO: This copies. Ideally, don't. Note: Avoiding this memcpy has minimal performance impact. let data = ArrayData::new( @@ -168,7 +178,7 @@ where values.len(), None, Some(nulls.into_buffer()), - 0, /* offset */ + 0, /* offset */ buffers, vec![], ); @@ -187,7 +197,13 @@ where total_num_groups: usize, ) -> Result<()> { // update / merge are almost the same except we're adding T's vs. adding U's. - self.update_or_merge_batch::(values, group_indices, opt_filter, total_num_groups, self.merge_fn) + self.update_or_merge_batch::( + values, + group_indices, + opt_filter, + total_num_groups, + self.merge_fn, + ) } /// Converts an input batch directly to a state batch diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 2d97325a04845..d549855b937c8 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -1174,8 +1174,6 @@ pub(crate) fn create_batch_from_map( // 4. collect all in a vector per key of vec, vec[i][j] // 5. concatenate the arrays over the second index [j] into a single vec. - let mut key_columns: Vec> = Vec::with_capacity(num_group_expr); - let key_columns: Vec> = write_group_result_rows_for_keys( &accumulation_state.flattened_group_by_values, accumulation_state.next_group_index,