From c34979e568953016ad46392c9f34bba19a034b9d Mon Sep 17 00:00:00 2001 From: Quinten Stokkink Date: Mon, 16 Sep 2024 11:25:20 +0200 Subject: [PATCH] Fixed InvalidStateError on shutdown --- src/tribler/core/restapi/events_endpoint.py | 8 ++++- .../core/restapi/test_events_endpoint.py | 29 ++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/tribler/core/restapi/events_endpoint.py b/src/tribler/core/restapi/events_endpoint.py index 675a7136a4..bee86db43f 100644 --- a/src/tribler/core/restapi/events_endpoint.py +++ b/src/tribler/core/restapi/events_endpoint.py @@ -2,7 +2,7 @@ import json import time -from asyncio import CancelledError, Event, Queue +from asyncio import CancelledError, Event, Future, Queue from contextlib import suppress from traceback import format_exception from typing import TYPE_CHECKING, TypedDict @@ -251,6 +251,12 @@ async def get_events(self, request: Request) -> web.StreamResponse: else: self._logger.info("Event stream was closed due to shutdown") + # A ``shutdown()`` on our parent may have cancelled ``_handler_waiter`` before this method returns. + # If we leave this be, an error will be raised if the ``Future`` result is set after this method returns. + # See: https://github.com/Tribler/tribler/issues/8156 + if request.protocol._handler_waiter and request.protocol._handler_waiter.cancelled(): # noqa: SLF001 + request.protocol._handler_waiter = Future() # noqa: SLF001 + # See: https://github.com/Tribler/tribler/pull/7906 with suppress(ValueError): self.events_responses.remove(response) diff --git a/src/tribler/test_unit/core/restapi/test_events_endpoint.py b/src/tribler/test_unit/core/restapi/test_events_endpoint.py index 2ed08dea99..d30d740d8e 100644 --- a/src/tribler/test_unit/core/restapi/test_events_endpoint.py +++ b/src/tribler/test_unit/core/restapi/test_events_endpoint.py @@ -1,4 +1,4 @@ -from asyncio import ensure_future, sleep +from asyncio import Future, ensure_future, sleep from aiohttp.abc import AbstractStreamWriter from ipv8.test.base import TestBase @@ -19,8 +19,21 @@ def __init__(self, endpoint: EventsEndpoint, count: int = 1) -> None: Create a new GetEventsRequest. """ self.payload_writer = MockStreamWriter(endpoint, count=count) + self._handler_waiter = Future() super().__init__({}, "GET", "/api/events", payload_writer=self.payload_writer) + def shutdown(self) -> None: + """ + Mimic a shutdown. + """ + self._handler_waiter.cancel() + + def finish_handler(self) -> None: + """ + Mimic finishing a handler. + """ + self._handler_waiter.set_result(None) + class MockStreamWriter(AbstractStreamWriter): """ @@ -214,3 +227,17 @@ async def test_no_forward_illegal_notification(self) -> None: self.assertEqual((b'event: tribler_new_version\n' b'data: {"version": "super cool version"}' b'\n\n'), request.payload_writer.captured[1]) + + async def test_shutdown_parent_before_event(self) -> None: + """ + Test if a parent shutdown does not cause errors after handling a child. + """ + request = GetEventsRequest(self.endpoint, count=3) # Blocks until shutdown + response_future = ensure_future(self.endpoint.get_events(request)) + + request.shutdown() # 1. The parent protocol is shut down + self.endpoint.shutdown_event.set() # 2. Tribler signals shutdown to the events endpoint + response = await response_future + request.finish_handler() # 3. aiohttp behavior: finish the request handling + + self.assertEqual(200, response.status)