From fd807e6670aaaf79f1e188b9e011c7811bca3318 Mon Sep 17 00:00:00 2001
From: Pat Nadolny <patnadolny@gmail.com>
Date: Tue, 20 Jun 2023 17:26:34 -0400
Subject: [PATCH] fix: Sink schema comparison before adding metadata columns
 (#1778)

* add failing test for get_sink

* refactor test mocks to share across modules

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reformat

* deep copy test schema so it doesnt get manipulated

* fix sink schema compare test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 singer_sdk/sinks/core.py         |  2 ++
 singer_sdk/target_base.py        |  2 +-
 tests/conftest.py                | 52 ++++++++++++++++++++++++++++++
 tests/core/test_target_base.py   | 30 ++++++++++++++++++
 tests/samples/test_target_csv.py | 54 ++------------------------------
 5 files changed, 87 insertions(+), 53 deletions(-)
 create mode 100644 tests/core/test_target_base.py

diff --git a/singer_sdk/sinks/core.py b/singer_sdk/sinks/core.py
index fbcbbb295..2510be03c 100644
--- a/singer_sdk/sinks/core.py
+++ b/singer_sdk/sinks/core.py
@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 import abc
+import copy
 import datetime
 import json
 import time
@@ -67,6 +68,7 @@ def __init__(
             "Initializing target sink for stream '%s'...",
             stream_name,
         )
+        self.original_schema = copy.deepcopy(schema)
         self.schema = schema
         if self.include_sdc_metadata_properties:
             self._add_sdc_metadata_to_schema()
diff --git a/singer_sdk/target_base.py b/singer_sdk/target_base.py
index 2fcb28314..d8320394a 100644
--- a/singer_sdk/target_base.py
+++ b/singer_sdk/target_base.py
@@ -167,7 +167,7 @@ def get_sink(
             return self.add_sink(stream_name, schema, key_properties)
 
         if (
-            existing_sink.schema != schema
+            existing_sink.original_schema != schema
             or existing_sink.key_properties != key_properties
         ):
             self.logger.info(
diff --git a/tests/conftest.py b/tests/conftest.py
index c2992328a..142e76fe1 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -9,6 +9,10 @@
 
 import pytest
 
+from singer_sdk import typing as th
+from singer_sdk.sinks import BatchSink
+from singer_sdk.target_base import Target
+
 if t.TYPE_CHECKING:
     from _pytest.config import Config
 
@@ -54,3 +58,51 @@ def outdir() -> t.Generator[str, None, None]:
 def snapshot_dir() -> pathlib.Path:
     """Return the path to the snapshot directory."""
     return pathlib.Path("tests/snapshots/")
+
+
+class BatchSinkMock(BatchSink):
+    """A mock Sink class."""
+
+    name = "batch-sink-mock"
+
+    def __init__(
+        self,
+        target: TargetMock,
+        stream_name: str,
+        schema: dict,
+        key_properties: list[str] | None,
+    ):
+        """Create the Mock batch-based sink."""
+        super().__init__(target, stream_name, schema, key_properties)
+        self.target = target
+
+    def process_record(self, record: dict, context: dict) -> None:
+        """Tracks the count of processed records."""
+        self.target.num_records_processed += 1
+        super().process_record(record, context)
+
+    def process_batch(self, context: dict) -> None:
+        """Write to mock trackers."""
+        self.target.records_written.extend(context["records"])
+        self.target.num_batches_processed += 1
+
+
+class TargetMock(Target):
+    """A mock Target class."""
+
+    name = "target-mock"
+    config_jsonschema = th.PropertiesList().to_dict()
+    default_sink_class = BatchSinkMock
+
+    def __init__(self, *args, **kwargs):
+        """Create the Mock target sync."""
+        super().__init__(*args, **kwargs)
+        self.state_messages_written: list[dict] = []
+        self.records_written: list[dict] = []
+        self.num_records_processed: int = 0
+        self.num_batches_processed: int = 0
+
+    def _write_state_message(self, state: dict):
+        """Emit the stream's latest state."""
+        super()._write_state_message(state)
+        self.state_messages_written.append(state)
diff --git a/tests/core/test_target_base.py b/tests/core/test_target_base.py
new file mode 100644
index 000000000..778fab722
--- /dev/null
+++ b/tests/core/test_target_base.py
@@ -0,0 +1,30 @@
+from __future__ import annotations
+
+import copy
+
+from tests.conftest import BatchSinkMock, TargetMock
+
+
+def test_get_sink():
+    input_schema_1 = {
+        "properties": {
+            "id": {
+                "type": ["string", "null"],
+            },
+            "col_ts": {
+                "format": "date-time",
+                "type": ["string", "null"],
+            },
+        },
+    }
+    input_schema_2 = copy.deepcopy(input_schema_1)
+    key_properties = []
+    target = TargetMock(config={"add_record_metadata": True})
+    sink = BatchSinkMock(target, "foo", input_schema_1, key_properties)
+    target._sinks_active["foo"] = sink
+    sink_returned = target.get_sink(
+        "foo",
+        schema=input_schema_2,
+        key_properties=key_properties,
+    )
+    assert sink_returned == sink
diff --git a/tests/samples/test_target_csv.py b/tests/samples/test_target_csv.py
index 54f50fc38..715edbb65 100644
--- a/tests/samples/test_target_csv.py
+++ b/tests/samples/test_target_csv.py
@@ -14,9 +14,6 @@
 from samples.sample_mapper.mapper import StreamTransform
 from samples.sample_tap_countries.countries_tap import SampleTapCountries
 from samples.sample_target_csv.csv_target import SampleTargetCSV
-from singer_sdk import typing as th
-from singer_sdk.sinks import BatchSink
-from singer_sdk.target_base import Target
 from singer_sdk.testing import (
     get_target_test_class,
     sync_end_to_end,
@@ -24,6 +21,7 @@
     tap_to_target_sync_test,
     target_sync_test,
 )
+from tests.conftest import TargetMock
 
 TEST_OUTPUT_DIR = Path(f".output/test_{uuid.uuid4()}/")
 SAMPLE_CONFIG = {"target_folder": f"{TEST_OUTPUT_DIR}/"}
@@ -55,54 +53,6 @@ def resource(self, test_output_dir):
 }
 
 
-class BatchSinkMock(BatchSink):
-    """A mock Sink class."""
-
-    name = "batch-sink-mock"
-
-    def __init__(
-        self,
-        target: TargetMock,
-        stream_name: str,
-        schema: dict,
-        key_properties: list[str] | None,
-    ):
-        """Create the Mock batch-based sink."""
-        super().__init__(target, stream_name, schema, key_properties)
-        self.target = target
-
-    def process_record(self, record: dict, context: dict) -> None:
-        """Tracks the count of processed records."""
-        self.target.num_records_processed += 1
-        super().process_record(record, context)
-
-    def process_batch(self, context: dict) -> None:
-        """Write to mock trackers."""
-        self.target.records_written.extend(context["records"])
-        self.target.num_batches_processed += 1
-
-
-class TargetMock(Target):
-    """A mock Target class."""
-
-    name = "target-mock"
-    config_jsonschema = th.PropertiesList().to_dict()
-    default_sink_class = BatchSinkMock
-
-    def __init__(self):
-        """Create the Mock target sync."""
-        super().__init__(config={})
-        self.state_messages_written: list[dict] = []
-        self.records_written: list[dict] = []
-        self.num_records_processed: int = 0
-        self.num_batches_processed: int = 0
-
-    def _write_state_message(self, state: dict):
-        """Emit the stream's latest state."""
-        super()._write_state_message(state)
-        self.state_messages_written.append(state)
-
-
 def test_countries_to_csv(csv_config: dict):
     tap = SampleTapCountries(config=SAMPLE_TAP_CONFIG, state=None)
     target = SampleTargetCSV(config=csv_config)
@@ -133,7 +83,7 @@ def test_target_batching():
     countries_record_count = 257
 
     with freeze_time(mocked_starttime):
-        target = TargetMock()
+        target = TargetMock(config={})
         target.max_parallelism = 1  # Limit unit test to 1 process
         assert target.num_records_processed == 0
         assert len(target.records_written) == 0