Skip to content

Commit

Permalink
Proper kwargs forwarding on run method (#30)
Browse files Browse the repository at this point in the history
* 30min test

* fix kwargs forwarding to f and on_done issue

* v 0.2.5
  • Loading branch information
cgarciae authored Mar 4, 2020
1 parent 5ca2408 commit 9121c64
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 49 deletions.
2 changes: 1 addition & 1 deletion examples/process_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tqdm import tqdm
import time

total = 10_000
total = 300_000


def f(x):
Expand Down
2 changes: 1 addition & 1 deletion pypeln/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,5 @@ async def slow_gt3(x):
from . import task

__all__ = ["process", "thread", "task"]
__version__ = "0.2.4"
__version__ = "0.2.5"

20 changes: 9 additions & 11 deletions pypeln/process/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,27 @@ def run(self, index):
if "worker_info" in inspect.getfullargspec(self.on_start).args:
on_start_kwargs["worker_info"] = utils.WorkerInfo(index=index)

f_kwargs = self.on_start(**on_start_kwargs)
kwargs = self.on_start(**on_start_kwargs)
else:
f_kwargs = {}
kwargs = {}

if f_kwargs is None:
f_kwargs = {}
if kwargs is None:
kwargs = {}

self.process(**f_kwargs)

self.output_queues.done()
self.process(**kwargs)

if self.on_done is not None:
with self.stage_lock:
self.stage_namespace.active_workers -= 1

on_done_kwargs = {}

if "stage_status" in inspect.getfullargspec(self.on_done).args:
on_done_kwargs["stage_status"] = utils.StageStatus(
kwargs["stage_status"] = utils.StageStatus(
namespace=self.stage_namespace, lock=self.stage_lock
)

self.on_done(**on_done_kwargs)
self.on_done(**kwargs)

self.output_queues.done()

except BaseException as e:
try:
Expand Down
31 changes: 11 additions & 20 deletions pypeln/task/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,39 +50,30 @@ async def run(self):
if "worker_info" in inspect.getfullargspec(self.on_start).args:
on_start_kwargs["worker_info"] = utils.WorkerInfo(index=0)

f_kwargs = self.on_start(**on_start_kwargs)
kwargs = self.on_start(**on_start_kwargs)
else:
f_kwargs = {}
kwargs = {}

if f_kwargs is None:
f_kwargs = {}
if kwargs is None:
kwargs = {}

if hasattr(f_kwargs, "__await__"):
f_kwargs = await f_kwargs
if hasattr(kwargs, "__await__"):
kwargs = await kwargs

if hasattr(self, "apply"):
async with utils.TaskPool(self.workers) as tasks:

async for x in self.input_queue:
task = self.apply(x, **f_kwargs)
await tasks.put(task)
else:
await self.process(**f_kwargs)

await self.output_queues.done()
await self.process(**kwargs)

if self.on_done is not None:

on_done_kwargs = {}

if "stage_status" in inspect.getfullargspec(self.on_done).args:
on_done_kwargs["stage_status"] = utils.StageStatus()
kwargs["stage_status"] = utils.StageStatus()

done_resp = self.on_done(**on_done_kwargs)
done_resp = self.on_done(**kwargs)

if hasattr(done_resp, "__await__"):
await done_resp

await self.output_queues.done()

except BaseException as e:
for stage in self.pipeline_stages:
await stage.input_queue.done()
Expand Down
24 changes: 9 additions & 15 deletions pypeln/thread/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,33 +55,27 @@ def run(self, index):
if "worker_info" in inspect.getfullargspec(self.on_start).args:
on_start_kwargs["worker_info"] = utils.WorkerInfo(index=index)

f_kwargs = self.on_start(**on_start_kwargs)
kwargs = self.on_start(**on_start_kwargs)
else:
f_kwargs = {}
kwargs = {}

if f_kwargs is None:
f_kwargs = {}
if kwargs is None:
kwargs = {}

if hasattr(self, "apply"):
for x in self.input_queue:
self.apply(x, **f_kwargs)
else:
self.process(**f_kwargs)

self.output_queues.done()
self.process(**kwargs)

if self.on_done is not None:
with self.stage_lock:
self.stage_namespace.active_workers -= 1

on_done_kwargs = {}

if "stage_status" in inspect.getfullargspec(self.on_done).args:
on_done_kwargs["stage_status"] = utils.StageStatus(
kwargs["stage_status"] = utils.StageStatus(
namespace=self.stage_namespace, lock=self.stage_lock
)

self.on_done(**on_done_kwargs)
self.on_done(**kwargs)

self.output_queues.done()

except BaseException as e:
try:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

[tool.poetry]
name = "pypeln"
version = "0.2.4"
version = "0.2.5"
description = ""
authors = ["Cristian Garcia <[email protected]>"]
license = "MIT"
Expand Down
23 changes: 23 additions & 0 deletions tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,29 @@ def on_start(worker_info):
assert nums_pl.issubset(set(range(n_workers)))


def test_kwargs():

nums = range(100)
n_workers = 4
letters = "abc"
namespace = pl.process.get_namespace()
namespace.on_done = None

def on_start():
return dict(y=letters)

def on_done(y):
namespace.on_done = y

nums_pl = pl.process.map(
lambda x, y: y, nums, on_start=on_start, on_done=on_done, workers=n_workers,
)
nums_pl = list(nums_pl)

assert namespace.on_done == letters
assert nums_pl == [letters] * len(nums)


@hp.given(nums=st.lists(st.integers()))
@hp.settings(max_examples=MAX_EXAMPLES)
def test_map_square_event_end(nums):
Expand Down
40 changes: 40 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,46 @@ def on_start():
assert namespace.x == 1


def test_worker_info():

nums = range(100)
n_workers = 4

def on_start(worker_info):
return dict(index=worker_info.index)

def _lambda(x, index):
return index

nums_pl = pl.task.map(_lambda, nums, on_start=on_start, workers=n_workers,)
nums_pl = set(nums_pl)

assert nums_pl.issubset(set(range(n_workers)))


def test_kwargs():

nums = range(100)
n_workers = 4
letters = "abc"
namespace = pl.task.get_namespace()
namespace.on_done = None

def on_start():
return dict(y=letters)

def on_done(y):
namespace.on_done = y

nums_pl = pl.task.map(
lambda x, y: y, nums, on_start=on_start, on_done=on_done, workers=n_workers,
)
nums_pl = list(nums_pl)

assert nums_pl == [letters] * len(nums)
assert namespace.on_done == letters


@hp.given(nums=st.lists(st.integers()))
@hp.settings(max_examples=MAX_EXAMPLES)
def test_map_square_event_end(nums):
Expand Down
23 changes: 23 additions & 0 deletions tests/test_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,29 @@ def _lambda(x, index):
assert nums_pl.issubset(set(range(n_workers)))


def test_kwargs():

nums = range(100)
n_workers = 4
letters = "abc"
namespace = pl.thread.get_namespace()
namespace.on_done = None

def on_start():
return dict(y=letters)

def on_done(y):
namespace.on_done = y

nums_pl = pl.thread.map(
lambda x, y: y, nums, on_start=on_start, on_done=on_done, workers=n_workers,
)
nums_pl = list(nums_pl)

assert nums_pl == [letters] * len(nums)
assert namespace.on_done == letters


@hp.given(nums=st.lists(st.integers()))
@hp.settings(max_examples=MAX_EXAMPLES)
def test_map_square_event_end(nums):
Expand Down

0 comments on commit 9121c64

Please sign in to comment.