Skip to content

Commit

Permalink
ARROW-10402: [Rust] Refactor array equality
Browse files Browse the repository at this point in the history
This is a major refactor of the `equal.rs` module.

The rational for this change is many fold:

* currently array comparison requires downcasting the array ref to its concrete types. This is painful and not very ergonomics, as the user must "guess" what to downcast for comparison. We can see this in the hacks around `sort`, `take` and `concatenate` kernel's tests, and some of the tests of the builders.
* the code in array comparison is difficult to follow given the amount of calls that they perform around offsets.
* The implementation currently indirectly uses many of the `unsafe` APIs that we have (via pointer aritmetics), which makes it risky to operate and mutate.
* Some code is being repeated.

This PR:

1. adds `impl PartialEq for dyn Array`, to allow `Array` comparison based on `Array::data` (main change)
2. Makes array equality to only depend on `ArrayData`, i.e. it no longer depends on concrete array types (such as `PrimitiveArray` and related API) to perform comparisons.
3. Significantly reduces the risk of panics and UB when composite arrays are of different types, by checking the types on `range` comparison
4. Makes array equality be statically dispatched, via `match datatype`.
5. DRY the code around array equality
6. Fixes an error in equality of dictionary with equal values
7. Added tests to equalities that were not tested (fixed binary, some edge cases of dictionaries)
8. splits `equal.rs` in smaller, more manageable files.
9. Removes `ArrayListOps`, since it it no longer needed
10. Moves Json equality to its own module, for clarity.
11. removes the need to have two functions per type to compare arrays.
12. Adds the number of buffers and their respective width to datatypes from the specification. This was backported from #8401
13. adds a benchmark for array equality

Note that this does not implement `PartialEq` for `ArrayData`, only `dyn Array`, as different data does not imply a different array (due to nullability). That implementation is being worked on #8200.

IMO this PR significantly simplifies the code around array comparison, to the point where many implementations are 5 lines long.

This also improves performance by 10-40%.

<details>
 <summary>Benchmark results</summary>

```
Previous HEAD position was 3dd3c69 Added bench for equality.
Switched to branch 'equal'
Your branch is up to date with 'origin/equal'.
   Compiling arrow v3.0.0-SNAPSHOT (/Users/jorgecarleitao/projects/arrow/rust/arrow)
    Finished bench [optimized] target(s) in 51.28s
     Running /Users/jorgecarleitao/projects/arrow/rust/target/release/deps/equal-176c3cb11360bd12
Gnuplot not found, using plotters backend
equal_512               time:   [36.861 ns 36.894 ns 36.934 ns]
                        change: [-43.752% -43.400% -43.005%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 12 outliers among 100 measurements (12.00%)
  7 (7.00%) high mild
  5 (5.00%) high severe

equal_nulls_512         time:   [2.3271 us 2.3299 us 2.3331 us]
                        change: [-10.846% -9.0877% -7.7336%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 11 outliers among 100 measurements (11.00%)
  4 (4.00%) high mild
  7 (7.00%) high severe

equal_string_512        time:   [49.219 ns 49.347 ns 49.517 ns]
                        change: [-30.789% -30.538% -30.235%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 9 outliers among 100 measurements (9.00%)
  3 (3.00%) high mild
  6 (6.00%) high severe

equal_string_nulls_512  time:   [3.7873 us 3.7939 us 3.8013 us]
                        change: [-8.2944% -7.0636% -5.4266%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 12 outliers among 100 measurements (12.00%)
  4 (4.00%) high mild
  8 (8.00%) high severe
```

</details>

All tests are there, plus new tests for some of the edge cases and untested arrays.

This change is backward incompatible `array1.equals(&array2)` no longer works: use `array1 == array2` instead, which is the idiomatic way of comparing structs and trait objects in rust.

Closes #8541 from jorgecarleitao/equal

Authored-by: Jorge C. Leitao <[email protected]>
Signed-off-by: Neville Dipale <[email protected]>
  • Loading branch information
jorgecarleitao authored and nevi-me committed Nov 7, 2020
1 parent eb42c50 commit a04a15a
Show file tree
Hide file tree
Showing 22 changed files with 2,727 additions and 2,363 deletions.
4 changes: 4 additions & 0 deletions rust/arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,7 @@ harness = false
[[bench]]
name = "csv_writer"
harness = false

[[bench]]
name = "equal"
harness = false
85 changes: 85 additions & 0 deletions rust/arrow/benches/equal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#[macro_use]
extern crate criterion;
use criterion::Criterion;

