From 23c0e810419c227455125c3cfd233c01665fe935 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 10 Nov 2024 11:08:03 -0500 Subject: [PATCH] fix(mysql): handle database names that must be quoted in `list_tables` (#10466) --- compose.yaml | 2 +- ibis/backends/mysql/__init__.py | 11 ++++++----- ibis/backends/mysql/tests/conftest.py | 2 +- ibis/backends/sql/__init__.py | 10 +++++----- ibis/backends/tests/test_client.py | 2 +- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/compose.yaml b/compose.yaml index 9024dbc0efa7..67e5298f28c0 100644 --- a/compose.yaml +++ b/compose.yaml @@ -18,7 +18,7 @@ services: mysql: environment: MYSQL_ALLOW_EMPTY_PASSWORD: "true" - MYSQL_DATABASE: ibis_testing + MYSQL_DATABASE: ibis-testing MYSQL_PASSWORD: ibis MYSQL_USER: ibis healthcheck: diff --git a/ibis/backends/mysql/__init__.py b/ibis/backends/mysql/__init__.py index b27fdb4b70fe..ed1219cea95b 100644 --- a/ibis/backends/mysql/__init__.py +++ b/ibis/backends/mysql/__init__.py @@ -122,7 +122,7 @@ def do_connect( >>> host = os.environ.get("IBIS_TEST_MYSQL_HOST", "localhost") >>> user = os.environ.get("IBIS_TEST_MYSQL_USER", "ibis") >>> password = os.environ.get("IBIS_TEST_MYSQL_PASSWORD", "ibis") - >>> database = os.environ.get("IBIS_TEST_MYSQL_DATABASE", "ibis_testing") + >>> database = os.environ.get("IBIS_TEST_MYSQL_DATABASE", "ibis-testing") >>> con = ibis.mysql.connect(database=database, host=host, user=user, password=password) >>> con.list_tables() # doctest: +ELLIPSIS [...] @@ -337,11 +337,12 @@ def list_tables( the current database (`self.current_database`). """ if database is not None: - table_loc = database + table_loc = self._to_sqlglot_table(database) else: - table_loc = self.current_database - - table_loc = self._to_sqlglot_table(table_loc) + table_loc = sge.Table( + db=sg.to_identifier(self.current_database, quoted=self.compiler.quoted), + catalog=None, + ) conditions = [TRUE] diff --git a/ibis/backends/mysql/tests/conftest.py b/ibis/backends/mysql/tests/conftest.py index f343dd2aa211..b8e4f5734d34 100644 --- a/ibis/backends/mysql/tests/conftest.py +++ b/ibis/backends/mysql/tests/conftest.py @@ -17,7 +17,7 @@ MYSQL_PASS = os.environ.get("IBIS_TEST_MYSQL_PASSWORD", "ibis") MYSQL_HOST = os.environ.get("IBIS_TEST_MYSQL_HOST", "localhost") MYSQL_PORT = int(os.environ.get("IBIS_TEST_MYSQL_PORT", 3306)) -IBIS_TEST_MYSQL_DB = os.environ.get("IBIS_TEST_MYSQL_DATABASE", "ibis_testing") +IBIS_TEST_MYSQL_DB = os.environ.get("IBIS_TEST_MYSQL_DATABASE", "ibis-testing") class TestConf(ServiceBackendTest): diff --git a/ibis/backends/sql/__init__.py b/ibis/backends/sql/__init__.py index 297329eb40ce..96e96c2f4f07 100644 --- a/ibis/backends/sql/__init__.py +++ b/ibis/backends/sql/__init__.py @@ -239,7 +239,7 @@ def drop_table( table_loc = self._to_sqlglot_table(database) catalog, db = self._to_catalog_db_tuple(table_loc) - drop_stmt = sg.exp.Drop( + drop_stmt = sge.Drop( kind="TABLE", this=sg.table(name, db=db, catalog=catalog, quoted=self.compiler.quoted), exists=force, @@ -509,7 +509,7 @@ def _to_catalog_db_tuple(self, table_loc: sge.Table): def _to_sqlglot_table(self, database): if database is None: # Create "table" with empty catalog and db - database = sg.exp.Table(catalog=None, db=None) + database = sge.Table(catalog=None, db=None) elif isinstance(database, (list, tuple)): if len(database) > 2: raise ValueError( @@ -528,7 +528,7 @@ def _to_sqlglot_table(self, database): '\n("catalog", "database")' '\n("database",)' ) - database = sg.exp.Table( + database = sge.Table( catalog=sg.to_identifier(catalog, quoted=self.compiler.quoted), db=sg.to_identifier(database, quoted=self.compiler.quoted), ) @@ -538,14 +538,14 @@ def _to_sqlglot_table(self, database): # sqlglot parsing of the string will assume that it's a Table # so we unpack the arguments into a new sqlglot object, switching # table (this) -> database (db) and database (db) -> catalog - table = sg.parse_one(database, into=sg.exp.Table, dialect=self.dialect) + table = sg.parse_one(database, into=sge.Table, dialect=self.dialect) if table.args["catalog"] is not None: raise exc.IbisInputError( f"Overspecified table hierarchy provided: `{table.sql(self.dialect)}`" ) catalog = table.args["db"] db = table.args["this"] - database = sg.exp.Table(catalog=catalog, db=db) + database = sge.Table(catalog=catalog, db=db) else: raise ValueError( """Invalid database hierarchy format. Please use either dotted diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 743ddf440613..3b1164426d60 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -676,7 +676,7 @@ def test_list_database_contents(con): "flink": {"default_database"}, "impala": {"ibis_testing", "default", "_impala_builtins"}, "mssql": {"INFORMATION_SCHEMA", "dbo", "guest"}, - "mysql": {"ibis_testing", "information_schema"}, + "mysql": {"ibis-testing", "information_schema"}, "oracle": {"SYS", "IBIS"}, "postgres": {"public", "information_schema"}, "pyspark": set(),