Skip to content

Commit

Permalink
Refactor pipeline thread executor creation
Browse files Browse the repository at this point in the history
Move the creation of executor to builder so that later
it can reuse the executor
  • Loading branch information
mthrok committed Nov 24, 2024
1 parent bebb00a commit fe96149
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
8 changes: 6 additions & 2 deletions src/spdl/pipeline/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Iterable,
Sequence,
)
from concurrent.futures import Executor
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager
from functools import partial
from typing import TypeVar
Expand Down Expand Up @@ -893,4 +893,8 @@ def build(self, *, num_threads: int | None = None) -> Pipeline:
]
num_threads = max(concurrencies) if concurrencies else 4
assert num_threads is not None
return Pipeline(coro, queues, num_threads, desc=self._get_desc())
executor = ThreadPoolExecutor(
max_workers=num_threads,
thread_name_prefix="spdl_",
)
return Pipeline(coro, queues, executor, desc=self._get_desc())
21 changes: 10 additions & 11 deletions src/spdl/pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import warnings
from asyncio import AbstractEventLoop, Queue as AsyncQueue
from collections.abc import Coroutine, Iterator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from enum import IntEnum
from threading import Event as SyncEvent, Thread
Expand All @@ -36,9 +37,13 @@
# This class has a bit exessive debug logs, because it is tricky to debug
# it from the outside.
class _EventLoop:
def __init__(self, coro: Coroutine[None, None, None], num_threads: int) -> None:
def __init__(
self,
coro: Coroutine[None, None, None],
executor: ThreadPoolExecutor,
) -> None:
self._coro = coro
self._num_threads = num_threads
self._executor = executor

self._loop: AbstractEventLoop | None = None

Expand All @@ -62,14 +67,8 @@ def __str__(self) -> str:

async def _execute_task(self) -> None:
_LG.debug("The event loop thread coroutine is started.")
_LG.debug("Initializing the thread pool of size=%d.", self._num_threads)
self._loop = asyncio.get_running_loop()
self._loop.set_default_executor(
concurrent.futures.ThreadPoolExecutor(
max_workers=self._num_threads,
thread_name_prefix="spdl_",
)
)
self._loop.set_default_executor(self._executor)

_LG.debug("Starting the task.")

Expand Down Expand Up @@ -263,7 +262,7 @@ def __init__(
self,
coro: Coroutine[None, None, None],
queues: list[AsyncQueue],
num_threads: int,
executor: ThreadPoolExecutor,
*,
desc: list[str],
) -> None:
Expand All @@ -272,7 +271,7 @@ def __init__(
self._str: str = "\n".join([repr(self), *desc])

self._output_queue: AsyncQueue = queues[-1]
self._event_loop = _EventLoop(coro, num_threads)
self._event_loop = _EventLoop(coro, executor)
self._event_loop_state: _EventLoopState = _EventLoopState.NOT_STARTED

try:
Expand Down

0 comments on commit fe96149

Please sign in to comment.