From 9121c64f9563046e79bc594e1d121944f3de3779 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 3 Mar 2020 23:55:10 -0500 Subject: [PATCH] Proper kwargs forwarding on run method (#30) * 30min test * fix kwargs forwarding to f and on_done issue * v 0.2.5 --- examples/process_error.py | 2 +- pypeln/__init__.py | 2 +- pypeln/process/stage.py | 20 +++++++++----------- pypeln/task/stage.py | 31 +++++++++++------------------- pypeln/thread/stage.py | 24 +++++++++-------------- pyproject.toml | 2 +- tests/test_process.py | 23 ++++++++++++++++++++++ tests/test_task.py | 40 +++++++++++++++++++++++++++++++++++++++ tests/test_thread.py | 23 ++++++++++++++++++++++ 9 files changed, 118 insertions(+), 49 deletions(-) diff --git a/examples/process_error.py b/examples/process_error.py index 9e3f0a7..48e81a6 100644 --- a/examples/process_error.py +++ b/examples/process_error.py @@ -2,7 +2,7 @@ from tqdm import tqdm import time -total = 10_000 +total = 300_000 def f(x): diff --git a/pypeln/__init__.py b/pypeln/__init__.py index b7c1e29..3d8d526 100644 --- a/pypeln/__init__.py +++ b/pypeln/__init__.py @@ -151,5 +151,5 @@ async def slow_gt3(x): from . import task __all__ = ["process", "thread", "task"] -__version__ = "0.2.4" +__version__ = "0.2.5" diff --git a/pypeln/process/stage.py b/pypeln/process/stage.py index 92b7294..87c4232 100644 --- a/pypeln/process/stage.py +++ b/pypeln/process/stage.py @@ -53,29 +53,27 @@ def run(self, index): if "worker_info" in inspect.getfullargspec(self.on_start).args: on_start_kwargs["worker_info"] = utils.WorkerInfo(index=index) - f_kwargs = self.on_start(**on_start_kwargs) + kwargs = self.on_start(**on_start_kwargs) else: - f_kwargs = {} + kwargs = {} - if f_kwargs is None: - f_kwargs = {} + if kwargs is None: + kwargs = {} - self.process(**f_kwargs) - - self.output_queues.done() + self.process(**kwargs) if self.on_done is not None: with self.stage_lock: self.stage_namespace.active_workers -= 1 - on_done_kwargs = {} - if "stage_status" in inspect.getfullargspec(self.on_done).args: - on_done_kwargs["stage_status"] = utils.StageStatus( + kwargs["stage_status"] = utils.StageStatus( namespace=self.stage_namespace, lock=self.stage_lock ) - self.on_done(**on_done_kwargs) + self.on_done(**kwargs) + + self.output_queues.done() except BaseException as e: try: diff --git a/pypeln/task/stage.py b/pypeln/task/stage.py index 4ba2476..e31b408 100644 --- a/pypeln/task/stage.py +++ b/pypeln/task/stage.py @@ -50,39 +50,30 @@ async def run(self): if "worker_info" in inspect.getfullargspec(self.on_start).args: on_start_kwargs["worker_info"] = utils.WorkerInfo(index=0) - f_kwargs = self.on_start(**on_start_kwargs) + kwargs = self.on_start(**on_start_kwargs) else: - f_kwargs = {} + kwargs = {} - if f_kwargs is None: - f_kwargs = {} + if kwargs is None: + kwargs = {} - if hasattr(f_kwargs, "__await__"): - f_kwargs = await f_kwargs + if hasattr(kwargs, "__await__"): + kwargs = await kwargs - if hasattr(self, "apply"): - async with utils.TaskPool(self.workers) as tasks: - - async for x in self.input_queue: - task = self.apply(x, **f_kwargs) - await tasks.put(task) - else: - await self.process(**f_kwargs) - - await self.output_queues.done() + await self.process(**kwargs) if self.on_done is not None: - on_done_kwargs = {} - if "stage_status" in inspect.getfullargspec(self.on_done).args: - on_done_kwargs["stage_status"] = utils.StageStatus() + kwargs["stage_status"] = utils.StageStatus() - done_resp = self.on_done(**on_done_kwargs) + done_resp = self.on_done(**kwargs) if hasattr(done_resp, "__await__"): await done_resp + await self.output_queues.done() + except BaseException as e: for stage in self.pipeline_stages: await stage.input_queue.done() diff --git a/pypeln/thread/stage.py b/pypeln/thread/stage.py index 461c2b4..89a82c3 100644 --- a/pypeln/thread/stage.py +++ b/pypeln/thread/stage.py @@ -55,33 +55,27 @@ def run(self, index): if "worker_info" in inspect.getfullargspec(self.on_start).args: on_start_kwargs["worker_info"] = utils.WorkerInfo(index=index) - f_kwargs = self.on_start(**on_start_kwargs) + kwargs = self.on_start(**on_start_kwargs) else: - f_kwargs = {} + kwargs = {} - if f_kwargs is None: - f_kwargs = {} + if kwargs is None: + kwargs = {} - if hasattr(self, "apply"): - for x in self.input_queue: - self.apply(x, **f_kwargs) - else: - self.process(**f_kwargs) - - self.output_queues.done() + self.process(**kwargs) if self.on_done is not None: with self.stage_lock: self.stage_namespace.active_workers -= 1 - on_done_kwargs = {} - if "stage_status" in inspect.getfullargspec(self.on_done).args: - on_done_kwargs["stage_status"] = utils.StageStatus( + kwargs["stage_status"] = utils.StageStatus( namespace=self.stage_namespace, lock=self.stage_lock ) - self.on_done(**on_done_kwargs) + self.on_done(**kwargs) + + self.output_queues.done() except BaseException as e: try: diff --git a/pyproject.toml b/pyproject.toml index 921b2b7..72df595 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "pypeln" -version = "0.2.4" +version = "0.2.5" description = "" authors = ["Cristian Garcia "] license = "MIT" diff --git a/tests/test_process.py b/tests/test_process.py index 6110bf5..b06b243 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -111,6 +111,29 @@ def on_start(worker_info): assert nums_pl.issubset(set(range(n_workers))) +def test_kwargs(): + + nums = range(100) + n_workers = 4 + letters = "abc" + namespace = pl.process.get_namespace() + namespace.on_done = None + + def on_start(): + return dict(y=letters) + + def on_done(y): + namespace.on_done = y + + nums_pl = pl.process.map( + lambda x, y: y, nums, on_start=on_start, on_done=on_done, workers=n_workers, + ) + nums_pl = list(nums_pl) + + assert namespace.on_done == letters + assert nums_pl == [letters] * len(nums) + + @hp.given(nums=st.lists(st.integers())) @hp.settings(max_examples=MAX_EXAMPLES) def test_map_square_event_end(nums): diff --git a/tests/test_task.py b/tests/test_task.py index 2846a2c..19f9b27 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -98,6 +98,46 @@ def on_start(): assert namespace.x == 1 +def test_worker_info(): + + nums = range(100) + n_workers = 4 + + def on_start(worker_info): + return dict(index=worker_info.index) + + def _lambda(x, index): + return index + + nums_pl = pl.task.map(_lambda, nums, on_start=on_start, workers=n_workers,) + nums_pl = set(nums_pl) + + assert nums_pl.issubset(set(range(n_workers))) + + +def test_kwargs(): + + nums = range(100) + n_workers = 4 + letters = "abc" + namespace = pl.task.get_namespace() + namespace.on_done = None + + def on_start(): + return dict(y=letters) + + def on_done(y): + namespace.on_done = y + + nums_pl = pl.task.map( + lambda x, y: y, nums, on_start=on_start, on_done=on_done, workers=n_workers, + ) + nums_pl = list(nums_pl) + + assert nums_pl == [letters] * len(nums) + assert namespace.on_done == letters + + @hp.given(nums=st.lists(st.integers())) @hp.settings(max_examples=MAX_EXAMPLES) def test_map_square_event_end(nums): diff --git a/tests/test_thread.py b/tests/test_thread.py index 2893822..af70f60 100644 --- a/tests/test_thread.py +++ b/tests/test_thread.py @@ -112,6 +112,29 @@ def _lambda(x, index): assert nums_pl.issubset(set(range(n_workers))) +def test_kwargs(): + + nums = range(100) + n_workers = 4 + letters = "abc" + namespace = pl.thread.get_namespace() + namespace.on_done = None + + def on_start(): + return dict(y=letters) + + def on_done(y): + namespace.on_done = y + + nums_pl = pl.thread.map( + lambda x, y: y, nums, on_start=on_start, on_done=on_done, workers=n_workers, + ) + nums_pl = list(nums_pl) + + assert nums_pl == [letters] * len(nums) + assert namespace.on_done == letters + + @hp.given(nums=st.lists(st.integers())) @hp.settings(max_examples=MAX_EXAMPLES) def test_map_square_event_end(nums):