Skip to content

Commit

Permalink
Simplify related query to remove nesting
Browse files Browse the repository at this point in the history
Signed-off-by: Olga Bulat <[email protected]>
  • Loading branch information
obulat committed Nov 3, 2023
1 parent 374a93e commit a5bb1a5
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 280 deletions.
Empty file.
222 changes: 222 additions & 0 deletions api/api/controllers/elasticsearch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,25 @@
import logging as log
import pprint
import time
from itertools import accumulate
from math import ceil

from django.conf import settings
from django.core.cache import cache

from elasticsearch import BadRequestError, NotFoundError
from elasticsearch_dsl import Q, Search
from elasticsearch_dsl.response import Hit

from api import models as models
from api.utils.check_dead_links import check_dead_links
from api.utils.dead_link_mask import get_query_hash, get_query_mask


FILTER_CACHE_TIMEOUT = 30
ELASTICSEARCH_MAX_RESULT_WINDOW = 10000
DEAD_LINK_RATIO = 1 / 2
DEEP_PAGINATION_ERROR = "Deep pagination is not allowed."


def log_timing_info(func):
Expand Down Expand Up @@ -55,3 +70,210 @@ def get_es_response(s, search_query=None):
@log_timing_info
def get_raw_es_response(index, body, search_query=None, **kwargs):
return settings.ES.search(index=index, body=body, **kwargs)


def _unmasked_query_end(page_size, page):
"""
Calculate the upper index of results to retrieve from Elasticsearch.
Used to retrieve the upper index of results to retrieve from Elasticsearch under the
following conditions:
1. There is no query mask
2. The lower index is beyond the scope of the existing query mask
3. The lower index is within the scope of the existing query mask
but the upper index exceeds it
In all these cases, the query mask is not used to calculate the upper index.
"""
return ceil(page_size * page / (1 - DEAD_LINK_RATIO))


def _paginate_with_dead_link_mask(
s: Search, page_size: int, page: int
) -> tuple[int, int]:
"""
Return the start and end of the results slice, given the query, page and page size.
In almost all cases the ``DEAD_LINK_RATIO`` will effectively double
the page size (given the current configuration of 0.5).
The "branch X" labels are for cross-referencing with the tests.
:param s: The elasticsearch Search object
:param page_size: How big the page should be.
:param page: The page number.
:return: Tuple of start and end.
"""
query_hash = get_query_hash(s)
query_mask = get_query_mask(query_hash)
if not query_mask: # branch 1
start = 0
end = _unmasked_query_end(page_size, page)
elif page_size * (page - 1) > sum(query_mask): # branch 2
start = len(query_mask)
end = _unmasked_query_end(page_size, page)
else: # branch 3
# query_mask is a list of 0 and 1 where 0 indicates the result position
# for the given query will be an invalid link. If we accumulate a query
# mask you end up, at each index, with the number of live results you
# will get back when you query that deeply.
# We then query for the start and end index _of the results_ in ES based
# on the number of results that we think will be valid based on the query mask.
# If we're requesting `page=2 page_size=3` and the mask is [0, 1, 0, 1, 0, 1],
# then we know that we have to _start_ with at least the sixth result of the
# overall query to skip the first page of 3 valid results. The "end" of the
# query will then follow the same pattern to reach the number of valid results
# required to fill the requested page. If the mask is not deep enough to
# account for the entire range, then we follow the typical assumption when
# a mask is not available that the end should be `page * page_size / 0.5`
# (i.e., double the page size)
accu_query_mask = list(accumulate(query_mask))
start = 0
if page > 1:
try: # branch 3_start_A
# find the index at which we can skip N valid results where N = all
# the results that would be skipped to arrive at the start of the
# requested page
# This will effectively be the index at which we have the number of
# previous valid results + 1 because we don't want to include the
# last valid result from the previous page
start = accu_query_mask.index(page_size * (page - 1) + 1)
except ValueError: # branch 3_start_B
# Cannot fail because of the check on branch 2 which verifies that
# the query mask already includes at least enough masked valid
# results to fulfill the requested page size
start = accu_query_mask.index(page_size * (page - 1)) + 1
# else: branch 3_start_C
# Always start page=1 queries at 0

if page_size * page > sum(query_mask): # branch 3_end_A
end = _unmasked_query_end(page_size, page)
else: # branch 3_end_B
end = accu_query_mask.index(page_size * page) + 1
return start, end


def get_query_slice(
s: Search, page_size: int, page: int, filter_dead: bool | None = False
) -> tuple[int, int]:
"""Select the start and end of the search results for this query."""

if filter_dead:
start_slice, end_slice = _paginate_with_dead_link_mask(s, page_size, page)
else:
# Paginate search query.
start_slice = page_size * (page - 1)
end_slice = page_size * page
if start_slice + end_slice > ELASTICSEARCH_MAX_RESULT_WINDOW:
raise ValueError(DEEP_PAGINATION_ERROR)
return start_slice, end_slice


