Skip to content

Commit

Permalink
Add "done_callback" to every task for debugging
Browse files Browse the repository at this point in the history
Currently, tasks that throw exceptions silently stop. This commit makes
it so that all tasks are created with a done_callback that checks
whether they stopped due to an exception.
  • Loading branch information
pinkwah committed Dec 13, 2023
1 parent c3d9e34 commit b2b2e48
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
37 changes: 36 additions & 1 deletion src/ert/async_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,44 @@
from __future__ import annotations

import asyncio
import sys
from traceback import print_exception
from typing import Any, Coroutine, Generator, TypeVar, Union

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)


def new_event_loop() -> asyncio.AbstractEventLoop:
loop = asyncio.new_event_loop()
loop.set_task_factory(_create_task)
return loop


def get_event_loop() -> asyncio.AbstractEventLoop:
try:
return asyncio.get_event_loop()
except RuntimeError:
asyncio.set_event_loop(asyncio.new_event_loop())
asyncio.set_event_loop(new_event_loop())
return asyncio.get_event_loop()


def _create_task(
loop: asyncio.AbstractEventLoop,
coro: Union[Coroutine[Any, Any, _T], Generator[Any, None, _T]],
) -> asyncio.Task[_T]:
task = asyncio.Task(coro, loop=loop)
task.add_done_callback(_done_callback)
return task


def _done_callback(task: asyncio.Task[_T_co]) -> None:
assert task.done()
try:
if (exc := task.exception()) is None:
return

print(f"Exception during {task.get_name()}", file=sys.stderr)
print_exception(exc, file=sys.stderr)
except asyncio.CancelledError:
pass
4 changes: 2 additions & 2 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from cloudevents.http.event import CloudEvent

from ert.async_utils import get_event_loop
from ert.async_utils import get_event_loop, new_event_loop
from ert.config.parsing.queue_system import QueueSystem
from ert.ensemble_evaluator import identifiers
from ert.job_queue import JobQueue
Expand Down Expand Up @@ -117,7 +117,7 @@ def _evaluate(self) -> None:
a coroutine
"""
# Get a fresh eventloop
asyncio.set_event_loop(asyncio.new_event_loop())
asyncio.set_event_loop(new_event_loop())

if self._config is None:
raise ValueError("no config")
Expand Down

0 comments on commit b2b2e48

Please sign in to comment.