Skip to content
This repository has been archived by the owner on Feb 22, 2023. It is now read-only.

Refactor search controller for consistency and clarity #778

Merged
merged 2 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
40 changes: 40 additions & 0 deletions api/catalog/api/controllers/elasticsearch/related.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from elasticsearch_dsl import Search

from catalog.api.controllers.elasticsearch.utils import (
exclude_filtered_providers,
get_query_slice,
get_result_and_page_count,
post_process_results,
)


def related_media(uuid, index, filter_dead):
"""
Given a UUID, find related search results.
"""
search_client = Search(using="default", index=index)

# Convert UUID to sequential ID.
item = search_client.query("match", identifier=uuid)
_id = item.execute().hits[0].id

s = search_client.query(
"more_like_this",
fields=["tags.name", "title", "creator"],
like={"_index": index, "_id": _id},
min_term_freq=1,
max_query_terms=50,
)
# Never show mature content in recommendations.
s = s.exclude("term", mature=True)
s = exclude_filtered_providers(s)
page_size = 10
page = 1
start, end = get_query_slice(s, page_size, page, filter_dead)
s = s[start:end]
response = s.execute()
results = post_process_results(s, start, end, page_size, response, filter_dead)

result_count, _ = get_result_and_page_count(response, results, page_size)

return results, result_count
208 changes: 208 additions & 0 deletions api/catalog/api/controllers/elasticsearch/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
from __future__ import annotations

import json
import logging as log
import pprint
from typing import List, Literal, Tuple

from django.conf import settings

from elasticsearch.exceptions import RequestError
from elasticsearch_dsl import Q, Search
from elasticsearch_dsl.response import Hit

from catalog.api.controllers.elasticsearch.utils import (
exclude_filtered_providers,
get_query_slice,
get_result_and_page_count,
post_process_results,
)
from catalog.api.serializers.media_serializers import MediaSearchRequestSerializer


class FieldMapping:
"""
Establishes a mapping between a field in ``MediaSearchRequestSerializer`` and the
Elasticsearch index for a media.
"""

def __init__(self, serializer_field: str, es_field: str = None):
self.serializer_field: str = serializer_field
"""the name of the field in ``MediaSearchRequestSerializer``"""

self.es_field: str = es_field or serializer_field
"""the name of the field in the Elasticsearch index"""


def _quote_escape(query_string: str) -> str:
"""
If there are any unmatched quotes in the query supplied by the user, ignore
them by escaping.

:param query_string: the string in which to escape unbalanced quotes
:return: the given string, if the quotes are balanced, the escaped string otherwise
"""

num_quotes = query_string.count('"')
if num_quotes % 2 == 1:
return query_string.replace('"', '\\"')
else:
return query_string


def _apply_filter(
s: Search,
query_serializer: MediaSearchRequestSerializer,
mapping: FieldMapping,
behaviour: Literal["filter", "exclude"] = "filter",
) -> Search:
"""
Parse and apply a filter from the search parameters serializer. The
parameter key is assumed to have the same name as the corresponding
Elasticsearch property. Each parameter value is assumed to be a comma
separated list encoded as a string.

:param s: the search query to issue to Elasticsearch
:param query_serializer: the ``MediaSearchRequestSerializer`` object with the query
:param mapping: the name of the field in the serializer and Elasticsearch
:param behaviour: whether to accept (``filter``) or reject (``exclude``) the hit
:return: the modified search query
"""

search_params = query_serializer.data
if mapping.serializer_field in search_params:
filters = []
for arg in search_params[mapping.serializer_field].split(","):
filters.append(Q("term", **{mapping.es_field: arg}))
method = getattr(s, behaviour) # can be ``s.filter`` or ``s.exclude``
return method("bool", should=filters)
else:
return s


def perform_search(
query_serializer: MediaSearchRequestSerializer,
index: Literal["image", "audio"],
ip: int,
) -> Tuple[List[Hit], int, int]:
"""
Perform a ranked, paginated search based on the query and filters given in the
search request.

:param query_serializer: the ``MediaSearchRequestSerializer`` object with the query
:param index: The Elasticsearch index to search (e.g. 'image')
:param ip: the users' hashed IP to consistently route to the same ES shard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting, I didn't know we did this. I wonder if circumventing ES's built-in load balancing could be an issue for us.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know enough to comment if it can be a problem, but I can see how redirecting all requests from a particularly active consumer to a single shard can be a bottleneck. I haven't modified it in this PR though, it's been that way from before.

:return: the list of search results with the page and result count
"""

s = Search(using="default", index=index)
search_params = query_serializer.data

