Skip to content

Commit

Permalink
fix(pyspark): set catalog and database with USE instead of pyspark …
Browse files Browse the repository at this point in the history
…api (#9620)
  • Loading branch information
gforsyth authored Jul 19, 2024
1 parent 66e150f commit 6991f04
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 16 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,11 @@ jobs:
if: matrix.pyspark-version == '3.5'
run: poetry run pip install delta-spark

- name: install iceberg
shell: bash
if: matrix.pyspark-version == '3.5'
run: pushd "$(poetry run python -c "import pyspark; print(pyspark.__file__.rsplit('/', 1)[0])")/jars" && curl -LO https://search.maven.org/remotecontent?filepath=org/apache/iceberg/iceberg-spark-runtime-3.5_2.12/1.5.2/iceberg-spark-runtime-3.5_2.12-1.5.2.jar

- name: run tests
run: just ci-check -m pyspark

Expand Down
54 changes: 47 additions & 7 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
from ibis.legacy.udf.vectorized import _coerce_to_series
from ibis.util import deprecated

try:
from pyspark.errors import ParseException as PySparkParseException
except ImportError:
from pyspark.sql.utils import ParseException as PySparkParseException

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from urllib.parse import ParseResult
Expand Down Expand Up @@ -226,27 +231,62 @@ def _active_catalog_database(self, catalog: str | None, db: str | None):
# 2. set database
# 3. set catalog to previous
# 4. set database to previous
#
# Unity catalog has special handling for "USE CATALOG" and "USE DATABASE"
# and also has weird permissioning around using `setCurrentCatalog` and
# `setCurrentDatabase`.
#
# We attempt to use the Unity-specific Spark SQL to set CATALOG and DATABASE
# and if that causes a parser exception we fall back to using the catalog API.
try:
if catalog is not None:
self._session.catalog.setCurrentCatalog(catalog)
self._session.catalog.setCurrentDatabase(db)
try:
catalog_sql = sg.to_identifier(catalog).sql(self.dialect)
self.raw_sql(f"USE CATALOG {catalog_sql}")
except PySparkParseException:
self._session.catalog.setCurrentCatalog(catalog)
try:
db_sql = sg.to_identifier(db).sql(self.dialect)
self.raw_sql(f"USE DATABASE {db_sql}")
except PySparkParseException:
self._session.catalog.setCurrentDatabase(db)
yield
finally:
if catalog is not None:
self._session.catalog.setCurrentCatalog(current_catalog)
self._session.catalog.setCurrentDatabase(current_db)
try:
catalog_sql = sg.to_identifier(current_catalog).sql(self.dialect)
self.raw_sql(f"USE CATALOG {catalog_sql}")
except PySparkParseException:
self._session.catalog.setCurrentCatalog(current_catalog)
try:
db_sql = sg.to_identifier(current_db).sql(self.dialect)
self.raw_sql(f"USE DATABASE {db_sql}")
except PySparkParseException:
self._session.catalog.setCurrentDatabase(current_db)

@contextlib.contextmanager
def _active_catalog(self, name: str | None):
if name is None or PYSPARK_LT_34:
yield
return
current = self.current_catalog
prev_catalog = self.current_catalog
prev_database = self.current_database
try:
self._session.catalog.setCurrentCatalog(name)
try:
catalog_sql = sg.to_identifier(name).sql(self.dialect)
self.raw_sql(f"USE CATALOG {catalog_sql};")
except PySparkParseException:
self._session.catalog.setCurrentCatalog(name)
yield
finally:
self._session.catalog.setCurrentCatalog(current)
try:
catalog_sql = sg.to_identifier(prev_catalog).sql(self.dialect)
db_sql = sg.to_identifier(prev_database).sql(self.dialect)
self.raw_sql(f"USE CATALOG {catalog_sql};")
self.raw_sql(f"USE DATABASE {db_sql};")
except PySparkParseException:
self._session.catalog.setCurrentCatalog(prev_catalog)
self._session.catalog.setCurrentDatabase(prev_database)

def list_catalogs(self, like: str | None = None) -> list[str]:
catalogs = [res.catalog for res in self._session.sql("SHOW CATALOGS").collect()]
Expand Down
10 changes: 10 additions & 0 deletions ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ def connect(*, tmpdir, worker_id, **kw):
.config("spark.sql.streaming.schemaInference", True)
)

