Skip to content

Commit

Permalink
fix invalid null handling in filter (#296)
Browse files Browse the repository at this point in the history
* fix invalid null handling in filter

* take offset into account

* remove incorrect UB warning
  • Loading branch information
ritchie46 authored May 21, 2021
1 parent e18b356 commit f042191
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions arrow/src/compute/kernels/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

//! Defines miscellaneous array kernels.
use crate::buffer::buffer_bin_and;
use crate::datatypes::DataType;
use crate::error::Result;
use crate::record_batch::RecordBatch;
use crate::{array::*, util::bit_chunk_iterator::BitChunkIterator};
Expand Down Expand Up @@ -204,8 +206,7 @@ pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
}

/// Filters an [Array], returning elements matching the filter (i.e. where the values are true).
/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
/// Therefore, it is considered undefined behavior to pass `filter` with null values.
///
/// # Example
/// ```rust
/// # use arrow::array::{Int32Array, BooleanArray};
Expand All @@ -221,6 +222,25 @@ pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
/// # }
/// ```
pub fn filter(array: &Array, filter: &BooleanArray) -> Result<ArrayRef> {
if filter.null_count() > 0 {
// this greatly simplifies subsequent filtering code
// now we only have a boolean mask to deal with
let array_data = filter.data_ref();
let null_bitmap = array_data.null_buffer().unwrap();
let mask = filter.values();
let offset = filter.offset();

let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len());

let array_data = ArrayData::builder(DataType::Boolean)
.len(filter.len())
.add_buffer(new_mask)
.build();
let filter = BooleanArray::from(array_data);
// fully qualified syntax, because we have an argument with the same name
return crate::compute::kernels::filter::filter(array, &filter);
}

let iter = SlicesIterator::new(filter);

let mut mutable =
Expand Down Expand Up @@ -249,6 +269,7 @@ pub fn filter_record_batch(
#[cfg(test)]
mod tests {
use super::*;
use crate::datatypes::Int64Type;
use crate::{
buffer::Buffer,
datatypes::{DataType, Field},
Expand Down Expand Up @@ -581,4 +602,27 @@ mod tests {
assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
assert_eq!(filter_count, 61 + 61 + 5);
}

#[test]
fn test_null_mask() -> Result<()> {
use crate::compute::kernels::comparison;
let a: PrimitiveArray<Int64Type> =
PrimitiveArray::from(vec![Some(1), Some(2), None]);
let mask0 = comparison::eq(&a, &a)?;
let out0 = filter(&a, &mask0)?;
let out_arr0 = out0
.as_any()
.downcast_ref::<PrimitiveArray<Int64Type>>()
.unwrap();

let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
let out1 = filter(&a, &mask1)?;
let out_arr1 = out1
.as_any()
.downcast_ref::<PrimitiveArray<Int64Type>>()
.unwrap();
assert_eq!(mask0, mask1);
assert_eq!(out_arr0, out_arr1);
Ok(())
}
}

0 comments on commit f042191

Please sign in to comment.