diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 34e0449c4b7a..559c7c8a3961 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -325,7 +325,7 @@ jobs: python -m venv venv source venv/bin/activate - pip install maturin==0.8.2 toml==0.10.1 pyarrow==1.0.0 + pip install maturin==0.8.2 toml==0.10.1 pyarrow==1.0.0 pytz maturin develop python -m unittest discover tests diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index c0de382057c1..5524c54ec178 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -17,9 +17,12 @@ # under the License. import unittest +from datetime import date, datetime +from decimal import Decimal -import pyarrow import arrow_pyarrow_integration_testing +import pyarrow +from pytz import timezone class TestCase(unittest.TestCase): @@ -69,9 +72,45 @@ def test_time32_python(self): Python -> Rust -> Python """ old_allocated = pyarrow.total_allocated_bytes() - a = pyarrow.array([None, 1, 2], pyarrow.time32('s')) + a = pyarrow.array([None, 1, 2], pyarrow.time32("s")) + b = arrow_pyarrow_integration_testing.concatenate(a) + expected = pyarrow.array([None, 1, 2] + [None, 1, 2], pyarrow.time32("s")) + self.assertEqual(b, expected) + del a + del b + del expected + # No leak of C++ memory + self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) + + def test_date32_python(self): + """ + Python -> Rust -> Python + """ + old_allocated = pyarrow.total_allocated_bytes() + py_array = [None, date(1990, 3, 9), date(2021, 6, 20)] + a = pyarrow.array(py_array, pyarrow.date32()) + b = arrow_pyarrow_integration_testing.concatenate(a) + expected = pyarrow.array(py_array + py_array, pyarrow.date32()) + self.assertEqual(b, expected) + del a + del b + del expected + # No leak of C++ memory + self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) + + def test_timestamp_python(self): + """ + Python -> Rust -> Python + """ + old_allocated = pyarrow.total_allocated_bytes() + py_array = [ + None, + datetime(2021, 1, 1, 1, 1, 1, 1), + datetime(2020, 3, 9, 1, 1, 1, 1), + ] + a = pyarrow.array(py_array, pyarrow.timestamp("us")) b = arrow_pyarrow_integration_testing.concatenate(a) - expected = pyarrow.array([None, 1, 2] + [None, 1, 2], pyarrow.time32('s')) + expected = pyarrow.array(py_array + py_array, pyarrow.timestamp("us")) self.assertEqual(b, expected) del a del b @@ -79,6 +118,42 @@ def test_time32_python(self): # No leak of C++ memory self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) + def test_timestamp_tz_python(self): + """ + Python -> Rust -> Python + """ + old_allocated = pyarrow.total_allocated_bytes() + py_array = [ + None, + datetime(2021, 1, 1, 1, 1, 1, 1, tzinfo=timezone("America/New_York")), + datetime(2020, 3, 9, 1, 1, 1, 1, tzinfo=timezone("America/New_York")), + ] + a = pyarrow.array(py_array, pyarrow.timestamp("us", tz="America/New_York")) + b = arrow_pyarrow_integration_testing.concatenate(a) + expected = pyarrow.array( + py_array + py_array, pyarrow.timestamp("us", tz="America/New_York") + ) + self.assertEqual(b, expected) + del a + del b + del expected + # No leak of C++ memory + self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) + + def test_decimal_python(self): + """ + Python -> Rust -> Python + """ + old_allocated = pyarrow.total_allocated_bytes() + py_array = [round(Decimal(123.45), 2), round(Decimal(-123.45), 2), None] + a = pyarrow.array(py_array, pyarrow.decimal128(6, 2)) + b = arrow_pyarrow_integration_testing.round_trip(a) + self.assertEqual(a, b) + del a + del b + # No leak of C++ memory + self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) + def test_list_array(self): """ Python -> Rust -> Python @@ -94,6 +169,3 @@ def test_list_array(self): del b # No leak of C++ memory self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - - diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index a16a096c00ba..b804dd2db74a 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -271,11 +271,75 @@ fn to_field(schema: &FFI_ArrowSchema) -> Result { .collect::>>()?; DataType::Struct(children) } + // Parametrized types, requiring string parse other => { - return Err(ArrowError::CDataInterface(format!( - "The datatype \"{:?}\" is still not supported in Rust implementation", - other - ))) + match other.splitn(2, ':').collect::>().as_slice() { + // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth" + ["d", extra] => { + match extra.splitn(3, ',').collect::>().as_slice() { + [precision, scale] => { + let parsed_precision = precision.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer precision".to_string(), + ) + })?; + let parsed_scale = scale.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer scale".to_string(), + ) + })?; + DataType::Decimal(parsed_precision, parsed_scale) + }, + [precision, scale, bits] => { + if *bits != "128" { + return Err(ArrowError::CDataInterface("Only 128 bit wide decimal is supported in the Rust implementation".to_string())); + } + let parsed_precision = precision.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer precision".to_string(), + ) + })?; + let parsed_scale = scale.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer scale".to_string(), + ) + })?; + DataType::Decimal(parsed_precision, parsed_scale) + } + _ => { + return Err(ArrowError::CDataInterface(format!( + "The decimal pattern \"d:{:?}\" is not supported in the Rust implementation", + extra + ))) + } + } + } + + // Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp. + ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None), + ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None), + ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None), + ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None), + ["tss", tz] => { + DataType::Timestamp(TimeUnit::Second, Some(tz.to_string())) + } + ["tsm", tz] => { + DataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string())) + } + ["tsu", tz] => { + DataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string())) + } + ["tsn", tz] => { + DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string())) + } + + _ => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{:?}\" is still not supported in Rust implementation", + other + ))) + } + } } }; Ok(Field::new(schema.name(), data_type, schema.nullable())) @@ -301,12 +365,31 @@ fn to_format(data_type: &DataType) -> Result { DataType::LargeBinary => "Z", DataType::Utf8 => "u", DataType::LargeUtf8 => "U", + DataType::Decimal(precision, scale) => { + return Ok(format!("d:{},{}", precision, scale)) + } DataType::Date32 => "tdD", DataType::Date64 => "tdm", DataType::Time32(TimeUnit::Second) => "tts", DataType::Time32(TimeUnit::Millisecond) => "ttm", DataType::Time64(TimeUnit::Microsecond) => "ttu", DataType::Time64(TimeUnit::Nanosecond) => "ttn", + DataType::Timestamp(TimeUnit::Second, None) => "tss:", + DataType::Timestamp(TimeUnit::Millisecond, None) => "tsm:", + DataType::Timestamp(TimeUnit::Microsecond, None) => "tsu:", + DataType::Timestamp(TimeUnit::Nanosecond, None) => "tsn:", + DataType::Timestamp(TimeUnit::Second, Some(tz)) => { + return Ok(format!("tss:{}", tz)) + } + DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => { + return Ok(format!("tsm:{}", tz)) + } + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { + return Ok(format!("tsu:{}", tz)) + } + DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => { + return Ok(format!("tsn:{}", tz)) + } DataType::List(_) => "+l", DataType::LargeList(_) => "+L", DataType::Struct(_) => "+s", @@ -338,6 +421,8 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Int64, 1) | (DataType::Date64, 1) | (DataType::Time64(_), 1) => size_of::() * 8, (DataType::Float32, 1) => size_of::() * 8, (DataType::Float64, 1) => size_of::() * 8, + (DataType::Decimal(..), 1) => size_of::() * 8, + (DataType::Timestamp(..), 1) => size_of::() * 8, // primitive types have a single buffer (DataType::Boolean, _) | (DataType::UInt8, _) | @@ -349,7 +434,9 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Int32, _) | (DataType::Date32, _) | (DataType::Time32(_), _) | (DataType::Int64, _) | (DataType::Date64, _) | (DataType::Time64(_), _) | (DataType::Float32, _) | - (DataType::Float64, _) => { + (DataType::Float64, _) | + (DataType::Decimal(..), _) | + (DataType::Timestamp(..), _) => { return Err(ArrowError::CDataInterface(format!( "The datatype \"{:?}\" expects 2 buffers, but requested {}. Please verify that the C data interface is correctly implemented.", data_type, i @@ -829,9 +916,10 @@ impl<'a> ArrowArrayChild<'a> { mod tests { use super::*; use crate::array::{ - make_array, Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, - GenericBinaryArray, GenericListArray, GenericStringArray, Int32Array, - OffsetSizeTrait, StringOffsetSizeTrait, Time32MillisecondArray, + make_array, Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, DecimalArray, + DecimalBuilder, GenericBinaryArray, GenericListArray, GenericStringArray, + Int32Array, OffsetSizeTrait, StringOffsetSizeTrait, Time32MillisecondArray, + TimestampMillisecondArray, }; use crate::compute::kernels; use crate::datatypes::Field; @@ -859,6 +947,32 @@ mod tests { // (drop/release) Ok(()) } + + #[test] + fn test_decimal_round_trip() -> Result<()> { + // create an array natively + let mut builder = DecimalBuilder::new(5, 6, 2); + builder.append_value(12345_i128).unwrap(); + builder.append_value(-12345_i128).unwrap(); + builder.append_null().unwrap(); + let original_array = builder.finish(); + + // export it + let array = ArrowArray::try_from(original_array.data().clone())?; + + // (simulate consumer) import it + let data = ArrayData::try_from(array)?; + let array = make_array(data); + + // perform some operation + let array = array.as_any().downcast_ref::().unwrap(); + + // verify + assert_eq!(array, &original_array); + + // (drop/release) + Ok(()) + } // case with nulls is tested in the docs, through the example on this module. fn test_generic_string() -> Result<()> { @@ -1077,4 +1191,40 @@ mod tests { // (drop/release) Ok(()) } + + #[test] + fn test_timestamp() -> Result<()> { + // create an array natively + let array = TimestampMillisecondArray::from(vec![None, Some(1), Some(2)]); + + // export it + let array = ArrowArray::try_from(array.data().clone())?; + + // (simulate consumer) import it + let data = ArrayData::try_from(array)?; + let array = make_array(data); + + // perform some operation + let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]).unwrap(); + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + // verify + assert_eq!( + array, + &TimestampMillisecondArray::from(vec![ + None, + Some(1), + Some(2), + None, + Some(1), + Some(2) + ]) + ); + + // (drop/release) + Ok(()) + } }