Skip to content

Commit

Permalink
Add parallel run chords to facilitate chords of tasks inside other tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
hylje committed Jul 12, 2024
1 parent 1096864 commit edaae86
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 26 deletions.
88 changes: 62 additions & 26 deletions hai/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,36 +36,14 @@ def failed_task_names(self) -> Set[str]:
RT = TypeVar("RT")


class ParallelRun:
"""
Encapsulates running several functions in parallel.
Due to the GIL, this is mainly useful for IO-bound threads, such as
downloading stuff from the Internet, or writing big buffers of data.
class BaseParallelRun:
pool: ThreadPool

It's recommended to use this in a `with` block, to ensure the thread pool
gets properly shut down.
"""

def __init__(self, parallelism: Optional[int] = None) -> None:
self.pool = ThreadPool(
processes=(parallelism or (int(os.cpu_count() or 1) * 2)),
)
def __init__(self) -> None:
self.task_complete_event = threading.Event()
self.tasks: List[ApplyResult[Any]] = []
self.completed_tasks: WeakSet[ApplyResult[Any]] = WeakSet()

def __enter__(self) -> "ParallelRun":
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def]
self.pool.terminate()

def __del__(self) -> None: # opportunistic cleanup
if self.pool:
self.pool.terminate()
self.pool = None # type: ignore[assignment]

def _set_task_complete_event(self, value: Any = None) -> None:
self.task_complete_event.set()

Expand Down Expand Up @@ -108,7 +86,7 @@ def wait(
self,
fail_fast: bool = True,
interval: float = 0.5,
callback: Optional[Callable[["ParallelRun"], None]] = None,
callback: Optional[Callable[["BaseParallelRun"], None]] = None,
max_wait: Optional[float] = None,
) -> List["ApplyResult[Any]"]:
"""
Expand Down Expand Up @@ -238,3 +216,61 @@ def exceptions(self) -> Dict[str, Exception]:
for t in self.tasks
if t.ready() and not t._success # type: ignore[attr-defined]
}


class ChordParallelRun(BaseParallelRun):
"""Parallel run that inherits the thread pool from another parallel run
instead of managing its own
"""

def __init__(self, main_run: "ParallelRun") -> None:
super().__init__()
self.main_run = main_run
self.pool = main_run.pool

def ready(self) -> bool:
return all(t.ready() for t in self.tasks)


class ParallelRun(BaseParallelRun):
"""
Encapsulates running several functions in parallel.
Due to the GIL, this is mainly useful for IO-bound threads, such as
downloading stuff from the Internet, or writing big buffers of data.
It's recommended to use this in a `with` block, to ensure the thread pool
gets properly shut down.
"""

def __init__(self, parallelism: Optional[int] = None) -> None:
super().__init__()
self.pool = ThreadPool(
processes=(parallelism or (int(os.cpu_count() or 1) * 2)),
)

def chord(self) -> ChordParallelRun:
"""Return a ChordParallelRun that can run a separate set of tasks using
the same thread pool as this ParallelRun. This avoids creating nested thread pools
for tasks that can themselves utilize multiple threads.
Tasks added to chords are handled separately by the chord and are not also
added to this ParallelRun's tasks. ParallelRun.wait() will not wait for them
to complete. Instead, call the wait() method on each ChordParallelRun separately
to wait for that chord to complete.
Exiting this ParallelRun will terminate the thread pool even if there are
incomplete chords.
"""
return ChordParallelRun(main_run=self)

def __enter__(self) -> "ParallelRun":
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def]
self.pool.terminate()

def __del__(self) -> None: # opportunistic cleanup
if self.pool:
self.pool.terminate()
self.pool = None # type: ignore[assignment]
55 changes: 55 additions & 0 deletions hai_tests/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,26 @@ def return_true():
return True


def run_chord(parallel: ParallelRun):
chord = parallel.chord()
chord.add_task(return_true, name="task_0")
chord.add_task(return_true, name="task_1")
chord.wait()
return chord.return_values


def run_chord_with_error(parallel: ParallelRun):
chord = parallel.chord()
chord.add_task(return_true, name="task_0")
chord.add_task(agh, name="task_1")
try:
chord.wait()
except TaskFailed:
# Wait for the remaining tasks to complete
chord.wait(fail_fast=False)
return chord.return_values


def test_parallel_crash():
with ParallelRun() as parallel:
parallel.add_task(return_true)
Expand Down Expand Up @@ -101,3 +121,38 @@ def test_parallel_max_wait():
parallel.add_task(time.sleep, args=(1,))
with pytest.raises(TimeoutError):
parallel.wait(interval=0.1, max_wait=0.5)


def test_parallel_chord_task():
with ParallelRun() as parallel:
for i in range(3):
parallel.add_task(
run_chord,
kwargs={"parallel": parallel},
name="chord_%d" % i,
)
parallel.wait()
assert parallel.return_values == {
"chord_0": {"task_0": True, "task_1": True},
"chord_1": {"task_0": True, "task_1": True},
"chord_2": {"task_0": True, "task_1": True},
}


def test_parallel_chord_task_fail():
"""
Test that failures inside a chord don't interrupt other chords or tasks
"""
with ParallelRun() as parallel:
parallel.add_task(return_true)
parallel.add_task(run_chord, kwargs={"parallel": parallel}, name="chord_ok")
parallel.add_task(
run_chord_with_error,
kwargs={"parallel": parallel},
name="chord_fail",
)
parallel.wait()
assert parallel.return_values["return_true"]
assert parallel.return_values["chord_ok"] == {"task_0": True, "task_1": True}
assert parallel.return_values["chord_fail"]["task_0"]
assert isinstance(parallel.return_values["chord_fail"]["task_1"], RuntimeError)

0 comments on commit edaae86

Please sign in to comment.