use rand::distributions::Alphanumeric;
use rand::Rng;
use std::sync::Arc;

extern crate arrow;

use arrow::array::*;

fn create_string_array(size: usize, with_nulls: bool) -> ArrayRef {
// use random numbers to avoid spurious compiler optimizations wrt to branching
let mut rng = rand::thread_rng();
let mut builder = StringBuilder::new(size);

for _ in 0..size {
if with_nulls && rng.gen::<f32>() > 0.5 {
builder.append_null().unwrap();
} else {
let string = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(10)
.collect::<String>();
builder.append_value(&string).unwrap();
}
}
Arc::new(builder.finish())
}

fn create_array(size: usize, with_nulls: bool) -> ArrayRef {
// use random numbers to avoid spurious compiler optimizations wrt to branching
let mut rng = rand::thread_rng();
let mut builder = Float32Builder::new(size);

for _ in 0..size {
if with_nulls && rng.gen::<f32>() > 0.5 {
builder.append_null().unwrap();
} else {
builder.append_value(rng.gen()).unwrap();
}
}
Arc::new(builder.finish())
}

fn bench_equal(arr_a: &ArrayRef) {
criterion::black_box(arr_a == arr_a);
}

fn add_benchmark(c: &mut Criterion) {
let arr_a = create_array(512, false);
c.bench_function("equal_512", |b| b.iter(|| bench_equal(&arr_a)));

let arr_a_nulls = create_array(512, true);
c.bench_function("equal_nulls_512", |b| b.iter(|| bench_equal(&arr_a_nulls)));

let arr_a = create_string_array(512, false);
c.bench_function("equal_string_512", |b| b.iter(|| bench_equal(&arr_a)));

let arr_a_nulls = create_string_array(512, true);
c.bench_function("equal_string_nulls_512", |b| {
b.iter(|| bench_equal(&arr_a_nulls))
});
}

criterion_group!(benches, add_benchmark);
criterion_main!(benches);
80 changes: 8 additions & 72 deletions rust/arrow/src/array/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use num::Num;

use super::*;
use crate::array::builder::StringDictionaryBuilder;
use crate::array::equal::JsonEqual;
use crate::array::equal_json::JsonEqual;
use crate::buffer::{buffer_bin_or, Buffer, MutableBuffer};
use crate::datatypes::DataType::Struct;
use crate::datatypes::*;
Expand All @@ -50,7 +50,7 @@ const NANOSECONDS: i64 = 1_000_000_000;

