diff --git a/nodestream/cli/operations/run_pipeline.py b/nodestream/cli/operations/run_pipeline.py index f388f2386..457f9825e 100644 --- a/nodestream/cli/operations/run_pipeline.py +++ b/nodestream/cli/operations/run_pipeline.py @@ -1,6 +1,6 @@ -from ...pipeline import PipelineInitializationArguments +from ...pipeline import PipelineInitializationArguments, PipelineProgressReporter from ...pipeline.meta import PipelineContext -from ...project import PipelineProgressReporter, Project, RunRequest +from ...project import Project, RunRequest from ..commands.nodestream_command import NodestreamCommand from .operation import Operation diff --git a/nodestream/pipeline/__init__.py b/nodestream/pipeline/__init__.py index bde4334d2..a2e3f3bd0 100644 --- a/nodestream/pipeline/__init__.py +++ b/nodestream/pipeline/__init__.py @@ -9,6 +9,7 @@ from .meta import UNKNOWN_PIPELINE_NAME from .pipeline import Pipeline from .pipeline_file_loader import PipelineFileLoader, PipelineInitializationArguments +from .progress_reporter import PipelineProgressReporter from .step import PassStep, Step from .transformers import Transformer from .writers import LoggerWriter, Writer @@ -30,4 +31,5 @@ "PassStep", "Flush", "UNKNOWN_PIPELINE_NAME", + "PipelineProgressReporter", ) diff --git a/nodestream/pipeline/pipeline.py b/nodestream/pipeline/pipeline.py index 46475f6b2..d0766fb6a 100644 --- a/nodestream/pipeline/pipeline.py +++ b/nodestream/pipeline/pipeline.py @@ -1,12 +1,12 @@ import asyncio -from functools import reduce -from typing import Any, AsyncGenerator, Iterable, List +from typing import Any, AsyncGenerator, Iterable, List, Optional from ..schema.schema import ( AggregatedIntrospectiveIngestionComponent, IntrospectiveIngestionComponent, ) -from .flush import Flush +from .meta import get_context +from .progress_reporter import PipelineProgressReporter from .step import Step @@ -15,6 +15,68 @@ async def empty_async_generator(): yield item +async def enumerate_async(iterable): + count = 0 + + async for item in iterable: + yield count, item + count += 1 + + +class StepExecutor: + def __init__( + self, + upstream: Optional["StepExecutor"], + step: Step, + outbox_size: int = 0, + progress_reporter: Optional[PipelineProgressReporter] = None, + ) -> None: + self.outbox = asyncio.Queue(maxsize=outbox_size) + self.upstream = upstream + self.done = False + self.step = step + self.progress_reporter = progress_reporter + self.end_of_line = False + + async def outbox_generator(self): + while not self.done or not self.outbox.empty(): + yield await self.outbox.get() + self.outbox.task_done() + + def start(self): + if self.progress_reporter: + self.progress_reporter.on_start_callback() + + def set_end_of_line(self, progress_reporter: Optional[PipelineProgressReporter]): + self.progress_reporter = progress_reporter + self.end_of_line = True + + async def stop(self): + self.done = True + await self.step.finish() + + if self.progress_reporter: + self.progress_reporter.on_finish_callback(get_context()) + + async def work_body(self): + if self.upstream is None: + upstream = empty_async_generator() + else: + upstream = self.upstream.outbox_generator() + + results = self.step.handle_async_record_stream(upstream) + async for index, record in enumerate_async(results): + if not self.end_of_line: + await self.outbox.put(record) + if self.progress_reporter: + self.progress_reporter.report(index, record) + + async def work_loop(self): + self.start() + await self.work_body() + await self.stop() + + class Pipeline(AggregatedIntrospectiveIngestionComponent): """A pipeline is a series of steps that are executed in order.""" @@ -23,20 +85,19 @@ class Pipeline(AggregatedIntrospectiveIngestionComponent): def __init__(self, steps: List[Step]) -> None: self.steps = steps - async def run(self) -> AsyncGenerator[Any, Any]: - record_stream_over_all_steps = reduce( - lambda stream, step: step.handle_async_record_stream(stream), - self.steps, - empty_async_generator(), - ) - async for record in record_stream_over_all_steps: - if record is not Flush: - yield record + async def run( + self, progress_reporter: Optional[PipelineProgressReporter] = None + ) -> AsyncGenerator[Any, Any]: + current_executor = None + tasks = [] + + for step in self.steps: + current_executor = StepExecutor(current_executor, step) + tasks.append(asyncio.create_task(current_executor.work_loop())) - await self.finish() + current_executor.set_end_of_line(progress_reporter) - async def finish(self): - await asyncio.gather(*(step.finish() for step in self.steps)) + await asyncio.gather(*tasks) def all_subordinate_components(self) -> Iterable[IntrospectiveIngestionComponent]: return (s for s in self.steps if isinstance(s, IntrospectiveIngestionComponent)) diff --git a/nodestream/project/pipeline_progress_reporter.py b/nodestream/pipeline/progress_reporter.py similarity index 53% rename from nodestream/project/pipeline_progress_reporter.py rename to nodestream/pipeline/progress_reporter.py index c6a87933f..674c5f506 100644 --- a/nodestream/project/pipeline_progress_reporter.py +++ b/nodestream/pipeline/progress_reporter.py @@ -4,8 +4,7 @@ from logging import Logger, getLogger from typing import Any, Callable -from ..pipeline import Pipeline -from ..pipeline.meta import PipelineContext, get_context +from .meta import PipelineContext def no_op(*_, **__): @@ -26,23 +25,6 @@ def get_max_mem_mb(): return int(max_mem) -async def enumerate_async(iterable): - """Asynchronously enumerate the given iterable. - - This function is similar to the built-in `enumerate` function, but it is - asynchronous. It will yield a tuple of the index and the item from the - iterable. - - Args: - iterable: The iterable to enumerate. - """ - count = 0 - - async for item in iterable: - yield count, item - count += 1 - - @dataclass class PipelineProgressReporter: """A `PipelineProgressReporter` is a utility that can be used to report on the progress of a pipeline.""" @@ -74,28 +56,10 @@ def for_testing(cls, results_list: list) -> "PipelineProgressReporter": on_finish_callback=no_op, ) - async def execute_with_reporting(self, pipeline: Pipeline): - """Execute the given pipeline with progress reporting. - - This method will execute the given pipeline asynchronously, reporting on the - progress of the pipeline every `reporting_frequency` records. The progress - reporting will be done via the `report` method. - - Before and after the pipeline is executed, the `on_start_callback` and - `on_finish_callback` will be called, respectively. The `on_finish_callback` - will be passed the `PipelineContext` that was used to execute the pipeline. - - Args: - pipeline: The pipeline to execute. - """ - self.on_start_callback() - async for index, record in enumerate_async(pipeline.run()): - if index % self.reporting_frequency == 0: - self.report(index, record) - self.on_finish_callback(get_context()) - def report(self, index, record): - self.logger.info( - "Records Processed", extra={"index": index, "max_memory": get_max_mem_mb()} - ) - self.callback(index, record) + if index % self.reporting_frequency == 0: + self.logger.info( + "Records Processed", + extra={"index": index, "max_memory": get_max_mem_mb()}, + ) + self.callback(index, record) diff --git a/nodestream/project/__init__.py b/nodestream/project/__init__.py index 28985efc2..fb8234943 100644 --- a/nodestream/project/__init__.py +++ b/nodestream/project/__init__.py @@ -1,12 +1,10 @@ from .pipeline_definition import PipelineDefinition -from .pipeline_progress_reporter import PipelineProgressReporter from .pipeline_scope import PipelineScope from .project import Project, ProjectPlugin from .run_request import RunRequest __all__ = ( "PipelineDefinition", - "PipelineProgressReporter", "PipelineScope", "Project", "ProjectPlugin", diff --git a/nodestream/project/run_request.py b/nodestream/project/run_request.py index fb5c00cb4..d840a79ae 100644 --- a/nodestream/project/run_request.py +++ b/nodestream/project/run_request.py @@ -2,8 +2,8 @@ from ..pipeline import PipelineInitializationArguments from ..pipeline.meta import start_context +from ..pipeline.progress_reporter import PipelineProgressReporter from .pipeline_definition import PipelineDefinition -from .pipeline_progress_reporter import PipelineProgressReporter @dataclass @@ -54,4 +54,4 @@ async def execute_with_definition(self, definition: PipelineDefinition): """ with start_context(self.pipeline_name): pipeline = definition.initialize(self.initialization_arguments) - await self.progress_reporter.execute_with_reporting(pipeline) + await pipeline.run(self.progress_reporter) diff --git a/tests/integration/test_pipeline_and_data_interpretation.py b/tests/integration/test_pipeline_and_data_interpretation.py index 45114b96d..30a305d0c 100644 --- a/tests/integration/test_pipeline_and_data_interpretation.py +++ b/tests/integration/test_pipeline_and_data_interpretation.py @@ -6,7 +6,10 @@ from pandas import Timestamp from nodestream.model import DesiredIngestion -from nodestream.pipeline import PipelineInitializationArguments +from nodestream.pipeline import ( + PipelineInitializationArguments, + PipelineProgressReporter, +) from nodestream.project import PipelineDefinition from nodestream.schema.printers import SchemaPrinter @@ -26,9 +29,14 @@ def get_pipeline_fixture_file_by_name(name: str) -> Path: @pytest.fixture def drive_definition_to_completion(): async def _drive_definition_to_completion(definition, **init_kwargs): + results = [] init_args = PipelineInitializationArguments(**init_kwargs) pipeline = definition.initialize(init_args) - return [r async for r in pipeline.run() if isinstance(r, DesiredIngestion)] + reporter = PipelineProgressReporter( + reporting_frequency=1, callback=lambda _, record: results.append(record) + ) + await pipeline.run(reporter) + return [r for r in results if isinstance(r, DesiredIngestion)] return _drive_definition_to_completion diff --git a/tests/unit/pipeline/test_pipeline.py b/tests/unit/pipeline/test_pipeline.py index 96ae73e95..ab1469592 100644 --- a/tests/unit/pipeline/test_pipeline.py +++ b/tests/unit/pipeline/test_pipeline.py @@ -15,7 +15,7 @@ def pipeline(mocker): @pytest.mark.asyncio async def test_pipeline_run(pipeline): - _ = [r async for r in pipeline.run()] + await pipeline.run() for step in pipeline.steps: step.handle_async_record_stream.assert_called_once() step.finish.assert_awaited_once() diff --git a/tests/unit/project/test_pipeline_progress_reporter.py b/tests/unit/pipeline/test_pipeline_progress_reporter.py similarity index 78% rename from tests/unit/project/test_pipeline_progress_reporter.py rename to tests/unit/pipeline/test_pipeline_progress_reporter.py index 08734b3bd..364c3c9c6 100644 --- a/tests/unit/project/test_pipeline_progress_reporter.py +++ b/tests/unit/pipeline/test_pipeline_progress_reporter.py @@ -1,15 +1,14 @@ import pytest from hamcrest import assert_that, equal_to -from nodestream.pipeline import IterableExtractor, Pipeline -from nodestream.project import PipelineProgressReporter +from nodestream.pipeline import IterableExtractor, Pipeline, PipelineProgressReporter @pytest.mark.asyncio async def test_pipeline_progress_reporter_calls_with_reporting_frequency(mocker): pipeline = Pipeline([IterableExtractor(range(100))]) reporter = PipelineProgressReporter(reporting_frequency=10, callback=mocker.Mock()) - await reporter.execute_with_reporting(pipeline) + await pipeline.run(reporter) assert_that(reporter.callback.call_count, equal_to(10)) diff --git a/tests/unit/project/test_pipeline_scope.py b/tests/unit/project/test_pipeline_scope.py index 9903e8a8c..2e61ecbe0 100644 --- a/tests/unit/project/test_pipeline_scope.py +++ b/tests/unit/project/test_pipeline_scope.py @@ -3,13 +3,11 @@ import pytest from hamcrest import assert_that, has_key, is_, not_, same_instance -from nodestream.pipeline import PipelineInitializationArguments -from nodestream.project import ( - PipelineDefinition, +from nodestream.pipeline import ( + PipelineInitializationArguments, PipelineProgressReporter, - PipelineScope, - RunRequest, ) +from nodestream.project import PipelineDefinition, PipelineScope, RunRequest from nodestream.project.pipeline_scope import MissingExpectedPipelineError diff --git a/tests/unit/project/test_project.py b/tests/unit/project/test_project.py index 6350bda3c..821cef260 100644 --- a/tests/unit/project/test_project.py +++ b/tests/unit/project/test_project.py @@ -3,14 +3,11 @@ import pytest from hamcrest import assert_that, equal_to, has_length, same_instance -from nodestream.pipeline import PipelineInitializationArguments -from nodestream.project import ( - PipelineDefinition, +from nodestream.pipeline import ( + PipelineInitializationArguments, PipelineProgressReporter, - PipelineScope, - Project, - RunRequest, ) +from nodestream.project import PipelineDefinition, PipelineScope, Project, RunRequest from nodestream.schema.schema import GraphSchema