diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index 49ea8572142e8..60dde7861104a 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -17,17 +17,13 @@ //! Conversions between PyArrow and DataFusion types -// TODO update to pyo3 new APIs -// See: https://pyo3.rs/v0.23.0/migration -#![allow(deprecated)] - use arrow::array::ArrayData; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; use arrow_array::Array; use pyo3::exceptions::PyException; use pyo3::prelude::PyErr; use pyo3::types::{PyAnyMethods, PyList}; -use pyo3::{Bound, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python}; +use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyObject, PyResult, Python}; use crate::{DataFusionError, ScalarValue}; @@ -44,8 +40,8 @@ impl FromPyArrow for ScalarValue { let val = value.call_method0("as_py")?; // construct pyarrow array from the python value and pyarrow type - let factory = py.import_bound("pyarrow")?.getattr("array")?; - let args = PyList::new_bound(py, [val]); + let factory = py.import("pyarrow")?.getattr("array")?; + let args = PyList::new(py, [val])?; let array = factory.call1((args, typ))?; // convert the pyarrow array to rust array using C data interface @@ -73,14 +69,25 @@ impl<'source> FromPyObject<'source> for ScalarValue { } } -impl IntoPy for ScalarValue { - fn into_py(self, py: Python) -> PyObject { - self.to_pyarrow(py).unwrap() +impl<'source> IntoPyObject<'source> for ScalarValue { + type Target = PyAny; + + type Output = Bound<'source, Self::Target>; + + type Error = PyErr; + + fn into_pyobject(self, py: Python<'source>) -> Result { + let array = self.to_array()?; + // convert to pyarrow array using C data interface + let pyarray = array.to_data().to_pyarrow(py)?; + let pyarray_bound = pyarray.bind(py); + pyarray_bound.call_method1("__getitem__", (0,)) } } #[cfg(test)] mod tests { + use pyo3::ffi::c_str; use pyo3::prepare_freethreaded_python; use pyo3::py_run; use pyo3::types::PyDict; @@ -90,10 +97,12 @@ mod tests { fn init_python() { prepare_freethreaded_python(); Python::with_gil(|py| { - if py.run_bound("import pyarrow", None, None).is_err() { - let locals = PyDict::new_bound(py); - py.run_bound( - "import sys; executable = sys.executable; python_path = sys.path", + if py.run(c_str!("import pyarrow"), None, None).is_err() { + let locals = PyDict::new(py); + py.run( + c_str!( + "import sys; executable = sys.executable; python_path = sys.path" + ), None, Some(&locals), ) @@ -139,17 +148,25 @@ mod tests { } #[test] - fn test_py_scalar() { + fn test_py_scalar() -> PyResult<()> { init_python(); - Python::with_gil(|py| { + Python::with_gil(|py| -> PyResult<()> { let scalar_float = ScalarValue::Float64(Some(12.34)); - let py_float = scalar_float.into_py(py).call_method0(py, "as_py").unwrap(); + let py_float = scalar_float + .into_pyobject(py)? + .call_method0("as_py") + .unwrap(); py_run!(py, py_float, "assert py_float == 12.34"); let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string())); - let py_string = scalar_string.into_py(py).call_method0(py, "as_py").unwrap(); + let py_string = scalar_string + .into_pyobject(py)? + .call_method0("as_py") + .unwrap(); py_run!(py, py_string, "assert py_string == 'Hello!'"); - }); + + Ok(()) + }) } }