From 5222794a3a1d0863053a8190182d64918afab626 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Thu, 26 Oct 2023 18:17:44 +0300 Subject: [PATCH] Simplify search query --- api/api/controllers/search_controller.py | 144 ++++++++++++++--------- 1 file changed, 90 insertions(+), 54 deletions(-) diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index 045cf593775..e2ad5d72681 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -270,7 +270,7 @@ def _apply_filter( return s -def _exclude_filtered(s: Search): +def get_dynamically_excluded_providers(): """Hide data sources from the catalog dynamically.""" filter_cache_key = "filtered_providers" @@ -282,10 +282,7 @@ def _exclude_filtered(s: Search): cache.set( key=filter_cache_key, timeout=FILTER_CACHE_TIMEOUT, value=filtered_providers ) - to_exclude = [f["provider_identifier"] for f in filtered_providers] - if to_exclude: - s = s.exclude("terms", provider=to_exclude) - return s + return [f["provider_identifier"] for f in filtered_providers] def _exclude_sensitive_by_param(s: Search, search_params): @@ -310,6 +307,41 @@ def _resolve_index( return index +def create_search_filter_queries( + search_params: media_serializers.MediaSearchRequestSerializer, +) -> dict[str, list[Q]]: + queries = {"filter": [], "must_not": []} + # Apply term filters. Each tuple pairs a filter's parameter name in the API + # with its corresponding field in Elasticsearch. "None" means that the + # names are identical. + query_filters = { + "filter": [ + ("extension", None), + ("category", None), + ("categories", "category"), + ("source", None), + ("license", None), + ("license_type", "license"), + # Audio-specific filters + ("length", None), + # Image-specific filters + ("aspect_ratio", None), + ("size", None), + ], + "must_not": [ + ("excluded_source", "source"), + ], + } + for behaviour, filters in query_filters.items(): + for serializer_field, es_field in filters: + if not (arguments := search_params.data.get(serializer_field)): + continue + arguments = arguments.split(",") + parameter = es_field or serializer_field + queries[behaviour].append(Q("terms", **{parameter: arguments})) + return queries + + def search( search_params: media_serializers.MediaSearchRequestSerializer, origin_index: OriginIndex, @@ -341,38 +373,26 @@ def search( index = origin_index search_client = Search(index=index) - s = search_client - # Apply term filters. Each tuple pairs a filter's parameter name in the API - # with its corresponding field in Elasticsearch. "None" means that the - # names are identical. - filters = [ - ("extension", None), - ("category", None), - ("categories", "category"), - ("source", None), - ("license", None), - ("license_type", "license"), - # Audio-specific filters - ("length", None), - # Image-specific filters - ("aspect_ratio", None), - ("size", None), - ] - for serializer_field, es_field in filters: - if serializer_field in search_params.data: - s = _apply_filter(s, search_params, serializer_field, es_field) - - exclude = [ - ("excluded_source", "source"), - ] - for serializer_field, es_field in exclude: - if serializer_field in search_params.data: - s = _apply_filter(s, search_params, serializer_field, es_field, "exclude") - - # Exclude mature content and disabled sources - s = _exclude_sensitive_by_param(s, search_params) - s = _exclude_filtered(s) + + search_queries = { + "filter": [], + "must_not": [], + "must": [], + "should": [], + } + + # Apply filters from the url query search parameters. + url_queries = create_search_filter_queries(search_params) + search_queries["filter"].extend(url_queries["filter"]) + search_queries["must_not"].extend(url_queries["must_not"]) + + # Exclude mature content + if not search_params.validated_data["include_sensitive_results"]: + search_queries["must_not"].append(Q("term", mature=True)) + # Exclude dynamically disabled sources (see Redis cache) + if excluded_providers := get_dynamically_excluded_providers(): + search_queries["must_not"].extend(Q("terms", provider=excluded_providers)) # Search either by generic multimatch or by "advanced search" with # individual field-level queries specified. @@ -388,10 +408,7 @@ def search( if '"' in query: base_query_kwargs["quote_field_suffix"] = ".exact" - s = s.query( - "simple_query_string", - **base_query_kwargs, - ) + search_queries["must"].append(Q("simple_query_string", **base_query_kwargs)) # Boost exact matches on the title quotes_stripped = query.replace('"', "") exact_match_boost = Q( @@ -400,17 +417,23 @@ def search( query=f"{quotes_stripped}", boost=10000, ) - s = search_client.query(Q("bool", must=s.query, should=exact_match_boost)) + search_queries["should"].append(exact_match_boost) else: - if "creator" in search_params.data: - creator = _quote_escape(search_params.data["creator"]) - s = s.query("simple_query_string", query=creator, fields=["creator"]) - if "title" in search_params.data: - title = _quote_escape(search_params.data["title"]) - s = s.query("simple_query_string", query=title, fields=["title"]) - if "tags" in search_params.data: - tags = _quote_escape(search_params.data["tags"]) - s = s.query("simple_query_string", fields=["tags.name"], query=tags) + if creator := search_params.data.get("creator"): + creator_query = Q( + "simple_query_string", query=_quote_escape(creator), fields=["creator"] + ) + search_queries["must"].append(creator_query) + if title := search_params.data.get("title"): + title_query = Q( + "simple_query_string", query=_quote_escape(title), fields=["title"] + ) + search_queries["must"].append(title_query) + if tags := search_params.data.get("tags"): + tags_query = Q( + "simple_query_string", query=_quote_escape(tags), fields=["tags.name"] + ) + search_queries["must"].append(tags_query) if settings.USE_RANK_FEATURES: feature_boost = {"standardized_popularity": DEFAULT_BOOST} @@ -422,9 +445,22 @@ def search( rank_queries = [] for field, boost in feature_boost.items(): rank_queries.append(Q("rank_feature", field=field, boost=boost)) - s = search_client.query( - Q("bool", must=s.query or EMPTY_QUERY, should=rank_queries) - ) + search_queries["should"].extend(rank_queries) + + # If there are no must query clauses, only the results that match + # the`should` clause are returned. To avoid this, we add an empty + # query clause to the `must` list. + if not search_queries["must"]: + search_queries["must"].append(EMPTY_QUERY) + + final_query = Q( + "bool", + filter=search_queries["filter"], + must_not=search_queries["must_not"], + must=search_queries["must"], + should=search_queries["should"], + ) + s = s.query(final_query) # Use highlighting to determine which fields contribute to the selection of # top results. @@ -541,7 +577,7 @@ def related_media(uuid: str, index: str, filter_dead: bool) -> list[Hit]: # Exclude the current item and mature content. s = s.query(related_query & ~Term(identifier=uuid) & ~Term(mature=True)) # Exclude the dynamically disabled sources. - s = _exclude_filtered(s) + s = get_dynamically_excluded_providers(s) page, page_size = 1, 10 start, end = _get_query_slice(s, page_size, page, filter_dead)