diff --git a/arrow/benches/filter_kernels.rs b/arrow/benches/filter_kernels.rs index d5ff09c040b8..be6d9027a8db 100644 --- a/arrow/benches/filter_kernels.rs +++ b/arrow/benches/filter_kernels.rs @@ -18,13 +18,13 @@ extern crate arrow; use std::sync::Arc; -use arrow::compute::{filter_record_batch, Filter}; +use arrow::compute::{filter_record_batch, FilterBuilder, FilterPredicate}; use arrow::record_batch::RecordBatch; use arrow::util::bench_util::*; use arrow::array::*; -use arrow::compute::{build_filter, filter}; -use arrow::datatypes::{Field, Float32Type, Schema, UInt8Type}; +use arrow::compute::filter; +use arrow::datatypes::{Field, Float32Type, Int32Type, Schema, UInt8Type}; use criterion::{criterion_group, criterion_main, Criterion}; @@ -32,8 +32,8 @@ fn bench_filter(data_array: &dyn Array, filter_array: &BooleanArray) { criterion::black_box(filter(data_array, filter_array).unwrap()); } -fn bench_built_filter<'a>(filter: &Filter<'a>, data: &impl Array) { - criterion::black_box(filter(data.data())); +fn bench_built_filter(filter: &FilterPredicate, array: &dyn Array) { + criterion::black_box(filter.filter(array).unwrap()); } fn add_benchmark(c: &mut Criterion) { @@ -42,68 +42,145 @@ fn add_benchmark(c: &mut Criterion) { let dense_filter_array = create_boolean_array(size, 0.0, 1.0 - 1.0 / 1024.0); let sparse_filter_array = create_boolean_array(size, 0.0, 1.0 / 1024.0); - let filter = build_filter(&filter_array).unwrap(); - let dense_filter = build_filter(&dense_filter_array).unwrap(); - let sparse_filter = build_filter(&sparse_filter_array).unwrap(); + let filter = FilterBuilder::new(&filter_array).optimize().build(); + let dense_filter = FilterBuilder::new(&dense_filter_array).optimize().build(); + let sparse_filter = FilterBuilder::new(&sparse_filter_array).optimize().build(); let data_array = create_primitive_array::(size, 0.0); - c.bench_function("filter u8", |b| { + c.bench_function("filter optimize (kept 1/2)", |b| { + b.iter(|| FilterBuilder::new(&filter_array).optimize().build()) + }); + + c.bench_function("filter optimize high selectivity (kept 1023/1024)", |b| { + b.iter(|| FilterBuilder::new(&dense_filter_array).optimize().build()) + }); + + c.bench_function("filter optimize low selectivity (kept 1/1024)", |b| { + b.iter(|| FilterBuilder::new(&sparse_filter_array).optimize().build()) + }); + + c.bench_function("filter u8 (kept 1/2)", |b| { b.iter(|| bench_filter(&data_array, &filter_array)) }); - c.bench_function("filter u8 high selectivity", |b| { + c.bench_function("filter u8 high selectivity (kept 1023/1024)", |b| { b.iter(|| bench_filter(&data_array, &dense_filter_array)) }); - c.bench_function("filter u8 low selectivity", |b| { + c.bench_function("filter u8 low selectivity (kept 1/1024)", |b| { b.iter(|| bench_filter(&data_array, &sparse_filter_array)) }); - c.bench_function("filter context u8", |b| { + c.bench_function("filter context u8 (kept 1/2)", |b| { b.iter(|| bench_built_filter(&filter, &data_array)) }); - c.bench_function("filter context u8 high selectivity", |b| { + c.bench_function("filter context u8 high selectivity (kept 1023/1024)", |b| { b.iter(|| bench_built_filter(&dense_filter, &data_array)) }); - c.bench_function("filter context u8 low selectivity", |b| { + c.bench_function("filter context u8 low selectivity (kept 1/1024)", |b| { b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); - let data_array = create_primitive_array::(size, 0.5); - c.bench_function("filter context u8 w NULLs", |b| { - b.iter(|| bench_built_filter(&filter, &data_array)) + let data_array = create_primitive_array::(size, 0.0); + c.bench_function("filter i32 (kept 1/2)", |b| { + b.iter(|| bench_filter(&data_array, &filter_array)) }); - c.bench_function("filter context u8 w NULLs high selectivity", |b| { - b.iter(|| bench_built_filter(&dense_filter, &data_array)) + c.bench_function("filter i32 high selectivity (kept 1023/1024)", |b| { + b.iter(|| bench_filter(&data_array, &dense_filter_array)) + }); + c.bench_function("filter i32 low selectivity (kept 1/1024)", |b| { + b.iter(|| bench_filter(&data_array, &sparse_filter_array)) + }); + + c.bench_function("filter context i32 (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) }); - c.bench_function("filter context u8 w NULLs low selectivity", |b| { + c.bench_function( + "filter context i32 high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function("filter context i32 low selectivity (kept 1/1024)", |b| { b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); + let data_array = create_primitive_array::(size, 0.5); + c.bench_function("filter context i32 w NULLs (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function( + "filter context i32 w NULLs high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function( + "filter context i32 w NULLs low selectivity (kept 1/1024)", + |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)), + ); + + let data_array = create_primitive_array::(size, 0.5); + c.bench_function("filter context u8 w NULLs (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function( + "filter context u8 w NULLs high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function( + "filter context u8 w NULLs low selectivity (kept 1/1024)", + |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)), + ); + let data_array = create_primitive_array::(size, 0.5); - c.bench_function("filter f32", |b| { + c.bench_function("filter f32 (kept 1/2)", |b| { b.iter(|| bench_filter(&data_array, &filter_array)) }); - c.bench_function("filter context f32", |b| { + c.bench_function("filter context f32 (kept 1/2)", |b| { b.iter(|| bench_built_filter(&filter, &data_array)) }); - c.bench_function("filter context f32 high selectivity", |b| { - b.iter(|| bench_built_filter(&dense_filter, &data_array)) - }); - c.bench_function("filter context f32 low selectivity", |b| { + c.bench_function( + "filter context f32 high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function("filter context f32 low selectivity (kept 1/1024)", |b| { b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); let data_array = create_string_array::(size, 0.5); - c.bench_function("filter context string", |b| { + c.bench_function("filter context string (kept 1/2)", |b| { b.iter(|| bench_built_filter(&filter, &data_array)) }); - c.bench_function("filter context string high selectivity", |b| { - b.iter(|| bench_built_filter(&dense_filter, &data_array)) - }); - c.bench_function("filter context string low selectivity", |b| { + c.bench_function( + "filter context string high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function("filter context string low selectivity (kept 1/1024)", |b| { b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); + let data_array = create_string_dict_array::(size, 0.0); + c.bench_function("filter context string dictionary (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function( + "filter context string dictionary high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function( + "filter context string dictionary low selectivity (kept 1/1024)", + |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)), + ); + + let data_array = create_string_dict_array::(size, 0.5); + c.bench_function("filter context string dictionary w NULLs (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function( + "filter context string dictionary w NULLs high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function( + "filter context string dictionary w NULLs low selectivity (kept 1/1024)", + |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)), + ); + let data_array = create_primitive_array::(size, 0.0); let field = Field::new("c1", data_array.data_type().clone(), true); diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index 041826372536..e90add475883 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -17,24 +17,60 @@ //! Defines miscellaneous array kernels. +use std::ops::AddAssign; +use std::sync::Arc; + +use num::Zero; + +use TimeUnit::*; + use crate::array::*; -use crate::buffer::buffer_bin_and; -use crate::datatypes::DataType; -use crate::error::Result; +use crate::buffer::{buffer_bin_and, Buffer, MutableBuffer}; +use crate::datatypes::*; +use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; use crate::util::bit_chunk_iterator::{UnalignedBitChunk, UnalignedBitChunkIterator}; +use crate::util::bit_util; -/// Function that can filter arbitrary arrays -pub type Filter<'a> = Box ArrayData + 'a>; +/// If the filter selects more than this fraction of rows, use +/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate +/// over individual rows using [`IndexIterator`] +/// +/// Threshold of 0.8 chosen based on +/// +const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; + +macro_rules! downcast_filter { + ($type: ty, $values: expr, $filter: expr) => {{ + let values = $values + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to a primitive array"); + + Ok(Arc::new(filter_primitive::<$type>(&values, $filter))) + }}; +} + +macro_rules! downcast_dict_filter { + ($type: ty, $values: expr, $filter: expr) => {{ + let values = $values + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to a dictionary array"); + Ok(Arc::new(filter_dict::<$type>(values, $filter))) + }}; +} -/// An iterator of `(usize, usize)` each representing an interval `[start,end[` whose -/// slots of a [BooleanArray] are true. Each interval corresponds to a contiguous region of memory to be -/// "taken" from an array to be filtered. +/// An iterator of `(usize, usize)` each representing an interval `[start, end]` whose +/// slots of a [BooleanArray] are true. Each interval corresponds to a contiguous region of memory +/// to be "taken" from an array to be filtered. +/// +/// This is only performant for filters that copy across long contiguous runs #[derive(Debug)] pub struct SlicesIterator<'a> { iter: UnalignedBitChunkIterator<'a>, len: usize, - chunk_end_offset: usize, + current_offset: i64, current_chunk: u64, } @@ -45,13 +81,13 @@ impl<'a> SlicesIterator<'a> { let chunk = UnalignedBitChunk::new(values.as_slice(), filter.offset(), len); let mut iter = chunk.iter(); - let chunk_end_offset = 64 - chunk.lead_padding(); + let current_offset = -(chunk.lead_padding() as i64); let current_chunk = iter.next().unwrap_or(0); Self { iter, len, - chunk_end_offset, + current_offset, current_chunk, } } @@ -59,18 +95,18 @@ impl<'a> SlicesIterator<'a> { /// Returns `Some((chunk_offset, bit_offset))` for the next chunk that has at /// least one bit set, or None if there is no such chunk. /// - /// Where `chunk_offset` is the bit offset to the current `usize`d chunk + /// Where `chunk_offset` is the bit offset to the current `u64` chunk /// and `bit_offset` is the offset of the first `1` bit in that chunk - fn advance_to_set_bit(&mut self) -> Option<(usize, u32)> { + fn advance_to_set_bit(&mut self) -> Option<(i64, u32)> { loop { if self.current_chunk != 0 { // Find the index of the first 1 let bit_pos = self.current_chunk.trailing_zeros(); - return Some((self.chunk_end_offset, bit_pos)); + return Some((self.current_offset, bit_pos)); } self.current_chunk = self.iter.next()?; - self.chunk_end_offset += 64; + self.current_offset += 64; } } } @@ -98,19 +134,19 @@ impl<'a> Iterator for SlicesIterator<'a> { self.current_chunk &= !((1 << end_bit) - 1); return Some(( - start_chunk + start_bit as usize - 64, - self.chunk_end_offset + end_bit as usize - 64, + (start_chunk + start_bit as i64) as usize, + (self.current_offset + end_bit as i64) as usize, )); } match self.iter.next() { Some(next) => { self.current_chunk = next; - self.chunk_end_offset += 64; + self.current_offset += 64; } None => { return Some(( - start_chunk + start_bit as usize - 64, + (start_chunk + start_bit as i64) as usize, std::mem::replace(&mut self.len, 0), )); } @@ -119,17 +155,83 @@ impl<'a> Iterator for SlicesIterator<'a> { } } +/// An iterator of `usize` whose index in [`BooleanArray`] is true +/// +/// This provides the best performance on most predicates, apart from those which keep +/// large runs and therefore favour [`SlicesIterator`] +struct IndexIterator<'a> { + current_chunk: u64, + chunk_offset: i64, + remaining: usize, + iter: UnalignedBitChunkIterator<'a>, +} + +impl<'a> IndexIterator<'a> { + fn new(filter: &'a BooleanArray, len: usize) -> Self { + assert_eq!(filter.null_count(), 0); + let data = filter.data(); + let chunks = + UnalignedBitChunk::new(&data.buffers()[0], data.offset(), data.len()); + let mut iter = chunks.iter(); + + let current_chunk = iter.next().unwrap_or(0); + let chunk_offset = -(chunks.lead_padding() as i64); + + Self { + current_chunk, + chunk_offset, + remaining: len, + iter, + } + } +} + +impl<'a> Iterator for IndexIterator<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + while self.remaining != 0 { + if self.current_chunk != 0 { + let bit_pos = self.current_chunk.trailing_zeros(); + self.current_chunk ^= 1 << bit_pos; + self.remaining -= 1; + return Some((self.chunk_offset + bit_pos as i64) as usize); + } + + // Must panic if exhausted early as trusted length iterator + self.current_chunk = self.iter.next().expect("IndexIterator exhausted early"); + self.chunk_offset += 64; + } + None + } + + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +/// Counts the number of set bits in `filter` fn filter_count(filter: &BooleanArray) -> usize { filter .values() .count_set_bits_offset(filter.offset(), filter.len()) } +/// Function that can filter arbitrary arrays +/// +/// Deprecated: Use [`FilterPredicate`] instead +#[deprecated] +pub type Filter<'a> = Box ArrayData + 'a>; + /// Returns a prepared function optimized to filter multiple arrays. /// Creating this function requires time, but using it is faster than [filter] when the /// same filter needs to be applied to multiple arrays (e.g. a multi-column `RecordBatch`). /// WARNING: the nulls of `filter` are ignored and the value on its slot is considered. /// Therefore, it is considered undefined behavior to pass `filter` with null values. +/// +/// Deprecated: Use [`FilterBuilder`] instead +#[deprecated] +#[allow(deprecated)] pub fn build_filter(filter: &BooleanArray) -> Result { let iter = SlicesIterator::new(filter); let filter_count = filter_count(filter); @@ -185,79 +287,600 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray { /// # Ok(()) /// # } /// ``` -pub fn filter(array: &dyn Array, predicate: &BooleanArray) -> Result { - if predicate.null_count() > 0 { - // this greatly simplifies subsequent filtering code - // now we only have a boolean mask to deal with - let predicate = prep_null_mask_filter(predicate); - return filter(array, &predicate); +pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result { + let predicate = FilterBuilder::new(predicate).build(); + filter_array(values, &predicate) +} + +/// Returns a new [RecordBatch] with arrays containing only values matching the filter. +pub fn filter_record_batch( + record_batch: &RecordBatch, + predicate: &BooleanArray, +) -> Result { + let mut filter_builder = FilterBuilder::new(predicate); + if record_batch.num_columns() > 1 { + // Only optimize if filtering more than one column + filter_builder = filter_builder.optimize(); + } + let filter = filter_builder.build(); + + let filtered_arrays = record_batch + .columns() + .iter() + .map(|a| filter_array(a, &filter)) + .collect::>>()?; + + RecordBatch::try_new(record_batch.schema(), filtered_arrays) +} + +/// A builder to construct [`FilterPredicate`] +#[derive(Debug)] +pub struct FilterBuilder { + filter: BooleanArray, + count: usize, + strategy: IterationStrategy, +} + +impl FilterBuilder { + /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`] + pub fn new(filter: &BooleanArray) -> Self { + let filter = match filter.null_count() { + 0 => BooleanArray::from(filter.data().clone()), + _ => prep_null_mask_filter(filter), + }; + + let count = filter_count(&filter); + let strategy = IterationStrategy::default_strategy(filter.len(), count); + + Self { + filter, + count, + strategy, + } } - let filter_count = filter_count(predicate); + /// Compute an optimised representation of the provided `filter` mask that can be + /// applied to an array more quickly. + /// + /// Note: There is limited benefit to calling this to then filter a single array + /// Note: This will likely have a larger memory footprint than the original mask + pub fn optimize(mut self) -> Self { + match self.strategy { + IterationStrategy::SlicesIterator => { + let slices = SlicesIterator::new(&self.filter).collect(); + self.strategy = IterationStrategy::Slices(slices) + } + IterationStrategy::IndexIterator => { + let indices = IndexIterator::new(&self.filter, self.count).collect(); + self.strategy = IterationStrategy::Indices(indices) + } + _ => {} + } + self + } - match filter_count { - 0 => { - // return empty - Ok(new_empty_array(array.data_type())) + /// Construct the final `FilterPredicate` + pub fn build(self) -> FilterPredicate { + FilterPredicate { + filter: self.filter, + count: self.count, + strategy: self.strategy, } - len if len == array.len() => { - // return all - let data = array.data().clone(); - Ok(make_array(data)) + } +} + +/// The iteration strategy used to evaluate [`FilterPredicate`] +#[derive(Debug)] +enum IterationStrategy { + /// A lazily evaluated iterator of ranges + SlicesIterator, + /// A lazily evaluated iterator of indices + IndexIterator, + /// A precomputed list of indices + Indices(Vec), + /// A precomputed array of ranges + Slices(Vec<(usize, usize)>), + /// Select all rows + All, + /// Select no rows + None, +} + +impl IterationStrategy { + /// The default [`IterationStrategy`] for a filter of length `filter_length` + /// and selecting `filter_count` rows + fn default_strategy(filter_length: usize, filter_count: usize) -> Self { + if filter_length == 0 || filter_count == 0 { + return IterationStrategy::None; } - _ => { - // actually filter - let mut mutable = - MutableArrayData::new(vec![array.data_ref()], false, filter_count); - let iter = SlicesIterator::new(predicate); - iter.for_each(|(start, end)| mutable.extend(0, start, end)); + if filter_count == filter_length { + return IterationStrategy::All; + } - let data = mutable.freeze(); - Ok(make_array(data)) + // Compute the selectivity of the predicate by dividing the number of true + // bits in the predicate by the predicate's total length + // + // This can then be used as a heuristic for the optimal iteration strategy + let selectivity_frac = filter_count as f64 / filter_length as f64; + if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD { + return IterationStrategy::SlicesIterator; } + IterationStrategy::IndexIterator } } -/// Returns a new [RecordBatch] with arrays containing only values matching the filter. -pub fn filter_record_batch( - record_batch: &RecordBatch, - predicate: &BooleanArray, -) -> Result { - if predicate.null_count() > 0 { - // this greatly simplifies subsequent filtering code - // now we only have a boolean mask to deal with - let predicate = prep_null_mask_filter(predicate); - return filter_record_batch(record_batch, &predicate); +/// A filtering predicate that can be applied to an [`Array`] +#[derive(Debug)] +pub struct FilterPredicate { + filter: BooleanArray, + count: usize, + strategy: IterationStrategy, +} + +impl FilterPredicate { + /// Selects rows from `values` based on this [`FilterPredicate`] + pub fn filter(&self, values: &dyn Array) -> Result { + filter_array(values, self) } +} - let num_columns = record_batch.columns().len(); +fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result { + if predicate.filter.len() > values.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Filter predicate of length {} is larger than target array of length {}", + predicate.filter.len(), + values.len() + ))); + } - let filtered_arrays = match num_columns { - 1 => { - vec![filter(record_batch.columns()[0].as_ref(), predicate)?] + match predicate.strategy { + IterationStrategy::None => Ok(new_empty_array(values.data_type())), + IterationStrategy::All => Ok(make_array(values.data().slice(0, predicate.count))), + // actually filter + _ => match values.data_type() { + DataType::Boolean => { + let values = values.as_any().downcast_ref::().unwrap(); + Ok(Arc::new(filter_boolean(values, predicate))) + } + DataType::Int8 => { + downcast_filter!(Int8Type, values, predicate) + } + DataType::Int16 => { + downcast_filter!(Int16Type, values, predicate) + } + DataType::Int32 => { + downcast_filter!(Int32Type, values, predicate) + } + DataType::Int64 => { + downcast_filter!(Int64Type, values, predicate) + } + DataType::UInt8 => { + downcast_filter!(UInt8Type, values, predicate) + } + DataType::UInt16 => { + downcast_filter!(UInt16Type, values, predicate) + } + DataType::UInt32 => { + downcast_filter!(UInt32Type, values, predicate) + } + DataType::UInt64 => { + downcast_filter!(UInt64Type, values, predicate) + } + DataType::Float32 => { + downcast_filter!(Float32Type, values, predicate) + } + DataType::Float64 => { + downcast_filter!(Float64Type, values, predicate) + } + DataType::Date32 => { + downcast_filter!(Date32Type, values, predicate) + } + DataType::Date64 => { + downcast_filter!(Date64Type, values, predicate) + } + DataType::Time32(Second) => { + downcast_filter!(Time32SecondType, values, predicate) + } + DataType::Time32(Millisecond) => { + downcast_filter!(Time32MillisecondType, values, predicate) + } + DataType::Time64(Microsecond) => { + downcast_filter!(Time64MicrosecondType, values, predicate) + } + DataType::Time64(Nanosecond) => { + downcast_filter!(Time64NanosecondType, values, predicate) + } + DataType::Timestamp(Second, _) => { + downcast_filter!(TimestampSecondType, values, predicate) + } + DataType::Timestamp(Millisecond, _) => { + downcast_filter!(TimestampMillisecondType, values, predicate) + } + DataType::Timestamp(Microsecond, _) => { + downcast_filter!(TimestampMicrosecondType, values, predicate) + } + DataType::Timestamp(Nanosecond, _) => { + downcast_filter!(TimestampNanosecondType, values, predicate) + } + DataType::Interval(IntervalUnit::YearMonth) => { + downcast_filter!(IntervalYearMonthType, values, predicate) + } + DataType::Interval(IntervalUnit::DayTime) => { + downcast_filter!(IntervalDayTimeType, values, predicate) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + downcast_filter!(IntervalMonthDayNanoType, values, predicate) + } + DataType::Duration(TimeUnit::Second) => { + downcast_filter!(DurationSecondType, values, predicate) + } + DataType::Duration(TimeUnit::Millisecond) => { + downcast_filter!(DurationMillisecondType, values, predicate) + } + DataType::Duration(TimeUnit::Microsecond) => { + downcast_filter!(DurationMicrosecondType, values, predicate) + } + DataType::Duration(TimeUnit::Nanosecond) => { + downcast_filter!(DurationNanosecondType, values, predicate) + } + DataType::Utf8 => { + let values = values + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(Arc::new(filter_string::(values, predicate))) + } + DataType::LargeUtf8 => { + let values = values + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(Arc::new(filter_string::(values, predicate))) + } + DataType::Dictionary(key_type, _) => match key_type.as_ref() { + DataType::Int8 => downcast_dict_filter!(Int8Type, values, predicate), + DataType::Int16 => downcast_dict_filter!(Int16Type, values, predicate), + DataType::Int32 => downcast_dict_filter!(Int32Type, values, predicate), + DataType::Int64 => downcast_dict_filter!(Int64Type, values, predicate), + DataType::UInt8 => downcast_dict_filter!(UInt8Type, values, predicate), + DataType::UInt16 => downcast_dict_filter!(UInt16Type, values, predicate), + DataType::UInt32 => downcast_dict_filter!(UInt32Type, values, predicate), + DataType::UInt64 => downcast_dict_filter!(UInt64Type, values, predicate), + t => { + unimplemented!("Filter not supported for dictionary key type {:?}", t) + } + }, + _ => { + // fallback to using MutableArrayData + let mut mutable = MutableArrayData::new( + vec![values.data_ref()], + false, + predicate.count, + ); + + match &predicate.strategy { + IterationStrategy::Slices(slices) => { + slices + .iter() + .for_each(|(start, end)| mutable.extend(0, *start, *end)); + } + _ => { + let iter = SlicesIterator::new(&predicate.filter); + iter.for_each(|(start, end)| mutable.extend(0, start, end)); + } + } + + let data = mutable.freeze(); + Ok(make_array(data)) + } + }, + } +} + +/// Computes a new null mask for `data` based on `predicate` +/// +/// If the predicate selected no null-rows, returns `None`, otherwise returns +/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls +/// in the filtered output, and `null_buffer` is the filtered null buffer +/// +fn filter_null_mask( + data: &ArrayData, + predicate: &FilterPredicate, +) -> Option<(usize, Buffer)> { + if data.null_count() == 0 { + return None; + } + + let nulls = filter_bits(data.null_buffer()?, data.offset(), predicate); + // The filtered `nulls` has a length of `predicate.count` bits and + // therefore the null count is this minus the number of valid bits + let null_count = predicate.count - nulls.count_set_bits(); + + if null_count == 0 { + return None; + } + + Some((null_count, nulls)) +} + +/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset` +fn filter_bits(buffer: &Buffer, offset: usize, predicate: &FilterPredicate) -> Buffer { + let src = buffer.as_slice(); + + match &predicate.strategy { + IterationStrategy::IndexIterator => { + let bits = IndexIterator::new(&predicate.filter, predicate.count) + .map(|src_idx| bit_util::get_bit(src, src_idx + offset)); + + // SAFETY: `IndexIterator` reports its size correctly + unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() } } - _ => { - let filter = build_filter(predicate)?; - record_batch - .columns() + IterationStrategy::Indices(indices) => { + let bits = indices .iter() - .map(|a| make_array(filter(a.data()))) - .collect() + .map(|src_idx| bit_util::get_bit(src, *src_idx + offset)); + + // SAFETY: `Vec::iter()` reports its size correctly + unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() } + } + IterationStrategy::SlicesIterator => { + let mut builder = + BooleanBufferBuilder::new(bit_util::ceil(predicate.count, 8)); + for (start, end) in SlicesIterator::new(&predicate.filter) { + builder.append_packed_range(start + offset..end + offset, src) + } + builder.finish() + } + IterationStrategy::Slices(slices) => { + let mut builder = + BooleanBufferBuilder::new(bit_util::ceil(predicate.count, 8)); + for (start, end) in slices { + builder.append_packed_range(*start + offset..*end + offset, src) + } + builder.finish() + } + IterationStrategy::All | IterationStrategy::None => unreachable!(), + } +} + +/// `filter` implementation for boolean buffers +fn filter_boolean(values: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray { + let data = values.data(); + assert_eq!(data.buffers().len(), 1); + assert_eq!(data.child_data().len(), 0); + + let values = filter_bits(&data.buffers()[0], data.offset(), predicate); + + let mut builder = ArrayDataBuilder::new(DataType::Boolean) + .len(predicate.count) + .add_buffer(values); + + if let Some((null_count, nulls)) = filter_null_mask(data, predicate) { + builder = builder.null_count(null_count).null_bit_buffer(nulls); + } + + let data = unsafe { builder.build_unchecked() }; + BooleanArray::from(data) +} + +/// `filter` implementation for primitive arrays +fn filter_primitive( + values: &PrimitiveArray, + predicate: &FilterPredicate, +) -> PrimitiveArray +where + T: ArrowPrimitiveType, +{ + let data = values.data(); + assert_eq!(data.buffers().len(), 1); + assert_eq!(data.child_data().len(), 0); + + let values = data.buffer::(0); + assert!(values.len() >= predicate.filter.len()); + + let buffer = match &predicate.strategy { + IterationStrategy::SlicesIterator => { + let mut buffer = + MutableBuffer::with_capacity(predicate.count * T::get_byte_width()); + for (start, end) in SlicesIterator::new(&predicate.filter) { + buffer.extend_from_slice(&values[start..end]); + } + buffer + } + IterationStrategy::Slices(slices) => { + let mut buffer = + MutableBuffer::with_capacity(predicate.count * T::get_byte_width()); + for (start, end) in slices { + buffer.extend_from_slice(&values[*start..*end]); + } + buffer + } + IterationStrategy::IndexIterator => { + let iter = + IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]); + + // SAFETY: IndexIterator is trusted length + unsafe { MutableBuffer::from_trusted_len_iter(iter) } + } + IterationStrategy::Indices(indices) => { + let iter = indices.iter().map(|x| values[*x]); + + // SAFETY: `Vec::iter` is trusted length + unsafe { MutableBuffer::from_trusted_len_iter(iter) } } + IterationStrategy::All | IterationStrategy::None => unreachable!(), }; - RecordBatch::try_new(record_batch.schema(), filtered_arrays) + + let mut builder = ArrayDataBuilder::new(data.data_type().clone()) + .len(predicate.count) + .add_buffer(buffer.into()); + + if let Some((null_count, nulls)) = filter_null_mask(data, predicate) { + builder = builder.null_count(null_count).null_bit_buffer(nulls); + } + + let data = unsafe { builder.build_unchecked() }; + PrimitiveArray::from(data) +} + +/// [`FilterString`] is created from a source [`GenericStringArray`] and can be +/// used to build a new [`GenericStringArray`] by copying values from the source +/// +/// TODO(raphael): Could this be used for the take kernel as well? +struct FilterString<'a, OffsetSize> { + src_offsets: &'a [OffsetSize], + src_values: &'a [u8], + dst_offsets: MutableBuffer, + dst_values: MutableBuffer, + cur_offset: OffsetSize, +} + +impl<'a, OffsetSize> FilterString<'a, OffsetSize> +where + OffsetSize: Zero + AddAssign + StringOffsetSizeTrait, +{ + fn new(capacity: usize, array: &'a GenericStringArray) -> Self { + let num_offsets_bytes = (capacity + 1) * std::mem::size_of::(); + let mut dst_offsets = MutableBuffer::new(num_offsets_bytes); + let dst_values = MutableBuffer::new(0); + let cur_offset = OffsetSize::zero(); + dst_offsets.push(cur_offset); + + Self { + src_offsets: array.value_offsets(), + src_values: &array.data().buffers()[1], + dst_offsets, + dst_values, + cur_offset, + } + } + + /// Returns the byte offset at `idx` + #[inline] + fn get_value_offset(&self, idx: usize) -> usize { + self.src_offsets[idx].to_usize().expect("illegal offset") + } + + /// Returns the start and end of the value at index `idx` along with its length + #[inline] + fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) { + // These can only fail if `array` contains invalid data + let start = self.get_value_offset(idx); + let end = self.get_value_offset(idx + 1); + let len = OffsetSize::from_usize(end - start).expect("illegal offset range"); + (start, end, len) + } + + /// Extends the in-progress array by the indexes in the provided iterator + fn extend_idx(&mut self, iter: impl Iterator) { + for idx in iter { + let (start, end, len) = self.get_value_range(idx); + self.cur_offset += len; + self.dst_offsets.push(self.cur_offset); + self.dst_values + .extend_from_slice(&self.src_values[start..end]); + } + } + + /// Extends the in-progress array by the ranges in the provided iterator + fn extend_slices(&mut self, iter: impl Iterator) { + for (start, end) in iter { + // These can only fail if `array` contains invalid data + for idx in start..end { + let (_, _, len) = self.get_value_range(idx); + self.cur_offset += len; + self.dst_offsets.push(self.cur_offset); // push_unchecked? + } + + let value_start = self.get_value_offset(start); + let value_end = self.get_value_offset(end); + self.dst_values + .extend_from_slice(&self.src_values[value_start..value_end]); + } + } +} + +/// `filter` implementation for string arrays +/// +/// Note: NULLs with a non-zero slot length in `array` will have the corresponding +/// data copied across. This allows handling the null mask separately from the data +fn filter_string( + array: &GenericStringArray, + predicate: &FilterPredicate, +) -> GenericStringArray +where + OffsetSize: Zero + AddAssign + StringOffsetSizeTrait, +{ + let data = array.data(); + assert_eq!(data.buffers().len(), 2); + assert_eq!(data.child_data().len(), 0); + let mut filter = FilterString::new(predicate.count, array); + + match &predicate.strategy { + IterationStrategy::SlicesIterator => { + filter.extend_slices(SlicesIterator::new(&predicate.filter)) + } + IterationStrategy::Slices(slices) => filter.extend_slices(slices.iter().cloned()), + IterationStrategy::IndexIterator => { + filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count)) + } + IterationStrategy::Indices(indices) => filter.extend_idx(indices.iter().cloned()), + IterationStrategy::All | IterationStrategy::None => unreachable!(), + } + + let mut builder = ArrayDataBuilder::new(data.data_type().clone()) + .len(predicate.count) + .add_buffer(filter.dst_offsets.into()) + .add_buffer(filter.dst_values.into()); + + if let Some((null_count, nulls)) = filter_null_mask(data, predicate) { + builder = builder.null_count(null_count).null_bit_buffer(nulls); + } + + let data = unsafe { builder.build_unchecked() }; + GenericStringArray::from(data) +} + +/// `filter` implementation for dictionaries +fn filter_dict( + array: &DictionaryArray, + predicate: &FilterPredicate, +) -> DictionaryArray +where + T: ArrowPrimitiveType, + T::Native: num::Num, +{ + let filtered_keys = filter_primitive::(array.keys(), predicate); + let filtered_data = filtered_keys.data_ref(); + + let data = unsafe { + ArrayData::new_unchecked( + array.data_type().clone(), + filtered_data.len(), + Some(filtered_data.null_count()), + filtered_data.null_buffer().cloned(), + filtered_data.offset(), + filtered_data.buffers().to_vec(), + array.data().child_data().to_vec(), + ) + }; + + DictionaryArray::::from(data) } #[cfg(test)] mod tests { - use super::*; + use rand::distributions::{Alphanumeric, Standard}; + use rand::prelude::*; + use crate::datatypes::Int64Type; use crate::{ buffer::Buffer, datatypes::{DataType, Field}, }; - use rand::prelude::*; + + use super::*; macro_rules! def_temporal_test { ($test:ident, $array_type: ident, $data: expr) => { @@ -682,12 +1305,15 @@ mod tests { .build() .unwrap(); - let bool_array = BooleanArray::from(data); + let filter = BooleanArray::from(data); - let bits: Vec<_> = SlicesIterator::new(&bool_array) + let slice_bits: Vec<_> = SlicesIterator::new(&filter) .flat_map(|(start, end)| start..end) .collect(); + let count = filter_count(&filter); + let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect(); + let expected_bits: Vec<_> = bools .iter() .skip(offset) @@ -696,7 +1322,8 @@ mod tests { .flat_map(|(idx, v)| v.then(|| idx)) .collect(); - assert_eq!(bits, expected_bits); + assert_eq!(slice_bits, expected_bits); + assert_eq!(index_bits, expected_bits); } #[test] @@ -720,4 +1347,141 @@ mod tests { test_slices_fuzz(32, 8, 8); test_slices_fuzz(32, 5, 9); } + + /// Filters `values` by `predicate` using standard rust iterators + fn filter_rust(values: impl IntoIterator, predicate: &[bool]) -> Vec { + values + .into_iter() + .zip(predicate) + .filter(|(_, x)| **x) + .map(|(a, _)| a) + .collect() + } + + /// Generates an array of length `len` with `valid_percent` non-null values + fn gen_primitive(len: usize, valid_percent: f64) -> Vec> + where + Standard: Distribution, + { + let mut rng = thread_rng(); + (0..len) + .map(|_| rng.gen_bool(valid_percent).then(|| rng.gen())) + .collect() + } + + /// Generates an array of length `len` with `valid_percent` non-null values + fn gen_strings( + len: usize, + valid_percent: f64, + str_len_range: std::ops::Range, + ) -> Vec> { + let mut rng = thread_rng(); + (0..len) + .map(|_| { + rng.gen_bool(valid_percent).then(|| { + let len = rng.gen_range(str_len_range.clone()); + (0..len) + .map(|_| char::from(rng.sample(Alphanumeric))) + .collect() + }) + }) + .collect() + } + + /// Returns an iterator that calls `Option::as_deref` on each item + fn as_deref( + src: &[Option], + ) -> impl Iterator> { + src.iter().map(|x| x.as_deref()) + } + + #[test] + fn fuzz_filter() { + let mut rng = thread_rng(); + + for i in 0..100 { + let filter_percent = match i { + 0..=4 => 1., + 5..=10 => 0., + _ => rng.gen_range(0.0..1.0), + }; + + let valid_percent = rng.gen_range(0.0..1.0); + + let array_len = rng.gen_range(32..256); + let array_offset = rng.gen_range(0..10); + + // Construct a predicate + let filter_offset = rng.gen_range(0..10); + let filter_truncate = rng.gen_range(0..10); + let bools: Vec<_> = std::iter::from_fn(|| Some(rng.gen_bool(filter_percent))) + .take(array_len + filter_offset - filter_truncate) + .collect(); + + let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some)); + + // Offset predicate + let predicate = predicate.slice(filter_offset, array_len - filter_truncate); + let predicate = predicate.as_any().downcast_ref::().unwrap(); + let bools = &bools[filter_offset..]; + + // Test i32 + let values = gen_primitive(array_len + array_offset, valid_percent); + let src = Int32Array::from_iter(values.iter().cloned()); + + let src = src.slice(array_offset, array_len); + let src = src.as_any().downcast_ref::().unwrap(); + let values = &values[array_offset..]; + + let filtered = filter(src, predicate).unwrap(); + let array = filtered.as_any().downcast_ref::().unwrap(); + let actual: Vec<_> = array.iter().collect(); + + assert_eq!(actual, filter_rust(values.iter().cloned(), bools)); + + // Test string + let strings = gen_strings(array_len + array_offset, valid_percent, 0..20); + let src = StringArray::from_iter(as_deref(&strings)); + + let src = src.slice(array_offset, array_len); + let src = src.as_any().downcast_ref::().unwrap(); + + let filtered = filter(src, predicate).unwrap(); + let array = filtered.as_any().downcast_ref::().unwrap(); + let actual: Vec<_> = array.iter().collect(); + + let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools); + assert_eq!(actual, expected_strings); + + // Test string dictionary + let src = DictionaryArray::::from_iter(as_deref(&strings)); + + let src = src.slice(array_offset, array_len); + let src = src + .as_any() + .downcast_ref::>() + .unwrap(); + + let filtered = filter(src, predicate).unwrap(); + + let array = filtered + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + let actual: Vec<_> = array + .keys() + .iter() + .map(|key| key.map(|key| values.value(key as usize))) + .collect(); + + assert_eq!(actual, expected_strings); + } + } } diff --git a/arrow/src/util/bench_util.rs b/arrow/src/util/bench_util.rs index 40340336882b..eeb906b8e075 100644 --- a/arrow/src/util/bench_util.rs +++ b/arrow/src/util/bench_util.rs @@ -110,6 +110,29 @@ pub fn create_string_array( .collect() } +/// Creates an random (but fixed-seeded) array of a given size and null density +/// consisting of random 4 character alphanumeric strings +pub fn create_string_dict_array( + size: usize, + null_density: f32, +) -> DictionaryArray { + let rng = &mut seedable_rng(); + + let data: Vec<_> = (0..size) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + let value = rng.sample_iter(&Alphanumeric).take(4).collect(); + let value = String::from_utf8(value).unwrap(); + Some(value) + } + }) + .collect(); + + data.iter().map(|x| x.as_deref()).collect() +} + /// Creates an random (but fixed-seeded) binary array of a given size and null density pub fn create_binary_array( size: usize,