Skip to content

Commit

Permalink
Use more modern asyncio apis
Browse files Browse the repository at this point in the history
This utilises the taskgroup backport to allow usage of these apis in
pre Python 3.11 code.

This should mean the code is more reliable, and robust. It definetly
means it is clearer.

Note I've switched from WeakSet to Set and the after_done_callback as
the Python TaskGroup code uses Set. I hope this could fix and explain
a possible memory leak reported by some users.
  • Loading branch information
pgjones committed Oct 28, 2023
1 parent 3dc7908 commit 8133958
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 107 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ documentation = "https://hypercorn.readthedocs.io"
[tool.poetry.dependencies]
python = ">=3.7"
aioquic = { version = ">= 0.9.0, < 1.0", optional = true }
exceptiongroup = { version = ">= 1.1.0", python = "<3.11", optional = true }
exceptiongroup = ">= 1.1.0"
h11 = "*"
h2 = ">=3.1.0"
priority = "*"
pydata_sphinx_theme = { version = "*", optional = true }
sphinxcontrib_mermaid = { version = "*", optional = true }
taskgroup = { version = "*", python = "<3.11", allow-prereleases = true }
tomli = { version = "*", python = "<3.11" }
trio = { version = ">=0.22.0", optional = true }
uvloop = { version = "*", markers = "platform_system != 'Windows'", optional = true }
Expand Down
87 changes: 18 additions & 69 deletions src/hypercorn/asyncio/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from multiprocessing.synchronize import Event as EventType
from os import getpid
from socket import socket
from typing import Any, Awaitable, Callable, Optional
from weakref import WeakSet
from typing import Any, Awaitable, Callable, Optional, Set

from .lifespan import Lifespan
from .statsd import StatsdLogger
Expand All @@ -26,13 +25,10 @@
ShutdownError,
)


async def _windows_signal_support() -> None:
# See https://bugs.python.org/issue23057, to catch signals on
# Windows it is necessary for an IO event to happen periodically.
# Fixed by Python 3.8
while True:
await asyncio.sleep(1)
try:
from asyncio import Runner
except ImportError:
from taskgroup import Runner # type: ignore


def _share_socket(sock: socket) -> socket:
Expand Down Expand Up @@ -89,10 +85,14 @@ def _signal_handler(*_: Any) -> None: # noqa: N803
ssl_handshake_timeout = config.ssl_handshake_timeout

context = WorkerContext()
server_tasks: WeakSet = WeakSet()
server_tasks: Set[asyncio.Task] = set()

async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
server_tasks.add(asyncio.current_task(loop))
nonlocal server_tasks

task = asyncio.current_task(loop)
server_tasks.add(task)
task.add_done_callback(server_tasks.discard)
await TCPServer(app, loop, config, context, reader, writer)

servers = []
Expand Down Expand Up @@ -129,22 +129,14 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW
_, protocol = await loop.create_datagram_endpoint(
lambda: UDPServer(app, loop, config, context), sock=sock
)
server_tasks.add(loop.create_task(protocol.run()))
task = loop.create_task(protocol.run())
server_tasks.add(task)
task.add_done_callback(server_tasks.discard)
bind = repr_socket_addr(sock.family, sock.getsockname())
await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)")

tasks = []
if platform.system() == "Windows":
tasks.append(loop.create_task(_windows_signal_support()))

tasks.append(loop.create_task(raise_shutdown(shutdown_trigger)))

try:
if len(tasks):
gathered_tasks = asyncio.gather(*tasks)
await gathered_tasks
else:
loop.run_forever()
await raise_shutdown(shutdown_trigger)
except (ShutdownError, KeyboardInterrupt):
pass
finally:
Expand All @@ -154,10 +146,6 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW
server.close()
await server.wait_closed()

# Retrieve the Gathered Tasks Cancelled Exception, to
# prevent a warning that this hasn't been done.
gathered_tasks.exception()

try:
gathered_server_tasks = asyncio.gather(*server_tasks)
await asyncio.wait_for(gathered_server_tasks, config.graceful_timeout)
Expand Down Expand Up @@ -221,48 +209,9 @@ def _run(
debug: bool = False,
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.set_debug(debug)
loop.set_exception_handler(_exception_handler)

try:
loop.run_until_complete(main(shutdown_trigger=shutdown_trigger))
except KeyboardInterrupt:
pass
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())

try:
loop.run_until_complete(loop.shutdown_default_executor())
except AttributeError:
pass # shutdown_default_executor is new to Python 3.9

finally:
asyncio.set_event_loop(None)
loop.close()


def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = [task for task in asyncio.all_tasks(loop) if not task.done()]
if not tasks:
return

for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))

for task in tasks:
if not task.cancelled() and task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during shutdown",
"exception": task.exception(),
"task": task,
}
)
with Runner(debug=debug) as runner:
runner.get_loop().set_exception_handler(_exception_handler)
runner.run(main(shutdown_trigger=shutdown_trigger))


def _exception_handler(loop: asyncio.AbstractEventLoop, context: dict) -> None:
Expand Down
32 changes: 9 additions & 23 deletions src/hypercorn/asyncio/task_group.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

import asyncio
import weakref
from functools import partial
from types import TracebackType
from typing import Any, Awaitable, Callable, Optional

from ..config import Config
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope

try:
from asyncio import TaskGroup as AsyncioTaskGroup
except ImportError:
from taskgroup import TaskGroup as AsyncioTaskGroup # type: ignore


async def _handle(
app: AppWrapper,
Expand All @@ -32,8 +36,7 @@ async def _handle(
class TaskGroup:
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._tasks: weakref.WeakSet = weakref.WeakSet()
self._exiting = False
self._task_group = AsyncioTaskGroup()

async def spawn_app(
self,
Expand Down Expand Up @@ -61,28 +64,11 @@ def _call_soon(func: Callable, *args: Any) -> Any:
return app_queue.put

def spawn(self, func: Callable, *args: Any) -> None:
if self._exiting:
raise RuntimeError("Spawning whilst exiting")
self._tasks.add(self._loop.create_task(func(*args)))
self._task_group.create_task(func(*args))

async def __aenter__(self) -> "TaskGroup":
await self._task_group.__aenter__()
return self

async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
self._exiting = True
if exc_type is not None:
self._cancel_tasks()

try:
task = asyncio.gather(*self._tasks)
await task
finally:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

def _cancel_tasks(self) -> None:
for task in self._tasks:
task.cancel()
await self._task_group.__aexit__(exc_type, exc_value, tb)
14 changes: 0 additions & 14 deletions tests/asyncio/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,3 @@ async def _error_app(scope: Scope, receive: Callable, send: Callable) -> None:
async with TaskGroup(event_loop) as task_group:
await task_group.spawn_app(ASGIWrapper(_error_app), Config(), http_scope, app_queue.put)
assert (await app_queue.get()) is None


@pytest.mark.asyncio
async def test_spawn_app_cancelled(
event_loop: asyncio.AbstractEventLoop, http_scope: HTTPScope
) -> None:
async def _error_app(scope: Scope, receive: Callable, send: Callable) -> None:
raise asyncio.CancelledError()

app_queue: asyncio.Queue = asyncio.Queue()
with pytest.raises(asyncio.CancelledError):
async with TaskGroup(event_loop) as task_group:
await task_group.spawn_app(ASGIWrapper(_error_app), Config(), http_scope, app_queue.put)
assert (await app_queue.get()) is None

0 comments on commit 8133958

Please sign in to comment.