diff --git a/.github/workflows/ibis-backends.yml b/.github/workflows/ibis-backends.yml index 931ca1eeabe4..bc4c09caa48f 100644 --- a/.github/workflows/ibis-backends.yml +++ b/.github/workflows/ibis-backends.yml @@ -743,6 +743,11 @@ jobs: if: matrix.pyspark-version == '3.5' run: poetry run pip install delta-spark + - name: install iceberg + shell: bash + if: matrix.pyspark-version == '3.5' + run: pushd "$(poetry run python -c "import pyspark; print(pyspark.__file__.rsplit('/', 1)[0])")/jars" && curl -LO https://search.maven.org/remotecontent?filepath=org/apache/iceberg/iceberg-spark-runtime-3.5_2.12/1.5.2/iceberg-spark-runtime-3.5_2.12-1.5.2.jar + - name: run tests run: just ci-check -m pyspark diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index 0d6df76b8151..4f8bab7e5fb7 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -29,6 +29,11 @@ from ibis.legacy.udf.vectorized import _coerce_to_series from ibis.util import deprecated +try: + from pyspark.errors import ParseException as PySparkParseException +except ImportError: + from pyspark.sql.utils import ParseException as PySparkParseException + if TYPE_CHECKING: from collections.abc import Mapping, Sequence from urllib.parse import ParseResult @@ -226,27 +231,62 @@ def _active_catalog_database(self, catalog: str | None, db: str | None): # 2. set database # 3. set catalog to previous # 4. set database to previous + # + # Unity catalog has special handling for "USE CATALOG" and "USE DATABASE" + # and also has weird permissioning around using `setCurrentCatalog` and + # `setCurrentDatabase`. + # + # We attempt to use the Unity-specific Spark SQL to set CATALOG and DATABASE + # and if that causes a parser exception we fall back to using the catalog API. try: if catalog is not None: - self._session.catalog.setCurrentCatalog(catalog) - self._session.catalog.setCurrentDatabase(db) + try: + catalog_sql = sg.to_identifier(catalog).sql(self.dialect) + self.raw_sql(f"USE CATALOG {catalog_sql}") + except PySparkParseException: + self._session.catalog.setCurrentCatalog(catalog) + try: + db_sql = sg.to_identifier(db).sql(self.dialect) + self.raw_sql(f"USE DATABASE {db_sql}") + except PySparkParseException: + self._session.catalog.setCurrentDatabase(db) yield finally: if catalog is not None: - self._session.catalog.setCurrentCatalog(current_catalog) - self._session.catalog.setCurrentDatabase(current_db) + try: + catalog_sql = sg.to_identifier(current_catalog).sql(self.dialect) + self.raw_sql(f"USE CATALOG {catalog_sql}") + except PySparkParseException: + self._session.catalog.setCurrentCatalog(current_catalog) + try: + db_sql = sg.to_identifier(current_db).sql(self.dialect) + self.raw_sql(f"USE DATABASE {db_sql}") + except PySparkParseException: + self._session.catalog.setCurrentDatabase(current_db) @contextlib.contextmanager def _active_catalog(self, name: str | None): if name is None or PYSPARK_LT_34: yield return - current = self.current_catalog + prev_catalog = self.current_catalog + prev_database = self.current_database try: - self._session.catalog.setCurrentCatalog(name) + try: + catalog_sql = sg.to_identifier(name).sql(self.dialect) + self.raw_sql(f"USE CATALOG {catalog_sql};") + except PySparkParseException: + self._session.catalog.setCurrentCatalog(name) yield finally: - self._session.catalog.setCurrentCatalog(current) + try: + catalog_sql = sg.to_identifier(prev_catalog).sql(self.dialect) + db_sql = sg.to_identifier(prev_database).sql(self.dialect) + self.raw_sql(f"USE CATALOG {catalog_sql};") + self.raw_sql(f"USE DATABASE {db_sql};") + except PySparkParseException: + self._session.catalog.setCurrentCatalog(prev_catalog) + self._session.catalog.setCurrentDatabase(prev_database) def list_catalogs(self, like: str | None = None) -> list[str]: catalogs = [res.catalog for res in self._session.sql("SHOW CATALOGS").collect()] diff --git a/ibis/backends/pyspark/tests/conftest.py b/ibis/backends/pyspark/tests/conftest.py index 5477471688bb..8c7a977d9653 100644 --- a/ibis/backends/pyspark/tests/conftest.py +++ b/ibis/backends/pyspark/tests/conftest.py @@ -171,6 +171,16 @@ def connect(*, tmpdir, worker_id, **kw): .config("spark.sql.streaming.schemaInference", True) ) + config = ( + config.config( + "spark.sql.extensions", + "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions", + ) + .config("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog") + .config("spark.sql.catalog.local.type", "hadoop") + .config("spark.sql.catalog.local.warehouse", "icehouse") + ) + try: from delta.pip_utils import configure_spark_with_delta_pip except ImportError: diff --git a/ibis/backends/pyspark/tests/test_client.py b/ibis/backends/pyspark/tests/test_client.py index bcbe74dc7e4e..6c977118cafd 100644 --- a/ibis/backends/pyspark/tests/test_client.py +++ b/ibis/backends/pyspark/tests/test_client.py @@ -10,25 +10,41 @@ def test_catalog_db_args(con, monkeypatch): monkeypatch.setattr(ibis.options, "default_backend", con) t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]}) + assert con.current_catalog == "spark_catalog" + assert con.current_database == "ibis_testing" + con.create_database("toot", catalog="local") + # create a table in specified catalog and db - con.create_table( - "t2", database=(con.current_catalog, "default"), obj=t, overwrite=True - ) + con.create_table("t2", database=("local", "toot"), obj=t, overwrite=True) + con.create_table("t3", database=("spark_catalog", "default"), obj=t, overwrite=True) + + assert con.current_database == "ibis_testing" assert "t2" not in con.list_tables() - assert "t2" in con.list_tables(database="default") - assert "t2" in con.list_tables(database="spark_catalog.default") - assert "t2" in con.list_tables(database=("spark_catalog", "default")) + assert "t2" in con.list_tables(database="local.toot") + assert "t2" in con.list_tables(database=("local", "toot")) - con.drop_table("t2", database="spark_catalog.default") + assert "t3" not in con.list_tables() + assert "t3" in con.list_tables(database="default") + assert "t3" in con.list_tables(database="spark_catalog.default") - assert "t2" not in con.list_tables(database="default") + con.drop_table("t2", database="local.toot") + con.drop_table("t3", database="spark_catalog.default") + + assert "t2" not in con.list_tables(database="local.toot") + assert "t3" not in con.list_tables(database="spark_catalog.default") + + con.drop_database("toot", catalog="local") + + assert con.current_catalog == "spark_catalog" + assert con.current_database == "ibis_testing" def test_create_table_no_catalog(con, monkeypatch): monkeypatch.setattr(ibis.options, "default_backend", con) t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]}) + assert con.current_database != "default" # create a table in specified catalog and db con.create_table("t2", database=("default"), obj=t, overwrite=True) @@ -39,3 +55,4 @@ def test_create_table_no_catalog(con, monkeypatch): con.drop_table("t2", database="default") assert "t2" not in con.list_tables(database="default") + assert con.current_database != "default" diff --git a/poetry-overrides.nix b/poetry-overrides.nix index e2e4d39a8cde..ee5829919c46 100644 --- a/poetry-overrides.nix +++ b/poetry-overrides.nix @@ -1 +1,16 @@ -_final: _prev: { } +final: prev: { + pyspark = prev.pyspark.overridePythonAttrs (attrs: + let + icebergJarUrl = "https://search.maven.org/remotecontent?filepath=org/apache/iceberg/iceberg-spark-runtime-3.5_2.12/1.5.2/iceberg-spark-runtime-3.5_2.12-1.5.2.jar"; + icebergJar = final.pkgs.fetchurl { + name = "iceberg-spark-runtime-3.5_2.12-1.5.2.jar"; + url = icebergJarUrl; + sha256 = "12v1704h0bq3qr2fci0mckg9171lyr8v6983wpa83k06v1w4pv1a"; + }; + in + { + postInstall = attrs.postInstall or "" + '' + cp ${icebergJar} $out/${final.python.sitePackages}/pyspark/jars/${icebergJar.name} + ''; + }); +}