Skip to content

Commit

Permalink
Support casting from decimal256 to float (#3267)
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya authored Dec 5, 2022
1 parent 1640fd1 commit 06e1111
Showing 1 changed file with 128 additions and 26 deletions.
154 changes: 128 additions & 26 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
// decimal to signed numeric
(Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) |
(Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64) => true,
(Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) => true,
(Decimal128(_, _), _) => false,
(_, Decimal128(_, _)) => false,
(Decimal256(_, _), _) => false,
Expand Down Expand Up @@ -496,23 +496,16 @@ where
}

// cast the decimal array to floating-point array
macro_rules! cast_decimal_to_float {
($ARRAY:expr, $SCALE : ident, $VALUE_BUILDER: ident, $NATIVE_TYPE : ty) => {{
let array = $ARRAY.as_any().downcast_ref::<Decimal128Array>().unwrap();
let div = 10_f64.powi(*$SCALE as i32);
let mut value_builder = $VALUE_BUILDER::with_capacity(array.len());
for i in 0..array.len() {
if array.is_null(i) {
value_builder.append_null();
} else {
// The range of f32 or f64 is larger than i128, we don't need to check overflow.
// cast the i128 to f64 will lose precision, for example the `112345678901234568` will be as `112345678901234560`.
let v = (array.value(i) as f64 / div) as $NATIVE_TYPE;
value_builder.append_value(v);
}
}
Ok(Arc::new(value_builder.finish()))
}};
fn cast_decimal_to_float<D: DecimalType, T: ArrowPrimitiveType, F>(
array: &ArrayRef,
op: F,
) -> Result<ArrayRef, ArrowError>
where
F: Fn(D::Native) -> T::Native,
{
let array = array.as_any().downcast_ref::<PrimitiveArray<D>>().unwrap();
let array = array.unary::<_, T>(op);
Ok(Arc::new(array))
}

// cast the List array to Utf8 array
Expand Down Expand Up @@ -796,10 +789,14 @@ pub fn cast_with_options(
cast_options,
),
Float32 => {
cast_decimal_to_float!(array, scale, Float32Builder, f32)
cast_decimal_to_float::<Decimal128Type, Float32Type, _>(array, |x| {
(x as f64 / 10_f64.powi(*scale as i32)) as f32
})
}
Float64 => {
cast_decimal_to_float!(array, scale, Float64Builder, f64)
cast_decimal_to_float::<Decimal128Type, Float64Type, _>(array, |x| {
(x as f64 / 10_f64.powi(*scale as i32)) as f64
})
}
Null => Ok(new_null_array(to_type, array.len())),
_ => Err(ArrowError::CastError(format!(
Expand Down Expand Up @@ -859,6 +856,16 @@ pub fn cast_with_options(
*scale,
cast_options,
),
Float32 => {
cast_decimal_to_float::<Decimal256Type, Float32Type, _>(array, |x| {
(x.to_f64().unwrap() / 10_f64.powi(*scale as i32)) as f32
})
}
Float64 => {
cast_decimal_to_float::<Decimal256Type, Float64Type, _>(array, |x| {
(x.to_f64().unwrap() / 10_f64.powi(*scale as i32)) as f64
})
}
Null => Ok(new_null_array(to_type, array.len())),
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
Expand Down Expand Up @@ -3735,16 +3742,28 @@ mod tests {
// f32
generate_cast_test_case!(
&array,
Int64Array,
&DataType::Int64,
vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)]
Float32Array,
&DataType::Float32,
vec![
Some(1.25_f32),
Some(2.25_f32),
Some(3.25_f32),
None,
Some(5.25_f32)
]
);
// f64
generate_cast_test_case!(
&array,
Int64Array,
&DataType::Int64,
vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)]
Float64Array,
&DataType::Float64,
vec![
Some(1.25_f64),
Some(2.25_f64),
Some(3.25_f64),
None,
Some(5.25_f64)
]
);

// overflow test: out of range of max u8
Expand Down Expand Up @@ -3904,6 +3923,32 @@ mod tests {
&DataType::Int64,
vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)]
);
// f32
generate_cast_test_case!(
&array,
Float32Array,
&DataType::Float32,
vec![
Some(1.25_f32),
Some(2.25_f32),
Some(3.25_f32),
None,
Some(5.25_f32)
]
);
// f64
generate_cast_test_case!(
&array,
Float64Array,
&DataType::Float64,
vec![
Some(1.25_f64),
Some(2.25_f64),
Some(3.25_f64),
None,
Some(5.25_f64)
]
);

// overflow test: out of range of max i8
let value_array: Vec<Option<i256>> = vec![Some(i256::from_i128(24400))];
Expand All @@ -3920,6 +3965,63 @@ mod tests {
cast_with_options(&array, &DataType::Int8, &CastOptions { safe: true });
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));

// loss the precision: convert decimal to f32、f64
// f32
// 112345678_f32 and 112345679_f32 are same, so the 112345679_f32 will lose precision.
let value_array: Vec<Option<i256>> = vec![
Some(i256::from_i128(125)),
Some(i256::from_i128(225)),
Some(i256::from_i128(325)),
None,
Some(i256::from_i128(525)),
Some(i256::from_i128(112345678)),
Some(i256::from_i128(112345679)),
];
let decimal_array = create_decimal256_array(value_array, 76, 2).unwrap();
let array = Arc::new(decimal_array) as ArrayRef;
generate_cast_test_case!(
&array,
Float32Array,
&DataType::Float32,
vec![
Some(1.25_f32),
Some(2.25_f32),
Some(3.25_f32),
None,
Some(5.25_f32),
Some(1_123_456.7_f32),
Some(1_123_456.7_f32)
]
);

// f64
// 112345678901234568_f64 and 112345678901234560_f64 are same, so the 112345678901234568_f64 will lose precision.
let value_array: Vec<Option<i256>> = vec![
Some(i256::from_i128(125)),
Some(i256::from_i128(225)),
Some(i256::from_i128(325)),
None,
Some(i256::from_i128(525)),
Some(i256::from_i128(112345678901234568)),
Some(i256::from_i128(112345678901234560)),
];
let decimal_array = create_decimal256_array(value_array, 76, 2).unwrap();
let array = Arc::new(decimal_array) as ArrayRef;
generate_cast_test_case!(
&array,
Float64Array,
&DataType::Float64,
vec![
Some(1.25_f64),
Some(2.25_f64),
Some(3.25_f64),
None,
Some(5.25_f64),
Some(1_123_456_789_012_345.6_f64),
Some(1_123_456_789_012_345.6_f64),
]
);
}

#[test]
Expand Down

0 comments on commit 06e1111

Please sign in to comment.