Skip to content

Commit

Permalink
Check overflow when casting integer to decimal (#3009)
Browse files Browse the repository at this point in the history
* Check overflow when casting integer to decimal

* Trigger Build

* Combine cast_integer_to_decimal functions of decimal128 and decimal256

* Fix clippy

* Trigger Build

* Use PREFIX way.
  • Loading branch information
viirya authored Nov 4, 2022
1 parent 8400b09 commit 766f69f
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 36 deletions.
5 changes: 5 additions & 0 deletions arrow-array/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,9 @@ pub trait DecimalType:
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType;
const DEFAULT_TYPE: DataType;

/// "Decimal128" or "Decimal256", for use in error messages
const PREFIX: &'static str;

/// Formats the decimal value with the provided precision and scale
fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String;

Expand All @@ -516,6 +519,7 @@ impl DecimalType for Decimal128Type {
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128;
const DEFAULT_TYPE: DataType =
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal128";

fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
Expand Down Expand Up @@ -543,6 +547,7 @@ impl DecimalType for Decimal256Type {
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256;
const DEFAULT_TYPE: DataType =
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal256";

fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
Expand Down
130 changes: 94 additions & 36 deletions arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,41 +309,43 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS)
}

fn cast_integer_to_decimal128<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
) -> Result<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<i128>,
{
let mul: i128 = 10_i128.pow(scale as u32);

unary::<T, _, Decimal128Type>(array, |v| v.as_() * mul)
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
}

fn cast_integer_to_decimal256<T: ArrowNumericType>(
fn cast_integer_to_decimal<
T: ArrowNumericType,
D: DecimalType + ArrowPrimitiveType<Native = M>,
M,
>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
base: M,
cast_options: &CastOptions,
) -> Result<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<i256>,
<T as ArrowPrimitiveType>::Native: AsPrimitive<M>,
M: ArrowNativeTypeOp,
{
let mul: i256 = i256::from_i128(10_i128)
.checked_pow(scale as u32)
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to Decimal256({}, {}). The scale causes overflow.",
precision, scale
))
})?;
let mul: M = base.pow_checked(scale as u32).map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). The scale causes overflow.",
D::PREFIX,
precision,
scale,
))
})?;

unary::<T, _, Decimal256Type>(array, |v| v.as_().wrapping_mul(mul))
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
if cast_options.safe {
let iter = array
.iter()
.map(|v| v.and_then(|v| v.as_().mul_checked(mul).ok()));
let casted_array = unsafe { PrimitiveArray::<D>::from_trusted_len_iter(iter) };
casted_array
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
try_unary::<T, _, D>(array, |v| v.as_().mul_checked(mul))
.and_then(|a| a.with_precision_and_scale(precision, scale))
.map(|a| Arc::new(a) as ArrayRef)
}
}

fn cast_floating_point_to_decimal128<T: ArrowNumericType>(
Expand Down Expand Up @@ -562,25 +564,33 @@ pub fn cast_with_options(
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
Int8 => cast_integer_to_decimal128(
Int8 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int8Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Int16 => cast_integer_to_decimal128(
Int16 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int16Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Int32 => cast_integer_to_decimal128(
Int32 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int32Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Int64 => cast_integer_to_decimal128(
Int64 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int64Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Float32 => cast_floating_point_to_decimal128(
as_primitive_array::<Float32Type>(array),
Expand All @@ -603,25 +613,33 @@ pub fn cast_with_options(
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
Int8 => cast_integer_to_decimal256(
Int8 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int8Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Int16 => cast_integer_to_decimal256(
Int16 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int16Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Int32 => cast_integer_to_decimal256(
Int32 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int32Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Int64 => cast_integer_to_decimal256(
Int64 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int64Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Float32 => cast_floating_point_to_decimal256(
as_primitive_array::<Float32Type>(array),
Expand Down Expand Up @@ -6049,4 +6067,44 @@ mod tests {
]
);
}

#[test]
fn test_cast_numeric_to_decimal128_overflow() {
let array = Int64Array::from(vec![i64::MAX]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast_with_options(
&array,
&DataType::Decimal128(38, 30),
&CastOptions { safe: true },
);
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));

let casted_array = cast_with_options(
&array,
&DataType::Decimal128(38, 30),
&CastOptions { safe: false },
);
assert!(casted_array.is_err());
}

#[test]
fn test_cast_numeric_to_decimal256_overflow() {
let array = Int64Array::from(vec![i64::MAX]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast_with_options(
&array,
&DataType::Decimal256(76, 76),
&CastOptions { safe: true },
);
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));

let casted_array = cast_with_options(
&array,
&DataType::Decimal256(76, 76),
&CastOptions { safe: false },
);
assert!(casted_array.is_err());
}
}
23 changes: 23 additions & 0 deletions arrow/src/datatypes/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::error::{ArrowError, Result};
pub use arrow_array::ArrowPrimitiveType;
pub use arrow_buffer::{i256, ArrowNativeType, ToByteSlice};
use half::f16;
use num::complex::ComplexFloat;

/// Trait for [`ArrowNativeType`] that adds checked and unchecked arithmetic operations,
/// and totally ordered comparison operations
Expand Down Expand Up @@ -68,6 +69,10 @@ pub trait ArrowNativeTypeOp: ArrowNativeType {

fn neg_wrapping(self) -> Self;

fn pow_checked(self, exp: u32) -> Result<Self>;

fn pow_wrapping(self, exp: u32) -> Self;

fn is_zero(self) -> bool;

fn is_eq(self, rhs: Self) -> bool;
Expand Down Expand Up @@ -171,6 +176,16 @@ macro_rules! native_type_op {
})
}

fn pow_checked(self, exp: u32) -> Result<Self> {
self.checked_pow(exp).ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow happened on: {:?}", self))
})
}

fn pow_wrapping(self, exp: u32) -> Self {
self.wrapping_pow(exp)
}

fn neg_wrapping(self) -> Self {
self.wrapping_neg()
}
Expand Down Expand Up @@ -279,6 +294,14 @@ macro_rules! native_type_float_op {
-self
}

fn pow_checked(self, exp: u32) -> Result<Self> {
Ok(self.powi(exp as i32))
}

fn pow_wrapping(self, exp: u32) -> Self {
self.powi(exp as i32)
}

fn is_zero(self) -> bool {
self == $zero
}
Expand Down

0 comments on commit 766f69f

Please sign in to comment.