Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: teach ScalarValue and PValue is_instance_of #958

Merged
merged 2 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ impl SparseArray {
if !matches!(indices.dtype(), &DType::IDX) {
vortex_bail!("Cannot use {} as indices", indices.dtype());
}
if !fill_value.is_instance_of(values.dtype()) {
vortex_bail!(
"fill value, {:?}, should be instance of values dtype, {}",
fill_value,
values.dtype(),
);
}
if indices.len() != values.len() {
vortex_bail!(
"Mismatched indices {} and values {} length",
Expand Down
30 changes: 30 additions & 0 deletions vortex-scalar/src/pvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ impl PValue {
}
}

pub fn is_instance_of(&self, ptype: &PType) -> bool {
&self.ptype() == ptype
}

#[allow(clippy::transmute_int_to_float, clippy::transmute_float_to_int)]
pub fn reinterpret_cast(&self, ptype: PType) -> Self {
if ptype == self.ptype() {
Expand Down Expand Up @@ -262,3 +266,29 @@ impl_pvalue!(i64, I64);
impl_pvalue!(f16, F16);
impl_pvalue!(f32, F32);
impl_pvalue!(f64, F64);

#[cfg(test)]
mod test {
use vortex_dtype::half::f16;
use vortex_dtype::PType;

use crate::PValue;

#[test]
pub fn test_is_instance_of() {
assert!(PValue::U8(10).is_instance_of(&PType::U8));
assert!(!PValue::U8(10).is_instance_of(&PType::U16));
assert!(!PValue::U8(10).is_instance_of(&PType::I8));
assert!(!PValue::U8(10).is_instance_of(&PType::F16));

assert!(PValue::I8(10).is_instance_of(&PType::I8));
assert!(!PValue::I8(10).is_instance_of(&PType::I16));
assert!(!PValue::I8(10).is_instance_of(&PType::U8));
assert!(!PValue::I8(10).is_instance_of(&PType::F16));

assert!(PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F16));
assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F32));
assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::U16));
assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::I16));
}
}
109 changes: 109 additions & 0 deletions vortex-scalar/src/value.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use vortex_buffer::{Buffer, BufferString};
use vortex_dtype::DType;
use vortex_error::{vortex_err, VortexResult};

use crate::pvalue::PValue;
Expand Down Expand Up @@ -28,6 +29,26 @@ impl ScalarValue {
matches!(self, Self::Null)
}

pub fn is_instance_of(&self, dtype: &DType) -> bool {
match (self, dtype) {
(ScalarValue::Bool(_), DType::Bool(_)) => true,
(ScalarValue::Primitive(pvalue), DType::Primitive(ptype, _)) => {
pvalue.is_instance_of(ptype)
}
(ScalarValue::Buffer(_), DType::Binary(_)) => true,
(ScalarValue::BufferString(_), DType::Utf8(_)) => true,
(ScalarValue::List(values), DType::List(dtype, _)) => {
values.iter().all(|v| v.is_instance_of(dtype))
}
(ScalarValue::List(values), DType::Struct(structdt, _)) => values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo, this should be ScalarValue::Struct on left

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rob & Nick argue against adding ScalarValue::Struct b/c field name duplication: https://spiraldb.slack.com/archives/C07BV3GKAJ2/p1727813531212969

.iter()
.zip(structdt.dtypes().to_vec())
.all(|(v, dt)| v.is_instance_of(&dt)),
(ScalarValue::Null, dtype) => dtype.is_nullable(),
(..) => false,
}
}

pub fn as_bool(&self) -> VortexResult<Option<bool>> {
match self {
Self::Null => Ok(None),
Expand Down Expand Up @@ -69,3 +90,91 @@ impl ScalarValue {
}
}
}

