Skip to content
This repository has been archived by the owner on Feb 22, 2023. It is now read-only.

Commit

Permalink
Refactor search controller for consistency and clarity (#699)
Browse files Browse the repository at this point in the history
(cherry picked from commit 0e442a4)
  • Loading branch information
dhruvkb committed Jun 8, 2022
1 parent 29fd019 commit 4b98cc4
Show file tree
Hide file tree
Showing 14 changed files with 521 additions and 525 deletions.
Empty file.
40 changes: 40 additions & 0 deletions api/catalog/api/controllers/elasticsearch/related.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from elasticsearch_dsl import Search

from catalog.api.controllers.elasticsearch.utils import (
exclude_filtered_providers,
get_query_slice,
get_result_and_page_count,
post_process_results,
)


def related_media(uuid, index, filter_dead):
"""
Given a UUID, find related search results.
"""
search_client = Search(using="default", index=index)

# Convert UUID to sequential ID.
item = search_client.query("match", identifier=uuid)
_id = item.execute().hits[0].id

s = search_client.query(
"more_like_this",
fields=["tags.name", "title", "creator"],
like={"_index": index, "_id": _id},
min_term_freq=1,
max_query_terms=50,
)
# Never show mature content in recommendations.
s = s.exclude("term", mature=True)
s = exclude_filtered_providers(s)
page_size = 10
page = 1
start, end = get_query_slice(s, page_size, page, filter_dead)
s = s[start:end]
response = s.execute()
results = post_process_results(s, start, end, page_size, response, filter_dead)

result_count, _ = get_result_and_page_count(response, results, page_size)

return results, result_count
198 changes: 198 additions & 0 deletions api/catalog/api/controllers/elasticsearch/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from __future__ import annotations

import json
import logging as log
import pprint
from typing import List, Literal, Tuple, Union

from django.conf import settings

from elasticsearch.exceptions import RequestError
from elasticsearch_dsl import Q, Search
from elasticsearch_dsl.response import Hit

from catalog.api.controllers.elasticsearch.utils import (
exclude_filtered_providers,
get_query_slice,
get_result_and_page_count,
post_process_results,
)
from catalog.api.serializers.media_serializers import MediaSearchRequestSerializer


def _quote_escape(query_string: str) -> str:
"""
If there are any unmatched quotes in the query supplied by the user, ignore
them by escaping.
:param query_string: the string in which to escape unbalanced quotes
:return: the given string, if the quotes are balanced, the escaped string otherwise
"""

num_quotes = query_string.count('"')
if num_quotes % 2 == 1:
return query_string.replace('"', '\\"')
else:
return query_string


def _apply_filter(
s: Search,
query_serializer: MediaSearchRequestSerializer,
basis: Union[str, tuple[str, str]],
behaviour: Literal["filter", "exclude"] = "filter",
) -> Search:
"""
Parse and apply a filter from the search parameters serializer. The
parameter key is assumed to have the same name as the corresponding
Elasticsearch property. Each parameter value is assumed to be a comma
separated list encoded as a string.
:param s: the search query to issue to Elasticsearch
:param query_serializer: the ``MediaSearchRequestSerializer`` object with the query
:param basis: the name of the field in the serializer and Elasticsearch
:param behaviour: whether to accept (``filter``) or reject (``exclude``) the hit
:return: the modified search query
"""

search_params = query_serializer.data
if isinstance(basis, tuple):
serializer_field, es_field = basis
else:
serializer_field = es_field = basis
if serializer_field in search_params:
filters = []
for arg in search_params[serializer_field].split(","):
filters.append(Q("term", **{es_field: arg}))
method = getattr(s, behaviour) # can be ``s.filter`` or ``s.exclude``
return method("bool", should=filters)
else:
return s


def perform_search(
query_serializer: MediaSearchRequestSerializer,
index: Literal["image", "audio"],
ip: int,
) -> Tuple[List[Hit], int, int]:
"""
Perform a ranked, paginated search based on the query and filters given in the
search request.
:param query_serializer: the ``MediaSearchRequestSerializer`` object with the query
:param index: The Elasticsearch index to search (e.g. 'image')
:param ip: the users' hashed IP to consistently route to the same ES shard
:return: the list of search results with the page and result count
"""

s = Search(using="default", index=index)
search_params = query_serializer.data

rules: dict[Literal["filter", "exclude"], list[Union[str, tuple[str, str]]]] = {
"filter": [
"extension",
"category",
("categories", "category"),
"aspect_ratio",
"size",
"length",
"source",
("license", "license.keyword"),
("license_type", "license.keyword"),
],
"exclude": [
("excluded_source", "source"),
],
}
for behaviour, bases in rules.items():
for basis in bases:
s = _apply_filter(s, query_serializer, basis, behaviour)

# Exclude mature content
if not search_params["mature"]:
s = s.exclude("term", mature=True)
# Exclude sources with ``filter_content`` enabled
s = exclude_filtered_providers(s)

# Search either by generic multimatch or by "advanced search" with
# individual field-level queries specified.

search_fields = ["tags.name", "title", "description"]
if "q" in search_params:
query = _quote_escape(search_params["q"])
s = s.query(
"simple_query_string",
query=query,
fields=search_fields,
default_operator="AND",
)
# Boost exact matches
quotes_stripped = query.replace('"', "")
exact_match_boost = Q(
"simple_query_string",
fields=["title"],
query=f'"{quotes_stripped}"',
boost=10000,
)
s.query = Q("bool", must=s.query, should=exact_match_boost)
else:
query_bases = ["creator", "title", ("tags", "tags.name")]
for query_basis in query_bases:
if isinstance(query_basis, tuple):
serializer_field, es_field = query_basis
else:
serializer_field = es_field = query_basis
if serializer_field in search_params:
value = _quote_escape(search_params[serializer_field])
s = s.query("simple_query_string", fields=[es_field], query=value)

if settings.USE_RANK_FEATURES:
feature_boost = {"standardized_popularity": 10000}
rank_queries = []
for field, boost in feature_boost.items():
rank_queries.append(Q("rank_feature", field=field, boost=boost))
s.query = Q("bool", must=s.query, should=rank_queries)

# Use highlighting to determine which fields contribute to the selection of
# top results.
s = s.highlight(*search_fields)
s = s.highlight_options(order="score")

# Route users to the same Elasticsearch worker node to reduce
# pagination inconsistencies and increase cache hits.
s = s.params(preference=str(ip), request_timeout=7)

# Paginate
start, end = get_query_slice(
s,
search_params["page_size"],
search_params["page"],
search_params["filter_dead"],
)
s = s[start:end]

try:
if settings.VERBOSE_ES_RESPONSE:
log.info(pprint.pprint(s.to_dict()))
search_response = s.execute()
log.info(
f"query={json.dumps(s.to_dict())}," f" es_took_ms={search_response.took}"
)
if settings.VERBOSE_ES_RESPONSE:
log.info(pprint.pprint(search_response.to_dict()))
except RequestError as e:
raise ValueError(e)

results = post_process_results(
s,
start,
end,
search_params["page_size"],
search_response,
search_params["filter_dead"],
)

result_count, page_count = get_result_and_page_count(
search_response, results, search_params["page_size"]
)
return results, page_count, result_count
51 changes: 51 additions & 0 deletions api/catalog/api/controllers/elasticsearch/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import logging as log
from typing import Literal

from django.core.cache import cache

from elasticsearch.exceptions import NotFoundError
from elasticsearch_dsl import Search


SOURCE_CACHE_TIMEOUT = 60 * 20 # seconds


def get_stats(index: Literal["image", "audio"]):
"""
Given an index, find all available data sources and return their counts. This data
is cached in Redis. See ``load_sample_data.sh`` for example of clearing the cache.
:param index: the Elasticsearch index name
:return: a dictionary mapping sources to the count of their media items
"""

source_cache_name = "sources-" + index
try:
sources = cache.get(key=source_cache_name)
if sources is not None:
return sources
except ValueError:
log.warning("Source cache fetch failed")

# Don't increase `size` without reading this issue first:
# https://github.com/elastic/elasticsearch/issues/18838
size = 100
try:
s = Search(using="default", index=index)
s.aggs.bucket(
"unique_sources",
"terms",
field="source.keyword",
size=size,
order={"_key": "desc"},
)
results = s.execute()
buckets = results["aggregations"]["unique_sources"]["buckets"]
sources = {result["key"]: result["doc_count"] for result in buckets}
except NotFoundError:
sources = {}

if sources:
cache.set(key=source_cache_name, timeout=SOURCE_CACHE_TIMEOUT, value=sources)

return sources
Loading

0 comments on commit 4b98cc4

Please sign in to comment.