diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index e6197eed19cf..7c7a5c811550 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -495,6 +495,9 @@ pub trait DecimalType: const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType; const DEFAULT_TYPE: DataType; + /// "Decimal128" or "Decimal256", for use in error messages + const PREFIX: &'static str; + /// Formats the decimal value with the provided precision and scale fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String; @@ -516,6 +519,7 @@ impl DecimalType for Decimal128Type { const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128; const DEFAULT_TYPE: DataType = DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); + const PREFIX: &'static str = "Decimal128"; fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String { format_decimal_str(&value.to_string(), precision as usize, scale as usize) @@ -543,6 +547,7 @@ impl DecimalType for Decimal256Type { const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256; const DEFAULT_TYPE: DataType = DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); + const PREFIX: &'static str = "Decimal256"; fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String { format_decimal_str(&value.to_string(), precision as usize, scale as usize) diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 4ad8dd99e73e..b1e744d26824 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -309,41 +309,43 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS) } -fn cast_integer_to_decimal128( - array: &PrimitiveArray, - precision: u8, - scale: u8, -) -> Result -where - ::Native: AsPrimitive, -{ - let mul: i128 = 10_i128.pow(scale as u32); - - unary::(array, |v| v.as_() * mul) - .with_precision_and_scale(precision, scale) - .map(|a| Arc::new(a) as ArrayRef) -} - -fn cast_integer_to_decimal256( +fn cast_integer_to_decimal< + T: ArrowNumericType, + D: DecimalType + ArrowPrimitiveType, + M, +>( array: &PrimitiveArray, precision: u8, scale: u8, + base: M, + cast_options: &CastOptions, ) -> Result where - ::Native: AsPrimitive, + ::Native: AsPrimitive, + M: ArrowNativeTypeOp, { - let mul: i256 = i256::from_i128(10_i128) - .checked_pow(scale as u32) - .ok_or_else(|| { - ArrowError::CastError(format!( - "Cannot cast to Decimal256({}, {}). The scale causes overflow.", - precision, scale - )) - })?; + let mul: M = base.pow_checked(scale as u32).map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast to {:?}({}, {}). The scale causes overflow.", + D::PREFIX, + precision, + scale, + )) + })?; - unary::(array, |v| v.as_().wrapping_mul(mul)) - .with_precision_and_scale(precision, scale) - .map(|a| Arc::new(a) as ArrayRef) + if cast_options.safe { + let iter = array + .iter() + .map(|v| v.and_then(|v| v.as_().mul_checked(mul).ok())); + let casted_array = unsafe { PrimitiveArray::::from_trusted_len_iter(iter) }; + casted_array + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } else { + try_unary::(array, |v| v.as_().mul_checked(mul)) + .and_then(|a| a.with_precision_and_scale(precision, scale)) + .map(|a| Arc::new(a) as ArrayRef) + } } fn cast_floating_point_to_decimal128( @@ -562,25 +564,33 @@ pub fn cast_with_options( // cast data to decimal match from_type { // TODO now just support signed numeric to decimal, support decimal to numeric later - Int8 => cast_integer_to_decimal128( + Int8 => cast_integer_to_decimal::<_, Decimal128Type, _>( as_primitive_array::(array), *precision, *scale, + 10_i128, + cast_options, ), - Int16 => cast_integer_to_decimal128( + Int16 => cast_integer_to_decimal::<_, Decimal128Type, _>( as_primitive_array::(array), *precision, *scale, + 10_i128, + cast_options, ), - Int32 => cast_integer_to_decimal128( + Int32 => cast_integer_to_decimal::<_, Decimal128Type, _>( as_primitive_array::(array), *precision, *scale, + 10_i128, + cast_options, ), - Int64 => cast_integer_to_decimal128( + Int64 => cast_integer_to_decimal::<_, Decimal128Type, _>( as_primitive_array::(array), *precision, *scale, + 10_i128, + cast_options, ), Float32 => cast_floating_point_to_decimal128( as_primitive_array::(array), @@ -603,25 +613,33 @@ pub fn cast_with_options( // cast data to decimal match from_type { // TODO now just support signed numeric to decimal, support decimal to numeric later - Int8 => cast_integer_to_decimal256( + Int8 => cast_integer_to_decimal::<_, Decimal256Type, _>( as_primitive_array::(array), *precision, *scale, + i256::from_i128(10_i128), + cast_options, ), - Int16 => cast_integer_to_decimal256( + Int16 => cast_integer_to_decimal::<_, Decimal256Type, _>( as_primitive_array::(array), *precision, *scale, + i256::from_i128(10_i128), + cast_options, ), - Int32 => cast_integer_to_decimal256( + Int32 => cast_integer_to_decimal::<_, Decimal256Type, _>( as_primitive_array::(array), *precision, *scale, + i256::from_i128(10_i128), + cast_options, ), - Int64 => cast_integer_to_decimal256( + Int64 => cast_integer_to_decimal::<_, Decimal256Type, _>( as_primitive_array::(array), *precision, *scale, + i256::from_i128(10_i128), + cast_options, ), Float32 => cast_floating_point_to_decimal256( as_primitive_array::(array), @@ -6049,4 +6067,44 @@ mod tests { ] ); } + + #[test] + fn test_cast_numeric_to_decimal128_overflow() { + let array = Int64Array::from(vec![i64::MAX]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(38, 30), + &CastOptions { safe: true }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(38, 30), + &CastOptions { safe: false }, + ); + assert!(casted_array.is_err()); + } + + #[test] + fn test_cast_numeric_to_decimal256_overflow() { + let array = Int64Array::from(vec![i64::MAX]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(76, 76), + &CastOptions { safe: true }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(76, 76), + &CastOptions { safe: false }, + ); + assert!(casted_array.is_err()); + } } diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs index bbdec14b44a0..28ef877a2fd3 100644 --- a/arrow/src/datatypes/native.rs +++ b/arrow/src/datatypes/native.rs @@ -19,6 +19,7 @@ use crate::error::{ArrowError, Result}; pub use arrow_array::ArrowPrimitiveType; pub use arrow_buffer::{i256, ArrowNativeType, ToByteSlice}; use half::f16; +use num::complex::ComplexFloat; /// Trait for [`ArrowNativeType`] that adds checked and unchecked arithmetic operations, /// and totally ordered comparison operations @@ -68,6 +69,10 @@ pub trait ArrowNativeTypeOp: ArrowNativeType { fn neg_wrapping(self) -> Self; + fn pow_checked(self, exp: u32) -> Result; + + fn pow_wrapping(self, exp: u32) -> Self; + fn is_zero(self) -> bool; fn is_eq(self, rhs: Self) -> bool; @@ -171,6 +176,16 @@ macro_rules! native_type_op { }) } + fn pow_checked(self, exp: u32) -> Result { + self.checked_pow(exp).ok_or_else(|| { + ArrowError::ComputeError(format!("Overflow happened on: {:?}", self)) + }) + } + + fn pow_wrapping(self, exp: u32) -> Self { + self.wrapping_pow(exp) + } + fn neg_wrapping(self) -> Self { self.wrapping_neg() } @@ -279,6 +294,14 @@ macro_rules! native_type_float_op { -self } + fn pow_checked(self, exp: u32) -> Result { + Ok(self.powi(exp as i32)) + } + + fn pow_wrapping(self, exp: u32) -> Self { + self.powi(exp as i32) + } + fn is_zero(self) -> bool { self == $zero }