Skip to content

Commit

Permalink
f
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Aug 13, 2021
1 parent 90318df commit 5d3eb09
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 100 deletions.
134 changes: 55 additions & 79 deletions arrow-pyarrow-integration-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
//! This library demonstrates a minimal usage of Rust's C data interface to pass
//! arrays from and to Python.
use std::convert::TryFrom;
use std::error;
use std::fmt;
use std::sync::Arc;
Expand All @@ -28,15 +27,12 @@ use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
//use libc::uintptr_t;

use arrow::array::{Array, ArrayData, ArrayRef, Int32Array, Int64Array};
use arrow::array::{ArrayData, ArrayRef, Int64Array};
use arrow::compute::kernels;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
//use arrow::ffi;
use arrow::ffi::FFI_ArrowSchema;
use arrow::pyarrow::PyArrowConvert;

use pyo3::FromPyObject;
use arrow::record_batch::RecordBatch;

/// an error that bridges ArrowError with a Python error
#[derive(Debug)]
Expand Down Expand Up @@ -75,39 +71,6 @@ impl From<PyO3ArrowError> for PyErr {
}
}

#[pyclass]
struct PyDataType {
inner: DataType,
}

#[pyclass]
struct PyField {
inner: Field,
}

#[pyclass]
struct PySchema {
inner: Schema,
}

// impl<'source> FromPyObject<'source> for PyDataType {
// fn extract(value: &'source PyAny) -> PyResult<Self> {
// PyDataType::from_pyarrow(value)
// }
// }

// impl<'source> FromPyObject<'source> for PyField {
// fn extract(value: &'source PyAny) -> PyResult<Self> {
// PyField::from_pyarrow(value)
// }
// }

// impl<'source> FromPyObject<'source> for PySchema {
// fn extract(value: &'source PyAny) -> PyResult<Self> {
// PySchema::from_pyarrow(value)
// }
// }

/// Returns `array + array` of an int64 array.
#[pyfunction]
fn double(array: &PyAny, py: Python) -> PyResult<PyObject> {
Expand All @@ -127,66 +90,79 @@ fn double(array: &PyAny, py: Python) -> PyResult<PyObject> {
/// calls a lambda function that receives and returns an array
/// whose result must be the array multiplied by two
#[pyfunction]
fn double_py(lambda: PyObject, py: Python) -> PyResult<bool> {
fn double_py(lambda: &PyAny, py: Python) -> PyResult<bool> {
// create
let array = Arc::new(Int64Array::from(vec![Some(1), None, Some(3)]));
let expected = Arc::new(Int64Array::from(vec![Some(2), None, Some(6)])) as ArrayRef;

// to py
let pyarray = array.into_py()?;
let pyarray = lambda.call1(py, (pyarray,))?;
let pyarray = array.to_pyarrow(py)?;
let pyarray = lambda.call1((pyarray,))?;
let array = ArrayRef::from_pyarrow(pyarray)?;

Ok(array == expected)
}

// /// Returns the substring
// #[pyfunction]
// fn substring(array: PyObject, start: i64, py: Python) -> PyResult<PyObject> {
// // import
// let array = array_to_rust(array, py)?;
/// Returns the substring
#[pyfunction]
fn substring(array: ArrayData, start: i64) -> PyResult<ArrayData> {
// import
let array = ArrayRef::from(array);

// substring
let array = kernels::substring::substring(array.as_ref(), start, &None)
.map_err(PyO3ArrowError::from)?;

Ok(array.data().to_owned())
}

/// Returns the concatenate
#[pyfunction]
fn concatenate(array: ArrayData) -> PyResult<ArrayData> {
let array = ArrayRef::from(array);

// // substring
// let array = kernels::substring::substring(array.as_ref(), start, &None)
// .map_err(PyO3ArrowError::from)?;
// concat
let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()])
.map_err(PyO3ArrowError::from)?;

// // export
// array_to_py(array, py)
// }
Ok(array.data().to_owned())
}

// /// Returns the concatenate
// #[pyfunction]
// fn concatenate(array: PyObject, py: Python) -> PyResult<PyObject> {
// // import
// let array = array_to_rust(array, py)?;
#[pyfunction]
fn round_trip_type(obj: DataType) -> PyResult<DataType> {
Ok(obj)
}

// // concat
// let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()])
// .map_err(PyO3ArrowError::from)?;
#[pyfunction]
fn round_trip_field(obj: Field) -> PyResult<Field> {
Ok(obj)
}

// // export
// array_to_py(array, py)
// }
#[pyfunction]
fn round_trip_schema(obj: Schema) -> PyResult<Schema> {
Ok(obj)
}

// /// Converts to rust and back to python
// #[pyfunction]
// fn round_trip(pyarray: PyObject, py: Python) -> PyResult<PyObject> {
// // import
// let array = array_to_rust(pyarray, py)?;
#[pyfunction]
fn round_trip_array(obj: ArrayData) -> PyResult<ArrayData> {
Ok(obj)
}

// // export
// array_to_py(array, py)
// }
#[pyfunction]
fn round_trip_record_batch(obj: RecordBatch) -> PyResult<RecordBatch> {
Ok(obj)
}

#[pymodule]
fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDataType>()?;
m.add_class::<PyField>()?;
m.add_class::<PySchema>()?;
m.add_wrapped(wrap_pyfunction!(double))?;
// m.add_wrapped(wrap_pyfunction!(double_py))?;
// m.add_wrapped(wrap_pyfunction!(substring))?;
// m.add_wrapped(wrap_pyfunction!(concatenate))?;
// m.add_wrapped(wrap_pyfunction!(round_trip))?;
m.add_wrapped(wrap_pyfunction!(double_py))?;
m.add_wrapped(wrap_pyfunction!(substring))?;
m.add_wrapped(wrap_pyfunction!(concatenate))?;
m.add_wrapped(wrap_pyfunction!(round_trip_type))?;
m.add_wrapped(wrap_pyfunction!(round_trip_field))?;
m.add_wrapped(wrap_pyfunction!(round_trip_schema))?;
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
Ok(())
}
26 changes: 12 additions & 14 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import pyarrow as pa
import pytz

