Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API routes and controllers for additional search views #2853

Merged
merged 16 commits into from
Nov 16, 2023
4 changes: 3 additions & 1 deletion api/api/constants/media_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
IMAGE_TYPE = "image"

MEDIA_TYPES = [AUDIO_TYPE, IMAGE_TYPE]
MediaType = Literal["audio", "image"]

MEDIA_TYPE_CHOICES = [(AUDIO_TYPE, "Audio"), (IMAGE_TYPE, "Image")]

OriginIndex = Literal["image", "audio"]
OriginIndex = MediaType
SearchIndex = Literal["image", "image-filtered", "audio", "audio-filtered"]
5 changes: 5 additions & 0 deletions api/api/constants/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Literal


SEARCH_STRATEGIES = ["search", "collection"]
SearchStrategy = Literal["search", "collection"]
212 changes: 155 additions & 57 deletions api/api/controllers/search_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import logging as log
from math import ceil
from typing import Literal
from typing import TYPE_CHECKING

from django.conf import settings
from django.core.cache import cache
Expand All @@ -15,21 +15,35 @@
from elasticsearch_dsl.response import Hit, Response

import api.models as models
from api.constants.media_types import OriginIndex
from api.constants.media_types import OriginIndex, SearchIndex
from api.constants.search import SearchStrategy
from api.constants.sorting import INDEXED_ON
from api.controllers.elasticsearch.helpers import (
ELASTICSEARCH_MAX_RESULT_WINDOW,
get_es_response,
get_query_slice,
get_raw_es_response,
)
from api.serializers import media_serializers
from api.utils import tallies
from api.utils.check_dead_links import check_dead_links
from api.utils.dead_link_mask import get_query_hash
from api.utils.search_context import SearchContext


# Using TYPE_CHECKING to avoid circular imports when importing types
if TYPE_CHECKING:
from api.serializers.audio_serializers import AudioCollectionRequestSerializer
from api.serializers.media_serializers import (
MediaSearchRequestSerializer,
PaginatedRequestSerializer,
)

MediaListRequestSerializer = (
AudioCollectionRequestSerializer
| MediaSearchRequestSerializer
| PaginatedRequestSerializer
)

module_logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -172,24 +186,24 @@ def get_excluded_providers_query() -> Q | None:
return None


def _resolve_index(
index: Literal["image", "audio"],
search_params: media_serializers.MediaSearchRequestSerializer,
) -> Literal["image", "image-filtered", "audio", "audio-filtered"]:
use_filtered_index = all(
(
settings.ENABLE_FILTERED_INDEX_QUERIES,
not search_params.validated_data["include_sensitive_results"],
)
)
if use_filtered_index:
return f"{index}-filtered"
def get_index(
exact_index: bool,
origin_index: OriginIndex,
search_params: MediaListRequestSerializer,
) -> SearchIndex:
if exact_index:
return origin_index

return index
include_sensitive_results = search_params.validated_data.get(
"include_sensitive_results", False
)
if settings.ENABLE_FILTERED_INDEX_QUERIES and not include_sensitive_results:
return f"{origin_index}-filtered"
return origin_index


