Skip to content

Commit

Permalink
include validities in comparisons (jorgecarleitao#846)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored and Dexter Duckworth committed Mar 2, 2022
1 parent 7e7c52c commit 46f8c98
Show file tree
Hide file tree
Showing 7 changed files with 375 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ compute_bitwise = []
compute_boolean = []
compute_boolean_kleene = []
compute_cast = ["lexical-core", "compute_take"]
compute_comparison = ["compute_take"]
compute_comparison = ["compute_take", "compute_boolean"]
compute_concatenate = []
compute_contains = []
compute_filter = []
Expand Down
45 changes: 45 additions & 0 deletions src/compute/comparison/binary.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Comparison functions for [`BinaryArray`]
use crate::compute::comparison::{finish_eq_validities, finish_neq_validities};
use crate::{
array::{BinaryArray, BooleanArray, Offset},
bitmap::Bitmap,
Expand Down Expand Up @@ -49,23 +50,67 @@ pub fn eq<O: Offset>(lhs: &BinaryArray<O>, rhs: &BinaryArray<O>) -> BooleanArray
compare_op(lhs, rhs, |a, b| a == b)
}

/// Perform `lhs == rhs` operation on [`BinaryArray`] and include validities in comparison.
/// # Panic
/// iff the arrays do not have the same length.
pub fn eq_and_validity<O: Offset>(lhs: &BinaryArray<O>, rhs: &BinaryArray<O>) -> BooleanArray {
let validity_lhs = lhs.validity().cloned();
let validity_rhs = rhs.validity().cloned();
let lhs = lhs.with_validity(None);
let rhs = rhs.with_validity(None);
let out = compare_op(&lhs, &rhs, |a, b| a == b);

finish_eq_validities(out, validity_lhs, validity_rhs)
}

/// Perform `lhs == rhs` operation on [`BinaryArray`] and a scalar.
pub fn eq_scalar<O: Offset>(lhs: &BinaryArray<O>, rhs: &[u8]) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a == b)
}

/// Perform `lhs == rhs` operation on [`BinaryArray`] and a scalar and include validities in comparison.
pub fn eq_scalar_and_validity<O: Offset>(lhs: &BinaryArray<O>, rhs: &[u8]) -> BooleanArray {
let validity = lhs.validity().cloned();
let lhs = lhs.with_validity(None);
let out = compare_op_scalar(&lhs, rhs, |a, b| a == b);

finish_eq_validities(out, validity, None)
}

/// Perform `lhs != rhs` operation on [`BinaryArray`].
/// # Panic
/// iff the arrays do not have the same length.
pub fn neq<O: Offset>(lhs: &BinaryArray<O>, rhs: &BinaryArray<O>) -> BooleanArray {
compare_op(lhs, rhs, |a, b| a != b)
}

/// Perform `lhs != rhs` operation on [`BinaryArray`].
/// # Panic
/// iff the arrays do not have the same length and include validities in comparison.
pub fn neq_and_validity<O: Offset>(lhs: &BinaryArray<O>, rhs: &BinaryArray<O>) -> BooleanArray {
let validity_lhs = lhs.validity().cloned();
let validity_rhs = rhs.validity().cloned();
let lhs = lhs.with_validity(None);
let rhs = rhs.with_validity(None);

let out = compare_op(&lhs, &rhs, |a, b| a != b);
finish_neq_validities(out, validity_lhs, validity_rhs)
}

/// Perform `lhs != rhs` operation on [`BinaryArray`] and a scalar.
pub fn neq_scalar<O: Offset>(lhs: &BinaryArray<O>, rhs: &[u8]) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a != b)
}

/// Perform `lhs != rhs` operation on [`BinaryArray`] and a scalar and include validities in comparison.
pub fn neq_scalar_and_validity<O: Offset>(lhs: &BinaryArray<O>, rhs: &[u8]) -> BooleanArray {
let validity = lhs.validity().cloned();
let lhs = lhs.with_validity(None);
let out = compare_op_scalar(&lhs, rhs, |a, b| a != b);

finish_neq_validities(out, validity, None)
}

