Skip to content

Commit

Permalink
Merge pull request #50 from nodestream-proj/lightspeed
Browse files Browse the repository at this point in the history
Improve performance by buffering every generator in the pipeline
  • Loading branch information
zprobst authored Aug 26, 2023
2 parents b8707db + f42fba5 commit 3d03c54
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 81 deletions.
4 changes: 2 additions & 2 deletions nodestream/cli/operations/run_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ...pipeline import PipelineInitializationArguments
from ...pipeline import PipelineInitializationArguments, PipelineProgressReporter
from ...pipeline.meta import PipelineContext
from ...project import PipelineProgressReporter, Project, RunRequest
from ...project import Project, RunRequest
from ..commands.nodestream_command import NodestreamCommand
from .operation import Operation

Expand Down
2 changes: 2 additions & 0 deletions nodestream/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .meta import UNKNOWN_PIPELINE_NAME
from .pipeline import Pipeline
from .pipeline_file_loader import PipelineFileLoader, PipelineInitializationArguments
from .progress_reporter import PipelineProgressReporter
from .step import PassStep, Step
from .transformers import Transformer
from .writers import LoggerWriter, Writer
Expand All @@ -30,4 +31,5 @@
"PassStep",
"Flush",
"UNKNOWN_PIPELINE_NAME",
"PipelineProgressReporter",
)
91 changes: 76 additions & 15 deletions nodestream/pipeline/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
from functools import reduce
from typing import Any, AsyncGenerator, Iterable, List
from typing import Any, AsyncGenerator, Iterable, List, Optional

from ..schema.schema import (
AggregatedIntrospectiveIngestionComponent,
IntrospectiveIngestionComponent,
)
from .flush import Flush
from .meta import get_context
from .progress_reporter import PipelineProgressReporter
from .step import Step


Expand All @@ -15,6 +15,68 @@ async def empty_async_generator():
yield item


async def enumerate_async(iterable):
count = 0

async for item in iterable:
yield count, item
count += 1


class StepExecutor:
def __init__(
self,
upstream: Optional["StepExecutor"],
step: Step,
outbox_size: int = 0,
progress_reporter: Optional[PipelineProgressReporter] = None,
) -> None:
self.outbox = asyncio.Queue(maxsize=outbox_size)
self.upstream = upstream
self.done = False
self.step = step
self.progress_reporter = progress_reporter
self.end_of_line = False

async def outbox_generator(self):
while not self.done or not self.outbox.empty():
yield await self.outbox.get()
self.outbox.task_done()

def start(self):
if self.progress_reporter:
self.progress_reporter.on_start_callback()

def set_end_of_line(self, progress_reporter: Optional[PipelineProgressReporter]):
self.progress_reporter = progress_reporter
self.end_of_line = True

async def stop(self):
self.done = True
await self.step.finish()

if self.progress_reporter:
self.progress_reporter.on_finish_callback(get_context())

async def work_body(self):
if self.upstream is None:
upstream = empty_async_generator()
else:
upstream = self.upstream.outbox_generator()

results = self.step.handle_async_record_stream(upstream)
async for index, record in enumerate_async(results):
if not self.end_of_line:
await self.outbox.put(record)
if self.progress_reporter:
self.progress_reporter.report(index, record)

async def work_loop(self):
self.start()
await self.work_body()
await self.stop()


class Pipeline(AggregatedIntrospectiveIngestionComponent):
"""A pipeline is a series of steps that are executed in order."""

Expand All @@ -23,20 +85,19 @@ class Pipeline(AggregatedIntrospectiveIngestionComponent):
def __init__(self, steps: List[Step]) -> None:
self.steps = steps

async def run(self) -> AsyncGenerator[Any, Any]:
record_stream_over_all_steps = reduce(
lambda stream, step: step.handle_async_record_stream(stream),
self.steps,
empty_async_generator(),
)
async for record in record_stream_over_all_steps:
if record is not Flush:
yield record
async def run(
self, progress_reporter: Optional[PipelineProgressReporter] = None
) -> AsyncGenerator[Any, Any]:
current_executor = None
tasks = []

for step in self.steps:
current_executor = StepExecutor(current_executor, step)
tasks.append(asyncio.create_task(current_executor.work_loop()))

await self.finish()
current_executor.set_end_of_line(progress_reporter)

async def finish(self):
await asyncio.gather(*(step.finish() for step in self.steps))
await asyncio.gather(*tasks)

def all_subordinate_components(self) -> Iterable[IntrospectiveIngestionComponent]:
return (s for s in self.steps if isinstance(s, IntrospectiveIngestionComponent))
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from logging import Logger, getLogger
from typing import Any, Callable

from ..pipeline import Pipeline
from ..pipeline.meta import PipelineContext, get_context
from .meta import PipelineContext


def no_op(*_, **__):
Expand All @@ -26,23 +25,6 @@ def get_max_mem_mb():
return int(max_mem)


async def enumerate_async(iterable):
"""Asynchronously enumerate the given iterable.
This function is similar to the built-in `enumerate` function, but it is
asynchronous. It will yield a tuple of the index and the item from the
iterable.
Args:
iterable: The iterable to enumerate.
"""
count = 0

async for item in iterable:
yield count, item
count += 1


@dataclass
class PipelineProgressReporter:
"""A `PipelineProgressReporter` is a utility that can be used to report on the progress of a pipeline."""
Expand Down Expand Up @@ -74,28 +56,10 @@ def for_testing(cls, results_list: list) -> "PipelineProgressReporter":
on_finish_callback=no_op,
)

