Skip to content

Commit

Permalink
feedbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
2010YOUY01 committed Nov 28, 2024
1 parent a834fda commit 773b9c5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -385,36 +385,17 @@ pub fn accumulate_multiple<T, F>(
T: ArrowPrimitiveType + Send,
F: FnMut(usize, &[T::Native]) + Send,
{
let acc_cols: Vec<&[T::Native]> = value_columns
.iter()
.map(|arr| arr.values().as_ref())
.collect();

// Calculate `valid_indices` to accumulate, non-valid indices are ignored.
// `valid_indices` is a bit mask corresponding to the `group_indices`. An index
// is considered valid if:
// 1. All columns are non-null at this index.
// 2. Not filtered out by `opt_filter`

// Take AND from all null buffers of `value_columns`.
let mut combined_nulls: Option<NullBuffer> = None;

for arr in value_columns.iter() {
if arr.null_count() > 0 {
let nulls = arr
.nulls()
.expect("If null_count() > 0, nulls must be present");
match combined_nulls {
None => {
combined_nulls = Some(nulls.clone());
}
Some(ref mut combined) => {
let result = NullBuffer::union(Some(combined), Some(nulls)).unwrap();
*combined = result.clone();
}
}
}
}
let combined_nulls = value_columns
.iter()
.map(|arr| arr.nulls())
.fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));

// Take AND from previous combined nulls and `opt_filter`.
let valid_indices = match (combined_nulls, opt_filter) {
Expand All @@ -427,15 +408,16 @@ pub fn accumulate_multiple<T, F>(
}
};

for col in acc_cols.iter() {
for col in value_columns.iter() {
assert_eq!(col.len(), group_indices.len());
}

match valid_indices {
None => {
for (idx, &group_idx) in group_indices.iter().enumerate() {
// Get `idx`-th row from all value(accumulate) columns
let row_values: Vec<_> = acc_cols.iter().map(|col| col[idx]).collect();
let row_values: Vec<_> =
value_columns.iter().map(|col| col.value(idx)).collect();
value_fn(group_idx, &row_values);
}
}
Expand All @@ -444,7 +426,7 @@ pub fn accumulate_multiple<T, F>(
if valid_indices.value(idx) {
// Get `idx`-th row from all value(accumulate) columns
let row_values: Vec<_> =
acc_cols.iter().map(|col| col[idx]).collect();
value_columns.iter().map(|col| col.value(idx)).collect();
value_fn(group_idx, &row_values);
}
}
Expand Down Expand Up @@ -1027,7 +1009,7 @@ mod test {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![1, 2, 3, 4]);
let values2 = Int32Array::from(vec![10, 20, 30, 40]);
let value_columns = vec![values1, values2];
let value_columns = [values1, values2];

let mut accumulated = vec![];
accumulate_multiple(
Expand All @@ -1053,7 +1035,7 @@ mod test {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
let value_columns = vec![values1, values2];
let value_columns = [values1, values2];

let mut accumulated = vec![];
accumulate_multiple(
Expand All @@ -1075,7 +1057,7 @@ mod test {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![1, 2, 3, 4]);
let values2 = Int32Array::from(vec![10, 20, 30, 40]);
let value_columns = vec![values1, values2];
let value_columns = [values1, values2];

let filter = BooleanArray::from(vec![true, false, true, false]);

Expand All @@ -1099,7 +1081,7 @@ mod test {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
let value_columns = vec![values1, values2];
let value_columns = [values1, values2];

let filter = BooleanArray::from(vec![true, true, true, false]);

Expand All @@ -1117,7 +1099,7 @@ mod test {
// 1. Filter is true
// 2. Both columns are non-null
// should be accumulated
let expected = vec![(0, vec![1, 10])];
let expected = [(0, vec![1, 10])];
assert_eq!(accumulated, expected);
}
}
10 changes: 2 additions & 8 deletions datafusion/functions-aggregate/src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ impl Accumulator for CorrelationAccumulator {
}
}

#[derive(Default)]
pub struct CorrelationGroupsAccumulator {
// Number of elements for each group
// This is also used to track nulls: if a group has 0 valid values accumulated,
Expand All @@ -303,14 +304,7 @@ pub struct CorrelationGroupsAccumulator {

impl CorrelationGroupsAccumulator {
pub fn new() -> Self {
Self {
count: Vec::new(),
sum_x: Vec::new(),
sum_y: Vec::new(),
sum_xy: Vec::new(),
sum_xx: Vec::new(),
sum_yy: Vec::new(),
}
Default::default()
}
}

Expand Down

0 comments on commit 773b9c5

Please sign in to comment.