from arrow_pyarrow_integration_testing import PyDataType, PyField, PySchema
import arrow_pyarrow_integration_testing as rust


Expand Down Expand Up @@ -113,43 +112,42 @@ def assert_pyarrow_leak():

@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
def test_type_roundtrip(pyarrow_type):
ty = PyDataType.from_pyarrow(pyarrow_type)
restored = ty.to_pyarrow()
restored = rust.round_trip_type(pyarrow_type)
assert restored == pyarrow_type
assert restored is not pyarrow_type


@pytest.mark.parametrize("pyarrow_type", _unsupported_pyarrow_types, ids=str)
def test_type_roundtrip_raises(pyarrow_type):
with pytest.raises(Exception):
PyDataType.from_pyarrow(pyarrow_type)
rust.round_trip_type(pyarrow_type)


def test_dictionary_type_roundtrip():
# the dictionary type conversion is incomplete
pyarrow_type = pa.dictionary(pa.int32(), pa.string())
ty = PyDataType.from_pyarrow(pyarrow_type)
assert ty.to_pyarrow() == pa.int32()
ty = rust.round_trip_type(pyarrow_type)
assert ty == pa.int32()


@pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str)
def test_field_roundtrip(pyarrow_type):
pyarrow_field = pa.field("test", pyarrow_type, nullable=True)
field = PyField.from_pyarrow(pyarrow_field)
assert field.to_pyarrow() == pyarrow_field
field = rust.round_trip_field(pyarrow_field)
assert field == pyarrow_field

if pyarrow_type != pa.null():
# A null type field may not be non-nullable
pyarrow_field = pa.field("test", pyarrow_type, nullable=False)
field = PyField.from_pyarrow(pyarrow_field)
assert field.to_pyarrow() == pyarrow_field
field = rust.round_trip_field(pyarrow_field)
assert field == pyarrow_field


def test_schema_roundtrip():
pyarrow_fields = zip(string.ascii_lowercase, _supported_pyarrow_types)
pyarrow_schema = pa.schema(pyarrow_fields)
schema = PySchema.from_pyarrow(pyarrow_schema)
assert schema.to_pyarrow() == pyarrow_schema
schema = rust.round_trip_schema(pyarrow_schema)
assert schema == pyarrow_schema


def test_primitive_python():
Expand Down Expand Up @@ -205,7 +203,7 @@ def test_list_array():
Python -> Rust -> Python
"""
a = pa.array([[], None, [1, 2], [4, 5, 6]], pa.list_(pa.int64()))
b = rust.round_trip(a)
b = rust.round_trip_array(a)
b.validate(full=True)
assert a.to_pylist() == b.to_pylist()
assert a.type == b.type
Expand Down Expand Up @@ -261,7 +259,7 @@ def test_decimal_python():
None
]
a = pa.array(data, pa.decimal128(6, 2))
b = rust.round_trip(a)
b = rust.round_trip_array(a)
assert a == b
del a
del b
16 changes: 9 additions & 7 deletions arrow/src/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@
//! This library demonstrates a minimal usage of Rust's C data interface to pass
//! arrays from and to Python.
use std::convert::From;
use std::convert::TryFrom;
use std::convert::{From, TryFrom};
use std::sync::Arc;

use libc::uintptr_t;
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use pyo3::types::PyList;

use crate::array::{make_array, ArrayData, Array, ArrayRef};
use crate::array::{make_array, Array, ArrayData, ArrayRef};
use crate::datatypes::{DataType, Field, Schema};
use crate::error::ArrowError;
use crate::ffi;
Expand Down Expand Up @@ -151,7 +150,10 @@ impl PyArrowConvert for ArrayRef {
}
}

impl<T> PyArrowConvert for T where T: Array + From<ArrayData> {
impl<T> PyArrowConvert for T
where
T: Array + From<ArrayData>,
{
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
Ok(ArrayData::from_pyarrow(value)?.into())
}
Expand Down Expand Up @@ -198,8 +200,8 @@ impl PyArrowConvert for RecordBatch {
}
}

macro_rules! add_conversion{
($typ:ty)=>{
macro_rules! add_conversion {
($typ:ty) => {
impl<'source> FromPyObject<'source> for $typ {
fn extract(value: &'source PyAny) -> PyResult<Self> {
Self::from_pyarrow(value)
Expand All @@ -211,7 +213,7 @@ macro_rules! add_conversion{
self.to_pyarrow(py).unwrap()
}
}
}
};
}

add_conversion!(DataType);
Expand Down

0 comments on commit 5d3eb09

Please sign in to comment.