Skip to content

Commit

Permalink
refactor: Allow loading stream schemas from `importlib.resources.abc.…
Browse files Browse the repository at this point in the history
…Traversable` types
  • Loading branch information
edgarrmondragon committed Jan 2, 2024
1 parent 5624ef0 commit 16d8a17
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ on:
- ".github/workflows/test.yml"
- ".github/workflows/constraints.txt"
push:
branches: [main]
# branches: [main]
paths:
- "cookiecutter/**"
- "samples/**"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,38 @@
{% if cookiecutter.auth_method in ("OAuth2", "JWT") -%}
import sys
{% endif -%}
from pathlib import Path
from typing import Any, Callable, Iterable

import requests
{% if cookiecutter.auth_method == "API Key" -%}
from singer_sdk.authenticators import APIKeyAuthenticator
from singer_sdk.helpers._compat import resources
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import BaseAPIPaginator # noqa: TCH002
from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream

{% elif cookiecutter.auth_method == "Bearer Token" -%}
from singer_sdk.authenticators import BearerTokenAuthenticator
from singer_sdk.helpers._compat import resources
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import BaseAPIPaginator # noqa: TCH002
from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream

{% elif cookiecutter.auth_method == "Basic Auth" -%}
from singer_sdk.authenticators import BasicAuthenticator
from singer_sdk.helpers._compat import resources
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import BaseAPIPaginator # noqa: TCH002
from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream

{% elif cookiecutter.auth_method == "Custom or N/A" -%}
from singer_sdk.helpers._compat import resources
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import BaseAPIPaginator # noqa: TCH002
from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream

{% elif cookiecutter.auth_method in ("OAuth2", "JWT") -%}
from singer_sdk.helpers._compat import resources
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import BaseAPIPaginator # noqa: TCH002
from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream
Expand All @@ -50,7 +54,9 @@
{% endif -%}

_Auth = Callable[[requests.PreparedRequest], requests.PreparedRequest]
SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")

# TODO: Delete this is if not using json files for schema definition
SCHEMAS_DIR = resources.files(__package__) / "schemas"


class {{ cookiecutter.source_name }}Stream({{ cookiecutter.stream_type }}Stream):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from __future__ import annotations

import typing as t
from pathlib import Path

from singer_sdk import typing as th # JSON Schema typing helpers
from singer_sdk.helpers._compat import resources

from {{ cookiecutter.library_name }}.client import {{ cookiecutter.source_name }}Stream

# TODO: Delete this is if not using json files for schema definition
SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")
SCHEMAS_DIR = resources.files(__package__) / "schemas"


{%- if cookiecutter.stream_type == "GraphQL" %}
Expand Down
4 changes: 2 additions & 2 deletions samples/sample_tap_countries/countries_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from __future__ import annotations

import abc
from pathlib import Path

from singer_sdk import typing as th
from singer_sdk.helpers._compat import resources
from singer_sdk.streams.graphql import GraphQLStream

SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")
SCHEMAS_DIR = resources.files(__package__) / "schemas"


class CountriesAPIStream(GraphQLStream, metaclass=abc.ABCMeta):
Expand Down
5 changes: 2 additions & 3 deletions samples/sample_tap_gitlab/gitlab_graphql_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

from __future__ import annotations

from pathlib import Path

from singer_sdk.helpers._compat import resources
from singer_sdk.streams import GraphQLStream

SITE_URL = "https://gitlab.com/graphql"

SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")
SCHEMAS_DIR = resources.files(__package__) / "schemas"


class GitlabGraphQLStream(GraphQLStream):
Expand Down
4 changes: 2 additions & 2 deletions samples/sample_tap_gitlab/gitlab_rest_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from __future__ import annotations

import typing as t
from pathlib import Path

from singer_sdk.authenticators import SimpleAuthenticator
from singer_sdk.helpers._compat import resources
from singer_sdk.pagination import SimpleHeaderPaginator
from singer_sdk.streams.rest import RESTStream
from singer_sdk.typing import (
Expand All @@ -17,7 +17,7 @@
StringType,
)

SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")
SCHEMAS_DIR = resources.files(__package__) / "schemas"

DEFAULT_URL_BASE = "https://gitlab.com/api/v4"

Expand Down
4 changes: 2 additions & 2 deletions samples/sample_tap_google_analytics/ga_tap_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import datetime
import typing as t
from pathlib import Path

from singer_sdk.authenticators import OAuthJWTAuthenticator
from singer_sdk.helpers._compat import resources
from singer_sdk.streams import RESTStream

GOOGLE_OAUTH_ENDPOINT = "https://oauth2.googleapis.com/token"
GA_OAUTH_SCOPES = "https://www.googleapis.com/auth/analytics.readonly"
SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")
SCHEMAS_DIR = resources.files(__package__) / "schemas"


class GoogleJWTAuthenticator(OAuthJWTAuthenticator):
Expand Down
9 changes: 9 additions & 0 deletions singer_sdk/helpers/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
else:
from importlib import resources

if sys.version_info < (3, 9):
from importlib_resources.abc import Traversable
elif sys.version_info < (3, 12):
from importlib.abc import Traversable
else:
from importlib.resources.abc import Traversable


if sys.version_info < (3, 11):
from backports.datetime_fromisoformat import MonkeyPatch

Expand All @@ -35,6 +43,7 @@
"metadata",
"final",
"resources",
"Traversable",
"entry_points",
"datetime_fromisoformat",
"date_fromisoformat",
Expand Down
7 changes: 4 additions & 3 deletions singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
if t.TYPE_CHECKING:
import logging

from singer_sdk.helpers._compat import Traversable
from singer_sdk.tap_base import Tap

# Replication methods
Expand Down Expand Up @@ -136,7 +137,7 @@ def __init__(
self._replication_key: str | None = None
self._primary_keys: t.Sequence[str] | None = None
self._state_partitioning_keys: list[str] | None = None
self._schema_filepath: Path | None = None
self._schema_filepath: Path | Traversable | None = None
self._metadata: singer.MetadataMapping | None = None
self._mask: singer.SelectionMask | None = None
self._schema: dict
Expand All @@ -160,7 +161,7 @@ def __init__(
raise ValueError(msg)

if self.schema_filepath:
self._schema = json.loads(Path(self.schema_filepath).read_text())
self._schema = json.loads(self.schema_filepath.read_text())

if not self.schema:
msg = (
Expand Down Expand Up @@ -421,7 +422,7 @@ def get_replication_key_signpost(
return utc_now() if self.is_timestamp_replication_key else None

@property
def schema_filepath(self) -> Path | None:
def schema_filepath(self) -> Path | Traversable | None:
"""Get path to schema file.
Returns:
Expand Down

0 comments on commit 16d8a17

Please sign in to comment.