async def execute_with_reporting(self, pipeline: Pipeline):
"""Execute the given pipeline with progress reporting.
This method will execute the given pipeline asynchronously, reporting on the
progress of the pipeline every `reporting_frequency` records. The progress
reporting will be done via the `report` method.
Before and after the pipeline is executed, the `on_start_callback` and
`on_finish_callback` will be called, respectively. The `on_finish_callback`
will be passed the `PipelineContext` that was used to execute the pipeline.
Args:
pipeline: The pipeline to execute.
"""
self.on_start_callback()
async for index, record in enumerate_async(pipeline.run()):
if index % self.reporting_frequency == 0:
self.report(index, record)
self.on_finish_callback(get_context())

def report(self, index, record):
self.logger.info(
"Records Processed", extra={"index": index, "max_memory": get_max_mem_mb()}
)
self.callback(index, record)
if index % self.reporting_frequency == 0:
self.logger.info(
"Records Processed",
extra={"index": index, "max_memory": get_max_mem_mb()},
)
self.callback(index, record)
2 changes: 0 additions & 2 deletions nodestream/project/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from .pipeline_definition import PipelineDefinition
from .pipeline_progress_reporter import PipelineProgressReporter
from .pipeline_scope import PipelineScope
from .project import Project, ProjectPlugin
from .run_request import RunRequest

__all__ = (
"PipelineDefinition",
"PipelineProgressReporter",
"PipelineScope",
"Project",
"ProjectPlugin",
Expand Down
4 changes: 2 additions & 2 deletions nodestream/project/run_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from ..pipeline import PipelineInitializationArguments
from ..pipeline.meta import start_context
from ..pipeline.progress_reporter import PipelineProgressReporter
from .pipeline_definition import PipelineDefinition
from .pipeline_progress_reporter import PipelineProgressReporter


@dataclass
Expand Down Expand Up @@ -54,4 +54,4 @@ async def execute_with_definition(self, definition: PipelineDefinition):
"""
with start_context(self.pipeline_name):
pipeline = definition.initialize(self.initialization_arguments)
await self.progress_reporter.execute_with_reporting(pipeline)
await pipeline.run(self.progress_reporter)
12 changes: 10 additions & 2 deletions tests/integration/test_pipeline_and_data_interpretation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from pandas import Timestamp

from nodestream.model import DesiredIngestion
from nodestream.pipeline import PipelineInitializationArguments
from nodestream.pipeline import (
PipelineInitializationArguments,
PipelineProgressReporter,
)
from nodestream.project import PipelineDefinition
from nodestream.schema.printers import SchemaPrinter

Expand All @@ -26,9 +29,14 @@ def get_pipeline_fixture_file_by_name(name: str) -> Path:
@pytest.fixture
def drive_definition_to_completion():
async def _drive_definition_to_completion(definition, **init_kwargs):
results = []
init_args = PipelineInitializationArguments(**init_kwargs)
pipeline = definition.initialize(init_args)
return [r async for r in pipeline.run() if isinstance(r, DesiredIngestion)]
reporter = PipelineProgressReporter(
reporting_frequency=1, callback=lambda _, record: results.append(record)
)
await pipeline.run(reporter)
return [r for r in results if isinstance(r, DesiredIngestion)]

return _drive_definition_to_completion

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def pipeline(mocker):

@pytest.mark.asyncio
async def test_pipeline_run(pipeline):
_ = [r async for r in pipeline.run()]
await pipeline.run()
for step in pipeline.steps:
step.handle_async_record_stream.assert_called_once()
step.finish.assert_awaited_once()
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import pytest
from hamcrest import assert_that, equal_to

from nodestream.pipeline import IterableExtractor, Pipeline
from nodestream.project import PipelineProgressReporter
from nodestream.pipeline import IterableExtractor, Pipeline, PipelineProgressReporter


@pytest.mark.asyncio
async def test_pipeline_progress_reporter_calls_with_reporting_frequency(mocker):
pipeline = Pipeline([IterableExtractor(range(100))])
reporter = PipelineProgressReporter(reporting_frequency=10, callback=mocker.Mock())
await reporter.execute_with_reporting(pipeline)
await pipeline.run(reporter)
assert_that(reporter.callback.call_count, equal_to(10))


Expand Down
8 changes: 3 additions & 5 deletions tests/unit/project/test_pipeline_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import pytest
from hamcrest import assert_that, has_key, is_, not_, same_instance

from nodestream.pipeline import PipelineInitializationArguments
from nodestream.project import (
PipelineDefinition,
from nodestream.pipeline import (
PipelineInitializationArguments,
PipelineProgressReporter,
PipelineScope,
RunRequest,
)
from nodestream.project import PipelineDefinition, PipelineScope, RunRequest
from nodestream.project.pipeline_scope import MissingExpectedPipelineError


Expand Down
9 changes: 3 additions & 6 deletions tests/unit/project/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
import pytest
from hamcrest import assert_that, equal_to, has_length, same_instance

from nodestream.pipeline import PipelineInitializationArguments
from nodestream.project import (
PipelineDefinition,
from nodestream.pipeline import (
PipelineInitializationArguments,
PipelineProgressReporter,
PipelineScope,
Project,
RunRequest,
)
from nodestream.project import PipelineDefinition, PipelineScope, Project, RunRequest
from nodestream.schema.schema import GraphSchema


Expand Down

0 comments on commit 3d03c54

Please sign in to comment.