/// Trait for dealing with different types of array at runtime when the type of the
/// array is not known in advance.
pub trait Array: fmt::Debug + Send + Sync + ArrayEqual + JsonEqual {
pub trait Array: fmt::Debug + Send + Sync + JsonEqual {
/// Returns the array as [`Any`](std::any::Any) so that it can be
/// downcasted to a specific implementation.
///
Expand Down Expand Up @@ -112,10 +112,10 @@ pub trait Array: fmt::Debug + Send + Sync + ArrayEqual + JsonEqual {
/// // Make slice over the values [2, 3, 4]
/// let array_slice = array.slice(1, 3);
///
/// assert!(array_slice.equals(&Int32Array::from(vec![2, 3, 4])));
/// assert_eq!(array_slice.as_ref(), &Int32Array::from(vec![2, 3, 4]));
/// ```
fn slice(&self, offset: usize, length: usize) -> ArrayRef {
make_array(slice_data(self.data_ref(), offset, length))
make_array(Arc::new(self.data_ref().as_ref().slice(offset, length)))
}

/// Returns the length (i.e., number of elements) of this array.
Expand Down Expand Up @@ -182,8 +182,7 @@ pub trait Array: fmt::Debug + Send + Sync + ArrayEqual + JsonEqual {
/// assert_eq!(array.is_null(1), true);
/// ```
fn is_null(&self, index: usize) -> bool {
let data = self.data_ref();
data.is_null(data.offset() + index)
self.data().is_null(index)
}

/// Returns whether the element at `index` is not null.
Expand All @@ -200,8 +199,7 @@ pub trait Array: fmt::Debug + Send + Sync + ArrayEqual + JsonEqual {
/// assert_eq!(array.is_valid(1), false);
/// ```
fn is_valid(&self, index: usize) -> bool {
let data = self.data_ref();
data.is_valid(data.offset() + index)
self.data().is_valid(index)
}

/// Returns the total number of null values in this array.
Expand Down Expand Up @@ -338,33 +336,6 @@ pub fn make_array(data: ArrayDataRef) -> ArrayRef {
}
}

/// Creates a zero-copy slice of the array's data.
///
/// # Panics
///
/// Panics if `offset + length > data.len()`.
fn slice_data(data: &ArrayDataRef, mut offset: usize, length: usize) -> ArrayDataRef {
assert!((offset + length) <= data.len());

let mut new_data = data.as_ref().clone();
let len = std::cmp::min(new_data.len - offset, length);

offset += data.offset;
new_data.len = len;
new_data.offset = offset;

// Calculate the new null count based on the offset
new_data.null_count = if let Some(bitmap) = new_data.null_bitmap() {
let valid_bits = bitmap.bits.data();
len.checked_sub(bit_util::count_set_bits_offset(valid_bits, offset, length))
.unwrap()
} else {
0
};

Arc::new(new_data)
}

// creates a new MutableBuffer initializes all falsed
// this is useful to populate null bitmaps
fn make_null_buffer(len: usize) -> MutableBuffer {
Expand Down Expand Up @@ -853,11 +824,6 @@ impl<T: ArrowPrimitiveType> From<ArrayDataRef> for PrimitiveArray<T> {
}
}

/// Common operations for List types.
pub trait ListArrayOps<OffsetSize: OffsetSizeTrait> {
fn value_offset_at(&self, i: usize) -> OffsetSize;
}

/// trait declaring an offset size, relevant for i32 vs i64 array types.
pub trait OffsetSizeTrait: ArrowNativeType + Num + Ord {
fn prefix() -> &'static str;
Expand Down Expand Up @@ -1033,14 +999,6 @@ impl<OffsetSize: OffsetSizeTrait> fmt::Debug for GenericListArray<OffsetSize> {
}
}

impl<OffsetSize: OffsetSizeTrait> ListArrayOps<OffsetSize>
for GenericListArray<OffsetSize>
{
fn value_offset_at(&self, i: usize) -> OffsetSize {
self.value_offset_at(i)
}
}

/// A list array where each element is a variable-sized sequence of values with the same
/// type whose memory offsets between elements are represented by a i32.
pub type ListArray = GenericListArray<i32>;
Expand Down Expand Up @@ -1327,14 +1285,6 @@ impl<OffsetSize: BinaryOffsetSizeTrait> Array for GenericBinaryArray<OffsetSize>
}
}

impl<OffsetSize: BinaryOffsetSizeTrait> ListArrayOps<OffsetSize>
for GenericBinaryArray<OffsetSize>
{
fn value_offset_at(&self, i: usize) -> OffsetSize {
self.value_offset_at(i)
}
}

impl<OffsetSize: BinaryOffsetSizeTrait> From<ArrayDataRef>
for GenericBinaryArray<OffsetSize>
{
Expand Down Expand Up @@ -1691,14 +1641,6 @@ impl<OffsetSize: StringOffsetSizeTrait> From<ArrayDataRef>
}
}

impl<OffsetSize: StringOffsetSizeTrait> ListArrayOps<OffsetSize>
for GenericStringArray<OffsetSize>
{
fn value_offset_at(&self, i: usize) -> OffsetSize {
self.value_offset_at(i)
}
}

/// An array where each element is a variable-sized sequence of bytes representing a string
/// whose maximum length (in bytes) is represented by a i32.
pub type StringArray = GenericStringArray<i32>;
Expand Down Expand Up @@ -1794,12 +1736,6 @@ impl FixedSizeBinaryArray {
}
}

impl ListArrayOps<i32> for FixedSizeBinaryArray {
fn value_offset_at(&self, i: usize) -> i32 {
self.value_offset_at(i)
}
}

impl From<ArrayDataRef> for FixedSizeBinaryArray {
fn from(data: ArrayDataRef) -> Self {
assert_eq!(
Expand Down Expand Up @@ -1935,8 +1871,8 @@ impl From<ArrayDataRef> for StructArray {
fn from(data: ArrayDataRef) -> Self {
let mut boxed_fields = vec![];
for cd in data.child_data() {
let child_data = if data.offset != 0 || data.len != cd.len {
slice_data(&cd, data.offset, data.len)
let child_data = if data.offset() != 0 || data.len() != cd.len() {
Arc::new(cd.as_ref().slice(data.offset(), data.len()))
} else {
cd.clone()
};
Expand Down
Loading

0 comments on commit a04a15a

Please sign in to comment.