diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 4b633d4bc9e5..d466d67efa6f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -230,6 +230,55 @@ jobs: # do not produce debug symbols to keep memory usage down RUSTFLAGS: "-C debuginfo=0" + test-datafusion-pyarrow: + needs: [linux-build-lib] + runs-on: ubuntu-latest + strategy: + matrix: + arch: [amd64] + rust: [stable] + container: + image: ${{ matrix.arch }}/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v2 + with: + submodules: true + - name: Cache Cargo + uses: actions/cache@v2 + with: + path: /github/home/.cargo + # this key equals the ones on `linux-build-lib` for re-use + key: cargo-cache- + - name: Cache Rust dependencies + uses: actions/cache@v2 + with: + path: /github/home/target + # this key equals the ones on `linux-build-lib` for re-use + key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }} + - uses: actions/setup-python@v2 + with: + python-version: "3.8" + - name: Install PyArrow + run: | + echo "LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV + python -m pip install pyarrow + - name: Setup Rust toolchain + run: | + rustup toolchain install ${{ matrix.rust }} + rustup default ${{ matrix.rust }} + rustup component add rustfmt + - name: Run tests + run: | + cd datafusion + cargo test --features=pyarrow + env: + CARGO_HOME: "/github/home/.cargo" + CARGO_TARGET_DIR: "/github/home/target" + lint: name: Lint runs-on: ubuntu-latest diff --git a/datafusion/src/pyarrow.rs b/datafusion/src/pyarrow.rs index da05d63d8c2c..d819b2b41154 100644 --- a/datafusion/src/pyarrow.rs +++ b/datafusion/src/pyarrow.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -use pyo3::exceptions::{PyException, PyNotImplementedError}; +use pyo3::exceptions::PyException; use pyo3::prelude::*; use pyo3::types::PyList; -use pyo3::PyNativeType; use crate::arrow::array::ArrayData; use crate::arrow::pyarrow::PyArrowConvert; @@ -49,8 +48,13 @@ impl PyArrowConvert for ScalarValue { Ok(scalar) } - fn to_pyarrow(&self, _py: Python) -> PyResult { - Err(PyNotImplementedError::new_err("Not implemented")) + fn to_pyarrow(&self, py: Python) -> PyResult { + let array = self.to_array(); + // convert to pyarrow array using C data interface + let pyarray = array.data_ref().clone().into_py(py); + let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?; + + Ok(pyscalar) } } @@ -65,3 +69,82 @@ impl<'a> IntoPy for ScalarValue { self.to_pyarrow(py).unwrap() } } + +#[cfg(test)] +mod tests { + use super::*; + use pyo3::prepare_freethreaded_python; + use pyo3::py_run; + use pyo3::types::PyDict; + use pyo3::Python; + + fn init_python() { + prepare_freethreaded_python(); + Python::with_gil(|py| { + if let Err(err) = py.run("import pyarrow", None, None) { + let locals = PyDict::new(py); + py.run( + "import sys; executable = sys.executable; python_path = sys.path", + None, + Some(locals), + ) + .expect("Couldn't get python info"); + let executable: String = + locals.get_item("executable").unwrap().extract().unwrap(); + let python_path: Vec<&str> = + locals.get_item("python_path").unwrap().extract().unwrap(); + + Err(err).expect( + format!( + "pyarrow not found\nExecutable: {}\nPython path: {:?}\n\ + HINT: try `pip install pyarrow`\n\ + NOTE: On Mac OS, you must compile against a Framework Python \ + (default in python.org installers and brew, but not pyenv)\n\ + NOTE: On Mac OS, PYO3 might point to incorrect Python library \ + path when using virtual environments. Try \ + `export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n", + executable, python_path + ) + .as_ref(), + ) + } + }) + } + + #[test] + fn test_roundtrip() { + init_python(); + + let example_scalars = vec![ + ScalarValue::Boolean(Some(true)), + ScalarValue::Int32(Some(23)), + ScalarValue::Float64(Some(12.34)), + ScalarValue::Utf8(Some("Hello!".to_string())), + ScalarValue::Date32(Some(1234)), + ]; + + Python::with_gil(|py| { + for scalar in example_scalars.iter() { + let result = + ScalarValue::from_pyarrow(scalar.to_pyarrow(py).unwrap().as_ref(py)) + .unwrap(); + assert_eq!(scalar, &result); + } + }); + } + + #[test] + fn test_py_scalar() { + init_python(); + + Python::with_gil(|py| { + let scalar_float = ScalarValue::Float64(Some(12.34)); + let py_float = scalar_float.into_py(py).call_method0(py, "as_py").unwrap(); + py_run!(py, py_float, "assert py_float == 12.34"); + + let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string())); + let py_string = scalar_string.into_py(py).call_method0(py, "as_py").unwrap(); + py_run!(py, py_string, "assert py_string == 'Hello!'"); + }); + } +}