Skip to content

Commit

Permalink
Add C data interface for decimal128
Browse files Browse the repository at this point in the history
  • Loading branch information
alippai committed Jun 13, 2021
1 parent e21f576 commit 6d9ea41
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 15 deletions.
24 changes: 19 additions & 5 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.

import unittest
from decimal import Decimal

import pyarrow
import arrow_pyarrow_integration_testing
Expand Down Expand Up @@ -69,9 +70,25 @@ def test_time32_python(self):
Python -> Rust -> Python
"""
old_allocated = pyarrow.total_allocated_bytes()
a = pyarrow.array([None, 1, 2], pyarrow.time32('s'))
a = pyarrow.array([None, 1, 2], pyarrow.time32("s"))
b = arrow_pyarrow_integration_testing.concatenate(a)
expected = pyarrow.array([None, 1, 2] + [None, 1, 2], pyarrow.time32('s'))
expected = pyarrow.array([None, 1, 2] + [None, 1, 2], pyarrow.time32("s"))
self.assertEqual(b, expected)
del a
del b
del expected
# No leak of C++ memory
self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())

def test_decimal_python(self):
"""
Python -> Rust -> Python
"""
old_allocated = pyarrow.total_allocated_bytes()
py_array = [None, Decimal(123.45), Decimal(-123.45)]
a = pyarrow.array(py_array, pyarrow.decimal128(6, 2))
b = arrow_pyarrow_integration_testing.concatenate(a)
expected = pyarrow.array(py_array + py_array, pyarrow.decimal128(6, 2))
self.assertEqual(b, expected)
del a
del b
Expand All @@ -94,6 +111,3 @@ def test_list_array(self):
del b
# No leak of C++ memory
self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())



93 changes: 83 additions & 10 deletions arrow/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,54 @@ fn to_field(schema: &FFI_ArrowSchema) -> Result<Field> {
.collect::<Result<Vec<_>>>()?;
DataType::Struct(children)
}
other => {
return Err(ArrowError::CDataInterface(format!(
"The datatype \"{:?}\" is still not supported in Rust implementation",
other
)))
}
other => match other
.split(|c| c == ':' || c == ',')
.collect::<Vec<&str>>()
.as_slice()
{
["d", precision, scale] => {
let parsed_precision = precision.parse::<usize>().map_err(|_| {
ArrowError::CDataInterface(format!(
"The decimal \"{:?}\" is not supported in Rust implementation",
other
))
})?;
let parsed_scale = scale.parse::<usize>().map_err(|_| {
ArrowError::CDataInterface(format!(
"The decimal \"{:?}\" is not supported in Rust implementation",
other
))
})?;
DataType::Decimal(parsed_precision, parsed_scale)
}
["d", precision, scale, bits] => {
if *bits != "128" {
return Err(ArrowError::CDataInterface(format!(
"The decimal \"{:?}\" is still not supported in Rust implementation",
other
)));
}
let parsed_precision = precision.parse::<usize>().map_err(|_| {
ArrowError::CDataInterface(format!(
"The decimal \"{:?}\" is not supported in Rust implementation",
other
))
})?;
let parsed_scale = scale.parse::<usize>().map_err(|_| {
ArrowError::CDataInterface(format!(
"The decimal \"{:?}\" is not supported in Rust implementation",
other
))
})?;
DataType::Decimal(parsed_precision, parsed_scale)
}
_ => {
return Err(ArrowError::CDataInterface(format!(
"The datatype \"{:?}\" is still not supported in Rust implementation",
other
)))
}
},
};
Ok(Field::new(schema.name(), data_type, schema.nullable()))
}
Expand All @@ -301,6 +343,9 @@ fn to_format(data_type: &DataType) -> Result<String> {
DataType::LargeBinary => "Z",
DataType::Utf8 => "u",
DataType::LargeUtf8 => "U",
DataType::Decimal(precision, scale) => {
return Ok(format!("d:{},{}", precision, scale))
}
DataType::Date32 => "tdD",
DataType::Date64 => "tdm",
DataType::Time32(TimeUnit::Second) => "tts",
Expand Down Expand Up @@ -338,6 +383,7 @@ fn bit_width(data_type: &DataType, i: usize) -> Result<usize> {
(DataType::Int64, 1) | (DataType::Date64, 1) | (DataType::Time64(_), 1) => size_of::<i64>() * 8,
(DataType::Float32, 1) => size_of::<f32>() * 8,
(DataType::Float64, 1) => size_of::<f64>() * 8,
(DataType::Decimal(..), 1) => size_of::<i128>() * 8,
// primitive types have a single buffer
(DataType::Boolean, _) |
(DataType::UInt8, _) |
Expand All @@ -349,7 +395,8 @@ fn bit_width(data_type: &DataType, i: usize) -> Result<usize> {
(DataType::Int32, _) | (DataType::Date32, _) | (DataType::Time32(_), _) |
(DataType::Int64, _) | (DataType::Date64, _) | (DataType::Time64(_), _) |
(DataType::Float32, _) |
(DataType::Float64, _) => {
(DataType::Float64, _) |
(DataType::Decimal(..), _) => {
return Err(ArrowError::CDataInterface(format!(
"The datatype \"{:?}\" expects 2 buffers, but requested {}. Please verify that the C data interface is correctly implemented.",
data_type, i
Expand Down Expand Up @@ -829,9 +876,9 @@ impl<'a> ArrowArrayChild<'a> {
mod tests {
use super::*;
use crate::array::{
make_array, Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray,
GenericBinaryArray, GenericListArray, GenericStringArray, Int32Array,
OffsetSizeTrait, StringOffsetSizeTrait, Time32MillisecondArray,
make_array, Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, DecimalArray,
DecimalBuilder, GenericBinaryArray, GenericListArray, GenericStringArray,
Int32Array, OffsetSizeTrait, StringOffsetSizeTrait, Time32MillisecondArray,
};
use crate::compute::kernels;
use crate::datatypes::Field;
Expand Down Expand Up @@ -860,6 +907,32 @@ mod tests {
// (drop/release)
Ok(())
}

#[test]
fn test_decimal_round_trip() -> Result<()> {
// create an array natively
let mut builder = DecimalBuilder::new(5, 6, 2);
builder.append_value(12345_i128).unwrap();
builder.append_value(-12345_i128).unwrap();
builder.append_null().unwrap();
let original_array = builder.finish();

// export it
let array = ArrowArray::try_from(original_array.data().clone())?;

// (simulate consumer) import it
let data = ArrayData::try_from(array)?;
let array = make_array(data);

// perform some operation
let array = array.as_any().downcast_ref::<DecimalArray>().unwrap();

// verify
assert_eq!(array, &original_array);

// (drop/release)
Ok(())
}
// case with nulls is tested in the docs, through the example on this module.

fn test_generic_string<Offset: StringOffsetSizeTrait>() -> Result<()> {
Expand Down

0 comments on commit 6d9ea41

Please sign in to comment.