diff --git a/samples/sample_tap_hostile/__init__.py b/samples/sample_tap_hostile/__init__.py new file mode 100644 index 000000000..8b1750c15 --- /dev/null +++ b/samples/sample_tap_hostile/__init__.py @@ -0,0 +1,3 @@ +"""A sample tap for testing SQL target property name transformations.""" + +from .hostile_tap import SampleTapHostile diff --git a/samples/sample_tap_hostile/hostile_streams.py b/samples/sample_tap_hostile/hostile_streams.py new file mode 100644 index 000000000..cb3b619a0 --- /dev/null +++ b/samples/sample_tap_hostile/hostile_streams.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import random +import string +from typing import Iterable + +from singer_sdk import typing as th +from singer_sdk.streams import Stream + + +class HostilePropertyNamesStream(Stream): + """ + A stream with property names that are not compatible as unescaped identifiers + in common DBMS systems. + """ + + name = "hostile_property_names_stream" + schema = th.PropertiesList( + th.Property("name with spaces", th.StringType), + th.Property("NameIsCamelCase", th.StringType), + th.Property("name-with-dashes", th.StringType), + th.Property("Name-with-Dashes-and-Mixed-cases", th.StringType), + th.Property("5name_starts_with_number", th.StringType), + th.Property("6name_starts_with_number", th.StringType), + th.Property("7name_starts_with_number", th.StringType), + th.Property("name_with_emoji_😈", th.StringType), + ).to_dict() + + @staticmethod + def get_random_lowercase_string(): + return "".join(random.choice(string.ascii_lowercase) for _ in range(10)) + + def get_records(self, context: dict | None) -> Iterable[dict | tuple[dict, dict]]: + return ( + { + key: self.get_random_lowercase_string() + for key in self.schema["properties"].keys() + } + for _ in range(10) + ) diff --git a/samples/sample_tap_hostile/hostile_tap.py b/samples/sample_tap_hostile/hostile_tap.py new file mode 100644 index 000000000..131a9746b --- /dev/null +++ b/samples/sample_tap_hostile/hostile_tap.py @@ -0,0 +1,24 @@ +"""A sample tap for testing SQL target property name transformations.""" + +from typing import List + +from samples.sample_tap_hostile.hostile_streams import HostilePropertyNamesStream +from singer_sdk import Stream, Tap +from singer_sdk.typing import PropertiesList + + +class SampleTapHostile(Tap): + """Sample tap for for testing SQL target property name transformations.""" + + name: str = "sample-tap-hostile" + config_jsonschema = PropertiesList().to_dict() + + def discover_streams(self) -> List[Stream]: + """Return a list of discovered streams.""" + return [ + HostilePropertyNamesStream(tap=self), + ] + + +if __name__ == "__main__": + SampleTapHostile.cli() diff --git a/singer_sdk/exceptions.py b/singer_sdk/exceptions.py index dd3674675..6f56f8417 100644 --- a/singer_sdk/exceptions.py +++ b/singer_sdk/exceptions.py @@ -50,3 +50,10 @@ class TapStreamConnectionFailure(Exception): class TooManyRecordsException(Exception): """Exception to raise when query returns more records than max_records.""" + + +class ConformedNameClashException(Exception): + """Raised when name conforming produces clashes. + + e.g. two columns conformed to the same name + """ diff --git a/singer_sdk/helpers/_conformers.py b/singer_sdk/helpers/_conformers.py new file mode 100644 index 000000000..a5de7f4f9 --- /dev/null +++ b/singer_sdk/helpers/_conformers.py @@ -0,0 +1,42 @@ +"""Helper functions for conforming identifiers.""" +import re +from string import ascii_lowercase, digits + + +def snakecase(string: str) -> str: + """Convert string into snake case. + + Args: + string: String to convert. + + Returns: + string: Snake cased string. + """ + string = re.sub(r"[\-\.\s]", "_", string) + string = ( + string[0].lower() + + re.sub( + r"[A-Z]", lambda matched: "_" + str(matched.group(0).lower()), string[1:] + ) + if string + else string + ) + return re.sub(r"_{2,}", "_", string).rstrip("_") + + +def replace_leading_digit(string: str) -> str: + """Replace leading numeric character with equivalent letter. + + Args: + string: String to process. + + Returns: + A modified string if original starts with a number, + else the unmodified original. + """ + if string[0] in digits: + letters = list(ascii_lowercase) + numbers = [int(d) for d in digits] + digit_map = {n: letters[n] for n in numbers} + return digit_map[int(string[0])] + string[1:] + return string diff --git a/singer_sdk/sinks/core.py b/singer_sdk/sinks/core.py index 18487546c..186ca28b2 100644 --- a/singer_sdk/sinks/core.py +++ b/singer_sdk/sinks/core.py @@ -71,7 +71,7 @@ def __init__( self.latest_state: dict | None = None self._draining_state: dict | None = None self.drained_state: dict | None = None - self.key_properties = key_properties or [] + self._key_properties = key_properties or [] # Tally counters self._total_records_written: int = 0 @@ -202,6 +202,15 @@ def datetime_error_treatment(self) -> DatetimeErrorTreatmentEnum: """ return DatetimeErrorTreatmentEnum.ERROR + @property + def key_properties(self) -> list[str]: + """Return key properties. + + Returns: + A list of stream key properties. + """ + return self._key_properties + # Record processing def _add_sdc_metadata_to_record( diff --git a/singer_sdk/sinks/sql.py b/singer_sdk/sinks/sql.py index c3455d5df..3b9ad0d34 100644 --- a/singer_sdk/sinks/sql.py +++ b/singer_sdk/sinks/sql.py @@ -1,5 +1,8 @@ """Sink classes load data to SQL targets.""" +import re +from collections import defaultdict +from copy import copy from textwrap import dedent from typing import Any, Dict, Iterable, List, Optional, Type, Union @@ -8,6 +11,8 @@ from sqlalchemy.sql import Executable from sqlalchemy.sql.expression import bindparam +from singer_sdk.exceptions import ConformedNameClashException +from singer_sdk.helpers._conformers import replace_leading_digit, snakecase from singer_sdk.plugin_base import PluginBase from singer_sdk.sinks.batch import BatchSink from singer_sdk.streams import SQLConnector @@ -67,7 +72,8 @@ def table_name(self) -> str: The target table name. """ parts = self.stream_name.split("-") - return self.stream_name if len(parts) == 1 else parts[-1] + table = self.stream_name if len(parts) == 1 else parts[-1] + return self.conform_name(table, "table") @property def schema_name(self) -> Optional[str]: @@ -80,7 +86,7 @@ def schema_name(self) -> Optional[str]: if len(parts) in {2, 3}: # Stream name is a two-part or three-part identifier. # Use the second-to-last part as the schema name. - return parts[-2] + return self.conform_name(parts[-2], "schema") # Schema name not detected. return None @@ -118,6 +124,86 @@ def full_schema_name(self) -> str: schema_name=self.schema_name, db_name=self.database_name ) + def conform_name(self, name: str, object_type: Optional[str] = None) -> str: + """Conform a stream property name to one suitable for the target system. + + Transforms names to snake case by default, applicable to most common DBMSs'. + Developers may override this method to apply custom transformations + to database/schema/table/column names. + + Args: + name: Property name. + object_type: One of ``database``, ``schema``, ``table`` or ``column``. + + + Returns: + The name transformed to snake case. + """ + # strip non-alphanumeric characters, keeping - . _ and spaces + name = re.sub(r"[^a-zA-Z0-9_\-\.\s]", "", name) + # convert to snakecase + name = snakecase(name) + # replace leading digit + return replace_leading_digit(name) + + @staticmethod + def _check_conformed_names_not_duplicated( + conformed_property_names: Dict[str, str] + ) -> None: + """Check if conformed names produce duplicate keys. + + Args: + conformed_property_names: A name:conformed_name dict map. + + Raises: + ConformedNameClashException: if duplicates found. + """ + # group: {'_a': ['1_a'], 'abc': ['aBc', 'abC']} + grouped = defaultdict(list) + for k, v in conformed_property_names.items(): + grouped[v].append(k) + + # filter + duplicates = list(filter(lambda p: len(p[1]) > 1, grouped.items())) + if duplicates: + raise ConformedNameClashException( + "Duplicate stream properties produced when " + + f"conforming property names: {duplicates}" + ) + + def conform_schema(self, schema: dict) -> dict: + """Return schema dictionary with property names conformed. + + Args: + schema: JSON schema dictionary. + + Returns: + A schema dictionary with the property names conformed. + """ + conformed_schema = copy(schema) + conformed_property_names = { + key: self.conform_name(key) for key in conformed_schema["properties"].keys() + } + self._check_conformed_names_not_duplicated(conformed_property_names) + conformed_schema["properties"] = { + conformed_property_names[key]: value + for key, value in conformed_schema["properties"].items() + } + return conformed_schema + + def conform_record(self, record: dict) -> dict: + """Return record dictionary with property names conformed. + + Args: + record: Dictionary representing a single record. + + Returns: + New record dictionary with conformed column names. + """ + conformed_property_names = {key: self.conform_name(key) for key in record} + self._check_conformed_names_not_duplicated(conformed_property_names) + return {conformed_property_names[key]: value for key, value in record.items()} + def setup(self) -> None: """Set up Sink. @@ -128,11 +214,20 @@ def setup(self) -> None: self.connector.prepare_schema(self.schema_name) self.connector.prepare_table( full_table_name=self.full_table_name, - schema=self.schema, + schema=self.conform_schema(self.schema), primary_keys=self.key_properties, as_temp_table=False, ) + @property + def key_properties(self) -> List[str]: + """Return key properties, conformed to target system naming requirements. + + Returns: + A list of key properties, conformed with `self.conform_name()` + """ + return [self.conform_name(key, "column") for key in super().key_properties] + def process_batch(self, context: dict) -> None: """Process a batch with the given batch context. @@ -164,7 +259,7 @@ def generate_insert_statement( Returns: An insert statement. """ - property_names = list(schema["properties"].keys()) + property_names = list(self.conform_schema(schema)["properties"].keys()) statement = dedent( f"""\ INSERT INTO {full_table_name} @@ -172,7 +267,6 @@ def generate_insert_statement( VALUES ({", ".join([f":{name}" for name in property_names])}) """ ) - return statement.rstrip() def bulk_insert_records( @@ -203,12 +297,14 @@ def bulk_insert_records( if isinstance(insert_sql, str): insert_sql = sqlalchemy.text(insert_sql) + conformed_records = ( + [self.conform_record(record) for record in records] + if isinstance(records, list) + else (self.conform_record(record) for record in records) + ) self.logger.info("Inserting with SQL: %s", insert_sql) - self.connector.connection.execute(insert_sql, records) - if isinstance(records, list): - return len(records) # If list, we can quickly return record count. - - return None # Unknown record count. + self.connector.connection.execute(insert_sql, conformed_records) + return len(conformed_records) if isinstance(conformed_records, list) else None def merge_upsert_from_table( self, target_table_name: str, from_table_name: str, join_keys: List[str] diff --git a/tests/core/test_sqlite.py b/tests/core/test_sqlite.py index b9c7c82ca..8c714dfcd 100644 --- a/tests/core/test_sqlite.py +++ b/tests/core/test_sqlite.py @@ -12,6 +12,7 @@ import pytest import sqlalchemy +from samples.sample_tap_hostile import SampleTapHostile from samples.sample_tap_sqlite import SQLiteConnector, SQLiteTap from samples.sample_target_csv.csv_target import SampleTargetCSV from samples.sample_target_sqlite import SQLiteSink, SQLiteTarget @@ -569,3 +570,41 @@ def test_sqlite_generate_insert_statement( sink.schema, ) assert dml == expected_dml + + +def test_hostile_to_sqlite( + sqlite_sample_target: SQLTarget, sqlite_target_test_config: dict +): + tap = SampleTapHostile() + tap_to_target_sync_test(tap, sqlite_sample_target) + # check if stream table was created + db = sqlite3.connect(sqlite_target_test_config["path_to_db"]) + cursor = db.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = [res[0] for res in cursor.fetchall()] + assert "hostile_property_names_stream" in tables + # check if columns were conformed + cursor.execute( + dedent( + """ + SELECT + p.name as columnName + FROM sqlite_master m + left outer join pragma_table_info((m.name)) p + on m.name <> p.name + where m.name = 'hostile_property_names_stream' + ; + """ + ) + ) + columns = {res[0] for res in cursor.fetchall()} + assert columns == { + "name_with_spaces", + "name_is_camel_case", + "name_with_dashes", + "name_with_dashes_and_mixed_cases", + "gname_starts_with_number", + "fname_starts_with_number", + "hname_starts_with_number", + "name_with_emoji", + }