Skip to content

Commit

Permalink
Add aiohttp client sharing
Browse files Browse the repository at this point in the history
  • Loading branch information
sarayourfriend committed Sep 14, 2023
1 parent 2d9485f commit 152040a
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 11 deletions.
62 changes: 62 additions & 0 deletions api/api/utils/aiohttp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import asyncio
import logging
import weakref

import aiohttp

from conf.asgi import application


logger = logging.getLogger(__name__)


_SESSIONS: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, aiohttp.ClientSession
] = weakref.WeakKeyDictionary()

_LOCKS: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, asyncio.Lock
] = weakref.WeakKeyDictionary()


async def get_aiohttp_session() -> aiohttp.ClientSession:
"""
Safely retrieve a shared aiohttp session for the current event loop.
If the loop already has an aiohttp session associated, it will be reused.
If the loop has not yet had an aiohttp session created for it, a new one
will be created and returned.
While the main application will always run in the same loop, and while
that covers 99% of our use cases, it is still possible for `async_to_sync`
to cause a new loop to be created if, for example, `force_new_loop` is
passed. In order to prevent surprises should that ever be the case, this
function assumes that it's possible for multiple loops to be present in
the lifetime of the application and therefore we need to verify that each
loop gets its own session.
"""

loop = asyncio.get_running_loop()

if loop not in _LOCKS:
_LOCKS[loop] = asyncio.Lock()

async with _LOCKS[loop]:
if loop not in _SESSIONS:
create_session = True
msg = "No session for loop. Creating new session."
elif _SESSIONS[loop].closed:
create_session = True
msg = "Loop's previous session closed. Creating new session."
else:
create_session = False
msg = "Reusing existing session for loop."

logger.info(msg)

if create_session:
session = aiohttp.ClientSession()
application.register_shutdown_handler(session.close)
_SESSIONS[loop] = session

return _SESSIONS[loop]
18 changes: 12 additions & 6 deletions api/api/utils/check_dead_links/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from decouple import config
from elasticsearch_dsl.response import Hit

from api.utils.aiohttp import get_aiohttp_session
from api.utils.check_dead_links.provider_status_mappings import provider_status_mappings
from api.utils.dead_link_mask import get_query_mask, save_query_mask

Expand All @@ -32,9 +33,15 @@ def _get_expiry(status, default):
return config(f"LINK_VALIDATION_CACHE_EXPIRY__{status}", default=default, cast=int)


_timeout = aiohttp.ClientTimeout(total=2)


async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]:
try:
async with session.head(url, allow_redirects=False) as response:
request = session.head(
url, allow_redirects=False, headers=HEADERS, timeout=_timeout
)
async with request as response:
return url, response.status
except (aiohttp.ClientError, asyncio.TimeoutError) as exception:
_log_validation_failure(exception)
Expand All @@ -45,11 +52,10 @@ async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]:
@async_to_sync
async def _make_head_requests(urls: list[str]) -> list[tuple[str, int]]:
tasks = []
timeout = aiohttp.ClientTimeout(total=2)
async with aiohttp.ClientSession(headers=HEADERS, timeout=timeout) as session:
tasks = [asyncio.ensure_future(_head(url, session)) for url in urls]
responses = asyncio.gather(*tasks)
await responses
session = await get_aiohttp_session()
tasks = [asyncio.ensure_future(_head(url, session)) for url in urls]
responses = asyncio.gather(*tasks)
await responses
return responses.result()


Expand Down
5 changes: 5 additions & 0 deletions api/conf/asgi_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ class OpenverseASGIHandler(ASGIHandler):
def __init__(self):
super().__init__()
self._on_shutdown: list[weakref.WeakMethod | weakref.ref] = []
self.has_shutdown = False

def _clean_ref(self, ref):
if self.has_shutdown:
return

self.logger.info("Cleaning up a ref")
self._on_shutdown.remove(ref)

Expand Down Expand Up @@ -73,3 +77,4 @@ async def shutdown(self):
handler()

self.logger.info(f"Executed {live_handlers} handler(s) before shutdown.")
self.has_shutdown = True
19 changes: 19 additions & 0 deletions api/test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
from asgiref.sync import async_to_sync

from conf.asgi import application


@pytest.fixture(scope="session", autouse=True)
def call_application_shutdown():
"""
Call application shutdown during test session teardown.
This cannot be an async fixture because the scope is session
and pytest-asynio's `event_loop` fixture, which is auto-used
for async tests and fixtures, is function scoped, which is
incomatible with session scoped fixtures. `async_to_sync` works
fine here, so it's not a problem.
"""
yield
async_to_sync(application.shutdown)()
7 changes: 2 additions & 5 deletions api/test/unit/utils/test_check_dead_links.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import asyncio
from unittest import mock

import aiohttp
import pook
import pytest

from api.utils.check_dead_links import HEADERS, check_dead_links


@mock.patch.object(aiohttp, "ClientSession", wraps=aiohttp.ClientSession)
@pook.on
def test_sends_user_agent(wrapped_client_session: mock.AsyncMock):
def test_sends_user_agent():
query_hash = "test_sends_user_agent"
results = [{"provider": "best_provider_ever"} for _ in range(40)]
image_urls = [f"https://example.org/{i}" for i in range(len(results))]
start_slice = 0

head_mock = (
pook.head(pook.regex(r"https://example.org/\d"))
.headers(HEADERS)
.times(len(results))
.reply(200)
.mock
Expand All @@ -30,8 +29,6 @@ def test_sends_user_agent(wrapped_client_session: mock.AsyncMock):
for url in image_urls:
assert url in requested_urls

wrapped_client_session.assert_called_once_with(headers=HEADERS, timeout=mock.ANY)


def test_handles_timeout():
"""
Expand Down

0 comments on commit 152040a

Please sign in to comment.