From 6d9ea41724e6f46066e60db1bd4d940c5da742a2 Mon Sep 17 00:00:00 2001 From: Adam Lippai Date: Sun, 13 Jun 2021 20:16:14 +0200 Subject: [PATCH] Add C data interface for decimal128 --- .../tests/test_sql.py | 24 ++++- arrow/src/ffi.rs | 93 +++++++++++++++++-- 2 files changed, 102 insertions(+), 15 deletions(-) diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index c0de382057c1..653b2e68a7ed 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -17,6 +17,7 @@ # under the License. import unittest +from decimal import Decimal import pyarrow import arrow_pyarrow_integration_testing @@ -69,9 +70,25 @@ def test_time32_python(self): Python -> Rust -> Python """ old_allocated = pyarrow.total_allocated_bytes() - a = pyarrow.array([None, 1, 2], pyarrow.time32('s')) + a = pyarrow.array([None, 1, 2], pyarrow.time32("s")) b = arrow_pyarrow_integration_testing.concatenate(a) - expected = pyarrow.array([None, 1, 2] + [None, 1, 2], pyarrow.time32('s')) + expected = pyarrow.array([None, 1, 2] + [None, 1, 2], pyarrow.time32("s")) + self.assertEqual(b, expected) + del a + del b + del expected + # No leak of C++ memory + self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) + + def test_decimal_python(self): + """ + Python -> Rust -> Python + """ + old_allocated = pyarrow.total_allocated_bytes() + py_array = [None, Decimal(123.45), Decimal(-123.45)] + a = pyarrow.array(py_array, pyarrow.decimal128(6, 2)) + b = arrow_pyarrow_integration_testing.concatenate(a) + expected = pyarrow.array(py_array + py_array, pyarrow.decimal128(6, 2)) self.assertEqual(b, expected) del a del b @@ -94,6 +111,3 @@ def test_list_array(self): del b # No leak of C++ memory self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - - diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index 42dc4407abb6..1c6e10b23600 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -271,12 +271,54 @@ fn to_field(schema: &FFI_ArrowSchema) -> Result { .collect::>>()?; DataType::Struct(children) } - other => { - return Err(ArrowError::CDataInterface(format!( - "The datatype \"{:?}\" is still not supported in Rust implementation", - other - ))) - } + other => match other + .split(|c| c == ':' || c == ',') + .collect::>() + .as_slice() + { + ["d", precision, scale] => { + let parsed_precision = precision.parse::().map_err(|_| { + ArrowError::CDataInterface(format!( + "The decimal \"{:?}\" is not supported in Rust implementation", + other + )) + })?; + let parsed_scale = scale.parse::().map_err(|_| { + ArrowError::CDataInterface(format!( + "The decimal \"{:?}\" is not supported in Rust implementation", + other + )) + })?; + DataType::Decimal(parsed_precision, parsed_scale) + } + ["d", precision, scale, bits] => { + if *bits != "128" { + return Err(ArrowError::CDataInterface(format!( + "The decimal \"{:?}\" is still not supported in Rust implementation", + other + ))); + } + let parsed_precision = precision.parse::().map_err(|_| { + ArrowError::CDataInterface(format!( + "The decimal \"{:?}\" is not supported in Rust implementation", + other + )) + })?; + let parsed_scale = scale.parse::().map_err(|_| { + ArrowError::CDataInterface(format!( + "The decimal \"{:?}\" is not supported in Rust implementation", + other + )) + })?; + DataType::Decimal(parsed_precision, parsed_scale) + } + _ => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{:?}\" is still not supported in Rust implementation", + other + ))) + } + }, }; Ok(Field::new(schema.name(), data_type, schema.nullable())) } @@ -301,6 +343,9 @@ fn to_format(data_type: &DataType) -> Result { DataType::LargeBinary => "Z", DataType::Utf8 => "u", DataType::LargeUtf8 => "U", + DataType::Decimal(precision, scale) => { + return Ok(format!("d:{},{}", precision, scale)) + } DataType::Date32 => "tdD", DataType::Date64 => "tdm", DataType::Time32(TimeUnit::Second) => "tts", @@ -338,6 +383,7 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Int64, 1) | (DataType::Date64, 1) | (DataType::Time64(_), 1) => size_of::() * 8, (DataType::Float32, 1) => size_of::() * 8, (DataType::Float64, 1) => size_of::() * 8, + (DataType::Decimal(..), 1) => size_of::() * 8, // primitive types have a single buffer (DataType::Boolean, _) | (DataType::UInt8, _) | @@ -349,7 +395,8 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Int32, _) | (DataType::Date32, _) | (DataType::Time32(_), _) | (DataType::Int64, _) | (DataType::Date64, _) | (DataType::Time64(_), _) | (DataType::Float32, _) | - (DataType::Float64, _) => { + (DataType::Float64, _) | + (DataType::Decimal(..), _) => { return Err(ArrowError::CDataInterface(format!( "The datatype \"{:?}\" expects 2 buffers, but requested {}. Please verify that the C data interface is correctly implemented.", data_type, i @@ -829,9 +876,9 @@ impl<'a> ArrowArrayChild<'a> { mod tests { use super::*; use crate::array::{ - make_array, Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, - GenericBinaryArray, GenericListArray, GenericStringArray, Int32Array, - OffsetSizeTrait, StringOffsetSizeTrait, Time32MillisecondArray, + make_array, Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, DecimalArray, + DecimalBuilder, GenericBinaryArray, GenericListArray, GenericStringArray, + Int32Array, OffsetSizeTrait, StringOffsetSizeTrait, Time32MillisecondArray, }; use crate::compute::kernels; use crate::datatypes::Field; @@ -860,6 +907,32 @@ mod tests { // (drop/release) Ok(()) } + + #[test] + fn test_decimal_round_trip() -> Result<()> { + // create an array natively + let mut builder = DecimalBuilder::new(5, 6, 2); + builder.append_value(12345_i128).unwrap(); + builder.append_value(-12345_i128).unwrap(); + builder.append_null().unwrap(); + let original_array = builder.finish(); + + // export it + let array = ArrowArray::try_from(original_array.data().clone())?; + + // (simulate consumer) import it + let data = ArrayData::try_from(array)?; + let array = make_array(data); + + // perform some operation + let array = array.as_any().downcast_ref::().unwrap(); + + // verify + assert_eq!(array, &original_array); + + // (drop/release) + Ok(()) + } // case with nulls is tested in the docs, through the example on this module. fn test_generic_string() -> Result<()> {