Skip to content

Commit

Permalink
Combine cast_integer_to_decimal functions of decimal128 and decimal256
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Nov 3, 2022
1 parent 839abb1 commit 5b356ad
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 46 deletions.
11 changes: 11 additions & 0 deletions arrow-array/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,9 @@ pub trait DecimalType:
/// Formats the decimal value with the provided precision and scale
fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String;

/// Formats the decimal type with the provided precision and scale
fn format_decimal_type(precision: u8, scale: u8) -> String;

/// Validates that `value` contains no more than `precision` decimal digits
fn validate_decimal_precision(
value: Self::Native,
Expand All @@ -521,6 +524,10 @@ impl DecimalType for Decimal128Type {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
}

fn format_decimal_type(precision: u8, scale: u8) -> String {
format!("Decimal128({}.{})", precision, scale)
}

fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> {
validate_decimal_precision(num, precision)
}
Expand Down Expand Up @@ -548,6 +555,10 @@ impl DecimalType for Decimal256Type {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
}

fn format_decimal_type(precision: u8, scale: u8) -> String {
format!("Decimal256({}.{})", precision, scale)
}

fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> {
validate_decimal256_precision_with_lt_bytes(&num.to_le_bytes(), precision)
}
Expand Down
78 changes: 32 additions & 46 deletions arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,60 +306,38 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS)
}

fn cast_integer_to_decimal128<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
cast_options: &CastOptions,
) -> Result<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<i128>,
{
let mul: i128 = 10_i128.pow(scale as u32);

if cast_options.safe {
let iter = array
.iter()
.map(|v| v.and_then(|v| v.as_().mul_checked(mul).ok()));
let casted_array = unsafe { Decimal128Array::from_trusted_len_iter(iter) };
casted_array
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
try_unary::<T, _, Decimal128Type>(array, |v| v.as_().mul_checked(mul))
.and_then(|a| a.with_precision_and_scale(precision, scale))
.map(|a| Arc::new(a) as ArrayRef)
}
}

fn cast_integer_to_decimal256<T: ArrowNumericType>(
fn cast_integer_to_decimal<
T: ArrowNumericType,
D: DecimalType + ArrowPrimitiveType<Native = M>,
M,
>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
base: M,
cast_options: &CastOptions,
) -> Result<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<i256>,
<T as ArrowPrimitiveType>::Native: AsPrimitive<M>,
M: ArrowNativeTypeOp,
{
let mul: i256 = i256::from_i128(10_i128)
.checked_pow(scale as u32)
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to Decimal256({}, {}). The scale causes overflow.",
precision, scale
))
})?;
let mul: M = base.pow_checked(scale as u32).map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast to {:?}. The scale causes overflow.",
D::format_decimal_type(precision, scale),
))
})?;

if cast_options.safe {
let iter = array
.iter()
.map(|v| v.and_then(|v| v.as_().mul_checked(mul).ok()));
let casted_array = unsafe { Decimal256Array::from_trusted_len_iter(iter) };
let casted_array = unsafe { PrimitiveArray::<D>::from_trusted_len_iter(iter) };
casted_array
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
try_unary::<T, _, Decimal256Type>(array, |v| v.as_().mul_checked(mul))
try_unary::<T, _, D>(array, |v| v.as_().mul_checked(mul))
.and_then(|a| a.with_precision_and_scale(precision, scale))
.map(|a| Arc::new(a) as ArrayRef)
}
Expand Down Expand Up @@ -579,28 +557,32 @@ pub fn cast_with_options(
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
Int8 => cast_integer_to_decimal128(
Int8 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int8Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Int16 => cast_integer_to_decimal128(
Int16 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int16Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Int32 => cast_integer_to_decimal128(
Int32 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int32Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Int64 => cast_integer_to_decimal128(
Int64 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int64Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Float32 => cast_floating_point_to_decimal128(
Expand All @@ -624,28 +606,32 @@ pub fn cast_with_options(
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
Int8 => cast_integer_to_decimal256(
Int8 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int8Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Int16 => cast_integer_to_decimal256(
Int16 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int16Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Int32 => cast_integer_to_decimal256(
Int32 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int32Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Int64 => cast_integer_to_decimal256(
Int64 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int64Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Float32 => cast_floating_point_to_decimal256(
Expand Down
23 changes: 23 additions & 0 deletions arrow/src/datatypes/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::error::{ArrowError, Result};
pub use arrow_array::ArrowPrimitiveType;
pub use arrow_buffer::{i256, ArrowNativeType, ToByteSlice};
use half::f16;
use num::complex::ComplexFloat;

/// Trait for [`ArrowNativeType`] that adds checked and unchecked arithmetic operations,
/// and totally ordered comparison operations
Expand Down Expand Up @@ -68,6 +69,10 @@ pub trait ArrowNativeTypeOp: ArrowNativeType {

fn neg_wrapping(self) -> Self;

fn pow_checked(self, exp: u32) -> Result<Self>;

fn pow_wrapping(self, exp: u32) -> Self;

fn is_zero(self) -> bool;

fn is_eq(self, rhs: Self) -> bool;
Expand Down Expand Up @@ -171,6 +176,16 @@ macro_rules! native_type_op {
})
}

fn pow_checked(self, mut exp: u32) -> Result<Self> {
self.checked_pow(exp).ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow happened on: {:?}", self))
})
}

fn pow_wrapping(self, mut exp: u32) -> Self {
self.wrapping_pow(exp)
}

fn neg_wrapping(self) -> Self {
self.wrapping_neg()
}
Expand Down Expand Up @@ -279,6 +294,14 @@ macro_rules! native_type_float_op {
-self
}

fn pow_checked(self, mut exp: u32) -> Result<Self> {
Ok(self.powi(exp as i32))
}

fn pow_wrapping(self, mut exp: u32) -> Self {
self.powi(exp as i32)
}

fn is_zero(self) -> bool {
self == $zero
}
Expand Down

0 comments on commit 5b356ad

Please sign in to comment.