diff --git a/singer_sdk/helpers/_batch.py b/singer_sdk/helpers/_batch.py new file mode 100644 index 0000000000..161e0097c7 --- /dev/null +++ b/singer_sdk/helpers/_batch.py @@ -0,0 +1,244 @@ +"""Batch helpers.""" + +from __future__ import annotations + +import enum +import sys +from contextlib import contextmanager +from dataclasses import asdict, dataclass, field +from typing import IO, TYPE_CHECKING, Any, ClassVar, Generator +from urllib.parse import ParseResult, parse_qs, urlencode, urlparse + +import fs +from singer.messages import Message + +from singer_sdk.helpers._singer import SingerMessageType + +if TYPE_CHECKING: + from fs.base import FS + + if sys.version_info >= (3, 8): + from typing import Literal + else: + from typing_extensions import Literal + + +class BatchFileFormat(str, enum.Enum): + """Batch file format.""" + + JSONL = "jsonl" + """JSON Lines format.""" + + +@dataclass +class BaseBatchFileEncoding: + """Base class for batch file encodings.""" + + registered_encodings: ClassVar[dict[str, type[BaseBatchFileEncoding]]] = {} + __encoding_format__: ClassVar[str] = "OVERRIDE_ME" + + # Base encoding fields + format: str = field(init=False) + """The format of the batch file.""" + + compression: str | None = None + """The compression of the batch file.""" + + def __init_subclass__(cls, **kwargs: Any) -> None: + """Register subclasses. + + Args: + **kwargs: Keyword arguments. + """ + super().__init_subclass__(**kwargs) + cls.registered_encodings[cls.__encoding_format__] = cls + + def __post_init__(self) -> None: + """Post-init hook.""" + self.format = self.__encoding_format__ + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> BaseBatchFileEncoding: + """Create an encoding from a dictionary.""" + data = data.copy() + encoding_format = data.pop("format") + encoding_cls = cls.registered_encodings[encoding_format] + return encoding_cls(**data) + + +@dataclass +class JSONLinesEncoding(BaseBatchFileEncoding): + """JSON Lines encoding for batch files.""" + + __encoding_format__ = "jsonl" + + +@dataclass +class SDKBatchMessage(Message): + """Singer batch message in the Meltano SDK flavor.""" + + type: Literal[SingerMessageType.BATCH] = field(init=False) + """The message type.""" + + stream: str + """The stream name.""" + + encoding: BaseBatchFileEncoding + """The file encoding of the batch.""" + + manifest: list[str] = field(default_factory=list) + """The manifest of files in the batch.""" + + def __post_init__(self): + if isinstance(self.encoding, dict): + self.encoding = BaseBatchFileEncoding.from_dict(self.encoding) + + self.type = SingerMessageType.BATCH + + def asdict(self): + """Return a dictionary representation of the message. + + Returns: + A dictionary with the defined message fields. + """ + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> SDKBatchMessage: + """Create an encoding from a dictionary. + + Args: + data: The dictionary to create the message from. + + Returns: + The created message. + """ + data.pop("type") + return cls(**data) + + +@dataclass +class StorageTarget: + """Storage target.""" + + root: str + """"The root directory of the storage target.""" + + prefix: str | None = None + """"The file prefix.""" + + params: dict = field(default_factory=dict) + """"The storage parameters.""" + + def asdict(self): + """Return a dictionary representation of the message. + + Returns: + A dictionary with the defined message fields. + """ + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> StorageTarget: + """Create an encoding from a dictionary. + + Args: + data: The dictionary to create the message from. + + Returns: + The created message. + """ + return cls(**data) + + @classmethod + def from_url(cls, url: ParseResult) -> StorageTarget: + """Create a storage target from a URL. + + Args: + url: The URL to create the storage target from. + + Returns: + The created storage target. + """ + new_url = url._replace(path="", query="") + return cls(root=new_url.geturl(), params=parse_qs(url.query)) + + @property + def fs_url(self) -> ParseResult: + """Get the storage target URL. + + Returns: + The storage target URL. + """ + return urlparse(self.root)._replace(query=urlencode(self.params)) + + @contextmanager + def fs(self, **kwargs: Any) -> Generator[FS, None, None]: + """Get a filesystem object for the storage target. + + Args: + kwargs: Additional arguments to pass ``f`.open_fs``. + + Returns: + The filesystem object. + """ + filesystem = fs.open_fs(self.fs_url.geturl(), **kwargs) + yield filesystem + filesystem.close() + + @contextmanager + def open(self, filename: str, mode: str = "rb") -> Generator[IO, None, None]: + """Open a file in the storage target. + + Args: + filename: The filename to open. + mode: The mode to open the file in. + + Returns: + The opened file. + """ + filesystem = fs.open_fs(self.root, writeable=True, create=True) + fo = filesystem.open(filename, mode=mode) + try: + yield fo + finally: + fo.close() + filesystem.close() + + +@dataclass +class BatchConfig: + """Batch configuration.""" + + encoding: BaseBatchFileEncoding + """The encoding of the batch file.""" + + storage: StorageTarget + """The storage target of the batch file.""" + + def __post_init__(self): + if isinstance(self.encoding, dict): + self.encoding = BaseBatchFileEncoding.from_dict(self.encoding) + + if isinstance(self.storage, dict): + self.storage = StorageTarget.from_dict(self.storage) + + def asdict(self): + """Return a dictionary representation of the message. + + Returns: + A dictionary with the defined message fields. + """ + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> BatchConfig: + """Create an encoding from a dictionary. + + Args: + data: The dictionary to create the message from. + + Returns: + The created message. + """ + return cls(**data) diff --git a/singer_sdk/helpers/_singer.py b/singer_sdk/helpers/_singer.py index 326c7d168a..c756f84a67 100644 --- a/singer_sdk/helpers/_singer.py +++ b/singer_sdk/helpers/_singer.py @@ -2,36 +2,17 @@ import enum import logging -import sys -from contextlib import contextmanager -from dataclasses import asdict, dataclass, field, fields -from typing import ( - IO, - TYPE_CHECKING, - Any, - ClassVar, - Dict, - Generator, - Iterable, - Tuple, - Union, - cast, -) - -import fs +from dataclasses import dataclass, fields +from typing import TYPE_CHECKING, Any, Dict, Iterable, Tuple, Union, cast + from singer.catalog import Catalog as BaseCatalog from singer.catalog import CatalogEntry as BaseCatalogEntry -from singer.messages import Message from singer_sdk.helpers._schema import SchemaPlus if TYPE_CHECKING: from typing_extensions import TypeAlias - if sys.version_info >= (3, 8): - from typing import Literal - else: - from typing_extensions import Literal Breadcrumb = Tuple[str, ...] @@ -316,180 +297,3 @@ def add_stream(self, entry: CatalogEntry) -> None: def get_stream(self, stream_id: str) -> CatalogEntry | None: """Retrieve a stream entry from the catalog.""" return self.get(stream_id) - - -class BatchFileFormat(str, enum.Enum): - """Batch file format.""" - - JSONL = "jsonl" - """JSON Lines format.""" - - -@dataclass -class BaseBatchFileEncoding: - """Base class for batch file encodings.""" - - registered_encodings: ClassVar[dict[str, type[BaseBatchFileEncoding]]] = {} - __encoding_format__: ClassVar[str] = "OVERRIDE_ME" - - # Base encoding fields - format: str = field(init=False) - """The format of the batch file.""" - - compression: str | None = None - """The compression of the batch file.""" - - def __init_subclass__(cls, **kwargs: Any) -> None: - """Register subclasses.""" - super().__init_subclass__(**kwargs) - cls.registered_encodings[cls.__encoding_format__] = cls - - def __post_init__(self) -> None: - self.format = self.__encoding_format__ - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> BaseBatchFileEncoding: - """Create an encoding from a dictionary.""" - data = data.copy() - encoding_format = data.pop("format") - encoding_cls = cls.registered_encodings[encoding_format] - return encoding_cls(**data) - - -@dataclass -class JSONLinesEncoding(BaseBatchFileEncoding): - """JSON Lines encoding for batch files.""" - - __encoding_format__ = "jsonl" - - -@dataclass -class SDKBatchMessage(Message): - """Singer batch message in the Meltano SDK flavor.""" - - type: Literal[SingerMessageType.BATCH] = field(init=False) - """The message type.""" - - stream: str - """The stream name.""" - - encoding: BaseBatchFileEncoding - """The file encoding of the batch.""" - - manifest: list[str] = field(default_factory=list) - """The manifest of files in the batch.""" - - def __post_init__(self): - if isinstance(self.encoding, dict): - self.encoding = BaseBatchFileEncoding.from_dict(self.encoding) - - self.type = SingerMessageType.BATCH - - def asdict(self): - """Return a dictionary representation of the message. - - Returns: - A dictionary with the defined message fields. - """ - return asdict(self) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> SDKBatchMessage: - """Create an encoding from a dictionary. - - Args: - data: The dictionary to create the message from. - - Returns: - The created message. - """ - data.pop("type") - return cls(**data) - - -@dataclass -class StorageTarget: - """Storage target.""" - - root: str - """"The root directory of the storage target.""" - - prefix: str - """"The file prefix.""" - - def asdict(self): - """Return a dictionary representation of the message. - - Returns: - A dictionary with the defined message fields. - """ - return asdict(self) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> StorageTarget: - """Create an encoding from a dictionary. - - Args: - data: The dictionary to create the message from. - - Returns: - The created message. - """ - return cls(**data) - - @contextmanager - def open(self, filename: str, mode: str = "rb") -> Generator[IO, None, None]: - """Open a file in the storage target. - - Args: - filename: The filename to open. - mode: The mode to open the file in. - - Returns: - The opened file. - """ - filesystem = fs.open_fs(self.root, writeable=True, create=True) - fo = filesystem.open(filename, mode=mode) - try: - yield fo - finally: - fo.close() - filesystem.close() - - -@dataclass -class BatchConfig: - """Batch configuration.""" - - encoding: BaseBatchFileEncoding - """The encoding of the batch file.""" - - storage: StorageTarget - """The storage target of the batch file.""" - - def __post_init__(self): - if isinstance(self.encoding, dict): - self.encoding = BaseBatchFileEncoding.from_dict(self.encoding) - - if isinstance(self.storage, dict): - self.storage = StorageTarget.from_dict(self.storage) - - def asdict(self): - """Return a dictionary representation of the message. - - Returns: - A dictionary with the defined message fields. - """ - return asdict(self) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> BatchConfig: - """Create an encoding from a dictionary. - - Args: - data: The dictionary to create the message from. - - Returns: - The created message. - """ - return cls(**data) diff --git a/singer_sdk/sinks/core.py b/singer_sdk/sinks/core.py index 470b4e6a62..dfd4afce15 100644 --- a/singer_sdk/sinks/core.py +++ b/singer_sdk/sinks/core.py @@ -11,17 +11,18 @@ from logging import Logger from types import MappingProxyType from typing import IO, Any, Mapping, Sequence +from urllib.parse import urlparse from dateutil import parser from jsonschema import Draft4Validator, FormatChecker -from singer_sdk.helpers._compat import final -from singer_sdk.helpers._singer import ( +from singer_sdk.helpers._batch import ( BaseBatchFileEncoding, BatchConfig, BatchFileFormat, StorageTarget, ) +from singer_sdk.helpers._compat import final from singer_sdk.helpers._typing import ( DatetimeErrorTreatmentEnum, get_datelike_property_type, @@ -435,27 +436,35 @@ def clean_up(self) -> None: def process_batch_files( self, encoding: BaseBatchFileEncoding, - storage: StorageTarget, files: Sequence[str], ) -> None: """Process a batch file with the given batch context. Args: encoding: The batch file encoding. - storage: The storage target. files: The batch files to process. Raises: NotImplementedError: If the batch file encoding is not supported. """ file: GzipFile | IO + storage: StorageTarget | None = None + for path in files: + url = urlparse(path) + + if self.batch_config: + storage = self.batch_config.storage + else: + storage = StorageTarget.from_url(url) + if encoding.format == BatchFileFormat.JSONL: - with storage.open(path) as file: - if encoding.compression == "gzip": - file = gzip_open(file) - context = {"records": [json.loads(line) for line in file]} - self.process_batch(context) + with storage.fs(create=False) as fs: + with fs.open(url.path, mode="rb") as file: + if encoding.compression == "gzip": + file = gzip_open(file) + context = {"records": [json.loads(line) for line in file]} + self.process_batch(context) else: raise NotImplementedError( f"Unsupported batch encoding format: {encoding.format}" diff --git a/singer_sdk/streams/core.py b/singer_sdk/streams/core.py index 075aaf6858..cd78c1ffc7 100644 --- a/singer_sdk/streams/core.py +++ b/singer_sdk/streams/core.py @@ -21,17 +21,19 @@ from singer import RecordMessage, Schema, SchemaMessage, StateMessage from singer_sdk.exceptions import InvalidStreamSortException, MaxRecordsLimitException +from singer_sdk.helpers._batch import ( + BaseBatchFileEncoding, + BatchConfig, + SDKBatchMessage, +) from singer_sdk.helpers._catalog import pop_deselected_record_properties from singer_sdk.helpers._compat import final from singer_sdk.helpers._flattening import get_flattening_options from singer_sdk.helpers._schema import SchemaPlus from singer_sdk.helpers._singer import ( - BaseBatchFileEncoding, - BatchConfig, Catalog, CatalogEntry, MetadataMapping, - SDKBatchMessage, SelectionMask, ) from singer_sdk.helpers._state import ( @@ -1300,7 +1302,7 @@ def get_batches( A tuple of (encoding, manifest) for each batch. """ sync_id = f"{self.tap_name}--{self.name}-{uuid4()}" - prefix = batch_config.storage.prefix + prefix = batch_config.storage.prefix or "" for i, chunk in enumerate( lazy_chunked_generator( @@ -1310,14 +1312,16 @@ def get_batches( start=1, ): filename = f"{prefix}{sync_id}-{i}.json.gz" - with batch_config.storage.open(filename, "wb") as f: - # TODO: Determine compression from config. - with gzip.GzipFile(fileobj=f, mode="wb") as gz: - gz.writelines( - (json.dumps(record) + "\n").encode() for record in chunk - ) - - yield batch_config.encoding, [filename] + with batch_config.storage.fs() as fs: + with fs.open(filename, "wb") as f: + # TODO: Determine compression from config. + with gzip.GzipFile(fileobj=f, mode="wb") as gz: + gz.writelines( + (json.dumps(record) + "\n").encode() for record in chunk + ) + file_url = fs.geturl(filename) + + yield batch_config.encoding, [file_url] def post_process(self, row: dict, context: dict | None = None) -> dict | None: """As needed, append or transform raw data to match expected structure. diff --git a/singer_sdk/target_base.py b/singer_sdk/target_base.py index d2c5631ba6..0b21b04c53 100644 --- a/singer_sdk/target_base.py +++ b/singer_sdk/target_base.py @@ -14,9 +14,9 @@ from singer_sdk.cli import common_options from singer_sdk.exceptions import RecordsWithoutSchemaException +from singer_sdk.helpers._batch import BaseBatchFileEncoding from singer_sdk.helpers._classproperty import classproperty from singer_sdk.helpers._compat import final -from singer_sdk.helpers._singer import BaseBatchFileEncoding from singer_sdk.helpers.capabilities import CapabilitiesEnum, PluginCapabilities from singer_sdk.io_base import SingerMessageType, SingerReader from singer_sdk.mapper import PluginMapper @@ -408,21 +408,12 @@ def _process_batch_message(self, message_dict: dict) -> None: Args: message_dict: TODO - - Raises: - RuntimeError: If the batch message can not be processed. """ sink = self.get_sink(message_dict["stream"]) - if sink.batch_config is None: - raise RuntimeError( - f"Received BATCH message for stream '{sink.stream_name}' " - "but no batch config was provided." - ) encoding = BaseBatchFileEncoding.from_dict(message_dict["encoding"]) sink.process_batch_files( encoding, - sink.batch_config.storage, message_dict["manifest"], ) diff --git a/tests/core/test_batch.py b/tests/core/test_batch.py new file mode 100644 index 0000000000..cd3aff10f9 --- /dev/null +++ b/tests/core/test_batch.py @@ -0,0 +1,40 @@ +from dataclasses import asdict +from urllib.parse import urlparse + +import pytest + +from singer_sdk.helpers._batch import ( + BaseBatchFileEncoding, + JSONLinesEncoding, + StorageTarget, +) + + +@pytest.mark.parametrize( + "encoding,expected", + [ + (JSONLinesEncoding("gzip"), {"compression": "gzip", "format": "jsonl"}), + (JSONLinesEncoding(), {"compression": None, "format": "jsonl"}), + ], + ids=["jsonl-compression-gzip", "jsonl-compression-none"], +) +def test_encoding_as_dict(encoding: BaseBatchFileEncoding, expected: dict) -> None: + """Test encoding as dict.""" + assert asdict(encoding) == expected + + +def test_storage_get_url(): + storage = StorageTarget("file://root_dir") + + with storage.fs(create=True) as fs: + url = fs.geturl("prefix--file.jsonl.gz") + assert url.startswith("file://") + assert url.endswith("root_dir/prefix--file.jsonl.gz") + + +def test_storage_from_url(): + url = urlparse("s3://bucket/path/to/file?region=us-east-1") + target = StorageTarget.from_url(url) + assert target.root == "s3://bucket" + assert target.prefix is None + assert target.params == {"region": ["us-east-1"]} diff --git a/tests/core/test_singer_messages.py b/tests/core/test_singer_messages.py index 185606631e..3858731a9f 100644 --- a/tests/core/test_singer_messages.py +++ b/tests/core/test_singer_messages.py @@ -2,25 +2,8 @@ import pytest -from singer_sdk.helpers._singer import ( - BaseBatchFileEncoding, - JSONLinesEncoding, - SDKBatchMessage, - SingerMessageType, -) - - -@pytest.mark.parametrize( - "encoding,expected", - [ - (JSONLinesEncoding("gzip"), {"compression": "gzip", "format": "jsonl"}), - (JSONLinesEncoding(), {"compression": None, "format": "jsonl"}), - ], - ids=["jsonl-compression-gzip", "jsonl-compression-none"], -) -def test_encoding_as_dict(encoding: BaseBatchFileEncoding, expected: dict) -> None: - """Test encoding as dict.""" - assert asdict(encoding) == expected +from singer_sdk.helpers._batch import JSONLinesEncoding, SDKBatchMessage +from singer_sdk.helpers._singer import SingerMessageType @pytest.mark.parametrize( diff --git a/tests/core/test_sqlite.py b/tests/core/test_sqlite.py index 9b8d5555ae..b6a24fd77c 100644 --- a/tests/core/test_sqlite.py +++ b/tests/core/test_sqlite.py @@ -114,16 +114,6 @@ def sqlite_sample_target_soft_delete(sqlite_target_test_config): def sqlite_sample_target_batch(sqlite_target_test_config): """Get a sample target object with hard_delete disabled.""" conf = sqlite_target_test_config - conf["batch_config"] = { - "encoding": { - "format": "jsonl", - "compression": "gzip", - }, - "storage": { - "root": "file://tests/core/resources", - "prefix": "test-batch-", - }, - } return SQLiteTarget(conf) @@ -435,7 +425,10 @@ def test_sqlite_process_batch_message( "type": "BATCH", "stream": "users", "encoding": {"format": "jsonl", "compression": "gzip"}, - "manifest": ["batch.1.jsonl.gz", "batch.2.jsonl.gz"], + "manifest": [ + "file://tests/core/resources/batch.1.jsonl.gz", + "file://tests/core/resources/batch.2.jsonl.gz", + ], } tap_output = "\n".join([json.dumps(schema_message), json.dumps(batch_message)])