Skip to content

Commit

Permalink
Fix filter UB and add fast path (#341) (#372)
Browse files Browse the repository at this point in the history
* fix ub in filter record_batch

* filter fast path

* add all false fast path

* use new_empty_array

* rename filter kernel argument

rename argument: 'filter' to 'predicate'
to reduce name collissions.

Co-authored-by: Ritchie Vink <[email protected]>
  • Loading branch information
alamb and ritchie46 authored May 27, 2021
1 parent 58d53cf commit 8cc9b71
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 33 deletions.
2 changes: 1 addition & 1 deletion arrow/src/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ impl ArrayData {
}

/// Returns a new empty [ArrayData] valid for `data_type`.
pub(super) fn new_empty(data_type: &DataType) -> Self {
pub fn new_empty(data_type: &DataType) -> Self {
let buffers = new_buffers(data_type, 0);
let [buffer1, buffer2] = buffers;
let buffers = into_buffers(data_type, buffer1, buffer2);
Expand Down
115 changes: 83 additions & 32 deletions arrow/src/compute/kernels/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,37 @@ pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
let chunks = iter.collect::<Vec<_>>();

Ok(Box::new(move |array: &ArrayData| {
let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
chunks
.iter()
.for_each(|(start, end)| mutable.extend(0, *start, *end));
mutable.freeze()
match filter_count {
// return all
len if len == array.len() => array.clone(),
0 => ArrayData::new_empty(array.data_type()),
_ => {
let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
chunks
.iter()
.for_each(|(start, end)| mutable.extend(0, *start, *end));
mutable.freeze()
}
}
}))
}

/// Remove null values by do a bitmask AND operation with null bits and the boolean bits.
fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
let array_data = filter.data_ref();
let null_bitmap = array_data.null_buffer().unwrap();
let mask = filter.values();
let offset = filter.offset();

let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len());

let array_data = ArrayData::builder(DataType::Boolean)
.len(filter.len())
.add_buffer(new_mask)
.build();
BooleanArray::from(array_data)
}

/// Filters an [Array], returning elements matching the filter (i.e. where the values are true).
///
/// # Example
Expand All @@ -221,43 +244,49 @@ pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
/// # Ok(())
/// # }
/// ```
pub fn filter(array: &Array, filter: &BooleanArray) -> Result<ArrayRef> {
if filter.null_count() > 0 {
pub fn filter(array: &Array, predicate: &BooleanArray) -> Result<ArrayRef> {
if predicate.null_count() > 0 {
// this greatly simplifies subsequent filtering code
// now we only have a boolean mask to deal with
let array_data = filter.data_ref();
let null_bitmap = array_data.null_buffer().unwrap();
let mask = filter.values();
let offset = filter.offset();

let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len());

let array_data = ArrayData::builder(DataType::Boolean)
.len(filter.len())
.add_buffer(new_mask)
.build();
let filter = BooleanArray::from(array_data);
// fully qualified syntax, because we have an argument with the same name
return crate::compute::kernels::filter::filter(array, &filter);
let predicate = prep_null_mask_filter(predicate);
return filter(array, &predicate);
}

let iter = SlicesIterator::new(filter);

let mut mutable =
MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count);
iter.for_each(|(start, end)| mutable.extend(0, start, end));
let data = mutable.freeze();
Ok(make_array(data))
let iter = SlicesIterator::new(predicate);
match iter.filter_count {
0 => {
// return empty
Ok(new_empty_array(array.data_type()))
}
len if len == array.len() => {
// return all
let data = array.data().clone();
Ok(make_array(data))
}
_ => {
// actually filter
let mut mutable =
MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count);
iter.for_each(|(start, end)| mutable.extend(0, start, end));
let data = mutable.freeze();
Ok(make_array(data))
}
}
}

/// Returns a new [RecordBatch] with arrays containing only values matching the filter.
/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
/// Therefore, it is considered undefined behavior to pass `filter` with null values.
pub fn filter_record_batch(
record_batch: &RecordBatch,
filter: &BooleanArray,
predicate: &BooleanArray,
) -> Result<RecordBatch> {
let filter = build_filter(filter)?;
if predicate.null_count() > 0 {
// this greatly simplifies subsequent filtering code
// now we only have a boolean mask to deal with
let predicate = prep_null_mask_filter(predicate);
return filter_record_batch(record_batch, &predicate);
}

let filter = build_filter(predicate)?;
let filtered_arrays = record_batch
.columns()
.iter()
Expand Down Expand Up @@ -625,4 +654,26 @@ mod tests {
assert_eq!(out_arr0, out_arr1);
Ok(())
}

#[test]
fn test_fast_path() -> Result<()> {
let a: PrimitiveArray<Int64Type> =
PrimitiveArray::from(vec![Some(1), Some(2), None]);

// all true
let mask = BooleanArray::from(vec![true, true, true]);
let out = filter(&a, &mask)?;
let b = out
.as_any()
.downcast_ref::<PrimitiveArray<Int64Type>>()
.unwrap();
assert_eq!(&a, b);

// all false
let mask = BooleanArray::from(vec![false, false, false]);
let out = filter(&a, &mask)?;
assert_eq!(out.len(), 0);
assert_eq!(out.data_type(), &DataType::Int64);
Ok(())
}
}

0 comments on commit 8cc9b71

Please sign in to comment.