/// Perform `lhs < rhs` operation on [`BinaryArray`].
pub fn lt<O: Offset>(lhs: &BinaryArray<O>, rhs: &BinaryArray<O>) -> BooleanArray {
compare_op(lhs, rhs, |a, b| a < b)
Expand Down
46 changes: 46 additions & 0 deletions src/compute/comparison/boolean.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Comparison functions for [`BooleanArray`]
use crate::compute::comparison::{finish_eq_validities, finish_neq_validities};
use crate::{
array::BooleanArray,
bitmap::{binary, unary, Bitmap},
Expand Down Expand Up @@ -38,6 +39,17 @@ pub fn eq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray {
compare_op(lhs, rhs, |a, b| !(a ^ b))
}

/// Perform `lhs == rhs` operation on two [`BooleanArray`]s and include validities in comparison.
pub fn eq_and_validity(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray {
let validity_lhs = lhs.validity().cloned();
let validity_rhs = rhs.validity().cloned();
let lhs = lhs.with_validity(None);
let rhs = rhs.with_validity(None);
let out = compare_op(&lhs, &rhs, |a, b| !(a ^ b));

finish_eq_validities(out, validity_lhs, validity_rhs)
}

/// Perform `lhs == rhs` operation on a [`BooleanArray`] and a scalar value.
pub fn eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
if rhs {
Expand All @@ -47,16 +59,50 @@ pub fn eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
}
}

/// Perform `lhs == rhs` operation on a [`BooleanArray`] and a scalar value and include validities in comparison.
pub fn eq_scalar_and_validity(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
let validity = lhs.validity().cloned();
let lhs = lhs.with_validity(None);
if rhs {
finish_eq_validities(lhs, validity, None)
} else {
let lhs = lhs.with_validity(None);

let out = compare_op_scalar(&lhs, rhs, |a, _| !a);

finish_eq_validities(out, validity, None)
}
}

/// `lhs != rhs` for [`BooleanArray`]
pub fn neq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray {
compare_op(lhs, rhs, |a, b| a ^ b)
}

/// `lhs != rhs` for [`BooleanArray`] and include validities in comparison.
pub fn neq_and_validity(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray {
let validity_lhs = lhs.validity().cloned();
let validity_rhs = rhs.validity().cloned();
let lhs = lhs.with_validity(None);
let rhs = rhs.with_validity(None);
let out = compare_op(&lhs, &rhs, |a, b| a ^ b);

finish_neq_validities(out, validity_lhs, validity_rhs)
}

/// Perform `left != right` operation on an array and a scalar value.
pub fn neq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
eq_scalar(lhs, !rhs)
}

/// Perform `left != right` operation on an array and a scalar value.
pub fn neq_scalar_and_validity(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
let validity = lhs.validity().cloned();
let lhs = lhs.with_validity(None);
let out = eq_scalar(&lhs, !rhs);
finish_neq_validities(out, validity, None)
}

/// Perform `left < right` operation on two arrays.
pub fn lt(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray {
compare_op(lhs, rhs, |a, b| !a & b)
Expand Down
95 changes: 95 additions & 0 deletions src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ mod simd;
pub use simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd};

use super::take::take_boolean;
use crate::bitmap::Bitmap;
use crate::compute;
pub(crate) use primitive::{
compare_values_op as primitive_compare_values_op,
compare_values_op_scalar as primitive_compare_values_op_scalar,
Expand Down Expand Up @@ -167,6 +169,17 @@ pub fn eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, eq, match_eq)
}

/// `==` between two [`Array`]s and includes validities in comparison.
/// Use [`can_eq`] to check whether the operation is valid
/// # Panic
/// Panics iff either:
/// * the arrays do not have have the same logical type
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn eq_and_validity(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, eq_and_validity, match_eq)
}

/// Returns whether a [`DataType`] is comparable is supported by [`eq`].
pub fn can_eq(data_type: &DataType) -> bool {
can_partial_eq(data_type)
Expand All @@ -183,6 +196,17 @@ pub fn neq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, neq, match_eq)
}

/// `!=` between two [`Array`]s and includes validities in comparison.
/// Use [`can_neq`] to check whether the operation is valid
/// # Panic
/// Panics iff either:
/// * the arrays do not have have the same logical type
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn neq_and_validity(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, neq_and_validity, match_eq)
}

/// Returns whether a [`DataType`] is comparable is supported by [`neq`].
pub fn can_neq(data_type: &DataType) -> bool {
can_partial_eq(data_type)
Expand Down Expand Up @@ -320,6 +344,16 @@ pub fn eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, eq_scalar, match_eq)
}

/// `==` between an [`Array`] and a [`Scalar`] and includes validities in comparison.
/// Use [`can_eq_scalar`] to check whether the operation is valid
/// # Panic
/// Panics iff either:
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn eq_scalar_and_validity(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, eq_scalar_and_validity, match_eq)
}

/// Returns whether a [`DataType`] is supported by [`eq_scalar`].
pub fn can_eq_scalar(data_type: &DataType) -> bool {
can_partial_eq_scalar(data_type)
Expand All @@ -335,6 +369,16 @@ pub fn neq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, neq_scalar, match_eq)
}

/// `!=` between an [`Array`] and a [`Scalar`] and includes validities in comparison.
/// Use [`can_neq_scalar`] to check whether the operation is valid
/// # Panic
/// Panics iff either:
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn neq_scalar_and_validity(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, neq_scalar_and_validity, match_eq)
}

/// Returns whether a [`DataType`] is supported by [`neq_scalar`].
pub fn can_neq_scalar(data_type: &DataType) -> bool {
can_partial_eq_scalar(data_type)
Expand Down Expand Up @@ -457,3 +501,54 @@ fn can_partial_eq_scalar(data_type: &DataType) -> bool {
| DataType::Interval(IntervalUnit::MonthDayNano)
)
}

