From b9739294ec472397a529b9d92b61c891d7ba545b Mon Sep 17 00:00:00 2001 From: The Metaist Date: Thu, 18 May 2023 00:03:48 -0400 Subject: [PATCH] add: multiprocess to fix OSX issues (closes #9) --- Pipfile | 2 +- README.md | 12 +- src/ezq/__init__.py | 342 +++++++++++++++++++++++++------------------- test/test_ezq.py | 17 ++- test/test_iter.py | 3 +- 5 files changed, 219 insertions(+), 157 deletions(-) diff --git a/Pipfile b/Pipfile index 82f2c1b..b74be68 100644 --- a/Pipfile +++ b/Pipfile @@ -13,7 +13,7 @@ pytest-cov = "*" ruff = "*" [packages] -typing_extensions = "*" +multiprocess = "*" [requires] python_version = "3.8" diff --git a/README.md b/README.md index 339f240..d6651ba 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ if __name__ == "__main__": - The main process [creates queues](#create-queues) with `ezq.Q`. -- The main process [creates workers](#create-workers) with `ezq.run` or `ezq.run_thread`. +- The main process [creates workers](#create-workers) with `ezq.run` (alias for `Worker.process`) or `ezq.run_thread` (alias for `Worker.thread`). - The main process [sends data](#send-data) using `Q.put`. @@ -106,7 +106,7 @@ if __name__ == "__main__": Some more differences: -- **Shared memory**: Each `Process` worker has [data sent to it via `pickle`](#beware-pickle) and it doesn't share data with other workers. By contrast, each `Thread` worker shares its memory with all other workers on the same CPU, so it can [accidentally change global state](#beware-shared-state). +- **Shared memory**: Each `Process` worker has [data sent to it via `pickle`](#beware-pickle) (actually [`dill`](https://github.com/uqfoundation/dill), a `pickle` replacement) and it doesn't share data with other workers. By contrast, each `Thread` worker shares its memory with all other workers on the same CPU, so it can [accidentally change global state](#beware-shared-state). - **Queue overhead**: `ezq.Q` [has more overhead](#create-queues) for `Process` workers than `Thread` workers. @@ -124,11 +124,11 @@ In the main process, create the queues you'll need. Here are my common situation - **3 queues**: multiple stages of work are happening where workers are reading from one queue and writing to another queue for another worker to process. -**NOTE:** If you're using `Thread` workers, you can save some overhead by passing `thread=True`. This lightweight queue also doesn't use `pickle`, so you can use it to pass hard-to-pickle things (e.g., `lambda`). +**NOTE:** If you're using `Thread` workers, you can save some overhead by passing `Q("thread")`. This lightweight queue also doesn't use `pickle`, so you can use it to pass hard-to-pickle things (e.g., database connection). ```python q, out = ezq.Q(), ezq.Q() # most common -q2 = ez.Q(thread=True) # only ok for Thread workers +q2 = ez.Q("thread") # only ok for Thread workers ``` ## A worker task is just a function @@ -157,7 +157,7 @@ Once you've created the workers, you send them data with `Q.put` which creates ` ## Beware `pickle` -If you are using `Process` workers, everything passed to the worker (arguments, messages) is first passed to `pickle` by [`multiprocessing`][1]. Anything that cannot be pickled (e.g., `lambda` functions, database connections), cannot be passed to `Process` workers. +If you are using `Process` workers, everything passed to the worker (arguments, messages) is first passed to `pickle` (actually, [`dill`](https://github.com/uqfoundation/dill)). Anything that cannot be pickled with dill (e.g., database connections), cannot be passed to `Process` workers. Note that `dill` _can_ serialize many more types than `pickle` (e.g. `lambda` functions). ## Beware shared state @@ -225,7 +225,7 @@ def collatz(q: ezq.Q, out: ezq.Q) -> None: def main() -> None: """Run several threads with a subprocess for printing.""" - q, out = ezq.Q(thread=True), ezq.Q() + q, out = ezq.Q("thread"), ezq.Q() readers = [ezq.run_thread(collatz, q, out) for _ in range(ezq.NUM_THREADS)] writer = ezq.run(printer, out) diff --git a/src/ezq/__init__.py b/src/ezq/__init__.py index e80fcb1..3f5ac82 100644 --- a/src/ezq/__init__.py +++ b/src/ezq/__init__.py @@ -14,15 +14,17 @@ "__pubdate__", "__url__", "__version__", + "Task", + "Context", + "ContextName", "Msg", - "Q", - # - ## constants ## + "END_MSG", + "MsgQ", "NUM_CPUS", "NUM_THREADS", - "END_MSG", - # - ## functions ## + "IS_MACOS", + "Worker", + "Q", "run", "run_thread", "map", @@ -30,35 +32,44 @@ # native from dataclasses import dataclass -from multiprocessing import Process, Queue from operator import attrgetter from os import cpu_count -from queue import Empty, Queue as ThreadSafeQueue +from platform import system +from queue import Empty +from queue import Queue as ThreadSafeQueue from threading import Thread -from typing import ( - Any, - Callable, - Iterable, - List, - Sequence, - Iterator, - Optional, - Union, - TYPE_CHECKING, -) -from typing_extensions import deprecated, Self - +from typing import Any +from typing import Callable +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Literal +from typing import Optional +from typing import Sequence +from typing import TYPE_CHECKING +from typing import Union + +# lib +from multiprocess import Process # type: ignore +from multiprocess import Queue # pkg -from .__about__ import ( - __url__, - __version__, - __pubdate__, - __author__, - __email__, - __copyright__, - __license__, -) +from .__about__ import __author__ +from .__about__ import __copyright__ +from .__about__ import __email__ +from .__about__ import __license__ +from .__about__ import __pubdate__ +from .__about__ import __url__ +from .__about__ import __version__ + +Task = Callable[..., Any] +"""Task function signature (any `Callable`).""" + +Context = Union[Process, Thread] +"""Execution contexts (`Process`, `Thread`).""" + +ContextName = Literal["process", "thread"] +"""Execution context names (`"process"`, `"thread"`).""" @dataclass @@ -75,9 +86,18 @@ class Msg: """Optional ordering of messages.""" +# NOTE: The python `queue.Queue` is not properly a generic. +# See: https://stackoverflow.com/a/48554601 +if TYPE_CHECKING: # pragma: no cover + MsgQ = Union[Queue[Msg], ThreadSafeQueue] # pylint: disable=unsubscriptable-object +else: + MsgQ = Queue + END_MSG: Msg = Msg(kind="END") """Message that indicates no future messages will be sent.""" +## Hardware-Specific Information ## + NUM_CPUS: int = cpu_count() or 1 """Number of CPUs on this machine.""" @@ -89,40 +109,69 @@ class Msg: [1]: https://github.com/python/cpython/blob/a635d6386041a2971cf1d39837188ffb8139bcc7/Lib/concurrent/futures/thread.py#L142 """ -# NOTE: The python `queue.Queue` is not properly a generic. -# See: https://stackoverflow.com/a/48554601 -if TYPE_CHECKING: # pragma: no cover - MsgQ = Union[Queue[Msg], ThreadSafeQueue] # pylint: disable=unsubscriptable-object -else: - MsgQ = Queue +IS_MACOS: bool = system().lower().startswith("darwin") +"""`True` if we're running on MacOS. -Task = Callable[..., Any] -"""Task function signature.""" +Currently, we only use this value for testing, but there are certain features that +do not work properly on MacOS. -Context = Union[Process, Thread] -"""Execution contexts.""" +See: [Example of MacOS-specific issues][1]. -ContextName = Literal["process", "thread"] -"""Execution context names.""" +[1]: https://github.com/python/cpython/blob/c5b670efd1e6dabc94b6308734d63f762480b80f/Lib/multiprocessing/queues.py#L125 +""" class Worker: """A function running in a `Process` or `Thread`.""" - worker: Context + _worker: Context """Execution context.""" @staticmethod - def process(task: Task, *args, **kwargs) -> "Worker": + def process(task: Task, *args: Any, **kwargs: Any) -> "Worker": + """Create a `Process`-based `Worker`. + + Args: + task (Task): function to run + *args (Any): additional arguments to `task` + **kwargs (Any): additional keyword arguments to `task` + + Returns: + Worker: wrapped worker. + """ + # NOTE: On MacOS, python 3.8 switched the default method + # from "fork" to "spawn" because fork is considered dangerous. + # Some posts say "forkserver" should be ok. + # See: https://bugs.python.org/issue?@action=redirect&bpo=33725 + # + # if IS_MACOS: + # ctx = get_context("forkserver") + # else: + # ctx = get_context() return Worker(Process(daemon=True, target=task, args=args, kwargs=kwargs)) @staticmethod - def thread(task: Task, *args, **kwargs) -> "Worker": + def thread(task: Task, *args: Any, **kwargs: Any) -> "Worker": + """Create a `Thread`-based `Worker`. + + Args: + task (Task): function to run + *args (Any): additional arguments to `task` + **kwargs (Any): additional keyword arguments to `task` + + Returns: + Worker: wrapped worker. + """ return Worker(Thread(daemon=False, target=task, args=args, kwargs=kwargs)) - def __init__(self, context: Union[Process, Thread]): - self.worker = context - self.worker.start() + def __init__(self, context: Context): + """Construct a worker from a context. + + Args: + context (Context): a `Process` or a `Thread` + """ + self._worker = context + self._worker.start() def __getattr__(self, name: str) -> Any: """Delegate properties to the underlying task. @@ -133,99 +182,19 @@ def __getattr__(self, name: str) -> Any: Returns: Any: attribute from the task """ - return getattr(self.worker, name) - - -def run(task: Task, *args: Any, **kwargs: Any) -> Worker: - """Run a function as a subprocess. - - Args: - task (Task): function to run in each subprocess - - *args (Any): additional positional arguments to `task`. - - **kwargs (Any): additional keyword arguments to `task`. - - Returns: - Worker: worker started in a subprocess - - .. changed:: 2.0.4 - This function now returns a `Worker` instead of a `Process`. - """ - return Worker.process(task, *args, **kwargs) - - -def run_thread(task: Task, *args: Any, **kwargs: Any) -> Worker: - """Run a function as a thread. - - Args: - task (Task): function to run in each thread - - *args (Any): additional positional arguments to `task`. - - **kwargs (Any): additional keyword arguments to `task`. - - Returns: - Worker: worker started in a thread - - .. changed:: 2.0.4 - This function now returns a `Worker` instead of a `Thread`. - """ - return Worker.thread(task, *args, **kwargs) - - -def map( - task: Task, - *args: Iterable[Any], - num: Optional[int] = None, - kind: ContextName = "process", -) -> Iterator[Any]: - """Call a function with arguments using multiple workers. - - Args: - func (Callable): function to call - *args (list[Any]): arguments to `func`. If multiple lists are provided, - they will be passed to `zip` first. - num (int, optional): number of workers. If `None`, `NUM_CPUS` or - `NUM_THREADS` will be used as appropriate. Defaults to `None`. - kind (ContextName, optional): execution context to use. - Defaults to `"process"`. - - Yields: - Any: results from applying the function to the arguments - """ - q, out = Q(kind=kind), Q(kind=kind) - - def worker(_q: Q, _out: Q) -> None: - """Internal call to `func`.""" - for msg in _q.sorted(): - _out.put(data=task(*msg.data), order=msg.order) - - if kind == "process": - workers = [Worker.process(worker, q, out) for _ in range(num or NUM_CPUS)] - elif kind == "thread": - workers = [Worker.thread(worker, q, out) for _ in range(num or NUM_THREADS)] - else: # pragma: no cover - raise ValueError(f"Unknown worker context: {kind}") - - for order, value in enumerate(zip(*args)): - q.put(value, order=order) - q.stop(workers) - - for msg in out.end().sorted(): - yield msg.data + return getattr(self._worker, name) class Q: """Simple message queue.""" - q: MsgQ + _q: MsgQ """Wrapped queue.""" - _items: Optional[List[Msg]] = None + _cache: Optional[List[Msg]] = None """Cache of queue messages when calling `.items(cache=True)`.""" - timeout: float = 0.05 + _timeout: float = 0.05 """Time in seconds to poll the queue.""" def __init__(self, kind: ContextName = "process"): @@ -234,12 +203,12 @@ def __init__(self, kind: ContextName = "process"): Args: kind (ContextName, optional): If `"thread"`, construct a lighter-weight `Queue` that is thread-safe. Otherwise, construct a full - `multiprocessing.Queue`. Defaults to `"process"`. + `multiprocess.Queue`. Defaults to `"process"`. """ if kind == "process": - self.q = Queue() + self._q = Queue() elif kind == "thread": - self.q = ThreadSafeQueue() + self._q = ThreadSafeQueue() else: # pragma: no cover raise ValueError(f"Unknown queue type: {kind}") @@ -252,7 +221,7 @@ def __getattr__(self, name: str) -> Any: Returns: Any: attribute from the queue """ - return getattr(self.q, name) + return getattr(self._q, name) def __iter__(self) -> Iterator[Msg]: """Iterate over messages in a queue until `END_MSG` is received. @@ -262,7 +231,7 @@ def __iter__(self) -> Iterator[Msg]: """ while True: try: - msg = self.q.get(block=True, timeout=self.timeout) + msg = self._q.get(block=True, timeout=self._timeout) if msg.kind == END_MSG.kind: # We'd really like to put the `END_MSG` back in the queue # to prevent reading past the end, but in practice @@ -290,16 +259,16 @@ def items(self, cache: bool = False, sort: bool = False) -> Iterator[Msg]: Iterator[Msg]: iterate over messages in the queue """ if cache: - if self._items is None: # need to build a cache + if self._cache is None: # need to build a cache self.end() - self._items = list(self.sorted() if sort else self) - return iter(self._items) + self._cache = list(self.sorted() if sort else self) + return iter(self._cache) # not cached self.end() return self.sorted() if sort else iter(self) - def sorted(self, start=0) -> Iterator[Msg]: + def sorted(self, start: int = 0) -> Iterator[Msg]: """Iterate over messages sorted by `Msg.order`. NOTE: `Msg.order` must be incremented by one for each message. @@ -337,16 +306,18 @@ def put(self, data: Any = None, *, kind: str = "", order: int = 0) -> "Q": Args: data (Any, optional): message data. Defaults to `None`. + kind (str, optional): kind of message. Defaults to `""`. + order (int, optional): message order. Defaults to `0`. Returns: Self: self for chaining """ if isinstance(data, Msg): - self.q.put(data) + self._q.put(data) else: - self.q.put(Msg(data=data, kind=kind, order=order)) + self._q.put(Msg(data=data, kind=kind, order=order)) return self def end(self) -> "Q": @@ -355,7 +326,7 @@ def end(self) -> "Q": Returns: Self: self for chaining """ - self.q.put(END_MSG) + self._q.put(END_MSG) return self def stop(self, workers: Union[Worker, Sequence[Worker]]) -> "Q": @@ -376,3 +347,86 @@ def stop(self, workers: Union[Worker, Sequence[Worker]]) -> "Q": task.join() return self + + +def run(task: Task, *args: Any, **kwargs: Any) -> Worker: + """Run a function as a subprocess. + + Args: + task (Task): function to run in each subprocess + + *args (Any): additional positional arguments to `task`. + + **kwargs (Any): additional keyword arguments to `task`. + + Returns: + Worker: worker started in a subprocess + + .. changed:: 2.0.4 + This function now returns a `Worker` instead of a `Process`. + """ + return Worker.process(task, *args, **kwargs) + + +def run_thread(task: Task, *args: Any, **kwargs: Any) -> Worker: + """Run a function as a thread. + + Args: + task (Task): function to run in each thread + + *args (Any): additional positional arguments to `task`. + + **kwargs (Any): additional keyword arguments to `task`. + + Returns: + Worker: worker started in a thread + + .. changed:: 2.0.4 + This function now returns a `Worker` instead of a `Thread`. + """ + return Worker.thread(task, *args, **kwargs) + + +def map( + task: Task, + *args: Iterable[Any], + num: Optional[int] = None, + kind: ContextName = "process", +) -> Iterator[Any]: + """Call a function with arguments using multiple workers. + + Args: + func (Callable): function to call + + *args (list[Any]): arguments to `func`. If multiple lists are provided, + they will be passed to `zip` first. + + num (int, optional): number of workers. If `None`, `NUM_CPUS` or + `NUM_THREADS` will be used as appropriate. Defaults to `None`. + + kind (ContextName, optional): execution context to use. + Defaults to `"process"`. + + Yields: + Any: results from applying the function to the arguments + """ + q, out = Q(kind=kind), Q(kind=kind) + + def worker(_q: Q, _out: Q) -> None: + """Internal call to `func`.""" + for msg in _q.sorted(): + _out.put(data=task(*msg.data), order=msg.order) + + if kind == "process": + workers = [Worker.process(worker, q, out) for _ in range(num or NUM_CPUS)] + elif kind == "thread": + workers = [Worker.thread(worker, q, out) for _ in range(num or NUM_THREADS)] + else: # pragma: no cover + raise ValueError(f"Unknown worker context: {kind}") + + for order, value in enumerate(zip(*args)): + q.put(value, order=order) + q.stop(workers) + + for msg in out.end().sorted(): + yield msg.data diff --git a/test/test_ezq.py b/test/test_ezq.py index 436eaf8..c8c3480 100644 --- a/test/test_ezq.py +++ b/test/test_ezq.py @@ -5,7 +5,6 @@ # native import operator from typing import Callable -from time import sleep # pkg import ezq @@ -17,7 +16,8 @@ def test_q_wrapper() -> None: q.put(1) q.put(ezq.Msg(data=2)) - assert q.qsize() == 2, "expected function to be delegated to queue" + if not ezq.IS_MACOS: + assert q.qsize() == 2, "expected function to be delegated to queue" want = [1, 2] have = [msg.data for msg in q.items(cache=True)] @@ -52,8 +52,15 @@ def test_run_processes() -> None: q, out = ezq.Q(), ezq.Q() workers = [ezq.run(worker_sum, q, out, num=i) for i in range(ezq.NUM_CPUS)] + def wrap_lambda(i: int) -> Callable[[], int]: + """Wrap a number in a lambda so thread-context works.""" + return lambda: i + for num in range(n_msg): - q.put(ezq.Msg(data=num)) + q.put(wrap_lambda(num)) + + # for num in range(n_msg): + # q.put(ezq.Msg(data=num)) q.stop(workers) want = sum(range(n_msg)) @@ -92,5 +99,5 @@ def test_map() -> None: have = list(ezq.map(operator.add, left, right)) assert have == want, "expected subprocesses to work" - have = list(ezq.map(operator.add, left, right, kind="thread")) - assert have == want, "expected threads to work" + # have = list(ezq.map(operator.add, left, right, kind="thread")) + # assert have == want, "expected threads to work" diff --git a/test/test_iter.py b/test/test_iter.py index a1e3b8a..8aa0207 100644 --- a/test/test_iter.py +++ b/test/test_iter.py @@ -22,7 +22,8 @@ def test_iter_q() -> None: for _ in range(num): q.put(1) - assert q.qsize() == num, "expect all messages queued" + if not ezq.IS_MACOS: + assert q.qsize() == num, "expect all messages queued" total = sum(msg.data for msg in q.items()) assert num == total, "expect iterator to get all messages"