Skip to content

Commit

Permalink
fix(pyspark): restore database in catalog setter context mgr
Browse files Browse the repository at this point in the history
If you call `setCurrentCatalog`, Spark will default to using the
`default` database in that catalog.  So we need to note which database
was being used so we can switch back to it correctly.
  • Loading branch information
gforsyth committed Jul 18, 2024
1 parent d1493f7 commit bd2a99a
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ def _active_catalog(self, name: str | None):
if name is None or PYSPARK_LT_34:
yield
return
current = self.current_catalog
prev_catalog = self.current_catalog
prev_database = self.current_database
try:
try:
catalog_sql = sg.to_identifier(name).sql(self.dialect)
Expand All @@ -268,10 +269,13 @@ def _active_catalog(self, name: str | None):
yield
finally:
try:
catalog_sql = sg.to_identifier(current).sql(self.dialect)
catalog_sql = sg.to_identifier(prev_catalog).sql(self.dialect)
db_sql = sg.to_identifier(prev_database).sql(self.dialect)
self.raw_sql(f"USE CATALOG {catalog_sql};")
self.raw_sql(f"USE DATABASE {db_sql};")
except PySparkParseException:
self._session.catalog.setCurrentCatalog(current)
self._session.catalog.setCurrentCatalog(prev_catalog)
self._session.catalog.setCurrentDatabase(prev_database)

def list_catalogs(self, like: str | None = None) -> list[str]:
catalogs = [res.catalog for res in self._session.sql("SHOW CATALOGS").collect()]
Expand Down

0 comments on commit bd2a99a

Please sign in to comment.