fn finish_eq_validities(
output_without_validities: BooleanArray,
validity_lhs: Option<Bitmap>,
validity_rhs: Option<Bitmap>,
) -> BooleanArray {
match (validity_lhs, validity_rhs) {
(None, None) => output_without_validities,
(Some(lhs), None) => compute::boolean::and(
&BooleanArray::from_data(DataType::Boolean, lhs, None),
&output_without_validities,
)
.unwrap(),
(None, Some(rhs)) => compute::boolean::and(
&output_without_validities,
&BooleanArray::from_data(DataType::Boolean, rhs, None),
)
.unwrap(),
(Some(lhs), Some(rhs)) => {
let lhs = BooleanArray::from_data(DataType::Boolean, lhs, None);
let rhs = BooleanArray::from_data(DataType::Boolean, rhs, None);
let eq_validities = compute::comparison::boolean::eq(&lhs, &rhs);
compute::boolean::and(&output_without_validities, &eq_validities).unwrap()
}
}
}
fn finish_neq_validities(
output_without_validities: BooleanArray,
validity_lhs: Option<Bitmap>,
validity_rhs: Option<Bitmap>,
) -> BooleanArray {
match (validity_lhs, validity_rhs) {
(None, None) => output_without_validities,
(Some(lhs), None) => {
let lhs_negated =
compute::boolean::not(&BooleanArray::from_data(DataType::Boolean, lhs, None));
compute::boolean::or(&lhs_negated, &output_without_validities).unwrap()
}
(None, Some(rhs)) => {
let rhs_negated =
compute::boolean::not(&BooleanArray::from_data(DataType::Boolean, rhs, None));
compute::boolean::or(&output_without_validities, &rhs_negated).unwrap()
}
(Some(lhs), Some(rhs)) => {
let lhs = BooleanArray::from_data(DataType::Boolean, lhs, None);
let rhs = BooleanArray::from_data(DataType::Boolean, rhs, None);
let neq_validities = compute::comparison::boolean::neq(&lhs, &rhs);
compute::boolean::or(&output_without_validities, &neq_validities).unwrap()
}
}
}
57 changes: 57 additions & 0 deletions src/compute/comparison/primitive.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Comparison functions for [`PrimitiveArray`]
use crate::compute::comparison::{finish_eq_validities, finish_neq_validities};
use crate::{
array::{BooleanArray, PrimitiveArray},
bitmap::MutableBitmap,
Expand Down Expand Up @@ -99,6 +100,21 @@ where
compare_op(lhs, rhs, |a, b| a.eq(b))
}

/// Perform `lhs == rhs` operation on two arrays and include validities in comparison.
pub fn eq_and_validity<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialEq,
{
let validity_lhs = lhs.validity().cloned();
let validity_rhs = rhs.validity().cloned();
let lhs = lhs.with_validity(None);
let rhs = rhs.with_validity(None);
let out = compare_op(&lhs, &rhs, |a, b| a.eq(b));

finish_eq_validities(out, validity_lhs, validity_rhs)
}

/// Perform `left == right` operation on an array and a scalar value.
pub fn eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
Expand All @@ -108,6 +124,19 @@ where
compare_op_scalar(lhs, rhs, |a, b| a.eq(b))
}

/// Perform `left == right` operation on an array and a scalar value and include validities in comparison.
pub fn eq_scalar_and_validity<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialEq,
{
let validity = lhs.validity().cloned();
let lhs = lhs.with_validity(None);
let out = compare_op_scalar(&lhs, rhs, |a, b| a.eq(b));

finish_eq_validities(out, validity, None)
}

/// Perform `left != right` operation on two arrays.
pub fn neq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
Expand All @@ -117,6 +146,21 @@ where
compare_op(lhs, rhs, |a, b| a.neq(b))
}

/// Perform `left != right` operation on two arrays and include validities in comparison.
pub fn neq_and_validity<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialEq,
{
let validity_lhs = lhs.validity().cloned();
let validity_rhs = rhs.validity().cloned();
let lhs = lhs.with_validity(None);
let rhs = rhs.with_validity(None);
let out = compare_op(&lhs, &rhs, |a, b| a.neq(b));

finish_neq_validities(out, validity_lhs, validity_rhs)
}

/// Perform `left != right` operation on an array and a scalar value.
pub fn neq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
Expand All @@ -126,6 +170,19 @@ where
compare_op_scalar(lhs, rhs, |a, b| a.neq(b))
}

/// Perform `left != right` operation on an array and a scalar value and include validities in comparison.
pub fn neq_scalar_and_validity<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialEq,
{
let validity = lhs.validity().cloned();
let lhs = lhs.with_validity(None);
let out = compare_op_scalar(&lhs, rhs, |a, b| a.neq(b));

finish_neq_validities(out, validity, None)
}

/// Perform `left < right` operation on two arrays.
pub fn lt<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
Expand Down
Loading

0 comments on commit 46f8c98

Please sign in to comment.