From 37b843bc42a3f18c47dcf076dc5d929cb4a54ec1 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Sun, 2 Jan 2022 07:08:24 -0500 Subject: [PATCH] Add gt eq dyn scalar kernel (#1117) * Add lt_dyn_scalar and tests * Add lt_eq_dyn_scalar kernel * Add gt_dyn_scalar kernel * Add gt_eq_dyn_scalar kernel * Add kernel to err message Co-authored-by: Andrew Lamb --- arrow/src/compute/kernels/comparison.rs | 65 +++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index c4c9fa1c3d41..2f7ddec07948 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -1229,6 +1229,42 @@ where } } +/// Perform `left >= right` operation on an array and a numeric scalar +/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values +pub fn gt_eq_dyn_scalar(left: Arc, right: T) -> Result +where + T: TryInto + Copy + std::fmt::Debug, +{ + match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, gt_eq_scalar)} + _ => Err(ArrowError::ComputeError( + "gt_eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + )) + } + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + dyn_compare_scalar!(&left, right, gt_eq_scalar) + } + _ => Err(ArrowError::ComputeError( + "gt_eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + )) + } +} + /// Perform `left == right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values pub fn eq_dyn_utf8_scalar(left: Arc, right: &str) -> Result { @@ -3209,6 +3245,35 @@ mod tests { ); } + #[test] + fn test_gt_eq_dyn_scalar() { + let array = Int32Array::from(vec![6, 7, 8, 8, 10]); + let array = Arc::new(array); + let a_eq = gt_eq_dyn_scalar(array, 8).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), Some(true)] + ) + ); + } + + #[test] + fn test_gt_eq_dyn_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(22).unwrap(); + builder.append_null().unwrap(); + builder.append(23).unwrap(); + let array = Arc::new(builder.finish()); + let a_eq = gt_eq_dyn_scalar(array, 23).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), None, Some(true)]) + ); + } + #[test] fn test_eq_dyn_utf8_scalar() { let array = StringArray::from(vec!["abc", "def", "xyz"]);