Skip to content

Commit

Permalink
Implement specialized min/max for GenericBinaryView (StringView a…
Browse files Browse the repository at this point in the history
…nd `BinaryView`) (apache#6089)

* implement better min/max for string view

* Apply suggestions from code review

Co-authored-by: Andrew Lamb <[email protected]>

* address review comments

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
XiangpengHao and alamb authored Jul 23, 2024
1 parent 93e4eb2 commit af40ea3
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 9 deletions.
51 changes: 47 additions & 4 deletions arrow-arith/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ use arrow_buffer::{ArrowNativeType, NullBuffer};
use arrow_data::bit_iterator::try_for_each_valid_idx;
use arrow_schema::*;
use std::borrow::BorrowMut;
use std::cmp::{self, Ordering};
use std::ops::{BitAnd, BitOr, BitXor};
use types::ByteViewType;

/// An accumulator for primitive numeric values.
trait NumericAccumulator<T: ArrowNativeTypeOp>: Copy + Default {
Expand Down Expand Up @@ -425,14 +427,55 @@ where
}
}

/// Helper to compute min/max of [`GenericByteViewArray<T>`].
/// The specialized min/max leverages the inlined values to compare the byte views.
/// `swap_cond` is the condition to swap current min/max with the new value.
/// For example, `Ordering::Greater` for max and `Ordering::Less` for min.
fn min_max_view_helper<T: ByteViewType>(
array: &GenericByteViewArray<T>,
swap_cond: cmp::Ordering,
) -> Option<&T::Native> {
let null_count = array.null_count();
if null_count == array.len() {
None
} else if null_count == 0 {
let target_idx = (0..array.len()).reduce(|acc, item| {
// SAFETY: array's length is correct so item is within bounds
let cmp = unsafe { GenericByteViewArray::compare_unchecked(array, item, array, acc) };
if cmp == swap_cond {
item
} else {
acc
}
});
// SAFETY: idx came from valid range `0..array.len()`
unsafe { target_idx.map(|idx| array.value_unchecked(idx)) }
} else {
let nulls = array.nulls().unwrap();

let target_idx = nulls.valid_indices().reduce(|acc_idx, idx| {
let cmp =
unsafe { GenericByteViewArray::compare_unchecked(array, idx, array, acc_idx) };
if cmp == swap_cond {
idx
} else {
acc_idx
}
});

// SAFETY: idx came from valid range `0..array.len()`
unsafe { target_idx.map(|idx| array.value_unchecked(idx)) }
}
}

/// Returns the maximum value in the binary array, according to the natural order.
pub fn max_binary<T: OffsetSizeTrait>(array: &GenericBinaryArray<T>) -> Option<&[u8]> {
min_max_helper::<&[u8], _, _>(array, |a, b| *a < *b)
}

/// Returns the maximum value in the binary view array, according to the natural order.
pub fn max_binary_view(array: &BinaryViewArray) -> Option<&[u8]> {
min_max_helper::<&[u8], _, _>(array, |a, b| *a < *b)
min_max_view_helper(array, Ordering::Greater)
}

/// Returns the minimum value in the binary array, according to the natural order.
Expand All @@ -442,7 +485,7 @@ pub fn min_binary<T: OffsetSizeTrait>(array: &GenericBinaryArray<T>) -> Option<&

/// Returns the minimum value in the binary view array, according to the natural order.
pub fn min_binary_view(array: &BinaryViewArray) -> Option<&[u8]> {
min_max_helper::<&[u8], _, _>(array, |a, b| *a > *b)
min_max_view_helper(array, Ordering::Less)
}

/// Returns the maximum value in the string array, according to the natural order.
Expand All @@ -452,7 +495,7 @@ pub fn max_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&

/// Returns the maximum value in the string view array, according to the natural order.
pub fn max_string_view(array: &StringViewArray) -> Option<&str> {
min_max_helper::<&str, _, _>(array, |a, b| *a < *b)
min_max_view_helper(array, Ordering::Greater)
}

/// Returns the minimum value in the string array, according to the natural order.
Expand All @@ -462,7 +505,7 @@ pub fn min_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&

/// Returns the minimum value in the string view array, according to the natural order.
pub fn min_string_view(array: &StringViewArray) -> Option<&str> {
min_max_helper::<&str, _, _>(array, |a, b| *a > *b)
min_max_view_helper(array, Ordering::Less)
}

/// Returns the sum of values in the array.
Expand Down
60 changes: 60 additions & 0 deletions arrow-array/src/array/byte_view_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,66 @@ impl<T: ByteViewType + ?Sized> GenericByteViewArray<T> {

builder.finish()
}

