diff --git a/api/catalog/api/controllers/elasticsearch/__init__.py b/api/catalog/api/controllers/elasticsearch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/api/catalog/api/controllers/elasticsearch/related.py b/api/catalog/api/controllers/elasticsearch/related.py deleted file mode 100644 index df28513b6..000000000 --- a/api/catalog/api/controllers/elasticsearch/related.py +++ /dev/null @@ -1,40 +0,0 @@ -from elasticsearch_dsl import Search - -from catalog.api.controllers.elasticsearch.utils import ( - exclude_filtered_providers, - get_query_slice, - get_result_and_page_count, - post_process_results, -) - - -def related_media(uuid, index, filter_dead): - """ - Given a UUID, find related search results. - """ - search_client = Search(using="default", index=index) - - # Convert UUID to sequential ID. - item = search_client.query("match", identifier=uuid) - _id = item.execute().hits[0].id - - s = search_client.query( - "more_like_this", - fields=["tags.name", "title", "creator"], - like={"_index": index, "_id": _id}, - min_term_freq=1, - max_query_terms=50, - ) - # Never show mature content in recommendations. - s = s.exclude("term", mature=True) - s = exclude_filtered_providers(s) - page_size = 10 - page = 1 - start, end = get_query_slice(s, page_size, page, filter_dead) - s = s[start:end] - response = s.execute() - results = post_process_results(s, start, end, page_size, response, filter_dead) - - result_count, _ = get_result_and_page_count(response, results, page_size) - - return results, result_count diff --git a/api/catalog/api/controllers/elasticsearch/search.py b/api/catalog/api/controllers/elasticsearch/search.py deleted file mode 100644 index 2c2b7cebf..000000000 --- a/api/catalog/api/controllers/elasticsearch/search.py +++ /dev/null @@ -1,198 +0,0 @@ -from __future__ import annotations - -import json -import logging as log -import pprint -from typing import List, Literal, Tuple, Union - -from django.conf import settings - -from elasticsearch.exceptions import RequestError -from elasticsearch_dsl import Q, Search -from elasticsearch_dsl.response import Hit - -from catalog.api.controllers.elasticsearch.utils import ( - exclude_filtered_providers, - get_query_slice, - get_result_and_page_count, - post_process_results, -) -from catalog.api.serializers.media_serializers import MediaSearchRequestSerializer - - -def _quote_escape(query_string: str) -> str: - """ - If there are any unmatched quotes in the query supplied by the user, ignore - them by escaping. - - :param query_string: the string in which to escape unbalanced quotes - :return: the given string, if the quotes are balanced, the escaped string otherwise - """ - - num_quotes = query_string.count('"') - if num_quotes % 2 == 1: - return query_string.replace('"', '\\"') - else: - return query_string - - -def _apply_filter( - s: Search, - query_serializer: MediaSearchRequestSerializer, - basis: Union[str, tuple[str, str]], - behaviour: Literal["filter", "exclude"] = "filter", -) -> Search: - """ - Parse and apply a filter from the search parameters serializer. The - parameter key is assumed to have the same name as the corresponding - Elasticsearch property. Each parameter value is assumed to be a comma - separated list encoded as a string. - - :param s: the search query to issue to Elasticsearch - :param query_serializer: the ``MediaSearchRequestSerializer`` object with the query - :param basis: the name of the field in the serializer and Elasticsearch - :param behaviour: whether to accept (``filter``) or reject (``exclude``) the hit - :return: the modified search query - """ - - search_params = query_serializer.data - if isinstance(basis, tuple): - serializer_field, es_field = basis - else: - serializer_field = es_field = basis - if serializer_field in search_params: - filters = [] - for arg in search_params[serializer_field].split(","): - filters.append(Q("term", **{es_field: arg})) - method = getattr(s, behaviour) # can be ``s.filter`` or ``s.exclude`` - return method("bool", should=filters) - else: - return s - - -def perform_search( - query_serializer: MediaSearchRequestSerializer, - index: Literal["image", "audio"], - ip: int, -) -> Tuple[List[Hit], int, int]: - """ - Perform a ranked, paginated search based on the query and filters given in the - search request. - - :param query_serializer: the ``MediaSearchRequestSerializer`` object with the query - :param index: The Elasticsearch index to search (e.g. 'image') - :param ip: the users' hashed IP to consistently route to the same ES shard - :return: the list of search results with the page and result count - """ - - s = Search(using="default", index=index) - search_params = query_serializer.data - - rules: dict[Literal["filter", "exclude"], list[Union[str, tuple[str, str]]]] = { - "filter": [ - "extension", - "category", - ("categories", "category"), - "aspect_ratio", - "size", - "length", - "source", - ("license", "license.keyword"), - ("license_type", "license.keyword"), - ], - "exclude": [ - ("excluded_source", "source"), - ], - } - for behaviour, bases in rules.items(): - for basis in bases: - s = _apply_filter(s, query_serializer, basis, behaviour) - - # Exclude mature content - if not search_params["mature"]: - s = s.exclude("term", mature=True) - # Exclude sources with ``filter_content`` enabled - s = exclude_filtered_providers(s) - - # Search either by generic multimatch or by "advanced search" with - # individual field-level queries specified. - - search_fields = ["tags.name", "title", "description"] - if "q" in search_params: - query = _quote_escape(search_params["q"]) - s = s.query( - "simple_query_string", - query=query, - fields=search_fields, - default_operator="AND", - ) - # Boost exact matches - quotes_stripped = query.replace('"', "") - exact_match_boost = Q( - "simple_query_string", - fields=["title"], - query=f'"{quotes_stripped}"', - boost=10000, - ) - s.query = Q("bool", must=s.query, should=exact_match_boost) - else: - query_bases = ["creator", "title", ("tags", "tags.name")] - for query_basis in query_bases: - if isinstance(query_basis, tuple): - serializer_field, es_field = query_basis - else: - serializer_field = es_field = query_basis - if serializer_field in search_params: - value = _quote_escape(search_params[serializer_field]) - s = s.query("simple_query_string", fields=[es_field], query=value) - - if settings.USE_RANK_FEATURES: - feature_boost = {"standardized_popularity": 10000} - rank_queries = [] - for field, boost in feature_boost.items(): - rank_queries.append(Q("rank_feature", field=field, boost=boost)) - s.query = Q("bool", must=s.query, should=rank_queries) - - # Use highlighting to determine which fields contribute to the selection of - # top results. - s = s.highlight(*search_fields) - s = s.highlight_options(order="score") - - # Route users to the same Elasticsearch worker node to reduce - # pagination inconsistencies and increase cache hits. - s = s.params(preference=str(ip), request_timeout=7) - - # Paginate - start, end = get_query_slice( - s, - search_params["page_size"], - search_params["page"], - search_params["filter_dead"], - ) - s = s[start:end] - - try: - if settings.VERBOSE_ES_RESPONSE: - log.info(pprint.pprint(s.to_dict())) - search_response = s.execute() - log.info( - f"query={json.dumps(s.to_dict())}," f" es_took_ms={search_response.took}" - ) - if settings.VERBOSE_ES_RESPONSE: - log.info(pprint.pprint(search_response.to_dict())) - except RequestError as e: - raise ValueError(e) - - results = post_process_results( - s, - start, - end, - search_params["page_size"], - search_response, - search_params["filter_dead"], - ) - - result_count, page_count = get_result_and_page_count( - search_response, results, search_params["page_size"] - ) - return results, page_count, result_count diff --git a/api/catalog/api/controllers/elasticsearch/stats.py b/api/catalog/api/controllers/elasticsearch/stats.py deleted file mode 100644 index f37e54812..000000000 --- a/api/catalog/api/controllers/elasticsearch/stats.py +++ /dev/null @@ -1,51 +0,0 @@ -import logging as log -from typing import Literal - -from django.core.cache import cache - -from elasticsearch.exceptions import NotFoundError -from elasticsearch_dsl import Search - - -SOURCE_CACHE_TIMEOUT = 60 * 20 # seconds - - -def get_stats(index: Literal["image", "audio"]): - """ - Given an index, find all available data sources and return their counts. This data - is cached in Redis. See ``load_sample_data.sh`` for example of clearing the cache. - - :param index: the Elasticsearch index name - :return: a dictionary mapping sources to the count of their media items - """ - - source_cache_name = "sources-" + index - try: - sources = cache.get(key=source_cache_name) - if sources is not None: - return sources - except ValueError: - log.warning("Source cache fetch failed") - - # Don't increase `size` without reading this issue first: - # https://github.com/elastic/elasticsearch/issues/18838 - size = 100 - try: - s = Search(using="default", index=index) - s.aggs.bucket( - "unique_sources", - "terms", - field="source.keyword", - size=size, - order={"_key": "desc"}, - ) - results = s.execute() - buckets = results["aggregations"]["unique_sources"]["buckets"] - sources = {result["key"]: result["doc_count"] for result in buckets} - except NotFoundError: - sources = {} - - if sources: - cache.set(key=source_cache_name, timeout=SOURCE_CACHE_TIMEOUT, value=sources) - - return sources diff --git a/api/catalog/api/controllers/elasticsearch/utils.py b/api/catalog/api/controllers/elasticsearch/utils.py deleted file mode 100644 index cb0bb6e12..000000000 --- a/api/catalog/api/controllers/elasticsearch/utils.py +++ /dev/null @@ -1,160 +0,0 @@ -from itertools import accumulate -from math import ceil -from typing import List, Optional, Tuple - -from django.core.cache import cache - -from elasticsearch_dsl import Search -from elasticsearch_dsl.response import Hit, Response - -from catalog.api.models import ContentProvider -from catalog.api.utils.dead_link_mask import get_query_hash, get_query_mask -from catalog.api.utils.validate_images import validate_images - - -FILTER_CACHE_TIMEOUT = 30 -DEAD_LINK_RATIO = 1 / 2 -ELASTICSEARCH_MAX_RESULT_WINDOW = 10000 - - -def exclude_filtered_providers(s: Search) -> Search: - """ - Hide data sources from the catalog dynamically. This excludes providers with - ``filter_content`` enabled from the search results. - - :param s: the search query to issue to Elasticsearch - :return: the modified search query - """ - - filter_cache_key = "filtered_providers" - filtered_providers = cache.get(key=filter_cache_key) - if filtered_providers is None: - filtered_providers = ContentProvider.objects.filter(filter_content=True).values( - "provider_identifier" - ) - cache.set( - key=filter_cache_key, - timeout=FILTER_CACHE_TIMEOUT, - value=filtered_providers, - ) - if len(filtered_providers) != 0: - to_exclude = [f["provider_identifier"] for f in filtered_providers] - s = s.exclude("terms", provider=to_exclude) - return s - - -def paginate_with_dead_link_mask( - s: Search, page_size: int, page: int -) -> Tuple[int, int]: - """ - Given a query, a page and page_size, return the start and end - of the slice of results. - - :param s: The elasticsearch Search object - :param page_size: How big the page should be. - :param page: The page number. - :return: Tuple of start and end. - """ - query_hash = get_query_hash(s) - query_mask = get_query_mask(query_hash) - if not query_mask: - start = 0 - end = ceil(page_size * page / (1 - DEAD_LINK_RATIO)) - elif page_size * (page - 1) > sum(query_mask): - start = len(query_mask) - end = ceil(page_size * page / (1 - DEAD_LINK_RATIO)) - else: - accu_query_mask = list(accumulate(query_mask)) - start = 0 - if page > 1: - try: - start = accu_query_mask.index(page_size * (page - 1) + 1) - except ValueError: - start = accu_query_mask.index(page_size * (page - 1)) + 1 - if page_size * page > sum(query_mask): - end = ceil(page_size * page / (1 - DEAD_LINK_RATIO)) - else: - end = accu_query_mask.index(page_size * page) + 1 - return start, end - - -def get_query_slice( - s: Search, page_size: int, page: int, filter_dead: Optional[bool] = False -) -> Tuple[int, int]: - """ - Select the start and end of the search results for this query. - """ - if filter_dead: - start_slice, end_slice = paginate_with_dead_link_mask(s, page_size, page) - else: - # Paginate search query. - start_slice = page_size * (page - 1) - end_slice = page_size * page - if start_slice + end_slice > ELASTICSEARCH_MAX_RESULT_WINDOW: - raise ValueError("Deep pagination is not allowed.") - return start_slice, end_slice - - -def post_process_results( - s, start, end, page_size, search_results, filter_dead -) -> List[Hit]: - """ - After fetching the search results from the back end, iterate through the - results, perform image validation, and route certain thumbnails through our - proxy. - - :param s: The Elasticsearch Search object. - :param start: The start of the result slice. - :param end: The end of the result slice. - :param search_results: The Elasticsearch response object containing search - results. - :param filter_dead: Whether images should be validated. - :return: List of results. - """ - results = [] - to_validate = [] - for res in search_results: - if hasattr(res.meta, "highlight"): - res.fields_matched = dir(res.meta.highlight) - to_validate.append(res.url) - results.append(res) - - if filter_dead: - query_hash = get_query_hash(s) - validate_images(query_hash, start, results, to_validate) - - if len(results) < page_size: - end += int(end / 2) - if start + end > ELASTICSEARCH_MAX_RESULT_WINDOW: - return results - - s = s[start:end] - search_response = s.execute() - - return post_process_results( - s, start, end, page_size, search_response, filter_dead - ) - return results[:page_size] - - -def get_result_and_page_count( - response_obj: Response, results: List[Hit], page_size: int -) -> Tuple[int, int]: - """ - Elasticsearch does not allow deep pagination of ranked queries. - Adjust returned page count to reflect this. - - :param response_obj: The original Elasticsearch response object. - :param results: The list of filtered result Hits. - :return: Result and page count. - """ - result_count = response_obj.hits.total.value - natural_page_count = int(result_count / page_size) - if natural_page_count % page_size != 0: - natural_page_count += 1 - last_allowed_page = int((5000 + page_size / 2) / page_size) - page_count = min(natural_page_count, last_allowed_page) - if len(results) < page_size and page_count == 0: - result_count = len(results) - - return result_count, page_count diff --git a/api/catalog/api/controllers/search_controller.py b/api/catalog/api/controllers/search_controller.py new file mode 100644 index 000000000..0a9fb9801 --- /dev/null +++ b/api/catalog/api/controllers/search_controller.py @@ -0,0 +1,466 @@ +from __future__ import annotations + +import json +import logging as log +import pprint +from itertools import accumulate +from math import ceil +from typing import List, Literal, Optional, Tuple + +from django.conf import settings +from django.core.cache import cache +from rest_framework.request import Request + +from aws_requests_auth.aws_auth import AWSRequestsAuth +from elasticsearch import Elasticsearch, RequestsHttpConnection +from elasticsearch.exceptions import NotFoundError, RequestError +from elasticsearch_dsl import Q, Search, connections +from elasticsearch_dsl.query import Query +from elasticsearch_dsl.response import Hit, Response + +import catalog.api.models as models +from catalog.api.serializers.media_serializers import MediaSearchRequestSerializer +from catalog.api.utils.dead_link_mask import get_query_hash, get_query_mask +from catalog.api.utils.validate_images import validate_images + + +ELASTICSEARCH_MAX_RESULT_WINDOW = 10000 +SOURCE_CACHE_TIMEOUT = 60 * 20 +FILTER_CACHE_TIMEOUT = 30 +DEAD_LINK_RATIO = 1 / 2 +THUMBNAIL = "thumbnail" +URL = "url" +PROVIDER = "provider" +DEEP_PAGINATION_ERROR = "Deep pagination is not allowed." +QUERY_SPECIAL_CHARACTER_ERROR = "Unescaped special characters are not allowed." + + +class RankFeature(Query): + name = "rank_feature" + + +def _paginate_with_dead_link_mask( + s: Search, page_size: int, page: int +) -> Tuple[int, int]: + """ + Given a query, a page and page_size, return the start and end + of the slice of results. + + :param s: The elasticsearch Search object + :param page_size: How big the page should be. + :param page: The page number. + :return: Tuple of start and end. + """ + query_hash = get_query_hash(s) + query_mask = get_query_mask(query_hash) + if not query_mask: + start = 0 + end = ceil(page_size * page / (1 - DEAD_LINK_RATIO)) + elif page_size * (page - 1) > sum(query_mask): + start = len(query_mask) + end = ceil(page_size * page / (1 - DEAD_LINK_RATIO)) + else: + accu_query_mask = list(accumulate(query_mask)) + start = 0 + if page > 1: + try: + start = accu_query_mask.index(page_size * (page - 1) + 1) + except ValueError: + start = accu_query_mask.index(page_size * (page - 1)) + 1 + if page_size * page > sum(query_mask): + end = ceil(page_size * page / (1 - DEAD_LINK_RATIO)) + else: + end = accu_query_mask.index(page_size * page) + 1 + return start, end + + +def _get_query_slice( + s: Search, page_size: int, page: int, filter_dead: Optional[bool] = False +) -> Tuple[int, int]: + """ + Select the start and end of the search results for this query. + """ + if filter_dead: + start_slice, end_slice = _paginate_with_dead_link_mask(s, page_size, page) + else: + # Paginate search query. + start_slice = page_size * (page - 1) + end_slice = page_size * page + if start_slice + end_slice > ELASTICSEARCH_MAX_RESULT_WINDOW: + raise ValueError(DEEP_PAGINATION_ERROR) + return start_slice, end_slice + + +def _quote_escape(query_string): + """ + If there are any unmatched quotes in the query supplied by the user, ignore + them. + """ + num_quotes = query_string.count('"') + if num_quotes % 2 == 1: + return query_string.replace('"', '\\"') + else: + return query_string + + +def _post_process_results( + s, start, end, page_size, search_results, request, filter_dead +) -> List[Hit]: + """ + After fetching the search results from the back end, iterate through the + results, perform image validation, and route certain thumbnails through our + proxy. + + :param s: The Elasticsearch Search object. + :param start: The start of the result slice. + :param end: The end of the result slice. + :param search_results: The Elasticsearch response object containing search + results. + :param request: The Django request object, used to build a "reversed" URL + to detail pages. + :param filter_dead: Whether images should be validated. + :return: List of results. + """ + results = [] + to_validate = [] + for res in search_results: + if hasattr(res.meta, "highlight"): + res.fields_matched = dir(res.meta.highlight) + to_validate.append(res.url) + results.append(res) + + if filter_dead: + query_hash = get_query_hash(s) + validate_images(query_hash, start, results, to_validate) + + if len(results) < page_size: + end += int(end / 2) + if start + end > ELASTICSEARCH_MAX_RESULT_WINDOW: + return results + + s = s[start:end] + search_response = s.execute() + + return _post_process_results( + s, start, end, page_size, search_response, request, filter_dead + ) + return results[:page_size] + + +def _apply_filter( + s: Search, + search_params: MediaSearchRequestSerializer, + serializer_field: str, + es_field: Optional[str] = None, + behaviour: Literal["filter", "exclude"] = "filter", +): + """ + Parse and apply a filter from the search parameters serializer. The + parameter key is assumed to have the same name as the corresponding + Elasticsearch property. Each parameter value is assumed to be a comma + separated list encoded as a string. + + :param s: The ``Search`` instance to apply the filter to + :param search_params: the serializer instance containing user input + :param serializer_field: the name of the parameter field in ``search_params`` + :param es_field: the corresponding parameter name in Elasticsearch + :param behaviour: whether to accept (``filter``) or reject (``exclude``) the hit + :return: the input ``Search`` object with the filters applied + """ + + if serializer_field in search_params.data: + filters = [] + for arg in search_params.data[serializer_field].split(","): + _param = es_field or serializer_field + args = {"name_or_query": "term", _param: arg} + filters.append(Q(**args)) + method = getattr(s, behaviour) + return method("bool", should=filters) + else: + return s + + +def _exclude_filtered(s: Search): + """ + Hide data sources from the catalog dynamically. + """ + filter_cache_key = "filtered_providers" + filtered_providers = cache.get(key=filter_cache_key) + if not filtered_providers: + filtered_providers = models.ContentProvider.objects.filter( + filter_content=True + ).values("provider_identifier") + cache.set( + key=filter_cache_key, timeout=FILTER_CACHE_TIMEOUT, value=filtered_providers + ) + to_exclude = [f["provider_identifier"] for f in filtered_providers] + s = s.exclude("terms", provider=to_exclude) + return s + + +def _exclude_mature_by_param(s: Search, search_params): + if not search_params.data["mature"]: + s = s.exclude("term", mature=True) + return s + + +def search( + search_params: MediaSearchRequestSerializer, + index: Literal["image", "audio"], + page_size: int, + ip: int, + request: Request, + filter_dead: bool, + page: int = 1, +) -> Tuple[List[Hit], int, int]: + """ + Given a set of keywords and an optional set of filters, perform a ranked + paginated search. + + :param search_params: Search parameters. See + :class: `ImageSearchQueryStringSerializer`. + :param index: The Elasticsearch index to search (e.g. 'image') + :param page_size: The number of results to return per page. + :param ip: The user's hashed IP. Hashed IPs are used to anonymously but + uniquely identify users exclusively for ensuring query consistency across + Elasticsearch shards. + :param request: Django's request object. + :param filter_dead: Whether dead links should be removed. + :param page: The results page number. + :return: Tuple with a List of Hits from elasticsearch, the total count of + pages, and number of results. + """ + search_client = Search(index=index) + + s = search_client + # Apply term filters. Each tuple pairs a filter's parameter name in the API + # with its corresponding field in Elasticsearch. "None" means that the + # names are identical. + filters = [ + ("extension", None), + ("category", None), + ("categories", "category"), + ("length", None), + ("aspect_ratio", None), + ("size", None), + ("source", None), + ("license", "license__keyword"), + ("license_type", "license__keyword"), + ] + for serializer_field, es_field in filters: + if serializer_field in search_params.data: + s = _apply_filter(s, search_params, serializer_field, es_field) + + exclude = [ + ("excluded_source", "source"), + ] + for serializer_field, es_field in exclude: + if serializer_field in search_params.data: + s = _apply_filter(s, search_params, serializer_field, es_field, "exclude") + + # Exclude mature content and disabled sources + s = _exclude_mature_by_param(s, search_params) + s = _exclude_filtered(s) + + # Search either by generic multimatch or by "advanced search" with + # individual field-level queries specified. + search_fields = ["tags.name", "title", "description"] + if "q" in search_params.data: + query = _quote_escape(search_params.data["q"]) + s = s.query( + "simple_query_string", + query=query, + fields=search_fields, + default_operator="AND", + ) + # Boost exact matches + quotes_stripped = query.replace('"', "") + exact_match_boost = Q( + "simple_query_string", + fields=["title"], + query=f'"{quotes_stripped}"', + boost=10000, + ) + s = search_client.query(Q("bool", must=s.query, should=exact_match_boost)) + else: + if "creator" in search_params.data: + creator = _quote_escape(search_params.data["creator"]) + s = s.query("simple_query_string", query=creator, fields=["creator"]) + if "title" in search_params.data: + title = _quote_escape(search_params.data["title"]) + s = s.query("simple_query_string", query=title, fields=["title"]) + if "tags" in search_params.data: + tags = _quote_escape(search_params.data["tags"]) + s = s.query("simple_query_string", fields=["tags.name"], query=tags) + + if settings.USE_RANK_FEATURES: + feature_boost = {"standardized_popularity": 10000} + rank_queries = [] + for field, boost in feature_boost.items(): + rank_queries.append(Q("rank_feature", field=field, boost=boost)) + s = search_client.query(Q("bool", must=s.query, should=rank_queries)) + + # Use highlighting to determine which fields contribute to the selection of + # top results. + s = s.highlight(*search_fields) + s = s.highlight_options(order="score") + s.extra(track_scores=True) + # Route users to the same Elasticsearch worker node to reduce + # pagination inconsistencies and increase cache hits. + s = s.params(preference=str(ip), request_timeout=7) + # Paginate + start, end = _get_query_slice(s, page_size, page, filter_dead) + s = s[start:end] + try: + if settings.VERBOSE_ES_RESPONSE: + log.info(pprint.pprint(s.to_dict())) + search_response = s.execute() + log.info( + f"query={json.dumps(s.to_dict())}," f" es_took_ms={search_response.took}" + ) + if settings.VERBOSE_ES_RESPONSE: + log.info(pprint.pprint(search_response.to_dict())) + except RequestError as e: + raise ValueError(e) + results = _post_process_results( + s, start, end, page_size, search_response, request, filter_dead + ) + + result_count, page_count = _get_result_and_page_count( + search_response, results, page_size + ) + return results, page_count, result_count + + +def related_media(uuid, index, request, filter_dead): + """ + Given a UUID, find related search results. + """ + search_client = Search(index=index) + + # Convert UUID to sequential ID. + item = search_client + item = item.query("match", identifier=uuid) + _id = item.execute().hits[0].id + + s = search_client + s = s.query( + "more_like_this", + fields=["tags.name", "title", "creator"], + like={"_index": index, "_id": _id}, + min_term_freq=1, + max_query_terms=50, + ) + # Never show mature content in recommendations. + s = s.exclude("term", mature=True) + s = _exclude_filtered(s) + page_size = 10 + page = 1 + start, end = _get_query_slice(s, page_size, page, filter_dead) + s = s[start:end] + response = s.execute() + results = _post_process_results( + s, start, end, page_size, response, request, filter_dead + ) + + result_count, _ = _get_result_and_page_count(response, results, page_size) + + return results, result_count + + +def get_sources(index): + """ + Given an index, find all available data sources and return their counts. + + :param index: An Elasticsearch index, such as `'image'`. + :return: A dictionary mapping sources to the count of their images.` + """ + source_cache_name = "sources-" + index + cache_fetch_failed = False + try: + sources = cache.get(key=source_cache_name) + except ValueError: + cache_fetch_failed = True + sources = None + log.warning("Source cache fetch failed due to corruption") + if type(sources) == list or cache_fetch_failed: + # Invalidate old provider format. + cache.delete(key=source_cache_name) + if not sources: + # Don't increase `size` without reading this issue first: + # https://github.com/elastic/elasticsearch/issues/18838 + size = 100 + agg_body = { + "aggs": { + "unique_sources": { + "terms": { + "field": "source.keyword", + "size": size, + "order": {"_key": "desc"}, + } + } + } + } + try: + results = es.search(index=index, body=agg_body, request_cache=True) + buckets = results["aggregations"]["unique_sources"]["buckets"] + except NotFoundError: + buckets = [{"key": "none_found", "doc_count": 0}] + sources = {result["key"]: result["doc_count"] for result in buckets} + cache.set(key=source_cache_name, timeout=SOURCE_CACHE_TIMEOUT, value=sources) + return sources + + +def _elasticsearch_connect(): + """ + Connect to configured Elasticsearch domain. + + :return: An Elasticsearch connection object. + """ + auth = AWSRequestsAuth( + aws_access_key=settings.AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, + aws_host=settings.ELASTICSEARCH_URL, + aws_region=settings.ELASTICSEARCH_AWS_REGION, + aws_service="es", + ) + auth.encode = lambda x: bytes(x.encode("utf-8")) + _es = Elasticsearch( + host=settings.ELASTICSEARCH_URL, + port=settings.ELASTICSEARCH_PORT, + connection_class=RequestsHttpConnection, + timeout=10, + max_retries=1, + retry_on_timeout=True, + http_auth=auth, + wait_for_status="yellow", + ) + _es.info() + return _es + + +es = _elasticsearch_connect() +connections.connections.add_connection("default", es) + + +def _get_result_and_page_count( + response_obj: Response, results: List[Hit], page_size: int +) -> Tuple[int, int]: + """ + Elasticsearch does not allow deep pagination of ranked queries. + Adjust returned page count to reflect this. + + :param response_obj: The original Elasticsearch response object. + :param results: The list of filtered result Hits. + :return: Result and page count. + """ + result_count = response_obj.hits.total.value + natural_page_count = int(result_count / page_size) + if natural_page_count % page_size != 0: + natural_page_count += 1 + last_allowed_page = int((5000 + page_size / 2) / page_size) + page_count = min(natural_page_count, last_allowed_page) + if len(results) < page_size and page_count == 0: + result_count = len(results) + + return result_count, page_count diff --git a/api/catalog/api/models/audio.py b/api/catalog/api/models/audio.py index 77171f74e..9e8e269d1 100644 --- a/api/catalog/api/models/audio.py +++ b/api/catalog/api/models/audio.py @@ -1,9 +1,9 @@ -from django.conf import settings from django.contrib.postgres.fields import ArrayField from django.db import models from uuslug import uuslug +import catalog.api.controllers.search_controller as search_controller from catalog.api.models import OpenLedgerModel from catalog.api.models.media import ( AbstractAltFile, @@ -250,7 +250,7 @@ class MatureAudio(AbstractMatureMedia): """Stores all audios that have been flagged as 'mature'.""" def delete(self, *args, **kwargs): - es = settings.ES + es = search_controller.es aud = Audio.objects.get(identifier=self.identifier) es_id = aud.id es.update(index="audio", id=es_id, body={"doc": {"mature": False}}) diff --git a/api/catalog/api/models/image.py b/api/catalog/api/models/image.py index a848a9470..009fde502 100644 --- a/api/catalog/api/models/image.py +++ b/api/catalog/api/models/image.py @@ -1,8 +1,8 @@ -from django.conf import settings from django.db import models from uuslug import uuslug +import catalog.api.controllers.search_controller as search_controller from catalog.api.models.media import ( AbstractDeletedMedia, AbstractMatureMedia, @@ -77,7 +77,7 @@ class MatureImage(AbstractMatureMedia): """Stores all images that have been flagged as 'mature'.""" def delete(self, *args, **kwargs): - es = settings.ES + es = search_controller.es img = Image.objects.get(identifier=self.identifier) es_id = img.id es.update(index="image", id=es_id, body={"doc": {"mature": False}}) diff --git a/api/catalog/api/models/media.py b/api/catalog/api/models/media.py index 270e0dc46..7ca538b36 100644 --- a/api/catalog/api/models/media.py +++ b/api/catalog/api/models/media.py @@ -1,10 +1,10 @@ import mimetypes -from django.conf import settings from django.contrib.postgres.fields import ArrayField from django.db import models from django.utils.html import format_html +import catalog.api.controllers.search_controller as search_controller from catalog.api.models.base import OpenLedgerModel from catalog.api.models.mixins import ( ForeignIdentifierMixin, @@ -189,7 +189,7 @@ def save(self, *args, **kwargs): update_required = {MATURE_FILTERED, DEINDEXED} # ES needs updating if self.status in update_required: - es = settings.ES + es = search_controller.es try: media = media_class.objects.get(identifier=self.identifier) except media_class.DoesNotExist: diff --git a/api/catalog/api/serializers/media_serializers.py b/api/catalog/api/serializers/media_serializers.py index 875f0ca35..4d10c0583 100644 --- a/api/catalog/api/serializers/media_serializers.py +++ b/api/catalog/api/serializers/media_serializers.py @@ -3,7 +3,7 @@ from rest_framework import serializers from catalog.api.constants.licenses import LICENSE_GROUPS -from catalog.api.controllers.elasticsearch.stats import get_stats +from catalog.api.controllers import search_controller from catalog.api.models.media import AbstractMedia from catalog.api.serializers.base import BaseModelSerializer from catalog.api.serializers.fields import SchemableHyperlinkedIdentityField @@ -104,22 +104,11 @@ class MediaSearchRequestSerializer(serializers.Serializer): required=False, default=False, ) - page = serializers.IntegerField( - min_value=1, - default=1, - help_text="The index of the page of the results to show.", - ) - page_size = serializers.IntegerField( - min_value=1, - max_value=500, - default=20, - help_text="The number of results to show in one page.", - ) @staticmethod def _truncate(value): max_length = 200 - return value[:max_length] + return value if len(value) <= max_length else value[:max_length] def validate_q(self, value): return self._truncate(value) @@ -361,7 +350,7 @@ class MediaSearchRequestSourceSerializer(serializers.Serializer): _field_attrs = { "help_text": make_comma_separated_help_text( - get_stats(media_type).keys(), "data sources" + search_controller.get_sources(media_type).keys(), "data sources" ), "required": False, } @@ -379,7 +368,7 @@ class MediaSearchRequestSourceSerializer(serializers.Serializer): def validate_source_field(value): """Checks whether source is a valid source.""" - allowed_sources = list(get_stats(media_type).keys()) + allowed_sources = list(search_controller.get_sources(media_type).keys()) sources = value.lower().split(",") sources = [source for source in sources if source in allowed_sources] value = ",".join(sources) diff --git a/api/catalog/api/utils/pagination.py b/api/catalog/api/utils/pagination.py index cc813bbd5..484e455fe 100644 --- a/api/catalog/api/utils/pagination.py +++ b/api/catalog/api/utils/pagination.py @@ -1,17 +1,48 @@ from rest_framework.pagination import PageNumberPagination from rest_framework.response import Response +from catalog.api.utils.exceptions import get_api_exception + class StandardPagination(PageNumberPagination): - page_query_param = None + page_size_query_param = "page_size" + page_query_param = "page" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.result_count = None # populated later self.page_count = None # populated later - self.page_size = 20 - self.page = 1 + self._page_size = 20 + self._page = None + + @property + def page_size(self): + """the number of results to show in one page""" + return self._page_size + + @page_size.setter + def page_size(self, value): + if value is None or not str(value).isnumeric(): + return + value = int(value) # convert str params to int + if value <= 0 or value > 500: + raise get_api_exception("Page size must be between 0 & 500.", 400) + self._page_size = value + + @property + def page(self): + """the current page number being served""" + return self._page + + @page.setter + def page(self, value): + if value is None or not str(value).isnumeric(): + value = 1 + value = int(value) # convert str params to int + if value <= 0: + raise get_api_exception("Page must be greater than 0.", 400) + self._page = value def get_paginated_response(self, data): return Response( diff --git a/api/catalog/api/views/media_views.py b/api/catalog/api/views/media_views.py index 151551d48..aefcd15c5 100644 --- a/api/catalog/api/views/media_views.py +++ b/api/catalog/api/views/media_views.py @@ -11,9 +11,7 @@ from rest_framework.response import Response from rest_framework.viewsets import ReadOnlyModelViewSet -from catalog.api.controllers.elasticsearch.related import related_media -from catalog.api.controllers.elasticsearch.search import perform_search -from catalog.api.controllers.elasticsearch.stats import get_stats +from catalog.api.controllers import search_controller from catalog.api.models import ContentProvider from catalog.api.serializers.provider_serializers import ProviderSerializer from catalog.api.utils.exceptions import get_api_exception @@ -54,21 +52,28 @@ def get_queryset(self): # Standard actions def list(self, request, *_, **__): + self.paginator.page_size = request.query_params.get("page_size") + page_size = self.paginator.page_size + self.paginator.page = request.query_params.get("page") + page = self.paginator.page + params = self.query_serializer_class(data=request.query_params) params.is_valid(raise_exception=True) - self.paginator.page_size = params.validated_data["page_size"] - self.paginator.page = params.validated_data["page"] - hashed_ip = hash(self._get_user_ip(request)) qa = params.validated_data["qa"] + filter_dead = params.validated_data["filter_dead"] search_index = self.qa_index if qa else self.default_index try: - results, num_pages, num_results = perform_search( + results, num_pages, num_results = search_controller.search( params, search_index, + page_size, hashed_ip, + request, + filter_dead, + page, ) self.paginator.page_count = num_pages self.paginator.result_count = num_results @@ -82,7 +87,7 @@ def list(self, request, *_, **__): @action(detail=False, serializer_class=ProviderSerializer, pagination_class=None) def stats(self, *_, **__): - source_counts = get_stats(self.default_index) + source_counts = search_controller.get_sources(self.default_index) context = self.get_serializer_context() | { "source_counts": source_counts, } @@ -96,9 +101,10 @@ def stats(self, *_, **__): @action(detail=True) def related(self, request, identifier=None, *_, **__): try: - results, num_results = related_media( + results, num_results = search_controller.related_media( uuid=identifier, index=self.default_index, + request=request, filter_dead=True, ) self.paginator.result_count = num_results diff --git a/api/catalog/settings.py b/api/catalog/settings.py index 9a70b339e..802027561 100644 --- a/api/catalog/settings.py +++ b/api/catalog/settings.py @@ -14,10 +14,7 @@ from socket import gethostbyname, gethostname import sentry_sdk -from aws_requests_auth.aws_auth import AWSRequestsAuth from decouple import config -from elasticsearch import Elasticsearch, RequestsHttpConnection -from elasticsearch_dsl import connections from sentry_sdk.integrations.django import DjangoIntegration from catalog.logger import LOGGING as LOGGING_CONF @@ -343,38 +340,3 @@ send_default_pii=False, environment=ENVIRONMENT, ) - - -# Elasticsearch connection - - -def _elasticsearch_connect(): - """ - Connect to configured Elasticsearch domain. - - :return: An Elasticsearch connection object. - """ - auth = AWSRequestsAuth( - aws_access_key=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - aws_host=ELASTICSEARCH_URL, - aws_region=ELASTICSEARCH_AWS_REGION, - aws_service="es", - ) - auth.encode = lambda x: bytes(x.encode("utf-8")) - _es = Elasticsearch( - host=ELASTICSEARCH_URL, - port=ELASTICSEARCH_PORT, - connection_class=RequestsHttpConnection, - timeout=10, - max_retries=1, - retry_on_timeout=True, - http_auth=auth, - wait_for_status="yellow", - ) - _es.info() - return _es - - -ES = _elasticsearch_connect() -connections.add_connection("default", ES) diff --git a/load_sample_data.sh b/load_sample_data.sh index 751821d3c..b22a15c1b 100755 --- a/load_sample_data.sh +++ b/load_sample_data.sh @@ -109,6 +109,5 @@ just ingest-upstream "image" just wait-for-index "image" # Clear source cache since it's out of date after data has been loaded -# See `api/catalog/api/controllers/elasticsearch/stats.py` docker-compose exec -T "$CACHE_SERVICE_NAME" /bin/bash -c "echo \"del :1:sources-image\" | redis-cli" docker-compose exec -T "$CACHE_SERVICE_NAME" /bin/bash -c "echo \"del :1:sources-audio\" | redis-cli"