Skip to content

Commit

Permalink
Fix search_sorted for FoRArray, BitPacked and PrimitiveArray (#732)
Browse files Browse the repository at this point in the history
Co-authored-by: Will Manning <[email protected]>
  • Loading branch information
robert3005 and lwwmanning authored Sep 5, 2024
1 parent 807c4b4 commit 479419c
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 57 deletions.
52 changes: 41 additions & 11 deletions encodings/fastlanes/src/bitpacking/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,57 @@ use std::cmp::Ordering;
use std::cmp::Ordering::Greater;

use fastlanes::BitPacking;
use num_traits::AsPrimitive;
use vortex::array::{PrimitiveArray, SparseArray};
use vortex::compute::{
search_sorted, IndexOrd, Len, SearchResult, SearchSorted, SearchSortedFn, SearchSortedSide,
};
use vortex::validity::Validity;
use vortex::{ArrayDType, IntoArrayVariant};
use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType};
use vortex_error::VortexResult;
use vortex_error::{VortexError, VortexResult};
use vortex_scalar::Scalar;

use crate::{unpack_single_primitive, BitPackedArray};

impl SearchSortedFn for BitPackedArray {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult> {
match_each_unsigned_integer_ptype!(self.ptype(), |$P| {
let unwrapped_value: $P = value.cast(self.dtype())?.try_into().unwrap();
if let Some(patches_array) = self.patches() {
if unwrapped_value as usize >= self.max_packed_value() {
search_sorted(&patches_array, value.clone(), side)
} else {
Ok(SearchSorted::search_sorted(&BitPackedSearch::new(self), &unwrapped_value, side))
}
} else {
Ok(SearchSorted::search_sorted(&BitPackedSearch::new(self), &unwrapped_value, side))
}
search_sorted_typed::<$P>(self, value, side)
})
}
}

fn search_sorted_typed<T>(
array: &BitPackedArray,
value: &Scalar,
side: SearchSortedSide,
) -> VortexResult<SearchResult>
where
T: NativePType + TryFrom<Scalar, Error = VortexError> + BitPacking + AsPrimitive<usize>,
{
let unwrapped_value: T = value.cast(array.dtype())?.try_into()?;
if let Some(patches_array) = array.patches() {
// If patches exist they must be the last elements in the array, if the value we're looking for is greater than
// max packed value just search the patches
if unwrapped_value.as_() > array.max_packed_value() {
search_sorted(&patches_array, value.clone(), side)
} else {
Ok(SearchSorted::search_sorted(
&BitPackedSearch::new(array),
&unwrapped_value,
side,
))
}
} else {
Ok(SearchSorted::search_sorted(
&BitPackedSearch::new(array),
&unwrapped_value,
side,
))
}
}

/// This wrapper exists, so that you can't invoke SearchSorted::search_sorted directly on BitPackedArray as it omits searching patches
#[derive(Debug)]
struct BitPackedSearch {
Expand All @@ -38,6 +61,7 @@ struct BitPackedSearch {
length: usize,
bit_width: usize,
min_patch_offset: Option<usize>,
validity: Validity,
}

impl BitPackedSearch {
Expand All @@ -52,6 +76,7 @@ impl BitPackedSearch {
.expect("Only Sparse patches are supported")
.min_index()
}),
validity: array.validity(),
}
}
}
Expand All @@ -63,6 +88,11 @@ impl<T: BitPacking + NativePType> IndexOrd<T> for BitPackedSearch {
return Some(Greater);
}
}

if self.validity.is_null(idx) {
return Some(Greater);
}

