From fbba196853e8f249686faf254010485f056de359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 16 Aug 2021 17:51:58 +0200 Subject: [PATCH 1/4] PyO3 bridge for pyarrow interoperability --- arrow-pyarrow-integration-testing/Cargo.toml | 4 +- arrow-pyarrow-integration-testing/src/lib.rs | 206 ++++------------ .../tests/test_sql.py | 26 +-- arrow/Cargo.toml | 7 +- arrow/src/array/array.rs | 9 +- arrow/src/lib.rs | 2 + arrow/src/pyarrow.rs | 221 ++++++++++++++++++ 7 files changed, 294 insertions(+), 181 deletions(-) create mode 100644 arrow/src/pyarrow.rs diff --git a/arrow-pyarrow-integration-testing/Cargo.toml b/arrow-pyarrow-integration-testing/Cargo.toml index 59a084fdf7fc..f1d226dcb140 100644 --- a/arrow-pyarrow-integration-testing/Cargo.toml +++ b/arrow-pyarrow-integration-testing/Cargo.toml @@ -31,8 +31,8 @@ name = "arrow_pyarrow_integration_testing" crate-type = ["cdylib"] [dependencies] -arrow = { path = "../arrow", version = "6.0.0-SNAPSHOT" } -pyo3 = { version = "0.12.1", features = ["extension-module"] } +arrow = { path = "../arrow", version = "6.0.0-SNAPSHOT", features = ["pyarrow"] } +pyo3 = { version = "0.14", features = ["extension-module"] } [package.metadata.maturin] requires-dist = ["pyarrow>=1"] diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index a601654d0bcd..659cdeccb577 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -18,21 +18,21 @@ //! This library demonstrates a minimal usage of Rust's C data interface to pass //! arrays from and to Python. -use std::convert::TryFrom; use std::error; use std::fmt; use std::sync::Arc; use pyo3::exceptions::PyOSError; +use pyo3::prelude::*; use pyo3::wrap_pyfunction; -use pyo3::{libc::uintptr_t, prelude::*}; +//use libc::uintptr_t; -use arrow::array::{make_array_from_raw, ArrayRef, Int64Array}; +use arrow::array::{ArrayData, ArrayRef, Int64Array}; use arrow::compute::kernels; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; -use arrow::ffi; -use arrow::ffi::FFI_ArrowSchema; +use arrow::pyarrow::PyArrowConvert; +use arrow::record_batch::RecordBatch; /// an error that bridges ArrowError with a Python error #[derive(Debug)] @@ -71,216 +71,98 @@ impl From for PyErr { } } -#[pyclass] -struct PyDataType { - inner: DataType, -} - -#[pyclass] -struct PyField { - inner: Field, -} - -#[pyclass] -struct PySchema { - inner: Schema, -} - -#[pymethods] -impl PyDataType { - #[staticmethod] - fn from_pyarrow(value: &PyAny) -> PyResult { - let c_schema = FFI_ArrowSchema::empty(); - let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; - value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; - let dtype = DataType::try_from(&c_schema).map_err(PyO3ArrowError::from)?; - Ok(Self { inner: dtype }) - } - - fn to_pyarrow(&self, py: Python) -> PyResult { - let c_schema = - FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?; - let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; - let module = py.import("pyarrow")?; - let class = module.getattr("DataType")?; - let dtype = class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; - Ok(dtype.into()) - } -} - -#[pymethods] -impl PyField { - #[staticmethod] - fn from_pyarrow(value: &PyAny) -> PyResult { - let c_schema = FFI_ArrowSchema::empty(); - let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; - value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; - let field = Field::try_from(&c_schema).map_err(PyO3ArrowError::from)?; - Ok(Self { inner: field }) - } - - fn to_pyarrow(&self, py: Python) -> PyResult { - let c_schema = - FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?; - let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; - let module = py.import("pyarrow")?; - let class = module.getattr("Field")?; - let dtype = class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; - Ok(dtype.into()) - } -} - -#[pymethods] -impl PySchema { - #[staticmethod] - fn from_pyarrow(value: &PyAny) -> PyResult { - let c_schema = FFI_ArrowSchema::empty(); - let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; - value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; - let schema = Schema::try_from(&c_schema).map_err(PyO3ArrowError::from)?; - Ok(Self { inner: schema }) - } - - fn to_pyarrow(&self, py: Python) -> PyResult { - let c_schema = - FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?; - let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; - let module = py.import("pyarrow")?; - let class = module.getattr("Schema")?; - let schema = - class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; - Ok(schema.into()) - } -} - -impl<'source> FromPyObject<'source> for PyDataType { - fn extract(value: &'source PyAny) -> PyResult { - PyDataType::from_pyarrow(value) - } -} - -impl<'source> FromPyObject<'source> for PyField { - fn extract(value: &'source PyAny) -> PyResult { - PyField::from_pyarrow(value) - } -} - -impl<'source> FromPyObject<'source> for PySchema { - fn extract(value: &'source PyAny) -> PyResult { - PySchema::from_pyarrow(value) - } -} - -fn array_to_rust(ob: PyObject, py: Python) -> PyResult { - // prepare a pointer to receive the Array struct - let (array_pointer, schema_pointer) = - ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() }); - - // make the conversion through PyArrow's private API - // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds - ob.call_method1( - py, - "_export_to_c", - (array_pointer as uintptr_t, schema_pointer as uintptr_t), - )?; - - let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) } - .map_err(PyO3ArrowError::from)?; - Ok(array) -} - -fn array_to_py(array: ArrayRef, py: Python) -> PyResult { - let (array_pointer, schema_pointer) = array.to_raw().map_err(PyO3ArrowError::from)?; - - let pa = py.import("pyarrow")?; - - let array = pa.getattr("Array")?.call_method1( - "_import_from_c", - (array_pointer as uintptr_t, schema_pointer as uintptr_t), - )?; - Ok(array.to_object(py)) -} - /// Returns `array + array` of an int64 array. #[pyfunction] -fn double(array: PyObject, py: Python) -> PyResult { +fn double(array: &PyAny, py: Python) -> PyResult { // import - let array = array_to_rust(array, py)?; + let array = ArrayRef::from_pyarrow(array)?; // perform some operation let array = array.as_any().downcast_ref::().ok_or_else(|| { PyO3ArrowError::ArrowError(ArrowError::ParseError("Expects an int64".to_string())) })?; let array = kernels::arithmetic::add(&array, &array).map_err(PyO3ArrowError::from)?; - let array = Arc::new(array); // export - array_to_py(array, py) + array.to_pyarrow(py) } /// calls a lambda function that receives and returns an array /// whose result must be the array multiplied by two #[pyfunction] -fn double_py(lambda: PyObject, py: Python) -> PyResult { +fn double_py(lambda: &PyAny, py: Python) -> PyResult { // create let array = Arc::new(Int64Array::from(vec![Some(1), None, Some(3)])); let expected = Arc::new(Int64Array::from(vec![Some(2), None, Some(6)])) as ArrayRef; // to py - let pyarray = array_to_py(array, py)?; - let pyarray = lambda.call1(py, (pyarray,))?; - let array = array_to_rust(pyarray, py)?; + let pyarray = array.to_pyarrow(py)?; + let pyarray = lambda.call1((pyarray,))?; + let array = ArrayRef::from_pyarrow(pyarray)?; Ok(array == expected) } /// Returns the substring #[pyfunction] -fn substring(array: PyObject, start: i64, py: Python) -> PyResult { +fn substring(array: ArrayData, start: i64) -> PyResult { // import - let array = array_to_rust(array, py)?; + let array = ArrayRef::from(array); // substring let array = kernels::substring::substring(array.as_ref(), start, &None) .map_err(PyO3ArrowError::from)?; - // export - array_to_py(array, py) + Ok(array.data().to_owned()) } /// Returns the concatenate #[pyfunction] -fn concatenate(array: PyObject, py: Python) -> PyResult { - // import - let array = array_to_rust(array, py)?; +fn concatenate(array: ArrayData, py: Python) -> PyResult { + let array = ArrayRef::from(array); // concat let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]) .map_err(PyO3ArrowError::from)?; - // export - array_to_py(array, py) + array.to_pyarrow(py) } -/// Converts to rust and back to python #[pyfunction] -fn round_trip(pyarray: PyObject, py: Python) -> PyResult { - // import - let array = array_to_rust(pyarray, py)?; +fn round_trip_type(obj: DataType) -> PyResult { + Ok(obj) +} - // export - array_to_py(array, py) +#[pyfunction] +fn round_trip_field(obj: Field) -> PyResult { + Ok(obj) +} + +#[pyfunction] +fn round_trip_schema(obj: Schema) -> PyResult { + Ok(obj) +} + +#[pyfunction] +fn round_trip_array(obj: ArrayData) -> PyResult { + Ok(obj) +} + +#[pyfunction] +fn round_trip_record_batch(obj: RecordBatch) -> PyResult { + Ok(obj) } #[pymodule] fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; m.add_wrapped(wrap_pyfunction!(double))?; m.add_wrapped(wrap_pyfunction!(double_py))?; m.add_wrapped(wrap_pyfunction!(substring))?; m.add_wrapped(wrap_pyfunction!(concatenate))?; - m.add_wrapped(wrap_pyfunction!(round_trip))?; + m.add_wrapped(wrap_pyfunction!(round_trip_type))?; + m.add_wrapped(wrap_pyfunction!(round_trip_field))?; + m.add_wrapped(wrap_pyfunction!(round_trip_schema))?; + m.add_wrapped(wrap_pyfunction!(round_trip_array))?; + m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?; Ok(()) } diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 301eac8d2a09..ddb9d6dd67f5 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -25,7 +25,6 @@ import pyarrow as pa import pytz -from arrow_pyarrow_integration_testing import PyDataType, PyField, PySchema import arrow_pyarrow_integration_testing as rust @@ -113,8 +112,7 @@ def assert_pyarrow_leak(): @pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str) def test_type_roundtrip(pyarrow_type): - ty = PyDataType.from_pyarrow(pyarrow_type) - restored = ty.to_pyarrow() + restored = rust.round_trip_type(pyarrow_type) assert restored == pyarrow_type assert restored is not pyarrow_type @@ -122,34 +120,34 @@ def test_type_roundtrip(pyarrow_type): @pytest.mark.parametrize("pyarrow_type", _unsupported_pyarrow_types, ids=str) def test_type_roundtrip_raises(pyarrow_type): with pytest.raises(Exception): - PyDataType.from_pyarrow(pyarrow_type) + rust.round_trip_type(pyarrow_type) def test_dictionary_type_roundtrip(): # the dictionary type conversion is incomplete pyarrow_type = pa.dictionary(pa.int32(), pa.string()) - ty = PyDataType.from_pyarrow(pyarrow_type) - assert ty.to_pyarrow() == pa.int32() + ty = rust.round_trip_type(pyarrow_type) + assert ty == pa.int32() @pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str) def test_field_roundtrip(pyarrow_type): pyarrow_field = pa.field("test", pyarrow_type, nullable=True) - field = PyField.from_pyarrow(pyarrow_field) - assert field.to_pyarrow() == pyarrow_field + field = rust.round_trip_field(pyarrow_field) + assert field == pyarrow_field if pyarrow_type != pa.null(): # A null type field may not be non-nullable pyarrow_field = pa.field("test", pyarrow_type, nullable=False) - field = PyField.from_pyarrow(pyarrow_field) - assert field.to_pyarrow() == pyarrow_field + field = rust.round_trip_field(pyarrow_field) + assert field == pyarrow_field def test_schema_roundtrip(): pyarrow_fields = zip(string.ascii_lowercase, _supported_pyarrow_types) pyarrow_schema = pa.schema(pyarrow_fields) - schema = PySchema.from_pyarrow(pyarrow_schema) - assert schema.to_pyarrow() == pyarrow_schema + schema = rust.round_trip_schema(pyarrow_schema) + assert schema == pyarrow_schema def test_primitive_python(): @@ -205,7 +203,7 @@ def test_list_array(): Python -> Rust -> Python """ a = pa.array([[], None, [1, 2], [4, 5, 6]], pa.list_(pa.int64())) - b = rust.round_trip(a) + b = rust.round_trip_array(a) b.validate(full=True) assert a.to_pylist() == b.to_pylist() assert a.type == b.type @@ -261,7 +259,7 @@ def test_decimal_python(): None ] a = pa.array(data, pa.decimal128(6, 2)) - b = rust.round_trip(a) + b = rust.round_trip_array(a) assert a == b del a del b diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 0c8ca76b7890..81ba040be38d 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -50,6 +50,8 @@ chrono = "0.4" flatbuffers = { version = "=2.0.0", optional = true } hex = "0.4" prettytable-rs = { version = "0.8.0", optional = true } +pyo3 = { version = "0.14", optional = true } +libc = { version = "0.2", optional = true } lexical-core = "^0.7" multiversion = "0.6.1" bitflags = "1.2.1" @@ -62,14 +64,15 @@ ipc = ["flatbuffers"] simd = ["packed_simd"] prettyprint = ["prettytable-rs"] # The test utils feature enables code used in benchmarks and tests but -# not the core arrow code itself. Be aware that `rand` must be kept as -# an optional dependency for supporting compile to wasm32-unknown-unknown +# not the core arrow code itself. Be aware that `rand` must be kept as +# an optional dependency for supporting compile to wasm32-unknown-unknown # target without assuming an environment containing JavaScript. test_utils = ["rand"] # this is only intended to be used in single-threaded programs: it verifies that # all allocated memory is being released (no memory leaks). # See README for details memory-check = [] +pyarrow = ["pyo3", "libc"] [dev-dependencies] rand = "0.8" diff --git a/arrow/src/array/array.rs b/arrow/src/array/array.rs index 4702179c7839..f4b79c49f3aa 100644 --- a/arrow/src/array/array.rs +++ b/arrow/src/array/array.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::convert::{From, TryFrom}; use std::fmt; use std::sync::Arc; -use std::{any::Any, convert::TryFrom}; use super::*; use crate::array::equal_json::JsonEqual; @@ -334,6 +335,12 @@ pub fn make_array(data: ArrayData) -> ArrayRef { } } +impl From for ArrayRef { + fn from(data: ArrayData) -> Self { + make_array(data) + } +} + /// Creates a new empty array /// /// ``` diff --git a/arrow/src/lib.rs b/arrow/src/lib.rs index 1932b0d0b6dc..2c2590cb4fc3 100644 --- a/arrow/src/lib.rs +++ b/arrow/src/lib.rs @@ -156,6 +156,8 @@ pub mod ffi; #[cfg(feature = "ipc")] pub mod ipc; pub mod json; +#[cfg(feature = "pyarrow")] +pub mod pyarrow; pub mod record_batch; pub mod temporal_conversions; pub mod tensor; diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs new file mode 100644 index 000000000000..3a188a345482 --- /dev/null +++ b/arrow/src/pyarrow.rs @@ -0,0 +1,221 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This library demonstrates a minimal usage of Rust's C data interface to pass +//! arrays from and to Python. + +use std::convert::{From, TryFrom}; +use std::sync::Arc; + +use libc::uintptr_t; +use pyo3::exceptions::PyException; +use pyo3::prelude::*; +use pyo3::types::PyList; + +use crate::array::{make_array, Array, ArrayData, ArrayRef}; +use crate::datatypes::{DataType, Field, Schema}; +use crate::error::ArrowError; +use crate::ffi; +use crate::ffi::FFI_ArrowSchema; +use crate::record_batch::RecordBatch; + +impl From for PyErr { + fn from(err: ArrowError) -> PyErr { + PyException::new_err(err.to_string()) + } +} + +pub trait PyArrowConvert: Sized { + fn from_pyarrow(value: &PyAny) -> PyResult; + fn to_pyarrow(&self, py: Python) -> PyResult; +} + +impl PyArrowConvert for DataType { + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let dtype = DataType::try_from(&c_schema)?; + Ok(dtype) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = FFI_ArrowSchema::try_from(self)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("DataType")?; + let dtype = class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(dtype.into()) + } +} + +impl PyArrowConvert for Field { + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let field = Field::try_from(&c_schema)?; + Ok(field) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = FFI_ArrowSchema::try_from(self)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("Field")?; + let dtype = class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(dtype.into()) + } +} + +impl PyArrowConvert for Schema { + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let schema = Schema::try_from(&c_schema)?; + Ok(schema) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = FFI_ArrowSchema::try_from(self)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("Schema")?; + let schema = + class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(schema.into()) + } +} + +impl PyArrowConvert for ArrayData { + fn from_pyarrow(value: &PyAny) -> PyResult { + // prepare a pointer to receive the Array struct + let (array_pointer, schema_pointer) = + ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() }); + + // make the conversion through PyArrow's private API + // this changes the pointer's memory and is thus unsafe. + // In particular, `_export_to_c` can go out of bounds + value.call_method1( + "_export_to_c", + (array_pointer as uintptr_t, schema_pointer as uintptr_t), + )?; + + let ffi_array = unsafe { + ffi::ArrowArray::try_from_raw(array_pointer, schema_pointer)? + }; + let data = ArrayData::try_from(ffi_array)?; + + Ok(data) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let array = ffi::ArrowArray::try_from(self.clone())?; + let (array_pointer, schema_pointer) = ffi::ArrowArray::into_raw(array); + + let module = py.import("pyarrow")?; + let class = module.getattr("Array")?; + let array = class.call_method1( + "_import_from_c", + (array_pointer as uintptr_t, schema_pointer as uintptr_t), + )?; + Ok(array.to_object(py)) + } +} + +impl PyArrowConvert for ArrayRef { + fn from_pyarrow(value: &PyAny) -> PyResult { + Ok(make_array(ArrayData::from_pyarrow(value)?)) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + self.data().to_pyarrow(py) + } +} + +impl PyArrowConvert for T +where + T: Array + From, +{ + fn from_pyarrow(value: &PyAny) -> PyResult { + Ok(ArrayData::from_pyarrow(value)?.into()) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + self.data().to_pyarrow(py) + } +} + +impl PyArrowConvert for RecordBatch { + fn from_pyarrow(value: &PyAny) -> PyResult { + // TODO(kszucs): implement the FFI conversions in arrow-rs for RecordBatches + let schema = value.getattr("schema")?; + let schema = Arc::new(Schema::from_pyarrow(schema)?); + + let arrays = value.getattr("columns")?.downcast::()?; + let arrays = arrays + .iter() + .map(ArrayRef::from_pyarrow) + .collect::>()?; + + let batch = RecordBatch::try_new(schema, arrays)?; + Ok(batch) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let mut py_arrays = vec![]; + let mut py_names = vec![]; + + let schema = self.schema(); + let fields = schema.fields().iter(); + let columns = self.columns().iter(); + + for (array, field) in columns.zip(fields) { + py_arrays.push(array.to_pyarrow(py)?); + py_names.push(field.name()); + } + + let module = py.import("pyarrow")?; + let class = module.getattr("RecordBatch")?; + let record = class.call_method1("from_arrays", (py_arrays, py_names))?; + + Ok(PyObject::from(record)) + } +} + +macro_rules! add_conversion { + ($typ:ty) => { + impl<'source> FromPyObject<'source> for $typ { + fn extract(value: &'source PyAny) -> PyResult { + Self::from_pyarrow(value) + } + } + + impl<'a> IntoPy for $typ { + fn into_py(self, py: Python) -> PyObject { + self.to_pyarrow(py).unwrap() + } + } + }; +} + +add_conversion!(DataType); +add_conversion!(Field); +add_conversion!(Schema); +add_conversion!(ArrayData); +add_conversion!(RecordBatch); From 8ecc65cf889cec854cec4a1a9f8cf021addaa12d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 16 Aug 2021 18:01:37 +0200 Subject: [PATCH 2/4] Fix clippy warnings --- arrow-pyarrow-integration-testing/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 659cdeccb577..2b71111af424 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -81,7 +81,7 @@ fn double(array: &PyAny, py: Python) -> PyResult { let array = array.as_any().downcast_ref::().ok_or_else(|| { PyO3ArrowError::ArrowError(ArrowError::ParseError("Expects an int64".to_string())) })?; - let array = kernels::arithmetic::add(&array, &array).map_err(PyO3ArrowError::from)?; + let array = kernels::arithmetic::add(array, array).map_err(PyO3ArrowError::from)?; // export array.to_pyarrow(py) From ca685206eaab3b0133364341be247d8735564d9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 16 Aug 2021 21:53:00 +0200 Subject: [PATCH 3/4] Simplify error handling --- arrow-pyarrow-integration-testing/src/lib.rs | 56 +++---------------- .../tests/test_sql.py | 2 +- arrow/src/pyarrow.rs | 12 ++-- 3 files changed, 16 insertions(+), 54 deletions(-) diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 2b71111af424..2296462cd14a 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -19,10 +19,8 @@ //! arrays from and to Python. use std::error; -use std::fmt; use std::sync::Arc; -use pyo3::exceptions::PyOSError; use pyo3::prelude::*; use pyo3::wrap_pyfunction; //use libc::uintptr_t; @@ -34,43 +32,6 @@ use arrow::error::ArrowError; use arrow::pyarrow::PyArrowConvert; use arrow::record_batch::RecordBatch; -/// an error that bridges ArrowError with a Python error -#[derive(Debug)] -enum PyO3ArrowError { - ArrowError(ArrowError), -} - -impl fmt::Display for PyO3ArrowError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - PyO3ArrowError::ArrowError(ref e) => e.fmt(f), - } - } -} - -impl error::Error for PyO3ArrowError { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - match *self { - // The cause is the underlying implementation error type. Is implicitly - // cast to the trait object `&error::Error`. This works because the - // underlying type already implements the `Error` trait. - PyO3ArrowError::ArrowError(ref e) => Some(e), - } - } -} - -impl From for PyO3ArrowError { - fn from(err: ArrowError) -> PyO3ArrowError { - PyO3ArrowError::ArrowError(err) - } -} - -impl From for PyErr { - fn from(err: PyO3ArrowError) -> PyErr { - PyOSError::new_err(err.to_string()) - } -} - /// Returns `array + array` of an int64 array. #[pyfunction] fn double(array: &PyAny, py: Python) -> PyResult { @@ -78,10 +39,11 @@ fn double(array: &PyAny, py: Python) -> PyResult { let array = ArrayRef::from_pyarrow(array)?; // perform some operation - let array = array.as_any().downcast_ref::().ok_or_else(|| { - PyO3ArrowError::ArrowError(ArrowError::ParseError("Expects an int64".to_string())) - })?; - let array = kernels::arithmetic::add(array, array).map_err(PyO3ArrowError::from)?; + let array = array + .as_any() + .downcast_ref::() + .ok_or(ArrowError::ParseError("Expects an int64".to_string()))?; + let array = kernels::arithmetic::add(array, array)?; // export array.to_pyarrow(py) @@ -110,8 +72,7 @@ fn substring(array: ArrayData, start: i64) -> PyResult { let array = ArrayRef::from(array); // substring - let array = kernels::substring::substring(array.as_ref(), start, &None) - .map_err(PyO3ArrowError::from)?; + let array = kernels::substring::substring(array.as_ref(), start, &None)?; Ok(array.data().to_owned()) } @@ -122,8 +83,7 @@ fn concatenate(array: ArrayData, py: Python) -> PyResult { let array = ArrayRef::from(array); // concat - let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]) - .map_err(PyO3ArrowError::from)?; + let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()])?; array.to_pyarrow(py) } @@ -154,7 +114,7 @@ fn round_trip_record_batch(obj: RecordBatch) -> PyResult { } #[pymodule] -fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { +fn arrow_pyarrow_integration_testing(py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(double))?; m.add_wrapped(wrap_pyfunction!(double_py))?; m.add_wrapped(wrap_pyfunction!(substring))?; diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index ddb9d6dd67f5..bacd1188ce4f 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -119,7 +119,7 @@ def test_type_roundtrip(pyarrow_type): @pytest.mark.parametrize("pyarrow_type", _unsupported_pyarrow_types, ids=str) def test_type_roundtrip_raises(pyarrow_type): - with pytest.raises(Exception): + with pytest.raises(pa.ArrowException): rust.round_trip_type(pyarrow_type) diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 3a188a345482..90b60248a472 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -22,7 +22,7 @@ use std::convert::{From, TryFrom}; use std::sync::Arc; use libc::uintptr_t; -use pyo3::exceptions::PyException; +use pyo3::import_exception; use pyo3::prelude::*; use pyo3::types::PyList; @@ -33,9 +33,12 @@ use crate::ffi; use crate::ffi::FFI_ArrowSchema; use crate::record_batch::RecordBatch; +import_exception!(pyarrow, ArrowException); +pub type PyArrowException = ArrowException; + impl From for PyErr { fn from(err: ArrowError) -> PyErr { - PyException::new_err(err.to_string()) + PyArrowException::new_err(err.to_string()) } } @@ -116,9 +119,8 @@ impl PyArrowConvert for ArrayData { (array_pointer as uintptr_t, schema_pointer as uintptr_t), )?; - let ffi_array = unsafe { - ffi::ArrowArray::try_from_raw(array_pointer, schema_pointer)? - }; + let ffi_array = + unsafe { ffi::ArrowArray::try_from_raw(array_pointer, schema_pointer)? }; let data = ArrayData::try_from(ffi_array)?; Ok(data) From a656da29f39bc03f32556f04e543f08edc411d88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 16 Aug 2021 22:11:40 +0200 Subject: [PATCH 4/4] Fix clippy warnings --- arrow-pyarrow-integration-testing/src/lib.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 2296462cd14a..082a72e9e1ff 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -18,12 +18,10 @@ //! This library demonstrates a minimal usage of Rust's C data interface to pass //! arrays from and to Python. -use std::error; use std::sync::Arc; use pyo3::prelude::*; use pyo3::wrap_pyfunction; -//use libc::uintptr_t; use arrow::array::{ArrayData, ArrayRef, Int64Array}; use arrow::compute::kernels; @@ -114,7 +112,7 @@ fn round_trip_record_batch(obj: RecordBatch) -> PyResult { } #[pymodule] -fn arrow_pyarrow_integration_testing(py: Python, m: &PyModule) -> PyResult<()> { +fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(double))?; m.add_wrapped(wrap_pyfunction!(double_py))?; m.add_wrapped(wrap_pyfunction!(substring))?;