Skip to content

Commit

Permalink
Remove binary_array_op_dyn_scalar! (apache#2512)
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb authored and ovr committed Aug 15, 2022
1 parent 909ed9c commit 5be64ab
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 79 deletions.
86 changes: 17 additions & 69 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ use arrow::compute::kernels::comparison::{
use arrow::compute::kernels::comparison::{
eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar,
};
use arrow::compute::kernels::comparison::{like_utf8, nlike_utf8, regexp_is_match_utf8};
use arrow::compute::kernels::comparison::{
like_utf8_scalar, nlike_utf8_scalar, regexp_is_match_utf8_scalar, nilike_utf8, ilike_utf8, nilike_utf8_scalar, ilike_utf8_scalar
ilike_utf8, ilike_utf8_scalar, like_utf8_scalar, nilike_utf8, nilike_utf8_scalar,
nlike_utf8_scalar, regexp_is_match_utf8_scalar,
};
use arrow::compute::kernels::comparison::{like_utf8, nlike_utf8, regexp_is_match_utf8};
use arrow::datatypes::{ArrowNumericType, DataType, Schema, TimeUnit};
use arrow::error::ArrowError::DivideByZero;
use arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -855,23 +856,6 @@ macro_rules! compute_utf8_op_dyn_scalar {
}};
}

/// Invoke a compute kernel on a boolean data array and a scalar value
macro_rules! compute_bool_op_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
use std::convert::TryInto;
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
// generate the scalar function name, such as lt_scalar, from the $OP parameter
// (which could have a value of lt) and the suffix _scalar
Ok(Arc::new(paste::expr! {[<$OP _bool_scalar>]}(
&ll,
$RIGHT.try_into()?,
)?))
}};
}

/// Invoke a compute kernel on a boolean data array and a scalar value
macro_rules! compute_bool_op_dyn_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
Expand Down Expand Up @@ -1052,52 +1036,6 @@ macro_rules! binary_primitive_array_op_scalar {
}};
}

/// The binary_array_op_scalar macro includes types that extend beyond the primitive,
/// such as Utf8 strings.
#[macro_export]
macro_rules! binary_array_op_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
DataType::Decimal(_,_) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray),
DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array),
DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array),
DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array),
DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array),
DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array),
DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array),
DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array),
DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array),
DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array),
DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array),
DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray),
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMicrosecondArray)
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMillisecondArray)
}
DataType::Timestamp(TimeUnit::Second, _) => {
compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampSecondArray)
}
DataType::Date32 => {
compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array)
}
DataType::Date64 => {
compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array)
}
DataType::Boolean => compute_bool_op_scalar!($LEFT, $RIGHT, $OP, BooleanArray),
other => Err(DataFusionError::Internal(format!(
"Data type {:?} not supported for scalar operation '{}' on dyn array",
other, stringify!($OP)
))),
};
Some(result)
}};
}

/// The binary_array_op macro includes types that extend beyond the primitive,
/// such as Utf8 strings.
#[macro_export]
Expand Down Expand Up @@ -1383,6 +1321,20 @@ macro_rules! binary_array_op_dyn_scalar {
}}
}

/// Compares the array with the scalar value for equality, sometimes
/// used in other kernels
pub(crate) fn array_eq_scalar(lhs: &dyn Array, rhs: &ScalarValue) -> Result<ArrayRef> {
binary_array_op_dyn_scalar!(lhs, rhs.clone(), eq, &DataType::Boolean).ok_or_else(
|| {
DataFusionError::Internal(format!(
"Data type {:?} and scalar {:?} not supported for array_eq_scalar",
lhs.data_type(),
rhs.get_datatype()
))
},
)?
}

impl BinaryExpr {
/// Evaluate the expression of the left input is an array and
/// right is literal - use scalar operations
Expand Down Expand Up @@ -1631,10 +1583,6 @@ fn is_not_distinct_from_null(
make_boolean_array(length, true)
}

pub fn eq_null(left: &NullArray, _right: &NullArray) -> Result<BooleanArray> {
Ok((0..left.len()).into_iter().map(|_| None).collect())
}

fn make_boolean_array(length: usize, value: bool) -> Result<BooleanArray> {
Ok((0..length).into_iter().map(|_| Some(value)).collect())
}
Expand Down
18 changes: 8 additions & 10 deletions datafusion/physical-expr/src/expressions/nullif.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@

use std::sync::Arc;

use crate::expressions::binary::{eq_decimal, eq_decimal_scalar, eq_null};
use arrow::array::Array;
use arrow::array::*;
use arrow::compute::eq_dyn;
use arrow::compute::kernels::boolean::nullif;
use arrow::compute::kernels::comparison::{
eq, eq_bool, eq_bool_scalar, eq_scalar, eq_utf8, eq_utf8_scalar,
};
use arrow::datatypes::{DataType, TimeUnit};
use arrow::datatypes::DataType;
use cube_ext::nullif_func_str;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;

use super::binary::array_eq_scalar;

/// Invoke a compute kernel on a primitive array and a Boolean Array
macro_rules! compute_bool_array_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
Expand Down Expand Up @@ -88,18 +86,18 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {

match (lhs, rhs) {
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
let cond_array = binary_array_op_scalar!(lhs, rhs.clone(), eq).unwrap()?;
let cond_array = array_eq_scalar(lhs, rhs)?;

let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?;

Ok(ColumnarValue::Array(array))
}
(ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => {
// Get args0 == args1 evaluated and produce a boolean array
let cond_array = binary_array_op!(lhs, rhs, eq)?;
let cond_array = eq_dyn(lhs, rhs)?;

// Now, invoke nullif on the result
let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?;
let array = primitive_bool_array_op!(lhs, cond_array, nullif)?;
Ok(ColumnarValue::Array(array))
}
_ => Err(DataFusionError::NotImplemented(
Expand Down Expand Up @@ -130,7 +128,7 @@ pub static SUPPORTED_NULLIF_TYPES: &[DataType] = &[
#[cfg(test)]
mod tests {
use super::*;
use datafusion_common::Result;
use datafusion_common::{Result, ScalarValue};

#[test]
fn nullif_int32() -> Result<()> {
Expand Down

0 comments on commit 5be64ab

Please sign in to comment.