diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 7821c22279cc..b46d7c642362 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -69,10 +69,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { match (from_type, to_type) { // TODO now just support signed numeric to decimal, support decimal to numeric later - // support one decimal data type to another decimal data type - // or UTF-8 to decimal - // numeric to decimal - (Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal(_, _), Decimal(_, _)) + (Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal(_, _)) | ( Null, Boolean @@ -245,6 +242,45 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS) } +// cast the integer array to defined decimal data type array +macro_rules! cast_integer_to_decimal { + ($ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => {{ + let mut decimal_builder = DecimalBuilder::new($ARRAY.len(), *$PRECISION, *$SCALE); + let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + let mul: i128 = 10_i128.pow(*$SCALE as u32); + for i in 0..array.len() { + if array.is_null(i) { + decimal_builder.append_null()?; + } else { + // convert i128 first + let v = array.value(i) as i128; + // if the input value is overflow, it will throw an error. + decimal_builder.append_value(mul * v)?; + } + } + Ok(Arc::new(decimal_builder.finish())) + }}; +} + +// cast the floating-point array to defined decimal data type array +macro_rules! cast_floating_point_to_decimal { + ($ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => {{ + let mut decimal_builder = DecimalBuilder::new($ARRAY.len(), *$PRECISION, *$SCALE); + let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + let mul = 10_f64.powi(*$SCALE as i32); + for i in 0..array.len() { + if array.is_null(i) { + decimal_builder.append_null()?; + } else { + let v = ((array.value(i) as f64) * mul) as i128; + // if the input value is overflow, it will throw an error. + decimal_builder.append_value(v)?; + } + } + Ok(Arc::new(decimal_builder.finish())) + }}; +} + /// Cast `array` to the provided data type and return a new Array with /// type `to_type`, if possible. It accepts `CastOptions` to allow consumers /// to configure cast behavior. @@ -279,6 +315,34 @@ pub fn cast_with_options( return Ok(array.clone()); } match (from_type, to_type) { + (_, Decimal(precision, scale)) => { + // cast data to decimal + match from_type { + // TODO now just support signed numeric to decimal, support decimal to numeric later + Int8 => { + cast_integer_to_decimal!(array, Int8Array, precision, scale) + } + Int16 => { + cast_integer_to_decimal!(array, Int16Array, precision, scale) + } + Int32 => { + cast_integer_to_decimal!(array, Int32Array, precision, scale) + } + Int64 => { + cast_integer_to_decimal!(array, Int64Array, precision, scale) + } + Float32 => { + cast_floating_point_to_decimal!(array, Float32Array, precision, scale) + } + Float64 => { + cast_floating_point_to_decimal!(array, Float64Array, precision, scale) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type + ))), + } + } ( Null, Boolean @@ -1257,7 +1321,7 @@ fn cast_string_to_date64( if string_array.is_null(i) { Ok(None) } else { - let string = string_array + let string = string_array .value(i); let result = string @@ -1476,7 +1540,7 @@ fn dictionary_cast( return Err(ArrowError::CastError(format!( "Unsupported type {:?} for dictionary index", to_index_type - ))) + ))); } }; @@ -1842,6 +1906,125 @@ where mod tests { use super::*; use crate::{buffer::Buffer, util::display::array_value_to_string}; + use num::traits::Pow; + + #[test] + fn test_cast_numeric_to_decimal() { + let data_types = vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + let decimal_type = DataType::Decimal(38, 6); + for data_type in data_types { + assert!(can_cast_types(&data_type, &decimal_type)) + } + assert!(!can_cast_types(&DataType::UInt64, &decimal_type)); + + // test i8 to decimal type + let array = Int8Array::from(vec![1, 2, 3, 4, 5]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast(&array, &decimal_type).unwrap(); + let decimal_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(&decimal_type, decimal_array.data_type()); + for i in 0..array.len() { + assert_eq!( + 10_i128.pow(6) * (i as i128 + 1), + decimal_array.value(i as usize) + ); + } + // test i8 to decimal type with overflow the result type + // the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3. + let array = Int8Array::from(vec![1, 2, 3, 4, 100]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast(&array, &DataType::Decimal(3, 1)); + assert!(casted_array.is_err()); + assert_eq!("Invalid argument error: The value of 1000 i128 is not compatible with Decimal(3,1)", casted_array.unwrap_err().to_string()); + + // test i16 to decimal type + let array = Int16Array::from(vec![1, 2, 3, 4, 5]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast(&array, &decimal_type).unwrap(); + let decimal_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(&decimal_type, decimal_array.data_type()); + for i in 0..array.len() { + assert_eq!( + 10_i128.pow(6) * (i as i128 + 1), + decimal_array.value(i as usize) + ); + } + + // test i32 to decimal type + let array = Int32Array::from(vec![1, 2, 3, 4, 5]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast(&array, &decimal_type).unwrap(); + let decimal_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(&decimal_type, decimal_array.data_type()); + for i in 0..array.len() { + assert_eq!( + 10_i128.pow(6) * (i as i128 + 1), + decimal_array.value(i as usize) + ); + } + + // test i64 to decimal type + let array = Int64Array::from(vec![1, 2, 3, 4, 5]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast(&array, &decimal_type).unwrap(); + let decimal_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(&decimal_type, decimal_array.data_type()); + for i in 0..array.len() { + assert_eq!( + 10_i128.pow(6) * (i as i128 + 1), + decimal_array.value(i as usize) + ); + } + + // test f32 to decimal type + let f_data: Vec = vec![1.1, 2.2, 4.4, 1.1234567891234]; + let array = Float32Array::from(f_data.clone()); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast(&array, &decimal_type).unwrap(); + let decimal_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(&decimal_type, decimal_array.data_type()); + for i in 0..array.len() { + let left = (f_data[i] as f64) * 10_f64.pow(6); + assert_eq!(left as i128, decimal_array.value(i as usize)); + } + + // test f64 to decimal type + let f_data: Vec = vec![1.1, 2.2, 4.4, 1.1234567891234]; + let array = Float64Array::from(f_data.clone()); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast(&array, &decimal_type).unwrap(); + let decimal_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(&decimal_type, decimal_array.data_type()); + for i in 0..array.len() { + let left = (f_data[i] as f64) * 10_f64.pow(6); + assert_eq!(left as i128, decimal_array.value(i as usize)); + } + } #[test] fn test_cast_i32_to_f64() {