Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check overflow when casting integer to decimal #3009

Merged
merged 6 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions arrow-array/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,9 @@ pub trait DecimalType:
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType;
const DEFAULT_TYPE: DataType;

/// "Decimal128" or "Decimal256", for use in error messages
const PREFIX: &'static str;

/// Formats the decimal value with the provided precision and scale
fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String;

Expand All @@ -516,6 +519,7 @@ impl DecimalType for Decimal128Type {
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128;
const DEFAULT_TYPE: DataType =
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal128";

fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
Expand Down Expand Up @@ -543,6 +547,7 @@ impl DecimalType for Decimal256Type {
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256;
const DEFAULT_TYPE: DataType =
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal256";

fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
Expand Down
130 changes: 94 additions & 36 deletions arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,41 +309,43 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS)
}

fn cast_integer_to_decimal128<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
) -> Result<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<i128>,
{
let mul: i128 = 10_i128.pow(scale as u32);

unary::<T, _, Decimal128Type>(array, |v| v.as_() * mul)
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
}

fn cast_integer_to_decimal256<T: ArrowNumericType>(
fn cast_integer_to_decimal<
T: ArrowNumericType,
D: DecimalType + ArrowPrimitiveType<Native = M>,
M,
>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
base: M,
cast_options: &CastOptions,
) -> Result<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<i256>,
<T as ArrowPrimitiveType>::Native: AsPrimitive<M>,
M: ArrowNativeTypeOp,
{
let mul: i256 = i256::from_i128(10_i128)
.checked_pow(scale as u32)
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to Decimal256({}, {}). The scale causes overflow.",
precision, scale
))
})?;
let mul: M = base.pow_checked(scale as u32).map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). The scale causes overflow.",
D::PREFIX,
precision,
scale,
))
})?;

unary::<T, _, Decimal256Type>(array, |v| v.as_().wrapping_mul(mul))
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
if cast_options.safe {
let iter = array
.iter()
.map(|v| v.and_then(|v| v.as_().mul_checked(mul).ok()));
let casted_array = unsafe { PrimitiveArray::<D>::from_trusted_len_iter(iter) };
casted_array
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
try_unary::<T, _, D>(array, |v| v.as_().mul_checked(mul))
.and_then(|a| a.with_precision_and_scale(precision, scale))
.map(|a| Arc::new(a) as ArrayRef)
}
}

fn cast_floating_point_to_decimal128<T: ArrowNumericType>(
Expand Down Expand Up @@ -562,25 +564,33 @@ pub fn cast_with_options(
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
Int8 => cast_integer_to_decimal128(
Int8 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int8Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Int16 => cast_integer_to_decimal128(
Int16 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int16Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Int32 => cast_integer_to_decimal128(
Int32 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int32Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Int64 => cast_integer_to_decimal128(
Int64 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int64Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Float32 => cast_floating_point_to_decimal128(
as_primitive_array::<Float32Type>(array),
Expand All @@ -603,25 +613,33 @@ pub fn cast_with_options(
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
Int8 => cast_integer_to_decimal256(
Int8 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int8Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Int16 => cast_integer_to_decimal256(
Int16 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int16Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Int32 => cast_integer_to_decimal256(
Int32 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int32Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Int64 => cast_integer_to_decimal256(
Int64 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int64Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Float32 => cast_floating_point_to_decimal256(
as_primitive_array::<Float32Type>(array),
Expand Down Expand Up @@ -6049,4 +6067,44 @@ mod tests {
]
);
}

#[test]
fn test_cast_numeric_to_decimal128_overflow() {
let array = Int64Array::from(vec![i64::MAX]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast_with_options(
&array,
&DataType::Decimal128(38, 30),
&CastOptions { safe: true },
);
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));

let casted_array = cast_with_options(
&array,
&DataType::Decimal128(38, 30),
&CastOptions { safe: false },
);
assert!(casted_array.is_err());
}

#[test]
fn test_cast_numeric_to_decimal256_overflow() {
let array = Int64Array::from(vec![i64::MAX]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast_with_options(
&array,
&DataType::Decimal256(76, 76),
&CastOptions { safe: true },
);
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));

let casted_array = cast_with_options(
&array,
&DataType::Decimal256(76, 76),
&CastOptions { safe: false },
);
assert!(casted_array.is_err());
}
}
23 changes: 23 additions & 0 deletions arrow/src/datatypes/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::error::{ArrowError, Result};
pub use arrow_array::ArrowPrimitiveType;
pub use arrow_buffer::{i256, ArrowNativeType, ToByteSlice};
use half::f16;
use num::complex::ComplexFloat;

/// Trait for [`ArrowNativeType`] that adds checked and unchecked arithmetic operations,
/// and totally ordered comparison operations
Expand Down Expand Up @@ -68,6 +69,10 @@ pub trait ArrowNativeTypeOp: ArrowNativeType {

fn neg_wrapping(self) -> Self;

fn pow_checked(self, exp: u32) -> Result<Self>;

fn pow_wrapping(self, exp: u32) -> Self;

fn is_zero(self) -> bool;

fn is_eq(self, rhs: Self) -> bool;
Expand Down Expand Up @@ -171,6 +176,16 @@ macro_rules! native_type_op {
})
}

fn pow_checked(self, exp: u32) -> Result<Self> {
self.checked_pow(exp).ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow happened on: {:?}", self))
})
}

fn pow_wrapping(self, exp: u32) -> Self {
self.wrapping_pow(exp)
}

fn neg_wrapping(self) -> Self {
self.wrapping_neg()
}
Expand Down Expand Up @@ -279,6 +294,14 @@ macro_rules! native_type_float_op {
-self
}

fn pow_checked(self, exp: u32) -> Result<Self> {
Ok(self.powi(exp as i32))
}

fn pow_wrapping(self, exp: u32) -> Self {
self.powi(exp as i32)
}

fn is_zero(self) -> bool {
self == $zero
}
Expand Down