Skip to content

Commit

Permalink
Add try_filter_leaves to propagate error from filter closure (#5575)
Browse files Browse the repository at this point in the history
* Propagate error from filter closure

* Add try_filter_leaves instead
  • Loading branch information
viirya authored Apr 3, 2024
1 parent 77a3132 commit ec85145
Showing 1 changed file with 78 additions and 19 deletions.
97 changes: 78 additions & 19 deletions arrow-schema/src/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,22 @@ impl Fields {
/// assert_eq!(filtered, expected);
/// ```
pub fn filter_leaves<F: FnMut(usize, &FieldRef) -> bool>(&self, mut filter: F) -> Self {
fn filter_field<F: FnMut(&FieldRef) -> bool>(
self.try_filter_leaves(|idx, field| Ok(filter(idx, field)))
.unwrap()
}

/// Returns a copy of this [`Fields`] containing only those [`FieldRef`] passing a predicate
/// or an error if the predicate fails.
///
/// See [`Fields::filter_leaves`] for more information.
pub fn try_filter_leaves<F: FnMut(usize, &FieldRef) -> Result<bool, ArrowError>>(
&self,
mut filter: F,
) -> Result<Self, ArrowError> {
fn filter_field<F: FnMut(&FieldRef) -> Result<bool, ArrowError>>(
f: &FieldRef,
filter: &mut F,
) -> Option<FieldRef> {
) -> Result<Option<FieldRef>, ArrowError> {
use DataType::*;

let v = match f.data_type() {
Expand All @@ -149,35 +161,72 @@ impl Fields {
d => d,
};
let d = match v {
List(child) => List(filter_field(child, filter)?),
LargeList(child) => LargeList(filter_field(child, filter)?),
Map(child, ordered) => Map(filter_field(child, filter)?, *ordered),
FixedSizeList(child, size) => FixedSizeList(filter_field(child, filter)?, *size),
List(child) => {
let fields = filter_field(child, filter)?;
if let Some(fields) = fields {
List(fields)
} else {
return Ok(None);
}
}
LargeList(child) => {
let fields = filter_field(child, filter)?;
if let Some(fields) = fields {
LargeList(fields)
} else {
return Ok(None);
}
}
Map(child, ordered) => {
let fields = filter_field(child, filter)?;
if let Some(fields) = fields {
Map(fields, *ordered)
} else {
return Ok(None);
}
}
FixedSizeList(child, size) => {
let fields = filter_field(child, filter)?;
if let Some(fields) = fields {
FixedSizeList(fields, *size)
} else {
return Ok(None);
}
}
Struct(fields) => {
let filtered: Fields = fields
let filtered: Result<Vec<_>, _> =
fields.iter().map(|f| filter_field(f, filter)).collect();
let filtered: Fields = filtered?
.iter()
.filter_map(|f| filter_field(f, filter))
.filter_map(|f| f.as_ref().cloned())
.collect();

if filtered.is_empty() {
return None;
return Ok(None);
}

Struct(filtered)
}
Union(fields, mode) => {
let filtered: UnionFields = fields
let filtered: Result<Vec<_>, _> = fields
.iter()
.map(|(id, f)| filter_field(f, filter).map(|f| f.map(|f| (id, f))))
.collect();
let filtered: UnionFields = filtered?
.iter()
.filter_map(|(id, f)| Some((id, filter_field(f, filter)?)))
.filter_map(|f| f.as_ref().cloned())
.collect();

if filtered.is_empty() {
return None;
return Ok(None);
}

Union(filtered, *mode)
}
_ => return filter(f).then(|| f.clone()),
_ => {
let filtered = filter(f)?;
return Ok(filtered.then(|| f.clone()));
}
};
let d = match f.data_type() {
Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)),
Expand All @@ -186,20 +235,26 @@ impl Fields {
}
_ => d,
};
Some(Arc::new(f.as_ref().clone().with_data_type(d)))
Ok(Some(Arc::new(f.as_ref().clone().with_data_type(d))))
}

let mut leaf_idx = 0;
let mut filter = |f: &FieldRef| {
let t = filter(leaf_idx, f);
let t = filter(leaf_idx, f)?;
leaf_idx += 1;
t
Ok(t)
};

self.0
let filtered: Result<Vec<_>, _> = self
.0
.iter()
.filter_map(|f| filter_field(f, &mut filter))
.collect()
.map(|f| filter_field(f, &mut filter))
.collect();
let filtered = filtered?
.iter()
.filter_map(|f| f.as_ref().cloned())
.collect();
Ok(filtered)
}

/// Remove a field by index and return it.
Expand Down Expand Up @@ -531,5 +586,9 @@ mod tests {
let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15);
assert_eq!(r.len(), 1);
assert_eq!(r[0], fields[9]);

// Propagate error
let r = fields.try_filter_leaves(|_, _| Err(ArrowError::SchemaError("error".to_string())));
assert!(r.is_err());
}
}

0 comments on commit ec85145

Please sign in to comment.