diff --git a/nodestream/pipeline/extractors/files.py b/nodestream/pipeline/extractors/files.py index a1a98d5c1..b25bae28c 100644 --- a/nodestream/pipeline/extractors/files.py +++ b/nodestream/pipeline/extractors/files.py @@ -60,6 +60,11 @@ def read_file_from_handle(self, fp: StringIO) -> Iterable[JsonLikeDocument]: return [json.load(fp)] +class LineSeperatedJsonFileFormat(SupportedFileFormat, alias=".jsonl"): + def read_file_from_handle(self, fp: StringIO) -> Iterable[JsonLikeDocument]: + return (json.loads(line) for line in fp) + + class TextFileFormat(SupportedFileFormat, alias=".txt"): def read_file_from_handle(self, fp: StringIO) -> Iterable[JsonLikeDocument]: return ({"line": line} for line in fp) diff --git a/tests/unit/pipeline/extractors/test_files.py b/tests/unit/pipeline/extractors/test_files.py index 4c041c510..156aa28f6 100644 --- a/tests/unit/pipeline/extractors/test_files.py +++ b/tests/unit/pipeline/extractors/test_files.py @@ -26,6 +26,16 @@ def json_file(fixture_directory): yield Path(temp_file.name) +@pytest.fixture +def jsonl_file(fixture_directory): + with NamedTemporaryFile("w+", suffix=".jsonl", dir=fixture_directory) as temp_file: + json.dump(SIMPLE_RECORD, temp_file) + temp_file.write("\n") + json.dump(SIMPLE_RECORD, temp_file) + temp_file.seek(0) + yield Path(temp_file.name) + + @pytest.fixture def csv_file(fixture_directory): with NamedTemporaryFile("w+", suffix=".csv", dir=fixture_directory) as temp_file: @@ -65,9 +75,16 @@ async def test_txt_formatting(txt_file): assert_that(results, equal_to([{"line": "hello world"}])) -def test_declarative_init(fixture_directory, csv_file, json_file, txt_file): +@pytest.mark.asyncio +async def test_jsonl_formatting(jsonl_file): + subject = FileExtractor([jsonl_file]) + results = [r async for r in subject.extract_records()] + assert_that(results, equal_to([SIMPLE_RECORD, SIMPLE_RECORD])) + + +def test_declarative_init(fixture_directory, csv_file, json_file, txt_file, jsonl_file): subject = FileExtractor.from_file_data(globs=[f"{fixture_directory}/**"]) - assert_that(list(subject.paths), has_length(3)) + assert_that(list(subject.paths), has_length(4)) @pytest.mark.asyncio