Skip to content

Commit

Permalink
fix(mysql): handle database names that must be quoted in list_tables (
Browse files Browse the repository at this point in the history
cpcloud authored Nov 10, 2024
1 parent 860b9ca commit 23c0e81
Showing 5 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion compose.yaml
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@ services:
mysql:
environment:
MYSQL_ALLOW_EMPTY_PASSWORD: "true"
MYSQL_DATABASE: ibis_testing
MYSQL_DATABASE: ibis-testing
MYSQL_PASSWORD: ibis
MYSQL_USER: ibis
healthcheck:
11 changes: 6 additions & 5 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
@@ -122,7 +122,7 @@ def do_connect(
>>> host = os.environ.get("IBIS_TEST_MYSQL_HOST", "localhost")
>>> user = os.environ.get("IBIS_TEST_MYSQL_USER", "ibis")
>>> password = os.environ.get("IBIS_TEST_MYSQL_PASSWORD", "ibis")
>>> database = os.environ.get("IBIS_TEST_MYSQL_DATABASE", "ibis_testing")
>>> database = os.environ.get("IBIS_TEST_MYSQL_DATABASE", "ibis-testing")
>>> con = ibis.mysql.connect(database=database, host=host, user=user, password=password)
>>> con.list_tables() # doctest: +ELLIPSIS
[...]
@@ -337,11 +337,12 @@ def list_tables(
the current database (`self.current_database`).
"""
if database is not None:
table_loc = database
table_loc = self._to_sqlglot_table(database)
else:
table_loc = self.current_database

table_loc = self._to_sqlglot_table(table_loc)
table_loc = sge.Table(
db=sg.to_identifier(self.current_database, quoted=self.compiler.quoted),
catalog=None,
)

conditions = [TRUE]

2 changes: 1 addition & 1 deletion ibis/backends/mysql/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
MYSQL_PASS = os.environ.get("IBIS_TEST_MYSQL_PASSWORD", "ibis")
MYSQL_HOST = os.environ.get("IBIS_TEST_MYSQL_HOST", "localhost")
MYSQL_PORT = int(os.environ.get("IBIS_TEST_MYSQL_PORT", 3306))
IBIS_TEST_MYSQL_DB = os.environ.get("IBIS_TEST_MYSQL_DATABASE", "ibis_testing")
IBIS_TEST_MYSQL_DB = os.environ.get("IBIS_TEST_MYSQL_DATABASE", "ibis-testing")


class TestConf(ServiceBackendTest):
10 changes: 5 additions & 5 deletions ibis/backends/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -239,7 +239,7 @@ def drop_table(
table_loc = self._to_sqlglot_table(database)
catalog, db = self._to_catalog_db_tuple(table_loc)

drop_stmt = sg.exp.Drop(
drop_stmt = sge.Drop(
kind="TABLE",
this=sg.table(name, db=db, catalog=catalog, quoted=self.compiler.quoted),
exists=force,
@@ -509,7 +509,7 @@ def _to_catalog_db_tuple(self, table_loc: sge.Table):
def _to_sqlglot_table(self, database):
if database is None:
# Create "table" with empty catalog and db
database = sg.exp.Table(catalog=None, db=None)
database = sge.Table(catalog=None, db=None)
elif isinstance(database, (list, tuple)):
if len(database) > 2:
raise ValueError(
@@ -528,7 +528,7 @@ def _to_sqlglot_table(self, database):
'\n("catalog", "database")'
'\n("database",)'
)
database = sg.exp.Table(
database = sge.Table(
catalog=sg.to_identifier(catalog, quoted=self.compiler.quoted),
db=sg.to_identifier(database, quoted=self.compiler.quoted),
)
@@ -538,14 +538,14 @@ def _to_sqlglot_table(self, database):
# sqlglot parsing of the string will assume that it's a Table
# so we unpack the arguments into a new sqlglot object, switching
# table (this) -> database (db) and database (db) -> catalog
table = sg.parse_one(database, into=sg.exp.Table, dialect=self.dialect)
table = sg.parse_one(database, into=sge.Table, dialect=self.dialect)
if table.args["catalog"] is not None:
raise exc.IbisInputError(
f"Overspecified table hierarchy provided: `{table.sql(self.dialect)}`"
)
catalog = table.args["db"]
db = table.args["this"]
database = sg.exp.Table(catalog=catalog, db=db)
database = sge.Table(catalog=catalog, db=db)
else:
raise ValueError(
"""Invalid database hierarchy format. Please use either dotted
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -676,7 +676,7 @@ def test_list_database_contents(con):
"flink": {"default_database"},
"impala": {"ibis_testing", "default", "_impala_builtins"},
"mssql": {"INFORMATION_SCHEMA", "dbo", "guest"},
"mysql": {"ibis_testing", "information_schema"},
"mysql": {"ibis-testing", "information_schema"},
"oracle": {"SYS", "IBIS"},
"postgres": {"public", "information_schema"},
"pyspark": set(),

0 comments on commit 23c0e81

Please sign in to comment.