Skip to content

Commit

Permalink
Add timeout support to supervisors (#8649)
Browse files Browse the repository at this point in the history
  • Loading branch information
zanieb authored Feb 27, 2023
1 parent 4939ae8 commit e39b8fe
Show file tree
Hide file tree
Showing 4 changed files with 568 additions and 64 deletions.
199 changes: 136 additions & 63 deletions src/prefect/_internal/concurrency/supervisors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,17 @@
)

from prefect._internal.concurrency.event_loop import call_soon_in_loop, get_running_loop
from prefect._internal.concurrency.timeouts import (
cancel_async_after,
cancel_async_at,
cancel_sync_after,
cancel_sync_at,
get_deadline,
)
from prefect.logging import get_logger

T = TypeVar("T")
Fn = TypeVar("Fn", bound=Callable)

# Python uses duck typing for futures; asyncio/threaded futures do not share a base
AnyFuture = Any
Expand Down Expand Up @@ -109,12 +117,12 @@ def _run_sync(self):
result = self.context.run(self.fn, *self.args, **self.kwargs)
except BaseException as exc:
self.future.set_exception(exc)
logger.debug("Encountered exception in work item %r", self)
# Prevent reference cycle in `exc`
self = None
del self
else:
self.future.set_result(result)

logger.debug("Finished work item %r", self)
logger.debug("Finished work item %r", self)

async def _run_async(self):
loop = asyncio.get_running_loop()
Expand All @@ -131,11 +139,11 @@ async def _run_async(self):
except BaseException as exc:
self.future.set_exception(exc)
# Prevent reference cycle in `exc`
self = None
logger.debug("Encountered exception %s in work item %r", exc, self)
del self
else:
self.future.set_result(result)

logger.debug("Finished work item %r", self)
logger.debug("Finished work item %r", self)


class Supervisor(abc.ABC, Generic[T]):
Expand All @@ -145,12 +153,21 @@ class Supervisor(abc.ABC, Generic[T]):
Work sent to the supervisor will be executed when the owner waits for the result.
"""

def __init__(self, submit_fn: Callable[..., concurrent.futures.Future]) -> None:
def __init__(
self,
submit_fn: Callable[..., concurrent.futures.Future],
timeout: Optional[float] = None,
) -> None:
self._submit_fn = submit_fn
self._owner_thread = threading.current_thread()
self._future: AnyFuture = None
self._future_thread = concurrent.futures.Future()
self._future: Optional[concurrent.futures.Future] = None
self._future_call: Tuple[Callable, Tuple, Dict] = None
self._timeout: Optional[float] = timeout

# Delayed computation
self._worker_thread_future = concurrent.futures.Future()
self._deadline_future = concurrent.futures.Future()

logger.debug("Created supervisor %r", self)

def submit(self, __fn, *args, **kwargs) -> concurrent.futures.Future:
Expand All @@ -165,17 +182,16 @@ def submit(self, __fn, *args, **kwargs) -> concurrent.futures.Future:
if self._future:
raise RuntimeError("A supervisor can only monitor a single future")

@functools.wraps(__fn)
def _call_in_supervised_thread():
# Capture the worker thread before running the function
thread = threading.current_thread()
self._future_thread.set_result(thread)
return __fn(*args, **kwargs)

with set_supervisor(self):
future = self._submit_fn(_call_in_supervised_thread)
future = self._submit_fn(self._add_supervision(__fn), *args, **kwargs)

self._set_future(future)
logger.debug(
"Call to %r submitted to supervisor %r tracked by future %r",
__fn.__name__,
self,
future,
)
self._future = future
self._future_call = (__fn, args, kwargs)

return future
Expand All @@ -186,24 +202,50 @@ def send_call(self, __fn, *args, **kwargs) -> concurrent.futures.Future:
"""
work_item = WorkItem.from_call(__fn, *args, **kwargs)
self._put_work_item(work_item)
logger.debug("Sent work item to %r", self)
logger.debug("Sent work item to supervisor %r", self)
return work_item.future

@property
def owner_thread(self):
return self._owner_thread

@abc.abstractmethod
def _put_work_item(self, work_item: WorkItem) -> None:
def _add_supervision(self, fn: Fn) -> Fn:
"""
Add a work item to the supervisor. Used by `send_call`.
Attach supervision to a callable.
"""
raise NotImplementedError()

@functools.wraps(fn)
def _call_in_supervised_thread(*args, **kwargs):
# Capture the worker thread before running the function
thread = threading.current_thread()
self._worker_thread_future.set_result(thread)

# Set the execution deadline
self._deadline_future.set_result(get_deadline(self._timeout))

# Enforce timeouts on synchronous execution
with cancel_sync_after(self._timeout):
retval = fn(*args, **kwargs)

# Enforce timeouts on asynchronous execution
if inspect.isawaitable(retval):

