Skip to content

Commit

Permalink
Decrease link validation timeout (#4598)
Browse files Browse the repository at this point in the history
* Lower link validation timeout to 0.8 seconds

Also make it configurable, renaming the existing link validation configs module to a more generic link validation name

* Add provider to link validation timing log

This will help us evaluate whether per-provider timeouts are necessary
  • Loading branch information
sarayourfriend authored Jul 15, 2024
1 parent 5e5079e commit c4b15ae
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 55 deletions.
3 changes: 1 addition & 2 deletions api/api/controllers/search_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,8 @@ def _post_process_results(
results = list(search_results)

if filter_dead:
to_validate = [res.url for res in search_results]
query_hash = get_query_hash(s)
check_dead_links(query_hash, start, results, to_validate)
check_dead_links(query_hash, start, results)

if len(results) == 0:
# first page is all dead links
Expand Down
63 changes: 40 additions & 23 deletions api/api/utils/check_dead_links/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,29 @@
}


def _get_cached_statuses(redis, image_urls):
def _get_cached_statuses(redis, urls):
try:
cached_statuses = redis.mget([CACHE_PREFIX + url for url in image_urls])
cached_statuses = redis.mget([CACHE_PREFIX + url for url in urls])
return [
int(b.decode("utf-8")) if b is not None else None for b in cached_statuses
]
except ConnectionError:
logger.warning("Redis connect failed, validating all URLs without cache.")
return [None] * len(image_urls)
return [None] * len(urls)


def _get_expiry(status, default):
return config(f"LINK_VALIDATION_CACHE_EXPIRY__{status}", default=default, cast=int)


_timeout = aiohttp.ClientTimeout(total=2)
_timeout = aiohttp.ClientTimeout(total=settings.LINK_VALIDATION_TIMEOUT_SECONDS)
_TIMEOUT_STATUS = -2
_ERROR_STATUS = -1


async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]:
async def _head(
url: str, session: aiohttp.ClientSession, provider: str
) -> tuple[str, int]:
start_time = time.perf_counter()

try:
Expand All @@ -52,33 +56,48 @@ async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]:
)
status = response.status
except (aiohttp.ClientError, asyncio.TimeoutError) as exception:
_log_validation_failure(exception)
status = -1
if not isinstance(exception, asyncio.TimeoutError):
logger.error("dead_link_validation_error", e=exception)
status = _ERROR_STATUS
else:
status = _TIMEOUT_STATUS

end_time = time.perf_counter()
logger.info(
"dead_link_validation_timing",
url=url,
status=status,
time=end_time - start_time,
provider=provider,
)

return url, status


# https://stackoverflow.com/q/55259755
@async_to_sync
async def _make_head_requests(urls: list[str]) -> list[tuple[str, int]]:
async def _make_head_requests(
urls: dict[str, int], results: list[Hit]
) -> list[tuple[str, int]]:
"""
Concurrently HEAD request the urls.
``urls`` must map to the index of the corresponding result in ``results``.
:param urls: A dictionary with keys of the URLs to request, mapped to the index of that url in ``results``
:param results: The ordered list of results, including ones not being validated.
"""
session = await get_aiohttp_session()
tasks = [asyncio.ensure_future(_head(url, session)) for url in urls]
tasks = [
asyncio.ensure_future(_head(url, session, results[idx].provider))
for url, idx in urls.items()
]
responses = asyncio.gather(*tasks)
await responses
return responses.result()


def check_dead_links(
query_hash: str, start_slice: int, results: list[Hit], image_urls: list[str]
) -> None:
def check_dead_links(query_hash: str, start_slice: int, results: list[Hit]) -> None:
"""
Make sure images exist before we display them.
Expand All @@ -88,26 +107,28 @@ def check_dead_links(
Results are cached in redis and shared amongst all API servers in the
cluster.
"""
if not image_urls:
logger.info("no image urls to validate")
if not results:
logger.info("link_validation_empty_results")
return

urls = [result.url for result in results]

logger.debug("starting validation")
start_time = time.time()

# Pull matching images from the cache.
redis = django_redis.get_redis_connection("default")
cached_statuses = _get_cached_statuses(redis, image_urls)
cached_statuses = _get_cached_statuses(redis, urls)
logger.debug(f"len(cached_statuses)={len(cached_statuses)}")

