Skip to content

Commit

Permalink
feat: SparseArray uses ScalarValue instead of Scalar (#955)
Browse files Browse the repository at this point in the history
  • Loading branch information
danking authored Oct 1, 2024
1 parent 70d75b4 commit df77be3
Show file tree
Hide file tree
Showing 14 changed files with 108 additions and 98 deletions.
4 changes: 2 additions & 2 deletions encodings/alp/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use vortex::validity::Validity;
use vortex::{Array, ArrayDType, ArrayDef, IntoArray, IntoArrayVariant};
use vortex_dtype::{NativePType, PType};
use vortex_error::{vortex_bail, VortexExpect as _, VortexResult};
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;

use crate::alp::ALPFloat;
use crate::array::ALPArray;
Expand Down Expand Up @@ -42,7 +42,7 @@ where
PrimitiveArray::from(exc_pos).into_array(),
PrimitiveArray::from_vec(exc, Validity::AllValid).into_array(),
len,
Scalar::null(values.dtype().as_nullable()),
ScalarValue::Null,
)
.vortex_expect("Failed to create SparseArray for ALP patches")
.into_array()
Expand Down
4 changes: 2 additions & 2 deletions encodings/fastlanes/src/bitpacking/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use vortex_dtype::{
match_each_integer_ptype, match_each_unsigned_integer_ptype, NativePType, PType,
};
use vortex_error::{vortex_bail, vortex_err, VortexResult, VortexUnwrap};
use vortex_scalar::Scalar;
use vortex_scalar::{Scalar, ScalarValue};

use crate::BitPackedArray;

Expand Down Expand Up @@ -131,7 +131,7 @@ pub fn bitpack_patches(
indices.into_array(),
PrimitiveArray::from_vec(values, Validity::AllValid).into_array(),
parray.len(),
Scalar::null(parray.dtype().as_nullable()),
ScalarValue::Null,
)
.vortex_unwrap()
.into_array()
Expand Down
4 changes: 2 additions & 2 deletions encodings/fastlanes/src/bitpacking/compute/scalar_at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ mod test {
use vortex::IntoArray;
use vortex_buffer::Buffer;
use vortex_dtype::{DType, Nullability, PType};
use vortex_scalar::Scalar;
use vortex_scalar::{Scalar, ScalarValue};

use crate::BitPackedArray;

Expand All @@ -45,7 +45,7 @@ mod test {
PrimitiveArray::from(vec![1u64]).into_array(),
PrimitiveArray::from_vec(vec![999u32], Validity::AllValid).into_array(),
8,
Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
ScalarValue::Null,
)
.unwrap()
.into_array(),
Expand Down
4 changes: 2 additions & 2 deletions encodings/fastlanes/src/for/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use vortex::validity::LogicalValidity;
use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant};
use vortex_dtype::{match_each_integer_ptype, NativePType};
use vortex_error::{vortex_err, VortexResult};
use vortex_scalar::Scalar;
use vortex_scalar::{Scalar, ScalarValue};

use crate::FoRArray;

Expand Down Expand Up @@ -41,7 +41,7 @@ pub fn for_compress(array: &PrimitiveArray) -> VortexResult<Array> {
ConstantArray::new(Scalar::zero::<$T>(array.dtype().nullability()), valid_len)
.into_array(),
array.len(),
Scalar::null(array.dtype().clone()),
ScalarValue::Null,
)?
.into_array()
}
Expand Down
4 changes: 2 additions & 2 deletions encodings/runend/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use vortex::validity::Validity;
use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant};
use vortex_dtype::match_each_integer_ptype;
use vortex_error::{VortexExpect as _, VortexResult};
use vortex_scalar::Scalar;
use vortex_scalar::{Scalar, ScalarValue};

use crate::RunEndArray;

Expand Down Expand Up @@ -86,7 +86,7 @@ impl TakeFn for RunEndArray {
dense_nonnull_indices,
filtered_values,
length,
Scalar::null(self.dtype().clone()),
ScalarValue::Null,
)?
.into_array()
}
Expand Down
25 changes: 11 additions & 14 deletions vortex-array/src/array/sparse/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::compute::{
search_sorted, take, ArrayCompute, FilterFn, SearchResult, SearchSortedFn, SearchSortedSide,
SliceFn, TakeFn,
};
use crate::{Array, ArrayDType, IntoArray, IntoArrayVariant};
use crate::{Array, IntoArray, IntoArrayVariant};

