From 24036e92903ba4eda1317acd4200ac0786de88e0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 29 Aug 2022 13:02:07 -0700 Subject: [PATCH] Support SQL-compliant NaN ordering for DictionaryArray and non-DictionaryArray (#2600) --- arrow/src/compute/kernels/comparison.rs | 544 +++++++++++++++++++++++- 1 file changed, 528 insertions(+), 16 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index c3ed660b2bfe..1c346a3114a1 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -2135,7 +2135,7 @@ macro_rules! typed_dict_non_dict_cmp { #[cfg(feature = "dyn_cmp_dict")] macro_rules! typed_cmp_dict_non_dict { - ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{ + ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{ match ($LEFT.data_type(), $RIGHT.data_type()) { (DataType::Dictionary(left_key_type, left_value_type), right_type) => { match (left_value_type.as_ref(), right_type) { @@ -2164,10 +2164,10 @@ macro_rules! typed_cmp_dict_non_dict { typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt64Type, $OP_BOOL, $OP) } (DataType::Float32, DataType::Float32) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float32Type, $OP_BOOL, $OP) + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float32Type, $OP_BOOL, $OP_FLOAT) } (DataType::Float64, DataType::Float64) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float64Type, $OP_BOOL, $OP) + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float64Type, $OP_BOOL, $OP_FLOAT) } (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( "Comparing dictionary array of type {} with array of type {} is not yet implemented", @@ -2186,7 +2186,7 @@ macro_rules! typed_cmp_dict_non_dict { #[cfg(not(feature = "dyn_cmp_dict"))] macro_rules! typed_cmp_dict_non_dict { - ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{ + ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{ Err(ArrowError::CastError(format!( "Comparing dictionary array of type {} with array of type {} requires \"dyn_cmp_dict\" feature", $LEFT.data_type(), $RIGHT.data_type() @@ -2670,10 +2670,54 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) if !matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(left, right, |a, b| a == b, |a, b| a == b) + #[cfg(not(feature = "nan_ordering"))] + return typed_cmp_dict_non_dict!( + left, + right, + |a, b| a == b, + |a, b| a == b, + |a, b| a == b + ); + + #[cfg(feature = "nan_ordering")] + return typed_cmp_dict_non_dict!( + left, + right, + |a, b| a == b, + |a, b| a == b, + |a, b| { + if is_nan(a) && is_nan(b) { + true + } else { + a == b + } + } + ); } _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a == b, |a, b| a == b) + #[cfg(not(feature = "nan_ordering"))] + return typed_cmp_dict_non_dict!( + right, + left, + |a, b| a == b, + |a, b| a == b, + |a, b| a == b + ); + + #[cfg(feature = "nan_ordering")] + return typed_cmp_dict_non_dict!( + right, + left, + |a, b| a == b, + |a, b| a == b, + |a, b| { + if is_nan(a) && is_nan(b) { + true + } else { + a == b + } + } + ); } _ => { #[cfg(not(feature = "nan_ordering"))] @@ -2752,10 +2796,54 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) if !matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(left, right, |a, b| a != b, |a, b| a != b) + #[cfg(not(feature = "nan_ordering"))] + return typed_cmp_dict_non_dict!( + left, + right, + |a, b| a != b, + |a, b| a != b, + |a, b| a != b + ); + + #[cfg(feature = "nan_ordering")] + return typed_cmp_dict_non_dict!( + left, + right, + |a, b| a != b, + |a, b| a != b, + |a, b| { + if is_nan(a) && is_nan(b) { + false + } else { + a != b + } + } + ); } _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a != b, |a, b| a != b) + #[cfg(not(feature = "nan_ordering"))] + return typed_cmp_dict_non_dict!( + right, + left, + |a, b| a != b, + |a, b| a != b, + |a, b| a != b + ); + + #[cfg(feature = "nan_ordering")] + return typed_cmp_dict_non_dict!( + right, + left, + |a, b| a != b, + |a, b| a != b, + |a, b| { + if is_nan(a) && is_nan(b) { + false + } else { + a != b + } + } + ); } _ => { #[cfg(not(feature = "nan_ordering"))] @@ -2836,10 +2924,58 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) if !matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(left, right, |a, b| a < b, |a, b| a < b) + #[cfg(not(feature = "nan_ordering"))] + return typed_cmp_dict_non_dict!( + left, + right, + |a, b| a < b, + |a, b| a < b, + |a, b| a < b + ); + + #[cfg(feature = "nan_ordering")] + return typed_cmp_dict_non_dict!( + left, + right, + |a, b| a < b, + |a, b| a < b, + |a, b| { + if is_nan(a) { + false + } else if is_nan(b) { + true + } else { + a < b + } + } + ); } _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a > b, |a, b| a > b) + #[cfg(not(feature = "nan_ordering"))] + return typed_cmp_dict_non_dict!( + right, + left, + |a, b| a > b, + |a, b| a > b, + |a, b| a > b + ); + + #[cfg(feature = "nan_ordering")] + return typed_cmp_dict_non_dict!( + right, + left, + |a, b| a > b, + |a, b| a > b, + |a, b| { + if is_nan(b) { + false + } else if is_nan(a) { + true + } else { + a > b + } + } + ); } _ => { #[cfg(not(feature = "nan_ordering"))] @@ -2919,10 +3055,54 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) if !matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(left, right, |a, b| a <= b, |a, b| a <= b) + #[cfg(not(feature = "nan_ordering"))] + return typed_cmp_dict_non_dict!( + left, + right, + |a, b| a <= b, + |a, b| a <= b, + |a, b| a <= b + ); + + #[cfg(feature = "nan_ordering")] + return typed_cmp_dict_non_dict!( + left, + right, + |a, b| a <= b, + |a, b| a <= b, + |a, b| { + if is_nan(a) { + is_nan(b) + } else { + a <= b + } + } + ); } _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a >= b, |a, b| a >= b) + #[cfg(not(feature = "nan_ordering"))] + return typed_cmp_dict_non_dict!( + right, + left, + |a, b| a >= b, + |a, b| a >= b, + |a, b| a >= b + ); + + #[cfg(feature = "nan_ordering")] + return typed_cmp_dict_non_dict!( + right, + left, + |a, b| a >= b, + |a, b| a >= b, + |a, b| { + if is_nan(b) { + is_nan(a) + } else { + a >= b + } + } + ); } _ => { #[cfg(not(feature = "nan_ordering"))] @@ -3002,10 +3182,58 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) if !matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(left, right, |a, b| a > b, |a, b| a > b) + #[cfg(not(feature = "nan_ordering"))] + return typed_cmp_dict_non_dict!( + left, + right, + |a, b| a > b, + |a, b| a > b, + |a, b| a > b + ); + + #[cfg(feature = "nan_ordering")] + return typed_cmp_dict_non_dict!( + left, + right, + |a, b| a > b, + |a, b| a > b, + |a, b| { + if is_nan(a) { + !is_nan(b) + } else if is_nan(b) { + false + } else { + a > b + } + } + ); } _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a < b, |a, b| a < b) + #[cfg(not(feature = "nan_ordering"))] + return typed_cmp_dict_non_dict!( + right, + left, + |a, b| a < b, + |a, b| a < b, + |a, b| a < b + ); + + #[cfg(feature = "nan_ordering")] + return typed_cmp_dict_non_dict!( + right, + left, + |a, b| a < b, + |a, b| a < b, + |a, b| { + if is_nan(b) { + !is_nan(a) + } else if is_nan(a) { + false + } else { + a < b + } + } + ); } _ => { #[cfg(not(feature = "nan_ordering"))] @@ -3084,10 +3312,54 @@ pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) if !matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(left, right, |a, b| a >= b, |a, b| a >= b) + #[cfg(not(feature = "nan_ordering"))] + return typed_cmp_dict_non_dict!( + left, + right, + |a, b| a >= b, + |a, b| a >= b, + |a, b| a >= b + ); + + #[cfg(feature = "nan_ordering")] + return typed_cmp_dict_non_dict!( + left, + right, + |a, b| a >= b, + |a, b| a >= b, + |a, b| { + if is_nan(a) { + true + } else { + a >= b + } + } + ); } _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a <= b, |a, b| a <= b) + #[cfg(not(feature = "nan_ordering"))] + return typed_cmp_dict_non_dict!( + right, + left, + |a, b| a <= b, + |a, b| a <= b, + |a, b| a <= b + ); + + #[cfg(feature = "nan_ordering")] + return typed_cmp_dict_non_dict!( + right, + left, + |a, b| a <= b, + |a, b| a <= b, + |a, b| { + if is_nan(b) { + true + } else { + a <= b + } + } + ); } _ => { #[cfg(not(feature = "nan_ordering"))] @@ -6336,4 +6608,244 @@ mod tests { ), ); } + + #[test] + #[cfg(feature = "dyn_cmp_dict")] + fn test_eq_dyn_neq_dyn_dict_non_dict_float_nan() { + let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let values = Float32Array::from(vec![f32::NAN, 8.0, 10.0]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 1, 2]); + let array2 = DictionaryArray::try_new(&keys, &values).unwrap(); + + #[cfg(not(feature = "nan_ordering"))] + { + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), Some(true)], + ); + assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); + } + #[cfg(feature = "nan_ordering")] + { + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(true), Some(true)], + ); + assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); + } + + #[cfg(not(feature = "nan_ordering"))] + { + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(false), Some(false)], + ); + assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); + } + #[cfg(feature = "nan_ordering")] + { + let expected = BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(false), Some(false)], + ); + assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); + } + + let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let values = Float64Array::from(vec![f64::NAN, 8.0, 10.0]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 1, 2]); + let array2 = DictionaryArray::try_new(&keys, &values).unwrap(); + + #[cfg(not(feature = "nan_ordering"))] + { + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), Some(true)], + ); + assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); + } + #[cfg(feature = "nan_ordering")] + { + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(true), Some(true)], + ); + assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); + } + + #[cfg(not(feature = "nan_ordering"))] + { + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(false), Some(false)], + ); + assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); + } + #[cfg(feature = "nan_ordering")] + { + let expected = BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(false), Some(false)], + ); + assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); + } + } + + #[test] + #[cfg(feature = "dyn_cmp_dict")] + fn test_lt_dyn_lt_eq_dyn_dict_non_dict_float_nan() { + let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 11.0, f32::NAN] + .into_iter() + .map(Some) + .collect(); + let values = Float32Array::from(vec![f32::NAN, 8.0, 9.0, 10.0, 1.0]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]); + let array2 = DictionaryArray::try_new(&keys, &values).unwrap(); + + #[cfg(not(feature = "nan_ordering"))] + { + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); + } + #[cfg(feature = "nan_ordering")] + { + let expected = BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); + } + + #[cfg(not(feature = "nan_ordering"))] + { + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); + } + #[cfg(feature = "nan_ordering")] + { + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); + } + + let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 11.0, f64::NAN] + .into_iter() + .map(Some) + .collect(); + let values = Float64Array::from(vec![f64::NAN, 8.0, 9.0, 10.0, 1.0]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]); + let array2 = DictionaryArray::try_new(&keys, &values).unwrap(); + + #[cfg(not(feature = "nan_ordering"))] + { + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); + } + #[cfg(feature = "nan_ordering")] + { + let expected = BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); + } + + #[cfg(not(feature = "nan_ordering"))] + { + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); + } + #[cfg(feature = "nan_ordering")] + { + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); + } + } + + #[test] + #[cfg(feature = "dyn_cmp_dict")] + fn test_gt_dyn_gt_eq_dyn_dict_non_dict_float_nan() { + let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 11.0, f32::NAN] + .into_iter() + .map(Some) + .collect(); + let values = Float32Array::from(vec![f32::NAN, 8.0, 9.0, 10.0, 1.0]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]); + let array2 = DictionaryArray::try_new(&keys, &values).unwrap(); + + #[cfg(not(feature = "nan_ordering"))] + { + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(true), Some(false)], + ); + assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); + } + #[cfg(feature = "nan_ordering")] + { + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(true), Some(true)], + ); + assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); + } + + #[cfg(not(feature = "nan_ordering"))] + { + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(false), Some(true), Some(false)], + ); + assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); + } + #[cfg(feature = "nan_ordering")] + { + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(false), Some(true), Some(true)], + ); + assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); + } + + let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 11.0, f64::NAN] + .into_iter() + .map(Some) + .collect(); + let values = Float64Array::from(vec![f64::NAN, 8.0, 9.0, 10.0, 1.0]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]); + let array2 = DictionaryArray::try_new(&keys, &values).unwrap(); + + #[cfg(not(feature = "nan_ordering"))] + { + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(true), Some(false)], + ); + assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); + } + #[cfg(feature = "nan_ordering")] + { + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(true), Some(true)], + ); + assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); + } + + #[cfg(not(feature = "nan_ordering"))] + { + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(false), Some(true), Some(false)], + ); + assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); + } + #[cfg(feature = "nan_ordering")] + { + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(false), Some(true), Some(true)], + ); + assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); + } + } }