diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index d74330aa58fa7..20201701066dd 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -39,6 +39,8 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; /// This is the single-valued counter-part of arrow’s `Array`. #[derive(Clone)] pub enum ScalarValue { + /// represents `DataType::Null` (castable to/from any other type) + Null, /// true or false value Boolean(Option), /// 32bit float @@ -170,6 +172,8 @@ impl PartialEq for ScalarValue { (IntervalMonthDayNano(_), _) => false, (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), (Struct(_, _), _) => false, + (Null, Null) => true, + (Null, _) => false, } } } @@ -270,6 +274,8 @@ impl PartialOrd for ScalarValue { } } (Struct(_, _), _) => None, + (Null, Null) => Some(Ordering::Equal), + (Null, _) => None, } } } @@ -325,6 +331,8 @@ impl std::hash::Hash for ScalarValue { v.hash(state); t.hash(state); } + // stable hash for Null value + Null => 1.hash(state), } } } @@ -594,6 +602,7 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) } ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), + ScalarValue::Null => DataType::Null, } } @@ -623,7 +632,8 @@ impl ScalarValue { pub fn is_null(&self) -> bool { matches!( *self, - ScalarValue::Boolean(None) + ScalarValue::Null + | ScalarValue::Boolean(None) | ScalarValue::UInt8(None) | ScalarValue::UInt16(None) | ScalarValue::UInt32(None) @@ -844,6 +854,7 @@ impl ScalarValue { ScalarValue::iter_to_decimal_array(scalars, precision, scale)?; Arc::new(decimal_array) } + DataType::Null => ScalarValue::iter_to_null_array(scalars), DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), @@ -976,6 +987,17 @@ impl ScalarValue { Ok(array) } + fn iter_to_null_array(scalars: impl IntoIterator) -> ArrayRef { + let length = + scalars + .into_iter() + .fold(0usize, |r, element: ScalarValue| match element { + ScalarValue::Null => r + 1, + _ => unreachable!(), + }); + new_null_array(&DataType::Null, length) + } + fn iter_to_decimal_array( scalars: impl IntoIterator, precision: &usize, @@ -1249,6 +1271,7 @@ impl ScalarValue { Arc::new(StructArray::from(field_values)) } }, + ScalarValue::Null => new_null_array(&DataType::Null, size), } } @@ -1274,6 +1297,7 @@ impl ScalarValue { } Ok(match array.data_type() { + DataType::Null => ScalarValue::Null, DataType::Decimal(precision, scale) => { ScalarValue::get_decimal_value_from_array(array, index, precision, scale) } @@ -1519,6 +1543,7 @@ impl ScalarValue { eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) } ScalarValue::Struct(_, _) => unimplemented!(), + ScalarValue::Null => array.data().is_null(index), } } @@ -1740,6 +1765,7 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Struct(fields) => { ScalarValue::Struct(None, Box::new(fields.clone())) } + DataType::Null => ScalarValue::Null, _ => { return Err(DataFusionError::NotImplemented(format!( "Can't create a scalar from data_type \"{:?}\"", @@ -1832,6 +1858,7 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, + ScalarValue::Null => write!(f, "NULL")?, }; Ok(()) } @@ -1899,6 +1926,7 @@ impl fmt::Debug for ScalarValue { None => write!(f, "Struct(NULL)"), } } + ScalarValue::Null => write!(f, "NULL"), } } } diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index c83857f1e47de..cd6df27a08ed6 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -154,7 +154,7 @@ impl LogicalPlanBuilder { .iter() .enumerate() .map(|(j, expr)| { - if let Expr::Literal(ScalarValue::Utf8(None)) = expr { + if let Expr::Literal(ScalarValue::Null) = expr { nulls.push((i, j)); Ok(field_types[j].clone()) } else { diff --git a/datafusion/core/src/physical_plan/functions.rs b/datafusion/core/src/physical_plan/functions.rs index 4127afe79df3e..cdea2c7c16ea0 100644 --- a/datafusion/core/src/physical_plan/functions.rs +++ b/datafusion/core/src/physical_plan/functions.rs @@ -63,6 +63,7 @@ macro_rules! make_utf8_to_return_type { Ok(match arg_type { DataType::LargeUtf8 => $largeUtf8Type, DataType::Utf8 => $utf8Type, + DataType::Null => DataType::Null, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal(format!( diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs index 49783c6f326ab..55ad1c947730e 100644 --- a/datafusion/core/src/physical_plan/hash_join.rs +++ b/datafusion/core/src/physical_plan/hash_join.rs @@ -837,7 +837,11 @@ fn equal_rows( .iter() .zip(right_arrays) .all(|(l, r)| match l.data_type() { - DataType::Null => true, + DataType::Null => { + // lhs and rhs are both `DataType::Null`, so the euqal result + // is dependent on `null_equals_null` + null_equals_null + } DataType::Boolean => { equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null) } diff --git a/datafusion/core/src/physical_plan/hash_utils.rs b/datafusion/core/src/physical_plan/hash_utils.rs index 4e503b19e7bf3..2ca1fa3df9d13 100644 --- a/datafusion/core/src/physical_plan/hash_utils.rs +++ b/datafusion/core/src/physical_plan/hash_utils.rs @@ -39,6 +39,19 @@ fn combine_hashes(l: u64, r: u64) -> u64 { hash.wrapping_mul(37).wrapping_add(r) } +fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) { + if mul_col { + hashes_buffer.iter_mut().for_each(|hash| { + // stable hash for null value + *hash = combine_hashes(i128::get_hash(&1, random_state), *hash); + }) + } else { + hashes_buffer.iter_mut().for_each(|hash| { + *hash = i128::get_hash(&1, random_state); + }) + } +} + fn hash_decimal128<'a>( array: &ArrayRef, random_state: &RandomState, @@ -284,6 +297,9 @@ pub fn create_hashes<'a>( for col in arrays { match col.data_type() { + DataType::Null => { + hash_null(random_state, hashes_buffer, multi_col); + } DataType::Decimal(_, _) => { hash_decimal128(col, random_state, hashes_buffer, multi_col); } diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 47eadf26deafa..ca9732fa9fba5 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -1609,7 +1609,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n), SQLExpr::Value(Value::SingleQuotedString(s)) => Ok(lit(s)), SQLExpr::Value(Value::Null) => { - Ok(Expr::Literal(ScalarValue::Utf8(None))) + Ok(Expr::Literal(ScalarValue::Null)) } SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), SQLExpr::UnaryOp { op, expr } => { @@ -1635,7 +1635,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())), SQLExpr::Value(Value::EscapedStringLiteral(ref s)) => Ok(lit(s.clone())), SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), - SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))), + SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)), SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction { fun: functions::BuiltinScalarFunction::DatePart, args: vec![ diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 644148445757a..9df7b09d151f9 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -411,11 +411,11 @@ async fn test_string_concat_operator() -> Result<()> { let sql = "SELECT 'aa' || NULL || 'd'"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+---------------------------------------+", - "| Utf8(\"aa\") || Utf8(NULL) || Utf8(\"d\") |", - "+---------------------------------------+", - "| |", - "+---------------------------------------+", + "+---------------------------------+", + "| Utf8(\"aa\") || NULL || Utf8(\"d\") |", + "+---------------------------------+", + "| |", + "+---------------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -433,6 +433,45 @@ async fn test_string_concat_operator() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_not_expressions() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "SELECT not(true), not(false)"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------------------+--------------------+", + "| NOT Boolean(true) | NOT Boolean(false) |", + "+-------------------+--------------------+", + "| false | true |", + "+-------------------+--------------------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT null, not(null)"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------+----------+", + "| NULL | NOT NULL |", + "+------+----------+", + "| | |", + "+------+----------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT NOT('hi')"; + let result = plan_and_collect(&ctx, sql).await; + match result { + Ok(_) => panic!("expected error"), + Err(e) => { + assert_contains!(e.to_string(), + "NOT 'Literal { value: Utf8(\"hi\") }' can't be evaluated because the expression's type is Utf8, not boolean or NULL" + ); + } + } + Ok(()) +} + #[tokio::test] async fn test_boolean_expressions() -> Result<()> { test_expression!("true", "true"); diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs index 226bb8159d788..ae86aeb174589 100644 --- a/datafusion/core/tests/sql/functions.rs +++ b/datafusion/core/tests/sql/functions.rs @@ -197,11 +197,11 @@ async fn coalesce_static_value_with_null() -> Result<()> { let sql = "SELECT COALESCE(NULL, 'test')"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+-----------------------------------+", - "| coalesce(Utf8(NULL),Utf8(\"test\")) |", - "+-----------------------------------+", - "| test |", - "+-----------------------------------+", + "+-----------------------------+", + "| coalesce(NULL,Utf8(\"test\")) |", + "+-----------------------------+", + "| test |", + "+-----------------------------+", ]; assert_batches_eq!(expected, &actual); Ok(()) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 4859041579d09..007de1bfe7517 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -829,7 +829,11 @@ async fn inner_join_nulls() { let sql = "SELECT * FROM (SELECT null AS id1) t1 INNER JOIN (SELECT null AS id2) t2 ON id1 = id2"; - let expected = vec!["++", "++"]; + #[rustfmt::skip] + let expected = vec![ + "++", + "++", + ]; let ctx = create_join_context_qualified().unwrap(); let actual = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 09c0fc99bb998..1af11a494ab0a 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -398,15 +398,37 @@ async fn select_distinct_from() { 1 IS NOT DISTINCT FROM CAST(NULL as INT) as c, 1 IS NOT DISTINCT FROM 1 as d, NULL IS DISTINCT FROM NULL as e, - NULL IS NOT DISTINCT FROM NULL as f + NULL IS NOT DISTINCT FROM NULL as f, + NULL is DISTINCT FROM 1 as g, + NULL is NOT DISTINCT FROM 1 as h "; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------+-------+-------+------+-------+------+", - "| a | b | c | d | e | f |", - "+------+-------+-------+------+-------+------+", - "| true | false | false | true | false | true |", - "+------+-------+-------+------+-------+------+", + "+------+-------+-------+------+-------+------+------+-------+", + "| a | b | c | d | e | f | g | h |", + "+------+-------+-------+------+-------+------+------+-------+", + "| true | false | false | true | false | true | true | false |", + "+------+-------+-------+------+-------+------+------+-------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "select + NULL IS DISTINCT FROM NULL as a, + NULL IS NOT DISTINCT FROM NULL as b, + NULL is DISTINCT FROM 1 as c, + NULL is NOT DISTINCT FROM 1 as d, + 1 IS DISTINCT FROM CAST(NULL as INT) as e, + 1 IS DISTINCT FROM 1 as f, + 1 IS NOT DISTINCT FROM CAST(NULL as INT) as g, + 1 IS NOT DISTINCT FROM 1 as h + "; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------+------+------+-------+------+-------+-------+------+", + "| a | b | c | d | e | f | g | h |", + "+-------+------+------+-------+------+-------+-------+------+", + "| false | true | true | false | true | false | false | true |", + "+-------+------+------+-------+------+-------+-------+------+", ]; assert_batches_eq!(expected, &actual); } diff --git a/datafusion/expr/src/type_coercion.rs b/datafusion/expr/src/type_coercion.rs new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/datafusion/physical-expr/src/coercion_rule/binary_rule.rs b/datafusion/physical-expr/src/coercion_rule/binary_rule.rs index 8ab5b6309edab..53e8ad0145151 100644 --- a/datafusion/physical-expr/src/coercion_rule/binary_rule.rs +++ b/datafusion/physical-expr/src/coercion_rule/binary_rule.rs @@ -538,6 +538,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { numerical_coercion(lhs_type, rhs_type) .or_else(|| dictionary_coercion(lhs_type, rhs_type)) .or_else(|| temporal_coercion(lhs_type, rhs_type)) + .or_else(|| null_coercion(lhs_type, rhs_type)) } /// Coercion rule for interval @@ -557,6 +558,28 @@ pub fn interval_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { + match (lhs_type, rhs_type) { + (DataType::Null, _) => { + if can_cast_types(&DataType::Null, rhs_type) { + Some(rhs_type.clone()) + } else { + None + } + } + (_, DataType::Null) => { + if can_cast_types(&DataType::Null, lhs_type) { + Some(lhs_type.clone()) + } else { + None + } + } + _ => None, + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index d93d772547d5f..3a95018b4b208 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -805,6 +805,20 @@ macro_rules! compute_decimal_op { }}; } +macro_rules! compute_null_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + let rr = $RIGHT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + Ok(Arc::new(paste::expr! {[<$OP _null>]}(&ll, &rr)?)) + }}; +} + /// Invoke a compute kernel on a pair of binary data arrays macro_rules! compute_utf8_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -1109,6 +1123,7 @@ macro_rules! binary_array_op_scalar { macro_rules! binary_array_op { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ match $LEFT.data_type() { + DataType::Null => compute_null_op!($LEFT, $RIGHT, $OP, NullArray), DataType::Decimal(_,_) => compute_decimal_op!($LEFT, $RIGHT, $OP, DecimalArray), DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array), DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array), @@ -1524,7 +1539,16 @@ impl BinaryExpr { Operator::GtEq => gt_eq_dyn(&left, &right), Operator::Eq => eq_dyn(&left, &right), Operator::NotEq => neq_dyn(&left, &right), - Operator::IsDistinctFrom => binary_array_op!(left, right, is_distinct_from), + Operator::IsDistinctFrom => { + match (left_data_type, right_data_type) { + // exchange lhs and rhs when lhs is Null, since `binary_array_op` is + // always try to down cast array according to $LEFT expression. + (DataType::Null, _) => { + binary_array_op!(right, left, is_distinct_from) + } + _ => binary_array_op!(left, right, is_distinct_from), + } + } Operator::IsNotDistinctFrom => { binary_array_op!(left, right, is_not_distinct_from) } @@ -1601,6 +1625,27 @@ fn is_distinct_from_utf8( .collect()) } +fn is_distinct_from_null(left: &NullArray, _right: &NullArray) -> Result { + let length = left.len(); + make_boolean_array(length, false) +} + +fn is_not_distinct_from_null( + left: &NullArray, + _right: &NullArray, +) -> Result { + let length = left.len(); + make_boolean_array(length, true) +} + +pub fn eq_null(left: &NullArray, _right: &NullArray) -> Result { + Ok((0..left.len()).into_iter().map(|_| None).collect()) +} + +fn make_boolean_array(length: usize, value: bool) -> Result { + Ok((0..length).into_iter().map(|_| Some(value)).collect()) +} + fn is_not_distinct_from( left: &PrimitiveArray, right: &PrimitiveArray, diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 2aee0d87dbde3..92a9d64c14205 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -451,6 +451,10 @@ impl PhysicalExpr for InListExpr { DataType::LargeUtf8 => { self.compare_utf8::(array, list_values, self.negated) } + DataType::Null => { + let null_array = new_null_array(&DataType::Boolean, array.len()); + Ok(ColumnarValue::Array(Arc::new(null_array))) + } datatype => Result::Err(DataFusionError::NotImplemented(format!( "InList does not support datatype {:?}.", datatype diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/physical-expr/src/expressions/nullif.rs index f3f72096600f2..7bcfa47ce815b 100644 --- a/datafusion/physical-expr/src/expressions/nullif.rs +++ b/datafusion/physical-expr/src/expressions/nullif.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use crate::expressions::binary::{eq_decimal, eq_decimal_scalar}; +use crate::expressions::binary::{eq_decimal, eq_decimal_scalar, eq_null}; use arrow::array::Array; use arrow::array::*; use arrow::compute::kernels::boolean::nullif;