Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix equal_to in PrimitiveGroupValueBuilder #12758

Merged
merged 4 commits into from
Oct 5, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 98 additions & 9 deletions datafusion/physical-plan/src/aggregates/group_values/group_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,28 @@ impl<T: ArrowPrimitiveType, const NULLABLE: bool> GroupColumn
for PrimitiveGroupValueBuilder<T, NULLABLE>
{
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
// Perf: skip null check (by short circuit) if input is not ullable
let null_match = if NULLABLE {
self.nulls.is_null(lhs_row) == array.is_null(rhs_row)
} else {
true
};
// Perf: skip null check (by short circuit) if input is not nullable
if NULLABLE {
// In nullable path, we should check if both `exist row` and `input row`
// are null/not null
let is_exist_null = self.nulls.is_null(lhs_row);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this last night and I think the byte buffer below has the same problem. I will make a follow on PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed it does -- #12770 to fix

let null_match = self.nulls.is_null(lhs_row) == array.is_null(rhs_row);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be faster to avoid calling is_null twice:

Suggested change
let null_match = self.nulls.is_null(lhs_row) == array.is_null(rhs_row);
let null_match = is_exist_null == array.is_null(rhs_row);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fixed.

if !null_match {
// If `is_null`s in `exist row` and `input row` don't match, return not equal to
return false;
} else if is_exist_null {
// If `is_null`s in `exist row` and `input row` match, and they are `null`s,
// return equal to
//
// NOTICE: we should not check their values when they are `null`s, because they are
// meaningless actually, and not ensured to be same
//
return true;
}
// Otherwise, we need to check their values
}

null_match
&& self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
}

fn append_val(&mut self, array: &ArrayRef, row: usize) {
Expand Down Expand Up @@ -373,9 +386,13 @@ where
mod tests {
use std::sync::Arc;

use arrow_array::{ArrayRef, StringArray};
use arrow::datatypes::Int64Type;
use arrow_array::{ArrayRef, Int64Array, StringArray};
use arrow_buffer::{BooleanBufferBuilder, NullBuffer};
use datafusion_physical_expr::binary_map::OutputType;

use crate::aggregates::group_values::group_column::PrimitiveGroupValueBuilder;

use super::{ByteGroupValueBuilder, GroupColumn};

#[test]
Expand Down Expand Up @@ -422,4 +439,76 @@ mod tests {
])) as ArrayRef;
assert_eq!(&output, &array);
}

#[test]
fn test_nullable_primitive_equal_to() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is some way to write a reproducer in an end to end test (as in .slt as well) 🤔

Copy link
Contributor Author

@Rachelint Rachelint Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 I am trying it, but I still have no idea.
I found it through the fuzz tests in #12667 . And to be honest, I am still confused about why the null row will have the non-default value...

Copy link
Contributor Author

@Rachelint Rachelint Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it !

It may be due to the compute::take function using to generate the random dataset.
https://github.com/Rachelint/arrow-datafusion/blob/9ad971be0c6c808e77a74cdfc571a33732a0838a/test-utils/src/array_gen/primitive.rs#L48-L62

Let's see take's implementation:

// In `take_primitive`:
let values_buf = take_native(values.values(), indices);
let nulls = take_nulls(values.nulls(), indices);

// In `take_native`:
    match indices.nulls().filter(|n| n.null_count() > 0) {
        Some(n) => indices
            .values()
            .iter()
            .enumerate()
            .map(|(idx, index)| match values.get(index.as_usize()) {
                Some(v) => *v,
                None => match n.is_null(idx) {
                    true => T::default(),
                    false => panic!("Out-of-bounds index {index:?}"),
                },
            })
            .collect(),

It will still try to take value from values ranther than using default value, even the row in indicies is null...

This logic seems unreasonable actually?

Copy link
Contributor Author

@Rachelint Rachelint Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be a bit hard to produce it in end to end test.

The null row with a non-default value is only possible to exist in some special cases,
like generating through take as mentioned above, or as I remember we use it to improve filter performance in avg accumulator.

// Will cover such cases:
// - exist null, input not null
// - exist null, input null; values not equal
// - exist null, input null; values equal
// - exist not null, input null
// - exist not null, input not null; values not equal
// - exist not null, input not null; values equal

// Define PrimitiveGroupValueBuilder
let mut builder = PrimitiveGroupValueBuilder::<Int64Type, true>::new();
let builder_array = Arc::new(Int64Array::from(vec![
None,
None,
None,
Some(1),
Some(2),
Some(3),
])) as ArrayRef;
builder.append_val(&builder_array, 0);
builder.append_val(&builder_array, 1);
builder.append_val(&builder_array, 2);
builder.append_val(&builder_array, 3);
builder.append_val(&builder_array, 4);
builder.append_val(&builder_array, 5);

// Define input array
let (_, values, _) =
Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)])
.into_parts();

let mut boolean_buffer_builder = BooleanBufferBuilder::new(6);
boolean_buffer_builder.append(true);
boolean_buffer_builder.append(false);
boolean_buffer_builder.append(false);
boolean_buffer_builder.append(false);
boolean_buffer_builder.append(true);
boolean_buffer_builder.append(true);
let nulls = NullBuffer::new(boolean_buffer_builder.finish());
let input_array = Arc::new(Int64Array::new(values, Some(nulls))) as ArrayRef;

// Check
assert!(!builder.equal_to(0, &input_array, 0));
assert!(builder.equal_to(1, &input_array, 1));
assert!(builder.equal_to(2, &input_array, 2));
assert!(!builder.equal_to(3, &input_array, 3));
assert!(!builder.equal_to(4, &input_array, 4));
assert!(builder.equal_to(5, &input_array, 5));
}

#[test]
fn test_not_nullable_primitive_equal_to() {
// Will cover such cases:
// - values equal
// - values not equal

// Define PrimitiveGroupValueBuilder
let mut builder = PrimitiveGroupValueBuilder::<Int64Type, false>::new();
let builder_array =
Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef;
builder.append_val(&builder_array, 0);
builder.append_val(&builder_array, 1);

// Define input array
let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef;

// Check
assert!(builder.equal_to(0, &input_array, 0));
assert!(!builder.equal_to(1, &input_array, 1));
}
}