diff --git a/api/api/utils/aiohttp.py b/api/api/utils/aiohttp.py new file mode 100644 index 00000000000..ef2e3b65f53 --- /dev/null +++ b/api/api/utils/aiohttp.py @@ -0,0 +1,62 @@ +import asyncio +import logging +import weakref + +import aiohttp + +from conf.asgi import application + + +logger = logging.getLogger(__name__) + + +_SESSIONS: weakref.WeakKeyDictionary[ + asyncio.AbstractEventLoop, aiohttp.ClientSession +] = weakref.WeakKeyDictionary() + +_LOCKS: weakref.WeakKeyDictionary[ + asyncio.AbstractEventLoop, asyncio.Lock +] = weakref.WeakKeyDictionary() + + +async def get_aiohttp_session() -> aiohttp.ClientSession: + """ + Safely retrieve a shared aiohttp session for the current event loop. + + If the loop already has an aiohttp session associated, it will be reused. + If the loop has not yet had an aiohttp session created for it, a new one + will be created and returned. + + While the main application will always run in the same loop, and while + that covers 99% of our use cases, it is still possible for `async_to_sync` + to cause a new loop to be created if, for example, `force_new_loop` is + passed. In order to prevent surprises should that ever be the case, this + function assumes that it's possible for multiple loops to be present in + the lifetime of the application and therefore we need to verify that each + loop gets its own session. + """ + + loop = asyncio.get_running_loop() + + if loop not in _LOCKS: + _LOCKS[loop] = asyncio.Lock() + + async with _LOCKS[loop]: + if loop not in _SESSIONS: + create_session = True + msg = "No session for loop. Creating new session." + elif _SESSIONS[loop].closed: + create_session = True + msg = "Loop's previous session closed. Creating new session." + else: + create_session = False + msg = "Reusing existing session for loop." + + logger.info(msg) + + if create_session: + session = aiohttp.ClientSession() + application.register_shutdown_handler(session.close) + _SESSIONS[loop] = session + + return _SESSIONS[loop] diff --git a/api/api/utils/check_dead_links/__init__.py b/api/api/utils/check_dead_links/__init__.py index 0e586f95fd0..c1a0f721467 100644 --- a/api/api/utils/check_dead_links/__init__.py +++ b/api/api/utils/check_dead_links/__init__.py @@ -10,6 +10,7 @@ from decouple import config from elasticsearch_dsl.response import Hit +from api.utils.aiohttp import get_aiohttp_session from api.utils.check_dead_links.provider_status_mappings import provider_status_mappings from api.utils.dead_link_mask import get_query_mask, save_query_mask @@ -32,9 +33,15 @@ def _get_expiry(status, default): return config(f"LINK_VALIDATION_CACHE_EXPIRY__{status}", default=default, cast=int) +_timeout = aiohttp.ClientTimeout(total=2) + + async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]: try: - async with session.head(url, allow_redirects=False) as response: + request = session.head( + url, allow_redirects=False, headers=HEADERS, timeout=_timeout + ) + async with request as response: return url, response.status except (aiohttp.ClientError, asyncio.TimeoutError) as exception: _log_validation_failure(exception) @@ -45,11 +52,10 @@ async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]: @async_to_sync async def _make_head_requests(urls: list[str]) -> list[tuple[str, int]]: tasks = [] - timeout = aiohttp.ClientTimeout(total=2) - async with aiohttp.ClientSession(headers=HEADERS, timeout=timeout) as session: - tasks = [asyncio.ensure_future(_head(url, session)) for url in urls] - responses = asyncio.gather(*tasks) - await responses + session = await get_aiohttp_session() + tasks = [asyncio.ensure_future(_head(url, session)) for url in urls] + responses = asyncio.gather(*tasks) + await responses return responses.result() diff --git a/api/conf/asgi_handler.py b/api/conf/asgi_handler.py index 7f397c7d363..ed68f94e3f6 100644 --- a/api/conf/asgi_handler.py +++ b/api/conf/asgi_handler.py @@ -29,8 +29,12 @@ class OpenverseASGIHandler(ASGIHandler): def __init__(self): super().__init__() self._on_shutdown: list[weakref.WeakMethod | weakref.ref] = [] + self.has_shutdown = False def _clean_ref(self, ref): + if self.has_shutdown: + return + self.logger.info("Cleaning up a ref") self._on_shutdown.remove(ref) @@ -73,3 +77,4 @@ async def shutdown(self): handler() self.logger.info(f"Executed {live_handlers} handler(s) before shutdown.") + self.has_shutdown = True diff --git a/api/test/conftest.py b/api/test/conftest.py new file mode 100644 index 00000000000..01f4e1a4f31 --- /dev/null +++ b/api/test/conftest.py @@ -0,0 +1,19 @@ +import pytest +from asgiref.sync import async_to_sync + +from conf.asgi import application + + +@pytest.fixture(scope="session", autouse=True) +def call_application_shutdown(): + """ + Call application shutdown during test session teardown. + + This cannot be an async fixture because the scope is session + and pytest-asynio's `event_loop` fixture, which is auto-used + for async tests and fixtures, is function scoped, which is + incomatible with session scoped fixtures. `async_to_sync` works + fine here, so it's not a problem. + """ + yield + async_to_sync(application.shutdown)() diff --git a/api/test/unit/utils/test_check_dead_links.py b/api/test/unit/utils/test_check_dead_links.py index 915c843f28d..d047c8ab50d 100644 --- a/api/test/unit/utils/test_check_dead_links.py +++ b/api/test/unit/utils/test_check_dead_links.py @@ -1,16 +1,14 @@ import asyncio from unittest import mock -import aiohttp import pook import pytest from api.utils.check_dead_links import HEADERS, check_dead_links -@mock.patch.object(aiohttp, "ClientSession", wraps=aiohttp.ClientSession) @pook.on -def test_sends_user_agent(wrapped_client_session: mock.AsyncMock): +def test_sends_user_agent(): query_hash = "test_sends_user_agent" results = [{"provider": "best_provider_ever"} for _ in range(40)] image_urls = [f"https://example.org/{i}" for i in range(len(results))] @@ -18,6 +16,7 @@ def test_sends_user_agent(wrapped_client_session: mock.AsyncMock): head_mock = ( pook.head(pook.regex(r"https://example.org/\d")) + .headers(HEADERS) .times(len(results)) .reply(200) .mock @@ -30,8 +29,6 @@ def test_sends_user_agent(wrapped_client_session: mock.AsyncMock): for url in image_urls: assert url in requested_urls - wrapped_client_session.assert_called_once_with(headers=HEADERS, timeout=mock.ANY) - def test_handles_timeout(): """