Skip to content

Commit

Permalink
Python wrapper classes for all user interfaces (#750)
Browse files Browse the repository at this point in the history
* Expose missing functions to python

* Initial commit for creating wrapper classes and functions for all user facing python features

* Remove extra level of python path that is no longer required

* Move import to only happen for type checking for hints

* Comment out classes from __all__ in the top level that are not currently exposed.

* Add license comments

* Add missing import

* Functions now only has one level of depth

* Applying google docstring formatting

* Addressing PR request to add google formatted docstrings

* Small docstring for ruff

* Linting

* Add docstring format checking to pre-commit stage

* Set explicit return types on UDFs

* Add options of passing either a path or a string

* Switch to google docstring style

* Update unit tests to include registering via path or string

* Add py.typed file

* Resolve deprecation warnings in unit tests

* Add path to unit test

* Expose an option in write_csv to include header and add unit test

* Update write_parquet unit test to include paths or strings

* Add unit test for write_json

* Add unit test for substrait serialization to a file

* Add unit tests for runtime config

* Setting return type to typing_extensions.Self per PR recommendation

* Correcting __next__ to not return None since it will raise an exception instead.

* Add optiona parameter of decimal places to round and add unit test

* Improve docstrings

* Set default to None instead of empty dict

* User request to allow passing multiple arguments to filter()

* Enhance Expr comparison operators to accept any python value and attempt to convert it to a literal

* Expose overlay and add unit test

* Allow select() to take either str for column names or a full expr

* Update comments on regexp and add unit tests

* Remove TODO markings no longer applicable

* Update udf documentation

* Docstring formatting

* Updating docstring formatting

* Updating docstring formatting

* Updating docstring formatting

* Updating docstring formatting

* Updating docstring formatting

* Cleaning up docstring line lengths

* Add pre-commit check of docstring line length

* Do not emit doc entry for __init__ of some classes

* Correct errors on code blocks generating in sphinx

* Resolve conflict with

* Add license info to py.typed

* Clean up some docstring too long errors in CI

* Correct ruff complain in unit tests

* Temporarily install google test to get clippy to pass

* Adding gmock to build step due to upstream error

* Add type_extensions to conda meta file

* Small comment suggestions from PR
  • Loading branch information
timsaucer authored Jul 18, 2024
1 parent faa5a3f commit aa8aa9c
Show file tree
Hide file tree
Showing 40 changed files with 4,441 additions and 288 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ jobs:
name: python-wheel-license
path: .

# To remove once https://github.com/MaterializeInc/rust-protobuf-native/issues/20 is resolved
- name: Install gtest
uses: MarkusJx/[email protected]

- name: Install Protoc
uses: arduino/setup-protoc@v1
with:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ jobs:
version: '3.20.2'
repo-token: ${{ secrets.GITHUB_TOKEN }}

# To remove once https://github.com/MaterializeInc/rust-protobuf-native/issues/20 is resolved
- name: Install gtest
uses: MarkusJx/[email protected]

- name: Setup Python
uses: actions/setup-python@v5
with:
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/db-benchmark/join-datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def ans_shape(batches):
ctx = df.SessionContext()
print(ctx)

# TODO we should be applying projections to these table reads to crete relations of different sizes
# TODO we should be applying projections to these table reads to create relations
# of different sizes

x_data = pacsv.read_csv(
src_jn_x, convert_options=pacsv.ConvertOptions(auto_dict_encode=True)
Expand Down
1 change: 1 addition & 0 deletions conda/recipes/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ requirements:
run:
- python
- pyarrow >=11.0.0
- typing_extensions

test:
imports:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ Functions
.. autosummary::
:toctree: ../generated/

functions.functions
functions
21 changes: 21 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.

"""Documenation generation."""

# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
Expand Down Expand Up @@ -78,6 +80,25 @@

autosummary_generate = True


def autodoc_skip_member(app, what, name, obj, skip, options):
exclude_functions = "__init__"
exclude_classes = ("Expr", "DataFrame")

class_name = ""
if hasattr(obj, "__qualname__"):
if obj.__qualname__ is not None:
class_name = obj.__qualname__.split(".")[0]

should_exclude = name in exclude_functions and class_name in exclude_classes

return True if should_exclude else None


def setup(app):
app.connect("autodoc-skip-member", autodoc_skip_member)


# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages. See the documentation for
Expand Down
15 changes: 5 additions & 10 deletions examples/substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@
from datafusion import SessionContext
from datafusion import substrait as ss


# Create a DataFusion context
ctx = SessionContext()

# Register table with context
ctx.register_csv("aggregate_test_data", "./testing/data/csv/aggregate_test_100.csv")

substrait_plan = ss.substrait.serde.serialize_to_plan(
"SELECT * FROM aggregate_test_data", ctx
)
substrait_plan = ss.Serde.serialize_to_plan("SELECT * FROM aggregate_test_data", ctx)
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>

# Encode it to bytes
Expand All @@ -38,17 +35,15 @@
# Alternative serialization approaches
# type(substrait_bytes) -> <class 'bytes'>, at this point the bytes can be distributed to file, network, etc safely
# where they could subsequently be deserialized on the receiving end.
substrait_bytes = ss.substrait.serde.serialize_bytes(
"SELECT * FROM aggregate_test_data", ctx
)
substrait_bytes = ss.Serde.serialize_bytes("SELECT * FROM aggregate_test_data", ctx)

# Imagine here bytes would be read from network, file, etc ... for example brevity this is omitted and variable is simply reused
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>
substrait_plan = ss.substrait.serde.deserialize_bytes(substrait_bytes)
substrait_plan = ss.Serde.deserialize_bytes(substrait_bytes)

# type(df_logical_plan) -> <class 'substrait.LogicalPlan'>
df_logical_plan = ss.substrait.consumer.from_substrait_plan(ctx, substrait_plan)
df_logical_plan = ss.Consumer.from_substrait_plan(ctx, substrait_plan)

# Back to Substrait Plan just for demonstration purposes
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>
substrait_plan = ss.substrait.producer.to_substrait_plan(df_logical_plan)
substrait_plan = ss.Producer.to_substrait_plan(df_logical_plan)
5 changes: 3 additions & 2 deletions examples/tpch/_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
module = import_module(query_code)
df = module.df

# Treat q17 as a special case. The answer file does not match the spec. Running at
# scale factor 1, we have manually verified this result does match the expected value.
# Treat q17 as a special case. The answer file does not match the spec.
# Running at scale factor 1, we have manually verified this result does
# match the expected value.
if answer_file == "q17":
return check_q17(df)

Expand Down
18 changes: 18 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,21 @@ exclude = [".github/**", "ci/**", ".asf.yaml"]
# Require Cargo.lock is up to date
locked = true
features = ["substrait"]

# Enable docstring linting using the google style guide
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "D", "W"]

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.lint.pycodestyle]
max-doc-length = 88

