diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index bfb3a93e3249e..9bfdb26d5fdcd 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -42,12 +42,13 @@ simd = ["arrow/simd"] crypto_expressions = ["md-5", "sha2"] regex_expressions = ["regex", "lazy_static"] unicode_expressions = ["unicode-segmentation"] +pyarrow = ["pyo3", "libc", "arrow/pyarrow"] [dependencies] ahash = "0.7" -hashbrown = "0.11" -arrow = { version = "5.1", features = ["prettyprint"] } -parquet = { version = "5.1", features = ["arrow"] } +hashbrown = { version = "0.11", features = ["raw"] } +arrow = { path = "../../arrow-rs/arrow", features = ["prettyprint"] } +parquet = { path = "../../arrow-rs/parquet", features = ["arrow"] } sqlparser = "0.9.0" paste = "^1.0" num_cpus = "1.13.0" @@ -66,6 +67,8 @@ regex = { version = "^1.4.3", optional = true } lazy_static = { version = "^1.4.0", optional = true } smallvec = { version = "1.6", features = ["union"] } rand = "0.8" +pyo3 = { version = "0.14", optional = true } +libc = { version = "0.2", optional = true } [dev-dependencies] criterion = "0.3" diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index d8be372dc8f08..680344d1fc16d 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -230,6 +230,9 @@ pub mod variable; pub use arrow; pub use parquet; +#[cfg(feature = "pyarrow")] +pub mod pyarrow; + #[cfg(test)] pub mod test; pub mod test_util; diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 0dfc1e7aa0480..1352ce39fe4f4 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -214,10 +214,11 @@ impl LogicalPlanBuilder { /// This function errors under any of the following conditions: /// * Two or more expressions have the same name /// * An invalid expression is used (e.g. a `sort` expression) - pub fn project(&self, expr: impl IntoIterator) -> Result { + pub fn project(&self, expr: impl IntoIterator>) -> Result { let input_schema = self.plan.schema(); let mut projected_expr = vec![]; for e in expr { + let e = e.into(); match e { Expr::Wildcard => { projected_expr.extend(expand_wildcard(input_schema, &self.plan)?) @@ -239,8 +240,8 @@ impl LogicalPlanBuilder { } /// Apply a filter - pub fn filter(&self, expr: Expr) -> Result { - let expr = normalize_col(expr, &self.plan)?; + pub fn filter(&self, expr: impl Into) -> Result { + let expr = normalize_col(expr.into(), &self.plan)?; Ok(Self::from(LogicalPlan::Filter { predicate: expr, input: Arc::new(self.plan.clone()), @@ -256,7 +257,7 @@ impl LogicalPlanBuilder { } /// Apply a sort - pub fn sort(&self, exprs: impl IntoIterator) -> Result { + pub fn sort(&self, exprs: impl IntoIterator>) -> Result { Ok(Self::from(LogicalPlan::Sort { expr: normalize_cols(exprs, &self.plan)?, input: Arc::new(self.plan.clone()), @@ -434,8 +435,8 @@ impl LogicalPlanBuilder { /// value of the `group_expr`; pub fn aggregate( &self, - group_expr: impl IntoIterator, - aggr_expr: impl IntoIterator, + group_expr: impl IntoIterator>, + aggr_expr: impl IntoIterator>, ) -> Result { let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 8b0e647261da8..92c1db27add7b 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1227,10 +1227,10 @@ fn normalize_col_with_schemas( /// Recursively normalize all Column expressions in a list of expression trees #[inline] pub fn normalize_cols( - exprs: impl IntoIterator, + exprs: impl IntoIterator>, plan: &LogicalPlan, ) -> Result> { - exprs.into_iter().map(|e| normalize_col(e, plan)).collect() + exprs.into_iter().map(|e| normalize_col(e.into(), plan)).collect() } /// Recursively 'unnormalize' (remove all qualifiers) from an diff --git a/datafusion/src/pyarrow.rs b/datafusion/src/pyarrow.rs new file mode 100644 index 0000000000000..606a8eb9b4ce1 --- /dev/null +++ b/datafusion/src/pyarrow.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::exceptions::{PyException, PyNotImplementedError}; +use pyo3::prelude::*; + +use crate::scalar::ScalarValue; +use crate::arrow::pyarrow::PyArrowConvert; +use crate::error::DataFusionError; + +impl From for PyErr { + fn from(err: DataFusionError) -> PyErr { + PyException::new_err(err.to_string()) + } +} + +impl PyArrowConvert for ScalarValue { + fn from_pyarrow(value: &PyAny) -> PyResult { + let t = value + .getattr("__class__")? + .getattr("__name__")? + .extract::<&str>()?; + + let p = value.call_method0("as_py")?; + + Ok(match t { + "Int8Scalar" => ScalarValue::Int8(Some(p.extract::()?)), + "Int16Scalar" => ScalarValue::Int16(Some(p.extract::()?)), + "Int32Scalar" => ScalarValue::Int32(Some(p.extract::()?)), + "Int64Scalar" => ScalarValue::Int64(Some(p.extract::()?)), + "UInt8Scalar" => ScalarValue::UInt8(Some(p.extract::()?)), + "UInt16Scalar" => ScalarValue::UInt16(Some(p.extract::()?)), + "UInt32Scalar" => ScalarValue::UInt32(Some(p.extract::()?)), + "UInt64Scalar" => ScalarValue::UInt64(Some(p.extract::()?)), + "FloatScalar" => ScalarValue::Float32(Some(p.extract::()?)), + "DoubleScalar" => ScalarValue::Float64(Some(p.extract::()?)), + "BooleanScalar" => ScalarValue::Boolean(Some(p.extract::()?)), + "StringScalar" => ScalarValue::Utf8(Some(p.extract::()?)), + "LargeStringScalar" => ScalarValue::LargeUtf8(Some(p.extract::()?)), + other => { + return Err(DataFusionError::NotImplemented(format!( + "Type \"{}\"not yet implemented", + other + )) + .into()) + } + }) + } + + fn to_pyarrow(&self, _py: Python) -> PyResult { + Err(PyNotImplementedError::new_err("Not implemented")) + } +} + +impl<'source> FromPyObject<'source> for ScalarValue { + fn extract(value: &'source PyAny) -> PyResult { + Self::from_pyarrow(value) + } +} + +impl<'a> IntoPy for ScalarValue { + fn into_py(self, py: Python) -> PyObject { + self.to_pyarrow(py).unwrap() + } +} diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index ef2b63464969b..5919331cddbe5 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -21,6 +21,7 @@ use std::collections::HashSet; use std::str::FromStr; use std::sync::Arc; use std::{convert::TryInto, vec}; +use std::iter; use crate::catalog::TableReference; use crate::datasource::TableProvider; @@ -766,7 +767,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = if select.distinct { return LogicalPlanBuilder::from(plan) - .aggregate(select_exprs_post_aggr, vec![])? + .aggregate(select_exprs_post_aggr, iter::empty::())? .build(); } else { plan diff --git a/python/Cargo.toml b/python/Cargo.toml index fe84e5234c333..8fd38ddf3f4cd 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -30,16 +30,16 @@ edition = "2018" libc = "0.2" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.7" -pyo3 = { version = "0.14.1", features = ["extension-module"] } -datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "4d61196dee8526998aee7e7bb10ea88422e5f9e1" } +pyo3 = { version = "0.14.2", features = ["extension-module"] } +datafusion = { path = "../datafusion", features = ["pyarrow"] } [lib] -name = "datafusion" +name = "internals" crate-type = ["cdylib"] [package.metadata.maturin] +name = "datafusion.internals" requires-dist = ["pyarrow>=1"] - classifier = [ "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py new file mode 100644 index 0000000000000..20bc3f22bfc5c --- /dev/null +++ b/python/datafusion/__init__.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .internals import PyDataFrame as DataFrame +from .internals import PyExecutionContext as ExecutionContext +from .internals import PyExpr as Expr +from .internals import functions + +__all__ = [ + "DataFrame", + "ExecutionContext", + "Expr", + "functions" +] diff --git a/python/tests/__init__.py b/python/datafusion/tests/__init__.py similarity index 100% rename from python/tests/__init__.py rename to python/datafusion/tests/__init__.py diff --git a/python/tests/generic.py b/python/datafusion/tests/generic.py similarity index 100% rename from python/tests/generic.py rename to python/datafusion/tests/generic.py diff --git a/python/tests/test_df.py b/python/datafusion/tests/test_df.py similarity index 99% rename from python/tests/test_df.py rename to python/datafusion/tests/test_df.py index 5b6cbddbd74ba..b04eba53f6fdc 100644 --- a/python/tests/test_df.py +++ b/python/datafusion/tests/test_df.py @@ -17,6 +17,7 @@ import pyarrow as pa import pytest + from datafusion import ExecutionContext from datafusion import functions as f diff --git a/python/tests/test_math_functions.py b/python/datafusion/tests/test_math_functions.py similarity index 99% rename from python/tests/test_math_functions.py rename to python/datafusion/tests/test_math_functions.py index 98656b8c4f422..4e473c3de16ac 100644 --- a/python/tests/test_math_functions.py +++ b/python/datafusion/tests/test_math_functions.py @@ -18,6 +18,7 @@ import numpy as np import pyarrow as pa import pytest + from datafusion import ExecutionContext from datafusion import functions as f diff --git a/python/tests/test_pa_types.py b/python/datafusion/tests/test_pa_types.py similarity index 100% rename from python/tests/test_pa_types.py rename to python/datafusion/tests/test_pa_types.py diff --git a/python/tests/test_sql.py b/python/datafusion/tests/test_sql.py similarity index 99% rename from python/tests/test_sql.py rename to python/datafusion/tests/test_sql.py index 669f640529eb5..d6a16f23b6c85 100644 --- a/python/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -20,6 +20,7 @@ import pytest from datafusion import ExecutionContext + from . import generic as helpers diff --git a/python/tests/test_string_functions.py b/python/datafusion/tests/test_string_functions.py similarity index 99% rename from python/tests/test_string_functions.py rename to python/datafusion/tests/test_string_functions.py index ea064a6b2e9f6..4255d34805a04 100644 --- a/python/tests/test_string_functions.py +++ b/python/datafusion/tests/test_string_functions.py @@ -17,6 +17,7 @@ import pyarrow as pa import pytest + from datafusion import ExecutionContext from datafusion import functions as f diff --git a/python/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py similarity index 99% rename from python/tests/test_udaf.py rename to python/datafusion/tests/test_udaf.py index e7044d6119e38..aca1215a7cb24 100644 --- a/python/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -20,6 +20,7 @@ import pyarrow as pa import pyarrow.compute as pc import pytest + from datafusion import ExecutionContext from datafusion import functions as f diff --git a/python/pyproject.toml b/python/pyproject.toml index 1482129897fae..ce33f58d29173 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -18,3 +18,12 @@ [build-system] requires = ["maturin>=0.11,<0.12"] build-backend = "maturin" + +[project] +name = "datafusion" +dependencies = [ + "pyarrow" +] + +[tool.isort] +profile = "black" diff --git a/python/rust-toolchain b/python/rust-toolchain index 6231a95e3036d..2bf5ad0447d33 100644 --- a/python/rust-toolchain +++ b/python/rust-toolchain @@ -1 +1 @@ -nightly-2021-05-10 +stable diff --git a/python/src/context.rs b/python/src/context.rs index 9acc14a5e2609..402e803a19c02 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -24,41 +24,37 @@ use rand::Rng; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use datafusion::arrow::datatypes::{DataType, Schema}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::MemTable; -use datafusion::execution::context::ExecutionContext as _ExecutionContext; +use datafusion::execution::context::ExecutionContext; use datafusion::prelude::CsvReadOptions; -use crate::dataframe; -use crate::errors; +use crate::dataframe::PyDataFrame; +use crate::errors::DataFusionError; use crate::functions; -use crate::to_rust; -use crate::types::PyDataType; -/// `ExecutionContext` is able to plan and execute DataFusion plans. +/// `PyExecutionContext` is able to plan and execute DataFusion plans. /// It has a powerful optimizer, a physical planner for local execution, and a /// multi-threaded execution engine to perform the execution. #[pyclass(unsendable)] -pub(crate) struct ExecutionContext { - ctx: _ExecutionContext, +pub(crate) struct PyExecutionContext { + ctx: ExecutionContext, } #[pymethods] -impl ExecutionContext { +impl PyExecutionContext { #[new] fn new() -> Self { - ExecutionContext { - ctx: _ExecutionContext::new(), + PyExecutionContext { + ctx: ExecutionContext::new(), } } - /// Returns a DataFrame whose plan corresponds to the SQL statement. - fn sql(&mut self, query: &str) -> PyResult { - let df = self - .ctx - .sql(query) - .map_err(|e| -> errors::DataFusionError { e.into() })?; - Ok(dataframe::DataFrame::new( + /// Returns a PyDataFrame whose plan corresponds to the SQL statement. + fn sql(&mut self, query: &str) -> PyResult { + let df = self.ctx.sql(query).map_err(DataFusionError::from)?; + Ok(PyDataFrame::new( self.ctx.state.clone(), df.to_logical_plan(), )) @@ -66,21 +62,10 @@ impl ExecutionContext { fn create_dataframe( &mut self, - partitions: Vec>, - py: Python, - ) -> PyResult { - let partitions: Vec> = partitions - .iter() - .map(|batches| { - batches - .iter() - .map(|batch| to_rust::to_rust_batch(batch.as_ref(py))) - .collect() - }) - .collect::>()?; - - let table = - errors::wrap(MemTable::try_new(partitions[0][0].schema(), partitions))?; + partitions: Vec>, + ) -> PyResult { + let table = MemTable::try_new(partitions[0][0].schema(), partitions) + .map_err(DataFusionError::from)?; // generate a random (unique) name for this table let name = rand::thread_rng() @@ -88,15 +73,19 @@ impl ExecutionContext { .take(10) .collect::(); - errors::wrap(self.ctx.register_table(&*name, Arc::new(table)))?; - Ok(dataframe::DataFrame::new( - self.ctx.state.clone(), - errors::wrap(self.ctx.table(&*name))?.to_logical_plan(), - )) + self.ctx + .register_table(&*name, Arc::new(table)) + .map_err(DataFusionError::from)?; + let table = self.ctx.table(&*name).map_err(DataFusionError::from)?; + + let df = PyDataFrame::new(self.ctx.state.clone(), table.to_logical_plan()); + Ok(df) } fn register_parquet(&mut self, name: &str, path: &str) -> PyResult<()> { - errors::wrap(self.ctx.register_parquet(name, path))?; + self.ctx + .register_parquet(name, path) + .map_err(DataFusionError::from)?; Ok(()) } @@ -111,7 +100,7 @@ impl ExecutionContext { &mut self, name: &str, path: PathBuf, - schema: Option<&PyAny>, + schema: Option, has_header: bool, delimiter: &str, schema_infer_max_records: usize, @@ -120,10 +109,6 @@ impl ExecutionContext { let path = path .to_str() .ok_or(PyValueError::new_err("Unable to convert path to a string"))?; - let schema = match schema { - Some(s) => Some(to_rust::to_rust_schema(s)?), - None => None, - }; let delimiter = delimiter.as_bytes(); if delimiter.len() != 1 { return Err(PyValueError::new_err( @@ -138,7 +123,9 @@ impl ExecutionContext { .file_extension(file_extension); options.schema = schema.as_ref(); - errors::wrap(self.ctx.register_csv(name, path, options))?; + self.ctx + .register_csv(name, path, options) + .map_err(DataFusionError::from)?; Ok(()) } @@ -146,12 +133,12 @@ impl ExecutionContext { &mut self, name: &str, func: PyObject, - args_types: Vec, - return_type: PyDataType, - ) { - let function = functions::create_udf(func, args_types, return_type, name); - + args_types: Vec, + return_type: DataType, + ) -> PyResult<()> { + let function = functions::create_udf(func, args_types, return_type, name)?; self.ctx.register_udf(function.function); + Ok(()) } fn tables(&self) -> HashSet { diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs index 4a50262ec3292..41bf960c51b6e 100644 --- a/python/src/dataframe.rs +++ b/python/src/dataframe.rs @@ -15,91 +15,76 @@ // specific language governing permissions and limitations // under the License. +use std::convert::From; use std::sync::{Arc, Mutex}; -use logical_plan::LogicalPlan; -use pyo3::{prelude::*, types::PyTuple}; +use pyo3::prelude::*; use tokio::runtime::Runtime; -use datafusion::execution::context::ExecutionContext as _ExecutionContext; -use datafusion::logical_plan::{JoinType, LogicalPlanBuilder}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::execution::context::{ExecutionContext, ExecutionContextState}; +use datafusion::logical_plan::{JoinType, LogicalPlan, LogicalPlanBuilder}; use datafusion::physical_plan::collect; -use datafusion::{execution::context::ExecutionContextState, logical_plan}; -use crate::{errors, to_py}; -use crate::{errors::DataFusionError, expression}; +use crate::{errors, errors::DataFusionError, expression, expression::PyExpr}; -/// A DataFrame is a representation of a logical plan and an API to compose statements. +/// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. /// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. #[pyclass] -pub(crate) struct DataFrame { +pub(crate) struct PyDataFrame { ctx_state: Arc>, plan: LogicalPlan, } -impl DataFrame { - /// creates a new DataFrame +impl PyDataFrame { + /// creates a new PyDataFrame pub fn new(ctx_state: Arc>, plan: LogicalPlan) -> Self { Self { ctx_state, plan } } } #[pymethods] -impl DataFrame { - /// Select `expressions` from the existing DataFrame. +impl PyDataFrame { + /// Select `expressions` from the existing PyDataFrame. #[args(args = "*")] - fn select(&self, args: &PyTuple) -> PyResult { - let expressions = expression::from_tuple(args)?; + fn select(&self, args: Vec) -> PyResult { let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = - errors::wrap(builder.project(expressions.into_iter().map(|e| e.expr)))?; - let plan = errors::wrap(builder.build())?; + let plan = builder.project(args)?.build()?; - Ok(DataFrame { + Ok(PyDataFrame { ctx_state: self.ctx_state.clone(), plan, }) } /// Filter according to the `predicate` expression - fn filter(&self, predicate: expression::Expression) -> PyResult { + fn filter(&self, predicate: PyExpr) -> PyResult { let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = errors::wrap(builder.filter(predicate.expr))?; - let plan = errors::wrap(builder.build())?; + let plan = builder.filter(predicate)?.build()?; - Ok(DataFrame { + Ok(PyDataFrame { ctx_state: self.ctx_state.clone(), plan, }) } /// Aggregates using expressions - fn aggregate( - &self, - group_by: Vec, - aggs: Vec, - ) -> PyResult { + fn aggregate(&self, group_by: Vec, aggs: Vec) -> PyResult { let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = errors::wrap(builder.aggregate( - group_by.into_iter().map(|e| e.expr), - aggs.into_iter().map(|e| e.expr), - ))?; - let plan = errors::wrap(builder.build())?; + let plan = builder.aggregate(group_by, aggs)?.build()?; - Ok(DataFrame { + Ok(PyDataFrame { ctx_state: self.ctx_state.clone(), plan, }) } /// Sort by specified sorting expressions - fn sort(&self, exprs: Vec) -> PyResult { - let exprs = exprs.into_iter().map(|e| e.expr); + fn sort(&self, exprs: Vec) -> PyResult { let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = errors::wrap(builder.sort(exprs))?; - let plan = errors::wrap(builder.build())?; - Ok(DataFrame { + let plan = builder.sort(exprs)?.build()?; + Ok(PyDataFrame { ctx_state: self.ctx_state.clone(), plan, }) @@ -108,10 +93,9 @@ impl DataFrame { /// Limits the plan to return at most `count` rows fn limit(&self, count: usize) -> PyResult { let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = errors::wrap(builder.limit(count))?; - let plan = errors::wrap(builder.build())?; + let plan = builder.limit(count)?.build()?; - Ok(DataFrame { + Ok(PyDataFrame { ctx_state: self.ctx_state.clone(), plan, }) @@ -119,28 +103,23 @@ impl DataFrame { /// Executes the plan, returning a list of `RecordBatch`es. /// Unless some order is specified in the plan, there is no guarantee of the order of the result - fn collect(&self, py: Python) -> PyResult { - let ctx = _ExecutionContext::from(self.ctx_state.clone()); - let plan = ctx - .optimize(&self.plan) - .map_err(|e| -> errors::DataFusionError { e.into() })?; + fn collect(&self, py: Python) -> PyResult> { + let ctx = ExecutionContext::from(self.ctx_state.clone()); + let plan = ctx.optimize(&self.plan).map_err(DataFusionError::from)?; let plan = ctx .create_physical_plan(&plan) - .map_err(|e| -> errors::DataFusionError { e.into() })?; + .map_err(DataFusionError::from)?; let rt = Runtime::new().unwrap(); let batches = py.allow_threads(|| { - rt.block_on(async { - collect(plan) - .await - .map_err(|e| -> errors::DataFusionError { e.into() }) - }) + rt.block_on(async { collect(plan).await.map_err(DataFusionError::from) }) })?; - to_py::to_py(&batches) + + Ok(batches) } - /// Returns the join of two DataFrames `on`. - fn join(&self, right: &DataFrame, on: Vec<&str>, how: &str) -> PyResult { + /// Returns the join of two PyDataFrames `on`. + fn join(&self, right: &PyDataFrame, on: Vec<&str>, how: &str) -> PyResult { let builder = LogicalPlanBuilder::from(self.plan.clone()); let join_type = match how { @@ -159,11 +138,11 @@ impl DataFrame { } }; - let builder = errors::wrap(builder.join(&right.plan, join_type, on.clone(), on))?; - - let plan = errors::wrap(builder.build())?; + let plan = builder + .join_using(&right.plan, join_type, on.clone())? + .build()?; - Ok(DataFrame { + Ok(PyDataFrame { ctx_state: self.ctx_state.clone(), plan, }) diff --git a/python/src/errors.rs b/python/src/errors.rs index fbe98037a030f..cc181a98755d4 100644 --- a/python/src/errors.rs +++ b/python/src/errors.rs @@ -16,10 +16,11 @@ // under the License. use core::fmt; +//use std::result::Result; use datafusion::arrow::error::ArrowError; use datafusion::error::DataFusionError as InnerDataFusionError; -use pyo3::{exceptions, PyErr}; +use pyo3::{exceptions::PyException, PyErr}; #[derive(Debug)] pub enum DataFusionError { @@ -38,9 +39,9 @@ impl fmt::Display for DataFusionError { } } -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { - exceptions::PyException::new_err(err.to_string()) +impl From for DataFusionError { + fn from(err: ArrowError) -> DataFusionError { + DataFusionError::ArrowError(err) } } @@ -50,9 +51,9 @@ impl From for DataFusionError { } } -impl From for DataFusionError { - fn from(err: ArrowError) -> DataFusionError { - DataFusionError::ArrowError(err) +impl From for PyErr { + fn from(err: DataFusionError) -> PyErr { + PyException::new_err(err.to_string()) } } diff --git a/python/src/expression.rs b/python/src/expression.rs index 4320b1d14c8b7..016dc2c94e1c1 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -18,91 +18,98 @@ use pyo3::{ basic::CompareOp, prelude::*, types::PyTuple, PyNumberProtocol, PyObjectProtocol, }; +use std::convert::From; +use std::vec::Vec; -use datafusion::logical_plan::Expr as _Expr; -use datafusion::physical_plan::udaf::AggregateUDF as _AggregateUDF; -use datafusion::physical_plan::udf::ScalarUDF as _ScalarUDF; +use datafusion::logical_plan::Expr; +use datafusion::physical_plan::{udaf::AggregateUDF, udf::ScalarUDF}; -/// An expression that can be used on a DataFrame +/// An PyExpr that can be used on a DataFrame #[pyclass] #[derive(Debug, Clone)] -pub(crate) struct Expression { - pub(crate) expr: _Expr, +pub(crate) struct PyExpr { + pub(crate) expr: Expr, +} + +impl From for Expr { + fn from(expr: PyExpr) -> Expr { + expr.expr + } } /// converts a tuple of expressions into a vector of Expressions -pub(crate) fn from_tuple(value: &PyTuple) -> PyResult> { +pub(crate) fn from_tuple(value: &PyTuple) -> PyResult> { value .iter() - .map(|e| e.extract::()) + .map(|e| e.extract::()) .collect::>() } #[pyproto] -impl PyNumberProtocol for Expression { - fn __add__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { +impl PyNumberProtocol for PyExpr { + fn __add__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(PyExpr { expr: lhs.expr + rhs.expr, }) } - fn __sub__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { + fn __sub__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(PyExpr { expr: lhs.expr - rhs.expr, }) } - fn __truediv__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { + fn __truediv__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(PyExpr { expr: lhs.expr / rhs.expr, }) } - fn __mul__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { + fn __mul__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(PyExpr { expr: lhs.expr * rhs.expr, }) } - fn __and__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { + fn __and__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(PyExpr { expr: lhs.expr.and(rhs.expr), }) } - fn __or__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { + fn __or__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(PyExpr { expr: lhs.expr.or(rhs.expr), }) } - fn __invert__(&self) -> PyResult { - Ok(Expression { + fn __invert__(&self) -> PyResult { + Ok(PyExpr { expr: self.expr.clone().not(), }) } } #[pyproto] -impl PyObjectProtocol for Expression { - fn __richcmp__(&self, other: Expression, op: CompareOp) -> Expression { +impl PyObjectProtocol for PyExpr { + fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr { match op { - CompareOp::Lt => Expression { + CompareOp::Lt => PyExpr { expr: self.expr.clone().lt(other.expr), }, - CompareOp::Le => Expression { + CompareOp::Le => PyExpr { expr: self.expr.clone().lt_eq(other.expr), }, - CompareOp::Eq => Expression { + CompareOp::Eq => PyExpr { expr: self.expr.clone().eq(other.expr), }, - CompareOp::Ne => Expression { + CompareOp::Ne => PyExpr { expr: self.expr.clone().not_eq(other.expr), }, - CompareOp::Gt => Expression { + CompareOp::Gt => PyExpr { expr: self.expr.clone().gt(other.expr), }, - CompareOp::Ge => Expression { + CompareOp::Ge => PyExpr { expr: self.expr.clone().gt_eq(other.expr), }, } @@ -110,39 +117,39 @@ impl PyObjectProtocol for Expression { } #[pymethods] -impl Expression { - /// assign a name to the expression - pub fn alias(&self, name: &str) -> PyResult { - Ok(Expression { +impl PyExpr { + /// assign a name to the PyExpr + pub fn alias(&self, name: &str) -> PyResult { + Ok(PyExpr { expr: self.expr.clone().alias(name), }) } - /// Create a sort expression from an existing expression. + /// Create a sort PyExpr from an existing PyExpr. #[args(ascending = true, nulls_first = true)] - pub fn sort(&self, ascending: bool, nulls_first: bool) -> PyResult { - Ok(Expression { + pub fn sort(&self, ascending: bool, nulls_first: bool) -> PyResult { + Ok(PyExpr { expr: self.expr.clone().sort(ascending, nulls_first), }) } } -/// Represents a ScalarUDF +/// Represents a PyScalarUDF #[pyclass] #[derive(Debug, Clone)] -pub struct ScalarUDF { - pub(crate) function: _ScalarUDF, +pub struct PyScalarUDF { + pub(crate) function: ScalarUDF, } #[pymethods] -impl ScalarUDF { - /// creates a new expression with the call of the udf +impl PyScalarUDF { + /// creates a new PyExpr with the call of the udf #[call] #[args(args = "*")] - fn __call__(&self, args: &PyTuple) -> PyResult { + fn __call__(&self, args: &PyTuple) -> PyResult { let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect(); - Ok(Expression { + Ok(PyExpr { expr: self.function.call(args), }) } @@ -151,19 +158,19 @@ impl ScalarUDF { /// Represents a AggregateUDF #[pyclass] #[derive(Debug, Clone)] -pub struct AggregateUDF { - pub(crate) function: _AggregateUDF, +pub struct PyAggregateUDF { + pub(crate) function: AggregateUDF, } #[pymethods] -impl AggregateUDF { - /// creates a new expression with the call of the udf +impl PyAggregateUDF { + /// creates a new PyExpr with the call of the udf #[call] #[args(args = "*")] - fn __call__(&self, args: &PyTuple) -> PyResult { + fn __call__(&self, args: &PyTuple) -> PyResult { let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect(); - Ok(Expression { + Ok(PyExpr { expr: self.function.call(args), }) } diff --git a/python/src/functions.rs b/python/src/functions.rs index 23f010a6ae45c..ad1409c9151ee 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -15,46 +15,46 @@ // specific language governing permissions and limitations // under the License. -use crate::udaf; -use crate::udf; -use crate::{expression, types::PyDataType}; +use std::sync::Arc; + use datafusion::arrow::datatypes::DataType; use datafusion::logical_plan; use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction}; -use std::sync::Arc; -/// Expression representing a column on the existing plan. +use crate::{ + expression, + expression::{PyAggregateUDF, PyExpr, PyScalarUDF}, + udaf, udf, +}; + +/// PyExpr representing a column on the existing plan. #[pyfunction] #[pyo3(text_signature = "(name)")] -fn col(name: &str) -> expression::Expression { - expression::Expression { +fn col(name: &str) -> PyExpr { + PyExpr { expr: logical_plan::col(name), } } -/// Expression representing a constant value +/// PyExpr representing a constant value #[pyfunction] #[pyo3(text_signature = "(value)")] -fn lit(value: i32) -> expression::Expression { - expression::Expression { +fn lit(value: i32) -> PyExpr { + PyExpr { expr: logical_plan::lit(value), } } #[pyfunction] -fn array(value: Vec) -> expression::Expression { - expression::Expression { +fn array(value: Vec) -> PyExpr { + PyExpr { expr: logical_plan::array(value.into_iter().map(|x| x.expr).collect::>()), } } #[pyfunction] -fn in_list( - expr: expression::Expression, - value: Vec, - negated: bool, -) -> expression::Expression { - expression::Expression { +fn in_list(expr: PyExpr, value: Vec, negated: bool) -> PyExpr { + PyExpr { expr: logical_plan::in_list( expr.expr, value.into_iter().map(|x| x.expr).collect::>(), @@ -65,8 +65,8 @@ fn in_list( /// Current date and time #[pyfunction] -fn now() -> expression::Expression { - expression::Expression { +fn now() -> PyExpr { + PyExpr { // here lit(0) is a stub for conform to arity expr: logical_plan::now(logical_plan::lit(0)), } @@ -74,8 +74,8 @@ fn now() -> expression::Expression { /// Returns a random value in the range 0.0 <= x < 1.0 #[pyfunction] -fn random() -> expression::Expression { - expression::Expression { +fn random() -> PyExpr { + PyExpr { expr: logical_plan::random(), } } @@ -83,10 +83,10 @@ fn random() -> expression::Expression { /// Concatenates the text representations of all the arguments. /// NULL arguments are ignored. #[pyfunction(args = "*")] -fn concat(args: &PyTuple) -> PyResult { +fn concat(args: &PyTuple) -> PyResult { let expressions = expression::from_tuple(args)?; let args = expressions.into_iter().map(|e| e.expr).collect::>(); - Ok(expression::Expression { + Ok(PyExpr { expr: logical_plan::concat(&args), }) } @@ -95,10 +95,10 @@ fn concat(args: &PyTuple) -> PyResult { /// The first argument is used as the separator string, and should not be NULL. /// Other NULL arguments are ignored. #[pyfunction(sep, args = "*")] -fn concat_ws(sep: String, args: &PyTuple) -> PyResult { +fn concat_ws(sep: String, args: &PyTuple) -> PyResult { let expressions = expression::from_tuple(args)?; let args = expressions.into_iter().map(|e| e.expr).collect::>(); - Ok(expression::Expression { + Ok(PyExpr { expr: logical_plan::concat_ws(sep, &args), }) } @@ -107,8 +107,8 @@ macro_rules! define_unary_function { ($NAME: ident) => { #[doc = "This function is not documented yet"] #[pyfunction] - fn $NAME(value: expression::Expression) -> expression::Expression { - expression::Expression { + fn $NAME(value: PyExpr) -> PyExpr { + PyExpr { expr: logical_plan::$NAME(value.expr), } } @@ -116,8 +116,8 @@ macro_rules! define_unary_function { ($NAME: ident, $DOC: expr) => { #[doc = $DOC] #[pyfunction] - fn $NAME(value: expression::Expression) -> expression::Expression { - expression::Expression { + fn $NAME(value: PyExpr) -> PyExpr { + PyExpr { expr: logical_plan::$NAME(value.expr), } } @@ -202,61 +202,55 @@ define_unary_function!(count); pub(crate) fn create_udf( fun: PyObject, - input_types: Vec, - return_type: PyDataType, + input_types: Vec, + return_type: DataType, name: &str, -) -> expression::ScalarUDF { - let input_types: Vec = - input_types.iter().map(|d| d.data_type.clone()).collect(); - let return_type = Arc::new(return_type.data_type); +) -> PyResult { + //let return_type = Arc::new(DataType::from_pyarrow(return_type)?); - expression::ScalarUDF { + Ok(PyScalarUDF { function: logical_plan::create_udf( name, input_types, - return_type, + Arc::new(return_type), udf::array_udf(fun), ), - } + }) } /// Creates a new udf. #[pyfunction] fn udf( fun: PyObject, - input_types: Vec, - return_type: PyDataType, + input_types: Vec, + return_type: DataType, py: Python, -) -> PyResult { +) -> PyResult { let name = fun.getattr(py, "__qualname__")?.extract::(py)?; - Ok(create_udf(fun, input_types, return_type, &name)) + create_udf(fun, input_types, return_type, &name) } /// Creates a new udf. #[pyfunction] fn udaf( accumulator: PyObject, - input_type: PyDataType, - return_type: PyDataType, - state_type: Vec, + input_type: DataType, + return_type: DataType, + state_type: Vec, py: Python, -) -> PyResult { +) -> PyResult { let name = accumulator .getattr(py, "__qualname__")? .extract::(py)?; - let input_type = input_type.data_type; - let return_type = Arc::new(return_type.data_type); - let state_type = Arc::new(state_type.into_iter().map(|t| t.data_type).collect()); - - Ok(expression::AggregateUDF { + Ok(PyAggregateUDF { function: logical_plan::create_udaf( &name, input_type, - return_type, + Arc::new(return_type), udaf::array_udaf(accumulator), - state_type, + Arc::new(state_type), ), }) } diff --git a/python/src/lib.rs b/python/src/lib.rs index aecfe9994cd1a..902afa87ce63e 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -22,19 +22,15 @@ mod dataframe; mod errors; mod expression; mod functions; -mod scalar; -mod to_py; -mod to_rust; -mod types; mod udaf; mod udf; /// DataFusion. #[pymodule] -fn datafusion(py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; +fn internals(py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; let functions = PyModule::new(py, "functions")?; functions::init(functions)?; diff --git a/python/src/scalar.rs b/python/src/scalar.rs deleted file mode 100644 index 0c562a9403616..0000000000000 --- a/python/src/scalar.rs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use pyo3::prelude::*; - -use datafusion::scalar::ScalarValue as _Scalar; - -use crate::to_rust::to_rust_scalar; - -/// An expression that can be used on a DataFrame -#[derive(Debug, Clone)] -pub(crate) struct Scalar { - pub(crate) scalar: _Scalar, -} - -impl<'source> FromPyObject<'source> for Scalar { - fn extract(ob: &'source PyAny) -> PyResult { - Ok(Self { - scalar: to_rust_scalar(ob)?, - }) - } -} diff --git a/python/src/to_py.rs b/python/src/to_py.rs deleted file mode 100644 index 6bc0581c8c70a..0000000000000 --- a/python/src/to_py.rs +++ /dev/null @@ -1,75 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::array::ArrayRef; -use datafusion::arrow::record_batch::RecordBatch; -use libc::uintptr_t; -use pyo3::prelude::*; -use pyo3::types::PyList; -use pyo3::PyErr; -use std::convert::From; - -use crate::errors; - -pub fn to_py_array(array: &ArrayRef, py: Python) -> PyResult { - let (array_pointer, schema_pointer) = - array.to_raw().map_err(errors::DataFusionError::from)?; - - let pa = py.import("pyarrow")?; - - let array = pa.getattr("Array")?.call_method1( - "_import_from_c", - (array_pointer as uintptr_t, schema_pointer as uintptr_t), - )?; - Ok(array.to_object(py)) -} - -fn to_py_batch<'a>( - batch: &RecordBatch, - py: Python, - pyarrow: &'a PyModule, -) -> Result { - let mut py_arrays = vec![]; - let mut py_names = vec![]; - - let schema = batch.schema(); - for (array, field) in batch.columns().iter().zip(schema.fields().iter()) { - let array = to_py_array(array, py)?; - - py_arrays.push(array); - py_names.push(field.name()); - } - - let record = pyarrow - .getattr("RecordBatch")? - .call_method1("from_arrays", (py_arrays, py_names))?; - - Ok(PyObject::from(record)) -} - -/// Converts a &[RecordBatch] into a Vec represented in PyArrow -pub fn to_py(batches: &[RecordBatch]) -> PyResult { - Python::with_gil(|py| { - let pyarrow = PyModule::import(py, "pyarrow")?; - let mut py_batches = vec![]; - for batch in batches { - py_batches.push(to_py_batch(batch, py, pyarrow)?); - } - let list = PyList::new(py, py_batches); - Ok(PyObject::from(list)) - }) -} diff --git a/python/src/to_rust.rs b/python/src/to_rust.rs deleted file mode 100644 index 7977fe4ff8ce1..0000000000000 --- a/python/src/to_rust.rs +++ /dev/null @@ -1,122 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::convert::TryFrom; -use std::sync::Arc; - -use datafusion::arrow::{ - array::{make_array_from_raw, ArrayRef}, - datatypes::Field, - datatypes::Schema, - ffi, - record_batch::RecordBatch, -}; -use datafusion::scalar::ScalarValue; -use libc::uintptr_t; -use pyo3::prelude::*; - -use crate::{errors, types::PyDataType}; - -/// converts a pyarrow Array into a Rust Array -pub fn to_rust(ob: &PyAny) -> PyResult { - // prepare a pointer to receive the Array struct - let (array_pointer, schema_pointer) = - ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() }); - - // make the conversion through PyArrow's private API - // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds - ob.call_method1( - "_export_to_c", - (array_pointer as uintptr_t, schema_pointer as uintptr_t), - )?; - - let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) } - .map_err(errors::DataFusionError::from)?; - Ok(array) -} - -/// converts a pyarrow batch into a RecordBatch -pub fn to_rust_batch(batch: &PyAny) -> PyResult { - let schema = batch.getattr("schema")?; - let names = schema.getattr("names")?.extract::>()?; - - let fields = names - .iter() - .enumerate() - .map(|(i, name)| { - let field = schema.call_method1("field", (i,))?; - let nullable = field.getattr("nullable")?.extract::()?; - let py_data_type = field.getattr("type")?; - let data_type = py_data_type.extract::()?.data_type; - Ok(Field::new(name, data_type, nullable)) - }) - .collect::>()?; - - let schema = Arc::new(Schema::new(fields)); - - let arrays = (0..names.len()) - .map(|i| { - let array = batch.call_method1("column", (i,))?; - to_rust(array) - }) - .collect::>()?; - - let batch = - RecordBatch::try_new(schema, arrays).map_err(errors::DataFusionError::from)?; - Ok(batch) -} - -/// converts a pyarrow Scalar into a Rust Scalar -pub fn to_rust_scalar(ob: &PyAny) -> PyResult { - let t = ob - .getattr("__class__")? - .getattr("__name__")? - .extract::<&str>()?; - - let p = ob.call_method0("as_py")?; - - Ok(match t { - "Int8Scalar" => ScalarValue::Int8(Some(p.extract::()?)), - "Int16Scalar" => ScalarValue::Int16(Some(p.extract::()?)), - "Int32Scalar" => ScalarValue::Int32(Some(p.extract::()?)), - "Int64Scalar" => ScalarValue::Int64(Some(p.extract::()?)), - "UInt8Scalar" => ScalarValue::UInt8(Some(p.extract::()?)), - "UInt16Scalar" => ScalarValue::UInt16(Some(p.extract::()?)), - "UInt32Scalar" => ScalarValue::UInt32(Some(p.extract::()?)), - "UInt64Scalar" => ScalarValue::UInt64(Some(p.extract::()?)), - "FloatScalar" => ScalarValue::Float32(Some(p.extract::()?)), - "DoubleScalar" => ScalarValue::Float64(Some(p.extract::()?)), - "BooleanScalar" => ScalarValue::Boolean(Some(p.extract::()?)), - "StringScalar" => ScalarValue::Utf8(Some(p.extract::()?)), - "LargeStringScalar" => ScalarValue::LargeUtf8(Some(p.extract::()?)), - other => { - return Err(errors::DataFusionError::Common(format!( - "Type \"{}\"not yet implemented", - other - )) - .into()) - } - }) -} - -pub fn to_rust_schema(ob: &PyAny) -> PyResult { - let c_schema = ffi::FFI_ArrowSchema::empty(); - let c_schema_ptr = &c_schema as *const ffi::FFI_ArrowSchema; - ob.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; - let schema = Schema::try_from(&c_schema).map_err(errors::DataFusionError::from)?; - Ok(schema) -} diff --git a/python/src/types.rs b/python/src/types.rs deleted file mode 100644 index bd6ef0d376e63..0000000000000 --- a/python/src/types.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::datatypes::DataType; -use pyo3::{FromPyObject, PyAny, PyResult}; - -use crate::errors; - -/// utility struct to convert PyObj to native DataType -#[derive(Debug, Clone)] -pub struct PyDataType { - pub data_type: DataType, -} - -impl<'source> FromPyObject<'source> for PyDataType { - fn extract(ob: &'source PyAny) -> PyResult { - let id = ob.getattr("id")?.extract::()?; - let data_type = data_type_id(&id)?; - Ok(PyDataType { data_type }) - } -} - -fn data_type_id(id: &i32) -> Result { - // see https://github.com/apache/arrow/blob/3694794bdfd0677b95b8c95681e392512f1c9237/python/pyarrow/includes/libarrow.pxd - // this is not ideal as it does not generalize for non-basic types - // Find a way to get a unique name from the pyarrow.DataType - Ok(match id { - 1 => DataType::Boolean, - 2 => DataType::UInt8, - 3 => DataType::Int8, - 4 => DataType::UInt16, - 5 => DataType::Int16, - 6 => DataType::UInt32, - 7 => DataType::Int32, - 8 => DataType::UInt64, - 9 => DataType::Int64, - 10 => DataType::Float16, - 11 => DataType::Float32, - 12 => DataType::Float64, - 13 => DataType::Utf8, - 14 => DataType::Binary, - 34 => DataType::LargeUtf8, - 35 => DataType::LargeBinary, - other => { - return Err(errors::DataFusionError::Common(format!( - "The type {} is not valid", - other - ))) - } - }) -} diff --git a/python/src/udaf.rs b/python/src/udaf.rs index 83e8be05db603..756afe68c31e7 100644 --- a/python/src/udaf.rs +++ b/python/src/udaf.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use pyo3::{prelude::*, types::PyTuple}; use datafusion::arrow::array::ArrayRef; +use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::error::Result; use datafusion::{ @@ -27,10 +28,6 @@ use datafusion::{ scalar::ScalarValue, }; -use crate::scalar::Scalar; -use crate::to_py::to_py_array; -use crate::to_rust::to_rust_scalar; - #[derive(Debug)] struct PyAccumulator { accum: PyObject, @@ -43,18 +40,9 @@ impl PyAccumulator { } impl Accumulator for PyAccumulator { - fn state(&self) -> Result> { - Python::with_gil(|py| { - let state = self - .accum - .as_ref(py) - .call_method0("to_scalars") - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))? - .extract::>() - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - Ok(state.into_iter().map(|v| v.scalar).collect::>()) - }) + fn state(&self) -> Result> { + Python::with_gil(|py| self.accum.as_ref(py).call_method0("to_scalars")?.extract()) + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) } fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { @@ -67,17 +55,9 @@ impl Accumulator for PyAccumulator { todo!() } - fn evaluate(&self) -> Result { - Python::with_gil(|py| { - let value = self - .accum - .as_ref(py) - .call_method0("evaluate") - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - to_rust_scalar(value) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) - }) + fn evaluate(&self) -> Result { + Python::with_gil(|py| self.accum.as_ref(py).call_method0("evaluate")?.extract()) + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -88,10 +68,7 @@ impl Accumulator for PyAccumulator { // 1. let py_args = values .iter() - .map(|arg| { - // remove unwrap - to_py_array(arg, py).unwrap() - }) + .map(|arg| arg.data().to_owned().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args); @@ -111,7 +88,8 @@ impl Accumulator for PyAccumulator { // 2. merge let state = &states[0]; - let state = to_py_array(state, py) + let state = state + .to_pyarrow(py) .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; // 2. diff --git a/python/src/udf.rs b/python/src/udf.rs index 49a18d9932412..fa77e4ab3257b 100644 --- a/python/src/udf.rs +++ b/python/src/udf.rs @@ -15,15 +15,12 @@ // specific language governing permissions and limitations // under the License. -use pyo3::{prelude::*, types::PyTuple}; - -use datafusion::{arrow::array, physical_plan::functions::make_scalar_function}; - +use datafusion::arrow::array::ArrayRef; +use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::error::DataFusionError; use datafusion::physical_plan::functions::ScalarFunctionImplementation; - -use crate::to_py::to_py_array; -use crate::to_rust::to_rust; +use datafusion::{arrow::array, physical_plan::functions::make_scalar_function}; +use pyo3::{prelude::*, types::PyTuple}; /// creates a DataFusion's UDF implementation from a python function that expects pyarrow arrays /// This is more efficient as it performs a zero-copy of the contents. @@ -38,10 +35,7 @@ pub fn array_udf(func: PyObject) -> ScalarFunctionImplementation { // 1. let py_args = args .iter() - .map(|arg| { - // remove unwrap - to_py_array(arg, py).unwrap() - }) + .map(|arg| arg.data().to_owned().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args); @@ -52,7 +46,7 @@ pub fn array_udf(func: PyObject) -> ScalarFunctionImplementation { Err(error) => Err(DataFusionError::Execution(format!("{:?}", error))), }?; - let array = to_rust(value).unwrap(); + let array = ArrayRef::from_pyarrow(value).unwrap(); Ok(array) }) },