diff --git a/cratedb_toolkit/testing/testcontainers/cratedb.py b/cratedb_toolkit/testing/testcontainers/cratedb.py index 3b7433e1..031e3d3d 100644 --- a/cratedb_toolkit/testing/testcontainers/cratedb.py +++ b/cratedb_toolkit/testing/testcontainers/cratedb.py @@ -24,7 +24,6 @@ from cratedb_toolkit.testing.testcontainers.util import KeepaliveContainer, asbool from cratedb_toolkit.util import DatabaseAdapter -from cratedb_toolkit.util.database import quote_table_name logger = logging.getLogger(__name__) @@ -189,7 +188,9 @@ def reset(self, tables: Optional[list] = None): """ if tables and self.database: for reset_table in tables: - self.database.connection.exec_driver_sql(f"DROP TABLE IF EXISTS {quote_table_name(reset_table)};") + self.database.connection.exec_driver_sql( + f"DROP TABLE IF EXISTS {self.database.quote_relation_name(reset_table)};" + ) def get_connection_url(self, *args, **kwargs): """ diff --git a/cratedb_toolkit/util/database.py b/cratedb_toolkit/util/database.py index 1a40061d..e44f62c2 100644 --- a/cratedb_toolkit/util/database.py +++ b/cratedb_toolkit/util/database.py @@ -33,6 +33,38 @@ def __init__(self, dburi: str, echo: bool = False): self.engine = sa.create_engine(self.dburi, echo=echo) self.connection = self.engine.connect() + def quote_relation_name(self, ident: str) -> str: + """ + Quote the given, possibly full-qualified, relation name if needed. + + In: foo + Out: foo + + In: Foo + Out: "Foo" + + In: "Foo" + Out: "Foo" + + In: foo.bar + Out: "foo"."bar" + + In: "foo.bar" + Out: "foo.bar" + """ + if ident[0] == '"' and ident[len(ident) - 1] == '"': + return ident + if "." in ident: + parts = ident.split(".") + if len(parts) > 2: + raise ValueError(f"Invalid relation name {ident}") + return ( + self.engine.dialect.identifier_preparer.quote_schema(parts[0]) + + "." + + self.engine.dialect.identifier_preparer.quote(parts[1]) + ) + return self.engine.dialect.identifier_preparer.quote(ident=ident) + def run_sql(self, sql: t.Union[str, Path, io.IOBase], records: bool = False, ignore: str = None): """ Run SQL statement, and return results, optionally ignoring exceptions. @@ -82,7 +114,7 @@ def count_records(self, name: str, errors: Literal["raise", "ignore"] = "raise") """ Return number of records in table. """ - sql = f"SELECT COUNT(*) AS count FROM {quote_table_name(name)};" # noqa: S608 + sql = f"SELECT COUNT(*) AS count FROM {self.quote_relation_name(name)};" # noqa: S608 try: results = self.run_sql(sql=sql) except ProgrammingError as ex: @@ -96,7 +128,7 @@ def table_exists(self, name: str) -> bool: """ Check whether given table exists. """ - sql = f"SELECT 1 FROM {quote_table_name(name)} LIMIT 1;" # noqa: S608 + sql = f"SELECT 1 FROM {self.quote_relation_name(name)} LIMIT 1;" # noqa: S608 try: self.run_sql(sql=sql) return True @@ -107,7 +139,7 @@ def refresh_table(self, name: str): """ Run a `REFRESH TABLE ...` command. """ - sql = f"REFRESH TABLE {quote_table_name(name)};" # noqa: S608 + sql = f"REFRESH TABLE {self.quote_relation_name(name)};" # noqa: S608 self.run_sql(sql=sql) return True @@ -115,7 +147,7 @@ def prune_table(self, name: str, errors: Literal["raise", "ignore"] = "raise"): """ Run a `DELETE FROM ...` command. """ - sql = f"DELETE FROM {quote_table_name(name)};" # noqa: S608 + sql = f"DELETE FROM {self.quote_relation_name(name)};" # noqa: S608 try: self.run_sql(sql=sql) except ProgrammingError as ex: @@ -129,7 +161,7 @@ def drop_table(self, name: str): """ Run a `DROP TABLE ...` command. """ - sql = f"DROP TABLE IF EXISTS {quote_table_name(name)};" # noqa: S608 + sql = f"DROP TABLE IF EXISTS {self.quote_relation_name(name)};" # noqa: S608 self.run_sql(sql=sql) return True @@ -332,21 +364,3 @@ def decode_database_table(url: str) -> t.Tuple[str, str]: if url_.scheme == "crate" and not database: database = url_.query_params.get("schema") return database, table - - -def quote_table_name(name: str) -> str: - """ - Quote table name if not happened already. - - In: foo - Out: "foo" - - In: "foo" - Out: "foo" - - In: foo.bar - Out: foo.bar - """ - if '"' not in name and "." not in name: - name = f'"{name}"' - return name diff --git a/tests/util/database.py b/tests/util/database.py new file mode 100644 index 00000000..3d8cd60c --- /dev/null +++ b/tests/util/database.py @@ -0,0 +1,23 @@ +import pytest + +from cratedb_toolkit.util import DatabaseAdapter + + +def test_quote_relation_name(): + database = DatabaseAdapter(dburi="crate://localhost") + assert database.quote_relation_name("my_table") == "my_table" + assert database.quote_relation_name("my-table") == '"my-table"' + assert database.quote_relation_name("MyTable") == '"MyTable"' + assert database.quote_relation_name('"MyTable"') == '"MyTable"' + assert database.quote_relation_name("my_schema.my_table") == "my_schema.my_table" + assert database.quote_relation_name("my-schema.my_table") == '"my-schema".my_table' + assert database.quote_relation_name('"wrong-quoted-fqn.my_table"') == '"wrong-quoted-fqn.my_table"' + assert database.quote_relation_name('"my_schema"."my_table"') == '"my_schema"."my_table"' + # reserved keyword must be quoted + assert database.quote_relation_name("table") == '"table"' + + +def test_quote_relation_name_with_invalid_fqn(): + database = DatabaseAdapter(dburi="crate://localhost") + with pytest.raises(ValueError): + database.quote_relation_name("my-db.my-schema.my-table")