Skip to content

Commit

Permalink
Implement Take for UnionArray (apache#4883)
Browse files Browse the repository at this point in the history
Implement Take for UnionArray (apache#4883)
  • Loading branch information
avantgardnerio authored Oct 2, 2023
1 parent 39e4d94 commit 4320a75
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion arrow-select/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use arrow_buffer::{
ScalarBuffer,
};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType, FieldRef};
use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};

use num::{One, Zero};

Expand Down Expand Up @@ -223,6 +223,21 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
Ok(new_null_array(&DataType::Null, indices.len()))
}
}
DataType::Union(fields, UnionMode::Sparse) => {
let mut field_type_ids = Vec::with_capacity(fields.len());
let mut children = Vec::with_capacity(fields.len());
let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
let type_ids = take_native(values.type_ids(), indices).into_inner();
for (type_id, field) in fields.iter() {
let values = values.child(type_id);
let values = take_impl(values, indices)?;
let field = (**field).clone();
children.push((field, values));
field_type_ids.push(type_id);
}
let array = UnionArray::try_new(field_type_ids.as_slice(), type_ids, None, children)?;
Ok(Arc::new(array))
}
t => unimplemented!("Take not supported for data type {:?}", t)
}
}
Expand Down Expand Up @@ -2013,4 +2028,41 @@ mod tests {
let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
assert_eq!(&values, &[Some("foo"), None, None, None])
}

#[test]
fn test_take_union() {
let structs = create_test_struct(vec![
Some((Some(true), Some(42))),
Some((Some(false), Some(28))),
Some((Some(false), Some(19))),
Some((Some(true), Some(31))),
None,
]);
let strings =
StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
let type_ids = Buffer::from_slice_ref(vec![1i8; 5]);

let children: Vec<(Field, Arc<dyn Array>)> = vec![
(
Field::new("f1", structs.data_type().clone(), true),
Arc::new(structs),
),
(
Field::new("f2", strings.data_type().clone(), true),
Arc::new(strings),
),
];
let array = UnionArray::try_new(&[0, 1], type_ids, None, children).unwrap();

let indices = vec![0, 3, 1, 0, 2, 4];
let index = UInt32Array::from(indices.clone());
let actual = take(&array, &index, None).unwrap();
let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
let strings = actual.child(1);
let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();

let actual = strings.iter().collect::<Vec<_>>();
let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
assert_eq!(expected, actual);
}
}

0 comments on commit 4320a75

Please sign in to comment.