diff --git a/hai/parallel.py b/hai/parallel.py index f20294f..4a2f30a 100644 --- a/hai/parallel.py +++ b/hai/parallel.py @@ -36,36 +36,14 @@ def failed_task_names(self) -> Set[str]: RT = TypeVar("RT") -class ParallelRun: - """ - Encapsulates running several functions in parallel. - - Due to the GIL, this is mainly useful for IO-bound threads, such as - downloading stuff from the Internet, or writing big buffers of data. +class BaseParallelRun: + pool: ThreadPool - It's recommended to use this in a `with` block, to ensure the thread pool - gets properly shut down. - """ - - def __init__(self, parallelism: Optional[int] = None) -> None: - self.pool = ThreadPool( - processes=(parallelism or (int(os.cpu_count() or 1) * 2)), - ) + def __init__(self) -> None: self.task_complete_event = threading.Event() self.tasks: List[ApplyResult[Any]] = [] self.completed_tasks: WeakSet[ApplyResult[Any]] = WeakSet() - def __enter__(self) -> "ParallelRun": - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def] - self.pool.terminate() - - def __del__(self) -> None: # opportunistic cleanup - if self.pool: - self.pool.terminate() - self.pool = None # type: ignore[assignment] - def _set_task_complete_event(self, value: Any = None) -> None: self.task_complete_event.set() @@ -108,7 +86,7 @@ def wait( self, fail_fast: bool = True, interval: float = 0.5, - callback: Optional[Callable[["ParallelRun"], None]] = None, + callback: Optional[Callable[["BaseParallelRun"], None]] = None, max_wait: Optional[float] = None, ) -> List["ApplyResult[Any]"]: """ @@ -238,3 +216,61 @@ def exceptions(self) -> Dict[str, Exception]: for t in self.tasks if t.ready() and not t._success # type: ignore[attr-defined] } + + +class ChordParallelRun(BaseParallelRun): + """Parallel run that inherits the thread pool from another parallel run + instead of managing its own + """ + + def __init__(self, main_run: "ParallelRun") -> None: + super().__init__() + self.main_run = main_run + self.pool = main_run.pool + + def ready(self) -> bool: + return all(t.ready() for t in self.tasks) + + +class ParallelRun(BaseParallelRun): + """ + Encapsulates running several functions in parallel. + + Due to the GIL, this is mainly useful for IO-bound threads, such as + downloading stuff from the Internet, or writing big buffers of data. + + It's recommended to use this in a `with` block, to ensure the thread pool + gets properly shut down. + """ + + def __init__(self, parallelism: Optional[int] = None) -> None: + super().__init__() + self.pool = ThreadPool( + processes=(parallelism or (int(os.cpu_count() or 1) * 2)), + ) + + def chord(self) -> ChordParallelRun: + """Return a ChordParallelRun that can run a separate set of tasks using + the same thread pool as this ParallelRun. This avoids creating nested thread pools + for tasks that can themselves utilize multiple threads. + + Tasks added to chords are handled separately by the chord and are not also + added to this ParallelRun's tasks. ParallelRun.wait() will not wait for them + to complete. Instead, call the wait() method on each ChordParallelRun separately + to wait for that chord to complete. + + Exiting this ParallelRun will terminate the thread pool even if there are + incomplete chords. + """ + return ChordParallelRun(main_run=self) + + def __enter__(self) -> "ParallelRun": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def] + self.pool.terminate() + + def __del__(self) -> None: # opportunistic cleanup + if self.pool: + self.pool.terminate() + self.pool = None # type: ignore[assignment] diff --git a/hai_tests/test_parallel.py b/hai_tests/test_parallel.py index 4e3c443..e30e881 100644 --- a/hai_tests/test_parallel.py +++ b/hai_tests/test_parallel.py @@ -15,6 +15,26 @@ def return_true(): return True +def run_chord(parallel: ParallelRun): + chord = parallel.chord() + chord.add_task(return_true, name="task_0") + chord.add_task(return_true, name="task_1") + chord.wait() + return chord.return_values + + +def run_chord_with_error(parallel: ParallelRun): + chord = parallel.chord() + chord.add_task(return_true, name="task_0") + chord.add_task(agh, name="task_1") + try: + chord.wait() + except TaskFailed: + # Wait for the remaining tasks to complete + chord.wait(fail_fast=False) + return chord.return_values + + def test_parallel_crash(): with ParallelRun() as parallel: parallel.add_task(return_true) @@ -101,3 +121,38 @@ def test_parallel_max_wait(): parallel.add_task(time.sleep, args=(1,)) with pytest.raises(TimeoutError): parallel.wait(interval=0.1, max_wait=0.5) + + +def test_parallel_chord_task(): + with ParallelRun() as parallel: + for i in range(3): + parallel.add_task( + run_chord, + kwargs={"parallel": parallel}, + name="chord_%d" % i, + ) + parallel.wait() + assert parallel.return_values == { + "chord_0": {"task_0": True, "task_1": True}, + "chord_1": {"task_0": True, "task_1": True}, + "chord_2": {"task_0": True, "task_1": True}, + } + + +def test_parallel_chord_task_fail(): + """ + Test that failures inside a chord don't interrupt other chords or tasks + """ + with ParallelRun() as parallel: + parallel.add_task(return_true) + parallel.add_task(run_chord, kwargs={"parallel": parallel}, name="chord_ok") + parallel.add_task( + run_chord_with_error, + kwargs={"parallel": parallel}, + name="chord_fail", + ) + parallel.wait() + assert parallel.return_values["return_true"] + assert parallel.return_values["chord_ok"] == {"task_0": True, "task_1": True} + assert parallel.return_values["chord_fail"]["task_0"] + assert isinstance(parallel.return_values["chord_fail"]["task_1"], RuntimeError)