async def _call_in_supervised_coro():
# TODO: Technnically, this should use the deadline but it is not
# clear how to mix sync/async deadlines yet
with cancel_async_after(self._timeout):
return await retval

return _call_in_supervised_coro()

return retval

return _call_in_supervised_thread

@abc.abstractmethod
def _set_future(self, future: AnyFuture) -> None:
def _put_work_item(self, work_item: WorkItem) -> None:
"""
Assign a future to the supervisor.
Add a work item to the supervisor. Used by `send_call`.
"""
raise NotImplementedError()

Expand All @@ -229,73 +271,104 @@ def __repr__(self) -> str:


class SyncSupervisor(Supervisor[T]):
def __init__(self, submit_fn: Callable[..., concurrent.futures.Future]) -> None:
super().__init__(submit_fn=submit_fn)
def __init__(
self,
submit_fn: Callable[..., concurrent.futures.Future],
timeout: Optional[float] = None,
) -> None:
super().__init__(submit_fn=submit_fn, timeout=timeout)
self._queue: queue.Queue = queue.Queue()
self._future: Optional[concurrent.futures.Future] = None

def _put_work_item(self, work_item: WorkItem):
self._queue.put_nowait(work_item)

def _set_future(self, future: concurrent.futures.Future):
future.add_done_callback(lambda _: self._queue.put_nowait(None))
self._future = future

def result(self) -> T:
if not self._future:
raise ValueError("No future being supervised.")

def _watch_for_work_items(self):
logger.debug("Watching for work sent to supervisor %r", self)
while True:
logger.debug("Watching for work sent to %r", self)
work_item: WorkItem = self._queue.get()
if work_item is None:
break

work_item.run()
del work_item

logger.debug("Retrieving result for %r", self)
def result(self) -> T:
if not self._future:
raise ValueError("No future being supervised.")

# Stop watching for work once the future is done
self._future.add_done_callback(lambda _: self._queue.put_nowait(None))

# Cancel work sent to the supervisor if the future exceeds its timeout
deadline = self._deadline_future.result()
try:
with cancel_sync_at(deadline) as ctx:
self._watch_for_work_items()
except TimeoutError:
# Timeouts will be generally be raised on future result retrieval but
# if its not our timeout it should be reraised
if not ctx.cancelled:
raise

logger.debug("Supervisor %r retrieving result of future %r", self, self._future)
return self._future.result()


class AsyncSupervisor(Supervisor[T]):
def __init__(self, submit_fn: Callable[..., concurrent.futures.Future]) -> None:
super().__init__(submit_fn=submit_fn)
def __init__(
self,
submit_fn: Callable[..., concurrent.futures.Future],
timeout: Optional[float] = None,
) -> None:
super().__init__(submit_fn=submit_fn, timeout=timeout)
self._queue: asyncio.Queue = asyncio.Queue()
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self._future: Optional[asyncio.Future] = None

def _set_future(
self, future: Union[asyncio.Future, concurrent.futures.Future]
) -> None:
# Ensure we're working with an asyncio future internally
if isinstance(future, concurrent.futures.Future):
future = asyncio.wrap_future(future)

future.add_done_callback(
lambda _: call_soon_in_loop(self._loop, self._queue.put_nowait, None)
)
self._future = future

def _put_work_item(self, work_item: WorkItem):
# We must put items in the queue from the event loop that owns it
call_soon_in_loop(self._loop, self._queue.put_nowait, work_item)

async def result(self) -> T:
if not self._future:
raise ValueError("No future being supervised.")

async def _watch_for_work_items(self):
logger.debug("Watching for work sent to %r", self)
while True:
logger.debug("Watching for work sent to %r", self)
work_item: WorkItem = await self._queue.get()
if work_item is None:
break

result = work_item.run()
if inspect.isawaitable(result):
await result
# TODO: We could use `cancel_sync_after` to guard this call as a sync call
# could block a timeout here
retval = work_item.run()

if inspect.isawaitable(retval):
await retval

del work_item

logger.debug("Retrieving result for %r", self)
return await self._future
async def result(self) -> T:
if not self._future:
raise ValueError("No future being supervised.")

future = (
# Convert to an asyncio future if necessary for non-blocking wait
asyncio.wrap_future(self._future)
if isinstance(self._future, concurrent.futures.Future)
else self._future
)

# Stop watching for work once the future is done
future.add_done_callback(
lambda _: call_soon_in_loop(self._loop, self._queue.put_nowait, None)
)

# Cancel work sent to the supervisor if the future exceeds its timeout
deadline = await asyncio.wrap_future(self._deadline_future)
try:
with cancel_async_at(deadline) as ctx:
await self._watch_for_work_items()
except TimeoutError:
# Timeouts will be re-raised on future result retrieval
if not ctx.cancelled:
raise

logger.debug("Supervisor %r retrieving result of future %r", self, future)
return await future
Loading

0 comments on commit e39b8fe

Please sign in to comment.