def create_search_filter_queries(
search_params: media_serializers.MediaSearchRequestSerializer,
search_params: MediaListRequestSerializer,
) -> dict[str, list[Q]]:
"""
Create a list of Elasticsearch queries for filtering search results.
Expand Down Expand Up @@ -230,7 +244,7 @@ def create_search_filter_queries(


def create_ranking_queries(
search_params: media_serializers.MediaSearchRequestSerializer,
search_params: MediaListRequestSerializer,
) -> list[Q]:
queries = [Q("rank_feature", field="standardized_popularity", boost=DEFAULT_BOOST)]
if search_params.data["unstable__authority"]:
Expand All @@ -240,8 +254,8 @@ def create_ranking_queries(
return queries


def create_search_query(
search_params: media_serializers.MediaSearchRequestSerializer,
def build_search_query(
search_params: MediaListRequestSerializer,
) -> Q:
# Apply filters from the url query search parameters.
url_queries = create_search_filter_queries(search_params)
Expand Down Expand Up @@ -315,8 +329,60 @@ def create_search_query(
)


def search(
search_params: media_serializers.MediaSearchRequestSerializer,
def build_collection_query(
search_params: MediaListRequestSerializer,
collection_params: dict[str, str],
):
"""
Build the query to retrieve items in a collection.
:param collection_params: `tag`, `source` and/or `creator` values from the path.
:param search_params: the validated search parameters.
:return: the search client with the query applied.
"""
search_query = {"filter": [], "must": [], "should": [], "must_not": []}
# Apply the term filters. Each tuple pairs a filter's parameter name in the API
# with its corresponding field in Elasticsearch. "None" means that the
# names are identical.
filters = [
# Collection filters allow a single value.
("tag", "tags.name.keyword"),
("source", None),
("creator", "creator.keyword"),
]
for serializer_field, es_field in filters:
if serializer_field in collection_params:
if not (argument := collection_params.get(serializer_field)):
continue
parameter = es_field or serializer_field
search_query["filter"].append({"term": {parameter: argument}})

# Exclude mature content and disabled sources
include_sensitive_by_params = search_params.validated_data.get(
"include_sensitive_results", False
)
if not include_sensitive_by_params:
search_query["must_not"].append({"term": {"mature": True}})

if excluded_providers_query := get_excluded_providers_query():
search_query["must_not"].append(excluded_providers_query)
obulat marked this conversation as resolved.
Show resolved Hide resolved

return Q("bool", **search_query)


def build_query(
strategy: SearchStrategy,
search_params: MediaListRequestSerializer,
collection_params: dict[str, str] | None,
) -> Q:
if strategy == "collection":
return build_collection_query(search_params, collection_params)
return build_search_query(search_params)


def query_media(
strategy: SearchStrategy,
search_params: MediaListRequestSerializer,
collection_params: dict[str, str] | None,
origin_index: OriginIndex,
exact_index: bool,
page_size: int,
Expand All @@ -325,10 +391,17 @@ def search(
page: int = 1,
) -> tuple[list[Hit], int, int, dict]:
"""
Perform a ranked paginated search from the set of keywords and, optionally, filters.

:param search_params: Search parameters. See
:class: `ImageSearchQueryStringSerializer`.
If ``strategy`` is ``search``, perform a ranked paginated search
from the set of keywords and, optionally, filters.
If `strategy` is `collection`, perform a paginated search
for the `tag`, `source` or `source` and `creator` combination.

