diff --git a/vortex-array/src/array/bool/compute/filter.rs b/vortex-array/src/array/bool/compute/filter.rs index 1225bb645..91490e87d 100644 --- a/vortex-array/src/array/bool/compute/filter.rs +++ b/vortex-array/src/array/bool/compute/filter.rs @@ -3,7 +3,6 @@ use vortex_error::{vortex_err, VortexResult}; use crate::array::BoolArray; use crate::compute::FilterFn; -use crate::validity::filter_validity; use crate::variants::BoolArrayTrait; use crate::{Array, IntoArray}; @@ -15,7 +14,7 @@ impl FilterFn for BoolArray { fn filter_select_bool(arr: &BoolArray, predicate: &Array) -> VortexResult { predicate.with_dyn(|b| { - let validity = filter_validity(arr.validity(), predicate)?; + let validity = arr.validity().filter(predicate)?; let predicate = b.as_bool_array().ok_or(vortex_err!( NotImplemented: "as_bool_array", predicate.encoding().id() diff --git a/vortex-array/src/array/primitive/compute/filter.rs b/vortex-array/src/array/primitive/compute/filter.rs index ba332a7c6..cc0cb12e9 100644 --- a/vortex-array/src/array/primitive/compute/filter.rs +++ b/vortex-array/src/array/primitive/compute/filter.rs @@ -3,7 +3,6 @@ use vortex_error::{vortex_err, VortexResult}; use crate::array::primitive::PrimitiveArray; use crate::compute::FilterFn; -use crate::validity::filter_validity; use crate::variants::BoolArrayTrait; use crate::{Array, IntoArray}; @@ -18,7 +17,7 @@ fn filter_select_primitive( predicate: &Array, ) -> VortexResult { predicate.with_dyn(|b| { - let validity = filter_validity(arr.validity(), predicate)?; + let validity = arr.validity().filter(predicate)?; let predicate = b.as_bool_array().ok_or_else(||vortex_err!( NotImplemented: "as_bool_array", predicate.encoding().id() diff --git a/vortex-array/src/array/struct_/compute.rs b/vortex-array/src/array/struct_/compute.rs index ad28679e6..e7fbcf5fb 100644 --- a/vortex-array/src/array/struct_/compute.rs +++ b/vortex-array/src/array/struct_/compute.rs @@ -1,10 +1,11 @@ use itertools::Itertools; -use vortex_error::VortexResult; +use vortex_error::{vortex_err, VortexResult}; use vortex_scalar::Scalar; use crate::array::struct_::StructArray; use crate::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn}; -use crate::compute::{slice, take, ArrayCompute, SliceFn, TakeFn}; +use crate::compute::{filter, slice, take, ArrayCompute, FilterFn, SliceFn, TakeFn}; +use crate::stats::ArrayStatistics; use crate::variants::StructArrayTrait; use crate::{Array, ArrayDType, IntoArray}; @@ -20,6 +21,10 @@ impl ArrayCompute for StructArray { fn take(&self) -> Option<&dyn TakeFn> { Some(self) } + + fn filter(&self) -> Option<&dyn FilterFn> { + Some(self) + } } impl ScalarAtFn for StructArray { @@ -71,3 +76,25 @@ impl SliceFn for StructArray { .map(|a| a.into_array()) } } + +impl FilterFn for StructArray { + fn filter(&self, predicate: &Array) -> VortexResult { + let fields = self + .children() + .map(|field| filter(&field, predicate)) + .try_collect()?; + + let predicate_true_count = predicate + .statistics() + .compute_true_count() + .ok_or_else(|| vortex_err!("Predicate should always be a boolean array"))?; + + Self::try_new( + self.names().clone(), + fields, + predicate_true_count, + self.validity().filter(predicate)?, + ) + .map(|a| a.into_array()) + } +} diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 0d11fe256..6cae93de6 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -120,6 +120,15 @@ impl Validity { } } + pub fn filter(&self, predicate: &Array) -> VortexResult { + match self { + v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => { + Ok(v.clone()) + } + Validity::Array(arr) => Ok(Validity::Array(filter(arr, predicate)?)), + } + } + pub fn to_logical(&self, length: usize) -> LogicalValidity { match self { Self::NonNullable => LogicalValidity::AllValid(length), @@ -328,10 +337,3 @@ impl IntoArray for LogicalValidity { } } } - -pub fn filter_validity(validity: Validity, predicate: &Array) -> VortexResult { - match validity { - v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => Ok(v), - Validity::Array(arr) => Ok(Validity::Array(filter(&arr, predicate)?)), - } -}