Skip to content

Commit

Permalink
Support for casting StringViewArray to DecimalArray (#6720)
Browse files Browse the repository at this point in the history
* [arrow-cast] Improve support for casting string view to numeric

Signed-off-by: Tai Le Manh <[email protected]>

* fix fmt

* resolve clippy warning

* [arrow-cast] Improve support for casting string view to numeric

Signed-off-by: Tai Le Manh <[email protected]>

---------

Signed-off-by: Tai Le Manh <[email protected]>
  • Loading branch information
tlm365 authored Nov 16, 2024
1 parent f955193 commit 8b33f96
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 17 deletions.
36 changes: 36 additions & 0 deletions arrow-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ mod list_view_array;

pub use list_view_array::*;

use crate::iterator::ArrayIter;

/// An array in the [arrow columnar format](https://arrow.apache.org/docs/format/Columnar.html)
pub trait Array: std::fmt::Debug + Send + Sync {
/// Returns the array as [`Any`] so that it can be
Expand Down Expand Up @@ -570,6 +572,40 @@ pub trait ArrayAccessor: Array {
unsafe fn value_unchecked(&self, index: usize) -> Self::Item;
}

/// A trait for Arrow String Arrays, currently three types are supported:
/// - `StringArray`
/// - `LargeStringArray`
/// - `StringViewArray`
///
/// This trait helps to abstract over the different types of string arrays
/// so that we don't need to duplicate the implementation for each type.
pub trait StringArrayType<'a>: ArrayAccessor<Item = &'a str> + Sized {
/// Returns true if all data within this string array is ASCII
fn is_ascii(&self) -> bool;

/// Constructs a new iterator
fn iter(&self) -> ArrayIter<Self>;
}

impl<'a, O: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray<O> {
fn is_ascii(&self) -> bool {
GenericStringArray::<O>::is_ascii(self)
}

fn iter(&self) -> ArrayIter<Self> {
GenericStringArray::<O>::iter(self)
}
}
impl<'a> StringArrayType<'a> for &'a StringViewArray {
fn is_ascii(&self) -> bool {
StringViewArray::is_ascii(self)
}

fn iter(&self) -> ArrayIter<Self> {
StringViewArray::iter(self)
}
}

impl PartialEq for dyn Array + '_ {
fn eq(&self, other: &Self) -> bool {
self.to_data().eq(&other.to_data())
Expand Down
68 changes: 58 additions & 10 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,15 +323,16 @@ where
})
}

pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
from: &GenericStringArray<Offset>,
pub(crate) fn generic_string_to_decimal_cast<'a, T, S>(
from: &'a S,
precision: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<PrimitiveArray<T>, ArrowError>
where
T: DecimalType,
T::Native: DecimalCast + ArrowNativeTypeOp,
&'a S: StringArrayType<'a>,
{
if cast_options.safe {
let iter = from.iter().map(|v| {
Expand Down Expand Up @@ -375,6 +376,37 @@ where
}
}

pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
from: &GenericStringArray<Offset>,
precision: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<PrimitiveArray<T>, ArrowError>
where
T: DecimalType,
T::Native: DecimalCast + ArrowNativeTypeOp,
{
generic_string_to_decimal_cast::<T, GenericStringArray<Offset>>(
from,
precision,
scale,
cast_options,
)
}

pub(crate) fn string_view_to_decimal_cast<T>(
from: &StringViewArray,
precision: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<PrimitiveArray<T>, ArrowError>
where
T: DecimalType,
T::Native: DecimalCast + ArrowNativeTypeOp,
{
generic_string_to_decimal_cast::<T, StringViewArray>(from, precision, scale, cast_options)
}

/// Cast Utf8 to decimal
pub(crate) fn cast_string_to_decimal<T, Offset: OffsetSizeTrait>(
from: &dyn Array,
Expand All @@ -399,14 +431,30 @@ where
)));
}

