diff --git a/Cargo.lock b/Cargo.lock index d18d9f551..941818840 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4330,6 +4330,7 @@ dependencies = [ name = "vortex-scalar" version = "0.1.0" dependencies = [ + "arrow-array", "build-vortex", "datafusion-common", "flatbuffers", diff --git a/vortex-array/src/array/constant/compute.rs b/vortex-array/src/array/constant/compute.rs index 74fb74747..c7a19eb95 100644 --- a/vortex-array/src/array/constant/compute.rs +++ b/vortex-array/src/array/constant/compute.rs @@ -1,16 +1,21 @@ use std::cmp::Ordering; +use std::sync::Arc; +use arrow_array::Datum; use vortex_dtype::Nullability; use vortex_error::{vortex_bail, VortexResult}; +use vortex_expr::Operator; use vortex_scalar::Scalar; use crate::array::constant::ConstantArray; +use crate::arrow::FromArrowArray; use crate::compute::unary::{scalar_at, ScalarAtFn}; use crate::compute::{ - AndFn, ArrayCompute, OrFn, SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, + scalar_cmp, AndFn, ArrayCompute, CompareFn, OrFn, SearchResult, SearchSortedFn, + SearchSortedSide, SliceFn, TakeFn, }; use crate::stats::{ArrayStatistics, Stat}; -use crate::{Array, ArrayDType, AsArray, IntoArray}; +use crate::{Array, ArrayDType, ArrayData, AsArray, IntoArray, IntoCanonical}; impl ArrayCompute for ConstantArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { @@ -29,6 +34,10 @@ impl ArrayCompute for ConstantArray { Some(self) } + fn compare(&self) -> Option<&dyn CompareFn> { + Some(self) + } + fn and(&self) -> Option<&dyn AndFn> { Some(self) } @@ -69,6 +78,34 @@ impl SearchSortedFn for ConstantArray { } } +impl CompareFn for ConstantArray { + fn compare(&self, rhs: &Array, operator: Operator) -> VortexResult { + if let Some(true) = rhs.statistics().get_as::(Stat::IsConstant) { + let lhs = self.scalar(); + let rhs = scalar_at(rhs, 0)?; + + let scalar = scalar_cmp(lhs, &rhs, operator); + + Ok(ConstantArray::new(scalar, self.len()).into_array()) + } else { + let datum = Arc::::from(self.scalar()); + let rhs = rhs.clone().into_canonical()?.into_arrow(); + let rhs = rhs.as_ref(); + + let boolean_array = match operator { + Operator::Eq => arrow_ord::cmp::eq(datum.as_ref(), &rhs)?, + Operator::NotEq => arrow_ord::cmp::neq(datum.as_ref(), &rhs)?, + Operator::Gt => arrow_ord::cmp::gt(datum.as_ref(), &rhs)?, + Operator::Gte => arrow_ord::cmp::gt_eq(datum.as_ref(), &rhs)?, + Operator::Lt => arrow_ord::cmp::lt(datum.as_ref(), &rhs)?, + Operator::Lte => arrow_ord::cmp::lt_eq(datum.as_ref(), &rhs)?, + }; + + Ok(ArrayData::from_arrow(&boolean_array, true).into_array()) + } + } +} + impl AndFn for ConstantArray { fn and(&self, array: &Array) -> VortexResult { constant_array_bool_impl( diff --git a/vortex-array/src/array/varbin/builder.rs b/vortex-array/src/array/varbin/builder.rs index 889e6cbc2..c5dea80bb 100644 --- a/vortex-array/src/array/varbin/builder.rs +++ b/vortex-array/src/array/varbin/builder.rs @@ -49,7 +49,6 @@ impl VarBinBuilder { pub fn finish(mut self, dtype: DType) -> VarBinArray { let offsets = PrimitiveArray::from(self.offsets); let data = PrimitiveArray::from_bytes(self.data.freeze(), Validity::NonNullable); - let nulls = self.validity.finish(); let validity = if dtype.is_nullable() { diff --git a/vortex-array/src/compute/boolean.rs b/vortex-array/src/compute/boolean.rs index 585d3a10a..d104349f3 100644 --- a/vortex-array/src/compute/boolean.rs +++ b/vortex-array/src/compute/boolean.rs @@ -27,6 +27,7 @@ pub fn and(lhs: &Array, rhs: &Array) -> VortexResult { return selection; } + // If neither side implements `AndFn`, we try to expand the left-hand side into a `BoolArray`, which we know does implement it, and call into that implementation. let lhs = lhs.clone().into_bool()?; lhs.and(rhs) @@ -49,6 +50,7 @@ pub fn or(lhs: &Array, rhs: &Array) -> VortexResult { return selection; } + // If neither side implements `OrFn`, we try to expand the left-hand side into a `BoolArray`, which we know does implement it, and call into that implementation. let lhs = lhs.clone().into_bool()?; lhs.or(rhs) diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index 093a665a8..819bf1d31 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -1,21 +1,39 @@ use arrow_ord::cmp; -use vortex_error::VortexResult; +use vortex_dtype::{DType, Nullability}; +use vortex_error::{vortex_bail, VortexResult}; use vortex_expr::Operator; +use vortex_scalar::Scalar; use crate::arrow::FromArrowArray; -use crate::{Array, ArrayData, IntoArray, IntoCanonical}; +use crate::{Array, ArrayDType, ArrayData, IntoArray, IntoCanonical}; pub trait CompareFn { fn compare(&self, array: &Array, operator: Operator) -> VortexResult; } pub fn compare(left: &Array, right: &Array, operator: Operator) -> VortexResult { + if left.len() != right.len() { + vortex_bail!("Compare operations only support arrays of the same length"); + } + + // TODO(adamg): This is a placeholder until we figure out type coercion and casting + if !left.dtype().eq_ignore_nullability(right.dtype()) { + vortex_bail!("Compare operations only support arrays of the same type"); + } + if let Some(selection) = left.with_dyn(|lhs| lhs.compare().map(|lhs| lhs.compare(right, operator))) { return selection; } + if let Some(selection) = right.with_dyn(|rhs| { + rhs.compare() + .map(|rhs| rhs.compare(left, operator.inverse())) + }) { + return selection; + } + // Fallback to arrow on canonical types let lhs = left.clone().into_canonical()?.into_arrow(); let rhs = right.clone().into_canonical()?.into_arrow(); @@ -31,3 +49,20 @@ pub fn compare(left: &Array, right: &Array, operator: Operator) -> VortexResult< Ok(ArrayData::from_arrow(&array, true).into_array()) } + +pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar { + if lhs.is_null() | rhs.is_null() { + Scalar::null(DType::Bool(Nullability::Nullable)) + } else { + let b = match operator { + Operator::Eq => lhs == rhs, + Operator::NotEq => lhs != rhs, + Operator::Gt => lhs > rhs, + Operator::Gte => lhs >= rhs, + Operator::Lt => lhs < rhs, + Operator::Lte => lhs <= rhs, + }; + + Scalar::bool(b, Nullability::Nullable) + } +} diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index c6eb7efd5..4b67198f3 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -8,7 +8,7 @@ //! from Arrow. pub use boolean::{and, or, AndFn, OrFn}; -pub use compare::{compare, CompareFn}; +pub use compare::{compare, scalar_cmp, CompareFn}; pub use filter::{filter, FilterFn}; pub use filter_indices::{filter_indices, FilterIndicesFn}; pub use search_sorted::*; diff --git a/vortex-scalar/Cargo.toml b/vortex-scalar/Cargo.toml index b9902dd0b..443d70206 100644 --- a/vortex-scalar/Cargo.toml +++ b/vortex-scalar/Cargo.toml @@ -12,6 +12,7 @@ edition = { workspace = true } rust-version = { workspace = true } [dependencies] +arrow-array = { workspace = true } datafusion-common = { workspace = true, optional = true } flatbuffers = { workspace = true, optional = true } flexbuffers = { workspace = true, optional = true } @@ -42,15 +43,7 @@ flatbuffers = [ "dep:serde", "vortex-buffer/flexbuffers", "vortex-error/flexbuffers", - "vortex-dtype/flatbuffers" -] -proto = [ - "dep:prost", - "dep:prost-types", - "vortex-dtype/proto", -] -serde = [ - "dep:serde", - "serde/derive", - "vortex-dtype/serde" + "vortex-dtype/flatbuffers", ] +proto = ["dep:prost", "dep:prost-types", "vortex-dtype/proto"] +serde = ["dep:serde", "serde/derive", "vortex-dtype/serde"] diff --git a/vortex-scalar/src/arrow.rs b/vortex-scalar/src/arrow.rs new file mode 100644 index 000000000..836e8181e --- /dev/null +++ b/vortex-scalar/src/arrow.rs @@ -0,0 +1,78 @@ +use std::sync::Arc; + +use arrow_array::*; +use vortex_dtype::{DType, PType}; + +use crate::{PValue, Scalar}; + +impl From<&Scalar> for Arc { + fn from(value: &Scalar) -> Arc { + match value.dtype { + DType::Null => Arc::new(NullArray::new(1)), + DType::Bool(_) => match value.value.as_bool().expect("should be bool") { + Some(b) => Arc::new(BooleanArray::new_scalar(b)), + None => Arc::new(BooleanArray::new_null(1)), + }, + DType::Primitive(ptype, _) => { + let pvalue = value.value.as_pvalue().expect("should be pvalue"); + match pvalue { + None => match ptype { + PType::U8 => Arc::new(UInt8Array::new_null(1)), + PType::U16 => Arc::new(UInt16Array::new_null(1)), + PType::U32 => Arc::new(UInt32Array::new_null(1)), + PType::U64 => Arc::new(UInt64Array::new_null(1)), + PType::I8 => Arc::new(Int8Array::new_null(1)), + PType::I16 => Arc::new(Int16Array::new_null(1)), + PType::I32 => Arc::new(Int32Array::new_null(1)), + PType::I64 => Arc::new(Int64Array::new_null(1)), + PType::F16 => Arc::new(Float16Array::new_null(1)), + PType::F32 => Arc::new(Float32Array::new_null(1)), + PType::F64 => Arc::new(Float64Array::new_null(1)), + }, + Some(pvalue) => match pvalue { + PValue::U8(v) => Arc::new(UInt8Array::new_scalar(v)), + PValue::U16(v) => Arc::new(UInt16Array::new_scalar(v)), + PValue::U32(v) => Arc::new(UInt32Array::new_scalar(v)), + PValue::U64(v) => Arc::new(UInt64Array::new_scalar(v)), + PValue::I8(v) => Arc::new(Int8Array::new_scalar(v)), + PValue::I16(v) => Arc::new(Int16Array::new_scalar(v)), + PValue::I32(v) => Arc::new(Int32Array::new_scalar(v)), + PValue::I64(v) => Arc::new(Int64Array::new_scalar(v)), + PValue::F16(v) => Arc::new(Float16Array::new_scalar(v)), + PValue::F32(v) => Arc::new(Float32Array::new_scalar(v)), + PValue::F64(v) => Arc::new(Float64Array::new_scalar(v)), + }, + } + } + DType::Utf8(_) => { + match value + .value + .as_buffer_string() + .expect("should be buffer string") + { + Some(s) => Arc::new(StringArray::new_scalar(s.as_str())), + None => Arc::new(StringArray::new_null(1)), + } + } + DType::Binary(_) => { + match value + .value + .as_buffer_string() + .expect("should be buffer string") + { + Some(s) => Arc::new(BinaryArray::new_scalar(s.as_bytes())), + None => Arc::new(BinaryArray::new_null(1)), + } + } + DType::Struct(..) => { + todo!("struct scalar conversion") + } + DType::List(..) => { + todo!("list scalar conversion") + } + DType::Extension(..) => { + todo!("extension scalar conversion") + } + } + } +} diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 1612839a7..023b187de 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -2,6 +2,7 @@ use std::cmp::Ordering; use vortex_dtype::DType; +mod arrow; mod binary; mod bool; mod datafusion;