Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into move-arrow-native…
Browse files Browse the repository at this point in the history
…-type-op
  • Loading branch information
tustvold committed Nov 4, 2022
2 parents 2a3e2f2 + 282e7b4 commit fa89d2b
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 53 deletions.
23 changes: 23 additions & 0 deletions arrow-array/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use arrow_buffer::{i256, ArrowNativeType};
use arrow_schema::ArrowError;
use half::f16;
use num::complex::ComplexFloat;

/// Trait for [`ArrowNativeType`] that adds checked and unchecked arithmetic operations,
/// and totally ordered comparison operations
Expand Down Expand Up @@ -67,6 +68,10 @@ pub trait ArrowNativeTypeOp: ArrowNativeType {

fn neg_wrapping(self) -> Self;

fn pow_checked(self, exp: u32) -> Result<Self>;

fn pow_wrapping(self, exp: u32) -> Self;

fn is_zero(self) -> bool;

fn is_eq(self, rhs: Self) -> bool;
Expand Down Expand Up @@ -170,6 +175,16 @@ macro_rules! native_type_op {
})
}

fn pow_checked(self, exp: u32) -> Result<Self> {
self.checked_pow(exp).ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow happened on: {:?}", self))
})
}

fn pow_wrapping(self, exp: u32) -> Self {
self.wrapping_pow(exp)
}

fn neg_wrapping(self) -> Self {
self.wrapping_neg()
}
Expand Down Expand Up @@ -278,6 +293,14 @@ macro_rules! native_type_float_op {
-self
}

fn pow_checked(self, exp: u32) -> Result<Self> {
Ok(self.powi(exp as i32))
}

fn pow_wrapping(self, exp: u32) -> Self {
self.powi(exp as i32)
}

fn is_zero(self) -> bool {
self == $zero
}
Expand Down
5 changes: 5 additions & 0 deletions arrow-array/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,9 @@ pub trait DecimalType:
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType;
const DEFAULT_TYPE: DataType;

/// "Decimal128" or "Decimal256", for use in error messages
const PREFIX: &'static str;

/// Formats the decimal value with the provided precision and scale
fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String;

Expand All @@ -516,6 +519,7 @@ impl DecimalType for Decimal128Type {
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128;
const DEFAULT_TYPE: DataType =
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal128";

fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
Expand Down Expand Up @@ -543,6 +547,7 @@ impl DecimalType for Decimal256Type {
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256;
const DEFAULT_TYPE: DataType =
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal256";

fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
Expand Down
130 changes: 94 additions & 36 deletions arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,41 +309,43 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS)
}

fn cast_integer_to_decimal128<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
) -> Result<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<i128>,
{
let mul: i128 = 10_i128.pow(scale as u32);

unary::<T, _, Decimal128Type>(array, |v| v.as_() * mul)
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
}

fn cast_integer_to_decimal256<T: ArrowNumericType>(
fn cast_integer_to_decimal<
T: ArrowNumericType,
D: DecimalType + ArrowPrimitiveType<Native = M>,
M,
>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
base: M,
cast_options: &CastOptions,
) -> Result<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<i256>,
<T as ArrowPrimitiveType>::Native: AsPrimitive<M>,
M: ArrowNativeTypeOp,
{
let mul: i256 = i256::from_i128(10_i128)
.checked_pow(scale as u32)
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to Decimal256({}, {}). The scale causes overflow.",
precision, scale
))
})?;
let mul: M = base.pow_checked(scale as u32).map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). The scale causes overflow.",
D::PREFIX,
precision,
scale,
))
})?;

unary::<T, _, Decimal256Type>(array, |v| v.as_().wrapping_mul(mul))
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
if cast_options.safe {
let iter = array
.iter()
.map(|v| v.and_then(|v| v.as_().mul_checked(mul).ok()));
let casted_array = unsafe { PrimitiveArray::<D>::from_trusted_len_iter(iter) };
casted_array
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
try_unary::<T, _, D>(array, |v| v.as_().mul_checked(mul))
.and_then(|a| a.with_precision_and_scale(precision, scale))
.map(|a| Arc::new(a) as ArrayRef)
}
}

