Skip to content

Commit

Permalink
Numeric, String, Boolean comparisons with literal NULL (apache#2481)
Browse files Browse the repository at this point in the history
  • Loading branch information
WinkerDu authored and MazterQyou committed May 11, 2023
1 parent a809c5e commit 1a26be2
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 36 deletions.
3 changes: 2 additions & 1 deletion datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1793,7 +1793,8 @@ mod tests {
let bool_expr = col("c1").eq(col("c1"));
let cases = vec![
// utf8 < u32
col("c1").lt(col("c2")),
// NOTE(cubesql): valid
//col("c1").lt(col("c2")),
// utf8 AND utf8
col("c1").and(col("c1")),
// u8 AND u8
Expand Down
120 changes: 120 additions & 0 deletions datafusion/core/tests/sql/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1280,3 +1280,123 @@ async fn nested_subquery() -> Result<()> {
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn comparisons_with_null() -> Result<()> {
let ctx = SessionContext::new();
// 1. Numeric comparison with NULL
let sql = "select column1 < NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------------------------+",
"| t.column1 Lt Utf8(NULL) |",
"+-------------------------+",
"| |",
"| |",
"+-------------------------+",
];
assert_batches_eq!(expected, &actual);

let sql =
"select column1 <= NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+---------------------------+",
"| t.column1 LtEq Utf8(NULL) |",
"+---------------------------+",
"| |",
"| |",
"+---------------------------+",
];
assert_batches_eq!(expected, &actual);

let sql = "select column1 > NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------------------------+",
"| t.column1 Gt Utf8(NULL) |",
"+-------------------------+",
"| |",
"| |",
"+-------------------------+",
];
assert_batches_eq!(expected, &actual);

let sql =
"select column1 >= NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+---------------------------+",
"| t.column1 GtEq Utf8(NULL) |",
"+---------------------------+",
"| |",
"| |",
"+---------------------------+",
];
assert_batches_eq!(expected, &actual);

let sql = "select column1 = NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------------------------+",
"| t.column1 Eq Utf8(NULL) |",
"+-------------------------+",
"| |",
"| |",
"+-------------------------+",
];
assert_batches_eq!(expected, &actual);

let sql =
"select column1 != NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------------------+",
"| t.column1 NotEq Utf8(NULL) |",
"+----------------------------+",
"| |",
"| |",
"+----------------------------+",
];
assert_batches_eq!(expected, &actual);

// 1.1 Float value comparison with NULL
let sql = "select column3 < NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------------------------+",
"| t.column3 Lt Utf8(NULL) |",
"+-------------------------+",
"| |",
"| |",
"+-------------------------+",
];
assert_batches_eq!(expected, &actual);

// String comparison with NULL
let sql = "select column2 < NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------------------------+",
"| t.column2 Lt Utf8(NULL) |",
"+-------------------------+",
"| |",
"| |",
"+-------------------------+",
];
assert_batches_eq!(expected, &actual);

// Boolean comparison with NULL
let sql = "select column1 < NULL from (VALUES (true), (false)) as t";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------------------------+",
"| t.column1 Lt Utf8(NULL) |",
"+-------------------------+",
"| |",
"| |",
"+-------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}
14 changes: 14 additions & 0 deletions datafusion/physical-expr/src/coercion_rule/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@ pub fn comparison_eq_coercion(
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
.or_else(|| string_numeric_coercion(lhs_type, rhs_type))
.or_else(|| string_boolean_coercion(lhs_type, rhs_type))
}

// NOTE: NULL hack!
fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
Expand All @@ -135,6 +137,15 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
}
}

fn string_boolean_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Utf8, Boolean) | (Boolean, Utf8) => Some(Utf8),
(LargeUtf8, Boolean) | (Boolean, LargeUtf8) => Some(LargeUtf8),
_ => None,
}
}

fn comparison_order_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
Expand All @@ -149,6 +160,9 @@ fn comparison_order_coercion(
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
.or_else(|| string_numeric_coercion(lhs_type, rhs_type))
.or_else(|| string_boolean_coercion(lhs_type, rhs_type))
}

