Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FoR array holds encoded values as unsinged #401

Merged
merged 10 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions encodings/fastlanes/src/for/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use vortex::compress::{CompressConfig, Compressor, EncodingCompression};
use vortex::stats::{ArrayStatistics, Stat};
use vortex::validity::ArrayValidity;
use vortex::{Array, ArrayDType, ArrayTrait, IntoArray, IntoArrayVariant};
use vortex_dtype::{match_each_integer_ptype, NativePType, PType};
use vortex_dtype::{match_each_integer_ptype, NativePType};
use vortex_error::{vortex_err, VortexResult};
use vortex_scalar::Scalar;

Expand Down Expand Up @@ -60,9 +60,16 @@ impl EncodingCompression for FoREncoding {

let child = match_each_integer_ptype!(parray.ptype(), |$T| {
if shift == <$T>::PTYPE.bit_width() as u8 {
ConstantArray::new(Scalar::zero::<$T>(parray.dtype().nullability()), parray.len()).into_array()
ConstantArray::new(
Scalar::zero::<$T>(parray.dtype().nullability())
.reinterpret_cast(parray.ptype().to_unsigned()),
parray.len(),
)
.into_array()
} else {
compress_primitive::<$T>(parray, shift, $T::try_from(&min)?).into_array()
compress_primitive::<$T>(&parray, shift, $T::try_from(&min)?)
.reinterpret_cast(parray.ptype().to_unsigned())
.into_array()
}
});
let for_like = like.map(|like_arr| FoRArray::try_from(like_arr).unwrap());
Expand All @@ -76,7 +83,7 @@ impl EncodingCompression for FoREncoding {
}

fn compress_primitive<T: NativePType + WrappingSub + PrimInt>(
parray: PrimitiveArray,
parray: &PrimitiveArray,
shift: u8,
min: T,
) -> PrimitiveArray {
Expand All @@ -102,8 +109,8 @@ fn compress_primitive<T: NativePType + WrappingSub + PrimInt>(

pub fn decompress(array: FoRArray) -> VortexResult<PrimitiveArray> {
let shift = array.shift();
let ptype: PType = array.dtype().try_into()?;
let encoded = array.encoded().into_primitive()?;
let ptype = array.ptype();
let encoded = array.encoded().into_primitive()?.reinterpret_cast(ptype);
Ok(match_each_integer_ptype!(ptype, |$T| {
let reference: $T = array.reference().try_into()?;
PrimitiveArray::from_vec(
Expand Down Expand Up @@ -202,9 +209,9 @@ mod test {
assert_eq!(i8::MIN, i8::try_from(compressed.reference()).unwrap());

let encoded = compressed.encoded().into_primitive().unwrap();
let bitcast: &[u8] = unsafe { std::mem::transmute(encoded.maybe_null_slice::<i8>()) };
let encoded_bytes: &[u8] = encoded.maybe_null_slice::<u8>();
let unsigned: Vec<u8> = (0..u8::MAX).collect_vec();
assert_eq!(bitcast, unsigned.as_slice());
assert_eq!(encoded_bytes, unsigned.as_slice());

let decompressed = compressed.array().clone().into_primitive().unwrap();
assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion encodings/fastlanes/src/for/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl TakeFn for FoRArray {

impl ScalarAtFn for FoRArray {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
let encoded_scalar = scalar_at(&self.encoded(), index)?;
let encoded_scalar = scalar_at(&self.encoded(), index)?.reinterpret_cast(self.ptype());
let encoded = PrimitiveScalar::try_from(&encoded_scalar)?;
let reference = PrimitiveScalar::try_from(self.reference())?;

Expand Down
23 changes: 18 additions & 5 deletions encodings/fastlanes/src/for/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use vortex::stats::ArrayStatisticsCompute;
use vortex::validity::{ArrayValidity, LogicalValidity};
use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor};
use vortex::{impl_encoding, ArrayDType, Canonical, IntoCanonical};
use vortex_dtype::PType;
use vortex_error::vortex_bail;
use vortex_scalar::Scalar;

Expand All @@ -24,9 +25,13 @@ impl FoRArray {
if reference.is_null() {
vortex_bail!("Reference value cannot be null",);
}
let reference = reference.cast(child.dtype())?;
let reference = reference.cast(
&reference
.dtype()
.with_nullability(child.dtype().nullability()),
)?;
Self::try_from_parts(
child.dtype().clone(),
reference.dtype().clone(),
FoRMetadata { reference, shift },
[child].into(),
StatsSet::new(),
Expand All @@ -35,9 +40,12 @@ impl FoRArray {

#[inline]
pub fn encoded(&self) -> Array {
self.array()
.child(0, self.dtype())
.expect("Missing FoR child")
let dtype = if self.ptype().is_signed_int() {
&DType::Primitive(self.ptype().to_unsigned(), self.dtype().nullability())
} else {
self.dtype()
};
self.array().child(0, dtype).expect("Missing FoR child")
}

#[inline]
Expand All @@ -49,6 +57,11 @@ impl FoRArray {
pub fn shift(&self) -> u8 {
self.metadata().shift
}

#[inline]
pub fn ptype(&self) -> PType {
self.dtype().try_into().unwrap()
}
}

impl ArrayValidity for FoRArray {
Expand Down
22 changes: 22 additions & 0 deletions vortex-scalar/src/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,28 @@ impl Scalar {
}
}

pub fn reinterpret_cast(&self, ptype: PType) -> Self {
let primitive = PrimitiveScalar::try_from(self).unwrap();
if primitive.ptype() == ptype {
return self.clone();
}

assert_eq!(
primitive.ptype().byte_width(),
ptype.byte_width(),
"can't reinterpret cast between integers of two different widths"
);

Self {
dtype: DType::Primitive(ptype, self.dtype.nullability()),
value: primitive
.pvalue
.map(|p| p.reinterpret_cast(ptype))
.map(ScalarValue::Primitive)
.unwrap_or_else(|| ScalarValue::Null),
}
}

pub fn zero<T: NativePType + Into<PValue>>(nullability: Nullability) -> Self {
Self {
dtype: DType::Primitive(T::PTYPE, nullability),
Expand Down
64 changes: 64 additions & 0 deletions vortex-scalar/src/pvalue.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::mem;

use num_traits::NumCast;
use vortex_dtype::half::f16;
use vortex_dtype::PType;
Expand Down Expand Up @@ -35,6 +37,68 @@ impl PValue {
Self::F64(_) => PType::F64,
}
}

pub fn reinterpret_cast(&self, ptype: PType) -> Self {
if ptype == self.ptype() {
return *self;
}

assert_eq!(
ptype.byte_width(),
self.ptype().byte_width(),
"Cannot reinterpret cast between types of different widths"
);

match self {
robert3005 marked this conversation as resolved.
Show resolved Hide resolved
PValue::U8(v) => unsafe { mem::transmute::<u8, i8>(*v) }.into(),
PValue::U16(v) => match ptype {
PType::I16 => unsafe { mem::transmute::<u16, i16>(*v) }.into(),
PType::F16 => f16::from_bits(*v).into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::U32(v) => match ptype {
PType::I32 => unsafe { mem::transmute::<u32, i32>(*v) }.into(),
PType::F32 => f32::from_bits(*v).into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::U64(v) => match ptype {
PType::I64 => unsafe { mem::transmute::<u64, i64>(*v) }.into(),
PType::F64 => f64::from_bits(*v).into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::I8(v) => unsafe { mem::transmute::<i8, u8>(*v) }.into(),
PValue::I16(v) => match ptype {
PType::U16 => unsafe { mem::transmute::<i16, u16>(*v) }.into(),
PType::F16 => f16::from_bits(unsafe { mem::transmute::<i16, u16>(*v) }).into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::I32(v) => match ptype {
PType::U32 => unsafe { mem::transmute::<i32, u32>(*v) }.into(),
PType::F32 => f32::from_bits(unsafe { mem::transmute::<i32, u32>(*v) }).into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::I64(v) => match ptype {
PType::U64 => unsafe { mem::transmute::<i64, u64>(*v) }.into(),
PType::F64 => f64::from_bits(unsafe { mem::transmute::<i64, u64>(*v) }).into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::F16(v) => match ptype {
PType::U16 => v.to_bits().into(),
PType::I16 => unsafe { mem::transmute::<u16, i16>(v.to_bits()) }.into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::F32(v) => match ptype {
PType::U32 => v.to_bits().into(),
PType::I32 => unsafe { mem::transmute::<u32, i32>(v.to_bits()) }.into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::F64(v) => match ptype {
PType::U64 => v.to_bits().into(),
PType::I64 => unsafe { mem::transmute::<u64, i64>(v.to_bits()) }.into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
}
}
}

macro_rules! int_pvalue {
Expand Down
Loading