Skip to content

Commit

Permalink
feat: Stream sync context is now available to all instances methods a…
Browse files Browse the repository at this point in the history
…s a `Stream.context` attribute (#2529)
  • Loading branch information
edgarrmondragon authored Jul 10, 2024
1 parent 6256fe5 commit 5500e45
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Iterable
from typing import TYPE_CHECKING, Iterable

import requests # noqa: TCH002
from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream
Expand All @@ -12,6 +12,9 @@
from {{ cookiecutter.library_name }}.auth import {{ cookiecutter.source_name }}Authenticator
{%- endif %}

if TYPE_CHECKING:
from singer_sdk.helpers.types import Context


class {{ cookiecutter.source_name }}Stream({{ cookiecutter.stream_type }}Stream):
"""{{ cookiecutter.source_name }} stream class."""
Expand Down Expand Up @@ -67,7 +70,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]:
def post_process(
self,
row: dict,
context: dict | None = None, # noqa: ARG002
context: Context | None = None, # noqa: ARG002
) -> dict | None:
"""As needed, append or transform raw data to match expected structure.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@

from __future__ import annotations

from typing import Iterable
from typing import TYPE_CHECKING, Iterable

from singer_sdk.streams import Stream

if TYPE_CHECKING:
from singer_sdk.helpers.types import Context


class {{ cookiecutter.source_name }}Stream(Stream):
"""Stream class for {{ cookiecutter.source_name }} streams."""

def get_records(
self,
context: dict | None, # noqa: ARG002
context: Context | None, # noqa: ARG002
) -> Iterable[dict]:
"""Return a generator of record-type dictionary objects.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
{%- if cookiecutter.auth_method in ("OAuth2", "JWT") %}
from functools import cached_property
{%- endif %}
from typing import Any, Callable, Iterable
from typing import TYPE_CHECKING, Any, Callable, Iterable

import requests
{% if cookiecutter.auth_method == "API Key" -%}
Expand Down Expand Up @@ -46,6 +46,10 @@
else:
import importlib_resources

if TYPE_CHECKING:
from singer_sdk.helpers.types import Context


_Auth = Callable[[requests.PreparedRequest], requests.PreparedRequest]

# TODO: Delete this is if not using json files for schema definition
Expand Down Expand Up @@ -157,7 +161,7 @@ def get_new_paginator(self) -> BaseAPIPaginator:

def get_url_params(
self,
context: dict | None, # noqa: ARG002
context: Context | None, # noqa: ARG002
next_page_token: Any | None, # noqa: ANN401
) -> dict[str, Any]:
"""Return a dictionary of values to be used in URL parameterization.
Expand All @@ -179,7 +183,7 @@ def get_url_params(

def prepare_request_payload(
self,
context: dict | None, # noqa: ARG002
context: Context | None, # noqa: ARG002
next_page_token: Any | None, # noqa: ARG002, ANN401
) -> dict | None:
"""Prepare the data payload for the REST API request.
Expand Down Expand Up @@ -211,7 +215,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]:
def post_process(
self,
row: dict,
context: dict | None = None, # noqa: ARG002
context: Context | None = None, # noqa: ARG002
) -> dict | None:
"""As needed, append or transform raw data to match expected structure.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ omit = [
"tests/*",
"samples/*",
"singer_sdk/helpers/_compat.py",
"singer_sdk/helpers/types.py",
]

[tool.coverage.report]
Expand Down
12 changes: 7 additions & 5 deletions singer_sdk/helpers/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
if t.TYPE_CHECKING:
import datetime

from singer_sdk.helpers import types

_T = t.TypeVar("_T", datetime.datetime, str, int, float)

PROGRESS_MARKERS = "progress_markers"
Expand Down Expand Up @@ -70,7 +72,7 @@ def get_state_partitions_list(tap_state: dict, tap_stream_id: str) -> list[dict]

def _find_in_partitions_list(
partitions: list[dict],
state_partition_context: dict,
state_partition_context: types.Context,
) -> dict | None:
found = [
partition_state
Expand All @@ -88,7 +90,7 @@ def _find_in_partitions_list(

def _create_in_partitions_list(
partitions: list[dict],
state_partition_context: dict,
state_partition_context: types.Context,
) -> dict:
# Existing partition not found. Creating new state entry in partitions list...
new_partition_state = {"context": state_partition_context}
Expand All @@ -99,7 +101,7 @@ def _create_in_partitions_list(
def get_writeable_state_dict(
tap_state: dict,
tap_stream_id: str,
state_partition_context: dict | None = None,
state_partition_context: types.Context | None = None,
) -> dict:
"""Return the stream or partition state, creating a new one if it does not exist.
Expand Down Expand Up @@ -283,8 +285,8 @@ def log_sort_error(
ex: Exception,
log_fn: t.Callable,
stream_name: str,
current_context: dict | None,
state_partition_context: dict | None,
current_context: types.Context | None,
state_partition_context: types.Context | None,
record_count: int,
partition_record_count: int,
) -> None:
Expand Down
24 changes: 24 additions & 0 deletions singer_sdk/helpers/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Type aliases for use in the SDK."""

from __future__ import annotations

import sys
import typing as t

if sys.version_info < (3, 9):
from typing import Mapping # noqa: ICN003
else:
from collections.abc import Mapping

if sys.version_info < (3, 10):
from typing_extensions import TypeAlias
else:
from typing import TypeAlias # noqa: ICN003

__all__ = [
"Context",
"Record",
]

Context: TypeAlias = Mapping
Record: TypeAlias = t.Dict[str, t.Any]
6 changes: 4 additions & 2 deletions singer_sdk/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
if t.TYPE_CHECKING:
from types import TracebackType

from singer_sdk.helpers import types
from singer_sdk.helpers._compat import Traversable


DEFAULT_LOG_INTERVAL = 60.0
METRICS_LOGGER_NAME = __name__
METRICS_LOG_LEVEL_SETTING = "metrics_log_level"
Expand Down Expand Up @@ -117,7 +119,7 @@ def __init__(self, metric: Metric, tags: dict | None = None) -> None:
self.logger = get_metrics_logger()

@property
def context(self) -> dict | None:
def context(self) -> types.Context | None:
"""Get the context for this meter.
Returns:
Expand All @@ -126,7 +128,7 @@ def context(self) -> dict | None:
return self.tags.get(Tag.CONTEXT)

@context.setter
def context(self, value: dict | None) -> None:
def context(self, value: types.Context | None) -> None:
"""Set the context for this meter.
Args:
Expand Down
Loading

0 comments on commit 5500e45

Please sign in to comment.