Skip to content

Commit

Permalink
feat: implement search_sorted_many (#840)
Browse files Browse the repository at this point in the history
Coming out of #823 , we find that search_sorted on BitPackedArray is
slow due to wastefully re-building `BitPackedArray`

This PR creates a new `search_sorted_bulk` that allows arrays to do some
up-front initialization before doing loops of repeated searches, like
`RunEndArray::find_physical_indices`

We're still about ~50% slower than #823 , unpack_single + branch
mispredicts (which I think is all of the self time in `search_sorted`)
seem to be the slowdown

<img width="3312" alt="image"
src="https://github.com/user-attachments/assets/f28fbe65-285e-4db7-a6d4-41a35391f6ea">
  • Loading branch information
a10y authored Sep 17, 2024
1 parent 7a0dc6d commit 79f816c
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 53 deletions.
212 changes: 176 additions & 36 deletions encodings/fastlanes/src/bitpacking/compute/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use std::cmp;
use std::cmp::Ordering;
use std::cmp::Ordering::Greater;

use fastlanes::BitPacking;
use itertools::Itertools;
use num_traits::AsPrimitive;
use vortex::array::{PrimitiveArray, SparseArray};
use vortex::array::SparseArray;
use vortex::compute::{
search_sorted, IndexOrd, Len, SearchResult, SearchSorted, SearchSortedFn, SearchSortedSide,
search_sorted_u64, IndexOrd, Len, SearchResult, SearchSorted, SearchSortedFn, SearchSortedSide,
};
use vortex::validity::Validity;
use vortex::{ArrayDType, IntoArrayVariant};
use vortex::ArrayDType;
use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType};
use vortex_error::{VortexError, VortexExpect as _, VortexResult};
use vortex_scalar::Scalar;
Expand All @@ -21,6 +23,61 @@ impl SearchSortedFn for BitPackedArray {
search_sorted_typed::<$P>(self, value, side)
})
}

fn search_sorted_u64(&self, value: u64, side: SearchSortedSide) -> VortexResult<SearchResult> {
match_each_unsigned_integer_ptype!(self.ptype(), |$P| {
// NOTE: conversion may truncate silently.
if let Some(pvalue) = num_traits::cast::<u64, $P>(value) {
search_sorted_native(self, pvalue, side)
} else {
// provided u64 is too large to fit in the provided PType, value must be off
// the right end of the array.
Ok(SearchResult::NotFound(self.len()))
}
})
}

fn search_sorted_many(
&self,
values: &[Scalar],
sides: &[SearchSortedSide],
) -> VortexResult<Vec<SearchResult>> {
match_each_unsigned_integer_ptype!(self.ptype(), |$P| {
let searcher = BitPackedSearch::<'_, $P>::new(self);

values
.iter()
.zip(sides.iter().copied())
.map(|(value, side)| {
// Unwrap to native value
let unwrapped_value: $P = value.cast(self.dtype())?.try_into()?;

Ok(searcher.search_sorted(&unwrapped_value, side))
})
.try_collect()
})
}

fn search_sorted_u64_many(
&self,
values: &[u64],
sides: &[SearchSortedSide],
) -> VortexResult<Vec<SearchResult>> {
match_each_unsigned_integer_ptype!(self.ptype(), |$P| {
let searcher = BitPackedSearch::<'_, $P>::new(self);

values
.iter()
.copied()
.zip(sides.iter().copied())
.map(|(value, side)| {
// NOTE: truncating cast
let cast_value: $P = value as $P;
Ok(searcher.search_sorted(&cast_value, side))
})
.try_collect()
})
}
}

