Skip to content

Commit

Permalink
implement eq_dyn and neq_dyn (#858)
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist authored and alamb committed Oct 27, 2021
1 parent bfac9e5 commit 32df9b6
Showing 1 changed file with 171 additions and 17 deletions.
188 changes: 171 additions & 17 deletions arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.

use regex::Regex;
use std::collections::HashMap;

use crate::array::*;
use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuffer};
use crate::compute::binary_boolean_kernel;
use crate::compute::util::combine_option_bitmap;
use crate::datatypes::{ArrowNumericType, DataType};
use crate::datatypes::{
ArrowNumericType, DataType, Float32Type, Float64Type, Int16Type, Int32Type,
Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use crate::error::{ArrowError, Result};
use crate::util::bit_util;
use regex::Regex;
use std::any::type_name;
use std::collections::HashMap;

/// Helper function to perform boolean lambda function on values from two arrays, this
/// version does not attempt to use SIMD.
Expand Down Expand Up @@ -959,7 +962,142 @@ where
Ok(BooleanArray::from(data))
}

/// Perform `left == right` operation on two arrays.
macro_rules! typed_cmp {
($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident) => {{
let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
ArrowError::CastError(format!(
"Left array cannot be cast to {}",
type_name::<$T>()
))
})?;
let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
ArrowError::CastError(format!(
"Right array cannot be cast to {}",
type_name::<$T>(),
))
})?;
$OP(left, right)
}};
($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
ArrowError::CastError(format!(
"Left array cannot be cast to {}",
type_name::<$T>()
))
})?;
let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
ArrowError::CastError(format!(
"Right array cannot be cast to {}",
type_name::<$T>(),
))
})?;
$OP::<$TT>(left, right)
}};
}

macro_rules! typed_compares {
($LEFT: expr, $RIGHT: expr, $OP_BOOL: ident, $OP_PRIM: ident, $OP_STR: ident) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
(DataType::Boolean, DataType::Boolean) => {
typed_cmp!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
}
(DataType::Int8, DataType::Int8) => {
typed_cmp!($LEFT, $RIGHT, Int8Array, $OP_PRIM, Int8Type)
}
(DataType::Int16, DataType::Int16) => {
typed_cmp!($LEFT, $RIGHT, Int16Array, $OP_PRIM, Int16Type)
}
(DataType::Int32, DataType::Int32) => {
typed_cmp!($LEFT, $RIGHT, Int32Array, $OP_PRIM, Int32Type)
}
(DataType::Int64, DataType::Int64) => {
typed_cmp!($LEFT, $RIGHT, Int64Array, $OP_PRIM, Int64Type)
}
(DataType::UInt8, DataType::UInt8) => {
typed_cmp!($LEFT, $RIGHT, UInt8Array, $OP_PRIM, UInt8Type)
}
(DataType::UInt16, DataType::UInt16) => {
typed_cmp!($LEFT, $RIGHT, UInt16Array, $OP_PRIM, UInt16Type)
}
(DataType::UInt32, DataType::UInt32) => {
typed_cmp!($LEFT, $RIGHT, UInt32Array, $OP_PRIM, UInt32Type)
}
(DataType::UInt64, DataType::UInt64) => {
typed_cmp!($LEFT, $RIGHT, UInt64Array, $OP_PRIM, UInt64Type)
}
(DataType::Float32, DataType::Float32) => {
typed_cmp!($LEFT, $RIGHT, Float32Array, $OP_PRIM, Float32Type)
}
(DataType::Float64, DataType::Float64) => {
typed_cmp!($LEFT, $RIGHT, Float64Array, $OP_PRIM, Float64Type)
}
(DataType::Utf8, DataType::Utf8) => {
typed_cmp!($LEFT, $RIGHT, StringArray, $OP_STR, i32)
}
(DataType::LargeUtf8, DataType::LargeUtf8) => {
typed_cmp!($LEFT, $RIGHT, LargeStringArray, $OP_STR, i64)
}
(t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
"Comparing arrays of type {} is not yet implemented",
t1
))),
(t1, t2) => Err(ArrowError::CastError(format!(
"Cannot compare two arrays of different types ({} and {})",
t1, t2
))),
}
}};
}

