From 535ded087b2f0188370f8ef282c978b966e64104 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Wed, 1 Nov 2023 14:24:39 +0300 Subject: [PATCH] Simplify search query (#3261) * Simplify search query * Add unit tests * Update api/api/controllers/search_controller.py Co-authored-by: sarayourfriend <24264157+sarayourfriend@users.noreply.github.com> * Refactor excluded providers function Signed-off-by: Olga Bulat * Add documentation about filters Signed-off-by: Olga Bulat * Raises exception in serializer Signed-off-by: Olga Bulat --------- Signed-off-by: Olga Bulat Co-authored-by: sarayourfriend <24264157+sarayourfriend@users.noreply.github.com> --- api/api/controllers/search_controller.py | 278 +++++++++--------- .../test_search_controller_search_query.py | 240 +++++++++++++++ .../api/reference/search_algorithm.md | 20 +- 3 files changed, 393 insertions(+), 145 deletions(-) create mode 100644 api/test/unit/controllers/test_search_controller_search_query.py diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index 045cf593775..8d5c9065dda 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -11,7 +11,7 @@ from elasticsearch.exceptions import BadRequestError, NotFoundError from elasticsearch_dsl import Q, Search -from elasticsearch_dsl.query import EMPTY_QUERY, Match, Query, SimpleQueryString, Term +from elasticsearch_dsl.query import EMPTY_QUERY, Match, SimpleQueryString, Term from elasticsearch_dsl.response import Hit, Response import api.models as models @@ -34,10 +34,7 @@ DEEP_PAGINATION_ERROR = "Deep pagination is not allowed." QUERY_SPECIAL_CHARACTER_ERROR = "Unescaped special characters are not allowed." DEFAULT_BOOST = 10000 - - -class RankFeature(Query): - name = "rank_feature" +DEFAULT_SEARCH_FIELDS = ["title", "description", "tags.name"] def _unmasked_query_end(page_size, page): @@ -235,44 +232,13 @@ def _post_process_results( return results[:page_size] -def _apply_filter( - s: Search, - search_params: media_serializers.MediaSearchRequestSerializer, - serializer_field: str, - es_field: str | None = None, - behaviour: Literal["filter", "exclude"] = "filter", -): +def get_excluded_providers_query() -> Q | None: """ - Parse and apply a filter from the search parameters serializer. - - The parameter key is assumed to have the same name as the corresponding - Elasticsearch property. Each parameter value is assumed to be a comma - separated list encoded as a string. - - :param s: The ``Search`` instance to apply the filter to - :param search_params: the serializer instance containing user input - :param serializer_field: the name of the parameter field in ``search_params`` - :param es_field: the corresponding parameter name in Elasticsearch - :param behaviour: whether to accept (``filter``) or reject (``exclude``) the hit - :return: the input ``Search`` object with the filters applied + Hide data sources from the catalog dynamically. + To exclude a provider, set ``filter_content`` to ``True`` in the + ``ContentProvider`` model in Django admin. """ - if serializer_field in search_params.data: - arguments = search_params.data.get(serializer_field) - if arguments is None: - return s - arguments = arguments.split(",") - parameter = es_field or serializer_field - query = Q("terms", **{parameter: arguments}) - method = getattr(s, behaviour) - return method("bool", should=query) - - return s - - -def _exclude_filtered(s: Search): - """Hide data sources from the catalog dynamically.""" - filter_cache_key = "filtered_providers" filtered_providers = cache.get(key=filter_cache_key) if not filtered_providers: @@ -282,16 +248,9 @@ def _exclude_filtered(s: Search): cache.set( key=filter_cache_key, timeout=FILTER_CACHE_TIMEOUT, value=filtered_providers ) - to_exclude = [f["provider_identifier"] for f in filtered_providers] - if to_exclude: - s = s.exclude("terms", provider=to_exclude) - return s - - -def _exclude_sensitive_by_param(s: Search, search_params): - if not search_params.validated_data["include_sensitive_results"]: - s = s.exclude("term", mature=True) - return s + if provider_list := [f["provider_identifier"] for f in filtered_providers]: + return Q("terms", provider=provider_list) + return None def _resolve_index( @@ -310,88 +269,91 @@ def _resolve_index( return index -def search( +def create_search_filter_queries( search_params: media_serializers.MediaSearchRequestSerializer, - origin_index: OriginIndex, - exact_index: bool, - page_size: int, - ip: int, - filter_dead: bool, - page: int = 1, -) -> tuple[list[Hit], int, int, dict]: +) -> dict[str, list[Q]]: """ - Perform a ranked paginated search from the set of keywords and, optionally, filters. - - :param search_params: Search parameters. See - :class: `ImageSearchQueryStringSerializer`. - :param origin_index: The Elasticsearch index to search (e.g. 'image') - :param exact_index: whether to skip all modifications to the index name - :param page_size: The number of results to return per page. - :param ip: The user's hashed IP. Hashed IPs are used to anonymously but - uniquely identify users exclusively for ensuring query consistency across - Elasticsearch shards. - :param filter_dead: Whether dead links should be removed. - :param page: The results page number. - :return: Tuple with a List of Hits from elasticsearch, the total count of - pages, the number of results, and the ``SearchContext`` as a dict. + Create a list of Elasticsearch queries for filtering search results. + The filter values are given in the request query string. + We use ES filters (`filter`, `must_not`) because we don't need to + compute the relevance score and the queries are cached for better + performance. """ - if not exact_index: - index = _resolve_index(origin_index, search_params) - else: - index = origin_index - - search_client = Search(index=index) - - s = search_client + queries = {"filter": [], "must_not": []} # Apply term filters. Each tuple pairs a filter's parameter name in the API # with its corresponding field in Elasticsearch. "None" means that the # names are identical. - filters = [ - ("extension", None), - ("category", None), - ("categories", "category"), - ("source", None), - ("license", None), - ("license_type", "license"), - # Audio-specific filters - ("length", None), - # Image-specific filters - ("aspect_ratio", None), - ("size", None), - ] - for serializer_field, es_field in filters: - if serializer_field in search_params.data: - s = _apply_filter(s, search_params, serializer_field, es_field) - - exclude = [ - ("excluded_source", "source"), - ] - for serializer_field, es_field in exclude: - if serializer_field in search_params.data: - s = _apply_filter(s, search_params, serializer_field, es_field, "exclude") - - # Exclude mature content and disabled sources - s = _exclude_sensitive_by_param(s, search_params) - s = _exclude_filtered(s) + query_filters = { + "filter": [ + ("extension", None), + ("category", None), + ("source", None), + ("license", None), + ("license_type", "license"), + # Audio-specific filters + ("length", None), + # Image-specific filters + ("aspect_ratio", None), + ("size", None), + ], + "must_not": [ + ("excluded_source", "source"), + ], + } + for behaviour, filters in query_filters.items(): + for serializer_field, es_field in filters: + if not (arguments := search_params.data.get(serializer_field)): + continue + arguments = arguments.split(",") + parameter = es_field or serializer_field + queries[behaviour].append(Q("terms", **{parameter: arguments})) + return queries + + +def create_ranking_queries( + search_params: media_serializers.MediaSearchRequestSerializer, +) -> list[Q]: + queries = [Q("rank_feature", field="standardized_popularity", boost=DEFAULT_BOOST)] + if search_params.data["unstable__authority"]: + boost = int(search_params.data["unstable__authority_boost"] * DEFAULT_BOOST) + authority_query = Q("rank_feature", field="authority_boost", boost=boost) + queries.append(authority_query) + return queries + + +def create_search_query( + search_params: media_serializers.MediaSearchRequestSerializer, +) -> Q: + # Apply filters from the url query search parameters. + url_queries = create_search_filter_queries(search_params) + search_queries = { + "filter": url_queries["filter"], + "must_not": url_queries["must_not"], + "must": [], + "should": [], + } + + # Exclude mature content + if not search_params.validated_data["include_sensitive_results"]: + search_queries["must_not"].append(Q("term", mature=True)) + # Exclude dynamically disabled sources (see Redis cache) + if excluded_providers_query := get_excluded_providers_query(): + search_queries["must_not"].append(excluded_providers_query) # Search either by generic multimatch or by "advanced search" with # individual field-level queries specified. - search_fields = ["tags.name", "title", "description"] if "q" in search_params.data: query = _quote_escape(search_params.data["q"]) base_query_kwargs = { "query": query, - "fields": search_fields, + "fields": DEFAULT_SEARCH_FIELDS, "default_operator": "AND", } if '"' in query: base_query_kwargs["quote_field_suffix"] = ".exact" - s = s.query( - "simple_query_string", - **base_query_kwargs, - ) + search_queries["must"].append(Q("simple_query_string", **base_query_kwargs)) # Boost exact matches on the title quotes_stripped = query.replace('"', "") exact_match_boost = Q( @@ -400,35 +362,78 @@ def search( query=f"{quotes_stripped}", boost=10000, ) - s = search_client.query(Q("bool", must=s.query, should=exact_match_boost)) + search_queries["should"].append(exact_match_boost) else: - if "creator" in search_params.data: - creator = _quote_escape(search_params.data["creator"]) - s = s.query("simple_query_string", query=creator, fields=["creator"]) - if "title" in search_params.data: - title = _quote_escape(search_params.data["title"]) - s = s.query("simple_query_string", query=title, fields=["title"]) - if "tags" in search_params.data: - tags = _quote_escape(search_params.data["tags"]) - s = s.query("simple_query_string", fields=["tags.name"], query=tags) + for field, field_name in [ + ("creator", "creator"), + ("title", "title"), + ("tags", "tags.name"), + ]: + if field_value := search_params.data.get(field): + search_queries["must"].append( + Q( + "simple_query_string", + query=_quote_escape(field_value), + fields=[field_name], + ) + ) if settings.USE_RANK_FEATURES: - feature_boost = {"standardized_popularity": DEFAULT_BOOST} - if search_params.data["unstable__authority"]: - feature_boost["authority_boost"] = ( - search_params.data["unstable__authority_boost"] * DEFAULT_BOOST - ) + search_queries["should"].extend(create_ranking_queries(search_params)) + + # If there are no `must` query clauses, only the results that match + # the `should` clause are returned. To avoid this, we add an empty + # query clause to the `must` list. + if not search_queries["must"]: + search_queries["must"].append(EMPTY_QUERY) + + return Q( + "bool", + filter=search_queries["filter"], + must_not=search_queries["must_not"], + must=search_queries["must"], + should=search_queries["should"], + ) - rank_queries = [] - for field, boost in feature_boost.items(): - rank_queries.append(Q("rank_feature", field=field, boost=boost)) - s = search_client.query( - Q("bool", must=s.query or EMPTY_QUERY, should=rank_queries) - ) + +def search( + search_params: media_serializers.MediaSearchRequestSerializer, + origin_index: OriginIndex, + exact_index: bool, + page_size: int, + ip: int, + filter_dead: bool, + page: int = 1, +) -> tuple[list[Hit], int, int, dict]: + """ + Perform a ranked paginated search from the set of keywords and, optionally, filters. + + :param search_params: Search parameters. See + :class: `ImageSearchQueryStringSerializer`. + :param origin_index: The Elasticsearch index to search (e.g. 'image') + :param exact_index: whether to skip all modifications to the index name + :param page_size: The number of results to return per page. + :param ip: The user's hashed IP. Hashed IPs are used to anonymously but + uniquely identify users exclusively for ensuring query consistency across + Elasticsearch shards. + :param filter_dead: Whether dead links should be removed. + :param page: The results page number. + :return: Tuple with a List of Hits from elasticsearch, the total count of + pages, the number of results, and the ``SearchContext`` as a dict. + """ + if not exact_index: + index = _resolve_index(origin_index, search_params) + else: + index = origin_index + + s = Search(index=index) + + search_query = create_search_query(search_params) + s = s.query(search_query) # Use highlighting to determine which fields contribute to the selection of # top results. - s = s.highlight(*search_fields) + s = s.highlight(*DEFAULT_SEARCH_FIELDS) s = s.highlight_options(order="score") s.extra(track_scores=True) # Route users to the same Elasticsearch worker node to reduce @@ -541,7 +546,8 @@ def related_media(uuid: str, index: str, filter_dead: bool) -> list[Hit]: # Exclude the current item and mature content. s = s.query(related_query & ~Term(identifier=uuid) & ~Term(mature=True)) # Exclude the dynamically disabled sources. - s = _exclude_filtered(s) + if excluded_providers_query := get_excluded_providers_query(): + s = s.exclude(excluded_providers_query) page, page_size = 1, 10 start, end = _get_query_slice(s, page_size, page, filter_dead) diff --git a/api/test/unit/controllers/test_search_controller_search_query.py b/api/test/unit/controllers/test_search_controller_search_query.py new file mode 100644 index 00000000000..755e4c6ac76 --- /dev/null +++ b/api/test/unit/controllers/test_search_controller_search_query.py @@ -0,0 +1,240 @@ +from django.core.cache import cache + +import pytest + +from api.controllers import search_controller + + +pytestmark = pytest.mark.django_db + + +def test_create_search_query_empty(media_type_config): + serializer = media_type_config.search_request_serializer(data={}) + serializer.is_valid(raise_exception=True) + search_query = search_controller.create_search_query(serializer) + actual_query_clauses = search_query.to_dict()["bool"] + + assert actual_query_clauses == { + "must_not": [{"term": {"mature": True}}], + "must": [{"match_all": {}}], + "should": [ + {"rank_feature": {"boost": 10000, "field": "standardized_popularity"}} + ], + } + + +def test_create_search_query_empty_no_ranking(media_type_config, settings): + settings.USE_RANK_FEATURES = False + serializer = media_type_config.search_request_serializer(data={}) + serializer.is_valid(raise_exception=True) + search_query = search_controller.create_search_query(serializer) + actual_query_clauses = search_query.to_dict()["bool"] + + assert actual_query_clauses == { + "must_not": [{"term": {"mature": True}}], + "must": [{"match_all": {}}], + } + + +def test_create_search_query_q_search_no_filters(media_type_config): + serializer = media_type_config.search_request_serializer(data={"q": "cat"}) + serializer.is_valid(raise_exception=True) + search_query = search_controller.create_search_query(serializer) + actual_query_clauses = search_query.to_dict()["bool"] + + assert actual_query_clauses == { + "must_not": [{"term": {"mature": True}}], + "must": [ + { + "simple_query_string": { + "default_operator": "AND", + "fields": ["title", "description", "tags.name"], + "query": "cat", + } + } + ], + "should": [ + { + "simple_query_string": { + "boost": 10000, + "fields": ["title"], + "query": "cat", + } + }, + {"rank_feature": {"boost": 10000, "field": "standardized_popularity"}}, + ], + } + + +def test_create_search_query_q_search_with_quotes_adds_exact_suffix(media_type_config): + serializer = media_type_config.search_request_serializer( + data={"q": '"The cutest cat"'} + ) + serializer.is_valid(raise_exception=True) + search_query = search_controller.create_search_query(serializer) + actual_query_clauses = search_query.to_dict()["bool"] + + assert actual_query_clauses == { + "must_not": [{"term": {"mature": True}}], + "must": [ + { + "simple_query_string": { + "default_operator": "AND", + "fields": ["title", "description", "tags.name"], + "query": '"The cutest cat"', + "quote_field_suffix": ".exact", + } + } + ], + "should": [ + { + "simple_query_string": { + "boost": 10000, + "fields": ["title"], + "query": "The cutest cat", + } + }, + {"rank_feature": {"boost": 10000, "field": "standardized_popularity"}}, + ], + } + + +def test_create_search_query_q_search_with_filters(image_media_type_config): + serializer = image_media_type_config.search_request_serializer( + data={ + "q": "cat", + "license": "by-nc", + "aspect_ratio": "wide", + # this is a deprecated param, and it doesn't work because it doesn't exist in the serializer + "categories": "digitized_artwork", + "category": "illustration", + "excluded_source": "flickr", + "unstable__authority": True, + "unstable__authority_boost": "2.5", + "unstable__include_sensitive_results": True, + } + ) + serializer.is_valid(raise_exception=True) + search_query = search_controller.create_search_query(serializer) + actual_query_clauses = search_query.to_dict()["bool"] + + assert actual_query_clauses == { + "filter": [ + {"terms": {"category": ["illustration"]}}, + {"terms": {"license": ["by-nc"]}}, + {"terms": {"aspect_ratio": ["wide"]}}, + ], + "must_not": [{"terms": {"source": ["flickr"]}}], + "must": [ + { + "simple_query_string": { + "default_operator": "AND", + "fields": ["title", "description", "tags.name"], + "query": "cat", + } + } + ], + "should": [ + { + "simple_query_string": { + "boost": 10000, + "fields": ["title"], + "query": "cat", + } + }, + {"rank_feature": {"boost": 10000, "field": "standardized_popularity"}}, + {"rank_feature": {"boost": 25000, "field": "authority_boost"}}, + ], + } + + +def test_create_search_query_non_q_query(image_media_type_config): + serializer = image_media_type_config.search_request_serializer( + data={ + "creator": "Artist From Openverse", + "title": "kitten🐱", + "tags": "cute", + } + ) + serializer.is_valid(raise_exception=True) + search_query = search_controller.create_search_query(serializer) + actual_query_clauses = search_query.to_dict()["bool"] + + assert actual_query_clauses == { + "must_not": [{"term": {"mature": True}}], + "must": [ + { + "simple_query_string": { + "fields": ["creator"], + "query": "Artist From Openverse", + } + }, + {"simple_query_string": {"fields": ["title"], "query": "kitten🐱"}}, + {"simple_query_string": {"fields": ["tags.name"], "query": "cute"}}, + ], + "should": [ + {"rank_feature": {"boost": 10000, "field": "standardized_popularity"}}, + ], + } + + +def test_create_search_query_q_search_license_license_type_creates_2_terms_filters( + image_media_type_config, +): + serializer = image_media_type_config.search_request_serializer( + data={ + "license": "by-nc", + "license_type": "commercial", + } + ) + serializer.is_valid(raise_exception=True) + search_query = search_controller.create_search_query(serializer) + actual_query_clauses = search_query.to_dict()["bool"] + + first_license_terms_filter = actual_query_clauses["filter"][0] + second_license_terms_filter_licenses = sorted( + actual_query_clauses["filter"][1]["terms"]["license"] + ) + # Extracting these to make comparisons not dependent on list order. + assert first_license_terms_filter == {"terms": {"license": ["by-nc"]}} + assert second_license_terms_filter_licenses == [ + "by", + "by-nd", + "by-sa", + "cc0", + "pdm", + "sampling+", + ] + actual_query_clauses.pop("filter", None) + + assert actual_query_clauses == { + "must_not": [{"term": {"mature": True}}], + "must": [{"match_all": {}}], + "should": [ + {"rank_feature": {"boost": 10000, "field": "standardized_popularity"}}, + ], + } + + +def test_create_search_query_empty_with_dynamically_excluded_providers( + image_media_type_config, +): + excluded = {"provider_identifier": "flickr"} + cache.set(key="filtered_providers", timeout=1, value=[excluded]) + + serializer = image_media_type_config.search_request_serializer(data={}) + serializer.is_valid(raise_exception=True) + + search_query = search_controller.create_search_query(serializer) + + actual_query_clauses = search_query.to_dict()["bool"] + assert actual_query_clauses == { + "must_not": [ + {"term": {"mature": True}}, + {"terms": {"provider": [excluded["provider_identifier"]]}}, + ], + "must": [{"match_all": {}}], + "should": [ + {"rank_feature": {"boost": 10000, "field": "standardized_popularity"}} + ], + } diff --git a/documentation/api/reference/search_algorithm.md b/documentation/api/reference/search_algorithm.md index a0dcd0c477f..516deeb601c 100644 --- a/documentation/api/reference/search_algorithm.md +++ b/documentation/api/reference/search_algorithm.md @@ -136,11 +136,12 @@ following fields: - Extension - Category -- Length -- Aspect ratio -- Size - Source - License +- License type +- Length (audio only) +- Aspect ratio (image only) +- Size (image only) Source is the only field for which you can currently also specify exclusions. @@ -152,12 +153,13 @@ field: - [Audio search](https://api.openverse.engineering/v1/#operation/audio_search) - [Image search](https://api.openverse.engineering/v1/#operation/image_search) -Each of these fields are searched relatively strictly, primarily because the -search domain in each is very small and "keyword" like. That is, there is a -limited and specific set of terms that appear for the relevant document fields -for each of these query parameters. All of them are validated to only allow -specific options (documented in the API documentation links above), which -enforces the "keyword" like nature of their usage. +For each of these fields, there is a limited and specific set of terms that +appear for the relevant document fields for each of these query parameters. +These fields are matched exactly, using the filter context Elasticsearch queries +("filter" or "must_not"). Filter-context queries can be cached by Elasticsearch, +which improves their performance. All of these filters except for `extension` +are validated to only allow specific options (documented in the API +documentation links above). ### General "query" searching