From bd2a99af2c41e453824bd20367e36206dbef22cd Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Thu, 18 Jul 2024 16:02:47 -0400 Subject: [PATCH] fix(pyspark): restore database in catalog setter context mgr If you call `setCurrentCatalog`, Spark will default to using the `default` database in that catalog. So we need to note which database was being used so we can switch back to it correctly. --- ibis/backends/pyspark/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index 32fa96a5e1b1..e79a4a6cf4b9 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -258,7 +258,8 @@ def _active_catalog(self, name: str | None): if name is None or PYSPARK_LT_34: yield return - current = self.current_catalog + prev_catalog = self.current_catalog + prev_database = self.current_database try: try: catalog_sql = sg.to_identifier(name).sql(self.dialect) @@ -268,10 +269,13 @@ def _active_catalog(self, name: str | None): yield finally: try: - catalog_sql = sg.to_identifier(current).sql(self.dialect) + catalog_sql = sg.to_identifier(prev_catalog).sql(self.dialect) + db_sql = sg.to_identifier(prev_database).sql(self.dialect) self.raw_sql(f"USE CATALOG {catalog_sql};") + self.raw_sql(f"USE DATABASE {db_sql};") except PySparkParseException: - self._session.catalog.setCurrentCatalog(current) + self._session.catalog.setCurrentCatalog(prev_catalog) + self._session.catalog.setCurrentDatabase(prev_database) def list_catalogs(self, like: str | None = None) -> list[str]: catalogs = [res.catalog for res in self._session.sql("SHOW CATALOGS").collect()]