From 39b251ba24bf76fb3ebd3759b522c02a7ab30163 Mon Sep 17 00:00:00 2001 From: Sam Hughes Date: Sun, 24 Nov 2024 23:40:48 -0800 Subject: [PATCH 1/2] perf: Use flattened_group_by_values to accumulate group keys for output --- datafusion/src/physical_plan/hash_aggregate.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index a27fe6ec6a2c..895c34082fc9 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -463,16 +463,15 @@ pub(crate) fn group_aggregate_batch( // 1.2 .or_insert_with(|| { batch_keys.append_value(&key).expect("must not fail"); + // Note that we still use plain String objects in GroupByScalar. Thus flattened_group_by_values isn't that great. let _ = create_group_by_values(&group_values, row, &mut group_by_values); - let mut taken_values = - smallvec![GroupByScalar::UInt32(0); group_values.len()]; - std::mem::swap(&mut taken_values, &mut group_by_values); + accumulation_state.flattened_group_by_values.extend( + group_by_values.iter_mut().map(|x| std::mem::replace(x, GroupByScalar::UInt32(0)))); let group_index = accumulation_state.next_group_index; accumulation_state.next_group_index += 1; ( key.clone(), AccumulationGroupState { - group_by_values: taken_values, indices: smallvec![row as u32], group_index, }, @@ -884,7 +883,6 @@ pub type Accumulators = HashMap; #[allow(missing_docs)] pub struct AccumulationGroupState { - group_by_values: SmallVec<[GroupByScalar; 2]>, indices: SmallVec<[u32; 4]>, group_index: usize, } @@ -893,6 +891,8 @@ pub struct AccumulationGroupState { #[derive(Default)] pub struct AccumulationState { accumulators: HashMap, + // Of length accumulators.len() * N where N is the number of group by columns. + flattened_group_by_values: Vec, groups_accumulators: Vec>, // For now, always equal to accumulators.len() next_group_index: usize, @@ -905,6 +905,7 @@ impl AccumulationState { ) -> AccumulationState { AccumulationState { accumulators: HashMap::new(), + flattened_group_by_values: Vec::new(), groups_accumulators, next_group_index: 0, } @@ -1174,12 +1175,13 @@ pub(crate) fn create_batch_from_map( for ( _, AccumulationGroupState { - group_by_values, group_index, .. }, ) in &accumulation_state.accumulators { + let group_by_values: &[GroupByScalar] = &accumulation_state.flattened_group_by_values[num_group_expr * group_index..num_group_expr * (group_index + 1)]; + // 2 and 3. write_group_result_row_with_groups_accumulator( *mode, From 594d1ffd6ecd7a9189ec2884708585bc6641fe69 Mon Sep 17 00:00:00 2001 From: Sam Hughes Date: Sun, 10 Nov 2024 12:32:47 -0800 Subject: [PATCH 2/2] perf: Make Sum use PrimitiveGroupsAccumulator --- datafusion/src/cube_ext/joinagg.rs | 5 +- datafusion/src/physical_plan/aggregates.rs | 2 +- .../src/physical_plan/expressions/sum.rs | 180 +++- .../src/physical_plan/groups_accumulator.rs | 10 - .../groups_accumulator_adapter.rs | 9 - .../groups_accumulator_flat_adapter.rs | 8 - .../groups_accumulator_prim_op.rs | 343 +++++++ .../src/physical_plan/hash_aggregate.rs | 174 ++-- datafusion/src/physical_plan/mod.rs | 2 + datafusion/src/physical_plan/null_state.rs | 949 ++++++++++++++++++ 10 files changed, 1549 insertions(+), 133 deletions(-) create mode 100644 datafusion/src/physical_plan/groups_accumulator_prim_op.rs create mode 100644 datafusion/src/physical_plan/null_state.rs diff --git a/datafusion/src/cube_ext/joinagg.rs b/datafusion/src/cube_ext/joinagg.rs index 2324398bcb46..4dac971ca260 100644 --- a/datafusion/src/cube_ext/joinagg.rs +++ b/datafusion/src/cube_ext/joinagg.rs @@ -245,7 +245,8 @@ impl ExecutionPlan for CrossJoinAggExec { &AggregateMode::Full, self.group_expr.len(), )?; - let mut accumulators = create_accumulation_state(&self.agg_expr)?; + let mut accumulators: hash_aggregate::AccumulationState = + create_accumulation_state(&self.agg_expr)?; for partition in 0..self.join.right.output_partitioning().partition_count() { let mut batches = self.join.right.execute(partition).await?; while let Some(right) = batches.next().await { @@ -273,7 +274,7 @@ impl ExecutionPlan for CrossJoinAggExec { let out_schema = self.schema.clone(); let r = hash_aggregate::create_batch_from_map( &AggregateMode::Full, - &accumulators, + accumulators, self.group_expr.len(), &out_schema, )?; diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index ed9a67d6625b..de86f4c97c90 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -144,7 +144,7 @@ pub fn create_aggregate_expr( )) } (AggregateFunction::Sum, false) => { - Arc::new(expressions::Sum::new(arg, name, return_type)) + Arc::new(expressions::Sum::new(arg, name, return_type, &arg_types[0])) } (AggregateFunction::Sum, true) => { return Err(DataFusionError::NotImplemented( diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index 958817ccef6d..aaa3c859450e 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -24,6 +24,7 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use crate::physical_plan::groups_accumulator::GroupsAccumulator; use crate::physical_plan::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter; +use crate::physical_plan::groups_accumulator_prim_op::PrimitiveGroupsAccumulator; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; use crate::scalar::ScalarValue; use arrow::compute; @@ -49,6 +50,7 @@ use smallvec::SmallVec; pub struct Sum { name: String, data_type: DataType, + input_data_type: DataType, expr: Arc, nullable: bool, } @@ -80,11 +82,16 @@ impl Sum { expr: Arc, name: impl Into, data_type: DataType, + input_data_type: &DataType, ) -> Self { + // Note: data_type = sum_return_type(input_data_type) in the actual caller, so we don't + // really need two params. But, we keep the four params to break symmetry with other + // accumulators and any code that might use 3 params, such as the generic_test_op macro. Self { name: name.into(), expr, data_type, + input_data_type: input_data_type.clone(), nullable: true, } } @@ -127,12 +134,147 @@ impl AggregateExpr for Sum { fn create_groups_accumulator( &self, ) -> arrow::error::Result>> { - let data_type = self.data_type.clone(); - Ok(Some(Box::new( - GroupsAccumulatorFlatAdapter::::new(move || { - SumAccumulator::try_new(&data_type) - }), - ))) + use arrow::datatypes::ArrowPrimitiveType; + + macro_rules! make_accumulator { + ($T:ty, $U:ty) => { + Box::new(PrimitiveGroupsAccumulator::<$T, $U, _, _>::new( + &<$T as ArrowPrimitiveType>::DATA_TYPE, + |x: &mut <$T as ArrowPrimitiveType>::Native, + y: <$U as ArrowPrimitiveType>::Native| { + *x = *x + (y as <$T as ArrowPrimitiveType>::Native); + }, + |x: &mut <$T as ArrowPrimitiveType>::Native, + y: <$T as ArrowPrimitiveType>::Native| { + *x = *x + y; + }, + )) + }; + } + + // Note that upstream uses x.add_wrapping(y) for the sum functions -- but here we just mimic + // the current datafusion Sum accumulator implementation using native +. (That native + + // specifically is the one in the expressions *x = *x + ... above.) + Ok(Some(match (&self.data_type, &self.input_data_type) { + (DataType::Int64, DataType::Int64) => make_accumulator!( + arrow::datatypes::Int64Type, + arrow::datatypes::Int64Type + ), + (DataType::Int64, DataType::Int32) => make_accumulator!( + arrow::datatypes::Int64Type, + arrow::datatypes::Int32Type + ), + (DataType::Int64, DataType::Int16) => make_accumulator!( + arrow::datatypes::Int64Type, + arrow::datatypes::Int16Type + ), + (DataType::Int64, DataType::Int8) => { + make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int8Type) + } + + (DataType::Int96, DataType::Int96) => make_accumulator!( + arrow::datatypes::Int96Type, + arrow::datatypes::Int96Type + ), + + (DataType::Int64Decimal(0), DataType::Int64Decimal(0)) => make_accumulator!( + arrow::datatypes::Int64Decimal0Type, + arrow::datatypes::Int64Decimal0Type + ), + (DataType::Int64Decimal(1), DataType::Int64Decimal(1)) => make_accumulator!( + arrow::datatypes::Int64Decimal1Type, + arrow::datatypes::Int64Decimal1Type + ), + (DataType::Int64Decimal(2), DataType::Int64Decimal(2)) => make_accumulator!( + arrow::datatypes::Int64Decimal2Type, + arrow::datatypes::Int64Decimal2Type + ), + (DataType::Int64Decimal(3), DataType::Int64Decimal(3)) => make_accumulator!( + arrow::datatypes::Int64Decimal3Type, + arrow::datatypes::Int64Decimal3Type + ), + (DataType::Int64Decimal(4), DataType::Int64Decimal(4)) => make_accumulator!( + arrow::datatypes::Int64Decimal4Type, + arrow::datatypes::Int64Decimal4Type + ), + (DataType::Int64Decimal(5), DataType::Int64Decimal(5)) => make_accumulator!( + arrow::datatypes::Int64Decimal5Type, + arrow::datatypes::Int64Decimal5Type + ), + (DataType::Int64Decimal(10), DataType::Int64Decimal(10)) => { + make_accumulator!( + arrow::datatypes::Int64Decimal10Type, + arrow::datatypes::Int64Decimal10Type + ) + } + + (DataType::Int96Decimal(0), DataType::Int96Decimal(0)) => make_accumulator!( + arrow::datatypes::Int96Decimal0Type, + arrow::datatypes::Int96Decimal0Type + ), + (DataType::Int96Decimal(1), DataType::Int96Decimal(1)) => make_accumulator!( + arrow::datatypes::Int96Decimal1Type, + arrow::datatypes::Int96Decimal1Type + ), + (DataType::Int96Decimal(2), DataType::Int96Decimal(2)) => make_accumulator!( + arrow::datatypes::Int96Decimal2Type, + arrow::datatypes::Int96Decimal2Type + ), + (DataType::Int96Decimal(3), DataType::Int96Decimal(3)) => make_accumulator!( + arrow::datatypes::Int96Decimal3Type, + arrow::datatypes::Int96Decimal3Type + ), + (DataType::Int96Decimal(4), DataType::Int96Decimal(4)) => make_accumulator!( + arrow::datatypes::Int96Decimal4Type, + arrow::datatypes::Int96Decimal4Type + ), + (DataType::Int96Decimal(5), DataType::Int96Decimal(5)) => make_accumulator!( + arrow::datatypes::Int96Decimal5Type, + arrow::datatypes::Int96Decimal5Type + ), + (DataType::Int96Decimal(10), DataType::Int96Decimal(10)) => { + make_accumulator!( + arrow::datatypes::Int96Decimal10Type, + arrow::datatypes::Int96Decimal10Type + ) + } + + (DataType::UInt64, DataType::UInt64) => make_accumulator!( + arrow::datatypes::UInt64Type, + arrow::datatypes::UInt64Type + ), + (DataType::UInt64, DataType::UInt32) => make_accumulator!( + arrow::datatypes::UInt64Type, + arrow::datatypes::UInt32Type + ), + (DataType::UInt64, DataType::UInt16) => make_accumulator!( + arrow::datatypes::UInt64Type, + arrow::datatypes::UInt16Type + ), + (DataType::UInt64, DataType::UInt8) => make_accumulator!( + arrow::datatypes::UInt64Type, + arrow::datatypes::UInt8Type + ), + + (DataType::Float32, DataType::Float32) => make_accumulator!( + arrow::datatypes::Float32Type, + arrow::datatypes::Float32Type + ), + (DataType::Float64, DataType::Float64) => make_accumulator!( + arrow::datatypes::Float64Type, + arrow::datatypes::Float64Type + ), + + _ => { + // This case should never be reached because we've handled all sum_return_type + // arg_type values. Nonetheless: + let data_type = self.data_type.clone(); + + Box::new(GroupsAccumulatorFlatAdapter::::new( + move || SumAccumulator::try_new(&data_type), + )) + } + })) } fn name(&self) -> &str { @@ -416,13 +558,27 @@ mod tests { use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + // A wrapper to make Sum::new, which now has an input_type argument, work with + // generic_test_op!. + struct SumTestStandin; + impl SumTestStandin { + fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Sum { + Sum::new(expr, name, data_type.clone(), &data_type) + } + } + #[test] fn sum_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + generic_test_op!( a, DataType::Int32, - Sum, + SumTestStandin, ScalarValue::from(15i64), DataType::Int64 ) @@ -440,7 +596,7 @@ mod tests { generic_test_op!( a, DataType::Int32, - Sum, + SumTestStandin, ScalarValue::from(13i64), DataType::Int64 ) @@ -452,7 +608,7 @@ mod tests { generic_test_op!( a, DataType::Int32, - Sum, + SumTestStandin, ScalarValue::Int64(None), DataType::Int64 ) @@ -465,7 +621,7 @@ mod tests { generic_test_op!( a, DataType::UInt32, - Sum, + SumTestStandin, ScalarValue::from(15u64), DataType::UInt64 ) @@ -478,7 +634,7 @@ mod tests { generic_test_op!( a, DataType::Float32, - Sum, + SumTestStandin, ScalarValue::from(15_f32), DataType::Float32 ) @@ -491,7 +647,7 @@ mod tests { generic_test_op!( a, DataType::Float64, - Sum, + SumTestStandin, ScalarValue::from(15_f64), DataType::Float64 ) diff --git a/datafusion/src/physical_plan/groups_accumulator.rs b/datafusion/src/physical_plan/groups_accumulator.rs index dfd8aaa5493c..795ba3bcc519 100644 --- a/datafusion/src/physical_plan/groups_accumulator.rs +++ b/datafusion/src/physical_plan/groups_accumulator.rs @@ -18,9 +18,7 @@ //! Vectorized [`GroupsAccumulator`] use crate::error::{DataFusionError, Result}; -use crate::scalar::ScalarValue; use arrow::array::{ArrayRef, BooleanArray}; -use smallvec::SmallVec; /// From upstream: This replaces a datafusion_common::{not_impl_err} import. macro_rules! not_impl_err { @@ -194,10 +192,6 @@ pub trait GroupsAccumulator: Send { /// `n`. See [`EmitTo::First`] for more details. fn evaluate(&mut self, emit_to: EmitTo) -> Result; - // TODO: Remove this? - /// evaluate for a particular group index. - fn peek_evaluate(&self, group_index: usize) -> Result; - /// Returns the intermediate aggregate state for this accumulator, /// used for multi-phase grouping, resetting its internal state. /// @@ -216,10 +210,6 @@ pub trait GroupsAccumulator: Send { /// [`Accumulator::state`]: crate::accumulator::Accumulator::state fn state(&mut self, emit_to: EmitTo) -> Result>; - // TODO: Remove this? - /// Looks at the state for a particular group index. - fn peek_state(&self, group_index: usize) -> Result>; - /// Merges intermediate state (the output from [`Self::state`]) /// into this accumulator's current state. /// diff --git a/datafusion/src/physical_plan/groups_accumulator_adapter.rs b/datafusion/src/physical_plan/groups_accumulator_adapter.rs index 5b2f62f8c9e5..ceccc46bbaf5 100644 --- a/datafusion/src/physical_plan/groups_accumulator_adapter.rs +++ b/datafusion/src/physical_plan/groups_accumulator_adapter.rs @@ -33,7 +33,6 @@ use arrow::{ compute, datatypes::UInt32Type, }; -use smallvec::SmallVec; /// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`] /// @@ -345,10 +344,6 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { result } - fn peek_evaluate(&self, group_index: usize) -> Result { - self.states[group_index].accumulator.evaluate() - } - // filtered_null_mask(opt_filter, &values); fn state(&mut self, emit_to: EmitTo) -> Result> { let vec_size_pre = self.states.allocated_size(); @@ -385,10 +380,6 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { Ok(arrays) } - fn peek_state(&self, group_index: usize) -> Result> { - self.states[group_index].accumulator.state() - } - fn merge_batch( &mut self, values: &[ArrayRef], diff --git a/datafusion/src/physical_plan/groups_accumulator_flat_adapter.rs b/datafusion/src/physical_plan/groups_accumulator_flat_adapter.rs index 611e3c259e98..a659aafb9472 100644 --- a/datafusion/src/physical_plan/groups_accumulator_flat_adapter.rs +++ b/datafusion/src/physical_plan/groups_accumulator_flat_adapter.rs @@ -387,10 +387,6 @@ impl GroupsAccumulator result } - fn peek_evaluate(&self, group_index: usize) -> Result { - self.accumulators[group_index].evaluate() - } - // filtered_null_mask(opt_filter, &values); fn state(&mut self, emit_to: EmitTo) -> Result> { let vec_size_pre = self.accumulators.allocated_size(); @@ -428,10 +424,6 @@ impl GroupsAccumulator Ok(arrays) } - fn peek_state(&self, group_index: usize) -> Result> { - self.accumulators[group_index].state() - } - fn merge_batch( &mut self, values: &[ArrayRef], diff --git a/datafusion/src/physical_plan/groups_accumulator_prim_op.rs b/datafusion/src/physical_plan/groups_accumulator_prim_op.rs new file mode 100644 index 000000000000..358a91a285a6 --- /dev/null +++ b/datafusion/src/physical_plan/groups_accumulator_prim_op.rs @@ -0,0 +1,343 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! PrimitiveGroupsAccumulator + +use std::any::type_name; +use std::marker::PhantomData; +use std::mem::size_of; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::groups_accumulator::{EmitTo, GroupsAccumulator}; +use arrow::array::{ + ArrayData, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, +}; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowPrimitiveType; +use arrow::datatypes::DataType; + +use crate::physical_plan::null_state::NullState; + +/// An accumulator that implements a single operation over +/// [`ArrowPrimitiveType`] where the accumulated state is the same as +/// the input type (such as `Sum`) +/// +/// F: The function to apply to two elements. The first argument is +/// the existing value and should be updated with the second value +/// (e.g. [`BitAndAssign`] style). +/// +/// [`BitAndAssign`]: std::ops::BitAndAssign +#[derive(Debug)] +pub struct PrimitiveGroupsAccumulator +where + T: ArrowPrimitiveType + Send, + U: ArrowPrimitiveType + Send, + F: Fn(&mut T::Native, U::Native) + Send + Sync + Copy, + G: Fn(&mut T::Native, T::Native) + Send + Sync + Copy, +{ + /// values per group, stored as the native type + values: Vec, + + /// The output type (needed for Decimal precision and scale) + data_type: DataType, + + /// The starting value for new groups + starting_value: T::Native, + + /// Track nulls in the input / filters + null_state: NullState, + + /// Function that computes the update result from input values + update_fn: F, + + /// Function that computes the merge result from state values + merge_fn: G, + + _marker: std::marker::PhantomData, +} + +impl PrimitiveGroupsAccumulator +where + T: ArrowPrimitiveType + Send, + U: ArrowPrimitiveType + Send, + F: Fn(&mut T::Native, U::Native) + Send + Sync + Copy, + G: Fn(&mut T::Native, T::Native) + Send + Sync + Copy, +{ + #[allow(missing_docs)] + pub fn new(data_type: &DataType, update_fn: F, merge_fn: G) -> Self { + Self { + values: vec![], + data_type: data_type.clone(), + null_state: NullState::new(), + starting_value: T::default_value(), + update_fn, + merge_fn, + _marker: PhantomData, + } + } + + /// Set the starting values for new groups + pub fn with_starting_value(mut self, starting_value: T::Native) -> Self { + self.starting_value = starting_value; + self + } + + /// Helper for update_batch and merge_batch -- (V, H) is either (T, G) or (U, F) respectively. + fn update_or_merge_batch< + V: ArrowPrimitiveType + Send, + H: Fn(&mut T::Native, V::Native) + Send + Sync, + >( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + binary_fn: H, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = match values[0].as_any().downcast_ref::>() { + Some(x) => x, + None => { + panic!( + "values[0] is of unexpected type {:?} while we are of type {:?} (T = {}, U = {}", + values[0].data_type(), + self.data_type, + type_name::(), + type_name::(), + ); + } + }; + + // update values + self.values.resize(total_num_groups, self.starting_value); + + // NullState dispatches / handles tracking nulls and groups that saw no values + self.null_state.accumulate( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value| { + let value = &mut self.values[group_index]; + (binary_fn)(value, new_value); + }, + ); + + Ok(()) + } +} + +impl GroupsAccumulator for PrimitiveGroupsAccumulator +where + T: ArrowPrimitiveType + Send, + U: ArrowPrimitiveType + Send, + F: Fn(&mut T::Native, U::Native) + Send + Sync + Copy, + G: Fn(&mut T::Native, T::Native) + Send + Sync + Copy, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // update / merge are almost the same except we're adding T's vs. adding U's. + self.update_or_merge_batch::( + values, + group_indices, + opt_filter, + total_num_groups, + self.update_fn, + ) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let values = emit_to.take_needed(&mut self.values); + let nulls = self.null_state.build(emit_to); + + let buffers = vec![Buffer::from_slice_ref(&values)]; // TODO: This copies. Ideally, don't. Note: Avoiding this memcpy has minimal performance impact. + + let data = ArrayData::new( + self.data_type.clone(), + values.len(), + None, + Some(nulls.into_buffer()), + 0, /* offset */ + buffers, + vec![], + ); + Ok(Arc::new(PrimitiveArray::::from(data))) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // update / merge are almost the same except we're adding T's vs. adding U's. + self.update_or_merge_batch::( + values, + group_indices, + opt_filter, + total_num_groups, + self.merge_fn, + ) + } + + /// Converts an input batch directly to a state batch + /// + /// The state is: + /// - self.prim_fn for all non null, non filtered values + /// - null otherwise + /// + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let values: PrimitiveArray = values[0] + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + + // Initializing state with starting values + let initial_state = + PrimitiveArray::::from_value(self.starting_value, values.len()); + + // Recalculating values in case there is filter + let values = match opt_filter { + None => values, + Some(filter) => { + let (filter_values, filter_nulls, filter_offset, filter_length) = + filter.clone().into_parts(); + // Calculating filter mask as a result of bitand of filter, and converting it to null buffer + let filter_bool: Buffer = match filter_nulls { + Some(filter_nulls) => (&filter_nulls.into_buffer() & &filter_values)?, + None => filter_values, + }; + let filter_nulls: Bitmap = Bitmap::from(filter_bool); + + // Rebuilding input values with a new nulls mask, which is equal to + // the union of original nulls and filter mask + let values_data: ArrayData = values.into_data(); + let dt = values_data.data_type().clone(); + let (values_buf, original_nulls, values_offset, values_length) = + values_data.into_1_dimensional_parts(); + let nulls_buf = null_buffer_union( + original_nulls, + values_offset, + values_length, + filter_nulls, + filter_offset, + filter_length, + )?; + + let data = ArrayData::new( + dt, + values_length, + None, + Some(nulls_buf.into_buffer()), + values_offset, + vec![values_buf], + vec![], + ); + PrimitiveArray::::from(data) + } + }; + + // TODO: Use a function math_op_mut, like upstream, with initial_state passed by value. + let state_values = arrow::compute::math_op_with_data_type( + self.data_type.clone(), + &initial_state, + &values, + |mut x, y| { + (self.update_fn)(&mut x, y); + x + }, + ); + let state_values: PrimitiveArray = + state_values.map_err(DataFusionError::from)?; + + Ok(vec![Arc::new(state_values)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.values.capacity() * size_of::() + self.null_state.size() + } +} + +/// Returns a bitmap whose offset is lhs_offset. +fn null_buffer_union( + lhs: Option, + lhs_offset: usize, + lhs_len: usize, + rhs: Bitmap, + rhs_offset: usize, + rhs_len: usize, +) -> Result { + assert_eq!(lhs_len, rhs_len); + match lhs { + Some(lhs) => { + if lhs_offset == rhs_offset && lhs.len() == rhs.len() { + // TODO: Do &= instead. + // TODO: We shouldn't need lhs.len() == rhs.len(), but it makes it more convenient to use the Bitmap & operator, but... they probably are in the happy path, anyway. + + // The bitmaps are true for non-null "valid" entries -- hence the union operation bitwise anding. + let new_bitmap: Bitmap = (&lhs & &rhs)?; + Ok(new_bitmap) + } else { + // TODO: Dog-awful performance. + let mut ret = BooleanBufferBuilder::new(lhs_offset + lhs_len); + for _ in 0..lhs_offset { + ret.append(false); + } + for i in 0..lhs_offset { + ret.append(lhs.is_set(i + lhs_offset) & rhs.is_set(i + rhs_offset)); + } + Ok(Bitmap::from(ret.finish())) + } + } + None => { + if lhs_offset == rhs_offset { + Ok(rhs) + } else { + // TODO: Dog-awful performance. + let mut ret = BooleanBufferBuilder::new(lhs_offset + lhs_len); + for _ in 0..lhs_offset { + ret.append(false); + } + for i in 0..lhs_len { + ret.append(rhs.is_set(i + rhs_offset)); + } + Ok(Bitmap::from(ret.finish())) + } + } + } +} diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 895c34082fc9..d549855b937c 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -71,7 +71,7 @@ use arrow::array::{ }; use async_trait::async_trait; -use super::groups_accumulator::GroupsAccumulator; +use super::groups_accumulator::{EmitTo, GroupsAccumulator}; use super::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter; use super::{ expressions::Column, group_scalar::GroupByScalar, RecordBatchStream, @@ -466,7 +466,10 @@ pub(crate) fn group_aggregate_batch( // Note that we still use plain String objects in GroupByScalar. Thus flattened_group_by_values isn't that great. let _ = create_group_by_values(&group_values, row, &mut group_by_values); accumulation_state.flattened_group_by_values.extend( - group_by_values.iter_mut().map(|x| std::mem::replace(x, GroupByScalar::UInt32(0)))); + group_by_values + .iter_mut() + .map(|x| std::mem::replace(x, GroupByScalar::UInt32(0))), + ); let group_index = accumulation_state.next_group_index; accumulation_state.next_group_index += 1; ( @@ -821,7 +824,7 @@ async fn compute_grouped_hash_aggregate( .map_err(DataFusionError::into_arrow_external_error)?; } - create_batch_from_map(&mode, &accumulators, group_expr.len(), &schema) + create_batch_from_map(&mode, accumulators, group_expr.len(), &schema) } impl GroupedHashAggregateStream { @@ -1157,50 +1160,40 @@ impl RecordBatchStream for HashAggregateStream { /// Create a RecordBatch with all group keys and accumulator' states or values. pub(crate) fn create_batch_from_map( mode: &AggregateMode, - accumulation_state: &AccumulationState, + accumulation_state: AccumulationState, num_group_expr: usize, output_schema: &Schema, ) -> ArrowResult { if accumulation_state.accumulators.is_empty() { return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned()))); } - // 1. for each key + // Fake instructions as we do aggregations columnarly. + // 1. for each group index // 2. create single-row ArrayRef with all group expressions - // 3. create single-row ArrayRef with all aggregate states or values + // 3. create single-row ArrayRef with all aggregate states or values (accumulators) // 4. collect all in a vector per key of vec, vec[i][j] // 5. concatenate the arrays over the second index [j] into a single vec. - let mut key_columns: Vec> = Vec::with_capacity(num_group_expr); - let mut value_columns = Vec::new(); - for ( - _, - AccumulationGroupState { - group_index, - .. - }, - ) in &accumulation_state.accumulators - { - let group_by_values: &[GroupByScalar] = &accumulation_state.flattened_group_by_values[num_group_expr * group_index..num_group_expr * (group_index + 1)]; - - // 2 and 3. - write_group_result_row_with_groups_accumulator( - *mode, - group_by_values, - &accumulation_state.groups_accumulators, - *group_index, - &output_schema.fields()[0..num_group_expr], - &mut key_columns, - &mut value_columns, - ) - .map_err(DataFusionError::into_arrow_external_error)?; - } + let key_columns: Vec> = write_group_result_rows_for_keys( + &accumulation_state.flattened_group_by_values, + accumulation_state.next_group_index, + &output_schema.fields()[0..num_group_expr], + ) + .map_err(DataFusionError::into_arrow_external_error)?; + // 3. + let value_columns = finalize_aggregation_into_with_groups_accumulators( + accumulation_state.groups_accumulators, + *mode, + ) + .map_err(DataFusionError::into_arrow_external_error)?; + // 4. let batch = if !key_columns.is_empty() || !value_columns.is_empty() { // 5. let columns = key_columns .into_iter() - .chain(value_columns) - .map(|mut b| b.finish()); + .map(|mut b| b.finish()) + .chain(value_columns); // cast output if needed (e.g. for types like Dictionary where // the intermediate GroupByScalar type was not the same as the @@ -1252,46 +1245,61 @@ pub fn write_group_result_row( finalize_aggregation_into(&accumulator_set, &mode, value_columns) } -// TODO: Dedup with write_group_result_row. #[allow(missing_docs)] -pub fn write_group_result_row_with_groups_accumulator( - mode: AggregateMode, - group_by_values: &[GroupByScalar], - groups_accumulators: &[Box], - group_index: usize, +pub fn write_group_result_rows_for_keys( + flattened_group_by_values: &[GroupByScalar], + num_groups: usize, key_fields: &[Field], - key_columns: &mut Vec>, - value_columns: &mut Vec>, -) -> Result<()> { - let add_key_columns = key_columns.is_empty(); - for i in 0..group_by_values.len() { +) -> Result>> { + // The caller must early exit and it does. Why? Because previous code did so, and it used + // create_builder from a ScalarValue at index 0, and we avoid changing that to minimize risk. + assert!(num_groups > 0); + + let num_group_expr = key_fields.len(); + let mut key_columns: Vec> = Vec::with_capacity(num_group_expr); + + for i in 0..num_group_expr { + // For clarity, we're operating with the first row, group_index 0. + let group_by_values = &flattened_group_by_values[0..num_group_expr]; + // TODO: We could probably do (GroupByValue::Null).to_scalar(...) if create_builder on a + // scalar is even the best way to create a builder. This code with the Utf8 branch and the + // v.to_scalar(...) exists solely as a rearrangement of existing logic, to minimize + // probability of breakage. match &group_by_values[i] { - // Optimization to avoid allocation on conversion to ScalarValue. - GroupByScalar::Utf8(str) => { - if add_key_columns { - key_columns.push(Box::new(StringBuilder::new(0))); - } - key_columns[i] - .as_any_mut() - .downcast_mut::() - .unwrap() - .append_value(str)?; + GroupByScalar::Utf8(_) => { + key_columns.push(Box::new(StringBuilder::new(0))); } v => { - let scalar = v.to_scalar(key_fields[i].data_type()); - if add_key_columns { - key_columns.push(create_builder(&scalar)); + let scalar: ScalarValue = v.to_scalar(key_fields[i].data_type()); + key_columns.push(create_builder(&scalar)); + } + } + } + + // Note that we MUST process groups in ascending group_index order as that's the same order as used for + // accumulator columns. + for group_index in 0..num_groups { + let group_by_values: &[GroupByScalar] = &flattened_group_by_values + [num_group_expr * group_index..num_group_expr * (group_index + 1)]; + + for i in 0..group_by_values.len() { + match &group_by_values[i] { + // Optimization to avoid allocation on conversion to ScalarValue. + GroupByScalar::Utf8(str) => { + key_columns[i] + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_value(str)?; + } + v => { + let scalar = v.to_scalar(key_fields[i].data_type()); + append_value(&mut *key_columns[i], &scalar)?; } - append_value(&mut *key_columns[i], &scalar)?; } } } - finalize_aggregation_into_with_groups_accumulators( - groups_accumulators, - group_index, - &mode, - value_columns, - ) + Ok(key_columns) } #[allow(missing_docs)] @@ -1463,44 +1471,28 @@ fn finalize_aggregation_into( Ok(()) } -/// adds aggregation results into columns, creating the required builders when necessary. +/// Returns aggregation results in columns. /// final value (mode = Final) or states (mode = Partial) fn finalize_aggregation_into_with_groups_accumulators( - groups_accumulators: &[Box], - group_index: usize, - mode: &AggregateMode, - columns: &mut Vec>, -) -> Result<()> { - let add_columns = columns.is_empty(); + mut groups_accumulators: Vec>, + mode: AggregateMode, +) -> Result>> { + let mut columns = Vec::new(); match mode { AggregateMode::Partial => { - let mut col_i = 0; - for ga in groups_accumulators.iter() { - let state = ga.peek_state(group_index)?; - // build the vector of states - for v in state { - if add_columns { - columns.push(create_builder(&v)); - assert_eq!(col_i + 1, columns.len()); - } - append_value(&mut *columns[col_i], &v)?; - col_i += 1; - } + for ga in &mut groups_accumulators { + let state = ga.state(EmitTo::All)?; + columns.extend(state.into_iter()); } } AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::Full => { - for (i, ga) in groups_accumulators.iter().enumerate() { - // merge the state to the final value - let v: ScalarValue = ga.peek_evaluate(group_index)?; - if add_columns { - columns.push(create_builder(&v)); - assert_eq!(i + 1, columns.len()); - } - append_value(&mut *columns[i], &v)?; + for ga in &mut groups_accumulators { + let value = ga.evaluate(EmitTo::All)?; + columns.push(value); } } } - Ok(()) + Ok(columns) } /// returns a vector of ArrayRefs, where each entry corresponds to either the diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index a9ed176b3642..1853a77e89cc 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -686,6 +686,7 @@ pub mod group_scalar; pub mod groups_accumulator; pub mod groups_accumulator_adapter; pub mod groups_accumulator_flat_adapter; +pub mod groups_accumulator_prim_op; pub mod hash_aggregate; pub mod hash_join; pub mod hash_utils; @@ -696,6 +697,7 @@ pub mod memory; pub mod merge; pub mod merge_join; pub mod merge_sort; +pub mod null_state; pub mod parquet; pub mod planner; pub mod projection; diff --git a/datafusion/src/physical_plan/null_state.rs b/datafusion/src/physical_plan/null_state.rs new file mode 100644 index 000000000000..ff1723fb166e --- /dev/null +++ b/datafusion/src/physical_plan/null_state.rs @@ -0,0 +1,949 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`] +//! +//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator + +use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray}; +use arrow::bitmap::Bitmap; +use arrow::buffer::{Buffer, MutableBuffer}; +// use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::datatypes::ArrowPrimitiveType; + +use crate::physical_plan::groups_accumulator::EmitTo; +/// Track the accumulator null state per row: if any values for that +/// group were null and if any values have been seen at all for that group. +/// +/// This is part of the inner loop for many [`GroupsAccumulator`]s, +/// and thus the performance is critical and so there are multiple +/// specialized implementations, invoked depending on the specific +/// combinations of the input. +/// +/// Typically there are 4 potential combinations of inputs must be +/// special cased for performance: +/// +/// * With / Without filter +/// * With / Without nulls in the input +/// +/// If the input has nulls, then the accumulator must potentially +/// handle each input null value specially (e.g. for `SUM` to mark the +/// corresponding sum as null) +/// +/// If there are filters present, `NullState` tracks if it has seen +/// *any* value for that group (as some values may be filtered +/// out). Without a filter, the accumulator is only passed groups that +/// had at least one value to accumulate so they do not need to track +/// if they have seen values for a particular group. +/// +/// [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator +#[derive(Debug)] +pub struct NullState { + /// Have we seen any non-filtered input values for `group_index`? + /// + /// If `seen_values[i]` is true, have seen at least one non null + /// value for group `i` + /// + /// If `seen_values[i]` is false, have not seen any values that + /// pass the filter yet for group `i` + seen_values: BooleanBufferBuilder, +} + +impl Default for NullState { + fn default() -> Self { + Self::new() + } +} + +impl NullState { + #[allow(missing_docs)] + pub fn new() -> Self { + Self { + seen_values: BooleanBufferBuilder::new(0), + } + } + + /// return the size of all buffers allocated by this null state, not including self + pub fn size(&self) -> usize { + // capacity is in bits, so convert to bytes + self.seen_values.capacity() / 8 + } + + /// Gets `seen_values[index]`, i.e. looks up the bitmap. Unused but exists for + /// debugging or documentation. + pub fn get_is_valid(&self, index: usize) -> bool { + self.seen_values.get_bit(index) + } + + /// Invokes `value_fn(group_index, value)` for each non null, non + /// filtered value of `value`, while tracking which groups have + /// seen null inputs and which groups have seen any inputs if necessary + // + /// # Arguments: + /// + /// * `values`: the input arguments to the accumulator + /// * `group_indices`: To which groups do the rows in `values` belong, (aka group_index) + /// * `opt_filter`: if present, only rows for which is Some(true) are included + /// * `value_fn`: function invoked for (group_index, value) where value is non null + /// + /// See [`accumulate`], for more details on how value_fn is called + /// + /// When value_fn is called it also sets + /// + /// 1. `self.seen_values[group_index]` to true for all rows that had a non null value + pub fn accumulate( + &mut self, + group_indices: &[usize], + values: &PrimitiveArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, T::Native) + Send, + { + // ensure the seen_values is big enough (start everything at + // "not seen" valid) + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + accumulate(group_indices, values, opt_filter, |group_index, value| { + seen_values.set_bit(group_index, true); + value_fn(group_index, value); + }); + } + + /// Invokes `value_fn(group_index, value)` for each non null, non + /// filtered value in `values`, while tracking which groups have + /// seen null inputs and which groups have seen any inputs, for + /// [`BooleanArray`]s. + /// + /// Since `BooleanArray` is not a [`PrimitiveArray`] it must be + /// handled specially. + /// + /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for + /// more details on other arguments. + pub fn accumulate_boolean( + &mut self, + group_indices: &[usize], + values: &BooleanArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + F: FnMut(usize, bool) + Send, + { + // Upstream this is a BooleanBuffer. To minimize code changes, we work with a buffer (of packed bits) directly here. + let data: &arrow::buffer::Buffer = values.values(); + assert_eq!(data.len(), group_indices.len().div_ceil(8)); + + // ensure the seen_values is big enough (start everything at + // "not seen" valid) + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + + // TODO: Using values.value(i), values.is_null(i), etc., is probably bad performance, but we + // were already using fairly similar iterators. + + // These could be made more performant by iterating in chunks of 64 bits at a time + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + // if we have previously seen nulls, ensure the null + // buffer is big enough (start everything at valid) + for (i, &group_index) in group_indices.iter().enumerate() { + let new_value = values.value(i); + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + } + // nulls, no filter + (true, None) => { + // let nulls = values.nulls().unwrap(); // TODO: Try faster usage with direct access to null bitmaps. + for (i, &group_index) in group_indices.iter().enumerate() { + let new_value = values.value(i); + let is_valid = !values.is_null(i); + if is_valid { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + } + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + for (i, &group_index) in group_indices.iter().enumerate() { + let new_value = values.value(i); + let filter_value = !filter.is_null(i) && filter.value(i); + if filter_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + } + } + // both null values and filters + (true, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + for (i, &group_index) in group_indices.iter().enumerate() { + let new_value = values.value(i); + let new_value_is_valid = !values.is_null(i); + let filter_value = !filter.is_null(i) & filter.value(i); + if filter_value { + if new_value_is_valid { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value) + } + } + } + } + } + } + + /// Creates the a [`NullBuffer`] representing which group_indices + /// should have null values (because they never saw any values) + /// for the `emit_to` rows. + /// + /// resets the internal state appropriately + pub fn build(&mut self, emit_to: EmitTo) -> Bitmap { + let original_len = self.seen_values.len(); + let nulls: Buffer = self.seen_values.finish(); + + let nulls: Buffer = match emit_to { + EmitTo::All => nulls, + EmitTo::First(n) => { + // split off the first N values in seen_values + + let new_len = original_len - n; + + // + // TODO make this more efficient rather than two + // copies and bitwise manipulation (upstream df comment) + // TODO: Of course use 64 bits, at least. Or simd or such. Sigh. + let num_bytes = n.div_ceil(8); + let nulls_slice = nulls.as_slice(); + let first_n_null: Buffer = Buffer::from(&nulls_slice[0..num_bytes]); + if n % 8 == 0 { + let mut mutable_buffer = + MutableBuffer::with_capacity(nulls_slice.len()); + mutable_buffer.extend_from_slice(&nulls_slice[num_bytes..]); + self.seen_values = + BooleanBufferBuilder::new_from_buffer(mutable_buffer, new_len); + } else { + let mut mutable_buffer = + MutableBuffer::with_capacity(nulls_slice.len()); + + // TODO: We could find a fast library function for this... or add an offset field to self.seen_values. At least do 64 bit. + let misalignment = n % 8; + // Because n % 8 != 0, we know num_bytes > 0. + let mut leftover = nulls_slice[num_bytes - 1] >> misalignment; + for seen in &nulls_slice[num_bytes..] { + let seen: u8 = *seen; + let rejoined = leftover | (seen << (8 - misalignment)); + mutable_buffer.push(rejoined); + leftover = seen >> misalignment; + } + mutable_buffer.push(leftover); + mutable_buffer.resize(new_len.div_ceil(8), 0); + self.seen_values = + BooleanBufferBuilder::new_from_buffer(mutable_buffer, new_len); + } + + first_n_null + } + }; + Bitmap::from(nulls) + } +} + +/// Invokes `value_fn(group_index, value)` for each non null, non +/// filtered value of `value`, +/// +/// # Arguments: +/// +/// * `group_indices`: To which groups do the rows in `values` belong, (aka group_index) +/// * `values`: the input arguments to the accumulator +/// * `opt_filter`: if present, only rows for which is Some(true) are included +/// * `value_fn`: function invoked for (group_index, value) where value is non null +/// +/// # Example +/// +/// ```text +/// ┌─────────┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ +/// │ ┌─────┐ │ │ ┌─────┐ │ ┌─────┐ +/// │ │ 2 │ │ │ │ 200 │ │ │ │ t │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 2 │ │ │ │ 100 │ │ │ │ f │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 0 │ │ │ │ 200 │ │ │ │ t │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 1 │ │ │ │ 200 │ │ │ │NULL │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 0 │ │ │ │ 300 │ │ │ │ t │ │ +/// │ └─────┘ │ │ └─────┘ │ └─────┘ +/// └─────────┘ └─────────┘ └ ─ ─ ─ ─ ┘ +/// +/// group_indices values opt_filter +/// ``` +/// +/// In the example above, `value_fn` is invoked for each (group_index, +/// value) pair where `opt_filter[i]` is true and values is non null +/// +/// ```text +/// value_fn(2, 200) +/// value_fn(0, 200) +/// value_fn(0, 300) +/// ``` +pub fn accumulate( + group_indices: &[usize], + values: &PrimitiveArray, + opt_filter: Option<&BooleanArray>, + mut value_fn: F, +) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, T::Native) + Send, +{ + let data: &[T::Native] = values.values(); + assert_eq!(data.len(), group_indices.len()); + + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + let iter = group_indices.iter().zip(data.iter()); + for (&group_index, &new_value) in iter { + value_fn(group_index, new_value); + } + } + // nulls, no filter + (true, None) => { + let valids: &Bitmap = values.data().null_bitmap().as_ref().unwrap(); + // This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum + // iterate over in chunks of 64 bits for more efficient null checking + let group_indices_chunks = group_indices.chunks_exact(64); + let data_chunks = data.chunks_exact(64); + let bit_chunks = valids + .buffer_ref() + .bit_chunks(values.offset(), values.len()); + + let group_indices_remainder = group_indices_chunks.remainder(); + let data_remainder = data_chunks.remainder(); + + group_indices_chunks + .zip(data_chunks) + .zip(bit_chunks.iter()) + .for_each(|((group_index_chunk, data_chunk), mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().zip(data_chunk.iter()).for_each( + |(&group_index, &new_value)| { + // valid bit was set, real value + let is_valid = (mask & index_mask) != 0; + if is_valid { + value_fn(group_index, new_value); + } + index_mask <<= 1; + }, + ) + }); + + // handle any remaining bits (after the initial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .zip(data_remainder.iter()) + .enumerate() + .for_each(|(i, (&group_index, &new_value))| { + let is_valid = remainder_bits & (1 << i) != 0; + if is_valid { + value_fn(group_index, new_value); + } + }); + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + // The performance with a filter could be improved by + // iterating over the filter in chunks, rather than a single + // iterator. TODO file a ticket (upstream df comment) + group_indices + .iter() + .zip(data.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, &new_value), filter_value)| { + if let Some(true) = filter_value { + value_fn(group_index, new_value); + } + }) + } + // both null values and filters + (true, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + // The performance with a filter could be improved by + // iterating over the filter in chunks, rather than using + // iterators. TODO file a ticket (upstream df comment) + filter + .iter() + .zip(group_indices.iter()) + .zip(values.iter()) + .for_each(|((filter_value, &group_index), new_value)| { + if let Some(true) = filter_value { + if let Some(new_value) = new_value { + value_fn(group_index, new_value) + } + } + }) + } + } +} + +/// This function is called to update the accumulator state per row +/// when the value is not needed (e.g. COUNT) +/// +/// `F`: Invoked like `value_fn(group_index) for all non null values +/// passing the filter. Note that no tracking is done for null inputs +/// or which groups have seen any values +/// +/// See [`NullState::accumulate`], for more details on other +/// arguments. +pub fn accumulate_indices( + group_indices: &[usize], + nulls: Option<(&Bitmap, usize, usize)>, + opt_filter: Option<&BooleanArray>, + mut index_fn: F, +) where + F: FnMut(usize) + Send, +{ + match (nulls, opt_filter) { + (None, None) => { + for &group_index in group_indices.iter() { + index_fn(group_index) + } + } + (None, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + // The performance with a filter could be improved by + // iterating over the filter in chunks, rather than a single + // iterator. TODO file a ticket (upstream df comment) + let iter = group_indices.iter().zip(filter.iter()); + for (&group_index, filter_value) in iter { + if let Some(true) = filter_value { + index_fn(group_index) + } + } + } + (Some((valids, valids_offset, valids_len)), None) => { + assert_eq!(valids_len, group_indices.len()); + // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum + // iterate over in chunks of 64 bits for more efficient null checking + let group_indices_chunks = group_indices.chunks_exact(64); + let bit_chunks = valids.buffer_ref().bit_chunks(valids_offset, valids_len); + + let group_indices_remainder = group_indices_chunks.remainder(); + + group_indices_chunks.zip(bit_chunks.iter()).for_each( + |(group_index_chunk, mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().for_each(|&group_index| { + // valid bit was set, real vale + let is_valid = (mask & index_mask) != 0; + if is_valid { + index_fn(group_index); + } + index_mask <<= 1; + }) + }, + ); + + // handle any remaining bits (after the initial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .enumerate() + .for_each(|(i, &group_index)| { + let is_valid = remainder_bits & (1 << i) != 0; + if is_valid { + index_fn(group_index) + } + }); + } + + (Some((valids, valids_offset, valids_len)), Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + assert_eq!(valids_len, group_indices.len()); + + // The performance with a filter could likely be improved by + // iterating over the filter in chunks, rather than using + // iterators. TODO file a ticket (upstream df comment) + filter + .iter() + .zip(group_indices.iter()) + .zip(valids.make_iter(valids_offset, valids_len)) + .for_each(|((filter_value, &group_index), is_valid)| { + if let (Some(true), true) = (filter_value, is_valid) { + index_fn(group_index) + } + }) + } + } +} + +/// Ensures that `builder` contains a `BooleanBufferBuilder with at +/// least `total_num_groups`. +/// +/// All new entries are initialized to `default_value` +fn initialize_builder( + builder: &mut BooleanBufferBuilder, + total_num_groups: usize, + default_value: bool, +) -> &mut BooleanBufferBuilder { + if builder.len() < total_num_groups { + let new_groups = total_num_groups - builder.len(); + builder.append_n(new_groups, default_value); + } + builder +} + +#[cfg(test)] +mod test { + use super::*; + + use arrow::array::UInt32Array; + use rand::{rngs::ThreadRng, Rng}; + use std::collections::HashSet; + + #[test] + fn accumulate() { + let group_indices = (0..100).collect(); + let values = (0..100).map(|i| (i + 1) * 10).collect(); + let values_with_nulls = (0..100) + .map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) }) + .collect(); + + // default to every fifth value being false, every even + // being null + let filter: BooleanArray = (0..100) + .map(|i| { + let is_even = i % 2 == 0; + let is_fifth = i % 5 == 0; + if is_even { + None + } else if is_fifth { + Some(false) + } else { + Some(true) + } + }) + .collect(); + + Fixture { + group_indices, + values, + values_with_nulls, + filter, + } + .run() + } + + #[test] + fn accumulate_fuzz() { + let mut rng = rand::thread_rng(); + for _ in 0..100 { + Fixture::new_random(&mut rng).run(); + } + } + + /// Values for testing (there are enough values to exercise the 64 bit chunks + struct Fixture { + /// 100..0 + group_indices: Vec, + + /// 10, 20, ... 1010 + values: Vec, + + /// same as values, but every third is null: + /// None, Some(20), Some(30), None ... + values_with_nulls: Vec>, + + /// filter (defaults to None) + filter: BooleanArray, + } + + impl Fixture { + fn new_random(rng: &mut ThreadRng) -> Self { + // Number of input values in a batch + let num_values: usize = rng.gen_range(1..200); + // number of distinct groups + let num_groups: usize = rng.gen_range(2..1000); + let max_group = num_groups - 1; + + let group_indices: Vec = (0..num_values) + .map(|_| rng.gen_range(0..max_group)) + .collect(); + + let values: Vec = (0..num_values).map(|_| rng.gen()).collect(); + + // 10% chance of false + // 10% change of null + // 80% chance of true + let filter: BooleanArray = (0..num_values) + .map(|_| { + let filter_value = rng.gen_range(0.0..1.0); + if filter_value < 0.1 { + Some(false) + } else if filter_value < 0.2 { + None + } else { + Some(true) + } + }) + .collect(); + + // random values with random number and location of nulls + // random null percentage + let null_pct: f32 = rng.gen_range(0.0..1.0); + let values_with_nulls: Vec> = (0..num_values) + .map(|_| { + let is_null = null_pct < rng.gen_range(0.0..1.0); + if is_null { + None + } else { + Some(rng.gen()) + } + }) + .collect(); + + Self { + group_indices, + values, + values_with_nulls, + filter, + } + } + + /// returns `Self::values` an Array + fn values_array(&self) -> UInt32Array { + UInt32Array::from(self.values.clone()) + } + + /// returns `Self::values_with_nulls` as an Array + fn values_with_nulls_array(&self) -> UInt32Array { + UInt32Array::from(self.values_with_nulls.clone()) + } + + /// Calls `NullState::accumulate` and `accumulate_indices` + /// with all combinations of nulls and filter values + fn run(&self) { + let total_num_groups = *self.group_indices.iter().max().unwrap() + 1; + + let group_indices = &self.group_indices; + let values_array = self.values_array(); + let values_with_nulls_array = self.values_with_nulls_array(); + let filter = &self.filter; + + // no null, no filters + Self::accumulate_test(group_indices, &values_array, None, total_num_groups); + + // nulls, no filters + Self::accumulate_test( + group_indices, + &values_with_nulls_array, + None, + total_num_groups, + ); + + // no nulls, filters + Self::accumulate_test( + group_indices, + &values_array, + Some(filter), + total_num_groups, + ); + + // nulls, filters + Self::accumulate_test( + group_indices, + &values_with_nulls_array, + Some(filter), + total_num_groups, + ); + } + + /// Calls `NullState::accumulate` and `accumulate_indices` to + /// ensure it generates the correct values. + /// + fn accumulate_test( + group_indices: &[usize], + values: &UInt32Array, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) { + Self::accumulate_values_test( + group_indices, + values, + opt_filter, + total_num_groups, + ); + Self::accumulate_indices_test( + group_indices, + values + .data() + .null_bitmap() + .as_ref() + .map(|bitmap| (bitmap, values.data().offset(), values.data().len())), + opt_filter, + ); + + // Convert values into a boolean array (anything above the + // average is true, otherwise false) + let avg: usize = values.iter().filter_map(|v| v.map(|v| v as usize)).sum(); + let boolean_values: BooleanArray = + values.iter().map(|v| v.map(|v| v as usize > avg)).collect(); + Self::accumulate_boolean_test( + group_indices, + &boolean_values, + opt_filter, + total_num_groups, + ); + } + + /// This is effectively a different implementation of + /// accumulate that we compare with the above implementation + fn accumulate_values_test( + group_indices: &[usize], + values: &UInt32Array, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) { + let mut accumulated_values = vec![]; + let mut null_state = NullState::new(); + + null_state.accumulate( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, value| { + accumulated_values.push((group_index, value)); + }, + ); + + // Figure out the expected values + let mut expected_values = vec![]; + let mut mock = MockNullState::new(); + + match opt_filter { + None => group_indices.iter().zip(values.iter()).for_each( + |(&group_index, value)| { + if let Some(value) = value { + mock.saw_value(group_index); + expected_values.push((group_index, value)); + } + }, + ), + Some(filter) => { + group_indices + .iter() + .zip(values.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, value), is_included)| { + // if value passed filter + if let Some(true) = is_included { + if let Some(value) = value { + mock.saw_value(group_index); + expected_values.push((group_index, value)); + } + } + }); + } + } + + assert_eq!(accumulated_values, expected_values, + "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); + let seen_values = Bitmap::from(null_state.seen_values.finish_cloned()); + mock.validate_seen_values((&seen_values, 0, null_state.seen_values.len())); + + // Validate the final buffer (one value per group) + let expected_null_buffer = mock.expected_null_buffer(total_num_groups); + + let null_buffer = null_state.build(EmitTo::All); + + assert_eq!(null_buffer, expected_null_buffer); + } + + // Calls `accumulate_indices` + // and opt_filter and ensures it calls the right values + fn accumulate_indices_test( + group_indices: &[usize], + nulls: Option<(&Bitmap, usize, usize)>, + opt_filter: Option<&BooleanArray>, + ) { + let mut accumulated_values = vec![]; + + accumulate_indices(group_indices, nulls, opt_filter, |group_index| { + accumulated_values.push(group_index); + }); + + // Figure out the expected values + let mut expected_values = vec![]; + + match (nulls, opt_filter) { + (None, None) => group_indices.iter().for_each(|&group_index| { + expected_values.push(group_index); + }), + (Some(nulls), None) => group_indices + .iter() + .zip(nulls.0.make_iter(nulls.1, nulls.2)) + .for_each(|(&group_index, is_valid)| { + if is_valid { + expected_values.push(group_index); + } + }), + (None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each( + |(&group_index, is_included)| { + if let Some(true) = is_included { + expected_values.push(group_index); + } + }, + ), + (Some(nulls), Some(filter)) => { + group_indices + .iter() + .zip(nulls.0.make_iter(nulls.1, nulls.2)) + .zip(filter.iter()) + .for_each(|((&group_index, is_valid), is_included)| { + // if value passed filter + if let (true, Some(true)) = (is_valid, is_included) { + expected_values.push(group_index); + } + }); + } + } + + assert_eq!(accumulated_values, expected_values, + "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); + } + + /// This is effectively a different implementation of + /// accumulate_boolean that we compare with the above implementation + fn accumulate_boolean_test( + group_indices: &[usize], + values: &BooleanArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) { + let mut accumulated_values = vec![]; + let mut null_state = NullState::new(); + + null_state.accumulate_boolean( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, value| { + accumulated_values.push((group_index, value)); + }, + ); + + // Figure out the expected values + let mut expected_values = vec![]; + let mut mock = MockNullState::new(); + + match opt_filter { + None => group_indices.iter().zip(values.iter()).for_each( + |(&group_index, value)| { + if let Some(value) = value { + mock.saw_value(group_index); + expected_values.push((group_index, value)); + } + }, + ), + Some(filter) => { + group_indices + .iter() + .zip(values.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, value), is_included)| { + // if value passed filter + if let Some(true) = is_included { + if let Some(value) = value { + mock.saw_value(group_index); + expected_values.push((group_index, value)); + } + } + }); + } + } + + assert_eq!(accumulated_values, expected_values, + "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); + + let seen_values = Bitmap::from(null_state.seen_values.finish_cloned()); + mock.validate_seen_values((&seen_values, 0, null_state.seen_values.len())); + + // Validate the final buffer (one value per group) + let expected_null_buffer = mock.expected_null_buffer(total_num_groups); + + let null_buffer = null_state.build(EmitTo::All); + + assert_eq!(null_buffer, expected_null_buffer); + } + } + + /// Parallel implementation of NullState to check expected values + #[derive(Debug, Default)] + struct MockNullState { + /// group indices that had values that passed the filter + seen_values: HashSet, + } + + impl MockNullState { + fn new() -> Self { + Default::default() + } + + fn saw_value(&mut self, group_index: usize) { + self.seen_values.insert(group_index); + } + + /// did this group index see any input? + fn expected_seen(&self, group_index: usize) -> bool { + self.seen_values.contains(&group_index) + } + + /// Validate that the seen_values matches self.seen_values + fn validate_seen_values(&self, seen_values: (&Bitmap, usize, usize)) { + for (group_index, is_seen) in seen_values + .0 + .make_iter(seen_values.1, seen_values.2) + .enumerate() + { + let expected_seen = self.expected_seen(group_index); + assert_eq!( + expected_seen, is_seen, + "mismatch at for group {group_index}" + ); + } + } + + /// Create the expected null buffer based on if the input had nulls and a filter + fn expected_null_buffer(&self, total_num_groups: usize) -> Bitmap { + let buf: Buffer = (0..total_num_groups) + .map(|group_index| self.expected_seen(group_index)) + .collect(); + Bitmap::from(buf) + } + } +}