From 773b9c56034b43ed6d740cbeb3ba7bef903d3d2e Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Thu, 28 Nov 2024 19:27:44 +0800 Subject: [PATCH] feedbacks --- .../groups_accumulator/accumulate.rs | 44 ++++++------------- .../functions-aggregate/src/correlation.rs | 10 +---- 2 files changed, 15 insertions(+), 39 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 95fa9e8bee03..38afeda8ccf8 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -385,11 +385,6 @@ pub fn accumulate_multiple( T: ArrowPrimitiveType + Send, F: FnMut(usize, &[T::Native]) + Send, { - let acc_cols: Vec<&[T::Native]> = value_columns - .iter() - .map(|arr| arr.values().as_ref()) - .collect(); - // Calculate `valid_indices` to accumulate, non-valid indices are ignored. // `valid_indices` is a bit mask corresponding to the `group_indices`. An index // is considered valid if: @@ -397,24 +392,10 @@ pub fn accumulate_multiple( // 2. Not filtered out by `opt_filter` // Take AND from all null buffers of `value_columns`. - let mut combined_nulls: Option = None; - - for arr in value_columns.iter() { - if arr.null_count() > 0 { - let nulls = arr - .nulls() - .expect("If null_count() > 0, nulls must be present"); - match combined_nulls { - None => { - combined_nulls = Some(nulls.clone()); - } - Some(ref mut combined) => { - let result = NullBuffer::union(Some(combined), Some(nulls)).unwrap(); - *combined = result.clone(); - } - } - } - } + let combined_nulls = value_columns + .iter() + .map(|arr| arr.nulls()) + .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); // Take AND from previous combined nulls and `opt_filter`. let valid_indices = match (combined_nulls, opt_filter) { @@ -427,7 +408,7 @@ pub fn accumulate_multiple( } }; - for col in acc_cols.iter() { + for col in value_columns.iter() { assert_eq!(col.len(), group_indices.len()); } @@ -435,7 +416,8 @@ pub fn accumulate_multiple( None => { for (idx, &group_idx) in group_indices.iter().enumerate() { // Get `idx`-th row from all value(accumulate) columns - let row_values: Vec<_> = acc_cols.iter().map(|col| col[idx]).collect(); + let row_values: Vec<_> = + value_columns.iter().map(|col| col.value(idx)).collect(); value_fn(group_idx, &row_values); } } @@ -444,7 +426,7 @@ pub fn accumulate_multiple( if valid_indices.value(idx) { // Get `idx`-th row from all value(accumulate) columns let row_values: Vec<_> = - acc_cols.iter().map(|col| col[idx]).collect(); + value_columns.iter().map(|col| col.value(idx)).collect(); value_fn(group_idx, &row_values); } } @@ -1027,7 +1009,7 @@ mod test { let group_indices = vec![0, 1, 0, 1]; let values1 = Int32Array::from(vec![1, 2, 3, 4]); let values2 = Int32Array::from(vec![10, 20, 30, 40]); - let value_columns = vec![values1, values2]; + let value_columns = [values1, values2]; let mut accumulated = vec![]; accumulate_multiple( @@ -1053,7 +1035,7 @@ mod test { let group_indices = vec![0, 1, 0, 1]; let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]); let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]); - let value_columns = vec![values1, values2]; + let value_columns = [values1, values2]; let mut accumulated = vec![]; accumulate_multiple( @@ -1075,7 +1057,7 @@ mod test { let group_indices = vec![0, 1, 0, 1]; let values1 = Int32Array::from(vec![1, 2, 3, 4]); let values2 = Int32Array::from(vec![10, 20, 30, 40]); - let value_columns = vec![values1, values2]; + let value_columns = [values1, values2]; let filter = BooleanArray::from(vec![true, false, true, false]); @@ -1099,7 +1081,7 @@ mod test { let group_indices = vec![0, 1, 0, 1]; let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]); let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]); - let value_columns = vec![values1, values2]; + let value_columns = [values1, values2]; let filter = BooleanArray::from(vec![true, true, true, false]); @@ -1117,7 +1099,7 @@ mod test { // 1. Filter is true // 2. Both columns are non-null // should be accumulated - let expected = vec![(0, vec![1, 10])]; + let expected = [(0, vec![1, 10])]; assert_eq!(accumulated, expected); } } diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 46e7e27dad59..8b1ef72bc3c0 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -284,6 +284,7 @@ impl Accumulator for CorrelationAccumulator { } } +#[derive(Default)] pub struct CorrelationGroupsAccumulator { // Number of elements for each group // This is also used to track nulls: if a group has 0 valid values accumulated, @@ -303,14 +304,7 @@ pub struct CorrelationGroupsAccumulator { impl CorrelationGroupsAccumulator { pub fn new() -> Self { - Self { - count: Vec::new(), - sum_x: Vec::new(), - sum_y: Vec::new(), - sum_xy: Vec::new(), - sum_xx: Vec::new(), - sum_yy: Vec::new(), - } + Default::default() } }