config = (
config.config(
"spark.sql.extensions",
"org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions",
)
.config("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog")
.config("spark.sql.catalog.local.type", "hadoop")
.config("spark.sql.catalog.local.warehouse", "icehouse")
)

try:
from delta.pip_utils import configure_spark_with_delta_pip
except ImportError:
Expand Down
33 changes: 25 additions & 8 deletions ibis/backends/pyspark/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,41 @@ def test_catalog_db_args(con, monkeypatch):
monkeypatch.setattr(ibis.options, "default_backend", con)
t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]})

assert con.current_catalog == "spark_catalog"
assert con.current_database == "ibis_testing"
con.create_database("toot", catalog="local")

# create a table in specified catalog and db
con.create_table(
"t2", database=(con.current_catalog, "default"), obj=t, overwrite=True
)
con.create_table("t2", database=("local", "toot"), obj=t, overwrite=True)
con.create_table("t3", database=("spark_catalog", "default"), obj=t, overwrite=True)

assert con.current_database == "ibis_testing"

assert "t2" not in con.list_tables()
assert "t2" in con.list_tables(database="default")
assert "t2" in con.list_tables(database="spark_catalog.default")
assert "t2" in con.list_tables(database=("spark_catalog", "default"))
assert "t2" in con.list_tables(database="local.toot")
assert "t2" in con.list_tables(database=("local", "toot"))

con.drop_table("t2", database="spark_catalog.default")
assert "t3" not in con.list_tables()
assert "t3" in con.list_tables(database="default")
assert "t3" in con.list_tables(database="spark_catalog.default")

assert "t2" not in con.list_tables(database="default")
con.drop_table("t2", database="local.toot")
con.drop_table("t3", database="spark_catalog.default")

assert "t2" not in con.list_tables(database="local.toot")
assert "t3" not in con.list_tables(database="spark_catalog.default")

con.drop_database("toot", catalog="local")

assert con.current_catalog == "spark_catalog"
assert con.current_database == "ibis_testing"


def test_create_table_no_catalog(con, monkeypatch):
monkeypatch.setattr(ibis.options, "default_backend", con)
t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]})

assert con.current_database != "default"
# create a table in specified catalog and db
con.create_table("t2", database=("default"), obj=t, overwrite=True)

Expand All @@ -39,3 +55,4 @@ def test_create_table_no_catalog(con, monkeypatch):
con.drop_table("t2", database="default")

assert "t2" not in con.list_tables(database="default")
assert con.current_database != "default"
17 changes: 16 additions & 1 deletion poetry-overrides.nix
Original file line number Diff line number Diff line change
@@ -1 +1,16 @@
_final: _prev: { }
final: prev: {
pyspark = prev.pyspark.overridePythonAttrs (attrs:
let
icebergJarUrl = "https://search.maven.org/remotecontent?filepath=org/apache/iceberg/iceberg-spark-runtime-3.5_2.12/1.5.2/iceberg-spark-runtime-3.5_2.12-1.5.2.jar";
icebergJar = final.pkgs.fetchurl {
name = "iceberg-spark-runtime-3.5_2.12-1.5.2.jar";
url = icebergJarUrl;
sha256 = "12v1704h0bq3qr2fci0mckg9171lyr8v6983wpa83k06v1w4pv1a";
};
in
{
postInstall = attrs.postInstall or "" + ''
cp ${icebergJar} $out/${final.python.sitePackages}/pyspark/jars/${icebergJar.name}
'';
});
}

0 comments on commit 6991f04

Please sign in to comment.