#[cfg(test)]
mod test {
use vortex_dtype::{DType, Nullability, PType, StructDType};

use crate::{PValue, ScalarValue};

#[test]
pub fn test_is_instance_of_bool() {
assert!(ScalarValue::Bool(true).is_instance_of(&DType::Bool(Nullability::Nullable)));
assert!(ScalarValue::Bool(true).is_instance_of(&DType::Bool(Nullability::NonNullable)));
assert!(ScalarValue::Bool(false).is_instance_of(&DType::Bool(Nullability::Nullable)));
assert!(ScalarValue::Bool(false).is_instance_of(&DType::Bool(Nullability::NonNullable)));
}

#[test]
pub fn test_is_instance_of_primitive() {
assert!(ScalarValue::Primitive(PValue::F64(0.0))
.is_instance_of(&DType::Primitive(PType::F64, Nullability::NonNullable)));
}

#[test]
pub fn test_is_instance_of_list_and_struct() {
let tbool = DType::Bool(Nullability::NonNullable);
let tboolnull = DType::Bool(Nullability::Nullable);
let tnull = DType::Null;

let bool_null = ScalarValue::List(vec![ScalarValue::Bool(true), ScalarValue::Null].into());
let bool_bool =
ScalarValue::List(vec![ScalarValue::Bool(true), ScalarValue::Bool(false)].into());

fn tlist(element: &DType) -> DType {
DType::List(element.clone().into(), Nullability::NonNullable)
}

assert!(bool_null.is_instance_of(&tlist(&tboolnull)));
assert!(!bool_null.is_instance_of(&tlist(&tbool)));
assert!(bool_bool.is_instance_of(&tlist(&tbool)));
assert!(bool_bool.is_instance_of(&tlist(&tbool)));

fn tstruct(left: &DType, right: &DType) -> DType {
DType::Struct(
StructDType::new(
vec!["left".into(), "right".into()].into(),
vec![left.clone(), right.clone()],
),
Nullability::NonNullable,
)
}

assert!(bool_null.is_instance_of(&tstruct(&tboolnull, &tboolnull)));
assert!(bool_null.is_instance_of(&tstruct(&tbool, &tboolnull)));
assert!(!bool_null.is_instance_of(&tstruct(&tboolnull, &tbool)));
assert!(!bool_null.is_instance_of(&tstruct(&tbool, &tbool)));

assert!(bool_null.is_instance_of(&tstruct(&tbool, &tnull)));
assert!(!bool_null.is_instance_of(&tstruct(&tnull, &tbool)));

assert!(bool_bool.is_instance_of(&tstruct(&tboolnull, &tboolnull)));
assert!(bool_bool.is_instance_of(&tstruct(&tbool, &tboolnull)));
assert!(bool_bool.is_instance_of(&tstruct(&tboolnull, &tbool)));
assert!(bool_bool.is_instance_of(&tstruct(&tbool, &tbool)));

assert!(!bool_bool.is_instance_of(&tstruct(&tbool, &tnull)));
assert!(!bool_bool.is_instance_of(&tstruct(&tnull, &tbool)));
}

#[test]
pub fn test_is_instance_of_null() {
assert!(ScalarValue::Null.is_instance_of(&DType::Bool(Nullability::Nullable)));
assert!(!ScalarValue::Null.is_instance_of(&DType::Bool(Nullability::NonNullable)));

assert!(
ScalarValue::Null.is_instance_of(&DType::Primitive(PType::U8, Nullability::Nullable))
);
assert!(ScalarValue::Null.is_instance_of(&DType::Utf8(Nullability::Nullable)));
assert!(ScalarValue::Null.is_instance_of(&DType::Binary(Nullability::Nullable)));
assert!(ScalarValue::Null.is_instance_of(&DType::Struct(
StructDType::new([].into(), [].into()),
Nullability::Nullable,
)));
assert!(ScalarValue::Null.is_instance_of(&DType::List(
DType::Utf8(Nullability::NonNullable).into(),
Nullability::Nullable
)));
assert!(ScalarValue::Null.is_instance_of(&DType::Null));
}
}
Loading