rules: dict[Literal["filter", "exclude"], list[FieldMapping]] = {
"filter": [
FieldMapping("extension"),
FieldMapping("category"),
FieldMapping("categories", "category"),
FieldMapping("aspect_ratio"),
FieldMapping("size"),
FieldMapping("length"),
FieldMapping("source"),
FieldMapping("license", "license.keyword"),
FieldMapping("license_type", "license.keyword"),
],
"exclude": [
FieldMapping("excluded_source", "source"),
],
}
for behaviour, mappings in rules.items():
for mapping in mappings:
s = _apply_filter(s, query_serializer, mapping, behaviour)

# Exclude mature content
if not search_params["mature"]:
s = s.exclude("term", mature=True)
# Exclude sources with ``filter_content`` enabled
s = exclude_filtered_providers(s)

# Search either by generic multimatch or by "advanced search" with
# individual field-level queries specified.

search_fields = ["tags.name", "title", "description"]
if "q" in search_params:
query = _quote_escape(search_params["q"])
s = s.query(
"simple_query_string",
query=query,
fields=search_fields,
default_operator="AND",
)
# Boost exact matches
quotes_stripped = query.replace('"', "")
exact_match_boost = Q(
"simple_query_string",
fields=["title"],
query=f'"{quotes_stripped}"',
boost=10000,
)
s.query = Q("bool", must=s.query, should=exact_match_boost)
else:
query_bases = ["creator", "title", ("tags", "tags.name")]
for query_basis in query_bases:
if isinstance(query_basis, tuple):
serializer_field, es_field = query_basis
else:
serializer_field = es_field = query_basis
if serializer_field in search_params:
value = _quote_escape(search_params[serializer_field])
s = s.query("simple_query_string", fields=[es_field], query=value)

if settings.USE_RANK_FEATURES:
feature_boost = {"standardized_popularity": 10000}
rank_queries = []
for field, boost in feature_boost.items():
rank_queries.append(Q("rank_feature", field=field, boost=boost))
s.query = Q("bool", must=s.query, should=rank_queries)

# Use highlighting to determine which fields contribute to the selection of
# top results.
s = s.highlight(*search_fields)
s = s.highlight_options(order="score")

# Route users to the same Elasticsearch worker node to reduce
# pagination inconsistencies and increase cache hits.
s = s.params(preference=str(ip), request_timeout=7)

# Paginate
start, end = get_query_slice(
s,
search_params["page_size"],
search_params["page"],
search_params["filter_dead"],
)
s = s[start:end]

try:
if settings.VERBOSE_ES_RESPONSE:
log.info(pprint.pprint(s.to_dict()))
search_response = s.execute()
log.info(
f"query={json.dumps(s.to_dict())}," f" es_took_ms={search_response.took}"
)
if settings.VERBOSE_ES_RESPONSE:
log.info(pprint.pprint(search_response.to_dict()))
except RequestError as e:
raise ValueError(e)

results = post_process_results(
s,
start,
end,
search_params["page_size"],
search_response,
search_params["filter_dead"],
)

result_count, page_count = get_result_and_page_count(
search_response, results, search_params["page_size"]
)
return results, page_count, result_count
51 changes: 51 additions & 0 deletions api/catalog/api/controllers/elasticsearch/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import logging as log
from typing import Literal

from django.core.cache import cache

from elasticsearch.exceptions import NotFoundError
from elasticsearch_dsl import Search


SOURCE_CACHE_TIMEOUT = 60 * 20 # seconds


def get_stats(index: Literal["image", "audio"]):
"""
Given an index, find all available data sources and return their counts. This data
is cached in Redis. See ``load_sample_data.sh`` for example of clearing the cache.

:param index: the Elasticsearch index name
:return: a dictionary mapping sources to the count of their media items
"""

source_cache_name = "sources-" + index
try:
sources = cache.get(key=source_cache_name)
if sources is not None:
return sources
except ValueError:
log.warning("Source cache fetch failed")

# Don't increase `size` without reading this issue first:
# https://github.com/elastic/elasticsearch/issues/18838
size = 100
try:
s = Search(using="default", index=index)
s.aggs.bucket(
"unique_sources",
"terms",
field="source.keyword",
size=size,
order={"_key": "desc"},
)
results = s.execute()
buckets = results["aggregations"]["unique_sources"]["buckets"]
sources = {result["key"]: result["doc_count"] for result in buckets}
except NotFoundError:
sources = {}

if sources:
cache.set(key=source_cache_name, timeout=SOURCE_CACHE_TIMEOUT, value=sources)

return sources
Loading