-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathtaskmanager.py
303 lines (250 loc) · 12.1 KB
/
taskmanager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
from __future__ import annotations
import logging
import time
import traceback
import types
from asyncio import CancelledError, Future, Task, ensure_future, gather, get_running_loop, iscoroutinefunction, sleep
from contextlib import suppress
from functools import wraps
from threading import RLock
from typing import TYPE_CHECKING, Any, Callable
from weakref import WeakValueDictionary
from .util import coroutine, succeed
if TYPE_CHECKING:
from collections.abc import Coroutine, Hashable, Sequence
from concurrent.futures import ThreadPoolExecutor
MAX_TASK_AGE = 600
async def interval_runner(delay: float, interval: float, task: Callable,
*args: Any) -> None: # noqa: ANN401
"""
Low-level scheduler for tasks that are supposed to run at a given interval.
"""
await sleep(delay)
while True:
await task(*args)
await sleep(interval)
async def delay_runner(delay: float, task: Callable, *args: Any) -> None: # noqa: ANN401
"""
Low-level scheduler for tasks that are supposed to run after a given interval.
"""
await sleep(delay)
await task(*args)
def task(func: Callable) -> Callable:
"""
Register a TaskManager function as an anonymous task and return the Task
object so that it can be awaited if needed. Any exceptions will be logged.
Note that if awaited, exceptions will still need to be handled.
"""
if not iscoroutinefunction(func):
msg = "Task decorator should be used with coroutine functions only!"
raise TypeError(msg)
@wraps(func)
def wrapper(self: TaskManager, *args: Any, **kwargs: Any) -> Future: # noqa: ANN401
return self.register_anonymous_task(func.__name__,
ensure_future(func(self, *args, **kwargs)),
ignore=(Exception,))
return wrapper
def set_name(self: Future, __value: object) -> None:
"""
This method mimics ``Task.set_name`` for non-Task objects. See ``Task.set_name`` for signature description.
:param self: The ``Future`` instance to set the name of
:param __value: The value to set
.. seealso:: :class:`asyncio.Task`
.. seealso:: :func:`asyncio.Task.set_name`
"""
self.name = __value # type: ignore[attr-defined]
def get_name(self: Future) -> object:
"""
This method mimics ``Task.get_name`` for non-Task objects. See ``Task.get_name`` for signature description.
:param self: The ``Future`` instance to get the name of
.. seealso:: :class:`asyncio.Task`
.. seealso:: :func:`asyncio.Task.get_name`
"""
return self.name # type: ignore[attr-defined]
class TaskManager:
"""
Provides a set of tools to maintain a list of asyncio Tasks that are to be
executed during the lifetime of an arbitrary object, usually getting killed with it.
"""
def __init__(self) -> None:
"""
Create a new TaskManager and start the introspection loop.
"""
self._pending_tasks: WeakValueDictionary[Hashable, Future] = WeakValueDictionary()
self._shutdown_tasks: list[tuple[Callable | Coroutine, tuple[Any, ...], dict[str, Any]]] = []
self._task_lock = RLock()
self._shutdown = False
self._counter = 0
self._logger = logging.getLogger(self.__class__.__name__)
self._checker = self.register_task("_check_tasks", self._check_tasks,
interval=MAX_TASK_AGE, delay=MAX_TASK_AGE * 1.5)
def _check_tasks(self) -> None:
now = time.time()
for name, task in self._pending_tasks.items():
if not task.interval and now - task.start_time > MAX_TASK_AGE: # type: ignore[attr-defined]
self._logger.warning('Non-interval task "%s" has been running for %.2f!',
name, now - task.start_time) # type: ignore[attr-defined]
def replace_task(self, name: Hashable, *args: Any, **kwargs) -> Future: # noqa: ANN401
"""
Replace named task with the new one, cancelling the old one in the process.
"""
new_task: Future = Future()
def cancel_cb(_: Any) -> None: # noqa: ANN401
try:
new_task.set_result(self.register_task(name, *args, **kwargs))
except Exception as e:
new_task.set_exception(e)
old_task = self.cancel_pending_task(name)
old_task.add_done_callback(cancel_cb)
return new_task
def register_task(self, name: Hashable, task: Callable | Coroutine | Future, # noqa: C901
*args: Any, delay: float | None = None, # noqa: ANN401
interval: float | None = None, ignore: Sequence[type | BaseException] = ()) -> Future:
"""
Register a Task/(coroutine)function so it can be canceled at shutdown time or by name.
"""
if not isinstance(task, Task) and not callable(task) and not isinstance(task, Future):
msg = "Register_task takes a Task/(coroutine)function/Future as a parameter"
raise TypeError(msg)
if (interval or delay) and not callable(task):
msg = "Cannot run non-callable at an interval or with a delay"
raise ValueError(msg)
if not isinstance(ignore, tuple) or not all(issubclass(e, Exception) for e in ignore):
msg = "Ignore should be a tuple of Exceptions or an empty tuple"
raise ValueError(msg)
with self._task_lock:
if self._shutdown:
self._logger.warning("Not adding task %s due to shutdown!", str(task))
if isinstance(task, (Task, Future)) and not task.done():
task.cancel()
# We need to return an awaitable in case the caller awaits the output of register_task.
return succeed(None)
if self.is_pending_task_active(name):
msg = f"Task already exists: '{name}'"
raise RuntimeError(msg)
if callable(task):
task = task if iscoroutinefunction(task) else coroutine(task)
if interval:
# The default delay for looping calls is the same as the interval
delay = interval if delay is None else delay
task = ensure_future(interval_runner(delay, interval, task, *args))
elif delay:
task = ensure_future(delay_runner(delay, task, *args))
else:
task = ensure_future(task(*args))
# Since weak references to list/tuple are not allowed, we're not storing start_time/interval
# in _pending_tasks. Instead, we add them as attributes to the task.
task.start_time = time.time() # type: ignore[attr-defined]
task.interval = interval # type: ignore[attr-defined]
if not hasattr(task, "set_name"):
task.set_name = types.MethodType(set_name, task) # type: ignore[attr-defined]
task.get_name = types.MethodType(get_name, task) # type: ignore[attr-defined]
task.set_name(f"{self.__class__.__name__}:{name}") # type: ignore[attr-defined]
assert isinstance(task, (Task, Future))
def done_cb(future: Future) -> None:
self._pending_tasks.pop(name, None)
try:
future.result()
except CancelledError:
pass
except ignore as e: # type: ignore[misc]
self._logger.exception("Task resulted in error: %s\n%s", e, "".join(traceback.format_exc()))
self._pending_tasks[name] = task
task.add_done_callback(done_cb)
return task
def register_anonymous_task(self, basename: str, task: Callable | Coroutine | Future,
*args: Any, **kwargs) -> Future: # noqa: ANN401
"""
Wrapper for register_task to derive a unique name from the basename.
"""
self._counter += 1
return self.register_task(basename + " " + str(self._counter), task, *args, **kwargs)
def register_executor_task(self, name: str, func: Callable, *args: Any, # noqa: ANN401
executor: ThreadPoolExecutor | None = None, anon: bool = False, **kwargs) -> Future:
"""
Run a synchronous function on the Asyncio threadpool. This function does not work with async functions.
"""
if not callable(func) or iscoroutinefunction(func):
msg = "Expected a non-async function as a parameter"
raise TypeError(msg)
future = get_running_loop().run_in_executor(executor, func, *args, **kwargs)
if anon:
return self.register_anonymous_task(name, future)
return self.register_task(name, future)
def register_shutdown_task(self, task: Callable | Coroutine, *args: Any, **kwargs) -> None: # noqa: ANN401
"""
Register a task to be run when this manager is shut down.
"""
self._shutdown_tasks.append((task, args, kwargs))
def cancel_pending_task(self, name: Hashable) -> Future:
"""
Cancels the named task.
"""
with self._task_lock:
task = self._pending_tasks.get(name, None)
if not task:
return succeed(None)
if not task.done():
task.cancel()
self._pending_tasks.pop(name, None)
return task
def cancel_all_pending_tasks(self) -> list[Future]:
"""
Cancels all the registered tasks.
This usually should be called when stopping or destroying the object so no tasks are left floating around.
"""
with self._task_lock:
assert all(isinstance(t, (Task, Future)) for t in self._pending_tasks.values()), self._pending_tasks
return [self.cancel_pending_task(name) for name in list(self._pending_tasks.keys())]
def is_pending_task_active(self, name: Hashable) -> bool:
"""
Return a boolean determining if a task is active.
"""
with self._task_lock:
task = self._pending_tasks.get(name, None)
return not task.done() if task else False
def get_task(self, name: Hashable) -> Future | None:
"""
Return a task if it exists. Otherwise, return None.
"""
with self._task_lock:
return self._pending_tasks.get(name, None)
def get_tasks(self) -> list[Future]:
"""
Returns a list of all registered tasks, excluding tasks the are created by the TaskManager itself.
"""
with self._task_lock:
return [t for t in self._pending_tasks.values() if t != self._checker]
def get_anonymous_tasks(self, base_name: str) -> list[Future]:
"""
Return all tasks with a given base name.
Note that this method will return ALL tasks that start with the given base name, including non-anonymous ones.
"""
with self._task_lock:
return [t[1] for t in self._pending_tasks.items() if isinstance(t[0], str) and t[0].startswith(base_name)]
async def wait_for_tasks(self) -> None:
"""
Waits until all registered tasks are done.
"""
tasks = self.get_tasks()
if tasks:
await gather(*tasks, return_exceptions=True)
async def shutdown_task_manager(self) -> None:
"""
Clear the task manager, cancel all pending tasks and disallow new tasks being added.
"""
if self._shutdown:
return
with self._task_lock:
self._shutdown = True
tasks = self.cancel_all_pending_tasks()
if tasks:
with suppress(CancelledError):
await gather(*tasks)
for post_shutdown_task, args, kwargs in self._shutdown_tasks:
if iscoroutinefunction(post_shutdown_task):
await post_shutdown_task(*args, **kwargs)
elif callable(post_shutdown_task): # This is not necessary, but Mypy wants it here
post_shutdown_task(*args, **kwargs)
self._shutdown_tasks = []
__all__ = ["TaskManager", "task"]