/// Comparing two [`GenericByteViewArray`] at index `left_idx` and `right_idx`
///
/// Comparing two ByteView types are non-trivial.
/// It takes a bit of patience to understand why we don't just compare two &[u8] directly.
///
/// ByteView types give us the following two advantages, and we need to be careful not to lose them:
/// (1) For string/byte smaller than 12 bytes, the entire data is inlined in the view.
/// Meaning that reading one array element requires only one memory access
/// (two memory access required for StringArray, one for offset buffer, the other for value buffer).
///
/// (2) For string/byte larger than 12 bytes, we can still be faster than (for certain operations) StringArray/ByteArray,
/// thanks to the inlined 4 bytes.
/// Consider equality check:
/// If the first four bytes of the two strings are different, we can return false immediately (with just one memory access).
///
/// If we directly compare two &[u8], we materialize the entire string (i.e., make multiple memory accesses), which might be unnecessary.
/// - Most of the time (eq, ord), we only need to look at the first 4 bytes to know the answer,
/// e.g., if the inlined 4 bytes are different, we can directly return unequal without looking at the full string.
///
/// # Order check flow
/// (1) if both string are smaller than 12 bytes, we can directly compare the data inlined to the view.
/// (2) if any of the string is larger than 12 bytes, we need to compare the full string.
/// (2.1) if the inlined 4 bytes are different, we can return the result immediately.
/// (2.2) o.w., we need to compare the full string.
///
/// # Safety
/// The left/right_idx must within range of each array
pub unsafe fn compare_unchecked(
left: &GenericByteViewArray<T>,
left_idx: usize,
right: &GenericByteViewArray<T>,
right_idx: usize,
) -> std::cmp::Ordering {
let l_view = left.views().get_unchecked(left_idx);
let l_len = *l_view as u32;

let r_view = right.views().get_unchecked(right_idx);
let r_len = *r_view as u32;

if l_len <= 12 && r_len <= 12 {
let l_data = unsafe { GenericByteViewArray::<T>::inline_value(l_view, l_len as usize) };
let r_data = unsafe { GenericByteViewArray::<T>::inline_value(r_view, r_len as usize) };
return l_data.cmp(r_data);
}

// one of the string is larger than 12 bytes,
// we then try to compare the inlined data first
let l_inlined_data = unsafe { GenericByteViewArray::<T>::inline_value(l_view, 4) };
let r_inlined_data = unsafe { GenericByteViewArray::<T>::inline_value(r_view, 4) };
if r_inlined_data != l_inlined_data {
return l_inlined_data.cmp(r_inlined_data);
}

// unfortunately, we need to compare the full data
let l_full_data: &[u8] = unsafe { left.value_unchecked(left_idx).as_ref() };
let r_full_data: &[u8] = unsafe { right.value_unchecked(right_idx).as_ref() };

l_full_data.cmp(r_full_data)
}
}

impl<T: ByteViewType + ?Sized> Debug for GenericByteViewArray<T> {
Expand Down
7 changes: 4 additions & 3 deletions arrow-ord/src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,13 @@ impl<'a, T: ByteViewType> ArrayOrd for &'a GenericByteViewArray<T> {
return false;
}

unsafe { compare_byte_view_unchecked(l.0, l.1, r.0, r.1).is_eq() }
unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, r.1).is_eq() }
}

fn is_lt(l: Self::Item, r: Self::Item) -> bool {
// # Safety
// The index is within bounds as it is checked in value()
unsafe { compare_byte_view_unchecked(l.0, l.1, r.0, r.1).is_lt() }
unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, r.1).is_lt() }
}

fn len(&self) -> usize {
Expand Down Expand Up @@ -626,7 +626,7 @@ pub fn compare_byte_view<T: ByteViewType>(
) -> std::cmp::Ordering {
assert!(left_idx < left.len());
assert!(right_idx < right.len());
unsafe { compare_byte_view_unchecked(left, left_idx, right, right_idx) }
unsafe { GenericByteViewArray::compare_unchecked(left, left_idx, right, right_idx) }
}

/// Comparing two [`GenericByteViewArray`] at index `left_idx` and `right_idx`
Expand Down Expand Up @@ -656,6 +656,7 @@ pub fn compare_byte_view<T: ByteViewType>(
///
/// # Safety
/// The left/right_idx must within range of each array
#[deprecated(note = "Use `GenericByteViewArray::compare_unchecked` instead")]
pub unsafe fn compare_byte_view_unchecked<T: ByteViewType>(
left: &GenericByteViewArray<T>,
left_idx: usize,
Expand Down
23 changes: 21 additions & 2 deletions arrow/benches/aggregate_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ fn add_benchmark(c: &mut Criterion) {
primitive_benchmark::<Int64Type>(c, "int64");

{
let nonnull_strings = create_string_array::<i32>(BATCH_SIZE, 0.0);
let nullable_strings = create_string_array::<i32>(BATCH_SIZE, 0.5);
let nonnull_strings = create_string_array_with_len::<i32>(BATCH_SIZE, 0.0, 16);
let nullable_strings = create_string_array_with_len::<i32>(BATCH_SIZE, 0.5, 16);
c.benchmark_group("string")
.throughput(Throughput::Elements(BATCH_SIZE as u64))
.bench_function("min nonnull", |b| b.iter(|| min_string(&nonnull_strings)))
Expand All @@ -67,6 +67,25 @@ fn add_benchmark(c: &mut Criterion) {
.bench_function("max nullable", |b| b.iter(|| max_string(&nullable_strings)));
}

{
let nonnull_strings = create_string_view_array_with_len(BATCH_SIZE, 0.0, 16, false);
let nullable_strings = create_string_view_array_with_len(BATCH_SIZE, 0.5, 16, false);
c.benchmark_group("string view")
.throughput(Throughput::Elements(BATCH_SIZE as u64))
.bench_function("min nonnull", |b| {
b.iter(|| min_string_view(&nonnull_strings))
})
.bench_function("max nonnull", |b| {
b.iter(|| max_string_view(&nonnull_strings))
})
.bench_function("min nullable", |b| {
b.iter(|| min_string_view(&nullable_strings))
})
.bench_function("max nullable", |b| {
b.iter(|| max_string_view(&nullable_strings))
});
}

{
let nonnull_bools_mixed = create_boolean_array(BATCH_SIZE, 0.0, 0.5);
let nonnull_bools_all_false = create_boolean_array(BATCH_SIZE, 0.0, 0.0);
Expand Down

0 comments on commit af40ea3

Please sign in to comment.