Skip to content

Commit

Permalink
Add UnionFields (apache#3955) (apache#3981)
Browse files Browse the repository at this point in the history
* Add UnionFields (apache#3955)

* Fix array_cast

* Review feedback

* Clippy
  • Loading branch information
tustvold authored Mar 30, 2023
1 parent a9ac325 commit e5a1676
Show file tree
Hide file tree
Showing 27 changed files with 430 additions and 289 deletions.
16 changes: 9 additions & 7 deletions arrow-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
DataType::LargeList(_) => Arc::new(LargeListArray::from(data)) as ArrayRef,
DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef,
DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef,
DataType::Union(_, _, _) => Arc::new(UnionArray::from(data)) as ArrayRef,
DataType::Union(_, _) => Arc::new(UnionArray::from(data)) as ArrayRef,
DataType::FixedSizeList(_, _) => {
Arc::new(FixedSizeListArray::from(data)) as ArrayRef
}
Expand Down Expand Up @@ -740,7 +740,7 @@ mod tests {
use crate::cast::{as_union_array, downcast_array};
use crate::downcast_run_array;
use arrow_buffer::{Buffer, MutableBuffer};
use arrow_schema::{Field, Fields, UnionMode};
use arrow_schema::{Field, Fields, UnionFields, UnionMode};

#[test]
fn test_empty_primitive() {
Expand Down Expand Up @@ -874,11 +874,13 @@ mod tests {
fn test_null_union() {
for mode in [UnionMode::Sparse, UnionMode::Dense] {
let data_type = DataType::Union(
vec![
Field::new("foo", DataType::Int32, true),
Field::new("bar", DataType::Int64, true),
],
vec![2, 1],
UnionFields::new(
vec![2, 1],
vec![
Field::new("foo", DataType::Int32, true),
Field::new("bar", DataType::Int64, true),
],
),
mode,
);
let array = new_null_array(&data_type, 4);
Expand Down
53 changes: 26 additions & 27 deletions arrow-array/src/array/union_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{make_array, Array, ArrayRef};
use arrow_buffer::buffer::NullBuffer;
use arrow_buffer::{Buffer, ScalarBuffer};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field, UnionMode};
use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
/// Contains the `UnionArray` type.
///
use std::any::Any;
Expand Down Expand Up @@ -145,8 +145,7 @@ impl UnionArray {
value_offsets: Option<Buffer>,
child_arrays: Vec<(Field, ArrayRef)>,
) -> Self {
let (field_types, field_values): (Vec<_>, Vec<_>) =
child_arrays.into_iter().unzip();
let (fields, field_values): (Vec<_>, Vec<_>) = child_arrays.into_iter().unzip();
let len = type_ids.len();

let mode = if value_offsets.is_some() {
Expand All @@ -156,8 +155,7 @@ impl UnionArray {
};

let builder = ArrayData::builder(DataType::Union(
field_types,
Vec::from(field_type_ids),
UnionFields::new(field_type_ids.iter().copied(), fields),
mode,
))
.add_buffer(type_ids)
Expand Down Expand Up @@ -282,9 +280,9 @@ impl UnionArray {
/// Returns the names of the types in the union.
pub fn type_names(&self) -> Vec<&str> {
match self.data.data_type() {
DataType::Union(fields, _, _) => fields
DataType::Union(fields, _) => fields
.iter()
.map(|f| f.name().as_str())
.map(|(_, f)| f.name().as_str())
.collect::<Vec<&str>>(),
_ => unreachable!("Union array's data type is not a union!"),
}
Expand All @@ -293,7 +291,7 @@ impl UnionArray {
/// Returns whether the `UnionArray` is dense (or sparse if `false`).
fn is_dense(&self) -> bool {
match self.data.data_type() {
DataType::Union(_, _, mode) => mode == &UnionMode::Dense,
DataType::Union(_, mode) => mode == &UnionMode::Dense,
_ => unreachable!("Union array's data type is not a union!"),
}
}
Expand All @@ -307,8 +305,8 @@ impl UnionArray {

impl From<ArrayData> for UnionArray {
fn from(data: ArrayData) -> Self {
let (field_ids, mode) = match data.data_type() {
DataType::Union(_, ids, mode) => (ids, *mode),
let (fields, mode) = match data.data_type() {
DataType::Union(fields, mode) => (fields, *mode),
d => panic!("UnionArray expected ArrayData with type Union got {d}"),
};
let (type_ids, offsets) = match mode {
Expand All @@ -326,10 +324,10 @@ impl From<ArrayData> for UnionArray {
),
};

let max_id = field_ids.iter().copied().max().unwrap_or_default() as usize;
let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
let mut boxed_fields = vec![None; max_id + 1];
for (cd, field_id) in data.child_data().iter().zip(field_ids) {
boxed_fields[*field_id as usize] = Some(make_array(cd.clone()));
for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) {
boxed_fields[field_id as usize] = Some(make_array(cd.clone()));
}
Self {
data,
Expand Down Expand Up @@ -402,19 +400,18 @@ impl std::fmt::Debug for UnionArray {
writeln!(f, "-- type id buffer:")?;
writeln!(f, "{:?}", self.type_ids)?;

let (fields, ids) = match self.data_type() {
DataType::Union(f, ids, _) => (f, ids),
_ => unreachable!(),
};

if let Some(offsets) = &self.offsets {
writeln!(f, "-- offsets buffer:")?;
writeln!(f, "{:?}", offsets)?;
}

assert_eq!(fields.len(), ids.len());
for (field, type_id) in fields.iter().zip(ids) {
let child = self.child(*type_id);
let fields = match self.data_type() {
DataType::Union(fields, _) => fields,
_ => unreachable!(),
};

for (type_id, field) in fields.iter() {
let child = self.child(type_id);
writeln!(
f,
"-- child {}: \"{}\" ({:?})",
Expand Down Expand Up @@ -1058,12 +1055,14 @@ mod tests {
#[test]
fn test_custom_type_ids() {
let data_type = DataType::Union(
vec![
Field::new("strings", DataType::Utf8, false),
Field::new("integers", DataType::Int32, false),
Field::new("floats", DataType::Float64, false),
],
vec![8, 4, 9],
UnionFields::new(
vec![8, 4, 9],
vec![
Field::new("strings", DataType::Utf8, false),
Field::new("integers", DataType::Int32, false),
Field::new("floats", DataType::Float64, false),
],
),
UnionMode::Dense,
);

Expand Down
2 changes: 1 addition & 1 deletion arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ mod tests {
let record_batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
.unwrap();
assert_eq!(record_batch.get_array_memory_size(), 628);
assert_eq!(record_batch.get_array_memory_size(), 564);
}

fn check_batch(record_batch: RecordBatch, num_rows: usize) {
Expand Down
14 changes: 7 additions & 7 deletions arrow-cast/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ fn make_formatter<'a>(
}
DataType::Struct(_) => array_format(as_struct_array(array), options),
DataType::Map(_, _) => array_format(as_map_array(array), options),
DataType::Union(_, _, _) => array_format(as_union_array(array), options),
DataType::Union(_, _) => array_format(as_union_array(array), options),
d => Err(ArrowError::NotYetImplemented(format!("formatting {d} is not yet supported"))),
}
}
Expand Down Expand Up @@ -801,16 +801,16 @@ impl<'a> DisplayIndexState<'a> for &'a UnionArray {
);

fn prepare(&self, options: &FormatOptions<'a>) -> Result<Self::State, ArrowError> {
let (fields, type_ids, mode) = match (*self).data_type() {
DataType::Union(fields, type_ids, mode) => (fields, type_ids, mode),
let (fields, mode) = match (*self).data_type() {
DataType::Union(fields, mode) => (fields, mode),
_ => unreachable!(),
};

let max_id = type_ids.iter().copied().max().unwrap_or_default() as usize;
let max_id = fields.iter().map(|(id, _)| id).max().unwrap_or_default() as usize;
let mut out: Vec<Option<FieldDisplay>> = (0..max_id + 1).map(|_| None).collect();
for (i, field) in type_ids.iter().zip(fields) {
let formatter = make_formatter(self.child(*i).as_ref(), options)?;
out[*i as usize] = Some((field.name().as_str(), formatter))
for (i, field) in fields.iter() {
let formatter = make_formatter(self.child(i).as_ref(), options)?;
out[i as usize] = Some((field.name().as_str(), formatter))
}
Ok((out, *mode))
}
Expand Down
42 changes: 25 additions & 17 deletions arrow-cast/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,11 +703,13 @@ mod tests {
let schema = Schema::new(vec![Field::new(
"Teamsters",
DataType::Union(
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
],
vec![0, 1],
UnionFields::new(
vec![0, 1],
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
],
),
UnionMode::Dense,
),
false,
Expand Down Expand Up @@ -743,11 +745,13 @@ mod tests {
let schema = Schema::new(vec![Field::new(
"Teamsters",
DataType::Union(
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
],
vec![0, 1],
UnionFields::new(
vec![0, 1],
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
],
),
UnionMode::Sparse,
),
false,
Expand Down Expand Up @@ -785,11 +789,13 @@ mod tests {
let inner_field = Field::new(
"European Union",
DataType::Union(
vec![
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Float64, false),
],
vec![0, 1],
UnionFields::new(
vec![0, 1],
vec![
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Float64, false),
],
),
UnionMode::Dense,
),
false,
Expand All @@ -809,8 +815,10 @@ mod tests {
let schema = Schema::new(vec![Field::new(
"Teamsters",
DataType::Union(
vec![Field::new("a", DataType::Int32, true), inner_field],
vec![0, 1],
UnionFields::new(
vec![0, 1],
vec![Field::new("a", DataType::Int32, true), inner_field],
),
UnionMode::Sparse,
),
false,
Expand Down
25 changes: 13 additions & 12 deletions arrow-data/src/data/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
MutableBuffer::new(capacity * mem::size_of::<u8>()),
empty_buffer,
],
DataType::Union(_, _, mode) => {
DataType::Union(_, mode) => {
let type_ids = MutableBuffer::new(capacity * mem::size_of::<i8>());
match mode {
UnionMode::Sparse => [type_ids, empty_buffer],
Expand All @@ -162,7 +162,7 @@ pub(crate) fn into_buffers(
| DataType::Binary
| DataType::LargeUtf8
| DataType::LargeBinary => vec![buffer1.into(), buffer2.into()],
DataType::Union(_, _, mode) => {
DataType::Union(_, mode) => {
match mode {
// Based on Union's DataTypeLayout
UnionMode::Sparse => vec![buffer1.into()],
Expand Down Expand Up @@ -621,8 +621,9 @@ impl ArrayData {
vec![ArrayData::new_empty(v.as_ref())],
true,
),
DataType::Union(f, i, mode) => {
let ids = Buffer::from_iter(std::iter::repeat(i[0]).take(len));
DataType::Union(f, mode) => {
let (id, _) = f.iter().next().unwrap();
let ids = Buffer::from_iter(std::iter::repeat(id).take(len));
let buffers = match mode {
UnionMode::Sparse => vec![ids],
UnionMode::Dense => {
Expand All @@ -634,7 +635,7 @@ impl ArrayData {
let children = f
.iter()
.enumerate()
.map(|(idx, f)| match idx {
.map(|(idx, (_, f))| match idx {
0 => Self::new_null(f.data_type(), len),
_ => Self::new_empty(f.data_type()),
})
Expand Down Expand Up @@ -986,10 +987,10 @@ impl ArrayData {
}
Ok(())
}
DataType::Union(fields, _, mode) => {
DataType::Union(fields, mode) => {
self.validate_num_child_data(fields.len())?;

for (i, field) in fields.iter().enumerate() {
for (i, (_, field)) in fields.iter().enumerate() {
let field_data = self.get_valid_child_data(i, field.data_type())?;

if mode == &UnionMode::Sparse
Expand Down Expand Up @@ -1255,7 +1256,7 @@ impl ArrayData {
let child = &self.child_data[0];
self.validate_offsets_full::<i64>(child.len)
}
DataType::Union(_, _, _) => {
DataType::Union(_, _) => {
// Validate Union Array as part of implementing new Union semantics
// See comments in `ArrayData::validate()`
// https://github.com/apache/arrow-rs/issues/85
Expand Down Expand Up @@ -1568,7 +1569,7 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout {
DataType::LargeList(_) => DataTypeLayout::new_fixed_width(size_of::<i64>()),
DataType::Struct(_) => DataTypeLayout::new_empty(), // all in child data,
DataType::RunEndEncoded(_, _) => DataTypeLayout::new_empty(), // all in child data,
DataType::Union(_, _, mode) => {
DataType::Union(_, mode) => {
let type_ids = BufferSpec::FixedWidth {
byte_width: size_of::<i8>(),
};
Expand Down Expand Up @@ -1823,7 +1824,7 @@ impl From<ArrayData> for ArrayDataBuilder {
#[cfg(test)]
mod tests {
use super::*;
use arrow_schema::Field;
use arrow_schema::{Field, UnionFields};

// See arrow/tests/array_data_validation.rs for test of array validation

Expand Down Expand Up @@ -2072,8 +2073,8 @@ mod tests {
#[test]
fn test_into_buffers() {
let data_types = vec![
DataType::Union(vec![], vec![], UnionMode::Dense),
DataType::Union(vec![], vec![], UnionMode::Sparse),
DataType::Union(UnionFields::empty(), UnionMode::Dense),
DataType::Union(UnionFields::empty(), UnionMode::Sparse),
];

for data_type in data_types {
Expand Down
2 changes: 1 addition & 1 deletion arrow-data/src/equal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ fn equal_values(
fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len)
}
DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len),
DataType::Union(_, _, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len),
DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len),
DataType::Dictionary(data_type, _) => match data_type.as_ref() {
DataType::Int8 => dictionary_equal::<i8>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Int16 => {
Expand Down
Loading

0 comments on commit e5a1676

Please sign in to comment.