diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs index 9c86716e1abe..38017b588f71 100644 --- a/arrow/src/array/builder.rs +++ b/arrow/src/array/builder.rs @@ -2117,7 +2117,10 @@ impl UnionBuilder { let mut field_data = match self.fields.remove(&type_name) { Some(data) => data, None => match self.value_offset_builder { - Some(_) => FieldData::new(self.fields.len() as i8, T::DATA_TYPE, None), + Some(_) => { + // For Dense Union, we don't build bitmap in individual field + FieldData::new(self.fields.len() as i8, T::DATA_TYPE, None) + } None => { let mut fd = FieldData::new( self.fields.len() as i8, diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs index 4f27786c6db1..07c173b13326 100644 --- a/arrow/src/array/equal/mod.rs +++ b/arrow/src/array/equal/mod.rs @@ -40,6 +40,7 @@ mod list; mod null; mod primitive; mod structure; +mod union; mod utils; mod variable_size; @@ -55,6 +56,7 @@ use list::list_equal; use null::null_equal; use primitive::primitive_equal; use structure::struct_equal; +use union::union_equal; use variable_size::variable_sized_equal; impl PartialEq for dyn Array { @@ -232,7 +234,9 @@ fn equal_values( DataType::Struct(_) => { struct_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) } - DataType::Union(_, _) => unimplemented!("See ARROW-8576"), + DataType::Union(_, _) => { + union_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } DataType::Dictionary(data_type, _) => match data_type.as_ref() { DataType::Int8 => dictionary_equal::( lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, @@ -313,11 +317,11 @@ mod tests { array::Array, ArrayData, ArrayDataBuilder, ArrayRef, BinaryOffsetSizeTrait, BooleanArray, FixedSizeBinaryBuilder, FixedSizeListBuilder, GenericBinaryArray, Int32Builder, ListBuilder, NullArray, PrimitiveBuilder, StringArray, - StringDictionaryBuilder, StringOffsetSizeTrait, StructArray, + StringDictionaryBuilder, StringOffsetSizeTrait, StructArray, UnionBuilder, }; use crate::array::{GenericStringArray, Int32Array}; use crate::buffer::Buffer; - use crate::datatypes::{Field, Int16Type, ToByteSlice}; + use crate::datatypes::{Field, Int16Type, Int32Type, ToByteSlice}; use super::*; @@ -1347,4 +1351,98 @@ mod tests { // nulls in it, string1 and string2 are not equal test_equal(string1, &string2, false); } + + #[test] + fn test_union_equal_dense() { + let mut builder = UnionBuilder::new_dense(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append_null().unwrap(); + builder.append::("a", 6).unwrap(); + builder.append::("b", 7).unwrap(); + let union1 = builder.build().unwrap(); + + builder = UnionBuilder::new_dense(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append_null().unwrap(); + builder.append::("a", 6).unwrap(); + builder.append::("b", 7).unwrap(); + let union2 = builder.build().unwrap(); + + builder = UnionBuilder::new_dense(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 5).unwrap(); + builder.append::("c", 4).unwrap(); + builder.append::("a", 6).unwrap(); + builder.append::("b", 7).unwrap(); + let union3 = builder.build().unwrap(); + + builder = UnionBuilder::new_dense(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append_null().unwrap(); + builder.append_null().unwrap(); + builder.append::("b", 7).unwrap(); + let union4 = builder.build().unwrap(); + + test_equal(union1.data(), union2.data(), true); + test_equal(union1.data(), union3.data(), false); + test_equal(union1.data(), union4.data(), false); + } + + #[test] + fn test_union_equal_sparse() { + let mut builder = UnionBuilder::new_sparse(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append_null().unwrap(); + builder.append::("a", 6).unwrap(); + builder.append::("b", 7).unwrap(); + let union1 = builder.build().unwrap(); + + builder = UnionBuilder::new_sparse(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append_null().unwrap(); + builder.append::("a", 6).unwrap(); + builder.append::("b", 7).unwrap(); + let union2 = builder.build().unwrap(); + + builder = UnionBuilder::new_sparse(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 5).unwrap(); + builder.append::("c", 4).unwrap(); + builder.append::("a", 6).unwrap(); + builder.append::("b", 7).unwrap(); + let union3 = builder.build().unwrap(); + + builder = UnionBuilder::new_sparse(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append_null().unwrap(); + builder.append_null().unwrap(); + builder.append::("b", 7).unwrap(); + let union4 = builder.build().unwrap(); + + test_equal(union1.data(), union2.data(), true); + test_equal(union1.data(), union3.data(), false); + test_equal(union1.data(), union4.data(), false); + } } diff --git a/arrow/src/array/equal/union.rs b/arrow/src/array/equal/union.rs new file mode 100644 index 000000000000..36cd19725b5d --- /dev/null +++ b/arrow/src/array/equal/union.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + array::ArrayData, buffer::Buffer, datatypes::DataType, datatypes::UnionMode, +}; + +use super::{ + equal_range, equal_values, utils::child_logical_null_buffer, utils::equal_nulls, +}; + +fn equal_dense( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_type_ids: &[i8], + rhs_type_ids: &[i8], + lhs_offsets: &[i32], + rhs_offsets: &[i32], +) -> bool { + let offsets = lhs_offsets.iter().zip(rhs_offsets.iter()); + + lhs_type_ids + .iter() + .zip(rhs_type_ids.iter()) + .zip(offsets) + .all(|((l_type_id, r_type_id), (l_offset, r_offset))| { + let lhs_values = &lhs.child_data()[*l_type_id as usize]; + let rhs_values = &rhs.child_data()[*r_type_id as usize]; + + equal_values( + lhs_values, + rhs_values, + None, + None, + *l_offset as usize, + *r_offset as usize, + 1, + ) + }) +} + +fn equal_sparse( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_nulls: Option<&Buffer>, + rhs_nulls: Option<&Buffer>, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + lhs.child_data() + .iter() + .zip(rhs.child_data()) + .all(|(lhs_values, rhs_values)| { + // merge the null data + let lhs_merged_nulls = child_logical_null_buffer(lhs, lhs_nulls, lhs_values); + let rhs_merged_nulls = child_logical_null_buffer(rhs, rhs_nulls, rhs_values); + equal_range( + lhs_values, + rhs_values, + lhs_merged_nulls.as_ref(), + rhs_merged_nulls.as_ref(), + lhs_start, + rhs_start, + len, + ) + }) +} + +pub(super) fn union_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_nulls: Option<&Buffer>, + rhs_nulls: Option<&Buffer>, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_type_ids = lhs.buffer::(0); + let rhs_type_ids = rhs.buffer::(0); + + let lhs_type_id_range = &lhs_type_ids[lhs_start..lhs_start + len]; + let rhs_type_id_range = &rhs_type_ids[rhs_start..rhs_start + len]; + + match (lhs.data_type(), rhs.data_type()) { + (DataType::Union(_, UnionMode::Dense), DataType::Union(_, UnionMode::Dense)) => { + let lhs_offsets = lhs.buffer::(1); + let rhs_offsets = rhs.buffer::(1); + + let lhs_offsets_range = &lhs_offsets[lhs_start..lhs_start + len]; + let rhs_offsets_range = &rhs_offsets[rhs_start..rhs_start + len]; + + // nullness is kept in the parent UnionArray, so we compare its nulls here + lhs_type_id_range == rhs_type_id_range + && equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + && equal_dense( + lhs, + rhs, + lhs_type_id_range, + rhs_type_id_range, + lhs_offsets_range, + rhs_offsets_range, + ) + } + ( + DataType::Union(_, UnionMode::Sparse), + DataType::Union(_, UnionMode::Sparse), + ) => { + lhs_type_id_range == rhs_type_id_range + && equal_sparse(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } + _ => unimplemented!( + "Logical equality not yet implemented between dense and sparse union arrays" + ), + } +} diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs index 1bced978c1b5..cf2eb3064fa3 100644 --- a/arrow/src/array/equal/utils.rs +++ b/arrow/src/array/equal/utils.rs @@ -18,7 +18,7 @@ use crate::array::{data::count_nulls, ArrayData, OffsetSizeTrait}; use crate::bitmap::Bitmap; use crate::buffer::{Buffer, MutableBuffer}; -use crate::datatypes::DataType; +use crate::datatypes::{DataType, UnionMode}; use crate::util::bit_util; // whether bits along the positions are equal @@ -67,6 +67,9 @@ pub(super) fn equal_nulls( #[inline] pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { let equal_type = match (lhs.data_type(), rhs.data_type()) { + (DataType::Union(l_fields, l_mode), DataType::Union(r_fields, r_mode)) => { + l_fields == r_fields && l_mode == r_mode + } (DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted)) => { let field_equal = match (l_field.data_type(), r_field.data_type()) { (DataType::Struct(l_fields), DataType::Struct(r_fields)) @@ -189,9 +192,13 @@ pub(super) fn child_logical_null_buffer( }); Some(buffer.into()) } - DataType::Union(_, _) => { - unimplemented!("Logical equality not yet implemented for union arrays") - } + DataType::Union(_, mode) => union_child_logical_null_buffer( + parent_data, + parent_len, + &parent_bitmap, + &self_null_bitmap, + mode, + ), DataType::Dictionary(_, _) => { unimplemented!("Logical equality not yet implemented for nested dictionaries") } @@ -199,6 +206,41 @@ pub(super) fn child_logical_null_buffer( } } +pub(super) fn union_child_logical_null_buffer( + parent_data: &ArrayData, + parent_len: usize, + parent_bitmap: &Bitmap, + self_null_bitmap: &Bitmap, + mode: &UnionMode, +) -> Option { + match mode { + UnionMode::Sparse => { + // See the logic of `DataType::Struct` in `child_logical_null_buffer`. + let result = parent_bitmap & self_null_bitmap; + if let Ok(bitmap) = result { + return Some(bitmap.bits); + } + + // slow path + let array_offset = parent_data.offset(); + let mut buffer = MutableBuffer::new_null(parent_len); + let null_slice = buffer.as_slice_mut(); + (0..parent_len).for_each(|index| { + if parent_bitmap.is_set(index + array_offset) + && self_null_bitmap.is_set(index + array_offset) + { + bit_util::set_bit(null_slice, index); + } + }); + Some(buffer.into()) + } + UnionMode::Dense => { + // We don't keep bitmap in child data of Dense UnionArray + unimplemented!("Logical equality not yet implemented for dense union arrays") + } + } +} + // Calculate a list child's logical bitmap/buffer #[inline] fn logical_list_bitmap(