Skip to content

Commit

Permalink
fix: Transformers that return multiple items are not propertly handled
Browse files Browse the repository at this point in the history
  • Loading branch information
zprobst committed Aug 17, 2023
1 parent 0ee9ca8 commit d5477df
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
10 changes: 5 additions & 5 deletions nodestream/pipeline/transformers/transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Any, AsyncGenerator, Generator
from typing import Any, AsyncGenerator

from ..flush import Flush
from ..step import Step
Expand All @@ -19,12 +19,12 @@ async def handle_async_record_stream(
if record is Flush:
yield record
else:
val_or_gen = await self.transform_record(record)
if isinstance(val_or_gen, Generator):
for result in val_or_gen:
val_or_gen = self.transform_record(record)
if isinstance(val_or_gen, AsyncGenerator):
async for result in val_or_gen:
yield result
else:
yield val_or_gen
yield await val_or_gen

@abstractmethod
async def transform_record(self, record: Any) -> Any:
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/pipeline/transformers/test_expand_json_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@
@pytest.mark.parametrize("input,output,path", [(SIMPLE_INPUT, SIMPLE_OUTPUT, "b")])
async def test_expand_json_fields(input, output, path):
subject = ExpandJsonField.from_file_data(path)
result = await subject.transform_record(input)
assert_that(result, equal_to(output))

async def upstream():
yield input

results = [r async for r in subject.handle_async_record_stream(upstream())]
assert_that(results, equal_to([output]))
6 changes: 5 additions & 1 deletion tests/unit/pipeline/transformers/test_value_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,9 @@ async def test_value_projection_transform_record():
subject = ValueProjection(
projection=JmespathValueProvider.from_string_expression("items[*]")
)
results = [r async for r in subject.transform_record(record={"items": [1, 2, 3]})]

async def record():
yield {"items": [1, 2, 3]}

results = [r async for r in subject.handle_async_record_stream(record())]
assert_that(results, equal_to([1, 2, 3]))

0 comments on commit d5477df

Please sign in to comment.