Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Schema::project and RecordBatch::project functions #1033

Merged
merged 4 commits into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion arrow/src/datatypes/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,23 @@ impl Schema {
Self { fields, metadata }
}

/// Returns a new schema with only the specified columns in the new schema
/// This carries metadata from the parent schema over as well
pub fn project(&self, indices: &[usize]) -> Result<Schema> {
let new_fields = indices.into_iter()
.map(|i| {
self.fields.get(*i).cloned().ok_or_else(|| {
ArrowError::SchemaError(format!(
"project index {} out of bounds, max field {}",
i,
self.fields().len()
))
})
})
.collect::<Result<Vec<_>>>()?;
Ok(Self::new_with_metadata(new_fields, self.metadata.clone()))
}

/// Merge schema into self if it is compatible. Struct fields will be merged recursively.
///
/// Example:
Expand Down Expand Up @@ -115,7 +132,7 @@ impl Schema {
/// ]),
/// );
/// ```
pub fn try_merge(schemas: impl IntoIterator<Item = Self>) -> Result<Self> {
pub fn try_merge(schemas: impl IntoIterator<Item=Self>) -> Result<Self> {
schemas
.into_iter()
.try_fold(Self::empty(), |mut merged, schema| {
Expand Down Expand Up @@ -369,4 +386,48 @@ mod tests {

assert_eq!(schema, de_schema);
}

#[test]
fn test_projection() {
let mut metadata = HashMap::new();
metadata.insert("meta".to_string(), "data".to_string());

let schema = Schema::new_with_metadata(
vec![
Field::new("name", DataType::Utf8, false),
Field::new("address", DataType::Utf8, false),
Field::new("priority", DataType::UInt8, false),
],
metadata,
);

let projected: Schema = schema.project(&[0, 2]).unwrap();

assert_eq!(projected.fields().len(), 2);
assert_eq!(projected.fields()[0].name(), "name");
assert_eq!(projected.fields()[1].name(), "priority");
assert_eq!(projected.metadata.get("meta").unwrap(), "data")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to above -- I recommend a test for handling if index is out of bounds -- like schema.project([2, 3])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do


#[test]
fn test_oob_projection() {
let mut metadata = HashMap::new();
metadata.insert("meta".to_string(), "data".to_string());

let schema = Schema::new_with_metadata(
vec![
Field::new("name", DataType::Utf8, false),
Field::new("address", DataType::Utf8, false),
Field::new("priority", DataType::UInt8, false),
],
metadata,
);

let projected: Result<Schema> = schema.project(&vec![0, 3]);

assert!(projected.is_err());
if let Err(e) = projected {
assert_eq!(e.to_string(), "Schema error: project index 3 out of bounds, max field 3".to_string())
}
}
}
38 changes: 38 additions & 0 deletions arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,25 @@ impl RecordBatch {
self.schema.clone()
}

/// Projects the schema onto the specified columns
pub fn project(&self, indices: &[usize]) -> Result<RecordBatch> {
let projected_schema = self.schema.project(indices)?;
let batch_fields = indices
.into_iter()
.map(|f| {
self.columns.get(*f).cloned().ok_or_else(|| {
ArrowError::SchemaError(format!(
"project index {} out of bounds, max field {}",
f,
self.columns.len()
))
})
})
.collect::<Result<Vec<_>>>()?;

RecordBatch::try_new(SchemaRef::new(projected_schema), batch_fields)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about some tests?

Perhaps something like

    #[test]
    fn project() {
        let a: ArrayRef = Arc::new(Int32Array::from(vec![
            Some(1),
            None,
            Some(3),
        ]));
        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));

        let record_batch = RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
            .expect("valid conversion");

        let expected = RecordBatch::try_from_iter(vec![("a", a), ("c", c)])
            .expect("valid conversion");

        assert_eq!(expected, record_batch.project(&vec![0, 2]).unwrap());
    }

/// Returns the number of columns in the record batch.
///
/// # Example
Expand Down Expand Up @@ -900,4 +919,23 @@ mod tests {

assert_ne!(batch1, batch2);
}

#[test]
fn project() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Some(1),
None,
Some(3),
]));
let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));

let record_batch = RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
.expect("valid conversion");

let expected = RecordBatch::try_from_iter(vec![("a", a), ("c", c)])
.expect("valid conversion");

assert_eq!(expected, record_batch.project(&vec![0, 2]).unwrap());
}
}