mod slice;
mod take;
Expand Down Expand Up @@ -38,18 +38,16 @@ impl ArrayCompute for SparseArray {

impl ScalarAtFn for SparseArray {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
match self.search_index(index)?.to_found() {
None => self.fill_value().clone().cast(self.dtype()),
Some(idx) => scalar_at_unchecked(&self.values(), idx).cast(self.dtype()),
}
Ok(match self.search_index(index)?.to_found() {
None => self.fill_scalar(),
Some(idx) => scalar_at_unchecked(&self.values(), idx),
})
}

fn scalar_at_unchecked(&self, index: usize) -> Scalar {
match self.search_index(index).vortex_unwrap().to_found() {
None => self.fill_value().clone().cast(self.dtype()).vortex_unwrap(),
Some(idx) => scalar_at_unchecked(&self.values(), idx)
.cast(self.dtype())
.vortex_unwrap(),
None => self.fill_scalar(),
Some(idx) => scalar_at_unchecked(&self.values(), idx),
}
}
}
Expand Down Expand Up @@ -115,8 +113,7 @@ impl FilterFn for SparseArray {
#[cfg(test)]
mod test {
use rstest::{fixture, rstest};
use vortex_dtype::{DType, Nullability, PType};
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;

use crate::array::primitive::PrimitiveArray;
use crate::array::sparse::SparseArray;
Expand All @@ -131,7 +128,7 @@ mod test {
PrimitiveArray::from(vec![2u64, 9, 15]).into_array(),
PrimitiveArray::from_vec(vec![33_i32, 44, 55], Validity::AllValid).into_array(),
20,
Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
ScalarValue::Null,
)
.unwrap()
.into_array()
Expand Down Expand Up @@ -176,7 +173,7 @@ mod test {
PrimitiveArray::from(vec![0u64]).into_array(),
PrimitiveArray::from_vec(vec![0u8], Validity::AllValid).into_array(),
2,
Scalar::null(DType::Primitive(PType::U8, Nullability::Nullable)),
ScalarValue::Null,
)
.unwrap()
.into_array();
Expand Down Expand Up @@ -216,7 +213,7 @@ mod test {
PrimitiveArray::from(vec![0_u64, 3, 6]).into_array(),
PrimitiveArray::from_vec(vec![33_i32, 44, 55], Validity::AllValid).into_array(),
7,
Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
ScalarValue::Null,
)
.unwrap()
.into_array();
Expand Down
25 changes: 6 additions & 19 deletions vortex-array/src/array/sparse/compute/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ impl SliceFn for SparseArray {

#[cfg(test)]
mod tests {
use vortex_dtype::Nullability;
use vortex_scalar::Scalar;

use super::*;
use crate::IntoArrayVariant;

Expand All @@ -34,14 +31,9 @@ mod tests {
let values = vec![15_u32, 135, 13531, 42].into_array();
let indices = vec![10_u64, 11, 50, 100].into_array();

let sparse = SparseArray::try_new(
indices.clone(),
values,
101,
Scalar::primitive(0_u32, Nullability::NonNullable),
)
.unwrap()
.into_array();
let sparse = SparseArray::try_new(indices.clone(), values, 101, 0_u32.into())
.unwrap()
.into_array();

let sliced = slice(&sparse, 15, 100).unwrap();
assert_eq!(sliced.len(), 100 - 15);
Expand All @@ -59,14 +51,9 @@ mod tests {
let values = vec![15_u32, 135, 13531, 42].into_array();
let indices = vec![10_u64, 11, 50, 100].into_array();

let sparse = SparseArray::try_new(
indices.clone(),
values,
101,
Scalar::primitive(0_u32, Nullability::NonNullable),
)
.unwrap()
.into_array();
let sparse = SparseArray::try_new(indices.clone(), values, 101, 0_u32.into())
.unwrap()
.into_array();

let sliced = slice(&sparse, 15, 100).unwrap();
assert_eq!(sliced.len(), 100 - 15);
Expand Down
5 changes: 2 additions & 3 deletions vortex-array/src/array/sparse/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ fn take_search_sorted(
#[cfg(test)]
mod test {
use itertools::Itertools;
use vortex_dtype::{DType, Nullability, PType};
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;

use crate::array::primitive::PrimitiveArray;
use crate::array::sparse::compute::take::take_map;
Expand All @@ -100,7 +99,7 @@ mod test {
PrimitiveArray::from_vec(vec![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid)
.into_array(),
100,
Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable)),
ScalarValue::Null,
)
.unwrap()
.into_array()
Expand Down
8 changes: 4 additions & 4 deletions vortex-array/src/array/sparse/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, MutableBuffer};
use itertools::Itertools;
use vortex_dtype::{match_each_native_ptype, DType, NativePType};
use vortex_error::{VortexError, VortexResult};
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;

use crate::array::primitive::PrimitiveArray;
use crate::array::sparse::SparseArray;
Expand Down Expand Up @@ -46,7 +46,7 @@ fn canonicalize_sparse_bools(
values: BooleanBuffer,
indices: &[usize],
len: usize,
fill_value: &Scalar,
fill_value: &ScalarValue,
mut validity_buffer: BooleanBufferBuilder,
) -> VortexResult<Canonical> {
let fill_bool: bool = if fill_value.is_null() {
Expand All @@ -67,12 +67,12 @@ fn canonicalize_sparse_bools(
}

fn canonicalize_sparse_primitives<
T: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
T: NativePType + for<'a> TryFrom<&'a ScalarValue, Error = VortexError>,
>(
values: &[T],
indices: &[usize],
len: usize,
fill_value: &Scalar,
fill_value: &ScalarValue,
mut validity: BooleanBufferBuilder,
) -> VortexResult<Canonical> {
let primitive_fill = if fill_value.is_null() {
Expand Down
51 changes: 21 additions & 30 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use ::serde::{Deserialize, Serialize};
use vortex_dtype::{match_each_integer_ptype, DType};
use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult};
use vortex_scalar::Scalar;
use vortex_scalar::{Scalar, ScalarValue};

use crate::array::constant::ConstantArray;
use crate::compute::unary::scalar_at;
Expand All @@ -23,15 +23,15 @@ pub struct SparseMetadata {
// Offset value for patch indices as a result of slicing
indices_offset: usize,
indices_len: usize,
fill_value: Scalar,
fill_value: ScalarValue,
}

impl SparseArray {
pub fn try_new(
indices: Array,
values: Array,
len: usize,
fill_value: Scalar,
fill_value: ScalarValue,
) -> VortexResult<Self> {
Self::try_new_with_offset(indices, values, len, 0, fill_value)
}
Expand All @@ -41,18 +41,11 @@ impl SparseArray {
values: Array,
len: usize,
indices_offset: usize,
fill_value: Scalar,
fill_value: ScalarValue,
) -> VortexResult<Self> {
if !matches!(indices.dtype(), &DType::IDX) {
vortex_bail!("Cannot use {} as indices", indices.dtype());
}
if values.dtype() != fill_value.dtype() {
vortex_bail!(
"Mismatched fill value dtype {} and values dtype {}",
fill_value.dtype(),
values.dtype(),
);
}
if indices.len() != values.len() {
vortex_bail!(
"Mismatched indices {} and values {} length",
Expand Down Expand Up @@ -102,10 +95,15 @@ impl SparseArray {
}

#[inline]
pub fn fill_value(&self) -> &Scalar {
pub fn fill_value(&self) -> &ScalarValue {
&self.metadata().fill_value
}

#[inline]
pub fn fill_scalar(&self) -> Scalar {
Scalar::new(self.dtype().clone(), self.fill_value().clone())
}

/// Returns the position or the insertion point of a given index in the indices array.
fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
search_sorted(
Expand Down Expand Up @@ -196,7 +194,7 @@ impl ArrayValidity for SparseArray {
#[cfg(test)]
mod test {
use itertools::Itertools;
use vortex_dtype::Nullability::{self, Nullable};
use vortex_dtype::Nullability::Nullable;
use vortex_dtype::{DType, PType};
use vortex_error::VortexError;
use vortex_scalar::Scalar;
Expand All @@ -221,9 +219,14 @@ mod test {
let mut values = vec![100i32, 200, 300].into_array();
values = try_cast(&values, fill_value.dtype()).unwrap();

SparseArray::try_new(vec![2u64, 5, 8].into_array(), values, 10, fill_value)
.unwrap()
.into_array()
SparseArray::try_new(
vec![2u64, 5, 8].into_array(),
values,
10,
fill_value.value().clone(),
)
.unwrap()
.into_array()
}

fn assert_sparse_array(sparse: &Array, values: &[Option<i32>]) {
Expand Down Expand Up @@ -372,26 +375,14 @@ mod test {
let values = vec![15_u32, 135, 13531, 42].into_array();
let indices = vec![10_u64, 11, 50, 100].into_array();

SparseArray::try_new(
indices.clone(),
values,
100,
Scalar::primitive(0_u32, Nullability::NonNullable),
)
.unwrap();
SparseArray::try_new(indices.clone(), values, 100, 0_u32.into()).unwrap();
}

#[test]
fn test_valid_length() {
let values = vec![15_u32, 135, 13531, 42].into_array();
let indices = vec![10_u64, 11, 50, 100].into_array();

SparseArray::try_new(
indices.clone(),
values,
101,
Scalar::primitive(0_u32, Nullability::NonNullable),
)
.unwrap();
SparseArray::try_new(indices.clone(), values, 101, 0_u32.into()).unwrap();
}
}
Loading

0 comments on commit df77be3

Please sign in to comment.