fn comparison_binary_numeric_coercion(
Expand Down
80 changes: 45 additions & 35 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,17 +844,15 @@ macro_rules! compute_utf8_op_scalar {

/// Invoke a compute kernel on a data array and a scalar value
macro_rules! compute_utf8_op_dyn_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
if let Some(string_value) = $RIGHT {
Ok(Arc::new(paste::expr! {[<$OP _dyn_utf8_scalar>]}(
$LEFT,
&string_value,
)?))
} else {
Err(DataFusionError::Internal(format!(
"compute_utf8_op_scalar for '{}' failed with literal 'none' value",
stringify!($OP),
)))
// when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
}
}};
}
Expand All @@ -878,7 +876,7 @@ macro_rules! compute_bool_op_scalar {

/// Invoke a compute kernel on a boolean data array and a scalar value
macro_rules! compute_bool_op_dyn_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
// generate the scalar function name, such as lt_dyn_bool_scalar, from the $OP parameter
// (which could have a value of lt) and the suffix _scalar
if let Some(b) = $RIGHT {
Expand All @@ -887,10 +885,8 @@ macro_rules! compute_bool_op_dyn_scalar {
b,
)?))
} else {
Err(DataFusionError::Internal(format!(
"compute_utf8_op_scalar for '{}' failed with literal 'none' value",
stringify!($OP),
)))
// when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
}
}};
}
Expand Down Expand Up @@ -938,8 +934,9 @@ macro_rules! compute_op_scalar {

/// Invoke a dyn compute kernel on a data array and a scalar value
/// LEFT is Primitive or Dictionart array of numeric values, RIGHT is scalar value
/// OP_TYPE is the return type of scalar function
macro_rules! compute_op_dyn_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
// generate the scalar function name, such as lt_dyn_scalar, from the $OP parameter
// (which could have a value of lt_dyn) and the suffix _scalar
if let Some(value) = $RIGHT {
Expand All @@ -948,10 +945,8 @@ macro_rules! compute_op_dyn_scalar {
value,
)?))
} else {
Err(DataFusionError::Internal(format!(
"compute_utf8_op_scalar for '{}' failed with literal 'none' value",
stringify!($OP),
)))
// when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
}
}};
}
Expand Down Expand Up @@ -1359,22 +1354,22 @@ impl PhysicalExpr for BinaryExpr {
/// such as Utf8 strings.
#[macro_export]
macro_rules! binary_array_op_dyn_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
let result: Result<Arc<dyn Array>> = match $RIGHT {
ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, $OP),
ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, $OP, $OP_TYPE),
ScalarValue::Decimal128(..) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray),
ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP),
ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP),
ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
ScalarValue::Int16(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
ScalarValue::Int32(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
ScalarValue::Int64(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
ScalarValue::UInt8(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
ScalarValue::UInt16(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
ScalarValue::UInt32(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
ScalarValue::UInt64(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
ScalarValue::Float32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array),
ScalarValue::Float64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array),
ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::Int16(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::Int32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::Int64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::UInt8(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::UInt16(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::UInt32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::UInt64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::Float32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::Float64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::Date32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array),
ScalarValue::Date64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array),
ScalarValue::TimestampSecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampSecondArray),
Expand All @@ -1397,22 +1392,37 @@ impl BinaryExpr {
) -> Result<Option<Result<ArrayRef>>> {
let scalar_result = match &self.op {
Operator::Lt => {
binary_array_op_dyn_scalar!(array, scalar.clone(), lt)
binary_array_op_dyn_scalar!(array, scalar.clone(), lt, &DataType::Boolean)
}
Operator::LtEq => {
binary_array_op_dyn_scalar!(array, scalar.clone(), lt_eq)
binary_array_op_dyn_scalar!(
array,
scalar.clone(),
lt_eq,
&DataType::Boolean
)
}
Operator::Gt => {
binary_array_op_dyn_scalar!(array, scalar.clone(), gt)
binary_array_op_dyn_scalar!(array, scalar.clone(), gt, &DataType::Boolean)
}
Operator::GtEq => {
binary_array_op_dyn_scalar!(array, scalar.clone(), gt_eq)
binary_array_op_dyn_scalar!(
array,
scalar.clone(),
gt_eq,
&DataType::Boolean
)
}
Operator::Eq => {
binary_array_op_dyn_scalar!(array, scalar.clone(), eq)
binary_array_op_dyn_scalar!(array, scalar.clone(), eq, &DataType::Boolean)
}
Operator::NotEq => {
binary_array_op_dyn_scalar!(array, scalar.clone(), neq)
binary_array_op_dyn_scalar!(
array,
scalar.clone(),
neq,
&DataType::Boolean
)
}
Operator::Like => {
binary_string_array_op_scalar!(array, scalar.clone(), like)
Expand Down

0 comments on commit 1a26be2

Please sign in to comment.