Skip to content

Commit

Permalink
Rework the python bindings using conversion traits from arrow-rs
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Aug 13, 2021
1 parent 4ddd2f5 commit 89bbaea
Show file tree
Hide file tree
Showing 30 changed files with 356 additions and 589 deletions.
9 changes: 6 additions & 3 deletions datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ simd = ["arrow/simd"]
crypto_expressions = ["md-5", "sha2"]
regex_expressions = ["regex", "lazy_static"]
unicode_expressions = ["unicode-segmentation"]
pyarrow = ["pyo3", "libc", "arrow/pyarrow"]

[dependencies]
ahash = "0.7"
hashbrown = "0.11"
arrow = { version = "5.1", features = ["prettyprint"] }
parquet = { version = "5.1", features = ["arrow"] }
hashbrown = { version = "0.11", features = ["raw"] }
arrow = { path = "../../arrow-rs/arrow", features = ["prettyprint"] }
parquet = { path = "../../arrow-rs/parquet", features = ["arrow"] }
sqlparser = "0.9.0"
paste = "^1.0"
num_cpus = "1.13.0"
Expand All @@ -66,6 +67,8 @@ regex = { version = "^1.4.3", optional = true }
lazy_static = { version = "^1.4.0", optional = true }
smallvec = { version = "1.6", features = ["union"] }
rand = "0.8"
pyo3 = { version = "0.14", optional = true }
libc = { version = "0.2", optional = true }

[dev-dependencies]
criterion = "0.3"
Expand Down
3 changes: 3 additions & 0 deletions datafusion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ pub mod variable;
pub use arrow;
pub use parquet;

#[cfg(feature = "pyarrow")]
pub mod pyarrow;