fn search_sorted_typed<T>(
Expand All @@ -29,69 +86,92 @@ fn search_sorted_typed<T>(
side: SearchSortedSide,
) -> VortexResult<SearchResult>
where
T: NativePType + TryFrom<Scalar, Error = VortexError> + BitPacking + AsPrimitive<usize>,
T: NativePType
+ TryFrom<Scalar, Error = VortexError>
+ BitPacking
+ AsPrimitive<usize>
+ AsPrimitive<u64>,
{
let native_value: T = value.cast(array.dtype())?.try_into()?;
search_sorted_native(array, native_value, side)
}

/// Native variant of search_sorted that operates over Rust unsigned integer types.
fn search_sorted_native<T>(
array: &BitPackedArray,
value: T,
side: SearchSortedSide,
) -> VortexResult<SearchResult>
where
T: NativePType + BitPacking + AsPrimitive<usize> + AsPrimitive<u64>,
{
let unwrapped_value: T = value.cast(array.dtype())?.try_into()?;
if let Some(patches_array) = array.patches() {
// If patches exist they must be the last elements in the array, if the value we're looking for is greater than
// max packed value just search the patches
if unwrapped_value.as_() > array.max_packed_value() {
search_sorted(&patches_array, value.clone(), side)
let usize_value: usize = value.as_();
if usize_value > array.max_packed_value() {
search_sorted_u64(&patches_array, value.as_(), side)
} else {
Ok(BitPackedSearch::new(array).search_sorted(&unwrapped_value, side))
Ok(BitPackedSearch::<'_, T>::new(array).search_sorted(&value, side))
}
} else {
Ok(BitPackedSearch::new(array).search_sorted(&unwrapped_value, side))
Ok(BitPackedSearch::<'_, T>::new(array).search_sorted(&value, side))
}
}

/// This wrapper exists, so that you can't invoke SearchSorted::search_sorted directly on BitPackedArray as it omits searching patches
#[derive(Debug)]
struct BitPackedSearch {
packed: PrimitiveArray,
struct BitPackedSearch<'a, T> {
// NOTE: caching this here is important for performance, as each call to `maybe_null_slice`
// invokes a call to DType <> PType conversion
packed_maybe_null_slice: &'a [T],
offset: usize,
length: usize,
bit_width: usize,
min_patch_offset: Option<usize>,
validity: Validity,
first_invalid_idx: usize,
}

impl BitPackedSearch {
pub fn new(array: &BitPackedArray) -> Self {
impl<'a, T: BitPacking + NativePType> BitPackedSearch<'a, T> {
pub fn new(array: &'a BitPackedArray) -> Self {
let min_patch_offset = array
.patches()
.and_then(|p| {
SparseArray::try_from(p)
.vortex_expect("Only sparse patches are supported")
.min_index()
})
.unwrap_or_else(|| array.len());
let first_null_idx = match array.validity() {
Validity::NonNullable | Validity::AllValid => array.len(),
Validity::AllInvalid => 0,
Validity::Array(varray) => {
// In sorted order, nulls come after all the non-null values.
varray.with_dyn(|a| a.as_bool_array_unchecked().true_count())
}
};

let first_invalid_idx = cmp::min(min_patch_offset, first_null_idx);

Self {
packed: array
.packed()
.into_primitive()
.vortex_expect("Failed to get packed bytes as PrimitiveArray"),
packed_maybe_null_slice: array.packed_slice::<T>(),
offset: array.offset(),
length: array.len(),
bit_width: array.bit_width(),
min_patch_offset: array.patches().and_then(|p| {
SparseArray::try_from(p)
.vortex_expect("Only sparse patches are supported")
.min_index()
}),
validity: array.validity(),
first_invalid_idx,
}
}
}

impl<T: BitPacking + NativePType> IndexOrd<T> for BitPackedSearch {
impl<T: BitPacking + NativePType> IndexOrd<T> for BitPackedSearch<'_, T> {
fn index_cmp(&self, idx: usize, elem: &T) -> Option<Ordering> {
if let Some(min_patch) = self.min_patch_offset {
if idx >= min_patch {
return Some(Greater);
}
}

if self.validity.is_null(idx) {
if idx >= self.first_invalid_idx {
return Some(Greater);
}

// SAFETY: Used in search_sorted_by which ensures that idx is within bounds
let val: T = unsafe {
unpack_single_primitive(
self.packed.maybe_null_slice::<T>(),
self.packed_maybe_null_slice,
self.bit_width,
idx + self.offset,
)
Expand All @@ -100,7 +180,7 @@ impl<T: BitPacking + NativePType> IndexOrd<T> for BitPackedSearch {
}
}

impl Len for BitPackedSearch {
impl<T> Len for BitPackedSearch<'_, T> {
fn len(&self) -> usize {
self.length
}
Expand All @@ -109,8 +189,12 @@ impl Len for BitPackedSearch {
#[cfg(test)]
mod test {
use vortex::array::PrimitiveArray;
use vortex::compute::{search_sorted, slice, SearchResult, SearchSortedSide};
use vortex::compute::{
search_sorted, search_sorted_many, slice, SearchResult, SearchSortedFn, SearchSortedSide,
};
use vortex::IntoArray;
use vortex_dtype::Nullability;
use vortex_scalar::Scalar;

use crate::BitPackedArray;

Expand Down Expand Up @@ -162,4 +246,60 @@ mod test {
SearchResult::Found(1)
);
}

#[test]
fn test_search_sorted_nulls() {
let bitpacked = BitPackedArray::encode(
PrimitiveArray::from_nullable_vec(vec![Some(1i64), None, None]).as_ref(),
2,
)
.unwrap();

let found = bitpacked
.search_sorted(
&Scalar::primitive(1i64, Nullability::Nullable),
SearchSortedSide::Left,
)
.unwrap();
assert_eq!(found, SearchResult::Found(0));
}

#[test]
fn test_search_sorted_many() {
// Test search_sorted_many with an array that contains several null values.
let bitpacked = BitPackedArray::encode(
PrimitiveArray::from_nullable_vec(vec![
Some(1i64),
Some(2i64),
Some(3i64),
None,
None,
None,
None,
])
.as_ref(),
3,
)
.unwrap();

let results = search_sorted_many(
bitpacked.as_ref(),
&[3i64, 2i64, 1i64],
&[
SearchSortedSide::Left,
SearchSortedSide::Left,
SearchSortedSide::Left,
],
)
.unwrap();

assert_eq!(
results,
vec![
SearchResult::Found(2),
SearchResult::Found(1),
SearchResult::Found(0),
]
);
}
}
13 changes: 12 additions & 1 deletion encodings/fastlanes/src/bitpacking/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use ::serde::{Deserialize, Serialize};
pub use compress::*;
use fastlanes::BitPacking;
use vortex::array::{PrimitiveArray, SparseArray};
use vortex::stats::{ArrayStatisticsCompute, StatsSet};
use vortex::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata};
use vortex::variants::{ArrayVariants, PrimitiveArrayTrait};
use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor};
use vortex::{impl_encoding, Array, ArrayDType, ArrayDef, ArrayTrait, Canonical, IntoCanonical};
use vortex_dtype::{Nullability, PType};
use vortex_dtype::{NativePType, Nullability, PType};
use vortex_error::{
vortex_bail, vortex_err, vortex_panic, VortexError, VortexExpect as _, VortexResult,
};
Expand Down Expand Up @@ -122,6 +123,16 @@ impl BitPackedArray {
.vortex_expect("BitpackedArray is missing packed child bytes array")
}

#[inline]
pub fn packed_slice<T: NativePType + BitPacking>(&self) -> &[T] {
let packed_primitive = self.packed().as_primitive();
let maybe_null_slice = packed_primitive.maybe_null_slice::<T>();
// SAFETY: maybe_null_slice points to buffer memory that outlives the lifetime of `self`.
// Unfortunately Rust cannot understand this, so we reconstruct the slice from raw parts
// to get it to reinterpret the lifetime.
unsafe { std::slice::from_raw_parts(maybe_null_slice.as_ptr(), maybe_null_slice.len()) }
}

#[inline]
pub fn bit_width(&self) -> usize {
self.metadata().bit_width
Expand Down
21 changes: 14 additions & 7 deletions encodings/runend/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use vortex::compute::{filter, slice, take, ArrayCompute, SliceFn, TakeFn};
use vortex::validity::Validity;
use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant};
use vortex_dtype::match_each_integer_ptype;
use vortex_error::{vortex_bail, VortexExpect as _, VortexResult};
use vortex_error::{VortexExpect as _, VortexResult};
use vortex_scalar::Scalar;

use crate::RunEndArray;
Expand Down Expand Up @@ -39,19 +39,26 @@ impl ScalarAtFn for RunEndArray {
impl TakeFn for RunEndArray {
fn take(&self, indices: &Array) -> VortexResult<Array> {
let primitive_indices = indices.clone().into_primitive()?;
let physical_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| {
let u64_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| {
primitive_indices
.maybe_null_slice::<$P>()
.iter()
.map(|idx| *idx as usize)
.copied()
.map(|idx| {
if idx >= self.len() {
vortex_bail!(OutOfBounds: idx, 0, self.len())
let usize_idx = idx as usize;
if usize_idx >= self.len() {
vortex_error::vortex_bail!(OutOfBounds: usize_idx, 0, self.len());
}
self.find_physical_index(idx).map(|loc| loc as u64)

Ok(idx as u64)
})
.collect::<VortexResult<Vec<_>>>()?
.collect::<VortexResult<Vec<u64>>>()?
});
let physical_indices: Vec<u64> = self
.find_physical_indices(&u64_indices)?
.iter()
.map(|idx| *idx as u64)
.collect();
let physical_indices_array = PrimitiveArray::from(physical_indices).into_array();
let dense_values = take(&self.values(), &physical_indices_array)?;

Expand Down
19 changes: 18 additions & 1 deletion encodings/runend/src/runend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use vortex::array::PrimitiveArray;
use vortex::compute::unary::scalar_at;
use vortex::compute::{search_sorted, SearchSortedSide};
use vortex::compute::{search_sorted, search_sorted_u64_many, SearchSortedSide};
use vortex::stats::{ArrayStatistics, ArrayStatisticsCompute, StatsSet};
use vortex::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata};
use vortex::variants::{ArrayVariants, PrimitiveArrayTrait};
Expand Down Expand Up @@ -84,6 +84,23 @@ impl RunEndArray {
.map(|s| s.to_ends_index(self.ends().len()))
}

/// Convert a batch of logical indices into an index for the values.
///
/// See: [find_physical_index][Self::find_physical_index].
pub fn find_physical_indices(&self, indices: &[u64]) -> VortexResult<Vec<usize>> {
search_sorted_u64_many(
&self.ends(),
indices,
&vec![SearchSortedSide::Right; indices.len()],
)
.map(|results| {
results
.iter()
.map(|result| result.to_ends_index(self.ends().len()))
.collect()
})
}

/// Run the array through run-end encoding.
pub fn encode(array: Array) -> VortexResult<Self> {
if let Ok(parray) = PrimitiveArray::try_from(array) {
Expand Down
Loading

0 comments on commit 79f816c

Please sign in to comment.