Skip to content

Commit

Permalink
Implement ArrayEqual for UnionArray (apache#1469)
Browse files Browse the repository at this point in the history
* init

* more

* Remove dense/sparse case

* Fix clippy

* For review

* For review
  • Loading branch information
viirya authored Mar 31, 2022
1 parent 15c87ae commit 0575a94
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 8 deletions.
5 changes: 4 additions & 1 deletion arrow/src/array/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2117,7 +2117,10 @@ impl UnionBuilder {
let mut field_data = match self.fields.remove(&type_name) {
Some(data) => data,
None => match self.value_offset_builder {
Some(_) => FieldData::new(self.fields.len() as i8, T::DATA_TYPE, None),
Some(_) => {
// For Dense Union, we don't build bitmap in individual field
FieldData::new(self.fields.len() as i8, T::DATA_TYPE, None)
}
None => {
let mut fd = FieldData::new(
self.fields.len() as i8,
Expand Down
104 changes: 101 additions & 3 deletions arrow/src/array/equal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ mod list;
mod null;
mod primitive;
mod structure;
mod union;
mod utils;
mod variable_size;

Expand All @@ -55,6 +56,7 @@ use list::list_equal;
use null::null_equal;
use primitive::primitive_equal;
use structure::struct_equal;
use union::union_equal;
use variable_size::variable_sized_equal;

impl PartialEq for dyn Array {
Expand Down Expand Up @@ -232,7 +234,9 @@ fn equal_values(
DataType::Struct(_) => {
struct_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
}
DataType::Union(_, _) => unimplemented!("See ARROW-8576"),
DataType::Union(_, _) => {
union_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
}
DataType::Dictionary(data_type, _) => match data_type.as_ref() {
DataType::Int8 => dictionary_equal::<i8>(
lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
Expand Down Expand Up @@ -313,11 +317,11 @@ mod tests {
array::Array, ArrayData, ArrayDataBuilder, ArrayRef, BinaryOffsetSizeTrait,
BooleanArray, FixedSizeBinaryBuilder, FixedSizeListBuilder, GenericBinaryArray,
Int32Builder, ListBuilder, NullArray, PrimitiveBuilder, StringArray,
StringDictionaryBuilder, StringOffsetSizeTrait, StructArray,
StringDictionaryBuilder, StringOffsetSizeTrait, StructArray, UnionBuilder,
};
use crate::array::{GenericStringArray, Int32Array};
use crate::buffer::Buffer;
use crate::datatypes::{Field, Int16Type, ToByteSlice};
use crate::datatypes::{Field, Int16Type, Int32Type, ToByteSlice};

use super::*;

Expand Down Expand Up @@ -1347,4 +1351,98 @@ mod tests {
// nulls in it, string1 and string2 are not equal
test_equal(string1, &string2, false);
}

#[test]
fn test_union_equal_dense() {
let mut builder = UnionBuilder::new_dense(7);
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
builder.append_null().unwrap();
builder.append::<Int32Type>("a", 6).unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union1 = builder.build().unwrap();

builder = UnionBuilder::new_dense(7);
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
builder.append_null().unwrap();
builder.append::<Int32Type>("a", 6).unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union2 = builder.build().unwrap();

builder = UnionBuilder::new_dense(7);
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 5).unwrap();
builder.append::<Int32Type>("c", 4).unwrap();
builder.append::<Int32Type>("a", 6).unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union3 = builder.build().unwrap();

builder = UnionBuilder::new_dense(7);
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
builder.append_null().unwrap();
builder.append_null().unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union4 = builder.build().unwrap();

test_equal(union1.data(), union2.data(), true);
test_equal(union1.data(), union3.data(), false);
test_equal(union1.data(), union4.data(), false);
}

#[test]
fn test_union_equal_sparse() {
let mut builder = UnionBuilder::new_sparse(7);
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
builder.append_null().unwrap();
builder.append::<Int32Type>("a", 6).unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union1 = builder.build().unwrap();

builder = UnionBuilder::new_sparse(7);
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
builder.append_null().unwrap();
builder.append::<Int32Type>("a", 6).unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union2 = builder.build().unwrap();

builder = UnionBuilder::new_sparse(7);
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 5).unwrap();
builder.append::<Int32Type>("c", 4).unwrap();
builder.append::<Int32Type>("a", 6).unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union3 = builder.build().unwrap();

builder = UnionBuilder::new_sparse(7);
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
builder.append_null().unwrap();
builder.append_null().unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union4 = builder.build().unwrap();

test_equal(union1.data(), union2.data(), true);
test_equal(union1.data(), union3.data(), false);
test_equal(union1.data(), union4.data(), false);
}
}
130 changes: 130 additions & 0 deletions arrow/src/array/equal/union.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::{
array::ArrayData, buffer::Buffer, datatypes::DataType, datatypes::UnionMode,
};

use super::{
equal_range, equal_values, utils::child_logical_null_buffer, utils::equal_nulls,
};

fn equal_dense(
lhs: &ArrayData,
rhs: &ArrayData,
lhs_type_ids: &[i8],
rhs_type_ids: &[i8],
lhs_offsets: &[i32],
rhs_offsets: &[i32],
) -> bool {
let offsets = lhs_offsets.iter().zip(rhs_offsets.iter());

lhs_type_ids
.iter()
.zip(rhs_type_ids.iter())
.zip(offsets)
.all(|((l_type_id, r_type_id), (l_offset, r_offset))| {
let lhs_values = &lhs.child_data()[*l_type_id as usize];
let rhs_values = &rhs.child_data()[*r_type_id as usize];

equal_values(
lhs_values,
rhs_values,
None,
None,
*l_offset as usize,
*r_offset as usize,
1,
)
})
}

fn equal_sparse(
lhs: &ArrayData,
rhs: &ArrayData,
lhs_nulls: Option<&Buffer>,
rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
) -> bool {
lhs.child_data()
.iter()
.zip(rhs.child_data())
.all(|(lhs_values, rhs_values)| {
// merge the null data
let lhs_merged_nulls = child_logical_null_buffer(lhs, lhs_nulls, lhs_values);
let rhs_merged_nulls = child_logical_null_buffer(rhs, rhs_nulls, rhs_values);
equal_range(
lhs_values,
rhs_values,
lhs_merged_nulls.as_ref(),
rhs_merged_nulls.as_ref(),
lhs_start,
rhs_start,
len,
)
})
}

pub(super) fn union_equal(
lhs: &ArrayData,
rhs: &ArrayData,
lhs_nulls: Option<&Buffer>,
rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
) -> bool {
let lhs_type_ids = lhs.buffer::<i8>(0);
let rhs_type_ids = rhs.buffer::<i8>(0);

let lhs_type_id_range = &lhs_type_ids[lhs_start..lhs_start + len];
let rhs_type_id_range = &rhs_type_ids[rhs_start..rhs_start + len];

match (lhs.data_type(), rhs.data_type()) {
(DataType::Union(_, UnionMode::Dense), DataType::Union(_, UnionMode::Dense)) => {
let lhs_offsets = lhs.buffer::<i32>(1);
let rhs_offsets = rhs.buffer::<i32>(1);

let lhs_offsets_range = &lhs_offsets[lhs_start..lhs_start + len];
let rhs_offsets_range = &rhs_offsets[rhs_start..rhs_start + len];

// nullness is kept in the parent UnionArray, so we compare its nulls here
lhs_type_id_range == rhs_type_id_range
&& equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
&& equal_dense(
lhs,
rhs,
lhs_type_id_range,
rhs_type_id_range,
lhs_offsets_range,
rhs_offsets_range,
)
}
(
DataType::Union(_, UnionMode::Sparse),
DataType::Union(_, UnionMode::Sparse),
) => {
lhs_type_id_range == rhs_type_id_range
&& equal_sparse(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
}
_ => unimplemented!(
"Logical equality not yet implemented between dense and sparse union arrays"
),
}
}
50 changes: 46 additions & 4 deletions arrow/src/array/equal/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use crate::array::{data::count_nulls, ArrayData, OffsetSizeTrait};
use crate::bitmap::Bitmap;
use crate::buffer::{Buffer, MutableBuffer};
use crate::datatypes::DataType;
use crate::datatypes::{DataType, UnionMode};
use crate::util::bit_util;

// whether bits along the positions are equal
Expand Down Expand Up @@ -67,6 +67,9 @@ pub(super) fn equal_nulls(
#[inline]
pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool {
let equal_type = match (lhs.data_type(), rhs.data_type()) {
(DataType::Union(l_fields, l_mode), DataType::Union(r_fields, r_mode)) => {
l_fields == r_fields && l_mode == r_mode
}
(DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted)) => {
let field_equal = match (l_field.data_type(), r_field.data_type()) {
(DataType::Struct(l_fields), DataType::Struct(r_fields))
Expand Down Expand Up @@ -189,16 +192,55 @@ pub(super) fn child_logical_null_buffer(
});
Some(buffer.into())
}
DataType::Union(_, _) => {
unimplemented!("Logical equality not yet implemented for union arrays")
}
DataType::Union(_, mode) => union_child_logical_null_buffer(
parent_data,
parent_len,
&parent_bitmap,
&self_null_bitmap,
mode,
),
DataType::Dictionary(_, _) => {
unimplemented!("Logical equality not yet implemented for nested dictionaries")
}
data_type => panic!("Data type {:?} is not a supported nested type", data_type),
}
}

pub(super) fn union_child_logical_null_buffer(
parent_data: &ArrayData,
parent_len: usize,
parent_bitmap: &Bitmap,
self_null_bitmap: &Bitmap,
mode: &UnionMode,
) -> Option<Buffer> {
match mode {
UnionMode::Sparse => {
// See the logic of `DataType::Struct` in `child_logical_null_buffer`.
let result = parent_bitmap & self_null_bitmap;
if let Ok(bitmap) = result {
return Some(bitmap.bits);
}

// slow path
let array_offset = parent_data.offset();
let mut buffer = MutableBuffer::new_null(parent_len);
let null_slice = buffer.as_slice_mut();
(0..parent_len).for_each(|index| {
if parent_bitmap.is_set(index + array_offset)
&& self_null_bitmap.is_set(index + array_offset)
{
bit_util::set_bit(null_slice, index);
}
});
Some(buffer.into())
}
UnionMode::Dense => {
// We don't keep bitmap in child data of Dense UnionArray
unimplemented!("Logical equality not yet implemented for dense union arrays")
}
}
}

// Calculate a list child's logical bitmap/buffer
#[inline]
fn logical_list_bitmap<OffsetSize: OffsetSizeTrait>(
Expand Down

0 comments on commit 0575a94

Please sign in to comment.