/// Perform `left == right` operation on two (dynamic) [`Array`]s.
///
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
/// with a casting error.
pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, eq_bool, eq, eq_utf8)
}

/// Perform `left != right` operation on two (dynamic) [`Array`]s.
///
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
/// with a casting error.
pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, neq_bool, neq, neq_utf8)
}

/// Perform `left < right` operation on two (dynamic) [`Array`]s.
///
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
/// with a casting error.
pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, lt_bool, lt, lt_utf8)
}

/// Perform `left <= right` operation on two (dynamic) [`Array`]s.
///
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
/// with a casting error.
pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8)
}

/// Perform `left > right` operation on two (dynamic) [`Array`]s.
///
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
/// with a casting error.
pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, gt_bool, gt, gt_utf8)
}

/// Perform `left >= right` operation on two (dynamic) [`Array`]s.
///
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
/// with a casting error.
pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8)
}

/// Perform `left == right` operation on two [`PrimitiveArray`]s.
pub fn eq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: ArrowNumericType,
Expand All @@ -970,7 +1108,7 @@ where
return compare_op!(left, right, |a, b| a == b);
}

/// Perform `left == right` operation on an array and a scalar value.
/// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value.
pub fn eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
Expand All @@ -981,7 +1119,7 @@ where
return compare_op_scalar!(left, right, |a, b| a == b);
}

/// Perform `left != right` operation on two arrays.
/// Perform `left != right` operation on two [`PrimitiveArray`]s.
pub fn neq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: ArrowNumericType,
Expand All @@ -992,7 +1130,7 @@ where
return compare_op!(left, right, |a, b| a != b);
}

/// Perform `left != right` operation on an array and a scalar value.
/// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar value.
pub fn neq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
Expand All @@ -1003,7 +1141,7 @@ where
return compare_op_scalar!(left, right, |a, b| a != b);
}

/// Perform `left < right` operation on two arrays. Null values are less than non-null
/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
/// values.
pub fn lt<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
where
Expand All @@ -1015,7 +1153,7 @@ where
return compare_op!(left, right, |a, b| a < b);
}

/// Perform `left < right` operation on an array and a scalar value.
/// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar value.
/// Null values are less than non-null values.
pub fn lt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
Expand All @@ -1027,7 +1165,7 @@ where
return compare_op_scalar!(left, right, |a, b| a < b);
}

/// Perform `left <= right` operation on two arrays. Null values are less than non-null
/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
/// values.
pub fn lt_eq<T>(
left: &PrimitiveArray<T>,
Expand All @@ -1042,7 +1180,7 @@ where
return compare_op!(left, right, |a, b| a <= b);
}

/// Perform `left <= right` operation on an array and a scalar value.
/// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar value.
/// Null values are less than non-null values.
pub fn lt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
Expand All @@ -1054,7 +1192,7 @@ where
return compare_op_scalar!(left, right, |a, b| a <= b);
}

/// Perform `left > right` operation on two arrays. Non-null values are greater than null
/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
/// values.
pub fn gt<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
where
Expand All @@ -1066,7 +1204,7 @@ where
return compare_op!(left, right, |a, b| a > b);
}

/// Perform `left > right` operation on an array and a scalar value.
/// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar value.
/// Non-null values are greater than null values.
pub fn gt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
Expand All @@ -1078,7 +1216,7 @@ where
return compare_op_scalar!(left, right, |a, b| a > b);
}

/// Perform `left >= right` operation on two arrays. Non-null values are greater than null
/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
/// values.
pub fn gt_eq<T>(
left: &PrimitiveArray<T>,
Expand All @@ -1093,7 +1231,7 @@ where
return compare_op!(left, right, |a, b| a >= b);
}

