diff --git a/benches/filter_kernels.rs b/benches/filter_kernels.rs index 007453f9e15..4ff8b0428d3 100644 --- a/benches/filter_kernels.rs +++ b/benches/filter_kernels.rs @@ -16,9 +16,12 @@ // under the License. extern crate arrow2; +use std::sync::Arc; + use arrow2::array::*; -use arrow2::compute::filter::{build_filter, filter, Filter}; -use arrow2::datatypes::DataType; +use arrow2::compute::filter::{build_filter, filter, filter_record_batch, Filter}; +use arrow2::datatypes::{DataType, Field, Schema}; +use arrow2::record_batch::RecordBatch; use arrow2::util::bench_util::*; use criterion::{criterion_group, criterion_main, Criterion}; @@ -120,6 +123,17 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("filter context string low selectivity", |b| { b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); + + let data_array = create_primitive_array::(size, DataType::Float32, 0.0); + + let field = Field::new("c1", data_array.data_type().clone(), true); + let schema = Schema::new(vec![field]); + + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data_array)]).unwrap(); + + c.bench_function("filter single record batch", |b| { + b.iter(|| filter_record_batch(&batch, &filter_array)) + }); } criterion_group!(benches, add_benchmark); diff --git a/src/compute/filter.rs b/src/compute/filter.rs index cc3d837c7a7..9b975d8529c 100644 --- a/src/compute/filter.rs +++ b/src/compute/filter.rs @@ -254,14 +254,23 @@ pub fn filter(array: &dyn Array, filter: &BooleanArray) -> Result /// Therefore, it is considered undefined behavior to pass `filter` with null values. pub fn filter_record_batch( record_batch: &RecordBatch, - filter: &BooleanArray, + filter_values: &BooleanArray, ) -> Result { - let filter = build_filter(filter)?; - let filtered_arrays = record_batch - .columns() - .iter() - .map(|a| filter(a.as_ref()).into()) - .collect(); + let num_colums = record_batch.columns().len(); + + let filtered_arrays = match num_colums { + 1 => { + vec![filter(record_batch.columns()[0].as_ref(), filter_values)?.into()] + } + _ => { + let filter = build_filter(filter_values)?; + record_batch + .columns() + .iter() + .map(|a| filter(a.as_ref()).into()) + .collect() + } + }; RecordBatch::try_new(record_batch.schema().clone(), filtered_arrays) }