Skip to content

Commit

Permalink
Use sqlalchemy's 'quote' function to quote table names
Browse files Browse the repository at this point in the history
  • Loading branch information
seut committed Jun 14, 2024
1 parent 6be34c3 commit a778756
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 25 deletions.
5 changes: 3 additions & 2 deletions cratedb_toolkit/testing/testcontainers/cratedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from cratedb_toolkit.testing.testcontainers.util import KeepaliveContainer, asbool
from cratedb_toolkit.util import DatabaseAdapter
from cratedb_toolkit.util.database import quote_table_name

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -189,7 +188,9 @@ def reset(self, tables: Optional[list] = None):
"""
if tables and self.database:
for reset_table in tables:
self.database.connection.exec_driver_sql(f"DROP TABLE IF EXISTS {quote_table_name(reset_table)};")
self.database.connection.exec_driver_sql(
f"DROP TABLE IF EXISTS {self.database.quote_relation_name(reset_table)};"
)

def get_connection_url(self, *args, **kwargs):
"""
Expand Down
60 changes: 37 additions & 23 deletions cratedb_toolkit/util/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,38 @@ def __init__(self, dburi: str, echo: bool = False):
self.engine = sa.create_engine(self.dburi, echo=echo)
self.connection = self.engine.connect()

def quote_relation_name(self, ident: str) -> str:
"""
Quote the given, possibly full-qualified, relation name if needed.
In: foo
Out: foo
In: Foo
Out: "Foo"
In: "Foo"
Out: "Foo"
In: foo.bar
Out: "foo"."bar"
In: "foo.bar"
Out: "foo.bar"
"""
if ident[0] == '"' and ident[len(ident) - 1] == '"':
return ident
if "." in ident:
parts = ident.split(".")
if len(parts) > 2:
raise ValueError(f"Invalid relation name {ident}")
return (
self.engine.dialect.identifier_preparer.quote_schema(parts[0])
+ "."
+ self.engine.dialect.identifier_preparer.quote(parts[1])
)
return self.engine.dialect.identifier_preparer.quote(ident=ident)

def run_sql(self, sql: t.Union[str, Path, io.IOBase], records: bool = False, ignore: str = None):
"""
Run SQL statement, and return results, optionally ignoring exceptions.
Expand Down Expand Up @@ -82,7 +114,7 @@ def count_records(self, name: str, errors: Literal["raise", "ignore"] = "raise")
"""
Return number of records in table.
"""
sql = f"SELECT COUNT(*) AS count FROM {quote_table_name(name)};" # noqa: S608
sql = f"SELECT COUNT(*) AS count FROM {self.quote_relation_name(name)};" # noqa: S608
try:
results = self.run_sql(sql=sql)
except ProgrammingError as ex:
Expand All @@ -96,7 +128,7 @@ def table_exists(self, name: str) -> bool:
"""
Check whether given table exists.
"""
sql = f"SELECT 1 FROM {quote_table_name(name)} LIMIT 1;" # noqa: S608
sql = f"SELECT 1 FROM {self.quote_relation_name(name)} LIMIT 1;" # noqa: S608
try:
self.run_sql(sql=sql)
return True
Expand All @@ -107,15 +139,15 @@ def refresh_table(self, name: str):
"""
Run a `REFRESH TABLE ...` command.
"""
sql = f"REFRESH TABLE {quote_table_name(name)};" # noqa: S608
sql = f"REFRESH TABLE {self.quote_relation_name(name)};" # noqa: S608
self.run_sql(sql=sql)
return True

def prune_table(self, name: str, errors: Literal["raise", "ignore"] = "raise"):
"""
Run a `DELETE FROM ...` command.
"""
sql = f"DELETE FROM {quote_table_name(name)};" # noqa: S608
sql = f"DELETE FROM {self.quote_relation_name(name)};" # noqa: S608
try:
self.run_sql(sql=sql)
except ProgrammingError as ex:
Expand All @@ -129,7 +161,7 @@ def drop_table(self, name: str):
"""
Run a `DROP TABLE ...` command.
"""
sql = f"DROP TABLE IF EXISTS {quote_table_name(name)};" # noqa: S608
sql = f"DROP TABLE IF EXISTS {self.quote_relation_name(name)};" # noqa: S608
self.run_sql(sql=sql)
return True

Expand Down Expand Up @@ -332,21 +364,3 @@ def decode_database_table(url: str) -> t.Tuple[str, str]:
if url_.scheme == "crate" and not database:
database = url_.query_params.get("schema")
return database, table


def quote_table_name(name: str) -> str:
"""
Quote table name if not happened already.
In: foo
Out: "foo"
In: "foo"
Out: "foo"
In: foo.bar
Out: foo.bar
"""
if '"' not in name and "." not in name:
name = f'"{name}"'
return name
23 changes: 23 additions & 0 deletions tests/util/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from cratedb_toolkit.util import DatabaseAdapter


def test_quote_relation_name():
database = DatabaseAdapter(dburi="crate://localhost")
assert database.quote_relation_name("my_table") == "my_table"
assert database.quote_relation_name("my-table") == '"my-table"'
assert database.quote_relation_name("MyTable") == '"MyTable"'
assert database.quote_relation_name('"MyTable"') == '"MyTable"'
assert database.quote_relation_name("my_schema.my_table") == "my_schema.my_table"
assert database.quote_relation_name("my-schema.my_table") == '"my-schema".my_table'
assert database.quote_relation_name('"wrong-quoted-fqn.my_table"') == '"wrong-quoted-fqn.my_table"'
assert database.quote_relation_name('"my_schema"."my_table"') == '"my_schema"."my_table"'
# reserved keyword must be quoted
assert database.quote_relation_name("table") == '"table"'


def test_quote_relation_name_with_invalid_fqn():
database = DatabaseAdapter(dburi="crate://localhost")
with pytest.raises(ValueError):
database.quote_relation_name("my-db.my-schema.my-table")

0 comments on commit a778756

Please sign in to comment.