:param collection_params: The path parameters for collection search, if
strategy is `collection`.
:param strategy: Whether to perform a default search or retrieve a collection.
:param search_params: If `strategy` is `collection`, `PaginatedRequestSerializer`
or `AudioCollectionRequestSerializer`. If `strategy` is `search`, search
query params, see :class: `MediaRequestSerializer`.
:param origin_index: The Elasticsearch index to search (e.g. 'image')
:param exact_index: whether to skip all modifications to the index name
:param page_size: The number of results to return per page.
Expand All @@ -337,46 +410,51 @@ def search(
Elasticsearch shards.
:param filter_dead: Whether dead links should be removed.
:param page: The results page number.
:return: Tuple with a List of Hits from elasticsearch, the total count of
:return: Tuple with a list of Hits from elasticsearch, the total count of
pages, the number of results, and the ``SearchContext`` as a dict.
"""
if not exact_index:
index = _resolve_index(origin_index, search_params)
else:
index = origin_index
index = get_index(exact_index, origin_index, search_params)

query = build_query(strategy, search_params, collection_params)

s = Search(index=index)
s = Search(index=index).query(query)

search_query = create_search_query(search_params)
s = s.query(search_query)
if strategy == "search":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Just a gripe/wondering what other approaches exist, no real suggestion here or request for changes)

I wish it was possible to isolate the strategy-specific and generic stuff in this method. The only way that comes to mind is to create the Search object before the strategy check and then pass it for mutation into the build_*_query functions, but maybe that's harder to follow. Or going full OOP strategy pattern to isolate things, but that doesn't seem worth it either, unless the strategies got a lot more complex.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've extracted build_query to make this part a bit clearer.

# Use highlighting to determine which fields contribute to the selection of
# top results.
s = s.highlight(*DEFAULT_SEARCH_FIELDS)
s = s.highlight_options(order="score")
s.extra(track_scores=True)

# Use highlighting to determine which fields contribute to the selection of
# top results.
s = s.highlight(*DEFAULT_SEARCH_FIELDS)
s = s.highlight_options(order="score")
s.extra(track_scores=True)
# Route users to the same Elasticsearch worker node to reduce
# pagination inconsistencies and increase cache hits.
# TODO: Re-add 7s request_timeout when ES stability is restored
s = s.params(preference=str(ip))

# Sort by new
if search_params.validated_data["sort_by"] == INDEXED_ON:
s = s.sort({"created_on": {"order": search_params.validated_data["sort_dir"]}})
# Sort by `created_on` if the parameter is set or if `strategy` is `collection`.
sort_by = search_params.validated_data.get("sort_by")
if strategy == "collection" or sort_by == INDEXED_ON:
sort_dir = search_params.validated_data.get("sort_dir", "desc")
s = s.sort({"created_on": {"order": sort_dir}})

# Paginate
start, end = get_query_slice(s, page_size, page, filter_dead)
s = s[start:end]
search_response = get_es_response(s, es_query="search")

results = _post_process_results(
s, start, end, page_size, search_response, filter_dead
# Execute paginated search and tally results
page_count, result_count, results = execute_search(
s, page, page_size, filter_dead, index, es_query=strategy
)

result_count, page_count = _get_result_and_page_count(
search_response, results, page_size, page
)
result_ids = [result.identifier for result in results]
search_context = SearchContext.build(result_ids, origin_index)

return results, page_count, result_count, search_context.asdict()


def tally_results(
index: SearchIndex, results: list[Hit] | None, page: int, page_size: int
) -> None:
"""
Tally the number of the results from each provider in the results
for the search query.
"""
results_to_tally = results or []
max_result_depth = page * page_size
if max_result_depth <= 80:
Expand Down Expand Up @@ -405,13 +483,33 @@ def search(
# check things like provider density for a set of queries.
tallies.count_provider_occurrences(results_to_tally, index)

if not results:
results = []

result_ids = [result.identifier for result in results]
search_context = SearchContext.build(result_ids, origin_index)
def execute_search(
s: Search,
page: int,
page_size: int,
filter_dead: bool,
index: SearchIndex,
es_query: str,
) -> tuple[int, int, list[Hit]]:
"""
Execute search for the given query slice, post-processes the results,
and returns the results and result and page counts.
"""
start, end = get_query_slice(s, page_size, page, filter_dead)
s = s[start:end]

return results, page_count, result_count, search_context.asdict()
search_response = get_es_response(s, es_query=es_query)

results: list[Hit] = (
_post_process_results(s, start, end, page_size, search_response, filter_dead)
or []
)
result_count, page_count = _get_result_and_page_count(
search_response, results, page_size, page
)
tally_results(index, results, page, page_size)
return page_count, result_count, results


def get_sources(index):
Expand Down
17 changes: 15 additions & 2 deletions api/api/docs/audio_docs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from drf_spectacular.utils import OpenApiResponse, extend_schema

from api.docs.base_docs import custom_extend_schema, fields_to_md
from api.docs.base_docs import collection_schema, custom_extend_schema, fields_to_md
from api.examples import (
audio_complain_201_example,
audio_complain_curl,
Expand Down Expand Up @@ -39,7 +39,7 @@

By using this endpoint, you can obtain search results based on specified
query and optionally filter results by
{fields_to_md(AudioSearchRequestSerializer.fields_names)}.
{fields_to_md(AudioSearchRequestSerializer.field_names)}.

Results are ranked in order of relevance and paginated on the basis of the
`page` param. The `page_size` param controls the total number of pages.
Expand Down Expand Up @@ -116,3 +116,16 @@
},
eg=[audio_waveform_curl],
)

source_collection = collection_schema(
media_type="audio",
collection="source",
)
creator_collection = collection_schema(
media_type="audio",
collection="creator",
)
tag_collection = collection_schema(
media_type="audio",
collection="tag",
)
Loading