// SAFETY: Used in search_sorted_by which ensures that idx is within bounds
let val: T = unsafe {
unpack_single_primitive(
Expand Down
2 changes: 1 addition & 1 deletion encodings/fastlanes/src/bitpacking/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl BitPackedArray {

#[inline]
pub fn max_packed_value(&self) -> usize {
1 << self.bit_width()
(1 << self.bit_width()) - 1
}
}

Expand Down
17 changes: 8 additions & 9 deletions encodings/fastlanes/src/for/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub fn for_compress(array: &PrimitiveArray) -> VortexResult<Array> {
if shift == <$T>::PTYPE.bit_width() as u8 {
match array.validity().to_logical(array.len()) {
LogicalValidity::AllValid(l) => {
ConstantArray::new(Scalar::zero::<i32>(array.dtype().nullability()), l).into_array()
ConstantArray::new(Scalar::zero::<$T>(array.dtype().nullability()), l).into_array()
},
LogicalValidity::AllInvalid(l) => {
ConstantArray::new(Scalar::null(array.dtype().clone()), l).into_array()
Expand Down Expand Up @@ -66,12 +66,11 @@ fn compress_primitive<T: NativePType + WrappingSub + PrimInt>(
) -> PrimitiveArray {
assert!(shift < T::PTYPE.bit_width() as u8);
let values = if shift > 0 {
let shifted_min = min >> shift as usize;
parray
.maybe_null_slice::<T>()
.iter()
.map(|&v| v >> shift as usize)
.map(|v| v.wrapping_sub(&shifted_min))
.map(|&v| v.wrapping_sub(&min))
.map(|v| v >> shift as usize)
.collect_vec()
} else {
parray
Expand All @@ -90,29 +89,29 @@ pub fn decompress(array: FoRArray) -> VortexResult<PrimitiveArray> {
let encoded = array.encoded().into_primitive()?.reinterpret_cast(ptype);
let validity = encoded.validity();
Ok(match_each_integer_ptype!(ptype, |$T| {
let reference: $T = array.reference().try_into()?;
let min: $T = array.reference().try_into()?;
PrimitiveArray::from_vec(
decompress_primitive(encoded.into_maybe_null_slice::<$T>(), reference, shift),
decompress_primitive(encoded.into_maybe_null_slice::<$T>(), min, shift),
validity,
)
}))
}

fn decompress_primitive<T: NativePType + WrappingAdd + PrimInt>(
values: Vec<T>,
reference: T,
min: T,
shift: usize,
) -> Vec<T> {
if shift > 0 {
values
.into_iter()
.map(|v| v << shift)
.map(|v| v.wrapping_add(&reference))
.map(|v| v.wrapping_add(&min))
.collect_vec()
} else {
values
.into_iter()
.map(|v| v.wrapping_add(&reference))
.map(|v| v.wrapping_add(&min))
.collect_vec()
}
}
Expand Down
164 changes: 144 additions & 20 deletions encodings/fastlanes/src/for/compute.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use std::ops::{AddAssign, Shl, Shr};

use num_traits::{WrappingAdd, WrappingSub};
use vortex::compute::unary::{scalar_at_unchecked, ScalarAtFn};
use vortex::compute::{
search_sorted, slice, take, ArrayCompute, SearchResult, SearchSortedFn, SearchSortedSide,
SliceFn, TakeFn,
};
use vortex::{Array, ArrayDType, IntoArray};
use vortex_dtype::match_each_integer_ptype;
use vortex_error::VortexResult;
use vortex_scalar::{PrimitiveScalar, Scalar};
use vortex_dtype::{match_each_integer_ptype, NativePType};
use vortex_error::{VortexError, VortexResult};
use vortex_scalar::{PValue, PrimitiveScalar, Scalar};

use crate::FoRArray;

Expand Down Expand Up @@ -51,7 +54,6 @@ impl ScalarAtFn for FoRArray {
let reference = PrimitiveScalar::try_from(self.reference()).unwrap();

match_each_integer_ptype!(encoded.ptype(), |$P| {
use num_traits::WrappingAdd;
encoded.typed_value::<$P>().map(|v| (v << self.shift()).wrapping_add(reference.typed_value::<$P>().unwrap()))
.map(|v| Scalar::primitive::<$P>(v, encoded.dtype().nullability()))
.unwrap_or_else(|| Scalar::null(encoded.dtype().clone()))
Expand All @@ -73,32 +75,72 @@ impl SliceFn for FoRArray {
impl SearchSortedFn for FoRArray {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult> {
match_each_integer_ptype!(self.ptype(), |$P| {
let min: $P = self.reference().try_into().unwrap();
let shifted_min = min >> self.shift();
let unwrapped_value: $P = value.cast(self.dtype())?.try_into().unwrap();
let shifted_value: $P = unwrapped_value >> self.shift();
// Make sure that smaller values are still smaller and not larger than (which they would be after wrapping_sub)
if shifted_value < shifted_min {
return Ok(SearchResult::NotFound(0));
}

let translated_scalar = Scalar::primitive(
shifted_value.wrapping_sub(shifted_min),
value.dtype().nullability(),
)
.reinterpret_cast(self.ptype().to_unsigned());
search_sorted(&self.encoded(), translated_scalar, side)
search_sorted_typed::<$P>(self, value, side)
})
}
}

fn search_sorted_typed<T>(
array: &FoRArray,
value: &Scalar,
side: SearchSortedSide,
) -> VortexResult<SearchResult>
where
T: NativePType
+ for<'a> TryFrom<&'a Scalar, Error = VortexError>
+ Shr<u8, Output = T>
+ Shl<u8, Output = T>
+ WrappingSub
+ WrappingAdd
+ AddAssign
+ Into<PValue>,
{
let min: T = array.reference().try_into()?;
let primitive_value: T = value.cast(array.dtype())?.as_ref().try_into()?;
// Make sure that smaller values are still smaller and not larger than (which they would be after wrapping_sub)
if primitive_value < min {
return Ok(SearchResult::NotFound(0));
}

// When the values in the array are shifted, not all values in the domain are representable in the compressed
// space. Multiple different search values can translate to same value in the compressed space.
let encoded_value = primitive_value.wrapping_sub(&min) >> array.shift();
let decoded_value = (encoded_value << array.shift()).wrapping_add(&min);

// We first determine whether the value can be represented in the compressed array. For any value that is not
// representable, it is by definition NotFound. For NotFound values, the correct insertion index is by definition
// the same regardless of which side we search on.
// However, to correctly handle repeated values in the array, we need to search left on the next *representable*
// value (i.e., increment the translated value by 1).
let representable = decoded_value == primitive_value;
let (side, target) = if representable {
(side, encoded_value)
} else {
(
SearchSortedSide::Left,
encoded_value.wrapping_add(&T::one()),
)
};

let target_scalar = Scalar::primitive(target, value.dtype().nullability())
.reinterpret_cast(array.ptype().to_unsigned());
let search_result = search_sorted(&array.encoded(), target_scalar, side)?;
Ok(
if representable && matches!(search_result, SearchResult::Found(_)) {
search_result
} else {
SearchResult::NotFound(search_result.to_index())
},
)
}

#[cfg(test)]
mod test {
use vortex::array::PrimitiveArray;
use vortex::compute::unary::scalar_at;
use vortex::compute::{search_sorted, SearchResult, SearchSortedSide};

use crate::for_compress;
use crate::{for_compress, FoRArray};

#[test]
fn for_scalar_at() {
Expand All @@ -124,4 +166,86 @@ mod test {
SearchResult::NotFound(0)
);
}

#[test]
fn search_with_shift_notfound() {
let for_arr = for_compress(&PrimitiveArray::from(vec![62, 114])).unwrap();
assert_eq!(
search_sorted(&for_arr, 63, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(1)
);
let for_arr = for_compress(&PrimitiveArray::from(vec![62, 114])).unwrap();
assert_eq!(
search_sorted(&for_arr, 61, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(0)
);
let for_arr = for_compress(&PrimitiveArray::from(vec![62, 114])).unwrap();
assert_eq!(
search_sorted(&for_arr, 113, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(1)
);
assert_eq!(
search_sorted(&for_arr, 115, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(2)
);
}

#[test]
fn search_with_shift_repeated() {
let arr = for_compress(&PrimitiveArray::from(vec![62, 62, 114, 114])).unwrap();
let for_array = FoRArray::try_from(arr.clone()).unwrap();

let min: i32 = for_array.reference().try_into().unwrap();
assert_eq!(min, 62);
assert_eq!(for_array.shift(), 1);

assert_eq!(
search_sorted(&arr, 61, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(0)
);
assert_eq!(
search_sorted(&arr, 61, SearchSortedSide::Right).unwrap(),
SearchResult::NotFound(0)
);
assert_eq!(
search_sorted(&arr, 62, SearchSortedSide::Left).unwrap(),
SearchResult::Found(0)
);
assert_eq!(
search_sorted(&arr, 62, SearchSortedSide::Right).unwrap(),
SearchResult::Found(2)
);
assert_eq!(
search_sorted(&arr, 63, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(2)
);
assert_eq!(
search_sorted(&arr, 63, SearchSortedSide::Right).unwrap(),
SearchResult::NotFound(2)
);
assert_eq!(
search_sorted(&arr, 113, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(2)
);
assert_eq!(
search_sorted(&arr, 113, SearchSortedSide::Right).unwrap(),
SearchResult::NotFound(2)
);
assert_eq!(
search_sorted(&arr, 114, SearchSortedSide::Left).unwrap(),
SearchResult::Found(2)
);
assert_eq!(
search_sorted(&arr, 114, SearchSortedSide::Right).unwrap(),
SearchResult::Found(4)
);
assert_eq!(
search_sorted(&arr, 115, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(4)
);
assert_eq!(
search_sorted(&arr, 115, SearchSortedSide::Right).unwrap(),
SearchResult::NotFound(4)
);
}
}
7 changes: 3 additions & 4 deletions fuzz/fuzz_targets/fuzz_target_1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus {
Corpus::Keep
}
Action::SearchSorted(s, side) => {
if !array_is_sorted(&array).unwrap() {
if !array_is_sorted(&array).unwrap() || s.is_null() {
return Corpus::Reject;
}

Expand Down Expand Up @@ -76,7 +76,7 @@ fn assert_take(original: &Array, taken: &Array, indices: &Array) {
let o = scalar_at(original, to_take).unwrap();
let s = scalar_at(taken, idx).unwrap();

fuzzing_scalar_cmp(o, s, original.encoding().id(), taken.encoding().id(), idx);
fuzzing_scalar_cmp(o, s, original.encoding().id(), indices.encoding().id(), idx);
}
}

Expand Down Expand Up @@ -122,8 +122,7 @@ fn fuzzing_scalar_cmp(

assert!(
equal_values,
"{l} != {r} at index {idx}, lhs is {} rhs is {}",
lhs_encoding, rhs_encoding
"{l} != {r} at index {idx}, lhs is {lhs_encoding} rhs is {rhs_encoding}",
);
assert_eq!(l.is_valid(), r.is_valid());
}
Expand Down
Loading

0 comments on commit 479419c

Please sign in to comment.