def post_process_results(
s, start, end, page_size, search_results, filter_dead
) -> list[Hit] | None:
"""
Perform some steps on results fetched from the backend.
After fetching the search results from the back end, iterate through the
results, perform image validation, and route certain thumbnails through our
proxy.
Keeps making new query requests until it is able to fill the page size.
:param s: The Elasticsearch Search object.
:param start: The start of the result slice.
:param end: The end of the result slice.
:param page_size: The number of results to return in a page.
:param search_results: The Elasticsearch response object containing search
results.
:param filter_dead: Whether images should be validated.
:return: List of results.
"""
results = []
to_validate = []
for res in search_results:
if hasattr(res.meta, "highlight"):
res.fields_matched = dir(res.meta.highlight)
to_validate.append(res.url)
results.append(res)

if filter_dead:
query_hash = get_query_hash(s)
check_dead_links(query_hash, start, results, to_validate)

if len(results) == 0:
# first page is all dead links
return None

if len(results) < page_size:
"""
The variables in this function get updated in an interesting way.
Here is an example of that for a typical query. Note that ``end``
increases but start stays the same. This has the effect of slowly
increasing the size of the query we send to Elasticsearch with the
goal of backfilling the results until we have enough valid (live)
results to fulfill the requested page size.
```
page_size: 20
page: 1
start: 0
end: 40 (DEAD_LINK_RATIO applied)
end gets updated to end + end/2 = 60
end = 90
end = 90 + 45
```
"""
if end >= search_results.hits.total.value:
# Total available hits already exhausted in previous iteration
return results

end += int(end / 2)
query_size = start + end
if query_size > ELASTICSEARCH_MAX_RESULT_WINDOW:
return results

# subtract start to account for the records skipped
# and which should not count towards the total
# available hits for the query
total_available_hits = search_results.hits.total.value - start
if query_size > total_available_hits:
# Clamp the query size to last available hit. On the next
# iteration, if results are still insufficient, the check
# to compare previous_query_size and total_available_hits
# will prevent further query attempts
end = search_results.hits.total.value

s = s[start:end]
search_response = get_es_response(s, "postprocess_search")

return post_process_results(
s, start, end, page_size, search_response, filter_dead
)

return results[:page_size]


def get_excluded_providers_query() -> Q | None:
"""
Hide data sources from the catalog dynamically.
To exclude a provider, set ``filter_content`` to ``True`` in the
``ContentProvider`` model in Django admin.
"""

filter_cache_key = "filtered_providers"
filtered_providers = cache.get(key=filter_cache_key)
if not filtered_providers:
filtered_providers = models.ContentProvider.objects.filter(
filter_content=True
).values("provider_identifier")
cache.set(
key=filter_cache_key, timeout=FILTER_CACHE_TIMEOUT, value=filtered_providers
)
if provider_list := [f["provider_identifier"] for f in filtered_providers]:
return Q("terms", provider=provider_list)
return None
74 changes: 74 additions & 0 deletions api/api/controllers/elasticsearch/related.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

from elasticsearch_dsl import Search
from elasticsearch_dsl.query import Match, Q, SimpleQueryString, Term
from elasticsearch_dsl.response import Hit

from api.controllers.elasticsearch.helpers import (
get_es_response,
get_excluded_providers_query,
get_query_slice,
post_process_results,
)


def related_media(uuid: str, index: str, filter_dead: bool) -> list[Hit]:
"""
Given a UUID, finds 10 related search results based on title and tags.
Uses Match query for title or SimpleQueryString for tags.
If the item has no title and no tags, returns items by the same creator.
If the item has no title, no tags or no creator, returns empty list.
:param uuid: The UUID of the item to find related results for.
:param index: The Elasticsearch index to search (e.g. 'image')
:param filter_dead: Whether dead links should be removed.
:return: List of related results.
"""

# Search the default index for the item itself as it might be sensitive.
item_search = Search(index=index)
item_hit = item_search.query(Term(identifier=uuid)).execute().hits[0]

# Match related using title.
title = getattr(item_hit, "title", None)
tags = getattr(item_hit, "tags", None)
creator = getattr(item_hit, "creator", None)

related_query = {"must_not": [], "must": [], "should": []}

if not title and not tags:
if not creator:
return []
else:
# Only use `creator` query if there are no `title` and `tags`
related_query["should"].append(Term(creator=creator))
else:
if title:
related_query["should"].append(Match(title=title))

# Match related using tags, if the item has any.
if tags:
# Only use the first 10 tags
tags = " | ".join([tag.name for tag in tags[:10]])
tags_query = SimpleQueryString(fields=["tags.name"], query=tags)
related_query["should"].append(tags_query)

# Exclude the dynamically disabled sources.
if excluded_providers_query := get_excluded_providers_query():
related_query["must_not"].append(excluded_providers_query)
# Exclude the current item and mature content.
related_query["must_not"].extend(
[Q("term", mature=True), Q("term", identifier=uuid)]
)

# Search the filtered index for related items.
s = Search(index=f"{index}-filtered")
s = s.query("bool", **related_query)

page, page_size = 1, 10
start, end = get_query_slice(s, page_size, page, filter_dead)

response = get_es_response(s, "related_media")
results = post_process_results(s, start, end, page_size, response, filter_dead)
return results or []
Loading

0 comments on commit a5bb1a5

Please sign in to comment.