diff --git a/singer_sdk/sinks/core.py b/singer_sdk/sinks/core.py index fbcbbb295..2510be03c 100644 --- a/singer_sdk/sinks/core.py +++ b/singer_sdk/sinks/core.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +import copy import datetime import json import time @@ -67,6 +68,7 @@ def __init__( "Initializing target sink for stream '%s'...", stream_name, ) + self.original_schema = copy.deepcopy(schema) self.schema = schema if self.include_sdc_metadata_properties: self._add_sdc_metadata_to_schema() diff --git a/singer_sdk/target_base.py b/singer_sdk/target_base.py index 2fcb28314..d8320394a 100644 --- a/singer_sdk/target_base.py +++ b/singer_sdk/target_base.py @@ -167,7 +167,7 @@ def get_sink( return self.add_sink(stream_name, schema, key_properties) if ( - existing_sink.schema != schema + existing_sink.original_schema != schema or existing_sink.key_properties != key_properties ): self.logger.info( diff --git a/tests/conftest.py b/tests/conftest.py index c2992328a..142e76fe1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,10 @@ import pytest +from singer_sdk import typing as th +from singer_sdk.sinks import BatchSink +from singer_sdk.target_base import Target + if t.TYPE_CHECKING: from _pytest.config import Config @@ -54,3 +58,51 @@ def outdir() -> t.Generator[str, None, None]: def snapshot_dir() -> pathlib.Path: """Return the path to the snapshot directory.""" return pathlib.Path("tests/snapshots/") + + +class BatchSinkMock(BatchSink): + """A mock Sink class.""" + + name = "batch-sink-mock" + + def __init__( + self, + target: TargetMock, + stream_name: str, + schema: dict, + key_properties: list[str] | None, + ): + """Create the Mock batch-based sink.""" + super().__init__(target, stream_name, schema, key_properties) + self.target = target + + def process_record(self, record: dict, context: dict) -> None: + """Tracks the count of processed records.""" + self.target.num_records_processed += 1 + super().process_record(record, context) + + def process_batch(self, context: dict) -> None: + """Write to mock trackers.""" + self.target.records_written.extend(context["records"]) + self.target.num_batches_processed += 1 + + +class TargetMock(Target): + """A mock Target class.""" + + name = "target-mock" + config_jsonschema = th.PropertiesList().to_dict() + default_sink_class = BatchSinkMock + + def __init__(self, *args, **kwargs): + """Create the Mock target sync.""" + super().__init__(*args, **kwargs) + self.state_messages_written: list[dict] = [] + self.records_written: list[dict] = [] + self.num_records_processed: int = 0 + self.num_batches_processed: int = 0 + + def _write_state_message(self, state: dict): + """Emit the stream's latest state.""" + super()._write_state_message(state) + self.state_messages_written.append(state) diff --git a/tests/core/test_target_base.py b/tests/core/test_target_base.py new file mode 100644 index 000000000..778fab722 --- /dev/null +++ b/tests/core/test_target_base.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import copy + +from tests.conftest import BatchSinkMock, TargetMock + + +def test_get_sink(): + input_schema_1 = { + "properties": { + "id": { + "type": ["string", "null"], + }, + "col_ts": { + "format": "date-time", + "type": ["string", "null"], + }, + }, + } + input_schema_2 = copy.deepcopy(input_schema_1) + key_properties = [] + target = TargetMock(config={"add_record_metadata": True}) + sink = BatchSinkMock(target, "foo", input_schema_1, key_properties) + target._sinks_active["foo"] = sink + sink_returned = target.get_sink( + "foo", + schema=input_schema_2, + key_properties=key_properties, + ) + assert sink_returned == sink diff --git a/tests/samples/test_target_csv.py b/tests/samples/test_target_csv.py index 54f50fc38..715edbb65 100644 --- a/tests/samples/test_target_csv.py +++ b/tests/samples/test_target_csv.py @@ -14,9 +14,6 @@ from samples.sample_mapper.mapper import StreamTransform from samples.sample_tap_countries.countries_tap import SampleTapCountries from samples.sample_target_csv.csv_target import SampleTargetCSV -from singer_sdk import typing as th -from singer_sdk.sinks import BatchSink -from singer_sdk.target_base import Target from singer_sdk.testing import ( get_target_test_class, sync_end_to_end, @@ -24,6 +21,7 @@ tap_to_target_sync_test, target_sync_test, ) +from tests.conftest import TargetMock TEST_OUTPUT_DIR = Path(f".output/test_{uuid.uuid4()}/") SAMPLE_CONFIG = {"target_folder": f"{TEST_OUTPUT_DIR}/"} @@ -55,54 +53,6 @@ def resource(self, test_output_dir): } -class BatchSinkMock(BatchSink): - """A mock Sink class.""" - - name = "batch-sink-mock" - - def __init__( - self, - target: TargetMock, - stream_name: str, - schema: dict, - key_properties: list[str] | None, - ): - """Create the Mock batch-based sink.""" - super().__init__(target, stream_name, schema, key_properties) - self.target = target - - def process_record(self, record: dict, context: dict) -> None: - """Tracks the count of processed records.""" - self.target.num_records_processed += 1 - super().process_record(record, context) - - def process_batch(self, context: dict) -> None: - """Write to mock trackers.""" - self.target.records_written.extend(context["records"]) - self.target.num_batches_processed += 1 - - -class TargetMock(Target): - """A mock Target class.""" - - name = "target-mock" - config_jsonschema = th.PropertiesList().to_dict() - default_sink_class = BatchSinkMock - - def __init__(self): - """Create the Mock target sync.""" - super().__init__(config={}) - self.state_messages_written: list[dict] = [] - self.records_written: list[dict] = [] - self.num_records_processed: int = 0 - self.num_batches_processed: int = 0 - - def _write_state_message(self, state: dict): - """Emit the stream's latest state.""" - super()._write_state_message(state) - self.state_messages_written.append(state) - - def test_countries_to_csv(csv_config: dict): tap = SampleTapCountries(config=SAMPLE_TAP_CONFIG, state=None) target = SampleTargetCSV(config=csv_config) @@ -133,7 +83,7 @@ def test_target_batching(): countries_record_count = 257 with freeze_time(mocked_starttime): - target = TargetMock() + target = TargetMock(config={}) target.max_parallelism = 1 # Limit unit test to 1 process assert target.num_records_processed == 0 assert len(target.records_written) == 0