Skip to content

Commit

Permalink
feat(pyspark): provide a mode option to manage both batch and streami…
Browse files Browse the repository at this point in the history
…ng connections
  • Loading branch information
chloeh13q authored May 29, 2024
1 parent 7a39bd3 commit e425ad5
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 16 deletions.
84 changes: 68 additions & 16 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import contextlib
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

import pyspark
import sqlglot as sg
Expand Down Expand Up @@ -38,6 +38,8 @@

PYSPARK_LT_34 = vparse(pyspark.__version__) < vparse("3.4")

ConnectionMode = Literal["streaming", "batch"]


def normalize_filenames(source_list):
# Promote to list
Expand Down Expand Up @@ -132,13 +134,20 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._cached_dataframes = {}

def do_connect(self, session: SparkSession | None = None) -> None:
def do_connect(
self, session: SparkSession | None = None, mode: ConnectionMode | None = None
) -> None:
"""Create a PySpark `Backend` for use with Ibis.
Parameters
----------
session
A SparkSession instance
mode
Can be either "batch" or "streaming". If "batch", every source, sink, and
query executed within this connection will be interpreted as a batch
workload. If "streaming", every source, sink, and query executed within
this connection will be interpreted as a streaming workload.
Examples
--------
Expand All @@ -154,6 +163,13 @@ def do_connect(self, session: SparkSession | None = None) -> None:

session = SparkSession.builder.getOrCreate()

mode = mode or "batch"
if mode not in ("batch", "streaming"):
raise com.IbisInputError(
f"Invalid connection mode: {mode}, must be `streaming` or `batch`"
)
self._mode = mode

self._session = session

# Spark internally stores timestamps as UTC values, and timestamp data
Expand All @@ -171,6 +187,10 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
struct_dtype = PySparkType.to_ibis(df.schema)
return sch.Schema(struct_dtype)

@property
def mode(self) -> ConnectionMode:
return self._mode

@property
def version(self):
return pyspark.__version__
Expand Down Expand Up @@ -624,19 +644,19 @@ def _clean_up_cached_table(self, op):

def read_delta(
self,
source: str | Path,
path: str | Path,
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a Delta Lake table as a table in the current database.
Parameters
----------
source
path
The path to the Delta Lake table.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
a random generated name.
kwargs
Additional keyword arguments passed to PySpark.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.load.html
Expand All @@ -647,28 +667,32 @@ def read_delta(
The just-registered table
"""
source = util.normalize_filename(source)
spark_df = self._session.read.format("delta").load(source, **kwargs)
if self.mode == "streaming":
raise NotImplementedError(
"Reading a Delta Lake table in streaming mode is not supported"
)
path = util.normalize_filename(path)
spark_df = self._session.read.format("delta").load(path, **kwargs)
table_name = table_name or util.gen_name("read_delta")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

