diff --git a/arrow/benches/arithmetic_kernels.rs b/arrow/benches/arithmetic_kernels.rs index 4be4a26933aa..10af0b5432ef 100644 --- a/arrow/benches/arithmetic_kernels.rs +++ b/arrow/benches/arithmetic_kernels.rs @@ -55,13 +55,13 @@ fn bench_multiply(arr_a: &ArrayRef, arr_b: &ArrayRef) { fn bench_divide(arr_a: &ArrayRef, arr_b: &ArrayRef) { let arr_a = arr_a.as_any().downcast_ref::().unwrap(); let arr_b = arr_b.as_any().downcast_ref::().unwrap(); - criterion::black_box(divide(arr_a, arr_b).unwrap()); + criterion::black_box(divide_checked(arr_a, arr_b).unwrap()); } fn bench_divide_unchecked(arr_a: &ArrayRef, arr_b: &ArrayRef) { let arr_a = arr_a.as_any().downcast_ref::().unwrap(); let arr_b = arr_b.as_any().downcast_ref::().unwrap(); - criterion::black_box(divide_unchecked(arr_a, arr_b).unwrap()); + criterion::black_box(divide(arr_a, arr_b).unwrap()); } fn bench_divide_scalar(array: &ArrayRef, divisor: f32) { diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index fff687e18b3c..53f48570d927 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -35,8 +35,9 @@ use crate::compute::unary_dyn; use crate::compute::util::combine_option_bitmap; use crate::datatypes; use crate::datatypes::{ - ArrowNumericType, DataType, Date32Type, Date64Type, IntervalDayTimeType, - IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, + native_op::ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, DataType, + Date32Type, Date64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, + IntervalYearMonthType, }; use crate::datatypes::{ Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, @@ -103,6 +104,106 @@ where Ok(PrimitiveArray::::from(data)) } +/// This is similar to `math_op` as it performs given operation between two input primitive arrays. +/// But the given operation can return `None` if overflow is detected. For the case, this function +/// returns an `Err`. +fn math_checked_op( + left: &PrimitiveArray, + right: &PrimitiveArray, + op: F, +) -> Result> +where + LT: ArrowNumericType, + RT: ArrowNumericType, + F: Fn(LT::Native, RT::Native) -> Option, +{ + if left.len() != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform math operation on arrays of different length".to_string(), + )); + } + + let left_iter = ArrayIter::new(left); + let right_iter = ArrayIter::new(right); + + let values: Result::Native>>> = left_iter + .into_iter() + .zip(right_iter.into_iter()) + .map(|(l, r)| { + if let (Some(l), Some(r)) = (l, r) { + let result = op(l, r); + if let Some(r) = result { + Ok(Some(r)) + } else { + // Overflow + Err(ArrowError::ComputeError(format!( + "Overflow happened on: {:?}, {:?}", + l, r + ))) + } + } else { + Ok(None) + } + }) + .collect(); + + let values = values?; + + Ok(PrimitiveArray::::from_iter(values)) +} + +/// This is similar to `math_checked_op` but just for divide op. +fn math_checked_divide( + left: &PrimitiveArray, + right: &PrimitiveArray, + op: F, +) -> Result> +where + LT: ArrowNumericType, + RT: ArrowNumericType, + RT::Native: One + Zero, + F: Fn(LT::Native, RT::Native) -> Option, +{ + if left.len() != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform math operation on arrays of different length".to_string(), + )); + } + + let left_iter = ArrayIter::new(left); + let right_iter = ArrayIter::new(right); + + let values: Result::Native>>> = left_iter + .into_iter() + .zip(right_iter.into_iter()) + .map(|(l, r)| { + if let (Some(l), Some(r)) = (l, r) { + let result = op(l, r); + if let Some(r) = result { + Ok(Some(r)) + } else if r.is_zero() { + Err(ArrowError::ComputeError(format!( + "DivideByZero on: {:?}, {:?}", + l, r + ))) + } else { + // Overflow + Err(ArrowError::ComputeError(format!( + "Overflow happened on: {:?}, {:?}", + l, r + ))) + } + } else { + Ok(None) + } + }) + .collect(); + + let values = values?; + + Ok(PrimitiveArray::::from_iter(values)) +} + /// Helper function for operations where a valid `0` on the right array should /// result in an [ArrowError::DivideByZero], namely the division and modulo operations /// @@ -760,15 +861,34 @@ where /// Perform `left + right` operation on two arrays. If either left or right value is null /// then the result is also null. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `add_checked` instead. pub fn add( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> where T: ArrowNumericType, - T::Native: Add, + T::Native: ArrowNativeTypeOp, +{ + math_op(left, right, |a, b| a.add_wrapping(b)) +} + +/// Perform `left + right` operation on two arrays. If either left or right value is null +/// then the result is also null. Once +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `add` instead. +pub fn add_checked( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, { - math_op(left, right, |a, b| a + b) + math_checked_op(left, right, |a, b| a.add_checked(b)) } /// Perform `left + right` operation on two arrays. If either left or right value is null @@ -856,15 +976,34 @@ where /// Perform `left - right` operation on two arrays. If either left or right value is null /// then the result is also null. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `subtract_checked` instead. pub fn subtract( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> where T: datatypes::ArrowNumericType, - T::Native: Sub, + T::Native: ArrowNativeTypeOp, { - math_op(left, right, |a, b| a - b) + math_op(left, right, |a, b| a.sub_wrapping(b)) +} + +/// Perform `left - right` operation on two arrays. If either left or right value is null +/// then the result is also null. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `subtract` instead. +pub fn subtract_checked( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result> +where + T: datatypes::ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + math_checked_op(left, right, |a, b| a.sub_checked(b)) } /// Perform `left - right` operation on two arrays. If either left or right value is null @@ -933,15 +1072,34 @@ where /// Perform `left * right` operation on two arrays. If either left or right value is null /// then the result is also null. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `multiply_check` instead. pub fn multiply( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> where T: datatypes::ArrowNumericType, - T::Native: Mul, + T::Native: ArrowNativeTypeOp, { - math_op(left, right, |a, b| a * b) + math_op(left, right, |a, b| a.mul_wrapping(b)) +} + +/// Perform `left * right` operation on two arrays. If either left or right value is null +/// then the result is also null. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `multiply` instead. +pub fn multiply_checked( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result> +where + T: datatypes::ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + math_checked_op(left, right, |a, b| a.mul_checked(b)) } /// Perform `left * right` operation on two arrays. If either left or right value is null @@ -1013,18 +1171,21 @@ where /// Perform `left / right` operation on two arrays. If either left or right value is null /// then the result is also null. If any right hand value is zero then the result of this /// operation will be `Err(ArrowError::DivideByZero)`. -pub fn divide( +/// +/// When `simd` feature is not enabled. This detects overflow and returns an `Err` for that. +/// For an non-overflow-checking variant, use `divide` instead. +pub fn divide_checked( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> where T: datatypes::ArrowNumericType, - T::Native: Div + Zero + One, + T::Native: ArrowNativeTypeOp + Zero + One, { #[cfg(feature = "simd")] return simd_checked_divide_op(&left, &right, simd_checked_divide::, |a, b| a / b); #[cfg(not(feature = "simd"))] - return math_checked_divide_op(left, right, |a, b| a / b); + return math_checked_divide(left, right, |a, b| a.div_checked(b)); } /// Perform `left / right` operation on two arrays. If either left or right value is null @@ -1040,17 +1201,21 @@ pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result { } /// Perform `left / right` operation on two arrays without checking for division by zero. -/// The result of dividing by zero follows normal floating point rules. +/// For floating point types, the result of dividing by zero follows normal floating point +/// rules. For other numeric types, dividing by zero will panic, /// If either left or right value is null then the result is also null. If any right hand value is zero then the result of this -pub fn divide_unchecked( +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `divide_checked` instead. +pub fn divide( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> where - T: datatypes::ArrowFloatNumericType, - T::Native: Div, + T: datatypes::ArrowNumericType, + T::Native: ArrowNativeTypeOp, { - math_op(left, right, |a, b| a / b) + math_op(left, right, |a, b| a.div_wrapping(b)) } /// Modulus every value in an array by a scalar. If any value in the array is null then the @@ -1769,7 +1934,7 @@ mod tests { fn test_primitive_array_divide_with_nulls() { let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9), None]); let b = Int32Array::from(vec![Some(5), Some(6), Some(8), Some(9), None, None]); - let c = divide(&a, &b).unwrap(); + let c = divide_checked(&a, &b).unwrap(); assert_eq!(3, c.value(0)); assert!(c.is_null(1)); assert_eq!(1, c.value(2)); @@ -1854,7 +2019,7 @@ mod tests { let b = b.slice(8, 6); let b = b.as_any().downcast_ref::().unwrap(); - let c = divide(a, b).unwrap(); + let c = divide_checked(a, b).unwrap(); assert_eq!(6, c.len()); assert_eq!(3, c.value(0)); assert!(c.is_null(1)); @@ -1919,6 +2084,14 @@ mod tests { #[test] #[should_panic(expected = "DivideByZero")] + fn test_primitive_array_divide_by_zero_with_checked() { + let a = Int32Array::from(vec![15]); + let b = Int32Array::from(vec![0]); + divide_checked(&a, &b).unwrap(); + } + + #[test] + #[should_panic(expected = "attempt to divide by zero")] fn test_primitive_array_divide_by_zero() { let a = Int32Array::from(vec![15]); let b = Int32Array::from(vec![0]); @@ -2019,4 +2192,57 @@ mod tests { let expected = Float64Array::from(vec![Some(1.0), None, Some(9.0)]); assert_eq!(expected, actual); } + + #[test] + fn test_primitive_add_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + let b = Int32Array::from(vec![1, 1]); + + let wrapped = add(&a, &b); + let expected = Int32Array::from(vec![-2147483648, -2147483647]); + assert_eq!(expected, wrapped.unwrap()); + + let overflow = add_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_subtract_wrapping_overflow() { + let a = Int32Array::from(vec![-2]); + let b = Int32Array::from(vec![i32::MAX]); + + let wrapped = subtract(&a, &b); + let expected = Int32Array::from(vec![i32::MAX]); + assert_eq!(expected, wrapped.unwrap()); + + let overflow = subtract_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_mul_wrapping_overflow() { + let a = Int32Array::from(vec![10]); + let b = Int32Array::from(vec![i32::MAX]); + + let wrapped = multiply(&a, &b); + let expected = Int32Array::from(vec![-10]); + assert_eq!(expected, wrapped.unwrap()); + + let overflow = multiply_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + #[cfg(not(feature = "simd"))] + fn test_primitive_div_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MIN]); + let b = Int32Array::from(vec![-1]); + + let wrapped = divide(&a, &b); + let expected = Int32Array::from(vec![-2147483648]); + assert_eq!(expected, wrapped.unwrap()); + + let overflow = divide_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } } diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs index 207e8cb40330..444f2b27dce6 100644 --- a/arrow/src/datatypes/native.rs +++ b/arrow/src/datatypes/native.rs @@ -114,6 +114,112 @@ pub trait ArrowPrimitiveType: 'static { } } +pub(crate) mod native_op { + use super::ArrowNativeType; + use std::ops::{Add, Div, Mul, Sub}; + + /// Trait for ArrowNativeType to provide overflow-checking and non-overflow-checking + /// variants for arithmetic operations. For floating point types, this provides some + /// default implementations. Integer types that need to deal with overflow can implement + /// this trait. + /// + /// The APIs with `_wrapping` suffix are the variant of non-overflow-checking. If overflow + /// occurred, they will supposedly wrap around the boundary of the type. + /// + /// The APIs with `_checked` suffix are the variant of overflow-checking which return `None` + /// if overflow occurred. + pub trait ArrowNativeTypeOp: + ArrowNativeType + + Add + + Sub + + Mul + + Div + { + fn add_checked(self, rhs: Self) -> Option { + Some(self + rhs) + } + + fn add_wrapping(self, rhs: Self) -> Self { + self + rhs + } + + fn sub_checked(self, rhs: Self) -> Option { + Some(self - rhs) + } + + fn sub_wrapping(self, rhs: Self) -> Self { + self - rhs + } + + fn mul_checked(self, rhs: Self) -> Option { + Some(self * rhs) + } + + fn mul_wrapping(self, rhs: Self) -> Self { + self * rhs + } + + fn div_checked(self, rhs: Self) -> Option { + Some(self / rhs) + } + + fn div_wrapping(self, rhs: Self) -> Self { + self / rhs + } + } +} + +macro_rules! native_type_op { + ($t:tt) => { + impl native_op::ArrowNativeTypeOp for $t { + fn add_checked(self, rhs: Self) -> Option { + self.checked_add(rhs) + } + + fn add_wrapping(self, rhs: Self) -> Self { + self.wrapping_add(rhs) + } + + fn sub_checked(self, rhs: Self) -> Option { + self.checked_sub(rhs) + } + + fn sub_wrapping(self, rhs: Self) -> Self { + self.wrapping_sub(rhs) + } + + fn mul_checked(self, rhs: Self) -> Option { + self.checked_mul(rhs) + } + + fn mul_wrapping(self, rhs: Self) -> Self { + self.wrapping_mul(rhs) + } + + fn div_checked(self, rhs: Self) -> Option { + self.checked_div(rhs) + } + + fn div_wrapping(self, rhs: Self) -> Self { + self.wrapping_div(rhs) + } + } + }; +} + +native_type_op!(i8); +native_type_op!(i16); +native_type_op!(i32); +native_type_op!(i64); +native_type_op!(u8); +native_type_op!(u16); +native_type_op!(u32); +native_type_op!(u64); + +impl native_op::ArrowNativeTypeOp for f16 {} +impl native_op::ArrowNativeTypeOp for f32 {} +impl native_op::ArrowNativeTypeOp for f64 {} + impl private::Sealed for i8 {} impl ArrowNativeType for i8 { #[inline]