diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index f7098256ea8..b7420ca02ef 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -98,9 +98,8 @@ def _post_process_results( results = list(search_results) if filter_dead: - to_validate = [res.url for res in search_results] query_hash = get_query_hash(s) - check_dead_links(query_hash, start, results, to_validate) + check_dead_links(query_hash, start, results) if len(results) == 0: # first page is all dead links diff --git a/api/api/utils/check_dead_links/__init__.py b/api/api/utils/check_dead_links/__init__.py index ed65cf03f15..f9abc9de4a0 100644 --- a/api/api/utils/check_dead_links/__init__.py +++ b/api/api/utils/check_dead_links/__init__.py @@ -25,25 +25,29 @@ } -def _get_cached_statuses(redis, image_urls): +def _get_cached_statuses(redis, urls): try: - cached_statuses = redis.mget([CACHE_PREFIX + url for url in image_urls]) + cached_statuses = redis.mget([CACHE_PREFIX + url for url in urls]) return [ int(b.decode("utf-8")) if b is not None else None for b in cached_statuses ] except ConnectionError: logger.warning("Redis connect failed, validating all URLs without cache.") - return [None] * len(image_urls) + return [None] * len(urls) def _get_expiry(status, default): return config(f"LINK_VALIDATION_CACHE_EXPIRY__{status}", default=default, cast=int) -_timeout = aiohttp.ClientTimeout(total=2) +_timeout = aiohttp.ClientTimeout(total=settings.LINK_VALIDATION_TIMEOUT_SECONDS) +_TIMEOUT_STATUS = -2 +_ERROR_STATUS = -1 -async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]: +async def _head( + url: str, session: aiohttp.ClientSession, provider: str +) -> tuple[str, int]: start_time = time.perf_counter() try: @@ -52,8 +56,11 @@ async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]: ) status = response.status except (aiohttp.ClientError, asyncio.TimeoutError) as exception: - _log_validation_failure(exception) - status = -1 + if not isinstance(exception, asyncio.TimeoutError): + logger.error("dead_link_validation_error", e=exception) + status = _ERROR_STATUS + else: + status = _TIMEOUT_STATUS end_time = time.perf_counter() logger.info( @@ -61,6 +68,7 @@ async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]: url=url, status=status, time=end_time - start_time, + provider=provider, ) return url, status @@ -68,17 +76,28 @@ async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]: # https://stackoverflow.com/q/55259755 @async_to_sync -async def _make_head_requests(urls: list[str]) -> list[tuple[str, int]]: +async def _make_head_requests( + urls: dict[str, int], results: list[Hit] +) -> list[tuple[str, int]]: + """ + Concurrently HEAD request the urls. + + ``urls`` must map to the index of the corresponding result in ``results``. + + :param urls: A dictionary with keys of the URLs to request, mapped to the index of that url in ``results`` + :param results: The ordered list of results, including ones not being validated. + """ session = await get_aiohttp_session() - tasks = [asyncio.ensure_future(_head(url, session)) for url in urls] + tasks = [ + asyncio.ensure_future(_head(url, session, results[idx].provider)) + for url, idx in urls.items() + ] responses = asyncio.gather(*tasks) await responses return responses.result() -def check_dead_links( - query_hash: str, start_slice: int, results: list[Hit], image_urls: list[str] -) -> None: +def check_dead_links(query_hash: str, start_slice: int, results: list[Hit]) -> None: """ Make sure images exist before we display them. @@ -88,26 +107,28 @@ def check_dead_links( Results are cached in redis and shared amongst all API servers in the cluster. """ - if not image_urls: - logger.info("no image urls to validate") + if not results: + logger.info("link_validation_empty_results") return + urls = [result.url for result in results] + logger.debug("starting validation") start_time = time.time() # Pull matching images from the cache. redis = django_redis.get_redis_connection("default") - cached_statuses = _get_cached_statuses(redis, image_urls) + cached_statuses = _get_cached_statuses(redis, urls) logger.debug(f"len(cached_statuses)={len(cached_statuses)}") # Anything that isn't in the cache needs to be validated via HEAD request. to_verify = {} - for idx, url in enumerate(image_urls): + for idx, url in enumerate(urls): if cached_statuses[idx] is None: to_verify[url] = idx logger.debug(f"len(to_verify)={len(to_verify)}") - verified = _make_head_requests(to_verify.keys()) + verified = _make_head_requests(to_verify, results) # Cache newly verified image statuses. to_cache = {CACHE_PREFIX + url: status for url, status in verified} @@ -119,7 +140,7 @@ def check_dead_links( for key, status in to_cache.items(): if status == 200: logger.debug(f"healthy link key={key}") - elif status == -1: + elif status == _TIMEOUT_STATUS or status == _ERROR_STATUS: logger.debug(f"no response from provider key={key}") else: logger.debug(f"broken link key={key}") @@ -152,7 +173,7 @@ def check_dead_links( if status in status_mapping.unknown: logger.warning( "Image validation failed due to rate limiting or blocking. " - f"url={image_urls[idx]} " + f"url={urls[idx]} " f"status={status} " f"provider={provider} " ) @@ -184,7 +205,3 @@ def check_dead_links( f"start_time={start_time} " f"delta={end_time - start_time} " ) - - -def _log_validation_failure(exception): - logger.warning(f"Failed to validate image! Reason: {exception}") diff --git a/api/conf/settings/__init__.py b/api/conf/settings/__init__.py index f52fd936035..e4848a53d6f 100644 --- a/api/conf/settings/__init__.py +++ b/api/conf/settings/__init__.py @@ -42,7 +42,7 @@ "spectacular.py", "thumbnails.py", # Openverse-specific settings - "link_validation_cache.py", + "link_validation.py", "misc.py", "openverse.py", ) diff --git a/api/conf/settings/link_validation_cache.py b/api/conf/settings/link_validation.py similarity index 95% rename from api/conf/settings/link_validation_cache.py rename to api/conf/settings/link_validation.py index c872c220b1b..e440e0643c1 100644 --- a/api/conf/settings/link_validation_cache.py +++ b/api/conf/settings/link_validation.py @@ -12,6 +12,11 @@ logger = structlog.get_logger(__name__) +LINK_VALIDATION_TIMEOUT_SECONDS = config( + "LINK_VALIDATION_TIMEOUT_SECONDS", default=0.8, cast=float +) + + class LinkValidationCacheExpiryConfiguration(defaultdict): """Link validation cache expiry configuration.""" diff --git a/api/test/factory/es_http.py b/api/test/factory/es_http.py index d784271db9a..8f6edb1b959 100644 --- a/api/test/factory/es_http.py +++ b/api/test/factory/es_http.py @@ -6,7 +6,11 @@ def create_mock_es_http_image_hit( - _id: str, index: str, live: bool = True, identifier: str | None = None + _id: str, + index: str, + live: bool = True, + identifier: str | None = None, + **additional_fields, ): return { "_index": index, @@ -39,7 +43,8 @@ def create_mock_es_http_image_hit( "created_on": "2022-02-26T08:48:33+00:00", "tags": [{"name": "bird"}], "mature": False, - }, + } + | additional_fields, } diff --git a/api/test/integration/test_dead_link_filter.py b/api/test/integration/test_dead_link_filter.py index 1b7e7a6ca3e..ec7a36efa9c 100644 --- a/api/test/integration/test_dead_link_filter.py +++ b/api/test/integration/test_dead_link_filter.py @@ -32,7 +32,7 @@ def get_empty_cached_statuses(_, image_urls): def _patch_make_head_requests(): - def _make_head_requests(urls): + def _make_head_requests(urls, *args, **kwargs): responses = [] for idx, url in enumerate(urls): status_code = 200 if idx % 10 != 0 else 404 @@ -45,7 +45,7 @@ def _make_head_requests(urls): def patch_link_validation_dead_for_count(count): total_res_count = 0 - def _make_head_requests(urls): + def _make_head_requests(urls, *args, **kwargs): nonlocal total_res_count responses = [] for idx, url in enumerate(urls): diff --git a/api/test/unit/configuration/test_link_validation_cache.py b/api/test/unit/configuration/test_link_validation_cache.py index 8735f46f116..8ca1e76571a 100644 --- a/api/test/unit/configuration/test_link_validation_cache.py +++ b/api/test/unit/configuration/test_link_validation_cache.py @@ -4,7 +4,7 @@ import pytest -from conf.settings.link_validation_cache import LinkValidationCacheExpiryConfiguration +from conf.settings.link_validation import LinkValidationCacheExpiryConfiguration @pytest.mark.parametrize( diff --git a/api/test/unit/controllers/test_search_controller.py b/api/test/unit/controllers/test_search_controller.py index 2e4b9d650f8..d8ff1da43ed 100644 --- a/api/test/unit/controllers/test_search_controller.py +++ b/api/test/unit/controllers/test_search_controller.py @@ -810,7 +810,8 @@ def test_excessive_recursion_in_post_process( redis, caplog, ): - def _delete_all_results_but_first(_, __, results, ___): + def _delete_all_results_but_first(*args): + results = args[2] results[1:] = [] mock_check_dead_links.side_effect = _delete_all_results_but_first diff --git a/api/test/unit/utils/test_check_dead_links.py b/api/test/unit/utils/test_check_dead_links.py index 2c99238d35d..344bb94e9b2 100644 --- a/api/test/unit/utils/test_check_dead_links.py +++ b/api/test/unit/utils/test_check_dead_links.py @@ -1,34 +1,46 @@ import asyncio +from collections.abc import Callable +from typing import Any import pook import pytest from aiohttp.client import ClientSession +from elasticsearch_dsl.response import Hit from structlog.testing import capture_logs from api.utils.check_dead_links import HEADERS, check_dead_links +from test.factory.es_http import create_mock_es_http_image_hit + + +def _make_hits( + count: int, gen_fields: Callable[[int], dict[str, Any]] = lambda _: dict() +): + return [ + Hit(create_mock_es_http_image_hit(_id, "image", live=True, **gen_fields(_id))) + for _id in range(40) + ] @pook.on def test_sends_user_agent(): query_hash = "test_sends_user_agent" - results = [{"provider": "best_provider_ever"} for _ in range(40)] - image_urls = [f"https://example.org/{i}" for i in range(len(results))] + results = _make_hits(40) start_slice = 0 head_mock = ( - pook.head(pook.regex(r"https://example.org/\d")) + pook.head(pook.regex(r"https://example.com/openverse-live-image-result-url/\d")) .headers(HEADERS) .times(len(results)) .reply(200) .mock ) - check_dead_links(query_hash, start_slice, results, image_urls) + check_dead_links(query_hash, start_slice, results) assert head_mock.calls == len(results) requested_urls = [req.rawurl for req in head_mock.matches] - for url in image_urls: - assert url in requested_urls + for result in results: + assert result.url in requested_urls def test_handles_timeout(monkeypatch): @@ -39,15 +51,14 @@ def test_handles_timeout(monkeypatch): 3 seconds. """ query_hash = "test_handles_timeout" - results = [{"identifier": i, "provider": "best_provider_ever"} for i in range(1)] - image_urls = [f"https://example.org/{i}" for i in range(len(results))] + results = _make_hits(1) start_slice = 0 async def raise_timeout_error(*args, **kwargs): raise asyncio.TimeoutError() monkeypatch.setattr(ClientSession, "_request", raise_timeout_error) - check_dead_links(query_hash, start_slice, results, image_urls) + check_dead_links(query_hash, start_slice, results) # `check_dead_links` directly modifies the results list # if the results are timing out then they're considered dead and discarded @@ -60,22 +71,20 @@ async def raise_timeout_error(*args, **kwargs): def test_403_considered_dead(provider): query_hash = f"test_{provider}_403_considered_dead" other_provider = "fake_other_provider" - results = [ - {"identifier": i, "provider": provider if i % 2 else other_provider} - for i in range(4) - ] + results = _make_hits( + 4, lambda i: {"provider": provider if i % 2 else other_provider} + ) len_results = len(results) - image_urls = [f"https://example.org/{i}" for i in range(len(results))] start_slice = 0 head_mock = ( - pook.head(pook.regex(r"https://example.org/\d")) + pook.head(pook.regex(r"https://example.com/openverse-live-image-result-url/\d")) .times(len(results)) .reply(403) .mock ) - check_dead_links(query_hash, start_slice, results, image_urls) + check_dead_links(query_hash, start_slice, results) assert head_mock.calls == len_results @@ -92,25 +101,24 @@ def test_mset_and_expire_for_responses(is_cache_reachable, cache_name, request): cache = request.getfixturevalue(cache_name) query_hash = "test_mset_and_expiry_for_responses" - results = [{"identifier": i, "provider": "best_provider_ever"} for i in range(40)] - image_urls = [f"https://example.org/{i}" for i in range(len(results))] + results = _make_hits(40) start_slice = 0 ( - pook.head(pook.regex(r"https://example.org/\d")) + pook.head(pook.regex(r"https://example.com/openverse-live-image-result-url/\d")) .headers(HEADERS) .times(len(results)) .reply(200) ) with capture_logs() as cap_logs: - check_dead_links(query_hash, start_slice, results, image_urls) + check_dead_links(query_hash, start_slice, results) if is_cache_reachable: - for i in range(len(results)): - assert cache.get(f"valid:https://example.org/{i}") == b"200" + for result in results: + assert cache.get(f"valid:{result.url}") == b"200" # TTL is 30 days for 2xx responses - assert cache.ttl(f"valid:https://example.org/{i}") == 2592000 + assert cache.ttl(f"valid:{result.url}") == 2592000 else: messages = [record["event"] for record in cap_logs] assert all(