# Anything that isn't in the cache needs to be validated via HEAD request.
to_verify = {}
for idx, url in enumerate(image_urls):
for idx, url in enumerate(urls):
if cached_statuses[idx] is None:
to_verify[url] = idx
logger.debug(f"len(to_verify)={len(to_verify)}")

verified = _make_head_requests(to_verify.keys())
verified = _make_head_requests(to_verify, results)

# Cache newly verified image statuses.
to_cache = {CACHE_PREFIX + url: status for url, status in verified}
Expand All @@ -119,7 +140,7 @@ def check_dead_links(
for key, status in to_cache.items():
if status == 200:
logger.debug(f"healthy link key={key}")
elif status == -1:
elif status == _TIMEOUT_STATUS or status == _ERROR_STATUS:
logger.debug(f"no response from provider key={key}")
else:
logger.debug(f"broken link key={key}")
Expand Down Expand Up @@ -152,7 +173,7 @@ def check_dead_links(
if status in status_mapping.unknown:
logger.warning(
"Image validation failed due to rate limiting or blocking. "
f"url={image_urls[idx]} "
f"url={urls[idx]} "
f"status={status} "
f"provider={provider} "
)
Expand Down Expand Up @@ -184,7 +205,3 @@ def check_dead_links(
f"start_time={start_time} "
f"delta={end_time - start_time} "
)


def _log_validation_failure(exception):
logger.warning(f"Failed to validate image! Reason: {exception}")
2 changes: 1 addition & 1 deletion api/conf/settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"spectacular.py",
"thumbnails.py",
# Openverse-specific settings
"link_validation_cache.py",
"link_validation.py",
"misc.py",
"openverse.py",
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
logger = structlog.get_logger(__name__)


LINK_VALIDATION_TIMEOUT_SECONDS = config(
"LINK_VALIDATION_TIMEOUT_SECONDS", default=0.8, cast=float
)


class LinkValidationCacheExpiryConfiguration(defaultdict):
"""Link validation cache expiry configuration."""

Expand Down
9 changes: 7 additions & 2 deletions api/test/factory/es_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@


def create_mock_es_http_image_hit(
_id: str, index: str, live: bool = True, identifier: str | None = None
_id: str,
index: str,
live: bool = True,
identifier: str | None = None,
**additional_fields,
):
return {
"_index": index,
Expand Down Expand Up @@ -39,7 +43,8 @@ def create_mock_es_http_image_hit(
"created_on": "2022-02-26T08:48:33+00:00",
"tags": [{"name": "bird"}],
"mature": False,
},
}
| additional_fields,
}


Expand Down
4 changes: 2 additions & 2 deletions api/test/integration/test_dead_link_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_empty_cached_statuses(_, image_urls):


def _patch_make_head_requests():
def _make_head_requests(urls):
def _make_head_requests(urls, *args, **kwargs):
responses = []
for idx, url in enumerate(urls):
status_code = 200 if idx % 10 != 0 else 404
Expand All @@ -45,7 +45,7 @@ def _make_head_requests(urls):
def patch_link_validation_dead_for_count(count):
total_res_count = 0

def _make_head_requests(urls):
def _make_head_requests(urls, *args, **kwargs):
nonlocal total_res_count
responses = []
for idx, url in enumerate(urls):
Expand Down
2 changes: 1 addition & 1 deletion api/test/unit/configuration/test_link_validation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from conf.settings.link_validation_cache import LinkValidationCacheExpiryConfiguration
from conf.settings.link_validation import LinkValidationCacheExpiryConfiguration


@pytest.mark.parametrize(
Expand Down
3 changes: 2 additions & 1 deletion api/test/unit/controllers/test_search_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,8 @@ def test_excessive_recursion_in_post_process(
redis,
caplog,
):
def _delete_all_results_but_first(_, __, results, ___):
def _delete_all_results_but_first(*args):
results = args[2]
results[1:] = []

mock_check_dead_links.side_effect = _delete_all_results_but_first
Expand Down
54 changes: 31 additions & 23 deletions api/test/unit/utils/test_check_dead_links.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,46 @@
import asyncio
from collections.abc import Callable
from typing import Any

import pook
import pytest
from aiohttp.client import ClientSession
from elasticsearch_dsl.response import Hit
from structlog.testing import capture_logs

from api.utils.check_dead_links import HEADERS, check_dead_links
from test.factory.es_http import create_mock_es_http_image_hit


def _make_hits(
count: int, gen_fields: Callable[[int], dict[str, Any]] = lambda _: dict()
):
return [
Hit(create_mock_es_http_image_hit(_id, "image", live=True, **gen_fields(_id)))
for _id in range(40)
]


@pook.on
def test_sends_user_agent():
query_hash = "test_sends_user_agent"
results = [{"provider": "best_provider_ever"} for _ in range(40)]
image_urls = [f"https://example.org/{i}" for i in range(len(results))]
results = _make_hits(40)
start_slice = 0

head_mock = (
pook.head(pook.regex(r"https://example.org/\d"))
pook.head(pook.regex(r"https://example.com/openverse-live-image-result-url/\d"))
.headers(HEADERS)
.times(len(results))
.reply(200)
.mock
)

check_dead_links(query_hash, start_slice, results, image_urls)
check_dead_links(query_hash, start_slice, results)

assert head_mock.calls == len(results)
requested_urls = [req.rawurl for req in head_mock.matches]
for url in image_urls:
assert url in requested_urls
for result in results:
assert result.url in requested_urls


def test_handles_timeout(monkeypatch):
Expand All @@ -39,15 +51,14 @@ def test_handles_timeout(monkeypatch):
3 seconds.
"""
query_hash = "test_handles_timeout"
results = [{"identifier": i, "provider": "best_provider_ever"} for i in range(1)]
image_urls = [f"https://example.org/{i}" for i in range(len(results))]
results = _make_hits(1)
start_slice = 0

async def raise_timeout_error(*args, **kwargs):
raise asyncio.TimeoutError()

monkeypatch.setattr(ClientSession, "_request", raise_timeout_error)
check_dead_links(query_hash, start_slice, results, image_urls)
check_dead_links(query_hash, start_slice, results)

# `check_dead_links` directly modifies the results list
# if the results are timing out then they're considered dead and discarded
Expand All @@ -60,22 +71,20 @@ async def raise_timeout_error(*args, **kwargs):
def test_403_considered_dead(provider):
query_hash = f"test_{provider}_403_considered_dead"
other_provider = "fake_other_provider"
results = [
{"identifier": i, "provider": provider if i % 2 else other_provider}
for i in range(4)
]
results = _make_hits(
4, lambda i: {"provider": provider if i % 2 else other_provider}
)
len_results = len(results)
image_urls = [f"https://example.org/{i}" for i in range(len(results))]
start_slice = 0

head_mock = (
pook.head(pook.regex(r"https://example.org/\d"))
pook.head(pook.regex(r"https://example.com/openverse-live-image-result-url/\d"))
.times(len(results))
.reply(403)
.mock
)

check_dead_links(query_hash, start_slice, results, image_urls)
check_dead_links(query_hash, start_slice, results)

assert head_mock.calls == len_results

Expand All @@ -92,25 +101,24 @@ def test_mset_and_expire_for_responses(is_cache_reachable, cache_name, request):
cache = request.getfixturevalue(cache_name)

query_hash = "test_mset_and_expiry_for_responses"
results = [{"identifier": i, "provider": "best_provider_ever"} for i in range(40)]
image_urls = [f"https://example.org/{i}" for i in range(len(results))]
results = _make_hits(40)
start_slice = 0

(
pook.head(pook.regex(r"https://example.org/\d"))
pook.head(pook.regex(r"https://example.com/openverse-live-image-result-url/\d"))
.headers(HEADERS)
.times(len(results))
.reply(200)
)

with capture_logs() as cap_logs:
check_dead_links(query_hash, start_slice, results, image_urls)
check_dead_links(query_hash, start_slice, results)

if is_cache_reachable:
for i in range(len(results)):
assert cache.get(f"valid:https://example.org/{i}") == b"200"
for result in results:
assert cache.get(f"valid:{result.url}") == b"200"
# TTL is 30 days for 2xx responses
assert cache.ttl(f"valid:https://example.org/{i}") == 2592000
assert cache.ttl(f"valid:{result.url}") == 2592000
else:
messages = [record["event"] for record in cap_logs]
assert all(
Expand Down

0 comments on commit c4b15ae

Please sign in to comment.