Skip to content

Commit

Permalink
refactor: break out default batch file writer into separate class (#1668
Browse files Browse the repository at this point in the history
)
  • Loading branch information
Ken Payne authored May 5, 2023
1 parent 2974649 commit e029e30
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 52 deletions.
8 changes: 8 additions & 0 deletions docs/classes/singer_sdk.batch.BaseBatcher.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
singer_sdk.batch.BaseBatcher
============================

.. currentmodule:: singer_sdk.batch

.. autoclass:: BaseBatcher
:members:
:special-members: __init__, __call__
8 changes: 8 additions & 0 deletions docs/classes/singer_sdk.batch.JSONLinesBatcher.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
singer_sdk.batch.JSONLinesBatcher
=================================

.. currentmodule:: singer_sdk.batch

.. autoclass:: JSONLinesBatcher
:members:
:special-members: __init__, __call__
10 changes: 10 additions & 0 deletions docs/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,13 @@ Pagination
pagination.BaseOffsetPaginator
pagination.LegacyPaginatedStreamProtocol
pagination.LegacyStreamPaginator

Batch
-----

.. autosummary::
:toctree: classes
:template: class.rst

batch.BaseBatcher
batch.JSONLinesBatcher
110 changes: 110 additions & 0 deletions singer_sdk/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Batching utilities for Singer SDK."""
from __future__ import annotations

import gzip
import itertools
import json
import typing as t
from abc import ABC, abstractmethod
from uuid import uuid4

if t.TYPE_CHECKING:
from singer_sdk.helpers._batch import BatchConfig

_T = t.TypeVar("_T")


def lazy_chunked_generator(
iterable: t.Iterable[_T],
chunk_size: int,
) -> t.Generator[t.Iterator[_T], None, None]:
"""Yield a generator for each chunk of the given iterable.
Args:
iterable: The iterable to chunk.
chunk_size: The size of each chunk.
Yields:
A generator for each chunk of the given iterable.
"""
iterator = iter(iterable)
while True:
chunk = list(itertools.islice(iterator, chunk_size))
if not chunk:
break
yield iter(chunk)


class BaseBatcher(ABC):
"""Base Record Batcher."""

def __init__(
self,
tap_name: str,
stream_name: str,
batch_config: BatchConfig,
) -> None:
"""Initialize the batcher.
Args:
tap_name: The name of the tap.
stream_name: The name of the stream.
batch_config: The batch configuration.
"""
self.tap_name = tap_name
self.stream_name = stream_name
self.batch_config = batch_config

@abstractmethod
def get_batches(
self,
records: t.Iterator[dict],
) -> t.Iterator[list[str]]:
"""Yield manifest of batches.
Args:
records: The records to batch.
Raises:
NotImplementedError: If the method is not implemented.
"""
raise NotImplementedError


class JSONLinesBatcher(BaseBatcher):
"""JSON Lines Record Batcher."""

def get_batches(
self,
records: t.Iterator[dict],
) -> t.Iterator[list[str]]:
"""Yield manifest of batches.
Args:
records: The records to batch.
Yields:
A list of file paths (called a manifest).
"""
sync_id = f"{self.tap_name}--{self.stream_name}-{uuid4()}"
prefix = self.batch_config.storage.prefix or ""

for i, chunk in enumerate(
lazy_chunked_generator(
records,
self.batch_config.batch_size,
),
start=1,
):
filename = f"{prefix}{sync_id}-{i}.json.gz"
with self.batch_config.storage.fs() as fs:
# TODO: Determine compression from config.
with fs.open(filename, "wb") as f, gzip.GzipFile(
fileobj=f,
mode="wb",
) as gz:
gz.writelines(
(json.dumps(record) + "\n").encode() for record in chunk
)
file_url = fs.geturl(filename)
yield [file_url]
8 changes: 8 additions & 0 deletions singer_sdk/helpers/_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
if t.TYPE_CHECKING:
from fs.base import FS

DEFAULT_BATCH_SIZE = 10000


class BatchFileFormat(str, enum.Enum):
"""Batch file format."""
Expand Down Expand Up @@ -209,13 +211,19 @@ class BatchConfig:
storage: StorageTarget
"""The storage target of the batch file."""

batch_size: int = DEFAULT_BATCH_SIZE
"""The max number of records in a batch."""

def __post_init__(self):
if isinstance(self.encoding, dict):
self.encoding = BaseBatchFileEncoding.from_dict(self.encoding)

if isinstance(self.storage, dict):
self.storage = StorageTarget.from_dict(self.storage)

if self.batch_size is None:
self.batch_size = DEFAULT_BATCH_SIZE

def asdict(self):
"""Return a dictionary representation of the message.
Expand Down
61 changes: 9 additions & 52 deletions singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
import abc
import copy
import datetime
import gzip
import itertools
import json
import typing as t
from os import PathLike
from pathlib import Path
from types import MappingProxyType
from uuid import uuid4

import pendulum

import singer_sdk._singerlib as singer
from singer_sdk import metrics
from singer_sdk.batch import JSONLinesBatcher
from singer_sdk.exceptions import (
AbortedSyncFailedException,
AbortedSyncPausedException,
Expand Down Expand Up @@ -63,28 +61,6 @@
REPLICATION_LOG_BASED = "LOG_BASED"

FactoryType = t.TypeVar("FactoryType", bound="Stream")
_T = t.TypeVar("_T")


def lazy_chunked_generator(
iterable: t.Iterable[_T],
chunk_size: int,
) -> t.Generator[t.Iterator[_T], None, None]:
"""Yield a generator for each chunk of the given iterable.
Args:
iterable: The iterable to chunk.
chunk_size: The size of each chunk.
Yields:
A generator for each chunk of the given iterable.
"""
iterator = iter(iterable)
while True:
chunk = list(itertools.islice(iterator, chunk_size))
if not chunk:
break
yield iter(chunk)


class Stream(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -124,10 +100,6 @@ class Stream(metaclass=abc.ABCMeta):
# Internal API cost aggregator
_sync_costs: dict[str, int] = {}

# Batch attributes
batch_size: int = 1000
"""Max number of records to write to each batch file."""

def __init__(
self,
tap: Tap,
Expand Down Expand Up @@ -1341,29 +1313,14 @@ def get_batches(
Yields:
A tuple of (encoding, manifest) for each batch.
"""
sync_id = f"{self.tap_name}--{self.name}-{uuid4()}"
prefix = batch_config.storage.prefix or ""

for i, chunk in enumerate(
lazy_chunked_generator(
self._sync_records(context, write_messages=False),
self.batch_size,
),
start=1,
):
filename = f"{prefix}{sync_id}-{i}.json.gz"
with batch_config.storage.fs() as fs:
# TODO: Determine compression from config.
with fs.open(filename, "wb") as f, gzip.GzipFile(
fileobj=f,
mode="wb",
) as gz:
gz.writelines(
(json.dumps(record) + "\n").encode() for record in chunk
)
file_url = fs.geturl(filename)

yield batch_config.encoding, [file_url]
batcher = JSONLinesBatcher(
tap_name=self.tap_name,
stream_name=self.name,
batch_config=batch_config,
)
records = self._sync_records(context, write_messages=False)
for manifest in batcher.get_batches(records=records):
yield batch_config.encoding, manifest

def post_process(
self,
Expand Down

0 comments on commit e029e30

Please sign in to comment.