diff --git a/api/api/constants/media_types.py b/api/api/constants/media_types.py index 4a0b502ad3c..ae97bbd3e6d 100644 --- a/api/api/constants/media_types.py +++ b/api/api/constants/media_types.py @@ -6,7 +6,9 @@ IMAGE_TYPE = "image" MEDIA_TYPES = [AUDIO_TYPE, IMAGE_TYPE] +MediaType = Literal["audio", "image"] MEDIA_TYPE_CHOICES = [(AUDIO_TYPE, "Audio"), (IMAGE_TYPE, "Image")] -OriginIndex = Literal["image", "audio"] +OriginIndex = MediaType +SearchIndex = Literal["image", "image-filtered", "audio", "audio-filtered"] diff --git a/api/api/constants/search.py b/api/api/constants/search.py new file mode 100644 index 00000000000..4325c3a96f0 --- /dev/null +++ b/api/api/constants/search.py @@ -0,0 +1,5 @@ +from typing import Literal + + +SEARCH_STRATEGIES = ["search", "collection"] +SearchStrategy = Literal["search", "collection"] diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index 5da859f102e..0740587ea06 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -3,7 +3,7 @@ import logging import logging as log from math import ceil -from typing import Literal +from typing import TYPE_CHECKING from django.conf import settings from django.core.cache import cache @@ -15,7 +15,8 @@ from elasticsearch_dsl.response import Hit, Response import api.models as models -from api.constants.media_types import OriginIndex +from api.constants.media_types import OriginIndex, SearchIndex +from api.constants.search import SearchStrategy from api.constants.sorting import INDEXED_ON from api.controllers.elasticsearch.helpers import ( ELASTICSEARCH_MAX_RESULT_WINDOW, @@ -23,13 +24,26 @@ get_query_slice, get_raw_es_response, ) -from api.serializers import media_serializers from api.utils import tallies from api.utils.check_dead_links import check_dead_links from api.utils.dead_link_mask import get_query_hash from api.utils.search_context import SearchContext +# Using TYPE_CHECKING to avoid circular imports when importing types +if TYPE_CHECKING: + from api.serializers.audio_serializers import AudioCollectionRequestSerializer + from api.serializers.media_serializers import ( + MediaSearchRequestSerializer, + PaginatedRequestSerializer, + ) + + MediaListRequestSerializer = ( + AudioCollectionRequestSerializer + | MediaSearchRequestSerializer + | PaginatedRequestSerializer + ) + module_logger = logging.getLogger(__name__) @@ -172,24 +186,24 @@ def get_excluded_providers_query() -> Q | None: return None -def _resolve_index( - index: Literal["image", "audio"], - search_params: media_serializers.MediaSearchRequestSerializer, -) -> Literal["image", "image-filtered", "audio", "audio-filtered"]: - use_filtered_index = all( - ( - settings.ENABLE_FILTERED_INDEX_QUERIES, - not search_params.validated_data["include_sensitive_results"], - ) - ) - if use_filtered_index: - return f"{index}-filtered" +def get_index( + exact_index: bool, + origin_index: OriginIndex, + search_params: MediaListRequestSerializer, +) -> SearchIndex: + if exact_index: + return origin_index - return index + include_sensitive_results = search_params.validated_data.get( + "include_sensitive_results", False + ) + if settings.ENABLE_FILTERED_INDEX_QUERIES and not include_sensitive_results: + return f"{origin_index}-filtered" + return origin_index def create_search_filter_queries( - search_params: media_serializers.MediaSearchRequestSerializer, + search_params: MediaListRequestSerializer, ) -> dict[str, list[Q]]: """ Create a list of Elasticsearch queries for filtering search results. @@ -230,7 +244,7 @@ def create_search_filter_queries( def create_ranking_queries( - search_params: media_serializers.MediaSearchRequestSerializer, + search_params: MediaListRequestSerializer, ) -> list[Q]: queries = [Q("rank_feature", field="standardized_popularity", boost=DEFAULT_BOOST)] if search_params.data["unstable__authority"]: @@ -240,8 +254,8 @@ def create_ranking_queries( return queries -def create_search_query( - search_params: media_serializers.MediaSearchRequestSerializer, +def build_search_query( + search_params: MediaListRequestSerializer, ) -> Q: # Apply filters from the url query search parameters. url_queries = create_search_filter_queries(search_params) @@ -315,8 +329,60 @@ def create_search_query( ) -def search( - search_params: media_serializers.MediaSearchRequestSerializer, +def build_collection_query( + search_params: MediaListRequestSerializer, + collection_params: dict[str, str], +): + """ + Build the query to retrieve items in a collection. + :param collection_params: `tag`, `source` and/or `creator` values from the path. + :param search_params: the validated search parameters. + :return: the search client with the query applied. + """ + search_query = {"filter": [], "must": [], "should": [], "must_not": []} + # Apply the term filters. Each tuple pairs a filter's parameter name in the API + # with its corresponding field in Elasticsearch. "None" means that the + # names are identical. + filters = [ + # Collection filters allow a single value. + ("tag", "tags.name.keyword"), + ("source", None), + ("creator", "creator.keyword"), + ] + for serializer_field, es_field in filters: + if serializer_field in collection_params: + if not (argument := collection_params.get(serializer_field)): + continue + parameter = es_field or serializer_field + search_query["filter"].append({"term": {parameter: argument}}) + + # Exclude mature content and disabled sources + include_sensitive_by_params = search_params.validated_data.get( + "include_sensitive_results", False + ) + if not include_sensitive_by_params: + search_query["must_not"].append({"term": {"mature": True}}) + + if excluded_providers_query := get_excluded_providers_query(): + search_query["must_not"].append(excluded_providers_query) + + return Q("bool", **search_query) + + +def build_query( + strategy: SearchStrategy, + search_params: MediaListRequestSerializer, + collection_params: dict[str, str] | None, +) -> Q: + if strategy == "collection": + return build_collection_query(search_params, collection_params) + return build_search_query(search_params) + + +def query_media( + strategy: SearchStrategy, + search_params: MediaListRequestSerializer, + collection_params: dict[str, str] | None, origin_index: OriginIndex, exact_index: bool, page_size: int, @@ -325,10 +391,17 @@ def search( page: int = 1, ) -> tuple[list[Hit], int, int, dict]: """ - Perform a ranked paginated search from the set of keywords and, optionally, filters. - - :param search_params: Search parameters. See - :class: `ImageSearchQueryStringSerializer`. + If ``strategy`` is ``search``, perform a ranked paginated search + from the set of keywords and, optionally, filters. + If `strategy` is `collection`, perform a paginated search + for the `tag`, `source` or `source` and `creator` combination. + + :param collection_params: The path parameters for collection search, if + strategy is `collection`. + :param strategy: Whether to perform a default search or retrieve a collection. + :param search_params: If `strategy` is `collection`, `PaginatedRequestSerializer` + or `AudioCollectionRequestSerializer`. If `strategy` is `search`, search + query params, see :class: `MediaRequestSerializer`. :param origin_index: The Elasticsearch index to search (e.g. 'image') :param exact_index: whether to skip all modifications to the index name :param page_size: The number of results to return per page. @@ -337,46 +410,51 @@ def search( Elasticsearch shards. :param filter_dead: Whether dead links should be removed. :param page: The results page number. - :return: Tuple with a List of Hits from elasticsearch, the total count of + :return: Tuple with a list of Hits from elasticsearch, the total count of pages, the number of results, and the ``SearchContext`` as a dict. """ - if not exact_index: - index = _resolve_index(origin_index, search_params) - else: - index = origin_index + index = get_index(exact_index, origin_index, search_params) + + query = build_query(strategy, search_params, collection_params) - s = Search(index=index) + s = Search(index=index).query(query) - search_query = create_search_query(search_params) - s = s.query(search_query) + if strategy == "search": + # Use highlighting to determine which fields contribute to the selection of + # top results. + s = s.highlight(*DEFAULT_SEARCH_FIELDS) + s = s.highlight_options(order="score") + s.extra(track_scores=True) - # Use highlighting to determine which fields contribute to the selection of - # top results. - s = s.highlight(*DEFAULT_SEARCH_FIELDS) - s = s.highlight_options(order="score") - s.extra(track_scores=True) # Route users to the same Elasticsearch worker node to reduce # pagination inconsistencies and increase cache hits. # TODO: Re-add 7s request_timeout when ES stability is restored s = s.params(preference=str(ip)) - # Sort by new - if search_params.validated_data["sort_by"] == INDEXED_ON: - s = s.sort({"created_on": {"order": search_params.validated_data["sort_dir"]}}) + # Sort by `created_on` if the parameter is set or if `strategy` is `collection`. + sort_by = search_params.validated_data.get("sort_by") + if strategy == "collection" or sort_by == INDEXED_ON: + sort_dir = search_params.validated_data.get("sort_dir", "desc") + s = s.sort({"created_on": {"order": sort_dir}}) - # Paginate - start, end = get_query_slice(s, page_size, page, filter_dead) - s = s[start:end] - search_response = get_es_response(s, es_query="search") - - results = _post_process_results( - s, start, end, page_size, search_response, filter_dead + # Execute paginated search and tally results + page_count, result_count, results = execute_search( + s, page, page_size, filter_dead, index, es_query=strategy ) - result_count, page_count = _get_result_and_page_count( - search_response, results, page_size, page - ) + result_ids = [result.identifier for result in results] + search_context = SearchContext.build(result_ids, origin_index) + + return results, page_count, result_count, search_context.asdict() + +def tally_results( + index: SearchIndex, results: list[Hit] | None, page: int, page_size: int +) -> None: + """ + Tally the number of the results from each provider in the results + for the search query. + """ results_to_tally = results or [] max_result_depth = page * page_size if max_result_depth <= 80: @@ -405,13 +483,33 @@ def search( # check things like provider density for a set of queries. tallies.count_provider_occurrences(results_to_tally, index) - if not results: - results = [] - result_ids = [result.identifier for result in results] - search_context = SearchContext.build(result_ids, origin_index) +def execute_search( + s: Search, + page: int, + page_size: int, + filter_dead: bool, + index: SearchIndex, + es_query: str, +) -> tuple[int, int, list[Hit]]: + """ + Execute search for the given query slice, post-processes the results, + and returns the results and result and page counts. + """ + start, end = get_query_slice(s, page_size, page, filter_dead) + s = s[start:end] - return results, page_count, result_count, search_context.asdict() + search_response = get_es_response(s, es_query=es_query) + + results: list[Hit] = ( + _post_process_results(s, start, end, page_size, search_response, filter_dead) + or [] + ) + result_count, page_count = _get_result_and_page_count( + search_response, results, page_size, page + ) + tally_results(index, results, page, page_size) + return page_count, result_count, results def get_sources(index): diff --git a/api/api/docs/audio_docs.py b/api/api/docs/audio_docs.py index 5c50d5481a6..0b11489e3bc 100644 --- a/api/api/docs/audio_docs.py +++ b/api/api/docs/audio_docs.py @@ -1,6 +1,6 @@ from drf_spectacular.utils import OpenApiResponse, extend_schema -from api.docs.base_docs import custom_extend_schema, fields_to_md +from api.docs.base_docs import collection_schema, custom_extend_schema, fields_to_md from api.examples import ( audio_complain_201_example, audio_complain_curl, @@ -39,7 +39,7 @@ By using this endpoint, you can obtain search results based on specified query and optionally filter results by - {fields_to_md(AudioSearchRequestSerializer.fields_names)}. + {fields_to_md(AudioSearchRequestSerializer.field_names)}. Results are ranked in order of relevance and paginated on the basis of the `page` param. The `page_size` param controls the total number of pages. @@ -116,3 +116,16 @@ }, eg=[audio_waveform_curl], ) + +source_collection = collection_schema( + media_type="audio", + collection="source", +) +creator_collection = collection_schema( + media_type="audio", + collection="creator", +) +tag_collection = collection_schema( + media_type="audio", + collection="tag", +) diff --git a/api/api/docs/base_docs.py b/api/api/docs/base_docs.py index dfbfb9fb3fd..51069662056 100644 --- a/api/api/docs/base_docs.py +++ b/api/api/docs/base_docs.py @@ -1,8 +1,23 @@ from http.client import responses as http_responses from textwrap import dedent - -from drf_spectacular.openapi import AutoSchema, OpenApiResponse -from drf_spectacular.utils import OpenApiExample, extend_schema +from typing import Literal + +from drf_spectacular.openapi import AutoSchema +from drf_spectacular.utils import ( + OpenApiExample, + OpenApiParameter, + OpenApiResponse, + extend_schema, +) + +from api.constants.media_types import MediaType +from api.serializers.audio_serializers import ( + AudioCollectionRequestSerializer, + AudioSerializer, +) +from api.serializers.error_serializers import NotFoundErrorSerializer +from api.serializers.image_serializers import ImageSerializer +from api.serializers.media_serializers import PaginatedRequestSerializer def fields_to_md(field_names): @@ -77,3 +92,138 @@ def get_operation_id(self) -> str: else: operation_tokens.append("detail") return "_".join(operation_tokens) + + +source_404_message = "Invalid source 'name'. Valid sources are ..." +source_404_response = OpenApiResponse( + NotFoundErrorSerializer, + examples=[ + OpenApiExample( + name="404", + value={"detail": source_404_message}, + ) + ], +) + + +def build_source_path_parameter(media_type: MediaType): + valid_description = ( + f"Valid values are source_names from the stats endpoint: " + f"https://api.openverse.engineering/v1/{media_type}/stats/." + ) + + return OpenApiParameter( + name="source", + type=str, + location=OpenApiParameter.PATH, + description=f"The source of {media_type}. {valid_description}", + ) + + +creator_path_parameter = OpenApiParameter( + name="creator", + type=str, + location=OpenApiParameter.PATH, + description="The name of the media creator. This parameter " + "is case-sensitive, and matches exactly.", +) +tag_path_parameter = OpenApiParameter( + name="tag", + type=str, + location=OpenApiParameter.PATH, + description="The tag of the media. Not case-sensitive, matches exactly.", +) + + +def get_collection_description(media_type, collection): + if collection == "tag": + return f""" +Get a collection of {media_type} with a specific tag. + +This endpoint matches a single tag, exactly and entirely. + +Differences that will cause tags to not match are: +- upper and lower case letters +- diacritical marks +- hyphenation +- spacing +- multi-word tags where the query is only one of the words in the tag +- multi-word tags where the words are in a different order + +Examples of tags that **do not** match: +- "Low-Quality" and "low-quality" +- "jalapeƱo" and "jalapeno" +- "Saint Pierre des Champs" and "Saint-Pierre-des-Champs" +- "dog walking" and "dog walking" (where the latter has two spaces between the +last two words, as in a typographical error) +- "runner" and "marathon runner" +- "exclaiming loudly" and "loudly exclaiming" + +For non-exact or multi-tag matching, using the `search` endpoint's `tags` query +parameter. + +The returned results are ordered based on the time when they were added to Openverse. + """ + elif collection == "source": + return f""" +Get a collection of {media_type} from a specific source. + +This endpoint returns only the exact matches. To search within the source value, +use the `search` endpoint with `source` query parameter. + +The results in the collection will be sorted by the order in which they +were added to Openverse. + """ + elif collection == "creator": + return f""" +Get a collection of {media_type} by a specific creator from the specified source. + +This endpoint returns only the exact matches both on the creator and the source. +Notice that a single creator's media items can be found on several sources, but +this endpoint only returns the items from the specified source. To search within +the creator value, use the `search` endpoint with `source` query parameter +instead of `q`. + +The items will be sorted by the date when they were added to Openverse. + """ + + +COLLECTION_TO_OPERATION_ID = { + ("images", "source"): "images_by_source", + ("images", "creator"): "images_by_source_and_creator", + ("images", "tag"): "images_by_tag", + ("audio", "source"): "audio_by_source", + ("audio", "creator"): "audio_by_source_and_creator", + ("audio", "tag"): "audio_by_tag", +} + + +def collection_schema( + media_type: Literal["images", "audio"], + collection: Literal["source", "creator", "tag"], +): + if media_type == "images": + request_serializer = PaginatedRequestSerializer + serializer = ImageSerializer + else: + request_serializer = AudioCollectionRequestSerializer + serializer = AudioSerializer + + if collection == "tag": + responses = {200: serializer(many=True)} + path_parameters = [tag_path_parameter] + else: + responses = {200: serializer(many=True), 404: source_404_response} + path_parameters = [build_source_path_parameter(media_type)] + if collection == "creator": + path_parameters.append(creator_path_parameter) + operation_id = COLLECTION_TO_OPERATION_ID[(media_type, collection)] + description = get_collection_description(media_type, collection) + return extend_schema( + operation_id=operation_id, + summary=operation_id, + auth=[], + description=description, + responses=responses, + parameters=[request_serializer, *path_parameters], + ) diff --git a/api/api/docs/image_docs.py b/api/api/docs/image_docs.py index 58b89d883cd..ae603224f7f 100644 --- a/api/api/docs/image_docs.py +++ b/api/api/docs/image_docs.py @@ -1,6 +1,6 @@ from drf_spectacular.utils import OpenApiResponse, extend_schema -from api.docs.base_docs import custom_extend_schema, fields_to_md +from api.docs.base_docs import collection_schema, custom_extend_schema, fields_to_md from api.examples import ( image_complain_201_example, image_complain_curl, @@ -40,7 +40,7 @@ By using this endpoint, you can obtain search results based on specified query and optionally filter results by - {fields_to_md(ImageSearchRequestSerializer.fields_names)}. + {fields_to_md(ImageSearchRequestSerializer.field_names)}. Results are ranked in order of relevance and paginated on the basis of the `page` param. The `page_size` param controls the total number of pages. @@ -122,3 +122,16 @@ watermark = custom_extend_schema( deprecated=True, ) + +source_collection = collection_schema( + media_type="images", + collection="source", +) +creator_collection = collection_schema( + media_type="images", + collection="creator", +) +tag_collection = collection_schema( + media_type="images", + collection="tag", +) diff --git a/api/api/serializers/audio_serializers.py b/api/api/serializers/audio_serializers.py index bb6718c5774..0da59296267 100644 --- a/api/api/serializers/audio_serializers.py +++ b/api/api/serializers/audio_serializers.py @@ -11,6 +11,7 @@ MediaReportRequestSerializer, MediaSearchRequestSerializer, MediaSerializer, + PaginatedRequestSerializer, get_hyperlinks_serializer, get_search_request_source_serializer, ) @@ -24,14 +25,31 @@ AudioSearchRequestSourceSerializer = get_search_request_source_serializer("audio") +class AudioCollectionRequestSerializer(PaginatedRequestSerializer): + field_names = [ + *PaginatedRequestSerializer.field_names, + "peaks", + ] + + peaks = serializers.BooleanField( + help_text="Whether to include the waveform peaks or not", + required=False, + default=False, + ) + + @property + def needs_db(self) -> bool: + return super().needs_db or self.data["peaks"] + + class AudioSearchRequestSerializer( AudioSearchRequestSourceSerializer, MediaSearchRequestSerializer, ): """Parse and validate search query string parameters.""" - fields_names = [ - *MediaSearchRequestSerializer.fields_names, + field_names = [ + *MediaSearchRequestSerializer.field_names, *AudioSearchRequestSourceSerializer.field_names, "category", "length", @@ -62,8 +80,7 @@ def needs_db(self) -> bool: return super().needs_db or self.data["peaks"] def validate_internal__index(self, value): - index = super().validate_internal__index(value) - if index is None: + if not (index := super().validate_internal__index(value)): return None if not index.startswith(AUDIO_TYPE): raise serializers.ValidationError(f"Invalid index name `{value}`.") diff --git a/api/api/serializers/image_serializers.py b/api/api/serializers/image_serializers.py index 5824198936c..d483a5d5e2e 100644 --- a/api/api/serializers/image_serializers.py +++ b/api/api/serializers/image_serializers.py @@ -33,8 +33,8 @@ class ImageSearchRequestSerializer( ): """Parse and validate search query string parameters.""" - fields_names = [ - *MediaSearchRequestSerializer.fields_names, + field_names = [ + *MediaSearchRequestSerializer.field_names, *ImageSearchRequestSourceSerializer.field_names, "category", "aspect_ratio", diff --git a/api/api/serializers/media_serializers.py b/api/api/serializers/media_serializers.py index a1e2b8de8a7..d8a52d4c1f8 100644 --- a/api/api/serializers/media_serializers.py +++ b/api/api/serializers/media_serializers.py @@ -26,6 +26,66 @@ ####################### +class PaginatedRequestSerializer(serializers.Serializer): + """This serializer passes pagination parameters from the query string.""" + + field_names = [ + "page_size", + "page", + ] + page_size = serializers.IntegerField( + label="page_size", + help_text=f"Number of results to return per page. " + f"Maximum is {settings.MAX_AUTHED_PAGE_SIZE} for authenticated " + f"requests, and {settings.MAX_ANONYMOUS_PAGE_SIZE} for " + f"unauthenticated requests.", + required=False, + default=settings.MAX_ANONYMOUS_PAGE_SIZE, + min_value=1, + ) + page = serializers.IntegerField( + label="page", + help_text="The page of results to retrieve.", + required=False, + default=1, + max_value=settings.MAX_PAGINATION_DEPTH, + min_value=1, + ) + + def validate_page_size(self, value): + request = self.context.get("request") + is_anonymous = bool(request and request.user and request.user.is_anonymous) + max_value = ( + settings.MAX_ANONYMOUS_PAGE_SIZE + if is_anonymous + else settings.MAX_AUTHED_PAGE_SIZE + ) + + validator = MaxValueValidator( + max_value, + message=serializers.IntegerField.default_error_messages["max_value"].format( + max_value=max_value + ), + ) + + if is_anonymous: + try: + validator(value) + except ValidationError as e: + raise NotAuthenticated( + detail=e.message, + code=e.code, + ) + else: + validator(value) + + return value + + @property + def needs_db(self) -> bool: + return False + + @extend_schema_serializer( # Hide unstable and internal fields from documentation. # Also see `field_names` below. @@ -38,7 +98,7 @@ "internal__index", ], ) -class MediaSearchRequestSerializer(serializers.Serializer): +class MediaSearchRequestSerializer(PaginatedRequestSerializer): """This serializer parses and validates search query string parameters.""" DeprecatedParam = namedtuple("DeprecatedParam", ["original", "successor"]) @@ -48,7 +108,7 @@ class MediaSearchRequestSerializer(serializers.Serializer): DeprecatedParam("pagesize", "page_size"), DeprecatedParam("provider", "source"), ] - fields_names = [ + field_names = [ "q", "license", "license_type", @@ -64,8 +124,7 @@ class MediaSearchRequestSerializer(serializers.Serializer): # "unstable__authority", # "unstable__authority_boost", # "unstable__include_sensitive_results", - "page_size", - "page", + *PaginatedRequestSerializer.field_names, ] """ Keep the fields names in sync with the actual fields below as this list is @@ -91,19 +150,30 @@ class MediaSearchRequestSerializer(serializers.Serializer): ) creator = serializers.CharField( label="creator", - help_text="Search by creator only. Cannot be used with `q`.", + help_text="Search by creator only. Cannot be used with `q`. The search " + "is fuzzy, so `creator=john` will match any value that includes the " + "word `john`. If the value contains space, items that contain any of " + "the words in the value will match. To search for several values, " + "join them with a comma.", required=False, max_length=200, ) tags = serializers.CharField( label="tags", - help_text="Search by tag only. Cannot be used with `q`.", + help_text="Search by tag only. Cannot be used with `q`. The search " + "is fuzzy, so `tags=cat` will match any value that includes the word " + "`cat`. If the value contains space, items that contain any of the " + "words in the value will match. To search for several values, join " + "them with a comma.", required=False, max_length=200, ) title = serializers.CharField( label="title", - help_text="Search by title only. Cannot be used with `q`.", + help_text="Search by title only. Cannot be used with `q`. The search is fuzzy," + " so `title=photo` will match any value that includes the word `photo`. " + "If the value contains space, items that contain any of the words in the " + "value will match. To search for several values, join them with a comma.", required=False, max_length=200, ) @@ -179,22 +249,6 @@ class MediaSearchRequestSerializer(serializers.Serializer): required=False, ) - page_size = serializers.IntegerField( - label="page_size", - help_text="Number of results to return per page.", - required=False, - default=settings.MAX_ANONYMOUS_PAGE_SIZE, - min_value=1, - ) - page = serializers.IntegerField( - label="page", - help_text="The page of results to retrieve.", - required=False, - default=1, - max_value=settings.MAX_PAGINATION_DEPTH, - min_value=1, - ) - def is_request_anonymous(self): request = self.context.get("request") return bool(request and request.user and request.user.is_anonymous) @@ -283,34 +337,6 @@ def validate_internal__index(self, value): raise serializers.ValidationError(f"Invalid index name `{value}`.") return value - def validate_page_size(self, value): - is_anonymous = self.is_request_anonymous() - max_value = ( - settings.MAX_ANONYMOUS_PAGE_SIZE - if is_anonymous - else settings.MAX_AUTHED_PAGE_SIZE - ) - - validator = MaxValueValidator( - max_value, - message=serializers.IntegerField.default_error_messages["max_value"].format( - max_value=max_value - ), - ) - - if is_anonymous: - try: - validator(value) - except ValidationError as e: - raise NotAuthenticated( - detail=e.message, - code=e.code, - ) - else: - validator(value) - - return value - @staticmethod def validate_extension(value): return value.lower() @@ -329,10 +355,6 @@ def validate(self, data): return data - @property - def needs_db(self) -> bool: - return False - class MediaThumbnailRequestSerializer(serializers.Serializer): """This serializer parses and validates thumbnail query string parameters.""" diff --git a/api/api/views/audio_views.py b/api/api/views/audio_views.py index c26fe8055b8..6457b233a09 100644 --- a/api/api/views/audio_views.py +++ b/api/api/views/audio_views.py @@ -7,16 +7,20 @@ from api.constants.media_types import AUDIO_TYPE from api.docs.audio_docs import ( + creator_collection, detail, related, report, search, + source_collection, stats, + tag_collection, thumbnail, waveform, ) from api.models import Audio from api.serializers.audio_serializers import ( + AudioCollectionRequestSerializer, AudioReportRequestSerializer, AudioSearchRequestSerializer, AudioSerializer, @@ -38,15 +42,43 @@ class AudioViewSet(MediaViewSet): """Viewset for all endpoints pertaining to audio.""" model_class = Audio + media_type = AUDIO_TYPE query_serializer_class = AudioSearchRequestSerializer default_index = settings.MEDIA_INDEX_MAPPING[AUDIO_TYPE] serializer_class = AudioSerializer + collection_serializer_class = AudioCollectionRequestSerializer def get_queryset(self): return super().get_queryset().select_related("mature_audio", "audioset") # Extra actions + @creator_collection + @action( + detail=False, + methods=["get"], + url_path="source/(?P[^/.]+)/creator/(?P.+)", + ) + def creator_collection(self, request, source, creator): + return super().creator_collection(request, source, creator) + + @source_collection + @action( + detail=False, + methods=["get"], + url_path="source/(?P[^/.]+)", + ) + def source_collection(self, request, source): + return super().source_collection(request, source) + + @tag_collection + @action( + detail=False, + methods=["get"], + url_path="tag/(?P[^/.]+)", + ) + def tag_collection(self, request, tag, *_, **__): + return super().tag_collection(request, tag, *_, **__) @thumbnail @action( diff --git a/api/api/views/image_views.py b/api/api/views/image_views.py index 6f138908db8..641a230f384 100644 --- a/api/api/views/image_views.py +++ b/api/api/views/image_views.py @@ -13,12 +13,15 @@ from api.constants.media_types import IMAGE_TYPE from api.docs.image_docs import ( + creator_collection, detail, oembed, related, report, search, + source_collection, stats, + tag_collection, thumbnail, ) from api.docs.image_docs import watermark as watermark_doc @@ -31,7 +34,10 @@ OembedSerializer, WatermarkRequestSerializer, ) -from api.serializers.media_serializers import MediaThumbnailRequestSerializer +from api.serializers.media_serializers import ( + MediaThumbnailRequestSerializer, + PaginatedRequestSerializer, +) from api.utils.throttle import AnonThumbnailRateThrottle, OAuth2IdThumbnailRateThrottle from api.utils.watermark import watermark from api.views.media_views import MediaViewSet @@ -48,10 +54,12 @@ class ImageViewSet(MediaViewSet): """Viewset for all endpoints pertaining to images.""" model_class = Image + media_type = IMAGE_TYPE query_serializer_class = ImageSearchRequestSerializer default_index = settings.MEDIA_INDEX_MAPPING[IMAGE_TYPE] serializer_class = ImageSerializer + collection_serializer_class = PaginatedRequestSerializer OEMBED_HEADERS = { "User-Agent": settings.OUTBOUND_USER_AGENT_TEMPLATE.format(purpose="OEmbed"), @@ -61,6 +69,32 @@ def get_queryset(self): return super().get_queryset().select_related("mature_image") # Extra actions + @creator_collection + @action( + detail=False, + methods=["get"], + url_path="source/(?P[^/.]+)/creator/(?P.+)", + ) + def creator_collection(self, request, source, creator): + return super().creator_collection(request, source, creator) + + @source_collection + @action( + detail=False, + methods=["get"], + url_path="source/(?P[^/.]+)", + ) + def source_collection(self, request, source, *_, **__): + return super().source_collection(request, source) + + @tag_collection + @action( + detail=False, + methods=["get"], + url_path="tag/(?P[^/.]+)", + ) + def tag_collection(self, request, tag, *_, **__): + return super().tag_collection(request, tag, *_, **__) @oembed @action( diff --git a/api/api/views/media_views.py b/api/api/views/media_views.py index a0f28f68d55..ecf6005f410 100644 --- a/api/api/views/media_views.py +++ b/api/api/views/media_views.py @@ -1,4 +1,5 @@ import logging +from typing import Union from rest_framework import status from rest_framework.decorators import action @@ -6,10 +7,13 @@ from rest_framework.response import Response from rest_framework.viewsets import ReadOnlyModelViewSet +from api.constants.media_types import MediaType +from api.constants.search import SearchStrategy from api.controllers import search_controller from api.controllers.elasticsearch.related import related_media from api.models import ContentProvider from api.models.media import AbstractMedia +from api.serializers import audio_serializers, media_serializers from api.serializers.provider_serializers import ProviderSerializer from api.utils import image_proxy from api.utils.pagination import StandardPagination @@ -18,6 +22,18 @@ logger = logging.getLogger(__name__) +MediaListRequestSerializer = Union[ + audio_serializers.AudioCollectionRequestSerializer, + media_serializers.PaginatedRequestSerializer, + media_serializers.MediaSearchRequestSerializer, +] + + +class InvalidSource(APIException): + status_code = 400 + default_detail = "Invalid source." + default_code = "invalid_source" + class MediaViewSet(ReadOnlyModelViewSet): lookup_field = "identifier" @@ -30,13 +46,16 @@ class MediaViewSet(ReadOnlyModelViewSet): # Populate these in the corresponding subclass model_class: type[AbstractMedia] = None + media_type: MediaType | None = None query_serializer_class = None + collection_serializer_class = None default_index = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) required_fields = [ self.model_class, + self.media_type, self.query_serializer_class, self.default_index, ] @@ -101,12 +120,64 @@ def retrieve(self, request, *_, **__): def list(self, request, *_, **__): params = self._get_request_serializer(request) + return self.get_media_results(request, "search", params) + + def _validate_source(self, source): + valid_sources = search_controller.get_sources(self.media_type) + if source not in valid_sources: + valid_string = ", ".join([f"'{k}'" for k in valid_sources.keys()]) + raise InvalidSource( + detail=f"Invalid source '{source}'. Valid sources are: {valid_string}.", + ) + def collection(self, request, tag, source, creator, *_, **__): + if tag: + collection_params = {"tag": tag} + elif creator: + collection_params = {"creator": creator, "source": source} + else: + collection_params = {"source": source} + if source: + self._validate_source(source) + + params = self.collection_serializer_class( + data=request.query_params, context={"request": request} + ) + params.is_valid(raise_exception=True) + + return self.get_media_results(request, "collection", params, collection_params) + + @action(detail=False, methods=["get"], url_path="tag/(?P[^/.]+)") + def tag_collection(self, request, tag, *_, **__): + tag_lower = tag.lower() + return self.collection(request, tag_lower, None, None) + + @action(detail=False, methods=["get"], url_path="source/(?P[^/.]+)") + def source_collection(self, request, source, *_, **__): + return self.collection(request, None, source, None) + + @action( + detail=False, + methods=["get"], + url_path="source/(?P[^/.]+)/creator/(?P.+)", + ) + def creator_collection(self, request, source, creator): + return self.collection(request, None, source, creator) + + # Common functionality for search and collection views + + def get_media_results( + self, + request, + strategy: SearchStrategy, + params: MediaListRequestSerializer, + collection_params: dict[str, str] | None = None, + ): page_size = self.paginator.page_size = params.data["page_size"] page = self.paginator.page = params.data["page"] hashed_ip = hash(self._get_user_ip(request)) - filter_dead = params.validated_data["filter_dead"] + filter_dead = params.validated_data.get("filter_dead", True) if pref_index := params.validated_data.get("index"): logger.info(f"Using preferred index {pref_index} for media.") @@ -118,8 +189,15 @@ def list(self, request, *_, **__): exact_index = False try: - results, num_pages, num_results, search_context = search_controller.search( + ( + results, + num_pages, + num_results, + search_context, + ) = search_controller.query_media( + strategy, params, + collection_params, search_index, exact_index, page_size, diff --git a/api/test/media_integration.py b/api/test/media_integration.py index c1704f12047..3c4a30f8e73 100644 --- a/api/test/media_integration.py +++ b/api/test/media_integration.py @@ -5,6 +5,7 @@ """ import json +import re from test.constants import API_URL import requests @@ -26,6 +27,46 @@ def search_by_category(media_path, category, fixture): assert all(audio_item["category"] == category for audio_item in results) +def tag_collection(media_path): + response = requests.get(f"{API_URL}/v1/{media_path}/tag/cat") + assert response.status_code == 200 + + results = response.json()["results"] + for r in results: + tag_names = [tag["name"] for tag in r["tags"]] + assert "cat" in tag_names + + +def source_collection(media_path): + source = requests.get(f"{API_URL}/v1/{media_path}/stats").json()[0]["source_name"] + + response = requests.get(f"{API_URL}/v1/{media_path}/source/{source}") + assert response.status_code == 200 + + results = response.json()["results"] + assert all(result["source"] == source for result in results) + + +def creator_collection(media_path): + source = requests.get(f"{API_URL}/v1/{media_path}/stats").json()[0]["source_name"] + + first_res = requests.get(f"{API_URL}/v1/{media_path}/source/{source}").json()[ + "results" + ][0] + if not (creator := first_res.get("creator")): + raise AttributeError(f"No creator in {first_res}") + + response = requests.get( + f"{API_URL}/v1/{media_path}/source/{source}/creator/{creator}" + ) + assert response.status_code == 200 + + results = response.json()["results"] + for result in results: + assert result["source"] == source, f"{result['source']} != {source}" + assert result["creator"] == creator, f"{result['creator']} != {creator}" + + def search_all_excluded(media_path, excluded_source): response = requests.get( f"{API_URL}/v1/{media_path}?q=test&excluded_source={','.join(excluded_source)}" @@ -159,7 +200,9 @@ def related(fixture): assert response["page_count"] == 1 def get_terms_set(res): - return set([t["name"] for t in res["tags"]] + res["title"].split(" ")) + # The title is analyzed in ES, we try to mimic it here. + terms = [t["name"] for t in res["tags"]] + re.split(" |-", res["title"]) + return {t.lower() for t in terms} terms_set = get_terms_set(item) # Make sure each result has at least one word in common with the original item, @@ -168,7 +211,7 @@ def get_terms_set(res): assert ( len(terms_set.intersection(get_terms_set(result))) > 0 or result["creator"] == item["creator"] - ) + ), f"{terms_set} {get_terms_set(result)}/{result['creator']}-{item['creator']}" def sensitive_search_and_detail(media_type): diff --git a/api/test/test_audio_integration.py b/api/test/test_audio_integration.py index 9c4aa0210da..c9e355739af 100644 --- a/api/test/test_audio_integration.py +++ b/api/test/test_audio_integration.py @@ -8,6 +8,7 @@ import json from test.constants import API_URL from test.media_integration import ( + creator_collection, detail, license_filter_case_insensitivity, related, @@ -21,7 +22,9 @@ search_source_and_excluded, search_special_chars, sensitive_search_and_detail, + source_collection, stats, + tag_collection, uuid_validation, ) @@ -160,5 +163,17 @@ def test_audio_related(audio_fixture): related(audio_fixture) +def test_audio_tag_collection(): + tag_collection("audio") + + +def test_audio_source_collection(): + source_collection("audio") + + +def test_audio_creator_collection(): + creator_collection("audio") + + def test_audio_sensitive_search_and_detail(): sensitive_search_and_detail("audio") diff --git a/api/test/test_image_integration.py b/api/test/test_image_integration.py index 36fe7a33f67..f87d3fc16c9 100644 --- a/api/test/test_image_integration.py +++ b/api/test/test_image_integration.py @@ -8,6 +8,7 @@ import json from test.constants import API_URL from test.media_integration import ( + creator_collection, detail, license_filter_case_insensitivity, related, @@ -20,7 +21,9 @@ search_source_and_excluded, search_special_chars, sensitive_search_and_detail, + source_collection, stats, + tag_collection, uuid_validation, ) from urllib.parse import urlencode @@ -145,6 +148,18 @@ def test_image_uuid_validation(): uuid_validation("images", "abcd") +def test_image_tag_collection(): + tag_collection("images") + + +def test_image_source_collection(): + source_collection("images") + + +def test_image_creator_collection(): + creator_collection("images") + + def test_image_related(image_fixture): related(image_fixture) diff --git a/api/test/unit/controllers/test_search_controller.py b/api/test/unit/controllers/test_search_controller.py index b344a1074d6..764fffca21f 100644 --- a/api/test/unit/controllers/test_search_controller.py +++ b/api/test/unit/controllers/test_search_controller.py @@ -456,8 +456,10 @@ def test_search_tallies_pages_less_than_5( ) serializer.is_valid() - search_controller.search( + search_controller.query_media( + strategy="search", search_params=serializer, + collection_params=None, ip=0, origin_index=media_type_config.origin_index, exact_index=False, @@ -495,8 +497,10 @@ def test_search_tallies_handles_empty_page( serializer = media_type_config.search_request_serializer(data={"q": "dogs"}) serializer.is_valid() - search_controller.search( + search_controller.query_media( + strategy="search", search_params=serializer, + collection_params=None, ip=0, origin_index=media_type_config.origin_index, exact_index=False, @@ -538,8 +542,10 @@ def test_resolves_index( ) serializer.is_valid() - search_controller.search( + search_controller.query_media( + strategy="search", search_params=serializer, + collection_params=None, ip=0, origin_index=origin_index, exact_index=False, @@ -605,8 +611,10 @@ def test_no_post_process_results_recursion( data={"q": "bird perched"} ) serializer.is_valid() - results, _, _, _ = search_controller.search( + results, _, _, _ = search_controller.query_media( + strategy="search", search_params=serializer, + collection_params=None, ip=0, origin_index=image_media_type_config.origin_index, exact_index=True, @@ -743,8 +751,10 @@ def test_post_process_results_recurses_as_needed( data={"q": "bird perched"} ) serializer.is_valid() - results, _, _, _ = search_controller.search( + results, _, _, _ = search_controller.query_media( + strategy="search", search_params=serializer, + collection_params=None, ip=0, origin_index=image_media_type_config.origin_index, exact_index=True, @@ -785,8 +795,10 @@ def _delete_all_results_but_first(_, __, results, ___): serializer.is_valid() with caplog.at_level(logging.INFO): - results, _, _, _ = search_controller.search( + results, _, _, _ = search_controller.query_media( + strategy="search", search_params=serializer, + collection_params=None, ip=0, origin_index=image_media_type_config.origin_index, exact_index=True, diff --git a/api/test/unit/controllers/test_search_controller_search_query.py b/api/test/unit/controllers/test_search_controller_search_query.py index 005c1bbbe10..c61f12cc4de 100644 --- a/api/test/unit/controllers/test_search_controller_search_query.py +++ b/api/test/unit/controllers/test_search_controller_search_query.py @@ -1,6 +1,7 @@ from django.core.cache import cache import pytest +from elasticsearch_dsl import Q from api.controllers import search_controller @@ -23,7 +24,7 @@ def excluded_providers_cache(): def test_create_search_query_empty(media_type_config): serializer = media_type_config.search_request_serializer(data={}) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -39,7 +40,7 @@ def test_create_search_query_empty_no_ranking(media_type_config, settings): settings.USE_RANK_FEATURES = False serializer = media_type_config.search_request_serializer(data={}) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -51,7 +52,7 @@ def test_create_search_query_empty_no_ranking(media_type_config, settings): def test_create_search_query_q_search_no_filters(media_type_config): serializer = media_type_config.search_request_serializer(data={"q": "cat"}) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -83,7 +84,7 @@ def test_create_search_query_q_search_with_quotes_adds_exact_suffix(media_type_c data={"q": '"The cutest cat"'} ) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -127,7 +128,7 @@ def test_create_search_query_q_search_with_filters(image_media_type_config): } ) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -169,7 +170,7 @@ def test_create_search_query_non_q_query(image_media_type_config): } ) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -200,7 +201,7 @@ def test_create_search_query_q_search_license_license_type_creates_2_terms_filte } ) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] first_license_terms_filter = actual_query_clauses["filter"][0] @@ -235,7 +236,7 @@ def test_create_search_query_empty_with_dynamically_excluded_providers( serializer = image_media_type_config.search_request_serializer(data={}) serializer.is_valid(raise_exception=True) - search_query = search_controller.create_search_query(serializer) + search_query = search_controller.build_search_query(serializer) actual_query_clauses = search_query.to_dict()["bool"] assert actual_query_clauses == { @@ -248,3 +249,44 @@ def test_create_search_query_empty_with_dynamically_excluded_providers( {"rank_feature": {"boost": 10000, "field": "standardized_popularity"}} ], } + + +@pytest.mark.parametrize( + ("data", "expected_query_filter"), + [ + pytest.param( + {"tag": "art"}, + [{"term": {"tags.name.keyword": "art"}}], + id="filter_by_tag", + ), + pytest.param( + {"tag": "art, photography"}, + [{"term": {"tags.name.keyword": "art, photography"}}], + id="filter_by_tag_treats_punctuation_as_part_of_tag", + ), + pytest.param( + {"source": "flickr"}, + [{"term": {"source": "flickr"}}], + id="filter_by_source", + ), + pytest.param( + {"source": "flickr", "creator": "nasa"}, + [ + {"term": {"source": "flickr"}}, + {"term": {"creator.keyword": "nasa"}}, + ], + id="filter_by_creator", + ), + ], +) +def test_build_collection_query(image_media_type_config, data, expected_query_filter): + serializer = image_media_type_config.search_request_serializer(data={}) + serializer.is_valid(raise_exception=True) + actual_query = search_controller.build_collection_query(serializer, data) + expected_query = Q( + "bool", + filter=expected_query_filter, + must_not=[{"term": {"mature": True}}], + ) + + assert actual_query == expected_query diff --git a/api/test/unit/models/test_media_report.py b/api/test/unit/models/test_media_report.py index 0433647015a..e936c47d356 100644 --- a/api/test/unit/models/test_media_report.py +++ b/api/test/unit/models/test_media_report.py @@ -1,5 +1,4 @@ import uuid -from typing import Literal, Union from django.core.exceptions import ObjectDoesNotExist @@ -22,8 +21,6 @@ pytestmark = pytest.mark.django_db -MediaType = Union[Literal["audio"], Literal["image"]] - reason_params = pytest.mark.parametrize("reason", [DMCA, MATURE, OTHER]) diff --git a/api/test/unit/serializers/test_media_serializers.py b/api/test/unit/serializers/test_media_serializers.py index 46edaedea8f..d48798c40cb 100644 --- a/api/test/unit/serializers/test_media_serializers.py +++ b/api/test/unit/serializers/test_media_serializers.py @@ -195,17 +195,17 @@ def test_index_is_only_set_if_authenticated( mock_es.indices.exists.return_value = True request = authed_request if authenticated else anon_request - serializer = MediaSearchRequestSerializer( - data={"internal__index": "some-index"}, context={"request": request} + serializer = ImageSearchRequestSerializer( + data={"internal__index": "image-some-index"}, context={"request": request} ) assert serializer.is_valid() assert serializer.validated_data.get("index") == ( - "some-index" if authenticated else None + "image-some-index" if authenticated else None ) if authenticated: # If authenticated, we should have checked that the index exists. - mock_es.indices.exists.assert_called_with("some-index") + mock_es.indices.exists.assert_called_with("image-some-index") else: # If not authenticated, the validator quickly returns ``None``. mock_es.indices.exists.assert_not_called() @@ -215,12 +215,12 @@ def test_index_is_only_set_if_authenticated( @patch("django.conf.settings.ES") @pytest.mark.parametrize( "index, is_valid", - (("index-that-exists", True), ("index-that-does-not-exist", False)), + (("image-index-that-exists", True), ("image-index-that-does-not-exist", False)), ) def test_index_is_only_set_if_valid(mock_es, index, is_valid, authed_request): mock_es.indices.exists = lambda index: "exists" in index - serializer = MediaSearchRequestSerializer( + serializer = ImageSearchRequestSerializer( data={"internal__index": index}, context={"request": authed_request} ) assert serializer.is_valid() == is_valid diff --git a/api/test/unit/views/test_media_views.py b/api/test/unit/views/test_media_views.py index b94e309d168..e95c320ce97 100644 --- a/api/test/unit/views/test_media_views.py +++ b/api/test/unit/views/test_media_views.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 +from rest_framework.response import Response + import pytest import pytest_django.asserts @@ -26,7 +28,7 @@ def test_list_query_count(api_client, media_type_config): ) with patch( "api.views.media_views.search_controller", - search=MagicMock(return_value=controller_ret), + query_media=MagicMock(return_value=controller_ret), ), patch( "api.serializers.media_serializers.search_controller", get_sources=MagicMock(return_value={}), @@ -49,6 +51,36 @@ def test_retrieve_query_count(api_client, media_type_config): assert res.status_code == 200 +@pytest.mark.django_db +@pytest.mark.parametrize( + "path, expected_params", + [ + pytest.param("tag/cat/", {"tag": "cat"}, id="tag"), + pytest.param("source/flickr/", {"source": "flickr"}, id="source"), + pytest.param( + "source/flickr/creator/cat/", + {"source": "flickr", "creator": "cat"}, + id="source_creator", + ), + ], +) +def test_collection_parameters(path, expected_params, api_client): + mock_get_media_results = MagicMock(return_value=Response()) + + with patch( + "api.views.media_views.MediaViewSet.get_media_results", + new_callable=lambda: mock_get_media_results, + ) as mock_get_media_results: + api_client.get(f"/v1/images/{path}") + + actual_params = mock_get_media_results.call_args[0][3] + request_kind = mock_get_media_results.call_args[0][1] + + assert mock_get_media_results.called + assert actual_params == expected_params + assert request_kind == "collection" + + @pytest.mark.parametrize( "filter_content", (True, False), ids=lambda x: "filtered" if x else "not_filtered" )