Skip to content

Commit

Permalink
Add API routes and controllers for additional search views (#2853)
Browse files Browse the repository at this point in the history
* Add routes for the collections

Signed-off-by: Olga Bulat <[email protected]>

* Add and update tests

Signed-off-by: Olga Bulat <[email protected]>

* Remove double plurals in field_names

* Add docs to media_serializers
 Notes on fuzzy matching for query params and maximum page_size documentation

* Improve collection documentation

Co-authored-by: Staci Mullins <[email protected]>

* Add integration tests

Signed-off-by: Olga Bulat <[email protected]>

* Update docs

Signed-off-by: Olga Bulat <[email protected]>

* Update api/api/controllers/search_controller.py

Co-authored-by: sarayourfriend <[email protected]>

* Update api/api/controllers/search_controller.py

Co-authored-by: sarayourfriend <[email protected]>

* Update test

Signed-off-by: Olga Bulat <[email protected]>

* Combine TYPE_CHECKING clauses

Signed-off-by: Olga Bulat <[email protected]>

* Extract build_query to clean up query_media

Signed-off-by: Olga Bulat <[email protected]>

* Use the keyword creator field

Signed-off-by: Olga Bulat <[email protected]>

* Fix related test

Signed-off-by: Olga Bulat <[email protected]>

* Fix creator collection test

Signed-off-by: Olga Bulat <[email protected]>

* Update api/api/docs/base_docs.py

Co-authored-by: sarayourfriend <[email protected]>
Signed-off-by: Olga Bulat <[email protected]>

---------

Signed-off-by: Olga Bulat <[email protected]>
Co-authored-by: Staci Mullins <[email protected]>
Co-authored-by: sarayourfriend <[email protected]>
  • Loading branch information
3 people authored Nov 16, 2023
1 parent cdee828 commit 3986e71
Show file tree
Hide file tree
Showing 20 changed files with 775 additions and 155 deletions.
4 changes: 3 additions & 1 deletion api/api/constants/media_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
IMAGE_TYPE = "image"

MEDIA_TYPES = [AUDIO_TYPE, IMAGE_TYPE]
MediaType = Literal["audio", "image"]

MEDIA_TYPE_CHOICES = [(AUDIO_TYPE, "Audio"), (IMAGE_TYPE, "Image")]

OriginIndex = Literal["image", "audio"]
OriginIndex = MediaType
SearchIndex = Literal["image", "image-filtered", "audio", "audio-filtered"]
5 changes: 5 additions & 0 deletions api/api/constants/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Literal


SEARCH_STRATEGIES = ["search", "collection"]
SearchStrategy = Literal["search", "collection"]
212 changes: 155 additions & 57 deletions api/api/controllers/search_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import logging as log
from math import ceil
from typing import Literal
from typing import TYPE_CHECKING

from django.conf import settings
from django.core.cache import cache
Expand All @@ -15,21 +15,35 @@
from elasticsearch_dsl.response import Hit, Response

import api.models as models
from api.constants.media_types import OriginIndex
from api.constants.media_types import OriginIndex, SearchIndex
from api.constants.search import SearchStrategy
from api.constants.sorting import INDEXED_ON
from api.controllers.elasticsearch.helpers import (
ELASTICSEARCH_MAX_RESULT_WINDOW,
get_es_response,
get_query_slice,
get_raw_es_response,
)
from api.serializers import media_serializers
from api.utils import tallies
from api.utils.check_dead_links import check_dead_links
from api.utils.dead_link_mask import get_query_hash
from api.utils.search_context import SearchContext


# Using TYPE_CHECKING to avoid circular imports when importing types
if TYPE_CHECKING:
from api.serializers.audio_serializers import AudioCollectionRequestSerializer
from api.serializers.media_serializers import (
MediaSearchRequestSerializer,
PaginatedRequestSerializer,
)

MediaListRequestSerializer = (
AudioCollectionRequestSerializer
| MediaSearchRequestSerializer
| PaginatedRequestSerializer
)

module_logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -172,24 +186,24 @@ def get_excluded_providers_query() -> Q | None:
return None


def _resolve_index(
index: Literal["image", "audio"],
search_params: media_serializers.MediaSearchRequestSerializer,
) -> Literal["image", "image-filtered", "audio", "audio-filtered"]:
use_filtered_index = all(
(
settings.ENABLE_FILTERED_INDEX_QUERIES,
not search_params.validated_data["include_sensitive_results"],
)
)
if use_filtered_index:
return f"{index}-filtered"
def get_index(
exact_index: bool,
origin_index: OriginIndex,
search_params: MediaListRequestSerializer,
) -> SearchIndex:
if exact_index:
return origin_index

return index
include_sensitive_results = search_params.validated_data.get(
"include_sensitive_results", False
)
if settings.ENABLE_FILTERED_INDEX_QUERIES and not include_sensitive_results:
return f"{origin_index}-filtered"
return origin_index


def create_search_filter_queries(
search_params: media_serializers.MediaSearchRequestSerializer,
search_params: MediaListRequestSerializer,
) -> dict[str, list[Q]]:
"""
Create a list of Elasticsearch queries for filtering search results.
Expand Down Expand Up @@ -230,7 +244,7 @@ def create_search_filter_queries(


def create_ranking_queries(
search_params: media_serializers.MediaSearchRequestSerializer,
search_params: MediaListRequestSerializer,
) -> list[Q]:
queries = [Q("rank_feature", field="standardized_popularity", boost=DEFAULT_BOOST)]
if search_params.data["unstable__authority"]:
Expand All @@ -240,8 +254,8 @@ def create_ranking_queries(
return queries


def create_search_query(
search_params: media_serializers.MediaSearchRequestSerializer,
def build_search_query(
search_params: MediaListRequestSerializer,
) -> Q:
# Apply filters from the url query search parameters.
url_queries = create_search_filter_queries(search_params)
Expand Down Expand Up @@ -315,8 +329,60 @@ def create_search_query(
)


def search(
search_params: media_serializers.MediaSearchRequestSerializer,
def build_collection_query(
search_params: MediaListRequestSerializer,
collection_params: dict[str, str],
):
"""
Build the query to retrieve items in a collection.
:param collection_params: `tag`, `source` and/or `creator` values from the path.
:param search_params: the validated search parameters.
:return: the search client with the query applied.
"""
search_query = {"filter": [], "must": [], "should": [], "must_not": []}
# Apply the term filters. Each tuple pairs a filter's parameter name in the API
# with its corresponding field in Elasticsearch. "None" means that the
# names are identical.
filters = [
# Collection filters allow a single value.
("tag", "tags.name.keyword"),
("source", None),
("creator", "creator.keyword"),
]
for serializer_field, es_field in filters:
if serializer_field in collection_params:
if not (argument := collection_params.get(serializer_field)):
continue
parameter = es_field or serializer_field
search_query["filter"].append({"term": {parameter: argument}})

# Exclude mature content and disabled sources
include_sensitive_by_params = search_params.validated_data.get(
"include_sensitive_results", False
)
if not include_sensitive_by_params:
search_query["must_not"].append({"term": {"mature": True}})

if excluded_providers_query := get_excluded_providers_query():
search_query["must_not"].append(excluded_providers_query)

return Q("bool", **search_query)


def build_query(
strategy: SearchStrategy,
search_params: MediaListRequestSerializer,
collection_params: dict[str, str] | None,
) -> Q:
if strategy == "collection":
return build_collection_query(search_params, collection_params)
return build_search_query(search_params)


def query_media(
strategy: SearchStrategy,
search_params: MediaListRequestSerializer,
collection_params: dict[str, str] | None,
origin_index: OriginIndex,
exact_index: bool,
page_size: int,
Expand All @@ -325,10 +391,17 @@ def search(
page: int = 1,
) -> tuple[list[Hit], int, int, dict]:
"""
Perform a ranked paginated search from the set of keywords and, optionally, filters.
:param search_params: Search parameters. See
:class: `ImageSearchQueryStringSerializer`.
If ``strategy`` is ``search``, perform a ranked paginated search
from the set of keywords and, optionally, filters.
If `strategy` is `collection`, perform a paginated search
for the `tag`, `source` or `source` and `creator` combination.
:param collection_params: The path parameters for collection search, if
strategy is `collection`.
:param strategy: Whether to perform a default search or retrieve a collection.
:param search_params: If `strategy` is `collection`, `PaginatedRequestSerializer`
or `AudioCollectionRequestSerializer`. If `strategy` is `search`, search
query params, see :class: `MediaRequestSerializer`.
:param origin_index: The Elasticsearch index to search (e.g. 'image')
:param exact_index: whether to skip all modifications to the index name
:param page_size: The number of results to return per page.
Expand All @@ -337,46 +410,51 @@ def search(
Elasticsearch shards.
:param filter_dead: Whether dead links should be removed.
:param page: The results page number.
:return: Tuple with a List of Hits from elasticsearch, the total count of
:return: Tuple with a list of Hits from elasticsearch, the total count of
pages, the number of results, and the ``SearchContext`` as a dict.
"""
if not exact_index:
index = _resolve_index(origin_index, search_params)
else:
index = origin_index
index = get_index(exact_index, origin_index, search_params)

query = build_query(strategy, search_params, collection_params)

s = Search(index=index)
s = Search(index=index).query(query)

search_query = create_search_query(search_params)
s = s.query(search_query)
if strategy == "search":
# Use highlighting to determine which fields contribute to the selection of
# top results.
s = s.highlight(*DEFAULT_SEARCH_FIELDS)
s = s.highlight_options(order="score")
s.extra(track_scores=True)

# Use highlighting to determine which fields contribute to the selection of
# top results.
s = s.highlight(*DEFAULT_SEARCH_FIELDS)
s = s.highlight_options(order="score")
s.extra(track_scores=True)
# Route users to the same Elasticsearch worker node to reduce
# pagination inconsistencies and increase cache hits.
# TODO: Re-add 7s request_timeout when ES stability is restored
s = s.params(preference=str(ip))

# Sort by new
if search_params.validated_data["sort_by"] == INDEXED_ON:
s = s.sort({"created_on": {"order": search_params.validated_data["sort_dir"]}})
# Sort by `created_on` if the parameter is set or if `strategy` is `collection`.
sort_by = search_params.validated_data.get("sort_by")
if strategy == "collection" or sort_by == INDEXED_ON:
sort_dir = search_params.validated_data.get("sort_dir", "desc")
s = s.sort({"created_on": {"order": sort_dir}})

# Paginate
start, end = get_query_slice(s, page_size, page, filter_dead)
s = s[start:end]
search_response = get_es_response(s, es_query="search")

results = _post_process_results(
s, start, end, page_size, search_response, filter_dead
# Execute paginated search and tally results
page_count, result_count, results = execute_search(
s, page, page_size, filter_dead, index, es_query=strategy
)

result_count, page_count = _get_result_and_page_count(
search_response, results, page_size, page
)
result_ids = [result.identifier for result in results]
search_context = SearchContext.build(result_ids, origin_index)

return results, page_count, result_count, search_context.asdict()


def tally_results(
index: SearchIndex, results: list[Hit] | None, page: int, page_size: int
) -> None:
"""
Tally the number of the results from each provider in the results
for the search query.
"""
results_to_tally = results or []
max_result_depth = page * page_size
if max_result_depth <= 80:
Expand Down Expand Up @@ -405,13 +483,33 @@ def search(
# check things like provider density for a set of queries.
tallies.count_provider_occurrences(results_to_tally, index)

if not results:
results = []

result_ids = [result.identifier for result in results]
search_context = SearchContext.build(result_ids, origin_index)
def execute_search(
s: Search,
page: int,
page_size: int,
filter_dead: bool,
index: SearchIndex,
es_query: str,
) -> tuple[int, int, list[Hit]]:
"""
Execute search for the given query slice, post-processes the results,
and returns the results and result and page counts.
"""
start, end = get_query_slice(s, page_size, page, filter_dead)
s = s[start:end]

return results, page_count, result_count, search_context.asdict()
search_response = get_es_response(s, es_query=es_query)

results: list[Hit] = (
_post_process_results(s, start, end, page_size, search_response, filter_dead)
or []
)
result_count, page_count = _get_result_and_page_count(
search_response, results, page_size, page
)
tally_results(index, results, page, page_size)
return page_count, result_count, results


def get_sources(index):
Expand Down
17 changes: 15 additions & 2 deletions api/api/docs/audio_docs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from drf_spectacular.utils import OpenApiResponse, extend_schema

from api.docs.base_docs import custom_extend_schema, fields_to_md
from api.docs.base_docs import collection_schema, custom_extend_schema, fields_to_md
from api.examples import (
audio_complain_201_example,
audio_complain_curl,
Expand Down Expand Up @@ -39,7 +39,7 @@
By using this endpoint, you can obtain search results based on specified
query and optionally filter results by
{fields_to_md(AudioSearchRequestSerializer.fields_names)}.
{fields_to_md(AudioSearchRequestSerializer.field_names)}.
Results are ranked in order of relevance and paginated on the basis of the
`page` param. The `page_size` param controls the total number of pages.
Expand Down Expand Up @@ -116,3 +116,16 @@
},
eg=[audio_waveform_curl],
)

source_collection = collection_schema(
media_type="audio",
collection="source",
)
creator_collection = collection_schema(
media_type="audio",
collection="creator",
)
tag_collection = collection_schema(
media_type="audio",
collection="tag",
)
Loading

0 comments on commit 3986e71

Please sign in to comment.