Skip to content

Commit

Permalink
refactor: Allow overriding the bulk insert statement
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Aug 12, 2022
1 parent cc7e06d commit 4b5ff66
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 8 deletions.
38 changes: 31 additions & 7 deletions singer_sdk/sinks/sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Sink classes load data to SQL targets."""

from textwrap import dedent
from typing import Any, Dict, Iterable, List, Optional, Type

import sqlalchemy
Expand Down Expand Up @@ -162,6 +163,31 @@ def create_table_with_records(
full_table_name=full_table_name, schema=schema, records=records
)

def generate_insert_statement(
self,
full_table_name: str,
schema: dict,
) -> str:
"""Generate an insert statement for the given records.
Args:
full_table_name: the target table name.
schema: the JSON schema for the new table.
Returns:
An insert statement.
"""
property_names = list(schema["properties"].keys())
statement = dedent(
f"""\
INSERT INTO {full_table_name}
({", ".join(property_names)})
VALUES ({", ".join([f":{name}" for name in property_names])})
"""
)

return statement.rstrip()

def bulk_insert_records(
self,
full_table_name: str,
Expand All @@ -183,15 +209,13 @@ def bulk_insert_records(
Returns:
True if table exists, False if not, None if unsure or undetectable.
"""
property_names = list(schema["properties"].keys())
insert_sql = sqlalchemy.text(
f"INSERT INTO {full_table_name} "
f"({', '.join([n for n in property_names])})"
f" VALUES "
f"({', '.join([':' + n for n in property_names])})"
insert_sql = self.generate_insert_statement(
full_table_name,
schema,
)
self.logger.info("Inserting with SQL: %s", insert_sql)
self.connector.connection.execute(
insert_sql,
sqlalchemy.text(insert_sql),
records,
)
if isinstance(records, list):
Expand Down
49 changes: 48 additions & 1 deletion tests/core/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from copy import deepcopy
from io import StringIO
from pathlib import Path
from textwrap import dedent
from typing import Dict, cast
from uuid import uuid4

import pytest

from samples.sample_tap_sqlite import SQLiteConnector, SQLiteTap
from samples.sample_target_csv.csv_target import SampleTargetCSV
from samples.sample_target_sqlite import SQLiteTarget
from samples.sample_target_sqlite import SQLiteSink, SQLiteTarget
from singer_sdk import SQLStream
from singer_sdk import typing as th
from singer_sdk.helpers._singer import Catalog, MetadataMapping, StreamMetadata
Expand Down Expand Up @@ -431,3 +432,49 @@ def test_sqlite_column_no_morph(sqlite_sample_target: SQLTarget):
target_sync_test(sqlite_sample_target, input=StringIO(tap_output_a), finalize=True)
# Int should be inserted as string.
target_sync_test(sqlite_sample_target, input=StringIO(tap_output_b), finalize=True)


@pytest.mark.parametrize(
"stream_name,schema,key_properties,expected_dml",
[
(
"test_stream",
{
"type": "object",
"properties": {
"id": {"type": "integer"},
"name": {"type": "string"},
},
},
[],
dedent(
"""\
INSERT INTO test_stream
(id, name)
VALUES (:id, :name)"""
),
),
],
ids=[
"no_key_properties",
],
)
def test_sqlite_generate_insert_statement(
sqlite_sample_target: SQLiteTarget,
stream_name: str,
schema: dict,
key_properties: list,
expected_dml: str,
):
sink = SQLiteSink(
sqlite_sample_target,
stream_name=stream_name,
schema=schema,
key_properties=key_properties,
)

dml = sink.generate_insert_statement(
sink.full_table_name,
sink.schema,
)
assert dml == expected_dml

0 comments on commit 4b5ff66

Please sign in to comment.