/// Perform `left >= right` operation on an array and a scalar value.
/// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar value.
/// Non-null values are greater than null values.
pub fn gt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
Expand Down Expand Up @@ -1245,11 +1383,17 @@ mod tests {
/// `EXPECTED` can be either `Vec<bool>` or `Vec<Option<bool>>`.
/// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`.
macro_rules! cmp_i64 {
($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => {
($KERNEL:ident, $DYN_KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => {
let a = Int64Array::from($A_VEC);
let b = Int64Array::from($B_VEC);
let c = $KERNEL(&a, &b).unwrap();
assert_eq!(BooleanArray::from($EXPECTED), c);

// slice and test if the dynamic array works
let a = a.slice(0, a.len());
let b = b.slice(0, b.len());
let c = $DYN_KERNEL(a.as_ref(), b.as_ref()).unwrap();
assert_eq!(BooleanArray::from($EXPECTED), c);
};
}

Expand All @@ -1269,6 +1413,7 @@ mod tests {
fn test_primitive_array_eq() {
cmp_i64!(
eq,
eq_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![false, false, true, false, false, false, false, true, false, false]
Expand Down Expand Up @@ -1315,6 +1460,7 @@ mod tests {
fn test_primitive_array_neq() {
cmp_i64!(
neq,
neq_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![true, true, false, true, true, true, true, false, true, true]
Expand Down Expand Up @@ -1396,6 +1542,7 @@ mod tests {
fn test_primitive_array_lt() {
cmp_i64!(
lt,
lt_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![false, false, false, true, true, false, false, false, true, true]
Expand All @@ -1416,6 +1563,7 @@ mod tests {
fn test_primitive_array_lt_nulls() {
cmp_i64!(
lt,
lt_dyn,
vec![None, None, Some(1), Some(1), None, None, Some(2), Some(2),],
vec![None, Some(1), None, Some(1), None, Some(3), None, Some(3),],
vec![None, None, None, Some(false), None, None, None, Some(true)]
Expand All @@ -1436,6 +1584,7 @@ mod tests {
fn test_primitive_array_lt_eq() {
cmp_i64!(
lt_eq,
lt_eq_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![false, false, true, true, true, false, false, true, true, true]
Expand All @@ -1456,6 +1605,7 @@ mod tests {
fn test_primitive_array_lt_eq_nulls() {
cmp_i64!(
lt_eq,
lt_eq_dyn,
vec![None, None, Some(1), None, None, Some(1), None, None, Some(1)],
vec![None, Some(1), Some(0), None, Some(1), Some(2), None, None, Some(3)],
vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)]
Expand All @@ -1476,6 +1626,7 @@ mod tests {
fn test_primitive_array_gt() {
cmp_i64!(
gt,
gt_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![true, true, false, false, false, true, true, false, false, false]
Expand All @@ -1496,6 +1647,7 @@ mod tests {
fn test_primitive_array_gt_nulls() {
cmp_i64!(
gt,
gt_dyn,
vec![None, None, Some(1), None, None, Some(2), None, None, Some(3)],
vec![None, Some(1), Some(1), None, Some(1), Some(1), None, Some(1), Some(1)],
vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)]
Expand All @@ -1516,6 +1668,7 @@ mod tests {
fn test_primitive_array_gt_eq() {
cmp_i64!(
gt_eq,
gt_eq_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![true, true, true, false, false, true, true, true, false, false]
Expand All @@ -1536,6 +1689,7 @@ mod tests {
fn test_primitive_array_gt_eq_nulls() {
cmp_i64!(
gt_eq,
gt_eq_dyn,
vec![None, None, Some(1), None, Some(1), Some(2), None, None, Some(1)],
vec![None, Some(1), None, None, Some(1), Some(1), None, Some(2), Some(2)],
vec![None, None, None, None, Some(true), Some(true), None, None, Some(false)]
Expand Down

0 comments on commit 32df9b6

Please sign in to comment.