-
Notifications
You must be signed in to change notification settings - Fork 784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add overflow-checking variant for primitive arithmetic kernels and explicitly define overflow behavior #2643
Changes from 9 commits
154f8a5
02a2a80
74d8cc9
314886f
de42e04
e8218c4
06f4ce3
4aa4432
3d98aff
c930f85
34ea894
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,8 +35,9 @@ use crate::compute::unary_dyn; | |
use crate::compute::util::combine_option_bitmap; | ||
use crate::datatypes; | ||
use crate::datatypes::{ | ||
ArrowNumericType, DataType, Date32Type, Date64Type, IntervalDayTimeType, | ||
IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, | ||
ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, DataType, Date32Type, | ||
Date64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, | ||
IntervalYearMonthType, | ||
}; | ||
use crate::datatypes::{ | ||
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, | ||
|
@@ -103,6 +104,106 @@ where | |
Ok(PrimitiveArray::<LT>::from(data)) | ||
} | ||
|
||
/// This is similar to `math_op` as it performs given operation between two input primitive arrays. | ||
/// But the given operation can return `None` if overflow is detected. For the case, this function | ||
/// returns an `Err`. | ||
fn math_checked_op<LT, RT, F>( | ||
left: &PrimitiveArray<LT>, | ||
right: &PrimitiveArray<RT>, | ||
op: F, | ||
) -> Result<PrimitiveArray<LT>> | ||
where | ||
LT: ArrowNumericType, | ||
RT: ArrowNumericType, | ||
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>, | ||
{ | ||
if left.len() != right.len() { | ||
return Err(ArrowError::ComputeError( | ||
"Cannot perform math operation on arrays of different length".to_string(), | ||
)); | ||
} | ||
|
||
let left_iter = ArrayIter::new(left); | ||
let right_iter = ArrayIter::new(right); | ||
|
||
let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> = left_iter | ||
.into_iter() | ||
.zip(right_iter.into_iter()) | ||
.map(|(l, r)| { | ||
if let (Some(l), Some(r)) = (l, r) { | ||
let result = op(l, r); | ||
if let Some(r) = result { | ||
Ok(Some(r)) | ||
} else { | ||
// Overflow | ||
Err(ArrowError::ComputeError(format!( | ||
"Overflow happened on: {:?}, {:?}", | ||
l, r | ||
))) | ||
} | ||
} else { | ||
Ok(None) | ||
} | ||
}) | ||
.collect(); | ||
|
||
let values = values?; | ||
|
||
Ok(PrimitiveArray::<LT>::from_iter(values)) | ||
} | ||
|
||
/// This is similar to `math_checked_op` but just for divide op. | ||
fn math_checked_divide<LT, RT, F>( | ||
left: &PrimitiveArray<LT>, | ||
right: &PrimitiveArray<RT>, | ||
op: F, | ||
) -> Result<PrimitiveArray<LT>> | ||
where | ||
LT: ArrowNumericType, | ||
RT: ArrowNumericType, | ||
RT::Native: One + Zero, | ||
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>, | ||
{ | ||
if left.len() != right.len() { | ||
return Err(ArrowError::ComputeError( | ||
"Cannot perform math operation on arrays of different length".to_string(), | ||
)); | ||
} | ||
|
||
let left_iter = ArrayIter::new(left); | ||
let right_iter = ArrayIter::new(right); | ||
|
||
let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> = left_iter | ||
.into_iter() | ||
.zip(right_iter.into_iter()) | ||
.map(|(l, r)| { | ||
if let (Some(l), Some(r)) = (l, r) { | ||
let result = op(l, r); | ||
if let Some(r) = result { | ||
Ok(Some(r)) | ||
} else if r.is_zero() { | ||
Err(ArrowError::ComputeError(format!( | ||
"DivideByZero on: {:?}, {:?}", | ||
l, r | ||
))) | ||
} else { | ||
// Overflow | ||
Err(ArrowError::ComputeError(format!( | ||
"Overflow happened on: {:?}, {:?}", | ||
l, r | ||
))) | ||
} | ||
} else { | ||
Ok(None) | ||
} | ||
}) | ||
.collect(); | ||
|
||
let values = values?; | ||
|
||
Ok(PrimitiveArray::<LT>::from_iter(values)) | ||
} | ||
|
||
/// Helper function for operations where a valid `0` on the right array should | ||
/// result in an [ArrowError::DivideByZero], namely the division and modulo operations | ||
/// | ||
|
@@ -760,15 +861,34 @@ where | |
|
||
/// Perform `left + right` operation on two arrays. If either left or right value is null | ||
/// then the result is also null. | ||
/// | ||
/// This doesn't detect overflow. Once overflowing, the result will wrap around. | ||
/// For an overflow-checking variant, use `add_checked` instead. | ||
pub fn add<T>( | ||
left: &PrimitiveArray<T>, | ||
right: &PrimitiveArray<T>, | ||
) -> Result<PrimitiveArray<T>> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: Add<Output = T::Native>, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
math_op(left, right, |a, b| a.add_wrapping(b)) | ||
} | ||
|
||
/// Perform `left + right` operation on two arrays. If either left or right value is null | ||
/// then the result is also null. Once | ||
/// | ||
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, | ||
/// use `add` instead. | ||
pub fn add_checked<T>( | ||
left: &PrimitiveArray<T>, | ||
right: &PrimitiveArray<T>, | ||
) -> Result<PrimitiveArray<T>> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
math_op(left, right, |a, b| a + b) | ||
math_checked_op(left, right, |a, b| a.add_checked(b)) | ||
} | ||
|
||
/// Perform `left + right` operation on two arrays. If either left or right value is null | ||
|
@@ -856,15 +976,34 @@ where | |
|
||
/// Perform `left - right` operation on two arrays. If either left or right value is null | ||
/// then the result is also null. | ||
/// | ||
/// This doesn't detect overflow. Once overflowing, the result will wrap around. | ||
/// For an overflow-checking variant, use `subtract_checked` instead. | ||
pub fn subtract<T>( | ||
left: &PrimitiveArray<T>, | ||
right: &PrimitiveArray<T>, | ||
) -> Result<PrimitiveArray<T>> | ||
where | ||
T: datatypes::ArrowNumericType, | ||
T::Native: Sub<Output = T::Native>, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
math_op(left, right, |a, b| a - b) | ||
math_op(left, right, |a, b| a.sub_wrapping(b)) | ||
} | ||
|
||
/// Perform `left - right` operation on two arrays. If either left or right value is null | ||
/// then the result is also null. | ||
/// | ||
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, | ||
/// use `subtract` instead. | ||
pub fn subtract_checked<T>( | ||
left: &PrimitiveArray<T>, | ||
right: &PrimitiveArray<T>, | ||
) -> Result<PrimitiveArray<T>> | ||
where | ||
T: datatypes::ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
math_checked_op(left, right, |a, b| a.sub_checked(b)) | ||
} | ||
|
||
/// Perform `left - right` operation on two arrays. If either left or right value is null | ||
|
@@ -933,15 +1072,34 @@ where | |
|
||
/// Perform `left * right` operation on two arrays. If either left or right value is null | ||
/// then the result is also null. | ||
/// | ||
/// This doesn't detect overflow. Once overflowing, the result will wrap around. | ||
/// For an overflow-checking variant, use `multiply_check` instead. | ||
pub fn multiply<T>( | ||
left: &PrimitiveArray<T>, | ||
right: &PrimitiveArray<T>, | ||
) -> Result<PrimitiveArray<T>> | ||
where | ||
T: datatypes::ArrowNumericType, | ||
T::Native: Mul<Output = T::Native>, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
math_op(left, right, |a, b| a * b) | ||
math_op(left, right, |a, b| a.mul_wrapping(b)) | ||
} | ||
|
||
/// Perform `left * right` operation on two arrays. If either left or right value is null | ||
/// then the result is also null. | ||
/// | ||
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, | ||
/// use `multiply` instead. | ||
pub fn multiply_checked<T>( | ||
left: &PrimitiveArray<T>, | ||
right: &PrimitiveArray<T>, | ||
) -> Result<PrimitiveArray<T>> | ||
where | ||
T: datatypes::ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
math_checked_op(left, right, |a, b| a.mul_checked(b)) | ||
} | ||
|
||
/// Perform `left * right` operation on two arrays. If either left or right value is null | ||
|
@@ -1013,18 +1171,21 @@ where | |
/// Perform `left / right` operation on two arrays. If either left or right value is null | ||
/// then the result is also null. If any right hand value is zero then the result of this | ||
/// operation will be `Err(ArrowError::DivideByZero)`. | ||
pub fn divide<T>( | ||
/// | ||
/// When `simd` feature is not enabled. This detects overflow and returns an `Err` for that. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens when SIMD is enabled? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, that would imply rust division is always checked 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup - and LLVM cannot vectorize it correctly - https://rust.godbolt.org/z/T8eTGM8zn |
||
/// For an non-overflow-checking variant, use `divide` instead. | ||
pub fn divide_checked<T>( | ||
left: &PrimitiveArray<T>, | ||
right: &PrimitiveArray<T>, | ||
) -> Result<PrimitiveArray<T>> | ||
where | ||
T: datatypes::ArrowNumericType, | ||
T::Native: Div<Output = T::Native> + Zero + One, | ||
T::Native: ArrowNativeTypeOp + Zero + One, | ||
{ | ||
#[cfg(feature = "simd")] | ||
return simd_checked_divide_op(&left, &right, simd_checked_divide::<T>, |a, b| a / b); | ||
#[cfg(not(feature = "simd"))] | ||
return math_checked_divide_op(left, right, |a, b| a / b); | ||
return math_checked_divide(left, right, |a, b| a.div_checked(b)); | ||
} | ||
|
||
/// Perform `left / right` operation on two arrays. If either left or right value is null | ||
|
@@ -1040,17 +1201,21 @@ pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> { | |
} | ||
|
||
/// Perform `left / right` operation on two arrays without checking for division by zero. | ||
/// The result of dividing by zero follows normal floating point rules. | ||
/// For floating point types, the result of dividing by zero follows normal floating point | ||
/// rules. For other numeric types, dividing by zero will panic, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems inconsistent with the other APIs, perhaps it should saturate instead There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, |
||
/// If either left or right value is null then the result is also null. If any right hand value is zero then the result of this | ||
pub fn divide_unchecked<T>( | ||
/// | ||
/// This doesn't detect overflow. Once overflowing, the result will wrap around. | ||
/// For an overflow-checking variant, use `divide_checked` instead. | ||
pub fn divide<T>( | ||
left: &PrimitiveArray<T>, | ||
right: &PrimitiveArray<T>, | ||
) -> Result<PrimitiveArray<T>> | ||
where | ||
T: datatypes::ArrowFloatNumericType, | ||
T::Native: Div<Output = T::Native>, | ||
T: datatypes::ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
math_op(left, right, |a, b| a / b) | ||
math_op(left, right, |a, b| a.div_wrapping(b)) | ||
} | ||
|
||
/// Modulus every value in an array by a scalar. If any value in the array is null then the | ||
|
@@ -1769,7 +1934,7 @@ mod tests { | |
fn test_primitive_array_divide_with_nulls() { | ||
let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9), None]); | ||
let b = Int32Array::from(vec![Some(5), Some(6), Some(8), Some(9), None, None]); | ||
let c = divide(&a, &b).unwrap(); | ||
let c = divide_checked(&a, &b).unwrap(); | ||
assert_eq!(3, c.value(0)); | ||
assert!(c.is_null(1)); | ||
assert_eq!(1, c.value(2)); | ||
|
@@ -1854,7 +2019,7 @@ mod tests { | |
let b = b.slice(8, 6); | ||
let b = b.as_any().downcast_ref::<Int32Array>().unwrap(); | ||
|
||
let c = divide(a, b).unwrap(); | ||
let c = divide_checked(a, b).unwrap(); | ||
assert_eq!(6, c.len()); | ||
assert_eq!(3, c.value(0)); | ||
assert!(c.is_null(1)); | ||
|
@@ -1919,6 +2084,14 @@ mod tests { | |
|
||
#[test] | ||
#[should_panic(expected = "DivideByZero")] | ||
fn test_primitive_array_divide_by_zero_with_checked() { | ||
let a = Int32Array::from(vec![15]); | ||
let b = Int32Array::from(vec![0]); | ||
divide_checked(&a, &b).unwrap(); | ||
} | ||
|
||
#[test] | ||
#[should_panic(expected = "attempt to divide by zero")] | ||
fn test_primitive_array_divide_by_zero() { | ||
let a = Int32Array::from(vec![15]); | ||
let b = Int32Array::from(vec![0]); | ||
|
@@ -2019,4 +2192,57 @@ mod tests { | |
let expected = Float64Array::from(vec![Some(1.0), None, Some(9.0)]); | ||
assert_eq!(expected, actual); | ||
} | ||
|
||
#[test] | ||
fn test_primitive_add_wrapping_overflow() { | ||
let a = Int32Array::from(vec![i32::MAX, i32::MIN]); | ||
let b = Int32Array::from(vec![1, 1]); | ||
|
||
let wrapped = add(&a, &b); | ||
let expected = Int32Array::from(vec![-2147483648, -2147483647]); | ||
assert_eq!(expected, wrapped.unwrap()); | ||
|
||
let overflow = add_checked(&a, &b); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
|
||
#[test] | ||
fn test_primitive_subtract_wrapping_overflow() { | ||
let a = Int32Array::from(vec![-2]); | ||
let b = Int32Array::from(vec![i32::MAX]); | ||
|
||
let wrapped = subtract(&a, &b); | ||
let expected = Int32Array::from(vec![i32::MAX]); | ||
assert_eq!(expected, wrapped.unwrap()); | ||
|
||
let overflow = subtract_checked(&a, &b); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
|
||
#[test] | ||
fn test_primitive_mul_wrapping_overflow() { | ||
let a = Int32Array::from(vec![10]); | ||
let b = Int32Array::from(vec![i32::MAX]); | ||
|
||
let wrapped = multiply(&a, &b); | ||
let expected = Int32Array::from(vec![-10]); | ||
assert_eq!(expected, wrapped.unwrap()); | ||
|
||
let overflow = multiply_checked(&a, &b); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
|
||
#[test] | ||
#[cfg(not(feature = "simd"))] | ||
fn test_primitive_div_wrapping_overflow() { | ||
let a = Int32Array::from(vec![i32::MIN]); | ||
let b = Int32Array::from(vec![-1]); | ||
|
||
let wrapped = divide(&a, &b); | ||
let expected = Int32Array::from(vec![-2147483648]); | ||
assert_eq!(expected, wrapped.unwrap()); | ||
|
||
let overflow = divide_checked(&a, &b); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the difference between this function and
math_checked_divide_op
and why do we need both of them?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finally I hope we can just have one. Currently
math_checked_divide_op
is used bydivide_dyn
and I want to limit the range of change to primitive kernels only.