From 22ba9315bd123c73c02f82b957bfee91ae8d4208 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Sun, 20 Aug 2023 21:02:33 +0300 Subject: [PATCH 01/16] Add routes for the collections Signed-off-by: Olga Bulat --- api/api/constants/media_types.py | 4 +- api/api/constants/search.py | 5 + api/api/controllers/search_controller.py | 194 +++++++++++++----- api/api/docs/base_docs.py | 141 ++++++++++++- api/api/serializers/audio_serializers.py | 21 +- api/api/serializers/media_serializers.py | 111 +++++----- api/api/views/audio_views.py | 97 ++++++++- api/api/views/image_views.py | 100 ++++++++- api/api/views/media_views.py | 82 +++++++- .../controllers/test_search_controller.py | 85 +++++++- api/test/unit/models/test_media_report.py | 3 - 11 files changed, 715 insertions(+), 128 deletions(-) create mode 100644 api/api/constants/search.py diff --git a/api/api/constants/media_types.py b/api/api/constants/media_types.py index 4a0b502ad3c..ae97bbd3e6d 100644 --- a/api/api/constants/media_types.py +++ b/api/api/constants/media_types.py @@ -6,7 +6,9 @@ IMAGE_TYPE = "image" MEDIA_TYPES = [AUDIO_TYPE, IMAGE_TYPE] +MediaType = Literal["audio", "image"] MEDIA_TYPE_CHOICES = [(AUDIO_TYPE, "Audio"), (IMAGE_TYPE, "Image")] -OriginIndex = Literal["image", "audio"] +OriginIndex = MediaType +SearchIndex = Literal["image", "image-filtered", "audio", "audio-filtered"] diff --git a/api/api/constants/search.py b/api/api/constants/search.py new file mode 100644 index 00000000000..4325c3a96f0 --- /dev/null +++ b/api/api/constants/search.py @@ -0,0 +1,5 @@ +from typing import Literal + + +SEARCH_STRATEGIES = ["search", "collection"] +SearchStrategy = Literal["search", "collection"] diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index 5da859f102e..e6eb9844580 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -3,7 +3,7 @@ import logging import logging as log from math import ceil -from typing import Literal +from typing import TYPE_CHECKING from django.conf import settings from django.core.cache import cache @@ -15,7 +15,8 @@ from elasticsearch_dsl.response import Hit, Response import api.models as models -from api.constants.media_types import OriginIndex +from api.constants.media_types import OriginIndex, SearchIndex +from api.constants.search import SearchStrategy from api.constants.sorting import INDEXED_ON from api.controllers.elasticsearch.helpers import ( ELASTICSEARCH_MAX_RESULT_WINDOW, @@ -23,12 +24,17 @@ get_query_slice, get_raw_es_response, ) -from api.serializers import media_serializers from api.utils import tallies from api.utils.check_dead_links import check_dead_links from api.utils.dead_link_mask import get_query_hash from api.utils.search_context import SearchContext +if TYPE_CHECKING: + from api.serializers.audio_serializers import AudioCollectionRequestSerializer + from api.serializers.media_serializers import ( + MediaSearchRequestSerializer, + PaginatedRequestSerializer, + ) module_logger = logging.getLogger(__name__) @@ -172,24 +178,24 @@ def get_excluded_providers_query() -> Q | None: return None -def _resolve_index( - index: Literal["image", "audio"], - search_params: media_serializers.MediaSearchRequestSerializer, -) -> Literal["image", "image-filtered", "audio", "audio-filtered"]: - use_filtered_index = all( - ( - settings.ENABLE_FILTERED_INDEX_QUERIES, - not search_params.validated_data["include_sensitive_results"], - ) - ) - if use_filtered_index: - return f"{index}-filtered" +def get_index( + exact_index: bool, + origin_index: OriginIndex, + search_params: MediaListRequestSerializer, +) -> SearchIndex: + if exact_index: + return origin_index - return index + include_sensitive_results = search_params.validated_data.get( + "include_sensitive_results", False + ) + if settings.ENABLE_FILTERED_INDEX_QUERIES and not include_sensitive_results: + return f"{origin_index}-filtered" + return origin_index def create_search_filter_queries( - search_params: media_serializers.MediaSearchRequestSerializer, + search_params: MediaListRequestSerializer, ) -> dict[str, list[Q]]: """ Create a list of Elasticsearch queries for filtering search results. @@ -230,7 +236,7 @@ def create_search_filter_queries( def create_ranking_queries( - search_params: media_serializers.MediaSearchRequestSerializer, + search_params: MediaListRequestSerializer, ) -> list[Q]: queries = [Q("rank_feature", field="standardized_popularity", boost=DEFAULT_BOOST)] if search_params.data["unstable__authority"]: @@ -240,8 +246,8 @@ def create_ranking_queries( return queries -def create_search_query( - search_params: media_serializers.MediaSearchRequestSerializer, +def build_search_query( + search_params: MediaListRequestSerializer, ) -> Q: # Apply filters from the url query search parameters. url_queries = create_search_filter_queries(search_params) @@ -315,8 +321,49 @@ def create_search_query( ) -def search( - search_params: media_serializers.MediaSearchRequestSerializer, +def build_collection_query( + search_params: MediaListRequestSerializer, + collection_params: dict[str, str], +): + """ + Build the query to retrieve items in a collection. + :param collection_params: `tag`, `source` and/or `creator` values from the path. + :param search_params: the validated search parameters. + :return: the search client with the query applied. + """ + search_query = {"filter": [], "must": [], "should": [], "must_not": []} + # Apply the term filters. Each tuple pairs a filter's parameter name in the API + # with its corresponding field in Elasticsearch. "None" means that the + # names are identical. + filters = [ + # Collection filters allow a single value. + ("tag", "tags.name"), + ("source", None), + ("creator", None), + ] + for serializer_field, es_field in filters: + if serializer_field in collection_params: + if not (argument := collection_params.get(serializer_field)): + continue + parameter = es_field or serializer_field + search_query["filter"].append({"term": {parameter: argument}}) + + # Exclude mature content and disabled sources + include_sensitive_by_params = search_params.validated_data.get( + "include_sensitive_results", False + ) + if not include_sensitive_by_params: + search_query["must_not"].append({"term": {"mature": True}}) + if excluded_providers_query := get_excluded_providers_query(): + search_query["must_not"].append(excluded_providers_query) + + return Q("bool", **search_query) + + +def query_media( + strategy: SearchStrategy, + search_params: MediaListRequestSerializer, + collection_params: dict[str, str] | None, origin_index: OriginIndex, exact_index: bool, page_size: int, @@ -325,10 +372,17 @@ def search( page: int = 1, ) -> tuple[list[Hit], int, int, dict]: """ - Perform a ranked paginated search from the set of keywords and, optionally, filters. - - :param search_params: Search parameters. See - :class: `ImageSearchQueryStringSerializer`. + If ``strategy`` is ``search``, perform a ranked paginated search + from the set of keywords and, optionally, filters. + If `strategy` is `collection`, perform a paginated search + for the `tag`, `source` or `source` and `creator` combination. + + :param collection_params: The path parameters for collection search, if + strategy is `collection`. + :param strategy: Whether to perform a default search or retrieve a collection. + :param search_params: If `strategy` is `collection`, `PaginatedRequestSerializer` + or `AudioCollectionRequestSerializer`. If `strategy` is `search`, search + query params, see :class: `MediaRequestSerializer`. :param origin_index: The Elasticsearch index to search (e.g. 'image') :param exact_index: whether to skip all modifications to the index name :param page_size: The number of results to return per page. @@ -337,46 +391,54 @@ def search( Elasticsearch shards. :param filter_dead: Whether dead links should be removed. :param page: The results page number. - :return: Tuple with a List of Hits from elasticsearch, the total count of + :return: Tuple with a list of Hits from elasticsearch, the total count of pages, the number of results, and the ``SearchContext`` as a dict. """ - if not exact_index: - index = _resolve_index(origin_index, search_params) + index = get_index(exact_index, origin_index, search_params) + + if strategy == "collection": + query = build_collection_query(search_params, collection_params) else: - index = origin_index + query = build_search_query(search_params) - s = Search(index=index) + s = Search(index=index).query(query) - search_query = create_search_query(search_params) - s = s.query(search_query) + if strategy == "search": + # Use highlighting to determine which fields contribute to the selection of + # top results. + s = s.highlight(*DEFAULT_SEARCH_FIELDS) + s = s.highlight_options(order="score") + s.extra(track_scores=True) - # Use highlighting to determine which fields contribute to the selection of - # top results. - s = s.highlight(*DEFAULT_SEARCH_FIELDS) - s = s.highlight_options(order="score") - s.extra(track_scores=True) # Route users to the same Elasticsearch worker node to reduce # pagination inconsistencies and increase cache hits. # TODO: Re-add 7s request_timeout when ES stability is restored s = s.params(preference=str(ip)) - # Sort by new - if search_params.validated_data["sort_by"] == INDEXED_ON: - s = s.sort({"created_on": {"order": search_params.validated_data["sort_dir"]}}) - - # Paginate - start, end = get_query_slice(s, page_size, page, filter_dead) - s = s[start:end] - search_response = get_es_response(s, es_query="search") + # Sort by `created_on` if the parameter is set or if `strategy` is `collection`. + sort_by = search_params.validated_data.get("sort_by") + sort_dir = search_params.validated_data.get("sort_dir", "desc") + if strategy == "collection" or sort_by == INDEXED_ON: + s = s.sort({"created_on": {"order": sort_dir}}) - results = _post_process_results( - s, start, end, page_size, search_response, filter_dead + # Execute paginated search and tally results + page_count, result_count, results = execute_search( + s, page, page_size, filter_dead, index, es_query=strategy ) - result_count, page_count = _get_result_and_page_count( - search_response, results, page_size, page - ) + result_ids = [result.identifier for result in results] + search_context = SearchContext.build(result_ids, origin_index) + + return results, page_count, result_count, search_context.asdict() + +def tally_results( + index: SearchIndex, results: list[Hit] | None, page: int, page_size: int +) -> None: + """ + Tally the number of the results from each provider in the results + for the search query. + """ results_to_tally = results or [] max_result_depth = page * page_size if max_result_depth <= 80: @@ -405,13 +467,33 @@ def search( # check things like provider density for a set of queries. tallies.count_provider_occurrences(results_to_tally, index) - if not results: - results = [] - result_ids = [result.identifier for result in results] - search_context = SearchContext.build(result_ids, origin_index) +def execute_search( + s: Search, + page: int, + page_size: int, + filter_dead: bool, + index: SearchIndex, + es_query: str, +) -> tuple[int, int, list[Hit]]: + """ + Execute search for the given query slice, post-processes the results, + and returns the results and result and page counts. + """ + start, end = get_query_slice(s, page_size, page, filter_dead) + s = s[start:end] - return results, page_count, result_count, search_context.asdict() + search_response = get_es_response(s, es_query=es_query) + + results: list[Hit] = ( + _post_process_results(s, start, end, page_size, search_response, filter_dead) + or [] + ) + result_count, page_count = _get_result_and_page_count( + search_response, results, page_size, page + ) + tally_results(index, results, page, page_size) + return page_count, result_count, results def get_sources(index): diff --git a/api/api/docs/base_docs.py b/api/api/docs/base_docs.py index dfbfb9fb3fd..41caf7e8236 100644 --- a/api/api/docs/base_docs.py +++ b/api/api/docs/base_docs.py @@ -1,8 +1,23 @@ from http.client import responses as http_responses from textwrap import dedent - -from drf_spectacular.openapi import AutoSchema, OpenApiResponse -from drf_spectacular.utils import OpenApiExample, extend_schema +from typing import Literal + +from drf_spectacular.openapi import AutoSchema +from drf_spectacular.utils import ( + OpenApiExample, + OpenApiParameter, + OpenApiResponse, + extend_schema, +) + +from api.constants.media_types import MediaType +from api.serializers.audio_serializers import ( + AudioCollectionRequestSerializer, + AudioSerializer, +) +from api.serializers.error_serializers import NotFoundErrorSerializer +from api.serializers.image_serializers import ImageSerializer +from api.serializers.media_serializers import PaginatedRequestSerializer def fields_to_md(field_names): @@ -77,3 +92,123 @@ def get_operation_id(self) -> str: else: operation_tokens.append("detail") return "_".join(operation_tokens) + + +source_404_message = "Invalid source 'name'. Valid sources are ..." +source_404_response = OpenApiResponse( + NotFoundErrorSerializer, + examples=[ + OpenApiExample( + name="404", + value={"detail": source_404_message}, + ) + ], +) + + +def build_source_path_parameter(media_type: MediaType): + valid_description = ( + f"Valid values are source_names from the stats endpoint: " + f"https://api.openverse.engineering/v1/{media_type}/stats/." + ) + + return OpenApiParameter( + name="source", + type=str, + location=OpenApiParameter.PATH, + description=f"The source of {media_type}. {valid_description}", + ) + + +creator_path_parameter = OpenApiParameter( + name="creator", + type=str, + location=OpenApiParameter.PATH, + description="The name of the media creator. This parameter " + "is case-sensitive, and matches exactly.", +) +tag_path_parameter = OpenApiParameter( + name="tag", + type=str, + location=OpenApiParameter.PATH, + description="The tag of the media. Not case-sensitive, matches exactly.", +) + + +def get_collection_description(media_type, collection): + if collection == "tag": + return f""" +Get a collection of {media_type} with a specific tag. + +This endpoint returns only the exact matches. To search within the +tag values, or to match several tags, use the `search` endpoint +with `tags` query parameter instead of `q` parameter. + +The returned results are ordered primarily based on their popularity +and authority. However, note that the exact order may vary over time +or across requests. + """ + elif collection == "source": + return f""" +Get a collection of {media_type} from a specific source. + +This endpoint returns only the exact matches. To search within the source value, +use the `search` endpoint with `source` query parameter. + +The results in the collection will be sorted by the order in which they +were added to Openverse. + """ + elif collection == "creator": + return f""" +Get a collection of {media_type} by a specific creator from the specified source. + +This endpoint returns only the exact matches both on the creator and the source. +Notice that a single creator's media items can be found on several sources, but +this endpoint only returns the items from the specified source. To search within +the creator value, use the `search` endpoint with `source` query parameter +instead of `q`. + +The order in the results is not guaranteed to stay the same. Most likely, the images +in the collection will be sorted by the order in which they were added to Openverse. + """ + + +COLLECTION_TO_OPERATION_ID = { + ("images", "source"): "images_by_source", + ("images", "creator"): "images_by_source_and_creator", + ("images", "tag"): "images_by_tag", + ("audio", "source"): "audio_by_source", + ("audio", "creator"): "audio_by_source_and_creator", + ("audio", "tag"): "audio_by_tag", +} + + +def collection_schema( + media_type: Literal["images", "audio"], + collection: Literal["source", "creator", "tag"], +): + if media_type == "images": + request_serializer = PaginatedRequestSerializer + serializer = ImageSerializer + else: + request_serializer = AudioCollectionRequestSerializer + serializer = AudioSerializer + + if collection == "tag": + responses = {200: serializer(many=True)} + path_parameters = [tag_path_parameter] + else: + responses = {200: serializer(many=True), 404: source_404_response} + path_parameters = [build_source_path_parameter(media_type)] + if collection == "creator": + path_parameters.append(creator_path_parameter) + operation_id = COLLECTION_TO_OPERATION_ID[(media_type, collection)] + description = get_collection_description(media_type, collection) + return extend_schema( + operation_id=operation_id, + summary=operation_id, + auth=[], + description=description, + responses=responses, + parameters=[request_serializer, *path_parameters], + ) diff --git a/api/api/serializers/audio_serializers.py b/api/api/serializers/audio_serializers.py index bb6718c5774..53236feb0d4 100644 --- a/api/api/serializers/audio_serializers.py +++ b/api/api/serializers/audio_serializers.py @@ -11,6 +11,7 @@ MediaReportRequestSerializer, MediaSearchRequestSerializer, MediaSerializer, + PaginatedRequestSerializer, get_hyperlinks_serializer, get_search_request_source_serializer, ) @@ -24,6 +25,23 @@ AudioSearchRequestSourceSerializer = get_search_request_source_serializer("audio") +class AudioCollectionRequestSerializer(PaginatedRequestSerializer): + field_names = [ + *PaginatedRequestSerializer.field_names, + "peaks", + ] + + peaks = serializers.BooleanField( + help_text="Whether to include the waveform peaks or not", + required=False, + default=False, + ) + + @property + def needs_db(self) -> bool: + return super().needs_db or self.data["peaks"] + + class AudioSearchRequestSerializer( AudioSearchRequestSourceSerializer, MediaSearchRequestSerializer, @@ -62,8 +80,7 @@ def needs_db(self) -> bool: return super().needs_db or self.data["peaks"] def validate_internal__index(self, value): - index = super().validate_internal__index(value) - if index is None: + if not (index := super().validate_internal__index(value)): return None if not index.startswith(AUDIO_TYPE): raise serializers.ValidationError(f"Invalid index name `{value}`.") diff --git a/api/api/serializers/media_serializers.py b/api/api/serializers/media_serializers.py index a1e2b8de8a7..b022182f69d 100644 --- a/api/api/serializers/media_serializers.py +++ b/api/api/serializers/media_serializers.py @@ -26,6 +26,64 @@ ####################### +class PaginatedRequestSerializer(serializers.Serializer): + """This serializer passes pagination parameters from the query string.""" + + field_names = [ + "page_size", + "page", + ] + page_size = serializers.IntegerField( + label="page_size", + help_text=f"Number of results to return per page. " + f"Maximum for unauthenticated requests is {settings.MAX_ANONYMOUS_PAGE_SIZE}.", + required=False, + default=settings.MAX_ANONYMOUS_PAGE_SIZE, + min_value=1, + ) + page = serializers.IntegerField( + label="page", + help_text="The page of results to retrieve.", + required=False, + default=1, + max_value=settings.MAX_PAGINATION_DEPTH, + min_value=1, + ) + + def validate_page_size(self, value): + request = self.context.get("request") + is_anonymous = bool(request and request.user and request.user.is_anonymous) + max_value = ( + settings.MAX_ANONYMOUS_PAGE_SIZE + if is_anonymous + else settings.MAX_AUTHED_PAGE_SIZE + ) + + validator = MaxValueValidator( + max_value, + message=serializers.IntegerField.default_error_messages["max_value"].format( + max_value=max_value + ), + ) + + if is_anonymous: + try: + validator(value) + except ValidationError as e: + raise NotAuthenticated( + detail=e.message, + code=e.code, + ) + else: + validator(value) + + return value + + @property + def needs_db(self) -> bool: + return False + + @extend_schema_serializer( # Hide unstable and internal fields from documentation. # Also see `field_names` below. @@ -38,7 +96,7 @@ "internal__index", ], ) -class MediaSearchRequestSerializer(serializers.Serializer): +class MediaSearchRequestSerializer(PaginatedRequestSerializer): """This serializer parses and validates search query string parameters.""" DeprecatedParam = namedtuple("DeprecatedParam", ["original", "successor"]) @@ -64,8 +122,7 @@ class MediaSearchRequestSerializer(serializers.Serializer): # "unstable__authority", # "unstable__authority_boost", # "unstable__include_sensitive_results", - "page_size", - "page", + *PaginatedRequestSerializer.field_names, ] """ Keep the fields names in sync with the actual fields below as this list is @@ -179,22 +236,6 @@ class MediaSearchRequestSerializer(serializers.Serializer): required=False, ) - page_size = serializers.IntegerField( - label="page_size", - help_text="Number of results to return per page.", - required=False, - default=settings.MAX_ANONYMOUS_PAGE_SIZE, - min_value=1, - ) - page = serializers.IntegerField( - label="page", - help_text="The page of results to retrieve.", - required=False, - default=1, - max_value=settings.MAX_PAGINATION_DEPTH, - min_value=1, - ) - def is_request_anonymous(self): request = self.context.get("request") return bool(request and request.user and request.user.is_anonymous) @@ -283,34 +324,6 @@ def validate_internal__index(self, value): raise serializers.ValidationError(f"Invalid index name `{value}`.") return value - def validate_page_size(self, value): - is_anonymous = self.is_request_anonymous() - max_value = ( - settings.MAX_ANONYMOUS_PAGE_SIZE - if is_anonymous - else settings.MAX_AUTHED_PAGE_SIZE - ) - - validator = MaxValueValidator( - max_value, - message=serializers.IntegerField.default_error_messages["max_value"].format( - max_value=max_value - ), - ) - - if is_anonymous: - try: - validator(value) - except ValidationError as e: - raise NotAuthenticated( - detail=e.message, - code=e.code, - ) - else: - validator(value) - - return value - @staticmethod def validate_extension(value): return value.lower() @@ -329,10 +342,6 @@ def validate(self, data): return data - @property - def needs_db(self) -> bool: - return False - class MediaThumbnailRequestSerializer(serializers.Serializer): """This serializer parses and validates thumbnail query string parameters.""" diff --git a/api/api/views/audio_views.py b/api/api/views/audio_views.py index c26fe8055b8..66f242bba34 100644 --- a/api/api/views/audio_views.py +++ b/api/api/views/audio_views.py @@ -3,7 +3,12 @@ from rest_framework.exceptions import APIException, NotFound from rest_framework.response import Response -from drf_spectacular.utils import extend_schema, extend_schema_view +from drf_spectacular.utils import ( + OpenApiExample, + OpenApiResponse, + extend_schema, + extend_schema_view, +) from api.constants.media_types import AUDIO_TYPE from api.docs.audio_docs import ( @@ -17,11 +22,13 @@ ) from api.models import Audio from api.serializers.audio_serializers import ( + AudioCollectionRequestSerializer, AudioReportRequestSerializer, AudioSearchRequestSerializer, AudioSerializer, AudioWaveformSerializer, ) +from api.serializers.error_serializers import NotFoundErrorSerializer from api.serializers.media_serializers import MediaThumbnailRequestSerializer from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle from api.views.media_views import MediaViewSet @@ -38,15 +45,103 @@ class AudioViewSet(MediaViewSet): """Viewset for all endpoints pertaining to audio.""" model_class = Audio + media_type = AUDIO_TYPE query_serializer_class = AudioSearchRequestSerializer default_index = settings.MEDIA_INDEX_MAPPING[AUDIO_TYPE] serializer_class = AudioSerializer + collection_serializer_class = AudioCollectionRequestSerializer def get_queryset(self): return super().get_queryset().select_related("mature_audio", "audioset") # Extra actions + @extend_schema( + operation_id="audio_by_creator", + summary="audio_by_creator_at_source", + responses={ + 200: AudioSerializer(many=True), + 404: OpenApiResponse( + NotFoundErrorSerializer, + examples=[ + OpenApiExample( + name="404", + value={ + "detail": "Invalid source 'source_name'. " + "Valid sources are ..." + }, + ) + ], + ), + }, + parameters=[AudioCollectionRequestSerializer], + ) + @action( + detail=False, + methods=["get"], + url_path="source/(?P[^/.]+)/creator/(?P.+)", + ) + def creator_collection(self, request, source, creator): + """ + Get a collection of audio items by a specific creator from the specified source. + + The items in the collection will be sorted by the order in which they were + added to Openverse. + """ + return super().creator_collection(request, source, creator) + + @extend_schema( + operation_id="audio_by_source", + summary="audio_by_source", + responses={ + 200: AudioSerializer(many=True), + 404: OpenApiResponse( + NotFoundErrorSerializer, + examples=[ + OpenApiExample( + name="404", + value={ + "detail": "Invalid source 'source_name'. " + "Valid sources are ..." + }, + ) + ], + ), + }, + parameters=[AudioCollectionRequestSerializer], + ) + @action( + detail=False, + methods=["get"], + url_path="source/(?P[^/.]+)", + ) + def source_collection(self, request, source): + """ + Get a collection of audio items from a specific source. + + The items in the collection will be sorted by the order in which they were + added to Openverse. + """ + return super().source_collection(request, source) + + @extend_schema( + operation_id="audio_by_tag", + summary="audio_by_tag", + responses={200: AudioSerializer(many=True)}, + parameters=[AudioCollectionRequestSerializer], + ) + @action( + detail=False, + methods=["get"], + url_path="tag/(?P[^/.]+)", + ) + def tag_collection(self, request, tag, *_, **__): + """ + Get a collection of audio items with a specific tag. + + The items will be ranked by their popularity and authority. + """ + return super().tag_collection(request, tag, *_, **__) @thumbnail @action( diff --git a/api/api/views/image_views.py b/api/api/views/image_views.py index 6f138908db8..9a9c3f55f70 100644 --- a/api/api/views/image_views.py +++ b/api/api/views/image_views.py @@ -8,7 +8,12 @@ from rest_framework.response import Response import requests -from drf_spectacular.utils import extend_schema, extend_schema_view +from drf_spectacular.utils import ( + OpenApiExample, + OpenApiResponse, + extend_schema, + extend_schema_view, +) from PIL import Image as PILImage from api.constants.media_types import IMAGE_TYPE @@ -23,6 +28,7 @@ ) from api.docs.image_docs import watermark as watermark_doc from api.models import Image +from api.serializers.error_serializers import NotFoundErrorSerializer from api.serializers.image_serializers import ( ImageReportRequestSerializer, ImageSearchRequestSerializer, @@ -31,7 +37,10 @@ OembedSerializer, WatermarkRequestSerializer, ) -from api.serializers.media_serializers import MediaThumbnailRequestSerializer +from api.serializers.media_serializers import ( + MediaThumbnailRequestSerializer, + PaginatedRequestSerializer, +) from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle from api.utils.watermark import watermark from api.views.media_views import MediaViewSet @@ -48,6 +57,7 @@ class ImageViewSet(MediaViewSet): """Viewset for all endpoints pertaining to images.""" model_class = Image + media_type = IMAGE_TYPE query_serializer_class = ImageSearchRequestSerializer default_index = settings.MEDIA_INDEX_MAPPING[IMAGE_TYPE] @@ -61,6 +71,92 @@ def get_queryset(self): return super().get_queryset().select_related("mature_image") # Extra actions + @extend_schema( + operation_id="image_by_creator_at_source", + summary="image_by_creator_at_source", + responses={ + 200: ImageSerializer(many=True), + 404: OpenApiResponse( + NotFoundErrorSerializer, + examples=[ + OpenApiExample( + name="404", + value={ + "detail": "Invalid source 'source_name'. " + "Valid sources are ..." + }, + ) + ], + ), + }, + parameters=[PaginatedRequestSerializer], + ) + @action( + detail=False, + methods=["get"], + url_path="source/(?P[^/.]+)/creator/(?P.+)", + ) + def creator_collection(self, request, source, creator): + """ + Get a collection of images by a specific creator from the specified source. + + The images in the collection will be sorted by the order in which they were + added to Openverse. + """ + return super().creator_collection(request, source, creator) + + @extend_schema( + operation_id="image_by_source", + summary="image_by_source", + responses={ + 200: ImageSerializer(many=True), + 404: OpenApiResponse( + NotFoundErrorSerializer, + examples=[ + OpenApiExample( + name="404", + value={ + "detail": "Invalid source 'source_name'. " + "Valid sources are ..." + }, + ) + ], + ), + }, + parameters=[PaginatedRequestSerializer], + ) + @action( + detail=False, + methods=["get"], + url_path="source/(?P[^/.]+)", + ) + def source_collection(self, request, source, *_, **__): + """ + Get a collection of images from a specific source. + + The images in the collection will be sorted by the order in which they were + added to Openverse. + """ + return super().source_collection(request, source) + + @extend_schema( + operation_id="images_by_tag", + summary="images_by_tag", + responses={200: ImageSerializer(many=True)}, + parameters=[PaginatedRequestSerializer], + ) + @action( + detail=False, + methods=["get"], + url_path="tag/(?P[^/.]+)", + ) + def tag_collection(self, request, tag, *_, **__): + """ + Get a collection of images with a specific tag. + + The images in the collection will be ranked by their popularity and authority. + """ + return super().tag_collection(request, tag, *_, **__) @oembed @action( diff --git a/api/api/views/media_views.py b/api/api/views/media_views.py index a0f28f68d55..ecf6005f410 100644 --- a/api/api/views/media_views.py +++ b/api/api/views/media_views.py @@ -1,4 +1,5 @@ import logging +from typing import Union from rest_framework import status from rest_framework.decorators import action @@ -6,10 +7,13 @@ from rest_framework.response import Response from rest_framework.viewsets import ReadOnlyModelViewSet +from api.constants.media_types import MediaType +from api.constants.search import SearchStrategy from api.controllers import search_controller from api.controllers.elasticsearch.related import related_media from api.models import ContentProvider from api.models.media import AbstractMedia +from api.serializers import audio_serializers, media_serializers from api.serializers.provider_serializers import ProviderSerializer from api.utils import image_proxy from api.utils.pagination import StandardPagination @@ -18,6 +22,18 @@ logger = logging.getLogger(__name__) +MediaListRequestSerializer = Union[ + audio_serializers.AudioCollectionRequestSerializer, + media_serializers.PaginatedRequestSerializer, + media_serializers.MediaSearchRequestSerializer, +] + + +class InvalidSource(APIException): + status_code = 400 + default_detail = "Invalid source." + default_code = "invalid_source" + class MediaViewSet(ReadOnlyModelViewSet): lookup_field = "identifier" @@ -30,13 +46,16 @@ class MediaViewSet(ReadOnlyModelViewSet): # Populate these in the corresponding subclass model_class: type[AbstractMedia] = None + media_type: MediaType | None = None query_serializer_class = None + collection_serializer_class = None default_index = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) required_fields = [ self.model_class, + self.media_type, self.query_serializer_class, self.default_index, ] @@ -101,12 +120,64 @@ def retrieve(self, request, *_, **__): def list(self, request, *_, **__): params = self._get_request_serializer(request) + return self.get_media_results(request, "search", params) + + def _validate_source(self, source): + valid_sources = search_controller.get_sources(self.media_type) + if source not in valid_sources: + valid_string = ", ".join([f"'{k}'" for k in valid_sources.keys()]) + raise InvalidSource( + detail=f"Invalid source '{source}'. Valid sources are: {valid_string}.", + ) + def collection(self, request, tag, source, creator, *_, **__): + if tag: + collection_params = {"tag": tag} + elif creator: + collection_params = {"creator": creator, "source": source} + else: + collection_params = {"source": source} + if source: + self._validate_source(source) + + params = self.collection_serializer_class( + data=request.query_params, context={"request": request} + ) + params.is_valid(raise_exception=True) + + return self.get_media_results(request, "collection", params, collection_params) + + @action(detail=False, methods=["get"], url_path="tag/(?P[^/.]+)") + def tag_collection(self, request, tag, *_, **__): + tag_lower = tag.lower() + return self.collection(request, tag_lower, None, None) + + @action(detail=False, methods=["get"], url_path="source/(?P[^/.]+)") + def source_collection(self, request, source, *_, **__): + return self.collection(request, None, source, None) + + @action( + detail=False, + methods=["get"], + url_path="source/(?P[^/.]+)/creator/(?P.+)", + ) + def creator_collection(self, request, source, creator): + return self.collection(request, None, source, creator) + + # Common functionality for search and collection views + + def get_media_results( + self, + request, + strategy: SearchStrategy, + params: MediaListRequestSerializer, + collection_params: dict[str, str] | None = None, + ): page_size = self.paginator.page_size = params.data["page_size"] page = self.paginator.page = params.data["page"] hashed_ip = hash(self._get_user_ip(request)) - filter_dead = params.validated_data["filter_dead"] + filter_dead = params.validated_data.get("filter_dead", True) if pref_index := params.validated_data.get("index"): logger.info(f"Using preferred index {pref_index} for media.") @@ -118,8 +189,15 @@ def list(self, request, *_, **__): exact_index = False try: - results, num_pages, num_results, search_context = search_controller.search( + ( + results, + num_pages, + num_results, + search_context, + ) = search_controller.query_media( + strategy, params, + collection_params, search_index, exact_index, page_size, diff --git a/api/test/unit/controllers/test_search_controller.py b/api/test/unit/controllers/test_search_controller.py index b344a1074d6..9b7c04f57b6 100644 --- a/api/test/unit/controllers/test_search_controller.py +++ b/api/test/unit/controllers/test_search_controller.py @@ -3,6 +3,8 @@ import re from collections.abc import Callable from enum import Enum, auto + +from api.serializers.media_serializers import PaginatedRequestSerializer from test.factory.es_http import ( MOCK_DEAD_RESULT_URL_PREFIX, MOCK_LIVE_RESULT_URL_PREFIX, @@ -14,10 +16,12 @@ import pook import pytest from django_redis import get_redis_connection -from elasticsearch_dsl import Search +from elasticsearch_dsl import Q, Search from api.controllers import search_controller from api.controllers.elasticsearch import helpers as es_helpers +from api.controllers.search_controller import build_collection_query +from api.serializers import image_serializers from api.utils import tallies from api.utils.dead_link_mask import get_query_hash, save_query_mask from api.utils.search_context import SearchContext @@ -456,7 +460,8 @@ def test_search_tallies_pages_less_than_5( ) serializer.is_valid() - search_controller.search( + search_controller.query_media( + strategy="search", search_params=serializer, ip=0, origin_index=media_type_config.origin_index, @@ -495,7 +500,8 @@ def test_search_tallies_handles_empty_page( serializer = media_type_config.search_request_serializer(data={"q": "dogs"}) serializer.is_valid() - search_controller.search( + search_controller.query_media( + strategy="search", search_params=serializer, ip=0, origin_index=media_type_config.origin_index, @@ -538,7 +544,8 @@ def test_resolves_index( ) serializer.is_valid() - search_controller.search( + search_controller.query_media( + strategy="search", search_params=serializer, ip=0, origin_index=origin_index, @@ -605,7 +612,8 @@ def test_no_post_process_results_recursion( data={"q": "bird perched"} ) serializer.is_valid() - results, _, _, _ = search_controller.search( + results, _, _, _ = search_controller.query_media( + strategy="search", search_params=serializer, ip=0, origin_index=image_media_type_config.origin_index, @@ -743,7 +751,8 @@ def test_post_process_results_recurses_as_needed( data={"q": "bird perched"} ) serializer.is_valid() - results, _, _, _ = search_controller.search( + results, _, _, _ = search_controller.query_media( + strategy="search", search_params=serializer, ip=0, origin_index=image_media_type_config.origin_index, @@ -785,8 +794,10 @@ def _delete_all_results_but_first(_, __, results, ___): serializer.is_valid() with caplog.at_level(logging.INFO): - results, _, _, _ = search_controller.search( + results, _, _, _ = search_controller.query_media( + strategy="search", search_params=serializer, + collection_params=None, ip=0, origin_index=image_media_type_config.origin_index, exact_index=True, @@ -795,3 +806,63 @@ def _delete_all_results_but_first(_, __, results, ___): filter_dead=True, ) assert "Nesting threshold breached" in caplog.text + + +@pytest.mark.parametrize( + ("data", "expected_query"), + [ + pytest.param( + {"unstable__include_sensitive_results": False, "tag": "art"}, + Q( + "bool", + filter=[{"terms": {"tags.name.keyword": ["art"]}}], + must_not=[{"term": {"mature": True}}], + ), + id="filter_by_tag_without_sensitive", + ), + pytest.param( + {"unstable__include_sensitive_results": True, "tag": "art"}, + Q( + "bool", + filter=[{"terms": {"tags.name.keyword": ["art"]}}], + ), + id="filter_by_tag_with_sensitive", + ), + pytest.param( + {"unstable__include_sensitive_results": False, "source": "flickr"}, + Q( + "bool", + filter=[{"terms": {"source.keyword": ["flickr"]}}], + must_not=[{"term": {"mature": True}}], + ), + id="filter_by_source_without_sensitive", + ), + pytest.param( + { + "unstable__include_sensitive_results": False, + "source": "flickr", + "creator": "nasa", + }, + Q( + "bool", + filter=[ + {"terms": {"source.keyword": ["flickr"]}}, + {"terms": {"creator.keyword": ["nasa"]}}, + ], + must_not=[{"term": {"mature": True}}], + ), + id="filter_by_creator_without_sensitive", + ), + ], +) +@mock.patch("api.controllers.search_controller.Search", wraps=Search) +def test_build_collection_query(mock_search_class, data, expected_query): + # Setup + mock_search = mock_search_class.return_value + + # Action + build_collection_query(mock_search, data) + actual_query = mock_search.query.call_args[0][0] + + # Validate + assert actual_query == expected_query diff --git a/api/test/unit/models/test_media_report.py b/api/test/unit/models/test_media_report.py index 0433647015a..e936c47d356 100644 --- a/api/test/unit/models/test_media_report.py +++ b/api/test/unit/models/test_media_report.py @@ -1,5 +1,4 @@ import uuid -from typing import Literal, Union from django.core.exceptions import ObjectDoesNotExist @@ -22,8 +21,6 @@ pytestmark = pytest.mark.django_db -MediaType = Union[Literal["audio"], Literal["image"]] - reason_params = pytest.mark.parametrize("reason", [DMCA, MATURE, OTHER]) From 56139fa76edc66bddb9d5a1d6a6da8a3e79ecc63 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Mon, 21 Aug 2023 19:06:37 +0300 Subject: [PATCH 02/16] Add and update tests Signed-off-by: Olga Bulat --- .../controllers/test_search_controller.py | 9 ++- .../test_search_controller_search_query.py | 58 ++++++++++++++++--- .../serializers/test_media_serializers.py | 12 ++-- api/test/unit/views/test_media_views.py | 34 ++++++++++- 4 files changed, 95 insertions(+), 18 deletions(-) diff --git a/api/test/unit/controllers/test_search_controller.py b/api/test/unit/controllers/test_search_controller.py index 9b7c04f57b6..82602a2c6f6 100644 --- a/api/test/unit/controllers/test_search_controller.py +++ b/api/test/unit/controllers/test_search_controller.py @@ -16,12 +16,10 @@ import pook import pytest from django_redis import get_redis_connection -from elasticsearch_dsl import Q, Search +from elasticsearch_dsl import Search from api.controllers import search_controller from api.controllers.elasticsearch import helpers as es_helpers -from api.controllers.search_controller import build_collection_query -from api.serializers import image_serializers from api.utils import tallies from api.utils.dead_link_mask import get_query_hash, save_query_mask from api.utils.search_context import SearchContext @@ -463,6 +461,7 @@ def test_search_tallies_pages_less_than_5( search_controller.query_media( strategy="search", search_params=serializer, + collection_params=None, ip=0, origin_index=media_type_config.origin_index, exact_index=False, @@ -503,6 +502,7 @@ def test_search_tallies_handles_empty_page( search_controller.query_media( strategy="search", search_params=serializer, + collection_params=None, ip=0, origin_index=media_type_config.origin_index, exact_index=False, @@ -547,6 +547,7 @@ def test_resolves_index( search_controller.query_media( strategy="search", search_params=serializer, + collection_params=None, ip=0, origin_index=origin_index, exact_index=False, @@ -615,6 +616,7 @@ def test_no_post_process_results_recursion( results, _, _, _ = search_controller.query_media( strategy="search", search_params=serializer, + collection_params=None, ip=0, origin_index=image_media_type_config.origin_index, exact_index=True, @@ -754,6 +756,7 @@ def test_post_process_results_recurses_as_needed( results, _, _, _ = search_controller.query_media( strategy="search", search_params=serializer, + collection_params=None, ip=0, origin_index=image_media_type_config.origin_index, exact_index=True, diff --git a/api/test/unit/controllers/test_search_controller_search_query.py b/api/test/unit/controllers/test_search_controller_search_query.py index 005c1bbbe10..566e718334c 100644 --- a/api/test/unit/controllers/test_search_controller_search_query.py +++ b/api/test/unit/controllers/test_search_controller_search_query.py @@ -1,6 +1,7 @@ from django.core.cache import cache import pytest +from elasticsearch_dsl import Q from api.controllers import search_controller @@ -23,7 +24,7 @@ def excluded_providers_cache(): def test_create_search_query_empty(media_type_config): serializer = media_type_config.search_request_serializer(data={}) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -39,7 +40,7 @@ def test_create_search_query_empty_no_ranking(media_type_config, settings): settings.USE_RANK_FEATURES = False serializer = media_type_config.search_request_serializer(data={}) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -51,7 +52,7 @@ def test_create_search_query_empty_no_ranking(media_type_config, settings): def test_create_search_query_q_search_no_filters(media_type_config): serializer = media_type_config.search_request_serializer(data={"q": "cat"}) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -83,7 +84,7 @@ def test_create_search_query_q_search_with_quotes_adds_exact_suffix(media_type_c data={"q": '"The cutest cat"'} ) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -127,7 +128,7 @@ def test_create_search_query_q_search_with_filters(image_media_type_config): } ) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -169,7 +170,7 @@ def test_create_search_query_non_q_query(image_media_type_config): } ) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -200,7 +201,7 @@ def test_create_search_query_q_search_license_license_type_creates_2_terms_filte } ) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] first_license_terms_filter = actual_query_clauses["filter"][0] @@ -235,7 +236,7 @@ def test_create_search_query_empty_with_dynamically_excluded_providers( serializer = image_media_type_config.search_request_serializer(data={}) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -248,3 +249,44 @@ def test_create_search_query_empty_with_dynamically_excluded_providers( {"rank_feature": {"boost": 10000, "field": "standardized_popularity"}} ], } + + +@pytest.mark.parametrize( + ("data", "expected_query_filter"), + [ + pytest.param( + {"tag": "art"}, + [{"term": {"tags.name": "art"}}], + id="filter_by_tag", + ), + pytest.param( + {"tag": "art, photography"}, + [{"term": {"tags.name": "art, photography"}}], + id="filter_by_tag_treats_punctuation_as_part_of_tag", + ), + pytest.param( + {"source": "flickr"}, + [{"term": {"source": "flickr"}}], + id="filter_by_source", + ), + pytest.param( + {"source": "flickr", "creator": "nasa"}, + [ + {"term": {"source": "flickr"}}, + {"term": {"creator": "nasa"}}, + ], + id="filter_by_creator", + ), + ], +) +def test_build_collection_query(image_media_type_config, data, expected_query_filter): + serializer = image_media_type_config.search_request_serializer(data={}) + serializer.is_valid(raise_exception=True) + actual_query = search_controller.build_collection_query(serializer, data) + expected_query = Q( + "bool", + filter=expected_query_filter, + must_not=[{"term": {"mature": True}}], + ) + + assert actual_query == expected_query diff --git a/api/test/unit/serializers/test_media_serializers.py b/api/test/unit/serializers/test_media_serializers.py index 46edaedea8f..d48798c40cb 100644 --- a/api/test/unit/serializers/test_media_serializers.py +++ b/api/test/unit/serializers/test_media_serializers.py @@ -195,17 +195,17 @@ def test_index_is_only_set_if_authenticated( mock_es.indices.exists.return_value = True request = authed_request if authenticated else anon_request - serializer = MediaSearchRequestSerializer( - data={"internal__index": "some-index"}, context={"request": request} + serializer = ImageSearchRequestSerializer( + data={"internal__index": "image-some-index"}, context={"request": request} ) assert serializer.is_valid() assert serializer.validated_data.get("index") == ( - "some-index" if authenticated else None + "image-some-index" if authenticated else None ) if authenticated: # If authenticated, we should have checked that the index exists. - mock_es.indices.exists.assert_called_with("some-index") + mock_es.indices.exists.assert_called_with("image-some-index") else: # If not authenticated, the validator quickly returns ``None``. mock_es.indices.exists.assert_not_called() @@ -215,12 +215,12 @@ def test_index_is_only_set_if_authenticated( @patch("django.conf.settings.ES") @pytest.mark.parametrize( "index, is_valid", - (("index-that-exists", True), ("index-that-does-not-exist", False)), + (("image-index-that-exists", True), ("image-index-that-does-not-exist", False)), ) def test_index_is_only_set_if_valid(mock_es, index, is_valid, authed_request): mock_es.indices.exists = lambda index: "exists" in index - serializer = MediaSearchRequestSerializer( + serializer = ImageSearchRequestSerializer( data={"internal__index": index}, context={"request": authed_request} ) assert serializer.is_valid() == is_valid diff --git a/api/test/unit/views/test_media_views.py b/api/test/unit/views/test_media_views.py index b94e309d168..e95c320ce97 100644 --- a/api/test/unit/views/test_media_views.py +++ b/api/test/unit/views/test_media_views.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 +from rest_framework.response import Response + import pytest import pytest_django.asserts @@ -26,7 +28,7 @@ def test_list_query_count(api_client, media_type_config): ) with patch( "api.views.media_views.search_controller", - search=MagicMock(return_value=controller_ret), + query_media=MagicMock(return_value=controller_ret), ), patch( "api.serializers.media_serializers.search_controller", get_sources=MagicMock(return_value={}), @@ -49,6 +51,36 @@ def test_retrieve_query_count(api_client, media_type_config): assert res.status_code == 200 +@pytest.mark.django_db +@pytest.mark.parametrize( + "path, expected_params", + [ + pytest.param("tag/cat/", {"tag": "cat"}, id="tag"), + pytest.param("source/flickr/", {"source": "flickr"}, id="source"), + pytest.param( + "source/flickr/creator/cat/", + {"source": "flickr", "creator": "cat"}, + id="source_creator", + ), + ], +) +def test_collection_parameters(path, expected_params, api_client): + mock_get_media_results = MagicMock(return_value=Response()) + + with patch( + "api.views.media_views.MediaViewSet.get_media_results", + new_callable=lambda: mock_get_media_results, + ) as mock_get_media_results: + api_client.get(f"/v1/images/{path}") + + actual_params = mock_get_media_results.call_args[0][3] + request_kind = mock_get_media_results.call_args[0][1] + + assert mock_get_media_results.called + assert actual_params == expected_params + assert request_kind == "collection" + + @pytest.mark.parametrize( "filter_content", (True, False), ids=lambda x: "filtered" if x else "not_filtered" ) From 09f245caa7479db5ab5ad73ebf0039a324de6804 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Thu, 28 Sep 2023 16:42:01 +0300 Subject: [PATCH 03/16] Remove double plurals in field_names --- api/api/docs/audio_docs.py | 2 +- api/api/docs/image_docs.py | 2 +- api/api/serializers/audio_serializers.py | 4 ++-- api/api/serializers/image_serializers.py | 4 ++-- api/api/serializers/media_serializers.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/api/api/docs/audio_docs.py b/api/api/docs/audio_docs.py index 5c50d5481a6..39a3bf99252 100644 --- a/api/api/docs/audio_docs.py +++ b/api/api/docs/audio_docs.py @@ -39,7 +39,7 @@ By using this endpoint, you can obtain search results based on specified query and optionally filter results by - {fields_to_md(AudioSearchRequestSerializer.fields_names)}. + {fields_to_md(AudioSearchRequestSerializer.field_names)}. Results are ranked in order of relevance and paginated on the basis of the `page` param. The `page_size` param controls the total number of pages. diff --git a/api/api/docs/image_docs.py b/api/api/docs/image_docs.py index 58b89d883cd..bc4ef7c59a3 100644 --- a/api/api/docs/image_docs.py +++ b/api/api/docs/image_docs.py @@ -40,7 +40,7 @@ By using this endpoint, you can obtain search results based on specified query and optionally filter results by - {fields_to_md(ImageSearchRequestSerializer.fields_names)}. + {fields_to_md(ImageSearchRequestSerializer.field_names)}. Results are ranked in order of relevance and paginated on the basis of the `page` param. The `page_size` param controls the total number of pages. diff --git a/api/api/serializers/audio_serializers.py b/api/api/serializers/audio_serializers.py index 53236feb0d4..0da59296267 100644 --- a/api/api/serializers/audio_serializers.py +++ b/api/api/serializers/audio_serializers.py @@ -48,8 +48,8 @@ class AudioSearchRequestSerializer( ): """Parse and validate search query string parameters.""" - fields_names = [ - *MediaSearchRequestSerializer.fields_names, + field_names = [ + *MediaSearchRequestSerializer.field_names, *AudioSearchRequestSourceSerializer.field_names, "category", "length", diff --git a/api/api/serializers/image_serializers.py b/api/api/serializers/image_serializers.py index 5824198936c..d483a5d5e2e 100644 --- a/api/api/serializers/image_serializers.py +++ b/api/api/serializers/image_serializers.py @@ -33,8 +33,8 @@ class ImageSearchRequestSerializer( ): """Parse and validate search query string parameters.""" - fields_names = [ - *MediaSearchRequestSerializer.fields_names, + field_names = [ + *MediaSearchRequestSerializer.field_names, *ImageSearchRequestSourceSerializer.field_names, "category", "aspect_ratio", diff --git a/api/api/serializers/media_serializers.py b/api/api/serializers/media_serializers.py index b022182f69d..44564016f08 100644 --- a/api/api/serializers/media_serializers.py +++ b/api/api/serializers/media_serializers.py @@ -106,7 +106,7 @@ class MediaSearchRequestSerializer(PaginatedRequestSerializer): DeprecatedParam("pagesize", "page_size"), DeprecatedParam("provider", "source"), ] - fields_names = [ + field_names = [ "q", "license", "license_type", From 383bf3c05883cdec6892170272603039725f7520 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Thu, 28 Sep 2023 16:54:11 +0300 Subject: [PATCH 04/16] Add docs to media_serializers Notes on fuzzy matching for query params and maximum page_size documentation --- api/api/serializers/media_serializers.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/api/api/serializers/media_serializers.py b/api/api/serializers/media_serializers.py index 44564016f08..d8a52d4c1f8 100644 --- a/api/api/serializers/media_serializers.py +++ b/api/api/serializers/media_serializers.py @@ -36,7 +36,9 @@ class PaginatedRequestSerializer(serializers.Serializer): page_size = serializers.IntegerField( label="page_size", help_text=f"Number of results to return per page. " - f"Maximum for unauthenticated requests is {settings.MAX_ANONYMOUS_PAGE_SIZE}.", + f"Maximum is {settings.MAX_AUTHED_PAGE_SIZE} for authenticated " + f"requests, and {settings.MAX_ANONYMOUS_PAGE_SIZE} for " + f"unauthenticated requests.", required=False, default=settings.MAX_ANONYMOUS_PAGE_SIZE, min_value=1, @@ -148,19 +150,30 @@ class MediaSearchRequestSerializer(PaginatedRequestSerializer): ) creator = serializers.CharField( label="creator", - help_text="Search by creator only. Cannot be used with `q`.", + help_text="Search by creator only. Cannot be used with `q`. The search " + "is fuzzy, so `creator=john` will match any value that includes the " + "word `john`. If the value contains space, items that contain any of " + "the words in the value will match. To search for several values, " + "join them with a comma.", required=False, max_length=200, ) tags = serializers.CharField( label="tags", - help_text="Search by tag only. Cannot be used with `q`.", + help_text="Search by tag only. Cannot be used with `q`. The search " + "is fuzzy, so `tags=cat` will match any value that includes the word " + "`cat`. If the value contains space, items that contain any of the " + "words in the value will match. To search for several values, join " + "them with a comma.", required=False, max_length=200, ) title = serializers.CharField( label="title", - help_text="Search by title only. Cannot be used with `q`.", + help_text="Search by title only. Cannot be used with `q`. The search is fuzzy," + " so `title=photo` will match any value that includes the word `photo`. " + "If the value contains space, items that contain any of the words in the " + "value will match. To search for several values, join them with a comma.", required=False, max_length=200, ) From 00523508c728821b27ca04cf628b01bf4af70c0f Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Thu, 28 Sep 2023 17:07:11 +0300 Subject: [PATCH 05/16] Improve collection documentation Co-authored-by: Staci Mullins <63313398+stacimc@users.noreply.github.com> --- api/api/docs/audio_docs.py | 15 ++++++- api/api/docs/image_docs.py | 15 ++++++- api/api/views/audio_views.py | 77 ++++------------------------------- api/api/views/image_views.py | 78 ++++-------------------------------- 4 files changed, 43 insertions(+), 142 deletions(-) diff --git a/api/api/docs/audio_docs.py b/api/api/docs/audio_docs.py index 39a3bf99252..0b11489e3bc 100644 --- a/api/api/docs/audio_docs.py +++ b/api/api/docs/audio_docs.py @@ -1,6 +1,6 @@ from drf_spectacular.utils import OpenApiResponse, extend_schema -from api.docs.base_docs import custom_extend_schema, fields_to_md +from api.docs.base_docs import collection_schema, custom_extend_schema, fields_to_md from api.examples import ( audio_complain_201_example, audio_complain_curl, @@ -116,3 +116,16 @@ }, eg=[audio_waveform_curl], ) + +source_collection = collection_schema( + media_type="audio", + collection="source", +) +creator_collection = collection_schema( + media_type="audio", + collection="creator", +) +tag_collection = collection_schema( + media_type="audio", + collection="tag", +) diff --git a/api/api/docs/image_docs.py b/api/api/docs/image_docs.py index bc4ef7c59a3..ae603224f7f 100644 --- a/api/api/docs/image_docs.py +++ b/api/api/docs/image_docs.py @@ -1,6 +1,6 @@ from drf_spectacular.utils import OpenApiResponse, extend_schema -from api.docs.base_docs import custom_extend_schema, fields_to_md +from api.docs.base_docs import collection_schema, custom_extend_schema, fields_to_md from api.examples import ( image_complain_201_example, image_complain_curl, @@ -122,3 +122,16 @@ watermark = custom_extend_schema( deprecated=True, ) + +source_collection = collection_schema( + media_type="images", + collection="source", +) +creator_collection = collection_schema( + media_type="images", + collection="creator", +) +tag_collection = collection_schema( + media_type="images", + collection="tag", +) diff --git a/api/api/views/audio_views.py b/api/api/views/audio_views.py index 66f242bba34..6457b233a09 100644 --- a/api/api/views/audio_views.py +++ b/api/api/views/audio_views.py @@ -3,20 +3,18 @@ from rest_framework.exceptions import APIException, NotFound from rest_framework.response import Response -from drf_spectacular.utils import ( - OpenApiExample, - OpenApiResponse, - extend_schema, - extend_schema_view, -) +from drf_spectacular.utils import extend_schema, extend_schema_view from api.constants.media_types import AUDIO_TYPE from api.docs.audio_docs import ( + creator_collection, detail, related, report, search, + source_collection, stats, + tag_collection, thumbnail, waveform, ) @@ -28,7 +26,6 @@ AudioSerializer, AudioWaveformSerializer, ) -from api.serializers.error_serializers import NotFoundErrorSerializer from api.serializers.media_serializers import MediaThumbnailRequestSerializer from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle from api.views.media_views import MediaViewSet @@ -56,91 +53,31 @@ def get_queryset(self): return super().get_queryset().select_related("mature_audio", "audioset") # Extra actions - @extend_schema( - operation_id="audio_by_creator", - summary="audio_by_creator_at_source", - responses={ - 200: AudioSerializer(many=True), - 404: OpenApiResponse( - NotFoundErrorSerializer, - examples=[ - OpenApiExample( - name="404", - value={ - "detail": "Invalid source 'source_name'. " - "Valid sources are ..." - }, - ) - ], - ), - }, - parameters=[AudioCollectionRequestSerializer], - ) + @creator_collection @action( detail=False, methods=["get"], url_path="source/(?P[^/.]+)/creator/(?P.+)", ) def creator_collection(self, request, source, creator): - """ - Get a collection of audio items by a specific creator from the specified source. - - The items in the collection will be sorted by the order in which they were - added to Openverse. - """ return super().creator_collection(request, source, creator) - @extend_schema( - operation_id="audio_by_source", - summary="audio_by_source", - responses={ - 200: AudioSerializer(many=True), - 404: OpenApiResponse( - NotFoundErrorSerializer, - examples=[ - OpenApiExample( - name="404", - value={ - "detail": "Invalid source 'source_name'. " - "Valid sources are ..." - }, - ) - ], - ), - }, - parameters=[AudioCollectionRequestSerializer], - ) + @source_collection @action( detail=False, methods=["get"], url_path="source/(?P[^/.]+)", ) def source_collection(self, request, source): - """ - Get a collection of audio items from a specific source. - - The items in the collection will be sorted by the order in which they were - added to Openverse. - """ return super().source_collection(request, source) - @extend_schema( - operation_id="audio_by_tag", - summary="audio_by_tag", - responses={200: AudioSerializer(many=True)}, - parameters=[AudioCollectionRequestSerializer], - ) + @tag_collection @action( detail=False, methods=["get"], url_path="tag/(?P[^/.]+)", ) def tag_collection(self, request, tag, *_, **__): - """ - Get a collection of audio items with a specific tag. - - The items will be ranked by their popularity and authority. - """ return super().tag_collection(request, tag, *_, **__) @thumbnail diff --git a/api/api/views/image_views.py b/api/api/views/image_views.py index 9a9c3f55f70..641a230f384 100644 --- a/api/api/views/image_views.py +++ b/api/api/views/image_views.py @@ -8,27 +8,24 @@ from rest_framework.response import Response import requests -from drf_spectacular.utils import ( - OpenApiExample, - OpenApiResponse, - extend_schema, - extend_schema_view, -) +from drf_spectacular.utils import extend_schema, extend_schema_view from PIL import Image as PILImage from api.constants.media_types import IMAGE_TYPE from api.docs.image_docs import ( + creator_collection, detail, oembed, related, report, search, + source_collection, stats, + tag_collection, thumbnail, ) from api.docs.image_docs import watermark as watermark_doc from api.models import Image -from api.serializers.error_serializers import NotFoundErrorSerializer from api.serializers.image_serializers import ( ImageReportRequestSerializer, ImageSearchRequestSerializer, @@ -62,6 +59,7 @@ class ImageViewSet(MediaViewSet): default_index = settings.MEDIA_INDEX_MAPPING[IMAGE_TYPE] serializer_class = ImageSerializer + collection_serializer_class = PaginatedRequestSerializer OEMBED_HEADERS = { "User-Agent": settings.OUTBOUND_USER_AGENT_TEMPLATE.format(purpose="OEmbed"), @@ -71,91 +69,31 @@ def get_queryset(self): return super().get_queryset().select_related("mature_image") # Extra actions - @extend_schema( - operation_id="image_by_creator_at_source", - summary="image_by_creator_at_source", - responses={ - 200: ImageSerializer(many=True), - 404: OpenApiResponse( - NotFoundErrorSerializer, - examples=[ - OpenApiExample( - name="404", - value={ - "detail": "Invalid source 'source_name'. " - "Valid sources are ..." - }, - ) - ], - ), - }, - parameters=[PaginatedRequestSerializer], - ) + @creator_collection @action( detail=False, methods=["get"], url_path="source/(?P[^/.]+)/creator/(?P.+)", ) def creator_collection(self, request, source, creator): - """ - Get a collection of images by a specific creator from the specified source. - - The images in the collection will be sorted by the order in which they were - added to Openverse. - """ return super().creator_collection(request, source, creator) - @extend_schema( - operation_id="image_by_source", - summary="image_by_source", - responses={ - 200: ImageSerializer(many=True), - 404: OpenApiResponse( - NotFoundErrorSerializer, - examples=[ - OpenApiExample( - name="404", - value={ - "detail": "Invalid source 'source_name'. " - "Valid sources are ..." - }, - ) - ], - ), - }, - parameters=[PaginatedRequestSerializer], - ) + @source_collection @action( detail=False, methods=["get"], url_path="source/(?P[^/.]+)", ) def source_collection(self, request, source, *_, **__): - """ - Get a collection of images from a specific source. - - The images in the collection will be sorted by the order in which they were - added to Openverse. - """ return super().source_collection(request, source) - @extend_schema( - operation_id="images_by_tag", - summary="images_by_tag", - responses={200: ImageSerializer(many=True)}, - parameters=[PaginatedRequestSerializer], - ) + @tag_collection @action( detail=False, methods=["get"], url_path="tag/(?P[^/.]+)", ) def tag_collection(self, request, tag, *_, **__): - """ - Get a collection of images with a specific tag. - - The images in the collection will be ranked by their popularity and authority. - """ return super().tag_collection(request, tag, *_, **__) @oembed From f6713df0cc6c0385cb8c537cd9732594d3ff5954 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Fri, 10 Nov 2023 09:24:38 +0300 Subject: [PATCH 06/16] Add integration tests Signed-off-by: Olga Bulat --- api/test/media_integration.py | 40 ++++++++++++++++++++++++++++++ api/test/test_audio_integration.py | 15 +++++++++++ api/test/test_image_integration.py | 15 +++++++++++ 3 files changed, 70 insertions(+) diff --git a/api/test/media_integration.py b/api/test/media_integration.py index c1704f12047..d78ac60edfc 100644 --- a/api/test/media_integration.py +++ b/api/test/media_integration.py @@ -26,6 +26,46 @@ def search_by_category(media_path, category, fixture): assert all(audio_item["category"] == category for audio_item in results) +def tag_collection(media_path): + response = requests.get(f"{API_URL}/v1/{media_path}/tag/cat") + assert response.status_code == 200 + + results = response.json()["results"] + for r in results: + tag_names = [tag["name"] for tag in r["tags"]] + assert "cat" in tag_names + + +def source_collection(media_path): + source = requests.get(f"{API_URL}/v1/{media_path}/stats").json()[0]["source_name"] + + response = requests.get(f"{API_URL}/v1/{media_path}/source/{source}") + assert response.status_code == 200 + + results = response.json()["results"] + assert all(result["source"] == source for result in results) + + +def creator_collection(media_path): + source = requests.get(f"{API_URL}/v1/{media_path}/stats").json()[0]["source_name"] + + first_res = requests.get(f"{API_URL}/v1/{media_path}/source/{source}").json()[ + "results" + ][0] + if not (creator := first_res.get("creator")): + raise AttributeError(f"No creator in {first_res}") + + response = requests.get( + f"{API_URL}/v1/{media_path}/source/{source}/creator/{creator}" + ) + assert response.status_code == 200 + + results = response.json()["results"] + assert all( + r["creator"] == "creator" and results["source"] == source for r in results + ) + + def search_all_excluded(media_path, excluded_source): response = requests.get( f"{API_URL}/v1/{media_path}?q=test&excluded_source={','.join(excluded_source)}" diff --git a/api/test/test_audio_integration.py b/api/test/test_audio_integration.py index 9c4aa0210da..c9e355739af 100644 --- a/api/test/test_audio_integration.py +++ b/api/test/test_audio_integration.py @@ -8,6 +8,7 @@ import json from test.constants import API_URL from test.media_integration import ( + creator_collection, detail, license_filter_case_insensitivity, related, @@ -21,7 +22,9 @@ search_source_and_excluded, search_special_chars, sensitive_search_and_detail, + source_collection, stats, + tag_collection, uuid_validation, ) @@ -160,5 +163,17 @@ def test_audio_related(audio_fixture): related(audio_fixture) +def test_audio_tag_collection(): + tag_collection("audio") + + +def test_audio_source_collection(): + source_collection("audio") + + +def test_audio_creator_collection(): + creator_collection("audio") + + def test_audio_sensitive_search_and_detail(): sensitive_search_and_detail("audio") diff --git a/api/test/test_image_integration.py b/api/test/test_image_integration.py index 36fe7a33f67..f87d3fc16c9 100644 --- a/api/test/test_image_integration.py +++ b/api/test/test_image_integration.py @@ -8,6 +8,7 @@ import json from test.constants import API_URL from test.media_integration import ( + creator_collection, detail, license_filter_case_insensitivity, related, @@ -20,7 +21,9 @@ search_source_and_excluded, search_special_chars, sensitive_search_and_detail, + source_collection, stats, + tag_collection, uuid_validation, ) from urllib.parse import urlencode @@ -145,6 +148,18 @@ def test_image_uuid_validation(): uuid_validation("images", "abcd") +def test_image_tag_collection(): + tag_collection("images") + + +def test_image_source_collection(): + source_collection("images") + + +def test_image_creator_collection(): + creator_collection("images") + + def test_image_related(image_fixture): related(image_fixture) From 419c0df64e9d1a3f1056237a7cae285c22d5e20a Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Fri, 10 Nov 2023 09:38:47 +0300 Subject: [PATCH 07/16] Update docs Signed-off-by: Olga Bulat --- api/api/docs/base_docs.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/api/api/docs/base_docs.py b/api/api/docs/base_docs.py index 41caf7e8236..92a3f9637e9 100644 --- a/api/api/docs/base_docs.py +++ b/api/api/docs/base_docs.py @@ -140,13 +140,13 @@ def get_collection_description(media_type, collection): return f""" Get a collection of {media_type} with a specific tag. -This endpoint returns only the exact matches. To search within the -tag values, or to match several tags, use the `search` endpoint -with `tags` query parameter instead of `q` parameter. +This endpoint returns only the exact matches, case-insensitive matches for the +specified tag. For example, 'birds' and 'birding' are not matches for 'bird'. +To search within the tag values, or to match several tags, use the `search` endpoint +with `tags` query parameter instead of `q` parameter. In this case, the matches will + not be exact, so 'cat' would match both 'cat' and 'cats'. -The returned results are ordered primarily based on their popularity -and authority. However, note that the exact order may vary over time -or across requests. +The returned results are ordered based on the time when they were added to Openverse. """ elif collection == "source": return f""" @@ -168,8 +168,7 @@ def get_collection_description(media_type, collection): the creator value, use the `search` endpoint with `source` query parameter instead of `q`. -The order in the results is not guaranteed to stay the same. Most likely, the images -in the collection will be sorted by the order in which they were added to Openverse. +The items will be sorted by the date when they were added to Openverse. """ From 79bf445a934964cb96e7c101575334176449b43e Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Mon, 13 Nov 2023 05:57:08 +0300 Subject: [PATCH 08/16] Update api/api/controllers/search_controller.py Co-authored-by: sarayourfriend <24264157+sarayourfriend@users.noreply.github.com> --- api/api/controllers/search_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index e6eb9844580..ba58e0c5b01 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -337,7 +337,7 @@ def build_collection_query( # names are identical. filters = [ # Collection filters allow a single value. - ("tag", "tags.name"), + ("tag", "tags.name.keyword"), ("source", None), ("creator", None), ] From 545d45e4b3ce8337af419023a483d8e4502dbb3a Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Mon, 13 Nov 2023 05:57:27 +0300 Subject: [PATCH 09/16] Update api/api/controllers/search_controller.py Co-authored-by: sarayourfriend <24264157+sarayourfriend@users.noreply.github.com> --- api/api/controllers/search_controller.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index ba58e0c5b01..e04534fff75 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -354,6 +354,7 @@ def build_collection_query( ) if not include_sensitive_by_params: search_query["must_not"].append({"term": {"mature": True}}) + if excluded_providers_query := get_excluded_providers_query(): search_query["must_not"].append(excluded_providers_query) From 7317ca79d74843d3099de1d904c6adbcf7a5f418 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Mon, 13 Nov 2023 06:28:43 +0300 Subject: [PATCH 10/16] Update test Signed-off-by: Olga Bulat --- .../controllers/test_search_controller.py | 62 ------------------- .../test_search_controller_search_query.py | 4 +- 2 files changed, 2 insertions(+), 64 deletions(-) diff --git a/api/test/unit/controllers/test_search_controller.py b/api/test/unit/controllers/test_search_controller.py index 82602a2c6f6..764fffca21f 100644 --- a/api/test/unit/controllers/test_search_controller.py +++ b/api/test/unit/controllers/test_search_controller.py @@ -3,8 +3,6 @@ import re from collections.abc import Callable from enum import Enum, auto - -from api.serializers.media_serializers import PaginatedRequestSerializer from test.factory.es_http import ( MOCK_DEAD_RESULT_URL_PREFIX, MOCK_LIVE_RESULT_URL_PREFIX, @@ -809,63 +807,3 @@ def _delete_all_results_but_first(_, __, results, ___): filter_dead=True, ) assert "Nesting threshold breached" in caplog.text - - -@pytest.mark.parametrize( - ("data", "expected_query"), - [ - pytest.param( - {"unstable__include_sensitive_results": False, "tag": "art"}, - Q( - "bool", - filter=[{"terms": {"tags.name.keyword": ["art"]}}], - must_not=[{"term": {"mature": True}}], - ), - id="filter_by_tag_without_sensitive", - ), - pytest.param( - {"unstable__include_sensitive_results": True, "tag": "art"}, - Q( - "bool", - filter=[{"terms": {"tags.name.keyword": ["art"]}}], - ), - id="filter_by_tag_with_sensitive", - ), - pytest.param( - {"unstable__include_sensitive_results": False, "source": "flickr"}, - Q( - "bool", - filter=[{"terms": {"source.keyword": ["flickr"]}}], - must_not=[{"term": {"mature": True}}], - ), - id="filter_by_source_without_sensitive", - ), - pytest.param( - { - "unstable__include_sensitive_results": False, - "source": "flickr", - "creator": "nasa", - }, - Q( - "bool", - filter=[ - {"terms": {"source.keyword": ["flickr"]}}, - {"terms": {"creator.keyword": ["nasa"]}}, - ], - must_not=[{"term": {"mature": True}}], - ), - id="filter_by_creator_without_sensitive", - ), - ], -) -@mock.patch("api.controllers.search_controller.Search", wraps=Search) -def test_build_collection_query(mock_search_class, data, expected_query): - # Setup - mock_search = mock_search_class.return_value - - # Action - build_collection_query(mock_search, data) - actual_query = mock_search.query.call_args[0][0] - - # Validate - assert actual_query == expected_query diff --git a/api/test/unit/controllers/test_search_controller_search_query.py b/api/test/unit/controllers/test_search_controller_search_query.py index 566e718334c..4ebee63f435 100644 --- a/api/test/unit/controllers/test_search_controller_search_query.py +++ b/api/test/unit/controllers/test_search_controller_search_query.py @@ -256,12 +256,12 @@ def test_create_search_query_empty_with_dynamically_excluded_providers( [ pytest.param( {"tag": "art"}, - [{"term": {"tags.name": "art"}}], + [{"term": {"tags.name.keyword": "art"}}], id="filter_by_tag", ), pytest.param( {"tag": "art, photography"}, - [{"term": {"tags.name": "art, photography"}}], + [{"term": {"tags.name.keyword": "art, photography"}}], id="filter_by_tag_treats_punctuation_as_part_of_tag", ), pytest.param( From 03f79fb1347962ad403f5f5816c738c641d9f1e3 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Mon, 13 Nov 2023 06:45:08 +0300 Subject: [PATCH 11/16] Combine TYPE_CHECKING clauses Signed-off-by: Olga Bulat --- api/api/controllers/search_controller.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index e04534fff75..df9a9d51646 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -29,6 +29,8 @@ from api.utils.dead_link_mask import get_query_hash from api.utils.search_context import SearchContext + +# Using TYPE_CHECKING to avoid circular imports when importing types if TYPE_CHECKING: from api.serializers.audio_serializers import AudioCollectionRequestSerializer from api.serializers.media_serializers import ( @@ -36,6 +38,12 @@ PaginatedRequestSerializer, ) + MediaListRequestSerializer = ( + AudioCollectionRequestSerializer + | MediaSearchRequestSerializer + | PaginatedRequestSerializer + ) + module_logger = logging.getLogger(__name__) From e4fc8e80dab2e41cae2feca186f2d8312a101858 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Mon, 13 Nov 2023 06:56:00 +0300 Subject: [PATCH 12/16] Extract build_query to clean up query_media Signed-off-by: Olga Bulat --- api/api/controllers/search_controller.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index df9a9d51646..49fd2da4957 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -369,6 +369,16 @@ def build_collection_query( return Q("bool", **search_query) +def build_query( + strategy: SearchStrategy, + search_params: MediaListRequestSerializer, + collection_params: dict[str, str] | None, +) -> Q: + if strategy == "collection": + return build_collection_query(search_params, collection_params) + return build_search_query(search_params) + + def query_media( strategy: SearchStrategy, search_params: MediaListRequestSerializer, @@ -405,10 +415,7 @@ def query_media( """ index = get_index(exact_index, origin_index, search_params) - if strategy == "collection": - query = build_collection_query(search_params, collection_params) - else: - query = build_search_query(search_params) + query = build_query(strategy, search_params, collection_params) s = Search(index=index).query(query) @@ -426,8 +433,8 @@ def query_media( # Sort by `created_on` if the parameter is set or if `strategy` is `collection`. sort_by = search_params.validated_data.get("sort_by") - sort_dir = search_params.validated_data.get("sort_dir", "desc") if strategy == "collection" or sort_by == INDEXED_ON: + sort_dir = search_params.validated_data.get("sort_dir", "desc") s = s.sort({"created_on": {"order": sort_dir}}) # Execute paginated search and tally results From 9666f0bd7f249f946cb30b7f0daeb1a9bad616f4 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Tue, 14 Nov 2023 07:31:22 +0300 Subject: [PATCH 13/16] Use the keyword creator field Signed-off-by: Olga Bulat --- api/api/controllers/search_controller.py | 2 +- .../unit/controllers/test_search_controller_search_query.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index 49fd2da4957..0740587ea06 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -347,7 +347,7 @@ def build_collection_query( # Collection filters allow a single value. ("tag", "tags.name.keyword"), ("source", None), - ("creator", None), + ("creator", "creator.keyword"), ] for serializer_field, es_field in filters: if serializer_field in collection_params: diff --git a/api/test/unit/controllers/test_search_controller_search_query.py b/api/test/unit/controllers/test_search_controller_search_query.py index 4ebee63f435..c61f12cc4de 100644 --- a/api/test/unit/controllers/test_search_controller_search_query.py +++ b/api/test/unit/controllers/test_search_controller_search_query.py @@ -273,7 +273,7 @@ def test_create_search_query_empty_with_dynamically_excluded_providers( {"source": "flickr", "creator": "nasa"}, [ {"term": {"source": "flickr"}}, - {"term": {"creator": "nasa"}}, + {"term": {"creator.keyword": "nasa"}}, ], id="filter_by_creator", ), From 8a72ad30675d4aea23581cf4a85e44d0ab90b830 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Tue, 14 Nov 2023 18:26:25 +0300 Subject: [PATCH 14/16] Fix related test Signed-off-by: Olga Bulat --- api/test/media_integration.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/api/test/media_integration.py b/api/test/media_integration.py index d78ac60edfc..fd15ea7c70f 100644 --- a/api/test/media_integration.py +++ b/api/test/media_integration.py @@ -5,6 +5,7 @@ """ import json +import re from test.constants import API_URL import requests @@ -199,7 +200,9 @@ def related(fixture): assert response["page_count"] == 1 def get_terms_set(res): - return set([t["name"] for t in res["tags"]] + res["title"].split(" ")) + # The title is analyzed in ES, we try to mimic it here. + terms = [t["name"] for t in res["tags"]] + re.split(" |-", res["title"]) + return {t.lower() for t in terms} terms_set = get_terms_set(item) # Make sure each result has at least one word in common with the original item, @@ -208,7 +211,7 @@ def get_terms_set(res): assert ( len(terms_set.intersection(get_terms_set(result))) > 0 or result["creator"] == item["creator"] - ) + ), f"{terms_set} {get_terms_set(result)}/{result['creator']}-{item['creator']}" def sensitive_search_and_detail(media_type): From 2b30643b8674fc28a386a699d01086e76e93e76b Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Tue, 14 Nov 2023 18:51:18 +0300 Subject: [PATCH 15/16] Fix creator collection test Signed-off-by: Olga Bulat --- api/test/media_integration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/test/media_integration.py b/api/test/media_integration.py index fd15ea7c70f..3c4a30f8e73 100644 --- a/api/test/media_integration.py +++ b/api/test/media_integration.py @@ -62,9 +62,9 @@ def creator_collection(media_path): assert response.status_code == 200 results = response.json()["results"] - assert all( - r["creator"] == "creator" and results["source"] == source for r in results - ) + for result in results: + assert result["source"] == source, f"{result['source']} != {source}" + assert result["creator"] == creator, f"{result['creator']} != {creator}" def search_all_excluded(media_path, excluded_source): From 3248289d4179470dc106dc8fd8535d889716e64c Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Thu, 16 Nov 2023 19:24:32 +0300 Subject: [PATCH 16/16] Update api/api/docs/base_docs.py Co-authored-by: sarayourfriend <24264157+sarayourfriend@users.noreply.github.com> Signed-off-by: Olga Bulat --- api/api/docs/base_docs.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/api/api/docs/base_docs.py b/api/api/docs/base_docs.py index 92a3f9637e9..51069662056 100644 --- a/api/api/docs/base_docs.py +++ b/api/api/docs/base_docs.py @@ -140,11 +140,27 @@ def get_collection_description(media_type, collection): return f""" Get a collection of {media_type} with a specific tag. -This endpoint returns only the exact matches, case-insensitive matches for the -specified tag. For example, 'birds' and 'birding' are not matches for 'bird'. -To search within the tag values, or to match several tags, use the `search` endpoint -with `tags` query parameter instead of `q` parameter. In this case, the matches will - not be exact, so 'cat' would match both 'cat' and 'cats'. +This endpoint matches a single tag, exactly and entirely. + +Differences that will cause tags to not match are: +- upper and lower case letters +- diacritical marks +- hyphenation +- spacing +- multi-word tags where the query is only one of the words in the tag +- multi-word tags where the words are in a different order + +Examples of tags that **do not** match: +- "Low-Quality" and "low-quality" +- "jalapeƱo" and "jalapeno" +- "Saint Pierre des Champs" and "Saint-Pierre-des-Champs" +- "dog walking" and "dog walking" (where the latter has two spaces between the +last two words, as in a typographical error) +- "runner" and "marathon runner" +- "exclaiming loudly" and "loudly exclaiming" + +For non-exact or multi-tag matching, using the `search` endpoint's `tags` query +parameter. The returned results are ordered based on the time when they were added to Openverse. """