#[cfg(test)]
pub mod test;
pub mod test_util;
Expand Down
13 changes: 7 additions & 6 deletions datafusion/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,11 @@ impl LogicalPlanBuilder {
/// This function errors under any of the following conditions:
/// * Two or more expressions have the same name
/// * An invalid expression is used (e.g. a `sort` expression)
pub fn project(&self, expr: impl IntoIterator<Item = Expr>) -> Result<Self> {
pub fn project(&self, expr: impl IntoIterator<Item = impl Into<Expr>>) -> Result<Self> {
let input_schema = self.plan.schema();
let mut projected_expr = vec![];
for e in expr {
let e = e.into();
match e {
Expr::Wildcard => {
projected_expr.extend(expand_wildcard(input_schema, &self.plan)?)
Expand All @@ -239,8 +240,8 @@ impl LogicalPlanBuilder {
}

/// Apply a filter
pub fn filter(&self, expr: Expr) -> Result<Self> {
let expr = normalize_col(expr, &self.plan)?;
pub fn filter(&self, expr: impl Into<Expr>) -> Result<Self> {
let expr = normalize_col(expr.into(), &self.plan)?;
Ok(Self::from(LogicalPlan::Filter {
predicate: expr,
input: Arc::new(self.plan.clone()),
Expand All @@ -256,7 +257,7 @@ impl LogicalPlanBuilder {
}

/// Apply a sort
pub fn sort(&self, exprs: impl IntoIterator<Item = Expr>) -> Result<Self> {
pub fn sort(&self, exprs: impl IntoIterator<Item = impl Into<Expr>>) -> Result<Self> {
Ok(Self::from(LogicalPlan::Sort {
expr: normalize_cols(exprs, &self.plan)?,
input: Arc::new(self.plan.clone()),
Expand Down Expand Up @@ -434,8 +435,8 @@ impl LogicalPlanBuilder {
/// value of the `group_expr`;
pub fn aggregate(
&self,
group_expr: impl IntoIterator<Item = Expr>,
aggr_expr: impl IntoIterator<Item = Expr>,
group_expr: impl IntoIterator<Item = impl Into<Expr>>,
aggr_expr: impl IntoIterator<Item = impl Into<Expr>>,
) -> Result<Self> {
let group_expr = normalize_cols(group_expr, &self.plan)?;
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
Expand Down
4 changes: 2 additions & 2 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1227,10 +1227,10 @@ fn normalize_col_with_schemas(
/// Recursively normalize all Column expressions in a list of expression trees
#[inline]
pub fn normalize_cols(
exprs: impl IntoIterator<Item = Expr>,
exprs: impl IntoIterator<Item = impl Into<Expr>>,
plan: &LogicalPlan,
) -> Result<Vec<Expr>> {
exprs.into_iter().map(|e| normalize_col(e, plan)).collect()
exprs.into_iter().map(|e| normalize_col(e.into(), plan)).collect()
}

/// Recursively 'unnormalize' (remove all qualifiers) from an
Expand Down
79 changes: 79 additions & 0 deletions datafusion/src/pyarrow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use pyo3::exceptions::{PyException, PyNotImplementedError};
use pyo3::prelude::*;

use crate::scalar::ScalarValue;
use crate::arrow::pyarrow::PyArrowConvert;
use crate::error::DataFusionError;

impl From<DataFusionError> for PyErr {
fn from(err: DataFusionError) -> PyErr {
PyException::new_err(err.to_string())
}
}

impl PyArrowConvert for ScalarValue {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
let t = value
.getattr("__class__")?
.getattr("__name__")?
.extract::<&str>()?;

let p = value.call_method0("as_py")?;

Ok(match t {
"Int8Scalar" => ScalarValue::Int8(Some(p.extract::<i8>()?)),
"Int16Scalar" => ScalarValue::Int16(Some(p.extract::<i16>()?)),
"Int32Scalar" => ScalarValue::Int32(Some(p.extract::<i32>()?)),
"Int64Scalar" => ScalarValue::Int64(Some(p.extract::<i64>()?)),
"UInt8Scalar" => ScalarValue::UInt8(Some(p.extract::<u8>()?)),
"UInt16Scalar" => ScalarValue::UInt16(Some(p.extract::<u16>()?)),
"UInt32Scalar" => ScalarValue::UInt32(Some(p.extract::<u32>()?)),
"UInt64Scalar" => ScalarValue::UInt64(Some(p.extract::<u64>()?)),
"FloatScalar" => ScalarValue::Float32(Some(p.extract::<f32>()?)),
"DoubleScalar" => ScalarValue::Float64(Some(p.extract::<f64>()?)),
"BooleanScalar" => ScalarValue::Boolean(Some(p.extract::<bool>()?)),
"StringScalar" => ScalarValue::Utf8(Some(p.extract::<String>()?)),
"LargeStringScalar" => ScalarValue::LargeUtf8(Some(p.extract::<String>()?)),
other => {
return Err(DataFusionError::NotImplemented(format!(
"Type \"{}\"not yet implemented",
other
))
.into())
}
})
}

fn to_pyarrow(&self, _py: Python) -> PyResult<PyObject> {
Err(PyNotImplementedError::new_err("Not implemented"))
}
}

impl<'source> FromPyObject<'source> for ScalarValue {
fn extract(value: &'source PyAny) -> PyResult<Self> {
Self::from_pyarrow(value)
}
}

impl<'a> IntoPy<PyObject> for ScalarValue {
fn into_py(self, py: Python) -> PyObject {
self.to_pyarrow(py).unwrap()
}
}
3 changes: 2 additions & 1 deletion datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::collections::HashSet;
use std::str::FromStr;
use std::sync::Arc;
use std::{convert::TryInto, vec};
use std::iter;

use crate::catalog::TableReference;
use crate::datasource::TableProvider;
Expand Down Expand Up @@ -766,7 +767,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

let plan = if select.distinct {
return LogicalPlanBuilder::from(plan)
.aggregate(select_exprs_post_aggr, vec![])?
.aggregate(select_exprs_post_aggr, iter::empty::<Expr>())?
.build();
} else {
plan
Expand Down
8 changes: 4 additions & 4 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ edition = "2018"
libc = "0.2"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
rand = "0.7"
pyo3 = { version = "0.14.1", features = ["extension-module"] }
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "4d61196dee8526998aee7e7bb10ea88422e5f9e1" }
pyo3 = { version = "0.14.2", features = ["extension-module"] }
datafusion = { path = "../datafusion", features = ["pyarrow"] }

[lib]
name = "datafusion"
name = "internals"
crate-type = ["cdylib"]

[package.metadata.maturin]
name = "datafusion.internals"
requires-dist = ["pyarrow>=1"]

classifier = [
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
Expand Down
28 changes: 28 additions & 0 deletions python/datafusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from .internals import PyDataFrame as DataFrame
from .internals import PyExecutionContext as ExecutionContext
from .internals import PyExpr as Expr
from .internals import functions

__all__ = [
"DataFrame",
"ExecutionContext",
"Expr",
"functions"
]
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pyarrow as pa
import pytest

from datafusion import ExecutionContext
from datafusion import functions as f

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import pyarrow as pa
import pytest

from datafusion import ExecutionContext
from datafusion import functions as f

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest

from datafusion import ExecutionContext

from . import generic as helpers


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pyarrow as pa
import pytest

from datafusion import ExecutionContext
from datafusion import functions as f

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pyarrow as pa
import pyarrow.compute as pc
import pytest

from datafusion import ExecutionContext
from datafusion import functions as f

Expand Down
9 changes: 9 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,12 @@
[build-system]
requires = ["maturin>=0.11,<0.12"]
build-backend = "maturin"

[project]
name = "datafusion"
dependencies = [
"pyarrow"
]

[tool.isort]
profile = "black"
2 changes: 1 addition & 1 deletion python/rust-toolchain
Original file line number Diff line number Diff line change
@@ -1 +1 @@
nightly-2021-05-10
stable
Loading

0 comments on commit 89bbaea

Please sign in to comment.