diff --git a/src/prefect/_internal/concurrency/supervisors.py b/src/prefect/_internal/concurrency/supervisors.py index 18366d4f74ab..0ca7e4ab9a1d 100644 --- a/src/prefect/_internal/concurrency/supervisors.py +++ b/src/prefect/_internal/concurrency/supervisors.py @@ -26,9 +26,17 @@ ) from prefect._internal.concurrency.event_loop import call_soon_in_loop, get_running_loop +from prefect._internal.concurrency.timeouts import ( + cancel_async_after, + cancel_async_at, + cancel_sync_after, + cancel_sync_at, + get_deadline, +) from prefect.logging import get_logger T = TypeVar("T") +Fn = TypeVar("Fn", bound=Callable) # Python uses duck typing for futures; asyncio/threaded futures do not share a base AnyFuture = Any @@ -109,12 +117,12 @@ def _run_sync(self): result = self.context.run(self.fn, *self.args, **self.kwargs) except BaseException as exc: self.future.set_exception(exc) + logger.debug("Encountered exception in work item %r", self) # Prevent reference cycle in `exc` - self = None + del self else: self.future.set_result(result) - - logger.debug("Finished work item %r", self) + logger.debug("Finished work item %r", self) async def _run_async(self): loop = asyncio.get_running_loop() @@ -131,11 +139,11 @@ async def _run_async(self): except BaseException as exc: self.future.set_exception(exc) # Prevent reference cycle in `exc` - self = None + logger.debug("Encountered exception %s in work item %r", exc, self) + del self else: self.future.set_result(result) - - logger.debug("Finished work item %r", self) + logger.debug("Finished work item %r", self) class Supervisor(abc.ABC, Generic[T]): @@ -145,12 +153,21 @@ class Supervisor(abc.ABC, Generic[T]): Work sent to the supervisor will be executed when the owner waits for the result. """ - def __init__(self, submit_fn: Callable[..., concurrent.futures.Future]) -> None: + def __init__( + self, + submit_fn: Callable[..., concurrent.futures.Future], + timeout: Optional[float] = None, + ) -> None: self._submit_fn = submit_fn self._owner_thread = threading.current_thread() - self._future: AnyFuture = None - self._future_thread = concurrent.futures.Future() + self._future: Optional[concurrent.futures.Future] = None self._future_call: Tuple[Callable, Tuple, Dict] = None + self._timeout: Optional[float] = timeout + + # Delayed computation + self._worker_thread_future = concurrent.futures.Future() + self._deadline_future = concurrent.futures.Future() + logger.debug("Created supervisor %r", self) def submit(self, __fn, *args, **kwargs) -> concurrent.futures.Future: @@ -165,17 +182,16 @@ def submit(self, __fn, *args, **kwargs) -> concurrent.futures.Future: if self._future: raise RuntimeError("A supervisor can only monitor a single future") - @functools.wraps(__fn) - def _call_in_supervised_thread(): - # Capture the worker thread before running the function - thread = threading.current_thread() - self._future_thread.set_result(thread) - return __fn(*args, **kwargs) - with set_supervisor(self): - future = self._submit_fn(_call_in_supervised_thread) + future = self._submit_fn(self._add_supervision(__fn), *args, **kwargs) - self._set_future(future) + logger.debug( + "Call to %r submitted to supervisor %r tracked by future %r", + __fn.__name__, + self, + future, + ) + self._future = future self._future_call = (__fn, args, kwargs) return future @@ -186,24 +202,50 @@ def send_call(self, __fn, *args, **kwargs) -> concurrent.futures.Future: """ work_item = WorkItem.from_call(__fn, *args, **kwargs) self._put_work_item(work_item) - logger.debug("Sent work item to %r", self) + logger.debug("Sent work item to supervisor %r", self) return work_item.future @property def owner_thread(self): return self._owner_thread - @abc.abstractmethod - def _put_work_item(self, work_item: WorkItem) -> None: + def _add_supervision(self, fn: Fn) -> Fn: """ - Add a work item to the supervisor. Used by `send_call`. + Attach supervision to a callable. """ - raise NotImplementedError() + + @functools.wraps(fn) + def _call_in_supervised_thread(*args, **kwargs): + # Capture the worker thread before running the function + thread = threading.current_thread() + self._worker_thread_future.set_result(thread) + + # Set the execution deadline + self._deadline_future.set_result(get_deadline(self._timeout)) + + # Enforce timeouts on synchronous execution + with cancel_sync_after(self._timeout): + retval = fn(*args, **kwargs) + + # Enforce timeouts on asynchronous execution + if inspect.isawaitable(retval): + + async def _call_in_supervised_coro(): + # TODO: Technnically, this should use the deadline but it is not + # clear how to mix sync/async deadlines yet + with cancel_async_after(self._timeout): + return await retval + + return _call_in_supervised_coro() + + return retval + + return _call_in_supervised_thread @abc.abstractmethod - def _set_future(self, future: AnyFuture) -> None: + def _put_work_item(self, work_item: WorkItem) -> None: """ - Assign a future to the supervisor. + Add a work item to the supervisor. Used by `send_call`. """ raise NotImplementedError() @@ -229,24 +271,20 @@ def __repr__(self) -> str: class SyncSupervisor(Supervisor[T]): - def __init__(self, submit_fn: Callable[..., concurrent.futures.Future]) -> None: - super().__init__(submit_fn=submit_fn) + def __init__( + self, + submit_fn: Callable[..., concurrent.futures.Future], + timeout: Optional[float] = None, + ) -> None: + super().__init__(submit_fn=submit_fn, timeout=timeout) self._queue: queue.Queue = queue.Queue() - self._future: Optional[concurrent.futures.Future] = None def _put_work_item(self, work_item: WorkItem): self._queue.put_nowait(work_item) - def _set_future(self, future: concurrent.futures.Future): - future.add_done_callback(lambda _: self._queue.put_nowait(None)) - self._future = future - - def result(self) -> T: - if not self._future: - raise ValueError("No future being supervised.") - + def _watch_for_work_items(self): + logger.debug("Watching for work sent to supervisor %r", self) while True: - logger.debug("Watching for work sent to %r", self) work_item: WorkItem = self._queue.get() if work_item is None: break @@ -254,48 +292,83 @@ def result(self) -> T: work_item.run() del work_item - logger.debug("Retrieving result for %r", self) + def result(self) -> T: + if not self._future: + raise ValueError("No future being supervised.") + + # Stop watching for work once the future is done + self._future.add_done_callback(lambda _: self._queue.put_nowait(None)) + + # Cancel work sent to the supervisor if the future exceeds its timeout + deadline = self._deadline_future.result() + try: + with cancel_sync_at(deadline) as ctx: + self._watch_for_work_items() + except TimeoutError: + # Timeouts will be generally be raised on future result retrieval but + # if its not our timeout it should be reraised + if not ctx.cancelled: + raise + + logger.debug("Supervisor %r retrieving result of future %r", self, self._future) return self._future.result() class AsyncSupervisor(Supervisor[T]): - def __init__(self, submit_fn: Callable[..., concurrent.futures.Future]) -> None: - super().__init__(submit_fn=submit_fn) + def __init__( + self, + submit_fn: Callable[..., concurrent.futures.Future], + timeout: Optional[float] = None, + ) -> None: + super().__init__(submit_fn=submit_fn, timeout=timeout) self._queue: asyncio.Queue = asyncio.Queue() self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() - self._future: Optional[asyncio.Future] = None - - def _set_future( - self, future: Union[asyncio.Future, concurrent.futures.Future] - ) -> None: - # Ensure we're working with an asyncio future internally - if isinstance(future, concurrent.futures.Future): - future = asyncio.wrap_future(future) - - future.add_done_callback( - lambda _: call_soon_in_loop(self._loop, self._queue.put_nowait, None) - ) - self._future = future def _put_work_item(self, work_item: WorkItem): # We must put items in the queue from the event loop that owns it call_soon_in_loop(self._loop, self._queue.put_nowait, work_item) - async def result(self) -> T: - if not self._future: - raise ValueError("No future being supervised.") - + async def _watch_for_work_items(self): + logger.debug("Watching for work sent to %r", self) while True: - logger.debug("Watching for work sent to %r", self) work_item: WorkItem = await self._queue.get() if work_item is None: break - result = work_item.run() - if inspect.isawaitable(result): - await result + # TODO: We could use `cancel_sync_after` to guard this call as a sync call + # could block a timeout here + retval = work_item.run() + + if inspect.isawaitable(retval): + await retval del work_item - logger.debug("Retrieving result for %r", self) - return await self._future + async def result(self) -> T: + if not self._future: + raise ValueError("No future being supervised.") + + future = ( + # Convert to an asyncio future if necessary for non-blocking wait + asyncio.wrap_future(self._future) + if isinstance(self._future, concurrent.futures.Future) + else self._future + ) + + # Stop watching for work once the future is done + future.add_done_callback( + lambda _: call_soon_in_loop(self._loop, self._queue.put_nowait, None) + ) + + # Cancel work sent to the supervisor if the future exceeds its timeout + deadline = await asyncio.wrap_future(self._deadline_future) + try: + with cancel_async_at(deadline) as ctx: + await self._watch_for_work_items() + except TimeoutError: + # Timeouts will be re-raised on future result retrieval + if not ctx.cancelled: + raise + + logger.debug("Supervisor %r retrieving result of future %r", self, future) + return await future diff --git a/src/prefect/_internal/concurrency/timeouts.py b/src/prefect/_internal/concurrency/timeouts.py new file mode 100644 index 000000000000..74a471db3269 --- /dev/null +++ b/src/prefect/_internal/concurrency/timeouts.py @@ -0,0 +1,254 @@ +import contextlib +import ctypes +import math +import signal +import sys +import threading +import time +from typing import Optional, Type + +import anyio + +from prefect.logging import get_logger + +# TODO: We should update the format for this logger to include the current thread +logger = get_logger("prefect._internal.concurrency.timeouts") + + +class CancelContext: + """ + Tracks if a cancel context manager was cancelled. + + The `cancelled` property is threadsafe. + """ + + def __init__(self, timeout: Optional[float]) -> None: + self._timeout = timeout + self._cancelled: bool = False + + @property + def timeout(self) -> Optional[float]: + return self._timeout + + @property + def cancelled(self): + return self._cancelled + + def mark_cancelled(self): + self._cancelled = True + + +@contextlib.contextmanager +def cancel_async_after(timeout: Optional[float]): + """ + Cancel any async calls within the context if it does not exit after the given + timeout. + + A timeout error will be raised on the next `await` when the timeout expires. + + Yields a `CancelContext`. + """ + ctx = CancelContext(timeout=timeout) + if timeout is None: + yield ctx + return + + try: + with anyio.fail_after(timeout) as cancel_scope: + logger.debug( + f"Entered asynchronous cancel context with %.2f timeout", timeout + ) + yield ctx + finally: + if cancel_scope.cancel_called: + ctx.mark_cancelled() + + +def get_deadline(timeout: Optional[float]): + """ + Compute an deadline given a timeout. + + Uses a monotonic clock. + """ + if timeout is None: + return None + + return time.monotonic() + timeout + + +@contextlib.contextmanager +def cancel_async_at(deadline: Optional[float]): + """ + Cancel any async calls within the context if it does not exit by the given deadline. + + Deadlines must be computed with the monotonic clock. See `get_deadline`. + + A timeout error will be raised on the next `await` when the timeout expires. + + Yields a `CancelContext`. + """ + if deadline is None: + yield CancelContext(timeout=None) + return + + timeout = max(0, deadline - time.monotonic()) + + ctx = CancelContext(timeout=timeout) + try: + with cancel_async_after(timeout) as inner_ctx: + yield ctx + finally: + if inner_ctx.cancelled: + ctx.mark_cancelled() + + +@contextlib.contextmanager +def cancel_sync_at(deadline: Optional[float]): + """ + Cancel any sync calls within the context if it does not exit by the given deadline. + + Deadlines must be computed with the monotonic clock. See `get_deadline`. + + The cancel method varies depending on if this is called in the main thread or not. + See `cancel_sync_after` for details + + Yields a `CancelContext`. + """ + if deadline is None: + yield CancelContext(timeout=None) + return + + timeout = max(0, deadline - time.monotonic()) + + ctx = CancelContext(timeout=timeout) + try: + with cancel_sync_after(timeout) as inner_ctx: + yield ctx + finally: + if inner_ctx.cancelled: + ctx.mark_cancelled() + + +@contextlib.contextmanager +def cancel_sync_after(timeout: Optional[float]): + """ + Cancel any sync calls within the context if it does not exit after the given + timeout. + + The timeout method varies depending on if this is called in the main thread or not. + See `_alarm_based_timeout` and `_watcher_thread_based_timeout` for details. + + Yields a `CancelContext`. + """ + ctx = CancelContext(timeout=timeout) + if timeout is None: + yield ctx + return + + if sys.platform.startswith("win"): + # Timeouts cannot be enforced on Windows + logger.warning( + f"Entered cancel context on Windows; %.2f timeout will not be enforced.", + timeout, + ) + yield ctx + return + + if threading.current_thread() is threading.main_thread(): + method = _alarm_based_timeout + method_name = "alarm" + else: + method = _watcher_thread_based_timeout + method_name = "watcher" + + try: + with method(timeout) as inner_ctx: + logger.debug( + f"Entered synchronous cancel context with %.2f %s based timeout", + timeout, + method_name, + ) + yield ctx + finally: + if inner_ctx.cancelled: + ctx.mark_cancelled() + + +@contextlib.contextmanager +def _alarm_based_timeout(timeout: float): + """ + Enforce a timeout using an alarm. + + Sets an alarm for `timeout` seconds, then raises a timeout error if the context is + not exited before the deadline. + + !!! Alarms cannot be floats, so the timeout is rounded up to the nearest integer. + + Alarms have the benefit of interrupt sys calls like `sleep`, but signals are always + raised in the main thread and this cannot be used elsewhere. + """ + current_thread = threading.current_thread() + if not current_thread is threading.main_thread(): + raise ValueError("Alarm based timeouts can only be used in the main thread.") + + ctx = CancelContext(timeout=timeout) + + def raise_alarm_as_timeout(signum, frame): + ctx.mark_cancelled() + logger.debug( + "Cancel fired for alarm based timeout of thread %r", current_thread.name + ) + raise TimeoutError() + + try: + signal.signal(signal.SIGALRM, raise_alarm_as_timeout) + signal.alarm(math.ceil(timeout)) # alarms do not support floats + yield ctx + finally: + signal.alarm(0) # Clear the alarm when the context exits + + +@contextlib.contextmanager +def _watcher_thread_based_timeout(timeout: float): + """ + Enforce a timeout using a watcher thread. + + Creates a thread that sleeps for `timeout` seconds, then sends a timeout error to + the supervised (current) thread if the context is not exited before the deadline. + + Note this will not interrupt sys calls like `sleep`. + """ + event = threading.Event() + supervised_thread = threading.current_thread() + ctx = CancelContext(timeout=timeout) + + def timeout_enforcer(): + time.sleep(timeout) + if not event.is_set(): + logger.debug( + "Cancel fired for watcher based timeout of thread %r", + supervised_thread.name, + ) + ctx.mark_cancelled() + _send_exception_to_thread(supervised_thread, TimeoutError) + + enforcer = threading.Thread(target=timeout_enforcer, daemon=True) + enforcer.start() + + try: + yield ctx + finally: + event.set() + + +def _send_exception_to_thread(thread: threading.Thread, exc_type: Type[BaseException]): + """ + Raise an exception in a thread. + + This will not interrupt long-running system calls like `sleep` or `wait`. + """ + ret = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(thread.ident), ctypes.py_object(exc_type) + ) + if ret == 0: + raise ValueError("Thread not found.") diff --git a/tests/_internal/concurrency/test_supervisors.py b/tests/_internal/concurrency/test_supervisors.py index bdf6d6ae814a..5672ae524dd5 100644 --- a/tests/_internal/concurrency/test_supervisors.py +++ b/tests/_internal/concurrency/test_supervisors.py @@ -1,4 +1,6 @@ +import asyncio import concurrent.futures +import time import pytest @@ -7,7 +9,6 @@ def fake_submit_fn(__fn, *args, **kwargs): future = concurrent.futures.Future() - future.set_result(__fn(*args, **kwargs)) return future @@ -15,6 +16,11 @@ def fake_fn(*args, **kwargs): pass +def sleep_repeatedly(seconds: int): + for i in range(seconds * 10): + time.sleep(float(i) / 10) + + @pytest.mark.parametrize("cls", [AsyncSupervisor, SyncSupervisor]) async def test_supervisor_repr(cls): supervisor = cls(submit_fn=fake_submit_fn) @@ -28,3 +34,93 @@ async def test_supervisor_repr(cls): == f"<{cls.__name__} submit_fn='fake_submit_fn', submitted='fake_fn'," " owner='MainThread'>" ) + + +def test_sync_supervisor_timeout_in_worker_thread(): + """ + In this test, a timeout is raised due to a slow call that is occuring on the worker + thread. + """ + with concurrent.futures.ThreadPoolExecutor() as executor: + supervisor = SyncSupervisor(submit_fn=executor.submit, timeout=0.1) + future = supervisor.submit(sleep_repeatedly, 1) + + t0 = time.time() + with pytest.raises(TimeoutError): + supervisor.result() + t1 = time.time() + + with pytest.raises(TimeoutError): + future.result() + + assert t1 - t0 < 1 + + +def test_sync_supervisor_timeout_in_main_thread(): + """ + In this test, a timeout is raised due to a slow call that is sent back to the main + thread by the worker thread. + """ + with concurrent.futures.ThreadPoolExecutor() as executor: + supervisor = SyncSupervisor(submit_fn=executor.submit, timeout=0.1) + + def on_worker_thread(): + # Send sleep to the main thread + future = supervisor.send_call(time.sleep, 2) + return future + + supervisor.submit(on_worker_thread) + + t0 = time.time() + future = supervisor.result() + t1 = time.time() + + # The timeout error is not raised by `supervisor.result()` because the worker + # does not check the result of the future; however, the work that was sent + # to the main thread should have a timeout error + with pytest.raises(TimeoutError): + future.result() + + # main thread timeouts round up to the nearest second + assert t1 - t0 < 2 + + +async def test_async_supervisor_timeout_in_worker_thread(): + with concurrent.futures.ThreadPoolExecutor() as executor: + supervisor = AsyncSupervisor(submit_fn=executor.submit, timeout=0.1) + future = supervisor.submit(sleep_repeatedly, 1) + + t0 = time.time() + with pytest.raises(TimeoutError): + await supervisor.result() + t1 = time.time() + + assert t1 - t0 < 1 + + # The future has a timeout error too + with pytest.raises(TimeoutError): + future.result() + + +async def test_async_supervisor_timeout_in_main_thread(): + with concurrent.futures.ThreadPoolExecutor() as executor: + supervisor = AsyncSupervisor(submit_fn=executor.submit, timeout=0.1) + + def on_worker_thread(): + # Send sleep to the main thread + future = supervisor.send_call(asyncio.sleep, 1) + return future + + supervisor.submit(on_worker_thread) + + t0 = time.time() + future = await supervisor.result() + t1 = time.time() + + assert t1 - t0 < 1 + + # The timeout error is not raised by `supervisor.result()` because the worker + # does not check the result of the future; however, the work that was sent + # to the main thread should have a timeout error + with pytest.raises(asyncio.CancelledError): + future.result() diff --git a/tests/_internal/concurrency/test_timeouts.py b/tests/_internal/concurrency/test_timeouts.py new file mode 100644 index 000000000000..0b4ad0feb1fe --- /dev/null +++ b/tests/_internal/concurrency/test_timeouts.py @@ -0,0 +1,81 @@ +import asyncio +import concurrent.futures +import time + +import pytest + +from prefect._internal.concurrency.timeouts import ( + cancel_async_after, + cancel_async_at, + cancel_sync_after, + cancel_sync_at, + get_deadline, +) + + +async def test_cancel_async_after(): + t0 = time.perf_counter() + with pytest.raises(TimeoutError): + with cancel_async_after(0.1) as ctx: + await asyncio.sleep(1) + t1 = time.perf_counter() + + assert ctx.cancelled + assert t1 - t0 < 1 + + +def test_cancel_sync_after_in_main_thread(): + t0 = time.perf_counter() + with pytest.raises(TimeoutError): + with cancel_sync_after(0.1) as ctx: + # floats are not suppported by alarm timeouts so this will actually timeout + # after 1s + time.sleep(2) + t1 = time.perf_counter() + + assert ctx.cancelled + assert t1 - t0 < 2 + + +def test_cancel_sync_after_in_worker_thread(): + def on_worker_thread(): + t0 = time.perf_counter() + with pytest.raises(TimeoutError): + with cancel_sync_after(0.1) as ctx: + # this timeout method does not interrupt sleep calls, the timeout is + # raised on the next instruction + for _ in range(10): + time.sleep(0.1) + t1 = time.perf_counter() + return t1 - t0, ctx + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(on_worker_thread) + elapsed_time, ctx = future.result() + + assert elapsed_time < 1 + assert ctx.cancelled + + +async def test_cancel_async_at(): + t0 = time.perf_counter() + with pytest.raises(TimeoutError): + with cancel_async_at(get_deadline(timeout=0.1)) as ctx: + await asyncio.sleep(1) + t1 = time.perf_counter() + + assert ctx.cancelled + assert t1 - t0 < 1 + + +def test_cancel_sync_at(): + t0 = time.perf_counter() + with pytest.raises(TimeoutError): + with cancel_sync_at(get_deadline(timeout=0.1)) as ctx: + # floats are not suppported by alarm timeouts so this will actually timeout + # after 1s + time.sleep(2) + t1 = time.perf_counter() + + assert ctx.cancelled + assert t1 - t0 < 2