diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index f2e900eea501..57a417c463b9 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -3,7 +3,7 @@ import contextlib import os from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import pyspark import sqlglot as sg @@ -38,6 +38,8 @@ PYSPARK_LT_34 = vparse(pyspark.__version__) < vparse("3.4") +ConnectionMode = Literal["streaming", "batch"] + def normalize_filenames(source_list): # Promote to list @@ -132,13 +134,20 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._cached_dataframes = {} - def do_connect(self, session: SparkSession | None = None) -> None: + def do_connect( + self, session: SparkSession | None = None, mode: ConnectionMode | None = None + ) -> None: """Create a PySpark `Backend` for use with Ibis. Parameters ---------- session A SparkSession instance + mode + Can be either "batch" or "streaming". If "batch", every source, sink, and + query executed within this connection will be interpreted as a batch + workload. If "streaming", every source, sink, and query executed within + this connection will be interpreted as a streaming workload. Examples -------- @@ -154,6 +163,13 @@ def do_connect(self, session: SparkSession | None = None) -> None: session = SparkSession.builder.getOrCreate() + mode = mode or "batch" + if mode not in ("batch", "streaming"): + raise com.IbisInputError( + f"Invalid connection mode: {mode}, must be `streaming` or `batch`" + ) + self._mode = mode + self._session = session # Spark internally stores timestamps as UTC values, and timestamp data @@ -171,6 +187,10 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: struct_dtype = PySparkType.to_ibis(df.schema) return sch.Schema(struct_dtype) + @property + def mode(self) -> ConnectionMode: + return self._mode + @property def version(self): return pyspark.__version__ @@ -624,7 +644,7 @@ def _clean_up_cached_table(self, op): def read_delta( self, - source: str | Path, + path: str | Path, table_name: str | None = None, **kwargs: Any, ) -> ir.Table: @@ -632,11 +652,11 @@ def read_delta( Parameters ---------- - source + path The path to the Delta Lake table. table_name An optional name to use for the created table. This defaults to - a sequentially generated name. + a random generated name. kwargs Additional keyword arguments passed to PySpark. https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.load.html @@ -647,8 +667,12 @@ def read_delta( The just-registered table """ - source = util.normalize_filename(source) - spark_df = self._session.read.format("delta").load(source, **kwargs) + if self.mode == "streaming": + raise NotImplementedError( + "Reading a Delta Lake table in streaming mode is not supported" + ) + path = util.normalize_filename(path) + spark_df = self._session.read.format("delta").load(path, **kwargs) table_name = table_name or util.gen_name("read_delta") spark_df.createOrReplaceTempView(table_name) @@ -656,7 +680,7 @@ def read_delta( def read_parquet( self, - source: str | Path, + path: str | Path, table_name: str | None = None, **kwargs: Any, ) -> ir.Table: @@ -664,11 +688,11 @@ def read_parquet( Parameters ---------- - source + path The data source. May be a path to a file or directory of parquet files. table_name An optional name to use for the created table. This defaults to - a sequentially generated name. + a random generated name. kwargs Additional keyword arguments passed to PySpark. https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.parquet.html @@ -679,8 +703,13 @@ def read_parquet( The just-registered table """ - source = util.normalize_filename(source) - spark_df = self._session.read.parquet(source, **kwargs) + if self.mode == "streaming": + raise NotImplementedError( + "Pyspark in streaming mode does not support direction registration of parquet files. " + "Please use `read_parquet_directory` instead." + ) + path = util.normalize_filename(path) + spark_df = self._session.read.parquet(path, **kwargs) table_name = table_name or util.gen_name("read_parquet") spark_df.createOrReplaceTempView(table_name) @@ -701,7 +730,7 @@ def read_csv( iterable of CSV files. table_name An optional name to use for the created table. This defaults to - a sequentially generated name. + a random generated name. kwargs Additional keyword arguments passed to PySpark loading function. https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.csv.html @@ -712,6 +741,11 @@ def read_csv( The just-registered table """ + if self.mode == "streaming": + raise NotImplementedError( + "Pyspark in streaming mode does not support direction registration of CSV files. " + "Please use `read_csv_directory` instead." + ) inferSchema = kwargs.pop("inferSchema", True) header = kwargs.pop("header", True) source_list = normalize_filenames(source_list) @@ -738,7 +772,7 @@ def read_json( iterable of JSON files. table_name An optional name to use for the created table. This defaults to - a sequentially generated name. + a random generated name. kwargs Additional keyword arguments passed to PySpark loading function. https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.json.html @@ -749,6 +783,11 @@ def read_json( The just-registered table """ + if self.mode == "streaming": + raise NotImplementedError( + "Pyspark in streaming mode does not support direction registration of JSON files. " + "Please use `read_json_directory` instead." + ) source_list = normalize_filenames(source_list) spark_df = self._session.read.json(source_list, **kwargs) table_name = table_name or util.gen_name("read_json") @@ -775,7 +814,7 @@ def register( parquet/csv files, or an iterable of CSV files. table_name An optional name to use for the created table. This defaults to - a sequentially generated name. + a random generated name. **kwargs Additional keyword arguments passed to PySpark loading functions for CSV or parquet. @@ -835,9 +874,14 @@ def to_delta( The data source. A string or Path to the Delta Lake table. **kwargs - PySpark Delta Lake table write arguments. https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrameWriter.save.html + PySpark Delta Lake table write arguments. + https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriter.save.html """ + if self.mode == "streaming": + raise NotImplementedError( + "Writing to a Delta Lake table in streaming mode is not supported" + ) df = self._session.sql(expr.compile()) df.write.format("delta").save(os.fspath(path), **kwargs) @@ -848,6 +892,10 @@ def to_pyarrow( limit: int | str | None = None, **kwargs: Any, ) -> pa.Table: + if self.mode == "streaming": + raise NotImplementedError( + "PySpark in streaming mode does not support to_pyarrow" + ) import pyarrow as pa import pyarrow_hotfix # noqa: F401 @@ -870,6 +918,10 @@ def to_pyarrow_batches( chunk_size: int = 1000000, **kwargs: Any, ) -> pa.ipc.RecordBatchReader: + if self.mode == "streaming": + raise NotImplementedError( + "PySpark in streaming mode does not support to_pyarrow_batches" + ) pa = self._import_pyarrow() pa_table = self.to_pyarrow( expr.as_table(), params=params, limit=limit, **kwargs diff --git a/ibis/backends/pyspark/tests/conftest.py b/ibis/backends/pyspark/tests/conftest.py index ed111377d25f..79664c5752a1 100644 --- a/ibis/backends/pyspark/tests/conftest.py +++ b/ibis/backends/pyspark/tests/conftest.py @@ -177,6 +177,70 @@ def connect(*, tmpdir, worker_id, **kw): return ibis.pyspark.connect(spark, **kw) +class TestConfForStreaming(BackendTest): + deps = ("pyspark",) + + def _load_data(self, **_: Any) -> None: + s = self.connection._session + num_partitions = 4 + + sort_cols = {"functional_alltypes": "id"} + + for name in TEST_TABLES: + path = str(self.data_dir / "directory" / "parquet" / name) + t = s.readStream.parquet(path).repartition(num_partitions) + if (sort_col := sort_cols.get(name)) is not None: + t = t.sort(sort_col) + t.createOrReplaceTempView(name) + + @staticmethod + def connect(*, tmpdir, worker_id, **kw): + # Spark internally stores timestamps as UTC values, and timestamp + # data that is brought in without a specified time zone is + # converted as local time to UTC with microsecond resolution. + # https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics + + from pyspark.sql import SparkSession + + config = ( + SparkSession.builder.appName("ibis_testing") + .master("local[1]") + .config("spark.cores.max", 1) + .config("spark.default.parallelism", 1) + .config("spark.driver.extraJavaOptions", "-Duser.timezone=GMT") + .config("spark.dynamicAllocation.enabled", False) + .config("spark.executor.extraJavaOptions", "-Duser.timezone=GMT") + .config("spark.executor.heartbeatInterval", "3600s") + .config("spark.executor.instances", 1) + .config("spark.network.timeout", "4200s") + .config("spark.rdd.compress", False) + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config("spark.shuffle.compress", False) + .config("spark.shuffle.spill.compress", False) + .config("spark.sql.legacy.timeParserPolicy", "LEGACY") + .config("spark.sql.session.timeZone", "UTC") + .config("spark.sql.shuffle.partitions", 1) + .config("spark.storage.blockManagerSlaveTimeoutMs", "4200s") + .config("spark.ui.enabled", False) + .config("spark.ui.showConsoleProgress", False) + .config("spark.sql.execution.arrow.pyspark.enabled", False) + .config("spark.sql.streaming.schemaInference", True) + ) + + try: + from delta.pip_utils import configure_spark_with_delta_pip + except ImportError: + configure_spark_with_delta_pip = lambda cfg: cfg + else: + config = config.config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ).config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + + spark = configure_spark_with_delta_pip(config).getOrCreate() + return ibis.pyspark.connect(spark, mode="streaming", **kw) + + @pytest.fixture(scope="session") def con(data_dir, tmp_path_factory, worker_id): import pyspark.sql.functions as F @@ -293,6 +357,12 @@ def con(data_dir, tmp_path_factory, worker_id): return con +@pytest.fixture(scope="session") +def con_streaming(data_dir, tmp_path_factory, worker_id): + backend_test = TestConfForStreaming.load_data(data_dir, tmp_path_factory, worker_id) + return backend_test.connection + + class IbisWindow: # Test util class to generate different types of ibis windows def __init__(self, windows): diff --git a/ibis/backends/pyspark/tests/test_import_export.py b/ibis/backends/pyspark/tests/test_import_export.py new file mode 100644 index 000000000000..1aed2537d830 --- /dev/null +++ b/ibis/backends/pyspark/tests/test_import_export.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from operator import methodcaller + +import pytest + + +@pytest.mark.parametrize( + "method", + [ + methodcaller("read_delta", path="test.delta"), + methodcaller("read_csv", source_list="test.csv"), + methodcaller("read_parquet", path="test.parquet"), + methodcaller("read_json", source_list="test.json"), + ], +) +def test_streaming_import_not_implemented(con_streaming, method): + with pytest.raises(NotImplementedError): + method(con_streaming)