fn cast_floating_point_to_decimal128<T: ArrowNumericType>(
Expand Down Expand Up @@ -562,25 +564,33 @@ pub fn cast_with_options(
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
Int8 => cast_integer_to_decimal128(
Int8 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int8Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Int16 => cast_integer_to_decimal128(
Int16 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int16Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Int32 => cast_integer_to_decimal128(
Int32 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int32Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Int64 => cast_integer_to_decimal128(
Int64 => cast_integer_to_decimal::<_, Decimal128Type, _>(
as_primitive_array::<Int64Type>(array),
*precision,
*scale,
10_i128,
cast_options,
),
Float32 => cast_floating_point_to_decimal128(
as_primitive_array::<Float32Type>(array),
Expand All @@ -603,25 +613,33 @@ pub fn cast_with_options(
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
Int8 => cast_integer_to_decimal256(
Int8 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int8Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Int16 => cast_integer_to_decimal256(
Int16 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int16Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Int32 => cast_integer_to_decimal256(
Int32 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int32Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Int64 => cast_integer_to_decimal256(
Int64 => cast_integer_to_decimal::<_, Decimal256Type, _>(
as_primitive_array::<Int64Type>(array),
*precision,
*scale,
i256::from_i128(10_i128),
cast_options,
),
Float32 => cast_floating_point_to_decimal256(
as_primitive_array::<Float32Type>(array),
Expand Down Expand Up @@ -6049,4 +6067,44 @@ mod tests {
]
);
}

#[test]
fn test_cast_numeric_to_decimal128_overflow() {
let array = Int64Array::from(vec![i64::MAX]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast_with_options(
&array,
&DataType::Decimal128(38, 30),
&CastOptions { safe: true },
);
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));

let casted_array = cast_with_options(
&array,
&DataType::Decimal128(38, 30),
&CastOptions { safe: false },
);
assert!(casted_array.is_err());
}

#[test]
fn test_cast_numeric_to_decimal256_overflow() {
let array = Int64Array::from(vec![i64::MAX]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast_with_options(
&array,
&DataType::Decimal256(76, 76),
&CastOptions { safe: true },
);
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));

let casted_array = cast_with_options(
&array,
&DataType::Decimal256(76, 76),
&CastOptions { safe: false },
);
assert!(casted_array.is_err());
}
}
88 changes: 71 additions & 17 deletions arrow/src/compute/kernels/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

//! Defines temporal kernels for time and date related functions.
use arrow_array::downcast_dictionary_array;
use chrono::{DateTime, Datelike, NaiveDateTime, NaiveTime, Offset, Timelike};
use std::sync::Arc;

use crate::array::*;
use crate::datatypes::*;
Expand Down Expand Up @@ -180,21 +182,74 @@ where
T: ArrowTemporalType + ArrowNumericType,
i64: From<T::Native>,
{
hour_generic::<T, _>(array)
hour_internal::<T, _>(array, array.data_type())
}

/// Extracts the hours of a given temporal array as an array of integers within
/// the range of [0, 23].
pub fn hour_generic<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Result<Int32Array>
where
T: ArrowTemporalType + ArrowNumericType,
i64: From<T::Native>,
{
/// Extracts the hours of a given array as an array of integers within
/// the range of [0, 23]. If the given array isn't temporal primitive or dictionary array,
/// an `Err` will be returned.
pub fn hour_dyn(array: &dyn Array) -> Result<ArrayRef> {
match array.data_type().clone() {
DataType::Dictionary(_, value_type) => {
hour_internal::<T, A>(array, value_type.as_ref())
DataType::Dictionary(_, _) => {
downcast_dictionary_array!(
array => {
let hour_values = hour_dyn(array.values())?;
Ok(Arc::new(array.with_values(&hour_values)))
}
dt => return_compute_error_with!("hour does not support", dt),
)
}
DataType::Time32(TimeUnit::Second) => {
let array = as_primitive_array::<Time32SecondType>(array);
hour_internal::<Time32SecondType, _>(array, array.data_type())
.map(|a| Arc::new(a) as ArrayRef)
}
DataType::Time32(TimeUnit::Microsecond) => {
let array = as_primitive_array::<Time32MillisecondType>(array);
hour_internal::<Time32MillisecondType, _>(array, array.data_type())
.map(|a| Arc::new(a) as ArrayRef)
}
DataType::Time64(TimeUnit::Microsecond) => {
let array = as_primitive_array::<Time64MicrosecondType>(array);
hour_internal::<Time64MicrosecondType, _>(array, array.data_type())
.map(|a| Arc::new(a) as ArrayRef)
}
DataType::Time64(TimeUnit::Nanosecond) => {
let array = as_primitive_array::<Time64NanosecondType>(array);
hour_internal::<Time64NanosecondType, _>(array, array.data_type())
.map(|a| Arc::new(a) as ArrayRef)
}
dt => hour_internal::<T, A>(array, &dt),
DataType::Date32 => {
let array = as_primitive_array::<Date32Type>(array);
hour_internal::<Date32Type, _>(array, array.data_type())
.map(|a| Arc::new(a) as ArrayRef)
}
DataType::Date64 => {
let array = as_primitive_array::<Date64Type>(array);
hour_internal::<Date64Type, _>(array, array.data_type())
.map(|a| Arc::new(a) as ArrayRef)
}
DataType::Timestamp(TimeUnit::Second, _) => {
let array = as_primitive_array::<TimestampSecondType>(array);
hour_internal::<TimestampSecondType, _>(array, array.data_type())
.map(|a| Arc::new(a) as ArrayRef)
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
let array = as_primitive_array::<TimestampMillisecondType>(array);
hour_internal::<TimestampMillisecondType, _>(array, array.data_type())
.map(|a| Arc::new(a) as ArrayRef)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
let array = as_primitive_array::<TimestampMicrosecondType>(array);
hour_internal::<TimestampMicrosecondType, _>(array, array.data_type())
.map(|a| Arc::new(a) as ArrayRef)
}
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
let array = as_primitive_array::<TimestampNanosecondType>(array);
hour_internal::<TimestampNanosecondType, _>(array, array.data_type())
.map(|a| Arc::new(a) as ArrayRef)
}
dt => return_compute_error_with!("hour does not support", dt),
}
}

Expand Down Expand Up @@ -1197,13 +1252,12 @@ mod tests {
let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 1]);
let dict = DictionaryArray::try_new(&keys, &a).unwrap();

let b = hour_generic::<TimestampSecondType, _>(
dict.downcast_dict::<TimestampSecondArray>().unwrap(),
)
.unwrap();
let b = hour_dyn(&dict).unwrap();

let expected = Int32Array::from(vec![11, 11, 21, 7, 21]);
assert_eq!(expected, b);
let expected_dict =
DictionaryArray::try_new(&keys, &Int32Array::from(vec![11, 21, 7])).unwrap();
let expected = Arc::new(expected_dict) as ArrayRef;
assert_eq!(&expected, &b);

let b = time_fraction_generic::<TimestampSecondType, _, _>(
dict.downcast_dict::<TimestampSecondArray>().unwrap(),
Expand Down
Loading

0 comments on commit fa89d2b

Please sign in to comment.