diff --git a/api/api/serializers/media_serializers.py b/api/api/serializers/media_serializers.py index c47219b693b..9bc3988cf0b 100644 --- a/api/api/serializers/media_serializers.py +++ b/api/api/serializers/media_serializers.py @@ -1,17 +1,21 @@ import logging from collections import namedtuple +from typing import TypedDict from django.conf import settings from django.core.exceptions import ValidationError as DjangoValidationError from django.core.validators import MaxValueValidator +from django.urls import reverse from rest_framework import serializers from rest_framework.exceptions import NotAuthenticated, ValidationError +from rest_framework.request import Request from drf_spectacular.utils import extend_schema_serializer from elasticsearch_dsl.response import Hit from api.constants import sensitivity from api.constants.licenses import LICENSE_GROUPS +from api.constants.media_types import MediaType from api.constants.parameters import COLLECTION, TAG from api.constants.sorting import DESCENDING, RELEVANCE, SORT_DIRECTIONS, SORT_FIELDS from api.controllers import search_controller @@ -294,8 +298,16 @@ class MediaSearchRequestSerializer(PaginatedRequestSerializer): required=False, ) + class Context(TypedDict, total=True): + warnings: list[dict] + media_type: MediaType + request: Request + + context: Context + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.context["warnings"] = [] self.media_type = self.context.get("media_type") if not self.media_type: raise ValueError( @@ -398,15 +410,34 @@ def validate_source(self, value): ) return value else: - sources = value.lower().split(",") - valid_sources = set( - [source for source in sources if source in allowed_sources] - ) - if len(sources) > len(valid_sources): - invalid_sources = set(sources).difference(valid_sources) - logger.warning( - f"Invalid sources in search query: {invalid_sources}; sources query: '{value}'" + sources = set(value.lower().split(",")) + valid_sources = {source for source in sources if source in allowed_sources} + if not valid_sources: + # Raise only if there are _no_ valid sources selected + # If the requester passed only `mispelled_museum_name1,mispelled_musesum_name2` + # the request cannot move forward, as all the top responses will likely be from Flickr + # which provides radically different responses than most other providers. + # If even one source is valid, it won't be a problem, in which case we'll issue a warning + raise serializers.ValidationError( + f"Invalid source parameter '{value}'. No valid sources selected. " + f"Refer to the source list for valid options: {sources_list}." ) + elif invalid_sources := (sources - valid_sources): + available_sources_uri = self.context["request"].build_absolute_uri( + reverse(f"{self.media_type}-stats") + ) + self.context["warnings"].append( + { + "code": "partially invalid source parameter", + "message": ( + "The source parameter included non-existent sources. " + f"For a list of available sources, see {available_sources_uri}" + ), + "invalid_sources": invalid_sources, + "valid_sources": valid_sources, + } + ) + return ",".join(valid_sources) def validate_excluded_source(self, input_sources): diff --git a/api/api/utils/pagination.py b/api/api/utils/pagination.py index 85e2cf24b3a..8c3534504ec 100644 --- a/api/api/utils/pagination.py +++ b/api/api/utils/pagination.py @@ -7,21 +7,36 @@ class StandardPagination(PageNumberPagination): page_size_query_param = "page_size" page_query_param = "page" + result_count: int | None + page_count: int | None + page: int + warnings: list[dict] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.result_count = None # populated later self.page_count = None # populated later self.page = 1 # default, gets updated when necessary + self.warnings = [] # populated later as needed def get_paginated_response(self, data): + response = { + "result_count": self.result_count, + "page_count": min(settings.MAX_PAGINATION_DEPTH, self.page_count), + "page_size": self.page_size, + "page": self.page, + "results": data, + } return Response( - { - "result_count": self.result_count, - "page_count": min(settings.MAX_PAGINATION_DEPTH, self.page_count), - "page_size": self.page_size, - "page": self.page, - "results": data, - } + ( + { + # Put ``warnings`` first so it is as visible as possible. + "warnings": list(self.warnings), + } + if self.warnings + else {} + ) + | response ) def get_paginated_response_schema(self, schema): @@ -39,15 +54,45 @@ def get_paginated_response_schema(self, schema): "page_size": ("The number of items per page.", 20), "page": ("The current page number returned in the response.", 1), } + + properties = { + field: { + "type": "integer", + "description": description, + "example": example, + } + for field, (description, example) in field_descriptions.items() + } | { + "results": schema, + "warnings": { + "type": "array", + "items": { + "type": "object", + }, + "description": ( + "Warnings pertinent to the request. " + "If there are no warnings, this property will not be present on the response. " + "Warnings are non-critical problems with the request. " + "Responses with warnings should be treated as unstable. " + "Warning descriptions must not be treated as machine readable " + "and their schema can change at any time." + ), + "example": [ + { + "code": "partially invalid request parameter", + "message": ( + "Some of the request parameters were bad, " + "but we processed the request anywhere. " + "Here's some information that might help you " + "fix the problem for future requests." + ), + } + ], + }, + } + return { "type": "object", - "properties": { - field: { - "type": "integer", - "description": description, - "example": example, - } - for field, (description, example) in field_descriptions.items() - } - | {"results": schema}, + "properties": properties, + "required": list(set(properties.keys()) - {"warnings"}), } diff --git a/api/api/views/media_views.py b/api/api/views/media_views.py index b242eceb833..1018ef94c78 100644 --- a/api/api/views/media_views.py +++ b/api/api/views/media_views.py @@ -160,6 +160,7 @@ def get_media_results( ): page_size = self.paginator.page_size = params.data["page_size"] page = self.paginator.page = params.data["page"] + self.paginator.warnings = params.context["warnings"] hashed_ip = hash(self._get_user_ip(request)) filter_dead = params.validated_data.get("filter_dead", True) diff --git a/api/test/fixtures/rest_framework.py b/api/test/fixtures/rest_framework.py index 3359b0a81df..f96eeb0aacd 100644 --- a/api/test/fixtures/rest_framework.py +++ b/api/test/fixtures/rest_framework.py @@ -4,7 +4,7 @@ @pytest.fixture -def api_client(): +def api_client() -> APIClient: return APIClient() diff --git a/api/test/integration/test_media_integration.py b/api/test/integration/test_media_integration.py index c933e57cc5e..68bb1e8d797 100644 --- a/api/test/integration/test_media_integration.py +++ b/api/test/integration/test_media_integration.py @@ -329,6 +329,50 @@ def test_detail_view_for_invalid_uuids_returns_not_found( assert res.status_code == 404 +def test_search_with_only_valid_sources_produces_no_warning(media_type, api_client): + search = api_client.get( + f"/v1/{media_type.path}/", + {"source": ",".join(media_type.providers)}, + ) + assert search.status_code == 200 + assert "warnings" not in search.json() + + +def test_search_with_partially_invalid_sources_produces_warning_but_still_succeeds( + media_type: MediaType, api_client +): + invalid_sources = [ + "surely_neither_this_one", + "this_is_sure_not_to_ever_be_a_real_source_name", + ] + + search = api_client.get( + f"/v1/{media_type.path}/", + {"source": ",".join([media_type.providers[0]] + invalid_sources)}, + ) + assert search.status_code == 200 + result = search.json() + + assert {w["code"] for w in result["warnings"]} == { + "partially invalid source parameter" + } + warning = result["warnings"][0] + assert set(warning["invalid_sources"]) == set(invalid_sources) + assert warning["valid_sources"] == [media_type.providers[0]] + assert f"v1/{media_type.path}/stats/" in warning["message"] + + +def test_search_with_all_invalid_sources_fails(media_type, api_client): + invalid_sources = [ + "this_is_sure_not_to_ever_be_a_real_source_name", + "surely_neither_this_one", + ] + search = api_client.get( + f"/v1/{media_type.path}/", {"source": ",".join(invalid_sources)} + ) + assert search.status_code == 400 + + def test_detail_view_returns_ok(single_result, api_client): media_type, item = single_result res = api_client.get(f"/v1/{media_type.path}/{item['id']}/") diff --git a/api/test/unit/serializers/test_media_serializers.py b/api/test/unit/serializers/test_media_serializers.py index ea2500a2681..79059210ee4 100644 --- a/api/test/unit/serializers/test_media_serializers.py +++ b/api/test/unit/serializers/test_media_serializers.py @@ -99,25 +99,23 @@ def test_media_serializer_adds_license_url_if_missing( assert repr["license_url"] == "https://creativecommons.org/publicdomain/zero/1.0/" -def test_media_serializer_logs_when_invalid_or_duplicate_source(media_type_config): +def test_media_serializer_recovers_invalid_or_duplicate_source( + media_type_config, request_factory +): sources = { "image": ("flickr,flickr,invalid", "flickr"), "audio": ("freesound,freesound,invalid", "freesound"), } - with patch("api.serializers.media_serializers.logger.warning") as mock_logger: - serializer_class = media_type_config.search_request_serializer( - context={"media_type": media_type_config.media_type}, - data={"source": sources[media_type_config.media_type][0]}, - ) - assert serializer_class.is_valid() - assert ( - serializer_class.validated_data["source"] - == sources[media_type_config.media_type][1] - ) - mock_logger.assert_called_with( - f"Invalid sources in search query: {{'invalid'}}; " - f"sources query: '{sources[media_type_config.media_type][0]}'" - ) + request = request_factory.get("/v1/images/") + serializer_class = media_type_config.search_request_serializer( + context={"media_type": media_type_config.media_type, "request": request}, + data={"source": sources[media_type_config.media_type][0]}, + ) + assert serializer_class.is_valid() + assert ( + serializer_class.validated_data["source"] + == sources[media_type_config.media_type][1] + ) @pytest.mark.parametrize(