# Disable docstring checking for these directories
[tool.ruff.lint.per-file-ignores]
"python/datafusion/tests/*" = ["D"]
"examples/*" = ["D", "W505"]
"dev/*" = ["D"]
"benchmarks/*" = ["D", "F"]
"docs/*" = ["D"]
170 changes: 19 additions & 151 deletions python/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,206 +15,74 @@
# specific language governing permissions and limitations
# under the License.

from abc import ABCMeta, abstractmethod
from typing import List
"""DataFusion python package.
This is a Python library that binds to Apache Arrow in-memory query engine DataFusion.
See https://datafusion.apache.org/python for more information.
"""

try:
import importlib.metadata as importlib_metadata
except ImportError:
import importlib_metadata

import pyarrow as pa

from ._internal import (
AggregateUDF,
Config,
DataFrame,
from .context import (
SessionContext,
SessionConfig,
RuntimeConfig,
ScalarUDF,
SQLOptions,
)

# The following imports are okay to remain as opaque to the user.
from ._internal import Config

from .udf import ScalarUDF, AggregateUDF, Accumulator

from .common import (
DFSchema,
)

from .dataframe import DataFrame

from .expr import (
Alias,
Analyze,
Expr,
Filter,
Limit,
Like,
ILike,
Projection,
SimilarTo,
ScalarVariable,
Sort,
TableScan,
Not,
IsNotNull,
IsTrue,
IsFalse,
IsUnknown,
IsNotTrue,
IsNotFalse,
IsNotUnknown,
Negative,
InList,
Exists,
Subquery,
InSubquery,
ScalarSubquery,
GroupingSet,
Placeholder,
Case,
Cast,
TryCast,
Between,
Explain,
CreateMemoryTable,
SubqueryAlias,
Extension,
CreateView,
Distinct,
DropTable,
Repartition,
Partitioning,
Window,
WindowFrame,
)

__version__ = importlib_metadata.version(__name__)

__all__ = [
"Accumulator",
"Config",
"DataFrame",
"SessionContext",
"SessionConfig",
"SQLOptions",
"RuntimeConfig",
"Expr",
"AggregateUDF",
"ScalarUDF",
"Window",
"WindowFrame",
"column",
"literal",
"TableScan",
"Projection",
"DFSchema",
"DFField",
"Analyze",
"Sort",
"Limit",
"Filter",
"Like",
"ILike",
"SimilarTo",
"ScalarVariable",
"Alias",
"Not",
"IsNotNull",
"IsTrue",
"IsFalse",
"IsUnknown",
"IsNotTrue",
"IsNotFalse",
"IsNotUnknown",
"Negative",
"ScalarFunction",
"BuiltinScalarFunction",
"InList",
"Exists",
"Subquery",
"InSubquery",
"ScalarSubquery",
"GroupingSet",
"Placeholder",
"Case",
"Cast",
"TryCast",
"Between",
"Explain",
"SubqueryAlias",
"Extension",
"CreateMemoryTable",
"CreateView",
"Distinct",
"DropTable",
"Repartition",
"Partitioning",
]


class Accumulator(metaclass=ABCMeta):
@abstractmethod
def state(self) -> List[pa.Scalar]:
pass

@abstractmethod
def update(self, values: pa.Array) -> None:
pass

@abstractmethod
def merge(self, states: pa.Array) -> None:
pass

@abstractmethod
def evaluate(self) -> pa.Scalar:
pass


def column(value):
def column(value: str):
"""Create a column expression."""
return Expr.column(value)


col = column


def literal(value):
if not isinstance(value, pa.Scalar):
value = pa.scalar(value)
"""Create a literal expression."""
return Expr.literal(value)


lit = literal

udf = ScalarUDF.udf

def udf(func, input_types, return_type, volatility, name=None):
"""
Create a new User Defined Function
"""
if not callable(func):
raise TypeError("`func` argument must be callable")
if name is None:
name = func.__qualname__.lower()
return ScalarUDF(
name=name,
func=func,
input_types=input_types,
return_type=return_type,
volatility=volatility,
)


def udaf(accum, input_type, return_type, state_type, volatility, name=None):
"""
Create a new User Defined Aggregate Function
"""
if not issubclass(accum, Accumulator):
raise TypeError("`accum` must implement the abstract base class Accumulator")
if name is None:
name = accum.__qualname__.lower()
if isinstance(input_type, pa.lib.DataType):
input_type = [input_type]
return AggregateUDF(
name=name,
accumulator=accum,
input_type=input_type,
return_type=return_type,
state_type=state_type,
volatility=volatility,
)
udaf = AggregateUDF.udaf
Loading

0 comments on commit aa8aa9c

Please sign in to comment.