Skip to content

Commit

Permalink
fix(pyspark): try USE [CATALOG|DATABASE] before pyspark api
Browse files Browse the repository at this point in the history
The unity catalog on Databricks has weird permissioning where it doesn't
allow (at least one) user to run `setCurrentCatalog` or `setCurrentDatabase`.

It _does_ allow setting those via `USE CATALOG mycat;` and `USE DATABASE
mydb;`.

This, however, is not part of standard Spark SQL and is special
Databricks Spark SQL.

SO, we try the special databricks SQL, if that throws a parser error, we
fall back to using the pyspark API methods.
  • Loading branch information
gforsyth committed Jul 18, 2024
1 parent 4e3fb6a commit 42b8f06
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 14 deletions.
44 changes: 38 additions & 6 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import sqlglot.expressions as sge
from packaging.version import parse as vparse
from pyspark import SparkConf
from pyspark.errors import ParseException as PySparkParseException
from pyspark.sql import SparkSession
from pyspark.sql.types import BooleanType, DoubleType, LongType, StringType

Expand Down Expand Up @@ -219,15 +220,38 @@ def _active_catalog_database(self, catalog: str | None, db: str | None):
# 2. set database
# 3. set catalog to previous
# 4. set database to previous
#
# Unity catalog has special handling for "USE CATALOG" and "USE DATABASE"
# and also has weird permissioning around using `setCurrentCatalog` and
# `setCurrentDatabase`.
#
# We attempt to use the Unity-specific Spark SQL to set CATALOG and DATABASE
# and if that causes a parser exception we fall back to using the catalog API.
try:
if catalog is not None:
self._session.catalog.setCurrentCatalog(catalog)
self._session.catalog.setCurrentDatabase(db)
try:
catalog_sql = sg.to_identifier(catalog).sql(self.dialect)
self.raw_sql(f"USE CATALOG {catalog_sql}")
except PySparkParseException:
self._session.catalog.setCurrentCatalog(catalog)
try:
db_sql = sg.to_identifier(db).sql(self.dialect)
self.raw_sql(f"USE DATABASE {db_sql}")
except PySparkParseException:
self._session.catalog.setCurrentDatabase(db)
yield
finally:
if catalog is not None:
self._session.catalog.setCurrentCatalog(current_catalog)
self._session.catalog.setCurrentDatabase(current_db)
try:
catalog_sql = sg.to_identifier(current_catalog).sql(self.dialect)
self.raw_sql(f"USE CATALOG {catalog_sql}")
except PySparkParseException:
self._session.catalog.setCurrentCatalog(current_catalog)
try:
db_sql = sg.to_identifier(current_db).sql(self.dialect)
self.raw_sql(f"USE DATABASE {db_sql}")
except PySparkParseException:
self._session.catalog.setCurrentDatabase(current_db)

@contextlib.contextmanager
def _active_catalog(self, name: str | None):
Expand All @@ -236,10 +260,18 @@ def _active_catalog(self, name: str | None):
return
current = self.current_catalog
try:
self._session.catalog.setCurrentCatalog(name)
try:
catalog_sql = sg.to_identifier(name).sql(self.dialect)
self.raw_sql(f"USE CATALOG {catalog_sql};")
except PySparkParseException:
self._session.catalog.setCurrentCatalog(name)
yield
finally:
self._session.catalog.setCurrentCatalog(current)
try:
catalog_sql = sg.to_identifier(current).sql(self.dialect)
self.raw_sql(f"USE CATALOG {catalog_sql};")
except PySparkParseException:
self._session.catalog.setCurrentCatalog(current)

def list_catalogs(self, like: str | None = None) -> list[str]:
catalogs = [res.catalog for res in self._session.sql("SHOW CATALOGS").collect()]
Expand Down
33 changes: 25 additions & 8 deletions ibis/backends/pyspark/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,41 @@ def test_catalog_db_args(con, monkeypatch):
monkeypatch.setattr(ibis.options, "default_backend", con)
t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]})

assert con.current_catalog == "spark_catalog"
assert con.current_database == "ibis_testing"
con.create_database("toot", catalog="local")

# create a table in specified catalog and db
con.create_table(
"t2", database=(con.current_catalog, "default"), obj=t, overwrite=True
)
con.create_table("t2", database=("local", "toot"), obj=t, overwrite=True)
con.create_table("t3", database=("spark_catalog", "default"), obj=t, overwrite=True)

assert con.current_database == "ibis_testing"

assert "t2" not in con.list_tables()
assert "t2" in con.list_tables(database="default")
assert "t2" in con.list_tables(database="spark_catalog.default")
assert "t2" in con.list_tables(database=("spark_catalog", "default"))
assert "t2" in con.list_tables(database="local.toot")
assert "t2" in con.list_tables(database=("local", "toot"))

con.drop_table("t2", database="spark_catalog.default")
assert "t3" not in con.list_tables()
assert "t3" in con.list_tables(database="default")
assert "t3" in con.list_tables(database="spark_catalog.default")

assert "t2" not in con.list_tables(database="default")
con.drop_table("t2", database="local.toot")
con.drop_table("t3", database="spark_catalog.default")

assert "t2" not in con.list_tables(database="local.toot")
assert "t3" not in con.list_tables(database="spark_catalog.default")

con.drop_database("toot", catalog="local")

assert con.current_catalog == "spark_catalog"
assert con.current_database == "ibis_testing"


def test_create_table_no_catalog(con, monkeypatch):
monkeypatch.setattr(ibis.options, "default_backend", con)
t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]})

assert con.current_database != "default"
# create a table in specified catalog and db
con.create_table("t2", database=("default"), obj=t, overwrite=True)

Expand All @@ -39,3 +55,4 @@ def test_create_table_no_catalog(con, monkeypatch):
con.drop_table("t2", database="default")

assert "t2" not in con.list_tables(database="default")
assert con.current_database != "default"

0 comments on commit 42b8f06

Please sign in to comment.