Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python blocking progress improvements #312

Draft
wants to merge 5 commits into
base: branch-0.41
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/ucxx/ucxx/_lib_async/application_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ def continuous_ucx_progress(self, event_loop=None):
if loop in ProgressTasks:
return # Progress has already been guaranteed for the current event loop

logger.info(f"Starting progress in '{self.progress_mode}' mode")

if self.progress_mode == "thread":
task = ThreadMode(self.worker, loop, polling_mode=False)
elif self.progress_mode == "thread-polling":
Expand Down
55 changes: 35 additions & 20 deletions python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
from ucxx._lib.libucxx import UCXWorker


def _cancel_task(event_loop, task):
if task is not None:
try:
task.cancel()
event_loop.run_until_complete(task)
except asyncio.exceptions.CancelledError:
pass


class ProgressTask(object):
def __init__(self, worker, event_loop):
"""Creates a task that keeps calling worker.progress()
Expand All @@ -28,20 +37,20 @@ def __init__(self, worker, event_loop):
"""
self.worker = worker
self.event_loop = event_loop
self.asyncio_task = None
self.asyncio_tasks = dict()

event_loop_close_original = self.event_loop.close

def _event_loop_close(event_loop_close_original, *args, **kwargs):
if not self.event_loop.is_closed() and self.asyncio_task is not None:
try:
self.asyncio_task.cancel()
self.event_loop.run_until_complete(self.asyncio_task)
except asyncio.exceptions.CancelledError:
pass
finally:
self.asyncio_task = None
event_loop_close_original(*args, **kwargs)
if self.event_loop.is_closed():
return

try:
for task in self.asyncio_tasks.values():
_cancel_task(event_loop, task)
finally:
event_loop_close_original(*args, **kwargs)
self.asyncio_tasks = None

self.event_loop.close = partial(_event_loop_close, event_loop_close_original)

Expand Down Expand Up @@ -70,7 +79,7 @@ def __init__(self, worker, event_loop, polling_mode=False):
class PollingMode(ProgressTask):
def __init__(self, worker, event_loop):
super().__init__(worker, event_loop)
self.asyncio_task = event_loop.create_task(self._progress_task())
self.asyncio_tasks["progress"] = event_loop.create_task(self._progress_task())
self.worker.init_blocking_progress_mode()

async def _progress_task(self):
Expand Down Expand Up @@ -132,9 +141,12 @@ def __init__(
weakref.finalize(self, event_loop.remove_reader, epoll_fd)
weakref.finalize(self, self.rsock.close)

self.blocking_asyncio_task = None
self.armed = False
self.asyncio_tasks["arm"] = self.event_loop.create_task(self._arm_worker())
self.last_progress_time = time.monotonic() - self._progress_timeout
self.asyncio_task = event_loop.create_task(self._progress_with_timeout())
self.asyncio_tasks["progress"] = event_loop.create_task(
self._progress_with_timeout()
)

def _fd_reader_callback(self):
"""Schedule new progress task upon worker event.
Expand All @@ -144,10 +156,9 @@ def _fd_reader_callback(self):
"""
self.worker.progress()

# Notice, we can safely overwrite `self.blocking_asyncio_task`
# since previous arm task is finished by now.
assert self.blocking_asyncio_task is None or self.blocking_asyncio_task.done()
self.blocking_asyncio_task = self.event_loop.create_task(self._arm_worker())
assert not self.armed

self.armed = False

async def _arm_worker(self):
"""Progress the worker and rearm.
Expand All @@ -161,6 +172,9 @@ async def _arm_worker(self):
# so that the asyncio's next state is epoll wait.
# See <https://github.com/rapidsai/ucx-py/issues/413>
while True:
if self.armed:
continue

self.last_progress_time = time.monotonic()
self.worker.progress()

Expand Down Expand Up @@ -193,10 +207,11 @@ async def _progress_with_timeout(self):
# seem to respect timeout with `asyncio.wait_for`, thus we cancel
# it here instead. It will get recreated after a new event on
# `worker.epoll_file_descriptor`.
if self.blocking_asyncio_task is not None:
self.blocking_asyncio_task.cancel()
arm_task = self.asyncio_tasks["arm"]
if arm_task is not None:
arm_task.cancel()
try:
await self.blocking_asyncio_task
await arm_task
except asyncio.exceptions.CancelledError:
pass

Expand Down
Loading