Skip to content

Commit

Permalink
Simplify search query (#3261)
Browse files Browse the repository at this point in the history
* Simplify search query

* Add unit tests

* Update api/api/controllers/search_controller.py

Co-authored-by: sarayourfriend <[email protected]>

* Refactor excluded providers function

Signed-off-by: Olga Bulat <[email protected]>

* Add documentation about filters

Signed-off-by: Olga Bulat <[email protected]>

* Raises exception in serializer

Signed-off-by: Olga Bulat <[email protected]>

---------

Signed-off-by: Olga Bulat <[email protected]>
Co-authored-by: sarayourfriend <[email protected]>
  • Loading branch information
obulat and sarayourfriend authored Nov 1, 2023
1 parent 02a4f68 commit 535ded0
Show file tree
Hide file tree
Showing 3 changed files with 393 additions and 145 deletions.
278 changes: 142 additions & 136 deletions api/api/controllers/search_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from elasticsearch.exceptions import BadRequestError, NotFoundError
from elasticsearch_dsl import Q, Search
from elasticsearch_dsl.query import EMPTY_QUERY, Match, Query, SimpleQueryString, Term
from elasticsearch_dsl.query import EMPTY_QUERY, Match, SimpleQueryString, Term
from elasticsearch_dsl.response import Hit, Response

import api.models as models
Expand All @@ -34,10 +34,7 @@
DEEP_PAGINATION_ERROR = "Deep pagination is not allowed."
QUERY_SPECIAL_CHARACTER_ERROR = "Unescaped special characters are not allowed."
DEFAULT_BOOST = 10000


class RankFeature(Query):
name = "rank_feature"
DEFAULT_SEARCH_FIELDS = ["title", "description", "tags.name"]


def _unmasked_query_end(page_size, page):
Expand Down Expand Up @@ -235,44 +232,13 @@ def _post_process_results(
return results[:page_size]


def _apply_filter(
s: Search,
search_params: media_serializers.MediaSearchRequestSerializer,
serializer_field: str,
es_field: str | None = None,
behaviour: Literal["filter", "exclude"] = "filter",
):
def get_excluded_providers_query() -> Q | None:
"""
Parse and apply a filter from the search parameters serializer.
The parameter key is assumed to have the same name as the corresponding
Elasticsearch property. Each parameter value is assumed to be a comma
separated list encoded as a string.
:param s: The ``Search`` instance to apply the filter to
:param search_params: the serializer instance containing user input
:param serializer_field: the name of the parameter field in ``search_params``
:param es_field: the corresponding parameter name in Elasticsearch
:param behaviour: whether to accept (``filter``) or reject (``exclude``) the hit
:return: the input ``Search`` object with the filters applied
Hide data sources from the catalog dynamically.
To exclude a provider, set ``filter_content`` to ``True`` in the
``ContentProvider`` model in Django admin.
"""

if serializer_field in search_params.data:
arguments = search_params.data.get(serializer_field)
if arguments is None:
return s
arguments = arguments.split(",")
parameter = es_field or serializer_field
query = Q("terms", **{parameter: arguments})
method = getattr(s, behaviour)
return method("bool", should=query)

return s


def _exclude_filtered(s: Search):
"""Hide data sources from the catalog dynamically."""

filter_cache_key = "filtered_providers"
filtered_providers = cache.get(key=filter_cache_key)
if not filtered_providers:
Expand All @@ -282,16 +248,9 @@ def _exclude_filtered(s: Search):
cache.set(
key=filter_cache_key, timeout=FILTER_CACHE_TIMEOUT, value=filtered_providers
)
to_exclude = [f["provider_identifier"] for f in filtered_providers]
if to_exclude:
s = s.exclude("terms", provider=to_exclude)
return s


def _exclude_sensitive_by_param(s: Search, search_params):
if not search_params.validated_data["include_sensitive_results"]:
s = s.exclude("term", mature=True)
return s
if provider_list := [f["provider_identifier"] for f in filtered_providers]:
return Q("terms", provider=provider_list)
return None


def _resolve_index(
Expand All @@ -310,88 +269,91 @@ def _resolve_index(
return index


def search(
def create_search_filter_queries(
search_params: media_serializers.MediaSearchRequestSerializer,
origin_index: OriginIndex,
exact_index: bool,
page_size: int,
ip: int,
filter_dead: bool,
page: int = 1,
) -> tuple[list[Hit], int, int, dict]:
) -> dict[str, list[Q]]:
"""
Perform a ranked paginated search from the set of keywords and, optionally, filters.
:param search_params: Search parameters. See
:class: `ImageSearchQueryStringSerializer`.
:param origin_index: The Elasticsearch index to search (e.g. 'image')
:param exact_index: whether to skip all modifications to the index name
:param page_size: The number of results to return per page.
:param ip: The user's hashed IP. Hashed IPs are used to anonymously but
uniquely identify users exclusively for ensuring query consistency across
Elasticsearch shards.
:param filter_dead: Whether dead links should be removed.
:param page: The results page number.
:return: Tuple with a List of Hits from elasticsearch, the total count of
pages, the number of results, and the ``SearchContext`` as a dict.
Create a list of Elasticsearch queries for filtering search results.
The filter values are given in the request query string.
We use ES filters (`filter`, `must_not`) because we don't need to
compute the relevance score and the queries are cached for better
performance.
"""
if not exact_index:
index = _resolve_index(origin_index, search_params)
else:
index = origin_index

search_client = Search(index=index)

s = search_client
queries = {"filter": [], "must_not": []}
# Apply term filters. Each tuple pairs a filter's parameter name in the API
# with its corresponding field in Elasticsearch. "None" means that the
# names are identical.
filters = [
("extension", None),
("category", None),
("categories", "category"),
("source", None),
("license", None),
("license_type", "license"),
# Audio-specific filters
("length", None),
# Image-specific filters
("aspect_ratio", None),
("size", None),
]
for serializer_field, es_field in filters:
if serializer_field in search_params.data:
s = _apply_filter(s, search_params, serializer_field, es_field)

exclude = [
("excluded_source", "source"),
]
for serializer_field, es_field in exclude:
if serializer_field in search_params.data:
s = _apply_filter(s, search_params, serializer_field, es_field, "exclude")

# Exclude mature content and disabled sources
s = _exclude_sensitive_by_param(s, search_params)
s = _exclude_filtered(s)
query_filters = {
"filter": [
("extension", None),
("category", None),
("source", None),
("license", None),
("license_type", "license"),
# Audio-specific filters
("length", None),
# Image-specific filters
("aspect_ratio", None),
("size", None),
],
"must_not": [
("excluded_source", "source"),
],
}
for behaviour, filters in query_filters.items():
for serializer_field, es_field in filters:
if not (arguments := search_params.data.get(serializer_field)):
continue
arguments = arguments.split(",")
parameter = es_field or serializer_field
queries[behaviour].append(Q("terms", **{parameter: arguments}))
return queries


def create_ranking_queries(
search_params: media_serializers.MediaSearchRequestSerializer,
) -> list[Q]:
queries = [Q("rank_feature", field="standardized_popularity", boost=DEFAULT_BOOST)]
if search_params.data["unstable__authority"]:
boost = int(search_params.data["unstable__authority_boost"] * DEFAULT_BOOST)
authority_query = Q("rank_feature", field="authority_boost", boost=boost)
queries.append(authority_query)
return queries


def create_search_query(
search_params: media_serializers.MediaSearchRequestSerializer,
) -> Q:
# Apply filters from the url query search parameters.
url_queries = create_search_filter_queries(search_params)
search_queries = {
"filter": url_queries["filter"],
"must_not": url_queries["must_not"],
"must": [],
"should": [],
}

# Exclude mature content
if not search_params.validated_data["include_sensitive_results"]:
search_queries["must_not"].append(Q("term", mature=True))
# Exclude dynamically disabled sources (see Redis cache)
if excluded_providers_query := get_excluded_providers_query():
search_queries["must_not"].append(excluded_providers_query)

# Search either by generic multimatch or by "advanced search" with
# individual field-level queries specified.
search_fields = ["tags.name", "title", "description"]
if "q" in search_params.data:
query = _quote_escape(search_params.data["q"])
base_query_kwargs = {
"query": query,
"fields": search_fields,
"fields": DEFAULT_SEARCH_FIELDS,
"default_operator": "AND",
}

if '"' in query:
base_query_kwargs["quote_field_suffix"] = ".exact"

s = s.query(
"simple_query_string",
**base_query_kwargs,
)
search_queries["must"].append(Q("simple_query_string", **base_query_kwargs))
# Boost exact matches on the title
quotes_stripped = query.replace('"', "")
exact_match_boost = Q(
Expand All @@ -400,35 +362,78 @@ def search(
query=f"{quotes_stripped}",
boost=10000,
)
s = search_client.query(Q("bool", must=s.query, should=exact_match_boost))
search_queries["should"].append(exact_match_boost)
else:
if "creator" in search_params.data:
creator = _quote_escape(search_params.data["creator"])
s = s.query("simple_query_string", query=creator, fields=["creator"])
if "title" in search_params.data:
title = _quote_escape(search_params.data["title"])
s = s.query("simple_query_string", query=title, fields=["title"])
if "tags" in search_params.data:
tags = _quote_escape(search_params.data["tags"])
s = s.query("simple_query_string", fields=["tags.name"], query=tags)
for field, field_name in [
("creator", "creator"),
("title", "title"),
("tags", "tags.name"),
]:
if field_value := search_params.data.get(field):
search_queries["must"].append(
Q(
"simple_query_string",
query=_quote_escape(field_value),
fields=[field_name],
)
)

if settings.USE_RANK_FEATURES:
feature_boost = {"standardized_popularity": DEFAULT_BOOST}
if search_params.data["unstable__authority"]:
feature_boost["authority_boost"] = (
search_params.data["unstable__authority_boost"] * DEFAULT_BOOST
)
search_queries["should"].extend(create_ranking_queries(search_params))

# If there are no `must` query clauses, only the results that match
# the `should` clause are returned. To avoid this, we add an empty
# query clause to the `must` list.
if not search_queries["must"]:
search_queries["must"].append(EMPTY_QUERY)

return Q(
"bool",
filter=search_queries["filter"],
must_not=search_queries["must_not"],
must=search_queries["must"],
should=search_queries["should"],
)

rank_queries = []
for field, boost in feature_boost.items():
rank_queries.append(Q("rank_feature", field=field, boost=boost))
s = search_client.query(
Q("bool", must=s.query or EMPTY_QUERY, should=rank_queries)
)

def search(
search_params: media_serializers.MediaSearchRequestSerializer,
origin_index: OriginIndex,
exact_index: bool,
page_size: int,
ip: int,
filter_dead: bool,
page: int = 1,
) -> tuple[list[Hit], int, int, dict]:
"""
Perform a ranked paginated search from the set of keywords and, optionally, filters.
:param search_params: Search parameters. See
:class: `ImageSearchQueryStringSerializer`.
:param origin_index: The Elasticsearch index to search (e.g. 'image')
:param exact_index: whether to skip all modifications to the index name
:param page_size: The number of results to return per page.
:param ip: The user's hashed IP. Hashed IPs are used to anonymously but
uniquely identify users exclusively for ensuring query consistency across
Elasticsearch shards.
:param filter_dead: Whether dead links should be removed.
:param page: The results page number.
:return: Tuple with a List of Hits from elasticsearch, the total count of
pages, the number of results, and the ``SearchContext`` as a dict.
"""
if not exact_index:
index = _resolve_index(origin_index, search_params)
else:
index = origin_index

s = Search(index=index)

search_query = create_search_query(search_params)
s = s.query(search_query)

# Use highlighting to determine which fields contribute to the selection of
# top results.
s = s.highlight(*search_fields)
s = s.highlight(*DEFAULT_SEARCH_FIELDS)
s = s.highlight_options(order="score")
s.extra(track_scores=True)
# Route users to the same Elasticsearch worker node to reduce
Expand Down Expand Up @@ -541,7 +546,8 @@ def related_media(uuid: str, index: str, filter_dead: bool) -> list[Hit]:
# Exclude the current item and mature content.
s = s.query(related_query & ~Term(identifier=uuid) & ~Term(mature=True))
# Exclude the dynamically disabled sources.
s = _exclude_filtered(s)
if excluded_providers_query := get_excluded_providers_query():
s = s.exclude(excluded_providers_query)

page, page_size = 1, 10
start, end = _get_query_slice(s, page_size, page, filter_dead)
Expand Down
Loading

0 comments on commit 535ded0

Please sign in to comment.