Ok(Arc::new(string_to_decimal_cast::<T, Offset>(
from.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
.unwrap(),
precision,
scale,
cast_options,
)?))
let result = match from.data_type() {
DataType::Utf8View => string_view_to_decimal_cast::<T>(
from.as_any().downcast_ref::<StringViewArray>().unwrap(),
precision,
scale,
cast_options,
)?,
DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::<T, Offset>(
from.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
.unwrap(),
precision,
scale,
cast_options,
)?,
other => {
return Err(ArrowError::ComputeError(format!(
"Cannot cast {:?} to decimal",
other
)))
}
};

Ok(Arc::new(result))
}

pub(crate) fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
Expand Down
48 changes: 41 additions & 7 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Decimal128(_, _) | Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) => true,
// decimal to Utf8
(Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true,
// Utf8 to decimal
(Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true,
// string to decimal
(Utf8View | Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true,
(Struct(from_fields), Struct(to_fields)) => {
from_fields.len() == to_fields.len() &&
from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| {
Expand Down Expand Up @@ -230,7 +230,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
) => true,
(Utf8 | LargeUtf8, Utf8View) => true,
(BinaryView, Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View ) => true,
(Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16,
(Utf8View | Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16,
(_, Utf8 | LargeUtf8) => from_type.is_primitive(),

(_, Binary | LargeBinary) => from_type.is_integer(),
Expand Down Expand Up @@ -1061,7 +1061,7 @@ pub fn cast_with_options(
*scale,
cast_options,
),
Utf8 => cast_string_to_decimal::<Decimal128Type, i32>(
Utf8View | Utf8 => cast_string_to_decimal::<Decimal128Type, i32>(
array,
*precision,
*scale,
Expand Down Expand Up @@ -1150,7 +1150,7 @@ pub fn cast_with_options(
*scale,
cast_options,
),
Utf8 => cast_string_to_decimal::<Decimal256Type, i32>(
Utf8View | Utf8 => cast_string_to_decimal::<Decimal256Type, i32>(
array,
*precision,
*scale,
Expand Down Expand Up @@ -2485,12 +2485,11 @@ where

#[cfg(test)]
mod tests {
use super::*;
use arrow_buffer::{Buffer, IntervalDayTime, NullBuffer};
use chrono::NaiveDate;
use half::f16;

use super::*;

macro_rules! generate_cast_test_case {
($INPUT_ARRAY: expr, $OUTPUT_TYPE_ARRAY: ident, $OUTPUT_TYPE: expr, $OUTPUT_VALUES: expr) => {
let output =
Expand Down Expand Up @@ -3720,6 +3719,41 @@ mod tests {
assert!(!c.is_valid(4));
}

#[test]
fn test_cast_utf8view_to_i32() {
let array = StringViewArray::from(vec!["5", "6", "seven", "8", "9.1"]);
let b = cast(&array, &DataType::Int32).unwrap();
let c = b.as_primitive::<Int32Type>();
assert_eq!(5, c.value(0));
assert_eq!(6, c.value(1));
assert!(!c.is_valid(2));
assert_eq!(8, c.value(3));
assert!(!c.is_valid(4));
}

#[test]
fn test_cast_utf8view_to_f32() {
let array = StringViewArray::from(vec!["3", "4.56", "seven", "8.9"]);
let b = cast(&array, &DataType::Float32).unwrap();
let c = b.as_primitive::<Float32Type>();
assert_eq!(3.0, c.value(0));
assert_eq!(4.56, c.value(1));
assert!(!c.is_valid(2));
assert_eq!(8.9, c.value(3));
}

#[test]
fn test_cast_utf8view_to_decimal128() {
let array = StringViewArray::from(vec![None, Some("4"), Some("5.6"), Some("7.89")]);
let arr = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
&arr,
Decimal128Array,
&DataType::Decimal128(4, 2),
vec![None, Some(400_i128), Some(560_i128), Some(789_i128)]
);
}

#[test]
fn test_cast_with_options_utf8_to_i32() {
let array = StringArray::from(vec!["5", "6", "seven", "8", "9.1"]);
Expand Down

0 comments on commit 8b33f96

Please sign in to comment.