diff --git a/arrow-data/Cargo.toml b/arrow-data/Cargo.toml index ca50d8a12aee..a1938af4b194 100644 --- a/arrow-data/Cargo.toml +++ b/arrow-data/Cargo.toml @@ -42,6 +42,11 @@ bench = false # this is not enabled by default as it is too computationally expensive # but is run as part of our CI checks force_validate = [] +# Enable ffi support +ffi = ["arrow-schema/ffi"] + +[package.metadata.docs.rs] +features = ["ffi"] [dependencies] diff --git a/arrow-data/src/ffi.rs b/arrow-data/src/ffi.rs new file mode 100644 index 000000000000..e506653bb59b --- /dev/null +++ b/arrow-data/src/ffi.rs @@ -0,0 +1,285 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains declarations to bind to the [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html). + +use crate::{layout, ArrayData}; +use arrow_buffer::Buffer; +use arrow_schema::DataType; +use std::ffi::c_void; + +/// ABI-compatible struct for ArrowArray from C Data Interface +/// See +/// +/// ``` +/// # use arrow_data::ArrayData; +/// # use arrow_data::ffi::FFI_ArrowArray; +/// fn export_array(array: &ArrayData) -> FFI_ArrowArray { +/// FFI_ArrowArray::new(array) +/// } +/// ``` +#[repr(C)] +#[derive(Debug)] +pub struct FFI_ArrowArray { + length: i64, + null_count: i64, + offset: i64, + n_buffers: i64, + n_children: i64, + buffers: *mut *const c_void, + children: *mut *mut FFI_ArrowArray, + dictionary: *mut FFI_ArrowArray, + release: Option, + // When exported, this MUST contain everything that is owned by this array. + // for example, any buffer pointed to in `buffers` must be here, as well + // as the `buffers` pointer itself. + // In other words, everything in [FFI_ArrowArray] must be owned by + // `private_data` and can assume that they do not outlive `private_data`. + private_data: *mut c_void, +} + +impl Drop for FFI_ArrowArray { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +unsafe impl Send for FFI_ArrowArray {} +unsafe impl Sync for FFI_ArrowArray {} + +// callback used to drop [FFI_ArrowArray] when it is exported +unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) { + if array.is_null() { + return; + } + let array = &mut *array; + + // take ownership of `private_data`, therefore dropping it` + let private = Box::from_raw(array.private_data as *mut ArrayPrivateData); + for child in private.children.iter() { + let _ = Box::from_raw(*child); + } + if !private.dictionary.is_null() { + let _ = Box::from_raw(private.dictionary); + } + + array.release = None; +} + +struct ArrayPrivateData { + #[allow(dead_code)] + buffers: Vec>, + buffers_ptr: Box<[*const c_void]>, + children: Box<[*mut FFI_ArrowArray]>, + dictionary: *mut FFI_ArrowArray, +} + +impl FFI_ArrowArray { + /// creates a new `FFI_ArrowArray` from existing data. + /// # Memory Leaks + /// This method releases `buffers`. Consumers of this struct *must* call `release` before + /// releasing this struct, or contents in `buffers` leak. + pub fn new(data: &ArrayData) -> Self { + let data_layout = layout(data.data_type()); + + let buffers = if data_layout.can_contain_null_mask { + // * insert the null buffer at the start + // * make all others `Option`. + std::iter::once(data.null_buffer().cloned()) + .chain(data.buffers().iter().map(|b| Some(b.clone()))) + .collect::>() + } else { + data.buffers().iter().map(|b| Some(b.clone())).collect() + }; + + // `n_buffers` is the number of buffers by the spec. + let n_buffers = { + data_layout.buffers.len() + { + // If the layout has a null buffer by Arrow spec. + // Note that even the array doesn't have a null buffer because it has + // no null value, we still need to count 1 here to follow the spec. + usize::from(data_layout.can_contain_null_mask) + } + } as i64; + + let buffers_ptr = buffers + .iter() + .flat_map(|maybe_buffer| match maybe_buffer { + // note that `raw_data` takes into account the buffer's offset + Some(b) => Some(b.as_ptr() as *const c_void), + // This is for null buffer. We only put a null pointer for + // null buffer if by spec it can contain null mask. + None if data_layout.can_contain_null_mask => Some(std::ptr::null()), + None => None, + }) + .collect::>(); + + let empty = vec![]; + let (child_data, dictionary) = match data.data_type() { + DataType::Dictionary(_, _) => ( + empty.as_slice(), + Box::into_raw(Box::new(FFI_ArrowArray::new(&data.child_data()[0]))), + ), + _ => (data.child_data(), std::ptr::null_mut()), + }; + + let children = child_data + .iter() + .map(|child| Box::into_raw(Box::new(FFI_ArrowArray::new(child)))) + .collect::>(); + let n_children = children.len() as i64; + + // create the private data owning everything. + // any other data must be added here, e.g. via a struct, to track lifetime. + let mut private_data = Box::new(ArrayPrivateData { + buffers, + buffers_ptr, + children, + dictionary, + }); + + Self { + length: data.len() as i64, + null_count: data.null_count() as i64, + offset: data.offset() as i64, + n_buffers, + n_children, + buffers: private_data.buffers_ptr.as_mut_ptr(), + children: private_data.children.as_mut_ptr(), + dictionary, + release: Some(release_array), + private_data: Box::into_raw(private_data) as *mut c_void, + } + } + + /// create an empty `FFI_ArrowArray`, which can be used to import data into + pub fn empty() -> Self { + Self { + length: 0, + null_count: 0, + offset: 0, + n_buffers: 0, + n_children: 0, + buffers: std::ptr::null_mut(), + children: std::ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: None, + private_data: std::ptr::null_mut(), + } + } + + /// the length of the array + #[inline] + pub fn len(&self) -> usize { + self.length as usize + } + + /// whether the array is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Whether the array has been released + #[inline] + pub fn is_released(&self) -> bool { + self.release.is_none() + } + + /// the offset of the array + #[inline] + pub fn offset(&self) -> usize { + self.offset as usize + } + + /// the null count of the array + #[inline] + pub fn null_count(&self) -> usize { + self.null_count as usize + } + + /// Returns the buffer at the provided index + /// + /// # Panic + /// Panics if index exceeds the number of buffers or the buffer is not correctly aligned + #[inline] + pub fn buffer(&self, index: usize) -> *const u8 { + assert!(!self.buffers.is_null()); + assert!(index < self.num_buffers()); + // SAFETY: + // If buffers is not null must be valid for reads up to num_buffers + unsafe { std::ptr::read_unaligned((self.buffers as *mut *const u8).add(index)) } + } + + /// Returns the number of buffers + #[inline] + pub fn num_buffers(&self) -> usize { + self.n_buffers as _ + } + + /// Returns the child at the provided index + #[inline] + pub fn child(&self, index: usize) -> &FFI_ArrowArray { + assert!(!self.children.is_null()); + assert!(index < self.num_children()); + // Safety: + // If children is not null must be valid for reads up to num_children + unsafe { + let child = std::ptr::read_unaligned(self.children.add(index)); + child.as_ref().unwrap() + } + } + + /// Returns the number of children + #[inline] + pub fn num_children(&self) -> usize { + self.n_children as _ + } + + /// Returns the dictionary if any + #[inline] + pub fn dictionary(&self) -> Option<&Self> { + // Safety: + // If dictionary is not null should be valid for reads of `Self` + unsafe { self.dictionary.as_ref() } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // More tests located in top-level arrow crate + + #[test] + fn null_array_n_buffers() { + let data = ArrayData::new_null(&DataType::Null, 10); + + let ffi_array = FFI_ArrowArray::new(&data); + assert_eq!(0, ffi_array.n_buffers); + + let private_data = + unsafe { Box::from_raw(ffi_array.private_data as *mut ArrayPrivateData) }; + + assert_eq!(0, private_data.buffers_ptr.len()); + + Box::into_raw(private_data); + } +} diff --git a/arrow-data/src/lib.rs b/arrow-data/src/lib.rs index 58571e181176..b37a8c5da72f 100644 --- a/arrow-data/src/lib.rs +++ b/arrow-data/src/lib.rs @@ -28,3 +28,6 @@ pub mod transform; pub mod bit_iterator; pub mod bit_mask; pub mod decimal; + +#[cfg(feature = "ffi")] +pub mod ffi; diff --git a/arrow-schema/Cargo.toml b/arrow-schema/Cargo.toml index 1a25c1022195..e4e7d0082eb8 100644 --- a/arrow-schema/Cargo.toml +++ b/arrow-schema/Cargo.toml @@ -39,9 +39,14 @@ bench = false [dependencies] serde = { version = "1.0", default-features = false, features = ["derive", "std"], optional = true } +bitflags = { version = "1.2.1", default-features = false, optional = true } [features] -default = [] +# Enable ffi support +ffi = ["bitflags"] + +[package.metadata.docs.rs] +features = ["ffi"] [dev-dependencies] serde_json = "1.0" diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs new file mode 100644 index 000000000000..8e58e3158c8b --- /dev/null +++ b/arrow-schema/src/ffi.rs @@ -0,0 +1,703 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains declarations to bind to the [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html). +//! +//! ``` +//! # use arrow_schema::{DataType, Field, Schema}; +//! # use arrow_schema::ffi::FFI_ArrowSchema; +//! +//! // Create from data type +//! let ffi_data_type = FFI_ArrowSchema::try_from(&DataType::LargeUtf8).unwrap(); +//! let back = DataType::try_from(&ffi_data_type).unwrap(); +//! assert_eq!(back, DataType::LargeUtf8); +//! +//! // Create from schema +//! let schema = Schema::new(vec![Field::new("foo", DataType::Int64, false)]); +//! let ffi_schema = FFI_ArrowSchema::try_from(&schema).unwrap(); +//! let back = Schema::try_from(&ffi_schema).unwrap(); +//! +//! assert_eq!(schema, back); +//! ``` + +use crate::{ArrowError, DataType, Field, Schema, TimeUnit, UnionMode}; +use bitflags::bitflags; +use std::ffi::{c_char, c_void, CStr, CString}; + +bitflags! { + pub struct Flags: i64 { + const DICTIONARY_ORDERED = 0b00000001; + const NULLABLE = 0b00000010; + const MAP_KEYS_SORTED = 0b00000100; + } +} + +/// ABI-compatible struct for `ArrowSchema` from C Data Interface +/// See +/// +/// ``` +/// # use arrow_schema::DataType; +/// # use arrow_schema::ffi::FFI_ArrowSchema; +/// fn array_schema(data_type: &DataType) -> FFI_ArrowSchema { +/// FFI_ArrowSchema::try_from(data_type).unwrap() +/// } +/// ``` +/// +#[repr(C)] +#[derive(Debug)] +pub struct FFI_ArrowSchema { + format: *const c_char, + name: *const c_char, + metadata: *const c_char, + flags: i64, + n_children: i64, + children: *mut *mut FFI_ArrowSchema, + dictionary: *mut FFI_ArrowSchema, + release: Option, + private_data: *mut c_void, +} + +struct SchemaPrivateData { + children: Box<[*mut FFI_ArrowSchema]>, + dictionary: *mut FFI_ArrowSchema, +} + +// callback used to drop [FFI_ArrowSchema] when it is exported. +unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) { + if schema.is_null() { + return; + } + let schema = &mut *schema; + + // take ownership back to release it. + drop(CString::from_raw(schema.format as *mut c_char)); + if !schema.name.is_null() { + drop(CString::from_raw(schema.name as *mut c_char)); + } + if !schema.private_data.is_null() { + let private_data = Box::from_raw(schema.private_data as *mut SchemaPrivateData); + for child in private_data.children.iter() { + drop(Box::from_raw(*child)) + } + if !private_data.dictionary.is_null() { + drop(Box::from_raw(private_data.dictionary)); + } + + drop(private_data); + } + + schema.release = None; +} + +impl FFI_ArrowSchema { + /// create a new [`FFI_ArrowSchema`]. This fails if the fields' + /// [`DataType`] is not supported. + pub fn try_new( + format: &str, + children: Vec, + dictionary: Option, + ) -> Result { + let mut this = Self::empty(); + + let children_ptr = children + .into_iter() + .map(Box::new) + .map(Box::into_raw) + .collect::>(); + + this.format = CString::new(format).unwrap().into_raw(); + this.release = Some(release_schema); + this.n_children = children_ptr.len() as i64; + + let dictionary_ptr = dictionary + .map(|d| Box::into_raw(Box::new(d))) + .unwrap_or(std::ptr::null_mut()); + + let mut private_data = Box::new(SchemaPrivateData { + children: children_ptr, + dictionary: dictionary_ptr, + }); + + // intentionally set from private_data (see https://github.com/apache/arrow-rs/issues/580) + this.children = private_data.children.as_mut_ptr(); + + this.dictionary = dictionary_ptr; + + this.private_data = Box::into_raw(private_data) as *mut c_void; + + Ok(this) + } + + pub fn with_name(mut self, name: &str) -> Result { + self.name = CString::new(name).unwrap().into_raw(); + Ok(self) + } + + pub fn with_flags(mut self, flags: Flags) -> Result { + self.flags = flags.bits(); + Ok(self) + } + + pub fn empty() -> Self { + Self { + format: std::ptr::null_mut(), + name: std::ptr::null_mut(), + metadata: std::ptr::null_mut(), + flags: 0, + n_children: 0, + children: std::ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: None, + private_data: std::ptr::null_mut(), + } + } + + /// returns the format of this schema. + pub fn format(&self) -> &str { + assert!(!self.format.is_null()); + // safe because the lifetime of `self.format` equals `self` + unsafe { CStr::from_ptr(self.format) } + .to_str() + .expect("The external API has a non-utf8 as format") + } + + /// returns the name of this schema. + pub fn name(&self) -> &str { + assert!(!self.name.is_null()); + // safe because the lifetime of `self.name` equals `self` + unsafe { CStr::from_ptr(self.name) } + .to_str() + .expect("The external API has a non-utf8 as name") + } + + pub fn flags(&self) -> Option { + Flags::from_bits(self.flags) + } + + pub fn child(&self, index: usize) -> &Self { + assert!(index < self.n_children as usize); + unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() } + } + + pub fn children(&self) -> impl Iterator { + (0..self.n_children as usize).map(move |i| self.child(i)) + } + + pub fn nullable(&self) -> bool { + (self.flags / 2) & 1 == 1 + } + + pub fn dictionary(&self) -> Option<&Self> { + unsafe { self.dictionary.as_ref() } + } + + pub fn map_keys_sorted(&self) -> bool { + self.flags & 0b00000100 != 0 + } + + pub fn dictionary_ordered(&self) -> bool { + self.flags & 0b00000001 != 0 + } +} + +impl Drop for FFI_ArrowSchema { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +impl TryFrom<&FFI_ArrowSchema> for DataType { + type Error = ArrowError; + + /// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings) + fn try_from(c_schema: &FFI_ArrowSchema) -> Result { + let mut dtype = match c_schema.format() { + "n" => DataType::Null, + "b" => DataType::Boolean, + "c" => DataType::Int8, + "C" => DataType::UInt8, + "s" => DataType::Int16, + "S" => DataType::UInt16, + "i" => DataType::Int32, + "I" => DataType::UInt32, + "l" => DataType::Int64, + "L" => DataType::UInt64, + "e" => DataType::Float16, + "f" => DataType::Float32, + "g" => DataType::Float64, + "z" => DataType::Binary, + "Z" => DataType::LargeBinary, + "u" => DataType::Utf8, + "U" => DataType::LargeUtf8, + "tdD" => DataType::Date32, + "tdm" => DataType::Date64, + "tts" => DataType::Time32(TimeUnit::Second), + "ttm" => DataType::Time32(TimeUnit::Millisecond), + "ttu" => DataType::Time64(TimeUnit::Microsecond), + "ttn" => DataType::Time64(TimeUnit::Nanosecond), + "tDs" => DataType::Duration(TimeUnit::Second), + "tDm" => DataType::Duration(TimeUnit::Millisecond), + "tDu" => DataType::Duration(TimeUnit::Microsecond), + "tDn" => DataType::Duration(TimeUnit::Nanosecond), + "+l" => { + let c_child = c_schema.child(0); + DataType::List(Box::new(Field::try_from(c_child)?)) + } + "+L" => { + let c_child = c_schema.child(0); + DataType::LargeList(Box::new(Field::try_from(c_child)?)) + } + "+s" => { + let fields = c_schema.children().map(Field::try_from); + DataType::Struct(fields.collect::, ArrowError>>()?) + } + "+m" => { + let c_child = c_schema.child(0); + let map_keys_sorted = c_schema.map_keys_sorted(); + DataType::Map(Box::new(Field::try_from(c_child)?), map_keys_sorted) + } + // Parametrized types, requiring string parse + other => { + match other.splitn(2, ':').collect::>().as_slice() { + // FixedSizeBinary type in format "w:num_bytes" + ["w", num_bytes] => { + let parsed_num_bytes = num_bytes.parse::().map_err(|_| { + ArrowError::CDataInterface( + "FixedSizeBinary requires an integer parameter representing number of bytes per element".to_string()) + })?; + DataType::FixedSizeBinary(parsed_num_bytes) + }, + // FixedSizeList type in format "+w:num_elems" + ["+w", num_elems] => { + let c_child = c_schema.child(0); + let parsed_num_elems = num_elems.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The FixedSizeList type requires an integer parameter representing number of elements per list".to_string()) + })?; + DataType::FixedSizeList(Box::new(Field::try_from(c_child)?), parsed_num_elems) + }, + // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth" + ["d", extra] => { + match extra.splitn(3, ',').collect::>().as_slice() { + [precision, scale] => { + let parsed_precision = precision.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer precision".to_string(), + ) + })?; + let parsed_scale = scale.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer scale".to_string(), + ) + })?; + DataType::Decimal128(parsed_precision, parsed_scale) + }, + [precision, scale, bits] => { + if *bits != "128" && *bits != "256" { + return Err(ArrowError::CDataInterface("Only 128/256 bit wide decimal is supported in the Rust implementation".to_string())); + } + let parsed_precision = precision.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer precision".to_string(), + ) + })?; + let parsed_scale = scale.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer scale".to_string(), + ) + })?; + if *bits == "128" { + DataType::Decimal128(parsed_precision, parsed_scale) + } else { + DataType::Decimal256(parsed_precision, parsed_scale) + } + } + _ => { + return Err(ArrowError::CDataInterface(format!( + "The decimal pattern \"d:{extra:?}\" is not supported in the Rust implementation" + ))) + } + } + } + // DenseUnion + ["+ud", extra] => { + let type_ids = extra.split(',').map(|t| t.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The Union type requires an integer type id".to_string(), + ) + })).collect::, ArrowError>>()?; + let mut fields = Vec::with_capacity(type_ids.len()); + for idx in 0..c_schema.n_children { + let c_child = c_schema.child(idx as usize); + let field = Field::try_from(c_child)?; + fields.push(field); + } + + if fields.len() != type_ids.len() { + return Err(ArrowError::CDataInterface( + "The Union type requires same number of fields and type ids".to_string(), + )); + } + + DataType::Union(fields, type_ids, UnionMode::Dense) + } + // SparseUnion + ["+us", extra] => { + let type_ids = extra.split(',').map(|t| t.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The Union type requires an integer type id".to_string(), + ) + })).collect::, ArrowError>>()?; + let mut fields = Vec::with_capacity(type_ids.len()); + for idx in 0..c_schema.n_children { + let c_child = c_schema.child(idx as usize); + let field = Field::try_from(c_child)?; + fields.push(field); + } + + if fields.len() != type_ids.len() { + return Err(ArrowError::CDataInterface( + "The Union type requires same number of fields and type ids".to_string(), + )); + } + + DataType::Union(fields, type_ids, UnionMode::Sparse) + } + + // Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp. + ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None), + ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None), + ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None), + ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None), + ["tss", tz] => { + DataType::Timestamp(TimeUnit::Second, Some(tz.to_string())) + } + ["tsm", tz] => { + DataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string())) + } + ["tsu", tz] => { + DataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string())) + } + ["tsn", tz] => { + DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string())) + } + _ => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{other:?}\" is still not supported in Rust implementation" + ))) + } + } + } + }; + + if let Some(dict_schema) = c_schema.dictionary() { + let value_type = Self::try_from(dict_schema)?; + dtype = DataType::Dictionary(Box::new(dtype), Box::new(value_type)); + } + + Ok(dtype) + } +} + +impl TryFrom<&FFI_ArrowSchema> for Field { + type Error = ArrowError; + + fn try_from(c_schema: &FFI_ArrowSchema) -> Result { + let dtype = DataType::try_from(c_schema)?; + let field = Field::new(c_schema.name(), dtype, c_schema.nullable()); + Ok(field) + } +} + +impl TryFrom<&FFI_ArrowSchema> for Schema { + type Error = ArrowError; + + fn try_from(c_schema: &FFI_ArrowSchema) -> Result { + // interpret it as a struct type then extract its fields + let dtype = DataType::try_from(c_schema)?; + if let DataType::Struct(fields) = dtype { + Ok(Schema::new(fields)) + } else { + Err(ArrowError::CDataInterface( + "Unable to interpret C data struct as a Schema".to_string(), + )) + } + } +} + +impl TryFrom<&DataType> for FFI_ArrowSchema { + type Error = ArrowError; + + /// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings) + fn try_from(dtype: &DataType) -> Result { + let format = get_format_string(dtype)?; + // allocate and hold the children + let children = match dtype { + DataType::List(child) + | DataType::LargeList(child) + | DataType::FixedSizeList(child, _) + | DataType::Map(child, _) => { + vec![FFI_ArrowSchema::try_from(child.as_ref())?] + } + DataType::Union(fields, _, _) => fields + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::, ArrowError>>()?, + DataType::Struct(fields) => fields + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::, ArrowError>>()?, + _ => vec![], + }; + let dictionary = if let DataType::Dictionary(_, value_data_type) = dtype { + Some(Self::try_from(value_data_type.as_ref())?) + } else { + None + }; + + let flags = match dtype { + DataType::Map(_, true) => Flags::MAP_KEYS_SORTED, + _ => Flags::empty(), + }; + + FFI_ArrowSchema::try_new(&format, children, dictionary)?.with_flags(flags) + } +} + +fn get_format_string(dtype: &DataType) -> Result { + match dtype { + DataType::Null => Ok("n".to_string()), + DataType::Boolean => Ok("b".to_string()), + DataType::Int8 => Ok("c".to_string()), + DataType::UInt8 => Ok("C".to_string()), + DataType::Int16 => Ok("s".to_string()), + DataType::UInt16 => Ok("S".to_string()), + DataType::Int32 => Ok("i".to_string()), + DataType::UInt32 => Ok("I".to_string()), + DataType::Int64 => Ok("l".to_string()), + DataType::UInt64 => Ok("L".to_string()), + DataType::Float16 => Ok("e".to_string()), + DataType::Float32 => Ok("f".to_string()), + DataType::Float64 => Ok("g".to_string()), + DataType::Binary => Ok("z".to_string()), + DataType::LargeBinary => Ok("Z".to_string()), + DataType::Utf8 => Ok("u".to_string()), + DataType::LargeUtf8 => Ok("U".to_string()), + DataType::FixedSizeBinary(num_bytes) => Ok(format!("w:{num_bytes}")), + DataType::FixedSizeList(_, num_elems) => Ok(format!("+w:{num_elems}")), + DataType::Decimal128(precision, scale) => Ok(format!("d:{precision},{scale}")), + DataType::Decimal256(precision, scale) => { + Ok(format!("d:{precision},{scale},256")) + } + DataType::Date32 => Ok("tdD".to_string()), + DataType::Date64 => Ok("tdm".to_string()), + DataType::Time32(TimeUnit::Second) => Ok("tts".to_string()), + DataType::Time32(TimeUnit::Millisecond) => Ok("ttm".to_string()), + DataType::Time64(TimeUnit::Microsecond) => Ok("ttu".to_string()), + DataType::Time64(TimeUnit::Nanosecond) => Ok("ttn".to_string()), + DataType::Timestamp(TimeUnit::Second, None) => Ok("tss:".to_string()), + DataType::Timestamp(TimeUnit::Millisecond, None) => Ok("tsm:".to_string()), + DataType::Timestamp(TimeUnit::Microsecond, None) => Ok("tsu:".to_string()), + DataType::Timestamp(TimeUnit::Nanosecond, None) => Ok("tsn:".to_string()), + DataType::Timestamp(TimeUnit::Second, Some(tz)) => Ok(format!("tss:{tz}")), + DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => Ok(format!("tsm:{tz}")), + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => Ok(format!("tsu:{tz}")), + DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => Ok(format!("tsn:{tz}")), + DataType::Duration(TimeUnit::Second) => Ok("tDs".to_string()), + DataType::Duration(TimeUnit::Millisecond) => Ok("tDm".to_string()), + DataType::Duration(TimeUnit::Microsecond) => Ok("tDu".to_string()), + DataType::Duration(TimeUnit::Nanosecond) => Ok("tDn".to_string()), + DataType::List(_) => Ok("+l".to_string()), + DataType::LargeList(_) => Ok("+L".to_string()), + DataType::Struct(_) => Ok("+s".to_string()), + DataType::Map(_, _) => Ok("+m".to_string()), + DataType::Dictionary(key_data_type, _) => get_format_string(key_data_type), + DataType::Union(_, type_ids, mode) => { + let formats = type_ids.iter().map(|t| t.to_string()).collect::>(); + match mode { + UnionMode::Dense => Ok(format!("{}:{}", "+ud", formats.join(","))), + UnionMode::Sparse => Ok(format!("{}:{}", "+us", formats.join(","))), + } + } + other => Err(ArrowError::CDataInterface(format!( + "The datatype \"{other:?}\" is still not supported in Rust implementation" + ))), + } +} + +impl TryFrom<&Field> for FFI_ArrowSchema { + type Error = ArrowError; + + fn try_from(field: &Field) -> Result { + let mut flags = if field.is_nullable() { + Flags::NULLABLE + } else { + Flags::empty() + }; + + if let Some(true) = field.dict_is_ordered() { + flags |= Flags::DICTIONARY_ORDERED; + } + + FFI_ArrowSchema::try_from(field.data_type())? + .with_name(field.name())? + .with_flags(flags) + } +} + +impl TryFrom<&Schema> for FFI_ArrowSchema { + type Error = ArrowError; + + fn try_from(schema: &Schema) -> Result { + let dtype = DataType::Struct(schema.fields().clone()); + let c_schema = FFI_ArrowSchema::try_from(&dtype)?; + Ok(c_schema) + } +} + +impl TryFrom for FFI_ArrowSchema { + type Error = ArrowError; + + fn try_from(dtype: DataType) -> Result { + FFI_ArrowSchema::try_from(&dtype) + } +} + +impl TryFrom for FFI_ArrowSchema { + type Error = ArrowError; + + fn try_from(field: Field) -> Result { + FFI_ArrowSchema::try_from(&field) + } +} + +impl TryFrom for FFI_ArrowSchema { + type Error = ArrowError; + + fn try_from(schema: Schema) -> Result { + FFI_ArrowSchema::try_from(&schema) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn round_trip_type(dtype: DataType) { + let c_schema = FFI_ArrowSchema::try_from(&dtype).unwrap(); + let restored = DataType::try_from(&c_schema).unwrap(); + assert_eq!(restored, dtype); + } + + fn round_trip_field(field: Field) { + let c_schema = FFI_ArrowSchema::try_from(&field).unwrap(); + let restored = Field::try_from(&c_schema).unwrap(); + assert_eq!(restored, field); + } + + fn round_trip_schema(schema: Schema) { + let c_schema = FFI_ArrowSchema::try_from(&schema).unwrap(); + let restored = Schema::try_from(&c_schema).unwrap(); + assert_eq!(restored, schema); + } + + #[test] + fn test_type() { + round_trip_type(DataType::Int64); + round_trip_type(DataType::UInt64); + round_trip_type(DataType::Float64); + round_trip_type(DataType::Date64); + round_trip_type(DataType::Time64(TimeUnit::Nanosecond)); + round_trip_type(DataType::FixedSizeBinary(12)); + round_trip_type(DataType::FixedSizeList( + Box::new(Field::new("a", DataType::Int64, false)), + 5, + )); + round_trip_type(DataType::Utf8); + round_trip_type(DataType::List(Box::new(Field::new( + "a", + DataType::Int16, + false, + )))); + round_trip_type(DataType::Struct(vec![Field::new( + "a", + DataType::Utf8, + true, + )])); + } + + #[test] + fn test_field() { + let dtype = DataType::Struct(vec![Field::new("a", DataType::Utf8, true)]); + round_trip_field(Field::new("test", dtype, true)); + } + + #[test] + fn test_schema() { + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ]); + round_trip_schema(schema); + + // test that we can interpret struct types as schema + let dtype = DataType::Struct(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int16, false), + ]); + let c_schema = FFI_ArrowSchema::try_from(&dtype).unwrap(); + let schema = Schema::try_from(&c_schema).unwrap(); + assert_eq!(schema.fields().len(), 2); + + // test that we assert the input type + let c_schema = FFI_ArrowSchema::try_from(&DataType::Float64).unwrap(); + let result = Schema::try_from(&c_schema); + assert!(result.is_err()); + } + + #[test] + fn test_map_keys_sorted() { + let keys = Field::new("keys", DataType::Int32, false); + let values = Field::new("values", DataType::UInt32, false); + let entry_struct = DataType::Struct(vec![keys, values]); + + // Construct a map array from the above two + let map_data_type = + DataType::Map(Box::new(Field::new("entries", entry_struct, true)), true); + + let arrow_schema = FFI_ArrowSchema::try_from(map_data_type).unwrap(); + assert!(arrow_schema.map_keys_sorted()); + } + + #[test] + fn test_dictionary_ordered() { + let schema = Schema::new(vec![Field::new_dict( + "dict", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + 0, + true, + )]); + + let arrow_schema = FFI_ArrowSchema::try_from(schema).unwrap(); + assert!(arrow_schema.child(0).dictionary_ordered()); + } +} diff --git a/arrow-schema/src/lib.rs b/arrow-schema/src/lib.rs index c2b1aba3b926..6bc2329dbd36 100644 --- a/arrow-schema/src/lib.rs +++ b/arrow-schema/src/lib.rs @@ -26,6 +26,9 @@ pub use field::*; mod schema; pub use schema::*; +#[cfg(feature = "ffi")] +pub mod ffi; + /// Options that define the sort order of a given column #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct SortOptions { diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 814ca14c8058..ef89e5a81232 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -62,7 +62,6 @@ arrow-string = { version = "33.0.0", path = "../arrow-string" } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } comfy-table = { version = "6.0", optional = true, default-features = false } pyo3 = { version = "0.18", default-features = false, optional = true } -bitflags = { version = "1.2.1", default-features = false, optional = true } [package.metadata.docs.rs] features = ["prettyprint", "ipc_compression", "dyn_cmp_dict", "dyn_arith_dict", "ffi", "pyarrow"] @@ -86,7 +85,7 @@ pyarrow = ["pyo3", "ffi"] # but is run as part of our CI checks force_validate = ["arrow-data/force_validate"] # Enable ffi support -ffi = ["bitflags"] +ffi = ["arrow-schema/ffi", "arrow-data/ffi"] # Enable dyn-comparison of dictionary arrays with other arrays # Note: this does not impact comparison against scalars dyn_cmp_dict = ["arrow-string/dyn_cmp_dict", "arrow-ord/dyn_cmp_dict"] diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index 58cad3d08a4e..b248758bc120 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -14,505 +14,3 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - -use arrow_schema::UnionMode; -use std::convert::TryFrom; - -use crate::datatypes::DataType::Map; -use crate::{ - datatypes::{DataType, Field, Schema, TimeUnit}, - error::{ArrowError, Result}, - ffi::{FFI_ArrowSchema, Flags}, -}; - -impl TryFrom<&FFI_ArrowSchema> for DataType { - type Error = ArrowError; - - /// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings) - fn try_from(c_schema: &FFI_ArrowSchema) -> Result { - let mut dtype = match c_schema.format() { - "n" => DataType::Null, - "b" => DataType::Boolean, - "c" => DataType::Int8, - "C" => DataType::UInt8, - "s" => DataType::Int16, - "S" => DataType::UInt16, - "i" => DataType::Int32, - "I" => DataType::UInt32, - "l" => DataType::Int64, - "L" => DataType::UInt64, - "e" => DataType::Float16, - "f" => DataType::Float32, - "g" => DataType::Float64, - "z" => DataType::Binary, - "Z" => DataType::LargeBinary, - "u" => DataType::Utf8, - "U" => DataType::LargeUtf8, - "tdD" => DataType::Date32, - "tdm" => DataType::Date64, - "tts" => DataType::Time32(TimeUnit::Second), - "ttm" => DataType::Time32(TimeUnit::Millisecond), - "ttu" => DataType::Time64(TimeUnit::Microsecond), - "ttn" => DataType::Time64(TimeUnit::Nanosecond), - "tDs" => DataType::Duration(TimeUnit::Second), - "tDm" => DataType::Duration(TimeUnit::Millisecond), - "tDu" => DataType::Duration(TimeUnit::Microsecond), - "tDn" => DataType::Duration(TimeUnit::Nanosecond), - "+l" => { - let c_child = c_schema.child(0); - DataType::List(Box::new(Field::try_from(c_child)?)) - } - "+L" => { - let c_child = c_schema.child(0); - DataType::LargeList(Box::new(Field::try_from(c_child)?)) - } - "+s" => { - let fields = c_schema.children().map(Field::try_from); - DataType::Struct(fields.collect::>>()?) - } - "+m" => { - let c_child = c_schema.child(0); - let map_keys_sorted = c_schema.map_keys_sorted(); - DataType::Map(Box::new(Field::try_from(c_child)?), map_keys_sorted) - } - // Parametrized types, requiring string parse - other => { - match other.splitn(2, ':').collect::>().as_slice() { - // FixedSizeBinary type in format "w:num_bytes" - ["w", num_bytes] => { - let parsed_num_bytes = num_bytes.parse::().map_err(|_| { - ArrowError::CDataInterface( - "FixedSizeBinary requires an integer parameter representing number of bytes per element".to_string()) - })?; - DataType::FixedSizeBinary(parsed_num_bytes) - }, - // FixedSizeList type in format "+w:num_elems" - ["+w", num_elems] => { - let c_child = c_schema.child(0); - let parsed_num_elems = num_elems.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The FixedSizeList type requires an integer parameter representing number of elements per list".to_string()) - })?; - DataType::FixedSizeList(Box::new(Field::try_from(c_child)?), parsed_num_elems) - }, - // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth" - ["d", extra] => { - match extra.splitn(3, ',').collect::>().as_slice() { - [precision, scale] => { - let parsed_precision = precision.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The decimal type requires an integer precision".to_string(), - ) - })?; - let parsed_scale = scale.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The decimal type requires an integer scale".to_string(), - ) - })?; - DataType::Decimal128(parsed_precision, parsed_scale) - }, - [precision, scale, bits] => { - if *bits != "128" && *bits != "256" { - return Err(ArrowError::CDataInterface("Only 128/256 bit wide decimal is supported in the Rust implementation".to_string())); - } - let parsed_precision = precision.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The decimal type requires an integer precision".to_string(), - ) - })?; - let parsed_scale = scale.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The decimal type requires an integer scale".to_string(), - ) - })?; - if *bits == "128" { - DataType::Decimal128(parsed_precision, parsed_scale) - } else { - DataType::Decimal256(parsed_precision, parsed_scale) - } - } - _ => { - return Err(ArrowError::CDataInterface(format!( - "The decimal pattern \"d:{extra:?}\" is not supported in the Rust implementation" - ))) - } - } - } - // DenseUnion - ["+ud", extra] => { - let type_ids = extra.split(',').map(|t| t.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The Union type requires an integer type id".to_string(), - ) - })).collect::>>()?; - let mut fields = Vec::with_capacity(type_ids.len()); - for idx in 0..c_schema.n_children { - let c_child = c_schema.child(idx as usize); - let field = Field::try_from(c_child)?; - fields.push(field); - } - - if fields.len() != type_ids.len() { - return Err(ArrowError::CDataInterface( - "The Union type requires same number of fields and type ids".to_string(), - )); - } - - DataType::Union(fields, type_ids, UnionMode::Dense) - } - // SparseUnion - ["+us", extra] => { - let type_ids = extra.split(',').map(|t| t.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The Union type requires an integer type id".to_string(), - ) - })).collect::>>()?; - let mut fields = Vec::with_capacity(type_ids.len()); - for idx in 0..c_schema.n_children { - let c_child = c_schema.child(idx as usize); - let field = Field::try_from(c_child)?; - fields.push(field); - } - - if fields.len() != type_ids.len() { - return Err(ArrowError::CDataInterface( - "The Union type requires same number of fields and type ids".to_string(), - )); - } - - DataType::Union(fields, type_ids, UnionMode::Sparse) - } - - // Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp. - ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None), - ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None), - ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None), - ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None), - ["tss", tz] => { - DataType::Timestamp(TimeUnit::Second, Some(tz.to_string())) - } - ["tsm", tz] => { - DataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string())) - } - ["tsu", tz] => { - DataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string())) - } - ["tsn", tz] => { - DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string())) - } - _ => { - return Err(ArrowError::CDataInterface(format!( - "The datatype \"{other:?}\" is still not supported in Rust implementation" - ))) - } - } - } - }; - - if let Some(dict_schema) = c_schema.dictionary() { - let value_type = Self::try_from(dict_schema)?; - dtype = DataType::Dictionary(Box::new(dtype), Box::new(value_type)); - } - - Ok(dtype) - } -} - -impl TryFrom<&FFI_ArrowSchema> for Field { - type Error = ArrowError; - - fn try_from(c_schema: &FFI_ArrowSchema) -> Result { - let dtype = DataType::try_from(c_schema)?; - let field = Field::new(c_schema.name(), dtype, c_schema.nullable()); - Ok(field) - } -} - -impl TryFrom<&FFI_ArrowSchema> for Schema { - type Error = ArrowError; - - fn try_from(c_schema: &FFI_ArrowSchema) -> Result { - // interpret it as a struct type then extract its fields - let dtype = DataType::try_from(c_schema)?; - if let DataType::Struct(fields) = dtype { - Ok(Schema::new(fields)) - } else { - Err(ArrowError::CDataInterface( - "Unable to interpret C data struct as a Schema".to_string(), - )) - } - } -} - -impl TryFrom<&DataType> for FFI_ArrowSchema { - type Error = ArrowError; - - /// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings) - fn try_from(dtype: &DataType) -> Result { - let format = get_format_string(dtype)?; - // allocate and hold the children - let children = match dtype { - DataType::List(child) - | DataType::LargeList(child) - | DataType::FixedSizeList(child, _) - | DataType::Map(child, _) => { - vec![FFI_ArrowSchema::try_from(child.as_ref())?] - } - DataType::Union(fields, _, _) => fields - .iter() - .map(FFI_ArrowSchema::try_from) - .collect::>>()?, - DataType::Struct(fields) => fields - .iter() - .map(FFI_ArrowSchema::try_from) - .collect::>>()?, - _ => vec![], - }; - let dictionary = if let DataType::Dictionary(_, value_data_type) = dtype { - Some(Self::try_from(value_data_type.as_ref())?) - } else { - None - }; - - let flags = match dtype { - Map(_, true) => Flags::MAP_KEYS_SORTED, - _ => Flags::empty(), - }; - - FFI_ArrowSchema::try_new(&format, children, dictionary)?.with_flags(flags) - } -} - -fn get_format_string(dtype: &DataType) -> Result { - match dtype { - DataType::Null => Ok("n".to_string()), - DataType::Boolean => Ok("b".to_string()), - DataType::Int8 => Ok("c".to_string()), - DataType::UInt8 => Ok("C".to_string()), - DataType::Int16 => Ok("s".to_string()), - DataType::UInt16 => Ok("S".to_string()), - DataType::Int32 => Ok("i".to_string()), - DataType::UInt32 => Ok("I".to_string()), - DataType::Int64 => Ok("l".to_string()), - DataType::UInt64 => Ok("L".to_string()), - DataType::Float16 => Ok("e".to_string()), - DataType::Float32 => Ok("f".to_string()), - DataType::Float64 => Ok("g".to_string()), - DataType::Binary => Ok("z".to_string()), - DataType::LargeBinary => Ok("Z".to_string()), - DataType::Utf8 => Ok("u".to_string()), - DataType::LargeUtf8 => Ok("U".to_string()), - DataType::FixedSizeBinary(num_bytes) => Ok(format!("w:{num_bytes}")), - DataType::FixedSizeList(_, num_elems) => Ok(format!("+w:{num_elems}")), - DataType::Decimal128(precision, scale) => Ok(format!("d:{precision},{scale}")), - DataType::Decimal256(precision, scale) => { - Ok(format!("d:{precision},{scale},256")) - } - DataType::Date32 => Ok("tdD".to_string()), - DataType::Date64 => Ok("tdm".to_string()), - DataType::Time32(TimeUnit::Second) => Ok("tts".to_string()), - DataType::Time32(TimeUnit::Millisecond) => Ok("ttm".to_string()), - DataType::Time64(TimeUnit::Microsecond) => Ok("ttu".to_string()), - DataType::Time64(TimeUnit::Nanosecond) => Ok("ttn".to_string()), - DataType::Timestamp(TimeUnit::Second, None) => Ok("tss:".to_string()), - DataType::Timestamp(TimeUnit::Millisecond, None) => Ok("tsm:".to_string()), - DataType::Timestamp(TimeUnit::Microsecond, None) => Ok("tsu:".to_string()), - DataType::Timestamp(TimeUnit::Nanosecond, None) => Ok("tsn:".to_string()), - DataType::Timestamp(TimeUnit::Second, Some(tz)) => Ok(format!("tss:{tz}")), - DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => Ok(format!("tsm:{tz}")), - DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => Ok(format!("tsu:{tz}")), - DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => Ok(format!("tsn:{tz}")), - DataType::Duration(TimeUnit::Second) => Ok("tDs".to_string()), - DataType::Duration(TimeUnit::Millisecond) => Ok("tDm".to_string()), - DataType::Duration(TimeUnit::Microsecond) => Ok("tDu".to_string()), - DataType::Duration(TimeUnit::Nanosecond) => Ok("tDn".to_string()), - DataType::List(_) => Ok("+l".to_string()), - DataType::LargeList(_) => Ok("+L".to_string()), - DataType::Struct(_) => Ok("+s".to_string()), - DataType::Map(_, _) => Ok("+m".to_string()), - DataType::Dictionary(key_data_type, _) => get_format_string(key_data_type), - DataType::Union(_, type_ids, mode) => { - let formats = type_ids.iter().map(|t| t.to_string()).collect::>(); - match mode { - UnionMode::Dense => Ok(format!("{}:{}", "+ud", formats.join(","))), - UnionMode::Sparse => Ok(format!("{}:{}", "+us", formats.join(","))), - } - } - other => Err(ArrowError::CDataInterface(format!( - "The datatype \"{other:?}\" is still not supported in Rust implementation" - ))), - } -} - -impl TryFrom<&Field> for FFI_ArrowSchema { - type Error = ArrowError; - - fn try_from(field: &Field) -> Result { - let mut flags = if field.is_nullable() { - Flags::NULLABLE - } else { - Flags::empty() - }; - - if let Some(true) = field.dict_is_ordered() { - flags |= Flags::DICTIONARY_ORDERED; - } - - FFI_ArrowSchema::try_from(field.data_type())? - .with_name(field.name())? - .with_flags(flags) - } -} - -impl TryFrom<&Schema> for FFI_ArrowSchema { - type Error = ArrowError; - - fn try_from(schema: &Schema) -> Result { - let dtype = DataType::Struct(schema.fields().clone()); - let c_schema = FFI_ArrowSchema::try_from(&dtype)?; - Ok(c_schema) - } -} - -impl TryFrom for FFI_ArrowSchema { - type Error = ArrowError; - - fn try_from(dtype: DataType) -> Result { - FFI_ArrowSchema::try_from(&dtype) - } -} - -impl TryFrom for FFI_ArrowSchema { - type Error = ArrowError; - - fn try_from(field: Field) -> Result { - FFI_ArrowSchema::try_from(&field) - } -} - -impl TryFrom for FFI_ArrowSchema { - type Error = ArrowError; - - fn try_from(schema: Schema) -> Result { - FFI_ArrowSchema::try_from(&schema) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::datatypes::{DataType, Field, TimeUnit}; - use crate::error::Result; - use std::convert::TryFrom; - - fn round_trip_type(dtype: DataType) -> Result<()> { - let c_schema = FFI_ArrowSchema::try_from(&dtype)?; - let restored = DataType::try_from(&c_schema)?; - assert_eq!(restored, dtype); - Ok(()) - } - - fn round_trip_field(field: Field) -> Result<()> { - let c_schema = FFI_ArrowSchema::try_from(&field)?; - let restored = Field::try_from(&c_schema)?; - assert_eq!(restored, field); - Ok(()) - } - - fn round_trip_schema(schema: Schema) -> Result<()> { - let c_schema = FFI_ArrowSchema::try_from(&schema)?; - let restored = Schema::try_from(&c_schema)?; - assert_eq!(restored, schema); - Ok(()) - } - - #[test] - fn test_type() -> Result<()> { - round_trip_type(DataType::Int64)?; - round_trip_type(DataType::UInt64)?; - round_trip_type(DataType::Float64)?; - round_trip_type(DataType::Date64)?; - round_trip_type(DataType::Time64(TimeUnit::Nanosecond))?; - round_trip_type(DataType::FixedSizeBinary(12))?; - round_trip_type(DataType::FixedSizeList( - Box::new(Field::new("a", DataType::Int64, false)), - 5, - ))?; - round_trip_type(DataType::Utf8)?; - round_trip_type(DataType::List(Box::new(Field::new( - "a", - DataType::Int16, - false, - ))))?; - round_trip_type(DataType::Struct(vec![Field::new( - "a", - DataType::Utf8, - true, - )]))?; - Ok(()) - } - - #[test] - fn test_field() -> Result<()> { - let dtype = DataType::Struct(vec![Field::new("a", DataType::Utf8, true)]); - round_trip_field(Field::new("test", dtype, true))?; - Ok(()) - } - - #[test] - fn test_schema() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("address", DataType::Utf8, false), - Field::new("priority", DataType::UInt8, false), - ]); - round_trip_schema(schema)?; - - // test that we can interpret struct types as schema - let dtype = DataType::Struct(vec![ - Field::new("a", DataType::Utf8, true), - Field::new("b", DataType::Int16, false), - ]); - let c_schema = FFI_ArrowSchema::try_from(&dtype)?; - let schema = Schema::try_from(&c_schema)?; - assert_eq!(schema.fields().len(), 2); - - // test that we assert the input type - let c_schema = FFI_ArrowSchema::try_from(&DataType::Float64)?; - let result = Schema::try_from(&c_schema); - assert!(result.is_err()); - Ok(()) - } - - #[test] - fn test_map_keys_sorted() -> Result<()> { - let keys = Field::new("keys", DataType::Int32, false); - let values = Field::new("values", DataType::UInt32, false); - let entry_struct = DataType::Struct(vec![keys, values]); - - // Construct a map array from the above two - let map_data_type = - DataType::Map(Box::new(Field::new("entries", entry_struct, true)), true); - - let arrow_schema = FFI_ArrowSchema::try_from(map_data_type)?; - assert!(arrow_schema.map_keys_sorted()); - - Ok(()) - } - - #[test] - fn test_dictionary_ordered() -> Result<()> { - let schema = Schema::new(vec![Field::new_dict( - "dict", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - false, - 0, - true, - )]); - - let arrow_schema = FFI_ArrowSchema::try_from(schema)?; - assert!(arrow_schema.child(0).dictionary_ordered()); - - Ok(()) - } -} diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index 0f0f94c7a6b8..4d62b9e7cf61 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -104,19 +104,9 @@ To import an array, unsafely create an `ArrowArray` from two pointers using [Arr To export an array, create an `ArrowArray` using [ArrowArray::try_new]. */ -use std::{ - convert::TryFrom, - ffi::CStr, - ffi::CString, - iter, - mem::size_of, - os::raw::{c_char, c_void}, - ptr::{self, NonNull}, - sync::Arc, -}; +use std::{mem::size_of, ptr::NonNull, sync::Arc}; use arrow_schema::UnionMode; -use bitflags::bitflags; use crate::array::{layout, ArrayData}; use crate::buffer::{Buffer, MutableBuffer}; @@ -124,194 +114,11 @@ use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -bitflags! { - pub struct Flags: i64 { - const DICTIONARY_ORDERED = 0b00000001; - const NULLABLE = 0b00000010; - const MAP_KEYS_SORTED = 0b00000100; - } -} - -/// ABI-compatible struct for `ArrowSchema` from C Data Interface -/// See -/// -/// ``` -/// # use arrow::ffi::FFI_ArrowSchema; -/// # use arrow_data::ArrayData; -/// fn array_schema(data: &ArrayData) -> FFI_ArrowSchema { -/// FFI_ArrowSchema::try_from(data.data_type()).unwrap() -/// } -/// ``` -/// -#[repr(C)] -#[derive(Debug)] -pub struct FFI_ArrowSchema { - pub(crate) format: *const c_char, - pub(crate) name: *const c_char, - pub(crate) metadata: *const c_char, - pub(crate) flags: i64, - pub(crate) n_children: i64, - pub(crate) children: *mut *mut FFI_ArrowSchema, - pub(crate) dictionary: *mut FFI_ArrowSchema, - pub(crate) release: Option, - pub(crate) private_data: *mut c_void, -} - -struct SchemaPrivateData { - children: Box<[*mut FFI_ArrowSchema]>, - dictionary: *mut FFI_ArrowSchema, -} - -// callback used to drop [FFI_ArrowSchema] when it is exported. -unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) { - if schema.is_null() { - return; - } - let schema = &mut *schema; - - // take ownership back to release it. - drop(CString::from_raw(schema.format as *mut c_char)); - if !schema.name.is_null() { - drop(CString::from_raw(schema.name as *mut c_char)); - } - if !schema.private_data.is_null() { - let private_data = Box::from_raw(schema.private_data as *mut SchemaPrivateData); - for child in private_data.children.iter() { - drop(Box::from_raw(*child)) - } - if !private_data.dictionary.is_null() { - drop(Box::from_raw(private_data.dictionary)); - } - - drop(private_data); - } - - schema.release = None; -} - -impl FFI_ArrowSchema { - /// create a new [`FFI_ArrowSchema`]. This fails if the fields' - /// [`DataType`] is not supported. - pub fn try_new( - format: &str, - children: Vec, - dictionary: Option, - ) -> Result { - let mut this = Self::empty(); - - let children_ptr = children - .into_iter() - .map(Box::new) - .map(Box::into_raw) - .collect::>(); - - this.format = CString::new(format).unwrap().into_raw(); - this.release = Some(release_schema); - this.n_children = children_ptr.len() as i64; - - let dictionary_ptr = dictionary - .map(|d| Box::into_raw(Box::new(d))) - .unwrap_or(std::ptr::null_mut()); - - let mut private_data = Box::new(SchemaPrivateData { - children: children_ptr, - dictionary: dictionary_ptr, - }); - - // intentionally set from private_data (see https://github.com/apache/arrow-rs/issues/580) - this.children = private_data.children.as_mut_ptr(); - - this.dictionary = dictionary_ptr; - - this.private_data = Box::into_raw(private_data) as *mut c_void; - - Ok(this) - } - - pub fn with_name(mut self, name: &str) -> Result { - self.name = CString::new(name).unwrap().into_raw(); - Ok(self) - } - - pub fn with_flags(mut self, flags: Flags) -> Result { - self.flags = flags.bits(); - Ok(self) - } - - pub fn empty() -> Self { - Self { - format: std::ptr::null_mut(), - name: std::ptr::null_mut(), - metadata: std::ptr::null_mut(), - flags: 0, - n_children: 0, - children: ptr::null_mut(), - dictionary: std::ptr::null_mut(), - release: None, - private_data: std::ptr::null_mut(), - } - } - - /// returns the format of this schema. - pub fn format(&self) -> &str { - assert!(!self.format.is_null()); - // safe because the lifetime of `self.format` equals `self` - unsafe { CStr::from_ptr(self.format) } - .to_str() - .expect("The external API has a non-utf8 as format") - } - - /// returns the name of this schema. - pub fn name(&self) -> &str { - assert!(!self.name.is_null()); - // safe because the lifetime of `self.name` equals `self` - unsafe { CStr::from_ptr(self.name) } - .to_str() - .expect("The external API has a non-utf8 as name") - } - - pub fn flags(&self) -> Option { - Flags::from_bits(self.flags) - } - - pub fn child(&self, index: usize) -> &Self { - assert!(index < self.n_children as usize); - unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() } - } - - pub fn children(&self) -> impl Iterator { - (0..self.n_children as usize).map(move |i| self.child(i)) - } - - pub fn nullable(&self) -> bool { - (self.flags / 2) & 1 == 1 - } - - pub fn dictionary(&self) -> Option<&Self> { - unsafe { self.dictionary.as_ref() } - } - - pub fn map_keys_sorted(&self) -> bool { - self.flags & 0b00000100 != 0 - } - - pub fn dictionary_ordered(&self) -> bool { - self.flags & 0b00000001 != 0 - } -} - -impl Drop for FFI_ArrowSchema { - fn drop(&mut self) { - match self.release { - None => (), - Some(release) => unsafe { release(self) }, - }; - } -} +pub use arrow_data::ffi::FFI_ArrowArray; +pub use arrow_schema::ffi::{FFI_ArrowSchema, Flags}; // returns the number of bits that buffer `i` (in the C data interface) is expected to have. // This is set by the Arrow specification -#[allow(clippy::manual_bits)] fn bit_width(data_type: &DataType, i: usize) -> Result { if let Some(primitive) = data_type.primitive_width() { return match i { @@ -332,7 +139,7 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { "The datatype \"{data_type:?}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." ))) } - (DataType::FixedSizeBinary(num_bytes), 1) => size_of::() * (*num_bytes as usize) * 8, + (DataType::FixedSizeBinary(num_bytes), 1) => *num_bytes as usize * u8::BITS as usize, (DataType::FixedSizeList(f, num_elems), 1) => { let child_bit_width = bit_width(f.data_type(), 1)?; child_bit_width * (*num_elems as usize) @@ -345,8 +152,8 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { // Variable-size list and map have one i32 buffer. // Variable-sized binaries: have two buffers. // "small": first buffer is i32, second is in bytes - (DataType::Utf8, 1) | (DataType::Binary, 1) | (DataType::List(_), 1) | (DataType::Map(_, _), 1) => size_of::() * 8, - (DataType::Utf8, 2) | (DataType::Binary, 2) => size_of::() * 8, + (DataType::Utf8, 1) | (DataType::Binary, 1) | (DataType::List(_), 1) | (DataType::Map(_, _), 1) => i32::BITS as _, + (DataType::Utf8, 2) | (DataType::Binary, 2) => u8::BITS as _, (DataType::List(_), _) | (DataType::Map(_, _), _) => { return Err(ArrowError::CDataInterface(format!( "The datatype \"{data_type:?}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." @@ -359,17 +166,17 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { } // Variable-sized binaries: have two buffers. // LargeUtf8: first buffer is i64, second is in bytes - (DataType::LargeUtf8, 1) | (DataType::LargeBinary, 1) | (DataType::LargeList(_), 1) => size_of::() * 8, - (DataType::LargeUtf8, 2) | (DataType::LargeBinary, 2) | (DataType::LargeList(_), 2)=> size_of::() * 8, + (DataType::LargeUtf8, 1) | (DataType::LargeBinary, 1) | (DataType::LargeList(_), 1) => i64::BITS as _, + (DataType::LargeUtf8, 2) | (DataType::LargeBinary, 2) | (DataType::LargeList(_), 2)=> u8::BITS as _, (DataType::LargeUtf8, _) | (DataType::LargeBinary, _) | (DataType::LargeList(_), _)=> { return Err(ArrowError::CDataInterface(format!( "The datatype \"{data_type:?}\" expects 3 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." ))) } // type ids. UnionArray doesn't have null bitmap so buffer index begins with 0. - (DataType::Union(_, _, _), 0) => size_of::() * 8, + (DataType::Union(_, _, _), 0) => i8::BITS as _, // Only DenseUnion has 2nd buffer - (DataType::Union(_, _, UnionMode::Dense), 1) => size_of::() * 8, + (DataType::Union(_, _, UnionMode::Dense), 1) => i32::BITS as _, (DataType::Union(_, _, UnionMode::Sparse), _) => { return Err(ArrowError::CDataInterface(format!( "The datatype \"{data_type:?}\" expects 1 buffer, but requested {i}. Please verify that the C data interface is correctly implemented." @@ -395,190 +202,6 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { }) } -/// ABI-compatible struct for ArrowArray from C Data Interface -/// See -/// -/// ``` -/// # use arrow::ffi::FFI_ArrowArray; -/// # use arrow_array::Array; -/// fn export_array(array: &dyn Array) -> FFI_ArrowArray { -/// FFI_ArrowArray::new(array.data()) -/// } -/// ``` -#[repr(C)] -#[derive(Debug)] -pub struct FFI_ArrowArray { - pub(crate) length: i64, - pub(crate) null_count: i64, - pub(crate) offset: i64, - pub(crate) n_buffers: i64, - pub(crate) n_children: i64, - pub(crate) buffers: *mut *const c_void, - pub(crate) children: *mut *mut FFI_ArrowArray, - pub(crate) dictionary: *mut FFI_ArrowArray, - pub(crate) release: Option, - // When exported, this MUST contain everything that is owned by this array. - // for example, any buffer pointed to in `buffers` must be here, as well - // as the `buffers` pointer itself. - // In other words, everything in [FFI_ArrowArray] must be owned by - // `private_data` and can assume that they do not outlive `private_data`. - pub(crate) private_data: *mut c_void, -} - -impl Drop for FFI_ArrowArray { - fn drop(&mut self) { - match self.release { - None => (), - Some(release) => unsafe { release(self) }, - }; - } -} - -unsafe impl Send for FFI_ArrowArray {} -unsafe impl Sync for FFI_ArrowArray {} - -// callback used to drop [FFI_ArrowArray] when it is exported -unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) { - if array.is_null() { - return; - } - let array = &mut *array; - - // take ownership of `private_data`, therefore dropping it` - let private = Box::from_raw(array.private_data as *mut ArrayPrivateData); - for child in private.children.iter() { - let _ = Box::from_raw(*child); - } - if !private.dictionary.is_null() { - let _ = Box::from_raw(private.dictionary); - } - - array.release = None; -} - -struct ArrayPrivateData { - #[allow(dead_code)] - buffers: Vec>, - buffers_ptr: Box<[*const c_void]>, - children: Box<[*mut FFI_ArrowArray]>, - dictionary: *mut FFI_ArrowArray, -} - -impl FFI_ArrowArray { - /// creates a new `FFI_ArrowArray` from existing data. - /// # Memory Leaks - /// This method releases `buffers`. Consumers of this struct *must* call `release` before - /// releasing this struct, or contents in `buffers` leak. - pub fn new(data: &ArrayData) -> Self { - let data_layout = layout(data.data_type()); - - let buffers = if data_layout.can_contain_null_mask { - // * insert the null buffer at the start - // * make all others `Option`. - iter::once(data.null_buffer().cloned()) - .chain(data.buffers().iter().map(|b| Some(b.clone()))) - .collect::>() - } else { - data.buffers().iter().map(|b| Some(b.clone())).collect() - }; - - // `n_buffers` is the number of buffers by the spec. - let n_buffers = { - data_layout.buffers.len() + { - // If the layout has a null buffer by Arrow spec. - // Note that even the array doesn't have a null buffer because it has - // no null value, we still need to count 1 here to follow the spec. - usize::from(data_layout.can_contain_null_mask) - } - } as i64; - - let buffers_ptr = buffers - .iter() - .flat_map(|maybe_buffer| match maybe_buffer { - // note that `raw_data` takes into account the buffer's offset - Some(b) => Some(b.as_ptr() as *const c_void), - // This is for null buffer. We only put a null pointer for - // null buffer if by spec it can contain null mask. - None if data_layout.can_contain_null_mask => Some(std::ptr::null()), - None => None, - }) - .collect::>(); - - let empty = vec![]; - let (child_data, dictionary) = match data.data_type() { - DataType::Dictionary(_, _) => ( - empty.as_slice(), - Box::into_raw(Box::new(FFI_ArrowArray::new(&data.child_data()[0]))), - ), - _ => (data.child_data(), std::ptr::null_mut()), - }; - - let children = child_data - .iter() - .map(|child| Box::into_raw(Box::new(FFI_ArrowArray::new(child)))) - .collect::>(); - let n_children = children.len() as i64; - - // create the private data owning everything. - // any other data must be added here, e.g. via a struct, to track lifetime. - let mut private_data = Box::new(ArrayPrivateData { - buffers, - buffers_ptr, - children, - dictionary, - }); - - Self { - length: data.len() as i64, - null_count: data.null_count() as i64, - offset: data.offset() as i64, - n_buffers, - n_children, - buffers: private_data.buffers_ptr.as_mut_ptr(), - children: private_data.children.as_mut_ptr(), - dictionary, - release: Some(release_array), - private_data: Box::into_raw(private_data) as *mut c_void, - } - } - - /// create an empty `FFI_ArrowArray`, which can be used to import data into - pub fn empty() -> Self { - Self { - length: 0, - null_count: 0, - offset: 0, - n_buffers: 0, - n_children: 0, - buffers: std::ptr::null_mut(), - children: std::ptr::null_mut(), - dictionary: std::ptr::null_mut(), - release: None, - private_data: std::ptr::null_mut(), - } - } - - /// the length of the array - pub fn len(&self) -> usize { - self.length as usize - } - - /// whether the array is empty - pub fn is_empty(&self) -> bool { - self.length == 0 - } - - /// the offset of the array - pub fn offset(&self) -> usize { - self.offset as usize - } - - /// the null count of the array - pub fn null_count(&self) -> usize { - self.null_count as usize - } -} - /// returns a new buffer corresponding to the index `i` of the FFI array. It may not exist (null pointer). /// `bits` is the number of bits that the native type of this buffer has. /// The size of the buffer will be `ceil(self.length * bits, 8)`. @@ -592,38 +215,13 @@ unsafe fn create_buffer( index: usize, len: usize, ) -> Option { - if array.buffers.is_null() || array.n_buffers == 0 { + if array.num_buffers() == 0 { return None; } - let buffers = array.buffers as *mut *const u8; - - assert!(index < array.n_buffers as usize); - let ptr = *buffers.add(index); - - NonNull::new(ptr as *mut u8) + NonNull::new(array.buffer(index) as _) .map(|ptr| Buffer::from_custom_allocation(ptr, len, owner)) } -fn create_child( - owner: Arc, - array: &FFI_ArrowArray, - schema: &FFI_ArrowSchema, - index: usize, -) -> ArrowArrayChild<'static> { - assert!(index < array.n_children as usize); - assert!(!array.children.is_null()); - assert!(!array.children.is_null()); - unsafe { - let arr_ptr = *array.children.add(index); - let schema_ptr = *schema.children.add(index); - assert!(!arr_ptr.is_null()); - assert!(!schema_ptr.is_null()); - let arr_ptr = &*arr_ptr; - let schema_ptr = &*schema_ptr; - ArrowArrayChild::from_raw(arr_ptr, schema_ptr, owner) - } -} - pub trait ArrowArrayRef { fn to_data(&self) -> Result { let data_type = self.data_type()?; @@ -640,7 +238,7 @@ pub trait ArrowArrayRef { None }; - let mut child_data: Vec = (0..self.array().n_children as usize) + let mut child_data: Vec = (0..self.array().num_children()) .map(|i| { let child = self.child(i); child.to_data() @@ -673,11 +271,9 @@ pub trait ArrowArrayRef { /// in the spec of the type) fn buffers(&self, can_contain_null_mask: bool) -> Result> { // + 1: skip null buffer - let buffer_begin = can_contain_null_mask as i64; - (buffer_begin..self.array().n_buffers) + let buffer_begin = can_contain_null_mask as usize; + (buffer_begin..self.array().num_buffers()) .map(|index| { - let index = index as usize; - let len = self.buffer_len(index)?; match unsafe { @@ -711,7 +307,7 @@ pub trait ArrowArrayRef { // `ffi::ArrowArray` records array offset, we need to add it back to the // buffer length to get the actual buffer length. - let length = self.array().length as usize + self.array().offset as usize; + let length = self.array().len() + self.array().offset(); // Inner type is not important for buffer length. Ok(match (&data_type, i) { @@ -733,9 +329,7 @@ pub trait ArrowArrayRef { // first buffer is the null buffer => add(1) // we assume that pointer is aligned for `i32`, as Utf8 uses `i32` offsets. #[allow(clippy::cast_ptr_alignment)] - let offset_buffer = unsafe { - *(self.array().buffers as *mut *const u8).add(1) as *const i32 - }; + let offset_buffer = self.array().buffer(1) as *const i32; // get last offset (unsafe { *offset_buffer.add(len / size_of::() - 1) }) as usize } @@ -745,9 +339,7 @@ pub trait ArrowArrayRef { // first buffer is the null buffer => add(1) // we assume that pointer is aligned for `i64`, as Large uses `i64` offsets. #[allow(clippy::cast_ptr_alignment)] - let offset_buffer = unsafe { - *(self.array().buffers as *mut *const u8).add(1) as *const i64 - }; + let offset_buffer = self.array().buffer(1) as *const i64; // get last offset (unsafe { *offset_buffer.add(len / size_of::() - 1) }) as usize } @@ -766,14 +358,18 @@ pub trait ArrowArrayRef { // similar to `self.buffer_len(0)`, but without `Result`. // `ffi::ArrowArray` records array offset, we need to add it back to the // buffer length to get the actual buffer length. - let length = self.array().length as usize + self.array().offset as usize; + let length = self.array().len() + self.array().offset(); let buffer_len = bit_util::ceil(length, 8); unsafe { create_buffer(self.owner().clone(), self.array(), 0, buffer_len) } } fn child(&self, index: usize) -> ArrowArrayChild { - create_child(self.owner().clone(), self.array(), self.schema(), index) + ArrowArrayChild { + array: self.array().child(index), + schema: self.schema().child(index), + owner: self.owner(), + } } fn owner(&self) -> &Arc; @@ -781,18 +377,14 @@ pub trait ArrowArrayRef { fn schema(&self) -> &FFI_ArrowSchema; fn data_type(&self) -> Result; fn dictionary(&self) -> Option { - unsafe { - assert!(!(self.array().dictionary.is_null() ^ self.schema().dictionary.is_null()), - "Dictionary should both be set or not set in FFI_ArrowArray and FFI_ArrowSchema"); - if !self.array().dictionary.is_null() { - Some(ArrowArrayChild::from_raw( - &*self.array().dictionary, - &*self.schema().dictionary, - self.owner().clone(), - )) - } else { - None - } + match (self.array().dictionary(), self.schema().dictionary()) { + (Some(array), Some(schema)) => Some(ArrowArrayChild { + array, + schema, + owner: self.owner(), + }), + (None, None) => None, + _ => panic!("Dictionary should both be set or not set in FFI_ArrowArray and FFI_ArrowSchema") } } } @@ -827,7 +419,7 @@ pub struct ArrowArray { pub struct ArrowArrayChild<'a> { array: &'a FFI_ArrowArray, schema: &'a FFI_ArrowSchema, - owner: Arc, + owner: &'a Arc, } impl ArrowArrayRef for ArrowArray { @@ -864,7 +456,7 @@ impl<'a> ArrowArrayRef for ArrowArrayChild<'a> { } fn owner(&self) -> &Arc { - &self.owner + self.owner } } @@ -936,20 +528,6 @@ impl ArrowArray { } } -impl<'a> ArrowArrayChild<'a> { - fn from_raw( - array: &'a FFI_ArrowArray, - schema: &'a FFI_ArrowSchema, - owner: Arc, - ) -> Self { - Self { - array, - schema, - owner, - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -957,7 +535,7 @@ mod tests { export_array_into_raw, make_array, Array, ArrayData, BooleanArray, Decimal128Array, DictionaryArray, DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, GenericStringArray, - Int32Array, MapArray, NullArray, OffsetSizeTrait, Time32MillisecondArray, + Int32Array, MapArray, OffsetSizeTrait, Time32MillisecondArray, TimestampMillisecondArray, UInt32Array, }; use crate::compute::kernels; @@ -1004,8 +582,9 @@ mod tests { // We can read them back to memory // SAFETY: // Pointers are aligned and valid - let array = - unsafe { ArrowArray::new(ptr::read(array_ptr), ptr::read(schema_ptr)) }; + let array = unsafe { + ArrowArray::new(std::ptr::read(array_ptr), std::ptr::read(schema_ptr)) + }; let array = Int32Array::from(ArrayData::try_from(array).unwrap()); assert_eq!(array, Int32Array::from(vec![1, 2, 3])); @@ -1526,24 +1105,6 @@ mod tests { Ok(()) } - #[test] - fn null_array_n_buffers() -> Result<()> { - let array = NullArray::new(10); - let data = array.data(); - - let ffi_array = FFI_ArrowArray::new(data); - assert_eq!(0, ffi_array.n_buffers); - - let private_data = - unsafe { Box::from_raw(ffi_array.private_data as *mut ArrayPrivateData) }; - - assert_eq!(0, private_data.buffers_ptr.len()); - - Box::into_raw(private_data); - - Ok(()) - } - #[test] fn test_map_array() -> Result<()> { let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"]; diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs index 4313eaaaf34f..b1046d142f32 100644 --- a/arrow/src/ffi_stream.rs +++ b/arrow/src/ffi_stream.rs @@ -60,6 +60,7 @@ //! } //! ``` +use std::ptr::addr_of; use std::{ convert::TryFrom, ffi::CString, @@ -203,11 +204,11 @@ impl ExportedArrayStream { let schema = FFI_ArrowSchema::try_from(reader.schema().as_ref()); match schema { - Ok(mut schema) => unsafe { - std::ptr::copy(&schema as *const FFI_ArrowSchema, out, 1); - schema.release = None; + Ok(schema) => { + unsafe { std::ptr::copy(addr_of!(schema), out, 1) }; + std::mem::forget(schema); 0 - }, + } Err(ref err) => { private_data.last_error = err.to_string(); get_error_code(err) @@ -222,21 +223,17 @@ impl ExportedArrayStream { let ret_code = match reader.next() { None => { // Marks ArrowArray released to indicate reaching the end of stream. - unsafe { - (*out).release = None; - } + unsafe { std::ptr::write(out, FFI_ArrowArray::empty()) } 0 } Some(next_batch) => { if let Ok(batch) = next_batch { let struct_array = StructArray::from(batch); - let mut array = FFI_ArrowArray::new(struct_array.data()); + let array = FFI_ArrowArray::new(struct_array.data()); - unsafe { - std::ptr::copy(&array as *const FFI_ArrowArray, out, 1); - array.release = None; - 0 - } + unsafe { std::ptr::copy(addr_of!(array), out, 1) }; + std::mem::forget(array); + 0 } else { let err = &next_batch.unwrap_err(); private_data.last_error = err.to_string(); @@ -362,7 +359,9 @@ impl Iterator for ArrowArrayStreamReader { let ffi_array = unsafe { Arc::from_raw(array_ptr) }; // The end of stream has been reached - ffi_array.release?; + if ffi_array.is_released() { + return None; + } let schema_ref = self.schema(); let schema = FFI_ArrowSchema::try_from(schema_ref.as_ref()).ok()?; @@ -482,7 +481,7 @@ mod tests { // The end of stream has been reached let ffi_array = unsafe { Arc::from_raw(array_ptr) }; - if ffi_array.release.is_none() { + if ffi_array.is_released() { break; }