diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index f48222640..24f7be756 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -5,6 +5,7 @@ import logging import typing as t import warnings +from collections import UserString from contextlib import contextmanager from datetime import datetime from functools import lru_cache @@ -22,6 +23,89 @@ from sqlalchemy.engine.reflection import Inspector +class FullyQualifiedName(UserString): + """A fully qualified table name. + + This class provides a simple way to represent a fully qualified table name + as a single object. The string representation of this object is the fully + qualified table name, with the parts separated by periods. + + The parts of the fully qualified table name are: + - database + - schema + - table + + The database and schema are optional. If only the table name is provided, + the string representation of the object will be the table name alone. + + Example: + ``` + table_name = FullyQualifiedName("my_table", "my_schema", "my_db") + print(table_name) # my_db.my_schema.my_table + ``` + """ + + def __init__( + self, + *, + table: str = "", + schema: str | None = None, + database: str | None = None, + delimiter: str = ".", + dialect: sa.engine.Dialect, + ) -> None: + """Initialize the fully qualified table name. + + Args: + table: The name of the table. + schema: The name of the schema. Defaults to None. + database: The name of the database. Defaults to None. + delimiter: The delimiter to use between parts. Defaults to '.'. + dialect: The SQLAlchemy dialect to use for quoting. + + Raises: + ValueError: If the fully qualified name could not be generated. + """ + self.table = table + self.schema = schema + self.database = database + self.delimiter = delimiter + self.dialect = dialect + + parts = [] + if self.database: + parts.append(self.prepare_part(self.database)) + if self.schema: + parts.append(self.prepare_part(self.schema)) + if self.table: + parts.append(self.prepare_part(self.table)) + + if not parts: + raise ValueError( + "Could not generate fully qualified name: " + + ":".join( + [ + self.database or "(unknown-db)", + self.schema or "(unknown-schema)", + self.table or "(unknown-table-name)", + ], + ), + ) + + super().__init__(self.delimiter.join(parts)) + + def prepare_part(self, part: str) -> str: + """Prepare a part of the fully qualified name. + + Args: + part: The part to prepare. + + Returns: + The prepared part. + """ + return self.dialect.identifier_preparer.quote(part) + + class SQLConnector: # noqa: PLR0904 """Base class for SQLAlchemy-based connectors. @@ -238,13 +322,13 @@ def to_sql_type(jsonschema_type: dict) -> sa.types.TypeEngine: """ return th.to_sql_type(jsonschema_type) - @staticmethod def get_fully_qualified_name( + self, table_name: str | None = None, schema_name: str | None = None, db_name: str | None = None, delimiter: str = ".", - ) -> str: + ) -> FullyQualifiedName: """Concatenates a fully qualified name from the parts. Args: @@ -253,34 +337,16 @@ def get_fully_qualified_name( db_name: The name of the database. Defaults to None. delimiter: Generally: '.' for SQL names and '-' for Singer names. - Raises: - ValueError: If all 3 name parts not supplied. - Returns: The fully qualified name as a string. """ - parts = [] - - if db_name: - parts.append(db_name) - if schema_name: - parts.append(schema_name) - if table_name: - parts.append(table_name) - - if not parts: - raise ValueError( - "Could not generate fully qualified name: " - + ":".join( - [ - db_name or "(unknown-db)", - schema_name or "(unknown-schema)", - table_name or "(unknown-table-name)", - ], - ), - ) - - return delimiter.join(parts) + return FullyQualifiedName( + table=table_name, # type: ignore[arg-type] + schema=schema_name, + database=db_name, + delimiter=delimiter, + dialect=self._dialect, + ) @property def _dialect(self) -> sa.engine.Dialect: @@ -429,12 +495,7 @@ def discover_catalog_entry( `CatalogEntry` object for the given table or a view """ # Initialize unique stream name - unique_stream_id = self.get_fully_qualified_name( - db_name=None, - schema_name=schema_name, - table_name=table_name, - delimiter="-", - ) + unique_stream_id = f"{schema_name}-{table_name}" # Detect key properties possible_primary_keys: list[list[str]] = [] @@ -528,7 +589,7 @@ def discover_catalog_entries(self) -> list[dict]: def parse_full_table_name( # noqa: PLR6301 self, - full_table_name: str, + full_table_name: str | FullyQualifiedName, ) -> tuple[str | None, str | None, str]: """Parse a fully qualified table name into its parts. @@ -547,6 +608,13 @@ def parse_full_table_name( # noqa: PLR6301 A three part tuple (db_name, schema_name, table_name) with any unspecified or unused parts returned as None. """ + if isinstance(full_table_name, FullyQualifiedName): + return ( + full_table_name.database, + full_table_name.schema, + full_table_name.table, + ) + db_name: str | None = None schema_name: str | None = None @@ -560,7 +628,7 @@ def parse_full_table_name( # noqa: PLR6301 return db_name, schema_name, table_name - def table_exists(self, full_table_name: str) -> bool: + def table_exists(self, full_table_name: str | FullyQualifiedName) -> bool: """Determine if the target table already exists. Args: @@ -587,7 +655,7 @@ def schema_exists(self, schema_name: str) -> bool: def get_table_columns( self, - full_table_name: str, + full_table_name: str | FullyQualifiedName, column_names: list[str] | None = None, ) -> dict[str, sa.Column]: """Return a list of table columns. @@ -618,7 +686,7 @@ def get_table_columns( def get_table( self, - full_table_name: str, + full_table_name: str | FullyQualifiedName, column_names: list[str] | None = None, ) -> sa.Table: """Return a table object. @@ -643,7 +711,9 @@ def get_table( schema=schema_name, ) - def column_exists(self, full_table_name: str, column_name: str) -> bool: + def column_exists( + self, full_table_name: str | FullyQualifiedName, column_name: str + ) -> bool: """Determine if the target table already exists. Args: @@ -666,7 +736,7 @@ def create_schema(self, schema_name: str) -> None: def create_empty_table( self, - full_table_name: str, + full_table_name: str | FullyQualifiedName, schema: dict, primary_keys: t.Sequence[str] | None = None, partition_keys: list[str] | None = None, @@ -715,7 +785,7 @@ def create_empty_table( def _create_empty_column( self, - full_table_name: str, + full_table_name: str | FullyQualifiedName, column_name: str, sql_type: sa.types.TypeEngine, ) -> None: @@ -753,7 +823,7 @@ def prepare_schema(self, schema_name: str) -> None: def prepare_table( self, - full_table_name: str, + full_table_name: str | FullyQualifiedName, schema: dict, primary_keys: t.Sequence[str], partition_keys: list[str] | None = None, @@ -797,7 +867,7 @@ def prepare_table( def prepare_column( self, - full_table_name: str, + full_table_name: str | FullyQualifiedName, column_name: str, sql_type: sa.types.TypeEngine, ) -> None: @@ -822,7 +892,9 @@ def prepare_column( sql_type=sql_type, ) - def rename_column(self, full_table_name: str, old_name: str, new_name: str) -> None: + def rename_column( + self, full_table_name: str | FullyQualifiedName, old_name: str, new_name: str + ) -> None: """Rename the provided columns. Args: @@ -951,7 +1023,7 @@ def _get_type_sort_key( def _get_column_type( self, - full_table_name: str, + full_table_name: str | FullyQualifiedName, column_name: str, ) -> sa.types.TypeEngine: """Get the SQL type of the declared column. @@ -976,7 +1048,7 @@ def _get_column_type( def get_column_add_ddl( self, - table_name: str, + table_name: str | FullyQualifiedName, column_name: str, column_type: sa.types.TypeEngine, ) -> sa.DDL: @@ -1009,7 +1081,7 @@ def get_column_add_ddl( @staticmethod def get_column_rename_ddl( - table_name: str, + table_name: str | FullyQualifiedName, column_name: str, new_column_name: str, ) -> sa.DDL: @@ -1037,7 +1109,7 @@ def get_column_rename_ddl( @staticmethod def get_column_alter_ddl( - table_name: str, + table_name: str | FullyQualifiedName, column_name: str, column_type: sa.types.TypeEngine, ) -> sa.DDL: @@ -1096,7 +1168,7 @@ def update_collation( def _adapt_column_type( self, - full_table_name: str, + full_table_name: str | FullyQualifiedName, column_name: str, sql_type: sa.types.TypeEngine, ) -> None: @@ -1187,7 +1259,7 @@ def deserialize_json(self, json_str: str) -> object: # noqa: PLR6301 def delete_old_versions( self, *, - full_table_name: str, + full_table_name: str | FullyQualifiedName, version_column_name: str, current_version: int, ) -> None: diff --git a/singer_sdk/sinks/sql.py b/singer_sdk/sinks/sql.py index 33a741614..0f7695ef0 100644 --- a/singer_sdk/sinks/sql.py +++ b/singer_sdk/sinks/sql.py @@ -21,6 +21,7 @@ if t.TYPE_CHECKING: from sqlalchemy.sql import Executable + from singer_sdk.connectors.sql import FullyQualifiedName from singer_sdk.target_base import Target _C = t.TypeVar("_C", bound=SQLConnector) @@ -109,7 +110,7 @@ def database_name(self) -> str | None: # Assumes single-DB target context. @property - def full_table_name(self) -> str: + def full_table_name(self) -> FullyQualifiedName: """Return the fully qualified table name. Returns: @@ -122,7 +123,7 @@ def full_table_name(self) -> str: ) @property - def full_schema_name(self) -> str: + def full_schema_name(self) -> FullyQualifiedName: """Return the fully qualified schema name. Returns: @@ -269,7 +270,7 @@ def process_batch(self, context: dict) -> None: def generate_insert_statement( self, - full_table_name: str, + full_table_name: str | FullyQualifiedName, schema: dict, ) -> str | Executable: """Generate an insert statement for the given records. @@ -297,7 +298,7 @@ def generate_insert_statement( def bulk_insert_records( self, - full_table_name: str, + full_table_name: str | FullyQualifiedName, schema: dict, records: t.Iterable[dict[str, t.Any]], ) -> int | None: diff --git a/singer_sdk/streams/sql.py b/singer_sdk/streams/sql.py index 954159885..2877a505b 100644 --- a/singer_sdk/streams/sql.py +++ b/singer_sdk/streams/sql.py @@ -14,6 +14,7 @@ from singer_sdk.streams.core import REPLICATION_INCREMENTAL, Stream if t.TYPE_CHECKING: + from singer_sdk.connectors.sql import FullyQualifiedName from singer_sdk.helpers.types import Context from singer_sdk.tap_base import Tap @@ -124,7 +125,7 @@ def primary_keys(self, new_value: t.Sequence[str]) -> None: self._singer_catalog_entry.metadata.root.table_key_properties = new_value @property - def fully_qualified_name(self) -> str: + def fully_qualified_name(self) -> FullyQualifiedName: """Generate the fully qualified version of the table name. Raises: diff --git a/tests/core/test_connector_sql.py b/tests/core/test_connector_sql.py index 10ee0c0f4..6a9a5e189 100644 --- a/tests/core/test_connector_sql.py +++ b/tests/core/test_connector_sql.py @@ -7,9 +7,11 @@ import pytest import sqlalchemy as sa from sqlalchemy.dialects import registry, sqlite +from sqlalchemy.engine.default import DefaultDialect from samples.sample_duckdb import DuckDBConnector from singer_sdk.connectors import SQLConnector +from singer_sdk.connectors.sql import FullyQualifiedName from singer_sdk.exceptions import ConfigValidationError if t.TYPE_CHECKING: @@ -355,3 +357,35 @@ def create_engine(self) -> Engine: connector = CustomConnector(config={"sqlalchemy_url": "myrdbms:///"}) connector.create_engine() + + +def test_fully_qualified_name(): + dialect = DefaultDialect() + + fqn = FullyQualifiedName(table="my_table", dialect=dialect) + assert fqn == "my_table" + + fqn = FullyQualifiedName(schema="my_schema", table="my_table", dialect=dialect) + assert fqn == "my_schema.my_table" + + fqn = FullyQualifiedName( + database="my_catalog", + schema="my_schema", + table="my_table", + dialect=dialect, + ) + assert fqn == "my_catalog.my_schema.my_table" + + +def test_fully_qualified_name_with_quoting(): + dialect = DefaultDialect() + + fqn = FullyQualifiedName(table="order", schema="public", dialect=dialect) + assert fqn == 'public."order"' + + +def test_fully_qualified_name_empty_error(): + dialect = DefaultDialect() + + with pytest.raises(ValueError, match="Could not generate fully qualified name"): + FullyQualifiedName(dialect=dialect)