-
Notifications
You must be signed in to change notification settings - Fork 71
/
_batch.py
226 lines (164 loc) · 5.75 KB
/
_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""Batch helpers."""
from __future__ import annotations
import enum
from contextlib import contextmanager
from dataclasses import asdict, dataclass, field
from typing import IO, TYPE_CHECKING, Any, ClassVar, Generator
from urllib.parse import ParseResult, urlencode, urlparse
import fs
from singer_sdk._singerlib.messages import Message, SingerMessageType
if TYPE_CHECKING:
from fs.base import FS
class BatchFileFormat(str, enum.Enum):
"""Batch file format."""
JSONL = "jsonl"
"""JSON Lines format."""
@dataclass
class BaseBatchFileEncoding:
"""Base class for batch file encodings."""
registered_encodings: ClassVar[dict[str, type[BaseBatchFileEncoding]]] = {}
__encoding_format__: ClassVar[str] = "OVERRIDE_ME"
# Base encoding fields
format: str = field(init=False)
"""The format of the batch file."""
compression: str | None = None
"""The compression of the batch file."""
def __init_subclass__(cls, **kwargs: Any) -> None:
"""Register subclasses.
Args:
**kwargs: Keyword arguments.
"""
super().__init_subclass__(**kwargs)
cls.registered_encodings[cls.__encoding_format__] = cls
def __post_init__(self) -> None:
"""Post-init hook."""
self.format = self.__encoding_format__
@classmethod
def from_dict(cls, data: dict[str, Any]) -> BaseBatchFileEncoding:
"""Create an encoding from a dictionary."""
data = data.copy()
encoding_format = data.pop("format")
encoding_cls = cls.registered_encodings[encoding_format]
return encoding_cls(**data)
@dataclass
class JSONLinesEncoding(BaseBatchFileEncoding):
"""JSON Lines encoding for batch files."""
__encoding_format__ = "jsonl"
@dataclass
class SDKBatchMessage(Message):
"""Singer batch message in the Meltano SDK flavor."""
stream: str
"""The stream name."""
encoding: BaseBatchFileEncoding
"""The file encoding of the batch."""
manifest: list[str] = field(default_factory=list)
"""The manifest of files in the batch."""
def __post_init__(self):
if isinstance(self.encoding, dict):
self.encoding = BaseBatchFileEncoding.from_dict(self.encoding)
self.type = SingerMessageType.BATCH
@dataclass
class StorageTarget:
"""Storage target."""
root: str
""""The root directory of the storage target."""
prefix: str | None = None
""""The file prefix."""
params: dict = field(default_factory=dict)
""""The storage parameters."""
def asdict(self):
"""Return a dictionary representation of the message.
Returns:
A dictionary with the defined message fields.
"""
return asdict(self)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> StorageTarget:
"""Create an encoding from a dictionary.
Args:
data: The dictionary to create the message from.
Returns:
The created message.
"""
return cls(**data)
@staticmethod
def split_url(url: str) -> tuple[str, str]:
"""Split a URL into a head and tail pair.
Args:
url: The URL to split.
Returns:
A tuple of the head and tail parts of the URL.
"""
return fs.path.split(url)
@classmethod
def from_url(cls, url: str) -> StorageTarget:
"""Create a storage target from a file URL.
Args:
url: The URL to create the storage target from.
Returns:
The created storage target.
"""
parsed_url = urlparse(url)
new_url = parsed_url._replace(query="")
return cls(root=new_url.geturl())
@property
def fs_url(self) -> ParseResult:
"""Get the storage target URL.
Returns:
The storage target URL.
"""
return urlparse(self.root)._replace(query=urlencode(self.params))
@contextmanager
def fs(self, **kwargs: Any) -> Generator[FS, None, None]:
"""Get a filesystem object for the storage target.
Args:
kwargs: Additional arguments to pass ``f`.open_fs``.
Returns:
The filesystem object.
"""
filesystem = fs.open_fs(self.fs_url.geturl(), **kwargs)
yield filesystem
filesystem.close()
@contextmanager
def open(self, filename: str, mode: str = "rb") -> Generator[IO, None, None]:
"""Open a file in the storage target.
Args:
filename: The filename to open.
mode: The mode to open the file in.
Returns:
The opened file.
"""
filesystem = fs.open_fs(self.root, writeable=True, create=True)
fo = filesystem.open(filename, mode=mode)
try:
yield fo
finally:
fo.close()
filesystem.close()
@dataclass
class BatchConfig:
"""Batch configuration."""
encoding: BaseBatchFileEncoding
"""The encoding of the batch file."""
storage: StorageTarget
"""The storage target of the batch file."""
def __post_init__(self):
if isinstance(self.encoding, dict):
self.encoding = BaseBatchFileEncoding.from_dict(self.encoding)
if isinstance(self.storage, dict):
self.storage = StorageTarget.from_dict(self.storage)
def asdict(self):
"""Return a dictionary representation of the message.
Returns:
A dictionary with the defined message fields.
"""
return asdict(self)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> BatchConfig:
"""Create an encoding from a dictionary.
Args:
data: The dictionary to create the message from.
Returns:
The created message.
"""
return cls(**data)