From 3d913023c0abd61343210ab398a6deaf9ba4de72 Mon Sep 17 00:00:00 2001 From: ryan-jacobs1 <41557891+ryan-jacobs1@users.noreply.github.com> Date: Wed, 11 May 2022 22:19:00 -0400 Subject: [PATCH] support duration in ffi (#1689) --- .../tests/test_sql.py | 5 +- arrow/src/datatypes/ffi.rs | 8 ++++ arrow/src/ffi.rs | 46 +++++++++++++++++-- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 9d5b93679b6b..324956c9c6a6 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -55,6 +55,10 @@ def assert_pyarrow_leak(): pa.timestamp("us"), pa.timestamp("us", tz="UTC"), pa.timestamp("us", tz="Europe/Paris"), + pa.duration("s"), + pa.duration("ms"), + pa.duration("us"), + pa.duration("ns"), pa.float16(), pa.float32(), pa.float64(), @@ -86,7 +90,6 @@ def assert_pyarrow_leak(): _unsupported_pyarrow_types = [ pa.decimal256(76, 38), - pa.duration("s"), pa.map_(pa.string(), pa.int32()), pa.union( [pa.field("a", pa.binary(10)), pa.field("b", pa.string())], diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index bc274e2dc3b8..2f1b092a862b 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -52,6 +52,10 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { "ttm" => DataType::Time32(TimeUnit::Millisecond), "ttu" => DataType::Time64(TimeUnit::Microsecond), "ttn" => DataType::Time64(TimeUnit::Nanosecond), + "tDs" => DataType::Duration(TimeUnit::Second), + "tDm" => DataType::Duration(TimeUnit::Millisecond), + "tDu" => DataType::Duration(TimeUnit::Microsecond), + "tDn" => DataType::Duration(TimeUnit::Nanosecond), "+l" => { let c_child = c_schema.child(0); DataType::List(Box::new(Field::try_from(c_child)?)) @@ -251,6 +255,10 @@ fn get_format_string(dtype: &DataType) -> Result { DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => Ok(format!("tsm:{}", tz)), DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => Ok(format!("tsu:{}", tz)), DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => Ok(format!("tsn:{}", tz)), + DataType::Duration(TimeUnit::Second) => Ok("tDs".to_string()), + DataType::Duration(TimeUnit::Millisecond) => Ok("tDm".to_string()), + DataType::Duration(TimeUnit::Microsecond) => Ok("tDu".to_string()), + DataType::Duration(TimeUnit::Nanosecond) => Ok("tDn".to_string()), DataType::List(_) => Ok("+l".to_string()), DataType::LargeList(_) => Ok("+L".to_string()), DataType::Struct(_) => Ok("+s".to_string()), diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index 1fd706c36794..b1fa9f7bbd2a 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -314,6 +314,7 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Float64, 1) => size_of::() * 8, (DataType::Decimal(..), 1) => size_of::() * 8, (DataType::Timestamp(..), 1) => size_of::() * 8, + (DataType::Duration(..), 1) => size_of::() * 8, // primitive types have a single buffer (DataType::Boolean, _) | (DataType::UInt8, _) | @@ -327,7 +328,8 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Float32, _) | (DataType::Float64, _) | (DataType::Decimal(..), _) | - (DataType::Timestamp(..), _) => { + (DataType::Timestamp(..), _) | + (DataType::Duration(..), _) => { return Err(ArrowError::CDataInterface(format!( "The datatype \"{:?}\" expects 2 buffers, but requested {}. Please verify that the C data interface is correctly implemented.", data_type, i @@ -873,9 +875,9 @@ mod tests { use super::*; use crate::array::{ export_array_into_raw, make_array, Array, ArrayData, BooleanArray, DecimalArray, - DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, GenericBinaryArray, - GenericListArray, GenericStringArray, Int32Array, OffsetSizeTrait, - Time32MillisecondArray, TimestampMillisecondArray, + DictionaryArray, DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, + GenericBinaryArray, GenericListArray, GenericStringArray, Int32Array, + OffsetSizeTrait, Time32MillisecondArray, TimestampMillisecondArray, }; use crate::compute::kernels; use crate::datatypes::{Field, Int8Type}; @@ -1358,4 +1360,40 @@ mod tests { } Ok(()) } + + #[test] + fn test_duration() -> Result<()> { + // create an array natively + let array = DurationSecondArray::from(vec![None, Some(1), Some(2)]); + + // export it + let array = ArrowArray::try_from(array.data().clone())?; + + // (simulate consumer) import it + let data = ArrayData::try_from(array)?; + let array = make_array(data); + + // perform some operation + let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]).unwrap(); + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + // verify + assert_eq!( + array, + &DurationSecondArray::from(vec![ + None, + Some(1), + Some(2), + None, + Some(1), + Some(2) + ]) + ); + + // (drop/release) + Ok(()) + } }