From ec85145f4bf486851013faf7f3af9a871e5a9d59 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 3 Apr 2024 01:37:34 -0700 Subject: [PATCH] Add `try_filter_leaves` to propagate error from filter closure (#5575) * Propagate error from filter closure * Add try_filter_leaves instead --- arrow-schema/src/fields.rs | 97 ++++++++++++++++++++++++++++++-------- 1 file changed, 78 insertions(+), 19 deletions(-) diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs index 59b7e76c7823..5a1a6c84c256 100644 --- a/arrow-schema/src/fields.rs +++ b/arrow-schema/src/fields.rs @@ -137,10 +137,22 @@ impl Fields { /// assert_eq!(filtered, expected); /// ``` pub fn filter_leaves bool>(&self, mut filter: F) -> Self { - fn filter_field bool>( + self.try_filter_leaves(|idx, field| Ok(filter(idx, field))) + .unwrap() + } + + /// Returns a copy of this [`Fields`] containing only those [`FieldRef`] passing a predicate + /// or an error if the predicate fails. + /// + /// See [`Fields::filter_leaves`] for more information. + pub fn try_filter_leaves Result>( + &self, + mut filter: F, + ) -> Result { + fn filter_field Result>( f: &FieldRef, filter: &mut F, - ) -> Option { + ) -> Result, ArrowError> { use DataType::*; let v = match f.data_type() { @@ -149,35 +161,72 @@ impl Fields { d => d, }; let d = match v { - List(child) => List(filter_field(child, filter)?), - LargeList(child) => LargeList(filter_field(child, filter)?), - Map(child, ordered) => Map(filter_field(child, filter)?, *ordered), - FixedSizeList(child, size) => FixedSizeList(filter_field(child, filter)?, *size), + List(child) => { + let fields = filter_field(child, filter)?; + if let Some(fields) = fields { + List(fields) + } else { + return Ok(None); + } + } + LargeList(child) => { + let fields = filter_field(child, filter)?; + if let Some(fields) = fields { + LargeList(fields) + } else { + return Ok(None); + } + } + Map(child, ordered) => { + let fields = filter_field(child, filter)?; + if let Some(fields) = fields { + Map(fields, *ordered) + } else { + return Ok(None); + } + } + FixedSizeList(child, size) => { + let fields = filter_field(child, filter)?; + if let Some(fields) = fields { + FixedSizeList(fields, *size) + } else { + return Ok(None); + } + } Struct(fields) => { - let filtered: Fields = fields + let filtered: Result, _> = + fields.iter().map(|f| filter_field(f, filter)).collect(); + let filtered: Fields = filtered? .iter() - .filter_map(|f| filter_field(f, filter)) + .filter_map(|f| f.as_ref().cloned()) .collect(); if filtered.is_empty() { - return None; + return Ok(None); } Struct(filtered) } Union(fields, mode) => { - let filtered: UnionFields = fields + let filtered: Result, _> = fields + .iter() + .map(|(id, f)| filter_field(f, filter).map(|f| f.map(|f| (id, f)))) + .collect(); + let filtered: UnionFields = filtered? .iter() - .filter_map(|(id, f)| Some((id, filter_field(f, filter)?))) + .filter_map(|f| f.as_ref().cloned()) .collect(); if filtered.is_empty() { - return None; + return Ok(None); } Union(filtered, *mode) } - _ => return filter(f).then(|| f.clone()), + _ => { + let filtered = filter(f)?; + return Ok(filtered.then(|| f.clone())); + } }; let d = match f.data_type() { Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)), @@ -186,20 +235,26 @@ impl Fields { } _ => d, }; - Some(Arc::new(f.as_ref().clone().with_data_type(d))) + Ok(Some(Arc::new(f.as_ref().clone().with_data_type(d)))) } let mut leaf_idx = 0; let mut filter = |f: &FieldRef| { - let t = filter(leaf_idx, f); + let t = filter(leaf_idx, f)?; leaf_idx += 1; - t + Ok(t) }; - self.0 + let filtered: Result, _> = self + .0 .iter() - .filter_map(|f| filter_field(f, &mut filter)) - .collect() + .map(|f| filter_field(f, &mut filter)) + .collect(); + let filtered = filtered? + .iter() + .filter_map(|f| f.as_ref().cloned()) + .collect(); + Ok(filtered) } /// Remove a field by index and return it. @@ -531,5 +586,9 @@ mod tests { let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15); assert_eq!(r.len(), 1); assert_eq!(r[0], fields[9]); + + // Propagate error + let r = fields.try_filter_leaves(|_, _| Err(ArrowError::SchemaError("error".to_string()))); + assert!(r.is_err()); } }