From 4b98cc4140c55966e83c3ce791806c03c7c8ebb7 Mon Sep 17 00:00:00 2001 From: Dhruv Bhanushali Date: Wed, 25 May 2022 19:07:40 +0400 Subject: [PATCH] Refactor search controller for consistency and clarity (#699) (cherry picked from commit 0e442a4ad09410448d5594faff1c0c83abb323c1) --- .../api/controllers/elasticsearch/__init__.py | 0 .../api/controllers/elasticsearch/related.py | 40 ++ .../api/controllers/elasticsearch/search.py | 198 ++++++++ .../api/controllers/elasticsearch/stats.py | 51 ++ .../api/controllers/elasticsearch/utils.py | 160 ++++++ .../api/controllers/search_controller.py | 466 ------------------ api/catalog/api/models/audio.py | 4 +- api/catalog/api/models/image.py | 4 +- api/catalog/api/models/media.py | 4 +- .../api/serializers/media_serializers.py | 19 +- api/catalog/api/utils/pagination.py | 37 +- api/catalog/api/views/media_views.py | 24 +- api/catalog/settings.py | 38 ++ load_sample_data.sh | 1 + 14 files changed, 521 insertions(+), 525 deletions(-) create mode 100644 api/catalog/api/controllers/elasticsearch/__init__.py create mode 100644 api/catalog/api/controllers/elasticsearch/related.py create mode 100644 api/catalog/api/controllers/elasticsearch/search.py create mode 100644 api/catalog/api/controllers/elasticsearch/stats.py create mode 100644 api/catalog/api/controllers/elasticsearch/utils.py delete mode 100644 api/catalog/api/controllers/search_controller.py diff --git a/api/catalog/api/controllers/elasticsearch/__init__.py b/api/catalog/api/controllers/elasticsearch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/catalog/api/controllers/elasticsearch/related.py b/api/catalog/api/controllers/elasticsearch/related.py new file mode 100644 index 000000000..df28513b6 --- /dev/null +++ b/api/catalog/api/controllers/elasticsearch/related.py @@ -0,0 +1,40 @@ +from elasticsearch_dsl import Search + +from catalog.api.controllers.elasticsearch.utils import ( + exclude_filtered_providers, + get_query_slice, + get_result_and_page_count, + post_process_results, +) + + +def related_media(uuid, index, filter_dead): + """ + Given a UUID, find related search results. + """ + search_client = Search(using="default", index=index) + + # Convert UUID to sequential ID. + item = search_client.query("match", identifier=uuid) + _id = item.execute().hits[0].id + + s = search_client.query( + "more_like_this", + fields=["tags.name", "title", "creator"], + like={"_index": index, "_id": _id}, + min_term_freq=1, + max_query_terms=50, + ) + # Never show mature content in recommendations. + s = s.exclude("term", mature=True) + s = exclude_filtered_providers(s) + page_size = 10 + page = 1 + start, end = get_query_slice(s, page_size, page, filter_dead) + s = s[start:end] + response = s.execute() + results = post_process_results(s, start, end, page_size, response, filter_dead) + + result_count, _ = get_result_and_page_count(response, results, page_size) + + return results, result_count diff --git a/api/catalog/api/controllers/elasticsearch/search.py b/api/catalog/api/controllers/elasticsearch/search.py new file mode 100644 index 000000000..2c2b7cebf --- /dev/null +++ b/api/catalog/api/controllers/elasticsearch/search.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import json +import logging as log +import pprint +from typing import List, Literal, Tuple, Union + +from django.conf import settings + +from elasticsearch.exceptions import RequestError +from elasticsearch_dsl import Q, Search +from elasticsearch_dsl.response import Hit + +from catalog.api.controllers.elasticsearch.utils import ( + exclude_filtered_providers, + get_query_slice, + get_result_and_page_count, + post_process_results, +) +from catalog.api.serializers.media_serializers import MediaSearchRequestSerializer + + +def _quote_escape(query_string: str) -> str: + """ + If there are any unmatched quotes in the query supplied by the user, ignore + them by escaping. + + :param query_string: the string in which to escape unbalanced quotes + :return: the given string, if the quotes are balanced, the escaped string otherwise + """ + + num_quotes = query_string.count('"') + if num_quotes % 2 == 1: + return query_string.replace('"', '\\"') + else: + return query_string + + +def _apply_filter( + s: Search, + query_serializer: MediaSearchRequestSerializer, + basis: Union[str, tuple[str, str]], + behaviour: Literal["filter", "exclude"] = "filter", +) -> Search: + """ + Parse and apply a filter from the search parameters serializer. The + parameter key is assumed to have the same name as the corresponding + Elasticsearch property. Each parameter value is assumed to be a comma + separated list encoded as a string. + + :param s: the search query to issue to Elasticsearch + :param query_serializer: the ``MediaSearchRequestSerializer`` object with the query + :param basis: the name of the field in the serializer and Elasticsearch + :param behaviour: whether to accept (``filter``) or reject (``exclude``) the hit + :return: the modified search query + """ + + search_params = query_serializer.data + if isinstance(basis, tuple): + serializer_field, es_field = basis + else: + serializer_field = es_field = basis + if serializer_field in search_params: + filters = [] + for arg in search_params[serializer_field].split(","): + filters.append(Q("term", **{es_field: arg})) + method = getattr(s, behaviour) # can be ``s.filter`` or ``s.exclude`` + return method("bool", should=filters) + else: + return s + + +def perform_search( + query_serializer: MediaSearchRequestSerializer, + index: Literal["image", "audio"], + ip: int, +) -> Tuple[List[Hit], int, int]: + """ + Perform a ranked, paginated search based on the query and filters given in the + search request. + + :param query_serializer: the ``MediaSearchRequestSerializer`` object with the query + :param index: The Elasticsearch index to search (e.g. 'image') + :param ip: the users' hashed IP to consistently route to the same ES shard + :return: the list of search results with the page and result count + """ + + s = Search(using="default", index=index) + search_params = query_serializer.data + + rules: dict[Literal["filter", "exclude"], list[Union[str, tuple[str, str]]]] = { + "filter": [ + "extension", + "category", + ("categories", "category"), + "aspect_ratio", + "size", + "length", + "source", + ("license", "license.keyword"), + ("license_type", "license.keyword"), + ], + "exclude": [ + ("excluded_source", "source"), + ], + } + for behaviour, bases in rules.items(): + for basis in bases: + s = _apply_filter(s, query_serializer, basis, behaviour) + + # Exclude mature content + if not search_params["mature"]: + s = s.exclude("term", mature=True) + # Exclude sources with ``filter_content`` enabled + s = exclude_filtered_providers(s) + + # Search either by generic multimatch or by "advanced search" with + # individual field-level queries specified. + + search_fields = ["tags.name", "title", "description"] + if "q" in search_params: + query = _quote_escape(search_params["q"]) + s = s.query( + "simple_query_string", + query=query, + fields=search_fields, + default_operator="AND", + ) + # Boost exact matches + quotes_stripped = query.replace('"', "") + exact_match_boost = Q( + "simple_query_string", + fields=["title"], + query=f'"{quotes_stripped}"', + boost=10000, + ) + s.query = Q("bool", must=s.query, should=exact_match_boost) + else: + query_bases = ["creator", "title", ("tags", "tags.name")] + for query_basis in query_bases: + if isinstance(query_basis, tuple): + serializer_field, es_field = query_basis + else: + serializer_field = es_field = query_basis + if serializer_field in search_params: + value = _quote_escape(search_params[serializer_field]) + s = s.query("simple_query_string", fields=[es_field], query=value) + + if settings.USE_RANK_FEATURES: + feature_boost = {"standardized_popularity": 10000} + rank_queries = [] + for field, boost in feature_boost.items(): + rank_queries.append(Q("rank_feature", field=field, boost=boost)) + s.query = Q("bool", must=s.query, should=rank_queries) + + # Use highlighting to determine which fields contribute to the selection of + # top results. + s = s.highlight(*search_fields) + s = s.highlight_options(order="score") + + # Route users to the same Elasticsearch worker node to reduce + # pagination inconsistencies and increase cache hits. + s = s.params(preference=str(ip), request_timeout=7) + + # Paginate + start, end = get_query_slice( + s, + search_params["page_size"], + search_params["page"], + search_params["filter_dead"], + ) + s = s[start:end] + + try: + if settings.VERBOSE_ES_RESPONSE: + log.info(pprint.pprint(s.to_dict())) + search_response = s.execute() + log.info( + f"query={json.dumps(s.to_dict())}," f" es_took_ms={search_response.took}" + ) + if settings.VERBOSE_ES_RESPONSE: + log.info(pprint.pprint(search_response.to_dict())) + except RequestError as e: + raise ValueError(e) + + results = post_process_results( + s, + start, + end, + search_params["page_size"], + search_response, + search_params["filter_dead"], + ) + + result_count, page_count = get_result_and_page_count( + search_response, results, search_params["page_size"] + ) + return results, page_count, result_count diff --git a/api/catalog/api/controllers/elasticsearch/stats.py b/api/catalog/api/controllers/elasticsearch/stats.py new file mode 100644 index 000000000..f37e54812 --- /dev/null +++ b/api/catalog/api/controllers/elasticsearch/stats.py @@ -0,0 +1,51 @@ +import logging as log +from typing import Literal + +from django.core.cache import cache + +from elasticsearch.exceptions import NotFoundError +from elasticsearch_dsl import Search + + +SOURCE_CACHE_TIMEOUT = 60 * 20 # seconds + + +def get_stats(index: Literal["image", "audio"]): + """ + Given an index, find all available data sources and return their counts. This data + is cached in Redis. See ``load_sample_data.sh`` for example of clearing the cache. + + :param index: the Elasticsearch index name + :return: a dictionary mapping sources to the count of their media items + """ + + source_cache_name = "sources-" + index + try: + sources = cache.get(key=source_cache_name) + if sources is not None: + return sources + except ValueError: + log.warning("Source cache fetch failed") + + # Don't increase `size` without reading this issue first: + # https://github.com/elastic/elasticsearch/issues/18838 + size = 100 + try: + s = Search(using="default", index=index) + s.aggs.bucket( + "unique_sources", + "terms", + field="source.keyword", + size=size, + order={"_key": "desc"}, + ) + results = s.execute() + buckets = results["aggregations"]["unique_sources"]["buckets"] + sources = {result["key"]: result["doc_count"] for result in buckets} + except NotFoundError: + sources = {} + + if sources: + cache.set(key=source_cache_name, timeout=SOURCE_CACHE_TIMEOUT, value=sources) + + return sources diff --git a/api/catalog/api/controllers/elasticsearch/utils.py b/api/catalog/api/controllers/elasticsearch/utils.py new file mode 100644 index 000000000..cb0bb6e12 --- /dev/null +++ b/api/catalog/api/controllers/elasticsearch/utils.py @@ -0,0 +1,160 @@ +from itertools import accumulate +from math import ceil +from typing import List, Optional, Tuple + +from django.core.cache import cache + +from elasticsearch_dsl import Search +from elasticsearch_dsl.response import Hit, Response + +from catalog.api.models import ContentProvider +from catalog.api.utils.dead_link_mask import get_query_hash, get_query_mask +from catalog.api.utils.validate_images import validate_images + + +FILTER_CACHE_TIMEOUT = 30 +DEAD_LINK_RATIO = 1 / 2 +ELASTICSEARCH_MAX_RESULT_WINDOW = 10000 + + +def exclude_filtered_providers(s: Search) -> Search: + """ + Hide data sources from the catalog dynamically. This excludes providers with + ``filter_content`` enabled from the search results. + + :param s: the search query to issue to Elasticsearch + :return: the modified search query + """ + + filter_cache_key = "filtered_providers" + filtered_providers = cache.get(key=filter_cache_key) + if filtered_providers is None: + filtered_providers = ContentProvider.objects.filter(filter_content=True).values( + "provider_identifier" + ) + cache.set( + key=filter_cache_key, + timeout=FILTER_CACHE_TIMEOUT, + value=filtered_providers, + ) + if len(filtered_providers) != 0: + to_exclude = [f["provider_identifier"] for f in filtered_providers] + s = s.exclude("terms", provider=to_exclude) + return s + + +def paginate_with_dead_link_mask( + s: Search, page_size: int, page: int +) -> Tuple[int, int]: + """ + Given a query, a page and page_size, return the start and end + of the slice of results. + + :param s: The elasticsearch Search object + :param page_size: How big the page should be. + :param page: The page number. + :return: Tuple of start and end. + """ + query_hash = get_query_hash(s) + query_mask = get_query_mask(query_hash) + if not query_mask: + start = 0 + end = ceil(page_size * page / (1 - DEAD_LINK_RATIO)) + elif page_size * (page - 1) > sum(query_mask): + start = len(query_mask) + end = ceil(page_size * page / (1 - DEAD_LINK_RATIO)) + else: + accu_query_mask = list(accumulate(query_mask)) + start = 0 + if page > 1: + try: + start = accu_query_mask.index(page_size * (page - 1) + 1) + except ValueError: + start = accu_query_mask.index(page_size * (page - 1)) + 1 + if page_size * page > sum(query_mask): + end = ceil(page_size * page / (1 - DEAD_LINK_RATIO)) + else: + end = accu_query_mask.index(page_size * page) + 1 + return start, end + + +def get_query_slice( + s: Search, page_size: int, page: int, filter_dead: Optional[bool] = False +) -> Tuple[int, int]: + """ + Select the start and end of the search results for this query. + """ + if filter_dead: + start_slice, end_slice = paginate_with_dead_link_mask(s, page_size, page) + else: + # Paginate search query. + start_slice = page_size * (page - 1) + end_slice = page_size * page + if start_slice + end_slice > ELASTICSEARCH_MAX_RESULT_WINDOW: + raise ValueError("Deep pagination is not allowed.") + return start_slice, end_slice + + +def post_process_results( + s, start, end, page_size, search_results, filter_dead +) -> List[Hit]: + """ + After fetching the search results from the back end, iterate through the + results, perform image validation, and route certain thumbnails through our + proxy. + + :param s: The Elasticsearch Search object. + :param start: The start of the result slice. + :param end: The end of the result slice. + :param search_results: The Elasticsearch response object containing search + results. + :param filter_dead: Whether images should be validated. + :return: List of results. + """ + results = [] + to_validate = [] + for res in search_results: + if hasattr(res.meta, "highlight"): + res.fields_matched = dir(res.meta.highlight) + to_validate.append(res.url) + results.append(res) + + if filter_dead: + query_hash = get_query_hash(s) + validate_images(query_hash, start, results, to_validate) + + if len(results) < page_size: + end += int(end / 2) + if start + end > ELASTICSEARCH_MAX_RESULT_WINDOW: + return results + + s = s[start:end] + search_response = s.execute() + + return post_process_results( + s, start, end, page_size, search_response, filter_dead + ) + return results[:page_size] + + +def get_result_and_page_count( + response_obj: Response, results: List[Hit], page_size: int +) -> Tuple[int, int]: + """ + Elasticsearch does not allow deep pagination of ranked queries. + Adjust returned page count to reflect this. + + :param response_obj: The original Elasticsearch response object. + :param results: The list of filtered result Hits. + :return: Result and page count. + """ + result_count = response_obj.hits.total.value + natural_page_count = int(result_count / page_size) + if natural_page_count % page_size != 0: + natural_page_count += 1 + last_allowed_page = int((5000 + page_size / 2) / page_size) + page_count = min(natural_page_count, last_allowed_page) + if len(results) < page_size and page_count == 0: + result_count = len(results) + + return result_count, page_count diff --git a/api/catalog/api/controllers/search_controller.py b/api/catalog/api/controllers/search_controller.py deleted file mode 100644 index 0a9fb9801..000000000 --- a/api/catalog/api/controllers/search_controller.py +++ /dev/null @@ -1,466 +0,0 @@ -from __future__ import annotations - -import json -import logging as log -import pprint -from itertools import accumulate -from math import ceil -from typing import List, Literal, Optional, Tuple - -from django.conf import settings -from django.core.cache import cache -from rest_framework.request import Request - -from aws_requests_auth.aws_auth import AWSRequestsAuth -from elasticsearch import Elasticsearch, RequestsHttpConnection -from elasticsearch.exceptions import NotFoundError, RequestError -from elasticsearch_dsl import Q, Search, connections -from elasticsearch_dsl.query import Query -from elasticsearch_dsl.response import Hit, Response - -import catalog.api.models as models -from catalog.api.serializers.media_serializers import MediaSearchRequestSerializer -from catalog.api.utils.dead_link_mask import get_query_hash, get_query_mask -from catalog.api.utils.validate_images import validate_images - - -ELASTICSEARCH_MAX_RESULT_WINDOW = 10000 -SOURCE_CACHE_TIMEOUT = 60 * 20 -FILTER_CACHE_TIMEOUT = 30 -DEAD_LINK_RATIO = 1 / 2 -THUMBNAIL = "thumbnail" -URL = "url" -PROVIDER = "provider" -DEEP_PAGINATION_ERROR = "Deep pagination is not allowed." -QUERY_SPECIAL_CHARACTER_ERROR = "Unescaped special characters are not allowed." - - -class RankFeature(Query): - name = "rank_feature" - - -def _paginate_with_dead_link_mask( - s: Search, page_size: int, page: int -) -> Tuple[int, int]: - """ - Given a query, a page and page_size, return the start and end - of the slice of results. - - :param s: The elasticsearch Search object - :param page_size: How big the page should be. - :param page: The page number. - :return: Tuple of start and end. - """ - query_hash = get_query_hash(s) - query_mask = get_query_mask(query_hash) - if not query_mask: - start = 0 - end = ceil(page_size * page / (1 - DEAD_LINK_RATIO)) - elif page_size * (page - 1) > sum(query_mask): - start = len(query_mask) - end = ceil(page_size * page / (1 - DEAD_LINK_RATIO)) - else: - accu_query_mask = list(accumulate(query_mask)) - start = 0 - if page > 1: - try: - start = accu_query_mask.index(page_size * (page - 1) + 1) - except ValueError: - start = accu_query_mask.index(page_size * (page - 1)) + 1 - if page_size * page > sum(query_mask): - end = ceil(page_size * page / (1 - DEAD_LINK_RATIO)) - else: - end = accu_query_mask.index(page_size * page) + 1 - return start, end - - -def _get_query_slice( - s: Search, page_size: int, page: int, filter_dead: Optional[bool] = False -) -> Tuple[int, int]: - """ - Select the start and end of the search results for this query. - """ - if filter_dead: - start_slice, end_slice = _paginate_with_dead_link_mask(s, page_size, page) - else: - # Paginate search query. - start_slice = page_size * (page - 1) - end_slice = page_size * page - if start_slice + end_slice > ELASTICSEARCH_MAX_RESULT_WINDOW: - raise ValueError(DEEP_PAGINATION_ERROR) - return start_slice, end_slice - - -def _quote_escape(query_string): - """ - If there are any unmatched quotes in the query supplied by the user, ignore - them. - """ - num_quotes = query_string.count('"') - if num_quotes % 2 == 1: - return query_string.replace('"', '\\"') - else: - return query_string - - -def _post_process_results( - s, start, end, page_size, search_results, request, filter_dead -) -> List[Hit]: - """ - After fetching the search results from the back end, iterate through the - results, perform image validation, and route certain thumbnails through our - proxy. - - :param s: The Elasticsearch Search object. - :param start: The start of the result slice. - :param end: The end of the result slice. - :param search_results: The Elasticsearch response object containing search - results. - :param request: The Django request object, used to build a "reversed" URL - to detail pages. - :param filter_dead: Whether images should be validated. - :return: List of results. - """ - results = [] - to_validate = [] - for res in search_results: - if hasattr(res.meta, "highlight"): - res.fields_matched = dir(res.meta.highlight) - to_validate.append(res.url) - results.append(res) - - if filter_dead: - query_hash = get_query_hash(s) - validate_images(query_hash, start, results, to_validate) - - if len(results) < page_size: - end += int(end / 2) - if start + end > ELASTICSEARCH_MAX_RESULT_WINDOW: - return results - - s = s[start:end] - search_response = s.execute() - - return _post_process_results( - s, start, end, page_size, search_response, request, filter_dead - ) - return results[:page_size] - - -def _apply_filter( - s: Search, - search_params: MediaSearchRequestSerializer, - serializer_field: str, - es_field: Optional[str] = None, - behaviour: Literal["filter", "exclude"] = "filter", -): - """ - Parse and apply a filter from the search parameters serializer. The - parameter key is assumed to have the same name as the corresponding - Elasticsearch property. Each parameter value is assumed to be a comma - separated list encoded as a string. - - :param s: The ``Search`` instance to apply the filter to - :param search_params: the serializer instance containing user input - :param serializer_field: the name of the parameter field in ``search_params`` - :param es_field: the corresponding parameter name in Elasticsearch - :param behaviour: whether to accept (``filter``) or reject (``exclude``) the hit - :return: the input ``Search`` object with the filters applied - """ - - if serializer_field in search_params.data: - filters = [] - for arg in search_params.data[serializer_field].split(","): - _param = es_field or serializer_field - args = {"name_or_query": "term", _param: arg} - filters.append(Q(**args)) - method = getattr(s, behaviour) - return method("bool", should=filters) - else: - return s - - -def _exclude_filtered(s: Search): - """ - Hide data sources from the catalog dynamically. - """ - filter_cache_key = "filtered_providers" - filtered_providers = cache.get(key=filter_cache_key) - if not filtered_providers: - filtered_providers = models.ContentProvider.objects.filter( - filter_content=True - ).values("provider_identifier") - cache.set( - key=filter_cache_key, timeout=FILTER_CACHE_TIMEOUT, value=filtered_providers - ) - to_exclude = [f["provider_identifier"] for f in filtered_providers] - s = s.exclude("terms", provider=to_exclude) - return s - - -def _exclude_mature_by_param(s: Search, search_params): - if not search_params.data["mature"]: - s = s.exclude("term", mature=True) - return s - - -def search( - search_params: MediaSearchRequestSerializer, - index: Literal["image", "audio"], - page_size: int, - ip: int, - request: Request, - filter_dead: bool, - page: int = 1, -) -> Tuple[List[Hit], int, int]: - """ - Given a set of keywords and an optional set of filters, perform a ranked - paginated search. - - :param search_params: Search parameters. See - :class: `ImageSearchQueryStringSerializer`. - :param index: The Elasticsearch index to search (e.g. 'image') - :param page_size: The number of results to return per page. - :param ip: The user's hashed IP. Hashed IPs are used to anonymously but - uniquely identify users exclusively for ensuring query consistency across - Elasticsearch shards. - :param request: Django's request object. - :param filter_dead: Whether dead links should be removed. - :param page: The results page number. - :return: Tuple with a List of Hits from elasticsearch, the total count of - pages, and number of results. - """ - search_client = Search(index=index) - - s = search_client - # Apply term filters. Each tuple pairs a filter's parameter name in the API - # with its corresponding field in Elasticsearch. "None" means that the - # names are identical. - filters = [ - ("extension", None), - ("category", None), - ("categories", "category"), - ("length", None), - ("aspect_ratio", None), - ("size", None), - ("source", None), - ("license", "license__keyword"), - ("license_type", "license__keyword"), - ] - for serializer_field, es_field in filters: - if serializer_field in search_params.data: - s = _apply_filter(s, search_params, serializer_field, es_field) - - exclude = [ - ("excluded_source", "source"), - ] - for serializer_field, es_field in exclude: - if serializer_field in search_params.data: - s = _apply_filter(s, search_params, serializer_field, es_field, "exclude") - - # Exclude mature content and disabled sources - s = _exclude_mature_by_param(s, search_params) - s = _exclude_filtered(s) - - # Search either by generic multimatch or by "advanced search" with - # individual field-level queries specified. - search_fields = ["tags.name", "title", "description"] - if "q" in search_params.data: - query = _quote_escape(search_params.data["q"]) - s = s.query( - "simple_query_string", - query=query, - fields=search_fields, - default_operator="AND", - ) - # Boost exact matches - quotes_stripped = query.replace('"', "") - exact_match_boost = Q( - "simple_query_string", - fields=["title"], - query=f'"{quotes_stripped}"', - boost=10000, - ) - s = search_client.query(Q("bool", must=s.query, should=exact_match_boost)) - else: - if "creator" in search_params.data: - creator = _quote_escape(search_params.data["creator"]) - s = s.query("simple_query_string", query=creator, fields=["creator"]) - if "title" in search_params.data: - title = _quote_escape(search_params.data["title"]) - s = s.query("simple_query_string", query=title, fields=["title"]) - if "tags" in search_params.data: - tags = _quote_escape(search_params.data["tags"]) - s = s.query("simple_query_string", fields=["tags.name"], query=tags) - - if settings.USE_RANK_FEATURES: - feature_boost = {"standardized_popularity": 10000} - rank_queries = [] - for field, boost in feature_boost.items(): - rank_queries.append(Q("rank_feature", field=field, boost=boost)) - s = search_client.query(Q("bool", must=s.query, should=rank_queries)) - - # Use highlighting to determine which fields contribute to the selection of - # top results. - s = s.highlight(*search_fields) - s = s.highlight_options(order="score") - s.extra(track_scores=True) - # Route users to the same Elasticsearch worker node to reduce - # pagination inconsistencies and increase cache hits. - s = s.params(preference=str(ip), request_timeout=7) - # Paginate - start, end = _get_query_slice(s, page_size, page, filter_dead) - s = s[start:end] - try: - if settings.VERBOSE_ES_RESPONSE: - log.info(pprint.pprint(s.to_dict())) - search_response = s.execute() - log.info( - f"query={json.dumps(s.to_dict())}," f" es_took_ms={search_response.took}" - ) - if settings.VERBOSE_ES_RESPONSE: - log.info(pprint.pprint(search_response.to_dict())) - except RequestError as e: - raise ValueError(e) - results = _post_process_results( - s, start, end, page_size, search_response, request, filter_dead - ) - - result_count, page_count = _get_result_and_page_count( - search_response, results, page_size - ) - return results, page_count, result_count - - -def related_media(uuid, index, request, filter_dead): - """ - Given a UUID, find related search results. - """ - search_client = Search(index=index) - - # Convert UUID to sequential ID. - item = search_client - item = item.query("match", identifier=uuid) - _id = item.execute().hits[0].id - - s = search_client - s = s.query( - "more_like_this", - fields=["tags.name", "title", "creator"], - like={"_index": index, "_id": _id}, - min_term_freq=1, - max_query_terms=50, - ) - # Never show mature content in recommendations. - s = s.exclude("term", mature=True) - s = _exclude_filtered(s) - page_size = 10 - page = 1 - start, end = _get_query_slice(s, page_size, page, filter_dead) - s = s[start:end] - response = s.execute() - results = _post_process_results( - s, start, end, page_size, response, request, filter_dead - ) - - result_count, _ = _get_result_and_page_count(response, results, page_size) - - return results, result_count - - -def get_sources(index): - """ - Given an index, find all available data sources and return their counts. - - :param index: An Elasticsearch index, such as `'image'`. - :return: A dictionary mapping sources to the count of their images.` - """ - source_cache_name = "sources-" + index - cache_fetch_failed = False - try: - sources = cache.get(key=source_cache_name) - except ValueError: - cache_fetch_failed = True - sources = None - log.warning("Source cache fetch failed due to corruption") - if type(sources) == list or cache_fetch_failed: - # Invalidate old provider format. - cache.delete(key=source_cache_name) - if not sources: - # Don't increase `size` without reading this issue first: - # https://github.com/elastic/elasticsearch/issues/18838 - size = 100 - agg_body = { - "aggs": { - "unique_sources": { - "terms": { - "field": "source.keyword", - "size": size, - "order": {"_key": "desc"}, - } - } - } - } - try: - results = es.search(index=index, body=agg_body, request_cache=True) - buckets = results["aggregations"]["unique_sources"]["buckets"] - except NotFoundError: - buckets = [{"key": "none_found", "doc_count": 0}] - sources = {result["key"]: result["doc_count"] for result in buckets} - cache.set(key=source_cache_name, timeout=SOURCE_CACHE_TIMEOUT, value=sources) - return sources - - -def _elasticsearch_connect(): - """ - Connect to configured Elasticsearch domain. - - :return: An Elasticsearch connection object. - """ - auth = AWSRequestsAuth( - aws_access_key=settings.AWS_ACCESS_KEY_ID, - aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, - aws_host=settings.ELASTICSEARCH_URL, - aws_region=settings.ELASTICSEARCH_AWS_REGION, - aws_service="es", - ) - auth.encode = lambda x: bytes(x.encode("utf-8")) - _es = Elasticsearch( - host=settings.ELASTICSEARCH_URL, - port=settings.ELASTICSEARCH_PORT, - connection_class=RequestsHttpConnection, - timeout=10, - max_retries=1, - retry_on_timeout=True, - http_auth=auth, - wait_for_status="yellow", - ) - _es.info() - return _es - - -es = _elasticsearch_connect() -connections.connections.add_connection("default", es) - - -def _get_result_and_page_count( - response_obj: Response, results: List[Hit], page_size: int -) -> Tuple[int, int]: - """ - Elasticsearch does not allow deep pagination of ranked queries. - Adjust returned page count to reflect this. - - :param response_obj: The original Elasticsearch response object. - :param results: The list of filtered result Hits. - :return: Result and page count. - """ - result_count = response_obj.hits.total.value - natural_page_count = int(result_count / page_size) - if natural_page_count % page_size != 0: - natural_page_count += 1 - last_allowed_page = int((5000 + page_size / 2) / page_size) - page_count = min(natural_page_count, last_allowed_page) - if len(results) < page_size and page_count == 0: - result_count = len(results) - - return result_count, page_count diff --git a/api/catalog/api/models/audio.py b/api/catalog/api/models/audio.py index 9e8e269d1..77171f74e 100644 --- a/api/catalog/api/models/audio.py +++ b/api/catalog/api/models/audio.py @@ -1,9 +1,9 @@ +from django.conf import settings from django.contrib.postgres.fields import ArrayField from django.db import models from uuslug import uuslug -import catalog.api.controllers.search_controller as search_controller from catalog.api.models import OpenLedgerModel from catalog.api.models.media import ( AbstractAltFile, @@ -250,7 +250,7 @@ class MatureAudio(AbstractMatureMedia): """Stores all audios that have been flagged as 'mature'.""" def delete(self, *args, **kwargs): - es = search_controller.es + es = settings.ES aud = Audio.objects.get(identifier=self.identifier) es_id = aud.id es.update(index="audio", id=es_id, body={"doc": {"mature": False}}) diff --git a/api/catalog/api/models/image.py b/api/catalog/api/models/image.py index 009fde502..a848a9470 100644 --- a/api/catalog/api/models/image.py +++ b/api/catalog/api/models/image.py @@ -1,8 +1,8 @@ +from django.conf import settings from django.db import models from uuslug import uuslug -import catalog.api.controllers.search_controller as search_controller from catalog.api.models.media import ( AbstractDeletedMedia, AbstractMatureMedia, @@ -77,7 +77,7 @@ class MatureImage(AbstractMatureMedia): """Stores all images that have been flagged as 'mature'.""" def delete(self, *args, **kwargs): - es = search_controller.es + es = settings.ES img = Image.objects.get(identifier=self.identifier) es_id = img.id es.update(index="image", id=es_id, body={"doc": {"mature": False}}) diff --git a/api/catalog/api/models/media.py b/api/catalog/api/models/media.py index 7ca538b36..270e0dc46 100644 --- a/api/catalog/api/models/media.py +++ b/api/catalog/api/models/media.py @@ -1,10 +1,10 @@ import mimetypes +from django.conf import settings from django.contrib.postgres.fields import ArrayField from django.db import models from django.utils.html import format_html -import catalog.api.controllers.search_controller as search_controller from catalog.api.models.base import OpenLedgerModel from catalog.api.models.mixins import ( ForeignIdentifierMixin, @@ -189,7 +189,7 @@ def save(self, *args, **kwargs): update_required = {MATURE_FILTERED, DEINDEXED} # ES needs updating if self.status in update_required: - es = search_controller.es + es = settings.ES try: media = media_class.objects.get(identifier=self.identifier) except media_class.DoesNotExist: diff --git a/api/catalog/api/serializers/media_serializers.py b/api/catalog/api/serializers/media_serializers.py index 4d10c0583..875f0ca35 100644 --- a/api/catalog/api/serializers/media_serializers.py +++ b/api/catalog/api/serializers/media_serializers.py @@ -3,7 +3,7 @@ from rest_framework import serializers from catalog.api.constants.licenses import LICENSE_GROUPS -from catalog.api.controllers import search_controller +from catalog.api.controllers.elasticsearch.stats import get_stats from catalog.api.models.media import AbstractMedia from catalog.api.serializers.base import BaseModelSerializer from catalog.api.serializers.fields import SchemableHyperlinkedIdentityField @@ -104,11 +104,22 @@ class MediaSearchRequestSerializer(serializers.Serializer): required=False, default=False, ) + page = serializers.IntegerField( + min_value=1, + default=1, + help_text="The index of the page of the results to show.", + ) + page_size = serializers.IntegerField( + min_value=1, + max_value=500, + default=20, + help_text="The number of results to show in one page.", + ) @staticmethod def _truncate(value): max_length = 200 - return value if len(value) <= max_length else value[:max_length] + return value[:max_length] def validate_q(self, value): return self._truncate(value) @@ -350,7 +361,7 @@ class MediaSearchRequestSourceSerializer(serializers.Serializer): _field_attrs = { "help_text": make_comma_separated_help_text( - search_controller.get_sources(media_type).keys(), "data sources" + get_stats(media_type).keys(), "data sources" ), "required": False, } @@ -368,7 +379,7 @@ class MediaSearchRequestSourceSerializer(serializers.Serializer): def validate_source_field(value): """Checks whether source is a valid source.""" - allowed_sources = list(search_controller.get_sources(media_type).keys()) + allowed_sources = list(get_stats(media_type).keys()) sources = value.lower().split(",") sources = [source for source in sources if source in allowed_sources] value = ",".join(sources) diff --git a/api/catalog/api/utils/pagination.py b/api/catalog/api/utils/pagination.py index 484e455fe..cc813bbd5 100644 --- a/api/catalog/api/utils/pagination.py +++ b/api/catalog/api/utils/pagination.py @@ -1,48 +1,17 @@ from rest_framework.pagination import PageNumberPagination from rest_framework.response import Response -from catalog.api.utils.exceptions import get_api_exception - class StandardPagination(PageNumberPagination): - page_size_query_param = "page_size" - page_query_param = "page" + page_query_param = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.result_count = None # populated later self.page_count = None # populated later - self._page_size = 20 - self._page = None - - @property - def page_size(self): - """the number of results to show in one page""" - return self._page_size - - @page_size.setter - def page_size(self, value): - if value is None or not str(value).isnumeric(): - return - value = int(value) # convert str params to int - if value <= 0 or value > 500: - raise get_api_exception("Page size must be between 0 & 500.", 400) - self._page_size = value - - @property - def page(self): - """the current page number being served""" - return self._page - - @page.setter - def page(self, value): - if value is None or not str(value).isnumeric(): - value = 1 - value = int(value) # convert str params to int - if value <= 0: - raise get_api_exception("Page must be greater than 0.", 400) - self._page = value + self.page_size = 20 + self.page = 1 def get_paginated_response(self, data): return Response( diff --git a/api/catalog/api/views/media_views.py b/api/catalog/api/views/media_views.py index aefcd15c5..151551d48 100644 --- a/api/catalog/api/views/media_views.py +++ b/api/catalog/api/views/media_views.py @@ -11,7 +11,9 @@ from rest_framework.response import Response from rest_framework.viewsets import ReadOnlyModelViewSet -from catalog.api.controllers import search_controller +from catalog.api.controllers.elasticsearch.related import related_media +from catalog.api.controllers.elasticsearch.search import perform_search +from catalog.api.controllers.elasticsearch.stats import get_stats from catalog.api.models import ContentProvider from catalog.api.serializers.provider_serializers import ProviderSerializer from catalog.api.utils.exceptions import get_api_exception @@ -52,28 +54,21 @@ def get_queryset(self): # Standard actions def list(self, request, *_, **__): - self.paginator.page_size = request.query_params.get("page_size") - page_size = self.paginator.page_size - self.paginator.page = request.query_params.get("page") - page = self.paginator.page - params = self.query_serializer_class(data=request.query_params) params.is_valid(raise_exception=True) + self.paginator.page_size = params.validated_data["page_size"] + self.paginator.page = params.validated_data["page"] + hashed_ip = hash(self._get_user_ip(request)) qa = params.validated_data["qa"] - filter_dead = params.validated_data["filter_dead"] search_index = self.qa_index if qa else self.default_index try: - results, num_pages, num_results = search_controller.search( + results, num_pages, num_results = perform_search( params, search_index, - page_size, hashed_ip, - request, - filter_dead, - page, ) self.paginator.page_count = num_pages self.paginator.result_count = num_results @@ -87,7 +82,7 @@ def list(self, request, *_, **__): @action(detail=False, serializer_class=ProviderSerializer, pagination_class=None) def stats(self, *_, **__): - source_counts = search_controller.get_sources(self.default_index) + source_counts = get_stats(self.default_index) context = self.get_serializer_context() | { "source_counts": source_counts, } @@ -101,10 +96,9 @@ def stats(self, *_, **__): @action(detail=True) def related(self, request, identifier=None, *_, **__): try: - results, num_results = search_controller.related_media( + results, num_results = related_media( uuid=identifier, index=self.default_index, - request=request, filter_dead=True, ) self.paginator.result_count = num_results diff --git a/api/catalog/settings.py b/api/catalog/settings.py index 802027561..9a70b339e 100644 --- a/api/catalog/settings.py +++ b/api/catalog/settings.py @@ -14,7 +14,10 @@ from socket import gethostbyname, gethostname import sentry_sdk +from aws_requests_auth.aws_auth import AWSRequestsAuth from decouple import config +from elasticsearch import Elasticsearch, RequestsHttpConnection +from elasticsearch_dsl import connections from sentry_sdk.integrations.django import DjangoIntegration from catalog.logger import LOGGING as LOGGING_CONF @@ -340,3 +343,38 @@ send_default_pii=False, environment=ENVIRONMENT, ) + + +# Elasticsearch connection + + +def _elasticsearch_connect(): + """ + Connect to configured Elasticsearch domain. + + :return: An Elasticsearch connection object. + """ + auth = AWSRequestsAuth( + aws_access_key=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + aws_host=ELASTICSEARCH_URL, + aws_region=ELASTICSEARCH_AWS_REGION, + aws_service="es", + ) + auth.encode = lambda x: bytes(x.encode("utf-8")) + _es = Elasticsearch( + host=ELASTICSEARCH_URL, + port=ELASTICSEARCH_PORT, + connection_class=RequestsHttpConnection, + timeout=10, + max_retries=1, + retry_on_timeout=True, + http_auth=auth, + wait_for_status="yellow", + ) + _es.info() + return _es + + +ES = _elasticsearch_connect() +connections.add_connection("default", ES) diff --git a/load_sample_data.sh b/load_sample_data.sh index b22a15c1b..751821d3c 100755 --- a/load_sample_data.sh +++ b/load_sample_data.sh @@ -109,5 +109,6 @@ just ingest-upstream "image" just wait-for-index "image" # Clear source cache since it's out of date after data has been loaded +# See `api/catalog/api/controllers/elasticsearch/stats.py` docker-compose exec -T "$CACHE_SERVICE_NAME" /bin/bash -c "echo \"del :1:sources-image\" | redis-cli" docker-compose exec -T "$CACHE_SERVICE_NAME" /bin/bash -c "echo \"del :1:sources-audio\" | redis-cli"