def read_parquet(
self,
source: str | Path,
path: str | Path,
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a parquet file as a table in the current database.
Parameters
----------
source
path
The data source. May be a path to a file or directory of parquet files.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
a random generated name.
kwargs
Additional keyword arguments passed to PySpark.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.parquet.html
Expand All @@ -679,8 +703,13 @@ def read_parquet(
The just-registered table
"""
source = util.normalize_filename(source)
spark_df = self._session.read.parquet(source, **kwargs)
if self.mode == "streaming":
raise NotImplementedError(
"Pyspark in streaming mode does not support direction registration of parquet files. "
"Please use `read_parquet_directory` instead."
)
path = util.normalize_filename(path)
spark_df = self._session.read.parquet(path, **kwargs)
table_name = table_name or util.gen_name("read_parquet")

spark_df.createOrReplaceTempView(table_name)
Expand All @@ -701,7 +730,7 @@ def read_csv(
iterable of CSV files.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
a random generated name.
kwargs
Additional keyword arguments passed to PySpark loading function.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.csv.html
Expand All @@ -712,6 +741,11 @@ def read_csv(
The just-registered table
"""
if self.mode == "streaming":
raise NotImplementedError(
"Pyspark in streaming mode does not support direction registration of CSV files. "
"Please use `read_csv_directory` instead."
)
inferSchema = kwargs.pop("inferSchema", True)
header = kwargs.pop("header", True)
source_list = normalize_filenames(source_list)
Expand All @@ -738,7 +772,7 @@ def read_json(
iterable of JSON files.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
a random generated name.
kwargs
Additional keyword arguments passed to PySpark loading function.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.json.html
Expand All @@ -749,6 +783,11 @@ def read_json(
The just-registered table
"""
if self.mode == "streaming":
raise NotImplementedError(
"Pyspark in streaming mode does not support direction registration of JSON files. "
"Please use `read_json_directory` instead."
)
source_list = normalize_filenames(source_list)
spark_df = self._session.read.json(source_list, **kwargs)
table_name = table_name or util.gen_name("read_json")
Expand All @@ -775,7 +814,7 @@ def register(
parquet/csv files, or an iterable of CSV files.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
a random generated name.
**kwargs
Additional keyword arguments passed to PySpark loading functions for
CSV or parquet.
Expand Down Expand Up @@ -835,9 +874,14 @@ def to_delta(
The data source. A string or Path to the Delta Lake table.
**kwargs
PySpark Delta Lake table write arguments. https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrameWriter.save.html
PySpark Delta Lake table write arguments.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriter.save.html
"""
if self.mode == "streaming":
raise NotImplementedError(
"Writing to a Delta Lake table in streaming mode is not supported"
)
df = self._session.sql(expr.compile())
df.write.format("delta").save(os.fspath(path), **kwargs)

Expand All @@ -848,6 +892,10 @@ def to_pyarrow(
limit: int | str | None = None,
**kwargs: Any,
) -> pa.Table:
if self.mode == "streaming":
raise NotImplementedError(
"PySpark in streaming mode does not support to_pyarrow"
)
import pyarrow as pa
import pyarrow_hotfix # noqa: F401

Expand All @@ -870,6 +918,10 @@ def to_pyarrow_batches(
chunk_size: int = 1000000,
**kwargs: Any,
) -> pa.ipc.RecordBatchReader:
if self.mode == "streaming":
raise NotImplementedError(
"PySpark in streaming mode does not support to_pyarrow_batches"
)
pa = self._import_pyarrow()
pa_table = self.to_pyarrow(
expr.as_table(), params=params, limit=limit, **kwargs
Expand Down
70 changes: 70 additions & 0 deletions ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,70 @@ def connect(*, tmpdir, worker_id, **kw):
return ibis.pyspark.connect(spark, **kw)


class TestConfForStreaming(BackendTest):
deps = ("pyspark",)

def _load_data(self, **_: Any) -> None:
s = self.connection._session
num_partitions = 4

sort_cols = {"functional_alltypes": "id"}

for name in TEST_TABLES:
path = str(self.data_dir / "directory" / "parquet" / name)
t = s.readStream.parquet(path).repartition(num_partitions)
if (sort_col := sort_cols.get(name)) is not None:
t = t.sort(sort_col)
t.createOrReplaceTempView(name)

@staticmethod
def connect(*, tmpdir, worker_id, **kw):
# Spark internally stores timestamps as UTC values, and timestamp
# data that is brought in without a specified time zone is
# converted as local time to UTC with microsecond resolution.
# https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics

from pyspark.sql import SparkSession

config = (
SparkSession.builder.appName("ibis_testing")
.master("local[1]")
.config("spark.cores.max", 1)
.config("spark.default.parallelism", 1)
.config("spark.driver.extraJavaOptions", "-Duser.timezone=GMT")
.config("spark.dynamicAllocation.enabled", False)
.config("spark.executor.extraJavaOptions", "-Duser.timezone=GMT")
.config("spark.executor.heartbeatInterval", "3600s")
.config("spark.executor.instances", 1)
.config("spark.network.timeout", "4200s")
.config("spark.rdd.compress", False)
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.shuffle.compress", False)
.config("spark.shuffle.spill.compress", False)
.config("spark.sql.legacy.timeParserPolicy", "LEGACY")
.config("spark.sql.session.timeZone", "UTC")
.config("spark.sql.shuffle.partitions", 1)
.config("spark.storage.blockManagerSlaveTimeoutMs", "4200s")
.config("spark.ui.enabled", False)
.config("spark.ui.showConsoleProgress", False)
.config("spark.sql.execution.arrow.pyspark.enabled", False)
.config("spark.sql.streaming.schemaInference", True)
)

try:
from delta.pip_utils import configure_spark_with_delta_pip
except ImportError:
configure_spark_with_delta_pip = lambda cfg: cfg
else:
config = config.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
).config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")

spark = configure_spark_with_delta_pip(config).getOrCreate()
return ibis.pyspark.connect(spark, mode="streaming", **kw)


@pytest.fixture(scope="session")
def con(data_dir, tmp_path_factory, worker_id):
import pyspark.sql.functions as F
Expand Down Expand Up @@ -293,6 +357,12 @@ def con(data_dir, tmp_path_factory, worker_id):
return con


@pytest.fixture(scope="session")
def con_streaming(data_dir, tmp_path_factory, worker_id):
backend_test = TestConfForStreaming.load_data(data_dir, tmp_path_factory, worker_id)
return backend_test.connection


class IbisWindow:
# Test util class to generate different types of ibis windows
def __init__(self, windows):
Expand Down
19 changes: 19 additions & 0 deletions ibis/backends/pyspark/tests/test_import_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from operator import methodcaller

import pytest


@pytest.mark.parametrize(
"method",
[
methodcaller("read_delta", path="test.delta"),
methodcaller("read_csv", source_list="test.csv"),
methodcaller("read_parquet", path="test.parquet"),
methodcaller("read_json", source_list="test.json"),
],
)
def test_streaming_import_not_implemented(con_streaming, method):
with pytest.raises(NotImplementedError):
method(con_streaming)

0 comments on commit e425ad5

Please sign in to comment.