From 5500e451569ebbc00b5c348b9a5ddc64ae76b713 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= <16805946+edgarrmondragon@users.noreply.github.com> Date: Wed, 10 Jul 2024 17:20:11 -0600 Subject: [PATCH] feat: Stream sync context is now available to all instances methods as a `Stream.context` attribute (#2529) --- .../graphql-client.py | 7 +- .../other-client.py | 7 +- .../rest-client.py | 12 ++- pyproject.toml | 1 + singer_sdk/helpers/_state.py | 12 ++- singer_sdk/helpers/types.py | 24 +++++ singer_sdk/metrics.py | 6 +- singer_sdk/streams/core.py | 96 ++++++++++--------- singer_sdk/streams/graphql.py | 2 +- singer_sdk/streams/rest.py | 2 +- singer_sdk/streams/sql.py | 2 +- 11 files changed, 107 insertions(+), 64 deletions(-) create mode 100644 singer_sdk/helpers/types.py diff --git a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/graphql-client.py b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/graphql-client.py index 66505556d..4e878cb23 100644 --- a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/graphql-client.py +++ b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/graphql-client.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Iterable +from typing import TYPE_CHECKING, Iterable import requests # noqa: TCH002 from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream @@ -12,6 +12,9 @@ from {{ cookiecutter.library_name }}.auth import {{ cookiecutter.source_name }}Authenticator {%- endif %} +if TYPE_CHECKING: + from singer_sdk.helpers.types import Context + class {{ cookiecutter.source_name }}Stream({{ cookiecutter.stream_type }}Stream): """{{ cookiecutter.source_name }} stream class.""" @@ -67,7 +70,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]: def post_process( self, row: dict, - context: dict | None = None, # noqa: ARG002 + context: Context | None = None, # noqa: ARG002 ) -> dict | None: """As needed, append or transform raw data to match expected structure. diff --git a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/other-client.py b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/other-client.py index c2def6322..1952579d7 100644 --- a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/other-client.py +++ b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/other-client.py @@ -2,17 +2,20 @@ from __future__ import annotations -from typing import Iterable +from typing import TYPE_CHECKING, Iterable from singer_sdk.streams import Stream +if TYPE_CHECKING: + from singer_sdk.helpers.types import Context + class {{ cookiecutter.source_name }}Stream(Stream): """Stream class for {{ cookiecutter.source_name }} streams.""" def get_records( self, - context: dict | None, # noqa: ARG002 + context: Context | None, # noqa: ARG002 ) -> Iterable[dict]: """Return a generator of record-type dictionary objects. diff --git a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/rest-client.py b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/rest-client.py index c7ea7b5ce..f4edca913 100644 --- a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/rest-client.py +++ b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/rest-client.py @@ -6,7 +6,7 @@ {%- if cookiecutter.auth_method in ("OAuth2", "JWT") %} from functools import cached_property {%- endif %} -from typing import Any, Callable, Iterable +from typing import TYPE_CHECKING, Any, Callable, Iterable import requests {% if cookiecutter.auth_method == "API Key" -%} @@ -46,6 +46,10 @@ else: import importlib_resources +if TYPE_CHECKING: + from singer_sdk.helpers.types import Context + + _Auth = Callable[[requests.PreparedRequest], requests.PreparedRequest] # TODO: Delete this is if not using json files for schema definition @@ -157,7 +161,7 @@ def get_new_paginator(self) -> BaseAPIPaginator: def get_url_params( self, - context: dict | None, # noqa: ARG002 + context: Context | None, # noqa: ARG002 next_page_token: Any | None, # noqa: ANN401 ) -> dict[str, Any]: """Return a dictionary of values to be used in URL parameterization. @@ -179,7 +183,7 @@ def get_url_params( def prepare_request_payload( self, - context: dict | None, # noqa: ARG002 + context: Context | None, # noqa: ARG002 next_page_token: Any | None, # noqa: ARG002, ANN401 ) -> dict | None: """Prepare the data payload for the REST API request. @@ -211,7 +215,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]: def post_process( self, row: dict, - context: dict | None = None, # noqa: ARG002 + context: Context | None = None, # noqa: ARG002 ) -> dict | None: """As needed, append or transform raw data to match expected structure. diff --git a/pyproject.toml b/pyproject.toml index 95895ba5a..8734abe9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -206,6 +206,7 @@ omit = [ "tests/*", "samples/*", "singer_sdk/helpers/_compat.py", + "singer_sdk/helpers/types.py", ] [tool.coverage.report] diff --git a/singer_sdk/helpers/_state.py b/singer_sdk/helpers/_state.py index ed3d345eb..a910bb71e 100644 --- a/singer_sdk/helpers/_state.py +++ b/singer_sdk/helpers/_state.py @@ -11,6 +11,8 @@ if t.TYPE_CHECKING: import datetime + from singer_sdk.helpers import types + _T = t.TypeVar("_T", datetime.datetime, str, int, float) PROGRESS_MARKERS = "progress_markers" @@ -70,7 +72,7 @@ def get_state_partitions_list(tap_state: dict, tap_stream_id: str) -> list[dict] def _find_in_partitions_list( partitions: list[dict], - state_partition_context: dict, + state_partition_context: types.Context, ) -> dict | None: found = [ partition_state @@ -88,7 +90,7 @@ def _find_in_partitions_list( def _create_in_partitions_list( partitions: list[dict], - state_partition_context: dict, + state_partition_context: types.Context, ) -> dict: # Existing partition not found. Creating new state entry in partitions list... new_partition_state = {"context": state_partition_context} @@ -99,7 +101,7 @@ def _create_in_partitions_list( def get_writeable_state_dict( tap_state: dict, tap_stream_id: str, - state_partition_context: dict | None = None, + state_partition_context: types.Context | None = None, ) -> dict: """Return the stream or partition state, creating a new one if it does not exist. @@ -283,8 +285,8 @@ def log_sort_error( ex: Exception, log_fn: t.Callable, stream_name: str, - current_context: dict | None, - state_partition_context: dict | None, + current_context: types.Context | None, + state_partition_context: types.Context | None, record_count: int, partition_record_count: int, ) -> None: diff --git a/singer_sdk/helpers/types.py b/singer_sdk/helpers/types.py new file mode 100644 index 000000000..783dbe914 --- /dev/null +++ b/singer_sdk/helpers/types.py @@ -0,0 +1,24 @@ +"""Type aliases for use in the SDK.""" + +from __future__ import annotations + +import sys +import typing as t + +if sys.version_info < (3, 9): + from typing import Mapping # noqa: ICN003 +else: + from collections.abc import Mapping + +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias # noqa: ICN003 + +__all__ = [ + "Context", + "Record", +] + +Context: TypeAlias = Mapping +Record: TypeAlias = t.Dict[str, t.Any] diff --git a/singer_sdk/metrics.py b/singer_sdk/metrics.py index 990285ae0..8a7efe51e 100644 --- a/singer_sdk/metrics.py +++ b/singer_sdk/metrics.py @@ -20,8 +20,10 @@ if t.TYPE_CHECKING: from types import TracebackType + from singer_sdk.helpers import types from singer_sdk.helpers._compat import Traversable + DEFAULT_LOG_INTERVAL = 60.0 METRICS_LOGGER_NAME = __name__ METRICS_LOG_LEVEL_SETTING = "metrics_log_level" @@ -117,7 +119,7 @@ def __init__(self, metric: Metric, tags: dict | None = None) -> None: self.logger = get_metrics_logger() @property - def context(self) -> dict | None: + def context(self) -> types.Context | None: """Get the context for this meter. Returns: @@ -126,7 +128,7 @@ def context(self) -> dict | None: return self.tags.get(Tag.CONTEXT) @context.setter - def context(self, value: dict | None) -> None: + def context(self, value: types.Context | None) -> None: """Set the context for this meter. Args: diff --git a/singer_sdk/streams/core.py b/singer_sdk/streams/core.py index 0b4d6eef6..c588a9729 100644 --- a/singer_sdk/streams/core.py +++ b/singer_sdk/streams/core.py @@ -6,7 +6,6 @@ import copy import datetime import json -import sys import typing as t from os import PathLike from pathlib import Path @@ -50,14 +49,10 @@ from singer_sdk.helpers._util import utc_now from singer_sdk.mapper import RemoveRecordTransform, SameRecordTransform, StreamMap -if sys.version_info < (3, 10): - from typing_extensions import TypeAlias -else: - from typing import TypeAlias # noqa: ICN003 - if t.TYPE_CHECKING: import logging + from singer_sdk.helpers import types from singer_sdk.helpers._compat import Traversable from singer_sdk.tap_base import Tap @@ -66,13 +61,15 @@ REPLICATION_INCREMENTAL = "INCREMENTAL" REPLICATION_LOG_BASED = "LOG_BASED" -FactoryType = t.TypeVar("FactoryType", bound="Stream") -Record: TypeAlias = t.Dict[str, t.Any] -Context: TypeAlias = t.Dict - class Stream(metaclass=abc.ABCMeta): # noqa: PLR0904 - """Abstract base class for tap streams.""" + """Abstract base class for tap streams. + + :ivar context: Stream partition or context dictionary. + + .. versionadded:: 0.39.0 + The ``context`` attribute. + """ STATE_MSG_FREQUENCY = 10000 """Number of records between state messages.""" @@ -134,6 +131,8 @@ def __init__( self.logger: logging.Logger = tap.logger.getChild(self.name) self.metrics_logger = tap.metrics_logger self.tap_name: str = tap.name + self.context: types.Context | None = None + self._config: dict = dict(tap.config) self._tap = tap self._tap_state = tap.state @@ -234,7 +233,7 @@ def is_timestamp_replication_key(self) -> bool: def get_starting_replication_key_value( self, - context: Context | None, + context: types.Context | None, ) -> t.Any | None: # noqa: ANN401 """Get starting replication key. @@ -260,7 +259,8 @@ def get_starting_replication_key_value( ) def get_starting_timestamp( - self, context: Context | None + self, + context: types.Context | None, ) -> datetime.datetime | None: """Get starting replication timestamp. @@ -340,7 +340,7 @@ def descendent_streams(self) -> list[Stream]: def _write_replication_key_signpost( self, - context: Context | None, + context: types.Context | None, value: datetime.datetime | str | int | float, ) -> None: """Write the signpost value, if available. @@ -381,7 +381,7 @@ def compare_start_date(self, value: str, start_date_value: str) -> str: return value - def _write_starting_replication_value(self, context: Context | None) -> None: + def _write_starting_replication_value(self, context: types.Context | None) -> None: """Write the starting replication value, if available. Args: @@ -409,7 +409,7 @@ def _write_starting_replication_value(self, context: Context | None) -> None: def get_replication_key_signpost( self, - context: Context | None, # noqa: ARG002 + context: types.Context | None, # noqa: ARG002 ) -> datetime.datetime | t.Any | None: # noqa: ANN401 """Get the replication signpost. @@ -656,7 +656,7 @@ def tap_state(self) -> dict: """ return self._tap_state - def get_context_state(self, context: Context | None) -> dict: + def get_context_state(self, context: types.Context | None) -> dict: """Return a writable state dict for the given context. Gives a partitioned context state if applicable; else returns stream state. @@ -711,7 +711,7 @@ def stream_state(self) -> dict: # Partitions @property - def partitions(self) -> list[Context] | None: + def partitions(self) -> list[types.Context] | None: """Get stream partitions. Developers may override this property to provide a default partitions list. @@ -722,7 +722,7 @@ def partitions(self) -> list[Context] | None: Returns: A list of partition key dicts (if applicable), otherwise `None`. """ - result: list[dict] = [ + result: list[types.Mapping] = [ partition_state["context"] for partition_state in ( get_state_partitions_list(self.tap_state, self.name) or [] @@ -734,9 +734,9 @@ def partitions(self) -> list[Context] | None: def _increment_stream_state( self, - latest_record: Record, + latest_record: types.Record, *, - context: Context | None = None, + context: types.Context | None = None, ) -> None: """Update state of stream or partition with data from the provided record. @@ -827,7 +827,7 @@ def mask(self) -> singer.SelectionMask: def _generate_record_messages( self, - record: Record, + record: types.Record, ) -> t.Generator[singer.RecordMessage, None, None]: """Write out a RECORD message. @@ -856,7 +856,7 @@ def _generate_record_messages( time_extracted=utc_now(), ) - def _write_record_message(self, record: Record) -> None: + def _write_record_message(self, record: types.Record) -> None: """Write out a RECORD message. Args: @@ -973,7 +973,7 @@ def reset_state_progress_markers(self, state: dict | None = None) -> None: state: State object to promote progress markers with. """ if state is None or state == {}: - context: Context | None + context: types.Context | None for context in self.partitions or [{}]: state = self.get_context_state(context or None) reset_state_progress_markers(state) @@ -1002,7 +1002,7 @@ def finalize_state_progress_markers(self, state: dict | None = None) -> None: for child_stream in self.child_streams or []: child_stream.finalize_state_progress_markers() - context: Context | None + context: types.Context | None for context in self.partitions or [{}]: state = self.get_context_state(context or None) self._finalize_state(state) @@ -1015,9 +1015,9 @@ def finalize_state_progress_markers(self, state: dict | None = None) -> None: def _process_record( self, - record: Record, - child_context: Context | None = None, - partition_context: Context | None = None, + record: types.Record, + child_context: types.Context | None = None, + partition_context: types.Context | None = None, ) -> None: """Process a record. @@ -1042,7 +1042,7 @@ def _process_record( def _sync_records( # noqa: C901 self, - context: Context | None = None, + context: types.Context | None = None, *, write_messages: bool = True, ) -> t.Generator[dict, t.Any, t.Any]: @@ -1064,8 +1064,8 @@ def _sync_records( # noqa: C901 timer = metrics.sync_timer(self.name) record_index = 0 - context_element: Context | None - context_list: list[dict] | None + context_element: types.Context | None + context_list: list[types.Context] | None context_list = [context] if context is not None else self.partitions selected = self.selected @@ -1080,7 +1080,7 @@ def _sync_records( # noqa: C901 current_context, ) self._write_starting_replication_value(current_context) - child_context: Context | None = ( + child_context: types.Context | None = ( None if current_context is None else copy.copy(current_context) ) @@ -1141,7 +1141,7 @@ def _sync_records( # noqa: C901 def _sync_batches( self, batch_config: BatchConfig, - context: Context | None = None, + context: types.Context | None = None, ) -> None: """Sync batches, emitting BATCH messages. @@ -1158,7 +1158,7 @@ def _sync_batches( # Public methods ("final", not recommended to be overridden) @t.final - def sync(self, context: Context | None = None) -> None: + def sync(self, context: types.Context | None = None) -> None: """Sync this stream. This method is internal to the SDK and should not need to be overridden. @@ -1173,6 +1173,7 @@ def sync(self, context: Context | None = None) -> None: if context: msg += f" with context: {context}" self.logger.info("%s...", msg) + self.context = MappingProxyType(context) if context else None # Use a replication signpost, if available signpost = self.get_replication_key_signpost(context) @@ -1198,7 +1199,7 @@ def sync(self, context: Context | None = None) -> None: ) raise - def _sync_children(self, child_context: Context | None) -> None: + def _sync_children(self, child_context: types.Context | None) -> None: if child_context is None: self.logger.warning( "Context for child streams of '%s' is null, " @@ -1233,7 +1234,10 @@ def apply_catalog(self, catalog: singer.Catalog) -> None: if catalog_entry.replication_method: self.forced_replication_method = catalog_entry.replication_method - def _get_state_partition_context(self, context: Context | None) -> dict | None: + def _get_state_partition_context( + self, + context: types.Context | None, + ) -> types.Context | None: """Override state handling if Stream.state_partitioning_keys is specified. Args: @@ -1252,9 +1256,9 @@ def _get_state_partition_context(self, context: Context | None) -> dict | None: def get_child_context( self, - record: Record, - context: Context | None, - ) -> dict | None: + record: types.Record, + context: types.Context | None, + ) -> types.Context | None: """Return a child context object from the record and optional provided context. By default, will return context if provided and otherwise the record dict. @@ -1295,9 +1299,9 @@ def get_child_context( def generate_child_contexts( self, - record: Record, - context: Context | None, - ) -> t.Iterable[dict | None]: + record: types.Record, + context: types.Context | None, + ) -> t.Iterable[types.Context | None]: """Generate child contexts. Args: @@ -1314,7 +1318,7 @@ def generate_child_contexts( @abc.abstractmethod def get_records( self, - context: Context | None, + context: types.Context | None, ) -> t.Iterable[dict | tuple[dict, dict | None]]: """Abstract record generator function. Must be overridden by the child class. @@ -1360,7 +1364,7 @@ def get_batch_config(self, config: t.Mapping) -> BatchConfig | None: # noqa: PL def get_batches( self, batch_config: BatchConfig, - context: Context | None = None, + context: types.Context | None = None, ) -> t.Iterable[tuple[BaseBatchFileEncoding, list[str]]]: """Batch generator function. @@ -1385,8 +1389,8 @@ def get_batches( def post_process( # noqa: PLR6301 self, - row: Record, - context: Context | None = None, # noqa: ARG002 + row: types.Record, + context: types.Context | None = None, # noqa: ARG002 ) -> dict | None: """As needed, append or transform raw data to match expected structure. diff --git a/singer_sdk/streams/graphql.py b/singer_sdk/streams/graphql.py index 04a2e80d6..4e5455bc3 100644 --- a/singer_sdk/streams/graphql.py +++ b/singer_sdk/streams/graphql.py @@ -9,7 +9,7 @@ from singer_sdk.streams.rest import RESTStream if t.TYPE_CHECKING: - from singer_sdk.streams.core import Context + from singer_sdk.helpers.types import Context _TToken = t.TypeVar("_TToken") diff --git a/singer_sdk/streams/rest.py b/singer_sdk/streams/rest.py index a241396c2..5aff95346 100644 --- a/singer_sdk/streams/rest.py +++ b/singer_sdk/streams/rest.py @@ -32,7 +32,7 @@ from backoff.types import Details from singer_sdk._singerlib import Schema - from singer_sdk.streams.core import Context + from singer_sdk.helpers.types import Context from singer_sdk.tap_base import Tap if sys.version_info >= (3, 10): diff --git a/singer_sdk/streams/sql.py b/singer_sdk/streams/sql.py index 2b610a2a5..04a3f9d17 100644 --- a/singer_sdk/streams/sql.py +++ b/singer_sdk/streams/sql.py @@ -14,7 +14,7 @@ from singer_sdk.streams.core import Stream if t.TYPE_CHECKING: - from singer_sdk.streams.core import Context + from singer_sdk.helpers.types import Context from singer_sdk.tap_base import Tap