diff --git a/singer_sdk/io_base.py b/singer_sdk/io_base.py index ef336ac18..07da6e63e 100644 --- a/singer_sdk/io_base.py +++ b/singer_sdk/io_base.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +import decimal import json import logging import sys @@ -49,6 +50,27 @@ def _assert_line_requires(line_dict: dict, requires: set[str]) -> None: msg = f"Line is missing required {', '.join(missing)} key(s): {line_dict}" raise Exception(msg) + def deserialize_json(self, line: str) -> dict: + """Deserialize a line of json. + + Args: + line: A single line of json. + + Returns: + A dictionary of the deserialized json. + + Raises: + json.decoder.JSONDecodeError: raised if any lines are not valid json + """ + try: + return json.loads( # type: ignore[no-any-return] + line, + parse_float=decimal.Decimal, + ) + except json.decoder.JSONDecodeError as exc: + logger.error("Unable to parse:\n%s", line, exc_info=exc) + raise + def _process_lines(self, file_input: t.IO[str]) -> t.Counter[str]: """Internal method to process jsonl lines from a Singer tap. @@ -57,18 +79,10 @@ def _process_lines(self, file_input: t.IO[str]) -> t.Counter[str]: Returns: A counter object for the processed lines. - - Raises: - json.decoder.JSONDecodeError: raised if any lines are not valid json """ stats: dict[str, int] = defaultdict(int) for line in file_input: - try: - line_dict = json.loads(line) - except json.decoder.JSONDecodeError as exc: - logger.error("Unable to parse:\n%s", line, exc_info=exc) - raise - + line_dict = self.deserialize_json(line) self._assert_line_requires(line_dict, requires={"type"}) record_type: SingerMessageType = line_dict["type"] diff --git a/tests/core/test_io.py b/tests/core/test_io.py new file mode 100644 index 000000000..c8de02447 --- /dev/null +++ b/tests/core/test_io.py @@ -0,0 +1,55 @@ +"""Test IO operations.""" + +from __future__ import annotations + +import decimal +import json +from contextlib import nullcontext + +import pytest + +from singer_sdk.io_base import SingerReader + + +class DummyReader(SingerReader): + def _process_activate_version_message(self, message_dict: dict) -> None: + pass + + def _process_batch_message(self, message_dict: dict) -> None: + pass + + def _process_record_message(self, message_dict: dict) -> None: + pass + + def _process_schema_message(self, message_dict: dict) -> None: + pass + + def _process_state_message(self, message_dict: dict) -> None: + pass + + +@pytest.mark.parametrize( + "line,expected,exception", + [ + pytest.param( + "not-valid-json", + None, + pytest.raises(json.decoder.JSONDecodeError), + id="unparsable", + ), + pytest.param( + '{"type": "RECORD", "stream": "users", "record": {"id": 1, "value": 1.23}}', # noqa: E501 + { + "type": "RECORD", + "stream": "users", + "record": {"id": 1, "value": decimal.Decimal("1.23")}, + }, + nullcontext(), + id="record", + ), + ], +) +def test_deserialize(line, expected, exception): + reader = DummyReader() + with exception: + assert reader.deserialize_json(line) == expected